Beispiel #1
0
  /** @param args */
  public static void main(String[] args) {

    int numStates = 3;
    int numObservations = 2;
    double epsilon = 1e-8;
    int maxIter = 10;

    double[] pi = new double[] {0.33, 0.33, 0.34};

    double[][] A =
        new double[][] {
          {0.5, 0.3, 0.2},
          {0.3, 0.5, 0.2},
          {0.2, 0.4, 0.4}
        };

    double[][] B =
        new double[][] {
          {0.7, 0.3},
          {0.5, 0.5},
          {0.4, 0.6}
        };

    // Generate the data sequences for training
    int D = 10000;
    int T_min = 5;
    int T_max = 10;
    int[][][] data = HMM.generateDataSequences(D, T_min, T_max, pi, A, B);
    int[][] Os = data[0];
    int[][] Qs = data[1];

    /*int[][] Os = new int[][] {
    		{0, 0, 1, 0, 1},
    		{0, 1, 1, 0, 1},
    		{1, 0, 0, 1, 0},
    		{0, 1, 1, 1, 0},
    		{0, 1, 1, 0},
    		{0, 1, 0, 1, 1}
    		};

    int[][] Qs = new int[][] {
    		{0, 0, 1, 0, 1},
    		{0, 1, 1, 0, 1},
    		{1, 0, 0, 0, 0},
    		{0, 0, 1, 1, 1},
    		{1, 1, 1, 0},
    		{0, 1, 1, 1, 1}
    		};*/

    boolean trainHMM = !false;
    if (trainHMM) {
      HMM HMM = new HMM(numStates, numObservations, epsilon, maxIter);
      HMM.feedData(Os);
      HMM.feedLabels(Qs);
      HMM.train();

      fprintf("True Model Parameters: \n");
      fprintf("Initial State Distribution: \n");
      display(pi);
      fprintf("State Transition Probability Matrix: \n");
      display(A);
      fprintf("Observation Probability Matrix: \n");
      display(B);

      fprintf("Trained Model Parameters: \n");
      fprintf("Initial State Distribution: \n");
      display(HMM.pi);
      fprintf("State Transition Probability Matrix: \n");
      display(HMM.A);
      fprintf("Observation Probability Matrix: \n");
      display(HMM.B);

      String HMMModelFilePath = "HMMModel.dat";
      HMM.saveModel(HMMModelFilePath);
    }
    /*HMM.setPi(pi);
    HMM.setA(A);
    HMM.setB(B);*/

    // Predict the single best state path

    // int[] O = new int[] {1, 0, 1, 1, 1, 0, 0, 1};

    int ID = new Random().nextInt(D);
    int[] O = Os[ID];

    HMM HMMt = new HMM();
    HMMt.loadModel("HMMModel.dat");
    int[] Q = HMMt.predict(O);

    fprintf("Observation sequence: \n");
    HMMt.showObservationSequence(O);
    fprintf("True state sequence: \n");
    HMMt.showStateSequence(Qs[ID]);
    fprintf("Predicted state sequence: \n");
    HMMt.showStateSequence(Q);
    double p = HMMt.evaluate(O);
    System.out.format("P(O|Theta) = %f\n", p);
  }