コード例 #1
0
  @Test
  public void testMmul() {
    IComplexNDArray n = Nd4j.createComplex(Nd4j.linspace(1, 10, 10));
    IComplexNDArray transposed = n.transpose();
    assertEquals(true, n.isRowVector());
    assertEquals(true, transposed.isColumnVector());

    INDArray innerProduct = n.mmul(transposed);

    INDArray scalar = Nd4j.scalar(Nd4j.createComplexNumber(385, 0));
    assertEquals(getFailureMessage(), scalar, innerProduct);

    INDArray outerProduct = transposed.mmul(n);
    assertEquals(true, Shape.shapeEquals(new int[] {10, 10}, outerProduct.shape()));

    IComplexNDArray d3 =
        Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {1, 2})).reshape(2, 1);
    IComplexNDArray d4 = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {3, 4}));
    INDArray resultNDArray = d3.mmul(d4);
    INDArray result =
        Nd4j.createComplex(
            new IComplexNumber[][] {
              {Nd4j.createComplexNumber(3, 0), Nd4j.createComplexNumber(4, 0)},
              {Nd4j.createComplexNumber(6, 0), Nd4j.createComplexNumber(8, 0)}
            });

    assertEquals(getFailureMessage(), result, resultNDArray);
  }
コード例 #2
0
  @Test
  public void testMmulColumnVector() {

    IComplexNDArray three =
        Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {4, 19}), new int[] {1, 2});
    IComplexNDArray test = Nd4j.complexLinSpace(1, 30, 30).reshape(3, 5, 2);
    IComplexNDArray sliceRow = test.slice(0).getRow(1);
    assertEquals(three, sliceRow);

    IComplexNDArray twoSix =
        Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {2, 6}), new int[] {2, 1});
    IComplexNDArray threeTwoSix = three.mmul(twoSix);

    IComplexNDArray sliceRowTwoSix = sliceRow.mmul(twoSix);
    assertEquals(getFailureMessage(), threeTwoSix, sliceRowTwoSix);
  }
コード例 #3
0
  @Test
  public void testTwoByTwoMmul() {
    IComplexNDArray oneThroughFour = Nd4j.createComplex(Nd4j.linspace(1, 4, 4).reshape(2, 2));
    IComplexNDArray fiveThroughEight = Nd4j.createComplex(Nd4j.linspace(5, 8, 4).reshape(2, 2));

    IComplexNDArray solution = Nd4j.createComplex(Nd4j.create(new double[][] {{23, 31}, {34, 46}}));
    IComplexNDArray test = oneThroughFour.mmul(fiveThroughEight);
    assertEquals(getFailureMessage(), solution, test);
  }