コード例 #1
0
ファイル: SharedTree.java プロジェクト: herberteuler/h2o-3
 // Call builder specific score code and then correct probabilities
 // if it is necessary.
 void score2(Chunk chks[], double weight, double offset, double fs[ /*nclass*/], int row) {
   double sum = score1(chks, weight, offset, fs, row);
   if (isClassifier()) {
     if (!Double.isInfinite(sum) && sum > 0f && sum != 1f) ArrayUtils.div(fs, sum);
     if (_parms._balance_classes)
       GenModel.correctProbabilities(
           fs, _model._output._priorClassDist, _model._output._modelClassDist);
   }
 }
コード例 #2
0
ファイル: KMeans.java プロジェクト: huamichaelchen/h2o-3
    @Override
    public void map(Chunk[] cs) {
      int N = cs.length - (_hasWeight ? 1 : 0);
      assert _centers[0].length == N;
      _cMeans = new double[_k][N];
      _cSqr = new double[_k];
      _size = new long[_k];
      // Space for cat histograms
      _cats = new long[_k][N][];
      for (int clu = 0; clu < _k; clu++)
        for (int col = 0; col < N; col++)
          _cats[clu][col] = _isCats[col] == null ? null : new long[cs[col].vec().cardinality()];
      _worst_err = 0;

      // Find closest cluster center for each row
      double[] values = new double[N]; // Temp data to hold row as doubles
      ClusterDist cd = new ClusterDist();
      for (int row = 0; row < cs[0]._len; row++) {
        double weight = _hasWeight ? cs[N].atd(row) : 1;
        if (weight == 0) continue; // skip holdout rows
        assert (weight == 1); // K-Means only works for weight 1 (or weight 0 for holdout)
        data(values, cs, row, _means, _mults, _modes); // Load row as doubles
        closest(_centers, values, _isCats, cd); // Find closest cluster center
        int clu = cd._cluster;
        assert clu != -1; // No broken rows
        _cSqr[clu] += cd._dist;

        // Add values and increment counter for chosen cluster
        for (int col = 0; col < N; col++)
          if (_isCats[col] != null) _cats[clu][col][(int) values[col]]++; // Histogram the cats
          else _cMeans[clu][col] += values[col]; // Sum the column centers
        _size[clu]++;
        // Track worst row
        if (cd._dist > _worst_err) {
          _worst_err = cd._dist;
          _worst_row = cs[0].start() + row;
        }
      }
      // Scale back down to local mean
      for (int clu = 0; clu < _k; clu++)
        if (_size[clu] != 0) ArrayUtils.div(_cMeans[clu], _size[clu]);
      _centers = null;
      _means = _mults = null;
      _modes = null;
    }