예제 #1
0
  public void setProbabilities(Lattice lattice) {

    for (Integer nodeID : jnodes.keySet()) {
      BayesNode n = jnodes.get(nodeID);
      //			System.out.println("setting CPT for "+n.getName());
      List<BayesNode> parents = n.getParents();
      List<BayesNode> parentsAndChild = new ArrayList<BayesNode>(parents);
      parentsAndChild.add(n);

      int nbParents = parents.size();
      //			System.out.println(nbParents+" parents");
      //			System.out.print("numbers for jayes =[");
      //			for (int i = 0; i < parentsAndChild.size(); i++) {
      //				BayesNode nodeTmp = parentsAndChild.get(i);
      //				System.out.print(nodesNumber.get(nodeTmp)+",");
      //			}
      //			System.out.println("]");

      BitSet numbers = new BitSet();
      numbers.set(nodeID);

      int[] sizes = new int[nbParents];
      int nbRowsInCPT = 1;
      for (int i = 0; i < parents.size(); i++) {
        BayesNode parent = parents.get(i);
        numbers.set(nodesNumber.get(parent));
        sizes[i] = parents.get(i).getOutcomeCount();
        nbRowsInCPT *= sizes[i];
      }

      LatticeNode latticeNode = lattice.getNode(numbers);
      Map<Integer, Integer> fromNodeIDToPositionInSortedTable = new HashMap<Integer, Integer>();

      Integer[] variablesNumbers = new Integer[numbers.cardinality()];
      int current = 0;
      for (int i = numbers.nextSetBit(0); i >= 0; i = numbers.nextSetBit(i + 1)) {
        variablesNumbers[current] = i;
        current++;
      }
      for (int i = 0; i < variablesNumbers.length; i++) {
        fromNodeIDToPositionInSortedTable.put(variablesNumbers[i], i);
      }

      int[] counts = new int[nbRowsInCPT * n.getOutcomeCount()];
      int[] indexes4lattice = new int[parentsAndChild.size()];
      int[] indexes4Jayes = new int[parentsAndChild.size()];
      //			System.out.println(counts.length +" cases");
      //			System.out.println("numbers for lattice "+Arrays.toString(variablesNumbers));

      for (int c = 0; c < counts.length; c++) {
        //				System.out.println("case "+c);
        int index = c;
        // find indexes
        for (int i = indexes4Jayes.length - 1; i > 0; i--) {
          BayesNode associatedNode = parentsAndChild.get(i);
          int dim = associatedNode.getOutcomeCount();
          indexes4Jayes[i] = index % dim;
          index /= dim;
        }
        indexes4Jayes[0] = index;

        //				System.out.println("indexes jayes = "+Arrays.toString(indexes4Jayes));

        for (int i = 0; i < indexes4Jayes.length; i++) {
          BayesNode nodeInPositionI = parentsAndChild.get(i);
          //					System.out.println(nodeInPositionI);
          //					System.out.println(fromNodeIDToPositionInSortedTable);
          int nodeInPositionIID = nodesNumber.get(nodeInPositionI);
          int indexInSortedTable = fromNodeIDToPositionInSortedTable.get(nodeInPositionIID);
          indexes4lattice[indexInSortedTable] = indexes4Jayes[i];
        }

        //				System.out.println("indexes lattice = "+Arrays.toString(indexes4lattice));

        int count = latticeNode.getMatrixCell(indexes4lattice);
        counts[c] = count;
      }
      //			System.out.println(Arrays.toString(counts));
      //			System.out.println("total="+sumAllCounts);

      double mTerm = 0.5;
      double[] probas1D = new double[n.getOutcomeCount() * nbRowsInCPT];
      for (int s = 0; s < probas1D.length; s += n.getOutcomeCount()) {

        double sumOfCounts = 0.0;
        for (int j = 0; j < n.getOutcomeCount(); j++) {
          sumOfCounts += counts[s + j] + mTerm;
        }

        for (int j = 0; j < n.getOutcomeCount(); j++) {
          probas1D[s + j] = (counts[s + j] + mTerm) / sumOfCounts;
        }
      }
      //			System.out.println(Arrays.toString(probas1D));
      n.setProbabilities(probas1D);
    }

    System.out.println("Compiling network for inference...");
    inferer = new JunctionTreeAlgorithm();
    inferer.setNetwork(jbn);
    evidence = new HashMap<BayesNode, String>();
    System.out.println("Compiled.");
  }
예제 #2
0
 public void clearEvidences() {
   evidence = new HashMap<BayesNode, String>();
   inferer.setEvidence(evidence);
 }
예제 #3
0
 public double[] getBelief(BayesNode n) {
   return inferer.getBeliefs(n);
 }
예제 #4
0
 public void recordEvidence() {
   inferer.setEvidence(evidence);
 }