@Override protected pikater.ontology.messages.Evaluation evaluateCA() { Evaluation eval = test(); pikater.ontology.messages.Evaluation result = new pikater.ontology.messages.Evaluation(); result.setError_rate((float) eval.errorRate()); try { result.setKappa_statistic((float) eval.kappa()); } catch (Exception e) { result.setKappa_statistic(-1); } result.setMean_absolute_error((float) eval.meanAbsoluteError()); try { result.setRelative_absolute_error((float) eval.relativeAbsoluteError()); } catch (Exception e) { result.setRelative_absolute_error(-1); } result.setRoot_mean_squared_error((float) eval.rootMeanSquaredError()); result.setRoot_relative_squared_error((float) eval.rootRelativeSquaredError()); return result; }
/** * 用分类器测试 * * @param trainFileName * @param testFileName */ public static void classify(String trainFileName, String testFileName) { try { File inputFile = new File(fileName + trainFileName); // 训练语料文件 ArffLoader atf = new ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); // 读入训练文件 // 设置类标签类 inputFile = new File(fileName + testFileName); // 测试语料文件 atf.setFile(inputFile); Instances instancesTest = atf.getDataSet(); // 读入测试文件 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance(); classifier.buildClassifier(instancesTrain); Evaluation eval = new Evaluation(instancesTrain); // 第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集 eval.evaluateModel(classifier, instancesTest); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); System.out.println("precision is :" + (1 - eval.errorRate())); } catch (Exception e) { e.printStackTrace(); } }
/** * Finds the best parameter combination. (recursive for each parameter being optimised). * * @param depth the index of the parameter to be optimised at this level * @param trainData the data the search is based on * @param random a random number generator * @throws Exception if an error occurs */ protected void findParamsByCrossValidation(int depth, Instances trainData, Random random) throws Exception { if (depth < m_CVParams.size()) { CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth); double upper; switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: upper = m_NumAttributes; break; case 2: upper = m_TrainFoldSize; break; default: upper = cvParam.m_Upper; break; } double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1); for (cvParam.m_ParamValue = cvParam.m_Lower; cvParam.m_ParamValue <= upper; cvParam.m_ParamValue += increment) { findParamsByCrossValidation(depth + 1, trainData, random); } } else { Evaluation evaluation = new Evaluation(trainData); // Set the classifier options String[] options = createOptions(); if (m_Debug) { System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":"); for (int i = 0; i < options.length; i++) { System.err.print(" " + options[i]); } System.err.println(""); } ((OptionHandler) m_Classifier).setOptions(options); for (int j = 0; j < m_NumFolds; j++) { // We want to randomize the data the same way for every // learning scheme. Instances train = trainData.trainCV(m_NumFolds, j, new Random(1)); Instances test = trainData.testCV(m_NumFolds, j); m_Classifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(m_Classifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4)); } if ((m_BestPerformance == -99) || (error < m_BestPerformance)) { m_BestPerformance = error; m_BestClassifierOptions = createOptions(); } } }
/** * Buildclassifier selects a classifier from the set of classifiers by minimising error on the * training data. * * @param data the training data to be used for generating the boosted classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifiers.length == 0) { throw new Exception("No base classifiers have been set!"); } Instances newData = new Instances(data); newData.deleteWithMissingClass(); newData.randomize(new Random(m_Seed)); if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1)) newData.stratify(m_NumXValFolds); Instances train = newData; // train on all data by default Instances test = newData; // test on training data by default Classifier bestClassifier = null; int bestIndex = -1; double bestPerformance = Double.NaN; int numClassifiers = m_Classifiers.length; for (int i = 0; i < numClassifiers; i++) { Classifier currentClassifier = getClassifier(i); Evaluation evaluation; if (m_NumXValFolds > 1) { evaluation = new Evaluation(newData); for (int j = 0; j < m_NumXValFolds; j++) { train = newData.trainCV(m_NumXValFolds, j); test = newData.testCV(m_NumXValFolds, j); currentClassifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(currentClassifier, test); } } else { currentClassifier.buildClassifier(train); evaluation = new Evaluation(train); evaluation.evaluateModel(currentClassifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println( "Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName()); } if ((i == 0) || (error < bestPerformance)) { bestClassifier = currentClassifier; bestPerformance = error; bestIndex = i; } } m_ClassifierIndex = bestIndex; m_Classifier = bestClassifier; if (m_NumXValFolds > 1) { m_Classifier.buildClassifier(newData); } }
/** * Predict the accuracy of the prototypes based on the learning set. It uses cross validation to * draw the prediction. * * @param nbFolds the number of folds for the x-validation * @return the predicted accuracy */ public double predictAccuracyXVal(int nbFolds) throws Exception { Evaluation eval = new Evaluation(trainingData); eval.crossValidateModel(this, trainingData, nbFolds, new Random(), new Object[] {}); return eval.errorRate(); }
/** * Accepts and processes a classifier encapsulated in an incremental classifier event * * @param ce an <code>IncrementalClassifierEvent</code> value */ @Override public void acceptClassifier(final IncrementalClassifierEvent ce) { try { if (ce.getStatus() == IncrementalClassifierEvent.NEW_BATCH) { m_throughput = new StreamThroughput(statusMessagePrefix()); m_throughput.setSamplePeriod(m_statusFrequency); // m_eval = new Evaluation(ce.getCurrentInstance().dataset()); m_eval = new Evaluation(ce.getStructure()); m_eval.useNoPriors(); m_dataLegend = new Vector(); m_reset = true; m_dataPoint = new double[0]; Instances inst = ce.getStructure(); System.err.println("NEW BATCH"); m_instanceCount = 0; if (m_windowSize > 0) { m_window = new LinkedList<Instance>(); m_windowEval = new Evaluation(ce.getStructure()); m_windowEval.useNoPriors(); m_windowedPreds = new LinkedList<double[]>(); if (m_logger != null) { m_logger.logMessage( statusMessagePrefix() + "[IncrementalClassifierEvaluator] Chart output using windowed " + "evaluation over " + m_windowSize + " instances"); } } /* * if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix() * + "IncrementalClassifierEvaluator: started processing..."); * m_logger.logMessage(statusMessagePrefix() + * " [IncrementalClassifierEvaluator]" + statusMessagePrefix() + * " started processing..."); } */ } else { Instance inst = ce.getCurrentInstance(); if (inst != null) { m_throughput.updateStart(); m_instanceCount++; // if (inst.attribute(inst.classIndex()).isNominal()) { double[] dist = ce.getClassifier().distributionForInstance(inst); double pred = 0; if (!inst.isMissing(inst.classIndex())) { if (m_outputInfoRetrievalStats) { // store predictions so AUC etc can be output. m_eval.evaluateModelOnceAndRecordPrediction(dist, inst); } else { m_eval.evaluateModelOnce(dist, inst); } if (m_windowSize > 0) { m_windowEval.evaluateModelOnce(dist, inst); m_window.addFirst(inst); m_windowedPreds.addFirst(dist); if (m_instanceCount > m_windowSize) { // "forget" the oldest prediction Instance oldest = m_window.removeLast(); double[] oldDist = m_windowedPreds.removeLast(); oldest.setWeight(-oldest.weight()); m_windowEval.evaluateModelOnce(oldDist, oldest); oldest.setWeight(-oldest.weight()); } } } else { pred = ce.getClassifier().classifyInstance(inst); } if (inst.classIndex() >= 0) { // need to check that the class is not missing if (inst.attribute(inst.classIndex()).isNominal()) { if (!inst.isMissing(inst.classIndex())) { if (m_dataPoint.length < 2) { m_dataPoint = new double[3]; m_dataLegend.addElement("Accuracy"); m_dataLegend.addElement("RMSE (prob)"); m_dataLegend.addElement("Kappa"); } // int classV = (int) inst.value(inst.classIndex()); if (m_windowSize > 0) { m_dataPoint[1] = m_windowEval.rootMeanSquaredError(); m_dataPoint[2] = m_windowEval.kappa(); } else { m_dataPoint[1] = m_eval.rootMeanSquaredError(); m_dataPoint[2] = m_eval.kappa(); } // int maxO = Utils.maxIndex(dist); // if (maxO == classV) { // dist[classV] = -1; // maxO = Utils.maxIndex(dist); // } // m_dataPoint[1] -= dist[maxO]; } else { if (m_dataPoint.length < 1) { m_dataPoint = new double[1]; m_dataLegend.addElement("Confidence"); } } double primaryMeasure = 0; if (!inst.isMissing(inst.classIndex())) { if (m_windowSize > 0) { primaryMeasure = 1.0 - m_windowEval.errorRate(); } else { primaryMeasure = 1.0 - m_eval.errorRate(); } } else { // record confidence as the primary measure // (another possibility would be entropy of // the distribution, or perhaps average // confidence) primaryMeasure = dist[Utils.maxIndex(dist)]; } // double [] dataPoint = new double[1]; m_dataPoint[0] = primaryMeasure; // double min = 0; double max = 100; /* * ChartEvent e = new * ChartEvent(IncrementalClassifierEvaluator.this, m_dataLegend, * min, max, dataPoint); */ m_ce.setLegendText(m_dataLegend); m_ce.setMin(0); m_ce.setMax(1); m_ce.setDataPoint(m_dataPoint); m_ce.setReset(m_reset); m_reset = false; } else { // numeric class if (m_dataPoint.length < 1) { m_dataPoint = new double[1]; if (inst.isMissing(inst.classIndex())) { m_dataLegend.addElement("Prediction"); } else { m_dataLegend.addElement("RMSE"); } } if (!inst.isMissing(inst.classIndex())) { double update; if (!inst.isMissing(inst.classIndex())) { if (m_windowSize > 0) { update = m_windowEval.rootMeanSquaredError(); } else { update = m_eval.rootMeanSquaredError(); } } else { update = pred; } m_dataPoint[0] = update; if (update > m_max) { m_max = update; } if (update < m_min) { m_min = update; } } m_ce.setLegendText(m_dataLegend); m_ce.setMin((inst.isMissing(inst.classIndex()) ? m_min : 0)); m_ce.setMax(m_max); m_ce.setDataPoint(m_dataPoint); m_ce.setReset(m_reset); m_reset = false; } notifyChartListeners(m_ce); } m_throughput.updateEnd(m_logger); } if (ce.getStatus() == IncrementalClassifierEvent.BATCH_FINISHED || inst == null) { if (m_logger != null) { m_logger.logMessage( "[IncrementalClassifierEvaluator]" + statusMessagePrefix() + " Finished processing."); } m_throughput.finished(m_logger); // save memory if using windowed evaluation for charting m_windowEval = null; m_window = null; m_windowedPreds = null; if (m_textListeners.size() > 0) { String textTitle = ce.getClassifier().getClass().getName(); textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length()); String results = "=== Performance information ===\n\n" + "Scheme: " + textTitle + "\n" + "Relation: " + m_eval.getHeader().relationName() + "\n\n" + m_eval.toSummaryString(); if (m_eval.getHeader().classIndex() >= 0 && m_eval.getHeader().classAttribute().isNominal() && (m_outputInfoRetrievalStats)) { results += "\n" + m_eval.toClassDetailsString(); } if (m_eval.getHeader().classIndex() >= 0 && m_eval.getHeader().classAttribute().isNominal()) { results += "\n" + m_eval.toMatrixString(); } textTitle = "Results: " + textTitle; TextEvent te = new TextEvent(this, results, textTitle); notifyTextListeners(te); } } } } catch (Exception ex) { if (m_logger != null) { m_logger.logMessage( "[IncrementalClassifierEvaluator]" + statusMessagePrefix() + " Error processing prediction " + ex.getMessage()); m_logger.statusMessage( statusMessagePrefix() + "ERROR: problem processing prediction (see log for details)"); } ex.printStackTrace(); stop(); } }