/** * 用分类器测试 * * @param trainFileName * @param testFileName */ public static void classify(String trainFileName, String testFileName) { try { File inputFile = new File(fileName + trainFileName); // 训练语料文件 ArffLoader atf = new ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); // 读入训练文件 // 设置类标签类 inputFile = new File(fileName + testFileName); // 测试语料文件 atf.setFile(inputFile); Instances instancesTest = atf.getDataSet(); // 读入测试文件 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance(); classifier.buildClassifier(instancesTrain); Evaluation eval = new Evaluation(instancesTrain); // 第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集 eval.evaluateModel(classifier, instancesTest); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); System.out.println("precision is :" + (1 - eval.errorRate())); } catch (Exception e) { e.printStackTrace(); } }
/** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { Instances isTrainingSet = createSet(4); Instance instance1 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance2 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); Instance instance22 = createInstance(new double[] {0, 0, 0, 0}, "S3", isTrainingSet); isTrainingSet.add(instance1); isTrainingSet.add(instance2); isTrainingSet.add(instance22); Instances isTestingSet = createSet(4); Instance instance3 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance4 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); isTestingSet.add(instance3); isTestingSet.add(instance4); // Create a naïve bayes classifier Classifier cModel = (Classifier) new BayesNet(); // M5P cModel.buildClassifier(isTrainingSet); // Test the model Evaluation eTest = new Evaluation(isTrainingSet); eTest.evaluateModel(cModel, isTestingSet); // Print the result à la Weka explorer: String strSummary = eTest.toSummaryString(); System.out.println(strSummary); // Get the likelihood of each classes // fDistribution[0] is the probability of being “positive” // fDistribution[1] is the probability of being “negative” double[] fDistribution = cModel.distributionForInstance(instance4); for (int i = 0; i < fDistribution.length; i++) { System.out.println(fDistribution[i]); } }
/** * Creates an evaluation overview of the built classifier. * * @return the panel to be displayed as result evaluation view for the current decision point */ protected JPanel createEvaluationVisualization(Instances data) { // build text field to display evaluation statistics JTextPane statistic = new JTextPane(); try { // build evaluation statistics Evaluation evaluation = new Evaluation(data); evaluation.evaluateModel(myClassifier, data); statistic.setText( evaluation.toSummaryString() + "\n\n" + evaluation.toClassDetailsString() + "\n\n" + evaluation.toMatrixString()); } catch (Exception ex) { ex.printStackTrace(); return createMessagePanel("Error while creating the decision tree evaluation view"); } statistic.setFont(new Font("Courier", Font.PLAIN, 14)); statistic.setEditable(false); statistic.setCaretPosition(0); JPanel resultViewPanel = new JPanel(); resultViewPanel.setLayout(new BoxLayout(resultViewPanel, BoxLayout.PAGE_AXIS)); resultViewPanel.add(new JScrollPane(statistic)); return resultViewPanel; }
public static void run(String[] args) throws Exception { /** * ************************************************* * * @param args[0]: train arff path * @param args[1]: test arff path */ DataSource source = new DataSource(args[0]); Instances data = source.getDataSet(); data.setClassIndex(data.numAttributes() - 1); NaiveBayes model = new NaiveBayes(); model.buildClassifier(data); // Evaluation: Evaluation eval = new Evaluation(data); Instances testData = new DataSource(args[1]).getDataSet(); testData.setClassIndex(testData.numAttributes() - 1); eval.evaluateModel(model, testData); System.out.println(model.toString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); System.out.println("======\nConfusion Matrix:"); double[][] confusionM = eval.confusionMatrix(); for (int i = 0; i < confusionM.length; ++i) { for (int j = 0; j < confusionM[i].length; ++j) { System.out.format("%10s ", confusionM[i][j]); } System.out.print("\n"); } }
public static void wekaAlgorithms(Instances data) throws Exception { classifier = new FilteredClassifier(); // new instance of tree classifier.setClassifier(new NaiveBayes()); // classifier.setClassifier(new J48()); // classifier.setClassifier(new RandomForest()); // classifier.setClassifier(new ZeroR()); // classifier.setClassifier(new NaiveBayes()); // classifier.setClassifier(new IBk()); data.setClassIndex(data.numAttributes() - 1); Evaluation eval = new Evaluation(data); int folds = 10; eval.crossValidateModel(classifier, data, folds, new Random(1)); System.out.println("===== Evaluating on filtered (training) dataset ====="); System.out.println(eval.toSummaryString()); System.out.println(eval.toClassDetailsString()); double[][] mat = eval.confusionMatrix(); System.out.println("========= Confusion Matrix ========="); for (int i = 0; i < mat.length; i++) { for (int j = 0; j < mat.length; j++) { System.out.print(mat[i][j] + " "); } System.out.println(" "); } }
/** evaluates the classifier */ @Override public void evaluate() throws Exception { // evaluate classifier and print some statistics if (_test.classIndex() == -1) _test.setClassIndex(_test.numAttributes() - 1); Evaluation eval = new Evaluation(_train); eval.evaluateModel(_cl, _test); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); System.out.println(eval.toMatrixString()); }
/** uses the meta-classifier */ protected static void useClassifier(Instances data) throws Exception { System.out.println("\n1. Meta-classfier"); AttributeSelectedClassifier classifier = new AttributeSelectedClassifier(); CfsSubsetEval eval = new CfsSubsetEval(); GreedyStepwise search = new GreedyStepwise(); search.setSearchBackwards(true); J48 base = new J48(); classifier.setClassifier(base); classifier.setEvaluator(eval); classifier.setSearch(search); Evaluation evaluation = new Evaluation(data); evaluation.crossValidateModel(classifier, data, 10, new Random(1)); System.out.println(evaluation.toSummaryString()); }
public void evaluateClassifyMethod() throws Exception { Instances data1 = getInstances(); Evaluation eval = new Evaluation(data1); J48 tree = new J48(); eval.crossValidateModel(tree, data1, 10, new Random(1)); System.out.println("决策树准确率:"); System.out.println(eval.toSummaryString("\nResults\n\n", false)); Instances data2 = getInstances(); Evaluation eval2 = new Evaluation(data2); NaiveBayes bayes = new NaiveBayes(); eval2.crossValidateModel(bayes, data2, 10, new Random(1)); System.out.println("贝叶斯准确率:"); System.out.println(eval2.toSummaryString("\nResults\n\n", false)); Instances data3 = getInstances(); Evaluation eval3 = new Evaluation(data3); IBk ibk = new IBk(3); eval3.crossValidateModel(ibk, data3, 10, new Random(1)); System.out.println("KNN准确率:"); System.out.println(eval3.toSummaryString("\nResults\n\n", false)); }
public static Double runClassify(String trainFile, String testFile) { double predictOrder = 0.0; double trueOrder = 0.0; try { String trainWekaFileName = trainFile; String testWekaFileName = testFile; Instances train = DataSource.read(trainWekaFileName); Instances test = DataSource.read(testWekaFileName); train.setClassIndex(0); test.setClassIndex(0); train.deleteAttributeAt(8); test.deleteAttributeAt(8); train.deleteAttributeAt(6); test.deleteAttributeAt(6); train.deleteAttributeAt(5); test.deleteAttributeAt(5); train.deleteAttributeAt(4); test.deleteAttributeAt(4); // AdditiveRegression classifier = new AdditiveRegression(); // NaiveBayes classifier = new NaiveBayes(); RandomForest classifier = new RandomForest(); // LibSVM classifier = new LibSVM(); classifier.buildClassifier(train); Evaluation eval = new Evaluation(train); eval.evaluateModel(classifier, test); System.out.println(eval.toSummaryString("\nResults\n\n", true)); // System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toMatrixString()); int k = 892; for (int i = 0; i < test.numInstances(); i++) { predictOrder = classifier.classifyInstance(test.instance(i)); trueOrder = test.instance(i).classValue(); System.out.println((k++) + "," + (int) predictOrder); } } catch (Exception e) { e.printStackTrace(); } return predictOrder; }
public static void trainModel(Instances dataTrain, Instances dataTest) { try { LibLINEAR classifier = new LibLINEAR(); classifier.setBias(10); classifier.buildClassifier(dataTrain); Evaluation eval = new Evaluation(dataTrain); eval.evaluateModel(classifier, dataTest); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
@Override public void crossValidation(String traindata) throws Exception { DataSource ds = new DataSource(traindata); Instances instances = ds.getDataSet(); StringToWordVector stv = new StringToWordVector(); stv.setOptions( weka.core.Utils.splitOptions( "-R first-last -W 1000 " + "-prune-rate -1.0 -N 0 " + "-stemmer weka.core.stemmers.NullStemmer -M 1 " + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters \\\" \\r\\n\\t.,;:\\\'\\\"()?!\"")); stv.setInputFormat(instances); instances = Filter.useFilter(instances, stv); instances.setClassIndex(0); Evaluation eval = new Evaluation(instances); eval.crossValidateModel(this.classifier, instances, 10, new Random(1)); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); }
@Override protected Evaluation test() { working = true; System.out.println("Agent " + getLocalName() + ": Testing..."); // evaluate classifier and print some statistics Evaluation eval = null; try { eval = new Evaluation(train); eval.evaluateModel(cls, test); System.out.println( eval.toSummaryString(getLocalName() + " agent: " + "\nResults\n=======\n", false)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } working = false; return eval; } // end test
private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData) throws Exception { System.err.println( "INFO: Starting split validation to predict '" + trainData.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' (#train=" + trainData.numInstances() + ",#test=" + testData.numInstances() + ") ..."); if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set"); c.buildClassifier(trainData); Evaluation eval = new Evaluation(testData); eval.useNoPriors(); double[] predictions = eval.evaluateModel(c, testData); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); // write predictions to file { System.err.println("INFO: Writing predictions to file ..."); Writer out = new FileWriter("prediction.trec"); writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out); out.close(); } // write predicted distributions to CSV { System.err.println("INFO: Writing predicted distributions to CSV ..."); Writer out = new FileWriter("predicted_distribution.csv"); writePredictedDistributions(c, testData, 0, out); out.close(); } }
/** outputs some data about the classifier */ public String toString() { StringBuffer result; result = new StringBuffer(); result.append("Weka - Demo\n===========\n\n"); result.append( "Classifier...: " + m_Classifier.getClass().getName() + " " + Utils.joinOptions(m_Classifier.getOptions()) + "\n"); if (m_Filter instanceof OptionHandler) result.append( "Filter.......: " + m_Filter.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) m_Filter).getOptions()) + "\n"); else result.append("Filter.......: " + m_Filter.getClass().getName() + "\n"); result.append("Training file: " + m_TrainingFile + "\n"); result.append("\n"); result.append(m_Classifier.toString() + "\n"); result.append(m_Evaluation.toSummaryString() + "\n"); try { result.append(m_Evaluation.toMatrixString() + "\n"); } catch (Exception e) { e.printStackTrace(); } try { result.append(m_Evaluation.toClassDetailsString() + "\n"); } catch (Exception e) { e.printStackTrace(); } return result.toString(); }
static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception { System.err.println( "INFO: Starting crossvalidation to predict '" + data.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' ..."); StringBuffer sb = new StringBuffer(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE); // write predictions to file { Writer out = new FileWriter("cv.log"); out.write(sb.toString()); out.close(); } System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); }
public static void analyze_accuracy_NHBS(int rng_seed) throws Exception { HashMap<String, Object> population_params = load_defaults(null); RawLoader rl = new RawLoader(population_params, true, false, rng_seed); List<DrugUser> learningData = rl.getLearningData(); Instances nhbs_data = new Instances("learning_instances", DrugUser.getAttInfo(), learningData.size()); for (DrugUser du : learningData) { nhbs_data.add(du.getInstance()); } System.out.println(nhbs_data.toSummaryString()); nhbs_data.setClass(DrugUser.getAttribMap().get("hcv_state")); // wishlist: remove infrequent values // weka.filters.unsupervised.instance.RemoveFrequentValues() Filter f1 = new RemoveUseless(); f1.setInputFormat(nhbs_data); nhbs_data = Filter.useFilter(nhbs_data, f1); System.out.println("NHBS IDU 2009 Dataset"); System.out.println("Summary of input:"); // System.out.printlnnhbs_data.toSummaryString()); System.out.println(" Num of classes: " + nhbs_data.numClasses()); System.out.println(" Num of attributes: " + nhbs_data.numAttributes()); for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) { Attribute attr = nhbs_data.attribute(idx); System.out.println("" + idx + ": " + attr.toString()); System.out.println(" distinct values:" + nhbs_data.numDistinctValues(idx)); // System.out.println("" + attr.enumerateValues()); } ArrayList<String> options = new ArrayList<String>(); options.add("-Q"); options.add("" + rng_seed); // System.exit(0); // nhbs_data.deleteAttributeAt(0); //response ID // nhbs_data.deleteAttributeAt(16); //zip // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00 // ROC=0.60 // Classifier classifier = new MINND(); // Classifier classifier = new CitationKNN(); // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7% // Classifier classifier = new SMOreg(); Classifier classifier = new Logistic(); // ROC=0.686 // Classifier classifier = new LinearNNSearch(); // LinearRegression: Cannot handle multi-valued nominal class! // Classifier classifier = new LinearRegression(); // Classifier classifier = new RandomForest(); // String[] options = {"-I", "100", "-K", "4"}; //-I trees, -K features per tree. generally, // might want to optimize (or not // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests) // options.add("-I"); options.add("100"); options.add("-K"); options.add("4"); // ROC=0.673 // KStar classifier = new KStar(); // classifier.setGlobalBlend(20); //the amount of not greedy, in percent // ROC=0.633 // Classifier classifier = new AdaBoostM1(); // ROC=0.66 // Classifier classifier = new MultiBoostAB(); // ROC=0.67 // Classifier classifier = new Stacking(); // ROC=0.495 // J48 classifier = new J48(); // new instance of tree //building a C45 tree classifier // ROC=0.585 // String[] options = new String[1]; // options[0] = "-U"; // unpruned tree // classifier.setOptions(options); // set the options classifier.setOptions((String[]) options.toArray(new String[0])); // not needed before CV: http://weka.wikispaces.com/Use+WEKA+in+your+Java+code // classifier.buildClassifier(nhbs_data); // build classifier // evaluation Evaluation eval = new Evaluation(nhbs_data); eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); // 10-fold cross validation System.out.println(eval.toSummaryString("\nResults\n\n", false)); System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toCumulativeMarginDistributionString()); }
public static void test_NHBS_old() throws Exception { // load the data CSVLoader loader = new CSVLoader(); // these must come before the getDataSet() // loader.setEnclosureCharacters(",\'\"S"); // loader.setNominalAttributes("16,71"); //zip code, drug name // loader.setStringAttributes(""); // loader.setDateAttributes("0,1"); // loader.setSource(new File("hcv/data/NHBS/IDU2_HCV_model_012913_cleaned_for_weka.csv")); loader.setSource(new File("/home/sasha/hcv/code/data/IDU2_HCV_model_012913_cleaned.csv")); Instances nhbs_data = loader.getDataSet(); loader.setMissingValue("NOVALUE"); // loader.setMissingValue(""); nhbs_data.deleteAttributeAt(12); // zip code nhbs_data.deleteAttributeAt(1); // date - redundant with age nhbs_data.deleteAttributeAt(0); // date System.out.println("classifying attribute:"); nhbs_data.setClassIndex(1); // new index 3->2->1 nhbs_data.attribute(1).getMetadata().toString(); // HCVEIARSLT1 // wishlist: perhaps it would be smarter to throw out unclassified instance? they interfere // with the scoring nhbs_data.deleteWithMissingClass(); // nhbs_data.setClass(new Attribute("HIVRSLT"));//.setClassIndex(1); //2nd column. all are // mostly negative // nhbs_data.setClass(new Attribute("HCVEIARSLT1"));//.setClassIndex(2); //3rd column // #14, i.e. rds_fem, should be made numeric System.out.println("NHBS IDU 2009 Dataset"); System.out.println("Summary of input:"); // System.out.printlnnhbs_data.toSummaryString()); System.out.println(" Num of classes: " + nhbs_data.numClasses()); System.out.println(" Num of attributes: " + nhbs_data.numAttributes()); for (int idx = 0; idx < nhbs_data.numAttributes(); ++idx) { Attribute attr = nhbs_data.attribute(idx); System.out.println("" + idx + ": " + attr.toString()); System.out.println(" distinct values:" + nhbs_data.numDistinctValues(idx)); // System.out.println("" + attr.enumerateValues()); } // System.exit(0); // nhbs_data.deleteAttributeAt(0); //response ID // nhbs_data.deleteAttributeAt(16); //zip // Classifier classifier = new NNge(); //best nearest-neighbor classifier: 40.00 // Classifier classifier = new MINND(); // Classifier classifier = new CitationKNN(); // Classifier classifier = new LibSVM(); //requires LibSVM classes. only gets 37.7% // Classifier classifier = new SMOreg(); // Classifier classifier = new LinearNNSearch(); // LinearRegression: Cannot handle multi-valued nominal class! // Classifier classifier = new LinearRegression(); Classifier classifier = new RandomForest(); String[] options = { "-I", "100", "-K", "4" }; // -I trees, -K features per tree. generally, might want to optimize (or not // https://cwiki.apache.org/confluence/display/MAHOUT/Random+Forests) classifier.setOptions(options); // Classifier classifier = new Logistic(); // KStar classifier = new KStar(); // classifier.setGlobalBlend(20); //the amount of not greedy, in percent // does poorly // Classifier classifier = new AdaBoostM1(); // Classifier classifier = new MultiBoostAB(); // Classifier classifier = new Stacking(); // building a C45 tree classifier // J48 classifier = new J48(); // new instance of tree // String[] options = new String[1]; // options[0] = "-U"; // unpruned tree // classifier.setOptions(options); // set the options // classifier.buildClassifier(nhbs_data); // build classifier // wishlist: remove infrequent values // weka.filters.unsupervised.instance.RemoveFrequentValues() Filter f1 = new RemoveUseless(); f1.setInputFormat(nhbs_data); nhbs_data = Filter.useFilter(nhbs_data, f1); // evaluation Evaluation eval = new Evaluation(nhbs_data); eval.crossValidateModel(classifier, nhbs_data, 10, new Random(1)); System.out.println(eval.toSummaryString("\nResults\n\n", false)); System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toCumulativeMarginDistributionString()); }
public static void main(String[] args) { if (args.length < 1) { System.out.println("usage: C4_5TweetTopicCategorization <root_path>"); System.exit(-1); } String rootPath = args[0]; File dataFolder = new File(rootPath + "/data"); String resultFolderPath = rootPath + "/results/C4_5/"; CrisisMailer crisisMailer = CrisisMailer.getCrisisMailer(); Logger logger = Logger.getLogger(C4_5TweetTopicCategorization.class); PropertyConfigurator.configure(Constants.LOG4J_PROPERTIES_FILE_PATH); File resultFolder = new File(resultFolderPath); if (!resultFolder.exists()) resultFolder.mkdir(); CSVLoader csvLoader = new CSVLoader(); try { for (File dataSetName : dataFolder.listFiles()) { Instances data = null; try { csvLoader.setSource(dataSetName); csvLoader.setStringAttributes("2"); data = csvLoader.getDataSet(); } catch (IOException ioe) { logger.error(ioe); crisisMailer.sendEmailAlert(ioe); System.exit(-1); } data.setClassIndex(data.numAttributes() - 1); data.deleteWithMissingClass(); Instances vectorizedData = null; StringToWordVector stringToWordVectorFilter = new StringToWordVector(); try { stringToWordVectorFilter.setInputFormat(data); stringToWordVectorFilter.setAttributeIndices("2"); stringToWordVectorFilter.setIDFTransform(true); stringToWordVectorFilter.setLowerCaseTokens(true); stringToWordVectorFilter.setOutputWordCounts(false); stringToWordVectorFilter.setUseStoplist(true); vectorizedData = Filter.useFilter(data, stringToWordVectorFilter); vectorizedData.deleteAttributeAt(0); // System.out.println(vectorizedData); } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.exit(-1); } J48 j48Classifier = new J48(); /* FilteredClassifier filteredClassifier = new FilteredClassifier(); filteredClassifier.setFilter(stringToWordVectorFilter); filteredClassifier.setClassifier(j48Classifier); */ try { Evaluation eval = new Evaluation(vectorizedData); eval.crossValidateModel( j48Classifier, vectorizedData, 5, new Random(System.currentTimeMillis())); FileOutputStream resultOutputStream = new FileOutputStream(new File(resultFolderPath + dataSetName.getName())); resultOutputStream.write(eval.toSummaryString("=== Summary ===", false).getBytes()); resultOutputStream.write(eval.toMatrixString().getBytes()); resultOutputStream.write(eval.toClassDetailsString().getBytes()); resultOutputStream.close(); } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.exit(-1); } } } catch (Exception exception) { logger.error(exception); crisisMailer.sendEmailAlert(exception); System.out.println(-1); } }
public double getLiblinear(String path, String train, String test) { // 本次精确度 double accuracy = 0.0; try { LibLINEAR c1 = new LibLINEAR(); // * String[] options=weka.core.Utils.splitOptions( // * "-S 1 -C 1.0 -E 0.001 -B 0"); c1.setOptions(options); ArffLoader atf = new ArffLoader(); File TraininputFile = new File(train); atf.setFile(TraininputFile); // 训练语料文件 Instances instancesTrain = atf.getDataSet(); // 读入训练文件 instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); File TestinputFile = new File(test); atf.setFile(TestinputFile); // 测试语料文件 Instances instancesTest = atf.getDataSet(); // 读入测试文件 // 设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); c1.buildClassifier(instancesTrain); // 训练 Evaluation eval = new Evaluation(instancesTrain); eval.evaluateModel(c1, instancesTest); // eval.crossValidateModel(c1, instancesTrain, 10, new // Random(1)); File newfile = new File(path + "OutLiblinear_temp" + ".txt"); BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(newfile), "utf-8")); bufferedWriter.write(eval.toSummaryString() + "\r\n"); bufferedWriter.write(eval.toClassDetailsString() + "\r\n"); bufferedWriter.write(eval.toMatrixString() + "\r\n"); bufferedWriter.flush(); bufferedWriter.close(); BufferedReader bufferedReader = new BufferedReader(new FileReader(newfile)); String[] splitLineString = new String[5]; while (bufferedReader.ready()) { bufferedReader.readLine(); String lineString = bufferedReader.readLine(); splitLineString = lineString.split(" "); System.out.println(splitLineString[4]); break; } bufferedReader.close(); // 求分类准确度 String tempLine; BufferedReader tempBF = new BufferedReader(new FileReader(newfile)); while (tempBF.ready()) { tempLine = tempBF.readLine(); if (tempLine.contains("Correctly Classified Instances")) { tempLine = tempLine.substring(tempLine.lastIndexOf(".") - 2, tempLine.lastIndexOf(" ")); accuracy = Double.parseDouble(tempLine); break; } } tempBF.close(); } catch (Exception e) { System.out.println("Can't run linlinear of weka."); } return accuracy; }
/** * Gets the results for the supplied train and test datasets. Now performs a deep copy of the * classifier before it is built and evaluated (just in case the classifier is not initialized * properly in buildClassifier()). * * @param train the training Instances. * @param test the testing Instances. * @return the results stored in an array. The objects stored in the array may be Strings, * Doubles, or null (for the missing value). * @throws Exception if a problem occurs while getting the results */ public Object[] getResult(Instances train, Instances test) throws Exception { if (train.classAttribute().type() != Attribute.NUMERIC) { throw new Exception("Class attribute is not numeric!"); } if (m_Template == null) { throw new Exception("No classifier has been specified"); } ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean(); boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported(); if (canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled()) thMonitor.setThreadCpuTimeEnabled(true); int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0; Object[] result = new Object[RESULT_SIZE + addm + m_numPluginStatistics]; long thID = Thread.currentThread().getId(); long CPUStartTime = -1, trainCPUTimeElapsed = -1, testCPUTimeElapsed = -1, trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed; Evaluation eval = new Evaluation(train); m_Classifier = AbstractClassifier.makeCopy(m_Template); trainTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); m_Classifier.buildClassifier(train); if (canMeasureCPUTime) trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; testTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); eval.evaluateModel(m_Classifier, test); if (canMeasureCPUTime) testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; testTimeElapsed = System.currentTimeMillis() - testTimeStart; thMonitor = null; m_result = eval.toSummaryString(); // The results stored are all per instance -- can be multiplied by the // number of instances to get absolute numbers int current = 0; result[current++] = new Double(train.numInstances()); result[current++] = new Double(eval.numInstances()); result[current++] = new Double(eval.meanAbsoluteError()); result[current++] = new Double(eval.rootMeanSquaredError()); result[current++] = new Double(eval.relativeAbsoluteError()); result[current++] = new Double(eval.rootRelativeSquaredError()); result[current++] = new Double(eval.correlationCoefficient()); result[current++] = new Double(eval.SFPriorEntropy()); result[current++] = new Double(eval.SFSchemeEntropy()); result[current++] = new Double(eval.SFEntropyGain()); result[current++] = new Double(eval.SFMeanPriorEntropy()); result[current++] = new Double(eval.SFMeanSchemeEntropy()); result[current++] = new Double(eval.SFMeanEntropyGain()); // Timing stats result[current++] = new Double(trainTimeElapsed / 1000.0); result[current++] = new Double(testTimeElapsed / 1000.0); if (canMeasureCPUTime) { result[current++] = new Double((trainCPUTimeElapsed / 1000000.0) / 1000.0); result[current++] = new Double((testCPUTimeElapsed / 1000000.0) / 1000.0); } else { result[current++] = new Double(Utils.missingValue()); result[current++] = new Double(Utils.missingValue()); } // sizes if (m_NoSizeDetermination) { result[current++] = -1.0; result[current++] = -1.0; result[current++] = -1.0; } else { ByteArrayOutputStream bastream = new ByteArrayOutputStream(); ObjectOutputStream oostream = new ObjectOutputStream(bastream); oostream.writeObject(m_Classifier); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(train); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(test); result[current++] = new Double(bastream.size()); } // Prediction interval statistics result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions()); result[current++] = new Double(eval.sizeOfPredictedRegions()); if (m_Classifier instanceof Summarizable) { result[current++] = ((Summarizable) m_Classifier).toSummaryString(); } else { result[current++] = null; } for (int i = 0; i < addm; i++) { if (m_doesProduce[i]) { try { double dv = ((AdditionalMeasureProducer) m_Classifier).getMeasure(m_AdditionalMeasures[i]); if (!Utils.isMissingValue(dv)) { Double value = new Double(dv); result[current++] = value; } else { result[current++] = null; } } catch (Exception ex) { System.err.println(ex); } } else { result[current++] = null; } } // get the actual metrics from the evaluation object List<AbstractEvaluationMetric> metrics = eval.getPluginMetrics(); if (metrics != null) { for (AbstractEvaluationMetric m : metrics) { if (m.appliesToNumericClass()) { List<String> statNames = m.getStatisticNames(); for (String s : statNames) { result[current++] = new Double(m.getStatistic(s)); } } } } if (current != RESULT_SIZE + addm + m_numPluginStatistics) { throw new Error("Results didn't fit RESULT_SIZE"); } return result; }
// 输入问题,输出问题所属类型。 public double classifyByBayes(String question) throws Exception { double label = -1; List<Question> questionID = questionDAO.getQuestionIDLabeled(); // 定义数据格式 Attribute att1 = new Attribute("法律政策"); Attribute att2 = new Attribute("位置交通"); Attribute att3 = new Attribute("风水"); Attribute att4 = new Attribute("房价"); Attribute att5 = new Attribute("楼层"); Attribute att6 = new Attribute("户型"); Attribute att7 = new Attribute("小区配套"); Attribute att8 = new Attribute("贷款"); Attribute att9 = new Attribute("买房时机"); Attribute att10 = new Attribute("开发商"); FastVector labels = new FastVector(); labels.addElement("1"); labels.addElement("2"); labels.addElement("3"); labels.addElement("4"); labels.addElement("5"); labels.addElement("6"); labels.addElement("7"); labels.addElement("8"); labels.addElement("9"); labels.addElement("10"); Attribute att11 = new Attribute("类别", labels); FastVector attributes = new FastVector(); attributes.addElement(att1); attributes.addElement(att2); attributes.addElement(att3); attributes.addElement(att4); attributes.addElement(att5); attributes.addElement(att6); attributes.addElement(att7); attributes.addElement(att8); attributes.addElement(att9); attributes.addElement(att10); attributes.addElement(att11); Instances dataset = new Instances("Test-dataset", attributes, 0); dataset.setClassIndex(10); Classifier classifier = null; if (!new File("naivebayes.model").exists()) { // 添加数据 double[] values = new double[11]; for (int i = 0; i < questionID.size(); i++) { for (int m = 0; m < 11; m++) { values[m] = 0; } int whitewordcount = 0; whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId()); if (whitewordcount != 0) { List<QuestionWhiteWord> questionwhiteword = questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId()); for (int j = 0; j < questionwhiteword.size(); j++) { values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++; } for (int m = 0; m < 11; m++) { values[m] = values[m] / whitewordcount; } } values[10] = questionID.get(i).getType() - 1; Instance inst = new Instance(1.0, values); dataset.add(inst); } // 构造分类器 classifier = new NaiveBayes(); classifier.buildClassifier(dataset); SerializationHelper.write("naivebayes.model", classifier); } else { classifier = (Classifier) SerializationHelper.read("naivebayes.model"); } System.out.println("*************begin evaluation*******************"); Evaluation evaluation = new Evaluation(dataset); evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。 System.out.println(evaluation.toSummaryString()); // 分类 System.out.println("*************begin classification*******************"); Instance subject = new Instance(1.0, getQuestionVector(question)); subject.setDataset(dataset); label = classifier.classifyInstance(subject); System.out.println("label: " + label); // double dis[]=classifier.distributionForInstance(inst); // for(double i:dis){ // System.out.print(i+" "); // } System.out.println(questionID.size()); return label + 1; }
public static void main(String[] args) throws Exception { /* * First we load our preditons from the CSV formatted file. */ CSVLoader predictCsvLoader = new CSVLoader(); predictCsvLoader.setSource(new File("predict.csv")); /* * Since we are not using the ARFF format here, we have to give the * loader a little bit of information about the data types. Columns * 3,8,10 need to be of type string and columns 1,4,11 are nominal * types. */ predictCsvLoader.setStringAttributes("3,8,10"); predictCsvLoader.setNominalAttributes("1,4,11"); Instances predictDataSet = predictCsvLoader.getDataSet(); /* * Here we set the attribute we want to test the predicitons with */ Attribute testAttribute = predictDataSet.attribute(0); predictDataSet.setClass(testAttribute); /* * We still have to remove all string attributes before we can test */ predictDataSet.deleteStringAttributes(); /* * Next we load the training data from our ARFF file */ ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute trainAttribute = trainDataSet.attribute(0); trainDataSet.setClass(trainAttribute); /* * The RandomForest implementation cannot handle columns of type string, * so we remove them for now. */ trainDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("titanic.model"); /* * Next we will use an Evaluation class to evaluate the performance of * our Classifier. */ Evaluation evaluation = new Evaluation(trainDataSet); evaluation.evaluateModel(classifier, predictDataSet, new Object[] {}); /* * After we evaluate the Classifier, we write out the summary * information to the screen. */ System.out.println(classifier); System.out.println(evaluation.toSummaryString()); }
public void crossValidation() throws Exception { Evaluation eval = new Evaluation(trainingData); eval.crossValidateModel(classifier, trainingData, 10, new Random(1)); System.out.println(eval.toSummaryString("Results", false)); }
public void testModel() throws Exception { Evaluation eval = new Evaluation(testData); eval.evaluateModel(classifier, testData); System.out.println(eval.toSummaryString("Results", false)); }
/** * Accepts and processes a classifier encapsulated in an incremental classifier event * * @param ce an <code>IncrementalClassifierEvent</code> value */ @Override public void acceptClassifier(final IncrementalClassifierEvent ce) { try { if (ce.getStatus() == IncrementalClassifierEvent.NEW_BATCH) { m_throughput = new StreamThroughput(statusMessagePrefix()); m_throughput.setSamplePeriod(m_statusFrequency); // m_eval = new Evaluation(ce.getCurrentInstance().dataset()); m_eval = new Evaluation(ce.getStructure()); m_eval.useNoPriors(); m_dataLegend = new Vector(); m_reset = true; m_dataPoint = new double[0]; Instances inst = ce.getStructure(); System.err.println("NEW BATCH"); m_instanceCount = 0; if (m_windowSize > 0) { m_window = new LinkedList<Instance>(); m_windowEval = new Evaluation(ce.getStructure()); m_windowEval.useNoPriors(); m_windowedPreds = new LinkedList<double[]>(); if (m_logger != null) { m_logger.logMessage( statusMessagePrefix() + "[IncrementalClassifierEvaluator] Chart output using windowed " + "evaluation over " + m_windowSize + " instances"); } } /* * if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix() * + "IncrementalClassifierEvaluator: started processing..."); * m_logger.logMessage(statusMessagePrefix() + * " [IncrementalClassifierEvaluator]" + statusMessagePrefix() + * " started processing..."); } */ } else { Instance inst = ce.getCurrentInstance(); if (inst != null) { m_throughput.updateStart(); m_instanceCount++; // if (inst.attribute(inst.classIndex()).isNominal()) { double[] dist = ce.getClassifier().distributionForInstance(inst); double pred = 0; if (!inst.isMissing(inst.classIndex())) { if (m_outputInfoRetrievalStats) { // store predictions so AUC etc can be output. m_eval.evaluateModelOnceAndRecordPrediction(dist, inst); } else { m_eval.evaluateModelOnce(dist, inst); } if (m_windowSize > 0) { m_windowEval.evaluateModelOnce(dist, inst); m_window.addFirst(inst); m_windowedPreds.addFirst(dist); if (m_instanceCount > m_windowSize) { // "forget" the oldest prediction Instance oldest = m_window.removeLast(); double[] oldDist = m_windowedPreds.removeLast(); oldest.setWeight(-oldest.weight()); m_windowEval.evaluateModelOnce(oldDist, oldest); oldest.setWeight(-oldest.weight()); } } } else { pred = ce.getClassifier().classifyInstance(inst); } if (inst.classIndex() >= 0) { // need to check that the class is not missing if (inst.attribute(inst.classIndex()).isNominal()) { if (!inst.isMissing(inst.classIndex())) { if (m_dataPoint.length < 2) { m_dataPoint = new double[3]; m_dataLegend.addElement("Accuracy"); m_dataLegend.addElement("RMSE (prob)"); m_dataLegend.addElement("Kappa"); } // int classV = (int) inst.value(inst.classIndex()); if (m_windowSize > 0) { m_dataPoint[1] = m_windowEval.rootMeanSquaredError(); m_dataPoint[2] = m_windowEval.kappa(); } else { m_dataPoint[1] = m_eval.rootMeanSquaredError(); m_dataPoint[2] = m_eval.kappa(); } // int maxO = Utils.maxIndex(dist); // if (maxO == classV) { // dist[classV] = -1; // maxO = Utils.maxIndex(dist); // } // m_dataPoint[1] -= dist[maxO]; } else { if (m_dataPoint.length < 1) { m_dataPoint = new double[1]; m_dataLegend.addElement("Confidence"); } } double primaryMeasure = 0; if (!inst.isMissing(inst.classIndex())) { if (m_windowSize > 0) { primaryMeasure = 1.0 - m_windowEval.errorRate(); } else { primaryMeasure = 1.0 - m_eval.errorRate(); } } else { // record confidence as the primary measure // (another possibility would be entropy of // the distribution, or perhaps average // confidence) primaryMeasure = dist[Utils.maxIndex(dist)]; } // double [] dataPoint = new double[1]; m_dataPoint[0] = primaryMeasure; // double min = 0; double max = 100; /* * ChartEvent e = new * ChartEvent(IncrementalClassifierEvaluator.this, m_dataLegend, * min, max, dataPoint); */ m_ce.setLegendText(m_dataLegend); m_ce.setMin(0); m_ce.setMax(1); m_ce.setDataPoint(m_dataPoint); m_ce.setReset(m_reset); m_reset = false; } else { // numeric class if (m_dataPoint.length < 1) { m_dataPoint = new double[1]; if (inst.isMissing(inst.classIndex())) { m_dataLegend.addElement("Prediction"); } else { m_dataLegend.addElement("RMSE"); } } if (!inst.isMissing(inst.classIndex())) { double update; if (!inst.isMissing(inst.classIndex())) { if (m_windowSize > 0) { update = m_windowEval.rootMeanSquaredError(); } else { update = m_eval.rootMeanSquaredError(); } } else { update = pred; } m_dataPoint[0] = update; if (update > m_max) { m_max = update; } if (update < m_min) { m_min = update; } } m_ce.setLegendText(m_dataLegend); m_ce.setMin((inst.isMissing(inst.classIndex()) ? m_min : 0)); m_ce.setMax(m_max); m_ce.setDataPoint(m_dataPoint); m_ce.setReset(m_reset); m_reset = false; } notifyChartListeners(m_ce); } m_throughput.updateEnd(m_logger); } if (ce.getStatus() == IncrementalClassifierEvent.BATCH_FINISHED || inst == null) { if (m_logger != null) { m_logger.logMessage( "[IncrementalClassifierEvaluator]" + statusMessagePrefix() + " Finished processing."); } m_throughput.finished(m_logger); // save memory if using windowed evaluation for charting m_windowEval = null; m_window = null; m_windowedPreds = null; if (m_textListeners.size() > 0) { String textTitle = ce.getClassifier().getClass().getName(); textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length()); String results = "=== Performance information ===\n\n" + "Scheme: " + textTitle + "\n" + "Relation: " + m_eval.getHeader().relationName() + "\n\n" + m_eval.toSummaryString(); if (m_eval.getHeader().classIndex() >= 0 && m_eval.getHeader().classAttribute().isNominal() && (m_outputInfoRetrievalStats)) { results += "\n" + m_eval.toClassDetailsString(); } if (m_eval.getHeader().classIndex() >= 0 && m_eval.getHeader().classAttribute().isNominal()) { results += "\n" + m_eval.toMatrixString(); } textTitle = "Results: " + textTitle; TextEvent te = new TextEvent(this, results, textTitle); notifyTextListeners(te); } } } } catch (Exception ex) { if (m_logger != null) { m_logger.logMessage( "[IncrementalClassifierEvaluator]" + statusMessagePrefix() + " Error processing prediction " + ex.getMessage()); m_logger.statusMessage( statusMessagePrefix() + "ERROR: problem processing prediction (see log for details)"); } ex.printStackTrace(); stop(); } }