コード例 #1
0
ファイル: DecisionTree.java プロジェクト: debmalyaroy/smile
  /**
   * Returns the impurity of a node.
   *
   * @param count the sample count in each class.
   * @param n the number of samples in the node.
   * @return the impurity of a node
   */
  private double impurity(int[] count, int n) {
    double impurity = 0.0;

    switch (rule) {
      case GINI:
        impurity = 1.0;
        for (int i = 0; i < count.length; i++) {
          if (count[i] > 0) {
            double p = (double) count[i] / n;
            impurity -= p * p;
          }
        }
        break;

      case ENTROPY:
        for (int i = 0; i < count.length; i++) {
          if (count[i] > 0) {
            double p = (double) count[i] / n;
            impurity -= p * Math.log2(p);
          }
        }
        break;
      case CLASSIFICATION_ERROR:
        impurity = 0;
        for (int i = 0; i < count.length; i++) {
          if (count[i] > 0) {
            impurity = Math.max(impurity, count[i] / (double) n);
          }
        }
        impurity = Math.abs(1 - impurity);
        break;
    }

    return impurity;
  }
コード例 #2
0
ファイル: GradientTreeBoost.java プロジェクト: nkhuyu/smile
  /**
   * Constructor. Learns a gradient tree boosting for classification.
   *
   * @param attributes the attribute properties.
   * @param x the training instances.
   * @param y the class labels.
   * @param T the number of iterations (trees).
   * @param J the number of leaves in each tree.
   * @param shrinkage the shrinkage parameter in (0, 1] controls the learning rate of procedure.
   * @param f the sampling fraction for stochastic tree boosting.
   */
  public GradientTreeBoost(
      Attribute[] attributes, double[][] x, int[] y, int T, int J, double shrinkage, double f) {
    if (x.length != y.length) {
      throw new IllegalArgumentException(
          String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
    }

    if (shrinkage <= 0 || shrinkage > 1) {
      throw new IllegalArgumentException("Invalid shrinkage: " + shrinkage);
    }

    if (f <= 0 || f > 1) {
      throw new IllegalArgumentException("Invalid sampling fraction: " + f);
    }

    if (attributes == null) {
      int p = x[0].length;
      attributes = new Attribute[p];
      for (int i = 0; i < p; i++) {
        attributes[i] = new NumericAttribute("V" + (i + 1));
      }
    }

    this.T = T;
    this.J = J;
    this.shrinkage = shrinkage;
    this.f = f;
    this.k = Math.max(y) + 1;

    if (k < 2) {
      throw new IllegalArgumentException("Only one class or negative class labels.");
    }

    importance = new double[attributes.length];
    if (k == 2) {
      train2(attributes, x, y);
      for (RegressionTree tree : trees) {
        double[] imp = tree.importance();
        for (int i = 0; i < imp.length; i++) {
          importance[i] += imp[i];
        }
      }
    } else {
      traink(attributes, x, y);
      for (RegressionTree[] grove : forest) {
        for (RegressionTree tree : grove) {
          double[] imp = tree.importance();
          for (int i = 0; i < imp.length; i++) {
            importance[i] += imp[i];
          }
        }
      }
    }
  }
コード例 #3
0
ファイル: LSH.java プロジェクト: changjiashuai/smile
  /**
   * Constructor.
   *
   * @param keys the keys of data objects.
   * @param data the data objects.
   * @param w the width of random projections. It should be sufficiently away from 0. But we should
   *     not choose an w value that is too large, which will increase the query time.
   * @param H the size of universal hash tables.
   */
  public LSH(double[][] keys, E[] data, double w, int H) {
    this(
        keys[0].length,
        Math.max(50, (int) Math.pow(keys.length, 0.25)),
        Math.max(3, (int) Math.log10(keys.length)),
        w,
        H);

    if (keys.length != data.length) {
      throw new IllegalArgumentException("The array size of keys and data are different.");
    }

    if (H < keys.length) {
      throw new IllegalArgumentException("Hash table size is too small: " + H);
    }

    int n = keys.length;
    for (int i = 0; i < n; i++) {
      put(keys[i], data[i]);
    }
  }
コード例 #4
0
ファイル: FPGrowthTest.java プロジェクト: myui/smile
  /** Test of learn method, of class FPGrowth. */
  @Test
  public void testKosarak() {
    System.out.println("kosarak");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream = getClass().getResourceAsStream("/smile/data/transaction/kosarak.dat");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        Set<Integer> items = new HashSet<Integer>();
        for (int i = 0; i < s.length; i++) {
          items.add(Integer.parseInt(s[i]));
        }

        int j = 0;
        int[] point = new int[items.size()];
        for (int i : items) {
          point[j++] = i;
        }
        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    long time = System.currentTimeMillis();
    FPGrowth fpgrowth = new FPGrowth(data, 1500);
    System.out.format(
        "Done building FP-tree: %.2f secs.\n", (System.currentTimeMillis() - time) / 1000.0);

    time = System.currentTimeMillis();
    List<ItemSet> results = fpgrowth.learn();
    System.out.format(
        "%d frequent item sets discovered: %.2f secs.\n",
        results.size(), (System.currentTimeMillis() - time) / 1000.0);

    assertEquals(219725, results.size());
  }
コード例 #5
0
ファイル: FPGrowthTest.java プロジェクト: myui/smile
  /** Test of learn method, of class FPGrowth. */
  @Test
  public void testPima() {
    System.out.println("pima");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream =
          getClass().getResourceAsStream("/smile/data/transaction/pima.D38.N768.C2");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        int[] point = new int[s.length];
        for (int i = 0; i < s.length; i++) {
          point[i] = Integer.parseInt(s[i]);
        }

        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    long time = System.currentTimeMillis();
    FPGrowth fpgrowth = new FPGrowth(data, 20);
    System.out.format(
        "Done building FP-tree: %.2f secs.\n", (System.currentTimeMillis() - time) / 1000.0);

    time = System.currentTimeMillis();
    long numItemsets = fpgrowth.learn(System.out);
    System.out.format(
        "%d frequent item sets discovered: %.2f secs.\n",
        numItemsets, (System.currentTimeMillis() - time) / 1000.0);

    assertEquals(1803, numItemsets);
    assertEquals(1803, fpgrowth.learn().size());
  }
コード例 #6
0
ファイル: ARMTest.java プロジェクト: myui/smile
  /** Test of learn method, of class ARM. */
  @Test
  public void testLearnKosarak() {
    System.out.println("kosarak");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream = getClass().getResourceAsStream("/smile/data/transaction/kosarak.dat");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        Set<Integer> items = new HashSet<Integer>();
        for (int i = 0; i < s.length; i++) {
          items.add(Integer.parseInt(s[i]));
        }

        int j = 0;
        int[] point = new int[items.size()];
        for (int i : items) {
          point[j++] = i;
        }

        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    ARM instance = new ARM(data, 0.003);
    long numRules = instance.learn(0.5, System.out);
    System.out.format("%d association rules discovered\n", numRules);
    assertEquals(17932, numRules);
  }
コード例 #7
0
ファイル: ARMTest.java プロジェクト: myui/smile
  /** Test of learn method, of class ARM. */
  @Test
  public void testLearnPima() {
    System.out.println("pima");

    List<int[]> dataList = new ArrayList<int[]>(1000);

    try {
      InputStream stream =
          getClass().getResourceAsStream("/smile/data/transaction/pima.D38.N768.C2");
      BufferedReader input = new BufferedReader(new InputStreamReader(stream));

      String line;
      for (int nrow = 0; (line = input.readLine()) != null; nrow++) {
        if (line.trim().isEmpty()) {
          continue;
        }

        String[] s = line.split(" ");

        int[] point = new int[s.length];
        for (int i = 0; i < s.length; i++) {
          point[i] = Integer.parseInt(s[i]);
        }

        dataList.add(point);
      }
    } catch (IOException ex) {
      System.err.println(ex);
    }

    int[][] data = dataList.toArray(new int[dataList.size()][]);

    int n = Math.max(data);
    System.out.format("%d transactions, %d items\n", data.length, n);

    ARM instance = new ARM(data, 20);
    long numRules = instance.learn(0.9, System.out);
    System.out.format("%d association rules discovered\n", numRules);
    assertEquals(6803, numRules);
    assertEquals(6803, instance.learn(0.9).size());
  }
コード例 #8
0
ファイル: CLARANS.java プロジェクト: debmalyaroy/smile
 /**
  * Constructor. Clustering data into k clusters.
  *
  * @param data the dataset for clustering.
  * @param distance the distance/dissimilarity measure.
  * @param k the number of clusters.
  * @param maxNeighbor the maximum number of neighbors examined during a random search of local
  *     minima.
  */
 public CLARANS(T[] data, Distance<T> distance, int k, int maxNeighbor) {
   this(data, distance, k, maxNeighbor, Math.max(2, MulticoreExecutor.getThreadPoolSize()));
 }