示例#1
0
  /**
   * Returns the impurity of a node.
   *
   * @param count the sample count in each class.
   * @param n the number of samples in the node.
   * @return the impurity of a node
   */
  private double impurity(int[] count, int n) {
    double impurity = 0.0;

    switch (rule) {
      case GINI:
        impurity = 1.0;
        for (int i = 0; i < count.length; i++) {
          if (count[i] > 0) {
            double p = (double) count[i] / n;
            impurity -= p * p;
          }
        }
        break;

      case ENTROPY:
        for (int i = 0; i < count.length; i++) {
          if (count[i] > 0) {
            double p = (double) count[i] / n;
            impurity -= p * Math.log2(p);
          }
        }
        break;
      case CLASSIFICATION_ERROR:
        impurity = 0;
        for (int i = 0; i < count.length; i++) {
          if (count[i] > 0) {
            impurity = Math.max(impurity, count[i] / (double) n);
          }
        }
        impurity = Math.abs(1 - impurity);
        break;
    }

    return impurity;
  }
示例#2
0
  /** Test of learn method, of class LogisticRegression. */
  @Test
  public void testIris() {
    System.out.println("Iris");
    ArffParser arffParser = new ArffParser();
    arffParser.setResponseIndex(4);
    try {
      AttributeDataset iris =
          arffParser.parse(smile.data.parser.IOUtils.getTestDataFile("weka/iris.arff"));
      double[][] x = iris.toArray(new double[iris.size()][]);
      int[] y = iris.toArray(new int[iris.size()]);

      int n = x.length;
      LOOCV loocv = new LOOCV(n);
      int error = 0;
      for (int i = 0; i < n; i++) {
        double[][] trainx = Math.slice(x, loocv.train[i]);
        int[] trainy = Math.slice(y, loocv.train[i]);
        LogisticRegression logit = new LogisticRegression(trainx, trainy);

        if (y[loocv.test[i]] != logit.predict(x[loocv.test[i]])) error++;
      }

      System.out.println("Logistic Regression error = " + error);
      assertEquals(3, error);
    } catch (Exception ex) {
      System.err.println(ex);
    }
  }
示例#3
0
  /** Train L2 tree boost. */
  private void train2(Attribute[] attributes, double[][] x, int[] y) {
    int n = x.length;
    int N = (int) Math.round(n * f);

    int[] y2 = new int[n];
    for (int i = 0; i < n; i++) {
      if (y[i] == 1) {
        y2[i] = 1;
      } else {
        y2[i] = -1;
      }
    }

    y = y2;

    double[] h = new double[n]; // current F(x_i)
    double[] response = new double[n]; // response variable for regression tree.

    double mu = Math.mean(y);
    b = 0.5 * Math.log((1 + mu) / (1 - mu));

    for (int i = 0; i < n; i++) {
      h[i] = b;
    }

    int[][] order = SmileUtils.sort(attributes, x);
    RegressionTree.NodeOutput output = new L2NodeOutput(response);
    trees = new RegressionTree[T];

    int[] perm = new int[n];
    int[] samples = new int[n];
    for (int i = 0; i < n; i++) {
      perm[i] = i;
    }

    for (int m = 0; m < T; m++) {
      Arrays.fill(samples, 0);

      Math.permutate(perm);
      for (int i = 0; i < N; i++) {
        samples[perm[i]] = 1;
      }

      for (int i = 0; i < n; i++) {
        response[i] = 2.0 * y[i] / (1 + Math.exp(2 * y[i] * h[i]));
      }

      trees[m] = new RegressionTree(attributes, x, response, J, order, samples, output);

      for (int i = 0; i < n; i++) {
        h[i] += shrinkage * trees[m].predict(x[i]);
      }
    }
  }
示例#4
0
文件: FLD.java 项目: grue/smile
  @Override
  public double[] project(double[] x) {
    if (x.length != p) {
      throw new IllegalArgumentException(
          String.format("Invalid input vector size: %d, expected: %d", x.length, p));
    }

    double[] y = new double[scaling[0].length];
    Math.atx(scaling, x, y);
    Math.minus(y, smean);
    return y;
  }
示例#5
0
  @Override
  public int predict(double[] x, double[] posteriori) {
    if (posteriori.length != k) {
      throw new IllegalArgumentException(
          String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, k));
    }

    if (k == 2) {
      double y = b;
      for (int i = 0; i < T; i++) {
        y += shrinkage * trees[i].predict(x);
      }

      posteriori[0] = 1.0 / (1.0 + Math.exp(2 * y));
      posteriori[1] = 1.0 - posteriori[0];

      if (y > 0) {
        return 1;
      } else {
        return 0;
      }
    } else {
      double max = Double.NEGATIVE_INFINITY;
      int y = -1;
      for (int j = 0; j < k; j++) {
        posteriori[j] = 0.0;

        for (int i = 0; i < T; i++) {
          posteriori[j] += shrinkage * forest[j][i].predict(x);
        }

        if (posteriori[j] > max) {
          max = posteriori[j];
          y = j;
        }
      }

      double Z = 0.0;
      for (int i = 0; i < k; i++) {
        posteriori[i] = Math.exp(posteriori[i] - max);
        Z += posteriori[i];
      }

      for (int i = 0; i < k; i++) {
        posteriori[i] /= Z;
      }

      return y;
    }
  }
示例#6
0
文件: FLD.java 项目: grue/smile
  @Override
  public double[][] project(double[][] x) {
    double[][] y = new double[x.length][scaling[0].length];

    for (int i = 0; i < x.length; i++) {
      if (x[i].length != p) {
        throw new IllegalArgumentException(
            String.format("Invalid input vector size: %d, expected: %d", x[i].length, p));
      }

      Math.atx(scaling, x[i], y[i]);
      Math.minus(y[i], smean);
    }

    return y;
  }
示例#7
0
  /**
   * Test the model on a validation dataset.
   *
   * @param x the test data set.
   * @param y the test data response values.
   * @return accuracies with first 1, 2, ..., decision trees.
   */
  public double[] test(double[][] x, int[] y) {
    double[] accuracy = new double[T];

    int n = x.length;
    int[] label = new int[n];

    Accuracy measure = new Accuracy();

    if (k == 2) {
      double[] prediction = new double[n];
      Arrays.fill(prediction, b);
      for (int i = 0; i < T; i++) {
        for (int j = 0; j < n; j++) {
          prediction[j] += shrinkage * trees[i].predict(x[j]);
          label[j] = prediction[j] > 0 ? 1 : 0;
        }
        accuracy[i] = measure.measure(y, label);
      }
    } else {
      double[][] prediction = new double[n][k];
      for (int i = 0; i < T; i++) {
        for (int j = 0; j < n; j++) {
          for (int l = 0; l < k; l++) {
            prediction[j][l] += shrinkage * forest[l][i].predict(x[j]);
          }
          label[j] = Math.whichMax(prediction[j]);
        }

        accuracy[i] = measure.measure(y, label);
      }
    }

    return accuracy;
  }
示例#8
0
 /** Push-relabel algorithm for maximum flow */
 private void push(double[][] flow, double[] excess, int u, int v) {
   double send = Math.min(excess[u], graph[u][v] - flow[u][v]);
   flow[u][v] += send;
   flow[v][u] -= send;
   excess[u] -= send;
   excess[v] += send;
 }
示例#9
0
 private void relabel(double[][] flow, int[] height, int u) {
   int minHeight = 2 * n;
   for (int v = 0; v < n; v++) {
     if (graph[u][v] - flow[u][v] > 0) {
       minHeight = Math.min(minHeight, height[v]);
       height[u] = minHeight + 1;
     }
   }
 }
示例#10
0
  /**
   * Constructor.
   *
   * @param d the dimensionality of data.
   * @param L the number of hash tables.
   * @param k the number of random projection hash functions, which is usually set to log(N) where N
   *     is the dataset size.
   * @param w the width of random projections. It should be sufficiently away from 0. But we should
   *     not choose an w value that is too large, which will increase the query time.
   * @param H the size of universal hash tables.
   */
  public LSH(int d, int L, int k, double w, int H) {
    if (d < 2) {
      throw new IllegalArgumentException("Invalid input space dimension: " + d);
    }

    if (L < 1) {
      throw new IllegalArgumentException("Invalid number of hash tables: " + L);
    }

    if (k < 1) {
      throw new IllegalArgumentException(
          "Invalid number of random projections per hash value: " + k);
    }

    if (w <= 0.0) {
      throw new IllegalArgumentException("Invalid width of random projections: " + w);
    }

    if (H < 1) {
      throw new IllegalArgumentException("Invalid size of hash tables: " + H);
    }

    this.d = d;
    this.L = L;
    this.k = k;
    this.w = w;
    this.H = H;

    keys = new ArrayList<double[]>();
    data = new ArrayList<E>();
    r1 = new int[k];
    r2 = new int[k];
    for (int i = 0; i < k; i++) {
      r1[i] = Math.randomInt(MAX_HASH_RND);
      r2[i] = Math.randomInt(MAX_HASH_RND);
    }

    hash = new ArrayList<Hash>(L);
    for (int i = 0; i < L; i++) {
      hash.add(new Hash());
    }
  }
示例#11
0
  /**
   * Constructor. Learns a gradient tree boosting for classification.
   *
   * @param attributes the attribute properties.
   * @param x the training instances.
   * @param y the class labels.
   * @param T the number of iterations (trees).
   * @param J the number of leaves in each tree.
   * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
   * @param f the sampling fraction for stochastic tree boosting.
   */
  public GradientTreeBoost(
      Attribute[] attributes, double[][] x, int[] y, int T, int J, double shrinkage, double f) {
    if (x.length != y.length) {
      throw new IllegalArgumentException(
          String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
    }

    if (shrinkage <= 0 || shrinkage > 1) {
      throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
    }

    if (f <= 0 || f > 1) {
      throw new IllegalArgumentException("Invalid sampling fraction: " + f);
    }

    if (attributes == null) {
      int p = x[0].length;
      attributes = new Attribute[p];
      for (int i = 0; i < p; i++) {
        attributes[i] = new NumericAttribute("V" + (i + 1));
      }
    }

    this.T = T;
    this.J = J;
    this.shrinkage = shrinkage;
    this.f = f;
    this.k = Math.max(y) + 1;

    if (k < 2) {
      throw new IllegalArgumentException("Only one class or negative class labels.");
    }

    importance = new double[attributes.length];
    if (k == 2) {
      train2(attributes, x, y);
      for (RegressionTree tree : trees) {
        double[] imp = tree.importance();
        for (int i = 0; i < imp.length; i++) {
          importance[i] += imp[i];
        }
      }
    } else {
      traink(attributes, x, y);
      for (RegressionTree[] grove : forest) {
        for (RegressionTree tree : grove) {
          double[] imp = tree.importance();
          for (int i = 0; i < imp.length; i++) {
            importance[i] += imp[i];
          }
        }
      }
    }
  }
示例#12
0
  /** Generate a random neighbor which differs in only one medoid with current clusters. */
  private double getRandomNeighbor(T[] data, T[] medoids, int[] y, double[] d) {
    int n = data.length;

    int index = Math.randomInt(k);
    T medoid = null;
    boolean dup;
    do {
      dup = false;
      medoid = data[Math.randomInt(n)];
      for (int i = 0; i < k; i++) {
        if (medoid == medoids[i]) {
          dup = true;
          break;
        }
      }
    } while (dup);

    medoids[index] = medoid;

    for (int i = 0; i < n; i++) {
      double dist = distance.d(data[i], medoid);
      if (d[i] > dist) {
        y[i] = index;
        d[i] = dist;
      } else if (y[i] == index) {
        d[i] = dist;
        y[i] = index;
        for (int j = 0; j < k; j++) {
          if (j != index) {
            dist = distance.d(data[i], medoids[j]);
            if (d[i] > dist) {
              y[i] = j;
              d[i] = dist;
            }
          }
        }
      }
    }

    return Math.sum(d);
  }
示例#13
0
  /**
   * Constructor.
   *
   * @param keys the keys of data objects.
   * @param data the data objects.
   * @param w the width of random projections. It should be sufficiently away from 0. But we should
   *     not choose an w value that is too large, which will increase the query time.
   * @param H the size of universal hash tables.
   */
  public LSH(double[][] keys, E[] data, double w, int H) {
    this(
        keys[0].length,
        Math.max(50, (int) Math.pow(keys.length, 0.25)),
        Math.max(3, (int) Math.log10(keys.length)),
        w,
        H);

    if (keys.length != data.length) {
      throw new IllegalArgumentException("The array size of keys and data are different.");
    }

    if (H < keys.length) {
      throw new IllegalArgumentException("Hash table size is too small: " + H);
    }

    int n = keys.length;
    for (int i = 0; i < n; i++) {
      put(keys[i], data[i]);
    }
  }
示例#14
0
  /** Test of learn method, of class FPGrowth. */
  @Test
  public void testKosarak() {
    System.out.println("kosarak");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream = getClass().getResourceAsStream("/smile/data/transaction/kosarak.dat");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        Set<Integer> items = new HashSet<Integer>();
        for (int i = 0; i < s.length; i++) {
          items.add(Integer.parseInt(s[i]));
        }

        int j = 0;
        int[] point = new int[items.size()];
        for (int i : items) {
          point[j++] = i;
        }
        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    long time = System.currentTimeMillis();
    FPGrowth fpgrowth = new FPGrowth(data, 1500);
    System.out.format(
        "Done building FP-tree: %.2f secs.\n", (System.currentTimeMillis() - time) / 1000.0);

    time = System.currentTimeMillis();
    List<ItemSet> results = fpgrowth.learn();
    System.out.format(
        "%d frequent item sets discovered: %.2f secs.\n",
        results.size(), (System.currentTimeMillis() - time) / 1000.0);

    assertEquals(219725, results.size());
  }
示例#15
0
  @Override
  public double k(double[] x, double[] y) {
    if (x.length != y.length)
      throw new IllegalArgumentException(
          String.format("Arrays have different length: x[%d], y[%d]", x.length, y.length));

    double sum = 0;
    for (int i = 0; i < x.length; i++) {
      sum += Math.sqrt(x[i] * y[i]);
    }

    return sum;
  }
示例#16
0
    /**
     * Returns the hash value of given vector x.
     *
     * @param x the vector to be hashed.
     * @param m the m-<i>th</i> hash function to be employed.
     * @return the hash value.
     */
    int hash(double[] x, int m) {
      double g = b[m];
      for (int j = 0; j < d; j++) {
        g += a[m][j] * x[j];
      }

      int h = (int) Math.floor(g / w);
      if (h < 0) {
        h += 2147483647;
      }

      return h;
    }
示例#17
0
  @Override
  public String toString() {
    StringBuilder sb = new StringBuilder();

    sb.append(String.format("CLARANS distortion: %.5f%n", distortion));
    sb.append(String.format("Clusters of %d data points:%n", y.length));
    for (int i = 0; i < k; i++) {
      int r = (int) Math.round(1000.0 * size[i] / y.length);
      sb.append(String.format("%3d\t%5d (%2d.%1d%%)%n", i, size[i], r / 10, r % 10));
    }

    return sb.toString();
  }
示例#18
0
    @Override
    public double calculate(int[] samples) {
      double nu = 0.0;
      double de = 0.0;
      for (int i = 0; i < samples.length; i++) {
        if (samples[i] > 0) {
          double abs = Math.abs(y[i]);
          nu += y[i];
          de += abs * (2.0 - abs);
        }
      }

      return nu / de;
    }
示例#19
0
文件: Label.java 项目: grue/smile
  /** Convert coordinate to a string. */
  public static String coordToString(double... c) {
    StringBuilder builder = new StringBuilder("(");
    for (int i = 0; i < c.length; i++) {
      builder.append(Math.round(c[i], 2)).append(",");
    }

    if (c.length > 0) {
      builder.setCharAt(builder.length(), ')');
    } else {
      builder.append(")");
    }

    return builder.toString();
  }
示例#20
0
  private void _cluster(double[][] data, int k) {
    long clock = System.currentTimeMillis();
    SpectralClustering cluster = new SpectralClustering(data, k, 0.355);
    System.out.format(
        "DBSCAN clusterings %d samples in %dms\n", data.length, System.currentTimeMillis() - clock);
    System.out.println("getNumClusters:" + cluster.getNumClusters());
    System.out.println("getClusterSize:" + cluster.getClusterSize());
    //        System.out.println(JSON.toJSONString(dbscan.getClusterSize()));
    System.out.println("toString:" + cluster.toString());
    /** ************************************************************ */
    boolean more = true;
    EigenValueDecomposition eigen = cluster.getEigen();
    double[] lab = eigen.getEigenValues();
    double sd = smile.math.Math.sd(eigen.getEigenValues());

    System.out.println("sd(eigen.getEigenValues()):" + sd);
    if (Math.min(eigen.getEigenValues()) > 0.3) {
      result = cluster;
      cluster(data, k + 1);
    } else {
      return;
    }
  }
示例#21
0
  /** Test of learn method, of class FPGrowth. */
  @Test
  public void testPima() {
    System.out.println("pima");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream =
          getClass().getResourceAsStream("/smile/data/transaction/pima.D38.N768.C2");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        int[] point = new int[s.length];
        for (int i = 0; i < s.length; i++) {
          point[i] = Integer.parseInt(s[i]);
        }

        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    long time = System.currentTimeMillis();
    FPGrowth fpgrowth = new FPGrowth(data, 20);
    System.out.format(
        "Done building FP-tree: %.2f secs.\n", (System.currentTimeMillis() - time) / 1000.0);

    time = System.currentTimeMillis();
    long numItemsets = fpgrowth.learn(System.out);
    System.out.format(
        "%d frequent item sets discovered: %.2f secs.\n",
        numItemsets, (System.currentTimeMillis() - time) / 1000.0);

    assertEquals(1803, numItemsets);
    assertEquals(1803, fpgrowth.learn().size());
  }
示例#22
0
    /** Constructor. */
    Hash() {
      a = new double[k][d];
      b = new double[k];

      GaussianDistribution gaussian = GaussianDistribution.getInstance();
      for (int i = 0; i < k; i++) {
        for (int j = 0; j < d; j++) {
          a[i][j] = gaussian.rand();
        }

        b[i] = Math.random(0, w);
      }

      table = new HashEntry[H];
    }
示例#23
0
文件: ARMTest.java 项目: myui/smile
  /** Test of learn method, of class ARM. */
  @Test
  public void testLearnKosarak() {
    System.out.println("kosarak");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream = getClass().getResourceAsStream("/smile/data/transaction/kosarak.dat");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        Set<Integer> items = new HashSet<Integer>();
        for (int i = 0; i < s.length; i++) {
          items.add(Integer.parseInt(s[i]));
        }

        int j = 0;
        int[] point = new int[items.size()];
        for (int i : items) {
          point[j++] = i;
        }

        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    ARM instance = new ARM(data, 0.003);
    long numRules = instance.learn(0.5, System.out);
    System.out.format("%d association rules discovered\n", numRules);
    assertEquals(17932, numRules);
  }
示例#24
0
  @Override
  public void range(double[] q, double radius, List<Neighbor<double[], E>> neighbors) {
    if (radius <= 0.0) {
      throw new IllegalArgumentException("Invalid radius: " + radius);
    }
    Set<Integer> candidates = obtainCandidates(q);
    for (int index : candidates) {
      double[] key = keys.get(index);
      if (q == key && identicalExcluded) {
        continue;
      }

      double distance = Math.distance(q, key);
      if (distance <= radius) {
        neighbors.add(new Neighbor<double[], E>(key, data.get(index), index, distance));
      }
    }
  }
示例#25
0
  @Override
  public Neighbor<double[], E>[] knn(double[] q, int k) {
    if (k < 1) {
      throw new IllegalArgumentException("Invalid k: " + k);
    }
    Set<Integer> candidates = obtainCandidates(q);
    Neighbor<double[], E> neighbor = new Neighbor<double[], E>(null, null, 0, Double.MAX_VALUE);
    @SuppressWarnings("unchecked")
    Neighbor<double[], E>[] neighbors =
        (Neighbor<double[], E>[]) java.lang.reflect.Array.newInstance(neighbor.getClass(), k);
    HeapSelect<Neighbor<double[], E>> heap = new HeapSelect<Neighbor<double[], E>>(neighbors);
    for (int i = 0; i < k; i++) {
      heap.add(neighbor);
    }

    int hit = 0;
    for (int index : candidates) {
      double[] key = keys.get(index);
      if (q == key && identicalExcluded) {
        continue;
      }

      double distance = Math.distance(q, key);
      if (distance < heap.peek().distance) {
        heap.add(new Neighbor<double[], E>(key, data.get(index), index, distance));
        hit++;
      }
    }

    heap.sort();

    if (hit < k) {
      @SuppressWarnings("unchecked")
      Neighbor<double[], E>[] n2 =
          (Neighbor<double[], E>[]) java.lang.reflect.Array.newInstance(neighbor.getClass(), hit);
      int start = k - hit;
      for (int i = 0; i < hit; i++) {
        n2[i] = neighbors[i + start];
      }
      neighbors = n2;
    }

    return neighbors;
  }
示例#26
0
  @Override
  public Neighbor<double[], E> nearest(double[] q) {
    Set<Integer> candidates = obtainCandidates(q);
    Neighbor<double[], E> neighbor = new Neighbor<double[], E>(null, null, -1, Double.MAX_VALUE);
    for (int index : candidates) {
      double[] key = keys.get(index);
      if (q == key && identicalExcluded) {
        continue;
      }
      double distance = Math.distance(q, key);
      if (distance < neighbor.distance) {
        neighbor.index = index;
        neighbor.distance = distance;
        neighbor.key = key;
        neighbor.value = data.get(index);
      }
    }

    return neighbor;
  }
示例#27
0
    @Override
    public double calculate(int[] samples) {
      int n = 0;
      double nu = 0.0;
      double de = 0.0;
      for (int i = 0; i < samples.length; i++) {
        if (samples[i] > 0) {
          n++;
          double abs = Math.abs(y[i]);
          nu += y[i];
          de += abs * (1.0 - abs);
        }
      }

      if (de < 1E-10) {
        return nu / n;
      }

      return ((k - 1.0) / k) * (nu / de);
    }
示例#28
0
文件: ARMTest.java 项目: myui/smile
  /** Test of learn method, of class ARM. */
  @Test
  public void testLearnPima() {
    System.out.println("pima");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream =
          getClass().getResourceAsStream("/smile/data/transaction/pima.D38.N768.C2");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        int[] point = new int[s.length];
        for (int i = 0; i < s.length; i++) {
          point[i] = Integer.parseInt(s[i]);
        }

        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    ARM instance = new ARM(data, 20);
    long numRules = instance.learn(0.9, System.out);
    System.out.format("%d association rules discovered\n", numRules);
    assertEquals(6803, numRules);
    assertEquals(6803, instance.learn(0.9).size());
  }
示例#29
0
文件: FLD.java 项目: grue/smile
  @Override
  public int predict(double[] x) {
    if (x.length != p) {
      throw new IllegalArgumentException(
          String.format("Invalid input vector size: %d, expected: %d", x.length, p));
    }

    double[] wx = project(x);

    int y = 0;
    double nearest = Double.POSITIVE_INFINITY;
    for (int i = 0; i < k; i++) {
      double d = Math.distance(wx, smu[i]);
      if (d < nearest) {
        nearest = d;
        y = i;
      }
    }

    return y;
  }
示例#30
0
  /**
   * Test the model on a validation dataset.
   *
   * @param x the test data set.
   * @param y the test data labels.
   * @param measures the performance measures of classification.
   * @return performance measures with first 1, 2, ..., decision trees.
   */
  public double[][] test(double[][] x, int[] y, ClassificationMeasure[] measures) {
    int m = measures.length;
    double[][] results = new double[T][m];

    int n = x.length;
    int[] label = new int[n];

    if (k == 2) {
      double[] prediction = new double[n];
      Arrays.fill(prediction, b);
      for (int i = 0; i < T; i++) {
        for (int j = 0; j < n; j++) {
          prediction[j] += shrinkage * trees[i].predict(x[j]);
          label[j] = prediction[j] > 0 ? 1 : 0;
        }

        for (int j = 0; j < m; j++) {
          results[i][j] = measures[j].measure(y, label);
        }
      }
    } else {
      double[][] prediction = new double[n][k];
      for (int i = 0; i < T; i++) {
        for (int j = 0; j < n; j++) {
          for (int l = 0; l < k; l++) {
            prediction[j][l] += shrinkage * forest[l][i].predict(x[j]);
          }
          label[j] = Math.whichMax(prediction[j]);
        }

        for (int j = 0; j < m; j++) {
          results[i][j] = measures[j].measure(y, label);
        }
      }
    }

    return results;
  }