예제 #1
0
파일: MBFS.java 프로젝트: bd2kccd/r-causal
  @Override
  public Graph search(DataSet dataSet, Parameters parameters) {
    edu.cmu.tetrad.search.Mbfs search =
        new edu.cmu.tetrad.search.Mbfs(
            test.getTest(dataSet, parameters), parameters.getInt("depth"));

    search.setKnowledge(knowledge);

    this.targetName = parameters.getString("targetName");
    Node target = dataSet.getVariable(targetName);
    return search.search(target);
  }
예제 #2
0
  /** Creates a cell count table for the given data set. */
  public DataSetProbs(DataSet dataSet) {
    if (dataSet == null) {
      throw new NullPointerException();
    }

    this.dataSet = dataSet;
    dims = new int[dataSet.getNumColumns()];

    for (int i = 0; i < dims.length; i++) {
      DiscreteVariable variable = (DiscreteVariable) dataSet.getVariable(i);
      dims[i] = variable.getNumCategories();
    }

    numRows = dataSet.getNumRows();
  }
예제 #3
0
  public void testManualDiscretize3() {
    Graph graph = new Dag(GraphUtils.randomGraph(5, 0, 5, 3, 3, 3, false));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(100, false);

    List<Node> nodes = data.getVariables();

    Discretizer discretizer = new Discretizer(data);
    discretizer.setVariablesCopied(true);

    discretizer.setVariablesCopied(true);
    discretizer.equalCounts(nodes.get(0), 3);

    DataSet discretized = discretizer.discretize();

    System.out.println(discretized);

    assertTrue(discretized.getVariable(0) instanceof DiscreteVariable);
    assertTrue(discretized.getVariable(1) instanceof ContinuousVariable);
    assertTrue(discretized.getVariable(2) instanceof ContinuousVariable);
    assertTrue(discretized.getVariable(3) instanceof ContinuousVariable);
    assertTrue(discretized.getVariable(4) instanceof ContinuousVariable);
  }
  private List<Node> expandVariable(DataSet dataSet, Node node) {
    if (node instanceof ContinuousVariable) {
      return Collections.singletonList(node);
    }

    if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) {
      return Collections.singletonList(node);
    }

    if (!(node instanceof DiscreteVariable)) {
      throw new IllegalArgumentException();
    }

    List<String> varCats = new ArrayList<String>(((DiscreteVariable) node).getCategories());

    // first category is reference
    varCats.remove(0);
    List<Node> variables = new ArrayList<Node>();

    for (String cat : varCats) {

      Node newVar;

      do {
        String newVarName = node.getName() + "MULTINOM" + "." + cat;
        newVar = new DiscreteVariable(newVarName, 2);
      } while (dataSet.getVariable(newVar.getName()) != null);

      variables.add(newVar);

      dataSet.addVariable(newVar);
      int newVarIndex = dataSet.getColumn(newVar);
      int numCases = dataSet.getNumRows();

      for (int l = 0; l < numCases; l++) {
        Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
        int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString());

        if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat))
          dataSet.setInt(l, newVarIndex, 1);
        else dataSet.setInt(l, newVarIndex, 0);
      }
    }

    return variables;
  }
예제 #5
0
파일: Tsls.java 프로젝트: jdramsey/tetrad
  protected SemIm estimateCoeffs(SemIm semIm) {

    // System.out.print("\n****************\nCalling 2SLS... ");
    SemGraph semGraph = semIm.getSemPm().getGraph();

    // Get list of fixed measurements that will be kept fixed, and the
    // respective latent variables that are their parents.
    // "X" variables are exogenous, while "Y" variables are endogenous.
    List<Node> ly = new LinkedList<Node>();
    List<Node> lx = new LinkedList<Node>();
    List<Node> my1 = new LinkedList<Node>();
    List<Node> mx1 = new LinkedList<Node>();
    List<Node> observed = new LinkedList<Node>();

    for (Node nodeA : semGraph.getNodes()) {
      if (nodeA.getNodeType() == NodeType.ERROR) {
        continue;
      }
      if (nodeA.getNodeType() == NodeType.LATENT) {
        if (semGraph.getParents(nodeA).size() == 0) {
          lx.add(nodeA);
        } else {
          ly.add(nodeA);
        }
      } else {
        observed.add(nodeA);
      }
    }
    setFixedNodes(semGraph, mx1, my1);

    // ------------------------------------------------------------------

    // Estimate freeParameters for the latent/latent edges
    for (Node current : ly) {
      if (nodeName != null && !nodeName.equals(current.getName())) {
        continue;
      }
      // Build Z, the matrix containing the data for the fixed measurements
      // associated with the parents of the getModel (endogenous) latent node
      List<Node> endo_parents_m = new LinkedList<Node>();
      List<Node> exo_parents_m = new LinkedList<Node>();
      List<Node> endo_parents = new LinkedList<Node>();
      List<Node> exo_parents = new LinkedList<Node>();
      Iterator<Node> it_p = semGraph.getParents(current).iterator();
      lNames = new String[lx.size() + ly.size()];
      while (it_p.hasNext()) {
        Node node = it_p.next();
        if (node.getNodeType() == NodeType.ERROR) {
          continue;
        }
        if (lx.contains(node)) {
          int position = lx.indexOf(node);
          exo_parents_m.add(mx1.get(position));
          exo_parents.add(node);
        } else {
          int position = ly.indexOf(node);
          endo_parents_m.add(my1.get(position));
          endo_parents.add(node);
        }
      }
      Object endp_a_m[] = endo_parents_m.toArray();
      Object exop_a_m[] = exo_parents_m.toArray();
      Object endp_a[] = endo_parents.toArray();
      Object exop_a[] = exo_parents.toArray();
      int n = dataSet.getNumRows(), c = endp_a_m.length + exop_a_m.length;
      if (c == 0) {
        continue;
      }
      double Z[][] = new double[n][c];
      int count = 0;

      for (int i = 0; i < endp_a_m.length; i++) {
        Node node = (Node) endp_a_m[i];
        String name = node.getName();
        Node variable = dataSet.getVariable(name);
        int colIndex = dataSet.getVariables().indexOf(variable);

        //                Column column = dataSet.getColumnObject(variable);
        //                double column_data[] = (double[]) column.getRawData();

        for (int j = 0; j < n; j++) {
          //                    Z[j][i] = column_data[j];
          Z[j][i] = dataSet.getDouble(j, colIndex);
        }

        lNames[count++] = (endo_parents.get(i)).getName();
      }
      for (int i = 0; i < exop_a_m.length; i++) {
        Node node = (Node) exop_a_m[i];
        String name = node.getName();
        Node variable = dataSet.getVariable(name);
        int colIndex = dataSet.getVariables().indexOf(variable);

        //                Column column = dataSet.getColumnObject(variable);
        //                double column_data[] = (double[]) column.getRawData();

        for (int j = 0; j < n; j++) {
          //                    Z[j][endp_a_m.length + i] = column_data[j];
          Z[j][endp_a_m.length + i] = dataSet.getDouble(j, colIndex);
        }
        lNames[count++] = exo_parents.get(i).getName();
      }
      // Build V, the matrix containing the data for the nonfixed measurements
      // associated with the parents of the getModel (endogenous) latent node
      endo_parents_m = new LinkedList<Node>();
      exo_parents_m = new LinkedList<Node>();
      it_p = semGraph.getParents(current).iterator();
      while (it_p.hasNext()) {
        Node node = it_p.next();
        if (node.getNodeType() == NodeType.ERROR) {
          continue;
        }
        List<Node> other_measures = new LinkedList<Node>();

        for (Node next : semGraph.getChildren(node)) {
          if (next.getNodeType() == NodeType.MEASURED) {
            other_measures.add(next);
          }
        }

        if (lx.contains(node)) {
          int position = lx.indexOf(node);
          other_measures.remove(mx1.get(position));
          exo_parents_m.addAll(other_measures);
        } else {
          int position = ly.indexOf(node);
          other_measures.remove(my1.get(position));
          endo_parents_m.addAll(other_measures);
        }
      }
      endp_a_m = endo_parents_m.toArray();
      exop_a_m = exo_parents_m.toArray();
      n = dataSet.getNumRows();
      c = endp_a_m.length + exop_a_m.length;
      double V[][] = new double[n][c];
      if (c == 0) {
        continue;
      }
      for (int i = 0; i < endp_a_m.length; i++) {
        Node node = ((Node) endp_a_m[i]);
        String name = node.getName();
        Node variable = dataSet.getVariable(name);
        int colIndex = dataSet.getVariables().indexOf(variable);

        //                Column column = dataSet.getColumnObject(variable);
        //                double column_data[] = (double[]) column.getRawData();

        for (int j = 0; j < n; j++) {
          //                    V[j][i] = column_data[j];
          V[j][i] = dataSet.getDouble(j, colIndex);
        }
      }
      for (int i = 0; i < exop_a_m.length; i++) {
        Node node = (Node) exop_a_m[i];
        String name = node.getName();
        Node variable = dataSet.getVariable(name);
        int colIndex = dataSet.getVariables().indexOf(variable);

        //                Column column = dataSet.getColumnObject(variable);
        //                double column_data[] = (double[]) column.getRawData();

        for (int j = 0; j < n; j++) {
          //                    V[j][endp_a_m.length + i] = column_data[j];
          V[j][endp_a_m.length + i] = dataSet.getDouble(j, colIndex);
        }
      }
      double yi[] = new double[n];
      if (lx.contains(current)) {
        int position = lx.indexOf(current);
        Node node = mx1.get(position);
        String name = node.getName();
        Node variable = dataSet.getVariable(name);
        int colIndex = dataSet.getVariables().indexOf(variable);

        //                Column column = dataSet.getColumnObject(variable);
        //
        //                System.arraycopy(column.getRawData(), 0, yi, 0, n);

        for (int i = 0; i < n; i++) {
          yi[i] = dataSet.getDouble(i, colIndex);
        }
      } else {
        int position = ly.indexOf(current);
        Node node = my1.get(position);
        String name = node.getName();
        Node variable = dataSet.getVariable(name);
        int colIndex = dataSet.getVariables().indexOf(variable);

        //                System.arraycopy(dataSet.getColumnObject(variable).getRawData(), 0, yi, 0,
        // n);

        for (int i = 0; i < n; i++) {
          yi[i] = dataSet.getDouble(i, colIndex);
        }
      }
      // Build Z_hat
      double Z_hat[][] =
          MatrixUtils.product(
              V,
              MatrixUtils.product(
                  MatrixUtils.inverse(MatrixUtils.product(MatrixUtils.transpose(V), V)),
                  MatrixUtils.product(MatrixUtils.transpose(V), Z)));
      A_hat =
          MatrixUtils.product(
              MatrixUtils.inverse(MatrixUtils.product(MatrixUtils.transpose(Z_hat), Z_hat)),
              MatrixUtils.product(MatrixUtils.transpose(Z_hat), yi));
      // Set the edge for the fixed measurement
      int position = ly.indexOf(current);
      semIm.setParamValue(current, my1.get(position), 1.);
      // Set the edge for the latents
      for (int i = 0; i < endp_a.length; i++) {
        semIm.setParamValue((Node) endp_a[i], current, A_hat[i]);
      }
      for (int i = 0; i < exop_a.length; i++) {
        semIm.setParamValue((Node) exop_a[i], current, A_hat[endp_a.length + i]);
      }
      if (nodeName != null && nodeName.equals(current.getName())) {
        computeAsymptLatentCovar(yi, A_hat, Z, Z_hat, dataSet.getNumRows());
        break;
      }
    }

    // ------------------------------------------------------------------

    // Estimate freeParameters of the measurement model

    // Set the edges of the fixed measurements of exogenous
    for (Node current : lx) {
      int position = lx.indexOf(current);
      semIm.setParamValue(current, mx1.get(position), 1.);
    }

    for (Node current : observed) {
      if (nodeName != null && !nodeName.equals(current.getName())) {
        continue;
      }
      if (mx1.contains(current) || my1.contains(current)) {
        continue;
      }

      // First, get the parent of this observed
      Node current_latent = null;

      for (Node node : semGraph.getParents(current)) {
        if (node.getNodeType() == NodeType.ERROR) {
          continue;
        }
        current_latent = node;
      }
      Iterator<Node> children = semGraph.getChildren(current_latent).iterator();
      List<Node> other_measures = new LinkedList<Node>();
      Node fixed_measurement;
      while (children.hasNext()) {
        Node next = children.next();
        if ((next.getNodeType() == NodeType.MEASURED) && next != current) {
          other_measures.add(next);
        }
      }
      if (lx.contains(current_latent)) {
        int position = lx.indexOf(current_latent);
        other_measures.remove(mx1.get(position));
        fixed_measurement = mx1.get(position);
      } else {
        int position = ly.indexOf(current_latent);
        other_measures.remove(my1.get(position));
        fixed_measurement = my1.get(position);
      }
      // Regress other_measures over the fixed measurement x1 (y1) correspondent
      // to the measurement variable that is being evaluated
      int n = dataSet.getNumRows(), c = other_measures.size();
      if (c == 0) {
        continue;
      }
      double Z[][] = new double[n][c];
      for (int i = 0; i < c; i++) {
        Node variable = dataSet.getVariable((other_measures.get(i)).getName());
        int varIndex = dataSet.getVariables().indexOf(variable);

        //                Column column = dataSet.getColumnObject(variable);
        //                double column_data[] = (double[]) column.getRawData();

        for (int j = 0; j < n; j++) {
          //                    Z[j][i] = column_data[j];
          Z[j][i] = dataSet.getDouble(varIndex, j);
        }
      }

      // Build C, the column matrix containing the data for the fixed
      // measurement associated with the only latent parent of the getModel
      // observed node (as assumed by the structure of our measurement model).
      Node variable = dataSet.getVariable(fixed_measurement.getName());
      int colIndex = dataSet.getVariables().indexOf(variable);
      //            Column column = dataSet.getColumnObject(variable);
      //            double C[] = (double[]) column.getRawData();

      double[] C = new double[dataSet.getNumRows()];

      for (int i = 0; i < dataSet.getNumRows(); i++) {
        C[i] = dataSet.getDouble(colIndex, i);
      }

      // Build V, the matrix containing the data for the other measurements
      // associated with the parents of the (latent) parent of getModel
      // observed node. The only difference with respect to the estimation
      // of the within-latent coefficients is that here we only include
      // the other measurements attached to the parent of the getModel node,
      // assuming that the error term of the getModel node is independent
      // of the error term of the others and that each measurement is
      // taken with respect to only one latent.
      n = dataSet.getNumRows();
      c = other_measures.size();
      double V[][] = new double[n][c];
      for (int i = 0; i < c; i++) {
        Node variable2 = dataSet.getVariable((other_measures.get(i)).getName());
        int var2index = dataSet.getVariables().indexOf(variable2);

        //                Column column = dataSet.getColumnObject(variable2);
        //                double column_data[] = (double[]) column.getRawData();

        for (int j = 0; j < n; j++) {
          //                    V[j][i] = column_data[j];
          V[j][i] = dataSet.getDouble(j, var2index);
        }
      }
      double yi[] = new double[n];
      Node variable3 = dataSet.getVariable((current).getName());
      int var3Index = dataSet.getVariables().indexOf(variable3);

      for (int i = 0; i < n; i++) {
        yi[i] = dataSet.getDouble(i, var3Index);
      }

      //            Object rawData = dataSet.getColumnObject(variable3).getRawData();
      //            System.arraycopy(rawData, 0, yi, 0, n);
      double C_hat[] =
          MatrixUtils.product(
              V,
              MatrixUtils.product(
                  MatrixUtils.inverse(MatrixUtils.product(MatrixUtils.transpose(V), V)),
                  MatrixUtils.product(MatrixUtils.transpose(V), C)));
      double A_hat =
          MatrixUtils.innerProduct(
              MatrixUtils.scalarProduct(1. / MatrixUtils.innerProduct(C_hat, C_hat), C_hat), yi);
      // Set the edge for the getModel measurement
      semIm.setParamValue(current_latent, current, A_hat);
    }

    return semIm;
  }
  private void initialize() {
    DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(bayesPmObs, 0.5);
    observedIm = DirichletEstimator.estimate(prior, dataSet);

    //        MLBayesEstimator dirichEst = new MLBayesEstimator();
    //        observedIm = dirichEst.estimate(bayesPmObs, dataSet);

    //        System.out.println("Estimated Bayes IM for Measured Variables:  ");
    //        System.out.println(observedIm);

    // mixedData should be ddsNm with new columns for the latent variables.
    // Each such column should contain missing data for each case.

    int numFullCases = dataSet.getNumRows();
    List<Node> variables = new LinkedList<Node>();

    for (Node node : nodes) {
      if (node.getNodeType() == NodeType.LATENT) {
        int numCategories = bayesPm.getNumCategories(node);
        DiscreteVariable latentVar = new DiscreteVariable(node.getName(), numCategories);
        variables.add(latentVar);
      } else {
        String name = bayesPm.getVariable(node).getName();
        Node variable = dataSet.getVariable(name);
        variables.add(variable);
      }
    }

    DataSet dsMixed = new ColtDataSet(numFullCases, variables);

    for (int j = 0; j < nodes.length; j++) {
      if (nodes[j].getNodeType() == NodeType.LATENT) {
        for (int i = 0; i < numFullCases; i++) {
          dsMixed.setInt(i, j, -99);
        }
      } else {
        String name = bayesPm.getVariable(nodes[j]).getName();
        Node variable = dataSet.getVariable(name);
        int index = dataSet.getColumn(variable);

        for (int i = 0; i < numFullCases; i++) {
          dsMixed.setInt(i, j, dataSet.getInt(i, index));
        }
      }
    }

    //        System.out.println(dsMixed);

    mixedData = dsMixed;
    allVariables = mixedData.getVariables();

    // Find the bayes net which is parameterized using mixedData or set randomly when that's
    // not possible.
    estimateIM(bayesPm, mixedData);

    // The following DEBUG section tests a case specified by P. Spirtes
    // DEBUG TAIL:   For use with embayes_l1x1x2x3V3.dat
    /*
    Node l1Node = graph.getNode("L1");
    //int l1Index = bayesImMixed.getNodeIndex(l1Node);
    int l1index = estimatedIm.getNodeIndex(l1Node);
    Node x1Node = graph.getNode("X1");
    //int x1Index = bayesImMixed.getNodeIndex(x1Node);
    int x1Index = estimatedIm.getNodeIndex(x1Node);
    Node x2Node = graph.getNode("X2");
    //int x2Index = bayesImMixed.getNodeIndex(x2Node);
    int x2Index = estimatedIm.getNodeIndex(x2Node);
    Node x3Node = graph.getNode("X3");
    //int x3Index = bayesImMixed.getNodeIndex(x3Node);
    int x3Index = estimatedIm.getNodeIndex(x3Node);

    estimatedIm.setProbability(l1index, 0, 0, 0.5);
    estimatedIm.setProbability(l1index, 0, 1, 0.5);

    //bayesImMixed.setProbability(x1Index, 0, 0, 0.33333);
    //bayesImMixed.setProbability(x1Index, 0, 1, 0.66667);
    estimatedIm.setProbability(x1Index, 0, 0, 0.6);      //p(x1 = 0 | l1 = 0)
    estimatedIm.setProbability(x1Index, 0, 1, 0.4);      //p(x1 = 1 | l1 = 0)
    estimatedIm.setProbability(x1Index, 1, 0, 0.4);      //p(x1 = 0 | l1 = 1)
    estimatedIm.setProbability(x1Index, 1, 1, 0.6);      //p(x1 = 1 | l1 = 1)

    //bayesImMixed.setProbability(x2Index, 1, 0, 0.66667);
    //bayesImMixed.setProbability(x2Index, 1, 1, 0.33333);
    estimatedIm.setProbability(x2Index, 1, 0, 0.4);      //p(x2 = 0 | l1 = 1)
    estimatedIm.setProbability(x2Index, 1, 1, 0.6);      //p(x2 = 1 | l1 = 1)
    estimatedIm.setProbability(x2Index, 0, 0, 0.6);      //p(x2 = 0 | l1 = 0)
    estimatedIm.setProbability(x2Index, 0, 1, 0.4);      //p(x2 = 1 | l1 = 0)

    //bayesImMixed.setProbability(x3Index, 1, 0, 0.66667);
    //bayesImMixed.setProbability(x3Index, 1, 1, 0.33333);
    estimatedIm.setProbability(x3Index, 1, 0, 0.4);      //p(x3 = 0 | l1 = 1)
    estimatedIm.setProbability(x3Index, 1, 1, 0.6);      //p(x3 = 1 | l1 = 1)
    estimatedIm.setProbability(x3Index, 0, 0, 0.6);      //p(x3 = 0 | l1 = 0)
    estimatedIm.setProbability(x3Index, 0, 1, 0.4);      //p(x3 = 1 | l1 = 0)
    */
    // END of TAIL

    // System.out.println("bayes IM estimated by estimateIM");
    // System.out.println(bayesImMixed);
    // System.out.println(estimatedIm);

    estimatedCounts = new double[nodes.length][][];
    estimatedCountsDenom = new double[nodes.length][];
    condProbs = new double[nodes.length][][];

    for (int i = 0; i < nodes.length; i++) {
      // int numRows = bayesImMixed.getNumRows(i);
      int numRows = estimatedIm.getNumRows(i);
      estimatedCounts[i] = new double[numRows][];
      estimatedCountsDenom[i] = new double[numRows];
      condProbs[i] = new double[numRows][];
      // for(int j = 0; j < bayesImMixed.getNumRows(i); j++) {
      for (int j = 0; j < estimatedIm.getNumRows(i); j++) {
        // int numCols = bayesImMixed.getNumColumns(i);
        int numCols = estimatedIm.getNumColumns(i);
        estimatedCounts[i][j] = new double[numCols];
        condProbs[i][j] = new double[numCols];
      }
    }
  }
예제 #7
0
  public final DataSet filter(DataSet dataSet) {

    // Why does it have to be discrete? Why can't we simply expand
    // whatever discrete columns are there and leave the continuous
    // ones untouched? jdramsey 7/4/2005
    //        if (!(dataSet.isDiscrete())) {
    //            throw new IllegalArgumentException("Data set must be discrete.");
    //        }

    List<Node> variables = new LinkedList<>();

    // Add all of the variables to the new data set.
    for (int j = 0; j < dataSet.getNumColumns(); j++) {
      Node _var = dataSet.getVariable(j);

      if (!(_var instanceof DiscreteVariable)) {
        variables.add(_var);
        continue;
      }

      DiscreteVariable variable = (DiscreteVariable) _var;

      String oldName = variable.getName();
      List<String> oldCategories = variable.getCategories();
      List<String> newCategories = new LinkedList<>(oldCategories);

      String newCategory = "Missing";
      int _j = 0;

      while (oldCategories.contains(newCategory)) {
        newCategory = "Missing" + (++_j);
      }

      newCategories.add(newCategory);
      String newName = oldName + "+";
      DiscreteVariable newVariable = new DiscreteVariable(newName, newCategories);

      variables.add(newVariable);
    }

    DataSet newDataSet = new ColtDataSet(dataSet.getNumRows(), variables);

    // Copy old values to new data set, replacing missing values with new
    // "MissingValue" categories.
    for (int j = 0; j < dataSet.getNumColumns(); j++) {
      Node _var = dataSet.getVariable(j);

      if (_var instanceof ContinuousVariable) {
        for (int i = 0; i < dataSet.getNumRows(); i++) {
          newDataSet.setDouble(i, j, dataSet.getDouble(i, j));
        }
      } else if (_var instanceof DiscreteVariable) {
        DiscreteVariable variable = (DiscreteVariable) _var;
        int numCategories = variable.getNumCategories();

        for (int i = 0; i < dataSet.getNumRows(); i++) {
          int value = dataSet.getInt(i, j);

          if (value == DiscreteVariable.MISSING_VALUE) {
            newDataSet.setInt(i, j, numCategories);
          } else {
            newDataSet.setInt(i, j, value);
          }
        }
      }
    }

    return newDataSet;
  }