@Override protected KMeans fit() { synchronized (fitLock) { if (null != labels) // already fit return this; final LogTimer timer = new LogTimer(); final double[][] X = data.getData(); final int n = data.getColumnDimension(); final double nan = Double.NaN; // Corner case: K = 1 or all singular values if (1 == k) { labelFromSingularK(X); fitSummary.add(new Object[] {iter, converged, tss, tss, nan, timer.wallTime()}); sayBye(timer); return this; } // Nearest centroid model to predict labels NearestCentroid model = null; EntryPair<int[], double[]> label_dist; // Keep track of TSS (sum of barycentric distances) double last_wss_sum = Double.POSITIVE_INFINITY, wss_sum = 0; ArrayList<double[]> new_centroids; for (iter = 0; iter < maxIter; iter++) { // Get labels for nearest centroids try { model = new NearestCentroid( CentroidUtils.centroidsToMatrix(centroids, false), VecUtils.arange(k), new NearestCentroidParameters() .setSeed(getSeed()) .setMetric(getSeparabilityMetric()) .setVerbose(false)) .fit(); } catch (NaNException NaN) { /* * If they metric used produces lots of infs or -infs, it * makes it hard if not impossible to effectively segment the * input space. Thus, the centroid assignment portion below can * yield a zero count (denominator) for one or more of the centroids * which makes the entire row NaN. We should tell the user to * try a different metric, if that's the case. * error(new IllegalClusterStateException(dist_metric.getName()+" produced an entirely " + "infinite distance matrix, making it difficult to segment the input space. Try a different " + "metric.")); */ this.k = 1; warn( "(dis)similarity metric (" + dist_metric + ") cannot partition space without propagating Infs. Returning one cluster"); labelFromSingularK(X); fitSummary.add(new Object[] {iter, converged, tss, tss, nan, timer.wallTime()}); sayBye(timer); return this; } label_dist = model.predict(X); // unpack the EntryPair labels = label_dist.getKey(); new_centroids = new ArrayList<>(k); int label; wss = new double[k]; int[] centroid_counts = new int[k]; double[] centroid; double[][] new_centroid_arrays = new double[k][n]; for (int i = 0; i < m; i++) { label = labels[i]; centroid = centroids.get(label); // increment count for this centroid double this_cost = 0; centroid_counts[label]++; for (int j = 0; j < centroid.length; j++) { double diff = X[i][j] - centroid[j]; this_cost += (diff * diff); // Add the the centroid sums new_centroid_arrays[label][j] += X[i][j]; } // add this cost to the WSS wss[label] += this_cost; } // one pass of K for some consolidation wss_sum = 0; for (int i = 0; i < k; i++) { wss_sum += wss[i]; for (int j = 0; j < n; j++) // meanify new_centroid_arrays[i][j] /= (double) centroid_counts[i]; new_centroids.add(new_centroid_arrays[i]); } // update the BSS bss = tss - wss_sum; // Assign new centroids double diff = last_wss_sum - wss_sum; last_wss_sum = wss_sum; // Check for convergence and add summary: converged = FastMath.abs(diff) < tolerance; // first iter will be inf fitSummary.add( new Object[] { converged ? iter++ : iter, converged, tss, wss_sum, bss, timer.wallTime() }); if (converged) { break; } else { // otherwise, reassign centroids centroids = new_centroids; } } // end iterations // Reorder the labels, centroids and wss indices reorderLabelsAndCentroids(); if (!converged) warn("algorithm did not converge"); // wrap things up, create summary.. sayBye(timer); return this; } }