private List<Node> expandVariable(DataSet dataSet, Node node) {
    if (node instanceof ContinuousVariable) {
      return Collections.singletonList(node);
    }

    if (node instanceof DiscreteVariable && ((DiscreteVariable) node).getNumCategories() < 3) {
      return Collections.singletonList(node);
    }

    if (!(node instanceof DiscreteVariable)) {
      throw new IllegalArgumentException();
    }

    List<String> varCats = new ArrayList<String>(((DiscreteVariable) node).getCategories());

    // first category is reference
    varCats.remove(0);
    List<Node> variables = new ArrayList<Node>();

    for (String cat : varCats) {

      Node newVar;

      do {
        String newVarName = node.getName() + "MULTINOM" + "." + cat;
        newVar = new DiscreteVariable(newVarName, 2);
      } while (dataSet.getVariable(newVar.getName()) != null);

      variables.add(newVar);

      dataSet.addVariable(newVar);
      int newVarIndex = dataSet.getColumn(newVar);
      int numCases = dataSet.getNumRows();

      for (int l = 0; l < numCases; l++) {
        Object dataCell = dataSet.getObject(l, dataSet.getColumn(node));
        int dataCellIndex = ((DiscreteVariable) node).getIndex(dataCell.toString());

        if (dataCellIndex == ((DiscreteVariable) node).getIndex(cat))
          dataSet.setInt(l, newVarIndex, 1);
        else dataSet.setInt(l, newVarIndex, 0);
      }
    }

    return variables;
  }
Exemplo n.º 2
0
  /** @return the splitNames selected by the editor. */
  public static DataModel createSplits(DataSet dataSet, SplitCasesParams params) {
    List<Integer> indices = new ArrayList<Integer>(dataSet.getNumRows());
    for (int i = 0; i < dataSet.getNumRows(); i++) {
      indices.add(i);
    }

    if (params.isDataShuffled()) {
      Collections.shuffle(indices);
    }

    SplitCasesSpec spec = params.getSpec();
    int numSplits = params.getNumSplits();
    int sampleSize = spec.getSampleSize();
    int[] breakpoints = spec.getBreakpoints();
    List<String> splitNames = spec.getSplitNames();

    int[] _breakpoints = new int[breakpoints.length + 2];
    _breakpoints[0] = 0;
    _breakpoints[_breakpoints.length - 1] = sampleSize;
    System.arraycopy(breakpoints, 0, _breakpoints, 1, breakpoints.length);

    DataModelList list = new DataModelList();
    int ncols = dataSet.getNumColumns();
    for (int n = 0; n < numSplits; n++) {
      int _sampleSize = _breakpoints[n + 1] - _breakpoints[n];

      DataSet _data = new ColtDataSet(_sampleSize, dataSet.getVariables());
      _data.setName(splitNames.get(n));

      for (int i = 0; i < _sampleSize; i++) {
        int oldCase = indices.get(i + _breakpoints[n]);

        for (int j = 0; j < ncols; j++) {
          _data.setObject(i, j, dataSet.getObject(oldCase, j));
        }
      }

      list.add(_data);
    }

    return list;
  }