/** * Discrete fourier transform 2d * * @param input the input to transform * @param rows the number of rows in the transformed output matrix * @param cols the number of columns in the transformed output matrix * @return the discrete fourier transform of the input */ public static ComplexFloatMatrix complexDisceteFourierTransform( FloatMatrix input, int rows, int cols) { ComplexFloatMatrix base; // pad if (input.rows < rows || input.columns < cols) base = MatrixUtil.complexPadWithZeros(input, rows, cols); // truncation else if (input.rows > rows || input.columns > cols) { base = new ComplexFloatMatrix(input); base = base.get( MatrixUtil.toIndices(RangeUtils.interval(0, rows)), MatrixUtil.toIndices(RangeUtils.interval(0, cols))); } else base = new ComplexFloatMatrix(input); ComplexFloatMatrix temp = new ComplexFloatMatrix(base.rows, base.columns); ComplexFloatMatrix ret = new ComplexFloatMatrix(base.rows, base.columns); for (int i = 0; i < base.columns; i++) { ComplexFloatMatrix column = base.getColumn(i); temp.putColumn(i, complexDiscreteFourierTransform1d(column)); } for (int i = 0; i < ret.rows; i++) { ComplexFloatMatrix row = temp.getRow(i); ret.putRow(i, complexDiscreteFourierTransform1d(row)); } return ret; }
public static FloatMatrix conv2d(FloatMatrix input, FloatMatrix kernel, Type type) { FloatMatrix xShape = new FloatMatrix(1, 2); xShape.put(0, input.rows); xShape.put(1, input.columns); FloatMatrix yShape = new FloatMatrix(1, 2); yShape.put(0, kernel.rows); yShape.put(1, kernel.columns); FloatMatrix zShape = xShape.add(yShape).sub(1); int retRows = (int) zShape.get(0); int retCols = (int) zShape.get(1); ComplexFloatMatrix fftInput = complexDisceteFourierTransform(input, retRows, retCols); ComplexFloatMatrix fftKernel = complexDisceteFourierTransform(kernel, retRows, retCols); ComplexFloatMatrix mul = fftKernel.mul(fftInput); ComplexFloatMatrix retComplex = complexInverseDisceteFourierTransform(mul); FloatMatrix ret = retComplex.getReal(); if (type == Type.VALID) { FloatMatrix validShape = xShape.subi(yShape).add(1); FloatMatrix start = zShape.sub(validShape).div(2); FloatMatrix end = start.add(validShape); if (start.get(0) < 1 || start.get(1) < 1) throw new IllegalStateException("Illegal row index " + start); if (end.get(0) < 1 || end.get(1) < 1) throw new IllegalStateException("Illegal column index " + end); ret = ret.get( RangeUtils.interval((int) start.get(0), (int) end.get(0)), RangeUtils.interval((int) start.get(1), (int) end.get(1))); } return ret; }