/** * Helper to create the DataInfo object from training/validation frames and the DL parameters * * @param train Training frame * @param valid Validation frame * @param parms Model parameters * @param nClasses Number of response levels (1: regression, >=2: classification) * @return DataInfo */ static DataInfo makeDataInfo( Frame train, Frame valid, DeepLearningParameters parms, int nClasses) { double x = 0.782347234; boolean identityLink = new Distribution(parms._distribution, parms._tweedie_power).link(x) == x; DataInfo dinfo = new DataInfo( train, valid, parms._autoencoder ? 0 : 1, // nResponses parms._autoencoder || parms._use_all_factor_levels, // use all FactorLevels for auto-encoder parms._standardize ? (parms._autoencoder ? DataInfo.TransformType.NORMALIZE : parms._sparse ? DataInfo.TransformType.DESCALE : DataInfo.TransformType.STANDARDIZE) : DataInfo.TransformType.NONE, // transform predictors !parms._standardize || train.lastVec().isCategorical() ? DataInfo.TransformType.NONE : identityLink ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType .NONE, // transform response for regression with identity link parms._missing_values_handling == DeepLearningParameters.MissingValuesHandling.Skip, // whether to skip missing false, // do not replace NAs in numeric cols with mean true, // always add a bucket for missing values parms._weights_column != null, // observation weights parms._offset_column != null, parms._fold_column != null); // Checks and adjustments: // 1) observation weights (adjust mean/sigmas for predictors and response) // 2) NAs (check that there's enough rows left) GLMTask.YMUTask ymt = new GLMTask.YMUTask( dinfo, nClasses, true, !parms._autoencoder && nClasses == 1, false, !parms._autoencoder) .doAll(dinfo._adaptedFrame); if (ymt._wsum == 0 && parms._missing_values_handling == DeepLearningParameters.MissingValuesHandling.Skip) throw new H2OIllegalArgumentException( "No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'."); if (parms._weights_column != null && parms._offset_column != null) { Log.warn( "Combination of offset and weights can lead to slight differences because Rollupstats aren't weighted - need to re-calculate weighted mean/sigma of the response including offset terms."); } if (parms._weights_column != null && parms._offset_column == null /*FIXME: offset not yet implemented*/) { dinfo.updateWeightedSigmaAndMean(ymt._basicStats.sigma(), ymt._basicStats.mean()); if (nClasses == 1) dinfo.updateWeightedSigmaAndMeanForResponse( ymt._basicStatsResponse.sigma(), ymt._basicStatsResponse.mean()); } return dinfo; }
/** * Train a Deep Learning model, assumes that all members are populated If checkpoint == null, * then start training a new model, otherwise continue from a checkpoint */ public final void buildModel() { DeepLearningModel cp = null; if (_parms._checkpoint == null) { cp = new DeepLearningModel( dest(), _parms, new DeepLearningModel.DeepLearningModelOutput(DeepLearning.this), _train, _valid, nclasses()); cp.model_info().initializeMembers(); } else { final DeepLearningModel previous = DKV.getGet(_parms._checkpoint); if (previous == null) throw new IllegalArgumentException("Checkpoint not found."); Log.info("Resuming from checkpoint."); _job.update(0, "Resuming from checkpoint"); if (isClassifier() != previous._output.isClassifier()) throw new H2OIllegalArgumentException( "Response type must be the same as for the checkpointed model."); if (isSupervised() != previous._output.isSupervised()) throw new H2OIllegalArgumentException( "Model type must be the same as for the checkpointed model."); // check the user-given arguments for consistency DeepLearningParameters oldP = previous._parms; // sanitized parameters for checkpointed model DeepLearningParameters newP = _parms; // user-given parameters for restart DeepLearningParameters oldP2 = (DeepLearningParameters) oldP.clone(); DeepLearningParameters newP2 = (DeepLearningParameters) newP.clone(); DeepLearningParameters.Sanity.modifyParms( oldP, oldP2, nclasses()); // sanitize the user-given parameters DeepLearningParameters.Sanity.modifyParms( newP, newP2, nclasses()); // sanitize the user-given parameters DeepLearningParameters.Sanity.checkpoint(oldP2, newP2); DataInfo dinfo; try { // PUBDEV-2513: Adapt _train and _valid (in-place) to match the frames that were used for // the previous model // This can add or remove dummy columns (can happen if the dataset is sparse and datasets // have different non-const columns) for (String st : previous.adaptTestForTrain(_train, true, false)) Log.warn(st); for (String st : previous.adaptTestForTrain(_valid, true, false)) Log.warn(st); dinfo = makeDataInfo(_train, _valid, _parms, nclasses()); DKV.put(dinfo); cp = new DeepLearningModel(dest(), _parms, previous, false, dinfo); cp.write_lock(_job); if (!Arrays.equals(cp._output._names, previous._output._names)) { throw new H2OIllegalArgumentException( "The columns of the training data must be the same as for the checkpointed model. Check ignored columns (or disable ignore_const_cols)."); } if (!Arrays.deepEquals(cp._output._domains, previous._output._domains)) { throw new H2OIllegalArgumentException( "Categorical factor levels of the training data must be the same as for the checkpointed model."); } if (dinfo.fullN() != previous.model_info().data_info().fullN()) { throw new H2OIllegalArgumentException( "Total number of predictors is different than for the checkpointed model."); } if (_parms._epochs <= previous.epoch_counter) { throw new H2OIllegalArgumentException( "Total number of epochs must be larger than the number of epochs already trained for the checkpointed model (" + previous.epoch_counter + ")."); } // these are the mutable parameters that are to be used by the model (stored in // model_info._parms) final DeepLearningParameters actualNewP = cp.model_info() .get_params(); // actually used parameters for model building (defaults filled in, // etc.) assert (actualNewP != previous.model_info().get_params()); assert (actualNewP != newP); assert (actualNewP != oldP); DeepLearningParameters.Sanity.update(actualNewP, newP, nclasses()); Log.info( "Continuing training after " + String.format("%.3f", previous.epoch_counter) + " epochs from the checkpointed model."); cp.update(_job); } catch (H2OIllegalArgumentException ex) { if (cp != null) { cp.unlock(_job); cp.delete(); cp = null; } throw ex; } finally { if (cp != null) cp.unlock(_job); } } trainModel(cp); // clean up, but don't delete weights and biases if user asked for export List<Key> keep = new ArrayList<>(); try { if (_parms._export_weights_and_biases && cp._output.weights != null && cp._output.biases != null) { for (Key k : Arrays.asList(cp._output.weights)) { keep.add(k); for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) { keep.add(vk._key); } } for (Key k : Arrays.asList(cp._output.biases)) { keep.add(k); for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) { keep.add(vk._key); } } } } finally { Scope.exit(keep.toArray(new Key[keep.size()])); } }