@Override
  public void init(Map<String, INDArray> params, NeuralNetConfiguration conf) {
    Distribution dist = Distributions.createDistribution(conf.getDist());

    int nL = conf.getNOut(); // i.e., n neurons in this layer
    int nLast = conf.getNIn(); // i.e., n neurons in previous layer

    conf.addVariable(RECURRENT_WEIGHTS);
    conf.addVariable(INPUT_WEIGHTS);
    conf.addVariable(BIAS);

    params.put(
        RECURRENT_WEIGHTS, WeightInitUtil.initWeights(nL, 4 * nL + 3, conf.getWeightInit(), dist));
    params.put(
        INPUT_WEIGHTS, WeightInitUtil.initWeights(nLast, 4 * nL, conf.getWeightInit(), dist));
    INDArray biases =
        Nd4j.zeros(1, 4 * nL); // Order: input, forget, output, input modulation, i.e., IFOG
    biases.put(
        new NDArrayIndex[] {NDArrayIndex.interval(nL, 2 * nL), new NDArrayIndex(0)},
        Nd4j.ones(1, nL).muli(5));
    /*The above line initializes the forget gate biases to 5.
     * See Sutskever PhD thesis, pg19:
     * "it is important for [the forget gate activations] to be approximately 1 at the early stages of learning,
     *  which is accomplished by initializing [the forget gate biases] to a large value (such as 5). If it is
     *  not done, it will be harder to learn long range dependencies because the smaller values of the forget
     *  gates will create a vanishing gradients problem."
     *  http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf
     */
    params.put(BIAS, biases);

    params.get(RECURRENT_WEIGHTS).data().persist();
    params.get(INPUT_WEIGHTS).data().persist();
    params.get(BIAS).data().persist();
  }
 protected INDArray createWeightMatrix(NeuralNetConfiguration conf) {
   /**
    * Create a 4d weight matrix of: (number of kernels, num input channels, kernel height, kernel
    * width) Inputs to the convolution layer are: (batch size, num input feature maps, image
    * height, image width)
    */
   Distribution dist = Distributions.createDistribution(conf.getDist());
   return WeightInitUtil.initWeights(
       Ints.concat(new int[] {conf.getNOut(), conf.getNIn()}, conf.getKernelSize()),
       conf.getWeightInit(),
       dist);
 }