/** * Buildclassifier selects a classifier from the set of classifiers by minimising error on the * training data. * * @param data the training data to be used for generating the boosted classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifiers.length == 0) { throw new Exception("No base classifiers have been set!"); } Instances newData = new Instances(data); newData.deleteWithMissingClass(); newData.randomize(new Random(m_Seed)); if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1)) newData.stratify(m_NumXValFolds); Instances train = newData; // train on all data by default Instances test = newData; // test on training data by default Classifier bestClassifier = null; int bestIndex = -1; double bestPerformance = Double.NaN; int numClassifiers = m_Classifiers.length; for (int i = 0; i < numClassifiers; i++) { Classifier currentClassifier = getClassifier(i); Evaluation evaluation; if (m_NumXValFolds > 1) { evaluation = new Evaluation(newData); for (int j = 0; j < m_NumXValFolds; j++) { train = newData.trainCV(m_NumXValFolds, j); test = newData.testCV(m_NumXValFolds, j); currentClassifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(currentClassifier, test); } } else { currentClassifier.buildClassifier(train); evaluation = new Evaluation(train); evaluation.evaluateModel(currentClassifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println( "Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName()); } if ((i == 0) || (error < bestPerformance)) { bestClassifier = currentClassifier; bestPerformance = error; bestIndex = i; } } m_ClassifierIndex = bestIndex; m_Classifier = bestClassifier; if (m_NumXValFolds > 1) { m_Classifier.buildClassifier(newData); } }
/** * 用分类器测试 * * @param trainFileName * @param testFileName */ public static void classify(String trainFileName, String testFileName) { try { File inputFile = new File(fileName + trainFileName); // 训练语料文件 ArffLoader atf = new ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); // 读入训练文件 // 设置类标签类 inputFile = new File(fileName + testFileName); // 测试语料文件 atf.setFile(inputFile); Instances instancesTest = atf.getDataSet(); // 读入测试文件 instancesTest.setClassIndex(instancesTest.numAttributes() - 1); instancesTrain.setClassIndex(instancesTrain.numAttributes() - 1); classifier = (Classifier) Class.forName(CLASSIFIERNAME).newInstance(); classifier.buildClassifier(instancesTrain); Evaluation eval = new Evaluation(instancesTrain); // 第一个为一个训练过的分类器,第二个参数是在某个数据集上评价的数据集 eval.evaluateModel(classifier, instancesTest); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString()); System.out.println(eval.toMatrixString()); System.out.println("precision is :" + (1 - eval.errorRate())); } catch (Exception e) { e.printStackTrace(); } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (!(m_Classifier instanceof WeightedInstancesHandler)) { throw new IllegalArgumentException("Classifier must be a " + "WeightedInstancesHandler!"); } // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); // only class? -> build ZeroR model if (instances.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(instances); return; } else { m_ZeroR = null; } m_Train = new Instances(instances, 0, instances.numInstances()); m_NNSearch.setInstances(m_Train); }
public static Instances getKnowledgeBase() { if (knowledgeBase == null) { try { // load knowledgebase from file CreateAppInsertIntoVm.knowledgeBase = Action.loadKnowledge(Configuration.getInstance().getKBCreateAppInsertIntoVm()); // prediction is also performed therefore the classifier and the evaluator must be // instantiated if (!isOnlyLearning()) { System.out.println("Classify data CreateAppInsertInto"); if (knowledgeBase.numInstances() > 0) { classifier = new MultilayerPerceptron(); classifier.buildClassifier(knowledgeBase); evaluation = new Evaluation(knowledgeBase); evaluation.crossValidateModel( classifier, knowledgeBase, 10, knowledgeBase.getRandomNumberGenerator(randomData.nextLong(1, 1000))); System.out.println("Classified data CreateAppInsertInto"); } else { System.out.println("No Instancedata for classifier CreateAppInsertIntoVm"); } } } catch (Exception e) { e.printStackTrace(); } } return knowledgeBase; }
/** * Boosting method. Boosts any classifier that can handle weighted instances. * * @param data the training data to be used for generating the boosted classifier. * @throws Exception if the classifier could not be built successfully */ protected void buildClassifierWithWeights(Instances data) throws Exception { Instances trainData, training, trainingWeightsNotNormalized; int numInstances = data.numInstances(); Random randomInstance = new Random(m_Seed); double minLoss = Double.MAX_VALUE; // Create a copy of the data so that when the weights are diddled // with it doesn't mess up the weights for anyone else trainingWeightsNotNormalized = new Instances(data, 0, numInstances); // Do boostrap iterations for (m_NumIterationsPerformed = -1; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++) { if (m_Debug) { System.err.println("Training classifier " + (m_NumIterationsPerformed + 1)); } training = new Instances(trainingWeightsNotNormalized); normalizeWeights(training, m_SumOfWeights); // Select instances to train the classifier on if (m_WeightThreshold < 100) { trainData = selectWeightQuantile(training, (double) m_WeightThreshold / 100); } else { trainData = new Instances(training, 0, numInstances); } // Build classifier if (m_NumIterationsPerformed == -1) { m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(data); } else { if (m_Classifiers[m_NumIterationsPerformed] instanceof Randomizable) ((Randomizable) m_Classifiers[m_NumIterationsPerformed]) .setSeed(randomInstance.nextInt()); m_Classifiers[m_NumIterationsPerformed].buildClassifier(trainData); } // Update instance weights setWeights(trainingWeightsNotNormalized, m_NumIterationsPerformed); // Has progress been made? double loss = 0; for (Instance inst : trainingWeightsNotNormalized) { loss += Math.log(inst.weight()); } if (m_Debug) { System.err.println("Current loss on log scale: " + loss); } if ((m_NumIterationsPerformed > -1) && (loss > minLoss)) { if (m_Debug) { System.err.println("Loss has increased: bailing out."); } break; } minLoss = loss; } }
/** trains the classifier */ @Override public void train() throws Exception { if (_train.classIndex() == -1) _train.setClassIndex(_train.numAttributes() - 1); _cl.buildClassifier(_train); // evaluate classifier and print some statistics evaluate(); }
private void jButton3ActionPerformed( java.awt.event.ActionEvent evt) { // GEN-FIRST:event_jButton3ActionPerformed // TODO add your handling code here: switch (jComboBox1.getSelectedIndex()) { case 0: model = new NaiveBayes(); jTextArea1.append("Building NaiveBayes model from training data ...\n"); break; case 1: model = new Id3(); jTextArea1.append("Building ID3 model from training data ...\n"); break; case 2: model = new J48(); jTextArea1.append("Building J48 model from training data ...\n"); break; } try { model.buildClassifier(training); jTextArea1.append("Model building is complete ...\n"); jButton4.setEnabled(true); jButton6.setEnabled(true); } catch (Exception ex) { jTextArea1.append("Model building failed ...\n"); jTextArea1.append(ex.getMessage()); jTextArea1.append("\n"); jButton4.setEnabled(true); jButton6.setEnabled(false); model = null; } } // GEN-LAST:event_jButton3ActionPerformed
/** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { Instances isTrainingSet = createSet(4); Instance instance1 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance2 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); Instance instance22 = createInstance(new double[] {0, 0, 0, 0}, "S3", isTrainingSet); isTrainingSet.add(instance1); isTrainingSet.add(instance2); isTrainingSet.add(instance22); Instances isTestingSet = createSet(4); Instance instance3 = createInstance(new double[] {1, 0.7, 0.1, 0.7}, "S1", isTrainingSet); Instance instance4 = createInstance(new double[] {0.1, 0.2, 1, 0.3}, "S2", isTrainingSet); isTestingSet.add(instance3); isTestingSet.add(instance4); // Create a naïve bayes classifier Classifier cModel = (Classifier) new BayesNet(); // M5P cModel.buildClassifier(isTrainingSet); // Test the model Evaluation eTest = new Evaluation(isTrainingSet); eTest.evaluateModel(cModel, isTestingSet); // Print the result à la Weka explorer: String strSummary = eTest.toSummaryString(); System.out.println(strSummary); // Get the likelihood of each classes // fDistribution[0] is the probability of being “positive” // fDistribution[1] is the probability of being “negative” double[] fDistribution = cModel.distributionForInstance(instance4); for (int i = 0; i < fDistribution.length; i++) { System.out.println(fDistribution[i]); } }
/** * Analyses the given list of decision points according to the context specified. Furthermore, the * context is provided with some visualization of the analysis result. * * @param decisionPoints the list of decision points to be analysed * @param log the log to be analysed * @param highLevelPN the simulation model to export discovered data dependencies */ public void analyse(ClusterDecisionAnalyzer cda) { clusterDecisionAnalyzer = cda; // create empty data set with attribute information Instances data = cda.getDataInfo(); // in case no single learning instance can be provided (as decision // point is never // reached, or decision classes cannot specified properly) --> do not // call algorithm if (data.numInstances() == 0) { System.out.println("No learning instances available"); } // actually solve the classification problem else { try { myClassifier.buildClassifier(data); // build up result visualization cda.setResultVisualization(createResultVisualization()); cda.setEvaluationVisualization(createEvaluationVisualization(data)); } catch (Exception ex) { ex.printStackTrace(); cda.setResultVisualization( createMessagePanel("Error while solving the classification problem")); } } }
public void run() throws Exception { BufferedReader datafileclassificationpickup = readDataFile(Config.outputPath() + "DaysPickUpClassification.txt"); BufferedReader datafileclassificationdropoff = readDataFile(Config.outputPath() + "DaysDropOffClassification.txt"); BufferedReader datafileregresssionpickup = readDataFile(Config.outputPath() + "DaysPickUpRegression.txt"); BufferedReader datafileregresssiondropoff = readDataFile(Config.outputPath() + "DaysDropOffRegression.txt"); dataclassificationpickup = new Instances(datafileclassificationpickup); dataclassificationpickup.setClassIndex(dataclassificationpickup.numAttributes() - 1); dataclassificationdropoff = new Instances(datafileclassificationdropoff); dataclassificationdropoff.setClassIndex(dataclassificationdropoff.numAttributes() - 1); dataregressionpickup = new Instances(datafileregresssionpickup); dataregressionpickup.setClassIndex(dataregressionpickup.numAttributes() - 1); dataregressiondropoff = new Instances(datafileregresssiondropoff); dataregressiondropoff.setClassIndex(dataregressiondropoff.numAttributes() - 1); System.out.println("KNN classification model"); ibkclassificationpickup = new IBk(10); ibkclassificationpickup.buildClassifier(dataclassificationpickup); ibkclassificationdropoff = new IBk(10); ibkclassificationdropoff.buildClassifier(dataclassificationdropoff); System.out.println("Classification Model Ready"); System.out.println("KNN regression model"); ibkregressionpickup = new IBk(10); ibkregressionpickup.buildClassifier(dataregressionpickup); ibkregressiondropoff = new IBk(10); ibkregressiondropoff.buildClassifier(dataregressiondropoff); System.out.println("Regression Model Ready"); instclassificationpickup = new DenseInstance(9); instclassificationpickup.setDataset(dataclassificationpickup); instclassificationdropoff = new DenseInstance(9); instclassificationdropoff.setDataset(dataclassificationdropoff); instregressionpickup = new DenseInstance(9); instregressionpickup.setDataset(dataregressionpickup); instregressiondropoff = new DenseInstance(9); instregressiondropoff.setDataset(dataregressiondropoff); System.out.println("Models ready"); }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (!Groovy.isPresent()) throw new Exception("Groovy classes not in CLASSPATH!"); // try loading the module initGroovyObject(); // build the model if (m_GroovyObject != null) m_GroovyObject.buildClassifier(instances); else System.err.println("buildClassifier: No Groovy object present!"); }
public Classifier buildWekaClassifier(Instances wekaInstances) { bayesNet = new BayesNet(); try { bayesNet.buildClassifier(wekaInstances); } catch (Exception e) { System.err.println(e.getMessage()); e.printStackTrace(); System.exit(-1); } return bayesNet; }
/** * Build - Create transformation for this node, and train classifier of type H upon it. The * dataset should have class as index 'j', and remove all indices less than L *not* in paY. */ public void build(Instances D, Classifier H) throws Exception { // transform data T = transform(D); // build SLC 'h' h = AbstractClassifier.makeCopy(H); h.buildClassifier(T); // save templates // t_ = new SparseInstance(T.numAttributes()); // t_.setDataset(T); // t_.setClassMissing(); // [?,x,x,x] T.clear(); }
/** runs 10fold CV over the training file */ public void execute() throws Exception { // run filter m_Filter.setInputFormat(m_Training); Instances filtered = Filter.useFilter(m_Training, m_Filter); // train classifier on complete file for tree m_Classifier.buildClassifier(filtered); // 10fold CV with seed=1 m_Evaluation = new Evaluation(filtered); m_Evaluation.crossValidateModel( m_Classifier, filtered, 10, m_Training.getRandomNumberGenerator(1)); }
@Override public final void run() { try { Classifier copiedClassifier = AbstractClassifier.makeCopy(classifier); copiedClassifier.buildClassifier(train); // log.print("The " + threadId + "th classifier is built!!!"); // accuracy = getAccuracy(copiedClassifier, test); // classifier = AbstractClassifier.makeCopy(classifier); // classifier.buildClassifier(train); log.print("The " + threadId + "th classifier is built!!!"); accuracy = getAccuracy(copiedClassifier, test); } catch (Exception e) { log.print(e.getStackTrace().toString()); log.print(e.toString()); } multiThreadEval.finishOneThreads(); log.print("The " + threadId + "th thread is finshed! accuracy = " + accuracy); }
private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData) throws Exception { System.err.println( "INFO: Starting split validation to predict '" + trainData.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' (#train=" + trainData.numInstances() + ",#test=" + testData.numInstances() + ") ..."); if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set"); c.buildClassifier(trainData); Evaluation eval = new Evaluation(testData); eval.useNoPriors(); double[] predictions = eval.evaluateModel(c, testData); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); // write predictions to file { System.err.println("INFO: Writing predictions to file ..."); Writer out = new FileWriter("prediction.trec"); writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out); out.close(); } // write predicted distributions to CSV { System.err.println("INFO: Writing predicted distributions to CSV ..."); Writer out = new FileWriter("predicted_distribution.csv"); writePredictedDistributions(c, testData, 0, out); out.close(); } }
/** * Cleanses the data based on misclassifications when performing cross-validation. * * @param data the data to train with and cleanse * @return the cleansed data * @throws Exception if something goes wrong */ private Instances cleanseCross(Instances data) throws Exception { Instance inst; Instances crossSet = new Instances(data); Instances temp = new Instances(data, data.numInstances()); Instances inverseSet = new Instances(data, data.numInstances()); int count = 0; double ans; int iterations = 0; int classIndex = m_classIndex; if (classIndex < 0) { classIndex = data.classIndex(); } if (classIndex < 0) { classIndex = data.numAttributes() - 1; } // loop until perfect while (count != crossSet.numInstances() && crossSet.numInstances() >= m_numOfCrossValidationFolds) { count = crossSet.numInstances(); // check if hit maximum number of iterations iterations++; if (m_numOfCleansingIterations > 0 && iterations > m_numOfCleansingIterations) { break; } crossSet.setClassIndex(classIndex); if (crossSet.classAttribute().isNominal()) { crossSet.stratify(m_numOfCrossValidationFolds); } // do the folds temp = new Instances(crossSet, crossSet.numInstances()); for (int fold = 0; fold < m_numOfCrossValidationFolds; fold++) { Instances train = crossSet.trainCV(m_numOfCrossValidationFolds, fold); m_cleansingClassifier.buildClassifier(train); Instances test = crossSet.testCV(m_numOfCrossValidationFolds, fold); // now test for (int i = 0; i < test.numInstances(); i++) { inst = test.instance(i); ans = m_cleansingClassifier.classifyInstance(inst); if (crossSet.classAttribute().isNumeric()) { if (ans >= inst.classValue() - m_numericClassifyThreshold && ans <= inst.classValue() + m_numericClassifyThreshold) { temp.add(inst); } else if (m_invertMatching) { inverseSet.add(inst); } } else { // class is nominal if (ans == inst.classValue()) { temp.add(inst); } else if (m_invertMatching) { inverseSet.add(inst); } } } } crossSet = temp; } if (m_invertMatching) { inverseSet.setClassIndex(data.classIndex()); return inverseSet; } else { crossSet.setClassIndex(data.classIndex()); return crossSet; } }
/** * Cleanses the data based on misclassifications when used training data. * * @param data the data to train with and cleanse * @return the cleansed data * @throws Exception if something goes wrong */ private Instances cleanseTrain(Instances data) throws Exception { Instance inst; Instances buildSet = new Instances(data); Instances temp = new Instances(data, data.numInstances()); Instances inverseSet = new Instances(data, data.numInstances()); int count = 0; double ans; int iterations = 0; int classIndex = m_classIndex; if (classIndex < 0) { classIndex = data.classIndex(); } if (classIndex < 0) { classIndex = data.numAttributes() - 1; } // loop until perfect while (count != buildSet.numInstances()) { // check if hit maximum number of iterations iterations++; if (m_numOfCleansingIterations > 0 && iterations > m_numOfCleansingIterations) { break; } // build classifier count = buildSet.numInstances(); buildSet.setClassIndex(classIndex); m_cleansingClassifier.buildClassifier(buildSet); temp = new Instances(buildSet, buildSet.numInstances()); // test on training data for (int i = 0; i < buildSet.numInstances(); i++) { inst = buildSet.instance(i); ans = m_cleansingClassifier.classifyInstance(inst); if (buildSet.classAttribute().isNumeric()) { if (ans >= inst.classValue() - m_numericClassifyThreshold && ans <= inst.classValue() + m_numericClassifyThreshold) { temp.add(inst); } else if (m_invertMatching) { inverseSet.add(inst); } } else { // class is nominal if (ans == inst.classValue()) { temp.add(inst); } else if (m_invertMatching) { inverseSet.add(inst); } } } buildSet = temp; } if (m_invertMatching) { inverseSet.setClassIndex(data.classIndex()); return inverseSet; } else { buildSet.setClassIndex(data.classIndex()); return buildSet; } }
/** * Build Decorate classifier * * @param data the training data to be used for generating the classifier * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (data.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (data.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Decorate can't handle a numeric class!"); } if (m_NumIterations < m_DesiredSize) throw new Exception("Max number of iterations must be >= desired ensemble size!"); // initialize random number generator if (m_Seed == -1) m_Random = new Random(); else m_Random = new Random(m_Seed); int i = 1; // current committee size int numTrials = 1; // number of Decorate iterations Instances divData = new Instances(data); // local copy of data - diversity data divData.deleteWithMissingClass(); Instances artData = null; // artificial data // compute number of artficial instances to add at each iteration int artSize = (int) (Math.abs(m_ArtSize) * divData.numInstances()); if (artSize == 0) artSize = 1; // atleast add one random example computeStats(data); // Compute training data stats for creating artificial examples // initialize new committee m_Committee = new Vector(); Classifier newClassifier = m_Classifier; newClassifier.buildClassifier(divData); m_Committee.add(newClassifier); double eComm = computeError(divData); // compute ensemble error if (m_Debug) System.out.println( "Initialize:\tClassifier " + i + " added to ensemble. Ensemble error = " + eComm); // repeat till desired committee size is reached OR the max number of iterations is exceeded while (i < m_DesiredSize && numTrials < m_NumIterations) { // Generate artificial training examples artData = generateArtificialData(artSize, data); // Label artificial examples labelData(artData); addInstances(divData, artData); // Add new artificial data // Build new classifier Classifier tmp[] = Classifier.makeCopies(m_Classifier, 1); newClassifier = tmp[0]; newClassifier.buildClassifier(divData); // Remove all the artificial data removeInstances(divData, artSize); // Test if the new classifier should be added to the ensemble m_Committee.add(newClassifier); // add new classifier to current committee double currError = computeError(divData); if (currError <= eComm) { // adding the new member did not increase the error i++; eComm = currError; if (m_Debug) System.out.println( "Iteration: " + (1 + numTrials) + "\tClassifier " + i + " added to ensemble. Ensemble error = " + eComm); } else { // reject the current classifier because it increased the ensemble error m_Committee.removeElementAt(m_Committee.size() - 1); // pop the last member } numTrials++; } }
public static void main(String[] args) throws Exception { BufferedReader reader = new BufferedReader(new FileReader("spambase.arff")); Instances data = new Instances(reader); reader.close(); // setting class attribute data.setClassIndex(data.numAttributes() - 1); int i = data.numInstances(); int j = data.numAttributes() - 1; File file = new File("tablelog.csv"); Writer output = null; output = new BufferedWriter(new FileWriter(file)); output.write( "%missing,auc1,correct1,fmeasure1,auc2,correct2,fmeasure2,auc3,correct3,fmeasure3\n"); Random randomGenerator = new Random(); data.randomize(randomGenerator); int numBlock = data.numInstances() / 2; double num0 = 0, num1 = 0, num2 = 0; /*mdata.instance(0).setMissing(0); mdata.deleteWithMissing(0); System.out.println(mdata.numInstances()+","+data.numInstances());*/ // Instances traindata=null; // Instances testdata=null; // System.out.println(data.instance(3).stringValue(1)); for (int perc = 10; perc < 101; perc = perc + 10) { Instances mdata = new Instances(data); int numMissing = perc * numBlock / 100; double y11[] = new double[2]; double y21[] = new double[2]; double y31[] = new double[2]; double y12[] = new double[2]; double y22[] = new double[2]; double y32[] = new double[2]; double y13[] = new double[2]; double y23[] = new double[2]; double y33[] = new double[2]; for (int p = 0; p < 2; p++) { Instances traindata = mdata.trainCV(2, p); Instances testdata = mdata.testCV(2, p); num0 = 0; num1 = 0; num2 = 0; for (int t = 0; t < numBlock; t++) { if (traindata.instance(t).classValue() == 0) num0++; if (traindata.instance(t).classValue() == 1) num1++; // if (traindata.instance(t).classValue()==2) num2++; } // System.out.println(mdata.instance(0).classValue()); Instances trainwithmissing = new Instances(traindata); Instances testwithmissing = new Instances(testdata); for (int q = 0; q < j; q++) { int r = randomGenerator.nextInt((int) i / 2); for (int k = 0; k < numMissing; k++) { // int r = randomGenerator.nextInt((int) i/2); // int c = randomGenerator.nextInt(j); trainwithmissing.instance((r + k) % numBlock).setMissing(q); testwithmissing.instance((r + k) % numBlock).setMissing(q); } } // trainwithmissing.deleteWithMissing(0);System.out.println(traindata.numInstances()+","+trainwithmissing.numInstances()); Classifier cModel = (Classifier) new Logistic(); // try for different classifiers and datasets cModel.buildClassifier(trainwithmissing); Evaluation eTest1 = new Evaluation(trainwithmissing); eTest1.evaluateModel(cModel, testdata); // eTest.crossValidateModel(cModel,mdata,10,mdata.getRandomNumberGenerator(1)); y11[p] = num0 / numBlock * eTest1.areaUnderROC(0) + num1 / numBlock * eTest1.areaUnderROC(1) /*+num2/numBlock*eTest1.areaUnderROC(2)*/; y21[p] = eTest1.correct(); y31[p] = num0 / numBlock * eTest1.fMeasure(0) + num1 / numBlock * eTest1.fMeasure(1) /*+num2/numBlock*eTest1.fMeasure(2)*/; Classifier cModel2 = (Classifier) new Logistic(); cModel2.buildClassifier(traindata); Evaluation eTest2 = new Evaluation(traindata); eTest2.evaluateModel(cModel2, testwithmissing); y12[p] = num0 / numBlock * eTest2.areaUnderROC(0) + num1 / numBlock * eTest2.areaUnderROC(1) /*+num2/numBlock*eTest2.areaUnderROC(2)*/; y22[p] = eTest2.correct(); y32[p] = num0 / numBlock * eTest2.fMeasure(0) + num1 / numBlock * eTest2.fMeasure(1) /*+num2/numBlock*eTest2.fMeasure(2)*/; Classifier cModel3 = (Classifier) new Logistic(); cModel3.buildClassifier(trainwithmissing); Evaluation eTest3 = new Evaluation(trainwithmissing); eTest3.evaluateModel(cModel3, testwithmissing); y13[p] = num0 / numBlock * eTest3.areaUnderROC(0) + num1 / numBlock * eTest3.areaUnderROC(1) /*+num2/numBlock*eTest3.areaUnderROC(2)*/; y23[p] = eTest3.correct(); y33[p] = num0 / numBlock * eTest3.fMeasure(0) + num1 / numBlock * eTest3.fMeasure(1) /*+num2/numBlock*eTest3.fMeasure(2)*/; // System.out.println(num0+","+num1+","+num2+"\n"); } double auc1 = (y11[0] + y11[1]) / 2; double auc2 = (y12[0] + y12[1]) / 2; double auc3 = (y13[0] + y13[1]) / 2; double corr1 = (y21[0] + y21[1]) / i; double corr2 = (y22[0] + y22[1]) / i; double corr3 = (y23[0] + y23[1]) / i; double fm1 = (y31[0] + y31[1]) / 2; double fm2 = (y32[0] + y32[1]) / 2; double fm3 = (y33[0] + y33[1]) / 2; output.write( perc + "," + auc1 + "," + corr1 + "," + fm1 + "," + auc2 + "," + corr2 + "," + fm2 + "," + auc3 + "," + corr3 + "," + fm3 + "\n"); // System.out.println(num0); // mdata=data; } output.close(); }
// 输入问题,输出问题所属类型。 public double classifyByBayes(String question) throws Exception { double label = -1; List<Question> questionID = questionDAO.getQuestionIDLabeled(); // 定义数据格式 Attribute att1 = new Attribute("法律政策"); Attribute att2 = new Attribute("位置交通"); Attribute att3 = new Attribute("风水"); Attribute att4 = new Attribute("房价"); Attribute att5 = new Attribute("楼层"); Attribute att6 = new Attribute("户型"); Attribute att7 = new Attribute("小区配套"); Attribute att8 = new Attribute("贷款"); Attribute att9 = new Attribute("买房时机"); Attribute att10 = new Attribute("开发商"); FastVector labels = new FastVector(); labels.addElement("1"); labels.addElement("2"); labels.addElement("3"); labels.addElement("4"); labels.addElement("5"); labels.addElement("6"); labels.addElement("7"); labels.addElement("8"); labels.addElement("9"); labels.addElement("10"); Attribute att11 = new Attribute("类别", labels); FastVector attributes = new FastVector(); attributes.addElement(att1); attributes.addElement(att2); attributes.addElement(att3); attributes.addElement(att4); attributes.addElement(att5); attributes.addElement(att6); attributes.addElement(att7); attributes.addElement(att8); attributes.addElement(att9); attributes.addElement(att10); attributes.addElement(att11); Instances dataset = new Instances("Test-dataset", attributes, 0); dataset.setClassIndex(10); Classifier classifier = null; if (!new File("naivebayes.model").exists()) { // 添加数据 double[] values = new double[11]; for (int i = 0; i < questionID.size(); i++) { for (int m = 0; m < 11; m++) { values[m] = 0; } int whitewordcount = 0; whitewordcount = questionDAO.getHitWhiteWordNum(questionID.get(i).getId()); if (whitewordcount != 0) { List<QuestionWhiteWord> questionwhiteword = questionDAO.getHitQuestionWhiteWord(questionID.get(i).getId()); for (int j = 0; j < questionwhiteword.size(); j++) { values[getAttIndex(questionwhiteword.get(j).getWordId()) - 1]++; } for (int m = 0; m < 11; m++) { values[m] = values[m] / whitewordcount; } } values[10] = questionID.get(i).getType() - 1; Instance inst = new Instance(1.0, values); dataset.add(inst); } // 构造分类器 classifier = new NaiveBayes(); classifier.buildClassifier(dataset); SerializationHelper.write("naivebayes.model", classifier); } else { classifier = (Classifier) SerializationHelper.read("naivebayes.model"); } System.out.println("*************begin evaluation*******************"); Evaluation evaluation = new Evaluation(dataset); evaluation.evaluateModel(classifier, dataset); // 按道理说,这里应该使用另一份数据,而不是训练集data。 System.out.println(evaluation.toSummaryString()); // 分类 System.out.println("*************begin classification*******************"); Instance subject = new Instance(1.0, getQuestionVector(question)); subject.setDataset(dataset); label = classifier.classifyInstance(subject); System.out.println("label: " + label); // double dis[]=classifier.distributionForInstance(inst); // for(double i:dis){ // System.out.print(i+" "); // } System.out.println(questionID.size()); return label + 1; }
@Override public void buildClassifier(Instances data) throws Exception { for (Classifier classifier : classifiers) { classifier.buildClassifier(data); } }
public void buildClassifier(Classifier classifier) throws Exception { this.classifier = classifier; classifier.buildClassifier(trainingData); }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // only class? -> build ZeroR model if (instances.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(instances); return; } else { m_ZeroR = null; } // reset variable m_NumClasses = instances.numClasses(); m_ClassIndex = instances.classIndex(); m_NumAttributes = instances.numAttributes(); m_NumInstances = instances.numInstances(); m_TotalAttValues = 0; // allocate space for attribute reference arrays m_StartAttIndex = new int[m_NumAttributes]; m_NumAttValues = new int[m_NumAttributes]; // set the starting index of each attribute and the number of values for // each attribute and the total number of values for all attributes (not including class). for (int i = 0; i < m_NumAttributes; i++) { if (i != m_ClassIndex) { m_StartAttIndex[i] = m_TotalAttValues; m_NumAttValues[i] = instances.attribute(i).numValues(); m_TotalAttValues += m_NumAttValues[i]; } else { m_StartAttIndex[i] = -1; m_NumAttValues[i] = m_NumClasses; } } // allocate space for counts and frequencies m_ClassCounts = new double[m_NumClasses]; m_AttCounts = new double[m_TotalAttValues]; m_AttAttCounts = new double[m_TotalAttValues][m_TotalAttValues]; m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues]; m_Header = new Instances(instances, 0); // Calculate the counts for (int k = 0; k < m_NumInstances; k++) { int classVal = (int) instances.instance(k).classValue(); m_ClassCounts[classVal]++; int[] attIndex = new int[m_NumAttributes]; for (int i = 0; i < m_NumAttributes; i++) { if (i == m_ClassIndex) { attIndex[i] = -1; } else { attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i); m_AttCounts[attIndex[i]]++; } } for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) { if (attIndex[Att1] == -1) continue; for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) { if ((attIndex[Att2] != -1)) { m_AttAttCounts[attIndex[Att1]][attIndex[Att2]]++; m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++; } } } } // compute mutual information between each attribute and class m_mutualInformation = new double[m_NumAttributes]; for (int att = 0; att < m_NumAttributes; att++) { if (att == m_ClassIndex) continue; m_mutualInformation[att] = mutualInfo(att); } }
/** * 文本分类要特别一点,因为在使用StringToWordVector对象计算文本中词项(attribute)权重的时候需要用到全局变量,比如DF,所以这里需要批量处理 * 在weka中要注意有些机器学习算法是批处理有些不是 */ public void finishBatch() throws Exception { filter.setIDFTransform(true); filter.setInputFormat(instances); Instances filteredData = Filter.useFilter(instances, filter); // 这才真正产生符合weka算法输入格式的数据集 classifier.buildClassifier(filteredData); // 真正的训练分类器 }
public void exec(PrintWriter printer) { try { FileWriter outFile = null; PrintWriter out = null; if (printer == null) { outFile = new FileWriter(id + ".results"); out = new PrintWriter(outFile); } else out = printer; DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); ProcessTweets tweetsProcessor = null; System.out.println("***************************************"); System.out.println("***\tEXECUTING TEST\t" + id + "***"); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println("Train size:" + traincorpus.size()); System.out.println("Test size:" + testcorpus.size()); out.println("***************************************"); out.println("***\tEXECUTING TEST\t***"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("Train size:" + traincorpus.size()); out.println("Test size:" + testcorpus.size()); String cloneID = ""; boolean clonar = false; if (baseline) { System.out.println("***************************************"); System.out.println("***\tEXECUTING TEST BASELINE\t***"); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println("Train size:" + traincorpus.size()); System.out.println("Test size:" + testcorpus.size()); out.println("***************************************"); out.println("***\tEXECUTING TEST\t***"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("Train size:" + traincorpus.size()); out.println("Test size:" + testcorpus.size()); BaselineClassifier base = new BaselineClassifier(testcorpus, 8); precision = base.getPrecision(); recall = base.getRecall(); fmeasure = base.getFmeasure(); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); out.flush(); out.close(); return; } else { System.out.println("Stemming: " + stemming); System.out.println("Lematization:" + lematization); System.out.println("URLs:" + urls); System.out.println("Hashtags:" + hashtags); System.out.println("Mentions:" + mentions); System.out.println("Unigrams:" + unigrams); System.out.println("Bigrams:" + bigrams); System.out.println("TF:" + tf); System.out.println("TF-IDF:" + tfidf); out.println("Stemming: " + stemming); out.println("Lematization:" + lematization); out.println("URLs:" + urls); out.println("Hashtags:" + hashtags); out.println("Mentions:" + mentions); out.println("Unigrams:" + unigrams); out.println("Bigrams:" + bigrams); out.println("TF:" + tf); out.println("TF-IDF:" + tfidf); } // Si tengo los tweets procesados, me evito un nuevo proceso System.out.println("1-Process tweets " + dateFormat.format(new Date())); out.println("1-Process tweets " + dateFormat.format(new Date())); List<ProcessedTweet> train = null; String[] ids = id.split("-"); cloneID = ids[0] + "-" + (Integer.valueOf(ids[1]) + 6); if (((Integer.valueOf(ids[1]) / 6) % 2) == 0) clonar = true; if (new File(id + "-train.ptweets").exists()) { train = ProcessedTweetSerialization.fromFile(id + "-train.ptweets"); tweetsProcessor = new ProcessTweets(stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); if (lematization) { tweetsProcessor.doLematization(train); } if (stemming) { tweetsProcessor.doStemming(train); } } else { tweetsProcessor = new ProcessTweets(stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); // Esto del set training es un añadido para poder diferenciar los idiomas de las url en el // corpus paralelo // tweetsProcessor.setTraining(true); train = tweetsProcessor.processTweets(traincorpus); // tweetsProcessor.setTraining(false); ProcessedTweetSerialization.toFile(id + "-train.ptweets", train); /* if (clonar) { File f = new File (id+"-train.ptweets"); Path p = f.toPath(); CopyOption[] options = new CopyOption[]{ StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.COPY_ATTRIBUTES }; Files.copy(p, new File (cloneID+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+12)+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+18)+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+24)+"-train.ptweets").toPath(), options); Files.copy(p, new File (ids[0]+"-"+(Integer.valueOf(ids[1])+30)+"-train.ptweets").toPath(), options); } */ } // Generamos las BOW. Igual que antes, si existen no las creo. System.out.println("2-Fill topics " + dateFormat.format(new Date())); out.println("2-Fill topics " + dateFormat.format(new Date())); TopicsList topics = null; if (new File(id + ".topics").exists()) { topics = TopicsSerialization.fromFile(id + ".topics"); if (tf) topics.setSelectionFeature(TopicDesc.TERM_TF); else topics.setSelectionFeature(TopicDesc.TERM_TF_IDF); topics.prepareTopics(); } else { topics = new TopicsList(); if (tf) topics.setSelectionFeature(TopicDesc.TERM_TF); else topics.setSelectionFeature(TopicDesc.TERM_TF_IDF); System.out.println("Filling topics " + dateFormat.format(new Date())); topics.fillTopics(train); System.out.println("Preparing topics topics " + dateFormat.format(new Date())); // Aquí tengo que serializar antes de preparar, porque si no no puedo calcular los tf y // tfidf System.out.println("Serializing topics topics " + dateFormat.format(new Date())); /* if (clonar) { TopicsSerialization.toFile(cloneID+".topics", topics); } */ topics.prepareTopics(); TopicsSerialization.toFile(id + ".topics", topics); } System.out.println("3-Generate arff train file " + dateFormat.format(new Date())); out.println("3-Generate arff train file " + dateFormat.format(new Date())); // Si el fichero arff no existe, lo creo. en caso contrario vengo haciendo lo que hasta ahora, // aprovechar trabajo previo if (!new File(id + "-train.arff").exists()) { BufferedWriter bw = topics.generateArffHeader(id + "-train.arff"); int tope = traincorpus.size(); if (tweetsProcessor == null) tweetsProcessor = new ProcessTweets( stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); for (int indTweet = 0; indTweet < tope; indTweet++) { topics.generateArffVector(bw, train.get(indTweet)); } bw.flush(); bw.close(); } // Ahora proceso los datos de test System.out.println("5-build test dataset " + dateFormat.format(new Date())); out.println("5-build test dataset " + dateFormat.format(new Date())); List<ProcessedTweet> test = null; if (new File(id + "-test.ptweets").exists()) test = ProcessedTweetSerialization.fromFile(id + "-test.ptweets"); else { if (tweetsProcessor == null) tweetsProcessor = new ProcessTweets( stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); test = tweetsProcessor.processTweets(testcorpus); ProcessedTweetSerialization.toFile(id + "-test.ptweets", test); /* if (clonar) { File f = new File (id+"-test.ptweets"); Path p = f.toPath(); CopyOption[] options = new CopyOption[]{ StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.COPY_ATTRIBUTES }; Files.copy(p, new File (cloneID+"-test.ptweets").toPath(), options); } */ } // Si el fichero arff no existe, lo creo. en caso contrario vengo haciendo lo que hasta ahora, // aprovechar trabajo previo if (!new File(id + "-test.arff").exists()) { BufferedWriter bw = topics.generateArffHeader(id + "-test.arff"); int tope = testcorpus.size(); if (tweetsProcessor == null) tweetsProcessor = new ProcessTweets( stemming, lematization, urls, hashtags, mentions, unigrams, bigrams); for (int indTweet = 0; indTweet < tope; indTweet++) { topics.generateArffVector(bw, test.get(indTweet)); } bw.flush(); bw.close(); } int topeTopics = topics.getTopicsList().size(); topics.getTopicsList().clear(); // Genero el clasificador // FJRM 25-08-2013 Lo cambio de orden para intentar liberar la memoria de los topics y tener // más libre System.out.println("4-Generate classifier " + dateFormat.format(new Date())); out.println("4-Generate classifier " + dateFormat.format(new Date())); Classifier cls = null; DataSource sourceTrain = null; Instances dataTrain = null; if (new File(id + "-MNB.classifier").exists()) { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(id + "-MNB.classifier")); cls = (Classifier) ois.readObject(); ois.close(); } else { sourceTrain = new DataSource(id + "-train.arff"); dataTrain = sourceTrain.getDataSet(); if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1); // Entreno el clasificador cls = new weka.classifiers.bayes.NaiveBayesMultinomial(); int clase = dataTrain.numAttributes() - 1; dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(id + "-MNB.classifier")); oos.writeObject(cls); oos.flush(); oos.close(); // data.delete();//no borro para el svm } // Ahora evaluo el clasificador con los datos de test System.out.println("6-Evaluate classifier MNB " + dateFormat.format(new Date())); out.println("6-Evaluate classifier MNB" + dateFormat.format(new Date())); DataSource sourceTest = new DataSource(id + "-test.arff"); Instances dataTest = sourceTest.getDataSet(); int clase = dataTest.numAttributes() - 1; dataTest.setClassIndex(clase); Evaluation eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); // Ahora calculo los valores precision, recall y fmeasure. Además saco las matrices de // confusion precision = 0; recall = 0; fmeasure = 0; for (int ind = 0; ind < topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); /* NO BORRAR System.out.println("7-Evaluate classifier SVM"+dateFormat.format(new Date())); out.println("7-Evaluate classifier SVM"+dateFormat.format(new Date())); if (new File(id+"-SVM.classifier").exists()) { ObjectInputStream ois = new ObjectInputStream(new FileInputStream(id+"-SVM.classifier")); cls = (Classifier) ois.readObject(); ois.close(); } else { if (dataTrain==null) { sourceTrain = new DataSource(id+"-train.arff"); dataTrain = sourceTrain.getDataSet(); if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1); } //Entreno el clasificador cls = new weka.classifiers.functions.LibSVM(); clase = dataTrain.numAttributes()-1; dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(id+"-SVM.classifier")); oos.writeObject(cls); oos.flush(); oos.close(); dataTrain.delete(); } eval.evaluateModel(cls, dataTest); precision=0; recall=0; fmeasure=0; for(int ind=0; ind<topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); */ System.out.println("Done " + dateFormat.format(new Date())); out.println("Done " + dateFormat.format(new Date())); if (printer == null) { out.flush(); out.close(); } // Intento de liberar memoria if (dataTrain != null) dataTrain.delete(); if (dataTest != null) dataTest.delete(); if (train != null) train.clear(); if (test != null) test.clear(); if (topics != null) { topics.getTopicsList().clear(); topics = null; } if (dataTest != null) dataTest.delete(); if (cls != null) cls = null; if (tweetsProcessor != null) tweetsProcessor = null; System.gc(); } catch (Exception e) { e.printStackTrace(); } }
public static void execSVM(String expName) { try { FileWriter outFile = null; PrintWriter out = null; outFile = new FileWriter(expName + "-SVM.results"); out = new PrintWriter(outFile); DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss"); ProcessTweets tweetsProcessor = null; System.out.println("***************************************"); System.out.println("***\tEXECUTING TEST\t" + expName + "***"); System.out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("***************************************"); out.println("***\tEXECUTING TEST\t" + expName + "***"); out.println("+++++++++++++++++++++++++++++++++++++++"); out.println("4-Generate classifier " + dateFormat.format(new Date())); Classifier cls = null; DataSource sourceTrain = new DataSource(expName + "-train.arff"); Instances dataTrain = sourceTrain.getDataSet(); if (dataTrain.classIndex() == -1) dataTrain.setClassIndex(dataTrain.numAttributes() - 1); // Entreno el clasificador // cls = new weka.classifiers.functions.LibSVM(); int clase = dataTrain.numAttributes() - 1; cls = new weka.classifiers.bayes.ComplementNaiveBayes(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(expName + "-SVM.classifier")); oos.writeObject(cls); oos.flush(); oos.close(); DataSource sourceTest = new DataSource(expName + "-test.arff"); Instances dataTest = sourceTest.getDataSet(); dataTest.setClassIndex(clase); Evaluation eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); // Ahora calculo los valores precision, recall y fmeasure. Además saco las matrices de // confusion float precision = 0; float recall = 0; float fmeasure = 0; int topeTopics = 8; for (int ind = 0; ind < topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ CNB ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ CNB ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); // OTRO CLASIFICADOR ZeroR cls = new weka.classifiers.rules.ZeroR(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); precision = 0; recall = 0; fmeasure = 0; for (int ind = 0; ind < topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ ZEROR ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ ZEROR ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); // OTRO CLASIFICADOR J48 /* cls = new weka.classifiers.trees.J48(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); precision=0; recall=0; fmeasure=0; for(int ind=0; ind<topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ J48 ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ J48 ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); //OTRO SMO cls = new weka.classifiers.functions.SMO(); dataTrain.setClassIndex(clase); cls.buildClassifier(dataTrain); eval = new Evaluation(dataTest); eval.evaluateModel(cls, dataTest); precision=0; recall=0; fmeasure=0; for(int ind=0; ind<topeTopics; ind++) { precision += eval.precision(ind); recall += eval.recall(ind); fmeasure += eval.fMeasure(ind); } precision = precision / topeTopics; recall = recall / topeTopics; fmeasure = fmeasure / topeTopics; System.out.println("++++++++++++++ SMO ++++++++++++++++++++"); System.out.println(eval.toMatrixString()); System.out.println("+++++++++++++++++++++++++++++++++++++++"); System.out.printf("Precision: %.3f\n", precision); System.out.printf("Recall: %.3f\n", recall); System.out.printf("F-measure: %.3f\n", fmeasure); System.out.println("***************************************"); out.println("++++++++++++++ SMO ++++++++++++++++++++"); out.println(eval.toMatrixString()); out.println("+++++++++++++++++++++++++++++++++++++++"); out.printf("Precision: %.3f\n", precision); out.printf("Recall: %.3f\n", recall); out.printf("F-measure: %.3f\n", fmeasure); out.println("***************************************"); */ out.flush(); out.close(); dataTest.delete(); dataTrain.delete(); } catch (FileNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
/** * Gets the results for the supplied train and test datasets. Now performs a deep copy of the * classifier before it is built and evaluated (just in case the classifier is not initialized * properly in buildClassifier()). * * @param train the training Instances. * @param test the testing Instances. * @return the results stored in an array. The objects stored in the array may be Strings, * Doubles, or null (for the missing value). * @throws Exception if a problem occurs while getting the results */ public Object[] getResult(Instances train, Instances test) throws Exception { if (train.classAttribute().type() != Attribute.NUMERIC) { throw new Exception("Class attribute is not numeric!"); } if (m_Template == null) { throw new Exception("No classifier has been specified"); } ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean(); boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported(); if (canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled()) thMonitor.setThreadCpuTimeEnabled(true); int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0; Object[] result = new Object[RESULT_SIZE + addm + m_numPluginStatistics]; long thID = Thread.currentThread().getId(); long CPUStartTime = -1, trainCPUTimeElapsed = -1, testCPUTimeElapsed = -1, trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed; Evaluation eval = new Evaluation(train); m_Classifier = AbstractClassifier.makeCopy(m_Template); trainTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); m_Classifier.buildClassifier(train); if (canMeasureCPUTime) trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; testTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); eval.evaluateModel(m_Classifier, test); if (canMeasureCPUTime) testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; testTimeElapsed = System.currentTimeMillis() - testTimeStart; thMonitor = null; m_result = eval.toSummaryString(); // The results stored are all per instance -- can be multiplied by the // number of instances to get absolute numbers int current = 0; result[current++] = new Double(train.numInstances()); result[current++] = new Double(eval.numInstances()); result[current++] = new Double(eval.meanAbsoluteError()); result[current++] = new Double(eval.rootMeanSquaredError()); result[current++] = new Double(eval.relativeAbsoluteError()); result[current++] = new Double(eval.rootRelativeSquaredError()); result[current++] = new Double(eval.correlationCoefficient()); result[current++] = new Double(eval.SFPriorEntropy()); result[current++] = new Double(eval.SFSchemeEntropy()); result[current++] = new Double(eval.SFEntropyGain()); result[current++] = new Double(eval.SFMeanPriorEntropy()); result[current++] = new Double(eval.SFMeanSchemeEntropy()); result[current++] = new Double(eval.SFMeanEntropyGain()); // Timing stats result[current++] = new Double(trainTimeElapsed / 1000.0); result[current++] = new Double(testTimeElapsed / 1000.0); if (canMeasureCPUTime) { result[current++] = new Double((trainCPUTimeElapsed / 1000000.0) / 1000.0); result[current++] = new Double((testCPUTimeElapsed / 1000000.0) / 1000.0); } else { result[current++] = new Double(Utils.missingValue()); result[current++] = new Double(Utils.missingValue()); } // sizes if (m_NoSizeDetermination) { result[current++] = -1.0; result[current++] = -1.0; result[current++] = -1.0; } else { ByteArrayOutputStream bastream = new ByteArrayOutputStream(); ObjectOutputStream oostream = new ObjectOutputStream(bastream); oostream.writeObject(m_Classifier); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(train); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(test); result[current++] = new Double(bastream.size()); } // Prediction interval statistics result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions()); result[current++] = new Double(eval.sizeOfPredictedRegions()); if (m_Classifier instanceof Summarizable) { result[current++] = ((Summarizable) m_Classifier).toSummaryString(); } else { result[current++] = null; } for (int i = 0; i < addm; i++) { if (m_doesProduce[i]) { try { double dv = ((AdditionalMeasureProducer) m_Classifier).getMeasure(m_AdditionalMeasures[i]); if (!Utils.isMissingValue(dv)) { Double value = new Double(dv); result[current++] = value; } else { result[current++] = null; } } catch (Exception ex) { System.err.println(ex); } } else { result[current++] = null; } } // get the actual metrics from the evaluation object List<AbstractEvaluationMetric> metrics = eval.getPluginMetrics(); if (metrics != null) { for (AbstractEvaluationMetric m : metrics) { if (m.appliesToNumericClass()) { List<String> statNames = m.getStatisticNames(); for (String s : statNames) { result[current++] = new Double(m.getStatistic(s)); } } } } if (current != RESULT_SIZE + addm + m_numPluginStatistics) { throw new Error("Results didn't fit RESULT_SIZE"); } return result; }