private Mx getDropoutMask() { final Mx dropoutMask = new VecBasedMx(output.rows(), output.columns()); for (int i = 0; i < dropoutMask.dim(); i++) { dropoutMask.set(i, Math.random() > dropoutFraction ? 1 : 0); } return dropoutMask; }
@Override public ExponentialObliviousTree fit(final VecDataSet ds, final L2 loss) { final ObliviousTree base = got.fit(ds, loss); features = base.features(); double baseMse = 0; for (int i = 0; i < ds.length(); i++) baseMse += sqr(base.value(ds.data().row(i)) - loss.target.get(i)); System.out.println("\nBase_MSE = " + baseMse); if (features.size() != depth) { System.out.println("Oblivious Tree bug"); try { final PrintWriter printWriter = new PrintWriter(new File("badloss.txt")); for (int i = 0; i < ds.length(); i++) printWriter.println(loss.target.get(i)); printWriter.close(); } catch (FileNotFoundException e) { e.printStackTrace(); } System.exit(-1); } precalculateMissCoefficients(ds, loss); // System.out.println("Precalc is over"); final double[][] out = new double[1 << depth][(depth + 1) * (depth + 2) / 2]; for (int index = 0; index < 1 << depth; index++) { final Mx a = new VecBasedMx(numberOfVariablesByLeaf, numberOfVariablesByLeaf); final Vec b = new ArrayVec(numberOfVariablesByLeaf); for (int i = 0; i < numberOfVariablesByLeaf; i++) b.set(i, -linearMissCoefficient[index][i]); for (int i = 0; i < numberOfVariablesByLeaf; i++) for (int j = 0; j < numberOfVariablesByLeaf; j++) a.set(i, j, quadraticMissCoefficient[index][i][j]); for (int i = 0; i < numberOfVariablesByLeaf; i++) a.adjust(i, i, 1e-1); final Vec value = GreedyPolynomialExponentRegion.solveLinearEquationUsingLQ(a, b); // System.out.println(a); for (int k = 0; k <= depth; k++) for (int j = 0; j <= k; j++) out[index][k * (k + 1) / 2 + j] = value.get(getIndex(0, k, j)); /*if(quadraticMissCoefficient[index][0][0] != 0) out[index][0] = linearMissCoefficient[index][0] / quadraticMissCoefficient[index][0][0];*/ // out[index][0] = base.values()[index]; // for (int i = 0; i < out[index].length; i++) // System.out.println(out[index][i]); } // for(int i =0 ; i < gradLambdas.size();i++) // System.out.println(serializeCondtion(i)); final ExponentialObliviousTree ret = new ExponentialObliviousTree(features, out, DistCoef); double mse = 0; for (int i = 0; i < ds.length(); i++) mse += sqr(ret.value(ds.data().row(i)) - loss.target.get(i)); System.out.println("MSE = " + mse); /*if (mse > baseMse + 1e-5) try { throw new Exception("Bad model work mse of based model less than mse of extended model"); } catch (Exception e) { e.printStackTrace(); //System.exit(-1); }*/ return ret; }
public void backward() { Mx cnc = null; if (bias_b != 0) { cnc = leftContract(output); } else { cnc = VecTools.copy(output); } difference = MxTools.multiply(MxTools.transpose(cnc), activations); for (int i = 0; i < difference.dim(); i++) { difference.set(i, difference.get(i) / activations.rows()); } input = MxTools.multiply(cnc, weights); rectifier.grad(activations, activations); for (int i = 0; i < input.dim(); i++) { input.set(i, input.get(i) * activations.get(i)); if (dropoutFraction > 0) { input.set(i, input.get(i) * dropoutMask.get(i)); } } }
public void forward() { if (bias != 0) { activations = leftExtend(input); } else { activations = VecTools.copy(input); } output = MxTools.multiply(activations, MxTools.transpose(weights)); rectifier.value(output, output); if (dropoutFraction > 0) { if (isTrain) { dropoutMask = getDropoutMask(); for (int i = 0; i < output.dim(); i++) { output.set(i, output.get(i) * dropoutMask.get(i)); } } else { for (int i = 0; i < output.dim(); i++) { output.set(i, output.get(i) * (1 - dropoutFraction)); } } } }