/** * Emit the result from a sub process wrapper back up to whoever called us * * @param res The classifier result. */ @Override protected void _processResults(ClassifierResult res) { System.out.print( "SubProcessWrapper: Time(" + res.getTime() + ") Score(" + res.getScore() + ")"); String outputFilePrefix = mProperties.getProperty("modelOutputFilePrefix", null); if (outputFilePrefix != null) { try { if (res.getAttributeSelection() != null) { weka.core.SerializationHelper.write( outputFilePrefix + ".attributeselection", res.getAttributeSelection()); } else { File oldFile = new File(outputFilePrefix + ".attributeselection"); if (oldFile.exists()) oldFile.delete(); } weka.core.SerializationHelper.write(outputFilePrefix + ".model", res.getClassifier()); } catch (Exception e) { throw new RuntimeException(e); } } }
public static void main(String args[]) throws Exception { ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("src/train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); weka.core.Attribute trainAttribute = trainDataSet.attribute("class"); trainDataSet.setClass(trainAttribute); // trainDataSet.deleteStringAttributes(); NaiveBayes classifier = new NaiveBayes(); final double startTime = System.currentTimeMillis(); classifier.buildClassifier(trainDataSet); final double endTime = System.currentTimeMillis(); double executionTime = (endTime - startTime) / (1000.0); System.out.println("Total execution time: " + executionTime); SerializationHelper.write("NaiveBayes.model", classifier); System.out.println("Saved trained model to classifier.model"); }
// 输入问题,输出问题所属类型。 public double classifyByBayes(String question) throws Exception { double label = -1; List<Question> questionID = questionDAO.getQuestionIDLabeled(); // 定义数据格式 Attribute att1 = new Attribute("法律政策"); Attribute att2 = new Attribute("位置交通"); Attribute att3 = new Attribute("风水"); Attribute att4 = new Attribute("房价"); Attribute att5 = new Attribute("楼层"); Attribute att6 = new Attribute("户型"); Attribute att7 = new Attribute("小区配套"); Attribute att8 = new Attribute("贷款"); Attribute att9 = new Attribute("买房时机"); Attribute att10 = new Attribute("开发商"); FastVector labels = new FastVector(); labels.addElement("1"); labels.addElement("2"); labels.addElement("3"); labels.addElement("4"); labels.addElement("5"); labels.addElement("6"); labels.addElement("7"); labels.addElement("8"); labels.addElement("9"); labels.addElement("10"); Attribute att11 = new Attribute("类别", labels); FastVector attributes = new FastVector(); attributes.addElement(att1); attributes.addElement(att2); attributes.addElement(att3); attributes.addElement(att4); attributes.addElement(att5); attributes.addElement(att6); attributes.addElement(att7); attributes.addElement(att8); attributes.addElement(att9); attributes.addElement(att10); attributes.addElement(att11); Instances dataset = new Instances("Test-dataset", attributes, 0); dataset.setClassIndex(10); Classifier classifier = null; if (!new File("naivebayes.model").exists()) { // 添加数据 double[] values = new double[11]; for (int i = 0; i < questionID.size(); i++) { for (int m = 0; m < 11; m++) { values[m] = 0; } int whitewordcount = 0; whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId()); if (whitewordcount != 0) { List<QuestionWhiteWord> questionwhiteword = questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId()); for (int j = 0; j < questionwhiteword.size(); j++) { values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++; } for (int m = 0; m < 11; m++) { values[m] = values[m] / whitewordcount; } } values[10] = questionID.get(i).getType() - 1; Instance inst = new Instance(1.0, values); dataset.add(inst); } // 构造分类器 classifier = new NaiveBayes(); classifier.buildClassifier(dataset); SerializationHelper.write("naivebayes.model", classifier); } else { classifier = (Classifier) SerializationHelper.read("naivebayes.model"); } System.out.println("*************begin evaluation*******************"); Evaluation evaluation = new Evaluation(dataset); evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。 System.out.println(evaluation.toSummaryString()); // 分类 System.out.println("*************begin classification*******************"); Instance subject = new Instance(1.0, getQuestionVector(question)); subject.setDataset(dataset); label = classifier.classifyInstance(subject); System.out.println("label: " + label); // double dis[]=classifier.distributionForInstance(inst); // for(double i:dis){ // System.out.print(i+" "); // } System.out.println(questionID.size()); return label + 1; }
public void saveModel(String filename) throws Exception { SerializationHelper.write(filename, classifier); }
@Override public Void doInBackground() { BufferedReader reader; try { publish("Reading data..."); reader = new BufferedReader(new FileReader("cross_validation_data.arff")); final Instances trainingdata = new Instances(reader); reader.close(); // setting class attribute trainingdata.setClassIndex(13); trainingdata.randomize(new Random(1)); long startTime = System.nanoTime(); publish("Training Naive Bayes Classifier..."); NaiveBayes nb = new NaiveBayes(); startTime = System.nanoTime(); nb.buildClassifier(trainingdata); double runningTimeNB = (System.nanoTime() - startTime) / 1000000; runningTimeNB /= 1000; // saving the naive bayes model weka.core.SerializationHelper.write("naivebayes.model", nb); System.out.println("running time" + runningTimeNB); publish("Done training NB.\nEvaluating NB using 10-fold cross-validation..."); evalNB = new Evaluation(trainingdata); evalNB.crossValidateModel(nb, trainingdata, 10, new Random(1)); publish("Done evaluating NB."); // System.out.println(evalNB.toSummaryString("\nResults for Naive Bayes\n======\n", false)); MultilayerPerceptron mlp = new MultilayerPerceptron(); mlp.setOptions(Utils.splitOptions("-L 0.3 -M 0.2 -N 500 -V 0 -S 0 -E 20 -H a")); publish("Training ANN..."); startTime = System.nanoTime(); mlp.buildClassifier(trainingdata); long runningTimeANN = (System.nanoTime() - startTime) / 1000000; runningTimeANN /= 1000; // saving the MLP model weka.core.SerializationHelper.write("mlp.model", mlp); publish("Done training ANN.\nEvaluating ANN using 10-fold cross-validation..."); evalANN = new Evaluation(trainingdata); evalANN.evaluateModel(mlp, trainingdata); // evalMLP.crossValidateModel(mlp, trainingdata, 10, new Random(1)); publish("Done evaluating ANN."); publish("Training SVM..."); SMO svm = new SMO(); startTime = System.nanoTime(); svm.buildClassifier(trainingdata); long runningTimeSVM = (System.nanoTime() - startTime) / 1000000; runningTimeSVM /= 1000; weka.core.SerializationHelper.write("svm.model", svm); publish("Done training SVM.\nEvaluating SVM using 10-fold cross-validation..."); evalSVM = new Evaluation(trainingdata); evalSVM.evaluateModel(svm, trainingdata); publish("Done evaluating SVM."); Platform.runLater( new Runnable() { @Override public void run() { bc.getData() .get(0) .getData() .get(0) .setYValue(evalANN.correct() / trainingdata.size() * 100); bc.getData() .get(0) .getData() .get(1) .setYValue(evalSVM.correct() / trainingdata.size() * 100); bc.getData() .get(0) .getData() .get(2) .setYValue(evalNB.correct() / trainingdata.size() * 100); for (int i = 0; i < NUM_CLASSES; i++) { lineChart.getData().get(0).getData().get(i).setYValue(evalANN.recall(i) * 100); lineChart.getData().get(1).getData().get(i).setYValue(evalSVM.recall(i) * 100); lineChart.getData().get(2).getData().get(i).setYValue(evalNB.recall(i) * 100); } } }); panel.fillConfTable(evalSVM.confusionMatrix()); summaryTable.setValueAt(evalANN.correct() / trainingdata.size() * 100., 0, 1); summaryTable.setValueAt(evalSVM.correct() / trainingdata.size() * 100, 0, 2); summaryTable.setValueAt(evalNB.correct() / trainingdata.size() * 100, 0, 3); summaryTable.setValueAt(runningTimeANN, 1, 1); summaryTable.setValueAt(runningTimeSVM, 1, 2); summaryTable.setValueAt(runningTimeNB, 1, 3); } catch (Exception e1) { // TODO Auto-generated catch block e1.printStackTrace(); } return null; }