/**
   * Estimates a Bayes IM using the variables, graph, and parameters in the given Bayes PM and the
   * data columns in the given data set. Each variable in the given Bayes PM must be equal to a
   * variable in the given data set. The Bayes IM so estimated is used as the initial Bayes net in
   * the iterative procedure implemented in the maximize method.
   */
  private void estimateIM(BayesPm bayesPm, DataSet dataSet) {
    if (bayesPm == null) {
      throw new NullPointerException();
    }

    if (dataSet == null) {
      throw new NullPointerException();
    }

    // Make sure all of the variables in the PM are in the data set;
    // otherwise, estimation is impossible.
    //        List pmvars = bayesPm.getVariables();
    //        List dsvars = dataSet.getVariables();
    //        List obsVars = observedIm.getBayesPm().getVariables();

    // System.out.println("Bayes PM as received by estimateMixedIM:  ");
    // System.out.println(bayesPm);
    //        Graph g = bayesPm.getDag();
    // System.out.println(g);

    // DEBUG Prints:
    // System.out.println("PM VARS " + pmvars);
    // System.out.println("DS VARS " + dsvars);
    // System.out.println("OBS IM Vars" + obsVars);

    BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
    //        DataUtils.ensureVariablesExist(bayesPm, dataSet);

    // Create a new Bayes IM to store the estimated values.
    this.estimatedIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM);

    int numNodes = estimatedIm.getNumNodes();

    for (int node = 0; node < numNodes; node++) {

      int numRows = estimatedIm.getNumRows(node);
      int numCols = estimatedIm.getNumColumns(node);
      int[] parentVarIndices = estimatedIm.getParents(node);
      if (nodes[node].getNodeType() == NodeType.LATENT) {
        continue;
      }

      // int nodeObsIndex = estimatedIm.getCorrespondingNodeIndex(node, observedIm);
      // System.out.println("nodes[node] name = " + nodes[node].getName());
      Node nodeObs = observedIm.getNode(nodes[node].getName());
      // System.out.println("nodeObs name = " + nodeObs.getName());
      int nodeObsIndex = observedIm.getNodeIndex(nodeObs);
      //            int[] parentsObs = observedIm.getParents(nodeObsIndex);

      // System.out.println("For node " + nodes[node] + " parents are:  ");
      boolean anyParentLatent = false;
      for (int parentVarIndice : parentVarIndices) {
        // System.out.println(nodes[parentVarIndices[p]]);
        if (nodes[parentVarIndice].getNodeType() == NodeType.LATENT) {
          anyParentLatent = true;
          break;
        }
      }

      if (anyParentLatent) {
        continue;
      }

      // At this point node is measured in bayesPm and so are its parents.
      for (int row = 0; row < numRows; row++) {
        //                int[] parentValues = estimatedIm.getParentValues(node, row);

        // estimatedIm.randomizeRow(node, row);

        // if the node and all its parents are measured get the probs
        // from observedIm

        // loop:
        for (int col = 0; col < numCols; col++) {
          double p = observedIm.getProbability(nodeObsIndex, row, col);
          estimatedIm.setProbability(node, row, col, p);
        }
      }
    }
  }
  private void initialize() {
    DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPmObs, 0.5);
    observedIm = DirichletEstimator.estimate(prior, dataSet);

    //        MLBayesEstimator dirichEst = new MLBayesEstimator();
    //        observedIm = dirichEst.estimate(bayesPmObs, dataSet);

    //        System.out.println("Estimated Bayes IM for Measured Variables:  ");
    //        System.out.println(observedIm);

    // mixedData should be ddsNm with new columns for the latent variables.
    // Each such column should contain missing data for each case.

    int numFullCases = dataSet.getNumRows();
    List<Node> variables = new LinkedList<Node>();

    for (Node node : nodes) {
      if (node.getNodeType() == NodeType.LATENT) {
        int numCategories = bayesPm.getNumCategories(node);
        DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories);
        variables.add(latentVar);
      } else {
        String name = bayesPm.getVariable(node).getName();
        Node variable = dataSet.getVariable(name);
        variables.add(variable);
      }
    }

    DataSet dsMixed = new ColtDataSet(numFullCases, variables);

    for (int j = 0; j < nodes.length; j++) {
      if (nodes[j].getNodeType() == NodeType.LATENT) {
        for (int i = 0; i < numFullCases; i++) {
          dsMixed.setInt(i, j, -99);
        }
      } else {
        String name = bayesPm.getVariable(nodes[j]).getName();
        Node variable = dataSet.getVariable(name);
        int index = dataSet.getColumn(variable);

        for (int i = 0; i < numFullCases; i++) {
          dsMixed.setInt(i, j, dataSet.getInt(i, index));
        }
      }
    }

    //        System.out.println(dsMixed);

    mixedData = dsMixed;
    allVariables = mixedData.getVariables();

    // Find the bayes net which is parameterized using mixedData or set randomly when that's
    // not possible.
    estimateIM(bayesPm, mixedData);

    // The following DEBUG section tests a case specified by P. Spirtes
    // DEBUG TAIL:   For use with embayes_l1x1x2x3V3.dat
    /*
    Node l1Node = graph.getNode("L1");
    //int l1Index = bayesImMixed.getNodeIndex(l1Node);
    int l1index = estimatedIm.getNodeIndex(l1Node);
    Node x1Node = graph.getNode("X1");
    //int x1Index = bayesImMixed.getNodeIndex(x1Node);
    int x1Index = estimatedIm.getNodeIndex(x1Node);
    Node x2Node = graph.getNode("X2");
    //int x2Index = bayesImMixed.getNodeIndex(x2Node);
    int x2Index = estimatedIm.getNodeIndex(x2Node);
    Node x3Node = graph.getNode("X3");
    //int x3Index = bayesImMixed.getNodeIndex(x3Node);
    int x3Index = estimatedIm.getNodeIndex(x3Node);

    estimatedIm.setProbability(l1index, 0, 0, 0.5);
    estimatedIm.setProbability(l1index, 0, 1, 0.5);

    //bayesImMixed.setProbability(x1Index, 0, 0, 0.33333);
    //bayesImMixed.setProbability(x1Index, 0, 1, 0.66667);
    estimatedIm.setProbability(x1Index, 0, 0, 0.6);      //p(x1 = 0 | l1 = 0)
    estimatedIm.setProbability(x1Index, 0, 1, 0.4);      //p(x1 = 1 | l1 = 0)
    estimatedIm.setProbability(x1Index, 1, 0, 0.4);      //p(x1 = 0 | l1 = 1)
    estimatedIm.setProbability(x1Index, 1, 1, 0.6);      //p(x1 = 1 | l1 = 1)

    //bayesImMixed.setProbability(x2Index, 1, 0, 0.66667);
    //bayesImMixed.setProbability(x2Index, 1, 1, 0.33333);
    estimatedIm.setProbability(x2Index, 1, 0, 0.4);      //p(x2 = 0 | l1 = 1)
    estimatedIm.setProbability(x2Index, 1, 1, 0.6);      //p(x2 = 1 | l1 = 1)
    estimatedIm.setProbability(x2Index, 0, 0, 0.6);      //p(x2 = 0 | l1 = 0)
    estimatedIm.setProbability(x2Index, 0, 1, 0.4);      //p(x2 = 1 | l1 = 0)

    //bayesImMixed.setProbability(x3Index, 1, 0, 0.66667);
    //bayesImMixed.setProbability(x3Index, 1, 1, 0.33333);
    estimatedIm.setProbability(x3Index, 1, 0, 0.4);      //p(x3 = 0 | l1 = 1)
    estimatedIm.setProbability(x3Index, 1, 1, 0.6);      //p(x3 = 1 | l1 = 1)
    estimatedIm.setProbability(x3Index, 0, 0, 0.6);      //p(x3 = 0 | l1 = 0)
    estimatedIm.setProbability(x3Index, 0, 1, 0.4);      //p(x3 = 1 | l1 = 0)
    */
    // END of TAIL

    // System.out.println("bayes IM estimated by estimateIM");
    // System.out.println(bayesImMixed);
    // System.out.println(estimatedIm);

    estimatedCounts = new double[nodes.length][][];
    estimatedCountsDenom = new double[nodes.length][];
    condProbs = new double[nodes.length][][];

    for (int i = 0; i < nodes.length; i++) {
      // int numRows = bayesImMixed.getNumRows(i);
      int numRows = estimatedIm.getNumRows(i);
      estimatedCounts[i] = new double[numRows][];
      estimatedCountsDenom[i] = new double[numRows];
      condProbs[i] = new double[numRows][];
      // for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
      for (int j = 0; j < estimatedIm.getNumRows(i); j++) {
        // int numCols = bayesImMixed.getNumColumns(i);
        int numCols = estimatedIm.getNumColumns(i);
        estimatedCounts[i][j] = new double[numCols];
        condProbs[i][j] = new double[numCols];
      }
    }
  }
  /**
   * This method takes an instantiated Bayes net (BayesIm) whose graph include all the variables
   * (observed and latent) and computes estimated counts using the data in the DataSet mixedData.
   * The counts that are estimated correspond to cells in the conditional probability tables of the
   * Bayes net. The outermost loop (indexed by j) is over the set of variables. If the variable has
   * no parents, each case in the dataset is examined and the count for the observed value of the
   * variables is increased by 1.0; if the value of the variable is missing the marginal
   * probabilities its values given the values of the variables that are available for that case are
   * used to increment the corresponding estimated counts. If a variable has parents then there is a
   * loop which steps through all possible sets of values of its parents. This loop is indexed by
   * the variable "row". Each case in the dataset is examined. It the variable and all its parents
   * have values in the case the corresponding estimated counts are incremented by 1.0. If the
   * variable or any of its parents have missing values, the joint marginal is computed for the
   * variable and the set of values of its parents corresponding to "row" and the corresponding
   * estimated counts are incremented by the appropriate probability. The estimated counts are
   * stored in the double[][][] array estimatedCounts. The count (possibly fractional) of the number
   * of times each combination of parent values occurs is stored in the double[][] array
   * estimatedCountsDenom. These two arrays are used to compute the estimated conditional
   * probabilities of the output Bayes net.
   */
  private BayesIm expectation(BayesIm inputBayesIm) {
    // System.out.println("Entered method expectation.");

    int numCases = mixedData.getNumRows();
    // StoredCellEstCounts estCounts = new StoredCellEstCounts(variables);

    int numVariables = allVariables.size();
    RowSummingExactUpdater rseu = new RowSummingExactUpdater(inputBayesIm);

    for (int j = 0; j < numVariables; j++) {
      DiscreteVariable var = (DiscreteVariable) allVariables.get(j);
      String varName = var.getName();
      Node varNode = graph.getNode(varName);
      int varIndex = inputBayesIm.getNodeIndex(varNode);
      int[] parentVarIndices = inputBayesIm.getParents(varIndex);
      // System.out.println("graph = " + graph);

      // for(int col = 0; col < var.getNumSplits(); col++)
      //    System.out.println("Category " + col + " = " + var.getCategory(col));

      // System.out.println("Updating estimated counts for node " + varName);
      // This segment is for variables with no parents:
      if (parentVarIndices.length == 0) {
        // System.out.println("No parents");
        for (int col = 0; col < var.getNumCategories(); col++) {
          estimatedCounts[j][0][col] = 0.0;
        }

        for (int i = 0; i < numCases; i++) {
          // System.out.println("Case " + i);
          // If this case has a value for var
          if (mixedData.getInt(i, j) != -99) {
            estimatedCounts[j][0][mixedData.getInt(i, j)] += 1.0;
            // System.out.println("Adding 1.0 to " + varName +
            //        " row 0 category " + mixedData[j][i]);
          } else {
            // find marginal probability, given obs data in this case, p(v=0)
            Evidence evidenceThisCase = Evidence.tautology(inputBayesIm);
            boolean existsEvidence = false;

            // Define evidence for updating by using the values of the other vars.
            for (int k = 0; k < numVariables; k++) {
              if (k == j) {
                continue;
              }
              Node otherVar = allVariables.get(k);
              if (mixedData.getInt(i, k) == -99) {
                continue;
              }
              existsEvidence = true;
              String otherVarName = otherVar.getName();
              Node otherNode = graph.getNode(otherVarName);
              int otherIndex = inputBayesIm.getNodeIndex(otherNode);

              evidenceThisCase.getProposition().setCategory(otherIndex, mixedData.getInt(i, k));
            }

            if (!existsEvidence) {
              continue; // No other variable contained useful data
            }

            rseu.setEvidence(evidenceThisCase);

            for (int m = 0; m < var.getNumCategories(); m++) {
              estimatedCounts[j][0][m] += rseu.getMarginal(varIndex, m);
              // System.out.println("Adding " + p + " to " + varName +
              //        " row 0 category " + m);

              // find marginal probability, given obs data in this case, p(v=1)
              // estimatedCounts[j][0][1] += 0.5;
            }
          }
        }

        // Print estimated counts:
        // System.out.println("Estimated counts:  ");

        // Print counts for each value of this variable with no parents.
        // for(int m = 0; m < var.getNumSplits(); m++)
        //    System.out.print("    " + m + " " + estimatedCounts[j][0][m]);
        // System.out.println();
      } else { // For variables with parents:
        int numRows = inputBayesIm.getNumRows(varIndex);
        for (int row = 0; row < numRows; row++) {
          int[] parValues = inputBayesIm.getParentValues(varIndex, row);
          estimatedCountsDenom[varIndex][row] = 0.0;
          for (int col = 0; col < var.getNumCategories(); col++) {
            estimatedCounts[varIndex][row][col] = 0.0;
          }

          for (int i = 0; i < numCases; i++) {
            // for a case where the parent values = parValues increment the estCount

            boolean parentMatch = true;

            for (int p = 0; p < parentVarIndices.length; p++) {
              if (parValues[p] != mixedData.getInt(i, parentVarIndices[p])
                  && mixedData.getInt(i, parentVarIndices[p]) != -99) {
                parentMatch = false;
                break;
              }
            }

            if (!parentMatch) {
              continue; // Not a matching case; go to next.
            }

            boolean parentMissing = false;
            for (int parentVarIndice : parentVarIndices) {
              if (mixedData.getInt(i, parentVarIndice) == -99) {
                parentMissing = true;
                break;
              }
            }

            if (mixedData.getInt(i, j) != -99 && !parentMissing) {
              estimatedCounts[j][row][mixedData.getInt(i, j)] += 1.0;
              estimatedCountsDenom[j][row] += 1.0;
              continue; // Next case
            }

            // for a case with missing data (either var or one of its parents)
            // compute the joint marginal
            // distribution for var & this combination of values of its parents
            // and update the estCounts accordingly

            // To compute marginals create the evidence
            boolean existsEvidence = false;

            Evidence evidenceThisCase = Evidence.tautology(inputBayesIm);

            // "evidenceVars" not used.
            //                        List<String> evidenceVars = new LinkedList<String>();
            //                        for (int k = 0; k < numVariables; k++) {
            //                            //if(k == j) continue;
            //                            Variable otherVar = allVariables.get(k);
            //                            if (mixedData.getInt(i, k) == -99) {
            //                                continue;
            //                            }
            //                            existsEvidence = true;
            //                            String otherVarName = otherVar.getName();
            //                            Node otherNode = graph.getNode(otherVarName);
            //                            int otherIndex = inputBayesIm.getNodeIndex(
            //                                    otherNode);
            //                            evidenceThisCase.getProposition().setCategory(
            //                                    otherIndex, mixedData.getInt(i, k));
            //                            evidenceVars.add(otherVarName);
            //                        }

            if (!existsEvidence) {
              continue;
            }

            rseu.setEvidence(evidenceThisCase);

            estimatedCountsDenom[j][row] += rseu.getJointMarginal(parentVarIndices, parValues);

            int[] parPlusChildIndices = new int[parentVarIndices.length + 1];
            int[] parPlusChildValues = new int[parentVarIndices.length + 1];

            parPlusChildIndices[0] = varIndex;
            for (int pc = 1; pc < parPlusChildIndices.length; pc++) {
              parPlusChildIndices[pc] = parentVarIndices[pc - 1];
              parPlusChildValues[pc] = parValues[pc - 1];
            }

            for (int m = 0; m < var.getNumCategories(); m++) {

              parPlusChildValues[0] = m;

              /*
              if(varName.equals("X1") && i == 0 ) {
                  System.out.println("Calling getJointMarginal with parvalues");
                  for(int k = 0; k < parPlusChildIndices.length; k++) {
                      int pIndex = parPlusChildIndices[k];
                      Node pNode = inputBayesIm.getNode(pIndex);
                      String pName = pNode.getName();
                      System.out.println(pName + " " + parPlusChildValues[k]);
                  }
              }
              */

              /*
              if(varName.equals("X1") && i == 0 ) {
                  System.out.println("Evidence = " + evidenceThisCase);
                  //int[] vars = {l1Index, x1Index};
                  Node nodex1 = inputBayesIm.getNode("X1");
                  int x1Index = inputBayesIm.getNodeIndex(nodex1);
                  Node nodel1 = inputBayesIm.getNode("L1");
                  int l1Index = inputBayesIm.getNodeIndex(nodel1);

                  int[] vars = {l1Index, x1Index};
                  int[] vals = {0, 0};
                  double ptest = rseu.getJointMarginal(vars, vals);
                  System.out.println("Joint marginal (X1=0, L1 = 0) = " + p);
              }
              */

              estimatedCounts[j][row][m] +=
                  rseu.getJointMarginal(parPlusChildIndices, parPlusChildValues);

              // System.out.println("Case " + i + " parent values ");
              // for (int pp = 0; pp < parentVarIndices.length; pp++) {
              //    Variable par = (Variable) allVariables.get(parentVarIndices[pp]);
              //    System.out.print("    " + par.getName() + " " + parValues[pp]);
              // }

              // System.out.println();
              // System.out.println("Adding " + p + " to " + varName +
              //        " row " + row + " category " + m);

            }
            // }
          }

          // Print estimated counts:
          // System.out.println("Estimated counts:  ");
          // System.out.println("    Parent values:  ");
          // for (int i = 0; i < parentVarIndices.length; i++) {
          //    Variable par = (Variable) allVariables.get(parentVarIndices[i]);
          //    System.out.print("    " + par.getName() + " " + parValues[i] + "    ");
          // }
          // System.out.println();

          // for(int m = 0; m < var.getNumSplits(); m++)
          //    System.out.print("    " + m + " " + estimatedCounts[j][row][m]);
          // System.out.println();

        }
      } // else
    } // j < numVariables

    BayesIm outputBayesIm = new MlBayesIm(bayesPm);

    for (int j = 0; j < nodes.length; j++) {

      DiscreteVariable var = (DiscreteVariable) allVariables.get(j);
      String varName = var.getName();
      Node varNode = graph.getNode(varName);
      int varIndex = inputBayesIm.getNodeIndex(varNode);
      //            int[] parentVarIndices = inputBayesIm.getParents(varIndex);

      int numRows = inputBayesIm.getNumRows(j);
      // System.out.println("Conditional probabilities for variable " + varName);

      int numCols = inputBayesIm.getNumColumns(j);
      if (numRows == 1) {
        double sum = 0.0;
        for (int m = 0; m < numCols; m++) {
          sum += estimatedCounts[j][0][m];
        }

        for (int m = 0; m < numCols; m++) {
          condProbs[j][0][m] = estimatedCounts[j][0][m] / sum;
          // System.out.print("  " + condProbs[j][0][m]);
          outputBayesIm.setProbability(varIndex, 0, m, condProbs[j][0][m]);
        }
        // System.out.println();
      } else {

        for (int row = 0; row < numRows; row++) {
          //                    int[] parValues = inputBayesIm.getParentValues(varIndex,
          //                            row);
          // int numCols = inputBayesIm.getNumColumns(j);

          // for (int p = 0; p < parentVarIndices.length; p++) {
          //    Variable par = (Variable) allVariables.get(parentVarIndices[p]);
          //    System.out.print("    " + par.getName() + " " + parValues[p]);
          // }

          // double sum = 0.0;
          // for(int m = 0; m < numCols; m++)
          //    sum += estimatedCounts[j][row][m];

          for (int m = 0; m < numCols; m++) {
            if (estimatedCountsDenom[j][row] != 0.0) {
              condProbs[j][row][m] = estimatedCounts[j][row][m] / estimatedCountsDenom[j][row];
            } else {
              condProbs[j][row][m] = Double.NaN;
            }
            // System.out.print("  " + condProbs[j][row][m]);
            outputBayesIm.setProbability(varIndex, row, m, condProbs[j][row][m]);
          }
          // System.out.println();

        }
      }
    }

    return outputBayesIm;
  }
 public EmBayesEstimator(BayesIm inputBayesIm, DataSet dataSet) {
   this(inputBayesIm.getBayesPm(), dataSet);
   // this.inputBayesIm = inputBayesIm;
 }