/**
   * Discrete fourier transform 2d
   *
   * @param input the input to transform
   * @param rows the number of rows in the transformed output matrix
   * @param cols the number of columns in the transformed output matrix
   * @return the discrete fourier transform of the input
   */
  public static ComplexFloatMatrix complexDisceteFourierTransform(
      FloatMatrix input, int rows, int cols) {
    ComplexFloatMatrix base;

    // pad
    if (input.rows < rows || input.columns < cols)
      base = MatrixUtil.complexPadWithZeros(input, rows, cols);
    // truncation
    else if (input.rows > rows || input.columns > cols) {
      base = new ComplexFloatMatrix(input);
      base =
          base.get(
              MatrixUtil.toIndices(RangeUtils.interval(0, rows)),
              MatrixUtil.toIndices(RangeUtils.interval(0, cols)));
    } else base = new ComplexFloatMatrix(input);

    ComplexFloatMatrix temp = new ComplexFloatMatrix(base.rows, base.columns);
    ComplexFloatMatrix ret = new ComplexFloatMatrix(base.rows, base.columns);
    for (int i = 0; i < base.columns; i++) {
      ComplexFloatMatrix column = base.getColumn(i);
      temp.putColumn(i, complexDiscreteFourierTransform1d(column));
    }

    for (int i = 0; i < ret.rows; i++) {
      ComplexFloatMatrix row = temp.getRow(i);
      ret.putRow(i, complexDiscreteFourierTransform1d(row));
    }
    return ret;
  }
  /**
   * Discrete fourier transform 2d
   *
   * @param input the input to transform
   * @param shape the shape of the output matrix
   * @return the discrete fourier transform of the input
   */
  public static ComplexDoubleMatrix complexDisceteFourierTransform(NDArray input, int[] shape) {
    ComplexNDArray base;

    // pad
    if (ArrayUtil.anyLess(input.shape(), shape))
      base = MatrixUtil.complexPadWithZeros(input, shape);
    // truncation
    else if (ArrayUtil.anyMore(input.shape(), shape)) {
      base = new ComplexNDArray(shape);
      for (int i = 0; i < ArrayUtil.prod(shape); i++) base.put(i, input.get(i));
    } else base = new ComplexNDArray(input);

    ComplexNDArray temp = new ComplexNDArray(shape);
    ComplexNDArray ret = new ComplexNDArray(shape);

    for (int i = 0; i < base.columns; i++) {
      ComplexDoubleMatrix column = base.getColumn(i);
      temp.putColumn(i, complexDiscreteFourierTransform1d(column));
    }

    for (int i = 0; i < ret.rows; i++) {
      ComplexDoubleMatrix row = temp.getRow(i);
      ret.putRow(i, complexDiscreteFourierTransform1d(row));
    }
    return ret;
  }