Ejemplo n.º 1
0
    @Override
    protected void init() {
      if (validation != null && n_folds != 0)
        throw new UnsupportedOperationException(
            "Cannot specify a validation dataset and non-zero number of cross-validation folds.");
      if (n_folds < 0)
        throw new UnsupportedOperationException(
            "The number of cross-validation folds must be >= 0.");
      super.init();
      xval_models = new Key[n_folds];
      for (int i = 0; i < xval_models.length; ++i)
        xval_models[i] = Key.make(dest().toString() + "_xval" + i);

      int rIndex = 0;
      for (int i = 0; i < source.vecs().length; i++)
        if (source.vecs()[i] == response) {
          rIndex = i;
          break;
        }
      _responseName = source._names != null && rIndex >= 0 ? source._names[rIndex] : "response";

      _train = selectVecs(source);
      _names = new String[cols.length];
      for (int i = 0; i < cols.length; i++) _names[i] = source._names[cols[i]];

      // Compute source response domain
      if (classification) _sourceResponseDomain = getVectorDomain(response);
      // Is validation specified?
      if (validation != null) {
        // Extract a validation response
        int idx = validation.find(source.names()[rIndex]);
        if (idx == -1)
          throw new IllegalArgumentException(
              "Validation set does not have a response column called " + _responseName);
        _validResponse = validation.vecs()[idx];
        // Compute output confusion matrix domain for classification:
        // - if validation dataset is specified then CM domain is union of train and validation
        // response domains
        //   else it is only domain of response column.
        if (classification) {
          _validResponseDomain = getVectorDomain(_validResponse);
          if (_validResponseDomain != null) {
            _cmDomain = Utils.domainUnion(_sourceResponseDomain, _validResponseDomain);
            if (!Arrays.deepEquals(_sourceResponseDomain, _validResponseDomain)) {
              _fromModel2CM =
                  Model.getDomainMapping(
                      _cmDomain,
                      _sourceResponseDomain,
                      false); // transformation from model produced response ~> cmDomain
              _fromValid2CM =
                  Model.getDomainMapping(
                      _cmDomain,
                      _validResponseDomain,
                      false); // transformation from validation response domain ~> cmDomain
            }
          } else _cmDomain = _sourceResponseDomain;
        } /* end of if classification */
      } else if (classification) _cmDomain = _sourceResponseDomain;
    }
Ejemplo n.º 2
0
 /**
  * Helper to handle arguments based on existing input values
  *
  * @param arg
  * @param inputArgs
  */
 @Override
 protected void queryArgumentValueSet(Argument arg, java.util.Properties inputArgs) {
   super.queryArgumentValueSet(arg, inputArgs);
   if (arg._name.equals("n_folds") && validation != null) {
     arg.disable("Only if no validation dataset is provided.");
     n_folds = 0;
   }
 }
Ejemplo n.º 3
0
 /**
  * Helper to specify which arguments trigger a refresh on change
  *
  * @param ver
  */
 @Override
 protected void registered(RequestServer.API_VERSION ver) {
   super.registered(ver);
   for (Argument arg : _arguments) {
     if (arg._name.equals("validation")) {
       arg.setRefreshOnChange();
     }
   }
 }