@Override
  public void init(Map<String, INDArray> params, NeuralNetConfiguration conf) {
    if (conf.getKernelSize().length < 2)
      throw new IllegalArgumentException("Filter size must be == 2");

    params.put(BIAS_KEY, createBias(conf));
    params.put(WEIGHT_KEY, createWeightMatrix(conf));
    conf.addVariable(WEIGHT_KEY);
    conf.addVariable(BIAS_KEY);
  }
 protected INDArray createWeightMatrix(NeuralNetConfiguration conf) {
   /**
    * Create a 4d weight matrix of: (number of kernels, num input channels, kernel height, kernel
    * width) Inputs to the convolution layer are: (batch size, num input feature maps, image
    * height, image width)
    */
   Distribution dist = Distributions.createDistribution(conf.getDist());
   return WeightInitUtil.initWeights(
       Ints.concat(new int[] {conf.getNOut(), conf.getNIn()}, conf.getKernelSize()),
       conf.getWeightInit(),
       dist);
 }