/** * Classifies the given test instance. The instance has to belong to a dataset when it's being * classified. Note that a classifier MUST implement either this or distributionForInstance(). * * @param instance the instance to be classified * @return the predicted most likely class for the instance or Utils.missingValue() if no * prediction is made * @exception Exception if an error occurred during the prediction */ @Override public double classifyInstance(Instance instance) throws Exception { double[] dist = distributionForInstance(instance); if (dist == null) { throw new Exception("Null distribution predicted"); } switch (instance.classAttribute().type()) { case Attribute.NOMINAL: double max = 0; int maxIndex = 0; for (int i = 0; i < dist.length; i++) { if (dist[i] > max) { maxIndex = i; max = dist[i]; } } if (max > 0) { return maxIndex; } else { return Utils.missingValue(); } case Attribute.NUMERIC: case Attribute.DATE: return dist[0]; default: return Utils.missingValue(); } }
/** * Convert a single instance over. The converted instance is added to the end of the output queue. * * @param instance the instance to convert */ protected void convertInstance(Instance instance) { int index = 0; double[] vals = new double[outputFormatPeek().numAttributes()]; // Copy and convert the values for (int i = 0; i < getInputFormat().numAttributes(); i++) { if (m_DiscretizeCols.isInRange(i) && getInputFormat().attribute(i).isNumeric()) { int j; double currentVal = instance.value(i); if (m_CutPoints[i] == null) { if (instance.isMissing(i)) { vals[index] = Utils.missingValue(); } else { vals[index] = 0; } index++; } else { if (!m_MakeBinary) { if (instance.isMissing(i)) { vals[index] = Utils.missingValue(); } else { for (j = 0; j < m_CutPoints[i].length; j++) { if (currentVal <= m_CutPoints[i][j]) { break; } } vals[index] = j; } index++; } else { for (j = 0; j < m_CutPoints[i].length; j++) { if (instance.isMissing(i)) { vals[index] = Utils.missingValue(); } else if (currentVal <= m_CutPoints[i][j]) { vals[index] = 0; } else { vals[index] = 1; } index++; } } } } else { vals[index] = instance.value(i); index++; } } Instance inst = null; if (instance instanceof SparseInstance) { inst = new SparseInstance(instance.weight(), vals); } else { inst = new DenseInstance(instance.weight(), vals); } inst.setDataset(getOutputFormat()); copyValues(inst, false, instance.dataset(), getOutputFormat()); inst.setDataset(getOutputFormat()); push(inst); }
/** * Adds the prediction intervals as additional attributes at the end. Since classifiers can * returns varying number of intervals per instance, the dataset is filled with missing values for * non-existing intervals. */ protected void addPredictionIntervals() { int maxNum; int num; int i; int n; FastVector preds; FastVector atts; Instances data; Instance inst; Instance newInst; double[] values; double[][] predInt; // determine the maximum number of intervals maxNum = 0; preds = m_Evaluation.predictions(); for (i = 0; i < preds.size(); i++) { num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length; if (num > maxNum) maxNum = num; } // create new header atts = new FastVector(); for (i = 0; i < m_PlotInstances.numAttributes(); i++) atts.addElement(m_PlotInstances.attribute(i)); for (i = 0; i < maxNum; i++) { atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-lowerBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-upperBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-width")); } data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances()); data.setClassIndex(m_PlotInstances.classIndex()); // update data for (i = 0; i < m_PlotInstances.numInstances(); i++) { inst = m_PlotInstances.instance(i); // copy old values values = new double[data.numAttributes()]; System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes()); // add interval data predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals(); for (n = 0; n < maxNum; n++) { if (n < predInt.length) { values[m_PlotInstances.numAttributes() + n * 3 + 0] = predInt[n][0]; values[m_PlotInstances.numAttributes() + n * 3 + 1] = predInt[n][1]; values[m_PlotInstances.numAttributes() + n * 3 + 2] = predInt[n][1] - predInt[n][0]; } else { values[m_PlotInstances.numAttributes() + n * 3 + 0] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n * 3 + 1] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n * 3 + 2] = Utils.missingValue(); } } // create new Instance newInst = new DenseInstance(inst.weight(), values); data.add(newInst); } m_PlotInstances = data; }
/** * Constructs an instance suitable for passing to the model for scoring * * @param incoming the incoming instance * @return an instance with values mapped to be consistent with what the model is expecting */ protected Instance mapIncomingFieldsToModelFields(Instance incoming) { Instances modelHeader = m_model.getHeader(); double[] vals = new double[modelHeader.numAttributes()]; for (int i = 0; i < modelHeader.numAttributes(); i++) { if (m_attributeMap[i] < 0) { // missing or type mismatch vals[i] = Utils.missingValue(); continue; } Attribute modelAtt = modelHeader.attribute(i); Attribute incomingAtt = incoming.dataset().attribute(m_attributeMap[i]); if (incoming.isMissing(incomingAtt.index())) { vals[i] = Utils.missingValue(); continue; } if (modelAtt.isNumeric()) { vals[i] = incoming.value(m_attributeMap[i]); } else if (modelAtt.isNominal()) { String incomingVal = incoming.stringValue(m_attributeMap[i]); int modelIndex = modelAtt.indexOfValue(incomingVal); if (modelIndex < 0) { vals[i] = Utils.missingValue(); } else { vals[i] = modelIndex; } } else if (modelAtt.isString()) { vals[i] = 0; modelAtt.setStringValue(incoming.stringValue(m_attributeMap[i])); } } if (modelHeader.classIndex() >= 0) { // set class to missing value vals[modelHeader.classIndex()] = Utils.missingValue(); } Instance newInst = null; if (incoming instanceof SparseInstance) { newInst = new SparseInstance(incoming.weight(), vals); } else { newInst = new DenseInstance(incoming.weight(), vals); } newInst.setDataset(modelHeader); return newInst; }
/** * Convert an input instance * * @param current the input instance to convert * @return a transformed instance * @throws Exception if a problem occurs */ protected Instance convertInstance(Instance current) throws Exception { double[] vals = new double[getOutputFormat().numAttributes()]; int index = 0; for (int j = 0; j < current.numAttributes(); j++) { if (j != current.classIndex()) { if (m_unchanged != null && m_unchanged.attribute(current.attribute(j).name()) != null) { vals[index++] = current.value(j); } else { Estimator[] estForAtt = m_estimatorLookup.get(current.attribute(j).name()); for (int k = 0; k < current.classAttribute().numValues(); k++) { if (current.isMissing(j)) { vals[index++] = Utils.missingValue(); } else { double e = estForAtt[k].getProbability(current.value(j)); vals[index++] = e; } } } } } vals[vals.length - 1] = current.classValue(); DenseInstance instNew = new DenseInstance(current.weight(), vals); return instNew; }
/** * generates an instance out of the given data * * @param tc the statistics * @param prob the probability * @return the generated instance */ private Instance makeInstance(TwoClassStats tc, double prob) { int count = 0; double[] vals = new double[13]; vals[count++] = tc.getTruePositive(); vals[count++] = tc.getFalseNegative(); vals[count++] = tc.getFalsePositive(); vals[count++] = tc.getTrueNegative(); vals[count++] = tc.getFalsePositiveRate(); vals[count++] = tc.getTruePositiveRate(); vals[count++] = tc.getPrecision(); vals[count++] = tc.getRecall(); vals[count++] = tc.getFallout(); vals[count++] = tc.getFMeasure(); double ss = (tc.getTruePositive() + tc.getFalsePositive()) / (tc.getTruePositive() + tc.getFalsePositive() + tc.getTrueNegative() + tc.getFalseNegative()); vals[count++] = ss; double expectedByChance = (ss * (tc.getTruePositive() + tc.getFalseNegative())); if (expectedByChance < 1) { vals[count++] = Utils.missingValue(); } else { vals[count++] = tc.getTruePositive() / expectedByChance; } vals[count++] = prob; return new DenseInstance(1.0, vals); }
/** * Calculates the area under the precision-recall curve (AUPRC). * * @param tcurve a previously extracted threshold curve Instances. * @return the PRC area, or Double.NaN if you don't pass in a ThresholdCurve generated Instances. */ public static double getPRCArea(Instances tcurve) { final int n = tcurve.numInstances(); if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) { return Double.NaN; } final int pInd = tcurve.attribute(PRECISION_NAME).index(); final int rInd = tcurve.attribute(RECALL_NAME).index(); final double[] pVals = tcurve.attributeToDoubleArray(pInd); final double[] rVals = tcurve.attributeToDoubleArray(rInd); double area = 0; double xlast = rVals[n - 1]; // start from the first real p/r pair (not the artificial zero point) for (int i = n - 2; i >= 0; i--) { double recallDelta = rVals[i] - xlast; area += (pVals[i] * recallDelta); xlast = rVals[i]; } if (area == 0) { return Utils.missingValue(); } return area; }
/** * Input an instance for filtering. Ordinarily the instance is processed and made available for * output immediately. Some filters require all instances be read before producing output. * * @param instance the input instance * @return true if the filtered instance may now be collected with output(). * @throws IllegalStateException if no input structure has been defined. */ @Override public boolean input(Instance instance) { if (getInputFormat() == null) { throw new IllegalStateException("No input instance format defined"); } if (m_NewBatch) { resetQueue(); m_NewBatch = false; } if (getOutputFormat().numAttributes() == 0) { return false; } if (m_selectedAttributes.length == 0) { push(instance); } else { double vals[] = new double[getOutputFormat().numAttributes()]; for (int i = 0; i < instance.numAttributes(); i++) { double currentV = instance.value(i); if (!m_selectedCols.isInRange(i)) { vals[i] = currentV; } else { if (currentV == Utils.missingValue()) { vals[i] = currentV; } else { String currentS = instance.attribute(i).value((int) currentV); String replace = m_ignoreCase ? m_renameMap.get(currentS.toLowerCase()) : m_renameMap.get(currentS); if (replace == null) { vals[i] = currentV; } else { vals[i] = getOutputFormat().attribute(i).indexOfValue(replace); } } } } Instance inst = null; if (instance instanceof SparseInstance) { inst = new SparseInstance(instance.weight(), vals); } else { inst = new DenseInstance(instance.weight(), vals); } inst.setDataset(getOutputFormat()); copyValues(inst, false, instance.dataset(), getOutputFormat()); inst.setDataset(getOutputFormat()); push(inst); } return true; }
/** * Method for building an Id3 tree. * * @param data the training data * @exception Exception if decision tree can't be built successfully */ private void makeTree(Instances data) throws Exception { // Check if no instances have reached this node. if (data.numInstances() == 0) { m_Attribute = null; m_ClassValue = Utils.missingValue(); m_Distribution = new double[data.numClasses()]; return; } // Compute attribute with maximum information gain. double[] infoGains = new double[data.numAttributes()]; Enumeration attEnum = data.enumerateAttributes(); while (attEnum.hasMoreElements()) { Attribute att = (Attribute) attEnum.nextElement(); infoGains[att.index()] = computeInfoGain(data, att); } m_Attribute = data.attribute(Utils.maxIndex(infoGains)); // Make leaf if information gain is zero. // Otherwise create successors. if (Utils.eq(infoGains[m_Attribute.index()], 0)) { m_Attribute = null; m_Distribution = new double[data.numClasses()]; Enumeration instEnum = data.enumerateInstances(); while (instEnum.hasMoreElements()) { Instance inst = (Instance) instEnum.nextElement(); m_Distribution[(int) inst.classValue()]++; } Utils.normalize(m_Distribution); m_ClassValue = Utils.maxIndex(m_Distribution); m_ClassAttribute = data.classAttribute(); } else { Instances[] splitData = splitData(data, m_Attribute); m_Successors = new Id3[m_Attribute.numValues()]; for (int j = 0; j < m_Attribute.numValues(); j++) { m_Successors[j] = new Id3(); m_Successors[j].makeTree(splitData[j]); } } }
/** * Return the full data set. If the structure hasn't yet been determined by a call to getStructure * then method should do so before processing the rest of the data set. * * @return the structure of the data set as an empty set of Instances * @exception IOException if there is no source or parsing fails */ @Override public Instances getDataSet() throws IOException { if ((m_sourceFile == null) && (m_sourceReader == null)) { throw new IOException("No source has been specified"); } if (m_structure == null) { getStructure(); } if (m_st == null) { m_st = new StreamTokenizer(m_sourceReader); initTokenizer(m_st); } m_st.ordinaryChar(m_FieldSeparator.charAt(0)); m_cumulativeStructure = new ArrayList<Hashtable<Object, Integer>>(m_structure.numAttributes()); for (int i = 0; i < m_structure.numAttributes(); i++) { m_cumulativeStructure.add(new Hashtable<Object, Integer>()); } m_cumulativeInstances = new ArrayList<ArrayList<Object>>(); ArrayList<Object> current; while ((current = getInstance(m_st)) != null) { m_cumulativeInstances.add(current); } ArrayList<Attribute> atts = new ArrayList<Attribute>(m_structure.numAttributes()); for (int i = 0; i < m_structure.numAttributes(); i++) { String attname = m_structure.attribute(i).name(); Hashtable<Object, Integer> tempHash = m_cumulativeStructure.get(i); if (tempHash.size() == 0) { atts.add(new Attribute(attname)); } else { if (m_StringAttributes.isInRange(i)) { atts.add(new Attribute(attname, (ArrayList<String>) null)); } else { ArrayList<String> values = new ArrayList<String>(tempHash.size()); // add dummy objects in order to make the ArrayList's size == capacity for (int z = 0; z < tempHash.size(); z++) { values.add("dummy"); } Enumeration e = tempHash.keys(); while (e.hasMoreElements()) { Object ob = e.nextElement(); // if (ob instanceof Double) { int index = ((Integer) tempHash.get(ob)).intValue(); String s = ob.toString(); if (s.startsWith("'") || s.startsWith("\"")) s = s.substring(1, s.length() - 1); values.set(index, new String(s)); // } } atts.add(new Attribute(attname, values)); } } } // make the instances String relationName; if (m_sourceFile != null) relationName = (m_sourceFile.getName()).replaceAll("\\.[cC][sS][vV]$", ""); else relationName = "stream"; Instances dataSet = new Instances(relationName, atts, m_cumulativeInstances.size()); for (int i = 0; i < m_cumulativeInstances.size(); i++) { current = m_cumulativeInstances.get(i); double[] vals = new double[dataSet.numAttributes()]; for (int j = 0; j < current.size(); j++) { Object cval = current.get(j); if (cval instanceof String) { if (((String) cval).compareTo(m_MissingValue) == 0) { vals[j] = Utils.missingValue(); } else { if (dataSet.attribute(j).isString()) { vals[j] = dataSet.attribute(j).addStringValue((String) cval); } else if (dataSet.attribute(j).isNominal()) { // find correct index Hashtable<Object, Integer> lookup = m_cumulativeStructure.get(j); int index = ((Integer) lookup.get(cval)).intValue(); vals[j] = index; } else { throw new IllegalStateException( "Wrong attribute type at position " + (i + 1) + "!!!"); } } } else if (dataSet.attribute(j).isNominal()) { // find correct index Hashtable<Object, Integer> lookup = m_cumulativeStructure.get(j); int index = ((Integer) lookup.get(cval)).intValue(); vals[j] = index; } else if (dataSet.attribute(j).isString()) { vals[j] = dataSet.attribute(j).addStringValue("" + cval); } else { vals[j] = ((Double) cval).doubleValue(); } } dataSet.add(new DenseInstance(1.0, vals)); } m_structure = new Instances(dataSet, 0); setRetrieval(BATCH); m_cumulativeStructure = null; // conserve memory // close the stream m_sourceReader.close(); return dataSet; }
protected Instance makeInstance() throws IOException { if (m_current == null) { return null; } double[] vals = new double[m_structure.numAttributes()]; for (int i = 0; i < m_structure.numAttributes(); i++) { Object val = m_current.get(i); if (val.toString().equals("?")) { vals[i] = Utils.missingValue(); } else if (m_structure.attribute(i).isString()) { vals[i] = 0; m_structure.attribute(i).setStringValue(Utils.unquote(val.toString())); } else if (m_structure.attribute(i).isDate()) { String format = m_structure.attribute(i).getDateFormat(); SimpleDateFormat sdf = new SimpleDateFormat(format); String dateVal = Utils.unquote(val.toString()); try { vals[i] = sdf.parse(dateVal).getTime(); } catch (ParseException e) { throw new IOException( "Unable to parse date value " + dateVal + " using date format " + format + " for date attribute " + m_structure.attribute(i) + " (line: " + m_rowCount + ")"); } } else if (m_structure.attribute(i).isNumeric()) { try { Double v = Double.parseDouble(val.toString()); vals[i] = v.doubleValue(); } catch (NumberFormatException ex) { throw new IOException( "Was expecting a number for attribute " + m_structure.attribute(i).name() + " but read " + val.toString() + " instead. (line: " + m_rowCount + ")"); } } else { // nominal double index = m_structure.attribute(i).indexOfValue(Utils.unquote(val.toString())); if (index < 0) { throw new IOException( "Read unknown nominal value " + val.toString() + "for attribute " + m_structure.attribute(i).name() + " (line: " + m_rowCount + "). Try increasing the size of the memory buffer" + " (-B option) or explicitly specify legal nominal values with " + "the -L option."); } vals[i] = index; } } DenseInstance inst = new DenseInstance(1.0, vals); inst.setDataset(m_structure); return inst; }
/** * Convert a single instance over. The converted instance is added to the end of the output queue. * * @param instance the instance to convert * @throws Exception if instance cannot be converted */ private void convertInstance(Instance instance) throws Exception { Instance inst = null; HashMap symbols = new HashMap(5); if (instance instanceof SparseInstance) { double[] newVals = new double[instance.numAttributes()]; int[] newIndices = new int[instance.numAttributes()]; double[] vals = instance.toDoubleArray(); int ind = 0; double value; for (int j = 0; j < instance.numAttributes(); j++) { if (m_SelectCols.isInRange(j)) { if (instance.attribute(j).isNumeric() && (!Utils.isMissingValue(vals[j])) && (getInputFormat().classIndex() != j)) { symbols.put("A", new Double(vals[j])); symbols.put("MAX", new Double(m_attStats[j].numericStats.max)); symbols.put("MIN", new Double(m_attStats[j].numericStats.min)); symbols.put("MEAN", new Double(m_attStats[j].numericStats.mean)); symbols.put("SD", new Double(m_attStats[j].numericStats.stdDev)); symbols.put("COUNT", new Double(m_attStats[j].numericStats.count)); symbols.put("SUM", new Double(m_attStats[j].numericStats.sum)); symbols.put("SUMSQUARED", new Double(m_attStats[j].numericStats.sumSq)); value = eval(symbols); if (Double.isNaN(value) || Double.isInfinite(value)) { System.err.println("WARNING:Error in evaluating the expression: missing value set"); value = Utils.missingValue(); } if (value != 0.0) { newVals[ind] = value; newIndices[ind] = j; ind++; } } } else { value = vals[j]; if (value != 0.0) { newVals[ind] = value; newIndices[ind] = j; ind++; } } } double[] tempVals = new double[ind]; int[] tempInd = new int[ind]; System.arraycopy(newVals, 0, tempVals, 0, ind); System.arraycopy(newIndices, 0, tempInd, 0, ind); inst = new SparseInstance(instance.weight(), tempVals, tempInd, instance.numAttributes()); } else { double[] vals = instance.toDoubleArray(); for (int j = 0; j < getInputFormat().numAttributes(); j++) { if (m_SelectCols.isInRange(j)) { if (instance.attribute(j).isNumeric() && (!Utils.isMissingValue(vals[j])) && (getInputFormat().classIndex() != j)) { symbols.put("A", new Double(vals[j])); symbols.put("MAX", new Double(m_attStats[j].numericStats.max)); symbols.put("MIN", new Double(m_attStats[j].numericStats.min)); symbols.put("MEAN", new Double(m_attStats[j].numericStats.mean)); symbols.put("SD", new Double(m_attStats[j].numericStats.stdDev)); symbols.put("COUNT", new Double(m_attStats[j].numericStats.count)); symbols.put("SUM", new Double(m_attStats[j].numericStats.sum)); symbols.put("SUMSQUARED", new Double(m_attStats[j].numericStats.sumSq)); vals[j] = eval(symbols); if (Double.isNaN(vals[j]) || Double.isInfinite(vals[j])) { System.err.println("WARNING:Error in Evaluation the Expression: missing value set"); vals[j] = Utils.missingValue(); } } } } inst = new DenseInstance(instance.weight(), vals); } inst.setDataset(instance.dataset()); push(inst); }
@Override protected void notifyJobOutputListeners() { weka.classifiers.Classifier finalClassifier = ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassifier(); Instances modelHeader = ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getTrainingHeader(); String classAtt = ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassAttribute(); try { weka.distributed.spark.WekaClassifierSparkJob.setClassIndex(classAtt, modelHeader, true); } catch (Exception ex) { if (m_log != null) { m_log.logMessage(statusMessagePrefix() + ex.getMessage()); } ex.printStackTrace(); } if (finalClassifier == null) { if (m_log != null) { m_log.logMessage(statusMessagePrefix() + "No classifier produced!"); } } if (modelHeader == null) { if (m_log != null) { m_log.logMessage(statusMessagePrefix() + "No training header available for the model!"); } } if (finalClassifier != null) { if (m_textListeners.size() > 0) { String textual = finalClassifier.toString(); String title = "Spark: "; String classifierSpec = finalClassifier.getClass().getName(); if (finalClassifier instanceof OptionHandler) { classifierSpec += " " + Utils.joinOptions(((OptionHandler) finalClassifier).getOptions()); } title += classifierSpec; TextEvent te = new TextEvent(this, textual, title); for (TextListener t : m_textListeners) { t.acceptText(te); } } if (modelHeader != null) { // have to add a single bogus instance to the header to trick // the SerializedModelSaver into saving it (since it ignores // structure only DataSetEvents) :-) double[] vals = new double[modelHeader.numAttributes()]; for (int i = 0; i < vals.length; i++) { vals[i] = Utils.missingValue(); } Instance tempI = new DenseInstance(1.0, vals); modelHeader.add(tempI); DataSetEvent dse = new DataSetEvent(this, modelHeader); BatchClassifierEvent be = new BatchClassifierEvent(this, finalClassifier, dse, dse, 1, 1); for (BatchClassifierListener b : m_classifierListeners) { b.acceptClassifier(be); } } } }
/** * Gets the results for the supplied train and test datasets. Now performs a deep copy of the * classifier before it is built and evaluated (just in case the classifier is not initialized * properly in buildClassifier()). * * @param train the training Instances. * @param test the testing Instances. * @return the results stored in an array. The objects stored in the array may be Strings, * Doubles, or null (for the missing value). * @throws Exception if a problem occurs while getting the results */ public Object[] getResult(Instances train, Instances test) throws Exception { if (train.classAttribute().type() != Attribute.NUMERIC) { throw new Exception("Class attribute is not numeric!"); } if (m_Template == null) { throw new Exception("No classifier has been specified"); } ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean(); boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported(); if (canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled()) thMonitor.setThreadCpuTimeEnabled(true); int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0; Object[] result = new Object[RESULT_SIZE + addm + m_numPluginStatistics]; long thID = Thread.currentThread().getId(); long CPUStartTime = -1, trainCPUTimeElapsed = -1, testCPUTimeElapsed = -1, trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed; Evaluation eval = new Evaluation(train); m_Classifier = AbstractClassifier.makeCopy(m_Template); trainTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); m_Classifier.buildClassifier(train); if (canMeasureCPUTime) trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; trainTimeElapsed = System.currentTimeMillis() - trainTimeStart; testTimeStart = System.currentTimeMillis(); if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID); eval.evaluateModel(m_Classifier, test); if (canMeasureCPUTime) testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime; testTimeElapsed = System.currentTimeMillis() - testTimeStart; thMonitor = null; m_result = eval.toSummaryString(); // The results stored are all per instance -- can be multiplied by the // number of instances to get absolute numbers int current = 0; result[current++] = new Double(train.numInstances()); result[current++] = new Double(eval.numInstances()); result[current++] = new Double(eval.meanAbsoluteError()); result[current++] = new Double(eval.rootMeanSquaredError()); result[current++] = new Double(eval.relativeAbsoluteError()); result[current++] = new Double(eval.rootRelativeSquaredError()); result[current++] = new Double(eval.correlationCoefficient()); result[current++] = new Double(eval.SFPriorEntropy()); result[current++] = new Double(eval.SFSchemeEntropy()); result[current++] = new Double(eval.SFEntropyGain()); result[current++] = new Double(eval.SFMeanPriorEntropy()); result[current++] = new Double(eval.SFMeanSchemeEntropy()); result[current++] = new Double(eval.SFMeanEntropyGain()); // Timing stats result[current++] = new Double(trainTimeElapsed / 1000.0); result[current++] = new Double(testTimeElapsed / 1000.0); if (canMeasureCPUTime) { result[current++] = new Double((trainCPUTimeElapsed / 1000000.0) / 1000.0); result[current++] = new Double((testCPUTimeElapsed / 1000000.0) / 1000.0); } else { result[current++] = new Double(Utils.missingValue()); result[current++] = new Double(Utils.missingValue()); } // sizes if (m_NoSizeDetermination) { result[current++] = -1.0; result[current++] = -1.0; result[current++] = -1.0; } else { ByteArrayOutputStream bastream = new ByteArrayOutputStream(); ObjectOutputStream oostream = new ObjectOutputStream(bastream); oostream.writeObject(m_Classifier); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(train); result[current++] = new Double(bastream.size()); bastream = new ByteArrayOutputStream(); oostream = new ObjectOutputStream(bastream); oostream.writeObject(test); result[current++] = new Double(bastream.size()); } // Prediction interval statistics result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions()); result[current++] = new Double(eval.sizeOfPredictedRegions()); if (m_Classifier instanceof Summarizable) { result[current++] = ((Summarizable) m_Classifier).toSummaryString(); } else { result[current++] = null; } for (int i = 0; i < addm; i++) { if (m_doesProduce[i]) { try { double dv = ((AdditionalMeasureProducer) m_Classifier).getMeasure(m_AdditionalMeasures[i]); if (!Utils.isMissingValue(dv)) { Double value = new Double(dv); result[current++] = value; } else { result[current++] = null; } } catch (Exception ex) { System.err.println(ex); } } else { result[current++] = null; } } // get the actual metrics from the evaluation object List<AbstractEvaluationMetric> metrics = eval.getPluginMetrics(); if (metrics != null) { for (AbstractEvaluationMetric m : metrics) { if (m.appliesToNumericClass()) { List<String> statNames = m.getStatisticNames(); for (String s : statNames) { result[current++] = new Double(m.getStatistic(s)); } } } } if (current != RESULT_SIZE + addm + m_numPluginStatistics) { throw new Error("Results didn't fit RESULT_SIZE"); } return result; }
/** * Classifies a given instance. * * @param instance the instance to be classified * @return index of the predicted class * @throws Exception if an error occurred during the prediction */ public double classifyInstance(Instance instance) throws Exception { if (m_GroovyObject != null) return m_GroovyObject.classifyInstance(instance); else return Utils.missingValue(); }
public static Instances retrieveInstances(InstanceQueryAdapter adapter, ResultSet rs) throws Exception { if (adapter.getDebug()) System.err.println("Getting metadata..."); ResultSetMetaData md = rs.getMetaData(); if (adapter.getDebug()) System.err.println("Completed getting metadata..."); // Determine structure of the instances int numAttributes = md.getColumnCount(); int[] attributeTypes = new int[numAttributes]; Hashtable[] nominalIndexes = new Hashtable[numAttributes]; FastVector[] nominalStrings = new FastVector[numAttributes]; for (int i = 1; i <= numAttributes; i++) { /* switch (md.getColumnType(i)) { case Types.CHAR: case Types.VARCHAR: case Types.LONGVARCHAR: case Types.BINARY: case Types.VARBINARY: case Types.LONGVARBINARY:*/ switch (adapter.translateDBColumnType(md.getColumnTypeName(i))) { case STRING: // System.err.println("String --> nominal"); attributeTypes[i - 1] = Attribute.NOMINAL; nominalIndexes[i - 1] = new Hashtable(); nominalStrings[i - 1] = new FastVector(); break; case TEXT: // System.err.println("Text --> string"); attributeTypes[i - 1] = Attribute.STRING; nominalIndexes[i - 1] = new Hashtable(); nominalStrings[i - 1] = new FastVector(); break; case BOOL: // System.err.println("boolean --> nominal"); attributeTypes[i - 1] = Attribute.NOMINAL; nominalIndexes[i - 1] = new Hashtable(); nominalIndexes[i - 1].put("false", new Double(0)); nominalIndexes[i - 1].put("true", new Double(1)); nominalStrings[i - 1] = new FastVector(); nominalStrings[i - 1].addElement("false"); nominalStrings[i - 1].addElement("true"); break; case DOUBLE: // System.err.println("BigDecimal --> numeric"); attributeTypes[i - 1] = Attribute.NUMERIC; break; case BYTE: // System.err.println("byte --> numeric"); attributeTypes[i - 1] = Attribute.NUMERIC; break; case SHORT: // System.err.println("short --> numeric"); attributeTypes[i - 1] = Attribute.NUMERIC; break; case INTEGER: // System.err.println("int --> numeric"); attributeTypes[i - 1] = Attribute.NUMERIC; break; case LONG: // System.err.println("long --> numeric"); attributeTypes[i - 1] = Attribute.NUMERIC; break; case FLOAT: // System.err.println("float --> numeric"); attributeTypes[i - 1] = Attribute.NUMERIC; break; case DATE: attributeTypes[i - 1] = Attribute.DATE; break; case TIME: attributeTypes[i - 1] = Attribute.DATE; break; default: // System.err.println("Unknown column type"); attributeTypes[i - 1] = Attribute.STRING; } } // For sqlite // cache column names because the last while(rs.next()) { iteration for // the tuples below will close the md object: Vector<String> columnNames = new Vector<String>(); for (int i = 0; i < numAttributes; i++) { columnNames.add(md.getColumnLabel(i + 1)); } // Step through the tuples if (adapter.getDebug()) System.err.println("Creating instances..."); FastVector instances = new FastVector(); int rowCount = 0; while (rs.next()) { if (rowCount % 100 == 0) { if (adapter.getDebug()) { System.err.print("read " + rowCount + " instances \r"); System.err.flush(); } } double[] vals = new double[numAttributes]; for (int i = 1; i <= numAttributes; i++) { /*switch (md.getColumnType(i)) { case Types.CHAR: case Types.VARCHAR: case Types.LONGVARCHAR: case Types.BINARY: case Types.VARBINARY: case Types.LONGVARBINARY:*/ switch (adapter.translateDBColumnType(md.getColumnTypeName(i))) { case STRING: String str = rs.getString(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { Double index = (Double) nominalIndexes[i - 1].get(str); if (index == null) { index = new Double(nominalStrings[i - 1].size()); nominalIndexes[i - 1].put(str, index); nominalStrings[i - 1].addElement(str); } vals[i - 1] = index.doubleValue(); } break; case TEXT: String txt = rs.getString(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { Double index = (Double) nominalIndexes[i - 1].get(txt); if (index == null) { index = new Double(nominalStrings[i - 1].size()); nominalIndexes[i - 1].put(txt, index); nominalStrings[i - 1].addElement(txt); } vals[i - 1] = index.doubleValue(); } break; case BOOL: boolean boo = rs.getBoolean(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { vals[i - 1] = (boo ? 1.0 : 0.0); } break; case DOUBLE: // BigDecimal bd = rs.getBigDecimal(i, 4); double dd = rs.getDouble(i); // Use the column precision instead of 4? if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { // newInst.setValue(i - 1, bd.doubleValue()); vals[i - 1] = dd; } break; case BYTE: byte by = rs.getByte(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { vals[i - 1] = (double) by; } break; case SHORT: short sh = rs.getShort(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { vals[i - 1] = (double) sh; } break; case INTEGER: int in = rs.getInt(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { vals[i - 1] = (double) in; } break; case LONG: long lo = rs.getLong(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { vals[i - 1] = (double) lo; } break; case FLOAT: float fl = rs.getFloat(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { vals[i - 1] = (double) fl; } break; case DATE: Date date = rs.getDate(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { // TODO: Do a value check here. vals[i - 1] = (double) date.getTime(); } break; case TIME: Time time = rs.getTime(i); if (rs.wasNull()) { vals[i - 1] = Utils.missingValue(); } else { // TODO: Do a value check here. vals[i - 1] = (double) time.getTime(); } break; default: vals[i - 1] = Utils.missingValue(); } } Instance newInst; if (adapter.getSparseData()) { newInst = new SparseInstance(1.0, vals); } else { newInst = new DenseInstance(1.0, vals); } instances.addElement(newInst); rowCount++; } // disconnectFromDatabase(); (perhaps other queries might be made) // Create the header and add the instances to the dataset if (adapter.getDebug()) System.err.println("Creating header..."); FastVector attribInfo = new FastVector(); for (int i = 0; i < numAttributes; i++) { /* Fix for databases that uppercase column names */ // String attribName = attributeCaseFix(md.getColumnName(i + 1)); String attribName = adapter.attributeCaseFix(columnNames.get(i)); switch (attributeTypes[i]) { case Attribute.NOMINAL: attribInfo.addElement(new Attribute(attribName, nominalStrings[i])); break; case Attribute.NUMERIC: attribInfo.addElement(new Attribute(attribName)); break; case Attribute.STRING: Attribute att = new Attribute(attribName, (FastVector) null); attribInfo.addElement(att); for (int n = 0; n < nominalStrings[i].size(); n++) { att.addStringValue((String) nominalStrings[i].elementAt(n)); } break; case Attribute.DATE: attribInfo.addElement(new Attribute(attribName, (String) null)); break; default: throw new Exception("Unknown attribute type"); } } Instances result = new Instances("QueryResult", attribInfo, instances.size()); for (int i = 0; i < instances.size(); i++) { result.add((Instance) instances.elementAt(i)); } return result; }