@Override
  public ExponentialObliviousTree fit(final VecDataSet ds, final L2 loss) {
    final ObliviousTree base = got.fit(ds, loss);
    features = base.features();
    double baseMse = 0;
    for (int i = 0; i < ds.length(); i++)
      baseMse += sqr(base.value(ds.data().row(i)) - loss.target.get(i));
    System.out.println("\nBase_MSE = " + baseMse);

    if (features.size() != depth) {
      System.out.println("Oblivious Tree bug");
      try {
        final PrintWriter printWriter = new PrintWriter(new File("badloss.txt"));
        for (int i = 0; i < ds.length(); i++) printWriter.println(loss.target.get(i));
        printWriter.close();
      } catch (FileNotFoundException e) {
        e.printStackTrace();
      }
      System.exit(-1);
    }

    precalculateMissCoefficients(ds, loss);
    // System.out.println("Precalc is over");
    final double[][] out = new double[1 << depth][(depth + 1) * (depth + 2) / 2];
    for (int index = 0; index < 1 << depth; index++) {
      final Mx a = new VecBasedMx(numberOfVariablesByLeaf, numberOfVariablesByLeaf);
      final Vec b = new ArrayVec(numberOfVariablesByLeaf);
      for (int i = 0; i < numberOfVariablesByLeaf; i++) b.set(i, -linearMissCoefficient[index][i]);
      for (int i = 0; i < numberOfVariablesByLeaf; i++)
        for (int j = 0; j < numberOfVariablesByLeaf; j++)
          a.set(i, j, quadraticMissCoefficient[index][i][j]);
      for (int i = 0; i < numberOfVariablesByLeaf; i++) a.adjust(i, i, 1e-1);
      final Vec value = GreedyPolynomialExponentRegion.solveLinearEquationUsingLQ(a, b);
      // System.out.println(a);
      for (int k = 0; k <= depth; k++)
        for (int j = 0; j <= k; j++) out[index][k * (k + 1) / 2 + j] = value.get(getIndex(0, k, j));
      /*if(quadraticMissCoefficient[index][0][0] != 0)
      out[index][0] = linearMissCoefficient[index][0] / quadraticMissCoefficient[index][0][0];*/
      // out[index][0] = base.values()[index];
      // for (int i = 0; i < out[index].length; i++)
      // System.out.println(out[index][i]);
    }
    // for(int i =0 ; i < gradLambdas.size();i++)
    //    System.out.println(serializeCondtion(i));
    final ExponentialObliviousTree ret = new ExponentialObliviousTree(features, out, DistCoef);
    double mse = 0;
    for (int i = 0; i < ds.length(); i++)
      mse += sqr(ret.value(ds.data().row(i)) - loss.target.get(i));
    System.out.println("MSE = " + mse);
    /*if (mse > baseMse + 1e-5)
    try {
      throw new Exception("Bad model work mse of based model less than mse of extended model");
    } catch (Exception e) {
      e.printStackTrace();
      //System.exit(-1);
    }*/
    return ret;
  }
예제 #2
0
  @Override
  public Ensemble fit(final VecDataSet learn, final GlobalLoss globalLoss) {
    final Vec cursor = new ArrayVec(globalLoss.xdim());
    final List<Trans> weakModels = new ArrayList<>(iterationsCount);
    final Trans gradient = globalLoss.gradient();

    for (int t = 0; t < iterationsCount; t++) {
      final Vec gradientValueAtCursor = gradient.trans(cursor);
      final L2 localLoss = DataTools.newTarget(factory, gradientValueAtCursor, learn);
      final Trans weakModel = weak.fit(learn, localLoss);
      weakModels.add(weakModel);
      invoke(new Ensemble(weakModels, -step));
      VecTools.append(cursor, VecTools.scale(weakModel.transAll(learn.data()), -step));
    }
    return new Ensemble(weakModels, -step);
  }