@Override public DataSet call(Collection<Writable> writables) throws Exception { List<Writable> list; if (writables instanceof List) list = (List<Writable>) writables; else list = new ArrayList<>(writables); // allow people to specify label index as -1 and infer the last possible label int labelIndex = this.labelIndex; if (numPossibleLabels >= 1 && labelIndex < 0) { labelIndex = list.size() - 1; } INDArray label = null; INDArray featureVector = Nd4j.create(labelIndex >= 0 ? list.size() - 1 : list.size()); int featureCount = 0; for (int j = 0; j < list.size(); j++) { Writable current = list.get(j); if (converter != null) current = converter.convert(current); if (labelIndex >= 0 && j == labelIndex) { // Current value is the label if (converter != null) { try { current = converter.convert(current); } catch (WritableConverterException e) { e.printStackTrace(); } } if (numPossibleLabels < 1) throw new IllegalStateException("Number of possible labels invalid, must be >= 1"); if (regression) { label = Nd4j.scalar(current.toDouble()); } else { // Convert to one-hot vector for int curr = current.toInt(); if (curr >= numPossibleLabels) throw new IllegalStateException( "Invalid input: class label is " + curr + " with numPossibleLables = " + numPossibleLabels + " (class label must be 0 <= labelIdx < numPossibleLabels)"); label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels); } } else { // Current value is not the label featureVector.putScalar(featureCount++, current.toDouble()); } } DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector)); if (preProcessor != null) preProcessor.preProcess(ds); return ds; }
private DataSet getDataSet(Collection<Writable> record) { List<Writable> currList; if (record instanceof List) currList = (List<Writable>) record; else currList = new ArrayList<>(record); // allow people to specify label index as -1 and infer the last possible label if (numPossibleLabels >= 1 && labelIndex < 0) { labelIndex = record.size() - 1; } INDArray label = null; INDArray featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size()); for (int j = 0; j < currList.size(); j++) { if (labelIndex >= 0 && j == labelIndex) { if (numPossibleLabels < 1) throw new IllegalStateException("Number of possible labels invalid, must be >= 1"); Writable current = currList.get(j); if (current.toString().isEmpty()) continue; if (converter != null) try { current = converter.convert(current); } catch (WritableConverterException e) { e.printStackTrace(); } if (regression) { label = Nd4j.scalar(Double.valueOf(current.toString())); } else { int curr = Double.valueOf(current.toString()).intValue(); if (curr >= numPossibleLabels) curr--; label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels); } } else { Writable current = currList.get(j); if (current.toString().isEmpty()) continue; featureVector.putScalar(j, Double.valueOf(current.toString())); } } return new DataSet(featureVector, labelIndex >= 0 ? label : featureVector); }