@Test
  public void testSGDUpdater() {
    double lr = 0.05;

    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .learningRate(lr)
            .layer(
                new DenseLayer.Builder()
                    .nIn(nIn)
                    .nOut(nOut)
                    .updater(org.deeplearning4j.nn.conf.Updater.SGD)
                    .build())
            .build();

    int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = LayerFactories.getFactory(conf).create(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);

    updater.update(layer, gradient, -1, 1);

    Gradient gradientDup = new DefaultGradient();
    gradientDup.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
    gradientDup.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());

    for (Map.Entry<String, INDArray> entry : gradientDup.gradientForVariable().entrySet()) {
      val = entry.getValue();
      gradExpected = val.mul(lr);
      assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
    }
    assertEquals(lr, layer.conf().getLayer().getLearningRate(), 1e-4);
  }
  @Test
  public void testNoOpUpdater() {
    Random r = new Random(12345L);
    double lr = 0.5;

    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .learningRate(lr)
            .layer(
                new DenseLayer.Builder()
                    .nIn(nIn)
                    .nOut(nOut)
                    .updater(org.deeplearning4j.nn.conf.Updater.NONE)
                    .build())
            .build();

    int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = LayerFactories.getFactory(conf).create(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);

    for (int i = 0; i < weightGradient.length(); i++) weightGradient.putScalar(i, r.nextDouble());
    for (int i = 0; i < biasGradient.length(); i++) biasGradient.putScalar(i, r.nextDouble());

    gradient.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGradient);
    gradient.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGradient);

    updater.update(layer, gradient, -1, 1);

    INDArray weightGradActual = gradient.getGradientFor(DefaultParamInitializer.WEIGHT_KEY);
    INDArray biasGradActual = gradient.getGradientFor(DefaultParamInitializer.BIAS_KEY);

    assertEquals(weightGradient, weightGradActual);
    assertEquals(biasGradient, biasGradActual);
  }
  @Test
  public void testAdamUpdater() {
    INDArray m, v;
    double lr = 0.01;
    int iteration = 0;
    double beta1 = 0.8;
    double beta2 = 0.888;

    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .learningRate(lr)
            .iterations(iteration)
            .adamMeanDecay(beta1)
            .adamVarDecay(beta2)
            .layer(
                new DenseLayer.Builder()
                    .nIn(nIn)
                    .nOut(nOut)
                    .updater(org.deeplearning4j.nn.conf.Updater.ADAM)
                    .build())
            .build();

    int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = LayerFactories.getFactory(conf).create(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    int updaterStateSize = updater.stateSizeForLayer(layer);
    INDArray updaterState = Nd4j.create(1, updaterStateSize);
    updater.setStateViewArray(layer, updaterState, true);

    updater.update(layer, gradient, iteration, 1);

    double beta1t = FastMath.pow(beta1, iteration);
    double beta2t = FastMath.pow(beta2, iteration);
    double alphat = lr * FastMath.sqrt(1 - beta2t) / (1 - beta1t);
    if (Double.isNaN(alphat) || alphat == 0.0) alphat = epsilon;

    Gradient gradientDup = new DefaultGradient();
    gradientDup.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient);
    gradientDup.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient);

    for (Map.Entry<String, INDArray> entry : gradientDup.gradientForVariable().entrySet()) {
      val = entry.getValue();
      m = Nd4j.zeros(val.shape());
      v = Nd4j.zeros(val.shape());

      m.muli(beta1).addi(val.mul(1.0 - beta1));
      v.muli(beta2).addi(val.mul(val).mul(1.0 - beta2));
      gradExpected = m.mul(alphat).divi(Transforms.sqrt(v).addi(epsilon));
      if (!gradExpected.equals(gradient.getGradientFor(entry.getKey()))) {
        System.out.println(Arrays.toString(gradExpected.dup().data().asFloat()));
        System.out.println(
            Arrays.toString(gradient.getGradientFor(entry.getKey()).dup().data().asFloat()));
      }
      assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
    }

    assertEquals(beta1, layer.conf().getLayer().getAdamMeanDecay(), 1e-4);
    assertEquals(beta2, layer.conf().getLayer().getAdamVarDecay(), 1e-4);
  }
  @Test
  public void testRMSPropUpdater() {
    double lr = 0.01;
    double rmsDecay = 0.25;
    Map<String, INDArray> lastG = new HashMap<>();

    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .learningRate(lr)
            .rmsDecay(rmsDecay)
            .layer(
                new DenseLayer.Builder()
                    .nIn(nIn)
                    .nOut(nOut)
                    .updater(org.deeplearning4j.nn.conf.Updater.RMSPROP)
                    .build())
            .build();

    int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = LayerFactories.getFactory(conf).create(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    int updaterStateSize = updater.stateSizeForLayer(layer);
    INDArray updaterState = Nd4j.create(1, updaterStateSize);
    updater.setStateViewArray(layer, updaterState, true);

    updater.update(layer, gradient, -1, 1);

    Gradient gradientDup = new DefaultGradient();
    gradientDup.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
    gradientDup.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());

    for (Map.Entry<String, INDArray> entry : gradientDup.gradientForVariable().entrySet()) {
      key = entry.getKey();
      val = entry.getValue();
      INDArray lastGTmp = lastG.get(key);

      if (lastGTmp == null) lastGTmp = Nd4j.zeros(val.shape());

      lastGTmp.muli(rmsDecay).addi(val.mul(val).muli(1 - rmsDecay));
      gradExpected = val.mul(lr).div(Transforms.sqrt(lastGTmp.add(Nd4j.EPS_THRESHOLD)));

      assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
      lastG.put(key, lastGTmp);
    }
    assertEquals(rmsDecay, layer.conf().getLayer().getRmsDecay(), 1e-4);
  }
  @Test
  public void testNestorovsUpdater() {
    double lr = 1e-2;
    double mu = 0.6;
    INDArray v, vPrev;

    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .learningRate(lr)
            .momentum(mu)
            .layer(
                new DenseLayer.Builder()
                    .nIn(nIn)
                    .nOut(nOut)
                    .updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS)
                    .build())
            .build();

    int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = LayerFactories.getFactory(conf).create(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    int updaterStateSize = updater.stateSizeForLayer(layer);
    INDArray updaterState = Nd4j.create(1, updaterStateSize);
    updater.setStateViewArray(layer, updaterState, true);

    updater.update(layer, gradient, -1, 1);

    Gradient gradientDup = new DefaultGradient();
    gradientDup.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
    gradientDup.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());

    for (Map.Entry<String, INDArray> entry : gradientDup.gradientForVariable().entrySet()) {
      val = entry.getValue();
      v = Nd4j.zeros(val.shape());
      vPrev = v;
      v = vPrev.mul(mu).subi(val.mul(lr));
      gradExpected = vPrev.muli(mu).addi(v.mul(-mu - 1));

      assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
    }

    assertEquals(mu, layer.conf().getLayer().getMomentum(), 1e-4);
  }
  @Test
  public void testAdaGradUpdater() {
    double lr = 1e-2;

    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .learningRate(lr)
            .layer(
                new DenseLayer.Builder()
                    .nIn(nIn)
                    .nOut(nOut)
                    .updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD)
                    .build())
            .build();

    int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = LayerFactories.getFactory(conf).create(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    int updaterStateSize = updater.stateSizeForLayer(layer);
    INDArray updaterState = Nd4j.create(1, updaterStateSize);
    updater.setStateViewArray(layer, updaterState, true);

    updater.update(layer, gradient, -1, 1);

    Gradient gradientDup = new DefaultGradient();
    gradientDup.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient);
    gradientDup.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient);

    for (Map.Entry<String, INDArray> entry : gradientDup.gradientForVariable().entrySet()) {
      val = entry.getValue();
      gradExpected = Transforms.sqrt(val.mul(val).add(epsilon)).rdiv(lr).mul(val);
      assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));
    }
    assertEquals(lr, layer.conf().getLayer().getLearningRate(), 1e-4);
  }
  @Test
  public void testAdaDeltaUpdate() {
    INDArray dxSquared;
    Map<String, INDArray> msg = new HashMap<>();
    Map<String, INDArray> msdx = new HashMap<>();

    double rho = 0.85;

    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .rho(rho)
            .layer(
                new DenseLayer.Builder()
                    .nIn(nIn)
                    .nOut(nOut)
                    .updater(org.deeplearning4j.nn.conf.Updater.ADADELTA)
                    .build())
            .build();

    int numParams = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
    INDArray params = Nd4j.create(1, numParams);
    Layer layer = LayerFactories.getFactory(conf).create(conf, null, 0, params, true);
    Updater updater = UpdaterCreator.getUpdater(layer);
    int updaterStateSize = updater.stateSizeForLayer(layer);
    INDArray updaterState = Nd4j.create(1, updaterStateSize);
    updater.setStateViewArray(layer, updaterState, true);

    Gradient gradientDup = new DefaultGradient();
    gradientDup.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, weightGradient.dup());
    gradientDup.setGradientFor(DefaultParamInitializer.BIAS_KEY, biasGradient.dup());

    for (int i = 0; i < 2; i++) {
      updater.update(layer, gradient, i, 1);

      // calculations for one iteration / update

      for (Map.Entry<String, INDArray> entry : gradientDup.gradientForVariable().entrySet()) {
        key = entry.getKey();
        val = entry.getValue();
        INDArray msgTmp = msg.get(key);
        INDArray msdxTmp = msdx.get(key);

        if (msgTmp == null) {
          msgTmp = Nd4j.zeros(val.shape());
          msdxTmp = Nd4j.zeros(val.shape());
        }

        msgTmp.muli(rho);
        msgTmp.addi(1 - rho).muli(val.mul(val));

        gradExpected =
            Transforms.sqrt(msdxTmp.add(Nd4j.EPS_THRESHOLD))
                .divi(Transforms.sqrt(msgTmp.add(Nd4j.EPS_THRESHOLD)))
                .muli(val);
        gradientDup.setGradientFor(key, gradExpected);
        assertEquals(gradExpected, gradient.getGradientFor(entry.getKey()));

        msdxTmp.muli(rho);
        dxSquared = gradExpected.mul(gradExpected);
        msdxTmp.addi(dxSquared.muli(1 - rho));

        msg.put(key, msgTmp);
        msdx.put(key, msdxTmp);
      }
      assertEquals(rho, layer.conf().getLayer().getRho(), 1e-4);
    }
  }
  @Test
  public void testMultiLayerUpdater() throws Exception {
    Nd4j.getRandom().setSeed(12345L);
    double lr = 0.03;

    MultiLayerConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .learningRate(lr)
            .momentum(0.6)
            .list()
            .layer(
                0,
                new DenseLayer.Builder()
                    .nIn(4)
                    .nOut(5)
                    .updater(org.deeplearning4j.nn.conf.Updater.SGD)
                    .build())
            .layer(
                1,
                new DenseLayer.Builder()
                    .nIn(5)
                    .nOut(6)
                    .updater(org.deeplearning4j.nn.conf.Updater.NONE)
                    .build())
            .layer(
                2,
                new DenseLayer.Builder()
                    .nIn(6)
                    .nOut(7)
                    .updater(org.deeplearning4j.nn.conf.Updater.ADAGRAD)
                    .build())
            .layer(
                3,
                new DenseLayer.Builder()
                    .nIn(7)
                    .nOut(8)
                    .updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS)
                    .build())
            .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    Updater updater = UpdaterCreator.getUpdater(net);
    assertNotNull(updater);
    assertTrue(updater.getClass() == MultiLayerUpdater.class);

    Field f = MultiLayerUpdater.class.getDeclaredField("layerUpdaters");
    f.setAccessible(true);
    Updater[] updaters = (Updater[]) f.get(updater);
    assertNotNull(updaters);
    assertTrue(updaters.length == net.getnLayers());
    assertTrue(updaters[0] instanceof SgdUpdater);
    assertTrue(updaters[1] instanceof NoOpUpdater);
    assertTrue(updaters[2] instanceof AdaGradUpdater);
    assertTrue(updaters[3] instanceof NesterovsUpdater);

    Updater[] uArr = new Updater[4];
    uArr[0] = new SgdUpdater();
    uArr[1] = new NoOpUpdater();
    uArr[2] = new AdaGradUpdater();
    int updaterStateSize = uArr[2].stateSizeForLayer(net.getLayer(2));
    INDArray updaterState = Nd4j.create(1, updaterStateSize);
    uArr[2].setStateViewArray(net.getLayer(2), updaterState, true);

    uArr[3] = new NesterovsUpdater();
    updaterStateSize = uArr[3].stateSizeForLayer(net.getLayer(3));
    updaterState = Nd4j.create(1, updaterStateSize);
    uArr[3].setStateViewArray(net.getLayer(3), updaterState, true);

    int[] nIns = {4, 5, 6, 7};
    int[] nOuts = {5, 6, 7, 8};

    for (int i = 0; i < 5; i++) {
      Gradient gradient = new DefaultGradient();
      Map<String, INDArray> expectedGradient = new LinkedHashMap<>();

      for (int j = 0; j < net.getnLayers(); j++) {
        // Generate test gradient:
        INDArray wGrad = Nd4j.rand(1, nIns[j] * nOuts[j]);
        INDArray bGrad = Nd4j.rand(1, nOuts[j]);

        String wKey = j + "_" + DefaultParamInitializer.WEIGHT_KEY;
        String bKey = j + "_" + DefaultParamInitializer.BIAS_KEY;

        gradient.setGradientFor(wKey, wGrad);
        gradient.setGradientFor(bKey, bGrad);

        // Also put copy of gradient through separate layer updaters to compare
        Gradient layerGradient = new DefaultGradient();
        layerGradient.setGradientFor(DefaultParamInitializer.WEIGHT_KEY, wGrad.dup());
        layerGradient.setGradientFor(DefaultParamInitializer.BIAS_KEY, bGrad.dup());

        uArr[j].update(net.getLayer(j), layerGradient, i, 1);
        for (String s : layerGradient.gradientForVariable().keySet()) {
          expectedGradient.put(j + "_" + s, layerGradient.getGradientFor(s));
        }
      }

      updater.update(net, gradient, i, 1);
      assertEquals(gradient.gradientForVariable(), expectedGradient);
    }
  }
  /**
   * Check backprop gradients for a MultiLayerNetwork.
   *
   * @param mln MultiLayerNetwork to test. This must be initialized.
   * @param epsilon Usually on the order/ of 1e-4 or so.
   * @param maxRelError Maximum relative error. Usually < 1e-5 or so, though maybe more for deep
   *     networks or those with nonlinear activation
   * @param minAbsoluteError Minimum absolute error to cause a failure. Numerical gradients can be
   *     non-zero due to precision issues. For example, 0.0 vs. 1e-18: relative error is 1.0, but
   *     not really a failure
   * @param print Whether to print full pass/failure details for each parameter gradient
   * @param exitOnFirstError If true: return upon first failure. If false: continue checking even if
   *     one parameter gradient has failed. Typically use false for debugging, true for unit tests.
   * @param input Input array to use for forward pass. May be mini-batch data.
   * @param labels Labels/targets to use to calculate backprop gradient. May be mini-batch data.
   * @return true if gradients are passed, false otherwise.
   */
  public static boolean checkGradients(
      MultiLayerNetwork mln,
      double epsilon,
      double maxRelError,
      double minAbsoluteError,
      boolean print,
      boolean exitOnFirstError,
      INDArray input,
      INDArray labels) {
    // Basic sanity checks on input:
    if (epsilon <= 0.0 || epsilon > 0.1)
      throw new IllegalArgumentException(
          "Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
    if (maxRelError <= 0.0 || maxRelError > 0.25)
      throw new IllegalArgumentException("Invalid maxRelativeError: " + maxRelError);
    if (!(mln.getOutputLayer() instanceof BaseOutputLayer))
      throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");

    // Check network configuration:

    int layerCount = 0;
    for (NeuralNetConfiguration n : mln.getLayerWiseConfigurations().getConfs()) {
      org.deeplearning4j.nn.conf.Updater u = n.getLayer().getUpdater();
      if (u == org.deeplearning4j.nn.conf.Updater.SGD) {
        // Must have LR of 1.0
        double lr = n.getLayer().getLearningRate();
        if (lr != 1.0) {
          throw new IllegalStateException(
              "When using SGD updater, must also use lr=1.0 for layer "
                  + layerCount
                  + "; got "
                  + u
                  + " with lr="
                  + lr);
        }
      } else if (u != org.deeplearning4j.nn.conf.Updater.NONE) {
        throw new IllegalStateException(
            "Must have Updater.NONE (or SGD + lr=1.0) for layer " + layerCount + "; got " + u);
      }
    }

    mln.setInput(input);
    mln.setLabels(labels);
    mln.computeGradientAndScore();
    Pair<Gradient, Double> gradAndScore = mln.gradientAndScore();

    Updater updater = UpdaterCreator.getUpdater(mln);
    updater.update(mln, gradAndScore.getFirst(), 0, mln.batchSize());

    INDArray gradientToCheck =
        gradAndScore
            .getFirst()
            .gradient()
            .dup(); // need dup: gradients are a *view* of the full gradient array (which will
                    // change every time backprop is done)
    INDArray originalParams =
        mln.params().dup(); // need dup: params are a *view* of full parameters

    int nParams = originalParams.length();

    int totalNFailures = 0;
    double maxError = 0.0;
    for (int i = 0; i < nParams; i++) {
      // (w+epsilon): Do forward pass and score
      INDArray params = originalParams.dup();
      params.putScalar(i, params.getDouble(i) + epsilon);
      mln.setParameters(params);
      mln.computeGradientAndScore();
      double scorePlus = mln.score();

      // (w-epsilon): Do forward pass and score
      params.putScalar(i, params.getDouble(i) - 2 * epsilon); // +eps - 2*eps = -eps
      mln.setParameters(params);
      mln.computeGradientAndScore();
      double scoreMinus = mln.score();

      // Calculate numerical parameter gradient:
      double scoreDelta = scorePlus - scoreMinus;

      double numericalGradient = scoreDelta / (2 * epsilon);
      if (Double.isNaN(numericalGradient))
        throw new IllegalStateException(
            "Numerical gradient was NaN for parameter " + i + " of " + nParams);

      double backpropGradient = gradientToCheck.getDouble(i);
      // http://cs231n.github.io/neural-networks-3/#gradcheck
      // use mean centered
      double relError =
          Math.abs(backpropGradient - numericalGradient)
              / (Math.abs(numericalGradient) + Math.abs(backpropGradient));
      if (backpropGradient == 0.0 && numericalGradient == 0.0)
        relError = 0.0; // Edge case: i.e., RNNs with time series length of 1.0

      if (relError > maxError) maxError = relError;
      if (relError > maxRelError || Double.isNaN(relError)) {
        double absError = Math.abs(backpropGradient - numericalGradient);
        if (absError < minAbsoluteError) {
          log.info(
              "Param "
                  + i
                  + " passed: grad= "
                  + backpropGradient
                  + ", numericalGrad= "
                  + numericalGradient
                  + ", relError= "
                  + relError
                  + "; absolute error = "
                  + absError
                  + " < minAbsoluteError = "
                  + minAbsoluteError);
        } else {
          if (print)
            log.info(
                "Param "
                    + i
                    + " FAILED: grad= "
                    + backpropGradient
                    + ", numericalGrad= "
                    + numericalGradient
                    + ", relError= "
                    + relError
                    + ", scorePlus="
                    + scorePlus
                    + ", scoreMinus= "
                    + scoreMinus);
          if (exitOnFirstError) return false;
          totalNFailures++;
        }
      } else if (print) {
        log.info(
            "Param "
                + i
                + " passed: grad= "
                + backpropGradient
                + ", numericalGrad= "
                + numericalGradient
                + ", relError= "
                + relError);
      }
    }

    if (print) {
      int nPass = nParams - totalNFailures;
      log.info(
          "GradientCheckUtil.checkGradients(): "
              + nParams
              + " params checked, "
              + nPass
              + " passed, "
              + totalNFailures
              + " failed. Largest relative error = "
              + maxError);
    }

    return totalNFailures == 0;
  }