@Override
 public void handleLearningEvent(LearningEvent event) {
   BackPropagation bp = (BackPropagation) event.getSource();
   LOG.error("Current iteration: " + bp.getCurrentIteration());
   LOG.error("Error: " + bp.getTotalNetworkError());
   LOG.error("Calculation time: " + (System.currentTimeMillis() - start) / 1000.0);
   //   neuralNetwork.save(bp.getCurrentIteration() + "CNN_MNIST" + bp.getCurrentIteration() +
   // ".nnet");
   start = System.currentTimeMillis();
   //            NeuralNetworkEvaluationService.completeEvaluation(neuralNetwork, testSet);
 }
  public static void main(String[] args) {
    try {

      DataSet trainSet =
          MNISTDataSet.createFromFile(
              MNISTDataSet.TRAIN_LABEL_NAME, MNISTDataSet.TRAIN_IMAGE_NAME, 60000);
      DataSet testSet =
          MNISTDataSet.createFromFile(
              MNISTDataSet.TEST_LABEL_NAME, MNISTDataSet.TEST_IMAGE_NAME, 10000);

      Layer2D.Dimensions inputDimension = new Layer2D.Dimensions(32, 32);
      Kernel convolutionKernel = new Kernel(5, 5);
      Kernel poolingKernel = new Kernel(2, 2);

      ConvolutionalNetwork convolutionNetwork =
          new ConvolutionalNetwork.Builder(inputDimension, 1)
              .withConvolutionLayer(convolutionKernel, 10)
              .withPoolingLayer(poolingKernel)
              .withConvolutionLayer(convolutionKernel, 1)
              .withPoolingLayer(poolingKernel)
              .withConvolutionLayer(convolutionKernel, 1)
              .withFullConnectedLayer(10)
              .createNetwork();

      BackPropagation backPropagation = new MomentumBackpropagation();
      backPropagation.setLearningRate(0.0001);
      backPropagation.setMaxError(0.00001);
      backPropagation.setMaxIterations(500);
      backPropagation.addListener(new LearningListener());
      backPropagation.setErrorFunction(new MeanSquaredError());

      convolutionNetwork.setLearningRule(backPropagation);
      backPropagation.addListener(new LearningListener());

      System.out.println("Started training...");

      convolutionNetwork.learn(testSet);

      System.out.println("Done training!");

      CrossValidation crossValidation = new CrossValidation(convolutionNetwork, testSet, 6);
      crossValidation.run();

      //           ClassificationMetrics validationResult =
      // crossValidation.computeErrorEstimate(convolutionNetwork, trainSet);
      // Evaluation.runFullEvaluation(convolutionNetwork, testSet);

      convolutionNetwork.save("/mnist.nnet");

    } catch (IOException e) {
      e.printStackTrace();
    }
  }