@Test
  public void testLeadingOnes() {
    IComplexNDArray complexRand = Nd4j.complexRand(100, 1, 28, 28);
    assertArrayEquals(new int[] {100, 1, 28, 28}, complexRand.shape());
    IComplexNDArray arr = complexRand.linearView();
    for (int i = 0; i < arr.length(); i++) {
      arr.putScalar(i, arr.getComplex(i));
    }

    IComplexNDArray complexRand2 = Nd4j.complexRand(28, 28, 1);
    assertArrayEquals(new int[] {28, 28, 1}, complexRand2.shape());
    IComplexNDArray arr2 = complexRand.linearView();
    for (int i = 0; i < arr2.length(); i++) {
      arr2.putScalar(i, arr2.getComplex(i));
    }
  }
  @Test
  public void testGetReal() {
    DataBuffer data = Nd4j.linspace(1, 8, 8).data();
    int[] shape = new int[] {1, 8};
    IComplexNDArray arr = Nd4j.createComplex(shape);
    for (int i = 0; i < arr.length(); i++) arr.put(i, Nd4j.scalar(data.getFloat(i)));
    INDArray arr2 = Nd4j.create(data, shape);
    assertEquals(arr2, arr.getReal());

    INDArray ones = Nd4j.ones(10);
    IComplexNDArray n2 = Nd4j.complexOnes(10);
    assertEquals(getFailureMessage(), ones, n2.getReal());
  }
 @Test
 public void testCreateComplexFromReal() {
   INDArray n = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8}, new int[] {2, 4});
   IComplexNDArray nComplex = Nd4j.createComplex(n);
   for (int i = 0; i < n.vectorsAlongDimension(0); i++) {
     INDArray vec = n.vectorAlongDimension(i, 0);
     IComplexNDArray vecComplex = nComplex.vectorAlongDimension(i, 0);
     assertEquals(vec.length(), vecComplex.length());
     for (int j = 0; j < vec.length(); j++) {
       IComplexNumber currComplex = vecComplex.getComplex(j);
       double curr = vec.getFloat(j);
       assertEquals(curr, currComplex.realComponent().doubleValue(), 1e-1);
     }
     assertEquals(vec, vecComplex.getReal());
   }
 }
  @Test
  public void testPutComplex() {
    INDArray fourTwoTwo = Nd4j.linspace(1, 16, 16).reshape(4, 2, 2);
    IComplexNDArray test = Nd4j.createComplex(4, 2, 2);

    for (int i = 0; i < test.vectorsAlongDimension(0); i++) {
      INDArray vector = fourTwoTwo.vectorAlongDimension(i, 0);
      IComplexNDArray complexVector = test.vectorAlongDimension(i, 0);
      for (int j = 0; j < complexVector.length(); j++) {
        complexVector.putReal(j, vector.getFloat(j));
      }
    }

    for (int i = 0; i < test.vectorsAlongDimension(0); i++) {
      INDArray vector = fourTwoTwo.vectorAlongDimension(i, 0);
      IComplexNDArray complexVector = test.vectorAlongDimension(i, 0);
      assertEquals(vector, complexVector.real());
    }
  }
  @Test
  public void testPutAndGet() {
    IComplexNDArray multiRow = Nd4j.createComplex(2, 2);
    multiRow.putScalar(0, 0, Nd4j.createComplexNumber(1, 0));
    multiRow.putScalar(0, 1, Nd4j.createComplexNumber(2, 0));
    multiRow.putScalar(1, 0, Nd4j.createComplexNumber(3, 0));
    multiRow.putScalar(1, 1, Nd4j.createComplexNumber(4, 0));
    assertEquals(Nd4j.createComplexNumber(1, 0), multiRow.getComplex(0, 0));
    assertEquals(Nd4j.createComplexNumber(2, 0), multiRow.getComplex(0, 1));
    assertEquals(Nd4j.createComplexNumber(3, 0), multiRow.getComplex(1, 0));
    assertEquals(Nd4j.createComplexNumber(4, 0), multiRow.getComplex(1, 1));

    IComplexNDArray arr =
        Nd4j.createComplex(Nd4j.create(new double[] {1, 2, 3, 4}, new int[] {2, 2}));
    assertEquals(4, arr.length());
    assertEquals(8, arr.data().length());
    arr.put(1, 1, Nd4j.scalar(5.0));

    IComplexNumber n1 = arr.getComplex(1, 1);
    IComplexNumber n2 = arr.getComplex(1, 1);

    assertEquals(getFailureMessage(), 5.0, n1.realComponent().doubleValue(), 1e-1);
    assertEquals(getFailureMessage(), 0.0, n2.imaginaryComponent().doubleValue(), 1e-1);
  }
  @Test
  public void testVectorGet() {
    IComplexNDArray arr =
        Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {1, 8}));
    for (int i = 0; i < arr.length(); i++) {
      IComplexNumber curr = arr.getComplex(i);
      assertEquals(Nd4j.createDouble(i + 1, 0), curr);
    }

    IComplexNDArray matrix =
        Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {2, 4}));
    IComplexNDArray row = matrix.getRow(1);
    IComplexNDArray column = matrix.getColumn(1);

    IComplexNDArray validate =
        Nd4j.createComplex(Nd4j.create(new double[] {2, 4, 6, 8}, new int[] {1, 4}));
    IComplexNumber d = row.getComplex(3);
    assertEquals(Nd4j.createDouble(8, 0), d);
    assertEquals(row, validate);

    IComplexNumber d2 = column.getComplex(1);

    assertEquals(Nd4j.createDouble(4, 0), d2);
  }