@Override
  public FloatMatrix getInitTheta() {

    FloatMatrix x = FloatMatrix.rand(rows, features);
    FloatMatrix theta = FloatMatrix.rand(columns, features);
    return MatrixUtil.merge(x.data, theta.data);
  }
  @Override
  public Object compute(FloatMatrix params, int flag) {

    x = params.getRange(0, rows * features);
    FloatMatrix theta = params.getRange(rows * features, params.length);

    x = x.reshape(rows, features);
    theta = theta.reshape(columns, features);

    if (flag == 1 || flag == 3) {
      FloatMatrix M = MatrixFunctions.pow(x.mmul(theta.transpose()).sub(y), 2);
      this.cost = M.mul(r).columnSums().rowSums().get(0) / 2;

      if (lambda != 0) {
        float cost1 =
            (lambda / 2)
                * (MatrixFunctions.pow(theta, 2).columnSums().rowSums().get(0)
                    + MatrixFunctions.pow(x, 2).columnSums().rowSums().get(0));
        this.cost += cost1;
      }
    }

    if (flag == 2 || flag == 3) {

      FloatMatrix xGrad = FloatMatrix.zeros(x.rows, x.columns);
      FloatMatrix thetaGrad = FloatMatrix.zeros(theta.rows, theta.columns);

      int[] indices;
      FloatMatrix thetaTemp;
      FloatMatrix xTemp;
      FloatMatrix yTemp;
      for (int i = 0; i < rows; i++) {
        indices = r.getRow(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        thetaTemp = theta.getRows(indices);
        yTemp = y.getRow(i).get(indices);
        xGrad.putRow(i, x.getRow(i).mmul(thetaTemp.transpose()).sub(yTemp).mmul(thetaTemp));
      }
      xGrad = xGrad.add(x.mmul(lambda));

      for (int i = 0; i < columns; i++) {
        indices = r.getColumn(i).eq(1).findIndices();
        if (indices.length == 0) continue;

        xTemp = x.getRows(indices);
        yTemp = y.getColumn(i).get(indices);
        thetaGrad.putRow(
            i, xTemp.mmul(theta.getRow(i).transpose()).sub(yTemp).transpose().mmul(xTemp));
      }
      thetaGrad = thetaGrad.add(theta.mmul(lambda));

      this.gradient = MatrixUtil.merge(xGrad.data, thetaGrad.data);
    }

    return flag == 1 ? cost : gradient;
  }
Exemple #3
0
  /**
   * 迭代并找出聚类中心
   *
   * @throws IOException
   */
  public static void part1() throws IOException {
    tic();

    logger.info("加载数据...\n");
    String path = COURSE_ML_PATH + "/ex7/ex7data2.data";
    FloatMatrix X = MatrixUtil.loadData(path, "\\s+");

    logger.info("模型初始化...\n");
    KMeans kMeans = new KMeans(X, 3);

    logger.info("执行训练...\n");
    FloatMatrix centroids = kMeans.run(10);

    logger.info("运行完毕.\n聚类中心如下:\n{}\n", centroids);
    toc();
  }