@Test public void testMmul() { IComplexNDArray n = Nd4j.createComplex(Nd4j.linspace(1, 10, 10)); IComplexNDArray transposed = n.transpose(); assertEquals(true, n.isRowVector()); assertEquals(true, transposed.isColumnVector()); INDArray innerProduct = n.mmul(transposed); INDArray scalar = Nd4j.scalar(Nd4j.createComplexNumber(385, 0)); assertEquals(getFailureMessage(), scalar, innerProduct); INDArray outerProduct = transposed.mmul(n); assertEquals(true, Shape.shapeEquals(new int[] {10, 10}, outerProduct.shape())); IComplexNDArray d3 = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {1, 2})).reshape(2, 1); IComplexNDArray d4 = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {3, 4})); INDArray resultNDArray = d3.mmul(d4); INDArray result = Nd4j.createComplex( new IComplexNumber[][] { {Nd4j.createComplexNumber(3, 0), Nd4j.createComplexNumber(4, 0)}, {Nd4j.createComplexNumber(6, 0), Nd4j.createComplexNumber(8, 0)} }); assertEquals(getFailureMessage(), result, resultNDArray); }
@Test public void testVectorInit() { DataBuffer data = Nd4j.linspace(1, 4, 4).data(); IComplexNDArray arr = Nd4j.createComplex(data, new int[] {4}); assertEquals(true, arr.isRowVector()); IComplexNDArray arr2 = Nd4j.createComplex(data, new int[] {1, 4}); assertEquals(true, arr2.isRowVector()); IComplexNDArray columnVector = Nd4j.createComplex(data, new int[] {4, 1}); assertEquals(true, columnVector.isColumnVector()); }