Esempio n. 1
0
 private Mx getDropoutMask() {
   final Mx dropoutMask = new VecBasedMx(output.rows(), output.columns());
   for (int i = 0; i < dropoutMask.dim(); i++) {
     dropoutMask.set(i, Math.random() > dropoutFraction ? 1 : 0);
   }
   return dropoutMask;
 }
  @Override
  public ExponentialObliviousTree fit(final VecDataSet ds, final L2 loss) {
    final ObliviousTree base = got.fit(ds, loss);
    features = base.features();
    double baseMse = 0;
    for (int i = 0; i < ds.length(); i++)
      baseMse += sqr(base.value(ds.data().row(i)) - loss.target.get(i));
    System.out.println("\nBase_MSE = " + baseMse);

    if (features.size() != depth) {
      System.out.println("Oblivious Tree bug");
      try {
        final PrintWriter printWriter = new PrintWriter(new File("badloss.txt"));
        for (int i = 0; i < ds.length(); i++) printWriter.println(loss.target.get(i));
        printWriter.close();
      } catch (FileNotFoundException e) {
        e.printStackTrace();
      }
      System.exit(-1);
    }

    precalculateMissCoefficients(ds, loss);
    // System.out.println("Precalc is over");
    final double[][] out = new double[1 << depth][(depth + 1) * (depth + 2) / 2];
    for (int index = 0; index < 1 << depth; index++) {
      final Mx a = new VecBasedMx(numberOfVariablesByLeaf, numberOfVariablesByLeaf);
      final Vec b = new ArrayVec(numberOfVariablesByLeaf);
      for (int i = 0; i < numberOfVariablesByLeaf; i++) b.set(i, -linearMissCoefficient[index][i]);
      for (int i = 0; i < numberOfVariablesByLeaf; i++)
        for (int j = 0; j < numberOfVariablesByLeaf; j++)
          a.set(i, j, quadraticMissCoefficient[index][i][j]);
      for (int i = 0; i < numberOfVariablesByLeaf; i++) a.adjust(i, i, 1e-1);
      final Vec value = GreedyPolynomialExponentRegion.solveLinearEquationUsingLQ(a, b);
      // System.out.println(a);
      for (int k = 0; k <= depth; k++)
        for (int j = 0; j <= k; j++) out[index][k * (k + 1) / 2 + j] = value.get(getIndex(0, k, j));
      /*if(quadraticMissCoefficient[index][0][0] != 0)
      out[index][0] = linearMissCoefficient[index][0] / quadraticMissCoefficient[index][0][0];*/
      // out[index][0] = base.values()[index];
      // for (int i = 0; i < out[index].length; i++)
      // System.out.println(out[index][i]);
    }
    // for(int i =0 ; i < gradLambdas.size();i++)
    //    System.out.println(serializeCondtion(i));
    final ExponentialObliviousTree ret = new ExponentialObliviousTree(features, out, DistCoef);
    double mse = 0;
    for (int i = 0; i < ds.length(); i++)
      mse += sqr(ret.value(ds.data().row(i)) - loss.target.get(i));
    System.out.println("MSE = " + mse);
    /*if (mse > baseMse + 1e-5)
    try {
      throw new Exception("Bad model work mse of based model less than mse of extended model");
    } catch (Exception e) {
      e.printStackTrace();
      //System.exit(-1);
    }*/
    return ret;
  }
Esempio n. 3
0
  public void backward() {
    Mx cnc = null;
    if (bias_b != 0) {
      cnc = leftContract(output);
    } else {
      cnc = VecTools.copy(output);
    }

    difference = MxTools.multiply(MxTools.transpose(cnc), activations);
    for (int i = 0; i < difference.dim(); i++) {
      difference.set(i, difference.get(i) / activations.rows());
    }

    input = MxTools.multiply(cnc, weights);

    rectifier.grad(activations, activations);
    for (int i = 0; i < input.dim(); i++) {
      input.set(i, input.get(i) * activations.get(i));
      if (dropoutFraction > 0) {
        input.set(i, input.get(i) * dropoutMask.get(i));
      }
    }
  }
Esempio n. 4
0
  public void forward() {
    if (bias != 0) {
      activations = leftExtend(input);
    } else {
      activations = VecTools.copy(input);
    }

    output = MxTools.multiply(activations, MxTools.transpose(weights));
    rectifier.value(output, output);

    if (dropoutFraction > 0) {
      if (isTrain) {
        dropoutMask = getDropoutMask();

        for (int i = 0; i < output.dim(); i++) {
          output.set(i, output.get(i) * dropoutMask.get(i));
        }
      } else {
        for (int i = 0; i < output.dim(); i++) {
          output.set(i, output.get(i) * (1 - dropoutFraction));
        }
      }
    }
  }