@Test public void testMmul() { IComplexNDArray n = Nd4j.createComplex(Nd4j.linspace(1, 10, 10)); IComplexNDArray transposed = n.transpose(); assertEquals(true, n.isRowVector()); assertEquals(true, transposed.isColumnVector()); INDArray innerProduct = n.mmul(transposed); INDArray scalar = Nd4j.scalar(Nd4j.createComplexNumber(385, 0)); assertEquals(getFailureMessage(), scalar, innerProduct); INDArray outerProduct = transposed.mmul(n); assertEquals(true, Shape.shapeEquals(new int[] {10, 10}, outerProduct.shape())); IComplexNDArray d3 = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {1, 2})).reshape(2, 1); IComplexNDArray d4 = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {3, 4})); INDArray resultNDArray = d3.mmul(d4); INDArray result = Nd4j.createComplex( new IComplexNumber[][] { {Nd4j.createComplexNumber(3, 0), Nd4j.createComplexNumber(4, 0)}, {Nd4j.createComplexNumber(6, 0), Nd4j.createComplexNumber(8, 0)} }); assertEquals(getFailureMessage(), result, resultNDArray); }
@Test public void testMmulColumnVector() { IComplexNDArray three = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {4, 19}), new int[] {1, 2}); IComplexNDArray test = Nd4j.complexLinSpace(1, 30, 30).reshape(3, 5, 2); IComplexNDArray sliceRow = test.slice(0).getRow(1); assertEquals(three, sliceRow); IComplexNDArray twoSix = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {2, 6}), new int[] {2, 1}); IComplexNDArray threeTwoSix = three.mmul(twoSix); IComplexNDArray sliceRowTwoSix = sliceRow.mmul(twoSix); assertEquals(getFailureMessage(), threeTwoSix, sliceRowTwoSix); }
@Test public void testTwoByTwoMmul() { IComplexNDArray oneThroughFour = Nd4j.createComplex(Nd4j.linspace(1, 4, 4).reshape(2, 2)); IComplexNDArray fiveThroughEight = Nd4j.createComplex(Nd4j.linspace(5, 8, 4).reshape(2, 2)); IComplexNDArray solution = Nd4j.createComplex(Nd4j.create(new double[][] {{23, 31}, {34, 46}})); IComplexNDArray test = oneThroughFour.mmul(fiveThroughEight); assertEquals(getFailureMessage(), solution, test); }