private void createOptimizationAlgorithm() {
   if (optimizationAlgorithm == OptimizationAlgorithm.CONJUGATE_GRADIENT) {
     opt = new VectorizedNonZeroStoppingConjugateGradient(this, this);
     opt.setTolerance(tolerance);
   } else {
     opt = new VectorizedDeepLearningGradientAscent(this, this);
     opt.setTolerance(tolerance);
     if (maxStep > 0) ((VectorizedDeepLearningGradientAscent) opt).setMaxStepSize(maxStep);
   }
 }
  public void train(INDArray x) {
    if (opt == null) {
      createOptimizationAlgorithm();
    }

    network.setInput(x);
    int epochs = extraParams.length < 3 ? 1000 : (int) extraParams[2];
    opt.setMaxIterations(epochs);
    opt.optimize(epochs);
    network.backProp(lr, epochs, extraParams);
  }