Ejemplo n.º 1
0
    @Override
    protected void init() {
      if (validation != null && n_folds != 0)
        throw new UnsupportedOperationException(
            "Cannot specify a validation dataset and non-zero number of cross-validation folds.");
      if (n_folds < 0)
        throw new UnsupportedOperationException(
            "The number of cross-validation folds must be >= 0.");
      super.init();
      xval_models = new Key[n_folds];
      for (int i = 0; i < xval_models.length; ++i)
        xval_models[i] = Key.make(dest().toString() + "_xval" + i);

      int rIndex = 0;
      for (int i = 0; i < source.vecs().length; i++)
        if (source.vecs()[i] == response) {
          rIndex = i;
          break;
        }
      _responseName = source._names != null && rIndex >= 0 ? source._names[rIndex] : "response";

      _train = selectVecs(source);
      _names = new String[cols.length];
      for (int i = 0; i < cols.length; i++) _names[i] = source._names[cols[i]];

      // Compute source response domain
      if (classification) _sourceResponseDomain = getVectorDomain(response);
      // Is validation specified?
      if (validation != null) {
        // Extract a validation response
        int idx = validation.find(source.names()[rIndex]);
        if (idx == -1)
          throw new IllegalArgumentException(
              "Validation set does not have a response column called " + _responseName);
        _validResponse = validation.vecs()[idx];
        // Compute output confusion matrix domain for classification:
        // - if validation dataset is specified then CM domain is union of train and validation
        // response domains
        //   else it is only domain of response column.
        if (classification) {
          _validResponseDomain = getVectorDomain(_validResponse);
          if (_validResponseDomain != null) {
            _cmDomain = Utils.domainUnion(_sourceResponseDomain, _validResponseDomain);
            if (!Arrays.deepEquals(_sourceResponseDomain, _validResponseDomain)) {
              _fromModel2CM =
                  Model.getDomainMapping(
                      _cmDomain,
                      _sourceResponseDomain,
                      false); // transformation from model produced response ~> cmDomain
              _fromValid2CM =
                  Model.getDomainMapping(
                      _cmDomain,
                      _validResponseDomain,
                      false); // transformation from validation response domain ~> cmDomain
            }
          } else _cmDomain = _sourceResponseDomain;
        } /* end of if classification */
      } else if (classification) _cmDomain = _sourceResponseDomain;
    }
Ejemplo n.º 2
0
  private static void addFolder(FileSystem fs, Path p, JsonArray succeeded, JsonArray failed) {
    try {
      if (fs == null) return;
      for (FileStatus file : fs.listStatus(p)) {
        Path pfs = file.getPath();
        if (file.isDir()) {
          addFolder(fs, pfs, succeeded, failed);
        } else {
          Key k = Key.make(pfs.toString());
          long size = file.getLen();
          Value val = null;
          if (pfs.getName().endsWith(Extensions.JSON)) {
            JsonParser parser = new JsonParser();
            JsonObject json = parser.parse(new InputStreamReader(fs.open(pfs))).getAsJsonObject();
            JsonElement v = json.get(Constants.VERSION);
            if (v == null) throw new InvalidDataException("Missing version");
            JsonElement type = json.get(Constants.TYPE);
            if (type == null) throw new InvalidDataException("Missing type");
            Class c = Class.forName(type.getAsString());
            Model model = (Model) c.newInstance();
            model.fromJson(json);
          } else if (pfs.getName().endsWith(Extensions.HEX)) { // Hex file?
            FSDataInputStream s = fs.open(pfs);
            int sz = (int) Math.min(1L << 20, size); // Read up to the 1st meg
            byte[] mem = MemoryManager.malloc1(sz);
            s.readFully(mem);
            // Convert to a ValueArray (hope it fits in 1Meg!)
            ValueArray ary = new ValueArray(k, 0).read(new AutoBuffer(mem));
            val = new Value(k, ary, Value.HDFS);
          } else if (size >= 2 * ValueArray.CHUNK_SZ) {
            val =
                new Value(
                    k,
                    new ValueArray(k, size),
                    Value.HDFS); // ValueArray byte wrapper over a large file
          } else {
            val = new Value(k, (int) size, Value.HDFS); // Plain Value
          }
          val.setdsk();
          DKV.put(k, val);

          JsonObject o = new JsonObject();
          o.addProperty(Constants.KEY, k.toString());
          o.addProperty(Constants.FILE, pfs.toString());
          o.addProperty(Constants.VALUE_SIZE, file.getLen());
          succeeded.add(o);
        }
      }
    } catch (Exception e) {
      Log.err(e);
      JsonObject o = new JsonObject();
      o.addProperty(Constants.FILE, p.toString());
      o.addProperty(Constants.ERROR, e.getMessage());
      failed.add(o);
    }
  }
Ejemplo n.º 3
0
 protected final void prepareValidationWithModel(final Model model) {
   if (validation == null) return;
   Frame[] av = model.adapt(validation, false);
   _adaptedValidation = av[0];
   gtrash(av[1]); // delete this after computation
   if (_fromValid2CM != null) {
     assert classification
         : "Validation response transformation should be declared only for classification!";
     assert _fromModel2CM != null
         : "Model response transformation should exist if validation response transformation exists!";
     Vec tmp = _validResponse.toEnum();
     _adaptedValidationResponse =
         tmp.makeTransf(
             _fromValid2CM, getCMDomain()); // Add an original response adapted to CM domain
     gtrash(_adaptedValidationResponse); // Add the created vector to a clean-up list
     gtrash(tmp);
   }
 }