@Override
  public INDArray preOutput(
      INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad) {
    int miniBatch = input.size(0);
    int inH = input.size(2);
    int inW = input.size(3);

    int outDepth = weights.size(0);
    int inDepth = weights.size(1);
    int kH = weights.size(2);
    int kW = weights.size(3);

    int[] srcStride = input.stride();
    checkCudnn(
        cudnnSetTensor4dDescriptorEx(
            cudnnContext.srcTensorDesc,
            dataType,
            miniBatch,
            inDepth,
            inH,
            inW,
            srcStride[0],
            srcStride[1],
            srcStride[2],
            srcStride[3]));
    checkCudnn(
        cudnnSetFilter4dDescriptor(
            cudnnContext.filterDesc, dataType, tensorFormat, outDepth, inDepth, kH, kW));
    checkCudnn(
        cudnnSetConvolution2dDescriptor(
            cudnnContext.convDesc,
            pad[0],
            pad[1],
            strides[0],
            strides[1],
            1,
            1,
            CUDNN_CROSS_CORRELATION));

    // find dimension of convolution output
    int[] algo = new int[1], n = new int[1], c = new int[1], h = new int[1], w = new int[1];
    checkCudnn(
        cudnnGetConvolution2dForwardOutputDim(
            cudnnContext.convDesc,
            cudnnContext.srcTensorDesc,
            cudnnContext.filterDesc,
            n,
            c,
            h,
            w));
    INDArray z = Nd4j.createUninitialized(new int[] {n[0], c[0], h[0], w[0]}, 'c');
    int[] dstStride = z.stride();
    checkCudnn(
        cudnnSetTensor4dDescriptorEx(
            cudnnContext.dstTensorDesc,
            dataType,
            n[0],
            c[0],
            h[0],
            w[0],
            dstStride[0],
            dstStride[1],
            dstStride[2],
            dstStride[3]));
    checkCudnn(
        cudnnGetConvolutionForwardAlgorithm(
            cudnnContext,
            cudnnContext.srcTensorDesc,
            cudnnContext.filterDesc,
            cudnnContext.convDesc,
            cudnnContext.dstTensorDesc,
            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
            0,
            algo));

    Allocator allocator = AtomicAllocator.getInstance();
    CudaContext context = allocator.getFlowController().prepareAction(input, weights, bias, z);
    Pointer srcData = allocator.getPointer(input, context);
    Pointer filterData = allocator.getPointer(weights, context);
    Pointer biasData = allocator.getPointer(bias, context);
    Pointer dstData = allocator.getPointer(z, context);

    checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
    checkCudnn(
        cudnnGetConvolutionForwardWorkspaceSize(
            cudnnContext,
            cudnnContext.srcTensorDesc,
            cudnnContext.filterDesc,
            cudnnContext.convDesc,
            cudnnContext.dstTensorDesc,
            algo[0],
            sizeInBytes));
    if (sizeInBytes.get(0) > workSpace.capacity()) {
      workSpace.deallocate();
      workSpace = new WorkSpace(sizeInBytes.get(0));
    }
    checkCudnn(
        cudnnConvolutionForward(
            cudnnContext,
            alpha,
            cudnnContext.srcTensorDesc,
            srcData,
            cudnnContext.filterDesc,
            filterData,
            cudnnContext.convDesc,
            algo[0],
            workSpace,
            workSpace.capacity(),
            beta,
            cudnnContext.dstTensorDesc,
            dstData));

    checkCudnn(
        cudnnSetTensor4dDescriptor(
            cudnnContext.biasTensorDesc, tensorFormat, dataType, 1, c[0], 1, 1));
    checkCudnn(
        cudnnAddTensor(
            cudnnContext,
            alpha,
            cudnnContext.biasTensorDesc,
            biasData,
            alpha,
            cudnnContext.dstTensorDesc,
            dstData));

    allocator.registerAction(context, input, weights, bias, z);

    return z;
  }
  @Override
  public INDArray activate(INDArray z, String afn) {
    INDArray activation = z;

    Allocator allocator = AtomicAllocator.getInstance();
    CudaContext context = allocator.getFlowController().prepareAction(z);
    Pointer dstData = allocator.getPointer(z, context);

    checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
    switch (afn) {
      case "identity":
        break;
      case "sigmoid":
        checkCudnn(
            cudnnSetActivationDescriptor(
                cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0));
        checkCudnn(
            cudnnActivationForward(
                cudnnContext,
                cudnnContext.activationDesc,
                alpha,
                cudnnContext.dstTensorDesc,
                dstData,
                beta,
                cudnnContext.dstTensorDesc,
                dstData));
        break;
      case "relu":
        checkCudnn(
            cudnnSetActivationDescriptor(
                cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
        checkCudnn(
            cudnnActivationForward(
                cudnnContext,
                cudnnContext.activationDesc,
                alpha,
                cudnnContext.dstTensorDesc,
                dstData,
                beta,
                cudnnContext.dstTensorDesc,
                dstData));
        break;
      case "tanh":
        checkCudnn(
            cudnnSetActivationDescriptor(
                cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0));
        checkCudnn(
            cudnnActivationForward(
                cudnnContext,
                cudnnContext.activationDesc,
                alpha,
                cudnnContext.dstTensorDesc,
                dstData,
                beta,
                cudnnContext.dstTensorDesc,
                dstData));
        break;
      case "softmax":
        checkCudnn(
            cudnnSoftmaxForward(
                cudnnContext,
                CUDNN_SOFTMAX_ACCURATE,
                CUDNN_SOFTMAX_MODE_CHANNEL,
                alpha,
                cudnnContext.dstTensorDesc,
                dstData,
                beta,
                cudnnContext.dstTensorDesc,
                dstData));
        break;
      case "logsoftmax":
        checkCudnn(
            cudnnSoftmaxForward(
                cudnnContext,
                CUDNN_SOFTMAX_LOG,
                CUDNN_SOFTMAX_MODE_CHANNEL,
                alpha,
                cudnnContext.dstTensorDesc,
                dstData,
                beta,
                cudnnContext.dstTensorDesc,
                dstData));
        break;
      default:
        activation = null;
    }

    allocator.registerAction(context, z);

    return activation;
  }
  @Override
  public Pair<Gradient, INDArray> backpropGradient(
      INDArray input,
      INDArray weights,
      INDArray delta,
      int[] kernel,
      int[] strides,
      int[] pad,
      INDArray biasGradView,
      INDArray weightGradView,
      String afn) {
    int miniBatch = input.size(0);
    int inH = input.size(2);
    int inW = input.size(3);

    int outDepth = weights.size(0);
    int inDepth = weights.size(1);
    int kH = weights.size(2);
    int kW = weights.size(3);

    int outH = Convolution.outSize(inH, kernel[0], strides[0], pad[0], false);
    int outW = Convolution.outSize(inW, kernel[1], strides[1], pad[1], false);

    if (!Shape.strideDescendingCAscendingF(delta)) {
      // apparently not supported by cuDNN
      delta = delta.dup();
    }

    int[] srcStride = input.stride();
    int[] deltaStride = delta.stride();
    int[] algo = new int[1];
    checkCudnn(
        cudnnSetTensor4dDescriptorEx(
            cudnnContext.srcTensorDesc,
            dataType,
            miniBatch,
            inDepth,
            inH,
            inW,
            srcStride[0],
            srcStride[1],
            srcStride[2],
            srcStride[3]));
    checkCudnn(
        cudnnSetTensor4dDescriptorEx(
            cudnnContext.deltaTensorDesc,
            dataType,
            miniBatch,
            outDepth,
            outH,
            outW,
            deltaStride[0],
            deltaStride[1],
            deltaStride[2],
            deltaStride[3]));
    checkCudnn(
        cudnnSetConvolution2dDescriptor(
            cudnnContext.convDesc,
            pad[0],
            pad[1],
            strides[0],
            strides[1],
            1,
            1,
            CUDNN_CROSS_CORRELATION));
    checkCudnn(
        cudnnSetFilter4dDescriptor(
            cudnnContext.filterDesc, dataType, tensorFormat, outDepth, inDepth, kH, kW));
    checkCudnn(
        cudnnGetConvolutionBackwardFilterAlgorithm(
            cudnnContext,
            cudnnContext.srcTensorDesc,
            cudnnContext.deltaTensorDesc,
            cudnnContext.convDesc,
            cudnnContext.filterDesc,
            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
            0,
            algo));

    INDArray epsNext = Nd4j.create(new int[] {miniBatch, inDepth, inH, inW}, 'c');
    int[] dstStride = epsNext.stride();

    Allocator allocator = AtomicAllocator.getInstance();
    CudaContext context =
        allocator
            .getFlowController()
            .prepareAction(input, weights, weightGradView, biasGradView, delta, epsNext);
    Pointer srcData = allocator.getPointer(input, context);
    Pointer filterData = allocator.getPointer(weights, context);
    Pointer filterGradData = allocator.getPointer(weightGradView, context);
    Pointer biasGradData = allocator.getPointer(biasGradView, context);
    Pointer deltaData = allocator.getPointer(delta, context);
    Pointer dstData = allocator.getPointer(epsNext, context);

    checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getOldStream())));
    checkCudnn(
        cudnnSetTensor4dDescriptorEx(
            cudnnContext.dstTensorDesc,
            dataType,
            miniBatch,
            inDepth,
            inH,
            inW,
            dstStride[0],
            dstStride[1],
            dstStride[2],
            dstStride[3]));
    checkCudnn(
        cudnnGetConvolutionBackwardFilterWorkspaceSize(
            cudnnContext,
            cudnnContext.srcTensorDesc,
            cudnnContext.deltaTensorDesc,
            cudnnContext.convDesc,
            cudnnContext.filterDesc,
            algo[0],
            sizeInBytes));
    long sizeInBytes1 = sizeInBytes.get(0);
    checkCudnn(
        cudnnGetConvolutionBackwardDataWorkspaceSize(
            cudnnContext,
            cudnnContext.filterDesc,
            cudnnContext.deltaTensorDesc,
            cudnnContext.convDesc,
            cudnnContext.dstTensorDesc,
            algo[0],
            sizeInBytes));
    long sizeInBytes2 = sizeInBytes.get(0);
    if (sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) {
      workSpace.deallocate();
      workSpace = new WorkSpace(Math.max(sizeInBytes1, sizeInBytes2));
    }

    checkCudnn(
        cudnnSetTensor4dDescriptor(
            cudnnContext.biasTensorDesc, tensorFormat, dataType, 1, outDepth, 1, 1));
    checkCudnn(
        cudnnConvolutionBackwardBias(
            cudnnContext,
            alpha,
            cudnnContext.deltaTensorDesc,
            deltaData,
            beta,
            cudnnContext.biasTensorDesc,
            biasGradData));
    checkCudnn(
        cudnnConvolutionBackwardFilter(
            cudnnContext,
            alpha,
            cudnnContext.srcTensorDesc,
            srcData,
            cudnnContext.deltaTensorDesc,
            deltaData,
            cudnnContext.convDesc,
            algo[0],
            workSpace,
            workSpace.capacity(),
            beta,
            cudnnContext.filterDesc,
            filterGradData));
    checkCudnn(
        cudnnConvolutionBackwardData(
            cudnnContext,
            alpha,
            cudnnContext.filterDesc,
            filterData,
            cudnnContext.deltaTensorDesc,
            deltaData,
            cudnnContext.convDesc,
            algo[0],
            workSpace,
            workSpace.capacity(),
            beta,
            cudnnContext.dstTensorDesc,
            dstData));

    allocator.registerAction(context, input, weights, weightGradView, biasGradView, delta, epsNext);

    Gradient retGradient = new DefaultGradient();
    retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView);
    retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c');

    return new Pair<>(retGradient, epsNext);
  }