示例#1
0
 @Override
 public void compute2() {
   try {
     Scope.enter();
     long cs = _parms.checksum();
     init(true);
     // Read lock input
     _parms.read_lock_frames(_job);
     // Something goes wrong
     if (error_count() > 0)
       throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(DeepLearning.this);
     buildModel();
     // check that _parms isn't changed during DL model training
     long cs2 = _parms.checksum();
     assert (cs == cs2);
   } finally {
     _parms.read_unlock_frames(_job);
     Scope.exit();
   }
   tryComplete();
 }
示例#2
0
    /**
     * Train a Deep Learning model, assumes that all members are populated If checkpoint == null,
     * then start training a new model, otherwise continue from a checkpoint
     */
    public final void buildModel() {
      DeepLearningModel cp = null;
      if (_parms._checkpoint == null) {
        cp =
            new DeepLearningModel(
                dest(),
                _parms,
                new DeepLearningModel.DeepLearningModelOutput(DeepLearning.this),
                _train,
                _valid,
                nclasses());
        cp.model_info().initializeMembers();
      } else {
        final DeepLearningModel previous = DKV.getGet(_parms._checkpoint);
        if (previous == null) throw new IllegalArgumentException("Checkpoint not found.");
        Log.info("Resuming from checkpoint.");
        _job.update(0, "Resuming from checkpoint");

        if (isClassifier() != previous._output.isClassifier())
          throw new H2OIllegalArgumentException(
              "Response type must be the same as for the checkpointed model.");
        if (isSupervised() != previous._output.isSupervised())
          throw new H2OIllegalArgumentException(
              "Model type must be the same as for the checkpointed model.");

        // check the user-given arguments for consistency
        DeepLearningParameters oldP =
            previous._parms; // sanitized parameters for checkpointed model
        DeepLearningParameters newP = _parms; // user-given parameters for restart

        DeepLearningParameters oldP2 = (DeepLearningParameters) oldP.clone();
        DeepLearningParameters newP2 = (DeepLearningParameters) newP.clone();
        DeepLearningParameters.Sanity.modifyParms(
            oldP, oldP2, nclasses()); // sanitize the user-given parameters
        DeepLearningParameters.Sanity.modifyParms(
            newP, newP2, nclasses()); // sanitize the user-given parameters
        DeepLearningParameters.Sanity.checkpoint(oldP2, newP2);

        DataInfo dinfo;
        try {
          // PUBDEV-2513: Adapt _train and _valid (in-place) to match the frames that were used for
          // the previous model
          // This can add or remove dummy columns (can happen if the dataset is sparse and datasets
          // have different non-const columns)
          for (String st : previous.adaptTestForTrain(_train, true, false)) Log.warn(st);
          for (String st : previous.adaptTestForTrain(_valid, true, false)) Log.warn(st);
          dinfo = makeDataInfo(_train, _valid, _parms, nclasses());
          DKV.put(dinfo);
          cp = new DeepLearningModel(dest(), _parms, previous, false, dinfo);
          cp.write_lock(_job);

          if (!Arrays.equals(cp._output._names, previous._output._names)) {
            throw new H2OIllegalArgumentException(
                "The columns of the training data must be the same as for the checkpointed model. Check ignored columns (or disable ignore_const_cols).");
          }
          if (!Arrays.deepEquals(cp._output._domains, previous._output._domains)) {
            throw new H2OIllegalArgumentException(
                "Categorical factor levels of the training data must be the same as for the checkpointed model.");
          }
          if (dinfo.fullN() != previous.model_info().data_info().fullN()) {
            throw new H2OIllegalArgumentException(
                "Total number of predictors is different than for the checkpointed model.");
          }
          if (_parms._epochs <= previous.epoch_counter) {
            throw new H2OIllegalArgumentException(
                "Total number of epochs must be larger than the number of epochs already trained for the checkpointed model ("
                    + previous.epoch_counter
                    + ").");
          }

          // these are the mutable parameters that are to be used by the model (stored in
          // model_info._parms)
          final DeepLearningParameters actualNewP =
              cp.model_info()
                  .get_params(); // actually used parameters for model building (defaults filled in,
                                 // etc.)
          assert (actualNewP != previous.model_info().get_params());
          assert (actualNewP != newP);
          assert (actualNewP != oldP);
          DeepLearningParameters.Sanity.update(actualNewP, newP, nclasses());

          Log.info(
              "Continuing training after "
                  + String.format("%.3f", previous.epoch_counter)
                  + " epochs from the checkpointed model.");
          cp.update(_job);
        } catch (H2OIllegalArgumentException ex) {
          if (cp != null) {
            cp.unlock(_job);
            cp.delete();
            cp = null;
          }
          throw ex;
        } finally {
          if (cp != null) cp.unlock(_job);
        }
      }
      trainModel(cp);

      // clean up, but don't delete weights and biases if user asked for export
      List<Key> keep = new ArrayList<>();
      try {
        if (_parms._export_weights_and_biases
            && cp._output.weights != null
            && cp._output.biases != null) {
          for (Key k : Arrays.asList(cp._output.weights)) {
            keep.add(k);
            for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
              keep.add(vk._key);
            }
          }
          for (Key k : Arrays.asList(cp._output.biases)) {
            keep.add(k);
            for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
              keep.add(vk._key);
            }
          }
        }
      } finally {
        Scope.exit(keep.toArray(new Key[keep.size()]));
      }
    }
示例#3
0
    @Override
    protected void compute2() {
      _model = null; // Resulting model!
      try {
        Scope.enter(); // Cleanup temp keys
        init(true); // Do any expensive tests & conversions now
        // Do lock even before checking the errors, since this block is finalized by unlock
        // (not the best solution, but the code is more readable)
        _parms.read_lock_frames(SharedTree.this); // Fetch & read-lock input frames
        if (error_count() > 0)
          throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(SharedTree.this);

        // New Model?  Or continuing from a checkpoint?
        if (_parms._checkpoint && DKV.get(_parms._model_id) != null) {
          _model = DKV.get(_dest).get();
          _model.write_lock(_key); // do not delete previous model; we are extending it
        } else { // New Model
          // Compute the zero-tree error - guessing only the class distribution.
          // MSE is stddev squared when guessing for regression.
          // For classification, guess the largest class.
          _model =
              makeModel(
                  _dest,
                  _parms,
                  initial_MSE(_response, _response),
                  _valid == null
                      ? Double.NaN
                      : initial_MSE(_response, _vresponse)); // Make a fresh model
          _model.delete_and_lock(_key); // and clear & write-lock it (smashing any prior)
          _model._output._init_f = _initialPrediction;
        }

        // Compute the response domain; makes for nicer printouts
        String[] domain = _response.domain();
        assert (_nclass > 1 && domain != null) || (_nclass == 1 && domain == null);
        if (_nclass == 1) domain = new String[] {"r"}; // For regression, give a name to class 0

        // Compute class distribution, used to for initial guesses and to
        // upsample minority classes (if asked for).
        if (_nclass > 1) { // Classification?

          // Handle imbalanced classes by stratified over/under-sampling.
          // initWorkFrame sets the modeled class distribution, and
          // model.score() corrects the probabilities back using the
          // distribution ratios
          if (_model._output.isClassifier() && _parms._balance_classes) {

            float[] trainSamplingFactors =
                new float
                    [_train
                        .lastVec()
                        .domain()
                        .length]; // leave initialized to 0 -> will be filled up below
            if (_parms._class_sampling_factors != null) {
              if (_parms._class_sampling_factors.length != _train.lastVec().domain().length)
                throw new IllegalArgumentException(
                    "class_sampling_factors must have "
                        + _train.lastVec().domain().length
                        + " elements");
              trainSamplingFactors =
                  _parms._class_sampling_factors.clone(); // clone: don't modify the original
            }
            Frame stratified =
                water.util.MRUtils.sampleFrameStratified(
                    _train,
                    _train.lastVec(),
                    _train.vec(_model._output.weightsName()),
                    trainSamplingFactors,
                    (long) (_parms._max_after_balance_size * _train.numRows()),
                    _parms._seed,
                    true,
                    false);
            if (stratified != _train) {
              _train = stratified;
              _response = stratified.vec(_parms._response_column);
              _weights = stratified.vec(_parms._weights_column);
              // Recompute distribution since the input frame was modified
              MRUtils.ClassDist cdmt2 =
                  _weights != null
                      ? new MRUtils.ClassDist(_nclass).doAll(_response, _weights)
                      : new MRUtils.ClassDist(_nclass).doAll(_response);
              _model._output._distribution = cdmt2.dist();
              _model._output._modelClassDist = cdmt2.rel_dist();
            }
          }
          Log.info("Prior class distribution: " + Arrays.toString(_model._output._priorClassDist));
          Log.info("Model class distribution: " + Arrays.toString(_model._output._modelClassDist));
        }

        // Also add to the basic working Frame these sets:
        //   nclass Vecs of current forest results (sum across all trees)
        //   nclass Vecs of working/temp data
        //   nclass Vecs of NIDs, allowing 1 tree per class

        // Current forest values: results of summing the prior M trees
        for (int i = 0; i < _nclass; i++) _train.add("Tree_" + domain[i], _response.makeZero());

        // Initial work columns.  Set-before-use in the algos.
        for (int i = 0; i < _nclass; i++) _train.add("Work_" + domain[i], _response.makeZero());

        // One Tree per class, each tree needs a NIDs.  For empty classes use a -1
        // NID signifying an empty regression tree.
        for (int i = 0; i < _nclass; i++)
          _train.add(
              "NIDs_" + domain[i],
              _response.makeCon(
                  _model._output._distribution == null
                      ? 0
                      : (_model._output._distribution[i] == 0 ? -1 : 0)));

        // Tag out rows missing the response column
        new ExcludeNAResponse().doAll(_train);

        // Variable importance: squared-error-improvement-per-variable-per-split
        _improvPerVar = new float[_ncols];

        // Sub-class tree-model-builder specific build code
        buildModel();
        done(); // Job done!
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        if (_model != null) _model.unlock(_key);
        _parms.read_unlock_frames(SharedTree.this);
        if (_model == null) Scope.exit();
        else {
          Scope.exit(
              _model._key,
              ModelMetrics.buildKey(_model, _parms.train()),
              ModelMetrics.buildKey(_model, _parms.valid()));
        }
      }
      tryComplete();
    }
    @Override
    protected void compute2() {
      CoxPHModel model = null;
      try {
        Scope.enter();
        _parms.read_lock_frames(CoxPH.this);
        init(true);

        applyScoringFrameSideEffects();

        // The model to be built
        model = new CoxPHModel(dest(), _parms, new CoxPHModel.CoxPHOutput(CoxPH.this));
        model.delete_and_lock(_key);

        applyTrainingFrameSideEffects();

        int nResponses = 1;
        boolean useAllFactorLevels = false;
        final DataInfo dinfo =
            new DataInfo(
                Key.make(),
                _modelBuilderTrain,
                null,
                nResponses,
                useAllFactorLevels,
                DataInfo.TransformType.DEMEAN,
                TransformType.NONE,
                true,
                false,
                false,
                false,
                false,
                false);
        initStats(model, dinfo);

        final int n_offsets =
            (model._parms.offset_columns == null) ? 0 : model._parms.offset_columns.length;
        final int n_coef = dinfo.fullN() - n_offsets;
        final double[] step = MemoryManager.malloc8d(n_coef);
        final double[] oldCoef = MemoryManager.malloc8d(n_coef);
        final double[] newCoef = MemoryManager.malloc8d(n_coef);
        Arrays.fill(step, Double.NaN);
        Arrays.fill(oldCoef, Double.NaN);
        for (int j = 0; j < n_coef; ++j) newCoef[j] = model._parms.init;
        double oldLoglik = -Double.MAX_VALUE;
        final int n_time = (int) (model._output.max_time - model._output.min_time + 1);
        final boolean has_start_column = (model._parms.start_column != null);
        final boolean has_weights_column = (model._parms.weights_column != null);
        for (int i = 0; i <= model._parms.iter_max; ++i) {
          model._output.iter = i;

          final CoxPHTask coxMR =
              new CoxPHTask(
                      self(),
                      dinfo,
                      newCoef,
                      model._output.min_time,
                      n_time,
                      n_offsets,
                      has_start_column,
                      has_weights_column)
                  .doAll(dinfo._adaptedFrame);

          final double newLoglik = calcLoglik(model, coxMR);
          if (newLoglik > oldLoglik) {
            if (i == 0) calcCounts(model, coxMR);

            calcModelStats(model, newCoef, newLoglik);
            calcCumhaz_0(model, coxMR);

            if (newLoglik == 0) model._output.lre = -Math.log10(Math.abs(oldLoglik - newLoglik));
            else model._output.lre = -Math.log10(Math.abs((oldLoglik - newLoglik) / newLoglik));
            if (model._output.lre >= model._parms.lre_min) break;

            Arrays.fill(step, 0);
            for (int j = 0; j < n_coef; ++j)
              for (int k = 0; k < n_coef; ++k)
                step[j] -= model._output.var_coef[j][k] * model._output.gradient[k];
            for (int j = 0; j < n_coef; ++j)
              if (Double.isNaN(step[j]) || Double.isInfinite(step[j])) break;

            oldLoglik = newLoglik;
            System.arraycopy(newCoef, 0, oldCoef, 0, oldCoef.length);
          } else {
            for (int j = 0; j < n_coef; ++j) step[j] /= 2;
          }

          for (int j = 0; j < n_coef; ++j) newCoef[j] = oldCoef[j] - step[j];
        }

        model.update(_key);
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        updateModelOutput();
        _parms.read_unlock_frames(CoxPH.this);
        Scope.exit();
        done(); // Job done!
      }
      tryComplete();
    }