private double multiLL(DoubleMatrix2D coeffs, Node dep, List<Node> indep) {

    DoubleMatrix2D indepData =
        factory2D.make(internalData.subsetColumns(indep).getDoubleData().toArray());
    List<Node> depList = new ArrayList<>();
    depList.add(dep);
    DoubleMatrix2D depData =
        factory2D.make(internalData.subsetColumns(depList).getDoubleData().toArray());

    int N = indepData.rows();
    DoubleMatrix2D probs =
        Algebra.DEFAULT.mult(factory2D.appendColumns(factory2D.make(N, 1, 1.0), indepData), coeffs);

    probs =
        factory2D
            .appendColumns(factory2D.make(indepData.rows(), 1, 1.0), probs)
            .assign(Functions.exp);
    double ll = 0;
    for (int i = 0; i < N; i++) {
      DoubleMatrix1D curRow = probs.viewRow(i);
      curRow.assign(Functions.div(curRow.zSum()));
      ll += Math.log(curRow.get((int) depData.get(i, 0)));
    }
    return ll;
  }
예제 #2
0
  /** Find the dataModel model. (If it's a list, take the one that's selected.) */
  private DataModel getSelectedDataModel(DataWrapper dataWrapper) {
    DataModelList dataModelList = dataWrapper.getDataModelList();

    if (dataModelList.size() > 1) {
      return dataModelList;
    }

    DataModel dataModel = dataWrapper.getSelectedDataModel();

    if (dataModel instanceof DataSet) {
      DataSet dataSet = (DataSet) dataModel;

      if (dataSet.isDiscrete()) {
        return dataSet;
      } else if (dataSet.isContinuous()) {
        return dataSet;
      } else if (dataSet.isMixed()) {
        return dataSet;
      }

      throw new IllegalArgumentException(
          "<html>"
              + "This data set contains a mixture of discrete and continuous "
              + "<br>columns; there are no algorithm in Tetrad currently to "
              + "<br>search over such data sets."
              + "</html>");
    } else if (dataModel instanceof ICovarianceMatrix) {
      return dataModel;
    } else if (dataModel instanceof TimeSeriesData) {
      return dataModel;
    }

    throw new IllegalArgumentException("Unexpected dataModel source: " + dataModel);
  }
예제 #3
0
  // Causes a package cycle.
  public void testManualDiscretize2() {
    Graph graph = new Dag(GraphUtils.randomGraph(5, 0, 5, 3, 3, 3, false));
    SemPm pm = new SemPm(graph);
    SemIm im = new SemIm(pm);
    DataSet data = im.simulateData(100, false);

    List<Node> nodes = data.getVariables();

    Discretizer discretizer = new Discretizer(data);
    //        discretizer.setVariablesCopied(true);

    discretizer.equalCounts(nodes.get(0), 3);
    discretizer.equalIntervals(nodes.get(1), 2);
    discretizer.equalCounts(nodes.get(2), 5);
    discretizer.equalIntervals(nodes.get(3), 8);
    discretizer.equalCounts(nodes.get(4), 4);

    DataSet discretized = discretizer.discretize();

    System.out.println(discretized);

    assertEquals(2, maxInColumn(discretized, 0));
    assertEquals(1, maxInColumn(discretized, 1));
    assertEquals(4, maxInColumn(discretized, 2));
    assertEquals(7, maxInColumn(discretized, 3));
    assertEquals(3, maxInColumn(discretized, 4));
  }
예제 #4
0
  /**
   * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the
   * extending class.
   */
  public void execute() {
    DataModel source = getDataModel();

    if (!(source instanceof DataSet)) {
      throw new IllegalArgumentException("Expecting a rectangular data set.");
    }

    DataSet data = (DataSet) source;

    if (!data.isContinuous()) {
      throw new IllegalArgumentException("Expecting a continuous data set.");
    }

    //        Lingam_old lingam = new Lingam_old();
    //        lingam.setPruningDone(true);
    //        lingam.setAlpha(getParams().getIndTestParams().getAlpha());
    //        GraphWithParameters result = lingam.lingam(data);
    //        Graph graph = result.getGraph();

    Lingam lingam = new Lingam();
    LingamParams params = (LingamParams) getParams();
    lingam.setPruneFactor(params.getPruneFactor());
    Graph graph = lingam.search(data);

    setResultGraph(graph);

    if (getSourceGraph() != null) {
      GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
    } else {
      GraphUtils.circleLayout(graph, 200, 200, 150);
    }
  }
예제 #5
0
  /** Constructs the score using a covariance matrix. */
  public BdeuScoreImages(List<DataModel> dataModels) {
    if (dataModels == null) {
      throw new NullPointerException();
    }

    List<BDeuScore> scores = new ArrayList<>();

    for (DataModel model : dataModels) {
      if (model instanceof DataSet) {
        DataSet dataSet = (DataSet) model;

        if (!dataSet.isDiscrete()) {
          throw new IllegalArgumentException("Datasets must be continuous.");
        }

        scores.add(new BDeuScore(dataSet));
      } else {
        throw new IllegalArgumentException(
            "Only continuous data sets and covariance matrices may be used as input.");
      }
    }

    List<Node> variables = scores.get(0).getVariables();

    for (int i = 2; i < scores.size(); i++) {
      scores.get(i).setVariables(variables);
    }

    this.scores = scores;
    this.variables = variables;
  }
  public FindOneFactorClustersWithCausalIndicators(
      DataSet dataSet, TestType testType, double alpha) {
    if (dataSet.isContinuous()) {
      this.variables = dataSet.getVariables();
      this.test = new ContinuousTetradTest(dataSet, testType, alpha);
      this.indTest = new IndTestFisherZ(dataSet, alpha);
      this.alpha = alpha;
      this.testType = testType;
      this.dataModel = dataSet;

      if (testType == TestType.TETRAD_DELTA) {
        deltaTest = new DeltaTetradTest(dataSet);
        deltaTest.setCacheFourthMoments(false);
      }

      this.cov = new CovarianceMatrix(dataSet);
    } else if (dataSet.isDiscrete()) {
      this.variables = dataSet.getVariables();
      this.test = new DiscreteTetradTest(dataSet, alpha);
      this.indTest = new IndTestChiSquare(dataSet, alpha);
      this.alpha = alpha;
      this.testType = testType;
      this.dataModel = dataSet;

      if (testType == TestType.TETRAD_DELTA) {
        deltaTest = new DeltaTetradTest(dataSet);
        deltaTest.setCacheFourthMoments(false);
      }
    }
  }
  public IndTestFisherZPercentIndependent(List<DataSet> dataSets, double alpha) {
    this.dataSets = dataSets;
    this.variables = dataSets.get(0).getVariables();

    data = new ArrayList<TetradMatrix>();

    for (DataSet dataSet : dataSets) {
      dataSet = DataUtils.center(dataSet);
      TetradMatrix _data = dataSet.getDoubleData();
      data.add(_data);
    }

    ncov = new ArrayList<TetradMatrix>();
    for (TetradMatrix d : this.data) ncov.add(d.transpose().times(d).scalarMult(1.0 / d.rows()));

    setAlpha(alpha);
    rows = new int[dataSets.get(0).getNumRows()];
    for (int i = 0; i < getRows().length; i++) getRows()[i] = i;

    variablesMap = new HashMap<Node, Integer>();
    for (int i = 0; i < variables.size(); i++) {
      variablesMap.put(variables.get(i), i);
    }

    this.recursivePartialCorrelation = new ArrayList<RecursivePartialCorrelation>();
    for (TetradMatrix covMatrix : ncov) {
      recursivePartialCorrelation.add(new RecursivePartialCorrelation(getVariables(), covMatrix));
    }
  }
예제 #8
0
  /**
   * Factory to return the correct param editor for independence test params. This will go in a
   * little box in the search editor.
   */
  private JComponent getIndTestParamBox(IndTestParams indTestParams) {
    if (indTestParams == null) {
      throw new NullPointerException();
    }

    if (indTestParams instanceof GesIndTestParams) {
      if (getAlgorithmRunner() instanceof IGesRunner) {
        IGesRunner gesRunner = ((IGesRunner) getAlgorithmRunner());
        GesIndTestParams params = (GesIndTestParams) indTestParams;
        DataModel dataModel = gesRunner.getDataModel();
        boolean discreteData = dataModel instanceof DataSet && ((DataSet) dataModel).isDiscrete();
        return new GesIndTestParamsEditor(params, discreteData);
      }

      if (getAlgorithmRunner() instanceof ImagesRunner) {
        ImagesRunner gesRunner = ((ImagesRunner) getAlgorithmRunner());
        GesIndTestParams params = (GesIndTestParams) indTestParams;
        DataSet dataSet = (DataSet) gesRunner.getDataModel();
        boolean discreteData = dataSet.isDiscrete();
        return new GesIndTestParamsEditor(params, discreteData);
      }
    }

    return new IndTestParamsEditor(indTestParams);
  }
예제 #9
0
  /**
   * Constructs a test using a given data set. If a data set is provided (that is, a tabular data
   * set), fourth moment statistics can be calculated (p. 160); otherwise, it must be assumed that
   * the data are multivariate Gaussian.
   */
  public DeltaSextadTest(DataSet dataSet) {
    if (dataSet == null) {
      throw new NullPointerException();
    }

    if (!dataSet.isContinuous()) {
      throw new IllegalArgumentException();
    }

    this.cov = new CovarianceMatrix(dataSet);

    List<DataSet> data1 = new ArrayList<DataSet>();
    data1.add(dataSet);
    List<DataSet> data2 = DataUtils.center(data1);

    this.dataSet = data2.get(0);

    this.data = this.dataSet.getDoubleData().transpose().toArray();
    this.N = dataSet.getNumRows();
    this.variables = dataSet.getVariables();
    this.numVars = dataSet.getNumColumns();

    this.variablesHash = new HashMap<Node, Integer>();

    for (int i = 0; i < variables.size(); i++) {
      variablesHash.put(variables.get(i), i);
    }

    this.means = new double[numVars];

    for (int i = 0; i < numVars; i++) {
      means[i] = mean(data[i], N);
    }
  }
  /**
   * Constructs a new Fisher Z independence test with the listed arguments.
   *
   * @param data A 2D continuous data set with no missing values.
   * @param variables A list of variables, a subset of the variables of <code>data</code>.
   * @param alpha The significance cutoff level. p values less than alpha will be reported as
   *     dependent.
   */
  public IndTestFisherZShortTriangular(TetradMatrix data, List<Node> variables, double alpha) {
    DataSet dataSet = ColtDataSet.makeContinuousData(variables, data);
    this.covMatrix = new ShortTriangularMatrix(dataSet.getNumColumns());
    this.covMatrix.becomeCorrelationMatrix(dataSet);
    this.variables = dataSet.getVariables();
    setAlpha(alpha);

    this.deterministicTest = new IndTestFisherZGeneralizedInverse(dataSet, alpha);
  }
예제 #11
0
  @Test
  public void test2() {
    RandomUtil.getInstance().setSeed(2999983L);

    int sampleSize = 1000;

    List<Node> variableNodes = new ArrayList<>();
    ContinuousVariable x1 = new ContinuousVariable("X1");
    ContinuousVariable x2 = new ContinuousVariable("X2");
    ContinuousVariable x3 = new ContinuousVariable("X3");
    ContinuousVariable x4 = new ContinuousVariable("X4");
    ContinuousVariable x5 = new ContinuousVariable("X5");

    variableNodes.add(x1);
    variableNodes.add(x2);
    variableNodes.add(x3);
    variableNodes.add(x4);
    variableNodes.add(x5);

    Graph _graph = new EdgeListGraph(variableNodes);
    SemGraph graph = new SemGraph(_graph);
    graph.addDirectedEdge(x1, x3);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x2, x4);
    graph.addDirectedEdge(x4, x5);
    graph.addDirectedEdge(x2, x5);

    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    DataSet dataSet = semIm.simulateData(sampleSize, false);

    print(semPm);

    GeneralizedSemPm _semPm = new GeneralizedSemPm(semPm);
    GeneralizedSemIm _semIm = new GeneralizedSemIm(_semPm, semIm);
    DataSet _dataSet = _semIm.simulateDataMinimizeSurface(sampleSize, false);

    print(_semPm);

    //        System.out.println(_dataSet);

    for (int j = 0; j < dataSet.getNumColumns(); j++) {
      double[] col = dataSet.getDoubleData().getColumn(j).toArray();
      double[] _col = _dataSet.getDoubleData().getColumn(j).toArray();

      double mean = StatUtils.mean(col);
      double _mean = StatUtils.mean(_col);

      double variance = StatUtils.variance(col);
      double _variance = StatUtils.variance(_col);

      assertEquals(mean, _mean, 0.3);
      assertEquals(1.0, variance / _variance, .2);
    }
  }
예제 #12
0
  /*
   * @param dataSet A discrete data set.
   * @param column the column in question.
   * @return the max value in that column.
   */
  private int maxInColumn(DataSet dataSet, int column) {
    int max = -1;

    for (int i = 0; i < dataSet.getNumRows(); i++) {
      int value = dataSet.getInt(i, column);
      if (value > max) max = value;
    }

    return max;
  }
예제 #13
0
  private void calcStats() {
    //        Graph resultGraph = getAlgorithmRunner().getResultGraph();
    IGesRunner runner = (IGesRunner) getAlgorithmRunner();

    if (runner.getTopGraphs().isEmpty()) {
      throw new IllegalArgumentException(
          "No patterns were recorded. Please adjust the number of " + "patterns to store.");
    }

    Graph resultGraph = runner.getTopGraphs().get(runner.getIndex()).getGraph();

    if (getAlgorithmRunner().getDataModel() instanceof DataSet) {

      // resultGraph may be the output of a PC search.
      // Such graphs sometimes contain doubly directed edges.

      // /We converte such edges to directed edges here.
      // For the time being an orientation is arbitrarily selected.
      Set<Edge> allEdges = resultGraph.getEdges();

      for (Edge edge : allEdges) {
        if (edge.getEndpoint1() == Endpoint.ARROW && edge.getEndpoint2() == Endpoint.ARROW) {
          // Option 1 orient it from node1 to node2
          resultGraph.setEndpoint(edge.getNode1(), edge.getNode2(), Endpoint.ARROW);

          // Option 2 remove such edges:
          resultGraph.removeEdge(edge);
        }
      }

      Pattern pattern = new Pattern(resultGraph);
      PatternToDag ptd = new PatternToDag(pattern);
      Graph dag = ptd.patternToDagMeekRules();

      DataSet dataSet = (DataSet) getAlgorithmRunner().getDataModel();
      String report;

      if (dataSet.isContinuous()) {
        report = reportIfContinuous(dag, dataSet);
      } else if (dataSet.isDiscrete()) {
        report = reportIfDiscrete(dag, dataSet);
      } else {
        throw new IllegalArgumentException("");
      }

      JScrollPane dagWorkbenchScroll = dagWorkbenchScroll(dag);
      modelStatsText.setLineWrap(true);
      modelStatsText.setWrapStyleWord(true);
      modelStatsText.setText(report);

      removeStatsTabs();
      tabbedPane.addTab("DAG in pattern", dagWorkbenchScroll);
      tabbedPane.addTab("DAG Model Statistics", modelStatsText);
    }
  }
예제 #14
0
  private void setDataSet(DataSet dataSet) {
    List<String> _varNames = dataSet.getVariableNames();

    this.variables = dataSet.getVariables();
    this.dataSet = dataSet;
    this.discrete = dataSet.isDiscrete();

    if (!isDiscrete()) {
      this.covariances = new CovarianceMatrix(dataSet);
    }

    this.sampleSize = dataSet.getNumRows();
  }
  /**
   * Constructs a new Independence test which checks independence facts based on the correlation
   * matrix implied by the given data set (must be continuous). The given significance level is
   * used.
   *
   * @param dataSet A data set containing only continuous columns.
   * @param alpha The alpha level of the test.
   */
  public IndTestFisherZShortTriangular(DataSet dataSet, double alpha) {
    if (!(dataSet.isContinuous())) {
      throw new IllegalArgumentException("Data set must be continuous.");
    }

    this.covMatrix = new ShortTriangularMatrix(dataSet.getNumColumns());
    this.covMatrix.becomeCorrelationMatrix(dataSet);
    this.variables = dataSet.getVariables();
    setAlpha(alpha);

    this.deterministicTest = new IndTestFisherZGeneralizedInverse(dataSet, alpha);
    this.dataSet = dataSet;
  }
예제 #16
0
  /** Creates a cell count table for the given data set. */
  public DataSetProbs(DataSet dataSet) {
    if (dataSet == null) {
      throw new NullPointerException();
    }

    this.dataSet = dataSet;
    dims = new int[dataSet.getNumColumns()];

    for (int i = 0; i < dims.length; i++) {
      DiscreteVariable variable = (DiscreteVariable) dataSet.getVariable(i);
      dims[i] = variable.getNumCategories();
    }

    numRows = dataSet.getNumRows();
  }
  // This takes an inordinate amount of time. -jdramsey 20150929
  private int[] getNonMissingRows(Node x, Node y, List<Node> z) {
    //        List<Integer> rows = new ArrayList<Integer>();
    //
    //        I:
    //        for (int i = 0; i < internalData.getNumRows(); i++) {
    //            for (Node node : variablesPerNode.get(x)) {
    //                if (isMissing(node, i)) continue I;
    //            }
    //
    //            for (Node node : variablesPerNode.get(y)) {
    //                if (isMissing(node, i)) continue I;
    //            }
    //
    //            for (Node _z : z) {
    //                for (Node node : variablesPerNode.get(_z)) {
    //                    if (isMissing(node, i)) continue I;
    //                }
    //            }
    //
    //            rows.add(i);
    //        }

    //        int[] _rows = new int[rows.size()];
    //        for (int k = 0; k < rows.size(); k++) _rows[k] = rows.get(k);

    if (_rows == null) {
      _rows = new int[internalData.getNumRows()];
      for (int k = 0; k < _rows.length; k++) _rows[k] = k;
    }

    return _rows;
  }
예제 #18
0
  /**
   * Constructs a new on-the-fly BayesIM that will calculate conditional probabilities on the fly
   * from the given discrete data set, for the given Bayes PM.
   *
   * @param bayesPm the given Bayes PM, which specifies a directed acyclic graph for a Bayes net and
   *     parametrization for the Bayes net, but not actual values for the parameters.
   * @param dataSet the discrete data set from which conditional probabilities should be estimated
   *     on the fly.
   */
  public OnTheFlyMarginalCalculator(BayesPm bayesPm, DataSet dataSet)
      throws IllegalArgumentException {
    if (bayesPm == null) {
      throw new NullPointerException();
    }

    if (dataSet == null) {
      throw new NullPointerException();
    }

    // Make sure all of the variables in the PM are in the data set;
    // otherwise, estimation is impossible.
    BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet);
    //        DataUtils.ensureVariablesExist(bayesPm, dataSet);

    this.bayesPm = new BayesPm(bayesPm);

    // Get the nodes from the BayesPm. This fixes the order of the nodes
    // in the BayesIm, independently of any change to the BayesPm.
    // (This order must be maintained.)
    Graph graph = bayesPm.getDag();
    this.nodes = graph.getNodes().toArray(new Node[0]);

    // Initialize.
    initialize();

    // Create a subset of the data set with the variables of the IM, in
    // the order of the IM.
    List<Node> variables = getVariables();
    this.dataSet = dataSet.subsetColumns(variables);

    // Create a tautologous proposition for evidence.
    this.evidence = new Evidence(Proposition.tautology(this));
  }
예제 #19
0
  /**
   * @return the estimated conditional probability for the given assertion conditional on the given
   *     condition.
   */
  public double getConditionalProb(Proposition assertion, Proposition condition) {
    if (assertion.getVariableSource() != condition.getVariableSource()) {
      throw new IllegalArgumentException(
          "Assertion and condition must be " + "for the same Bayes IM.");
    }

    List<Node> assertionVars = assertion.getVariableSource().getVariables();
    List<Node> dataVars = dataSet.getVariables();

    assertionVars = GraphUtils.replaceNodes(assertionVars, dataVars);

    if (!new HashSet<Node>(assertionVars).equals(new HashSet<Node>(dataVars))) {
      throw new IllegalArgumentException(
          "Assertion variable and data variables"
              + " are either different or in a different order: "
              + "\n\tAssertion vars: "
              + assertionVars
              + "\n\tData vars: "
              + dataVars);
    }

    int[] point = new int[dims.length];
    int count1 = 0;
    int count2 = 0;
    this.missingValueCaseFound = false;

    point:
    for (int i = 0; i < numRows; i++) {
      for (int j = 0; j < dims.length; j++) {
        point[j] = dataSet.getInt(i, j);

        if (point[j] == DiscreteVariable.MISSING_VALUE) {
          continue point;
        }
      }

      if (condition.isPermissibleCombination(point)) {
        count1++;

        if (assertion.isPermissibleCombination(point)) {
          count2++;
        }
      }
    }

    return count2 / (double) count1;
  }
  public IndTestMixedMultipleTTest(DataSet data, double alpha) {
    this.searchVariables = data.getVariables();
    this.originalData = data.copy();
    DataSet internalData = data.copy();
    this.alpha = alpha;

    List<Node> variables = internalData.getVariables();

    for (Node node : variables) {
      List<Node> nodes = expandVariable(internalData, node);
      variablesPerNode.put(node, nodes);
    }

    this.internalData = internalData;
    this.logisticRegression = new LogisticRegression(internalData);
    this.regression = new RegressionDataset(internalData);
  }
예제 #21
0
  public void actionPerformed(ActionEvent e) {
    DataSet dataSet = (DataSet) dataEditor.getSelectedDataModel();
    if (dataSet == null || dataSet.getNumColumns() == 0) {
      JOptionPane.showMessageDialog(
          findOwner(), "Cannot display a scatter plot for an empty data set.");
      return;
    }

    JPanel panel = new ScatterPlotView(dataSet);
    EditorWindow editorWindow = new EditorWindow(panel, "Scatter Plots", "Save", true, dataEditor);

    //        JPanel dialog = createScatterPlotDialog(null, null);
    //        EditorWindow editorWindow = new EditorWindow(dialog, "Scatter Plots", "Save", true,
    // dataEditor);

    DesktopController.getInstance().addEditorWindow(editorWindow, JLayeredPane.PALETTE_LAYER);
    editorWindow.pack();
    editorWindow.setVisible(true);
  }
예제 #22
0
 public GesConcurrent(DataSet dataSet) {
   setDataSet(dataSet);
   if (dataSet.isDiscrete()) {
     BDeuScore score = new BDeuScore(dataSet);
     score.setSamplePrior(10);
     score.setStructurePrior(0.001);
   }
   setStructurePrior(0.001);
   setSamplePrior(10.);
 }
예제 #23
0
  private String reportIfDiscrete(Graph dag, DataSet dataSet) {
    List vars = dataSet.getVariables();
    Map<String, DiscreteVariable> nodesToVars = new HashMap<String, DiscreteVariable>();
    for (int i = 0; i < dataSet.getNumColumns(); i++) {
      DiscreteVariable var = (DiscreteVariable) vars.get(i);
      String name = var.getName();
      Node node = new GraphNode(name);
      nodesToVars.put(node.getName(), var);
    }

    BayesPm bayesPm = new BayesPm(new Dag(dag));
    List<Node> nodes = bayesPm.getDag().getNodes();

    for (Node node : nodes) {
      Node var = nodesToVars.get(node.getName());

      if (var instanceof DiscreteVariable) {
        DiscreteVariable var2 = nodesToVars.get(node.getName());
        int numCategories = var2.getNumCategories();
        List<String> categories = new ArrayList<String>();
        for (int j = 0; j < numCategories; j++) {
          categories.add(var2.getCategory(j));
        }
        bayesPm.setCategories(node, categories);
      }
    }

    BayesProperties properties = new BayesProperties(dataSet, dag);
    properties.setGraph(dag);

    NumberFormat nf = NumberFormat.getInstance();
    nf.setMaximumFractionDigits(4);

    StringBuilder buf = new StringBuilder();
    buf.append("\nP-value = ").append(properties.getLikelihoodRatioP());
    buf.append("\nDf = ").append(properties.getPValueDf());
    buf.append("\nChi square = ").append(nf.format(properties.getPValueChisq()));
    buf.append("\nBIC score = ").append(nf.format(properties.getBic()));
    buf.append("\n\nH0: Completely disconnected graph.");

    return buf.toString();
  }
예제 #24
0
파일: MBFS.java 프로젝트: bd2kccd/r-causal
  @Override
  public Graph search(DataSet dataSet, Parameters parameters) {
    edu.cmu.tetrad.search.Mbfs search =
        new edu.cmu.tetrad.search.Mbfs(
            test.getTest(dataSet, parameters), parameters.getInt("depth"));

    search.setKnowledge(knowledge);

    this.targetName = parameters.getString("targetName");
    Node target = dataSet.getVariable(targetName);
    return search.search(target);
  }
  /**
   * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the
   * extending class.
   */
  public void execute() {
    DataModel source = getDataModel();

    if (!(source instanceof DataSet)) {
      throw new IllegalArgumentException("Expecting a rectangular data set.");
    }

    DataSet data = (DataSet) source;

    if (!data.isContinuous()) {
      throw new IllegalArgumentException("Expecting a continuous data set.");
    }

    Lingam lingam = new Lingam();
    lingam.setAlpha(getParams().getIndTestParams().getAlpha());
    lingam.setPruningDone(true);
    lingam.setAlpha(getParams().getIndTestParams().getAlpha());
    GraphWithParameters result = lingam.lingam(data);
    setResultGraph(result.getGraph());
    GraphUtils.arrangeBySourceGraph(getResultGraph(), getSourceGraph());
  }
  private boolean isMissing(Node x, int i) {
    int j = internalData.getColumn(x);

    if (x instanceof DiscreteVariable) {
      int v = internalData.getInt(i, j);

      if (v == -99) {
        return true;
      }
    }

    if (x instanceof ContinuousVariable) {
      double v = internalData.getDouble(i, j);

      if (Double.isNaN(v)) {
        return true;
      }
    }

    return false;
  }
예제 #27
0
  public void actionPerformed(ActionEvent e) {
    DataSet dataSet = (DataSet) dataEditor.getSelectedDataModel();
    if (dataSet == null || dataSet.getNumColumns() == 0) {
      JOptionPane.showMessageDialog(
          findOwner(), "Cannot run normality tests on an empty data set.");
      return;
    }
    // if there are missing values warn and don't display q-q plot.
    //        if(DataUtils.containsMissingValue(dataSet)){
    //            JOptionPane.showMessageDialog(findOwner(), new JLabel("<html>Data has missing
    // values, " +
    //                    "remove all missing values before<br>" +
    //                    "running normality tests.</html>"));
    //            return;
    //        }

    JPanel panel = createNormalityTestDialog(null);

    EditorWindow window = new EditorWindow(panel, "Normality Tests", "Close", false, dataEditor);
    DesktopController.getInstance().addEditorWindow(window, JLayeredPane.PALETTE_LAYER);
    window.setVisible(true);
  }
  private List<Node> expandVariable(DataSet dataSet, Node node) {
    if (node instanceof ContinuousVariable) {
      return Collections.singletonList(node);
    }

    if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) {
      return Collections.singletonList(node);
    }

    if (!(node instanceof DiscreteVariable)) {
      throw new IllegalArgumentException();
    }

    List<String> varCats = new ArrayList<String>(((DiscreteVariable) node).getCategories());

    // first category is reference
    varCats.remove(0);
    List<Node> variables = new ArrayList<Node>();

    for (String cat : varCats) {

      Node newVar;

      do {
        String newVarName = node.getName() + "MULTINOM" + "." + cat;
        newVar = new DiscreteVariable(newVarName, 2);
      } while (dataSet.getVariable(newVar.getName()) != null);

      variables.add(newVar);

      dataSet.addVariable(newVar);
      int newVarIndex = dataSet.getColumn(newVar);
      int numCases = dataSet.getNumRows();

      for (int l = 0; l < numCases; l++) {
        Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
        int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString());

        if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat))
          dataSet.setInt(l, newVarIndex, 1);
        else dataSet.setInt(l, newVarIndex, 0);
      }
    }

    return variables;
  }
예제 #29
0
  /** @return the splitNames selected by the editor. */
  public static DataModel createSplits(DataSet dataSet, SplitCasesParams params) {
    List<Integer> indices = new ArrayList<Integer>(dataSet.getNumRows());
    for (int i = 0; i < dataSet.getNumRows(); i++) {
      indices.add(i);
    }

    if (params.isDataShuffled()) {
      Collections.shuffle(indices);
    }

    SplitCasesSpec spec = params.getSpec();
    int numSplits = params.getNumSplits();
    int sampleSize = spec.getSampleSize();
    int[] breakpoints = spec.getBreakpoints();
    List<String> splitNames = spec.getSplitNames();

    int[] _breakpoints = new int[breakpoints.length + 2];
    _breakpoints[0] = 0;
    _breakpoints[_breakpoints.length - 1] = sampleSize;
    System.arraycopy(breakpoints, 0, _breakpoints, 1, breakpoints.length);

    DataModelList list = new DataModelList();
    int ncols = dataSet.getNumColumns();
    for (int n = 0; n < numSplits; n++) {
      int _sampleSize = _breakpoints[n + 1] - _breakpoints[n];

      DataSet _data = new ColtDataSet(_sampleSize, dataSet.getVariables());
      _data.setName(splitNames.get(n));

      for (int i = 0; i < _sampleSize; i++) {
        int oldCase = indices.get(i + _breakpoints[n]);

        for (int j = 0; j < ncols; j++) {
          _data.setObject(i, j, dataSet.getObject(oldCase, j));
        }
      }

      list.add(_data);
    }

    return list;
  }
예제 #30
0
  private double getForthMoment(int x, int y, int z, int w) {
    if (cacheFourthMoments) {
      if (fourthMoment == null) {
        initializeForthMomentMatrix(dataSet.getVariables());
      }

      double sxyzw = fourthMoment[x][y][z][w];

      if (sxyzw == 0.0) {
        sxyzw = sxyzw(x, y, z, w);
        setForthMoment(x, y, z, w, sxyzw);
      }

      return sxyzw;
    } else {
      return sxyzw(x, y, z, w);
    }
  }