@Override public void performExperiment() throws IOException { BilinearLearnerParameters params = new BilinearLearnerParameters(); params.put(BilinearLearnerParameters.ETA0_U, 0.02); params.put(BilinearLearnerParameters.ETA0_W, 0.02); params.put(BilinearLearnerParameters.LAMBDA, 0.001); params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.01); params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10); params.put(BilinearLearnerParameters.BIAS, true); params.put(BilinearLearnerParameters.ETA0_BIAS, 0.5); params.put(BilinearLearnerParameters.WINITSTRAT, new SingleValueInitStrat(0.1)); params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy()); BillMatlabFileDataGenerator bmfdg = new BillMatlabFileDataGenerator(new File(BILL_DATA()), 98, true); prepareExperimentLog(params); for (int i = 0; i < bmfdg.nFolds(); i++) { logger.debug("Fold: " + i); BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params); learner.reinitParams(); bmfdg.setFold(i, Mode.TEST); List<Pair<Matrix>> testpairs = new ArrayList<Pair<Matrix>>(); while (true) { Pair<Matrix> next = bmfdg.generate(); if (next == null) break; testpairs.add(next); } logger.debug("...training"); bmfdg.setFold(i, Mode.TRAINING); int j = 0; while (true) { Pair<Matrix> next = bmfdg.generate(); if (next == null) break; logger.debug("...trying item " + j++); learner.process(next.firstObject(), next.secondObject()); Matrix u = learner.getU(); Matrix w = learner.getW(); Matrix bias = MatrixFactory.getDenseDefault().copyMatrix(learner.getBias()); BilinearEvaluator eval = new RootMeanSumLossEvaluator(); eval.setLearner(learner); double loss = eval.evaluate(testpairs); logger.debug(String.format("Saving learner, Fold %d, Item %d", i, j)); File learnerOut = new File(FOLD_ROOT(i), String.format("learner_%d", j)); IOUtils.writeBinary(learnerOut, learner); logger.debug("W row sparcity: " + SandiaMatrixUtils.rowSparcity(w)); logger.debug("U row sparcity: " + SandiaMatrixUtils.rowSparcity(u)); Boolean biasMode = learner.getParams().getTyped(BilinearLearnerParameters.BIAS); if (biasMode) { logger.debug("Bias: " + SandiaMatrixUtils.diag(bias)); } logger.debug(String.format("... loss: %f", loss)); } } }
public double sumLoss( List<Pair<Matrix>> pairs, Matrix u, Matrix w, Matrix bias, BilinearLearnerParameters params) { LossFunction loss = params.getTyped(BilinearLearnerParameters.LOSS); loss = new MatLossFunction(loss); double total = 0; int i = 0; int ntasks = 0; for (Pair<Matrix> pair : pairs) { Matrix X = pair.firstObject(); Matrix Y = pair.secondObject(); SparseMatrix Yexp = BilinearSparseOnlineLearner.expandY(Y); Matrix expectedAll = u.transpose().times(X.transpose()).times(w); loss.setY(Yexp); loss.setX(expectedAll); if (bias != null) loss.setBias(bias); logger.debug("Testing pair: " + i); total += loss.eval(null); // Assums an identity w. i++; ntasks += Y.getNumColumns(); } total /= ntasks; return Math.sqrt(total); }