/** * Gets the classifier specification string, which contains the class name of the classifier and * any options to the classifier. * * @return the classifier string. */ protected String getClassifierSpec() { Classifier c = getClassifier(); if (c instanceof OptionHandler) { return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) c).getOptions()); } return c.getClass().getName(); }
/** * Gets the classifier specification string, which contains the class name of the classifier and * any options to the classifier * * @param index the index of the classifier string to retrieve, starting from 0. * @return the classifier string, or the empty string if no classifier has been assigned (or the * index given is out of range). */ protected String getClassifierSpec(int index) { if (m_Classifiers.length < index) { return ""; } Classifier c = getClassifier(index); if (c instanceof OptionHandler) { return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) c).getOptions()); } return c.getClass().getName(); }
/** * Returns an enumeration describing the available options.. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(1); newVector.addElement( new Option( "\tSkips the determination of sizes (train/test/classifier)\n" + "\t(default: sizes are determined)", "no-size", 0, "-no-size")); newVector.addElement( new Option( "\tThe full class name of the classifier.\n" + "\teg: weka.classifiers.bayes.NaiveBayes", "W", 1, "-W <class name>")); if ((m_Template != null) && (m_Template instanceof OptionHandler)) { newVector.addElement( new Option( "", "", 0, "\nOptions specific to classifier " + m_Template.getClass().getName() + ":")); Enumeration enu = ((OptionHandler) m_Template).listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } } return newVector.elements(); }
/** * Gets the key describing the current SplitEvaluator. For example This may contain the name of * the classifier used for classifier predictive evaluation. The number of key fields must be * constant for a given SplitEvaluator. * * @return an array of objects containing the key. */ public Object[] getKey() { Object[] key = new Object[KEY_SIZE]; key[0] = m_Template.getClass().getName(); key[1] = m_ClassifierOptions; key[2] = m_ClassifierVersion; return key; }
/** * Buildclassifier selects a classifier from the set of classifiers by minimising error on the * training data. * * @param data the training data to be used for generating the boosted classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifiers.length == 0) { throw new Exception("No base classifiers have been set!"); } Instances newData = new Instances(data); newData.deleteWithMissingClass(); newData.randomize(new Random(m_Seed)); if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1)) newData.stratify(m_NumXValFolds); Instances train = newData; // train on all data by default Instances test = newData; // test on training data by default Classifier bestClassifier = null; int bestIndex = -1; double bestPerformance = Double.NaN; int numClassifiers = m_Classifiers.length; for (int i = 0; i < numClassifiers; i++) { Classifier currentClassifier = getClassifier(i); Evaluation evaluation; if (m_NumXValFolds > 1) { evaluation = new Evaluation(newData); for (int j = 0; j < m_NumXValFolds; j++) { train = newData.trainCV(m_NumXValFolds, j); test = newData.testCV(m_NumXValFolds, j); currentClassifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(currentClassifier, test); } } else { currentClassifier.buildClassifier(train); evaluation = new Evaluation(train); evaluation.evaluateModel(currentClassifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println( "Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName()); } if ((i == 0) || (error < bestPerformance)) { bestClassifier = currentClassifier; bestPerformance = error; bestIndex = i; } } m_ClassifierIndex = bestIndex; m_Classifier = bestClassifier; if (m_NumXValFolds > 1) { m_Classifier.buildClassifier(newData); } }
/** * Returns a text description of the split evaluator. * * @return a text description of the split evaluator. */ public String toString() { String result = "RegressionSplitEvaluator: "; if (m_Template == null) { return result + "<null> classifier"; } return result + m_Template.getClass().getName() + " " + m_ClassifierOptions + "(version " + m_ClassifierVersion + ")"; }
/** Updates the options that the current classifier is using. */ protected void updateOptions() { if (m_Template instanceof OptionHandler) { m_ClassifierOptions = Utils.joinOptions(((OptionHandler) m_Template).getOptions()); } else { m_ClassifierOptions = ""; } if (m_Template instanceof Serializable) { ObjectStreamClass obs = ObjectStreamClass.lookup(m_Template.getClass()); m_ClassifierVersion = "" + obs.getSerialVersionUID(); } else { m_ClassifierVersion = ""; } }
/** * Returns the value of the named measure * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (m_Template instanceof AdditionalMeasureProducer) { if (m_Classifier == null) { throw new IllegalArgumentException( "ClassifierSplitEvaluator: " + "Can't return result for measure, " + "classifier has not been built yet."); } return ((AdditionalMeasureProducer) m_Classifier).getMeasure(additionalMeasureName); } else { throw new IllegalArgumentException( "ClassifierSplitEvaluator: " + "Can't return value for : " + additionalMeasureName + ". " + m_Template.getClass().getName() + " " + "is not an AdditionalMeasureProducer"); } }
private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData) throws Exception { System.err.println( "INFO: Starting split validation to predict '" + trainData.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' (#train=" + trainData.numInstances() + ",#test=" + testData.numInstances() + ") ..."); if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set"); c.buildClassifier(trainData); Evaluation eval = new Evaluation(testData); eval.useNoPriors(); double[] predictions = eval.evaluateModel(c, testData); System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); // write predictions to file { System.err.println("INFO: Writing predictions to file ..."); Writer out = new FileWriter("prediction.trec"); writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out); out.close(); } // write predicted distributions to CSV { System.err.println("INFO: Writing predicted distributions to CSV ..."); Writer out = new FileWriter("predicted_distribution.csv"); writePredictedDistributions(c, testData, 0, out); out.close(); } }
/** outputs some data about the classifier */ public String toString() { StringBuffer result; result = new StringBuffer(); result.append("Weka - Demo\n===========\n\n"); result.append( "Classifier...: " + m_Classifier.getClass().getName() + " " + Utils.joinOptions(m_Classifier.getOptions()) + "\n"); if (m_Filter instanceof OptionHandler) result.append( "Filter.......: " + m_Filter.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) m_Filter).getOptions()) + "\n"); else result.append("Filter.......: " + m_Filter.getClass().getName() + "\n"); result.append("Training file: " + m_TrainingFile + "\n"); result.append("\n"); result.append(m_Classifier.toString() + "\n"); result.append(m_Evaluation.toSummaryString() + "\n"); try { result.append(m_Evaluation.toMatrixString() + "\n"); } catch (Exception e) { e.printStackTrace(); } try { result.append(m_Evaluation.toClassDetailsString() + "\n"); } catch (Exception e) { e.printStackTrace(); } return result.toString(); }
static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception { System.err.println( "INFO: Starting crossvalidation to predict '" + data.classAttribute().name() + "' using '" + c.getClass().getCanonicalName() + ":" + Arrays.toString(c.getOptions()) + "' ..."); StringBuffer sb = new StringBuffer(); Evaluation eval = new Evaluation(data); eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE); // write predictions to file { Writer out = new FileWriter("cv.log"); out.write(sb.toString()); out.close(); } System.out.println(eval.toClassDetailsString()); System.out.println(eval.toSummaryString("\nResults\n======\n", false)); }
/** * Takes an evaluation object from a task and aggregates it with the overall one. * * @param eval the evaluation object to aggregate * @param classifier the classifier used by the task * @param testData the testData from the task * @param plotInstances the ClassifierErrorsPlotInstances object from the task * @param setNum the set number processed by the task * @param maxSetNum the maximum number of sets in this batch */ protected synchronized void aggregateEvalTask( Evaluation eval, Classifier classifier, Instances testData, ClassifierErrorsPlotInstances plotInstances, int setNum, int maxSetNum) { m_eval.aggregate(eval); if (m_aggregatedPlotInstances == null) { m_aggregatedPlotInstances = new Instances(plotInstances.getPlotInstances()); m_aggregatedPlotShapes = plotInstances.getPlotShapes(); m_aggregatedPlotSizes = plotInstances.getPlotSizes(); } else { Instances temp = plotInstances.getPlotInstances(); for (int i = 0; i < temp.numInstances(); i++) { m_aggregatedPlotInstances.add(temp.get(i)); m_aggregatedPlotShapes.addElement(plotInstances.getPlotShapes().get(i)); m_aggregatedPlotSizes.addElement(plotInstances.getPlotSizes().get(i)); } } m_setsComplete++; // if (ce.getSetNumber() == ce.getMaxSetNumber()) { if (m_setsComplete == maxSetNum) { try { String textTitle = classifier.getClass().getName(); String textOptions = ""; if (classifier instanceof OptionHandler) { textOptions = Utils.joinOptions(((OptionHandler) classifier).getOptions()); } textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length()); String resultT = "=== Evaluation result ===\n\n" + "Scheme: " + textTitle + "\n" + ((textOptions.length() > 0) ? "Options: " + textOptions + "\n" : "") + "Relation: " + testData.relationName() + "\n\n" + m_eval.toSummaryString(); if (testData.classAttribute().isNominal()) { resultT += "\n" + m_eval.toClassDetailsString() + "\n" + m_eval.toMatrixString(); } TextEvent te = new TextEvent(ClassifierPerformanceEvaluator.this, resultT, textTitle); notifyTextListeners(te); // set up visualizable errors if (m_visualizableErrorListeners.size() > 0) { PlotData2D errorD = new PlotData2D(m_aggregatedPlotInstances); errorD.setShapeSize(m_aggregatedPlotSizes); errorD.setShapeType(m_aggregatedPlotShapes); errorD.setPlotName(textTitle + " " + textOptions); /* PlotData2D errorD = m_PlotInstances.getPlotData( textTitle + " " + textOptions); */ VisualizableErrorEvent vel = new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this, errorD); notifyVisualizableErrorListeners(vel); m_PlotInstances.cleanUp(); } if (testData.classAttribute().isNominal() && m_thresholdListeners.size() > 0) { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_eval.predictions(), 0); result.setRelationName(testData.relationName()); PlotData2D pd = new PlotData2D(result); String htmlTitle = "<html><font size=-2>" + textTitle; String newOptions = ""; if (classifier instanceof OptionHandler) { String[] options = ((OptionHandler) classifier).getOptions(); if (options.length > 0) { for (int ii = 0; ii < options.length; ii++) { if (options[ii].length() == 0) { continue; } if (options[ii].charAt(0) == '-' && !(options[ii].charAt(1) >= '0' && options[ii].charAt(1) <= '9')) { newOptions += "<br>"; } newOptions += options[ii]; } } } htmlTitle += " " + newOptions + "<br>" + " (class: " + testData.classAttribute().value(0) + ")" + "</font></html>"; pd.setPlotName(textTitle + " (class: " + testData.classAttribute().value(0) + ")"); pd.setPlotNameHTML(htmlTitle); boolean[] connectPoints = new boolean[result.numInstances()]; for (int jj = 1; jj < connectPoints.length; jj++) { connectPoints[jj] = true; } pd.setConnectPoints(connectPoints); ThresholdDataEvent rde = new ThresholdDataEvent( ClassifierPerformanceEvaluator.this, pd, testData.classAttribute()); notifyThresholdListeners(rde); } if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix() + "Finished."); } } catch (Exception ex) { if (m_logger != null) { m_logger.logMessage( "[ClassifierPerformanceEvaluator] " + statusMessagePrefix() + " problem constructing evaluation results. " + ex.getMessage()); } ex.printStackTrace(); } finally { m_visual.setStatic(); // save memory m_PlotInstances = null; m_setsComplete = 0; m_tasks = null; m_aggregatedPlotInstances = null; } } }
@Override protected void notifyJobOutputListeners() { weka.classifiers.Classifier finalClassifier = ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassifier(); Instances modelHeader = ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getTrainingHeader(); String classAtt = ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassAttribute(); try { weka.distributed.spark.WekaClassifierSparkJob.setClassIndex(classAtt, modelHeader, true); } catch (Exception ex) { if (m_log != null) { m_log.logMessage(statusMessagePrefix() + ex.getMessage()); } ex.printStackTrace(); } if (finalClassifier == null) { if (m_log != null) { m_log.logMessage(statusMessagePrefix() + "No classifier produced!"); } } if (modelHeader == null) { if (m_log != null) { m_log.logMessage(statusMessagePrefix() + "No training header available for the model!"); } } if (finalClassifier != null) { if (m_textListeners.size() > 0) { String textual = finalClassifier.toString(); String title = "Spark: "; String classifierSpec = finalClassifier.getClass().getName(); if (finalClassifier instanceof OptionHandler) { classifierSpec += " " + Utils.joinOptions(((OptionHandler) finalClassifier).getOptions()); } title += classifierSpec; TextEvent te = new TextEvent(this, textual, title); for (TextListener t : m_textListeners) { t.acceptText(te); } } if (modelHeader != null) { // have to add a single bogus instance to the header to trick // the SerializedModelSaver into saving it (since it ignores // structure only DataSetEvents) :-) double[] vals = new double[modelHeader.numAttributes()]; for (int i = 0; i < vals.length; i++) { vals[i] = Utils.missingValue(); } Instance tempI = new DenseInstance(1.0, vals); modelHeader.add(tempI); DataSetEvent dse = new DataSetEvent(this, modelHeader); BatchClassifierEvent be = new BatchClassifierEvent(this, finalClassifier, dse, dse, 1, 1); for (BatchClassifierListener b : m_classifierListeners) { b.acceptClassifier(be); } } } }