예제 #1
0
  /**
   * Strips the dataset down to the specified labels and remaps them
   *
   * @param labels the labels to strip down to
   */
  public void filterAndStrip(int[] labels) {
    FloatDataSet filtered = filterBy(labels);
    List<Integer> newLabels = new ArrayList<>();

    // map new labels to index according to passed in labels
    Map<Integer, Integer> labelMap = new HashMap<>();

    for (int i = 0; i < labels.length; i++) labelMap.put(labels[i], i);

    // map examples
    for (int i = 0; i < filtered.numExamples(); i++) {
      int o2 = filtered.get(i).outcome();
      int outcome = labelMap.get(o2);
      newLabels.add(outcome);
    }

    FloatMatrix newLabelMatrix = new FloatMatrix(filtered.numExamples(), labels.length);

    if (newLabelMatrix.rows != newLabels.size())
      throw new IllegalStateException("Inconsistent label sizes");

    for (int i = 0; i < newLabelMatrix.rows; i++) {
      Integer i2 = newLabels.get(i);
      if (i2 == null) throw new IllegalStateException("Label not found on row " + i);
      FloatMatrix newRow = MatrixUtil.toOutcomeVectorFloat(i2, labels.length);
      newLabelMatrix.putRow(i, newRow);
    }

    setFirst(filtered.getFirst());
    setSecond(newLabelMatrix);
  }
예제 #2
0
 public void shuffle() {
   List<FloatDataSet> list = asList();
   Collections.shuffle(list);
   FloatDataSet ret = FloatDataSet.merge(list);
   setFirst(ret.getFirst());
   setSecond(ret.getSecond());
 }
예제 #3
0
  /**
   * Strips the data applyTransformToDestination of all but the passed in labels
   *
   * @param labels strips the data applyTransformToDestination of all but the passed in labels
   * @return the dataset with only the specified labels
   */
  public FloatDataSet filterBy(int[] labels) {
    List<FloatDataSet> list = asList();
    List<FloatDataSet> newList = new ArrayList<>();
    List<Integer> labelList = new ArrayList<>();
    for (int i : labels) labelList.add(i);
    for (FloatDataSet d : list) {
      if (labelList.contains(d.getLabel(d))) {
        newList.add(d);
      }
    }

    return FloatDataSet.merge(newList);
  }
예제 #4
0
  public static FloatDataSet merge(List<FloatDataSet> data) {
    if (data.isEmpty()) throw new IllegalArgumentException("Unable to merge empty dataset");
    FloatDataSet first = data.get(0);
    int numExamples = totalExamples(data);
    FloatMatrix in = new FloatMatrix(numExamples, first.getFirst().columns);
    FloatMatrix out = new FloatMatrix(numExamples, first.getSecond().columns);
    int count = 0;

    for (int i = 0; i < data.size(); i++) {
      FloatDataSet d1 = data.get(i);
      for (int j = 0; j < d1.numExamples(); j++) {
        FloatDataSet example = d1.get(j);
        in.putRow(count, example.getFirst());
        out.putRow(count, example.getSecond());
        count++;
      }
    }
    return new FloatDataSet(in, out);
  }
예제 #5
0
 private int getLabel(FloatDataSet data) {
   return SimpleBlas.iamax(data.getSecond());
 }
예제 #6
0
 public void addRow(FloatDataSet d, int i) {
   if (i > numExamples() || d == null)
     throw new IllegalArgumentException("Invalid index for adding a row");
   getFirst().putRow(i, d.getFirst());
   getSecond().putRow(i, d.getSecond());
 }
예제 #7
0
 /**
  * Partitions the data applyTransformToDestination by the specified number.
  *
  * @param num the number to split by
  * @return the partitioned data applyTransformToDestination
  */
 public List<FloatDataSet> dataSetBatches(int num) {
   List<List<FloatDataSet>> list = Lists.partition(asList(), num);
   List<FloatDataSet> ret = new ArrayList<>();
   for (List<FloatDataSet> l : list) ret.add(FloatDataSet.merge(l));
   return ret;
 }
예제 #8
0
 private static int totalExamples(Collection<FloatDataSet> coll) {
   int count = 0;
   for (FloatDataSet d : coll) count += d.numExamples();
   return count;
 }