コード例 #1
0
 /**
  * Emit the result from a sub process wrapper back up to whoever called us
  *
  * @param res The classifier result.
  */
 @Override
 protected void _processResults(ClassifierResult res) {
   System.out.print(
       "SubProcessWrapper: Time(" + res.getTime() + ") Score(" + res.getScore() + ")");
   String outputFilePrefix = mProperties.getProperty("modelOutputFilePrefix", null);
   if (outputFilePrefix != null) {
     try {
       if (res.getAttributeSelection() != null) {
         weka.core.SerializationHelper.write(
             outputFilePrefix + ".attributeselection", res.getAttributeSelection());
       } else {
         File oldFile = new File(outputFilePrefix + ".attributeselection");
         if (oldFile.exists()) oldFile.delete();
       }
       weka.core.SerializationHelper.write(outputFilePrefix + ".model", res.getClassifier());
     } catch (Exception e) {
       throw new RuntimeException(e);
     }
   }
 }
コード例 #2
0
ファイル: train.java プロジェクト: kavyasrinet/CIFAR-10
  public static void main(String args[]) throws Exception {
    ArffLoader trainLoader = new ArffLoader();
    trainLoader.setSource(new File("src/train.arff"));
    trainLoader.setRetrieval(Loader.BATCH);
    Instances trainDataSet = trainLoader.getDataSet();
    weka.core.Attribute trainAttribute = trainDataSet.attribute("class");

    trainDataSet.setClass(trainAttribute);
    // trainDataSet.deleteStringAttributes();

    NaiveBayes classifier = new NaiveBayes();

    final double startTime = System.currentTimeMillis();
    classifier.buildClassifier(trainDataSet);
    final double endTime = System.currentTimeMillis();
    double executionTime = (endTime - startTime) / (1000.0);
    System.out.println("Total execution time: " + executionTime);

    SerializationHelper.write("NaiveBayes.model", classifier);
    System.out.println("Saved trained model to classifier.model");
  }
コード例 #3
0
  // 输入问题,输出问题所属类型。
  public double classifyByBayes(String question) throws Exception {
    double label = -1;
    List<Question> questionID = questionDAO.getQuestionIDLabeled();

    // 定义数据格式
    Attribute att1 = new Attribute("法律政策");
    Attribute att2 = new Attribute("位置交通");
    Attribute att3 = new Attribute("风水");
    Attribute att4 = new Attribute("房价");
    Attribute att5 = new Attribute("楼层");
    Attribute att6 = new Attribute("户型");
    Attribute att7 = new Attribute("小区配套");
    Attribute att8 = new Attribute("贷款");
    Attribute att9 = new Attribute("买房时机");
    Attribute att10 = new Attribute("开发商");
    FastVector labels = new FastVector();
    labels.addElement("1");
    labels.addElement("2");
    labels.addElement("3");
    labels.addElement("4");
    labels.addElement("5");
    labels.addElement("6");
    labels.addElement("7");
    labels.addElement("8");
    labels.addElement("9");
    labels.addElement("10");
    Attribute att11 = new Attribute("类别", labels);

    FastVector attributes = new FastVector();
    attributes.addElement(att1);
    attributes.addElement(att2);
    attributes.addElement(att3);
    attributes.addElement(att4);
    attributes.addElement(att5);
    attributes.addElement(att6);
    attributes.addElement(att7);
    attributes.addElement(att8);
    attributes.addElement(att9);
    attributes.addElement(att10);
    attributes.addElement(att11);
    Instances dataset = new Instances("Test-dataset", attributes, 0);
    dataset.setClassIndex(10);

    Classifier classifier = null;
    if (!new File("naivebayes.model").exists()) {
      // 添加数据
      double[] values = new double[11];
      for (int i = 0; i < questionID.size(); i++) {
        for (int m = 0; m < 11; m++) {
          values[m] = 0;
        }
        int whitewordcount = 0;
        whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId());
        if (whitewordcount != 0) {
          List<QuestionWhiteWord> questionwhiteword =
              questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId());
          for (int j = 0; j < questionwhiteword.size(); j++) {
            values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++;
          }
          for (int m = 0; m < 11; m++) {
            values[m] = values[m] / whitewordcount;
          }
        }
        values[10] = questionID.get(i).getType() - 1;
        Instance inst = new Instance(1.0, values);
        dataset.add(inst);
      }
      // 构造分类器
      classifier = new NaiveBayes();
      classifier.buildClassifier(dataset);
      SerializationHelper.write("naivebayes.model", classifier);
    } else {
      classifier = (Classifier) SerializationHelper.read("naivebayes.model");
    }

    System.out.println("*************begin evaluation*******************");
    Evaluation evaluation = new Evaluation(dataset);
    evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。
    System.out.println(evaluation.toSummaryString());

    // 分类
    System.out.println("*************begin classification*******************");
    Instance subject = new Instance(1.0, getQuestionVector(question));
    subject.setDataset(dataset);
    label = classifier.classifyInstance(subject);
    System.out.println("label: " + label);

    //        double dis[]=classifier.distributionForInstance(inst);
    //        for(double i:dis){
    //            System.out.print(i+"    ");
    //        }

    System.out.println(questionID.size());
    return label + 1;
  }
コード例 #4
0
ファイル: MyWekaExplorer.java プロジェクト: Teofebano19/MyANN
 public void saveModel(String filename) throws Exception {
   SerializationHelper.write(filename, classifier);
 }
コード例 #5
0
    @Override
    public Void doInBackground() {
      BufferedReader reader;
      try {
        publish("Reading data...");
        reader = new BufferedReader(new FileReader("cross_validation_data.arff"));
        final Instances trainingdata = new Instances(reader);
        reader.close();
        // setting class attribute
        trainingdata.setClassIndex(13);
        trainingdata.randomize(new Random(1));
        long startTime = System.nanoTime();

        publish("Training Naive Bayes Classifier...");

        NaiveBayes nb = new NaiveBayes();
        startTime = System.nanoTime();
        nb.buildClassifier(trainingdata);
        double runningTimeNB = (System.nanoTime() - startTime) / 1000000;
        runningTimeNB /= 1000;
        // saving the naive bayes model
        weka.core.SerializationHelper.write("naivebayes.model", nb);
        System.out.println("running time" + runningTimeNB);
        publish("Done training NB.\nEvaluating NB using 10-fold cross-validation...");
        evalNB = new Evaluation(trainingdata);
        evalNB.crossValidateModel(nb, trainingdata, 10, new Random(1));
        publish("Done evaluating NB.");

        // System.out.println(evalNB.toSummaryString("\nResults for Naive Bayes\n======\n", false));

        MultilayerPerceptron mlp = new MultilayerPerceptron();
        mlp.setOptions(Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a"));
        publish("Training ANN...");
        startTime = System.nanoTime();
        mlp.buildClassifier(trainingdata);
        long runningTimeANN = (System.nanoTime() - startTime) / 1000000;
        runningTimeANN /= 1000;
        // saving the MLP model
        weka.core.SerializationHelper.write("mlp.model", mlp);

        publish("Done training ANN.\nEvaluating ANN using 10-fold cross-validation...");

        evalANN = new Evaluation(trainingdata);
        evalANN.evaluateModel(mlp, trainingdata);
        // evalMLP.crossValidateModel(mlp, trainingdata, 10, new Random(1));

        publish("Done evaluating ANN.");
        publish("Training SVM...");
        SMO svm = new SMO();

        startTime = System.nanoTime();
        svm.buildClassifier(trainingdata);
        long runningTimeSVM = (System.nanoTime() - startTime) / 1000000;
        runningTimeSVM /= 1000;
        weka.core.SerializationHelper.write("svm.model", svm);
        publish("Done training SVM.\nEvaluating SVM using 10-fold cross-validation...");
        evalSVM = new Evaluation(trainingdata);
        evalSVM.evaluateModel(svm, trainingdata);
        publish("Done evaluating SVM.");

        Platform.runLater(
            new Runnable() {
              @Override
              public void run() {
                bc.getData()
                    .get(0)
                    .getData()
                    .get(0)
                    .setYValue(evalANN.correct() / trainingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(1)
                    .setYValue(evalSVM.correct() / trainingdata.size() * 100);
                bc.getData()
                    .get(0)
                    .getData()
                    .get(2)
                    .setYValue(evalNB.correct() / trainingdata.size() * 100);

                for (int i = 0; i < NUM_CLASSES; i++) {
                  lineChart.getData().get(0).getData().get(i).setYValue(evalANN.recall(i) * 100);
                  lineChart.getData().get(1).getData().get(i).setYValue(evalSVM.recall(i) * 100);
                  lineChart.getData().get(2).getData().get(i).setYValue(evalNB.recall(i) * 100);
                }
              }
            });

        panel.fillConfTable(evalSVM.confusionMatrix());

        summaryTable.setValueAt(evalANN.correct() / trainingdata.size() * 100., 0, 1);
        summaryTable.setValueAt(evalSVM.correct() / trainingdata.size() * 100, 0, 2);
        summaryTable.setValueAt(evalNB.correct() / trainingdata.size() * 100, 0, 3);

        summaryTable.setValueAt(runningTimeANN, 1, 1);
        summaryTable.setValueAt(runningTimeSVM, 1, 2);
        summaryTable.setValueAt(runningTimeNB, 1, 3);

      } catch (Exception e1) {
        // TODO Auto-generated catch block
        e1.printStackTrace();
      }
      return null;
    }