示例#1
0
    // Handle the case where some centers go dry.  Rescue only 1 cluster
    // per iteration ('cause we only tracked the 1 worst row)
    boolean cleanupBadClusters(
        Lloyds task,
        final Vec[] vecs,
        final double[][] centers,
        final double[] means,
        final double[] mults,
        final int[] modes) {
      // Find any bad clusters
      int clu;
      for (clu = 0; clu < _parms._k; clu++) if (task._size[clu] == 0) break;
      if (clu == _parms._k) return false; // No bad clusters

      long row = task._worst_row;
      Log.warn("KMeans: Re-initializing cluster " + clu + " to row " + row);
      data(centers[clu] = task._cMeans[clu], vecs, row, means, mults, modes);
      task._size[clu] =
          1; // FIXME: PUBDEV-871 Some other cluster had their membership count reduced by one!
      // (which one?)

      // Find any MORE bad clusters; we only fixed the first one
      for (clu = 0; clu < _parms._k; clu++) if (task._size[clu] == 0) break;
      if (clu == _parms._k) return false; // No MORE bad clusters

      // If we see 2 or more bad rows, just re-run Lloyds to get the
      // next-worst row.  We don't count this as an iteration, because
      // we're not really adjusting the centers, we're trying to get
      // some centers *at-all*.
      Log.warn("KMeans: Re-running Lloyds to re-init another cluster");
      if (_reinit_attempts++ < _parms._k) {
        return true; // Rerun Lloyds, and assign points to centroids
      } else {
        _reinit_attempts = 0;
        return false;
      }
    }
示例#2
0
    public void handle(
        String target,
        Request baseRequest,
        HttpServletRequest request,
        HttpServletResponse response)
        throws IOException, ServletException {
      if (!H2O.ARGS.ldap_login) {
        return;
      }

      String loginName = request.getUserPrincipal().getName();
      if (!loginName.equals(H2O.ARGS.user_name)) {
        Log.warn(
            "Login name ("
                + loginName
                + ") does not match cluster owner name ("
                + H2O.ARGS.user_name
                + ")");
        sendResponseError(
            response,
            HttpServletResponse.SC_UNAUTHORIZED,
            "Login name does not match cluster owner name");
        baseRequest.setHandled(true);
      }
    }
示例#3
0
文件: Job.java 项目: chouclee/h2o
 @Override
 protected void init() {
   super.init();
   // Reject request if classification is required and response column is float
   // Argument a4class = find("classification"); // get UI control
   // String p4class = input("classification");  // get value from HTTP requests
   // if there is UI control and classification field was passed
   final boolean classificationFieldSpecified =
       true; // ROLLBACK: a4class!=null ? p4class!=null : /* we are not in UI so expect that
   // parameter is specified correctly */ true;
   if (!classificationFieldSpecified) { // can happen if a client sends a request which does not
     // specify classification parameter
     classification = response.isEnum();
     Log.warn(
         "Classification field is not specified - deriving according to response! The classification field set to "
             + classification);
   } else {
     if (classification && response.isFloat())
       throw new H2OIllegalArgumentException(
           find("classification"), "Requested classification on float column!");
     if (!classification && response.isEnum())
       throw new H2OIllegalArgumentException(
           find("classification"), "Requested regression on enum column!");
   }
 }
示例#4
0
 /**
  * Helper to create the DataInfo object from training/validation frames and the DL parameters
  *
  * @param train Training frame
  * @param valid Validation frame
  * @param parms Model parameters
  * @param nClasses Number of response levels (1: regression, >=2: classification)
  * @return DataInfo
  */
 static DataInfo makeDataInfo(
     Frame train, Frame valid, DeepLearningParameters parms, int nClasses) {
   double x = 0.782347234;
   boolean identityLink = new Distribution(parms._distribution, parms._tweedie_power).link(x) == x;
   DataInfo dinfo =
       new DataInfo(
           train,
           valid,
           parms._autoencoder ? 0 : 1, // nResponses
           parms._autoencoder
               || parms._use_all_factor_levels, // use all FactorLevels for auto-encoder
           parms._standardize
               ? (parms._autoencoder
                   ? DataInfo.TransformType.NORMALIZE
                   : parms._sparse
                       ? DataInfo.TransformType.DESCALE
                       : DataInfo.TransformType.STANDARDIZE)
               : DataInfo.TransformType.NONE, // transform predictors
           !parms._standardize || train.lastVec().isCategorical()
               ? DataInfo.TransformType.NONE
               : identityLink
                   ? DataInfo.TransformType.STANDARDIZE
                   : DataInfo.TransformType
                       .NONE, // transform response for regression with identity link
           parms._missing_values_handling
               == DeepLearningParameters.MissingValuesHandling.Skip, // whether to skip missing
           false, // do not replace NAs in numeric cols with mean
           true, // always add a bucket for missing values
           parms._weights_column != null, // observation weights
           parms._offset_column != null,
           parms._fold_column != null);
   // Checks and adjustments:
   // 1) observation weights (adjust mean/sigmas for predictors and response)
   // 2) NAs (check that there's enough rows left)
   GLMTask.YMUTask ymt =
       new GLMTask.YMUTask(
               dinfo,
               nClasses,
               true,
               !parms._autoencoder && nClasses == 1,
               false,
               !parms._autoencoder)
           .doAll(dinfo._adaptedFrame);
   if (ymt._wsum == 0
       && parms._missing_values_handling == DeepLearningParameters.MissingValuesHandling.Skip)
     throw new H2OIllegalArgumentException(
         "No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'.");
   if (parms._weights_column != null && parms._offset_column != null) {
     Log.warn(
         "Combination of offset and weights can lead to slight differences because Rollupstats aren't weighted - need to re-calculate weighted mean/sigma of the response including offset terms.");
   }
   if (parms._weights_column != null
       && parms._offset_column == null /*FIXME: offset not yet implemented*/) {
     dinfo.updateWeightedSigmaAndMean(ymt._basicStats.sigma(), ymt._basicStats.mean());
     if (nClasses == 1)
       dinfo.updateWeightedSigmaAndMeanForResponse(
           ymt._basicStatsResponse.sigma(), ymt._basicStatsResponse.mean());
   }
   return dinfo;
 }
  /**
   * Score a frame with the given model and return just the metrics.
   *
   * <p>NOTE: ModelMetrics are now always being created by model.score. . .
   */
  @SuppressWarnings("unused") // called through reflection by RequestServer
  public ModelMetricsListSchemaV3 score(int version, ModelMetricsListSchemaV3 s) {
    // parameters checking:
    if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model);
    if (null == DKV.get(s.model.name))
      throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);

    if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame);
    if (null == DKV.get(s.frame.name))
      throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);

    ModelMetricsList parms = s.createAndFillImpl();
    parms
        ._model
        .score(parms._frame, parms._predictions_name)
        .remove(); // throw away predictions, keep metrics as a side-effect
    ModelMetricsListSchemaV3 mm = this.fetch(version, s);

    // TODO: for now only binary predictors write an MM object.
    // For the others cons one up here to return the predictions frame.
    if (null == mm) mm = new ModelMetricsListSchemaV3();

    if (null == mm.model_metrics || 0 == mm.model_metrics.length) {
      Log.warn(
          "Score() did not return a ModelMetrics for model: " + s.model + " on frame: " + s.frame);
    }

    return mm;
  }
示例#6
0
  private static void sendErrorResponse(HttpServletResponse response, Exception e, String uri) {
    if (e instanceof H2OFailException) {
      H2OFailException ee = (H2OFailException) e;
      H2OError error = ee.toH2OError(uri);

      Log.fatal("Caught exception (fatal to the cluster): " + error.toString());
      throw (H2O.fail(error.toString()));
    } else if (e instanceof H2OAbstractRuntimeException) {
      H2OAbstractRuntimeException ee = (H2OAbstractRuntimeException) e;
      H2OError error = ee.toH2OError(uri);

      Log.warn("Caught exception: " + error.toString());
      setResponseStatus(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR);

      // Note: don't use Schema.schema(version, error) because we have to work at bootstrap:
      try {
        @SuppressWarnings("unchecked")
        String s = new H2OErrorV3().fillFromImpl(error).toJsonString();
        response.getWriter().write(s);
      } catch (Exception ignore) {
      }
    } else { // make sure that no Exception is ever thrown out from the request
      H2OError error = new H2OError(e, uri);

      // some special cases for which we return 400 because it's likely a problem with the client
      // request:
      if (e instanceof IllegalArgumentException)
        error._http_status = HttpResponseStatus.BAD_REQUEST.getCode();
      else if (e instanceof FileNotFoundException)
        error._http_status = HttpResponseStatus.BAD_REQUEST.getCode();
      else if (e instanceof MalformedURLException)
        error._http_status = HttpResponseStatus.BAD_REQUEST.getCode();
      setResponseStatus(response, error._http_status);

      Log.warn("Caught exception: " + error.toString());

      // Note: don't use Schema.schema(version, error) because we have to work at bootstrap:
      try {
        @SuppressWarnings("unchecked")
        String s = new H2OErrorV3().fillFromImpl(error).toJsonString();
        response.getWriter().write(s);
      } catch (Exception ignore) {
      }
    }
  }
示例#7
0
 @Override
 protected boolean toJavaCheckTooBig() {
   if (beta() != null && beta().length > 10000) {
     Log.warn(
         "toJavaCheckTooBig must be overridden for this model type to render it in the browser");
     return true;
   }
   return false;
 }
示例#8
0
 private void fillHelp() {
   this.help = new IcedHashMapGeneric.IcedHashMapStringString();
   try {
     Field[] dest_fields = Weaver.getWovenFields(this.getClass());
     for (Field f : dest_fields) {
       fillHelp(f);
     }
   } catch (Exception e) {
     Log.warn(e);
   }
 }
示例#9
0
 public static InputStream openStream(Key k, ProgressMonitor pmon) throws IOException {
   H2OHdfsInputStream res = null;
   Path p = new Path(k.toString());
   try {
     res = new H2OHdfsInputStream(p, 0, pmon);
   } catch (IOException e) {
     try {
       Thread.sleep(1000);
     } catch (Exception ex) {
     }
     Log.warn("Error while opening HDFS key " + k.toString() + ", will wait and retry.");
     res = new H2OHdfsInputStream(p, 0, pmon);
   }
   return res;
 }
示例#10
0
 public ValidationMessage(
     ModelBuilder.ValidationMessage.MessageType message_type,
     String field_name,
     String message) {
   this.message_type = message_type;
   this.field_name = field_name;
   this.message = message;
   switch (message_type) {
     case INFO:
       Log.info(field_name + ": " + message);
       break;
     case WARN:
       Log.warn(field_name + ": " + message);
       break;
     case ERROR:
       Log.err(field_name + ": " + message);
       break;
   }
 }
示例#11
0
 @Override
 protected Frame rebalance(final Frame original_fr, boolean local, final String name) {
   if (original_fr == null) return null;
   if (_parms._force_load_balance) {
     int original_chunks = original_fr.anyVec().nChunks();
     _job.update(0, "Load balancing " + name.substring(name.length() - 5) + " data...");
     int chunks = desiredChunks(original_fr, local);
     if (!_parms._reproducible) {
       if (original_chunks >= chunks) {
         if (!_parms._quiet_mode)
           Log.info(
               "Dataset already contains " + original_chunks + " chunks. No need to rebalance.");
         return original_fr;
       }
     } else { // reproducible, set chunks to 1
       assert chunks == 1;
       if (!_parms._quiet_mode)
         Log.warn("Reproducibility enforced - using only 1 thread - can be slow.");
       if (original_chunks == 1) return original_fr;
     }
     if (!_parms._quiet_mode)
       Log.info(
           "Rebalancing "
               + name.substring(name.length() - 5)
               + " dataset into "
               + chunks
               + " chunks.");
     Key newKey = Key.make(name + ".chks" + chunks);
     RebalanceDataSet rb = new RebalanceDataSet(original_fr, newKey, chunks);
     H2O.submitTask(rb).join();
     Frame rebalanced_fr = DKV.get(newKey).get();
     Scope.track(rebalanced_fr);
     return rebalanced_fr;
   }
   return original_fr;
 }
示例#12
0
文件: DRF.java 项目: rohit2412/h2o
 @SuppressWarnings("unused")
 @Override
 protected void init() {
   super.init();
   // Initialize local variables
   _mtry =
       (mtries == -1)
           ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
           (classification ? Math.max((int) Math.sqrt(_ncols), 1) : Math.max(_ncols / 3, 1))
           : mtries;
   if (!(1 <= _mtry && _mtry <= _ncols))
     throw new IllegalArgumentException(
         "Computed mtry should be in interval <1,#cols> but it is " + _mtry);
   if (!(0.0 < sample_rate && sample_rate <= 1.0))
     throw new IllegalArgumentException(
         "Sample rate should be interval (0,1> but it is " + sample_rate);
   if (DEBUG_DETERMINISTIC && seed == -1) _seed = 0x1321e74a0192470cL; // fixed version of seed
   else if (seed == -1) _seed = _seedGenerator.nextLong();
   else _seed = seed;
   if (sample_rate == 1f && validation != null)
     Log.warn(
         Sys.DRF__,
         "Sample rate is 100% and no validation dataset is required. There are no OOB data to perform validation!");
 }
示例#13
0
 // Convert a filename string to a Key
 private static Key str2Key_impl(String s) {
   String key = s;
   byte[] kb = new byte[(key.length() - 1) / 2];
   int i = 0, j = 0;
   if ((key.length() > 2)
       && (key.charAt(0) == '%')
       && (key.charAt(1) >= '0')
       && (key.charAt(1) <= '9')) {
     // Dehexalate until '%'
     for (i = 1; i < key.length(); i += 2) {
       if (key.charAt(i) == '%') break;
       char b0 = (char) (key.charAt(i) - '0');
       if (b0 > 9) b0 += '0' + 10 - 'A';
       char b1 = (char) (key.charAt(i + 1) - '0');
       if (b1 > 9) b1 += '0' + 10 - 'A';
       kb[j++] = (byte) ((b0 << 4) | b1); // De-hexelated byte
     }
     i++; // Skip the trailing '%'
   }
   // a normal key - ASCII with special characters encoded after % sign
   for (; i < key.length(); ++i) {
     byte b = (byte) key.charAt(i);
     if (b == '%') {
       switch (key.charAt(++i)) {
         case '%':
           b = '%';
           break;
         case 'c':
           b = ':';
           break;
         case 'd':
           b = '.';
           break;
         case 'g':
           b = '>';
           break;
         case 'l':
           b = '<';
           break;
         case 'q':
           b = '"';
           break;
         case 's':
           b = '/';
           break;
         case 'b':
           b = '\\';
           break;
         case 'z':
           b = '\0';
           break;
         default:
           Log.warn("Invalid format of filename " + s + " at index " + i);
       }
     }
     if (j >= kb.length) kb = Arrays.copyOf(kb, Math.max(2, j * 2));
     kb[j++] = b;
   }
   // now in kb we have the key name
   return Key.make(Arrays.copyOf(kb, j));
 }
示例#14
0
    /**
     * Train a Deep Learning neural net model
     *
     * @param model Input model (e.g., from initModel(), or from a previous training run)
     * @return Trained model
     */
    public final DeepLearningModel trainModel(DeepLearningModel model) {
      Frame validScoreFrame = null;
      Frame train, trainScoreFrame;
      try {
        //      if (checkpoint == null && !quiet_mode) logStart(); //if checkpoint is given, some
        // Job's params might be uninitialized (but the restarted model's parameters are correct)
        if (model == null) {
          model = DKV.get(dest()).get();
        }
        Log.info(
            "Model category: "
                + (_parms._autoencoder
                    ? "Auto-Encoder"
                    : isClassifier() ? "Classification" : "Regression"));
        final long model_size = model.model_info().size();
        Log.info(
            "Number of model parameters (weights/biases): " + String.format("%,d", model_size));
        model.write_lock(_job);
        _job.update(0, "Setting up training data...");
        final DeepLearningParameters mp = model.model_info().get_params();

        // temporary frames of the same "name" as the orig _train/_valid (asking the parameter's
        // Key, not the actual frame)
        // Note: don't put into DKV or they would overwrite the _train/_valid frames!
        Frame tra_fr = new Frame(mp._train, _train.names(), _train.vecs());
        Frame val_fr = _valid != null ? new Frame(mp._valid, _valid.names(), _valid.vecs()) : null;

        train = tra_fr;
        if (model._output.isClassifier() && mp._balance_classes) {
          _job.update(0, "Balancing class distribution of training data...");
          float[] trainSamplingFactors =
              new float
                  [train
                      .lastVec()
                      .domain()
                      .length]; // leave initialized to 0 -> will be filled up below
          if (mp._class_sampling_factors != null) {
            if (mp._class_sampling_factors.length != train.lastVec().domain().length)
              throw new IllegalArgumentException(
                  "class_sampling_factors must have "
                      + train.lastVec().domain().length
                      + " elements");
            trainSamplingFactors =
                mp._class_sampling_factors.clone(); // clone: don't modify the original
          }
          train =
              sampleFrameStratified(
                  train,
                  train.lastVec(),
                  train.vec(model._output.weightsName()),
                  trainSamplingFactors,
                  (long) (mp._max_after_balance_size * train.numRows()),
                  mp._seed,
                  true,
                  false);
          Vec l = train.lastVec();
          Vec w = train.vec(model._output.weightsName());
          MRUtils.ClassDist cd = new MRUtils.ClassDist(l);
          model._output._modelClassDist =
              _weights != null ? cd.doAll(l, w).rel_dist() : cd.doAll(l).rel_dist();
        }
        model.training_rows = train.numRows();
        if (_weights != null && _weights.min() == 0 && _weights.max() == 1 && _weights.isInt()) {
          model.training_rows = Math.round(train.numRows() * _weights.mean());
          Log.warn(
              "Not counting "
                  + (train.numRows() - model.training_rows)
                  + " rows with weight=0 towards an epoch.");
        }
        Log.info("One epoch corresponds to " + model.training_rows + " training data rows.");
        trainScoreFrame =
            sampleFrame(
                train,
                mp._score_training_samples,
                mp._seed); // training scoring dataset is always sampled uniformly from the training
                           // dataset
        if (trainScoreFrame != train) Scope.track(trainScoreFrame);

        if (!_parms._quiet_mode)
          Log.info("Number of chunks of the training data: " + train.anyVec().nChunks());
        if (val_fr != null) {
          model.validation_rows = val_fr.numRows();
          // validation scoring dataset can be sampled in multiple ways from the given validation
          // dataset
          if (model._output.isClassifier()
              && mp._balance_classes
              && mp._score_validation_sampling
                  == DeepLearningParameters.ClassSamplingMethod.Stratified) {
            _job.update(0, "Sampling validation data (stratified)...");
            validScoreFrame =
                sampleFrameStratified(
                    val_fr,
                    val_fr.lastVec(),
                    val_fr.vec(model._output.weightsName()),
                    null,
                    mp._score_validation_samples > 0
                        ? mp._score_validation_samples
                        : val_fr.numRows(),
                    mp._seed + 1,
                    false /* no oversampling */,
                    false);
          } else {
            _job.update(0, "Sampling validation data...");
            validScoreFrame = sampleFrame(val_fr, mp._score_validation_samples, mp._seed + 1);
            if (validScoreFrame != val_fr) Scope.track(validScoreFrame);
          }
          if (!_parms._quiet_mode)
            Log.info(
                "Number of chunks of the validation data: " + validScoreFrame.anyVec().nChunks());
        }

        // Set train_samples_per_iteration size (cannot be done earlier since this depends on
        // whether stratified sampling is done)
        model.actual_train_samples_per_iteration =
            computeTrainSamplesPerIteration(mp, model.training_rows, model);
        // Determine whether shuffling is enforced
        if (mp._replicate_training_data
            && (model.actual_train_samples_per_iteration
                == model.training_rows * (mp._single_node_mode ? 1 : H2O.CLOUD.size()))
            && !mp._shuffle_training_data
            && H2O.CLOUD.size() > 1
            && !mp._reproducible) {
          if (!mp._quiet_mode)
            Log.info(
                "Enabling training data shuffling, because all nodes train on the full dataset (replicated training data).");
          mp._shuffle_training_data = true;
        }
        if (!mp._shuffle_training_data
            && model.actual_train_samples_per_iteration == model.training_rows
            && train.anyVec().nChunks() == 1) {
          if (!mp._quiet_mode)
            Log.info(
                "Enabling training data shuffling to avoid training rows in the same order over and over (no Hogwild since there's only 1 chunk).");
          mp._shuffle_training_data = true;
        }

        //        if (!mp._quiet_mode) Log.info("Initial model:\n" + model.model_info());
        long now = System.currentTimeMillis();
        model._timeLastIterationEnter = now;
        if (_parms._autoencoder) {
          _job.update(0, "Scoring null model of autoencoder...");
          if (!mp._quiet_mode) Log.info("Scoring the null model of the autoencoder.");
          model.doScoring(
              trainScoreFrame,
              validScoreFrame,
              _job._key,
              0,
              false); // get the null model reconstruction error
        }
        // put the initial version of the model into DKV
        model.update(_job);
        model.total_setup_time_ms += now - _job.start_time();
        Log.info("Total setup time: " + PrettyPrint.msecs(model.total_setup_time_ms, true));
        Log.info("Starting to train the Deep Learning model.");
        _job.update(0, "Training...");

        // main loop
        for (; ; ) {
          model.iterations++;
          model.set_model_info(
              mp._epochs == 0
                  ? model.model_info()
                  : H2O.CLOUD.size() > 1 && mp._replicate_training_data
                      ? (mp._single_node_mode
                          ? new DeepLearningTask2(
                                  _job._key,
                                  train,
                                  model.model_info(),
                                  rowFraction(train, mp, model),
                                  model.iterations)
                              .doAll(Key.make(H2O.SELF))
                              .model_info()
                          : // replicated data + single node mode
                          new DeepLearningTask2(
                                  _job._key,
                                  train,
                                  model.model_info(),
                                  rowFraction(train, mp, model),
                                  model.iterations)
                              .doAllNodes()
                              .model_info())
                      : // replicated data + multi-node mode
                      new DeepLearningTask(
                              _job._key,
                              model.model_info(),
                              rowFraction(train, mp, model),
                              model.iterations)
                          .doAll(train)
                          .model_info()); // distributed data (always in multi-node mode)
          if (stop_requested() && !timeout()) break; // cancellation
          if (!model.doScoring(
              trainScoreFrame, validScoreFrame, _job._key, model.iterations, false))
            break; // finished training (or early stopping or convergence)
          if (timeout()) break; // stop after scoring
        }

        // replace the model with the best model so far (if it's better)
        if (!stop_requested()
            && _parms._overwrite_with_best_model
            && model.actual_best_model_key != null
            && _parms._nfolds == 0) {
          DeepLearningModel best_model = DKV.getGet(model.actual_best_model_key);
          if (best_model != null
              && best_model.loss() < model.loss()
              && Arrays.equals(best_model.model_info().units, model.model_info().units)) {
            if (!_parms._quiet_mode)
              Log.info("Setting the model to be the best model so far (based on scoring history).");
            DeepLearningModelInfo mi = best_model.model_info().deep_clone();
            // Don't cheat - count full amount of training samples, since that's the amount of
            // training it took to train (without finding anything better)
            mi.set_processed_global(model.model_info().get_processed_global());
            mi.set_processed_local(model.model_info().get_processed_local());
            model.set_model_info(mi);
            model.update(_job);
            model.doScoring(trainScoreFrame, validScoreFrame, _job._key, model.iterations, true);
            assert (best_model.loss() == model.loss());
          }
        }
        // store coefficient names for future use
        // possibly change
        model.model_info().data_info().coefNames();
        if (!_parms._quiet_mode) {
          Log.info(
              "==============================================================================================================================================================================");
          if (stop_requested()) {
            Log.info("Deep Learning model training was interrupted.");
          } else {
            Log.info("Finished training the Deep Learning model.");
            Log.info(model);
          }
          Log.info(
              "==============================================================================================================================================================================");
        }
      } finally {
        if (model != null) {
          model.deleteElasticAverageModels();
          model.unlock(_job);
          if (model.actual_best_model_key != null) {
            assert (model.actual_best_model_key != model._key);
            DKV.remove(model.actual_best_model_key);
          }
        }
      }
      return model;
    }
示例#15
0
    /**
     * Train a Deep Learning model, assumes that all members are populated If checkpoint == null,
     * then start training a new model, otherwise continue from a checkpoint
     */
    public final void buildModel() {
      DeepLearningModel cp = null;
      if (_parms._checkpoint == null) {
        cp =
            new DeepLearningModel(
                dest(),
                _parms,
                new DeepLearningModel.DeepLearningModelOutput(DeepLearning.this),
                _train,
                _valid,
                nclasses());
        cp.model_info().initializeMembers();
      } else {
        final DeepLearningModel previous = DKV.getGet(_parms._checkpoint);
        if (previous == null) throw new IllegalArgumentException("Checkpoint not found.");
        Log.info("Resuming from checkpoint.");
        _job.update(0, "Resuming from checkpoint");

        if (isClassifier() != previous._output.isClassifier())
          throw new H2OIllegalArgumentException(
              "Response type must be the same as for the checkpointed model.");
        if (isSupervised() != previous._output.isSupervised())
          throw new H2OIllegalArgumentException(
              "Model type must be the same as for the checkpointed model.");

        // check the user-given arguments for consistency
        DeepLearningParameters oldP =
            previous._parms; // sanitized parameters for checkpointed model
        DeepLearningParameters newP = _parms; // user-given parameters for restart

        DeepLearningParameters oldP2 = (DeepLearningParameters) oldP.clone();
        DeepLearningParameters newP2 = (DeepLearningParameters) newP.clone();
        DeepLearningParameters.Sanity.modifyParms(
            oldP, oldP2, nclasses()); // sanitize the user-given parameters
        DeepLearningParameters.Sanity.modifyParms(
            newP, newP2, nclasses()); // sanitize the user-given parameters
        DeepLearningParameters.Sanity.checkpoint(oldP2, newP2);

        DataInfo dinfo;
        try {
          // PUBDEV-2513: Adapt _train and _valid (in-place) to match the frames that were used for
          // the previous model
          // This can add or remove dummy columns (can happen if the dataset is sparse and datasets
          // have different non-const columns)
          for (String st : previous.adaptTestForTrain(_train, true, false)) Log.warn(st);
          for (String st : previous.adaptTestForTrain(_valid, true, false)) Log.warn(st);
          dinfo = makeDataInfo(_train, _valid, _parms, nclasses());
          DKV.put(dinfo);
          cp = new DeepLearningModel(dest(), _parms, previous, false, dinfo);
          cp.write_lock(_job);

          if (!Arrays.equals(cp._output._names, previous._output._names)) {
            throw new H2OIllegalArgumentException(
                "The columns of the training data must be the same as for the checkpointed model. Check ignored columns (or disable ignore_const_cols).");
          }
          if (!Arrays.deepEquals(cp._output._domains, previous._output._domains)) {
            throw new H2OIllegalArgumentException(
                "Categorical factor levels of the training data must be the same as for the checkpointed model.");
          }
          if (dinfo.fullN() != previous.model_info().data_info().fullN()) {
            throw new H2OIllegalArgumentException(
                "Total number of predictors is different than for the checkpointed model.");
          }
          if (_parms._epochs <= previous.epoch_counter) {
            throw new H2OIllegalArgumentException(
                "Total number of epochs must be larger than the number of epochs already trained for the checkpointed model ("
                    + previous.epoch_counter
                    + ").");
          }

          // these are the mutable parameters that are to be used by the model (stored in
          // model_info._parms)
          final DeepLearningParameters actualNewP =
              cp.model_info()
                  .get_params(); // actually used parameters for model building (defaults filled in,
                                 // etc.)
          assert (actualNewP != previous.model_info().get_params());
          assert (actualNewP != newP);
          assert (actualNewP != oldP);
          DeepLearningParameters.Sanity.update(actualNewP, newP, nclasses());

          Log.info(
              "Continuing training after "
                  + String.format("%.3f", previous.epoch_counter)
                  + " epochs from the checkpointed model.");
          cp.update(_job);
        } catch (H2OIllegalArgumentException ex) {
          if (cp != null) {
            cp.unlock(_job);
            cp.delete();
            cp = null;
          }
          throw ex;
        } finally {
          if (cp != null) cp.unlock(_job);
        }
      }
      trainModel(cp);

      // clean up, but don't delete weights and biases if user asked for export
      List<Key> keep = new ArrayList<>();
      try {
        if (_parms._export_weights_and_biases
            && cp._output.weights != null
            && cp._output.biases != null) {
          for (Key k : Arrays.asList(cp._output.weights)) {
            keep.add(k);
            for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
              keep.add(vk._key);
            }
          }
          for (Key k : Arrays.asList(cp._output.biases)) {
            keep.add(k);
            for (Vec vk : ((Frame) DKV.getGet(k)).vecs()) {
              keep.add(vk._key);
            }
          }
        }
      } finally {
        Scope.exit(keep.toArray(new Key[keep.size()]));
      }
    }
示例#16
0
文件: RPC.java 项目: liaochy/h2o-3
  // Handle traffic, from a client to this server asking for work to be done.
  // Called from either a F/J thread (generally with a UDP packet) or from the
  // TCPReceiver thread.
  static void remote_exec(AutoBuffer ab) {
    long lo = ab.get8(0), hi = ab._size >= 16 ? ab.get8(8) : 0;
    final int task = ab.getTask();
    final int flag = ab.getFlag();
    assert flag == CLIENT_UDP_SEND || flag == CLIENT_TCP_SEND; // Client-side send
    // Atomically record an instance of this task, one-time-only replacing a
    // null with an RPCCall, a placeholder while we work on a proper response -
    // and it serves to let us discard dup UDP requests.
    RPCCall old = ab._h2o.has_task(task);
    // This is a UDP packet requesting an answer back for a request sent via
    // TCP but the UDP packet has arrived ahead of the TCP.  Just drop the UDP
    // and wait for the TCP to appear.
    if (old == null && flag == CLIENT_TCP_SEND) {
      Log.warn(
          "got tcp with existing task #, FROM "
              + ab._h2o.toString()
              + " AB: " /* +  UDP.printx16(lo,hi)*/);
      assert !ab.hasTCP()
          : "ERROR: got tcp with existing task #, FROM "
              + ab._h2o.toString()
              + " AB: " /* + UDP.printx16(lo,hi)*/; // All the resends should be UDP only
      // DROP PACKET
    } else if (old == null) { // New task?
      RPCCall rpc;
      try {
        // Read the DTask Right Now.  If we are the TCPReceiver thread, then we
        // are reading in that thread... and thus TCP reads are single-threaded.
        rpc = new RPCCall(ab.get(water.DTask.class), ab._h2o, task);
      } catch (AutoBuffer.AutoBufferException e) {
        // Here we assume it's a TCP fail on read - and ignore the remote_exec
        // request.  The caller will send it again.  NOTE: this case is
        // indistinguishable from a broken short-writer/long-reader bug, except
        // that we'll re-send endlessly and fail endlessly.
        Log.info(
            "Network congestion OR short-writer/long-reader: TCP "
                + e._ioe.getMessage()
                + ",  AB="
                + ab
                + ", ignoring partial send");
        ab.drainClose();
        return;
      }
      RPCCall rpc2 = ab._h2o.record_task(rpc);
      if (rpc2 == null) { // Atomically insert (to avoid double-work)
        if (rpc._dt instanceof MRTask && rpc._dt.logVerbose())
          Log.debug("Start remote task#" + task + " " + rpc._dt.getClass() + " from " + ab._h2o);
        H2O.submitTask(rpc); // And execute!
      } else { // Else lost the task-insertion race
        if (ab.hasTCP()) ab.drainClose();
        // DROP PACKET
      }

    } else if (!old._computedAndReplied) {
      // This packet has not been fully computed.  Hence it's still a work-in-
      // progress locally.  We have no answer to reply but we do not want to
      // re-offer the packet for repeated work.  Send back a NACK, letting the
      // client know we're Working On It
      assert !ab.hasTCP()
          : "got tcp with existing task #, FROM "
              + ab._h2o.toString()
              + " AB: "
              + UDP.printx16(lo, hi)
              + ", position = "
              + ab._bb.position();
      ab.clearForWriting(udp.nack._prior).putTask(UDP.udp.nack.ordinal(), task);
      // DROP PACKET
    } else {
      // This is an old re-send of the same thing we've answered to before.
      // Send back the same old answer ACK.  If we sent via TCP before, then
      // we know the answer got there so just send a control-ACK back.  If we
      // sent via UDP, resend the whole answer.
      if (ab.hasTCP()) {
        Log.warn(
            "got tcp with existing task #, FROM "
                + ab._h2o.toString()
                + " AB: "
                + UDP.printx16(lo, hi)); // All the resends should be UDP only
        ab.drainClose();
      }
      if (old._dt != null) { // already ackacked
        ++old._ackResendCnt;
        if (old._ackResendCnt % 10 == 0)
          Log.err(
              "Possibly broken network, can not send ack through, got "
                  + old._ackResendCnt
                  + " for task # "
                  + old._tsknum
                  + ", dt == null?"
                  + (old._dt == null));
        old.resend_ack();
      }
    }
    ab.close();
  }
示例#17
0
  // Set K/V cache goals.
  // Allow (or disallow) allocations.
  // Called from the Cleaner, when "cacheUsed" has changed significantly.
  // Called from any FullGC notification, and HEAP/POJO_USED changed.
  // Called on any OOM allocation
  public static void set_goals(String msg, boolean oom, long bytes) {
    // Our best guess of free memory, as of the last GC cycle
    final long heapUsed = Boot.HEAP_USED_AT_LAST_GC;
    final long timeGC = Boot.TIME_AT_LAST_GC;
    final long freeHeap = MEM_MAX - heapUsed;
    assert freeHeap >= 0
        : "I am really confused about the heap usage; MEM_MAX=" + MEM_MAX + " heapUsed=" + heapUsed;
    // Current memory held in the K/V store.
    final long cacheUsage = myHisto.histo(false)._cached;
    // Our best guess of POJO object usage: Heap_used minus cache used
    final long pojoUsedGC = Math.max(heapUsed - cacheUsage, 0);

    // Block allocations if:
    // the cache is > 7/8 MEM_MAX, OR
    // we cannot allocate an equal amount of POJOs, pojoUsedGC > freeHeap.
    // Decay POJOS_USED by 1/8th every 5 sec: assume we got hit with a single
    // large allocation which is not repeating - so we do not need to have
    // double the POJO amount.
    // Keep at least 1/8th heap for caching.
    // Emergency-clean the cache down to the blocking level.
    long d = MEM_CRITICAL;
    // Decay POJO amount
    long p = pojoUsedGC;
    long age = (System.currentTimeMillis() - timeGC); // Age since last FullGC
    age = Math.min(age, 10 * 60 * 1000); // Clip at 10mins
    while ((age -= 5000) > 0) p = p - (p >> 3); // Decay effective POJO by 1/8th every 5sec
    d -= 2 * p - bytes; // Allow for the effective POJO, and again to throttle GC rate
    d = Math.max(d, MEM_MAX >> 3); // Keep at least 1/8th heap
    H2O.Cleaner.DESIRED = d;

    String m = "";
    if (cacheUsage > H2O.Cleaner.DESIRED) {
      m = (CAN_ALLOC ? "Blocking!  " : "blocked:   ");
      if (oom) setMemLow(); // Stop allocations; trigger emergency clean
      Boot.kick_store_cleaner();
    } else { // Else we are not *emergency* cleaning, but may be lazily cleaning.
      if (!CAN_ALLOC) m = "Unblocking:";
      else m = "MemGood:   ";
      setMemGood();
      if (oom) // Confused? OOM should have FullGCd should have set low-mem goals
      Log.warn(
            Sys.CLEAN,
            "OOM but no FullGC callback?  MEM_MAX = "
                + MEM_MAX
                + ", DESIRED = "
                + d
                + ", CACHE = "
                + cacheUsage
                + ", p = "
                + p
                + ", bytes = "
                + bytes);
    }

    // No logging if under memory pressure: can deadlock the cleaner thread
    if (Log.flag(Sys.CLEAN)) {
      String s =
          m
              + msg
              + ", HEAP_LAST_GC="
              + (heapUsed >> 20)
              + "M, KV="
              + (cacheUsage >> 20)
              + "M, POJO="
              + (pojoUsedGC >> 20)
              + "M, free="
              + (freeHeap >> 20)
              + "M, MAX="
              + (MEM_MAX >> 20)
              + "M, DESIRED="
              + (H2O.Cleaner.DESIRED >> 20)
              + "M"
              + (oom ? " OOM!" : " NO-OOM");
      if (CAN_ALLOC) Log.debug(Sys.CLEAN, s);
      else Log.unwrap(System.err, s);
    }
  }
  /**
   * Main constructor
   *
   * @param params Model parameters
   * @param dinfo Data Info
   * @param nClasses number of classes (1 for regression, 0 for autoencoder)
   * @param train User-given training data frame, prepared by AdaptTestTrain
   * @param valid User-specified validation data frame, prepared by AdaptTestTrain
   */
  public DeepLearningModelInfo(
      final DeepLearningParameters params,
      final DataInfo dinfo,
      int nClasses,
      Frame train,
      Frame valid) {
    _classification = nClasses > 1;
    _train = train;
    _valid = valid;
    data_info = dinfo;
    parameters =
        (DeepLearningParameters) params.clone(); // make a copy, don't change model's parameters
    DeepLearningParameters.Sanity.modifyParms(
        parameters, parameters, nClasses); // sanitize the model_info's parameters

    final int num_input = dinfo.fullN();
    final int num_output =
        get_params()._autoencoder
            ? num_input
            : (_classification ? train.lastVec().cardinality() : 1);
    if (!get_params()._autoencoder) assert (num_output == nClasses);

    _saw_missing_cats = dinfo._cats > 0 ? new boolean[data_info._cats] : null;
    assert (num_input > 0);
    assert (num_output > 0);
    if (has_momenta() && adaDelta())
      throw new IllegalArgumentException(
          "Cannot have non-zero momentum and adaptive rate at the same time.");
    final int layers = get_params()._hidden.length;
    // units (# neurons for each layer)
    units = new int[layers + 2];
    if (get_params()._max_categorical_features <= Integer.MAX_VALUE - dinfo._nums)
      units[0] = Math.min(dinfo._nums + get_params()._max_categorical_features, num_input);
    else units[0] = num_input;
    System.arraycopy(get_params()._hidden, 0, units, 1, layers);
    units[layers + 1] = num_output;

    boolean printLevels = units[0] > 1000L;
    boolean warn = units[0] > 100000L;
    if (printLevels) {
      final String[][] domains = dinfo._adaptedFrame.domains();
      int[] levels = new int[domains.length];
      for (int i = 0; i < levels.length; ++i) {
        levels[i] = domains[i] != null ? domains[i].length : 0;
      }
      Arrays.sort(levels);
      if (warn) {
        Log.warn(
            "===================================================================================================================================");
        Log.warn(
            num_input
                + " input features"
                + (dinfo._cats > 0 ? " (after categorical one-hot encoding)" : "")
                + ". Can be slow and require a lot of memory.");
      }
      if (levels[levels.length - 1] > 0) {
        int levelcutoff = levels[levels.length - 1 - Math.min(10, levels.length - 1)];
        int count = 0;
        for (int i = 0;
            i < dinfo._adaptedFrame.numCols() - (get_params()._autoencoder ? 0 : 1) && count < 10;
            ++i) {
          if (dinfo._adaptedFrame.domains()[i] != null
              && dinfo._adaptedFrame.domains()[i].length >= levelcutoff) {
            if (warn) {
              Log.warn(
                  "Categorical feature '"
                      + dinfo._adaptedFrame._names[i]
                      + "' has cardinality "
                      + dinfo._adaptedFrame.domains()[i].length
                      + ".");
            } else {
              Log.info(
                  "Categorical feature '"
                      + dinfo._adaptedFrame._names[i]
                      + "' has cardinality "
                      + dinfo._adaptedFrame.domains()[i].length
                      + ".");
            }
          }
          count++;
        }
      }
      if (warn) {
        Log.warn("Suggestions:");
        Log.warn(" *) Limit the size of the first hidden layer");
        if (dinfo._cats > 0) {
          Log.warn(
              " *) Limit the total number of one-hot encoded features with the parameter 'max_categorical_features'");
          Log.warn(
              " *) Run h2o.interaction(...,pairwise=F) on high-cardinality categorical columns to limit the factor count, see http://learn.h2o.ai");
        }
        Log.warn(
            "===================================================================================================================================");
      }
    }

    // weights (to connect layers)
    dense_row_weights = new Storage.DenseRowMatrix[layers + 1];
    dense_col_weights = new Storage.DenseColMatrix[layers + 1];

    // decide format of weight matrices row-major or col-major
    if (get_params()._col_major)
      dense_col_weights[0] = new Storage.DenseColMatrix(units[1], units[0]);
    else dense_row_weights[0] = new Storage.DenseRowMatrix(units[1], units[0]);
    for (int i = 1; i <= layers; ++i)
      dense_row_weights[i] = new Storage.DenseRowMatrix(units[i + 1] /*rows*/, units[i] /*cols*/);

    // biases (only for hidden layers and output layer)
    biases = new Storage.DenseVector[layers + 1];
    for (int i = 0; i <= layers; ++i) biases[i] = new Storage.DenseVector(units[i + 1]);
    // average activation (only for hidden layers)
    if (get_params()._autoencoder && get_params()._sparsity_beta > 0) {
      avg_activations = new Storage.DenseVector[layers];
      mean_a = new float[layers];
      for (int i = 0; i < layers; ++i) avg_activations[i] = new Storage.DenseVector(units[i + 1]);
    }
    allocateHelperArrays();
    // for diagnostics
    mean_rate = new float[units.length];
    rms_rate = new float[units.length];
    mean_bias = new float[units.length];
    rms_bias = new float[units.length];
    mean_weight = new float[units.length];
    rms_weight = new float[units.length];
  }