public static void wekaAlgorithms(Instances data) throws Exception { classifier = new FilteredClassifier(); // new instance of tree classifier.setClassifier(new NaiveBayes()); // classifier.setClassifier(new J48()); // classifier.setClassifier(new RandomForest()); // classifier.setClassifier(new ZeroR()); // classifier.setClassifier(new NaiveBayes()); // classifier.setClassifier(new IBk()); data.setClassIndex(data.numAttributes() - 1); Evaluation eval = new Evaluation(data); int folds = 10; eval.crossValidateModel(classifier, data, folds, new Random(1)); System.out.println("===== Evaluating on filtered (training) dataset ====="); System.out.println(eval.toSummaryString()); System.out.println(eval.toClassDetailsString()); double[][] mat = eval.confusionMatrix(); System.out.println("========= Confusion Matrix ========="); for (int i = 0; i < mat.length; i++) { for (int j = 0; j < mat.length; j++) { System.out.print(mat[i][j] + " "); } System.out.println(" "); } }
/** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { Instances isTrainingSet = createSet(4); Instance instance1 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance2 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); Instance instance22 = createInstance(new double[] {0, 0, 0, 0}, "S3", isTrainingSet); isTrainingSet.add(instance1); isTrainingSet.add(instance2); isTrainingSet.add(instance22); Instances isTestingSet = createSet(4); Instance instance3 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance4 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); isTestingSet.add(instance3); isTestingSet.add(instance4); // Create a naïve bayes classifier Classifier cModel = (Classifier) new BayesNet(); // M5P cModel.buildClassifier(isTrainingSet); // Test the model Evaluation eTest = new Evaluation(isTrainingSet); eTest.evaluateModel(cModel, isTestingSet); // Print the result à la Weka explorer: String strSummary = eTest.toSummaryString(); System.out.println(strSummary); // Get the likelihood of each classes // fDistribution[0] is the probability of being “positive” // fDistribution[1] is the probability of being “negative” double[] fDistribution = cModel.distributionForInstance(instance4); for (int i = 0; i < fDistribution.length; i++) { System.out.println(fDistribution[i]); } }
/** * Creates an evaluation overview of the built classifier. * * @return the panel to be displayed as result evaluation view for the current decision point */ protected JPanel createEvaluationVisualization(Instances data) { // build text field to display evaluation statistics JTextPane statistic = new JTextPane(); try { // build evaluation statistics Evaluation evaluation = new Evaluation(data); evaluation.evaluateModel(myClassifier, data); statistic.setText( evaluation.toSummaryString() + "\n\n" + evaluation.toClassDetailsString() + "\n\n" + evaluation.toMatrixString()); } catch (Exception ex) { ex.printStackTrace(); return createMessagePanel("Error while creating the decision tree evaluation view"); } statistic.setFont(new Font("Courier", Font.PLAIN, 14)); statistic.setEditable(false); statistic.setCaretPosition(0); JPanel resultViewPanel = new JPanel(); resultViewPanel.setLayout(new BoxLayout(resultViewPanel, BoxLayout.PAGE_AXIS)); resultViewPanel.add(new JScrollPane(statistic)); return resultViewPanel; }
public static void run(String[] args) throws Exception { /** * ************************************************* * * @param args[0]: train arff path * @param args[1]: test arff path */ DataSource source = new DataSource(args[0]); Instances data = source.getDataSet(); data.setClassIndex(data.numAttributes() - 1); NaiveBayes model = new NaiveBayes(); model.buildClassifier(data); // Evaluation: Evaluation eval = new Evaluation(data); Instances testData = new DataSource(args[1]).getDataSet(); testData.setClassIndex(testData.numAttributes() - 1); eval.evaluateModel(model, testData); System.out.println(model.toString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); System.out.println("======\nConfusion Matrix:"); double[][] confusionM = eval.confusionMatrix(); for (int i = 0; i < confusionM.length; ++i) { for (int j = 0; j < confusionM[i].length; ++j) { System.out.format("%10s ", confusionM[i][j]); } System.out.print("\n"); } }
/** * Finds the best parameter combination. (recursive for each parameter being optimised). * * @param depth the index of the parameter to be optimised at this level * @param trainData the data the search is based on * @param random a random number generator * @throws Exception if an error occurs */ protected void findParamsByCrossValidation(int depth, Instances trainData, Random random) throws Exception { if (depth < m_CVParams.size()) { CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth); double upper; switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: upper = m_NumAttributes; break; case 2: upper = m_TrainFoldSize; break; default: upper = cvParam.m_Upper; break; } double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1); for (cvParam.m_ParamValue = cvParam.m_Lower; cvParam.m_ParamValue <= upper; cvParam.m_ParamValue += increment) { findParamsByCrossValidation(depth + 1, trainData, random); } } else { Evaluation evaluation = new Evaluation(trainData); // Set the classifier options String[] options = createOptions(); if (m_Debug) { System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":"); for (int i = 0; i < options.length; i++) { System.err.print(" " + options[i]); } System.err.println(""); } ((OptionHandler) m_Classifier).setOptions(options); for (int j = 0; j < m_NumFolds; j++) { // We want to randomize the data the same way for every // learning scheme. Instances train = trainData.trainCV(m_NumFolds, j, new Random(1)); Instances test = trainData.testCV(m_NumFolds, j); m_Classifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(m_Classifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4)); } if ((m_BestPerformance == -99) || (error < m_BestPerformance)) { m_BestPerformance = error; m_BestClassifierOptions = createOptions(); } } }
/** evaluates the classifier */ @Override public void evaluate() throws Exception { // evaluate classifier and print some statistics if (_test.classIndex() == -1) _test.setClassIndex(_test.numAttributes() - 1); Evaluation eval = new Evaluation(_train); eval.evaluateModel(_cl, _test); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); System.out.println(eval.toMatrixString()); }
/** * Buildclassifier selects a classifier from the set of classifiers by minimising error on the * training data. * * @param data the training data to be used for generating the boosted classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifiers.length == 0) { throw new Exception("No base classifiers have been set!"); } Instances newData = new Instances(data); newData.deleteWithMissingClass(); newData.randomize(new Random(m_Seed)); if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1)) newData.stratify(m_NumXValFolds); Instances train = newData; // train on all data by default Instances test = newData; // test on training data by default Classifier bestClassifier = null; int bestIndex = -1; double bestPerformance = Double.NaN; int numClassifiers = m_Classifiers.length; for (int i = 0; i < numClassifiers; i++) { Classifier currentClassifier = getClassifier(i); Evaluation evaluation; if (m_NumXValFolds > 1) { evaluation = new Evaluation(newData); for (int j = 0; j < m_NumXValFolds; j++) { train = newData.trainCV(m_NumXValFolds, j); test = newData.testCV(m_NumXValFolds, j); currentClassifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(currentClassifier, test); } } else { currentClassifier.buildClassifier(train); evaluation = new Evaluation(train); evaluation.evaluateModel(currentClassifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println( "Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName()); } if ((i == 0) || (error < bestPerformance)) { bestClassifier = currentClassifier; bestPerformance = error; bestIndex = i; } } m_ClassifierIndex = bestIndex; m_Classifier = bestClassifier; if (m_NumXValFolds > 1) { m_Classifier.buildClassifier(newData); } }
public Evaluation evaluateClassifier(Instances trainInstances, Instances testInstances) { try { Evaluation eval = new Evaluation(trainInstances); eval.evaluateModel(bayesNet, testInstances); return eval; } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); return null; } }
/** uses the meta-classifier */ protected static void useClassifier(Instances data) throws Exception { System.out.println("\n1. Meta-classfier"); AttributeSelectedClassifier classifier = new AttributeSelectedClassifier(); CfsSubsetEval eval = new CfsSubsetEval(); GreedyStepwise search = new GreedyStepwise(); search.setSearchBackwards(true); J48 base = new J48(); classifier.setClassifier(base); classifier.setEvaluator(eval); classifier.setSearch(search); Evaluation evaluation = new Evaluation(data); evaluation.crossValidateModel(classifier, data, 10, new Random(1)); System.out.println(evaluation.toSummaryString()); }
public static Double runClassify(String trainFile, String testFile) { double predictOrder = 0.0; double trueOrder = 0.0; try { String trainWekaFileName = trainFile; String testWekaFileName = testFile; Instances train = DataSource.read(trainWekaFileName); Instances test = DataSource.read(testWekaFileName); train.setClassIndex(0); test.setClassIndex(0); train.deleteAttributeAt(8); test.deleteAttributeAt(8); train.deleteAttributeAt(6); test.deleteAttributeAt(6); train.deleteAttributeAt(5); test.deleteAttributeAt(5); train.deleteAttributeAt(4); test.deleteAttributeAt(4); // AdditiveRegression classifier = new AdditiveRegression(); // NaiveBayes classifier = new NaiveBayes(); RandomForest classifier = new RandomForest(); // LibSVM classifier = new LibSVM(); classifier.buildClassifier(train); Evaluation eval = new Evaluation(train); eval.evaluateModel(classifier, test); System.out.println(eval.toSummaryString("\nResults\n\n", true)); // System.out.println(eval.toClassDetailsString()); // System.out.println(eval.toMatrixString()); int k = 892; for (int i = 0; i < test.numInstances(); i++) { predictOrder = classifier.classifyInstance(test.instance(i)); trueOrder = test.instance(i).classValue(); System.out.println((k++) + "," + (int) predictOrder); } } catch (Exception e) { e.printStackTrace(); } return predictOrder; }
public static Instances getKnowledgeBase() { if (knowledgeBase == null) { try { // load knowledgebase from file CreateAppInsertIntoVm.knowledgeBase = Action.loadKnowledge(Configuration.getInstance().getKBCreateAppInsertIntoVm()); // prediction is also performed therefore the classifier and the evaluator must be // instantiated if (!isOnlyLearning()) { System.out.println("Classify data CreateAppInsertInto"); if (knowledgeBase.numInstances() > 0) { classifier = new MultilayerPerceptron(); classifier.buildClassifier(knowledgeBase); evaluation = new Evaluation(knowledgeBase); evaluation.crossValidateModel( classifier, knowledgeBase, 10, knowledgeBase.getRandomNumberGenerator(randomData.nextLong(1, 1000))); System.out.println("Classified data CreateAppInsertInto"); } else { System.out.println("No Instancedata for classifier CreateAppInsertIntoVm"); } } } catch (Exception e) { e.printStackTrace(); } } return knowledgeBase; }
/** * Process a classifier's prediction for an instance and update a set of plotting instances and * additional plotting info. m_PlotShape for nominal class datasets holds shape types (actual data * points have automatic shape type assignment; classifier error data points have box shape type). * For numeric class datasets, the actual data points are stored in m_PlotInstances and m_PlotSize * stores the error (which is later converted to shape size values). * * @param toPredict the actual data point * @param classifier the classifier * @param eval the evaluation object to use for evaluating the classifier on the instance to * predict * @see #m_PlotShapes * @see #m_PlotSizes * @see #m_PlotInstances */ public void process(Instance toPredict, Classifier classifier, Evaluation eval) { double pred; double[] values; int i; try { pred = eval.evaluateModelOnceAndRecordPrediction(classifier, toPredict); if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { toPredict = ((weka.classifiers.misc.InputMappedClassifier) classifier) .constructMappedInstance(toPredict); } if (!m_SaveForVisualization) return; if (m_PlotInstances != null) { values = new double[m_PlotInstances.numAttributes()]; for (i = 0; i < m_PlotInstances.numAttributes(); i++) { if (i < toPredict.classIndex()) { values[i] = toPredict.value(i); } else if (i == toPredict.classIndex()) { values[i] = pred; values[i + 1] = toPredict.value(i); i++; } else { values[i] = toPredict.value(i - 1); } } m_PlotInstances.add(new DenseInstance(1.0, values)); if (toPredict.classAttribute().isNominal()) { if (toPredict.isMissing(toPredict.classIndex()) || Utils.isMissingValue(pred)) { m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE)); } else if (pred != toPredict.classValue()) { // set to default error point shape m_PlotShapes.addElement(new Integer(Plot2D.ERROR_SHAPE)); } else { // otherwise set to constant (automatically assigned) point shape m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } m_PlotSizes.addElement(new Integer(Plot2D.DEFAULT_SHAPE_SIZE)); } else { // store the error (to be converted to a point size later) Double errd = null; if (!toPredict.isMissing(toPredict.classIndex()) && !Utils.isMissingValue(pred)) { errd = new Double(pred - toPredict.classValue()); m_PlotShapes.addElement(new Integer(Plot2D.CONST_AUTOMATIC_SHAPE)); } else { // missing shape if actual class not present or prediction is missing m_PlotShapes.addElement(new Integer(Plot2D.MISSING_SHAPE)); } m_PlotSizes.addElement(errd); } } } catch (Exception ex) { ex.printStackTrace(); } }
// calculate if an App fits to a pm // TODO: gst: use WEKA to calc fit factor!! private int calculateFit(App app2, VirtualMachine vm) { int output = 0; if (Action.isOnlyLearning() == false && CreateAppInsertIntoVm.evaluation != null) { // is free space available in the VM if (app2.getCpu() + vm.getCurrentCpuUsage() < vm.getCurrentCpuAllocation() && app2.getMemory() + vm.getCurrentMemoryUsage() < vm.getCurrentMemoryAllocation() && app2.getStorage() + vm.getCurrentStorageUsage() < vm.getCurrentCpuAllocation()) { Instance instance = createInstance(Instance.missingValue(), vm); instance.setDataset(CreateAppInsertIntoVm.getKnowledgeBase()); try { output = (int) (evaluation.evaluateModelOnce(classifier, instance) * 100); } catch (Exception e) { e.printStackTrace(); } } } else { if (app2.getCpu() + vm.getCurrentCpuUsage() < vm.getCurrentCpuAllocation() && app2.getMemory() + vm.getCurrentMemoryUsage() < vm.getCurrentMemoryAllocation() && app2.getStorage() + vm.getCurrentStorageUsage() < vm.getCurrentCpuAllocation()) { output = randomData.nextInt(1, 100); } } return output; }
/** * Adds the prediction intervals as additional attributes at the end. Since classifiers can * returns varying number of intervals per instance, the dataset is filled with missing values for * non-existing intervals. */ protected void addPredictionIntervals() { int maxNum; int num; int i; int n; FastVector preds; FastVector atts; Instances data; Instance inst; Instance newInst; double[] values; double[][] predInt; // determine the maximum number of intervals maxNum = 0; preds = m_Evaluation.predictions(); for (i = 0; i < preds.size(); i++) { num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length; if (num > maxNum) maxNum = num; } // create new header atts = new FastVector(); for (i = 0; i < m_PlotInstances.numAttributes(); i++) atts.addElement(m_PlotInstances.attribute(i)); for (i = 0; i < maxNum; i++) { atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-lowerBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-upperBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-width")); } data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances()); data.setClassIndex(m_PlotInstances.classIndex()); // update data for (i = 0; i < m_PlotInstances.numInstances(); i++) { inst = m_PlotInstances.instance(i); // copy old values values = new double[data.numAttributes()]; System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes()); // add interval data predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals(); for (n = 0; n < maxNum; n++) { if (n < predInt.length) { values[m_PlotInstances.numAttributes() + n * 3 + 0] = predInt[n][0]; values[m_PlotInstances.numAttributes() + n * 3 + 1] = predInt[n][1]; values[m_PlotInstances.numAttributes() + n * 3 + 2] = predInt[n][1] - predInt[n][0]; } else { values[m_PlotInstances.numAttributes() + n * 3 + 0] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n * 3 + 1] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n * 3 + 2] = Utils.missingValue(); } } // create new Instance newInst = new DenseInstance(inst.weight(), values); data.add(newInst); } m_PlotInstances = data; }
public static void trainModel(Instances dataTrain, Instances dataTest) { try { LibLINEAR classifier = new LibLINEAR(); classifier.setBias(10); classifier.buildClassifier(dataTrain); Evaluation eval = new Evaluation(dataTrain); eval.evaluateModel(classifier, dataTest); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
/** * Main method for testing this class. * * @param argv the options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new Decorate(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }
@Override protected pikater.ontology.messages.Evaluation evaluateCA() { Evaluation eval = test(); pikater.ontology.messages.Evaluation result = new pikater.ontology.messages.Evaluation(); result.setError_rate((float) eval.errorRate()); try { result.setKappa_statistic((float) eval.kappa()); } catch (Exception e) { result.setKappa_statistic(-1); } result.setMean_absolute_error((float) eval.meanAbsoluteError()); try { result.setRelative_absolute_error((float) eval.relativeAbsoluteError()); } catch (Exception e) { result.setRelative_absolute_error(-1); } result.setRoot_mean_squared_error((float) eval.rootMeanSquaredError()); result.setRoot_relative_squared_error((float) eval.rootRelativeSquaredError()); return result; }
/** * Main method for testing this class * * @param argv options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new UnivariateLinearRegression(), argv)); } catch (Exception e) { System.out.println(e.getMessage()); e.printStackTrace(); } }
/** * Utility method for fast 5-fold cross validation of a naive bayes model * * @param fullModel a <code>NaiveBayesUpdateable</code> value * @param trainingSet an <code>Instances</code> value * @param r a <code>Random</code> value * @return a <code>double</code> value * @exception Exception if an error occurs */ public static double crossValidate( NaiveBayesUpdateable fullModel, Instances trainingSet, Random r) throws Exception { // make some copies for fast evaluation of 5-fold xval Classifier[] copies = AbstractClassifier.makeCopies(fullModel, 5); Evaluation eval = new Evaluation(trainingSet); // make some splits for (int j = 0; j < 5; j++) { Instances test = trainingSet.testCV(5, j); // unlearn these test instances for (int k = 0; k < test.numInstances(); k++) { test.instance(k).setWeight(-test.instance(k).weight()); ((NaiveBayesUpdateable) copies[j]).updateClassifier(test.instance(k)); // reset the weight back to its original value test.instance(k).setWeight(-test.instance(k).weight()); } eval.evaluateModel(copies[j], test); } return eval.incorrect(); }
@Override protected Evaluation test() { working = true; System.out.println("Agent " + getLocalName() + ": Testing..."); // evaluate classifier and print some statistics Evaluation eval = null; try { eval = new Evaluation(train); eval.evaluateModel(cls, test); System.out.println( eval.toSummaryString(getLocalName() + " agent: " + "\nResults\n=======\n", false)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } working = false; return eval; } // end test
@Override public void crossValidation(String traindata) throws Exception { DataSource ds = new DataSource(traindata); Instances instances = ds.getDataSet(); StringToWordVector stv = new StringToWordVector(); stv.setOptions( weka.core.Utils.splitOptions( "-R first-last -W 1000 " + "-prune-rate -1.0 -N 0 " + "-stemmer weka.core.stemmers.NullStemmer -M 1 " + "-tokenizer \"weka.core.tokenizers.WordTokenizer -delimiters \\\" \\r\\n\\t.,;:\\\'\\\"()?!\"")); stv.setInputFormat(instances); instances = Filter.useFilter(instances, stv); instances.setClassIndex(0); Evaluation eval = new Evaluation(instances); eval.crossValidateModel(this.classifier, instances, 10, new Random(1)); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); }
private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData) throws Exception { System.err.println( "INFO: Starting split validation to predict '" + trainData.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' (#train=" + trainData.numInstances() + ",#test=" + testData.numInstances() + ") ..."); if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set"); c.buildClassifier(trainData); Evaluation eval = new Evaluation(testData); eval.useNoPriors(); double[] predictions = eval.evaluateModel(c, testData); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); // write predictions to file { System.err.println("INFO: Writing predictions to file ..."); Writer out = new FileWriter("prediction.trec"); writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out); out.close(); } // write predicted distributions to CSV { System.err.println("INFO: Writing predicted distributions to CSV ..."); Writer out = new FileWriter("predicted_distribution.csv"); writePredictedDistributions(c, testData, 0, out); out.close(); } }
protected static Evaluation adjustForInputMappedClassifier( Evaluation eval, weka.classifiers.Classifier classifier, Instances inst, ClassifierErrorsPlotInstances plotInstances) throws Exception { if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { Instances mappedClassifierHeader = ((weka.classifiers.misc.InputMappedClassifier) classifier) .getModelHeader(new Instances(inst, 0)); eval = new Evaluation(new Instances(mappedClassifierHeader, 0)); if (!eval.getHeader().equalHeaders(inst)) { // When the InputMappedClassifier is loading a model, // we need to make a new dataset that maps the test instances to // the structure expected by the mapped classifier - this is only // to ensure that the ClassifierPlotInstances object is configured // in accordance with what the embeded classifier was trained with Instances mappedClassifierDataset = ((weka.classifiers.misc.InputMappedClassifier) classifier) .getModelHeader(new Instances(mappedClassifierHeader, 0)); for (int zz = 0; zz < inst.numInstances(); zz++) { Instance mapped = ((weka.classifiers.misc.InputMappedClassifier) classifier) .constructMappedInstance(inst.instance(zz)); mappedClassifierDataset.add(mapped); } eval.setPriors(mappedClassifierDataset); plotInstances.setInstances(mappedClassifierDataset); plotInstances.setClassifier(classifier); plotInstances.setClassIndex(mappedClassifierDataset.classIndex()); plotInstances.setEvaluation(eval); } } return eval; }
/** * 用分类器测试 * * @param trainFileName * @param testFileName */ public static void classify(String trainFileName, String testFileName) { try { File inputFile = new File(fileName + trainFileName); // 训练语料文件 ArffLoader atf = new ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); // 读入训练文件 // 设置类标签类 inputFile = new File(fileName + testFileName); // 测试语料文件 atf.setFile(inputFile); Instances instancesTest = atf.getDataSet(); // 读入测试文件 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance(); classifier.buildClassifier(instancesTrain); Evaluation eval = new Evaluation(instancesTrain); // 第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集 eval.evaluateModel(classifier, instancesTest); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); System.out.println("precision is :" + (1 - eval.errorRate())); } catch (Exception e) { e.printStackTrace(); } }
/** outputs some data about the classifier */ public String toString() { StringBuffer result; result = new StringBuffer(); result.append("Weka - Demo\n===========\n\n"); result.append( "Classifier...: " + m_Classifier.getClass().getName() + " " + Utils.joinOptions(m_Classifier.getOptions()) + "\n"); if (m_Filter instanceof OptionHandler) result.append( "Filter.......: " + m_Filter.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) m_Filter).getOptions()) + "\n"); else result.append("Filter.......: " + m_Filter.getClass().getName() + "\n"); result.append("Training file: " + m_TrainingFile + "\n"); result.append("\n"); result.append(m_Classifier.toString() + "\n"); result.append(m_Evaluation.toSummaryString() + "\n"); try { result.append(m_Evaluation.toMatrixString() + "\n"); } catch (Exception e) { e.printStackTrace(); } try { result.append(m_Evaluation.toClassDetailsString() + "\n"); } catch (Exception e) { e.printStackTrace(); } return result.toString(); }
/** runs 10fold CV over the training file */ public void execute() throws Exception { // run filter m_Filter.setInputFormat(m_Training); Instances filtered = Filter.useFilter(m_Training, m_Filter); // train classifier on complete file for tree m_Classifier.buildClassifier(filtered); // 10fold CV with seed=1 m_Evaluation = new Evaluation(filtered); m_Evaluation.crossValidateModel( m_Classifier, filtered, 10, m_Training.getRandomNumberGenerator(1)); }
private static void run() throws Exception { DataSource source = new DataSource("src/files/powerpuffgirls.arff"); int folds = 10; int runs = 30; HashMap<String, Classifier> hash = new HashMap<>(); hash.put("J48", new J48()); hash.put("NaiveBayes", new NaiveBayes()); hash.put("IBk=1", new IBk(1)); hash.put("IBk=3", new IBk(3)); hash.put("MultilayerPerceptron", new MultilayerPerceptron()); // LibSVM svm = new LibSVM(); // svm.setOptions(new String[]{"-S 0 -K 2 -D 3 -G 0.0 -R 0.0 -N 0.5 -M 0.40 -C 1.0 -E // 0.001 -P 0.1"}); // hash.put("LibSVM", svm); Instances data = source.getDataSet(); data.setClassIndex(4); System.out.println("#seed \t correctly instances \t percentage of corrects\n"); for (Entry<String, Classifier> entry : hash.entrySet()) { System.out.println("\n Algorithm: " + entry.getKey() + "\n"); for (int i = 1; i <= runs; i++) { Evaluation eval = new Evaluation(data); eval.crossValidateModel(entry.getValue(), data, folds, new Random(i)); System.out.println(summary(eval)); } } }
private double[] makePredictions( Classifier classifier, Instances validationSet, Evaluation evaluation) { double[] predictions = null; try { predictions = evaluation.evaluateModel(classifier, validationSet); } catch (ArrayIndexOutOfBoundsException e) { throw new ClassifierPredictionException( "Error applying the trained classifier to the train instances. The number of features of the instance exceeds the number of features the classifier was trained on.", e); } catch (Exception e) { throw new ClassifierPredictionException( "Error applying the trained classifier to the test instances.", e); } return predictions; }
@Override protected DataInstances getPredictions(Instances test, DataInstances onto_test) { Evaluation eval = test(); double pre[] = new double[test.numInstances()]; for (int i = 0; i < test.numInstances(); i++) { try { pre[i] = eval.evaluateModelOnce((Classifier) getModelObject(), test.instance(i)); } catch (Exception e) { pre[i] = Integer.MAX_VALUE; } } // copy results to the DataInstancs int i = 0; Iterator itr = onto_test.getInstances().iterator(); while (itr.hasNext()) { Instance next_instance = (Instance) itr.next(); next_instance.setPrediction(pre[i]); i++; } return onto_test; }
static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception { System.err.println( "INFO: Starting crossvalidation to predict '" + data.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' ..."); StringBuffer sb = new StringBuffer(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE); // write predictions to file { Writer out = new FileWriter("cv.log"); out.write(sb.toString()); out.close(); } System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); }