Пример #1
0
  /**
   * Discrete fourier transform 2d
   *
   * @param input the input to transform
   * @param rows the number of rows in the transformed output matrix
   * @param cols the number of columns in the transformed output matrix
   * @return the discrete fourier transform of the input
   */
  public static ComplexDoubleMatrix complexDisceteFourierTransform(
      ComplexDoubleMatrix input, int rows, int cols) {
    ComplexDoubleMatrix base;

    // pad
    if (input.rows < rows || input.columns < cols)
      base = MatrixUtil.complexPadWithZeros(input, rows, cols);
    // truncation
    else if (input.rows > rows || input.columns > cols) {
      base = input.dup();
      base =
          base.get(
              MatrixUtil.toIndices(RangeUtils.interval(0, rows)),
              MatrixUtil.toIndices(RangeUtils.interval(0, cols)));
    } else base = input.dup();

    ComplexDoubleMatrix temp = new ComplexDoubleMatrix(base.rows, base.columns);
    ComplexDoubleMatrix ret = new ComplexDoubleMatrix(base.rows, base.columns);
    for (int i = 0; i < base.columns; i++) {
      ComplexDoubleMatrix column = base.getColumn(i);
      temp.putColumn(i, complexDiscreteFourierTransform1d(column));
    }

    for (int i = 0; i < ret.rows; i++) {
      ComplexDoubleMatrix row = temp.getRow(i);
      ret.putRow(i, complexDiscreteFourierTransform1d(row));
    }
    return ret;
  }
Пример #2
0
  /**
   * 1d discrete fourier transform, note that this will throw an exception if the passed in input
   * isn't a vector. See matlab's fft2 for more information
   *
   * @param inputC the input to transform
   * @return the the discrete fourier transform of the passed in input
   */
  public static ComplexNDArray complexDiscreteFourierTransform1d(ComplexNDArray inputC) {
    if (inputC.shape().length != 1)
      throw new IllegalArgumentException("Illegal input: Must be a vector");

    double len = Math.max(inputC.rows, inputC.columns);
    ComplexDouble c2 = new ComplexDouble(0, -2).muli(FastMath.PI).divi(len);
    ComplexDoubleMatrix range = MatrixUtil.complexRangeVector(0, len);
    ComplexDoubleMatrix matrix = exp(range.mmul(range.transpose().mul(c2)));
    ComplexDoubleMatrix complexRet = matrix.mmul(inputC);
    return ComplexNDArray.wrap(inputC, complexRet);
  }
Пример #3
0
  public static DoubleMatrix conv2d(DoubleMatrix input, DoubleMatrix kernel, Type type) {

    DoubleMatrix xShape = new DoubleMatrix(1, 2);
    xShape.put(0, input.rows);
    xShape.put(1, input.columns);

    DoubleMatrix yShape = new DoubleMatrix(1, 2);
    yShape.put(0, kernel.rows);
    yShape.put(1, kernel.columns);

    DoubleMatrix zShape = xShape.addi(yShape).subi(1);
    int retRows = (int) zShape.get(0);
    int retCols = (int) zShape.get(1);

    ComplexDoubleMatrix fftInput = complexDisceteFourierTransform(input, retRows, retCols);
    ComplexDoubleMatrix fftKernel = complexDisceteFourierTransform(kernel, retRows, retCols);
    ComplexDoubleMatrix mul = fftKernel.muli(fftInput);
    ComplexDoubleMatrix retComplex = complexInverseDisceteFourierTransform(mul);
    DoubleMatrix ret = retComplex.getReal();

    if (type == Type.VALID) {

      DoubleMatrix validShape = xShape.subi(yShape).addi(1);

      DoubleMatrix start = zShape.subi(validShape).divi(2);
      DoubleMatrix end = start.addi(validShape);
      if (start.get(0) < 1 || start.get(1) < 1)
        throw new IllegalStateException("Illegal row index " + start);
      if (end.get(0) < 1 || end.get(1) < 1)
        throw new IllegalStateException("Illegal column index " + end);

      ret =
          ret.get(
              RangeUtils.interval((int) start.get(0), (int) end.get(0)),
              RangeUtils.interval((int) start.get(1), (int) end.get(1)));
    }

    return ret;
  }
Пример #4
0
  /**
   * 1d discrete fourier transform, note that this will throw an exception if the passed in input
   * isn't a vector. See matlab's fft2 for more information
   *
   * @param inputC the input to transform
   * @return the the discrete fourier transform of the passed in input
   */
  public static ComplexDoubleMatrix complexDiscreteFourierTransform1d(ComplexDoubleMatrix inputC) {
    if (inputC.rows != 1 && inputC.columns != 1)
      throw new IllegalArgumentException("Illegal input: Must be a vector");

    double len = Math.max(inputC.rows, inputC.columns);
    ComplexDouble c2 = new ComplexDouble(0, -2).muli(FastMath.PI).divi(len);
    ComplexDoubleMatrix range = MatrixUtil.complexRangeVector(0, len);
    ComplexDoubleMatrix matrix = exp(range.mmul(range.transpose().mul(c2)));
    ComplexDoubleMatrix complexRet =
        inputC.isRowVector() ? matrix.mmul(inputC) : inputC.mmul(matrix);
    return complexRet;
  }