@Test
  public void testPreOutputMethodContained() {
    Layer layer = getContainedConfig();
    INDArray col = getContainedCol();

    INDArray expectedOutput =
        Nd4j.create(
            new double[] {
              4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8., 4., 4., 4., 4., 8.,
              8., 8., 8., 4., 4., 4., 4., 8., 8., 8., 8
            },
            new int[] {1, 2, 4, 4});

    org.deeplearning4j.nn.layers.convolution.ConvolutionLayer layer2 =
        (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) layer;
    layer2.setCol(col);
    INDArray activation = layer2.preOutput(true);

    assertArrayEquals(expectedOutput.shape(), activation.shape());
    assertEquals(expectedOutput, activation);
  }
  public static void testAccuracy() {
    double[][][][] data = {
      {
        {
          {1.0, 2.0, 3.0},
          {4.0, 5.0, 6.0},
          {7.0, 8.0, 9.0}
        }
      }
    };
    double[] flat = ArrayUtil.flattenDoubleArray(data);
    int[] shape = {1, 1, 3, 3};
    INDArray input = Nd4j.create(flat, shape, 'c');

    TestCase testCase = new TestCase(1, 1, 2, 2, 1, 1, 0, 0, 3, 3);

    ConvolutionLayer convolutionLayerBuilder =
        new ConvolutionLayer.Builder(testCase.kW, testCase.kH)
            .nIn(testCase.nInputPlane)
            .stride(testCase.dW, testCase.dH)
            .padding(testCase.padW, testCase.padH)
            .nOut(testCase.nOutputPlane)
            .build();

    MultiLayerConfiguration.Builder builder =
        new NeuralNetConfiguration.Builder().list().layer(0, convolutionLayerBuilder);

    MultiLayerConfiguration conf = builder.build();
    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setInput(input);
    model.getLayer(0).setInput(input);
    org.deeplearning4j.nn.layers.convolution.ConvolutionLayer convolutionLayer =
        (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) model.getLayer(0);

    System.out.println(convolutionLayer.params());
    System.out.println(convolutionLayer.preOutput(false));
  }
  public static void testForward() {
    for (TestCase testCase : allTestCases) {
      try (BufferedWriter writer =
          new BufferedWriter(new FileWriter(new File("dl4jPerformance.csv"), true))) {
        ConvolutionLayer convolutionLayerBuilder =
            new ConvolutionLayer.Builder(testCase.kW, testCase.kH)
                .nIn(testCase.nInputPlane)
                .stride(testCase.dW, testCase.dH)
                .padding(testCase.padW, testCase.padH)
                .nOut(testCase.nOutputPlane)
                .build();

        MultiLayerConfiguration.Builder builder =
            new NeuralNetConfiguration.Builder().list().layer(0, convolutionLayerBuilder);

        MultiLayerConfiguration conf = builder.build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        INDArray input =
            Nd4j.rand(
                seed, batchSize, testCase.nInputPlane, testCase.inputWidth, testCase.inputHeight);
        model.setInput(input);
        model.getLayer(0).setInput(input);
        org.deeplearning4j.nn.layers.convolution.ConvolutionLayer convolutionLayer =
            (org.deeplearning4j.nn.layers.convolution.ConvolutionLayer) model.getLayer(0);

        double start = System.nanoTime();
        for (int i = 0; i < forwardIterations; i++) {
          convolutionLayer.preOutput(false);
        }
        double end = System.nanoTime();
        double timeMillis = (end - start) / 1e6 / forwardIterations;

        writer.write(
            "Convolution("
                + testCase.nInputPlane
                + " "
                + testCase.nOutputPlane
                + " "
                + testCase.kW
                + " "
                + testCase.kH
                + " "
                + testCase.dW
                + " "
                + testCase.dH
                + " "
                + testCase.padW
                + " "
                + testCase.padH
                + " "
                + testCase.inputWidth
                + " "
                + testCase.inputHeight
                + ") "
                + " forward, "
                + timeMillis
                + "\n");
      } catch (Exception ex) {
        ex.printStackTrace();
      }
    }
  }