public static DoubleMatrix conv2d(DoubleMatrix input, DoubleMatrix kernel, Type type) { DoubleMatrix xShape = new DoubleMatrix(1, 2); xShape.put(0, input.rows); xShape.put(1, input.columns); DoubleMatrix yShape = new DoubleMatrix(1, 2); yShape.put(0, kernel.rows); yShape.put(1, kernel.columns); DoubleMatrix zShape = xShape.addi(yShape).subi(1); int retRows = (int) zShape.get(0); int retCols = (int) zShape.get(1); ComplexDoubleMatrix fftInput = complexDisceteFourierTransform(input, retRows, retCols); ComplexDoubleMatrix fftKernel = complexDisceteFourierTransform(kernel, retRows, retCols); ComplexDoubleMatrix mul = fftKernel.muli(fftInput); ComplexDoubleMatrix retComplex = complexInverseDisceteFourierTransform(mul); DoubleMatrix ret = retComplex.getReal(); if (type == Type.VALID) { DoubleMatrix validShape = xShape.subi(yShape).addi(1); DoubleMatrix start = zShape.subi(validShape).divi(2); DoubleMatrix end = start.addi(validShape); if (start.get(0) < 1 || start.get(1) < 1) throw new IllegalStateException("Illegal row index " + start); if (end.get(0) < 1 || end.get(1) < 1) throw new IllegalStateException("Illegal column index " + end); ret = ret.get( RangeUtils.interval((int) start.get(0), (int) end.get(0)), RangeUtils.interval((int) start.get(1), (int) end.get(1))); } return ret; }