/** * Buildclassifier selects a classifier from the set of classifiers by minimising error on the * training data. * * @param data the training data to be used for generating the boosted classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifiers.length == 0) { throw new Exception("No base classifiers have been set!"); } Instances newData = new Instances(data); newData.deleteWithMissingClass(); newData.randomize(new Random(m_Seed)); if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1)) newData.stratify(m_NumXValFolds); Instances train = newData; // train on all data by default Instances test = newData; // test on training data by default Classifier bestClassifier = null; int bestIndex = -1; double bestPerformance = Double.NaN; int numClassifiers = m_Classifiers.length; for (int i = 0; i < numClassifiers; i++) { Classifier currentClassifier = getClassifier(i); Evaluation evaluation; if (m_NumXValFolds > 1) { evaluation = new Evaluation(newData); for (int j = 0; j < m_NumXValFolds; j++) { train = newData.trainCV(m_NumXValFolds, j); test = newData.testCV(m_NumXValFolds, j); currentClassifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(currentClassifier, test); } } else { currentClassifier.buildClassifier(train); evaluation = new Evaluation(train); evaluation.evaluateModel(currentClassifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println( "Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName()); } if ((i == 0) || (error < bestPerformance)) { bestClassifier = currentClassifier; bestPerformance = error; bestIndex = i; } } m_ClassifierIndex = bestIndex; m_Classifier = bestClassifier; if (m_NumXValFolds > 1) { m_Classifier.buildClassifier(newData); } }
/** * 用分类器测试 * * @param trainFileName * @param testFileName */ public static void classify(String trainFileName, String testFileName) { try { File inputFile = new File(fileName + trainFileName); // 训练语料文件 ArffLoader atf = new ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); // 读入训练文件 // 设置类标签类 inputFile = new File(fileName + testFileName); // 测试语料文件 atf.setFile(inputFile); Instances instancesTest = atf.getDataSet(); // 读入测试文件 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance(); classifier.buildClassifier(instancesTrain); Evaluation eval = new Evaluation(instancesTrain); // 第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集 eval.evaluateModel(classifier, instancesTest); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); System.out.println("precision is :" + (1 - eval.errorRate())); } catch (Exception e) { e.printStackTrace(); } }
public static void run(String[] args) throws Exception { /** * ************************************************* * * @param args[0]: train arff path * @param args[1]: test arff path */ DataSource source = new DataSource(args[0]); Instances data = source.getDataSet(); data.setClassIndex(data.numAttributes() - 1); NaiveBayes model = new NaiveBayes(); model.buildClassifier(data); // Evaluation: Evaluation eval = new Evaluation(data); Instances testData = new DataSource(args[1]).getDataSet(); testData.setClassIndex(testData.numAttributes() - 1); eval.evaluateModel(model, testData); System.out.println(model.toString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); System.out.println("======\nConfusion Matrix:"); double[][] confusionM = eval.confusionMatrix(); for (int i = 0; i < confusionM.length; ++i) { for (int j = 0; j < confusionM[i].length; ++j) { System.out.format("%10s ", confusionM[i][j]); } System.out.print("\n"); } }
/** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { Instances isTrainingSet = createSet(4); Instance instance1 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance2 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); Instance instance22 = createInstance(new double[] {0, 0, 0, 0}, "S3", isTrainingSet); isTrainingSet.add(instance1); isTrainingSet.add(instance2); isTrainingSet.add(instance22); Instances isTestingSet = createSet(4); Instance instance3 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance4 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); isTestingSet.add(instance3); isTestingSet.add(instance4); // Create a naïve bayes classifier Classifier cModel = (Classifier) new BayesNet(); // M5P cModel.buildClassifier(isTrainingSet); // Test the model Evaluation eTest = new Evaluation(isTrainingSet); eTest.evaluateModel(cModel, isTestingSet); // Print the result à la Weka explorer: String strSummary = eTest.toSummaryString(); System.out.println(strSummary); // Get the likelihood of each classes // fDistribution[0] is the probability of being “positive” // fDistribution[1] is the probability of being “negative” double[] fDistribution = cModel.distributionForInstance(instance4); for (int i = 0; i < fDistribution.length; i++) { System.out.println(fDistribution[i]); } }
/** * Creates an evaluation overview of the built classifier. * * @return the panel to be displayed as result evaluation view for the current decision point */ protected JPanel createEvaluationVisualization(Instances data) { // build text field to display evaluation statistics JTextPane statistic = new JTextPane(); try { // build evaluation statistics Evaluation evaluation = new Evaluation(data); evaluation.evaluateModel(myClassifier, data); statistic.setText( evaluation.toSummaryString() + "\n\n" + evaluation.toClassDetailsString() + "\n\n" + evaluation.toMatrixString()); } catch (Exception ex) { ex.printStackTrace(); return createMessagePanel("Error while creating the decision tree evaluation view"); } statistic.setFont(new Font("Courier", Font.PLAIN, 14)); statistic.setEditable(false); statistic.setCaretPosition(0); JPanel resultViewPanel = new JPanel(); resultViewPanel.setLayout(new BoxLayout(resultViewPanel, BoxLayout.PAGE_AXIS)); resultViewPanel.add(new JScrollPane(statistic)); return resultViewPanel; }
/** * Finds the best parameter combination. (recursive for each parameter being optimised). * * @param depth the index of the parameter to be optimised at this level * @param trainData the data the search is based on * @param random a random number generator * @throws Exception if an error occurs */ protected void findParamsByCrossValidation(int depth, Instances trainData, Random random) throws Exception { if (depth < m_CVParams.size()) { CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth); double upper; switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: upper = m_NumAttributes; break; case 2: upper = m_TrainFoldSize; break; default: upper = cvParam.m_Upper; break; } double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1); for (cvParam.m_ParamValue = cvParam.m_Lower; cvParam.m_ParamValue <= upper; cvParam.m_ParamValue += increment) { findParamsByCrossValidation(depth + 1, trainData, random); } } else { Evaluation evaluation = new Evaluation(trainData); // Set the classifier options String[] options = createOptions(); if (m_Debug) { System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":"); for (int i = 0; i < options.length; i++) { System.err.print(" " + options[i]); } System.err.println(""); } ((OptionHandler) m_Classifier).setOptions(options); for (int j = 0; j < m_NumFolds; j++) { // We want to randomize the data the same way for every // learning scheme. Instances train = trainData.trainCV(m_NumFolds, j, new Random(1)); Instances test = trainData.testCV(m_NumFolds, j); m_Classifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(m_Classifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4)); } if ((m_BestPerformance == -99) || (error < m_BestPerformance)) { m_BestPerformance = error; m_BestClassifierOptions = createOptions(); } } }
/** * Main method for testing this class. * * @param argv the options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new Decorate(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }
/** * Main method for testing this class * * @param argv options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new UnivariateLinearRegression(), argv)); } catch (Exception e) { System.out.println(e.getMessage()); e.printStackTrace(); } }
/** evaluates the classifier */ @Override public void evaluate() throws Exception { // evaluate classifier and print some statistics if (_test.classIndex() == -1) _test.setClassIndex(_test.numAttributes() - 1); Evaluation eval = new Evaluation(_train); eval.evaluateModel(_cl, _test); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); System.out.println(eval.toMatrixString()); }
public Evaluation evaluateClassifier(Instances trainInstances, Instances testInstances) { try { Evaluation eval = new Evaluation(trainInstances); eval.evaluateModel(bayesNet, testInstances); return eval; } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); return null; } }
private double[] makePredictions( Classifier classifier, Instances validationSet, Evaluation evaluation) { double[] predictions = null; try { predictions = evaluation.evaluateModel(classifier, validationSet); } catch (ArrayIndexOutOfBoundsException e) { throw new ClassifierPredictionException( "Error applying the trained classifier to the train instances. The number of features of the instance exceeds the number of features the classifier was trained on.", e); } catch (Exception e) { throw new ClassifierPredictionException( "Error applying the trained classifier to the test instances.", e); } return predictions; }
public static Double runClassify(String trainFile, String testFile) { double predictOrder = 0.0; double trueOrder = 0.0; try { String trainWekaFileName = trainFile; String testWekaFileName = testFile; Instances train = DataSource.read(trainWekaFileName); Instances test = DataSource.read(testWekaFileName); train.setClassIndex(0); test.setClassIndex(0); train.deleteAttributeAt(8); test.deleteAttributeAt(8); train.deleteAttributeAt(6); test.deleteAttributeAt(6); train.deleteAttributeAt(5); test.deleteAttributeAt(5); train.deleteAttributeAt(4); test.deleteAttributeAt(4); // AdditiveRegression classifier = new AdditiveRegression(); // NaiveBayes classifier = new NaiveBayes(); RandomForest classifier = new RandomForest(); // LibSVM classifier = new LibSVM(); classifier.buildClassifier(train); Evaluation eval = new Evaluation(train); eval.evaluateModel(classifier, test); System.out.println(eval.toSummaryString("\nResults\n\n", true)); // System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toMatrixString()); int k = 892; for (int i = 0; i < test.numInstances(); i++) { predictOrder = classifier.classifyInstance(test.instance(i)); trueOrder = test.instance(i).classValue(); System.out.println((k++) + "," + (int) predictOrder); } } catch (Exception e) { e.printStackTrace(); } return predictOrder; }
/** * SVM trainer * * @param dataTrain * @param dataTest */ public static void trainModelLibSVM(Instances dataTrain, Instances dataTest) { try { LibSVM classifier = new LibSVM(); CVParameterSelection ps = new CVParameterSelection(); ps.setClassifier(classifier); ps.setNumFolds(5); // using 5-fold CV // ps.addCVParameter("C 0.1 0.5 5"); // build and output best options ps.buildClassifier(dataTrain); Evaluation eval = new Evaluation(dataTrain); eval.evaluateModel(ps, dataTest); System.out.println("Results of the set :::::::::::::::::::::: "); System.out.println( "Percentage of correctly classified instances : " + eval.pctCorrect() + "\n" + "Percentage of incorrectly classified instances : " + eval.pctIncorrect()); System.out.println("No of correct predictions : " + eval.correct()); System.out.println("TRUTHFUL"); System.out.println( "Precision : " + eval.precision(0) + "\n" + "Recall : " + eval.recall(0) + "\n" + "F measure/score : " + eval.fMeasure(0)); System.out.println("DECEPTIVE"); System.out.println( "Precision : " + eval.precision(0) + "\n" + "Recall : " + eval.recall(1) + "\n" + "F measure/score : " + eval.fMeasure(1)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
public static void trainModel(Instances dataTrain, Instances dataTest) { try { LibLINEAR classifier = new LibLINEAR(); classifier.setBias(10); classifier.buildClassifier(dataTrain); Evaluation eval = new Evaluation(dataTrain); eval.evaluateModel(classifier, dataTest); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
/** * Utility method for fast 5-fold cross validation of a naive bayes model * * @param fullModel a <code>NaiveBayesUpdateable</code> value * @param trainingSet an <code>Instances</code> value * @param r a <code>Random</code> value * @return a <code>double</code> value * @exception Exception if an error occurs */ public static double crossValidate( NaiveBayesUpdateable fullModel, Instances trainingSet, Random r) throws Exception { // make some copies for fast evaluation of 5-fold xval Classifier[] copies = AbstractClassifier.makeCopies(fullModel, 5); Evaluation eval = new Evaluation(trainingSet); // make some splits for (int j = 0; j < 5; j++) { Instances test = trainingSet.testCV(5, j); // unlearn these test instances for (int k = 0; k < test.numInstances(); k++) { test.instance(k).setWeight(-test.instance(k).weight()); ((NaiveBayesUpdateable) copies[j]).updateClassifier(test.instance(k)); // reset the weight back to its original value test.instance(k).setWeight(-test.instance(k).weight()); } eval.evaluateModel(copies[j], test); } return eval.incorrect(); }
@Override protected Evaluation test() { working = true; System.out.println("Agent " + getLocalName() + ": Testing..."); // evaluate classifier and print some statistics Evaluation eval = null; try { eval = new Evaluation(train); eval.evaluateModel(cls, test); System.out.println( eval.toSummaryString(getLocalName() + " agent: " + "\nResults\n=======\n", false)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } working = false; return eval; } // end test
private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData) throws Exception { System.err.println( "INFO: Starting split validation to predict '" + trainData.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' (#train=" + trainData.numInstances() + ",#test=" + testData.numInstances() + ") ..."); if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set"); c.buildClassifier(trainData); Evaluation eval = new Evaluation(testData); eval.useNoPriors(); double[] predictions = eval.evaluateModel(c, testData); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); // write predictions to file { System.err.println("INFO: Writing predictions to file ..."); Writer out = new FileWriter("prediction.trec"); writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out); out.close(); } // write predicted distributions to CSV { System.err.println("INFO: Writing predicted distributions to CSV ..."); Writer out = new FileWriter("predicted_distribution.csv"); writePredictedDistributions(c, testData, 0, out); out.close(); } }
public static void execSVM(String expName) { try { FileWriter outFile = null; PrintWriter out = null; outFile = new FileWriter(expName + "-SVM.results"); out = new PrintWriter(outFile); DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); ProcessTweets tweetsProcessor = null; System.out.println("***************************************"); System.out.println("***\tEXECUTING TEST\t" + expName + "***"); System.out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("***************************************"); out.println("***\tEXECUTING TEST\t" + expName + "***"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("4-Generate classifier " + dateFormat.format(new Date())); Classifier cls = null; DataSource sourceTrain = new DataSource(expName + "-train.arff"); Instances dataTrain = sourceTrain.getDataSet(); if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1); // Entreno el clasificador // cls = new weka.classifiers.functions.LibSVM(); int clase = dataTrain.numAttributes() - 1; cls = new weka.classifiers.bayes.ComplementNaiveBayes(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(expName + "-SVM.classifier")); oos.writeObject(cls); oos.flush(); oos.close(); DataSource sourceTest = new DataSource(expName + "-test.arff"); Instances dataTest = sourceTest.getDataSet(); dataTest.setClassIndex(clase); Evaluation eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); // Ahora calculo los valores precision, recall y fmeasure. Además saco las matrices de // confusion float precision = 0; float recall = 0; float fmeasure = 0; int topeTopics = 8; for (int ind = 0; ind < topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ CNB ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ CNB ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); // OTRO CLASIFICADOR ZeroR cls = new weka.classifiers.rules.ZeroR(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); precision = 0; recall = 0; fmeasure = 0; for (int ind = 0; ind < topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ ZEROR ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ ZEROR ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); // OTRO CLASIFICADOR J48 /* cls = new weka.classifiers.trees.J48(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); precision=0; recall=0; fmeasure=0; for(int ind=0; ind<topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ J48 ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ J48 ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); //OTRO SMO cls = new weka.classifiers.functions.SMO(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); precision=0; recall=0; fmeasure=0; for(int ind=0; ind<topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ SMO ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ SMO ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); */ out.flush(); out.close(); dataTest.delete(); dataTrain.delete(); } catch (FileNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
/** * Gets the results for the supplied train and test datasets. Now performs a deep copy of the * classifier before it is built and evaluated (just in case the classifier is not initialized * properly in buildClassifier()). * * @param train the training Instances. * @param test the testing Instances. * @return the results stored in an array. The objects stored in the array may be Strings, * Doubles, or null (for the missing value). * @throws Exception if a problem occurs while getting the results */ public Object[] getResult(Instances train, Instances test) throws Exception { if (train.classAttribute().type() != Attribute.NUMERIC) { throw new Exception("Class attribute is not numeric!"); } if (m_Template == null) { throw new Exception("No classifier has been specified"); } ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean(); boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported(); if (canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled()) thMonitor.setThreadCpuTimeEnabled(true); int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0; Object[] result = new Object[RESULT_SIZE + addm + m_numPluginStatistics]; long thID = Thread.currentThread().getId(); long CPUStartTime = -1, trainCPUTimeElapsed = -1, testCPUTimeElapsed = -1, trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed; Evaluation eval = new Evaluation(train); m_Classifier = AbstractClassifier.makeCopy(m_Template); trainTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); m_Classifier.buildClassifier(train); if (canMeasureCPUTime) trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; testTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); eval.evaluateModel(m_Classifier, test); if (canMeasureCPUTime) testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; testTimeElapsed = System.currentTimeMillis() - testTimeStart; thMonitor = null; m_result = eval.toSummaryString(); // The results stored are all per instance -- can be multiplied by the // number of instances to get absolute numbers int current = 0; result[current++] = new Double(train.numInstances()); result[current++] = new Double(eval.numInstances()); result[current++] = new Double(eval.meanAbsoluteError()); result[current++] = new Double(eval.rootMeanSquaredError()); result[current++] = new Double(eval.relativeAbsoluteError()); result[current++] = new Double(eval.rootRelativeSquaredError()); result[current++] = new Double(eval.correlationCoefficient()); result[current++] = new Double(eval.SFPriorEntropy()); result[current++] = new Double(eval.SFSchemeEntropy()); result[current++] = new Double(eval.SFEntropyGain()); result[current++] = new Double(eval.SFMeanPriorEntropy()); result[current++] = new Double(eval.SFMeanSchemeEntropy()); result[current++] = new Double(eval.SFMeanEntropyGain()); // Timing stats result[current++] = new Double(trainTimeElapsed / 1000.0); result[current++] = new Double(testTimeElapsed / 1000.0); if (canMeasureCPUTime) { result[current++] = new Double((trainCPUTimeElapsed / 1000000.0) / 1000.0); result[current++] = new Double((testCPUTimeElapsed / 1000000.0) / 1000.0); } else { result[current++] = new Double(Utils.missingValue()); result[current++] = new Double(Utils.missingValue()); } // sizes if (m_NoSizeDetermination) { result[current++] = -1.0; result[current++] = -1.0; result[current++] = -1.0; } else { ByteArrayOutputStream bastream = new ByteArrayOutputStream(); ObjectOutputStream oostream = new ObjectOutputStream(bastream); oostream.writeObject(m_Classifier); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(train); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(test); result[current++] = new Double(bastream.size()); } // Prediction interval statistics result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions()); result[current++] = new Double(eval.sizeOfPredictedRegions()); if (m_Classifier instanceof Summarizable) { result[current++] = ((Summarizable) m_Classifier).toSummaryString(); } else { result[current++] = null; } for (int i = 0; i < addm; i++) { if (m_doesProduce[i]) { try { double dv = ((AdditionalMeasureProducer) m_Classifier).getMeasure(m_AdditionalMeasures[i]); if (!Utils.isMissingValue(dv)) { Double value = new Double(dv); result[current++] = value; } else { result[current++] = null; } } catch (Exception ex) { System.err.println(ex); } } else { result[current++] = null; } } // get the actual metrics from the evaluation object List<AbstractEvaluationMetric> metrics = eval.getPluginMetrics(); if (metrics != null) { for (AbstractEvaluationMetric m : metrics) { if (m.appliesToNumericClass()) { List<String> statNames = m.getStatisticNames(); for (String s : statNames) { result[current++] = new Double(m.getStatistic(s)); } } } } if (current != RESULT_SIZE + addm + m_numPluginStatistics) { throw new Error("Results didn't fit RESULT_SIZE"); } return result; }
public void exec(PrintWriter printer) { try { FileWriter outFile = null; PrintWriter out = null; if (printer == null) { outFile = new FileWriter(id + ".results"); out = new PrintWriter(outFile); } else out = printer; DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); ProcessTweets tweetsProcessor = null; System.out.println("***************************************"); System.out.println("***\tEXECUTING TEST\t" + id + "***"); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println("Train size:" + traincorpus.size()); System.out.println("Test size:" + testcorpus.size()); out.println("***************************************"); out.println("***\tEXECUTING TEST\t***"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("Train size:" + traincorpus.size()); out.println("Test size:" + testcorpus.size()); String cloneID = ""; boolean clonar = false; if (baseline) { System.out.println("***************************************"); System.out.println("***\tEXECUTING TEST BASELINE\t***"); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println("Train size:" + traincorpus.size()); System.out.println("Test size:" + testcorpus.size()); out.println("***************************************"); out.println("***\tEXECUTING TEST\t***"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("Train size:" + traincorpus.size()); out.println("Test size:" + testcorpus.size()); BaselineClassifier base = new BaselineClassifier(testcorpus, 8); precision = base.getPrecision(); recall = base.getRecall(); fmeasure = base.getFmeasure(); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); out.flush(); out.close(); return; } else { System.out.println("Stemming: " + stemming); System.out.println("Lematization:" + lematization); System.out.println("URLs:" + urls); System.out.println("Hashtags:" + hashtags); System.out.println("Mentions:" + mentions); System.out.println("Unigrams:" + unigrams); System.out.println("Bigrams:" + bigrams); System.out.println("TF:" + tf); System.out.println("TF-IDF:" + tfidf); out.println("Stemming: " + stemming); out.println("Lematization:" + lematization); out.println("URLs:" + urls); out.println("Hashtags:" + hashtags); out.println("Mentions:" + mentions); out.println("Unigrams:" + unigrams); out.println("Bigrams:" + bigrams); out.println("TF:" + tf); out.println("TF-IDF:" + tfidf); } // Si tengo los tweets procesados, me evito un nuevo proceso System.out.println("1-Process tweets " + dateFormat.format(new Date())); out.println("1-Process tweets " + dateFormat.format(new Date())); List<ProcessedTweet> train = null; String[] ids = id.split("-"); cloneID = ids[0] + "-" + (Integer.valueOf(ids[1]) + 6); if (((Integer.valueOf(ids[1]) / 6) % 2) == 0) clonar = true; if (new File(id + "-train.ptweets").exists()) { train = ProcessedTweetSerialization.fromFile(id + "-train.ptweets"); tweetsProcessor = new ProcessTweets(stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); if (lematization) { tweetsProcessor.doLematization(train); } if (stemming) { tweetsProcessor.doStemming(train); } } else { tweetsProcessor = new ProcessTweets(stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); // Esto del set training es un añadido para poder diferenciar los idiomas de las url en el // corpus paralelo // tweetsProcessor.setTraining(true); train = tweetsProcessor.processTweets(traincorpus); // tweetsProcessor.setTraining(false); ProcessedTweetSerialization.toFile(id + "-train.ptweets", train); /* if (clonar) { File f = new File (id+"-train.ptweets"); Path p = f.toPath(); CopyOption[] options = new CopyOption[]{ StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.COPY_ATTRIBUTES }; Files.copy(p, new File (cloneID+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+12)+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+18)+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+24)+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+30)+"-train.ptweets").toPath(), options); } */ } // Generamos las BOW. Igual que antes, si existen no las creo. System.out.println("2-Fill topics " + dateFormat.format(new Date())); out.println("2-Fill topics " + dateFormat.format(new Date())); TopicsList topics = null; if (new File(id + ".topics").exists()) { topics = TopicsSerialization.fromFile(id + ".topics"); if (tf) topics.setSelectionFeature(TopicDesc.TERM_TF); else topics.setSelectionFeature(TopicDesc.TERM_TF_IDF); topics.prepareTopics(); } else { topics = new TopicsList(); if (tf) topics.setSelectionFeature(TopicDesc.TERM_TF); else topics.setSelectionFeature(TopicDesc.TERM_TF_IDF); System.out.println("Filling topics " + dateFormat.format(new Date())); topics.fillTopics(train); System.out.println("Preparing topics topics " + dateFormat.format(new Date())); // Aquí tengo que serializar antes de preparar, porque si no no puedo calcular los tf y // tfidf System.out.println("Serializing topics topics " + dateFormat.format(new Date())); /* if (clonar) { TopicsSerialization.toFile(cloneID+".topics", topics); } */ topics.prepareTopics(); TopicsSerialization.toFile(id + ".topics", topics); } System.out.println("3-Generate arff train file " + dateFormat.format(new Date())); out.println("3-Generate arff train file " + dateFormat.format(new Date())); // Si el fichero arff no existe, lo creo. en caso contrario vengo haciendo lo que hasta ahora, // aprovechar trabajo previo if (!new File(id + "-train.arff").exists()) { BufferedWriter bw = topics.generateArffHeader(id + "-train.arff"); int tope = traincorpus.size(); if (tweetsProcessor == null) tweetsProcessor = new ProcessTweets( stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); for (int indTweet = 0; indTweet < tope; indTweet++) { topics.generateArffVector(bw, train.get(indTweet)); } bw.flush(); bw.close(); } // Ahora proceso los datos de test System.out.println("5-build test dataset " + dateFormat.format(new Date())); out.println("5-build test dataset " + dateFormat.format(new Date())); List<ProcessedTweet> test = null; if (new File(id + "-test.ptweets").exists()) test = ProcessedTweetSerialization.fromFile(id + "-test.ptweets"); else { if (tweetsProcessor == null) tweetsProcessor = new ProcessTweets( stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); test = tweetsProcessor.processTweets(testcorpus); ProcessedTweetSerialization.toFile(id + "-test.ptweets", test); /* if (clonar) { File f = new File (id+"-test.ptweets"); Path p = f.toPath(); CopyOption[] options = new CopyOption[]{ StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.COPY_ATTRIBUTES }; Files.copy(p, new File (cloneID+"-test.ptweets").toPath(), options); } */ } // Si el fichero arff no existe, lo creo. en caso contrario vengo haciendo lo que hasta ahora, // aprovechar trabajo previo if (!new File(id + "-test.arff").exists()) { BufferedWriter bw = topics.generateArffHeader(id + "-test.arff"); int tope = testcorpus.size(); if (tweetsProcessor == null) tweetsProcessor = new ProcessTweets( stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); for (int indTweet = 0; indTweet < tope; indTweet++) { topics.generateArffVector(bw, test.get(indTweet)); } bw.flush(); bw.close(); } int topeTopics = topics.getTopicsList().size(); topics.getTopicsList().clear(); // Genero el clasificador // FJRM 25-08-2013 Lo cambio de orden para intentar liberar la memoria de los topics y tener // más libre System.out.println("4-Generate classifier " + dateFormat.format(new Date())); out.println("4-Generate classifier " + dateFormat.format(new Date())); Classifier cls = null; DataSource sourceTrain = null; Instances dataTrain = null; if (new File(id + "-MNB.classifier").exists()) { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(id + "-MNB.classifier")); cls = (Classifier) ois.readObject(); ois.close(); } else { sourceTrain = new DataSource(id + "-train.arff"); dataTrain = sourceTrain.getDataSet(); if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1); // Entreno el clasificador cls = new weka.classifiers.bayes.NaiveBayesMultinomial(); int clase = dataTrain.numAttributes() - 1; dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(id + "-MNB.classifier")); oos.writeObject(cls); oos.flush(); oos.close(); // data.delete();//no borro para el svm } // Ahora evaluo el clasificador con los datos de test System.out.println("6-Evaluate classifier MNB " + dateFormat.format(new Date())); out.println("6-Evaluate classifier MNB" + dateFormat.format(new Date())); DataSource sourceTest = new DataSource(id + "-test.arff"); Instances dataTest = sourceTest.getDataSet(); int clase = dataTest.numAttributes() - 1; dataTest.setClassIndex(clase); Evaluation eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); // Ahora calculo los valores precision, recall y fmeasure. Además saco las matrices de // confusion precision = 0; recall = 0; fmeasure = 0; for (int ind = 0; ind < topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); /* NO BORRAR System.out.println("7-Evaluate classifier SVM"+dateFormat.format(new Date())); out.println("7-Evaluate classifier SVM"+dateFormat.format(new Date())); if (new File(id+"-SVM.classifier").exists()) { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(id+"-SVM.classifier")); cls = (Classifier) ois.readObject(); ois.close(); } else { if (dataTrain==null) { sourceTrain = new DataSource(id+"-train.arff"); dataTrain = sourceTrain.getDataSet(); if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1); } //Entreno el clasificador cls = new weka.classifiers.functions.LibSVM(); clase = dataTrain.numAttributes()-1; dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(id+"-SVM.classifier")); oos.writeObject(cls); oos.flush(); oos.close(); dataTrain.delete(); } eval.evaluateModel(cls, dataTest); precision=0; recall=0; fmeasure=0; for(int ind=0; ind<topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); */ System.out.println("Done " + dateFormat.format(new Date())); out.println("Done " + dateFormat.format(new Date())); if (printer == null) { out.flush(); out.close(); } // Intento de liberar memoria if (dataTrain != null) dataTrain.delete(); if (dataTest != null) dataTest.delete(); if (train != null) train.clear(); if (test != null) test.clear(); if (topics != null) { topics.getTopicsList().clear(); topics = null; } if (dataTest != null) dataTest.delete(); if (cls != null) cls = null; if (tweetsProcessor != null) tweetsProcessor = null; System.gc(); } catch (Exception e) { e.printStackTrace(); } }
public double getLiblinear(String path, String train, String test) { // 本次精确度 double accuracy = 0.0; try { LibLINEAR c1 = new LibLINEAR(); // * String[] options=weka.core.Utils.splitOptions( // * "-S 1 -C 1.0 -E 0.001 -B 0"); c1.setOptions(options); ArffLoader atf = new ArffLoader(); File TraininputFile = new File(train); atf.setFile(TraininputFile); // 训练语料文件 Instances instancesTrain = atf.getDataSet(); // 读入训练文件 instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); File TestinputFile = new File(test); atf.setFile(TestinputFile); // 测试语料文件 Instances instancesTest = atf.getDataSet(); // 读入测试文件 // 设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); c1.buildClassifier(instancesTrain); // 训练 Evaluation eval = new Evaluation(instancesTrain); eval.evaluateModel(c1, instancesTest); // eval.crossValidateModel(c1, instancesTrain, 10, new // Random(1)); File newfile = new File(path + "OutLiblinear_temp" + ".txt"); BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(newfile), "utf-8")); bufferedWriter.write(eval.toSummaryString() + "\r\n"); bufferedWriter.write(eval.toClassDetailsString() + "\r\n"); bufferedWriter.write(eval.toMatrixString() + "\r\n"); bufferedWriter.flush(); bufferedWriter.close(); BufferedReader bufferedReader = new BufferedReader(new FileReader(newfile)); String[] splitLineString = new String[5]; while (bufferedReader.ready()) { bufferedReader.readLine(); String lineString = bufferedReader.readLine(); splitLineString = lineString.split(" "); System.out.println(splitLineString[4]); break; } bufferedReader.close(); // 求分类准确度 String tempLine; BufferedReader tempBF = new BufferedReader(new FileReader(newfile)); while (tempBF.ready()) { tempLine = tempBF.readLine(); if (tempLine.contains("Correctly Classified Instances")) { tempLine = tempLine.substring(tempLine.lastIndexOf(".") - 2, tempLine.lastIndexOf(" ")); accuracy = Double.parseDouble(tempLine); break; } } tempBF.close(); } catch (Exception e) { System.out.println("Can't run linlinear of weka."); } return accuracy; }
public static void main(String[] args) throws Exception { BufferedReader reader = new BufferedReader(new FileReader("spambase.arff")); Instances data = new Instances(reader); reader.close(); // setting class attribute data.setClassIndex(data.numAttributes() - 1); int i = data.numInstances(); int j = data.numAttributes() - 1; File file = new File("tablelog.csv"); Writer output = null; output = new BufferedWriter(new FileWriter(file)); output.write( "%missing,auc1,correct1,fmeasure1,auc2,correct2,fmeasure2,auc3,correct3,fmeasure3\n"); Random randomGenerator = new Random(); data.randomize(randomGenerator); int numBlock = data.numInstances() / 2; double num0 = 0, num1 = 0, num2 = 0; /*mdata.instance(0).setMissing(0); mdata.deleteWithMissing(0); System.out.println(mdata.numInstances()+","+data.numInstances());*/ // Instances traindata=null; // Instances testdata=null; // System.out.println(data.instance(3).stringValue(1)); for (int perc = 10; perc < 101; perc = perc + 10) { Instances mdata = new Instances(data); int numMissing = perc * numBlock / 100; double y11[] = new double[2]; double y21[] = new double[2]; double y31[] = new double[2]; double y12[] = new double[2]; double y22[] = new double[2]; double y32[] = new double[2]; double y13[] = new double[2]; double y23[] = new double[2]; double y33[] = new double[2]; for (int p = 0; p < 2; p++) { Instances traindata = mdata.trainCV(2, p); Instances testdata = mdata.testCV(2, p); num0 = 0; num1 = 0; num2 = 0; for (int t = 0; t < numBlock; t++) { if (traindata.instance(t).classValue() == 0) num0++; if (traindata.instance(t).classValue() == 1) num1++; // if (traindata.instance(t).classValue()==2) num2++; } // System.out.println(mdata.instance(0).classValue()); Instances trainwithmissing = new Instances(traindata); Instances testwithmissing = new Instances(testdata); for (int q = 0; q < j; q++) { int r = randomGenerator.nextInt((int) i / 2); for (int k = 0; k < numMissing; k++) { // int r = randomGenerator.nextInt((int) i/2); // int c = randomGenerator.nextInt(j); trainwithmissing.instance((r + k) % numBlock).setMissing(q); testwithmissing.instance((r + k) % numBlock).setMissing(q); } } // trainwithmissing.deleteWithMissing(0);System.out.println(traindata.numInstances()+","+trainwithmissing.numInstances()); Classifier cModel = (Classifier) new Logistic(); // try for different classifiers and datasets cModel.buildClassifier(trainwithmissing); Evaluation eTest1 = new Evaluation(trainwithmissing); eTest1.evaluateModel(cModel, testdata); // eTest.crossValidateModel(cModel,mdata,10,mdata.getRandomNumberGenerator(1)); y11[p] = num0 / numBlock * eTest1.areaUnderROC(0) + num1 / numBlock * eTest1.areaUnderROC(1) /*+num2/numBlock*eTest1.areaUnderROC(2)*/; y21[p] = eTest1.correct(); y31[p] = num0 / numBlock * eTest1.fMeasure(0) + num1 / numBlock * eTest1.fMeasure(1) /*+num2/numBlock*eTest1.fMeasure(2)*/; Classifier cModel2 = (Classifier) new Logistic(); cModel2.buildClassifier(traindata); Evaluation eTest2 = new Evaluation(traindata); eTest2.evaluateModel(cModel2, testwithmissing); y12[p] = num0 / numBlock * eTest2.areaUnderROC(0) + num1 / numBlock * eTest2.areaUnderROC(1) /*+num2/numBlock*eTest2.areaUnderROC(2)*/; y22[p] = eTest2.correct(); y32[p] = num0 / numBlock * eTest2.fMeasure(0) + num1 / numBlock * eTest2.fMeasure(1) /*+num2/numBlock*eTest2.fMeasure(2)*/; Classifier cModel3 = (Classifier) new Logistic(); cModel3.buildClassifier(trainwithmissing); Evaluation eTest3 = new Evaluation(trainwithmissing); eTest3.evaluateModel(cModel3, testwithmissing); y13[p] = num0 / numBlock * eTest3.areaUnderROC(0) + num1 / numBlock * eTest3.areaUnderROC(1) /*+num2/numBlock*eTest3.areaUnderROC(2)*/; y23[p] = eTest3.correct(); y33[p] = num0 / numBlock * eTest3.fMeasure(0) + num1 / numBlock * eTest3.fMeasure(1) /*+num2/numBlock*eTest3.fMeasure(2)*/; // System.out.println(num0+","+num1+","+num2+"\n"); } double auc1 = (y11[0] + y11[1]) / 2; double auc2 = (y12[0] + y12[1]) / 2; double auc3 = (y13[0] + y13[1]) / 2; double corr1 = (y21[0] + y21[1]) / i; double corr2 = (y22[0] + y22[1]) / i; double corr3 = (y23[0] + y23[1]) / i; double fm1 = (y31[0] + y31[1]) / 2; double fm2 = (y32[0] + y32[1]) / 2; double fm3 = (y33[0] + y33[1]) / 2; output.write( perc + "," + auc1 + "," + corr1 + "," + fm1 + "," + auc2 + "," + corr2 + "," + fm2 + "," + auc3 + "," + corr3 + "," + fm3 + "\n"); // System.out.println(num0); // mdata=data; } output.close(); }
// 输入问题,输出问题所属类型。 public double classifyByBayes(String question) throws Exception { double label = -1; List<Question> questionID = questionDAO.getQuestionIDLabeled(); // 定义数据格式 Attribute att1 = new Attribute("法律政策"); Attribute att2 = new Attribute("位置交通"); Attribute att3 = new Attribute("风水"); Attribute att4 = new Attribute("房价"); Attribute att5 = new Attribute("楼层"); Attribute att6 = new Attribute("户型"); Attribute att7 = new Attribute("小区配套"); Attribute att8 = new Attribute("贷款"); Attribute att9 = new Attribute("买房时机"); Attribute att10 = new Attribute("开发商"); FastVector labels = new FastVector(); labels.addElement("1"); labels.addElement("2"); labels.addElement("3"); labels.addElement("4"); labels.addElement("5"); labels.addElement("6"); labels.addElement("7"); labels.addElement("8"); labels.addElement("9"); labels.addElement("10"); Attribute att11 = new Attribute("类别", labels); FastVector attributes = new FastVector(); attributes.addElement(att1); attributes.addElement(att2); attributes.addElement(att3); attributes.addElement(att4); attributes.addElement(att5); attributes.addElement(att6); attributes.addElement(att7); attributes.addElement(att8); attributes.addElement(att9); attributes.addElement(att10); attributes.addElement(att11); Instances dataset = new Instances("Test-dataset", attributes, 0); dataset.setClassIndex(10); Classifier classifier = null; if (!new File("naivebayes.model").exists()) { // 添加数据 double[] values = new double[11]; for (int i = 0; i < questionID.size(); i++) { for (int m = 0; m < 11; m++) { values[m] = 0; } int whitewordcount = 0; whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId()); if (whitewordcount != 0) { List<QuestionWhiteWord> questionwhiteword = questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId()); for (int j = 0; j < questionwhiteword.size(); j++) { values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++; } for (int m = 0; m < 11; m++) { values[m] = values[m] / whitewordcount; } } values[10] = questionID.get(i).getType() - 1; Instance inst = new Instance(1.0, values); dataset.add(inst); } // 构造分类器 classifier = new NaiveBayes(); classifier.buildClassifier(dataset); SerializationHelper.write("naivebayes.model", classifier); } else { classifier = (Classifier) SerializationHelper.read("naivebayes.model"); } System.out.println("*************begin evaluation*******************"); Evaluation evaluation = new Evaluation(dataset); evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。 System.out.println(evaluation.toSummaryString()); // 分类 System.out.println("*************begin classification*******************"); Instance subject = new Instance(1.0, getQuestionVector(question)); subject.setDataset(dataset); label = classifier.classifyInstance(subject); System.out.println("label: " + label); // double dis[]=classifier.distributionForInstance(inst); // for(double i:dis){ // System.out.print(i+" "); // } System.out.println(questionID.size()); return label + 1; }
public void runFilter() throws Exception { System.out.println("filtering attributes..."); System.out.println("running weka filters and weka-libsvm"); File svmfile = new File(sentiAnalysis.DIR.concat(sentiAnalysis.outout.concat(".libsvm"))); LibSVMLoader libl = new LibSVMLoader(); libl.setFile(svmfile); Instances data = libl.getDataSet(); NumericToNominal nm = new NumericToNominal(); // Converting last index // attribute to type // nominal from numeric nm.setAttributeIndices("last"); // as the last index would be class // label for the data nm.setInputFormat(data); filteredData = Filter.useFilter(data, nm); // filtered data stored in // new Instances object AttrNo = filteredData.numAttributes(); // number of attributes in given // file RecordNo = filteredData.numInstances(); // Number of records in given // file lowerBound = 0; upperBound = AttrNo - 1; AttributeSelection atsl = new AttributeSelection(); Ranker search = new Ranker(); InfoGainAttributeEval infog = new InfoGainAttributeEval(); // Applying // Attribute // Selection // using // InfoGain // evaluator // with // Ranker // search atsl.setEvaluator(infog); atsl.setSearch(search); atsl.SelectAttributes(filteredData); InfoGain = atsl.rankedAttributes(); SelectedAttributes = atsl.selectedAttributes(); // count non zero infoGain int count = 0; for (int i = 0; i < InfoGain.length; i++) { count = (InfoGain[i][1] > 0) ? count + 1 : count; } System.out.println("writing attributes with non-zero InfoGain..."); FileWriter svmout = new FileWriter(sentiAnalysis.DIR.concat(sentiAnalysis.outout.concat("_new.libsvm"))); for (int i = 0; i < RecordNo; i++) { int index = 1; svmout.write((int) filteredData.instance(i).value(filteredData.classIndex()) + " "); for (int j = 0; j < count; j++) { svmout.write( index + ":" + (int) filteredData.instance(i).value((int) InfoGain[j][0]) + " "); index++; } svmout.write("\n"); } svmout.close(); // filtered File newsvm = new File(sentiAnalysis.DIR.concat(sentiAnalysis.outout.concat("_new.libsvm"))); LibSVMLoader liblnew = new LibSVMLoader(); liblnew.setFile(newsvm); Instances newdata = liblnew.getDataSet(); nm = new NumericToNominal(); // Converting last index attribute to type // nominal from numeric nm.setAttributeIndices("last"); // as the last index would be class // label for the data nm.setInputFormat(newdata); Instances filteredDataNew = Filter.useFilter(newdata, nm); // filtered // data // stored in // new // Instances // object // test file File newsvmtest = new File(sentiAnalysis.DIR.concat(sentiAnalysis.outout.concat("_test.libsvm"))); LibSVMLoader libltest = new LibSVMLoader(); libltest.setFile(newsvmtest); Instances newdatatest = libltest.getDataSet(); nm = new NumericToNominal(); // Converting last index attribute to type // nominal from numeric nm.setAttributeIndices("last"); // as the last index would be class // label for the data nm.setInputFormat(newdatatest); Instances filteredDataTest = Filter.useFilter(newdatatest, nm); // filtered // data // stored // in // new // Instances // object // weka.classifiers.functions.LibSVM -S 0 -K 2 -D 3 -G 0.0 -R 0.0 -N 0.5 // -M 40.0 -C 1.0 -E 0.001 -P 0.1 -seed 1 String[] options = new String[1]; options[0] = "-S 0 -K 2 -D 3 -G 0.1 -R 0.0 -N 0.5 -M 40.0 -C 1.0 -E 0.001 -P 0.1 -seed 1 -h 0"; System.out.println("building classifier..."); LibSVM svm_model = new LibSVM(); svm_model.setOptions(options); // set the options svm_model.buildClassifier(filteredData); // build classifier DecimalFormat df = new DecimalFormat("0.00"); System.out.println("running cross validation..."); Evaluation eval = new Evaluation(filteredData); // eval.crossValidateModel(svm_model, filteredDataNew, 10, new // Random(1)); eval.evaluateModel(svm_model, filteredDataTest); FileWriter results = new FileWriter(sentiAnalysis.DIR.concat(sentiAnalysis.outout.concat("_results.txt"))); results.write("Classifier 1: Support Vector Machines\n"); results.write("Positive class precision: " + df.format(eval.precision(0)) + "\n"); results.write("Positive class recall: " + df.format(eval.recall(0)) + "\n"); results.write("Positive class f-score: " + df.format(eval.fMeasure(0)) + "\n"); results.write("Negative class precision: " + df.format(eval.precision(0)) + "\n"); results.write("Negative class recall: " + df.format(eval.precision(0)) + "\n"); results.write("Negative class f-score: " + df.format(eval.fMeasure(0)) + "\n"); System.out.println("generating results..."); System.out.println("*" + sentiAnalysis.outout + "*\t" + "\tPositive\tNegative\tNeutral"); System.out.println( "Precision\t" + df.format(eval.precision(0)) + "\t" + df.format(eval.precision(2)) + "\t" + df.format(eval.precision(1))); System.out.println( "Recall\t" + df.format(eval.recall(0)) + "\t" + df.format(eval.recall(2)) + "\t" + df.format(eval.recall(1))); System.out.println( "F-score\t" + df.format(eval.fMeasure(0)) + "\t" + df.format(eval.fMeasure(2)) + "\t" + df.format(eval.fMeasure(1))); results.close(); }
public void testModel() throws Exception { Evaluation eval = new Evaluation(testData); eval.evaluateModel(classifier, testData); System.out.println(eval.toSummaryString("Results", false)); }
public QSARModel train(Instances data) throws QSARException { // GET A UUID AND DEFINE THE TEMPORARY FILE WHERE THE TRAINING DATA // ARE STORED IN ARFF FORMAT PRIOR TO TRAINING. final String rand = java.util.UUID.randomUUID().toString(); final String temporaryFilePath = ServerFolders.temp + "/" + rand + ".arff"; final File tempFile = new File(temporaryFilePath); // SAVE THE DATA IN THE TEMPORARY FILE try { ArffSaver dataSaver = new ArffSaver(); dataSaver.setInstances(data); dataSaver.setDestination(new FileOutputStream(tempFile)); dataSaver.writeBatch(); if (!tempFile.exists()) { throw new IOException("Temporary File was not created"); } } catch (final IOException ex) { /* * The content of the dataset cannot be * written to the destination file due to * some communication issue. */ tempFile.delete(); throw new RuntimeException( "Unexpected condition while trying to save the " + "dataset in a temporary ARFF file", ex); } NaiveBayes classifier = new NaiveBayes(); String[] generalOptions = { "-c", Integer.toString(data.classIndex() + 1), "-t", temporaryFilePath, /// Save the model in the following directory "-d", ServerFolders.models_weka + "/" + uuid }; try { Evaluation.evaluateModel(classifier, generalOptions); } catch (final Exception ex) { tempFile.delete(); throw new QSARException( Cause.XQReg350, "Unexpected condition while trying to train " + "an SVM model. Possible explanation : {" + ex.getMessage() + "}", ex); } QSARModel model = new QSARModel(); model.setParams(getParameters()); model.setCode(uuid.toString()); model.setAlgorithm(YaqpAlgorithms.NAIVE_BAYES); model.setDataset(datasetUri); model.setModelStatus(ModelStatus.UNDER_DEVELOPMENT); ArrayList<Feature> independentFeatures = new ArrayList<Feature>(); for (int i = 0; i < data.numAttributes(); i++) { Feature f = new Feature(data.attribute(i).name()); if (data.classIndex() != i) { independentFeatures.add(f); } } Feature dependentFeature = new Feature(data.classAttribute().name()); Feature predictedFeature = dependentFeature; model.setDependentFeature(dependentFeature); model.setIndependentFeatures(independentFeatures); model.setPredictionFeature(predictedFeature); tempFile.delete(); return model; }
public static void main(String[] args) throws Exception { /* * First we load our preditons from the CSV formatted file. */ CSVLoader predictCsvLoader = new CSVLoader(); predictCsvLoader.setSource(new File("predict.csv")); /* * Since we are not using the ARFF format here, we have to give the * loader a little bit of information about the data types. Columns * 3,8,10 need to be of type string and columns 1,4,11 are nominal * types. */ predictCsvLoader.setStringAttributes("3,8,10"); predictCsvLoader.setNominalAttributes("1,4,11"); Instances predictDataSet = predictCsvLoader.getDataSet(); /* * Here we set the attribute we want to test the predicitons with */ Attribute testAttribute = predictDataSet.attribute(0); predictDataSet.setClass(testAttribute); /* * We still have to remove all string attributes before we can test */ predictDataSet.deleteStringAttributes(); /* * Next we load the training data from our ARFF file */ ArffLoader trainLoader = new ArffLoader(); trainLoader.setSource(new File("train.arff")); trainLoader.setRetrieval(Loader.BATCH); Instances trainDataSet = trainLoader.getDataSet(); /* * Now we tell the data set which attribute we want to classify, in our * case, we want to classify the first column: survived */ Attribute trainAttribute = trainDataSet.attribute(0); trainDataSet.setClass(trainAttribute); /* * The RandomForest implementation cannot handle columns of type string, * so we remove them for now. */ trainDataSet.deleteStringAttributes(); /* * Now we read in the serialized model from disk */ Classifier classifier = (Classifier) SerializationHelper.read("titanic.model"); /* * Next we will use an Evaluation class to evaluate the performance of * our Classifier. */ Evaluation evaluation = new Evaluation(trainDataSet); evaluation.evaluateModel(classifier, predictDataSet, new Object[] {}); /* * After we evaluate the Classifier, we write out the summary * information to the screen. */ System.out.println(classifier); System.out.println(evaluation.toSummaryString()); }