예제 #1
0
  /** Find the dataModel model. (If it's a list, take the one that's selected.) */
  private DataModel getSelectedDataModel(DataWrapper dataWrapper) {
    DataModelList dataModelList = dataWrapper.getDataModelList();

    if (dataModelList.size() > 1) {
      return dataModelList;
    }

    DataModel dataModel = dataWrapper.getSelectedDataModel();

    if (dataModel instanceof DataSet) {
      DataSet dataSet = (DataSet) dataModel;

      if (dataSet.isDiscrete()) {
        return dataSet;
      } else if (dataSet.isContinuous()) {
        return dataSet;
      } else if (dataSet.isMixed()) {
        return dataSet;
      }

      throw new IllegalArgumentException(
          "<html>"
              + "This data set contains a mixture of discrete and continuous "
              + "<br>columns; there are no algorithm in Tetrad currently to "
              + "<br>search over such data sets."
              + "</html>");
    } else if (dataModel instanceof ICovarianceMatrix) {
      return dataModel;
    } else if (dataModel instanceof TimeSeriesData) {
      return dataModel;
    }

    throw new IllegalArgumentException("Unexpected dataModel source: " + dataModel);
  }
  public FindOneFactorClustersWithCausalIndicators(
      DataSet dataSet, TestType testType, double alpha) {
    if (dataSet.isContinuous()) {
      this.variables = dataSet.getVariables();
      this.test = new ContinuousTetradTest(dataSet, testType, alpha);
      this.indTest = new IndTestFisherZ(dataSet, alpha);
      this.alpha = alpha;
      this.testType = testType;
      this.dataModel = dataSet;

      if (testType == TestType.TETRAD_DELTA) {
        deltaTest = new DeltaTetradTest(dataSet);
        deltaTest.setCacheFourthMoments(false);
      }

      this.cov = new CovarianceMatrix(dataSet);
    } else if (dataSet.isDiscrete()) {
      this.variables = dataSet.getVariables();
      this.test = new DiscreteTetradTest(dataSet, alpha);
      this.indTest = new IndTestChiSquare(dataSet, alpha);
      this.alpha = alpha;
      this.testType = testType;
      this.dataModel = dataSet;

      if (testType == TestType.TETRAD_DELTA) {
        deltaTest = new DeltaTetradTest(dataSet);
        deltaTest.setCacheFourthMoments(false);
      }
    }
  }
예제 #3
0
  /** Constructs the score using a covariance matrix. */
  public BdeuScoreImages(List<DataModel> dataModels) {
    if (dataModels == null) {
      throw new NullPointerException();
    }

    List<BDeuScore> scores = new ArrayList<>();

    for (DataModel model : dataModels) {
      if (model instanceof DataSet) {
        DataSet dataSet = (DataSet) model;

        if (!dataSet.isDiscrete()) {
          throw new IllegalArgumentException("Datasets must be continuous.");
        }

        scores.add(new BDeuScore(dataSet));
      } else {
        throw new IllegalArgumentException(
            "Only continuous data sets and covariance matrices may be used as input.");
      }
    }

    List<Node> variables = scores.get(0).getVariables();

    for (int i = 2; i < scores.size(); i++) {
      scores.get(i).setVariables(variables);
    }

    this.scores = scores;
    this.variables = variables;
  }
예제 #4
0
  /**
   * Factory to return the correct param editor for independence test params. This will go in a
   * little box in the search editor.
   */
  private JComponent getIndTestParamBox(IndTestParams indTestParams) {
    if (indTestParams == null) {
      throw new NullPointerException();
    }

    if (indTestParams instanceof GesIndTestParams) {
      if (getAlgorithmRunner() instanceof IGesRunner) {
        IGesRunner gesRunner = ((IGesRunner) getAlgorithmRunner());
        GesIndTestParams params = (GesIndTestParams) indTestParams;
        DataModel dataModel = gesRunner.getDataModel();
        boolean discreteData = dataModel instanceof DataSet && ((DataSet) dataModel).isDiscrete();
        return new GesIndTestParamsEditor(params, discreteData);
      }

      if (getAlgorithmRunner() instanceof ImagesRunner) {
        ImagesRunner gesRunner = ((ImagesRunner) getAlgorithmRunner());
        GesIndTestParams params = (GesIndTestParams) indTestParams;
        DataSet dataSet = (DataSet) gesRunner.getDataModel();
        boolean discreteData = dataSet.isDiscrete();
        return new GesIndTestParamsEditor(params, discreteData);
      }
    }

    return new IndTestParamsEditor(indTestParams);
  }
예제 #5
0
 public GesConcurrent(DataSet dataSet) {
   setDataSet(dataSet);
   if (dataSet.isDiscrete()) {
     BDeuScore score = new BDeuScore(dataSet);
     score.setSamplePrior(10);
     score.setStructurePrior(0.001);
   }
   setStructurePrior(0.001);
   setSamplePrior(10.);
 }
예제 #6
0
  private void calcStats() {
    //        Graph resultGraph = getAlgorithmRunner().getResultGraph();
    IGesRunner runner = (IGesRunner) getAlgorithmRunner();

    if (runner.getTopGraphs().isEmpty()) {
      throw new IllegalArgumentException(
          "No patterns were recorded. Please adjust the number of " + "patterns to store.");
    }

    Graph resultGraph = runner.getTopGraphs().get(runner.getIndex()).getGraph();

    if (getAlgorithmRunner().getDataModel() instanceof DataSet) {

      // resultGraph may be the output of a PC search.
      // Such graphs sometimes contain doubly directed edges.

      // /We converte such edges to directed edges here.
      // For the time being an orientation is arbitrarily selected.
      Set<Edge> allEdges = resultGraph.getEdges();

      for (Edge edge : allEdges) {
        if (edge.getEndpoint1() == Endpoint.ARROW && edge.getEndpoint2() == Endpoint.ARROW) {
          // Option 1 orient it from node1 to node2
          resultGraph.setEndpoint(edge.getNode1(), edge.getNode2(), Endpoint.ARROW);

          // Option 2 remove such edges:
          resultGraph.removeEdge(edge);
        }
      }

      Pattern pattern = new Pattern(resultGraph);
      PatternToDag ptd = new PatternToDag(pattern);
      Graph dag = ptd.patternToDagMeekRules();

      DataSet dataSet = (DataSet) getAlgorithmRunner().getDataModel();
      String report;

      if (dataSet.isContinuous()) {
        report = reportIfContinuous(dag, dataSet);
      } else if (dataSet.isDiscrete()) {
        report = reportIfDiscrete(dag, dataSet);
      } else {
        throw new IllegalArgumentException("");
      }

      JScrollPane dagWorkbenchScroll = dagWorkbenchScroll(dag);
      modelStatsText.setLineWrap(true);
      modelStatsText.setWrapStyleWord(true);
      modelStatsText.setText(report);

      removeStatsTabs();
      tabbedPane.addTab("DAG in pattern", dagWorkbenchScroll);
      tabbedPane.addTab("DAG Model Statistics", modelStatsText);
    }
  }
예제 #7
0
  private void setDataSet(DataSet dataSet) {
    List<String> _varNames = dataSet.getVariableNames();

    this.variables = dataSet.getVariables();
    this.dataSet = dataSet;
    this.discrete = dataSet.isDiscrete();

    if (!isDiscrete()) {
      this.covariances = new CovarianceMatrix(dataSet);
    }

    this.sampleSize = dataSet.getNumRows();
  }