/**
   * Discrete fourier transform 2d
   *
   * @param input the input to transform
   * @param rows the number of rows in the transformed output matrix
   * @param cols the number of columns in the transformed output matrix
   * @return the discrete fourier transform of the input
   */
  public static ComplexFloatMatrix complexDisceteFourierTransform(
      FloatMatrix input, int rows, int cols) {
    ComplexFloatMatrix base;

    // pad
    if (input.rows < rows || input.columns < cols)
      base = MatrixUtil.complexPadWithZeros(input, rows, cols);
    // truncation
    else if (input.rows > rows || input.columns > cols) {
      base = new ComplexFloatMatrix(input);
      base =
          base.get(
              MatrixUtil.toIndices(RangeUtils.interval(0, rows)),
              MatrixUtil.toIndices(RangeUtils.interval(0, cols)));
    } else base = new ComplexFloatMatrix(input);

    ComplexFloatMatrix temp = new ComplexFloatMatrix(base.rows, base.columns);
    ComplexFloatMatrix ret = new ComplexFloatMatrix(base.rows, base.columns);
    for (int i = 0; i < base.columns; i++) {
      ComplexFloatMatrix column = base.getColumn(i);
      temp.putColumn(i, complexDiscreteFourierTransform1d(column));
    }

    for (int i = 0; i < ret.rows; i++) {
      ComplexFloatMatrix row = temp.getRow(i);
      ret.putRow(i, complexDiscreteFourierTransform1d(row));
    }
    return ret;
  }
  public static FloatMatrix conv2d(FloatMatrix input, FloatMatrix kernel, Type type) {

    FloatMatrix xShape = new FloatMatrix(1, 2);
    xShape.put(0, input.rows);
    xShape.put(1, input.columns);

    FloatMatrix yShape = new FloatMatrix(1, 2);
    yShape.put(0, kernel.rows);
    yShape.put(1, kernel.columns);

    FloatMatrix zShape = xShape.add(yShape).sub(1);
    int retRows = (int) zShape.get(0);
    int retCols = (int) zShape.get(1);

    ComplexFloatMatrix fftInput = complexDisceteFourierTransform(input, retRows, retCols);
    ComplexFloatMatrix fftKernel = complexDisceteFourierTransform(kernel, retRows, retCols);
    ComplexFloatMatrix mul = fftKernel.mul(fftInput);
    ComplexFloatMatrix retComplex = complexInverseDisceteFourierTransform(mul);

    FloatMatrix ret = retComplex.getReal();

    if (type == Type.VALID) {

      FloatMatrix validShape = xShape.subi(yShape).add(1);

      FloatMatrix start = zShape.sub(validShape).div(2);
      FloatMatrix end = start.add(validShape);
      if (start.get(0) < 1 || start.get(1) < 1)
        throw new IllegalStateException("Illegal row index " + start);
      if (end.get(0) < 1 || end.get(1) < 1)
        throw new IllegalStateException("Illegal column index " + end);

      ret =
          ret.get(
              RangeUtils.interval((int) start.get(0), (int) end.get(0)),
              RangeUtils.interval((int) start.get(1), (int) end.get(1)));
    }

    return ret;
  }