/**
   * Train and test the RBF net using 5x2-fold cross validation.
   *
   * @param size
   * @param dimensions
   * @param repeats
   * @param rbfBasisFunction
   * @param activationFunction
   * @param learningRate
   * @param clusters
   * @param momentums
   * @param hidden
   * @param dataset
   * @return 2D array of results
   */
  public static double[][] trainRBF(
      int size,
      int dimensions,
      int[] repeats,
      int rbfBasisFunction,
      ActivationFunction activationFunction,
      double learningRate,
      int[] clusters,
      double momentums,
      int hidden,
      double[][] dataset) {

    double[][] output = new double[4][10];
    double[] errors = new double[size * 10];
    double[] allValues = new double[size * 10];
    double[] allGuesses = new double[size * 10];

    double aboveRight = 0.0; // Total above and correct guesses
    double aboveWrong = 0.0; // Total above and incorrect guesses
    double belowRight = 0.0; // Total below and incorrect guesses
    double belowWrong = 0.0; // Total below and correct guesses
    double varianceSum = 0.0; // Sum of the error
    double realValue = 0.0; // sum of the actual values
    double estimatedValue = 0.0; // Sum of the approximated values
    double minError = 99999999.0; // minimum error
    double maxError = 0.0; // maximum error
    double percentError = 0.0; // sum of the percent errors
    System.out.println("RBF Neural Network:");
    // For 5 repetitions of 2-fold cross validation
    for (int i = 0; i < 5; i++) {
      System.out.println("Running cross-validation, on part: " + i);
      int count = i + 1;
      System.out.println("Two-fold cross validation repetition " + count);
      double[][][] datasets = partitionData(dataset);
      double[][] buildingData;

      int buildingSize = Math.min(clusters[0], datasets[0].length);
      buildingData = new double[buildingSize][];

      for (int j = 0; j < buildingSize; j++) {
        buildingData[j] = datasets[0][j];
      }

      // First: train on datasets[0], test on datasets[1]
      long startTime = System.currentTimeMillis();
      RBFNeuralNetwork rbf =
          RunRBF.testRBF(
              buildingData, datasets[0], learningRate, clusters[1], rbfBasisFunction, repeats[1]);
      long end = System.currentTimeMillis();
      System.out.println("Run time: " + (end - startTime));

      for (int j = 0; j < datasets[1].length; j++) {
        // Index of the output of the Rosenbrock function
        int index = datasets[1][j].length - 1;
        // Actual output of the Rosenbrock function
        double actualValue = datasets[1][j][index];
        realValue += actualValue;
        // Value plus or minus a constant percent of the actual value
        double offsetValue = plusOrMinus10(actualValue);
        // Approximated value
        double predictedValue = rbf.getResult(datasets[1][j]);
        estimatedValue += predictedValue;

        // Variance
        double variance = Math.abs(predictedValue - actualValue);

        // Store statistics aquired thus far
        errors[size * i + j * 2] = variance;
        allValues[size * i + j * 2] = realValue;
        allGuesses[size * i + j * 2] = predictedValue;

        // Calculate percent error
        percentError += 100 * ((predictedValue - actualValue) / actualValue);

        varianceSum += variance;
        if (minError > variance) {
          minError = variance;
        }

        if (maxError < variance) {
          maxError = variance;
        }

        // Predict above or below the function
        boolean abovePredicted = rbf.aboveValue(datasets[1][j], offsetValue);

        boolean aboveActual = actualValue < offsetValue;

        if (abovePredicted && aboveActual) {
          aboveWrong += 1.0;
        } else if (abovePredicted) {
          aboveRight += 1.0;
        } else if (aboveActual) {
          belowRight += 1.0;
        } else {
          belowWrong += 1.0;
        }
      }

      // Cross validation, swap testing and training sets
      rbf =
          RunRBF.testRBF(
              datasets[1], datasets[1], learningRate, clusters[1], rbfBasisFunction, repeats[1]);

      for (int j = 0; j < datasets[0].length; j++) {
        // Index of the output of the Rosenbrock function
        int index = datasets[0][j].length - 1;
        // Actual output of the Rosenbrock function
        double actualValue = datasets[0][j][index];
        realValue += actualValue;
        // Value plus or minus a constant percent of the actual value
        double offsetValue = plusOrMinus10(actualValue);
        // Approximated value
        double predictedValue = rbf.getResult(datasets[0][j]);
        estimatedValue += predictedValue;

        double variance = Math.abs(predictedValue - actualValue);
        // Store statistics aquired thus far
        errors[size * i + j * 2 + 1] = variance;
        allValues[size * i + j * 2 + 1] = realValue;
        allGuesses[size * i + j * 2 + 1] = predictedValue;
        percentError += 100 * ((predictedValue - actualValue) / actualValue);

        varianceSum += variance;
        if (minError > variance) {
          minError = variance;
        }

        if (maxError < variance) {
          maxError = variance;
        }

        // Predict above or below the function
        boolean abovePredicted = rbf.aboveValue(datasets[0][j], offsetValue);

        boolean aboveActual = actualValue < offsetValue;

        if (abovePredicted && aboveActual) {
          aboveWrong += 1.0;
        } else if (abovePredicted) {
          aboveRight += 1.0;
        } else if (aboveActual) {
          belowRight += 1.0;
        } else {
          belowWrong += 1.0;
        }
      }
    }

    // Store all data in the output array
    output[0][0] = aboveRight;
    output[0][1] = aboveWrong;
    output[0][2] = belowRight;
    output[0][3] = belowWrong;
    output[0][4] = varianceSum;
    output[0][5] = realValue;
    output[0][6] = estimatedValue;
    output[0][7] = minError;
    output[0][8] = maxError;
    output[0][9] = percentError;
    output[1] = errors;
    output[2] = allValues;
    output[3] = allGuesses;

    return output;
  }