예제 #1
0
  /**
   * 用分类器测试
   *
   * @param trainFileName
   * @param testFileName
   */
  public static void classify(String trainFileName, String testFileName) {
    try {
      File inputFile = new File(fileName + trainFileName); // 训练语料文件
      ArffLoader atf = new ArffLoader();
      atf.setFile(inputFile);
      Instances instancesTrain = atf.getDataSet(); // 读入训练文件

      // 设置类标签类
      inputFile = new File(fileName + testFileName); // 测试语料文件
      atf.setFile(inputFile);
      Instances instancesTest = atf.getDataSet(); // 读入测试文件

      instancesTest.setClassIndex(instancesTest.numAttributes() - 1);
      instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);

      classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance();
      classifier.buildClassifier(instancesTrain);

      Evaluation eval = new Evaluation(instancesTrain);
      //  第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集
      eval.evaluateModel(classifier, instancesTest);

      System.out.println(eval.toClassDetailsString());
      System.out.println(eval.toSummaryString());
      System.out.println(eval.toMatrixString());
      System.out.println("precision is :" + (1 - eval.errorRate()));

    } catch (Exception e) {
      e.printStackTrace();
    }
  }
예제 #2
0
파일: WekaTest.java 프로젝트: fsteeg/tm2
  /**
   * @param args
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    Instances isTrainingSet = createSet(4);
    Instance instance1 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet);
    Instance instance2 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet);
    Instance instance22 = createInstance(new double[] {0, 0, 0, 0}, "S3", isTrainingSet);
    isTrainingSet.add(instance1);
    isTrainingSet.add(instance2);
    isTrainingSet.add(instance22);
    Instances isTestingSet = createSet(4);
    Instance instance3 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet);
    Instance instance4 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet);
    isTestingSet.add(instance3);
    isTestingSet.add(instance4);

    // Create a naïve bayes classifier
    Classifier cModel = (Classifier) new BayesNet(); // M5P
    cModel.buildClassifier(isTrainingSet);

    // Test the model
    Evaluation eTest = new Evaluation(isTrainingSet);
    eTest.evaluateModel(cModel, isTestingSet);

    // Print the result à la Weka explorer:
    String strSummary = eTest.toSummaryString();
    System.out.println(strSummary);

    // Get the likelihood of each classes
    // fDistribution[0] is the probability of being “positive”
    // fDistribution[1] is the probability of being “negative”
    double[] fDistribution = cModel.distributionForInstance(instance4);
    for (int i = 0; i < fDistribution.length; i++) {
      System.out.println(fDistribution[i]);
    }
  }
예제 #3
0
  /**
   * Creates an evaluation overview of the built classifier.
   *
   * @return the panel to be displayed as result evaluation view for the current decision point
   */
  protected JPanel createEvaluationVisualization(Instances data) {
    // build text field to display evaluation statistics
    JTextPane statistic = new JTextPane();

    try {
      // build evaluation statistics
      Evaluation evaluation = new Evaluation(data);
      evaluation.evaluateModel(myClassifier, data);
      statistic.setText(
          evaluation.toSummaryString()
              + "\n\n"
              + evaluation.toClassDetailsString()
              + "\n\n"
              + evaluation.toMatrixString());

    } catch (Exception ex) {
      ex.printStackTrace();
      return createMessagePanel("Error while creating the decision tree evaluation view");
    }

    statistic.setFont(new Font("Courier", Font.PLAIN, 14));
    statistic.setEditable(false);
    statistic.setCaretPosition(0);

    JPanel resultViewPanel = new JPanel();
    resultViewPanel.setLayout(new BoxLayout(resultViewPanel, BoxLayout.PAGE_AXIS));
    resultViewPanel.add(new JScrollPane(statistic));

    return resultViewPanel;
  }
  public static void run(String[] args) throws Exception {
    /**
     * *************************************************
     *
     * @param args[0]: train arff path
     * @param args[1]: test arff path
     */
    DataSource source = new DataSource(args[0]);
    Instances data = source.getDataSet();
    data.setClassIndex(data.numAttributes() - 1);
    NaiveBayes model = new NaiveBayes();
    model.buildClassifier(data);

    // Evaluation:
    Evaluation eval = new Evaluation(data);
    Instances testData = new DataSource(args[1]).getDataSet();
    testData.setClassIndex(testData.numAttributes() - 1);
    eval.evaluateModel(model, testData);
    System.out.println(model.toString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
    System.out.println("======\nConfusion Matrix:");
    double[][] confusionM = eval.confusionMatrix();
    for (int i = 0; i < confusionM.length; ++i) {
      for (int j = 0; j < confusionM[i].length; ++j) {
        System.out.format("%10s ", confusionM[i][j]);
      }
      System.out.print("\n");
    }
  }
  public static void wekaAlgorithms(Instances data) throws Exception {
    classifier = new FilteredClassifier(); // new instance of tree
    classifier.setClassifier(new NaiveBayes());
    //  classifier.setClassifier(new J48());
    // classifier.setClassifier(new RandomForest());

    //	classifier.setClassifier(new ZeroR());
    //  classifier.setClassifier(new NaiveBayes());
    //     classifier.setClassifier(new IBk());

    data.setClassIndex(data.numAttributes() - 1);
    Evaluation eval = new Evaluation(data);

    int folds = 10;
    eval.crossValidateModel(classifier, data, folds, new Random(1));

    System.out.println("===== Evaluating on filtered (training) dataset =====");
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toClassDetailsString());
    double[][] mat = eval.confusionMatrix();
    System.out.println("========= Confusion Matrix =========");
    for (int i = 0; i < mat.length; i++) {
      for (int j = 0; j < mat.length; j++) {

        System.out.print(mat[i][j] + "  ");
      }
      System.out.println(" ");
    }
  }
 /** evaluates the classifier */
 @Override
 public void evaluate() throws Exception {
   // evaluate classifier and print some statistics
   if (_test.classIndex() == -1) _test.setClassIndex(_test.numAttributes() - 1);
   Evaluation eval = new Evaluation(_train);
   eval.evaluateModel(_cl, _test);
   System.out.println(eval.toSummaryString("\nResults\n======\n", false));
   System.out.println(eval.toMatrixString());
 }
예제 #7
0
 /** uses the meta-classifier */
 protected static void useClassifier(Instances data) throws Exception {
   System.out.println("\n1. Meta-classfier");
   AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
   CfsSubsetEval eval = new CfsSubsetEval();
   GreedyStepwise search = new GreedyStepwise();
   search.setSearchBackwards(true);
   J48 base = new J48();
   classifier.setClassifier(base);
   classifier.setEvaluator(eval);
   classifier.setSearch(search);
   Evaluation evaluation = new Evaluation(data);
   evaluation.crossValidateModel(classifier, data, 10, new Random(1));
   System.out.println(evaluation.toSummaryString());
 }
  public void evaluateClassifyMethod() throws Exception {

    Instances data1 = getInstances();
    Evaluation eval = new Evaluation(data1);
    J48 tree = new J48();
    eval.crossValidateModel(tree, data1, 10, new Random(1));
    System.out.println("决策树准确率:");
    System.out.println(eval.toSummaryString("\nResults\n\n", false));

    Instances data2 = getInstances();
    Evaluation eval2 = new Evaluation(data2);
    NaiveBayes bayes = new NaiveBayes();
    eval2.crossValidateModel(bayes, data2, 10, new Random(1));
    System.out.println("贝叶斯准确率:");
    System.out.println(eval2.toSummaryString("\nResults\n\n", false));

    Instances data3 = getInstances();
    Evaluation eval3 = new Evaluation(data3);
    IBk ibk = new IBk(3);
    eval3.crossValidateModel(ibk, data3, 10, new Random(1));
    System.out.println("KNN准确率:");
    System.out.println(eval3.toSummaryString("\nResults\n\n", false));
  }
예제 #9
0
  public static Double runClassify(String trainFile, String testFile) {
    double predictOrder = 0.0;
    double trueOrder = 0.0;
    try {
      String trainWekaFileName = trainFile;
      String testWekaFileName = testFile;

      Instances train = DataSource.read(trainWekaFileName);
      Instances test = DataSource.read(testWekaFileName);

      train.setClassIndex(0);
      test.setClassIndex(0);

      train.deleteAttributeAt(8);
      test.deleteAttributeAt(8);
      train.deleteAttributeAt(6);
      test.deleteAttributeAt(6);
      train.deleteAttributeAt(5);
      test.deleteAttributeAt(5);
      train.deleteAttributeAt(4);
      test.deleteAttributeAt(4);

      // AdditiveRegression classifier = new AdditiveRegression();

      // NaiveBayes classifier = new NaiveBayes();

      RandomForest classifier = new RandomForest();
      // LibSVM classifier = new LibSVM();

      classifier.buildClassifier(train);
      Evaluation eval = new Evaluation(train);
      eval.evaluateModel(classifier, test);

      System.out.println(eval.toSummaryString("\nResults\n\n", true));
      // System.out.println(eval.toClassDetailsString());
      // System.out.println(eval.toMatrixString());
      int k = 892;
      for (int i = 0; i < test.numInstances(); i++) {
        predictOrder = classifier.classifyInstance(test.instance(i));
        trueOrder = test.instance(i).classValue();
        System.out.println((k++) + "," + (int) predictOrder);
      }

    } catch (Exception e) {
      e.printStackTrace();
    }
    return predictOrder;
  }
  public static void trainModel(Instances dataTrain, Instances dataTest) {
    try {
      LibLINEAR classifier = new LibLINEAR();
      classifier.setBias(10);

      classifier.buildClassifier(dataTrain);

      Evaluation eval = new Evaluation(dataTrain);
      eval.evaluateModel(classifier, dataTest);

      System.out.println(eval.toSummaryString("\nResults\n======\n", false));

    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }
  @Override
  public void crossValidation(String traindata) throws Exception {
    DataSource ds = new DataSource(traindata);
    Instances instances = ds.getDataSet();
    StringToWordVector stv = new StringToWordVector();
    stv.setOptions(
        weka.core.Utils.splitOptions(
            "-R first-last -W 1000 "
                + "-prune-rate -1.0 -N 0 "
                + "-stemmer weka.core.stemmers.NullStemmer -M 1 "
                + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters  \\\" \\r\\n\\t.,;:\\\'\\\"()?!\""));

    stv.setInputFormat(instances);
    instances = Filter.useFilter(instances, stv);
    instances.setClassIndex(0);
    Evaluation eval = new Evaluation(instances);
    eval.crossValidateModel(this.classifier, instances, 10, new Random(1));
    System.out.println(eval.toSummaryString());
    System.out.println(eval.toMatrixString());
  }
예제 #12
0
  @Override
  protected Evaluation test() {
    working = true;
    System.out.println("Agent " + getLocalName() + ": Testing...");

    // evaluate classifier and print some statistics
    Evaluation eval = null;
    try {
      eval = new Evaluation(train);
      eval.evaluateModel(cls, test);
      System.out.println(
          eval.toSummaryString(getLocalName() + " agent: " + "\nResults\n=======\n", false));

    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
    working = false;
    return eval;
  } // end test
예제 #13
0
  private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData)
      throws Exception {
    System.err.println(
        "INFO: Starting split validation to predict '"
            + trainData.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' (#train="
            + trainData.numInstances()
            + ",#test="
            + testData.numInstances()
            + ") ...");

    if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set");

    c.buildClassifier(trainData);
    Evaluation eval = new Evaluation(testData);
    eval.useNoPriors();
    double[] predictions = eval.evaluateModel(c, testData);

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));

    // write predictions to file
    {
      System.err.println("INFO: Writing predictions to file ...");
      Writer out = new FileWriter("prediction.trec");
      writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out);
      out.close();
    }

    // write predicted distributions to CSV
    {
      System.err.println("INFO: Writing predicted distributions to CSV ...");
      Writer out = new FileWriter("predicted_distribution.csv");
      writePredictedDistributions(c, testData, 0, out);
      out.close();
    }
  }
예제 #14
0
  /** outputs some data about the classifier */
  public String toString() {
    StringBuffer result;

    result = new StringBuffer();
    result.append("Weka - Demo\n===========\n\n");

    result.append(
        "Classifier...: "
            + m_Classifier.getClass().getName()
            + " "
            + Utils.joinOptions(m_Classifier.getOptions())
            + "\n");
    if (m_Filter instanceof OptionHandler)
      result.append(
          "Filter.......: "
              + m_Filter.getClass().getName()
              + " "
              + Utils.joinOptions(((OptionHandler) m_Filter).getOptions())
              + "\n");
    else result.append("Filter.......: " + m_Filter.getClass().getName() + "\n");
    result.append("Training file: " + m_TrainingFile + "\n");
    result.append("\n");

    result.append(m_Classifier.toString() + "\n");
    result.append(m_Evaluation.toSummaryString() + "\n");
    try {
      result.append(m_Evaluation.toMatrixString() + "\n");
    } catch (Exception e) {
      e.printStackTrace();
    }
    try {
      result.append(m_Evaluation.toClassDetailsString() + "\n");
    } catch (Exception e) {
      e.printStackTrace();
    }

    return result.toString();
  }
예제 #15
0
  static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception {
    System.err.println(
        "INFO: Starting crossvalidation to predict '"
            + data.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' ...");

    StringBuffer sb = new StringBuffer();
    Evaluation eval = new Evaluation(data);
    eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE);

    // write predictions to file
    {
      Writer out = new FileWriter("cv.log");
      out.write(sb.toString());
      out.close();
    }

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
  }
예제 #16
0
  public static void analyze_accuracy_NHBS(int rng_seed) throws Exception {
    HashMap<String, Object> population_params = load_defaults(null);
    RawLoader rl = new RawLoader(population_params, true, false, rng_seed);
    List<DrugUser> learningData = rl.getLearningData();

    Instances nhbs_data =
        new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size());
    for (DrugUser du : learningData) {
      nhbs_data.add(du.getInstance());
    }
    System.out.println(nhbs_data.toSummaryString());
    nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state"));

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    ArrayList<String> options = new ArrayList<String>();
    options.add("-Q");
    options.add("" + rng_seed);
    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // ROC=0.60
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    Classifier classifier = new Logistic();
    // ROC=0.686
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    // Classifier classifier = new RandomForest();
    // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree.  generally,
    // might want to optimize (or not
    // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    // options.add("-I"); options.add("100"); options.add("-K"); options.add("4");
    // ROC=0.673

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent
    // ROC=0.633

    // Classifier classifier = new AdaBoostM1();
    // ROC=0.66
    // Classifier classifier = new MultiBoostAB();
    // ROC=0.67
    // Classifier classifier = new Stacking();
    // ROC=0.495

    // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier
    // ROC=0.585
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options

    classifier.setOptions((String[]) options.toArray(new String[0]));

    // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code
    // classifier.buildClassifier(nhbs_data); // build classifier

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }
예제 #17
0
  public static void test_NHBS_old() throws Exception {
    // load the data
    CSVLoader loader = new CSVLoader();
    // these must come before the getDataSet()
    // loader.setEnclosureCharacters(",\'\"S");
    // loader.setNominalAttributes("16,71"); //zip code, drug name
    // loader.setStringAttributes("");
    // loader.setDateAttributes("0,1");
    // loader.setSource(new File("hcv/data/NHBS/IDU2_HCV_model_012913_cleaned_for_weka.csv"));
    loader.setSource(new File("/home/sasha/hcv/code/data/IDU2_HCV_model_012913_cleaned.csv"));
    Instances nhbs_data = loader.getDataSet();
    loader.setMissingValue("NOVALUE");
    // loader.setMissingValue("");

    nhbs_data.deleteAttributeAt(12); // zip code
    nhbs_data.deleteAttributeAt(1); // date - redundant with age
    nhbs_data.deleteAttributeAt(0); // date
    System.out.println("classifying attribute:");
    nhbs_data.setClassIndex(1); // new index  3->2->1
    nhbs_data.attribute(1).getMetadata().toString(); // HCVEIARSLT1

    // wishlist: perhaps it would be smarter to throw out unclassified instance?  they interfere
    // with the scoring
    nhbs_data.deleteWithMissingClass();
    // nhbs_data.setClass(new Attribute("HIVRSLT"));//.setClassIndex(1); //2nd column.  all are
    // mostly negative
    // nhbs_data.setClass(new Attribute("HCVEIARSLT1"));//.setClassIndex(2); //3rd column

    // #14, i.e. rds_fem, should be made numeric
    System.out.println("NHBS IDU 2009 Dataset");
    System.out.println("Summary of input:");
    // System.out.printlnnhbs_data.toSummaryString());
    System.out.println("  Num of classes: " + nhbs_data.numClasses());
    System.out.println("  Num of attributes: " + nhbs_data.numAttributes());
    for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) {
      Attribute attr = nhbs_data.attribute(idx);
      System.out.println("" + idx + ": " + attr.toString());
      System.out.println("     distinct values:" + nhbs_data.numDistinctValues(idx));
      // System.out.println("" + attr.enumerateValues());
    }

    // System.exit(0);
    // nhbs_data.deleteAttributeAt(0); //response ID
    // nhbs_data.deleteAttributeAt(16); //zip

    // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00
    // Classifier classifier = new MINND();
    // Classifier classifier = new CitationKNN();
    // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7%
    // Classifier classifier = new SMOreg();
    // Classifier classifier = new LinearNNSearch();

    // LinearRegression: Cannot handle multi-valued nominal class!
    // Classifier classifier = new LinearRegression();

    Classifier classifier = new RandomForest();
    String[] options = {
      "-I", "100", "-K", "4"
    }; // -I trees, -K features per tree.  generally, might want to optimize (or not
       // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests)
    classifier.setOptions(options);
    // Classifier classifier = new Logistic();

    // KStar classifier = new KStar();
    // classifier.setGlobalBlend(20); //the amount of not greedy, in percent

    // does poorly
    // Classifier classifier = new AdaBoostM1();
    // Classifier classifier = new MultiBoostAB();
    // Classifier classifier = new Stacking();

    // building a C45 tree classifier
    // J48 classifier = new J48(); // new instance of tree
    // String[] options = new String[1];
    // options[0] = "-U"; // unpruned tree
    // classifier.setOptions(options); // set the options
    // classifier.buildClassifier(nhbs_data); // build classifier

    // wishlist: remove infrequent values
    // weka.filters.unsupervised.instance.RemoveFrequentValues()
    Filter f1 = new RemoveUseless();
    f1.setInputFormat(nhbs_data);
    nhbs_data = Filter.useFilter(nhbs_data, f1);

    // evaluation
    Evaluation eval = new Evaluation(nhbs_data);
    eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1));
    System.out.println(eval.toSummaryString("\nResults\n\n", false));
    System.out.println(eval.toClassDetailsString());
    // System.out.println(eval.toCumulativeMarginDistributionString());
  }
  public static void main(String[] args) {

    if (args.length < 1) {
      System.out.println("usage: C4_5TweetTopicCategorization <root_path>");
      System.exit(-1);
    }

    String rootPath = args[0];
    File dataFolder = new File(rootPath + "/data");
    String resultFolderPath = rootPath + "/results/C4_5/";

    CrisisMailer crisisMailer = CrisisMailer.getCrisisMailer();
    Logger logger = Logger.getLogger(C4_5TweetTopicCategorization.class);
    PropertyConfigurator.configure(Constants.LOG4J_PROPERTIES_FILE_PATH);

    File resultFolder = new File(resultFolderPath);
    if (!resultFolder.exists()) resultFolder.mkdir();

    CSVLoader csvLoader = new CSVLoader();

    try {
      for (File dataSetName : dataFolder.listFiles()) {

        Instances data = null;
        try {
          csvLoader.setSource(dataSetName);
          csvLoader.setStringAttributes("2");
          data = csvLoader.getDataSet();
        } catch (IOException ioe) {
          logger.error(ioe);
          crisisMailer.sendEmailAlert(ioe);
          System.exit(-1);
        }

        data.setClassIndex(data.numAttributes() - 1);
        data.deleteWithMissingClass();

        Instances vectorizedData = null;
        StringToWordVector stringToWordVectorFilter = new StringToWordVector();
        try {
          stringToWordVectorFilter.setInputFormat(data);
          stringToWordVectorFilter.setAttributeIndices("2");
          stringToWordVectorFilter.setIDFTransform(true);
          stringToWordVectorFilter.setLowerCaseTokens(true);
          stringToWordVectorFilter.setOutputWordCounts(false);
          stringToWordVectorFilter.setUseStoplist(true);

          vectorizedData = Filter.useFilter(data, stringToWordVectorFilter);
          vectorizedData.deleteAttributeAt(0);
          // System.out.println(vectorizedData);
        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }

        J48 j48Classifier = new J48();

        /*
        FilteredClassifier filteredClassifier = new FilteredClassifier();
        filteredClassifier.setFilter(stringToWordVectorFilter);
        filteredClassifier.setClassifier(j48Classifier);
        */

        try {
          Evaluation eval = new Evaluation(vectorizedData);
          eval.crossValidateModel(
              j48Classifier, vectorizedData, 5, new Random(System.currentTimeMillis()));

          FileOutputStream resultOutputStream =
              new FileOutputStream(new File(resultFolderPath + dataSetName.getName()));

          resultOutputStream.write(eval.toSummaryString("=== Summary ===", false).getBytes());
          resultOutputStream.write(eval.toMatrixString().getBytes());
          resultOutputStream.write(eval.toClassDetailsString().getBytes());
          resultOutputStream.close();

        } catch (Exception exception) {
          logger.error(exception);
          crisisMailer.sendEmailAlert(exception);
          System.exit(-1);
        }
      }
    } catch (Exception exception) {
      logger.error(exception);
      crisisMailer.sendEmailAlert(exception);
      System.out.println(-1);
    }
  }
예제 #19
0
  public double getLiblinear(String path, String train, String test) {
    // 本次精确度
    double accuracy = 0.0;

    try {
      LibLINEAR c1 = new LibLINEAR();

      // * String[] options=weka.core.Utils.splitOptions(
      // * "-S 1 -C 1.0 -E 0.001 -B 0"); c1.setOptions(options);

      ArffLoader atf = new ArffLoader();
      File TraininputFile = new File(train);
      atf.setFile(TraininputFile); // 训练语料文件
      Instances instancesTrain = atf.getDataSet(); // 读入训练文件
      instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1);

      File TestinputFile = new File(test);
      atf.setFile(TestinputFile); // 测试语料文件
      Instances instancesTest = atf.getDataSet(); // 读入测试文件
      // 设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数
      instancesTest.setClassIndex(instancesTest.numAttributes() - 1);

      c1.buildClassifier(instancesTrain); // 训练

      Evaluation eval = new Evaluation(instancesTrain);
      eval.evaluateModel(c1, instancesTest);
      // eval.crossValidateModel(c1, instancesTrain, 10, new
      // Random(1));
      File newfile = new File(path + "OutLiblinear_temp" + ".txt");

      BufferedWriter bufferedWriter =
          new BufferedWriter(new OutputStreamWriter(new FileOutputStream(newfile), "utf-8"));

      bufferedWriter.write(eval.toSummaryString() + "\r\n");
      bufferedWriter.write(eval.toClassDetailsString() + "\r\n");
      bufferedWriter.write(eval.toMatrixString() + "\r\n");

      bufferedWriter.flush();
      bufferedWriter.close();

      BufferedReader bufferedReader = new BufferedReader(new FileReader(newfile));
      String[] splitLineString = new String[5];
      while (bufferedReader.ready()) {
        bufferedReader.readLine();
        String lineString = bufferedReader.readLine();
        splitLineString = lineString.split(" ");
        System.out.println(splitLineString[4]);
        break;
      }
      bufferedReader.close();

      // 求分类准确度
      String tempLine;
      BufferedReader tempBF = new BufferedReader(new FileReader(newfile));
      while (tempBF.ready()) {
        tempLine = tempBF.readLine();
        if (tempLine.contains("Correctly Classified Instances")) {
          tempLine = tempLine.substring(tempLine.lastIndexOf(".") - 2, tempLine.lastIndexOf(" "));
          accuracy = Double.parseDouble(tempLine);
          break;
        }
      }

      tempBF.close();

    } catch (Exception e) {
      System.out.println("Can't run linlinear of weka.");
    }

    return accuracy;
  }
예제 #20
0
  /**
   * Gets the results for the supplied train and test datasets. Now performs a deep copy of the
   * classifier before it is built and evaluated (just in case the classifier is not initialized
   * properly in buildClassifier()).
   *
   * @param train the training Instances.
   * @param test the testing Instances.
   * @return the results stored in an array. The objects stored in the array may be Strings,
   *     Doubles, or null (for the missing value).
   * @throws Exception if a problem occurs while getting the results
   */
  public Object[] getResult(Instances train, Instances test) throws Exception {

    if (train.classAttribute().type() != Attribute.NUMERIC) {
      throw new Exception("Class attribute is not numeric!");
    }
    if (m_Template == null) {
      throw new Exception("No classifier has been specified");
    }
    ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean();
    boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported();
    if (canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled())
      thMonitor.setThreadCpuTimeEnabled(true);

    int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0;
    Object[] result = new Object[RESULT_SIZE + addm + m_numPluginStatistics];
    long thID = Thread.currentThread().getId();
    long CPUStartTime = -1,
        trainCPUTimeElapsed = -1,
        testCPUTimeElapsed = -1,
        trainTimeStart,
        trainTimeElapsed,
        testTimeStart,
        testTimeElapsed;
    Evaluation eval = new Evaluation(train);
    m_Classifier = AbstractClassifier.makeCopy(m_Template);

    trainTimeStart = System.currentTimeMillis();
    if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID);
    m_Classifier.buildClassifier(train);
    if (canMeasureCPUTime) trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    testTimeStart = System.currentTimeMillis();
    if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID);
    eval.evaluateModel(m_Classifier, test);
    if (canMeasureCPUTime) testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    testTimeElapsed = System.currentTimeMillis() - testTimeStart;
    thMonitor = null;

    m_result = eval.toSummaryString();
    // The results stored are all per instance -- can be multiplied by the
    // number of instances to get absolute numbers
    int current = 0;
    result[current++] = new Double(train.numInstances());
    result[current++] = new Double(eval.numInstances());

    result[current++] = new Double(eval.meanAbsoluteError());
    result[current++] = new Double(eval.rootMeanSquaredError());
    result[current++] = new Double(eval.relativeAbsoluteError());
    result[current++] = new Double(eval.rootRelativeSquaredError());
    result[current++] = new Double(eval.correlationCoefficient());

    result[current++] = new Double(eval.SFPriorEntropy());
    result[current++] = new Double(eval.SFSchemeEntropy());
    result[current++] = new Double(eval.SFEntropyGain());
    result[current++] = new Double(eval.SFMeanPriorEntropy());
    result[current++] = new Double(eval.SFMeanSchemeEntropy());
    result[current++] = new Double(eval.SFMeanEntropyGain());

    // Timing stats
    result[current++] = new Double(trainTimeElapsed / 1000.0);
    result[current++] = new Double(testTimeElapsed / 1000.0);
    if (canMeasureCPUTime) {
      result[current++] = new Double((trainCPUTimeElapsed / 1000000.0) / 1000.0);
      result[current++] = new Double((testCPUTimeElapsed / 1000000.0) / 1000.0);
    } else {
      result[current++] = new Double(Utils.missingValue());
      result[current++] = new Double(Utils.missingValue());
    }

    // sizes
    if (m_NoSizeDetermination) {
      result[current++] = -1.0;
      result[current++] = -1.0;
      result[current++] = -1.0;
    } else {
      ByteArrayOutputStream bastream = new ByteArrayOutputStream();
      ObjectOutputStream oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(m_Classifier);
      result[current++] = new Double(bastream.size());
      bastream = new ByteArrayOutputStream();
      oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(train);
      result[current++] = new Double(bastream.size());
      bastream = new ByteArrayOutputStream();
      oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(test);
      result[current++] = new Double(bastream.size());
    }

    // Prediction interval statistics
    result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions());
    result[current++] = new Double(eval.sizeOfPredictedRegions());

    if (m_Classifier instanceof Summarizable) {
      result[current++] = ((Summarizable) m_Classifier).toSummaryString();
    } else {
      result[current++] = null;
    }

    for (int i = 0; i < addm; i++) {
      if (m_doesProduce[i]) {
        try {
          double dv =
              ((AdditionalMeasureProducer) m_Classifier).getMeasure(m_AdditionalMeasures[i]);
          if (!Utils.isMissingValue(dv)) {
            Double value = new Double(dv);
            result[current++] = value;
          } else {
            result[current++] = null;
          }
        } catch (Exception ex) {
          System.err.println(ex);
        }
      } else {
        result[current++] = null;
      }
    }

    // get the actual metrics from the evaluation object
    List<AbstractEvaluationMetric> metrics = eval.getPluginMetrics();
    if (metrics != null) {
      for (AbstractEvaluationMetric m : metrics) {
        if (m.appliesToNumericClass()) {
          List<String> statNames = m.getStatisticNames();
          for (String s : statNames) {
            result[current++] = new Double(m.getStatistic(s));
          }
        }
      }
    }

    if (current != RESULT_SIZE + addm + m_numPluginStatistics) {
      throw new Error("Results didn't fit RESULT_SIZE");
    }
    return result;
  }
  // 输入问题,输出问题所属类型。
  public double classifyByBayes(String question) throws Exception {
    double label = -1;
    List<Question> questionID = questionDAO.getQuestionIDLabeled();

    // 定义数据格式
    Attribute att1 = new Attribute("法律政策");
    Attribute att2 = new Attribute("位置交通");
    Attribute att3 = new Attribute("风水");
    Attribute att4 = new Attribute("房价");
    Attribute att5 = new Attribute("楼层");
    Attribute att6 = new Attribute("户型");
    Attribute att7 = new Attribute("小区配套");
    Attribute att8 = new Attribute("贷款");
    Attribute att9 = new Attribute("买房时机");
    Attribute att10 = new Attribute("开发商");
    FastVector labels = new FastVector();
    labels.addElement("1");
    labels.addElement("2");
    labels.addElement("3");
    labels.addElement("4");
    labels.addElement("5");
    labels.addElement("6");
    labels.addElement("7");
    labels.addElement("8");
    labels.addElement("9");
    labels.addElement("10");
    Attribute att11 = new Attribute("类别", labels);

    FastVector attributes = new FastVector();
    attributes.addElement(att1);
    attributes.addElement(att2);
    attributes.addElement(att3);
    attributes.addElement(att4);
    attributes.addElement(att5);
    attributes.addElement(att6);
    attributes.addElement(att7);
    attributes.addElement(att8);
    attributes.addElement(att9);
    attributes.addElement(att10);
    attributes.addElement(att11);
    Instances dataset = new Instances("Test-dataset", attributes, 0);
    dataset.setClassIndex(10);

    Classifier classifier = null;
    if (!new File("naivebayes.model").exists()) {
      // 添加数据
      double[] values = new double[11];
      for (int i = 0; i < questionID.size(); i++) {
        for (int m = 0; m < 11; m++) {
          values[m] = 0;
        }
        int whitewordcount = 0;
        whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId());
        if (whitewordcount != 0) {
          List<QuestionWhiteWord> questionwhiteword =
              questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId());
          for (int j = 0; j < questionwhiteword.size(); j++) {
            values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++;
          }
          for (int m = 0; m < 11; m++) {
            values[m] = values[m] / whitewordcount;
          }
        }
        values[10] = questionID.get(i).getType() - 1;
        Instance inst = new Instance(1.0, values);
        dataset.add(inst);
      }
      // 构造分类器
      classifier = new NaiveBayes();
      classifier.buildClassifier(dataset);
      SerializationHelper.write("naivebayes.model", classifier);
    } else {
      classifier = (Classifier) SerializationHelper.read("naivebayes.model");
    }

    System.out.println("*************begin evaluation*******************");
    Evaluation evaluation = new Evaluation(dataset);
    evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。
    System.out.println(evaluation.toSummaryString());

    // 分类
    System.out.println("*************begin classification*******************");
    Instance subject = new Instance(1.0, getQuestionVector(question));
    subject.setDataset(dataset);
    label = classifier.classifyInstance(subject);
    System.out.println("label: " + label);

    //        double dis[]=classifier.distributionForInstance(inst);
    //        for(double i:dis){
    //            System.out.print(i+"    ");
    //        }

    System.out.println(questionID.size());
    return label + 1;
  }
예제 #22
0
  public static void main(String[] args) throws Exception {

    /*
     * First we load our preditons from the CSV formatted file.
     */
    CSVLoader predictCsvLoader = new CSVLoader();
    predictCsvLoader.setSource(new File("predict.csv"));

    /*
     * Since we are not using the ARFF format here, we have to give the
     * loader a little bit of information about the data types. Columns
     * 3,8,10 need to be of type string and columns 1,4,11 are nominal
     * types.
     */
    predictCsvLoader.setStringAttributes("3,8,10");
    predictCsvLoader.setNominalAttributes("1,4,11");
    Instances predictDataSet = predictCsvLoader.getDataSet();

    /*
     * Here we set the attribute we want to test the predicitons with
     */
    Attribute testAttribute = predictDataSet.attribute(0);
    predictDataSet.setClass(testAttribute);

    /*
     * We still have to remove all string attributes before we can test
     */
    predictDataSet.deleteStringAttributes();

    /*
     * Next we load the training data from our ARFF file
     */
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();

    /*
     * Now we tell the data set which attribute we want to classify, in our
     * case, we want to classify the first column: survived
     */
    Attribute trainAttribute = trainDataSet.attribute(0);
    trainDataSet.setClass(trainAttribute);

    /*
     * The RandomForest implementation cannot handle columns of type string,
     * so we remove them for now.
     */
    trainDataSet.deleteStringAttributes();

    /*
     * Now we read in the serialized model from disk
     */
    Classifier classifier = (Classifier) SerializationHelper.read("titanic.model");

    /*
     * Next we will use an Evaluation class to evaluate the performance of
     * our Classifier.
     */
    Evaluation evaluation = new Evaluation(trainDataSet);
    evaluation.evaluateModel(classifier, predictDataSet, new Object[] {});

    /*
     * After we evaluate the Classifier, we write out the summary
     * information to the screen.
     */
    System.out.println(classifier);
    System.out.println(evaluation.toSummaryString());
  }
예제 #23
0
 public void crossValidation() throws Exception {
   Evaluation eval = new Evaluation(trainingData);
   eval.crossValidateModel(classifier, trainingData, 10, new Random(1));
   System.out.println(eval.toSummaryString("Results", false));
 }
예제 #24
0
 public void testModel() throws Exception {
   Evaluation eval = new Evaluation(testData);
   eval.evaluateModel(classifier, testData);
   System.out.println(eval.toSummaryString("Results", false));
 }
  /**
   * Accepts and processes a classifier encapsulated in an incremental classifier event
   *
   * @param ce an <code>IncrementalClassifierEvent</code> value
   */
  @Override
  public void acceptClassifier(final IncrementalClassifierEvent ce) {
    try {
      if (ce.getStatus() == IncrementalClassifierEvent.NEW_BATCH) {
        m_throughput = new StreamThroughput(statusMessagePrefix());
        m_throughput.setSamplePeriod(m_statusFrequency);

        // m_eval = new Evaluation(ce.getCurrentInstance().dataset());
        m_eval = new Evaluation(ce.getStructure());
        m_eval.useNoPriors();

        m_dataLegend = new Vector();
        m_reset = true;
        m_dataPoint = new double[0];
        Instances inst = ce.getStructure();
        System.err.println("NEW BATCH");
        m_instanceCount = 0;

        if (m_windowSize > 0) {
          m_window = new LinkedList<Instance>();
          m_windowEval = new Evaluation(ce.getStructure());
          m_windowEval.useNoPriors();
          m_windowedPreds = new LinkedList<double[]>();

          if (m_logger != null) {
            m_logger.logMessage(
                statusMessagePrefix()
                    + "[IncrementalClassifierEvaluator] Chart output using windowed "
                    + "evaluation over "
                    + m_windowSize
                    + " instances");
          }
        }

        /*
         * if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix()
         * + "IncrementalClassifierEvaluator: started processing...");
         * m_logger.logMessage(statusMessagePrefix() +
         * " [IncrementalClassifierEvaluator]" + statusMessagePrefix() +
         * " started processing..."); }
         */
      } else {
        Instance inst = ce.getCurrentInstance();
        if (inst != null) {
          m_throughput.updateStart();
          m_instanceCount++;
          // if (inst.attribute(inst.classIndex()).isNominal()) {
          double[] dist = ce.getClassifier().distributionForInstance(inst);
          double pred = 0;
          if (!inst.isMissing(inst.classIndex())) {
            if (m_outputInfoRetrievalStats) {
              // store predictions so AUC etc can be output.
              m_eval.evaluateModelOnceAndRecordPrediction(dist, inst);
            } else {
              m_eval.evaluateModelOnce(dist, inst);
            }

            if (m_windowSize > 0) {

              m_windowEval.evaluateModelOnce(dist, inst);
              m_window.addFirst(inst);
              m_windowedPreds.addFirst(dist);

              if (m_instanceCount > m_windowSize) {
                // "forget" the oldest prediction
                Instance oldest = m_window.removeLast();

                double[] oldDist = m_windowedPreds.removeLast();
                oldest.setWeight(-oldest.weight());
                m_windowEval.evaluateModelOnce(oldDist, oldest);
                oldest.setWeight(-oldest.weight());
              }
            }
          } else {
            pred = ce.getClassifier().classifyInstance(inst);
          }
          if (inst.classIndex() >= 0) {
            // need to check that the class is not missing
            if (inst.attribute(inst.classIndex()).isNominal()) {
              if (!inst.isMissing(inst.classIndex())) {
                if (m_dataPoint.length < 2) {
                  m_dataPoint = new double[3];
                  m_dataLegend.addElement("Accuracy");
                  m_dataLegend.addElement("RMSE (prob)");
                  m_dataLegend.addElement("Kappa");
                }
                // int classV = (int) inst.value(inst.classIndex());

                if (m_windowSize > 0) {
                  m_dataPoint[1] = m_windowEval.rootMeanSquaredError();
                  m_dataPoint[2] = m_windowEval.kappa();
                } else {
                  m_dataPoint[1] = m_eval.rootMeanSquaredError();
                  m_dataPoint[2] = m_eval.kappa();
                }
                // int maxO = Utils.maxIndex(dist);
                // if (maxO == classV) {
                // dist[classV] = -1;
                // maxO = Utils.maxIndex(dist);
                // }
                // m_dataPoint[1] -= dist[maxO];
              } else {
                if (m_dataPoint.length < 1) {
                  m_dataPoint = new double[1];
                  m_dataLegend.addElement("Confidence");
                }
              }
              double primaryMeasure = 0;
              if (!inst.isMissing(inst.classIndex())) {
                if (m_windowSize > 0) {
                  primaryMeasure = 1.0 - m_windowEval.errorRate();
                } else {
                  primaryMeasure = 1.0 - m_eval.errorRate();
                }
              } else {
                // record confidence as the primary measure
                // (another possibility would be entropy of
                // the distribution, or perhaps average
                // confidence)
                primaryMeasure = dist[Utils.maxIndex(dist)];
              }
              // double [] dataPoint = new double[1];
              m_dataPoint[0] = primaryMeasure;
              // double min = 0; double max = 100;
              /*
               * ChartEvent e = new
               * ChartEvent(IncrementalClassifierEvaluator.this, m_dataLegend,
               * min, max, dataPoint);
               */

              m_ce.setLegendText(m_dataLegend);
              m_ce.setMin(0);
              m_ce.setMax(1);
              m_ce.setDataPoint(m_dataPoint);
              m_ce.setReset(m_reset);
              m_reset = false;
            } else {
              // numeric class
              if (m_dataPoint.length < 1) {
                m_dataPoint = new double[1];
                if (inst.isMissing(inst.classIndex())) {
                  m_dataLegend.addElement("Prediction");
                } else {
                  m_dataLegend.addElement("RMSE");
                }
              }
              if (!inst.isMissing(inst.classIndex())) {
                double update;
                if (!inst.isMissing(inst.classIndex())) {
                  if (m_windowSize > 0) {
                    update = m_windowEval.rootMeanSquaredError();
                  } else {
                    update = m_eval.rootMeanSquaredError();
                  }
                } else {
                  update = pred;
                }
                m_dataPoint[0] = update;
                if (update > m_max) {
                  m_max = update;
                }
                if (update < m_min) {
                  m_min = update;
                }
              }

              m_ce.setLegendText(m_dataLegend);
              m_ce.setMin((inst.isMissing(inst.classIndex()) ? m_min : 0));
              m_ce.setMax(m_max);
              m_ce.setDataPoint(m_dataPoint);
              m_ce.setReset(m_reset);
              m_reset = false;
            }
            notifyChartListeners(m_ce);
          }
          m_throughput.updateEnd(m_logger);
        }

        if (ce.getStatus() == IncrementalClassifierEvent.BATCH_FINISHED || inst == null) {
          if (m_logger != null) {
            m_logger.logMessage(
                "[IncrementalClassifierEvaluator]"
                    + statusMessagePrefix()
                    + " Finished processing.");
          }
          m_throughput.finished(m_logger);

          // save memory if using windowed evaluation for charting
          m_windowEval = null;
          m_window = null;
          m_windowedPreds = null;

          if (m_textListeners.size() > 0) {
            String textTitle = ce.getClassifier().getClass().getName();
            textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length());
            String results =
                "=== Performance information ===\n\n"
                    + "Scheme:   "
                    + textTitle
                    + "\n"
                    + "Relation: "
                    + m_eval.getHeader().relationName()
                    + "\n\n"
                    + m_eval.toSummaryString();
            if (m_eval.getHeader().classIndex() >= 0
                && m_eval.getHeader().classAttribute().isNominal()
                && (m_outputInfoRetrievalStats)) {
              results += "\n" + m_eval.toClassDetailsString();
            }

            if (m_eval.getHeader().classIndex() >= 0
                && m_eval.getHeader().classAttribute().isNominal()) {
              results += "\n" + m_eval.toMatrixString();
            }
            textTitle = "Results: " + textTitle;
            TextEvent te = new TextEvent(this, results, textTitle);
            notifyTextListeners(te);
          }
        }
      }
    } catch (Exception ex) {
      if (m_logger != null) {
        m_logger.logMessage(
            "[IncrementalClassifierEvaluator]"
                + statusMessagePrefix()
                + " Error processing prediction "
                + ex.getMessage());
        m_logger.statusMessage(
            statusMessagePrefix() + "ERROR: problem processing prediction (see log for details)");
      }
      ex.printStackTrace();
      stop();
    }
  }