void showNFoldTypeSpecificAccuracy(InstanceList instanceList, int n) { InstanceList.CrossValidationIterator cvIt = instanceList.crossValidationIterator(n); double accuracies[] = new double[n]; double accuracy = 0; int run = 0; while (cvIt.hasNext()) { InstanceList[] nextSplit = cvIt.nextSplit(); InstanceList trainingInstances = nextSplit[0]; InstanceList testingInstances = nextSplit[1]; double[] result = getTypeSpecificAccuracy(trainingInstances, testingInstances, true); accuracies[run++] = result[2]; accuracy += result[2]; } System.out.println(n + "-Fold cross-validation:"); System.out.println("Accuracy: " + accuracy / n); }
private void showNFoldAccuracy(InstanceList instanceList, int n, int count) { InstanceList.CrossValidationIterator cvIt = instanceList.crossValidationIterator(n); double accuracies[] = new double[n]; double accuracy = 0; int run = 0; double totalTP = 0; while (cvIt.hasNext()) { InstanceList[] nextSplit = cvIt.nextSplit(); InstanceList trainingInstances = nextSplit[0]; InstanceList testingInstances = nextSplit[1]; trainer = new MyClassifierTrainer(new RankMaxEntTrainer()); Classifier classifier = trainer.train(trainingInstances); accuracies[run] = getAccuracy(classifier, testingInstances); accuracy += accuracies[run]; totalTP += accuracies[run] * testingInstances.size(); run++; } System.out.println(n + "-Fold accuracy(avg): " + accuracy / n); System.out.println("Total tp:" + totalTP); System.out.println("Total count:" + count); System.out.println(n + "-Fold accuracy: " + totalTP / count); }
public static ArrayList<Double> getAverageCrossValidationScore( InstanceList ilist, int i, ClassifierTrainer trainer) { double crossValidAccSum = 0; double crossValidPrcSum = 0; double crossValidRecSum = 0; double crossValidF1Sum = 0; int count = 0; // get gross validation folds CrossValidationIterator cvIlists = ilist.crossValidationIterator(i); while (cvIlists.hasNext()) { System.out.println("#############Performing " + count + " iteration###########"); InstanceList[] ilists = cvIlists.next(); System.out.println("The train set size is " + ilists[0].size()); System.out.println("The test set size is " + ilists[1].size()); Classifier classifier = trainer.train(ilists[0]); System.out.println("The training accuracy is " + classifier.getAccuracy(ilists[0])); System.out.println("The testing accuracy is " + classifier.getAccuracy(ilists[1])); System.out.println("The testing precision is " + classifier.getPrecision(ilists[1], 1)); System.out.println("The testing recall is " + classifier.getRecall(ilists[1], 1)); System.out.println("The testing f1score is " + classifier.getF1(ilists[1], 1)); crossValidAccSum += classifier.getAccuracy(ilists[1]); crossValidPrcSum += classifier.getPrecision(ilists[1], 1); crossValidRecSum += classifier.getRecall(ilists[1], 1); crossValidF1Sum += classifier.getF1(ilists[1], 1); count++; // additional calculations ArrayList<Classification> outClassifications = classifier.classify(ilists[1]); int p1l1 = 0; int p1l0 = 0; int p0l1 = 0; int p0l0 = 0; int countCorrect = 0; int countIncorrect = 0; System.out.println("Outclassification size " + outClassifications.size()); for (int k = 0; k < outClassifications.size(); k++) { // System.out.println("Data "+outClassifications.get(k).getInstance().getName()); // System.out.println("Labeling "+outClassifications.get(k).getLabeling()); uncomment to get // score double predictedLabel = outClassifications.get(k).getLabeling().getBestIndex(); // System.out.println("Predicted label "+ predictedLabel); double targetLabel = Double.valueOf(outClassifications.get(k).getInstance().getTarget().toString()); // System.out.println("Target "+ targetLabel); boolean bestlabelIsCorrect = outClassifications.get(k).bestLabelIsCorrect(); // System.out.println("Prediction "+bestlabelIsCorrect); if (bestlabelIsCorrect) countCorrect++; else countIncorrect++; if ((predictedLabel == 1.0) && (targetLabel == 1.0)) p1l1++; else if ((predictedLabel == 1.0) && (targetLabel == 0.0)) p1l0++; else if ((predictedLabel == 0.0) && (targetLabel == 1.0)) p0l1++; else if ((predictedLabel == 0.0) && (targetLabel == 0.0)) p0l0++; } System.out.println("Count Correct " + countCorrect); System.out.println("Count Incorrect " + countIncorrect); System.out.println("p1l1 " + p1l1); System.out.println("p1l0 " + p1l0); System.out.println("p0l1 " + p0l1); System.out.println("p0l0 " + p0l0); } ArrayList<Double> results = new ArrayList<Double>(); double crossValidAccAvg = crossValidAccSum / count; double crossValidPrcAvg = crossValidPrcSum / count; double crossValidRecAvg = crossValidRecSum / count; double crossValidF1Avg = crossValidF1Sum / count; results.add(crossValidAccAvg); results.add(crossValidPrcAvg); results.add(crossValidRecAvg); results.add(crossValidF1Avg); return results; }