コード例 #1
0
  /**
   * Strips the dataset down to the specified labels and remaps them
   *
   * @param labels the labels to strip down to
   */
  public void filterAndStrip(int[] labels) {
    FloatDataSet filtered = filterBy(labels);
    List<Integer> newLabels = new ArrayList<>();

    // map new labels to index according to passed in labels
    Map<Integer, Integer> labelMap = new HashMap<>();

    for (int i = 0; i < labels.length; i++) labelMap.put(labels[i], i);

    // map examples
    for (int i = 0; i < filtered.numExamples(); i++) {
      int o2 = filtered.get(i).outcome();
      int outcome = labelMap.get(o2);
      newLabels.add(outcome);
    }

    FloatMatrix newLabelMatrix = new FloatMatrix(filtered.numExamples(), labels.length);

    if (newLabelMatrix.rows != newLabels.size())
      throw new IllegalStateException("Inconsistent label sizes");

    for (int i = 0; i < newLabelMatrix.rows; i++) {
      Integer i2 = newLabels.get(i);
      if (i2 == null) throw new IllegalStateException("Label not found on row " + i);
      FloatMatrix newRow = MatrixUtil.toOutcomeVectorFloat(i2, labels.length);
      newLabelMatrix.putRow(i, newRow);
    }

    setFirst(filtered.getFirst());
    setSecond(newLabelMatrix);
  }
コード例 #2
0
 public void shuffle() {
   List<FloatDataSet> list = asList();
   Collections.shuffle(list);
   FloatDataSet ret = FloatDataSet.merge(list);
   setFirst(ret.getFirst());
   setSecond(ret.getSecond());
 }
コード例 #3
0
  public static FloatDataSet merge(List<FloatDataSet> data) {
    if (data.isEmpty()) throw new IllegalArgumentException("Unable to merge empty dataset");
    FloatDataSet first = data.get(0);
    int numExamples = totalExamples(data);
    FloatMatrix in = new FloatMatrix(numExamples, first.getFirst().columns);
    FloatMatrix out = new FloatMatrix(numExamples, first.getSecond().columns);
    int count = 0;

    for (int i = 0; i < data.size(); i++) {
      FloatDataSet d1 = data.get(i);
      for (int j = 0; j < d1.numExamples(); j++) {
        FloatDataSet example = d1.get(j);
        in.putRow(count, example.getFirst());
        out.putRow(count, example.getSecond());
        count++;
      }
    }
    return new FloatDataSet(in, out);
  }
コード例 #4
0
 public void addRow(FloatDataSet d, int i) {
   if (i > numExamples() || d == null)
     throw new IllegalArgumentException("Invalid index for adding a row");
   getFirst().putRow(i, d.getFirst());
   getSecond().putRow(i, d.getSecond());
 }