public double getProbability(
      int time,
      int stateI,
      int stateJ,
      DenseDoubleVector observation,
      double[][] forward,
      double[][] backward) {

    double num;
    if (time == observation.getLength() - 1) {
      num = forward[stateI][time] * transitionProbabilities.get(stateI, stateJ);
    } else {
      num =
          forward[stateI][time]
              * transitionProbabilities.get(stateI, stateJ)
              * emissionProbabilities.get(stateJ, (int) observation.get(time + 1))
              * backward[stateJ][time + 1];
    }
    double denom = 0;

    for (int k = 0; k < numStates; k++) {
      denom += (forward[k][time] * backward[k][time]);
    }
    return denom != 0.0d ? num / denom : 0.0d;
  }
  public double[][] calculateForwardProbabilities(DenseDoubleVector o) {
    double[][] fwd = new double[numStates][o.getLength()];

    // initialization (time 0)
    for (int i = 0; i < numStates; i++) {
      fwd[i][0] = initialProbability.get(i) * emissionProbabilities.get(i, ((int) o.get(0)));
    }
    // induction
    for (int t = 0; t <= o.getLength() - 2; t++) {
      for (int j = 0; j < numStates; j++) {
        fwd[j][t + 1] = 0;
        for (int i = 0; i < numStates; i++) {
          fwd[j][t + 1] += (fwd[i][t] * transitionProbabilities.get(i, j));
        }
        fwd[j][t + 1] *= emissionProbabilities.get(j, (int) o.get(t + 1));
      }
    }

    return fwd;
  }
  public double[][] calculateBackwardProbabilities(DenseDoubleVector o) {
    double[][] bwd = new double[numStates][o.getLength()];

    // initialization (time 0)
    for (int i = 0; i < numStates; i++) {
      bwd[i][o.getLength() - 1] = 1;
    }

    // induction
    for (int t = o.getLength() - 2; t >= 0; t--) {
      for (int i = 0; i < numStates; i++) {
        bwd[i][t] = 0;
        for (int j = 0; j < numStates; j++) {
          bwd[i][t] +=
              (bwd[j][t + 1]
                  * transitionProbabilities.get(i, j)
                  * emissionProbabilities.get(j, (int) o.get(t + 1)));
        }
      }
    }

    return bwd;
  }
 public HMM(
     int numStates,
     int numOutputStates,
     DenseDoubleVector initialProbabilities,
     DenseDoubleMatrix transitionProbabilities,
     DenseDoubleMatrix emissionProbabilities) {
   Preconditions.checkArgument(numStates > 1);
   Preconditions.checkArgument(initialProbabilities.getLength() == numStates);
   Preconditions.checkArgument(numOutputStates > 0);
   this.initialProbability = initialProbabilities;
   this.transitionProbabilities = transitionProbabilities;
   this.numStates = numStates;
   this.numOutputStates = numOutputStates;
   this.emissionProbabilities = emissionProbabilities;
 }
  public static void forwardPropagate(
      DoubleMatrix[] thetas, DoubleMatrix[] ax, DoubleMatrix[] zx, NetworkConfiguration conf) {
    for (int i = 1; i < conf.layerSizes.length; i++) {
      zx[i] = multiply(ax[i - 1], thetas[i - 1], false, true, conf);

      if (i < (conf.layerSizes.length - 1)) {
        ax[i] =
            new DenseDoubleMatrix(
                DenseDoubleVector.ones(zx[i].getRowCount()), conf.activations[i].apply(zx[i]));
        if (conf.hiddenDropoutProbability > 0d) {
          // compute dropout for ax[i]
          dropout(conf.rnd, ax[i], conf.hiddenDropoutProbability);
        }
      } else {
        // the output doesn't need a bias
        ax[i] = conf.activations[i].apply(zx[i]);
      }
    }
  }
  private void print() {
    System.out.println("States: " + numStates + " | OutputStates: " + numOutputStates);
    System.out.println();
    DecimalFormat fmt = new DecimalFormat();
    fmt.setMinimumFractionDigits(5);
    fmt.setMaximumFractionDigits(5);

    for (int i = 0; i < numStates; i++) {
      System.out.println("init(" + i + ") = " + fmt.format(initialProbability.get(i)));
    }
    System.out.println();

    for (int i = 0; i < numStates; i++) {
      for (int j = 0; j < numStates; j++) {
        System.out.print(
            "transition("
                + i
                + ","
                + j
                + ") = "
                + fmt.format(transitionProbabilities.get(i, j))
                + "  ");
      }
      System.out.println();
    }

    System.out.println();
    for (int i = 0; i < numStates; i++) {
      for (int k = 0; k < numOutputStates; k++) {
        System.out.print(
            "emission("
                + i
                + ","
                + k
                + ") = "
                + fmt.format(emissionProbabilities.get(i, k))
                + "  ");
      }
      System.out.println();
    }
    System.out.println();
  }
  /**
   * Minimizes the given CostFunction with Nonlinear conjugate gradient method. <br>
   * It uses the Polack-Ribiere (PR) to calculate the conjugate direction. See <br>
   * {@link http://en.wikipedia.org/wiki/Nonlinear_conjugate_gradient_method} <br>
   * for more information.
   *
   * @param f the cost function to minimize.
   * @param pInput the input vector, also called starting point
   * @param length the number of iterations to make
   * @param verbose output the progress to STDOUT
   * @return a vector containing the optimized input
   */
  public static DoubleVector minimizeFunction(
      CostFunction f, DoubleVector pInput, int length, boolean verbose) {

    DoubleVector input = pInput;
    int M = 0;
    int i = 0; // zero the run length counter
    int red = 1; // starting point
    int ls_failed = 0; // no previous line search has failed
    DenseDoubleVector fX = new DenseDoubleVector(0); // what we return as fX
    // get function value and gradient
    final Tuple<Double, DoubleVector> evaluateCost = f.evaluateCost(input);
    double f1 = evaluateCost.getFirst();
    DoubleVector df1 = evaluateCost.getSecond();
    i = i + (length < 0 ? 1 : 0);
    DoubleVector s = df1.multiply(-1.0d); // search direction is
    // steepest

    double d1 = s.multiply(-1.0d).dot(s); // this is the slope
    double z1 = red / (1.0 - d1); // initial step is red/(|s|+1)

    while (i < Math.abs(length)) { // while not finished
      i = i + (length > 0 ? 1 : 0); // count iterations?!
      // make a copy of current values
      DoubleVector X0 = input.deepCopy();
      double f0 = f1;
      DoubleVector df0 = df1.deepCopy();
      // begin line search
      input = input.add(s.multiply(z1));
      final Tuple<Double, DoubleVector> evaluateCost2 = f.evaluateCost(input);
      double f2 = evaluateCost2.getFirst();
      DoubleVector df2 = evaluateCost2.getSecond();

      i = i + (length < 0 ? 1 : 0); // count epochs?!
      double d2 = df2.dot(s);
      // initialize point 3 equal to point 1
      double f3 = f1;
      double d3 = d1;
      double z3 = -z1;
      if (length > 0) {
        M = MAX;
      } else {
        M = Math.min(MAX, -length - i);
      }
      // initialize quanteties
      int success = 0;
      double limit = -1;

      while (true) {
        while (((f2 > f1 + z1 * RHO * d1) | (d2 > -SIG * d1)) && (M > 0)) {
          limit = z1; // tighten the bracket
          double z2 = 0.0d;
          double A = 0.0d;
          double B = 0.0d;
          if (f2 > f1) {
            // quadratic fit
            z2 = z3 - (0.5 * d3 * z3 * z3) / (d3 * z3 + f2 - f3);
          } else {
            A = 6 * (f2 - f3) / z3 + 3 * (d2 + d3); // cubic fit
            B = 3 * (f3 - f2) - z3 * (d3 + 2 * d2);
            // numerical error possible - ok!
            z2 = (Math.sqrt(B * B - A * d2 * z3 * z3) - B) / A;
          }
          if (Double.isNaN(z2) || Double.isInfinite(z2)) {
            z2 = z3 / 2.0d; // if we had a numerical problem then
            // bisect
          }
          // don't accept too close to limits
          z2 = Math.max(Math.min(z2, INT * z3), (1 - INT) * z3);
          z1 = z1 + z2; // update the step
          input = input.add(s.multiply(z2));
          final Tuple<Double, DoubleVector> evaluateCost3 = f.evaluateCost(input);
          f2 = evaluateCost3.getFirst();
          df2 = evaluateCost3.getSecond();
          M = M - 1;
          i = i + (length < 0 ? 1 : 0); // count epochs?!
          d2 = df2.dot(s);
          z3 = z3 - z2; // z3 is now relative to the location of z2
        }
        if (f2 > f1 + z1 * RHO * d1 || d2 > -SIG * d1) {
          break; // this is a failure
        } else if (d2 > SIG * d1) {
          success = 1;
          break; // success
        } else if (M == 0) {
          break; // failure
        }
        double A = 6 * (f2 - f3) / z3 + 3 * (d2 + d3); // make cubic
        // extrapolation
        double B = 3 * (f3 - f2) - z3 * (d3 + 2 * d2);
        double z2 = -d2 * z3 * z3 / (B + Math.sqrt(B * B - A * d2 * z3 * z3));
        // num prob or wrong sign?
        if (Double.isNaN(z2) || Double.isInfinite(z2) || z2 < 0)
          if (limit < -0.5) { // if we have no upper limit
            z2 = z1 * (EXT - 1); // the extrapolate the maximum
            // amount
          } else {
            z2 = (limit - z1) / 2; // otherwise bisect
          }
        else if ((limit > -0.5) && (z2 + z1 > limit)) {
          // extraplation beyond max?
          z2 = (limit - z1) / 2; // bisect
        } else if ((limit < -0.5) && (z2 + z1 > z1 * EXT)) {
          // extrapolationbeyond limit
          z2 = z1 * (EXT - 1.0); // set to extrapolation limit
        } else if (z2 < -z3 * INT) {
          z2 = -z3 * INT;
        } else if ((limit > -0.5) && (z2 < (limit - z1) * (1.0 - INT))) {
          // too close to the limit
          z2 = (limit - z1) * (1.0 - INT);
        }
        // set point 3 equal to point 2
        f3 = f2;
        d3 = d2;
        z3 = -z2;
        z1 = z1 + z2;
        // update current estimates
        input = input.add(s.multiply(z2));
        final Tuple<Double, DoubleVector> evaluateCost3 = f.evaluateCost(input);
        f2 = evaluateCost3.getFirst();
        df2 = evaluateCost3.getSecond();
        M = M - 1;
        i = i + (length < 0 ? 1 : 0); // count epochs?!
        d2 = df2.dot(s);
      } // end of line search

      DoubleVector tmp = null;

      if (success == 1) { // if line search succeeded
        f1 = f2;
        fX = new DenseDoubleVector(fX.toArray(), f1);
        if (verbose) System.out.print("Iteration " + i + " | Cost: " + f1 + "\r");
        // Polack-Ribiere direction: s =
        // (df2'*df2-df1'*df2)/(df1'*df1)*s - df2;
        final double numerator = (df2.dot(df2) - df1.dot(df2)) / df1.dot(df1);
        s = s.multiply(numerator).subtract(df2);
        tmp = df1;
        df1 = df2;
        df2 = tmp; // swap derivatives
        d2 = df1.dot(s);
        if (d2 > 0) { // new slope must be negative
          s = df1.multiply(-1.0d); // otherwise use steepest direction
          d2 = s.multiply(-1.0d).dot(s);
        }
        // realmin in octave = 2.2251e-308
        // slope ratio but max RATIO
        z1 = z1 * Math.min(RATIO, d1 / (d2 - 2.2251e-308));
        d1 = d2;
        ls_failed = 0; // this line search did not fail
      } else {
        input = X0;
        f1 = f0;
        df1 = df0; // restore point from before failed line search
        // line search failed twice in a row?
        if (ls_failed == 1 || i > Math.abs(length)) {
          break; // or we ran out of time, so we give up
        }
        tmp = df1;
        df1 = df2;
        df2 = tmp; // swap derivatives
        s = df1.multiply(-1.0d); // try steepest
        d1 = s.multiply(-1.0d).dot(s);
        z1 = 1.0d / (1.0d - d1);
        ls_failed = 1; // this line search failed
      }
    }

    return input;
  }
  public void trainBaumWelch(List<DenseDoubleVector> observations, int steps) {
    double pi1[] = new double[numStates];
    double a1[][] = new double[numStates][numStates];
    double b1[][] = new double[numStates][numOutputStates];

    for (int s = 0; s < steps; s++) {
      for (DenseDoubleVector o : observations) {
        // calculate forward and backward probabilities
        double[][] fwd = calculateForwardProbabilities(o);
        double[][] bwd = calculateBackwardProbabilities(o);

        // re-estimation of initial state probabilities
        for (int i = 0; i < numStates; i++) {
          pi1[i] = gamma(i, 0, fwd, bwd);
        }

        // re-estimation of transition probabilities
        for (int i = 0; i < numStates; i++) {
          for (int j = 0; j < numStates; j++) {
            double num = 0;
            double denom = 0;
            for (int t = 0; t <= o.getLength() - 1; t++) {
              num += getProbability(t, i, j, o, fwd, bwd);
              denom += gamma(i, t, fwd, bwd);
            }
            a1[i][j] = denom != 0.0d ? num / denom : 0.0d;
          }
        }

        // re-estimation of emission probabilities
        for (int i = 0; i < numStates; i++) {
          for (int k = 0; k < numOutputStates; k++) {
            double num = 0;
            double denom = 0;

            for (int t = 0; t <= o.getLength() - 1; t++) {
              double g = gamma(i, t, fwd, bwd);
              num += g * (k == o.get(t) ? 1.0d : 0.0d);
              denom += g;
            }
            b1[i][k] = denom != 0.0d ? num / denom : 0.0d;
          }
        }
      }

      // TODO calculate the kullback leibler divergence of the output
      for (int i = 0; i < initialProbability.getLength(); i++) {
        initialProbability.set(i, pi1[i]);
      }
      for (int col = 0; col < transitionProbabilities.getColumnCount(); col++) {
        for (int row = 0; row < transitionProbabilities.getRowCount(); row++) {
          transitionProbabilities.set(row, col, a1[row][col]);
        }
      }
      for (int col = 0; col < emissionProbabilities.getColumnCount(); col++) {
        for (int row = 0; row < emissionProbabilities.getRowCount(); row++) {
          emissionProbabilities.set(row, col, b1[row][col]);
        }
      }
    }
  }