@Override
  public void performExperiment() throws IOException {
    BilinearLearnerParameters params = new BilinearLearnerParameters();
    params.put(BilinearLearnerParameters.ETA0_U, 0.02);
    params.put(BilinearLearnerParameters.ETA0_W, 0.02);
    params.put(BilinearLearnerParameters.LAMBDA, 0.001);
    params.put(BilinearLearnerParameters.BICONVEX_TOL, 0.01);
    params.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10);
    params.put(BilinearLearnerParameters.BIAS, true);
    params.put(BilinearLearnerParameters.ETA0_BIAS, 0.5);
    params.put(BilinearLearnerParameters.WINITSTRAT, new SingleValueInitStrat(0.1));
    params.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy());
    BillMatlabFileDataGenerator bmfdg =
        new BillMatlabFileDataGenerator(new File(BILL_DATA()), 98, true);
    prepareExperimentLog(params);
    for (int i = 0; i < bmfdg.nFolds(); i++) {
      logger.debug("Fold: " + i);
      BilinearSparseOnlineLearner learner = new BilinearSparseOnlineLearner(params);
      learner.reinitParams();

      bmfdg.setFold(i, Mode.TEST);
      List<Pair<Matrix>> testpairs = new ArrayList<Pair<Matrix>>();
      while (true) {
        Pair<Matrix> next = bmfdg.generate();
        if (next == null) break;
        testpairs.add(next);
      }
      logger.debug("...training");
      bmfdg.setFold(i, Mode.TRAINING);
      int j = 0;
      while (true) {
        Pair<Matrix> next = bmfdg.generate();
        if (next == null) break;
        logger.debug("...trying item " + j++);
        learner.process(next.firstObject(), next.secondObject());
        Matrix u = learner.getU();
        Matrix w = learner.getW();
        Matrix bias = MatrixFactory.getDenseDefault().copyMatrix(learner.getBias());
        BilinearEvaluator eval = new RootMeanSumLossEvaluator();
        eval.setLearner(learner);
        double loss = eval.evaluate(testpairs);
        logger.debug(String.format("Saving learner, Fold %d, Item %d", i, j));
        File learnerOut = new File(FOLD_ROOT(i), String.format("learner_%d", j));
        IOUtils.writeBinary(learnerOut, learner);
        logger.debug("W row sparcity: " + SandiaMatrixUtils.rowSparcity(w));
        logger.debug("U row sparcity: " + SandiaMatrixUtils.rowSparcity(u));
        Boolean biasMode = learner.getParams().getTyped(BilinearLearnerParameters.BIAS);
        if (biasMode) {
          logger.debug("Bias: " + SandiaMatrixUtils.diag(bias));
        }
        logger.debug(String.format("... loss: %f", loss));
      }
    }
  }
  public double sumLoss(
      List<Pair<Matrix>> pairs, Matrix u, Matrix w, Matrix bias, BilinearLearnerParameters params) {
    LossFunction loss = params.getTyped(BilinearLearnerParameters.LOSS);
    loss = new MatLossFunction(loss);
    double total = 0;
    int i = 0;
    int ntasks = 0;
    for (Pair<Matrix> pair : pairs) {
      Matrix X = pair.firstObject();
      Matrix Y = pair.secondObject();
      SparseMatrix Yexp = BilinearSparseOnlineLearner.expandY(Y);
      Matrix expectedAll = u.transpose().times(X.transpose()).times(w);
      loss.setY(Yexp);
      loss.setX(expectedAll);
      if (bias != null) loss.setBias(bias);
      logger.debug("Testing pair: " + i);
      total += loss.eval(null); // Assums an identity w.
      i++;
      ntasks += Y.getNumColumns();
    }
    total /= ntasks;

    return Math.sqrt(total);
  }