Exemplo n.º 1
0
  @Override
  public void modifyParmsForCrossValidationMainModel(ModelBuilder[] cvModelBuilders) {
    _parms._overwrite_with_best_model = false;

    if (_parms._stopping_rounds == 0 && _parms._max_runtime_secs == 0)
      return; // No exciting changes to stopping conditions
    // Extract stopping conditions from each CV model, and compute the best stopping answer
    _parms._stopping_rounds = 0;
    _parms._max_runtime_secs = 0;
    double sum = 0;
    for (ModelBuilder cvmb : cvModelBuilders)
      sum += ((DeepLearningModel) DKV.getGet(cvmb.dest())).last_scored().epoch_counter;
    _parms._epochs = sum / cvModelBuilders.length;
    if (!_parms._quiet_mode) {
      warn(
          "_epochs",
          "Setting optimal _epochs to "
              + _parms._epochs
              + " for cross-validation main model based on early stopping of cross-validation models.");
      warn(
          "_stopping_rounds",
          "Disabling convergence-based early stopping for cross-validation main model.");
      warn(
          "_max_runtime_secs",
          "Disabling maximum allowed runtime for cross-validation main model.");
    }
  }
Exemplo n.º 2
0
    /**
     * Train a Deep Learning neural net model
     *
     * @param model Input model (e.g., from initModel(), or from a previous training run)
     * @return Trained model
     */
    public final DeepLearningModel trainModel(DeepLearningModel model) {
      Frame validScoreFrame = null;
      Frame train, trainScoreFrame;
      try {
        //      if (checkpoint == null && !quiet_mode) logStart(); //if checkpoint is given, some
        // Job's params might be uninitialized (but the restarted model's parameters are correct)
        if (model == null) {
          model = DKV.get(dest()).get();
        }
        Log.info(
            "Model category: "
                + (_parms._autoencoder
                    ? "Auto-Encoder"
                    : isClassifier() ? "Classification" : "Regression"));
        final long model_size = model.model_info().size();
        Log.info(
            "Number of model parameters (weights/biases): " + String.format("%,d", model_size));
        model.write_lock(_job);
        _job.update(0, "Setting up training data...");
        final DeepLearningParameters mp = model.model_info().get_params();

        // temporary frames of the same "name" as the orig _train/_valid (asking the parameter's
        // Key, not the actual frame)
        // Note: don't put into DKV or they would overwrite the _train/_valid frames!
        Frame tra_fr = new Frame(mp._train, _train.names(), _train.vecs());
        Frame val_fr = _valid != null ? new Frame(mp._valid, _valid.names(), _valid.vecs()) : null;

        train = tra_fr;
        if (model._output.isClassifier() && mp._balance_classes) {
          _job.update(0, "Balancing class distribution of training data...");
          float[] trainSamplingFactors =
              new float
                  [train
                      .lastVec()
                      .domain()
                      .length]; // leave initialized to 0 -> will be filled up below
          if (mp._class_sampling_factors != null) {
            if (mp._class_sampling_factors.length != train.lastVec().domain().length)
              throw new IllegalArgumentException(
                  "class_sampling_factors must have "
                      + train.lastVec().domain().length
                      + " elements");
            trainSamplingFactors =
                mp._class_sampling_factors.clone(); // clone: don't modify the original
          }
          train =
              sampleFrameStratified(
                  train,
                  train.lastVec(),
                  train.vec(model._output.weightsName()),
                  trainSamplingFactors,
                  (long) (mp._max_after_balance_size * train.numRows()),
                  mp._seed,
                  true,
                  false);
          Vec l = train.lastVec();
          Vec w = train.vec(model._output.weightsName());
          MRUtils.ClassDist cd = new MRUtils.ClassDist(l);
          model._output._modelClassDist =
              _weights != null ? cd.doAll(l, w).rel_dist() : cd.doAll(l).rel_dist();
        }
        model.training_rows = train.numRows();
        if (_weights != null && _weights.min() == 0 && _weights.max() == 1 && _weights.isInt()) {
          model.training_rows = Math.round(train.numRows() * _weights.mean());
          Log.warn(
              "Not counting "
                  + (train.numRows() - model.training_rows)
                  + " rows with weight=0 towards an epoch.");
        }
        Log.info("One epoch corresponds to " + model.training_rows + " training data rows.");
        trainScoreFrame =
            sampleFrame(
                train,
                mp._score_training_samples,
                mp._seed); // training scoring dataset is always sampled uniformly from the training
                           // dataset
        if (trainScoreFrame != train) Scope.track(trainScoreFrame);

        if (!_parms._quiet_mode)
          Log.info("Number of chunks of the training data: " + train.anyVec().nChunks());
        if (val_fr != null) {
          model.validation_rows = val_fr.numRows();
          // validation scoring dataset can be sampled in multiple ways from the given validation
          // dataset
          if (model._output.isClassifier()
              && mp._balance_classes
              && mp._score_validation_sampling
                  == DeepLearningParameters.ClassSamplingMethod.Stratified) {
            _job.update(0, "Sampling validation data (stratified)...");
            validScoreFrame =
                sampleFrameStratified(
                    val_fr,
                    val_fr.lastVec(),
                    val_fr.vec(model._output.weightsName()),
                    null,
                    mp._score_validation_samples > 0
                        ? mp._score_validation_samples
                        : val_fr.numRows(),
                    mp._seed + 1,
                    false /* no oversampling */,
                    false);
          } else {
            _job.update(0, "Sampling validation data...");
            validScoreFrame = sampleFrame(val_fr, mp._score_validation_samples, mp._seed + 1);
            if (validScoreFrame != val_fr) Scope.track(validScoreFrame);
          }
          if (!_parms._quiet_mode)
            Log.info(
                "Number of chunks of the validation data: " + validScoreFrame.anyVec().nChunks());
        }

        // Set train_samples_per_iteration size (cannot be done earlier since this depends on
        // whether stratified sampling is done)
        model.actual_train_samples_per_iteration =
            computeTrainSamplesPerIteration(mp, model.training_rows, model);
        // Determine whether shuffling is enforced
        if (mp._replicate_training_data
            && (model.actual_train_samples_per_iteration
                == model.training_rows * (mp._single_node_mode ? 1 : H2O.CLOUD.size()))
            && !mp._shuffle_training_data
            && H2O.CLOUD.size() > 1
            && !mp._reproducible) {
          if (!mp._quiet_mode)
            Log.info(
                "Enabling training data shuffling, because all nodes train on the full dataset (replicated training data).");
          mp._shuffle_training_data = true;
        }
        if (!mp._shuffle_training_data
            && model.actual_train_samples_per_iteration == model.training_rows
            && train.anyVec().nChunks() == 1) {
          if (!mp._quiet_mode)
            Log.info(
                "Enabling training data shuffling to avoid training rows in the same order over and over (no Hogwild since there's only 1 chunk).");
          mp._shuffle_training_data = true;
        }

        //        if (!mp._quiet_mode) Log.info("Initial model:\n" + model.model_info());
        long now = System.currentTimeMillis();
        model._timeLastIterationEnter = now;
        if (_parms._autoencoder) {
          _job.update(0, "Scoring null model of autoencoder...");
          if (!mp._quiet_mode) Log.info("Scoring the null model of the autoencoder.");
          model.doScoring(
              trainScoreFrame,
              validScoreFrame,
              _job._key,
              0,
              false); // get the null model reconstruction error
        }
        // put the initial version of the model into DKV
        model.update(_job);
        model.total_setup_time_ms += now - _job.start_time();
        Log.info("Total setup time: " + PrettyPrint.msecs(model.total_setup_time_ms, true));
        Log.info("Starting to train the Deep Learning model.");
        _job.update(0, "Training...");

        // main loop
        for (; ; ) {
          model.iterations++;
          model.set_model_info(
              mp._epochs == 0
                  ? model.model_info()
                  : H2O.CLOUD.size() > 1 && mp._replicate_training_data
                      ? (mp._single_node_mode
                          ? new DeepLearningTask2(
                                  _job._key,
                                  train,
                                  model.model_info(),
                                  rowFraction(train, mp, model),
                                  model.iterations)
                              .doAll(Key.make(H2O.SELF))
                              .model_info()
                          : // replicated data + single node mode
                          new DeepLearningTask2(
                                  _job._key,
                                  train,
                                  model.model_info(),
                                  rowFraction(train, mp, model),
                                  model.iterations)
                              .doAllNodes()
                              .model_info())
                      : // replicated data + multi-node mode
                      new DeepLearningTask(
                              _job._key,
                              model.model_info(),
                              rowFraction(train, mp, model),
                              model.iterations)
                          .doAll(train)
                          .model_info()); // distributed data (always in multi-node mode)
          if (stop_requested() && !timeout()) break; // cancellation
          if (!model.doScoring(
              trainScoreFrame, validScoreFrame, _job._key, model.iterations, false))
            break; // finished training (or early stopping or convergence)
          if (timeout()) break; // stop after scoring
        }

        // replace the model with the best model so far (if it's better)
        if (!stop_requested()
            && _parms._overwrite_with_best_model
            && model.actual_best_model_key != null
            && _parms._nfolds == 0) {
          DeepLearningModel best_model = DKV.getGet(model.actual_best_model_key);
          if (best_model != null
              && best_model.loss() < model.loss()
              && Arrays.equals(best_model.model_info().units, model.model_info().units)) {
            if (!_parms._quiet_mode)
              Log.info("Setting the model to be the best model so far (based on scoring history).");
            DeepLearningModelInfo mi = best_model.model_info().deep_clone();
            // Don't cheat - count full amount of training samples, since that's the amount of
            // training it took to train (without finding anything better)
            mi.set_processed_global(model.model_info().get_processed_global());
            mi.set_processed_local(model.model_info().get_processed_local());
            model.set_model_info(mi);
            model.update(_job);
            model.doScoring(trainScoreFrame, validScoreFrame, _job._key, model.iterations, true);
            assert (best_model.loss() == model.loss());
          }
        }
        // store coefficient names for future use
        // possibly change
        model.model_info().data_info().coefNames();
        if (!_parms._quiet_mode) {
          Log.info(
              "==============================================================================================================================================================================");
          if (stop_requested()) {
            Log.info("Deep Learning model training was interrupted.");
          } else {
            Log.info("Finished training the Deep Learning model.");
            Log.info(model);
          }
          Log.info(
              "==============================================================================================================================================================================");
        }
      } finally {
        if (model != null) {
          model.deleteElasticAverageModels();
          model.unlock(_job);
          if (model.actual_best_model_key != null) {
            assert (model.actual_best_model_key != model._key);
            DKV.remove(model.actual_best_model_key);
          }
        }
      }
      return model;
    }
Exemplo n.º 3
0
    /**
     * Train a Deep Learning model, assumes that all members are populated If checkpoint == null,
     * then start training a new model, otherwise continue from a checkpoint
     */
    public final void buildModel() {
      DeepLearningModel cp = null;
      if (_parms._checkpoint == null) {
        cp =
            new DeepLearningModel(
                dest(),
                _parms,
                new DeepLearningModel.DeepLearningModelOutput(DeepLearning.this),
                _train,
                _valid,
                nclasses());
        cp.model_info().initializeMembers();
      } else {
        final DeepLearningModel previous = DKV.getGet(_parms._checkpoint);
        if (previous == null) throw new IllegalArgumentException("Checkpoint not found.");
        Log.info("Resuming from checkpoint.");
        _job.update(0, "Resuming from checkpoint");

        if (isClassifier() != previous._output.isClassifier())
          throw new H2OIllegalArgumentException(
              "Response type must be the same as for the checkpointed model.");
        if (isSupervised() != previous._output.isSupervised())
          throw new H2OIllegalArgumentException(
              "Model type must be the same as for the checkpointed model.");

        // check the user-given arguments for consistency
        DeepLearningParameters oldP =
            previous._parms; // sanitized parameters for checkpointed model
        DeepLearningParameters newP = _parms; // user-given parameters for restart

        DeepLearningParameters oldP2 = (DeepLearningParameters) oldP.clone();
        DeepLearningParameters newP2 = (DeepLearningParameters) newP.clone();
        DeepLearningParameters.Sanity.modifyParms(
            oldP, oldP2, nclasses()); // sanitize the user-given parameters
        DeepLearningParameters.Sanity.modifyParms(
            newP, newP2, nclasses()); // sanitize the user-given parameters
        DeepLearningParameters.Sanity.checkpoint(oldP2, newP2);

        DataInfo dinfo;
        try {
          // PUBDEV-2513: Adapt _train and _valid (in-place) to match the frames that were used for
          // the previous model
          // This can add or remove dummy columns (can happen if the dataset is sparse and datasets
          // have different non-const columns)
          for (String st : previous.adaptTestForTrain(_train, true, false)) Log.warn(st);
          for (String st : previous.adaptTestForTrain(_valid, true, false)) Log.warn(st);
          dinfo = makeDataInfo(_train, _valid, _parms, nclasses());
          DKV.put(dinfo);
          cp = new DeepLearningModel(dest(), _parms, previous, false, dinfo);
          cp.write_lock(_job);

          if (!Arrays.equals(cp._output._names, previous._output._names)) {
            throw new H2OIllegalArgumentException(
                "The columns of the training data must be the same as for the checkpointed model. Check ignored columns (or disable ignore_const_cols).");
          }
          if (!Arrays.deepEquals(cp._output._domains, previous._output._domains)) {
            throw new H2OIllegalArgumentException(
                "Categorical factor levels of the training data must be the same as for the checkpointed model.");
          }
          if (dinfo.fullN() != previous.model_info().data_info().fullN()) {
            throw new H2OIllegalArgumentException(
                "Total number of predictors is different than for the checkpointed model.");
          }
          if (_parms._epochs <= previous.epoch_counter) {
            throw new H2OIllegalArgumentException(
                "Total number of epochs must be larger than the number of epochs already trained for the checkpointed model ("
                    + previous.epoch_counter
                    + ").");
          }

          // these are the mutable parameters that are to be used by the model (stored in
          // model_info._parms)
          final DeepLearningParameters actualNewP =
              cp.model_info()
                  .get_params(); // actually used parameters for model building (defaults filled in,
                                 // etc.)
          assert (actualNewP != previous.model_info().get_params());
          assert (actualNewP != newP);
          assert (actualNewP != oldP);
          DeepLearningParameters.Sanity.update(actualNewP, newP, nclasses());

          Log.info(
              "Continuing training after "
                  + String.format("%.3f", previous.epoch_counter)
                  + " epochs from the checkpointed model.");
          cp.update(_job);
        } catch (H2OIllegalArgumentException ex) {
          if (cp != null) {
            cp.unlock(_job);
            cp.delete();
            cp = null;
          }
          throw ex;
        } finally {
          if (cp != null) cp.unlock(_job);
        }
      }
      trainModel(cp);

      // clean up, but don't delete weights and biases if user asked for export
      List<Key> keep = new ArrayList<>();
      try {
        if (_parms._export_weights_and_biases
            && cp._output.weights != null
            && cp._output.biases != null) {
          for (Key k : Arrays.asList(cp._output.weights)) {
            keep.add(k);
            for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
              keep.add(vk._key);
            }
          }
          for (Key k : Arrays.asList(cp._output.biases)) {
            keep.add(k);
            for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
              keep.add(vk._key);
            }
          }
        }
      } finally {
        Scope.exit(keep.toArray(new Key[keep.size()]));
      }
    }
Exemplo n.º 4
0
    // Main worker thread
    @Override
    protected void compute2() {

      KMeansModel model = null;
      try {
        init(true);
        // Do lock even before checking the errors, since this block is finalized by unlock
        // (not the best solution, but the code is more readable)
        _parms.read_lock_frames(KMeans.this); // Fetch & read-lock input frames
        // Something goes wrong
        if (error_count() > 0)
          throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(KMeans.this);
        // The model to be built
        model = new KMeansModel(dest(), _parms, new KMeansModel.KMeansOutput(KMeans.this));
        model.delete_and_lock(_key);

        //
        final Vec vecs[] = _train.vecs();
        // mults & means for standardization
        final double[] means = _train.means(); // means are used to impute NAs
        final double[] mults = _parms._standardize ? _train.mults() : null;
        final int[] impute_cat = new int[vecs.length];
        for (int i = 0; i < vecs.length; i++)
          impute_cat[i] = vecs[i].isNumeric() ? -1 : DataInfo.imputeCat(vecs[i]);
        model._output._normSub = means;
        model._output._normMul = mults;
        // Initialize cluster centers and standardize if requested
        double[][] centers = initial_centers(model, vecs, means, mults, impute_cat);
        if (centers == null) return; // Stopped/cancelled during center-finding
        double[][] oldCenters = null;

        // ---
        // Run the main KMeans Clustering loop
        // Stop after enough iterations or average_change < TOLERANCE
        model._output._iterations =
            0; // Loop ends only when iterations > max_iterations with strict inequality
        while (!isDone(model, centers, oldCenters)) {
          Lloyds task =
              new Lloyds(centers, means, mults, impute_cat, _isCats, _parms._k, hasWeightCol())
                  .doAll(vecs);
          // Pick the max categorical level for cluster center
          max_cats(task._cMeans, task._cats, _isCats);

          // Handle the case where some centers go dry.  Rescue only 1 cluster
          // per iteration ('cause we only tracked the 1 worst row)
          if (cleanupBadClusters(task, vecs, centers, means, mults, impute_cat)) continue;

          // Compute model stats; update standardized cluster centers
          oldCenters = centers;
          centers = computeStatsFillModel(task, model, vecs, means, mults, impute_cat);

          model.update(_key); // Update model in K/V store
          update(1); // One unit of work
          if (model._parms._score_each_iteration) Log.info(model._output._model_summary);
        }

        Log.info(model._output._model_summary);
        //        Log.info(model._output._scoring_history);
        //
        // Log.info(((ModelMetricsClustering)model._output._training_metrics).createCentroidStatsTable().toString());

        // At the end: validation scoring (no need to gather scoring history)
        if (_valid != null) {
          model.score(_parms.valid()).delete(); // this appends a ModelMetrics on the validation set
          model._output._validation_metrics = ModelMetrics.getFromDKV(model, _parms.valid());
          model.update(_key); // Update model in K/V store
        }
        done(); // Job done!

      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        updateModelOutput();
        if (model != null) model.unlock(_key);
        _parms.read_unlock_frames(KMeans.this);
      }
      tryComplete();
    }
Exemplo n.º 5
0
 @Override
 public Vec vresponse() {
   if (_vresponse_key == null) return response();
   return _vresponse != null ? _vresponse : (_vresponse = DKV.getGet(_vresponse_key));
 }
Exemplo n.º 6
0
 @Override
 public Vec response() {
   return _response == null ? (_response = DKV.getGet(_response_key)) : _response;
 }
Exemplo n.º 7
0
    @Override
    protected void compute2() {
      _model = null; // Resulting model!
      try {
        Scope.enter(); // Cleanup temp keys
        init(true); // Do any expensive tests & conversions now
        // Do lock even before checking the errors, since this block is finalized by unlock
        // (not the best solution, but the code is more readable)
        _parms.read_lock_frames(SharedTree.this); // Fetch & read-lock input frames
        if (error_count() > 0)
          throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(SharedTree.this);

        // New Model?  Or continuing from a checkpoint?
        if (_parms._checkpoint && DKV.get(_parms._model_id) != null) {
          _model = DKV.get(_dest).get();
          _model.write_lock(_key); // do not delete previous model; we are extending it
        } else { // New Model
          // Compute the zero-tree error - guessing only the class distribution.
          // MSE is stddev squared when guessing for regression.
          // For classification, guess the largest class.
          _model =
              makeModel(
                  _dest,
                  _parms,
                  initial_MSE(_response, _response),
                  _valid == null
                      ? Double.NaN
                      : initial_MSE(_response, _vresponse)); // Make a fresh model
          _model.delete_and_lock(_key); // and clear & write-lock it (smashing any prior)
          _model._output._init_f = _initialPrediction;
        }

        // Compute the response domain; makes for nicer printouts
        String[] domain = _response.domain();
        assert (_nclass > 1 && domain != null) || (_nclass == 1 && domain == null);
        if (_nclass == 1) domain = new String[] {"r"}; // For regression, give a name to class 0

        // Compute class distribution, used to for initial guesses and to
        // upsample minority classes (if asked for).
        if (_nclass > 1) { // Classification?

          // Handle imbalanced classes by stratified over/under-sampling.
          // initWorkFrame sets the modeled class distribution, and
          // model.score() corrects the probabilities back using the
          // distribution ratios
          if (_model._output.isClassifier() && _parms._balance_classes) {

            float[] trainSamplingFactors =
                new float
                    [_train
                        .lastVec()
                        .domain()
                        .length]; // leave initialized to 0 -> will be filled up below
            if (_parms._class_sampling_factors != null) {
              if (_parms._class_sampling_factors.length != _train.lastVec().domain().length)
                throw new IllegalArgumentException(
                    "class_sampling_factors must have "
                        + _train.lastVec().domain().length
                        + " elements");
              trainSamplingFactors =
                  _parms._class_sampling_factors.clone(); // clone: don't modify the original
            }
            Frame stratified =
                water.util.MRUtils.sampleFrameStratified(
                    _train,
                    _train.lastVec(),
                    _train.vec(_model._output.weightsName()),
                    trainSamplingFactors,
                    (long) (_parms._max_after_balance_size * _train.numRows()),
                    _parms._seed,
                    true,
                    false);
            if (stratified != _train) {
              _train = stratified;
              _response = stratified.vec(_parms._response_column);
              _weights = stratified.vec(_parms._weights_column);
              // Recompute distribution since the input frame was modified
              MRUtils.ClassDist cdmt2 =
                  _weights != null
                      ? new MRUtils.ClassDist(_nclass).doAll(_response, _weights)
                      : new MRUtils.ClassDist(_nclass).doAll(_response);
              _model._output._distribution = cdmt2.dist();
              _model._output._modelClassDist = cdmt2.rel_dist();
            }
          }
          Log.info("Prior class distribution: " + Arrays.toString(_model._output._priorClassDist));
          Log.info("Model class distribution: " + Arrays.toString(_model._output._modelClassDist));
        }

        // Also add to the basic working Frame these sets:
        //   nclass Vecs of current forest results (sum across all trees)
        //   nclass Vecs of working/temp data
        //   nclass Vecs of NIDs, allowing 1 tree per class

        // Current forest values: results of summing the prior M trees
        for (int i = 0; i < _nclass; i++) _train.add("Tree_" + domain[i], _response.makeZero());

        // Initial work columns.  Set-before-use in the algos.
        for (int i = 0; i < _nclass; i++) _train.add("Work_" + domain[i], _response.makeZero());

        // One Tree per class, each tree needs a NIDs.  For empty classes use a -1
        // NID signifying an empty regression tree.
        for (int i = 0; i < _nclass; i++)
          _train.add(
              "NIDs_" + domain[i],
              _response.makeCon(
                  _model._output._distribution == null
                      ? 0
                      : (_model._output._distribution[i] == 0 ? -1 : 0)));

        // Tag out rows missing the response column
        new ExcludeNAResponse().doAll(_train);

        // Variable importance: squared-error-improvement-per-variable-per-split
        _improvPerVar = new float[_ncols];

        // Sub-class tree-model-builder specific build code
        buildModel();
        done(); // Job done!
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        if (_model != null) _model.unlock(_key);
        _parms.read_unlock_frames(SharedTree.this);
        if (_model == null) Scope.exit();
        else {
          Scope.exit(
              _model._key,
              ModelMetrics.buildKey(_model, _parms.train()),
              ModelMetrics.buildKey(_model, _parms.valid()));
        }
      }
      tryComplete();
    }
    @Override
    protected void compute2() {
      CoxPHModel model = null;
      try {
        Scope.enter();
        _parms.read_lock_frames(CoxPH.this);
        init(true);

        applyScoringFrameSideEffects();

        // The model to be built
        model = new CoxPHModel(dest(), _parms, new CoxPHModel.CoxPHOutput(CoxPH.this));
        model.delete_and_lock(_key);

        applyTrainingFrameSideEffects();

        int nResponses = 1;
        boolean useAllFactorLevels = false;
        final DataInfo dinfo =
            new DataInfo(
                Key.make(),
                _modelBuilderTrain,
                null,
                nResponses,
                useAllFactorLevels,
                DataInfo.TransformType.DEMEAN,
                TransformType.NONE,
                true,
                false,
                false,
                false,
                false,
                false);
        initStats(model, dinfo);

        final int n_offsets =
            (model._parms.offset_columns == null) ? 0 : model._parms.offset_columns.length;
        final int n_coef = dinfo.fullN() - n_offsets;
        final double[] step = MemoryManager.malloc8d(n_coef);
        final double[] oldCoef = MemoryManager.malloc8d(n_coef);
        final double[] newCoef = MemoryManager.malloc8d(n_coef);
        Arrays.fill(step, Double.NaN);
        Arrays.fill(oldCoef, Double.NaN);
        for (int j = 0; j < n_coef; ++j) newCoef[j] = model._parms.init;
        double oldLoglik = -Double.MAX_VALUE;
        final int n_time = (int) (model._output.max_time - model._output.min_time + 1);
        final boolean has_start_column = (model._parms.start_column != null);
        final boolean has_weights_column = (model._parms.weights_column != null);
        for (int i = 0; i <= model._parms.iter_max; ++i) {
          model._output.iter = i;

          final CoxPHTask coxMR =
              new CoxPHTask(
                      self(),
                      dinfo,
                      newCoef,
                      model._output.min_time,
                      n_time,
                      n_offsets,
                      has_start_column,
                      has_weights_column)
                  .doAll(dinfo._adaptedFrame);

          final double newLoglik = calcLoglik(model, coxMR);
          if (newLoglik > oldLoglik) {
            if (i == 0) calcCounts(model, coxMR);

            calcModelStats(model, newCoef, newLoglik);
            calcCumhaz_0(model, coxMR);

            if (newLoglik == 0) model._output.lre = -Math.log10(Math.abs(oldLoglik - newLoglik));
            else model._output.lre = -Math.log10(Math.abs((oldLoglik - newLoglik) / newLoglik));
            if (model._output.lre >= model._parms.lre_min) break;

            Arrays.fill(step, 0);
            for (int j = 0; j < n_coef; ++j)
              for (int k = 0; k < n_coef; ++k)
                step[j] -= model._output.var_coef[j][k] * model._output.gradient[k];
            for (int j = 0; j < n_coef; ++j)
              if (Double.isNaN(step[j]) || Double.isInfinite(step[j])) break;

            oldLoglik = newLoglik;
            System.arraycopy(newCoef, 0, oldCoef, 0, oldCoef.length);
          } else {
            for (int j = 0; j < n_coef; ++j) step[j] /= 2;
          }

          for (int j = 0; j < n_coef; ++j) newCoef[j] = oldCoef[j] - step[j];
        }

        model.update(_key);
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        updateModelOutput();
        _parms.read_unlock_frames(CoxPH.this);
        Scope.exit();
        done(); // Job done!
      }
      tryComplete();
    }