@Override protected void init() { if (validation != null && n_folds != 0) throw new UnsupportedOperationException( "Cannot specify a validation dataset and non-zero number of cross-validation folds."); if (n_folds < 0) throw new UnsupportedOperationException( "The number of cross-validation folds must be >= 0."); super.init(); xval_models = new Key[n_folds]; for (int i = 0; i < xval_models.length; ++i) xval_models[i] = Key.make(dest().toString() + "_xval" + i); int rIndex = 0; for (int i = 0; i < source.vecs().length; i++) if (source.vecs()[i] == response) { rIndex = i; break; } _responseName = source._names != null && rIndex >= 0 ? source._names[rIndex] : "response"; _train = selectVecs(source); _names = new String[cols.length]; for (int i = 0; i < cols.length; i++) _names[i] = source._names[cols[i]]; // Compute source response domain if (classification) _sourceResponseDomain = getVectorDomain(response); // Is validation specified? if (validation != null) { // Extract a validation response int idx = validation.find(source.names()[rIndex]); if (idx == -1) throw new IllegalArgumentException( "Validation set does not have a response column called " + _responseName); _validResponse = validation.vecs()[idx]; // Compute output confusion matrix domain for classification: // - if validation dataset is specified then CM domain is union of train and validation // response domains // else it is only domain of response column. if (classification) { _validResponseDomain = getVectorDomain(_validResponse); if (_validResponseDomain != null) { _cmDomain = Utils.domainUnion(_sourceResponseDomain, _validResponseDomain); if (!Arrays.deepEquals(_sourceResponseDomain, _validResponseDomain)) { _fromModel2CM = Model.getDomainMapping( _cmDomain, _sourceResponseDomain, false); // transformation from model produced response ~> cmDomain _fromValid2CM = Model.getDomainMapping( _cmDomain, _validResponseDomain, false); // transformation from validation response domain ~> cmDomain } } else _cmDomain = _sourceResponseDomain; } /* end of if classification */ } else if (classification) _cmDomain = _sourceResponseDomain; }
/** * Helper to handle arguments based on existing input values * * @param arg * @param inputArgs */ @Override protected void queryArgumentValueSet(Argument arg, java.util.Properties inputArgs) { super.queryArgumentValueSet(arg, inputArgs); if (arg._name.equals("n_folds") && validation != null) { arg.disable("Only if no validation dataset is provided."); n_folds = 0; } }
/** * Helper to specify which arguments trigger a refresh on change * * @param ver */ @Override protected void registered(RequestServer.API_VERSION ver) { super.registered(ver); for (Argument arg : _arguments) { if (arg._name.equals("validation")) { arg.setRefreshOnChange(); } } }