@Test public void testSubSampleLayerNoneBackprop() throws Exception { Layer layer = getCNNConfig(nChannelsIn, depth, kernelSize, stride, padding); Pair<Gradient, INDArray> out = layer.backpropGradient(epsilon); assertEquals(epsilon.shape().length, out.getSecond().shape().length); assertEquals(nExamples, out.getSecond().size(1)); // depth retained }
public static void testBackward() { for (TestCase testCase : allTestCases) { try (BufferedWriter writer = new BufferedWriter(new FileWriter(new File("dl4jPerformance.csv"), true))) { ConvolutionLayer convolutionLayerBuilder = new ConvolutionLayer.Builder(testCase.kW, testCase.kH) .nIn(testCase.nInputPlane) .stride(testCase.dW, testCase.dH) .padding(testCase.padW, testCase.padH) .nOut(testCase.nOutputPlane) .build(); MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0, convolutionLayerBuilder); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); INDArray input = Nd4j.rand( seed, batchSize, testCase.nInputPlane, testCase.inputWidth, testCase.inputHeight); model.setInput(input); model.getLayer(0).setInput(input); model.feedForward(); org.deeplearning4j.nn.api.Layer convolutionLayer = model.getLayer(0); INDArray output = convolutionLayer.activate(); INDArray epsilon = Nd4j.rand(seed, output.size(0), output.size(1), output.size(2), output.size(3)); Method initGradientView = model.getClass().getDeclaredMethod("initGradientsView"); initGradientView.setAccessible(true); initGradientView.invoke(model); double start = System.nanoTime(); for (int i = 0; i < backwardIterations; i++) { convolutionLayer.backpropGradient(epsilon); } double end = System.nanoTime(); double timeMillis = (end - start) / 1e6 / backwardIterations; writer.write( "Convolution(" + testCase.nInputPlane + " " + testCase.nOutputPlane + " " + testCase.kW + " " + testCase.kH + " " + testCase.dW + " " + testCase.dH + " " + testCase.padW + " " + testCase.padH + " " + testCase.inputWidth + " " + testCase.inputHeight + ") " + " backward, " + timeMillis + "\n"); } catch (Exception ex) { ex.printStackTrace(); } } }