@Override
  public Object compute(FloatMatrix params, int flag) {

    x = params.getRange(0, rows * features);
    FloatMatrix theta = params.getRange(rows * features, params.length);

    x = x.reshape(rows, features);
    theta = theta.reshape(columns, features);

    if (flag == 1 || flag == 3) {
      FloatMatrix M = MatrixFunctions.pow(x.mmul(theta.transpose()).sub(y), 2);
      this.cost = M.mul(r).columnSums().rowSums().get(0) / 2;

      if (lambda != 0) {
        float cost1 =
            (lambda / 2)
                * (MatrixFunctions.pow(theta, 2).columnSums().rowSums().get(0)
                    + MatrixFunctions.pow(x, 2).columnSums().rowSums().get(0));
        this.cost += cost1;
      }
    }

    if (flag == 2 || flag == 3) {

      FloatMatrix xGrad = FloatMatrix.zeros(x.rows, x.columns);
      FloatMatrix thetaGrad = FloatMatrix.zeros(theta.rows, theta.columns);

      int[] indices;
      FloatMatrix thetaTemp;
      FloatMatrix xTemp;
      FloatMatrix yTemp;
      for (int i = 0; i < rows; i++) {
        indices = r.getRow(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        thetaTemp = theta.getRows(indices);
        yTemp = y.getRow(i).get(indices);
        xGrad.putRow(i, x.getRow(i).mmul(thetaTemp.transpose()).sub(yTemp).mmul(thetaTemp));
      }
      xGrad = xGrad.add(x.mmul(lambda));

      for (int i = 0; i < columns; i++) {
        indices = r.getColumn(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        xTemp = x.getRows(indices);
        yTemp = y.getColumn(i).get(indices);
        thetaGrad.putRow(
            i, xTemp.mmul(theta.getRow(i).transpose()).sub(yTemp).transpose().mmul(xTemp));
      }
      thetaGrad = thetaGrad.add(theta.mmul(lambda));

      this.gradient = MatrixUtil.merge(xGrad.data, thetaGrad.data);
    }

    return flag == 1 ? cost : gradient;
  }
  public FloatMatrix normalizeRatings() {

    int[] indices;
    FloatMatrix yMean = FloatMatrix.zeros(rows, 1);
    FloatMatrix yNorm = FloatMatrix.zeros(rows, columns);

    for (int i = 0; i < rows; i++) {
      indices = r.getRow(i).eq(1).findIndices();
      yMean.put(i, y.getRow(i).get(indices).mean());
      yNorm.getRow(i).put(indices, y.getRow(i).get(indices).sub(yMean.get(i)));
    }

    return yMean;
  }
 /** Returns matrices of the right size for either binary or unary (terminal) classification */
 FloatMatrix randomClassificationMatrix() {
   // Leave the bias column with 0 values
   float range = 1.0f / (float) (Math.sqrt((float) numHidden));
   FloatMatrix ret = FloatMatrix.zeros(numOuts, numHidden + 1);
   FloatMatrix insert = MatrixUtil.rand(numOuts, numHidden, -range, range, rng);
   ret.put(interval(0, numOuts), interval(0, numHidden), insert);
   return SimpleBlas.scal(scalingForInit, ret);
 }
 public static FloatDataSet empty() {
   return new FloatDataSet(FloatMatrix.zeros(1), FloatMatrix.zeros(1));
 }
 public FloatDataSet() {
   this(FloatMatrix.zeros(1), FloatMatrix.zeros(1));
 }