コード例 #1
0
ファイル: MultiScheme.java プロジェクト: bigbigbug/wekax
  /**
   * Buildclassifier selects a classifier from the set of classifiers by minimising error on the
   * training data.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (m_Classifiers.length == 0) {
      throw new Exception("No base classifiers have been set!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    newData.randomize(new Random(m_Seed));
    if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
      newData.stratify(m_NumXValFolds);
    Instances train = newData; // train on all data by default
    Instances test = newData; // test on training data by default
    Classifier bestClassifier = null;
    int bestIndex = -1;
    double bestPerformance = Double.NaN;
    int numClassifiers = m_Classifiers.length;
    for (int i = 0; i < numClassifiers; i++) {
      Classifier currentClassifier = getClassifier(i);
      Evaluation evaluation;
      if (m_NumXValFolds > 1) {
        evaluation = new Evaluation(newData);
        for (int j = 0; j < m_NumXValFolds; j++) {
          train = newData.trainCV(m_NumXValFolds, j);
          test = newData.testCV(m_NumXValFolds, j);
          currentClassifier.buildClassifier(train);
          evaluation.setPriors(train);
          evaluation.evaluateModel(currentClassifier, test);
        }
      } else {
        currentClassifier.buildClassifier(train);
        evaluation = new Evaluation(train);
        evaluation.evaluateModel(currentClassifier, test);
      }

      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println(
            "Error rate: "
                + Utils.doubleToString(error, 6, 4)
                + " for classifier "
                + currentClassifier.getClass().getName());
      }

      if ((i == 0) || (error < bestPerformance)) {
        bestClassifier = currentClassifier;
        bestPerformance = error;
        bestIndex = i;
      }
    }
    m_ClassifierIndex = bestIndex;
    m_Classifier = bestClassifier;
    if (m_NumXValFolds > 1) {
      m_Classifier.buildClassifier(newData);
    }
  }
コード例 #2
0
ファイル: SMOTE.java プロジェクト: reacherxu/Graduation
  /**
   * 用分类器测试
   *
   * @param trainFileName
   * @param testFileName
   */
  public static void classify(String trainFileName, String testFileName) {
    try {
      File inputFile = new File(fileName + trainFileName); // 训练语料文件
      ArffLoader atf = new ArffLoader();
      atf.setFile(inputFile);
      Instances instancesTrain = atf.getDataSet(); // 读入训练文件

      // 设置类标签类
      inputFile = new File(fileName + testFileName); // 测试语料文件
      atf.setFile(inputFile);
      Instances instancesTest = atf.getDataSet(); // 读入测试文件

      instancesTest.setClassIndex(instancesTest.numAttributes() - 1);
      instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);

      classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance();
      classifier.buildClassifier(instancesTrain);

      Evaluation eval = new Evaluation(instancesTrain);
      //  第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集
      eval.evaluateModel(classifier, instancesTest);

      System.out.println(eval.toClassDetailsString());
      System.out.println(eval.toSummaryString());
      System.out.println(eval.toMatrixString());
      System.out.println("precision is :" + (1 - eval.errorRate()));

    } catch (Exception e) {
      e.printStackTrace();
    }
  }
コード例 #3
0
ファイル: LWL.java プロジェクト: alishakiba/jDenetX
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    if (!(m_Classifier instanceof WeightedInstancesHandler)) {
      throw new IllegalArgumentException("Classifier must be a " + "WeightedInstancesHandler!");
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    // only class? -> build ZeroR model
    if (instances.numAttributes() == 1) {
      System.err.println(
          "Cannot build model (only class attribute present in data!), "
              + "using ZeroR model instead!");
      m_ZeroR = new weka.classifiers.rules.ZeroR();
      m_ZeroR.buildClassifier(instances);
      return;
    } else {
      m_ZeroR = null;
    }

    m_Train = new Instances(instances, 0, instances.numInstances());

    m_NNSearch.setInstances(m_Train);
  }
コード例 #4
0
  public static Instances getKnowledgeBase() {
    if (knowledgeBase == null) {
      try {
        // load knowledgebase from file
        CreateAppInsertIntoVm.knowledgeBase =
            Action.loadKnowledge(Configuration.getInstance().getKBCreateAppInsertIntoVm());

        // prediction is also performed therefore the classifier and the evaluator must be
        // instantiated
        if (!isOnlyLearning()) {
          System.out.println("Classify data CreateAppInsertInto");
          if (knowledgeBase.numInstances() > 0) {
            classifier = new MultilayerPerceptron();
            classifier.buildClassifier(knowledgeBase);
            evaluation = new Evaluation(knowledgeBase);
            evaluation.crossValidateModel(
                classifier,
                knowledgeBase,
                10,
                knowledgeBase.getRandomNumberGenerator(randomData.nextLong(1, 1000)));
            System.out.println("Classified data CreateAppInsertInto");
          } else {
            System.out.println("No Instancedata for classifier CreateAppInsertIntoVm");
          }
        }
      } catch (Exception e) {
        e.printStackTrace();
      }
    }
    return knowledgeBase;
  }
コード例 #5
0
ファイル: RealAdaBoost.java プロジェクト: dachylong/weka
  /**
   * Boosting method. Boosts any classifier that can handle weighted instances.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @throws Exception if the classifier could not be built successfully
   */
  protected void buildClassifierWithWeights(Instances data) throws Exception {

    Instances trainData, training, trainingWeightsNotNormalized;
    int numInstances = data.numInstances();
    Random randomInstance = new Random(m_Seed);
    double minLoss = Double.MAX_VALUE;

    // Create a copy of the data so that when the weights are diddled
    // with it doesn't mess up the weights for anyone else
    trainingWeightsNotNormalized = new Instances(data, 0, numInstances);

    // Do boostrap iterations
    for (m_NumIterationsPerformed = -1;
        m_NumIterationsPerformed < m_Classifiers.length;
        m_NumIterationsPerformed++) {
      if (m_Debug) {
        System.err.println("Training classifier " + (m_NumIterationsPerformed + 1));
      }

      training = new Instances(trainingWeightsNotNormalized);
      normalizeWeights(training, m_SumOfWeights);

      // Select instances to train the classifier on
      if (m_WeightThreshold < 100) {
        trainData = selectWeightQuantile(training, (double) m_WeightThreshold / 100);
      } else {
        trainData = new Instances(training, 0, numInstances);
      }

      // Build classifier
      if (m_NumIterationsPerformed == -1) {
        m_ZeroR = new weka.classifiers.rules.ZeroR();
        m_ZeroR.buildClassifier(data);
      } else {
        if (m_Classifiers[m_NumIterationsPerformed] instanceof Randomizable)
          ((Randomizable) m_Classifiers[m_NumIterationsPerformed])
              .setSeed(randomInstance.nextInt());
        m_Classifiers[m_NumIterationsPerformed].buildClassifier(trainData);
      }

      // Update instance weights
      setWeights(trainingWeightsNotNormalized, m_NumIterationsPerformed);

      // Has progress been made?
      double loss = 0;
      for (Instance inst : trainingWeightsNotNormalized) {
        loss += Math.log(inst.weight());
      }
      if (m_Debug) {
        System.err.println("Current loss on log scale: " + loss);
      }
      if ((m_NumIterationsPerformed > -1) && (loss > minLoss)) {
        if (m_Debug) {
          System.err.println("Loss has increased: bailing out.");
        }
        break;
      }
      minLoss = loss;
    }
  }
コード例 #6
0
 /** trains the classifier */
 @Override
 public void train() throws Exception {
   if (_train.classIndex() == -1) _train.setClassIndex(_train.numAttributes() - 1);
   _cl.buildClassifier(_train);
   // evaluate classifier and print some statistics
   evaluate();
 }
コード例 #7
0
ファイル: jFMain.java プロジェクト: andresusanto/TreeLearner
  private void jButton3ActionPerformed(
      java.awt.event.ActionEvent evt) { // GEN-FIRST:event_jButton3ActionPerformed
    // TODO add your handling code here:
    switch (jComboBox1.getSelectedIndex()) {
      case 0:
        model = new NaiveBayes();
        jTextArea1.append("Building NaiveBayes model from training data ...\n");
        break;
      case 1:
        model = new Id3();
        jTextArea1.append("Building ID3 model from training data ...\n");
        break;
      case 2:
        model = new J48();
        jTextArea1.append("Building J48 model from training data ...\n");
        break;
    }

    try {
      model.buildClassifier(training);
      jTextArea1.append("Model building is complete ...\n");
      jButton4.setEnabled(true);
      jButton6.setEnabled(true);
    } catch (Exception ex) {
      jTextArea1.append("Model building failed ...\n");
      jTextArea1.append(ex.getMessage());
      jTextArea1.append("\n");
      jButton4.setEnabled(true);
      jButton6.setEnabled(false);
      model = null;
    }
  } // GEN-LAST:event_jButton3ActionPerformed
コード例 #8
0
ファイル: WekaTest.java プロジェクト: fsteeg/tm2
  /**
   * @param args
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    Instances isTrainingSet = createSet(4);
    Instance instance1 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet);
    Instance instance2 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet);
    Instance instance22 = createInstance(new double[] {0, 0, 0, 0}, "S3", isTrainingSet);
    isTrainingSet.add(instance1);
    isTrainingSet.add(instance2);
    isTrainingSet.add(instance22);
    Instances isTestingSet = createSet(4);
    Instance instance3 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet);
    Instance instance4 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet);
    isTestingSet.add(instance3);
    isTestingSet.add(instance4);

    // Create a naïve bayes classifier
    Classifier cModel = (Classifier) new BayesNet(); // M5P
    cModel.buildClassifier(isTrainingSet);

    // Test the model
    Evaluation eTest = new Evaluation(isTrainingSet);
    eTest.evaluateModel(cModel, isTestingSet);

    // Print the result à la Weka explorer:
    String strSummary = eTest.toSummaryString();
    System.out.println(strSummary);

    // Get the likelihood of each classes
    // fDistribution[0] is the probability of being “positive”
    // fDistribution[1] is the probability of being “negative”
    double[] fDistribution = cModel.distributionForInstance(instance4);
    for (int i = 0; i < fDistribution.length; i++) {
      System.out.println(fDistribution[i]);
    }
  }
コード例 #9
0
ファイル: DecisionAnalyzer.java プロジェクト: CaoAo/BeehiveZ
  /**
   * Analyses the given list of decision points according to the context specified. Furthermore, the
   * context is provided with some visualization of the analysis result.
   *
   * @param decisionPoints the list of decision points to be analysed
   * @param log the log to be analysed
   * @param highLevelPN the simulation model to export discovered data dependencies
   */
  public void analyse(ClusterDecisionAnalyzer cda) {
    clusterDecisionAnalyzer = cda;

    // create empty data set with attribute information
    Instances data = cda.getDataInfo();

    // in case no single learning instance can be provided (as decision
    // point is never
    // reached, or decision classes cannot specified properly) --> do not
    // call algorithm
    if (data.numInstances() == 0) {
      System.out.println("No learning instances available");
    }
    // actually solve the classification problem
    else {
      try {
        myClassifier.buildClassifier(data);
        // build up result visualization
        cda.setResultVisualization(createResultVisualization());
        cda.setEvaluationVisualization(createEvaluationVisualization(data));
      } catch (Exception ex) {
        ex.printStackTrace();
        cda.setResultVisualization(
            createMessagePanel("Error while solving the classification problem"));
      }
    }
  }
コード例 #10
0
  public void run() throws Exception {
    BufferedReader datafileclassificationpickup =
        readDataFile(Config.outputPath() + "DaysPickUpClassification.txt");
    BufferedReader datafileclassificationdropoff =
        readDataFile(Config.outputPath() + "DaysDropOffClassification.txt");
    BufferedReader datafileregresssionpickup =
        readDataFile(Config.outputPath() + "DaysPickUpRegression.txt");
    BufferedReader datafileregresssiondropoff =
        readDataFile(Config.outputPath() + "DaysDropOffRegression.txt");

    dataclassificationpickup = new Instances(datafileclassificationpickup);
    dataclassificationpickup.setClassIndex(dataclassificationpickup.numAttributes() - 1);

    dataclassificationdropoff = new Instances(datafileclassificationdropoff);
    dataclassificationdropoff.setClassIndex(dataclassificationdropoff.numAttributes() - 1);

    dataregressionpickup = new Instances(datafileregresssionpickup);
    dataregressionpickup.setClassIndex(dataregressionpickup.numAttributes() - 1);

    dataregressiondropoff = new Instances(datafileregresssiondropoff);
    dataregressiondropoff.setClassIndex(dataregressiondropoff.numAttributes() - 1);

    System.out.println("KNN classification model");
    ibkclassificationpickup = new IBk(10);
    ibkclassificationpickup.buildClassifier(dataclassificationpickup);
    ibkclassificationdropoff = new IBk(10);
    ibkclassificationdropoff.buildClassifier(dataclassificationdropoff);
    System.out.println("Classification Model Ready");

    System.out.println("KNN regression model");
    ibkregressionpickup = new IBk(10);
    ibkregressionpickup.buildClassifier(dataregressionpickup);
    ibkregressiondropoff = new IBk(10);
    ibkregressiondropoff.buildClassifier(dataregressiondropoff);
    System.out.println("Regression Model Ready");

    instclassificationpickup = new DenseInstance(9);
    instclassificationpickup.setDataset(dataclassificationpickup);
    instclassificationdropoff = new DenseInstance(9);
    instclassificationdropoff.setDataset(dataclassificationdropoff);
    instregressionpickup = new DenseInstance(9);
    instregressionpickup.setDataset(dataregressionpickup);
    instregressiondropoff = new DenseInstance(9);
    instregressiondropoff.setDataset(dataregressiondropoff);
    System.out.println("Models ready");
  }
コード例 #11
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {
    if (!Groovy.isPresent()) throw new Exception("Groovy classes not in CLASSPATH!");

    // try loading the module
    initGroovyObject();

    // build the model
    if (m_GroovyObject != null) m_GroovyObject.buildClassifier(instances);
    else System.err.println("buildClassifier: No Groovy object present!");
  }
コード例 #12
0
 public Classifier buildWekaClassifier(Instances wekaInstances) {
   bayesNet = new BayesNet();
   try {
     bayesNet.buildClassifier(wekaInstances);
   } catch (Exception e) {
     System.err.println(e.getMessage());
     e.printStackTrace();
     System.exit(-1);
   }
   return bayesNet;
 }
コード例 #13
0
ファイル: CNode.java プロジェクト: Waikato/meka
 /**
  * Build - Create transformation for this node, and train classifier of type H upon it. The
  * dataset should have class as index 'j', and remove all indices less than L *not* in paY.
  */
 public void build(Instances D, Classifier H) throws Exception {
   // transform data
   T = transform(D);
   // build SLC 'h'
   h = AbstractClassifier.makeCopy(H);
   h.buildClassifier(T);
   // save templates
   // t_ = new SparseInstance(T.numAttributes());
   // t_.setDataset(T);
   // t_.setClassMissing();								// [?,x,x,x]
   T.clear();
 }
コード例 #14
0
  /** runs 10fold CV over the training file */
  public void execute() throws Exception {
    // run filter
    m_Filter.setInputFormat(m_Training);
    Instances filtered = Filter.useFilter(m_Training, m_Filter);

    // train classifier on complete file for tree
    m_Classifier.buildClassifier(filtered);

    // 10fold CV with seed=1
    m_Evaluation = new Evaluation(filtered);

    m_Evaluation.crossValidateModel(
        m_Classifier, filtered, 10, m_Training.getRandomNumberGenerator(1));
  }
コード例 #15
0
 @Override
 public final void run() {
   try {
     Classifier copiedClassifier = AbstractClassifier.makeCopy(classifier);
     copiedClassifier.buildClassifier(train);
     //			log.print("The " + threadId + "th classifier is built!!!");
     //			accuracy = getAccuracy(copiedClassifier, test);
     //			classifier = AbstractClassifier.makeCopy(classifier);
     //			classifier.buildClassifier(train);
     log.print("The " + threadId + "th classifier is built!!!");
     accuracy = getAccuracy(copiedClassifier, test);
   } catch (Exception e) {
     log.print(e.getStackTrace().toString());
     log.print(e.toString());
   }
   multiThreadEval.finishOneThreads();
   log.print("The " + threadId + "th thread is finshed! accuracy = " + accuracy);
 }
コード例 #16
0
ファイル: Driver.java プロジェクト: illes/multimodal
  private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData)
      throws Exception {
    System.err.println(
        "INFO: Starting split validation to predict '"
            + trainData.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' (#train="
            + trainData.numInstances()
            + ",#test="
            + testData.numInstances()
            + ") ...");

    if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set");

    c.buildClassifier(trainData);
    Evaluation eval = new Evaluation(testData);
    eval.useNoPriors();
    double[] predictions = eval.evaluateModel(c, testData);

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));

    // write predictions to file
    {
      System.err.println("INFO: Writing predictions to file ...");
      Writer out = new FileWriter("prediction.trec");
      writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out);
      out.close();
    }

    // write predicted distributions to CSV
    {
      System.err.println("INFO: Writing predicted distributions to CSV ...");
      Writer out = new FileWriter("predicted_distribution.csv");
      writePredictedDistributions(c, testData, 0, out);
      out.close();
    }
  }
コード例 #17
0
ファイル: RemoveMisclassified.java プロジェクト: naranil/weka
  /**
   * Cleanses the data based on misclassifications when performing cross-validation.
   *
   * @param data the data to train with and cleanse
   * @return the cleansed data
   * @throws Exception if something goes wrong
   */
  private Instances cleanseCross(Instances data) throws Exception {

    Instance inst;
    Instances crossSet = new Instances(data);
    Instances temp = new Instances(data, data.numInstances());
    Instances inverseSet = new Instances(data, data.numInstances());
    int count = 0;
    double ans;
    int iterations = 0;
    int classIndex = m_classIndex;
    if (classIndex < 0) {
      classIndex = data.classIndex();
    }
    if (classIndex < 0) {
      classIndex = data.numAttributes() - 1;
    }

    // loop until perfect
    while (count != crossSet.numInstances()
        && crossSet.numInstances() >= m_numOfCrossValidationFolds) {

      count = crossSet.numInstances();

      // check if hit maximum number of iterations
      iterations++;
      if (m_numOfCleansingIterations > 0 && iterations > m_numOfCleansingIterations) {
        break;
      }

      crossSet.setClassIndex(classIndex);

      if (crossSet.classAttribute().isNominal()) {
        crossSet.stratify(m_numOfCrossValidationFolds);
      }
      // do the folds
      temp = new Instances(crossSet, crossSet.numInstances());

      for (int fold = 0; fold < m_numOfCrossValidationFolds; fold++) {
        Instances train = crossSet.trainCV(m_numOfCrossValidationFolds, fold);
        m_cleansingClassifier.buildClassifier(train);
        Instances test = crossSet.testCV(m_numOfCrossValidationFolds, fold);
        // now test
        for (int i = 0; i < test.numInstances(); i++) {
          inst = test.instance(i);
          ans = m_cleansingClassifier.classifyInstance(inst);
          if (crossSet.classAttribute().isNumeric()) {
            if (ans >= inst.classValue() - m_numericClassifyThreshold
                && ans <= inst.classValue() + m_numericClassifyThreshold) {
              temp.add(inst);
            } else if (m_invertMatching) {
              inverseSet.add(inst);
            }
          } else { // class is nominal
            if (ans == inst.classValue()) {
              temp.add(inst);
            } else if (m_invertMatching) {
              inverseSet.add(inst);
            }
          }
        }
      }
      crossSet = temp;
    }

    if (m_invertMatching) {
      inverseSet.setClassIndex(data.classIndex());
      return inverseSet;
    } else {
      crossSet.setClassIndex(data.classIndex());
      return crossSet;
    }
  }
コード例 #18
0
ファイル: RemoveMisclassified.java プロジェクト: naranil/weka
  /**
   * Cleanses the data based on misclassifications when used training data.
   *
   * @param data the data to train with and cleanse
   * @return the cleansed data
   * @throws Exception if something goes wrong
   */
  private Instances cleanseTrain(Instances data) throws Exception {

    Instance inst;
    Instances buildSet = new Instances(data);
    Instances temp = new Instances(data, data.numInstances());
    Instances inverseSet = new Instances(data, data.numInstances());
    int count = 0;
    double ans;
    int iterations = 0;
    int classIndex = m_classIndex;
    if (classIndex < 0) {
      classIndex = data.classIndex();
    }
    if (classIndex < 0) {
      classIndex = data.numAttributes() - 1;
    }

    // loop until perfect
    while (count != buildSet.numInstances()) {

      // check if hit maximum number of iterations
      iterations++;
      if (m_numOfCleansingIterations > 0 && iterations > m_numOfCleansingIterations) {
        break;
      }

      // build classifier
      count = buildSet.numInstances();
      buildSet.setClassIndex(classIndex);
      m_cleansingClassifier.buildClassifier(buildSet);

      temp = new Instances(buildSet, buildSet.numInstances());

      // test on training data
      for (int i = 0; i < buildSet.numInstances(); i++) {
        inst = buildSet.instance(i);
        ans = m_cleansingClassifier.classifyInstance(inst);
        if (buildSet.classAttribute().isNumeric()) {
          if (ans >= inst.classValue() - m_numericClassifyThreshold
              && ans <= inst.classValue() + m_numericClassifyThreshold) {
            temp.add(inst);
          } else if (m_invertMatching) {
            inverseSet.add(inst);
          }
        } else { // class is nominal
          if (ans == inst.classValue()) {
            temp.add(inst);
          } else if (m_invertMatching) {
            inverseSet.add(inst);
          }
        }
      }
      buildSet = temp;
    }

    if (m_invertMatching) {
      inverseSet.setClassIndex(data.classIndex());
      return inverseSet;
    } else {
      buildSet.setClassIndex(data.classIndex());
      return buildSet;
    }
  }
コード例 #19
0
ファイル: Decorate.java プロジェクト: paolopavan/cfr
  /**
   * Build Decorate classifier
   *
   * @param data the training data to be used for generating the classifier
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {
    if (m_Classifier == null) {
      throw new Exception("A base classifier has not been specified!");
    }
    if (data.checkForStringAttributes()) {
      throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
    }
    if (data.classAttribute().isNumeric()) {
      throw new UnsupportedClassTypeException("Decorate can't handle a numeric class!");
    }
    if (m_NumIterations < m_DesiredSize)
      throw new Exception("Max number of iterations must be >= desired ensemble size!");

    // initialize random number generator
    if (m_Seed == -1) m_Random = new Random();
    else m_Random = new Random(m_Seed);

    int i = 1; // current committee size
    int numTrials = 1; // number of Decorate iterations
    Instances divData = new Instances(data); // local copy of data - diversity data
    divData.deleteWithMissingClass();
    Instances artData = null; // artificial data

    // compute number of artficial instances to add at each iteration
    int artSize = (int) (Math.abs(m_ArtSize) * divData.numInstances());
    if (artSize == 0) artSize = 1; // atleast add one random example
    computeStats(data); // Compute training data stats for creating artificial examples

    // initialize new committee
    m_Committee = new Vector();
    Classifier newClassifier = m_Classifier;
    newClassifier.buildClassifier(divData);
    m_Committee.add(newClassifier);
    double eComm = computeError(divData); // compute ensemble error
    if (m_Debug)
      System.out.println(
          "Initialize:\tClassifier " + i + " added to ensemble. Ensemble error = " + eComm);

    // repeat till desired committee size is reached OR the max number of iterations is exceeded
    while (i < m_DesiredSize && numTrials < m_NumIterations) {
      // Generate artificial training examples
      artData = generateArtificialData(artSize, data);

      // Label artificial examples
      labelData(artData);
      addInstances(divData, artData); // Add new artificial data

      // Build new classifier
      Classifier tmp[] = Classifier.makeCopies(m_Classifier, 1);
      newClassifier = tmp[0];
      newClassifier.buildClassifier(divData);
      // Remove all the artificial data
      removeInstances(divData, artSize);

      // Test if the new classifier should be added to the ensemble
      m_Committee.add(newClassifier); // add new classifier to current committee
      double currError = computeError(divData);
      if (currError <= eComm) { // adding the new member did not increase the error
        i++;
        eComm = currError;
        if (m_Debug)
          System.out.println(
              "Iteration: "
                  + (1 + numTrials)
                  + "\tClassifier "
                  + i
                  + " added to ensemble. Ensemble error = "
                  + eComm);
      } else { // reject the current classifier because it increased the ensemble error
        m_Committee.removeElementAt(m_Committee.size() - 1); // pop the last member
      }
      numTrials++;
    }
  }
コード例 #20
0
  public static void main(String[] args) throws Exception {

    BufferedReader reader = new BufferedReader(new FileReader("spambase.arff"));
    Instances data = new Instances(reader);
    reader.close();
    // setting class attribute
    data.setClassIndex(data.numAttributes() - 1);

    int i = data.numInstances();
    int j = data.numAttributes() - 1;

    File file = new File("tablelog.csv");
    Writer output = null;
    output = new BufferedWriter(new FileWriter(file));
    output.write(
        "%missing,auc1,correct1,fmeasure1,auc2,correct2,fmeasure2,auc3,correct3,fmeasure3\n");

    Random randomGenerator = new Random();
    data.randomize(randomGenerator);
    int numBlock = data.numInstances() / 2;
    double num0 = 0, num1 = 0, num2 = 0;
    /*mdata.instance(0).setMissing(0);
    mdata.deleteWithMissing(0);
    System.out.println(mdata.numInstances()+","+data.numInstances());*/
    // Instances traindata=null;
    // Instances testdata=null;
    // System.out.println(data.instance(3).stringValue(1));
    for (int perc = 10; perc < 101; perc = perc + 10) {
      Instances mdata = new Instances(data);
      int numMissing = perc * numBlock / 100;
      double y11[] = new double[2];
      double y21[] = new double[2];
      double y31[] = new double[2];
      double y12[] = new double[2];
      double y22[] = new double[2];
      double y32[] = new double[2];
      double y13[] = new double[2];
      double y23[] = new double[2];
      double y33[] = new double[2];
      for (int p = 0; p < 2; p++) {
        Instances traindata = mdata.trainCV(2, p);
        Instances testdata = mdata.testCV(2, p);
        num0 = 0;
        num1 = 0;
        num2 = 0;
        for (int t = 0; t < numBlock; t++) {
          if (traindata.instance(t).classValue() == 0) num0++;
          if (traindata.instance(t).classValue() == 1) num1++;
          // if (traindata.instance(t).classValue()==2) num2++;
        }
        // System.out.println(mdata.instance(0).classValue());
        Instances trainwithmissing = new Instances(traindata);
        Instances testwithmissing = new Instances(testdata);
        for (int q = 0; q < j; q++) {
          int r = randomGenerator.nextInt((int) i / 2);

          for (int k = 0; k < numMissing; k++) {
            // int r = randomGenerator.nextInt((int) i/2);
            // int c = randomGenerator.nextInt(j);
            trainwithmissing.instance((r + k) % numBlock).setMissing(q);
            testwithmissing.instance((r + k) % numBlock).setMissing(q);
          }
        }
        // trainwithmissing.deleteWithMissing(0);System.out.println(traindata.numInstances()+","+trainwithmissing.numInstances());
        Classifier cModel =
            (Classifier) new Logistic(); // try for different classifiers and datasets
        cModel.buildClassifier(trainwithmissing);
        Evaluation eTest1 = new Evaluation(trainwithmissing);
        eTest1.evaluateModel(cModel, testdata);
        // eTest.crossValidateModel(cModel,mdata,10,mdata.getRandomNumberGenerator(1));
        y11[p] =
            num0 / numBlock * eTest1.areaUnderROC(0)
                + num1
                    / numBlock
                    * eTest1.areaUnderROC(1) /*+num2/numBlock*eTest1.areaUnderROC(2)*/;
        y21[p] = eTest1.correct();
        y31[p] =
            num0 / numBlock * eTest1.fMeasure(0)
                + num1 / numBlock * eTest1.fMeasure(1) /*+num2/numBlock*eTest1.fMeasure(2)*/;

        Classifier cModel2 = (Classifier) new Logistic();
        cModel2.buildClassifier(traindata);
        Evaluation eTest2 = new Evaluation(traindata);
        eTest2.evaluateModel(cModel2, testwithmissing);
        y12[p] =
            num0 / numBlock * eTest2.areaUnderROC(0)
                + num1
                    / numBlock
                    * eTest2.areaUnderROC(1) /*+num2/numBlock*eTest2.areaUnderROC(2)*/;
        y22[p] = eTest2.correct();
        y32[p] =
            num0 / numBlock * eTest2.fMeasure(0)
                + num1 / numBlock * eTest2.fMeasure(1) /*+num2/numBlock*eTest2.fMeasure(2)*/;

        Classifier cModel3 = (Classifier) new Logistic();
        cModel3.buildClassifier(trainwithmissing);
        Evaluation eTest3 = new Evaluation(trainwithmissing);
        eTest3.evaluateModel(cModel3, testwithmissing);
        y13[p] =
            num0 / numBlock * eTest3.areaUnderROC(0)
                + num1
                    / numBlock
                    * eTest3.areaUnderROC(1) /*+num2/numBlock*eTest3.areaUnderROC(2)*/;
        y23[p] = eTest3.correct();
        y33[p] =
            num0 / numBlock * eTest3.fMeasure(0)
                + num1 / numBlock * eTest3.fMeasure(1) /*+num2/numBlock*eTest3.fMeasure(2)*/;
        // System.out.println(num0+","+num1+","+num2+"\n");
      }
      double auc1 = (y11[0] + y11[1]) / 2;
      double auc2 = (y12[0] + y12[1]) / 2;
      double auc3 = (y13[0] + y13[1]) / 2;
      double corr1 = (y21[0] + y21[1]) / i;
      double corr2 = (y22[0] + y22[1]) / i;
      double corr3 = (y23[0] + y23[1]) / i;
      double fm1 = (y31[0] + y31[1]) / 2;
      double fm2 = (y32[0] + y32[1]) / 2;
      double fm3 = (y33[0] + y33[1]) / 2;
      output.write(
          perc + "," + auc1 + "," + corr1 + "," + fm1 + "," + auc2 + "," + corr2 + "," + fm2 + ","
              + auc3 + "," + corr3 + "," + fm3 + "\n"); // System.out.println(num0);
      // mdata=data;

    }
    output.close();
  }
コード例 #21
0
  // 输入问题,输出问题所属类型。
  public double classifyByBayes(String question) throws Exception {
    double label = -1;
    List<Question> questionID = questionDAO.getQuestionIDLabeled();

    // 定义数据格式
    Attribute att1 = new Attribute("法律政策");
    Attribute att2 = new Attribute("位置交通");
    Attribute att3 = new Attribute("风水");
    Attribute att4 = new Attribute("房价");
    Attribute att5 = new Attribute("楼层");
    Attribute att6 = new Attribute("户型");
    Attribute att7 = new Attribute("小区配套");
    Attribute att8 = new Attribute("贷款");
    Attribute att9 = new Attribute("买房时机");
    Attribute att10 = new Attribute("开发商");
    FastVector labels = new FastVector();
    labels.addElement("1");
    labels.addElement("2");
    labels.addElement("3");
    labels.addElement("4");
    labels.addElement("5");
    labels.addElement("6");
    labels.addElement("7");
    labels.addElement("8");
    labels.addElement("9");
    labels.addElement("10");
    Attribute att11 = new Attribute("类别", labels);

    FastVector attributes = new FastVector();
    attributes.addElement(att1);
    attributes.addElement(att2);
    attributes.addElement(att3);
    attributes.addElement(att4);
    attributes.addElement(att5);
    attributes.addElement(att6);
    attributes.addElement(att7);
    attributes.addElement(att8);
    attributes.addElement(att9);
    attributes.addElement(att10);
    attributes.addElement(att11);
    Instances dataset = new Instances("Test-dataset", attributes, 0);
    dataset.setClassIndex(10);

    Classifier classifier = null;
    if (!new File("naivebayes.model").exists()) {
      // 添加数据
      double[] values = new double[11];
      for (int i = 0; i < questionID.size(); i++) {
        for (int m = 0; m < 11; m++) {
          values[m] = 0;
        }
        int whitewordcount = 0;
        whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId());
        if (whitewordcount != 0) {
          List<QuestionWhiteWord> questionwhiteword =
              questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId());
          for (int j = 0; j < questionwhiteword.size(); j++) {
            values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++;
          }
          for (int m = 0; m < 11; m++) {
            values[m] = values[m] / whitewordcount;
          }
        }
        values[10] = questionID.get(i).getType() - 1;
        Instance inst = new Instance(1.0, values);
        dataset.add(inst);
      }
      // 构造分类器
      classifier = new NaiveBayes();
      classifier.buildClassifier(dataset);
      SerializationHelper.write("naivebayes.model", classifier);
    } else {
      classifier = (Classifier) SerializationHelper.read("naivebayes.model");
    }

    System.out.println("*************begin evaluation*******************");
    Evaluation evaluation = new Evaluation(dataset);
    evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。
    System.out.println(evaluation.toSummaryString());

    // 分类
    System.out.println("*************begin classification*******************");
    Instance subject = new Instance(1.0, getQuestionVector(question));
    subject.setDataset(dataset);
    label = classifier.classifyInstance(subject);
    System.out.println("label: " + label);

    //        double dis[]=classifier.distributionForInstance(inst);
    //        for(double i:dis){
    //            System.out.print(i+"    ");
    //        }

    System.out.println(questionID.size());
    return label + 1;
  }
コード例 #22
0
 @Override
 public void buildClassifier(Instances data) throws Exception {
   for (Classifier classifier : classifiers) {
     classifier.buildClassifier(data);
   }
 }
コード例 #23
0
ファイル: MyWekaExplorer.java プロジェクト: Teofebano19/MyANN
 public void buildClassifier(Classifier classifier) throws Exception {
   this.classifier = classifier;
   classifier.buildClassifier(trainingData);
 }
コード例 #24
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // only class? -> build ZeroR model
    if (instances.numAttributes() == 1) {
      System.err.println(
          "Cannot build model (only class attribute present in data!), "
              + "using ZeroR model instead!");
      m_ZeroR = new weka.classifiers.rules.ZeroR();
      m_ZeroR.buildClassifier(instances);
      return;
    } else {
      m_ZeroR = null;
    }

    // reset variable
    m_NumClasses = instances.numClasses();
    m_ClassIndex = instances.classIndex();
    m_NumAttributes = instances.numAttributes();
    m_NumInstances = instances.numInstances();
    m_TotalAttValues = 0;

    // allocate space for attribute reference arrays
    m_StartAttIndex = new int[m_NumAttributes];
    m_NumAttValues = new int[m_NumAttributes];

    // set the starting index of each attribute and the number of values for
    // each attribute and the total number of values for all attributes (not including class).
    for (int i = 0; i < m_NumAttributes; i++) {
      if (i != m_ClassIndex) {
        m_StartAttIndex[i] = m_TotalAttValues;
        m_NumAttValues[i] = instances.attribute(i).numValues();
        m_TotalAttValues += m_NumAttValues[i];
      } else {
        m_StartAttIndex[i] = -1;
        m_NumAttValues[i] = m_NumClasses;
      }
    }

    // allocate space for counts and frequencies
    m_ClassCounts = new double[m_NumClasses];
    m_AttCounts = new double[m_TotalAttValues];
    m_AttAttCounts = new double[m_TotalAttValues][m_TotalAttValues];
    m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues];
    m_Header = new Instances(instances, 0);

    // Calculate the counts
    for (int k = 0; k < m_NumInstances; k++) {
      int classVal = (int) instances.instance(k).classValue();
      m_ClassCounts[classVal]++;
      int[] attIndex = new int[m_NumAttributes];
      for (int i = 0; i < m_NumAttributes; i++) {
        if (i == m_ClassIndex) {
          attIndex[i] = -1;
        } else {
          attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i);
          m_AttCounts[attIndex[i]]++;
        }
      }
      for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
        if (attIndex[Att1] == -1) continue;
        for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
          if ((attIndex[Att2] != -1)) {
            m_AttAttCounts[attIndex[Att1]][attIndex[Att2]]++;
            m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++;
          }
        }
      }
    }

    // compute mutual information between each attribute and class
    m_mutualInformation = new double[m_NumAttributes];
    for (int att = 0; att < m_NumAttributes; att++) {
      if (att == m_ClassIndex) continue;
      m_mutualInformation[att] = mutualInfo(att);
    }
  }
コード例 #25
0
 /**
  * 文本分类要特别一点,因为在使用StringToWordVector对象计算文本中词项(attribute)权重的时候需要用到全局变量,比如DF,所以这里需要批量处理
  * 在weka中要注意有些机器学习算法是批处理有些不是
  */
 public void finishBatch() throws Exception {
   filter.setIDFTransform(true);
   filter.setInputFormat(instances);
   Instances filteredData = Filter.useFilter(instances, filter); // 这才真正产生符合weka算法输入格式的数据集
   classifier.buildClassifier(filteredData); // 真正的训练分类器
 }
コード例 #26
0
  public void exec(PrintWriter printer) {
    try {
      FileWriter outFile = null;
      PrintWriter out = null;
      if (printer == null) {
        outFile = new FileWriter(id + ".results");
        out = new PrintWriter(outFile);
      } else out = printer;

      DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
      ProcessTweets tweetsProcessor = null;
      System.out.println("***************************************");
      System.out.println("***\tEXECUTING TEST\t" + id + "***");
      System.out.println("+++++++++++++++++++++++++++++++++++++++");
      System.out.println("Train size:" + traincorpus.size());
      System.out.println("Test size:" + testcorpus.size());
      out.println("***************************************");
      out.println("***\tEXECUTING TEST\t***");
      out.println("+++++++++++++++++++++++++++++++++++++++");
      out.println("Train size:" + traincorpus.size());
      out.println("Test size:" + testcorpus.size());
      String cloneID = "";
      boolean clonar = false;
      if (baseline) {
        System.out.println("***************************************");
        System.out.println("***\tEXECUTING TEST BASELINE\t***");
        System.out.println("+++++++++++++++++++++++++++++++++++++++");
        System.out.println("Train size:" + traincorpus.size());
        System.out.println("Test size:" + testcorpus.size());
        out.println("***************************************");
        out.println("***\tEXECUTING TEST\t***");
        out.println("+++++++++++++++++++++++++++++++++++++++");
        out.println("Train size:" + traincorpus.size());
        out.println("Test size:" + testcorpus.size());

        BaselineClassifier base = new BaselineClassifier(testcorpus, 8);
        precision = base.getPrecision();
        recall = base.getRecall();
        fmeasure = base.getFmeasure();
        System.out.println("+++++++++++++++++++++++++++++++++++++++");
        System.out.printf("Precision: %.3f\n", precision);
        System.out.printf("Recall: %.3f\n", recall);
        System.out.printf("F-measure: %.3f\n", fmeasure);
        System.out.println("***************************************");
        out.println("+++++++++++++++++++++++++++++++++++++++");
        out.printf("Precision: %.3f\n", precision);
        out.printf("Recall: %.3f\n", recall);
        out.printf("F-measure: %.3f\n", fmeasure);
        out.println("***************************************");
        out.flush();
        out.close();
        return;
      } else {
        System.out.println("Stemming: " + stemming);
        System.out.println("Lematization:" + lematization);
        System.out.println("URLs:" + urls);
        System.out.println("Hashtags:" + hashtags);
        System.out.println("Mentions:" + mentions);
        System.out.println("Unigrams:" + unigrams);
        System.out.println("Bigrams:" + bigrams);
        System.out.println("TF:" + tf);
        System.out.println("TF-IDF:" + tfidf);
        out.println("Stemming: " + stemming);
        out.println("Lematization:" + lematization);
        out.println("URLs:" + urls);
        out.println("Hashtags:" + hashtags);
        out.println("Mentions:" + mentions);
        out.println("Unigrams:" + unigrams);
        out.println("Bigrams:" + bigrams);
        out.println("TF:" + tf);
        out.println("TF-IDF:" + tfidf);
      }
      // Si tengo los tweets procesados, me evito un nuevo proceso
      System.out.println("1-Process tweets " + dateFormat.format(new Date()));
      out.println("1-Process tweets " + dateFormat.format(new Date()));

      List<ProcessedTweet> train = null;
      String[] ids = id.split("-");
      cloneID = ids[0] + "-" + (Integer.valueOf(ids[1]) + 6);
      if (((Integer.valueOf(ids[1]) / 6) % 2) == 0) clonar = true;

      if (new File(id + "-train.ptweets").exists()) {
        train = ProcessedTweetSerialization.fromFile(id + "-train.ptweets");
        tweetsProcessor =
            new ProcessTweets(stemming, lematization, urls, hashtags, mentions, unigrams, bigrams);
        if (lematization) {
          tweetsProcessor.doLematization(train);
        }
        if (stemming) {
          tweetsProcessor.doStemming(train);
        }
      } else {
        tweetsProcessor =
            new ProcessTweets(stemming, lematization, urls, hashtags, mentions, unigrams, bigrams);
        // Esto del set training es un añadido para poder diferenciar los idiomas de las url en el
        // corpus paralelo
        //				tweetsProcessor.setTraining(true);
        train = tweetsProcessor.processTweets(traincorpus);
        //				tweetsProcessor.setTraining(false);
        ProcessedTweetSerialization.toFile(id + "-train.ptweets", train);
        /*
        				if (clonar)
        				{
        					File f = new File (id+"-train.ptweets");
        					Path p = f.toPath();
        					CopyOption[] options = new CopyOption[]{
        						      StandardCopyOption.REPLACE_EXISTING,
        						      StandardCopyOption.COPY_ATTRIBUTES
        						     };
        					Files.copy(p, new File (cloneID+"-train.ptweets").toPath(), options);
        					Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+12)+"-train.ptweets").toPath(), options);
        					Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+18)+"-train.ptweets").toPath(), options);
        					Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+24)+"-train.ptweets").toPath(), options);
        					Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+30)+"-train.ptweets").toPath(), options);
        				}
        */
      }

      // Generamos las BOW. Igual que antes, si existen no las creo.
      System.out.println("2-Fill topics " + dateFormat.format(new Date()));
      out.println("2-Fill topics " + dateFormat.format(new Date()));
      TopicsList topics = null;
      if (new File(id + ".topics").exists()) {
        topics = TopicsSerialization.fromFile(id + ".topics");
        if (tf) topics.setSelectionFeature(TopicDesc.TERM_TF);
        else topics.setSelectionFeature(TopicDesc.TERM_TF_IDF);
        topics.prepareTopics();
      } else {

        topics = new TopicsList();
        if (tf) topics.setSelectionFeature(TopicDesc.TERM_TF);
        else topics.setSelectionFeature(TopicDesc.TERM_TF_IDF);
        System.out.println("Filling topics " + dateFormat.format(new Date()));
        topics.fillTopics(train);
        System.out.println("Preparing topics topics " + dateFormat.format(new Date()));
        // Aquí tengo que serializar antes de preparar, porque si no no puedo calcular los tf y
        // tfidf
        System.out.println("Serializing topics topics " + dateFormat.format(new Date()));
        /*
        				if (clonar)
        				{
        					TopicsSerialization.toFile(cloneID+".topics", topics);
        				}
        */
        topics.prepareTopics();
        TopicsSerialization.toFile(id + ".topics", topics);
      }
      System.out.println("3-Generate arff train file " + dateFormat.format(new Date()));
      out.println("3-Generate arff train file " + dateFormat.format(new Date()));

      // Si el fichero arff no existe, lo creo. en caso contrario vengo haciendo lo que hasta ahora,
      // aprovechar trabajo previo
      if (!new File(id + "-train.arff").exists()) {

        BufferedWriter bw = topics.generateArffHeader(id + "-train.arff");
        int tope = traincorpus.size();
        if (tweetsProcessor == null)
          tweetsProcessor =
              new ProcessTweets(
                  stemming, lematization, urls, hashtags, mentions, unigrams, bigrams);
        for (int indTweet = 0; indTweet < tope; indTweet++) {
          topics.generateArffVector(bw, train.get(indTweet));
        }
        bw.flush();
        bw.close();
      }

      // Ahora proceso los datos de test
      System.out.println("5-build test dataset " + dateFormat.format(new Date()));
      out.println("5-build test dataset " + dateFormat.format(new Date()));

      List<ProcessedTweet> test = null;
      if (new File(id + "-test.ptweets").exists())
        test = ProcessedTweetSerialization.fromFile(id + "-test.ptweets");
      else {
        if (tweetsProcessor == null)
          tweetsProcessor =
              new ProcessTweets(
                  stemming, lematization, urls, hashtags, mentions, unigrams, bigrams);
        test = tweetsProcessor.processTweets(testcorpus);
        ProcessedTweetSerialization.toFile(id + "-test.ptweets", test);
        /*
        				if (clonar)
        				{
        					File f = new File (id+"-test.ptweets");
        					Path p = f.toPath();
        					CopyOption[] options = new CopyOption[]{
        						      StandardCopyOption.REPLACE_EXISTING,
        						      StandardCopyOption.COPY_ATTRIBUTES
        						     };
        					Files.copy(p, new File (cloneID+"-test.ptweets").toPath(), options);
        				}
        */

      }

      // Si el fichero arff no existe, lo creo. en caso contrario vengo haciendo lo que hasta ahora,
      // aprovechar trabajo previo
      if (!new File(id + "-test.arff").exists()) {
        BufferedWriter bw = topics.generateArffHeader(id + "-test.arff");
        int tope = testcorpus.size();
        if (tweetsProcessor == null)
          tweetsProcessor =
              new ProcessTweets(
                  stemming, lematization, urls, hashtags, mentions, unigrams, bigrams);
        for (int indTweet = 0; indTweet < tope; indTweet++) {
          topics.generateArffVector(bw, test.get(indTweet));
        }
        bw.flush();
        bw.close();
      }
      int topeTopics = topics.getTopicsList().size();
      topics.getTopicsList().clear();
      // Genero el clasificador
      // FJRM 25-08-2013 Lo cambio de orden para intentar liberar la memoria de los topics y tener
      // más libre
      System.out.println("4-Generate classifier " + dateFormat.format(new Date()));
      out.println("4-Generate classifier " + dateFormat.format(new Date()));

      Classifier cls = null;
      DataSource sourceTrain = null;
      Instances dataTrain = null;
      if (new File(id + "-MNB.classifier").exists()) {
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(id + "-MNB.classifier"));
        cls = (Classifier) ois.readObject();
        ois.close();
      } else {
        sourceTrain = new DataSource(id + "-train.arff");
        dataTrain = sourceTrain.getDataSet();
        if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1);
        // Entreno el clasificador
        cls = new weka.classifiers.bayes.NaiveBayesMultinomial();
        int clase = dataTrain.numAttributes() - 1;
        dataTrain.setClassIndex(clase);
        cls.buildClassifier(dataTrain);
        ObjectOutputStream oos =
            new ObjectOutputStream(new FileOutputStream(id + "-MNB.classifier"));
        oos.writeObject(cls);
        oos.flush();
        oos.close();
        // data.delete();//no borro para el svm
      }
      // Ahora evaluo el clasificador con los datos de test
      System.out.println("6-Evaluate classifier MNB " + dateFormat.format(new Date()));
      out.println("6-Evaluate classifier MNB" + dateFormat.format(new Date()));
      DataSource sourceTest = new DataSource(id + "-test.arff");
      Instances dataTest = sourceTest.getDataSet();
      int clase = dataTest.numAttributes() - 1;
      dataTest.setClassIndex(clase);
      Evaluation eval = new Evaluation(dataTest);
      eval.evaluateModel(cls, dataTest);
      // Ahora calculo los valores precision, recall y fmeasure. Además saco las matrices de
      // confusion

      precision = 0;
      recall = 0;
      fmeasure = 0;
      for (int ind = 0; ind < topeTopics; ind++) {
        precision += eval.precision(ind);
        recall += eval.recall(ind);
        fmeasure += eval.fMeasure(ind);
      }
      precision = precision / topeTopics;
      recall = recall / topeTopics;
      fmeasure = fmeasure / topeTopics;
      System.out.println("+++++++++++++++++++++++++++++++++++++++");
      System.out.println(eval.toMatrixString());
      System.out.println("+++++++++++++++++++++++++++++++++++++++");
      System.out.printf("Precision: %.3f\n", precision);
      System.out.printf("Recall: %.3f\n", recall);
      System.out.printf("F-measure: %.3f\n", fmeasure);
      System.out.println("***************************************");
      out.println("+++++++++++++++++++++++++++++++++++++++");
      out.println(eval.toMatrixString());
      out.println("+++++++++++++++++++++++++++++++++++++++");
      out.printf("Precision: %.3f\n", precision);
      out.printf("Recall: %.3f\n", recall);
      out.printf("F-measure: %.3f\n", fmeasure);
      out.println("***************************************");
      /*			NO BORRAR
      			System.out.println("7-Evaluate classifier SVM"+dateFormat.format(new Date()));
      			out.println("7-Evaluate classifier SVM"+dateFormat.format(new Date()));
      			if (new File(id+"-SVM.classifier").exists())
      			{
      				ObjectInputStream ois = new ObjectInputStream(new FileInputStream(id+"-SVM.classifier"));
      				cls = (Classifier) ois.readObject();
      				ois.close();
      			}
      			else
      			{
      				if (dataTrain==null)
      				{
      					sourceTrain = new DataSource(id+"-train.arff");
      					dataTrain = sourceTrain.getDataSet();
      					if (dataTrain.classIndex() == -1)
      						dataTrain.setClassIndex(dataTrain.numAttributes() - 1);
      				}
      	//Entreno el clasificador
      				cls = new weka.classifiers.functions.LibSVM();
      				clase = dataTrain.numAttributes()-1;
      				dataTrain.setClassIndex(clase);
      				cls.buildClassifier(dataTrain);
      				ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(id+"-SVM.classifier"));
      				oos.writeObject(cls);
      				oos.flush();
      				oos.close();
      				dataTrain.delete();
      			}
      			eval.evaluateModel(cls, dataTest);
      			precision=0;
      			recall=0;
      			fmeasure=0;
      			for(int ind=0; ind<topeTopics; ind++)
      			{
      				precision += eval.precision(ind);
      				recall += eval.recall(ind);
      				fmeasure += eval.fMeasure(ind);
      			}
      			precision = precision / topeTopics;
      			recall = recall / topeTopics;
      			fmeasure = fmeasure / topeTopics;
      			System.out.println("+++++++++++++++++++++++++++++++++++++++");
      			System.out.println(eval.toMatrixString());
      			System.out.println("+++++++++++++++++++++++++++++++++++++++");
      			System.out.printf("Precision: %.3f\n", precision);
      			System.out.printf("Recall: %.3f\n", recall);
      			System.out.printf("F-measure: %.3f\n", fmeasure);
      			System.out.println("***************************************");
      			out.println("+++++++++++++++++++++++++++++++++++++++");
      			out.println(eval.toMatrixString());
      			out.println("+++++++++++++++++++++++++++++++++++++++");
      			out.printf("Precision: %.3f\n", precision);
      			out.printf("Recall: %.3f\n", recall);
      			out.printf("F-measure: %.3f\n", fmeasure);
      			out.println("***************************************");
      */
      System.out.println("Done " + dateFormat.format(new Date()));
      out.println("Done " + dateFormat.format(new Date()));
      if (printer == null) {
        out.flush();
        out.close();
      }
      // Intento de liberar memoria
      if (dataTrain != null) dataTrain.delete();
      if (dataTest != null) dataTest.delete();
      if (train != null) train.clear();
      if (test != null) test.clear();
      if (topics != null) {
        topics.getTopicsList().clear();
        topics = null;
      }
      if (dataTest != null) dataTest.delete();
      if (cls != null) cls = null;
      if (tweetsProcessor != null) tweetsProcessor = null;
      System.gc();
    } catch (Exception e) {
      e.printStackTrace();
    }
  }
コード例 #27
0
  public static void execSVM(String expName) {
    try {
      FileWriter outFile = null;
      PrintWriter out = null;
      outFile = new FileWriter(expName + "-SVM.results");
      out = new PrintWriter(outFile);
      DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
      ProcessTweets tweetsProcessor = null;
      System.out.println("***************************************");
      System.out.println("***\tEXECUTING TEST\t" + expName + "***");
      System.out.println("+++++++++++++++++++++++++++++++++++++++");
      out.println("***************************************");
      out.println("***\tEXECUTING TEST\t" + expName + "***");
      out.println("+++++++++++++++++++++++++++++++++++++++");
      out.println("4-Generate classifier " + dateFormat.format(new Date()));

      Classifier cls = null;
      DataSource sourceTrain = new DataSource(expName + "-train.arff");
      Instances dataTrain = sourceTrain.getDataSet();
      if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1);
      // Entreno el clasificador
      // cls = new weka.classifiers.functions.LibSVM();
      int clase = dataTrain.numAttributes() - 1;
      cls = new weka.classifiers.bayes.ComplementNaiveBayes();
      dataTrain.setClassIndex(clase);
      cls.buildClassifier(dataTrain);
      ObjectOutputStream oos =
          new ObjectOutputStream(new FileOutputStream(expName + "-SVM.classifier"));
      oos.writeObject(cls);
      oos.flush();
      oos.close();
      DataSource sourceTest = new DataSource(expName + "-test.arff");
      Instances dataTest = sourceTest.getDataSet();
      dataTest.setClassIndex(clase);
      Evaluation eval = new Evaluation(dataTest);
      eval.evaluateModel(cls, dataTest);
      // Ahora calculo los valores precision, recall y fmeasure. Además saco las matrices de
      // confusion

      float precision = 0;
      float recall = 0;
      float fmeasure = 0;
      int topeTopics = 8;
      for (int ind = 0; ind < topeTopics; ind++) {
        precision += eval.precision(ind);
        recall += eval.recall(ind);
        fmeasure += eval.fMeasure(ind);
      }
      precision = precision / topeTopics;
      recall = recall / topeTopics;
      fmeasure = fmeasure / topeTopics;
      System.out.println("++++++++++++++ CNB ++++++++++++++++++++");
      System.out.println(eval.toMatrixString());
      System.out.println("+++++++++++++++++++++++++++++++++++++++");
      System.out.printf("Precision: %.3f\n", precision);
      System.out.printf("Recall: %.3f\n", recall);
      System.out.printf("F-measure: %.3f\n", fmeasure);
      System.out.println("***************************************");
      out.println("++++++++++++++ CNB ++++++++++++++++++++");
      out.println(eval.toMatrixString());
      out.println("+++++++++++++++++++++++++++++++++++++++");
      out.printf("Precision: %.3f\n", precision);
      out.printf("Recall: %.3f\n", recall);
      out.printf("F-measure: %.3f\n", fmeasure);
      out.println("***************************************");
      // OTRO CLASIFICADOR ZeroR
      cls = new weka.classifiers.rules.ZeroR();
      dataTrain.setClassIndex(clase);
      cls.buildClassifier(dataTrain);
      eval = new Evaluation(dataTest);
      eval.evaluateModel(cls, dataTest);
      precision = 0;
      recall = 0;
      fmeasure = 0;
      for (int ind = 0; ind < topeTopics; ind++) {
        precision += eval.precision(ind);
        recall += eval.recall(ind);
        fmeasure += eval.fMeasure(ind);
      }
      precision = precision / topeTopics;
      recall = recall / topeTopics;
      fmeasure = fmeasure / topeTopics;
      System.out.println("++++++++++++++ ZEROR ++++++++++++++++++++");
      System.out.println(eval.toMatrixString());
      System.out.println("+++++++++++++++++++++++++++++++++++++++");
      System.out.printf("Precision: %.3f\n", precision);
      System.out.printf("Recall: %.3f\n", recall);
      System.out.printf("F-measure: %.3f\n", fmeasure);
      System.out.println("***************************************");
      out.println("++++++++++++++ ZEROR ++++++++++++++++++++");
      out.println(eval.toMatrixString());
      out.println("+++++++++++++++++++++++++++++++++++++++");
      out.printf("Precision: %.3f\n", precision);
      out.printf("Recall: %.3f\n", recall);
      out.printf("F-measure: %.3f\n", fmeasure);
      out.println("***************************************");
      // OTRO CLASIFICADOR J48
      /*
      			cls = new weka.classifiers.trees.J48();
      			dataTrain.setClassIndex(clase);
      			cls.buildClassifier(dataTrain);
      			eval = new Evaluation(dataTest);
      			eval.evaluateModel(cls, dataTest);
      			precision=0;
      			recall=0;
      			fmeasure=0;
      			for(int ind=0; ind<topeTopics; ind++)
      			{
      				precision += eval.precision(ind);
      				recall += eval.recall(ind);
      				fmeasure += eval.fMeasure(ind);
      			}
      			precision = precision / topeTopics;
      			recall = recall / topeTopics;
      			fmeasure = fmeasure / topeTopics;
      			System.out.println("++++++++++++++ J48 ++++++++++++++++++++");
      			System.out.println(eval.toMatrixString());
      			System.out.println("+++++++++++++++++++++++++++++++++++++++");
      			System.out.printf("Precision: %.3f\n", precision);
      			System.out.printf("Recall: %.3f\n", recall);
      			System.out.printf("F-measure: %.3f\n", fmeasure);
      			System.out.println("***************************************");
      			out.println("++++++++++++++ J48 ++++++++++++++++++++");
      			out.println(eval.toMatrixString());
      			out.println("+++++++++++++++++++++++++++++++++++++++");
      			out.printf("Precision: %.3f\n", precision);
      			out.printf("Recall: %.3f\n", recall);
      			out.printf("F-measure: %.3f\n", fmeasure);
      			out.println("***************************************");

      //OTRO SMO
      			cls = new weka.classifiers.functions.SMO();
      			dataTrain.setClassIndex(clase);
      			cls.buildClassifier(dataTrain);
      			eval = new Evaluation(dataTest);
      			eval.evaluateModel(cls, dataTest);
      			precision=0;
      			recall=0;
      			fmeasure=0;
      			for(int ind=0; ind<topeTopics; ind++)
      			{
      				precision += eval.precision(ind);
      				recall += eval.recall(ind);
      				fmeasure += eval.fMeasure(ind);
      			}
      			precision = precision / topeTopics;
      			recall = recall / topeTopics;
      			fmeasure = fmeasure / topeTopics;
      			System.out.println("++++++++++++++ SMO ++++++++++++++++++++");
      			System.out.println(eval.toMatrixString());
      			System.out.println("+++++++++++++++++++++++++++++++++++++++");
      			System.out.printf("Precision: %.3f\n", precision);
      			System.out.printf("Recall: %.3f\n", recall);
      			System.out.printf("F-measure: %.3f\n", fmeasure);
      			System.out.println("***************************************");
      			out.println("++++++++++++++ SMO ++++++++++++++++++++");
      			out.println(eval.toMatrixString());
      			out.println("+++++++++++++++++++++++++++++++++++++++");
      			out.printf("Precision: %.3f\n", precision);
      			out.printf("Recall: %.3f\n", recall);
      			out.printf("F-measure: %.3f\n", fmeasure);
      			out.println("***************************************");
      */
      out.flush();
      out.close();
      dataTest.delete();
      dataTrain.delete();
    } catch (FileNotFoundException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }
コード例 #28
0
  /**
   * Gets the results for the supplied train and test datasets. Now performs a deep copy of the
   * classifier before it is built and evaluated (just in case the classifier is not initialized
   * properly in buildClassifier()).
   *
   * @param train the training Instances.
   * @param test the testing Instances.
   * @return the results stored in an array. The objects stored in the array may be Strings,
   *     Doubles, or null (for the missing value).
   * @throws Exception if a problem occurs while getting the results
   */
  public Object[] getResult(Instances train, Instances test) throws Exception {

    if (train.classAttribute().type() != Attribute.NUMERIC) {
      throw new Exception("Class attribute is not numeric!");
    }
    if (m_Template == null) {
      throw new Exception("No classifier has been specified");
    }
    ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean();
    boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported();
    if (canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled())
      thMonitor.setThreadCpuTimeEnabled(true);

    int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0;
    Object[] result = new Object[RESULT_SIZE + addm + m_numPluginStatistics];
    long thID = Thread.currentThread().getId();
    long CPUStartTime = -1,
        trainCPUTimeElapsed = -1,
        testCPUTimeElapsed = -1,
        trainTimeStart,
        trainTimeElapsed,
        testTimeStart,
        testTimeElapsed;
    Evaluation eval = new Evaluation(train);
    m_Classifier = AbstractClassifier.makeCopy(m_Template);

    trainTimeStart = System.currentTimeMillis();
    if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID);
    m_Classifier.buildClassifier(train);
    if (canMeasureCPUTime) trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    testTimeStart = System.currentTimeMillis();
    if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID);
    eval.evaluateModel(m_Classifier, test);
    if (canMeasureCPUTime) testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    testTimeElapsed = System.currentTimeMillis() - testTimeStart;
    thMonitor = null;

    m_result = eval.toSummaryString();
    // The results stored are all per instance -- can be multiplied by the
    // number of instances to get absolute numbers
    int current = 0;
    result[current++] = new Double(train.numInstances());
    result[current++] = new Double(eval.numInstances());

    result[current++] = new Double(eval.meanAbsoluteError());
    result[current++] = new Double(eval.rootMeanSquaredError());
    result[current++] = new Double(eval.relativeAbsoluteError());
    result[current++] = new Double(eval.rootRelativeSquaredError());
    result[current++] = new Double(eval.correlationCoefficient());

    result[current++] = new Double(eval.SFPriorEntropy());
    result[current++] = new Double(eval.SFSchemeEntropy());
    result[current++] = new Double(eval.SFEntropyGain());
    result[current++] = new Double(eval.SFMeanPriorEntropy());
    result[current++] = new Double(eval.SFMeanSchemeEntropy());
    result[current++] = new Double(eval.SFMeanEntropyGain());

    // Timing stats
    result[current++] = new Double(trainTimeElapsed / 1000.0);
    result[current++] = new Double(testTimeElapsed / 1000.0);
    if (canMeasureCPUTime) {
      result[current++] = new Double((trainCPUTimeElapsed / 1000000.0) / 1000.0);
      result[current++] = new Double((testCPUTimeElapsed / 1000000.0) / 1000.0);
    } else {
      result[current++] = new Double(Utils.missingValue());
      result[current++] = new Double(Utils.missingValue());
    }

    // sizes
    if (m_NoSizeDetermination) {
      result[current++] = -1.0;
      result[current++] = -1.0;
      result[current++] = -1.0;
    } else {
      ByteArrayOutputStream bastream = new ByteArrayOutputStream();
      ObjectOutputStream oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(m_Classifier);
      result[current++] = new Double(bastream.size());
      bastream = new ByteArrayOutputStream();
      oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(train);
      result[current++] = new Double(bastream.size());
      bastream = new ByteArrayOutputStream();
      oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(test);
      result[current++] = new Double(bastream.size());
    }

    // Prediction interval statistics
    result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions());
    result[current++] = new Double(eval.sizeOfPredictedRegions());

    if (m_Classifier instanceof Summarizable) {
      result[current++] = ((Summarizable) m_Classifier).toSummaryString();
    } else {
      result[current++] = null;
    }

    for (int i = 0; i < addm; i++) {
      if (m_doesProduce[i]) {
        try {
          double dv =
              ((AdditionalMeasureProducer) m_Classifier).getMeasure(m_AdditionalMeasures[i]);
          if (!Utils.isMissingValue(dv)) {
            Double value = new Double(dv);
            result[current++] = value;
          } else {
            result[current++] = null;
          }
        } catch (Exception ex) {
          System.err.println(ex);
        }
      } else {
        result[current++] = null;
      }
    }

    // get the actual metrics from the evaluation object
    List<AbstractEvaluationMetric> metrics = eval.getPluginMetrics();
    if (metrics != null) {
      for (AbstractEvaluationMetric m : metrics) {
        if (m.appliesToNumericClass()) {
          List<String> statNames = m.getStatisticNames();
          for (String s : statNames) {
            result[current++] = new Double(m.getStatistic(s));
          }
        }
      }
    }

    if (current != RESULT_SIZE + addm + m_numPluginStatistics) {
      throw new Error("Results didn't fit RESULT_SIZE");
    }
    return result;
  }