@Override public INDArray preOutput( INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad) { int miniBatch = input.size(0); int inH = input.size(2); int inW = input.size(3); int outDepth = weights.size(0); int inDepth = weights.size(1); int kH = weights.size(2); int kW = weights.size(3); int[] srcStride = input.stride(); checkCudnn( cudnnSetTensor4dDescriptorEx( cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, srcStride[0], srcStride[1], srcStride[2], srcStride[3])); checkCudnn( cudnnSetFilter4dDescriptor( cudnnContext.filterDesc, dataType, tensorFormat, outDepth, inDepth, kH, kW)); checkCudnn( cudnnSetConvolution2dDescriptor( cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION)); // find dimension of convolution output int[] algo = new int[1], n = new int[1], c = new int[1], h = new int[1], w = new int[1]; checkCudnn( cudnnGetConvolution2dForwardOutputDim( cudnnContext.convDesc, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, n, c, h, w)); INDArray z = Nd4j.createUninitialized(new int[] {n[0], c[0], h[0], w[0]}, 'c'); int[] dstStride = z.stride(); checkCudnn( cudnnSetTensor4dDescriptorEx( cudnnContext.dstTensorDesc, dataType, n[0], c[0], h[0], w[0], dstStride[0], dstStride[1], dstStride[2], dstStride[3])); checkCudnn( cudnnGetConvolutionForwardAlgorithm( cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, 0, algo)); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(input, weights, bias, z); Pointer srcData = allocator.getPointer(input, context); Pointer filterData = allocator.getPointer(weights, context); Pointer biasData = allocator.getPointer(bias, context); Pointer dstData = allocator.getPointer(z, context); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn( cudnnGetConvolutionForwardWorkspaceSize( cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], sizeInBytes)); if (sizeInBytes.get(0) > workSpace.capacity()) { workSpace.deallocate(); workSpace = new WorkSpace(sizeInBytes.get(0)); } checkCudnn( cudnnConvolutionForward( cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData)); checkCudnn( cudnnSetTensor4dDescriptor( cudnnContext.biasTensorDesc, tensorFormat, dataType, 1, c[0], 1, 1)); checkCudnn( cudnnAddTensor( cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha, cudnnContext.dstTensorDesc, dstData)); allocator.registerAction(context, input, weights, bias, z); return z; }
@Override public Pair<Gradient, INDArray> backpropGradient( INDArray input, INDArray weights, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, String afn) { int miniBatch = input.size(0); int inH = input.size(2); int inW = input.size(3); int outDepth = weights.size(0); int inDepth = weights.size(1); int kH = weights.size(2); int kW = weights.size(3); int outH = Convolution.outSize(inH, kernel[0], strides[0], pad[0], false); int outW = Convolution.outSize(inW, kernel[1], strides[1], pad[1], false); if (!Shape.strideDescendingCAscendingF(delta)) { // apparently not supported by cuDNN delta = delta.dup(); } int[] srcStride = input.stride(); int[] deltaStride = delta.stride(); int[] algo = new int[1]; checkCudnn( cudnnSetTensor4dDescriptorEx( cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, srcStride[0], srcStride[1], srcStride[2], srcStride[3])); checkCudnn( cudnnSetTensor4dDescriptorEx( cudnnContext.deltaTensorDesc, dataType, miniBatch, outDepth, outH, outW, deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3])); checkCudnn( cudnnSetConvolution2dDescriptor( cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION)); checkCudnn( cudnnSetFilter4dDescriptor( cudnnContext.filterDesc, dataType, tensorFormat, outDepth, inDepth, kH, kW)); checkCudnn( cudnnGetConvolutionBackwardFilterAlgorithm( cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, 0, algo)); INDArray epsNext = Nd4j.create(new int[] {miniBatch, inDepth, inH, inW}, 'c'); int[] dstStride = epsNext.stride(); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator .getFlowController() .prepareAction(input, weights, weightGradView, biasGradView, delta, epsNext); Pointer srcData = allocator.getPointer(input, context); Pointer filterData = allocator.getPointer(weights, context); Pointer filterGradData = allocator.getPointer(weightGradView, context); Pointer biasGradData = allocator.getPointer(biasGradView, context); Pointer deltaData = allocator.getPointer(delta, context); Pointer dstData = allocator.getPointer(epsNext, context); checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream()))); checkCudnn( cudnnSetTensor4dDescriptorEx( cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, dstStride[0], dstStride[1], dstStride[2], dstStride[3])); checkCudnn( cudnnGetConvolutionBackwardFilterWorkspaceSize( cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo[0], sizeInBytes)); long sizeInBytes1 = sizeInBytes.get(0); checkCudnn( cudnnGetConvolutionBackwardDataWorkspaceSize( cudnnContext, cudnnContext.filterDesc, cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], sizeInBytes)); long sizeInBytes2 = sizeInBytes.get(0); if (sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) { workSpace.deallocate(); workSpace = new WorkSpace(Math.max(sizeInBytes1, sizeInBytes2)); } checkCudnn( cudnnSetTensor4dDescriptor( cudnnContext.biasTensorDesc, tensorFormat, dataType, 1, outDepth, 1, 1)); checkCudnn( cudnnConvolutionBackwardBias( cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta, cudnnContext.biasTensorDesc, biasGradData)); checkCudnn( cudnnConvolutionBackwardFilter( cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo[0], workSpace, workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData)); checkCudnn( cudnnConvolutionBackwardData( cudnnContext, alpha, cudnnContext.filterDesc, filterData, cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo[0], workSpace, workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData)); allocator.registerAction(context, input, weights, weightGradView, biasGradView, delta, epsNext); Gradient retGradient = new DefaultGradient(); retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView); retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c'); return new Pair<>(retGradient, epsNext); }