/** * Discrete fourier transform 2d * * @param input the input to transform * @param rows the number of rows in the transformed output matrix * @param cols the number of columns in the transformed output matrix * @return the discrete fourier transform of the input */ public static ComplexDoubleMatrix complexDisceteFourierTransform( ComplexDoubleMatrix input, int rows, int cols) { ComplexDoubleMatrix base; // pad if (input.rows < rows || input.columns < cols) base = MatrixUtil.complexPadWithZeros(input, rows, cols); // truncation else if (input.rows > rows || input.columns > cols) { base = input.dup(); base = base.get( MatrixUtil.toIndices(RangeUtils.interval(0, rows)), MatrixUtil.toIndices(RangeUtils.interval(0, cols))); } else base = input.dup(); ComplexDoubleMatrix temp = new ComplexDoubleMatrix(base.rows, base.columns); ComplexDoubleMatrix ret = new ComplexDoubleMatrix(base.rows, base.columns); for (int i = 0; i < base.columns; i++) { ComplexDoubleMatrix column = base.getColumn(i); temp.putColumn(i, complexDiscreteFourierTransform1d(column)); } for (int i = 0; i < ret.rows; i++) { ComplexDoubleMatrix row = temp.getRow(i); ret.putRow(i, complexDiscreteFourierTransform1d(row)); } return ret; }
/** * 1d discrete fourier transform, note that this will throw an exception if the passed in input * isn't a vector. See matlab's fft2 for more information * * @param inputC the input to transform * @return the the discrete fourier transform of the passed in input */ public static ComplexNDArray complexDiscreteFourierTransform1d(ComplexNDArray inputC) { if (inputC.shape().length != 1) throw new IllegalArgumentException("Illegal input: Must be a vector"); double len = Math.max(inputC.rows, inputC.columns); ComplexDouble c2 = new ComplexDouble(0, -2).muli(FastMath.PI).divi(len); ComplexDoubleMatrix range = MatrixUtil.complexRangeVector(0, len); ComplexDoubleMatrix matrix = exp(range.mmul(range.transpose().mul(c2))); ComplexDoubleMatrix complexRet = matrix.mmul(inputC); return ComplexNDArray.wrap(inputC, complexRet); }
/** * Compute the singular values of a complex matrix. * * @param A ComplexDoubleMatrix of dimension m * n * @return A real-valued (!) min(m, n) vector of singular values. */ public static DoubleMatrix SVDValues(ComplexDoubleMatrix A) { int m = A.rows; int n = A.columns; DoubleMatrix S = new DoubleMatrix(min(m, n)); double[] rwork = new double[5 * min(m, n)]; int info = NativeBlas.zgesvd( 'N', 'N', m, n, A.dup().data, 0, m, S.data, 0, null, 0, 1, null, 0, min(m, n), rwork, 0); if (info > 0) { throw new LapackConvergenceException( "GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge."); } return S; }
/** * Compute a singular-value decomposition of A. * * @return A ComplexDoubleMatrix[3] array of U, S, V such that A = U * diag(S) * V' */ public static ComplexDoubleMatrix[] fullSVD(ComplexDoubleMatrix A) { int m = A.rows; int n = A.columns; ComplexDoubleMatrix U = new ComplexDoubleMatrix(m, m); DoubleMatrix S = new DoubleMatrix(min(m, n)); ComplexDoubleMatrix V = new ComplexDoubleMatrix(n, n); double[] rwork = new double[5 * min(m, n)]; int info = NativeBlas.zgesvd( 'A', 'A', m, n, A.dup().data, 0, m, S.data, 0, U.data, 0, m, V.data, 0, n, rwork, 0); if (info > 0) { throw new LapackConvergenceException( "GESVD", info + " superdiagonals of an intermediate bidiagonal form failed to converge."); } return new ComplexDoubleMatrix[] {U, new ComplexDoubleMatrix(S), V.hermitian()}; }
public static DoubleMatrix conv2d(DoubleMatrix input, DoubleMatrix kernel, Type type) { DoubleMatrix xShape = new DoubleMatrix(1, 2); xShape.put(0, input.rows); xShape.put(1, input.columns); DoubleMatrix yShape = new DoubleMatrix(1, 2); yShape.put(0, kernel.rows); yShape.put(1, kernel.columns); DoubleMatrix zShape = xShape.addi(yShape).subi(1); int retRows = (int) zShape.get(0); int retCols = (int) zShape.get(1); ComplexDoubleMatrix fftInput = complexDisceteFourierTransform(input, retRows, retCols); ComplexDoubleMatrix fftKernel = complexDisceteFourierTransform(kernel, retRows, retCols); ComplexDoubleMatrix mul = fftKernel.muli(fftInput); ComplexDoubleMatrix retComplex = complexInverseDisceteFourierTransform(mul); DoubleMatrix ret = retComplex.getReal(); if (type == Type.VALID) { DoubleMatrix validShape = xShape.subi(yShape).addi(1); DoubleMatrix start = zShape.subi(validShape).divi(2); DoubleMatrix end = start.addi(validShape); if (start.get(0) < 1 || start.get(1) < 1) throw new IllegalStateException("Illegal row index " + start); if (end.get(0) < 1 || end.get(1) < 1) throw new IllegalStateException("Illegal column index " + end); ret = ret.get( RangeUtils.interval((int) start.get(0), (int) end.get(0)), RangeUtils.interval((int) start.get(1), (int) end.get(1))); } return ret; }
/** * 1d discrete fourier transform, note that this will throw an exception if the passed in input * isn't a vector. See matlab's fft2 for more information * * @param inputC the input to transform * @return the the discrete fourier transform of the passed in input */ public static ComplexDoubleMatrix complexDiscreteFourierTransform1d(ComplexDoubleMatrix inputC) { if (inputC.rows != 1 && inputC.columns != 1) throw new IllegalArgumentException("Illegal input: Must be a vector"); double len = Math.max(inputC.rows, inputC.columns); ComplexDouble c2 = new ComplexDouble(0, -2).muli(FastMath.PI).divi(len); ComplexDoubleMatrix range = MatrixUtil.complexRangeVector(0, len); ComplexDoubleMatrix matrix = exp(range.mmul(range.transpose().mul(c2))); ComplexDoubleMatrix complexRet = inputC.isRowVector() ? matrix.mmul(inputC) : inputC.mmul(matrix); return complexRet; }