private static INDArray getData() {
   Random r = new Random(1);
   int[] result = new int[window];
   for (int i = 0; i < window; i++) {
     result[i] = r.nextInt(nIn);
   }
   return FeatureUtil.toOutcomeMatrix(result, nIn);
 }
  @Override
  public DataSet call(Collection<Writable> writables) throws Exception {
    List<Writable> list;
    if (writables instanceof List) list = (List<Writable>) writables;
    else list = new ArrayList<>(writables);

    // allow people to specify label index as -1 and infer the last possible label
    int labelIndex = this.labelIndex;
    if (numPossibleLabels >= 1 && labelIndex < 0) {
      labelIndex = list.size() - 1;
    }

    INDArray label = null;
    INDArray featureVector = Nd4j.create(labelIndex >= 0 ? list.size() - 1 : list.size());
    int featureCount = 0;
    for (int j = 0; j < list.size(); j++) {
      Writable current = list.get(j);
      if (converter != null) current = converter.convert(current);
      if (labelIndex >= 0 && j == labelIndex) {
        // Current value is the label
        if (converter != null) {
          try {
            current = converter.convert(current);
          } catch (WritableConverterException e) {
            e.printStackTrace();
          }
        }
        if (numPossibleLabels < 1)
          throw new IllegalStateException("Number of possible labels invalid, must be >= 1");

        if (regression) {
          label = Nd4j.scalar(current.toDouble());
        } else {
          // Convert to one-hot vector for
          int curr = current.toInt();
          if (curr >= numPossibleLabels)
            throw new IllegalStateException(
                "Invalid input: class label is "
                    + curr
                    + " with numPossibleLables = "
                    + numPossibleLabels
                    + " (class label must be 0 <= labelIdx < numPossibleLabels)");
          label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
        }
      } else {
        // Current value is not the label
        featureVector.putScalar(featureCount++, current.toDouble());
      }
    }

    DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector));
    if (preProcessor != null) preProcessor.preProcess(ds);
    return ds;
  }
  /**
   * Vectorizes the passed in text treating it as one document
   *
   * @param text the text to vectorize
   * @param label the label of the text
   * @return a dataset with a transform of weights(relative to impl; could be word counts or tfidf
   *     scores)
   */
  @Override
  public DataSet vectorize(String text, String label) {
    Tokenizer tokenizer = tokenizerFactory.create(text);
    List<String> tokens = tokenizer.getTokens();
    INDArray input = Nd4j.create(1, cache.numWords());
    for (String token : tokens) {
      int idx = cache.indexOf(token);
      if (cache.indexOf(token) >= 0) input.putScalar(idx, cache.wordFrequency(token));
    }

    INDArray labelMatrix = FeatureUtil.toOutcomeVector(labels.indexOf(label), labels.size());
    return new DataSet(input, labelMatrix);
  }
  private DataSet getDataSet(Collection<Writable> record) {
    List<Writable> currList;
    if (record instanceof List) currList = (List<Writable>) record;
    else currList = new ArrayList<>(record);

    // allow people to specify label index as -1 and infer the last possible label
    if (numPossibleLabels >= 1 && labelIndex < 0) {
      labelIndex = record.size() - 1;
    }

    INDArray label = null;
    INDArray featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
    for (int j = 0; j < currList.size(); j++) {
      if (labelIndex >= 0 && j == labelIndex) {
        if (numPossibleLabels < 1)
          throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
        Writable current = currList.get(j);
        if (current.toString().isEmpty()) continue;
        if (converter != null)
          try {
            current = converter.convert(current);
          } catch (WritableConverterException e) {
            e.printStackTrace();
          }
        if (regression) {
          label = Nd4j.scalar(Double.valueOf(current.toString()));
        } else {
          int curr = Double.valueOf(current.toString()).intValue();
          if (curr >= numPossibleLabels) curr--;
          label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
        }

      } else {
        Writable current = currList.get(j);
        if (current.toString().isEmpty()) continue;
        featureVector.putScalar(j, Double.valueOf(current.toString()));
      }
    }

    return new DataSet(featureVector, labelIndex >= 0 ? label : featureVector);
  }
예제 #5
0
 /**
  * Fit the model
  *
  * @param examples the examples to classify (one example in each row)
  * @param labels the labels for each example (the number of labels must match
  */
 @Override
 public void fit(INDArray examples, int[] labels) {
   INDArray outcomeMatrix = FeatureUtil.toOutcomeMatrix(labels, numLabels());
   fit(examples, outcomeMatrix);
 }