/**
   * Transition function. Currently, we only get the list of probabilities of each item.
   *
   * @param current - The current state
   * @param action - The action
   * @param possible - The possible state
   * @return The list of probabilities where each index refers to the probability of that item
   */
  private List<Double> transition(State current, Action action, State possible) {
    List<Double> probs = new ArrayList<Double>(); // Probabilities
    Map<Integer, Integer> currentStock, purchase, possibleStock; // Maps
    double currentProb; // The current probability
    Matrix currentMatrix; // The current probability matrix
    int row, column; // The row and column

    for (int i = 0; i < current.getState().size(); ++i) {
      currentProb = 0.0;
      currentStock = current.getState();
      purchase = action.getPurchases();
      possibleStock = possible.getState();
      row = currentStock.get(i) + purchase.get(i);
      column = row - possible.getState().get(i);
      currentMatrix = this.probabilities.get(i);
      if (column < 0
          || column >= currentMatrix.getNumCols()
          || row >= currentMatrix.getNumRows()) { // Invalid state
        probs.add(0.0);
        continue;
      }
      if (possibleStock.get(i) > 0
          || (possibleStock.get(i) == 0 && column == 0)) { // Sufficiently provided
        currentProb = currentMatrix.get(row, column);
      } else if (possibleStock.get(i) == 0 && column > 0) {
        // Range of probabilities because user could have eaten plenty
        for (int j = column; j < currentMatrix.getNumCols(); ++j) {
          currentProb += currentMatrix.get(row, j);
        }
      }
      probs.add(currentProb);
    }
    return probs;
  }
  /*
   * We incorporate the ability to create an arbitrary network structure.
   * We use array of arrays of doubles for each inter-layer matrix
   * Thus, between each layer, we need a matrix of weights.
   * Num rows * num columns in matrix = nodes in layer below * nodes in layer above
   *
   * We use the Math library's pow function to raise to exponent: double pow(double base, double exponent)
   *
   *                       Hidden Nodes in current Layer (j)
   * previous layers nodes[                             ]
   *             Features [         Wij                 ]
   *                (i)   [                             ]
   *
   * I set up a matrix with dimensions: [ nodes in previous layer ] [ nodes in next layer ]
   *
   * Since we are traveling through one layer at a time, we need to have another data structure
   * that will be outputs for this layer
   *
   * I use for loops to initialize array of arrays ( allocated necessary memory)
   * Please note that: number of layers + 1 = number of weight arrays needed
   */
  public void train(Matrix features, Matrix labels) throws Exception {

    double[] recentAccuracies = new double[5];
    int currentAccuracyIndex = 0;
    double currentAccuracy = 0;

    Random rand = new Random();
    // SHUFFLE labels, features together
    features.shuffle(rand, labels);

    // need to map 0,1, or 2 to the three dimensional vectors, DO N-OF-K-ENCODING FOR THE
    // BACKPROPAGATION
    Matrix newNOfKLabelsMatrix = new Matrix();
    newNOfKLabelsMatrix.setSize(
        labels.rows(), labels.valueCount(0)); // I HARD CODE IN THAT THERE SHOULD BE 3 OUTPUT NODES
    for (int row = 0; row < newNOfKLabelsMatrix.rows(); row++) { // for each instance
      for (int k = 0; k < labels.valueCount(0); k++) {
        if (labels.get(row, 0) == k) {
          for (int m = 0; m < labels.valueCount(0); m++) {
            newNOfKLabelsMatrix.set(row, m, 0);
          }
          newNOfKLabelsMatrix.set(row, k, 1);
        }
      }
    }
    labels = newNOfKLabelsMatrix;

    // IMMEDIATELY SAVE SOME OF THIS, NEVER WILL TRAIN ON THESE
    // STICK THESE INTO A VALIDATION SET
    // ONCE MSE STARTS TO INCREASE AGAIN ON THE VALIDATION SET, WE'VE GONE TOO FAR
    int numRowsToGetIntoTrainingSet = (int) (features.rows() * validationSetPercentageOfData);

    Matrix featuresForTrainingTrimmed = new Matrix();
    featuresForTrainingTrimmed.setSize(numRowsToGetIntoTrainingSet, features.cols());
    Matrix featuresValidationSet = new Matrix();
    featuresValidationSet.setSize(features.rows() - numRowsToGetIntoTrainingSet, features.cols());

    Matrix labelsForTrainingTrimmed = new Matrix();
    labelsForTrainingTrimmed.setSize(numRowsToGetIntoTrainingSet, labels.cols());
    Matrix labelsValidationSet = new Matrix();
    labelsValidationSet.setSize(features.rows() - numRowsToGetIntoTrainingSet, labels.cols());

    // LOOP THROUGH AND PUT MOST OF FEATURES INTO featuresForTrainingTrimmed
    for (int row = 0; row < features.rows(); row++) {
      for (int col = 0; col < features.cols(); col++) {
        if (row < numRowsToGetIntoTrainingSet) {
          featuresForTrainingTrimmed.set(row, col, features.get(row, col));
        } else {
          featuresValidationSet.set(row - numRowsToGetIntoTrainingSet, col, features.get(row, col));
        }
      }
    }

    // LOOP THROUGH AND PUT MOST OF FEATURES INTO featuresForTrainingTrimmed
    for (int row = 0; row < labels.rows(); row++) {
      for (int col = 0; col < labels.cols(); col++) {
        if (row < numRowsToGetIntoTrainingSet) {
          labelsForTrainingTrimmed.set(row, col, labels.get(row, col));
        } else {
          labelsValidationSet.set(row - numRowsToGetIntoTrainingSet, col, labels.get(row, col));
        }
      }
    }

    features = featuresForTrainingTrimmed;
    labels = labelsForTrainingTrimmed;
    // LOOP THROUGH AND PUT LEFTOVER PORTION OF FEATURES INTO validationSet
    arrayListOfEachLayersWeightMatrices = new ArrayList<double[][]>();

    for (int i = 0; i < numHiddenLayers + 1; i++) { // each layer
      double[][] specificLayersWeightMatrix;
      if (i == 0) { // first hidden layer (Each layer owns its own weights)
        specificLayersWeightMatrix =
            new double[features.cols()][numNodesPerHiddenLayer[i]]; // INPUTS are the rows
      } else if (i == numHiddenLayers) {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][labels.cols()]; // OUTPUTS ARE THE COLUMNS
      } else {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][numNodesPerHiddenLayer[i]];
      }
      arrayListOfEachLayersWeightMatrices.add(specificLayersWeightMatrix);
    }

    changeInWeightMatricesForEveryLayer = new ArrayList<double[][]>();

    for (int i = 0; i < numHiddenLayers + 1; i++) { // each layer
      double[][] specificLayersWeightMatrix;
      if (i == 0) { // first hidden layer (Each layer owns its own weights)
        specificLayersWeightMatrix =
            new double[features.cols()][numNodesPerHiddenLayer[i]]; // INPUTS are the rows
      } else if (i == numHiddenLayers) {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][labels.cols()]; // OUTPUTS ARE THE COLUMNS
      } else {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][numNodesPerHiddenLayer[i]];
      }
      changeInWeightMatricesForEveryLayer.add(specificLayersWeightMatrix);
    }

    // allocate space/ initialize the previous change in weights that we'll use for momentum
    temporaryStashChangeInWeightMatricesForEveryLayer = new ArrayList<double[][]>();

    for (int i = 0; i < numHiddenLayers + 1; i++) { // each layer
      double[][] specificLayersWeightMatrix;
      if (i == 0) { // first hidden layer (Each layer owns its own weights)
        specificLayersWeightMatrix =
            new double[features.cols()][numNodesPerHiddenLayer[i]]; // INPUTS are the rows
      } else if (i == numHiddenLayers) {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][labels.cols()]; // OUTPUTS ARE THE COLUMNS
      } else {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][numNodesPerHiddenLayer[i]];
      }
      temporaryStashChangeInWeightMatricesForEveryLayer.add(specificLayersWeightMatrix);
    }

    // ALLOCATE SPACE FOR DELTA ( INTERMEDIATE VALUES THAT WE USE TO UPDATE THE WEIGHTS)

    arrayListOfEachLayersDeltaArray = new ArrayList<double[]>();
    //  EACH LAYER HAS AN ARRAY OF DELTA VALUES
    for (int i = 0;
        i < numHiddenLayers + 2;
        i++) { // each layer  // OF COURSE WE COULD HAVE DONE numHiddenLayers + 1, but I want
               // consistency with fnet ArrayList
      double[] specificLayersDeltaArray;
      if (i == 0) { // first hidden layer (Each layer owns its own weights)
        specificLayersDeltaArray = new double[features.cols()]; // INPUTS are the rows
      } else if (i == (numHiddenLayers + 1)) {
        // specificLayersDeltaArray = new double[ numNodesPerHiddenLayer[ i-1 ] ]  ; //[
        // numNodesPerHiddenLayer[ labels.cols() ] ] ; // OUTPUTS ARE THE COLUMNS
        specificLayersDeltaArray = new double[labels.cols()]; // FIND OUT # NODES AT EACH LEVEL
      } else {
        specificLayersDeltaArray = new double[numNodesPerHiddenLayer[i - 1]];
      }
      arrayListOfEachLayersDeltaArray.add(specificLayersDeltaArray);
    }

    previousChangeInWeightMatricesForEachLayer = new ArrayList<double[][]>();

    for (int i = 0; i < numHiddenLayers + 1; i++) { // each layer
      double[][] specificLayersWeightMatrix;
      if (i == 0) { // first hidden layer (Each layer owns its own weights)
        specificLayersWeightMatrix =
            new double[features.cols()][numNodesPerHiddenLayer[i]]; // INPUTS are the rows
      } else if (i == numHiddenLayers) {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][labels.cols()]; // OUTPUTS ARE THE COLUMNS
      } else {
        specificLayersWeightMatrix =
            new double[numNodesPerHiddenLayer[i - 1]][numNodesPerHiddenLayer[i]];
      }
      previousChangeInWeightMatricesForEachLayer.add(specificLayersWeightMatrix);
    }

    // INITIALIZE ALL OF PREVIOUS DELTA VALUES TO 0 [ THIS IS DONE AUTOMATICALLY, CAN DELETE ALL OF
    // THIS CODE ]

    // initialize all weights randomly ( small random weights with 0 mean)

    double[][] currentLayersWeightMatrix;
    for (int i = 0; i < numNodesPerHiddenLayer.length + 1; i++) { // scroll across each layer

      currentLayersWeightMatrix = arrayListOfEachLayersWeightMatrices.get(i);
      for (int j = 0; j < currentLayersWeightMatrix.length; j++) {
        for (int k = 0; k < currentLayersWeightMatrix[j].length; k++) {
          currentLayersWeightMatrix[j][k] = (2 * rand.nextDouble()) - 1;
        }
      }
    }

    // GO THROUGH AND ADD THE SPECIFIC WEIGHTS
    // Initial Weights:

    // PUT ALL BIAS WEIGHTS INTO ARRAYLIST (ONE ARRAY FOR EACH LAYER'S BIAS WEIGHTS)
    biasWeightsAcrossAllLayers = new ArrayList<double[]>();
    for (int i = 0; i < numHiddenLayers + 1; i++) {
      if (i < numHiddenLayers) {
        double[] biasArrayToBeAdded = new double[numNodesPerHiddenLayer[i]];
        biasWeightsAcrossAllLayers.add(biasArrayToBeAdded);
      } else {
        double[] biasArrayForOutputNodesToBeAdded = new double[labels.cols()];
        biasWeightsAcrossAllLayers.add(biasArrayForOutputNodesToBeAdded);
      }
    }

    double[] currentBiasLayersWeightArray;
    for (int i = 0; i < numNodesPerHiddenLayer.length + 1; i++) { // scroll across each layer
      currentBiasLayersWeightArray = biasWeightsAcrossAllLayers.get(i);
      for (int j = 0; j < currentBiasLayersWeightArray.length; j++) {

        currentBiasLayersWeightArray[j] = (2 * rand.nextDouble()) - 1;
      }
    }

    // We'll need to store the previous bias weights
    previousBiasChangeInWeightsAcrossAllLayers = new ArrayList<double[]>();
    for (int i = 0; i < numHiddenLayers + 1; i++) {
      if (i < numHiddenLayers) {
        double[] biasArrayToBeAdded = new double[numNodesPerHiddenLayer[i]];
        previousBiasChangeInWeightsAcrossAllLayers.add(biasArrayToBeAdded);
      } else {
        double[] biasArrayForOutputNodesToBeAdded = new double[labels.cols()];
        previousBiasChangeInWeightsAcrossAllLayers.add(biasArrayForOutputNodesToBeAdded);
      }
    }

    // temporarily stashed bias weights across all layers
    temporarilyStashedChangeInBiasWeightsAcrossAllLayers = new ArrayList<double[]>();
    for (int i = 0; i < numHiddenLayers + 1; i++) {
      if (i < numHiddenLayers) {
        double[] biasArrayToBeAdded = new double[numNodesPerHiddenLayer[i]];
        temporarilyStashedChangeInBiasWeightsAcrossAllLayers.add(biasArrayToBeAdded);
      } else {
        double[] biasArrayForOutputNodesToBeAdded = new double[labels.cols()];
        temporarilyStashedChangeInBiasWeightsAcrossAllLayers.add(biasArrayForOutputNodesToBeAdded);
      }
    }

    changeInBiasArrayForEveryLayer = new ArrayList<double[]>();
    for (int i = 0; i < numHiddenLayers + 1; i++) {
      if (i < numHiddenLayers) {
        double[] biasArrayToBeAdded = new double[numNodesPerHiddenLayer[i]];
        changeInBiasArrayForEveryLayer.add(biasArrayToBeAdded);
      } else {
        double[] biasArrayForOutputNodesToBeAdded = new double[labels.cols()];
        changeInBiasArrayForEveryLayer.add(biasArrayForOutputNodesToBeAdded);
      }
    }

    // INITIALIZE BIAS FOR HIDDEN AND OUTPUT NEURONS

    // Stochastic weight update
    // SOMEHOW GOT TO INITIALIZE ALL OF THIS, ADD BLANKS, SO THAT LATER WE CAN
    // storedFNetForEachLayer.set( i, blah );

    storedFNetForEachLayer =
        new ArrayList<double[]>(); // f_net is the output that is fed into the next layer
    for (int i = 0;
        i < numHiddenLayers + 2;
        i++) { // WE HAVE ONE MORE layer of fnet( consider inputs as fnet)
      double[] thisLayersFNetValues;
      // COULD DO IF/ELSE STATEMENTS IF WE ARE LOOKING AT INPUTS, OR THEN HIDDEN NODES,
      if (i == 0) {
        thisLayersFNetValues = new double[features.cols()]; // FIND OUT # NODES AT EACH LEVEL
      } else if (i == numHiddenLayers + 1) { // OR IS IT +1
        thisLayersFNetValues = new double[labels.cols()]; // FIND OUT # NODES AT EACH LEVEL
      } else {
        thisLayersFNetValues =
            new double[numNodesPerHiddenLayer[i - 1]]; // FIND OUT # NODES AT EACH LEVEL
      }
      storedFNetForEachLayer.add(thisLayersFNetValues);
    }

    // -----BEGIN THE TRAINING-----
    double netValAtNode = 0;
    double fOfNetValAtNode = 0;
    for (int epoch = 0;
        epoch < 10000;
        epoch++) { // For each epoch, cap it at 10000, we want to avoid infinite loop
      System.out.println("---Epoch " + epoch + "---");
      for (int instance = 0;
          instance < features.rows();
          instance++) { // later we will swap this Matrix for featuresForTrainingTrimmed
        // GO FORWARD
        // ---------------------------------------------------------------------------------------------------------------------
        //				System.out.println("Forward propagating...");
        for (int layer = 0;
            layer < numHiddenLayers + 2;
            layer++) { // HERE LAYER DENOTES HIDDEN LAYER
          if (layer == 0) {
            storedFNetForEachLayer.set(
                layer, Arrays.copyOf(features.row(instance), features.row(0).length));
            continue;
          }
          double[] thisLayersFNetValues =
              storedFNetForEachLayer.get(
                  layer); // make a new array of doubles  CAN I PLEASE DELETE THIS LINE OF CODE
          for (int node = 0; node < storedFNetForEachLayer.get(layer).length; node++) {
            netValAtNode = 0;
            // FIND THE CROSS PRODUCT;
            // use a for loop to multiply each col of weights vector by each col of
            // outputsFromPreviousLayer
            for (int colInInputVector = 0;
                colInInputVector < storedFNetForEachLayer.get(layer - 1).length;
                colInInputVector++) {
              netValAtNode +=
                  (storedFNetForEachLayer.get(layer - 1)[colInInputVector]
                      * arrayListOfEachLayersWeightMatrices.get(layer - 1)[colInInputVector][node]);
            }
            netValAtNode += (biasWeightsAcrossAllLayers.get(layer - 1)[node]);
            if (netValAtNode < 0) { // make special function
              fOfNetValAtNode = (1 / (1 + Math.pow(Math.E, (-1 * netValAtNode))));
            } else { // normal
              fOfNetValAtNode =
                  (1
                      / (1
                          + (1
                              / (Math.pow(
                                  Math.E,
                                  (netValAtNode)))))); // if it was positive, then we raise to neg
                                                       // exponent
            }
            thisLayersFNetValues[node] = fOfNetValAtNode; // stick it into the object
          }
          storedFNetForEachLayer.set(
              layer,
              thisLayersFNetValues); // or if we are editing object, this is not even necessary
                                     // DOUBLE CHECK
        }
        // ---NOW FOR THIS INSTANCE, GO
        // BACKWARDS-----------------------------------------------------------------------------------------------------------------------
        // System.out.println("Back propagating...");
        // UPDATE THE WEIGHTS
        for (int layer = numHiddenLayers + 1; layer > 0; layer--) { // ACROSS EACH LAYER BACKWARD
          if (layer == numHiddenLayers + 1) { // THIS IS AN OUTPUT LAYER
            for (int node = 0; node < labels.cols(); node++) {
              double deltaArrayForThisLayer[] = arrayListOfEachLayersDeltaArray.get(layer);
              deltaArrayForThisLayer[node] =
                  ((labels.get(instance, node) - storedFNetForEachLayer.get(layer)[node])
                      * (storedFNetForEachLayer.get(layer)[node])
                      * (1 - (storedFNetForEachLayer.get(layer)[node])));
              // should automatically be set since we get the objects address from heap memory, and
              // change it
              for (int inputToThisNode = 0;
                  inputToThisNode < numNodesPerHiddenLayer[layer - 2] + 1;
                  inputToThisNode++) {
                double changeInWeightBetweenIJ = 0;
                if (inputToThisNode == numNodesPerHiddenLayer[layer - 2]) { // this is a bias node

                  changeInWeightBetweenIJ =
                      (learningRate
                          * 1
                          * arrayListOfEachLayersDeltaArray
                              .get(layer)[node]); // NEED TO ADD STUFF FOR MOMENTUM
                  double[] thisLayersBiasWeights =
                      changeInBiasArrayForEveryLayer.get(
                          layer - 1); // NEED TO ADD STUFF FOR MOMENTUM
                  thisLayersBiasWeights[node] =
                      (changeInWeightBetweenIJ); // NEED TO ADD STUFF FOR MOMENTUM
                } else {

                  changeInWeightBetweenIJ =
                      (learningRate
                          * storedFNetForEachLayer.get(layer - 1)[inputToThisNode]
                          * arrayListOfEachLayersDeltaArray.get(layer)[node]);
                  // double[][] thisLayersWeightMatrix =
                  // arrayListOfEachLayersWeightMatrices.get(layer-1);
                  // thisLayersWeightMatrix[inputToThisNode][node] += ( changeInWeightBetweenIJ );
                  double[][] changeInWeightsMatrixForThisLayer =
                      changeInWeightMatricesForEveryLayer.get(layer - 1);
                  changeInWeightsMatrixForThisLayer[inputToThisNode][node] =
                      changeInWeightBetweenIJ;
                }
              }
            }
          } else {

            for (int node = 0;
                node < numNodesPerHiddenLayer[layer - 1] + 1;
                node++) { // ACROSS EACH HIDDEN LAYER (ie these are not output nodes)
              double deltaArrayForThisLayer[] = arrayListOfEachLayersDeltaArray.get(layer);

              if (node == numNodesPerHiddenLayer[layer - 1]) { // this is a bias node
                // change in weight = learningRate *
              } else { // this is not a bias node
                double summedOutgoingWeightsCrossOutputDelta = 0;

                for (int outgoingEdgeToOutgoingNode = 0;
                    outgoingEdgeToOutgoingNode
                        < arrayListOfEachLayersDeltaArray.get(layer + 1).length;
                    outgoingEdgeToOutgoingNode++) {
                  summedOutgoingWeightsCrossOutputDelta +=
                      (arrayListOfEachLayersDeltaArray.get(layer + 1)[outgoingEdgeToOutgoingNode]
                          * arrayListOfEachLayersWeightMatrices
                              .get(layer)[node][outgoingEdgeToOutgoingNode]);
                }

                deltaArrayForThisLayer[node] =
                    ((summedOutgoingWeightsCrossOutputDelta)
                        * (storedFNetForEachLayer.get(layer)[node])
                        * (1 - (storedFNetForEachLayer.get(layer)[node])));
                // should automatically be set since we get the objects address from heap memory,
                // and change it

                if (layer == 1) {
                  // need a for loop across the neural net's input nodes
                  for (int inputToTheNeuralNet = 0;
                      inputToTheNeuralNet < features.cols() + 1;
                      inputToTheNeuralNet++) {
                    double changeInWeightBetweenIJ = 0;
                    if (inputToTheNeuralNet
                        == features.cols()) { // then we know that this is our bias node

                      changeInWeightBetweenIJ =
                          (learningRate
                              * 1
                              * arrayListOfEachLayersDeltaArray
                                  .get(layer)[node]); // NEED TO ADD STUFF FOR MOMENTUM
                      double[] thisLayersBiasWeights =
                          changeInBiasArrayForEveryLayer.get(
                              layer - 1); // NEED TO ADD STUFF FOR MOMENTUM
                      thisLayersBiasWeights[node] =
                          (changeInWeightBetweenIJ); // NEED TO ADD STUFF FOR MOMENTUM

                    } else {

                      changeInWeightBetweenIJ =
                          (learningRate
                              * storedFNetForEachLayer.get(layer - 1)[inputToTheNeuralNet]
                              * arrayListOfEachLayersDeltaArray.get(layer)[node]);
                      double[][] changeInWeightsMatrixForThisLayer =
                          changeInWeightMatricesForEveryLayer.get(layer - 1);
                      changeInWeightsMatrixForThisLayer[inputToTheNeuralNet][node] =
                          changeInWeightBetweenIJ;
                    }
                  }
                } else {
                  for (int inputToThisNode = 0;
                      inputToThisNode < numNodesPerHiddenLayer[layer - 2] + 1;
                      inputToThisNode++) {
                    double changeInWeightBetweenIJ = 0;
                    if (inputToThisNode
                        == numNodesPerHiddenLayer[layer - 2]) { // this is a bias node

                      changeInWeightBetweenIJ =
                          (learningRate
                              * 1
                              * arrayListOfEachLayersDeltaArray
                                  .get(layer)[node]); // NEED TO ADD STUFF FOR MOMENTUM
                      double[] thisLayersBiasWeights =
                          changeInBiasArrayForEveryLayer.get(
                              layer - 1); // NEED TO ADD STUFF FOR MOMENTUM
                      thisLayersBiasWeights[node] =
                          (changeInWeightBetweenIJ); // NEED TO ADD STUFF FOR MOMENTUM

                    } else {

                      changeInWeightBetweenIJ =
                          (learningRate
                              * storedFNetForEachLayer.get(layer - 1)[inputToThisNode]
                              * arrayListOfEachLayersDeltaArray.get(layer)[node]);
                      // double[][] thisLayersWeightMatrix =
                      // arrayListOfEachLayersWeightMatrices.get(layer-1);
                      // thisLayersWeightMatrix[inputToThisNode][node] += ( changeInWeightBetweenIJ
                      // );
                      double[][] changeInWeightsMatrixForThisLayer =
                          changeInWeightMatricesForEveryLayer.get(layer - 1);
                      changeInWeightsMatrixForThisLayer[inputToThisNode][node] =
                          changeInWeightBetweenIJ;
                    }
                  }
                }
              }
            }
          }
        }

        //				System.out.printf( "e_0=%.17f,  e_1=%.17f, e_2=%.17f, e_3=%.17f\n" ,
        // arrayListOfEachLayersDeltaArray.get(2)[0], arrayListOfEachLayersDeltaArray.get(1)[0] ,
        //				arrayListOfEachLayersDeltaArray.get(1)[1] ,
        // arrayListOfEachLayersDeltaArray.get(1)[2]);
        //				System.out.println("Descending Gradient...");

        //				// PUT TEMPORARILY STASHED INTO PREVIOUS
        //				// ONLY HERE SHOULD WE PUT IN THE STASHED WEIGHTS INTO THE PREVIOUS-STASH-SPOT
        //				// PUT STASHED INTO PREVIOUS
        //
        //				// update the bias weights

        // GET NEW CHANGE IN WEIGHT THANKS TO MOMENTUM, PLACE IN PREVIOUS SPOT
        // should be changeInBiasArrayForEveryLayer not

        for (int w = 0; w < previousBiasChangeInWeightsAcrossAllLayers.size(); w++) {
          for (int y = 0; y < previousBiasChangeInWeightsAcrossAllLayers.get(w).length; y++) {
            double currentChangeInWeightVal = changeInBiasArrayForEveryLayer.get(w)[y];
            double[] fullBiasWeightList = biasWeightsAcrossAllLayers.get(w);
            double previousXYCoordInBiasWeightMatrix =
                previousBiasChangeInWeightsAcrossAllLayers.get(w)[y];
            double thisIsTheWeightChangeIncludingMomentum =
                (currentChangeInWeightVal + (momentum * previousXYCoordInBiasWeightMatrix));
            fullBiasWeightList[y] += thisIsTheWeightChangeIncludingMomentum;
            double[] arrayOfPreviousBiases = previousBiasChangeInWeightsAcrossAllLayers.get(w);
            arrayOfPreviousBiases[y] = thisIsTheWeightChangeIncludingMomentum;
          }
        }

        // GET NEW CHANGE IN WEIGHT THANKS TO MOMENTUM, PLACE IN PREVIOUS SPOT

        // We update the weights ( by adding the changes in weights to the weight matrices) after
        // every layer has been processed
        for (int w = 0; w < arrayListOfEachLayersWeightMatrices.size(); w++) {
          for (int y = 0; y < arrayListOfEachLayersWeightMatrices.get(w).length; y++) {
            for (int z = 0; z < arrayListOfEachLayersWeightMatrices.get(w)[y].length; z++) {
              double currentXYCoordInMatrix = changeInWeightMatricesForEveryLayer.get(w)[y][z];
              double[] fullWeightListForLayer = arrayListOfEachLayersWeightMatrices.get(w)[y];

              double previousXYCoordInChangeInWeightMatrix =
                  previousChangeInWeightMatricesForEachLayer.get(w)[y][z];
              double thisIsTheWeightChangeIncludingMomentum =
                  (currentXYCoordInMatrix + (previousXYCoordInChangeInWeightMatrix * momentum));
              fullWeightListForLayer[z] += thisIsTheWeightChangeIncludingMomentum;
              double[][] arrayOfPreviousBiases = previousChangeInWeightMatricesForEachLayer.get(w);
              arrayOfPreviousBiases[y][z] = thisIsTheWeightChangeIncludingMomentum;
              // newWeight(at next round t+1) = learningRate * delta_at_node_we_feed_into * Xi +
              // momentum_parameter * change_in_weight_at_t
              // momentum goes into the weight updates ( not in the change in weights)
            }
          }
        }

        //				System.out.printf( "w_0=%.17f,  w_1=%.17f, w_2=%.17f, w_3=%.17f, w_4=%.17f,
        // w_5=%.17f,\n w_6=%.17f, w_7=%.17f, w_8=%.17f, w_9=%.17f," +
        //						"w_10=%.17f, w_11=%.17f,\n w_12=%.17f\n" ,
        //						biasWeightsAcrossAllLayers.get(1)[0],
        // arrayListOfEachLayersWeightMatrices.get(1)[0][0] ,
        //						arrayListOfEachLayersWeightMatrices.get(1)[1][0] ,
        // arrayListOfEachLayersWeightMatrices.get(1)[2][0] , biasWeightsAcrossAllLayers.get(0)[0],
        //				arrayListOfEachLayersWeightMatrices.get(0)[0][0],
        // arrayListOfEachLayersWeightMatrices.get(0)[1][0], biasWeightsAcrossAllLayers.get(0)[1],
        //				arrayListOfEachLayersWeightMatrices.get(0)[0][1],
        // arrayListOfEachLayersWeightMatrices.get(0)[1][1],
        // arrayListOfEachLayersWeightMatrices.get(0)[0][2],
        //				biasWeightsAcrossAllLayers.get(0)[2],
        // arrayListOfEachLayersWeightMatrices.get(0)[0][2],
        // arrayListOfEachLayersWeightMatrices.get(0)[1][2]);
        //				// ONLY AFTER THIS POINT HAS EVERY LAYER BEEN PROCESSED

      }

      // if( STOPPING CRITERIA MET ) {  // HAVE TO USE THE VALIDATION SET THIS TIME FOR THE STOPPING
      // CRITERION
      currentAccuracy = calculateMSEOnValidationSet(featuresValidationSet, labelsValidationSet);
      // currentAccuracy = calculateMSEOnValidationSet( features , labels ); // On the training set
      // now
      System.out.println(" Current MSE on epoch # " + epoch + " is: " + currentAccuracy);
      currentAccuracyIndex++;
      recentAccuracies[currentAccuracyIndex % 5] = currentAccuracy;
      double sumAccuracies = 0;
      if (currentAccuracyIndex > 5) {
        for (int i = 0; i < recentAccuracies.length; i++) {
          sumAccuracies +=
              Math.abs(recentAccuracies[currentAccuracyIndex % 5] - recentAccuracies[i]);
        }
        if (sumAccuracies
            < 0.01) { // we only stop training when measureAccuracy after 5 epochs does not increase
                      // by 0.01
          break;
        }
      }

      // In theory, it would be wise here to go back to the old best weights because now we're
      // already overfitting if the stopping criterion is met
      features.shuffle(
          rand, labels); // MUST SHUFFLE DATA ROWS AFTER EACH EPOCH,labels is the buddy matrix
    }
    return;
  }
Example #3
0
File: Main.java Project: YpGu/gcoev
  /*
   * backward pass 1: update
   *  (1) \hat{mu} (mu_hat_s)
   *  (2) \hat{grad_mu} (grad_mu_hat_s)
   *  (3) \hat{V} (v_hat_s)
   */
  public static void backward1(boolean update_grad) {
    for (int t1 = T - 1; t1 > t0; t1--) {
      int t = t1 - t0;
      //      System.out.println("backward 1;\tt = " + t1);
      if (t != T - 1 - t0) {
        double V_pre_t = v_s.get(t - 1); // V^{t-1}
        double V_hat_t = v_hat_s.get(t); // \hat{V}^{t}
        double[][] mu_pre_t = mu_s.get(t - 1); // \mu^{t-1}
        double[][] mu_hat_t = mu_hat_s.get(t); // \hat{\mu}^{t}  [t-1]

        Matrix A_pre_t = new Matrix(AS.get(t - 1)); // A^{t-1}
        Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); // h'^{t-1}
        Matrix ave_neighbors = A_pre_t.times(hprime_pre_t); // n * 1

        /* calculate \hat{\mu} at time t-1 */
        double factor_1 =
            (1 - lambda) * V_pre_t / (sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t);
        double factor_2 = (sigma * sigma) / (sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t);
        double[][] mu_hat_pre_t = new double[n][K];
        for (int i = 0; i < n; i++)
          for (int k = 0; k < K; k++) {
            mu_hat_pre_t[i][k] =
                factor_1 * (mu_hat_t[i][k] - lambda * ave_neighbors.get(i, k))
                    + factor_2 * mu_pre_t[i][k];
          }
        /* calculate \hat{V} at time t-1 */
        double V_hat_pre_t =
            V_pre_t
                + factor_1
                    * factor_1
                    * (V_hat_t - (1 - lambda) * (1 - lambda) * V_pre_t - (sigma * sigma));

        /* update \mu and V */
        mu_hat_s.set(t - 1, mu_hat_pre_t);
        v_hat_s.set(t - 1, V_hat_pre_t);

        /* calculate and update grad_mu_hat at time t-1 */
        if (update_grad)
          for (int s = 0; s < T - t0; s++) {
            double[][] grad_hat_t_s = grad_mu_hat_s.get(t * (T - t0) + s);
            double[][] grad_pre_t_s = grad_mu_s.get((t - 1) * (T - t0) + s);
            double[][] grad_hat_pre_t_s = new double[n][K];
            for (int i = 0; i < n; i++)
              for (int k = 0; k < K; k++) {
                grad_hat_pre_t_s[i][k] =
                    factor_1 * grad_hat_t_s[i][k] + factor_2 * grad_pre_t_s[i][k];
              }
            grad_mu_hat_s.set((t - 1) * (T - t0) + s, grad_hat_pre_t_s);
          }
      } else {
        /*
         * initial condition for backward pass:
         *  (1) \hat{mu}^{T} = mu^{T}
         *  (2) \hat{V}^{T} = V^{T}
         *  (3) \hat{grad_mu}^{T/s} = grad_mu^{T/s}, \forall s
         */
        mu_hat_s.set(t, mu_s.get(t));
        v_hat_s.set(t, v_s.get(t));
        if (update_grad)
          for (int s = 0; s < T - t0; s++) {
            grad_mu_hat_s.set(t * (T - t0) + s, grad_mu_s.get(t * (T - t0) + s));
          }
      }
      //      Scanner sc = new Scanner(System.in);
      //      int gu; gu = sc.nextInt();
      /* end for each t */
    }
  }
Example #4
0
File: Main.java Project: YpGu/gcoev
  /**
   * forward pass 1: update intrinsic features (1) mu (mu_s) (2) grad_mu (grad_mu_s) (3) variance V
   * (v_s)
   */
  public static void forward1(boolean update_grad, int iter) {
    /*
       if (iter == 4) {
         int t = 15;
         double[][] h_t = new double[n][K];
         double[][] h_hat_t = new double[n][K];
         double[][] h_prime_t = new double[n][K];
         double[][] h_hat_prime_t = new double[n][K];
         for (int i = 0; i < n; i++) for (int k = 0; k < K; k++) {
    h_t[i][k] = h_s.get(t-1)[i][k];
    h_hat_t[i][k] = h_hat_s.get(t-1)[i][k];
    h_prime_t[i][k] = h_prime_s.get(t-1)[i][k];
    h_hat_prime_t[i][k] = h_hat_prime_s.get(t-1)[i][k];
         }
         h_s.set(t, h_t); h_hat_s.set(t, h_hat_t);
         h_prime_s.set(t, h_prime_t); h_hat_prime_s.set(t, h_hat_prime_t);
       }
       */

    for (int t = 0; t < T - t0; t++) {
      //      System.out.println("forward 1;\tt = " + t1);
      if (t != 0) {
        double delta_t = delta_s.get(t); // delta_t
        double[][] h_hat_t = h_hat_s.get(t); // \hat{h}^t  [t]
        double[][] mu_pre_t = mu_s.get(t - 1); // mu^{t-1} (N*1)
        double V_pre_t = v_s.get(t - 1); // V^{t-1}

        Matrix a = new Matrix(AS.get(t - 1)); // A^{t-1}
        Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1)); // h'^{t-1}
        Matrix ave_neighbors = a.times(hprime_pre_t);

        /* calculate \mu */
        double[][] mu_t = new double[n][K];
        double factor_1 =
            (delta_t * delta_t)
                / (delta_t * delta_t + sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t);
        double factor_2 =
            (sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t)
                / (delta_t * delta_t + sigma * sigma + (1 - lambda) * (1 - lambda) * V_pre_t);
        for (int i = 0; i < n; i++)
          for (int k = 0; k < K; k++) {
            mu_t[i][k] =
                factor_1 * ((1 - lambda) * mu_pre_t[i][k] + lambda * ave_neighbors.get(i, k))
                    + factor_2 * h_hat_t[i][k];
          }
        /* calculate V */
        double V_t = factor_2 * delta_t * delta_t;

        /* update \mu and V */
        mu_s.set(t, mu_t);
        v_s.set(t, V_t);

        /* calculate and update grad_mu */
        if (update_grad)
          for (int s = 0; s < T - t0; s++) {
            double[][] grad_pre_t_s = grad_mu_s.get((t - 1) * (T - t0) + s);
            double[][] grad_t_s = new double[n][K];
            for (int i = 0; i < n; i++)
              for (int k = 0; k < K; k++) {
                grad_t_s[i][k] = factor_1 * (1 - lambda) * grad_pre_t_s[i][k];
                if (t == s) {
                  grad_t_s[i][k] += factor_2;
                }
              }
            grad_mu_s.set(t * (T - t0) + s, grad_t_s);
          }
      } else {
        /* mu, V: random init (keep unchanged) */
        /* grad_mu: set to 0 (keep unchanged) */
      }
      //      Scanner sc = new Scanner(System.in);
      //      int gu; gu = sc.nextInt();
      /* end for each t */
    }
  }
Example #5
0
File: Main.java Project: YpGu/gcoev
  public static void compute_gradient1(int iteration) {
    double[][][] tmp_grad_h_hat_s = new double[T - t0][n][K];

    for (int t = 0; t < T - t0; t++) {
      //      System.out.println("compute gradient 1, t = " + t);
      double delta_t = delta_s.get(t);
      double[][] G_t = GS.get(t);
      double[][] h_prime_t = h_prime_s.get(t);
      double[][] mu_hat_t = mu_hat_s.get(t);

      if (t != 0) {
        double[][] mu_hat_pre_t = mu_hat_s.get(t - 1);

        Matrix a = new Matrix(AS.get(t - 1));
        Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1));
        Matrix ave_neighbors = a.times(hprime_pre_t);

        /* TODO: check whether we can save computation by comparing s and t */
        for (int s = 0; s < T - t0; s++) {
          double[][] grad_hat_t = grad_mu_hat_s.get(t * (T - t0) + s);
          double[][] grad_hat_pre_t = grad_mu_hat_s.get((t - 1) * (T - t0) + s);
          double[] hp2delta2 = new double[n];
          for (int i = 0; i < n; i++)
            for (int k = 0; k < K; k++) {
              hp2delta2[i] += 0.5 * h_prime_t[i][k] * h_prime_t[i][k] * delta_t * delta_t;
            }

          for (int i = 0; i < n; i++) {
            /* first term */
            double[] weighted_exp_num = new double[K];
            double weighted_exp_den = 0;
            for (int l = 0; l < n; l++) {
              double hp_muh = Operations.inner_product(h_prime_t[l], mu_hat_t[i], K);
              double e = Math.exp(hp_muh + hp2delta2[l]);
              if (Double.isNaN(e)) {
                /* check if e explodes */
                System.out.println("ERROR2");
                Scanner sc = new Scanner(System.in);
                int gu;
                gu = sc.nextInt();
              }
              for (int k = 0; k < K; k++) {
                weighted_exp_num[k] += h_prime_t[l][k] * e;
                weighted_exp_den += e;
              }
            }
            for (int j = 0; j < n; j++)
              for (int k = 0; k < K; k++) {
                double weighted_exp = weighted_exp_num[k] / weighted_exp_den;
                double gi1 = G_t[i][j] * grad_hat_t[i][k] * (h_prime_t[j][k] - weighted_exp);
                tmp_grad_h_hat_s[s][i][k] += gi1;
              }

            /* second term */
            for (int k = 0; k < K; k++) {
              double gi2 =
                  -(mu_hat_t[i][k]
                          - (1 - lambda) * mu_hat_pre_t[i][k]
                          - lambda * ave_neighbors.get(i, k))
                      * (grad_hat_t[i][k] - (1 - lambda) * grad_hat_pre_t[i][k])
                      / (sigma * sigma);
              tmp_grad_h_hat_s[s][i][k] += gi2;
            }
          }
        }
      } else {
        /* no such term (t=0) in ELBO */
        /*
        for (int s = 0; s < T-t0; s++) {
          double[] grad_hat_t = grad_mu_hat_s.get(t * (T-t0) + s);

          for (int i = 0; i < n; i++) {
            double n_it = 0;
            for (int j = 0; j < n; j++) n_it += G_t[i][j];

            // first term
            double gi1 = -mu_hat_t[i] * grad_hat_t[i] / (sigma * sigma);
            tmp_grad_h_hat_s[s][i] += gi1;

            // second term
            double gi2 = 0;
            double weighted_exp_num = 0, weighted_exp_den = 0;
            for (int j = 0; j < NEG; j++) {
              int l = neg_samples.get(t)[i][j];
              double hpl = h_prime_t[l][0];
              double muit = mu_hat_t[i];
              double e = Math.exp(hpl * muit + 0.5 * hpl * hpl * delta_t * delta_t);
              // TODO: check if e explodes
              if (Double.isNaN(e)) {
        	System.out.println("ERROR3");
        	Scanner sc = new Scanner(System.in);
        	int gu; gu = sc.nextInt();
              }
              weighted_exp_num += hpl * e;
              weighted_exp_den += e;
            }
            double weighted_exp = weighted_exp_num / weighted_exp_den;
            for (int j = 0; j < n; j++) {
              gi2 += G_t[i][j] * grad_hat_t[i] * (h_prime_t[j][0] - weighted_exp);
            }
            tmp_grad_h_hat_s[s][i] += gi2;
          }
        }
        */
      }
      /* end if-else */
    }

    /* update global gradient */
    for (int t = 0; t < T - t0; t++) {
      double[][] grad = new double[n][K];
      for (int i = 0; i < n; i++)
        for (int k = 0; k < K; k++) {
          grad[i][k] = tmp_grad_h_hat_s[t][i][k];
        }
      grad_h_hat_s.set(t, grad);
    }
    FileParser.output_2d(grad_h_hat_s, "./grad/grad_" + iteration + ".txt");

    return;
  }
Example #6
0
File: Main.java Project: YpGu/gcoev
  /** compute_objective1: return the lower bound when h' is fixed */
  public static double compute_objective1() {
    double res = 0;
    for (int t = 0; t < T - t0; t++) {
      if (t != 0) {
        double[][] G_t = GS.get(t);
        double[][] h_prime_t = h_prime_s.get(t);
        double[][] h_prime_pre_t = h_prime_s.get(t - 1);
        double[][] mu_hat_t = mu_hat_s.get(t);
        double[][] mu_hat_pre_t = mu_hat_s.get(t - 1);
        double delta_t = delta_s.get(t);

        Matrix a = new Matrix(AS.get(t - 1));
        Matrix hprime_pre_t = new Matrix(h_prime_s.get(t - 1));
        Matrix ave_neighbors = a.times(hprime_pre_t);

        double[] hp2delta2 = new double[n];
        for (int i = 0; i < n; i++)
          for (int k = 0; k < K; k++) {
            hp2delta2[i] += 0.5 * h_prime_t[i][k] * h_prime_t[i][k] * delta_t * delta_t;
          }

        for (int i = 0; i < n; i++) {
          /* first term */
          List<Double> powers = new ArrayList<Double>();
          for (int l = 0; l < n; l++) {
            double hp_muh = Operations.inner_product(h_prime_t[l], mu_hat_t[i], K);
            powers.add(hp_muh + hp2delta2[l]);
          }
          double lse = log_sum_exp(powers);

          for (int j = 0; j < n; j++)
            if (G_t[i][j] != 0) {
              double hp_muh = Operations.inner_product(h_prime_t[j], mu_hat_t[i], K);
              res += G_t[i][j] * (hp_muh - lse);
            }

          /* second term */
          for (int k = 0; k < K; k++) {
            double diff =
                mu_hat_t[i][k]
                    - (1 - lambda) * mu_hat_pre_t[i][k]
                    - lambda * ave_neighbors.get(i, k);
            res -= 0.5 * diff * diff / (sigma * sigma);
          }
        }
      } else {
        /*
        double[][] G_t = GS.get(t);
        double[][] h_prime_t = h_prime_s.get(t);
        double[] mu_hat_t = mu_hat_s.get(t);
        double delta_t = delta_s.get(t);
        int[][] neg_sam_t = neg_samples.get(t);

        for (int i = 0; i < n; i++) {
          // first term
          for (int j = 0; j < n; j++) if (G_t[i][j] != 0) {
            List<Double> powers = new ArrayList<Double>();
            for (int _l = 0; _l < NEG; _l++) {
              int l = neg_sam_t[i][_l];
              powers.add(h_prime_t[l][0] * mu_hat_t[i]
        	  + 0.5 * h_prime_t[l][0] * h_prime_t[l][0] * delta_t * delta_t);
            }
            double lse = log_sum_exp(powers);
            res += G_t[i][j] * (h_prime_t[j][0] * mu_hat_t[i] - lse);
          }
        }
        */
      }
    }
    return res;
  }