Exemplo n.º 1
0
 private Mx getDropoutMask() {
   final Mx dropoutMask = new VecBasedMx(output.rows(), output.columns());
   for (int i = 0; i < dropoutMask.dim(); i++) {
     dropoutMask.set(i, Math.random() > dropoutFraction ? 1 : 0);
   }
   return dropoutMask;
 }
  @Override
  public ExponentialObliviousTree fit(final VecDataSet ds, final L2 loss) {
    final ObliviousTree base = got.fit(ds, loss);
    features = base.features();
    double baseMse = 0;
    for (int i = 0; i < ds.length(); i++)
      baseMse += sqr(base.value(ds.data().row(i)) - loss.target.get(i));
    System.out.println("\nBase_MSE = " + baseMse);

    if (features.size() != depth) {
      System.out.println("Oblivious Tree bug");
      try {
        final PrintWriter printWriter = new PrintWriter(new File("badloss.txt"));
        for (int i = 0; i < ds.length(); i++) printWriter.println(loss.target.get(i));
        printWriter.close();
      } catch (FileNotFoundException e) {
        e.printStackTrace();
      }
      System.exit(-1);
    }

    precalculateMissCoefficients(ds, loss);
    // System.out.println("Precalc is over");
    final double[][] out = new double[1 << depth][(depth + 1) * (depth + 2) / 2];
    for (int index = 0; index < 1 << depth; index++) {
      final Mx a = new VecBasedMx(numberOfVariablesByLeaf, numberOfVariablesByLeaf);
      final Vec b = new ArrayVec(numberOfVariablesByLeaf);
      for (int i = 0; i < numberOfVariablesByLeaf; i++) b.set(i, -linearMissCoefficient[index][i]);
      for (int i = 0; i < numberOfVariablesByLeaf; i++)
        for (int j = 0; j < numberOfVariablesByLeaf; j++)
          a.set(i, j, quadraticMissCoefficient[index][i][j]);
      for (int i = 0; i < numberOfVariablesByLeaf; i++) a.adjust(i, i, 1e-1);
      final Vec value = GreedyPolynomialExponentRegion.solveLinearEquationUsingLQ(a, b);
      // System.out.println(a);
      for (int k = 0; k <= depth; k++)
        for (int j = 0; j <= k; j++) out[index][k * (k + 1) / 2 + j] = value.get(getIndex(0, k, j));
      /*if(quadraticMissCoefficient[index][0][0] != 0)
      out[index][0] = linearMissCoefficient[index][0] / quadraticMissCoefficient[index][0][0];*/
      // out[index][0] = base.values()[index];
      // for (int i = 0; i < out[index].length; i++)
      // System.out.println(out[index][i]);
    }
    // for(int i =0 ; i < gradLambdas.size();i++)
    //    System.out.println(serializeCondtion(i));
    final ExponentialObliviousTree ret = new ExponentialObliviousTree(features, out, DistCoef);
    double mse = 0;
    for (int i = 0; i < ds.length(); i++)
      mse += sqr(ret.value(ds.data().row(i)) - loss.target.get(i));
    System.out.println("MSE = " + mse);
    /*if (mse > baseMse + 1e-5)
    try {
      throw new Exception("Bad model work mse of based model less than mse of extended model");
    } catch (Exception e) {
      e.printStackTrace();
      //System.exit(-1);
    }*/
    return ret;
  }
Exemplo n.º 3
0
  private Mx leftContract(final Mx original) {
    final VecBasedMx contracted = new VecBasedMx(original.rows(), original.columns() - 1);

    for (int i = 0; i < contracted.rows(); i++) {
      for (int j = 0; j < contracted.columns(); j++) {
        contracted.set(i, j, original.get(i, j + 1));
      }
    }
    return contracted;
  }
 @Override
 public double value(final Vec x) {
   final Mx predictMx = (Mx) x;
   int count = 0;
   for (int i = 0; i < predictMx.rows(); i++) {
     if (VecTools.distance(predictMx.row(i), targets.row(i)) < MathTools.EPSILON) {
       count++;
     }
   }
   return (double) count / targets.rows();
 }
 public CMLMetricOptimization(
     final VecDataSet ds,
     final BlockwiseMLLLogit target,
     final Mx S,
     final double c,
     final double step) {
   this.ds = ds;
   this.target = target;
   this.step = step;
   this.classesIdxs = MCTools.splitClassesIdxs(target.labels());
   this.laplacian = VecTools.copy(S);
   VecTools.scale(laplacian, -1.0);
   for (int i = 0; i < laplacian.rows(); i++) {
     final double diagElem = VecTools.sum(S.row(i));
     laplacian.adjust(i, i, diagElem);
   }
   this.c = c;
 }
 public Mx trainProbs(final Mx codingMatrix, final Func[] binClassifiers) {
   final Mx result = new VecBasedMx(codingMatrix.rows(), codingMatrix.columns());
   for (int l = 0; l < result.columns(); l++) {
     System.out.println("Optimize column " + l);
     final FuncC1 columnTargetFunction = new ColumnTargetFunction(binClassifiers[l]);
     final Vec muColumn = optimizeColumn(columnTargetFunction, codingMatrix.col(l));
     VecTools.assign(result.col(l), muColumn);
   }
   return result;
 }
Exemplo n.º 7
0
  private Mx leftExtend(final Mx original) {
    final VecBasedMx extended = new VecBasedMx(original.rows(), original.columns() + 1);

    for (int i = 0; i < original.rows(); i++) {
      extended.set(i, 0, bias);
    }

    for (int i = 0; i < original.rows(); i++) {
      for (int j = 1; j < original.columns() + 1; j++) {
        extended.set(i, j, original.get(i, j - 1));
      }
    }
    return extended;
  }
Exemplo n.º 8
0
  public void forward() {
    if (bias != 0) {
      activations = leftExtend(input);
    } else {
      activations = VecTools.copy(input);
    }

    output = MxTools.multiply(activations, MxTools.transpose(weights));
    rectifier.value(output, output);

    if (dropoutFraction > 0) {
      if (isTrain) {
        dropoutMask = getDropoutMask();

        for (int i = 0; i < output.dim(); i++) {
          output.set(i, output.get(i) * dropoutMask.get(i));
        }
      } else {
        for (int i = 0; i < output.dim(); i++) {
          output.set(i, output.get(i) * (1 - dropoutFraction));
        }
      }
    }
  }
 @Override
 public int dim() {
   return targets.dim();
 }
Exemplo n.º 10
0
  public void backward() {
    Mx cnc = null;
    if (bias_b != 0) {
      cnc = leftContract(output);
    } else {
      cnc = VecTools.copy(output);
    }

    difference = MxTools.multiply(MxTools.transpose(cnc), activations);
    for (int i = 0; i < difference.dim(); i++) {
      difference.set(i, difference.get(i) / activations.rows());
    }

    input = MxTools.multiply(cnc, weights);

    rectifier.grad(activations, activations);
    for (int i = 0; i < input.dim(); i++) {
      input.set(i, input.get(i) * activations.get(i));
      if (dropoutFraction > 0) {
        input.set(i, input.get(i) * dropoutMask.get(i));
      }
    }
  }