Exemplo n.º 1
0
  /** Compute y <- alpha*op(a)*x + beta * y (general matrix vector multiplication) */
  public static FloatMatrix gemv(
      float alpha, FloatMatrix a, FloatMatrix x, float beta, FloatMatrix y) {
    if (false) {
      NativeBlas.sgemv(
          'N', a.rows, a.columns, alpha, a.data, 0, a.rows, x.data, 0, 1, beta, y.data, 0, 1);
    } else {
      if (beta == 0.0f) {
        for (int i = 0; i < y.length; i++) y.data[i] = 0.0f;

        for (int j = 0; j < a.columns; j++) {
          float xj = x.get(j);
          if (xj != 0.0f) {
            for (int i = 0; i < a.rows; i++) y.data[i] += a.get(i, j) * xj;
          }
        }
      } else {
        for (int j = 0; j < a.columns; j++) {
          float byj = beta * y.data[j];
          float xj = x.get(j);
          for (int i = 0; i < a.rows; i++) y.data[j] = a.get(i, j) * xj + byj;
        }
      }
    }
    return y;
  }
Exemplo n.º 2
0
  /**
   * Compute a singular-value decomposition of A (sparse variant). Sparse means that the matrices U
   * and V are not square but only have as many columns (or rows) as necessary.
   *
   * @param A
   * @return A FloatMatrix[3] array of U, S, V such that A = U * diag(S) * V'
   */
  public static FloatMatrix[] sparseSVD(FloatMatrix A) {
    int m = A.rows;
    int n = A.columns;

    FloatMatrix U = new FloatMatrix(m, min(m, n));
    FloatMatrix S = new FloatMatrix(min(m, n));
    FloatMatrix V = new FloatMatrix(min(m, n), n);

    int info =
        NativeBlas.sgesvd(
            'S', 'S', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, min(m, n));

    if (info > 0) {
      throw new LapackConvergenceException(
          "GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
    }

    return new FloatMatrix[] {U, S, V.transpose()};
  }
Exemplo n.º 3
0
  protected static FloatVector crossProductAll(FloatVector... multiplicands) {
    int dimension = multiplicands.length + 1;

    if (dimension == 3) return multiplicands[0].crossProduct3D(multiplicands[1]);

    FloatVector nullVector = new FloatVector(dimension);
    FloatVector result = nullVector;

    FloatMatrix matrix = new FloatMatrix(true, multiplicands);
    float determinat;

    for (int row = 0; row < dimension; row++) {
      determinat = matrix.strikeRow(row).determinant();
      if ((row & 1) == 1) determinat = -determinat;

      result = result.add(nullVector.setComponent(row, 1f).scalarMultiply(determinat));
    }

    return result;
  }
Exemplo n.º 4
0
  /**
   * Compute the singular values of a matrix.
   *
   * @param A FloatMatrix of dimension m * n
   * @return A min(m, n) vector of singular values.
   */
  public static FloatMatrix SVDValues(FloatMatrix A) {
    int m = A.rows;
    int n = A.columns;
    FloatMatrix S = new FloatMatrix(min(m, n));

    int info =
        NativeBlas.sgesvd('N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, 1);

    if (info > 0) {
      throw new LapackConvergenceException(
          "GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge.");
    }

    return S;
  }
Exemplo n.º 5
0
  public static FloatMatrix conv2d(FloatMatrix input, FloatMatrix kernel, Type type) {

    FloatMatrix xShape = new FloatMatrix(1, 2);
    xShape.put(0, input.rows);
    xShape.put(1, input.columns);

    FloatMatrix yShape = new FloatMatrix(1, 2);
    yShape.put(0, kernel.rows);
    yShape.put(1, kernel.columns);

    FloatMatrix zShape = xShape.add(yShape).sub(1);
    int retRows = (int) zShape.get(0);
    int retCols = (int) zShape.get(1);

    ComplexFloatMatrix fftInput = complexDisceteFourierTransform(input, retRows, retCols);
    ComplexFloatMatrix fftKernel = complexDisceteFourierTransform(kernel, retRows, retCols);
    ComplexFloatMatrix mul = fftKernel.mul(fftInput);
    ComplexFloatMatrix retComplex = complexInverseDisceteFourierTransform(mul);

    FloatMatrix ret = retComplex.getReal();

    if (type == Type.VALID) {

      FloatMatrix validShape = xShape.subi(yShape).add(1);

      FloatMatrix start = zShape.sub(validShape).div(2);
      FloatMatrix end = start.add(validShape);
      if (start.get(0) < 1 || start.get(1) < 1)
        throw new IllegalStateException("Illegal row index " + start);
      if (end.get(0) < 1 || end.get(1) < 1)
        throw new IllegalStateException("Illegal column index " + end);

      ret =
          ret.get(
              RangeUtils.interval((int) start.get(0), (int) end.get(0)),
              RangeUtils.interval((int) start.get(1), (int) end.get(1)));
    }

    return ret;
  }