コード例 #1
0
  @Override
  public void acceptDataSet(DataSetEvent e) {

    m_busy = true;
    if (m_log != null && !e.isStructureOnly()) {
      m_log.statusMessage(statusMessagePrefix() + "Processing batch...");
    }

    init(new Instances(e.getDataSet(), 0));

    if (m_root != null) {
      Instances trueBatch = new Instances(e.getDataSet(), 0);
      Instances falseBatch = new Instances(e.getDataSet(), 0);

      for (int i = 0; i < e.getDataSet().numInstances(); i++) {
        Instance current = e.getDataSet().instance(i);

        boolean result = m_root.evaluate(current, true);

        if (result) {
          if (m_indexOfTrueStep >= 0) {
            trueBatch.add(current);
          }
        } else {
          if (m_indexOfFalseStep >= 0) {
            falseBatch.add(current);
          }
        }
      }

      if (m_indexOfTrueStep >= 0) {
        DataSetEvent d = new DataSetEvent(this, trueBatch);
        ((DataSourceListener) m_downstream[m_indexOfTrueStep]).acceptDataSet(d);
      }

      if (m_indexOfFalseStep >= 0) {
        DataSetEvent d = new DataSetEvent(this, falseBatch);
        ((DataSourceListener) m_downstream[m_indexOfFalseStep]).acceptDataSet(d);
      }
    } else {
      if (m_indexOfTrueStep >= 0) {
        DataSetEvent d = new DataSetEvent(this, e.getDataSet());
        ((DataSourceListener) m_downstream[m_indexOfTrueStep]).acceptDataSet(d);
      }
    }

    if (m_log != null && !e.isStructureOnly()) {
      m_log.statusMessage(statusMessagePrefix() + "Finished");
    }

    m_busy = false;
  }
コード例 #2
0
ファイル: Sorter.java プロジェクト: CSLeicester/weka
  /**
   * Accept and process a data set event
   *
   * @param e a <code>DataSetEvent</code> value
   */
  @Override
  public void acceptDataSet(DataSetEvent e) {
    m_busy = true;
    m_stopRequested.set(false);

    if (m_log != null && e.getDataSet().numInstances() > 0) {
      m_log.statusMessage(statusMessagePrefix() + "Sorting batch...");
    }

    if (e.isStructureOnly()) {
      // nothing to sort!

      // just notify listeners of structure
      DataSetEvent d = new DataSetEvent(this, e.getDataSet());
      notifyDataListeners(d);

      m_busy = false;
      return;
    }

    try {
      init(new Instances(e.getDataSet(), 0));
    } catch (IllegalArgumentException ex) {
      if (m_log != null) {
        String message = "ERROR: There is a problem with the incoming instance structure";

        // m_log.statusMessage(statusMessagePrefix() + message
        // + " - see log for details");
        // m_log.logMessage(statusMessagePrefix() + message + " :"
        // + ex.getMessage());
        stopWithErrorMessage(message, ex);
        m_busy = false;
        return;
      }
    }

    List<InstanceHolder> instances = new ArrayList<InstanceHolder>();
    for (int i = 0; i < e.getDataSet().numInstances(); i++) {
      InstanceHolder h = new InstanceHolder();
      h.m_instance = e.getDataSet().instance(i);
      instances.add(h);
    }
    Collections.sort(instances, m_sortComparator);
    Instances output = new Instances(e.getDataSet(), 0);
    for (int i = 0; i < instances.size(); i++) {
      output.add(instances.get(i).m_instance);
    }

    DataSetEvent d = new DataSetEvent(this, output);
    notifyDataListeners(d);

    if (m_log != null) {
      m_log.statusMessage(statusMessagePrefix() + "Finished.");
    }
    m_busy = false;
  }
コード例 #3
0
  /**
   * Accept a data set
   *
   * @param e a <code>DataSetEvent</code> value
   */
  @Override
  public void acceptDataSet(DataSetEvent e) {
    if (e.isStructureOnly()) {
      // Pass on structure to training and test set listeners
      TrainingSetEvent tse = new TrainingSetEvent(this, e.getDataSet());
      TestSetEvent tsee = new TestSetEvent(this, e.getDataSet());
      notifyTrainingSetProduced(tse);
      notifyTestSetProduced(tsee);
      return;
    }
    if (m_foldThread == null) {
      final Instances dataSet = new Instances(e.getDataSet());
      m_foldThread =
          new Thread() {
            @Override
            public void run() {
              boolean errorOccurred = false;
              try {
                Random random = new Random(getSeed());
                if (!m_preserveOrder) {
                  dataSet.randomize(random);
                }
                if (dataSet.classIndex() >= 0
                    && dataSet.attribute(dataSet.classIndex()).isNominal()
                    && !m_preserveOrder) {
                  dataSet.stratify(getFolds());
                  if (m_logger != null) {
                    m_logger.logMessage("[" + getCustomName() + "] " + "stratifying data");
                  }
                }

                for (int i = 0; i < getFolds(); i++) {
                  if (m_foldThread == null) {
                    if (m_logger != null) {
                      m_logger.logMessage(
                          "[" + getCustomName() + "] Cross validation has been canceled!");
                    }
                    // exit gracefully
                    break;
                  }
                  Instances train =
                      (!m_preserveOrder)
                          ? dataSet.trainCV(getFolds(), i, random)
                          : dataSet.trainCV(getFolds(), i);
                  Instances test = dataSet.testCV(getFolds(), i);

                  // inform all training set listeners
                  TrainingSetEvent tse = new TrainingSetEvent(this, train);
                  tse.m_setNumber = i + 1;
                  tse.m_maxSetNumber = getFolds();
                  String msg =
                      getCustomName() + "$" + CrossValidationFoldMaker.this.hashCode() + "|";
                  if (m_logger != null) {
                    m_logger.statusMessage(
                        msg
                            + "seed: "
                            + getSeed()
                            + " folds: "
                            + getFolds()
                            + "|Training fold "
                            + (i + 1));
                  }
                  if (m_foldThread != null) {
                    // System.err.println("--Just before notify training set");
                    notifyTrainingSetProduced(tse);
                    // System.err.println("---Just after notify");
                  }

                  // inform all test set listeners
                  TestSetEvent teste = new TestSetEvent(this, test);
                  teste.m_setNumber = i + 1;
                  teste.m_maxSetNumber = getFolds();

                  if (m_logger != null) {
                    m_logger.statusMessage(
                        msg
                            + "seed: "
                            + getSeed()
                            + " folds: "
                            + getFolds()
                            + "|Test fold "
                            + (i + 1));
                  }
                  if (m_foldThread != null) {
                    notifyTestSetProduced(teste);
                  }
                }
              } catch (Exception ex) {
                // stop all processing
                errorOccurred = true;
                if (m_logger != null) {
                  m_logger.logMessage(
                      "[" + getCustomName() + "] problem during fold creation. " + ex.getMessage());
                }
                ex.printStackTrace();
                CrossValidationFoldMaker.this.stop();
              } finally {
                m_foldThread = null;

                if (errorOccurred) {
                  if (m_logger != null) {
                    m_logger.statusMessage(
                        getCustomName()
                            + "$"
                            + CrossValidationFoldMaker.this.hashCode()
                            + "|"
                            + "ERROR (See log for details).");
                  }
                } else if (isInterrupted()) {
                  String msg = "[" + getCustomName() + "] Cross validation interrupted";
                  if (m_logger != null) {
                    m_logger.logMessage("[" + getCustomName() + "] Cross validation interrupted");
                    m_logger.statusMessage(
                        getCustomName()
                            + "$"
                            + CrossValidationFoldMaker.this.hashCode()
                            + "|"
                            + "INTERRUPTED");
                  } else {
                    System.err.println(msg);
                  }
                } else {
                  String msg =
                      getCustomName() + "$" + CrossValidationFoldMaker.this.hashCode() + "|";
                  if (m_logger != null) {
                    m_logger.statusMessage(msg + "Finished.");
                  }
                }
                block(false);
              }
            }
          };
      m_foldThread.setPriority(Thread.MIN_PRIORITY);
      m_foldThread.start();

      // if (m_foldThread.isAlive()) {
      block(true);
      // }
      m_foldThread = null;
    }
  }