public static KMeansScore score(KMeansModel model, ValueArray ary) { KMeansScore kms = new KMeansScore(); kms._arykey = ary._key; kms._cols = model.columnMapping(ary.colNames()); kms._clusters = model._clusters; kms._normalized = model._normalized; kms.invoke(ary._key); return kms; }
/** * This helper creates a ModelMetricsClustering from a trained model * * @param model, must contain valid statistics from training, such as _betweenss etc. */ private ModelMetricsClustering makeTrainingMetrics(KMeansModel model) { ModelMetricsClustering mm = new ModelMetricsClustering(model, model._parms.train()); mm._size = model._output._size; mm._withinss = model._output._withinss; mm._betweenss = model._output._betweenss; mm._totss = model._output._totss; mm._tot_withinss = model._output._tot_withinss; model.addMetrics(mm); return mm; }
// Main worker thread @Override protected void compute2() { KMeansModel model = null; try { init(true); // Do lock even before checking the errors, since this block is finalized by unlock // (not the best solution, but the code is more readable) _parms.read_lock_frames(KMeans.this); // Fetch & read-lock input frames // Something goes wrong if (error_count() > 0) throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(KMeans.this); // The model to be built model = new KMeansModel(dest(), _parms, new KMeansModel.KMeansOutput(KMeans.this)); model.delete_and_lock(_key); // final Vec vecs[] = _train.vecs(); // mults & means for standardization final double[] means = _train.means(); // means are used to impute NAs final double[] mults = _parms._standardize ? _train.mults() : null; final int[] impute_cat = new int[vecs.length]; for (int i = 0; i < vecs.length; i++) impute_cat[i] = vecs[i].isNumeric() ? -1 : DataInfo.imputeCat(vecs[i]); model._output._normSub = means; model._output._normMul = mults; // Initialize cluster centers and standardize if requested double[][] centers = initial_centers(model, vecs, means, mults, impute_cat); if (centers == null) return; // Stopped/cancelled during center-finding double[][] oldCenters = null; // --- // Run the main KMeans Clustering loop // Stop after enough iterations or average_change < TOLERANCE model._output._iterations = 0; // Loop ends only when iterations > max_iterations with strict inequality while (!isDone(model, centers, oldCenters)) { Lloyds task = new Lloyds(centers, means, mults, impute_cat, _isCats, _parms._k, hasWeightCol()) .doAll(vecs); // Pick the max categorical level for cluster center max_cats(task._cMeans, task._cats, _isCats); // Handle the case where some centers go dry. Rescue only 1 cluster // per iteration ('cause we only tracked the 1 worst row) if (cleanupBadClusters(task, vecs, centers, means, mults, impute_cat)) continue; // Compute model stats; update standardized cluster centers oldCenters = centers; centers = computeStatsFillModel(task, model, vecs, means, mults, impute_cat); model.update(_key); // Update model in K/V store update(1); // One unit of work if (model._parms._score_each_iteration) Log.info(model._output._model_summary); } Log.info(model._output._model_summary); // Log.info(model._output._scoring_history); // // Log.info(((ModelMetricsClustering)model._output._training_metrics).createCentroidStatsTable().toString()); // At the end: validation scoring (no need to gather scoring history) if (_valid != null) { model.score(_parms.valid()).delete(); // this appends a ModelMetrics on the validation set model._output._validation_metrics = ModelMetrics.getFromDKV(model, _parms.valid()); model.update(_key); // Update model in K/V store } done(); // Job done! } catch (Throwable t) { Job thisJob = DKV.getGet(_key); if (thisJob._state == JobState.CANCELLED) { Log.info("Job cancelled by user."); } else { t.printStackTrace(); failed(t); throw t; } } finally { updateModelOutput(); if (model != null) model.unlock(_key); _parms.read_unlock_frames(KMeans.this); } tryComplete(); }
// Initialize cluster centers double[][] initial_centers( KMeansModel model, final Vec[] vecs, final double[] means, final double[] mults, final int[] modes) { // Categoricals use a different distance metric than numeric columns. model._output._categorical_column_count = 0; _isCats = new String[vecs.length][]; for (int v = 0; v < vecs.length; v++) { _isCats[v] = vecs[v].isCategorical() ? new String[0] : null; if (_isCats[v] != null) model._output._categorical_column_count++; } Random rand = water.util.RandomUtils.getRNG(_parms._seed - 1); double centers[][]; // Cluster centers if (null != _parms._user_points) { // User-specified starting points Frame user_points = _parms._user_points.get(); int numCenters = (int) user_points.numRows(); int numCols = model._output.nfeatures(); centers = new double[numCenters][numCols]; Vec[] centersVecs = user_points.vecs(); // Get the centers and standardize them if requested for (int r = 0; r < numCenters; r++) { for (int c = 0; c < numCols; c++) { centers[r][c] = centersVecs[c].at(r); centers[r][c] = data(centers[r][c], c, means, mults, modes); } } } else { // Random, Furthest, or PlusPlus initialization if (_parms._init == Initialization.Random) { // Initialize all cluster centers to random rows centers = new double[_parms._k][model._output.nfeatures()]; for (double[] center : centers) randomRow(vecs, rand, center, means, mults, modes); } else { centers = new double[1][model._output.nfeatures()]; // Initialize first cluster center to random row randomRow(vecs, rand, centers[0], means, mults, modes); model._output._iterations = 0; while (model._output._iterations < 5) { // Sum squares distances to cluster center SumSqr sqr = new SumSqr(centers, means, mults, modes, _isCats).doAll(vecs); // Sample with probability inverse to square distance Sampler sampler = new Sampler( centers, means, mults, modes, _isCats, sqr._sqr, _parms._k * 3, _parms._seed, hasWeightCol()) .doAll(vecs); centers = ArrayUtils.append(centers, sampler._sampled); // Fill in sample centers into the model if (!isRunning()) return null; // Stopped/cancelled model._output._centers_raw = destandardize(centers, _isCats, means, mults); model._output._tot_withinss = sqr._sqr / _train.numRows(); model._output._iterations++; // One iteration done model.update( _key); // Make early version of model visible, but don't update progress using // update(1) } // Recluster down to k cluster centers centers = recluster(centers, rand, _parms._k, _parms._init, _isCats); model._output._iterations = 0; // Reset iteration count } } return centers; }