/** * Calculates the area under the ROC curve as the Wilcoxon-Mann-Whitney statistic. * * @param tcurve a previously extracted threshold curve Instances. * @return the ROC area, or Double.NaN if you don't pass in a ThresholdCurve generated Instances. */ public static double getROCArea(Instances tcurve) { final int n = tcurve.numInstances(); if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) { return Double.NaN; } final int tpInd = tcurve.attribute(TRUE_POS_NAME).index(); final int fpInd = tcurve.attribute(FALSE_POS_NAME).index(); final double[] tpVals = tcurve.attributeToDoubleArray(tpInd); final double[] fpVals = tcurve.attributeToDoubleArray(fpInd); double area = 0.0, cumNeg = 0.0; final double totalPos = tpVals[0]; final double totalNeg = fpVals[0]; for (int i = 0; i < n; i++) { double cip, cin; if (i < n - 1) { cip = tpVals[i] - tpVals[i + 1]; cin = fpVals[i] - fpVals[i + 1]; } else { cip = tpVals[n - 1]; cin = fpVals[n - 1]; } area += cip * (cumNeg + (0.5 * cin)); cumNeg += cin; } area /= (totalNeg * totalPos); return area; }
/** * Calculates the area under the precision-recall curve (AUPRC). * * @param tcurve a previously extracted threshold curve Instances. * @return the PRC area, or Double.NaN if you don't pass in a ThresholdCurve generated Instances. */ public static double getPRCArea(Instances tcurve) { final int n = tcurve.numInstances(); if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) { return Double.NaN; } final int pInd = tcurve.attribute(PRECISION_NAME).index(); final int rInd = tcurve.attribute(RECALL_NAME).index(); final double[] pVals = tcurve.attributeToDoubleArray(pInd); final double[] rVals = tcurve.attributeToDoubleArray(rInd); double area = 0; double xlast = rVals[n - 1]; // start from the first real p/r pair (not the artificial zero point) for (int i = n - 2; i >= 0; i--) { double recallDelta = rVals[i] - xlast; area += (pVals[i] * recallDelta); xlast = rVals[i]; } if (area == 0) { return Utils.missingValue(); } return area; }
/** * Builds the model of the base learner. * * @param data the training data * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); if (m_Classifier == null) { throw new Exception("No base classifier has been set!"); } if (m_MatrixSource == MATRIX_ON_DEMAND) { String costName = data.relationName() + CostMatrix.FILE_EXTENSION; File costFile = new File(getOnDemandDirectory(), costName); if (!costFile.exists()) { throw new Exception("On-demand cost file doesn't exist: " + costFile); } setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(costFile)))); } else if (m_CostMatrix == null) { // try loading an old format cost file m_CostMatrix = new CostMatrix(data.numClasses()); m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(m_CostFile))); } if (!m_MinimizeExpectedCost) { Random random = null; if (!(m_Classifier instanceof WeightedInstancesHandler)) { random = new Random(m_Seed); } data = m_CostMatrix.applyCostMatrix(data, random); } m_Classifier.buildClassifier(data); }
/** * Determines the output format based on the input format and returns this. In case the output * format cannot be returned immediately, i.e., immediateOutputFormat() returns false, then this * method will be called from batchFinished(). * * @param inputFormat the input format to base the output format on * @return the output format * @throws Exception in case the determination goes wrong * @see #hasImmediateOutputFormat() * @see #batchFinished() */ protected Instances determineOutputFormat(Instances inputFormat) throws Exception { Instances data; Instances result; FastVector atts; FastVector values; HashSet hash; int i; int n; boolean isDate; Instance inst; Vector sorted; m_Cols.setUpper(inputFormat.numAttributes() - 1); data = new Instances(inputFormat); atts = new FastVector(); for (i = 0; i < data.numAttributes(); i++) { if (!m_Cols.isInRange(i) || !data.attribute(i).isNumeric()) { atts.addElement(data.attribute(i)); continue; } // date attribute? isDate = (data.attribute(i).type() == Attribute.DATE); // determine all available attribtues in dataset hash = new HashSet(); for (n = 0; n < data.numInstances(); n++) { inst = data.instance(n); if (inst.isMissing(i)) continue; if (isDate) hash.add(inst.stringValue(i)); else hash.add(new Double(inst.value(i))); } // sort values sorted = new Vector(); for (Object o : hash) sorted.add(o); Collections.sort(sorted); // create attribute from sorted values values = new FastVector(); for (Object o : sorted) { if (isDate) values.addElement(o.toString()); else values.addElement(Utils.doubleToString(((Double) o).doubleValue(), MAX_DECIMALS)); } atts.addElement(new Attribute(data.attribute(i).name(), values)); } result = new Instances(inputFormat.relationName(), atts, 0); result.setClassIndex(inputFormat.classIndex()); return result; }
/** * Determines the output format based on the input format and returns this. In case the output * format cannot be returned immediately, i.e., hasImmediateOutputFormat() returns false, then * this method will called from batchFinished() after the call of preprocess(Instances), in which, * e.g., statistics for the actual processing step can be gathered. * * @param inputFormat the input format to base the output format on * @return the output format * @throws Exception in case the determination goes wrong */ protected Instances determineOutputFormat(Instances inputFormat) throws Exception { Instances result; FastVector atts; int i; int numAtts; Vector<Integer> indices; Vector<Integer> subset; Random rand; int index; // determine the number of attributes numAtts = inputFormat.numAttributes(); if (inputFormat.classIndex() > -1) numAtts--; if (m_NumAttributes < 1) { numAtts = (int) Math.round((double) numAtts * m_NumAttributes); } else { if (m_NumAttributes < numAtts) numAtts = (int) m_NumAttributes; } if (getDebug()) System.out.println("# of atts: " + numAtts); // determine random indices indices = new Vector<Integer>(); for (i = 0; i < inputFormat.numAttributes(); i++) { if (i == inputFormat.classIndex()) continue; indices.add(i); } subset = new Vector<Integer>(); rand = new Random(m_Seed); for (i = 0; i < numAtts; i++) { index = rand.nextInt(indices.size()); subset.add(indices.get(index)); indices.remove(index); } Collections.sort(subset); if (inputFormat.classIndex() > -1) subset.add(inputFormat.classIndex()); if (getDebug()) System.out.println("indices: " + subset); // generate output format atts = new FastVector(); m_Indices = new int[subset.size()]; for (i = 0; i < subset.size(); i++) { atts.addElement(inputFormat.attribute(subset.get(i))); m_Indices[i] = subset.get(i); } result = new Instances(inputFormat.relationName(), atts, 0); if (inputFormat.classIndex() > -1) result.setClassIndex(result.numAttributes() - 1); return result; }
/** * Gets the index of the instance with the closest threshold value to the desired target * * @param tcurve a set of instances that have been generated by this class * @param threshold the target threshold * @return the index of the instance that has threshold closest to the target, or -1 if this could * not be found (i.e. no data, or bad threshold target) */ public static int getThresholdInstance(Instances tcurve, double threshold) { if (!RELATION_NAME.equals(tcurve.relationName()) || (tcurve.numInstances() == 0) || (threshold < 0) || (threshold > 1.0)) { return -1; } if (tcurve.numInstances() == 1) { return 0; } double[] tvals = tcurve.attributeToDoubleArray(tcurve.numAttributes() - 1); int[] sorted = Utils.sort(tvals); return binarySearch(sorted, tvals, threshold); }
/** * Determines the output format based on the input format and returns this. * * @param inputFormat the input format to base the output format on * @return the output format * @throws Exception in case the determination goes wrong */ protected Instances determineOutputFormat(Instances inputFormat) throws Exception { Instances result; Attribute att; Attribute attSorted; FastVector atts; FastVector values; Vector<String> sorted; int i; int n; m_AttributeIndices.setUpper(inputFormat.numAttributes() - 1); // determine sorted indices atts = new FastVector(); m_NewOrder = new int[inputFormat.numAttributes()][]; for (i = 0; i < inputFormat.numAttributes(); i++) { att = inputFormat.attribute(i); if (!att.isNominal() || !m_AttributeIndices.isInRange(i)) { m_NewOrder[i] = new int[0]; atts.addElement(inputFormat.attribute(i).copy()); continue; } // sort labels sorted = new Vector<String>(); for (n = 0; n < att.numValues(); n++) sorted.add(att.value(n)); Collections.sort(sorted, m_Comparator); // determine new indices m_NewOrder[i] = new int[att.numValues()]; values = new FastVector(); for (n = 0; n < att.numValues(); n++) { m_NewOrder[i][n] = sorted.indexOf(att.value(n)); values.addElement(sorted.get(n)); } attSorted = new Attribute(att.name(), values); attSorted.setWeight(att.weight()); atts.addElement(attSorted); } // generate new header result = new Instances(inputFormat.relationName(), atts, 0); result.setClassIndex(inputFormat.classIndex()); return result; }
/** * Calculates the n point precision result, which is the precision averaged over n evenly spaced * (w.r.t recall) samples of the curve. * * @param tcurve a previously extracted threshold curve Instances. * @param n the number of points to average over. * @return the n-point precision. */ public static double getNPointPrecision(Instances tcurve, int n) { if (!RELATION_NAME.equals(tcurve.relationName()) || (tcurve.numInstances() == 0)) { return Double.NaN; } int recallInd = tcurve.attribute(RECALL_NAME).index(); int precisInd = tcurve.attribute(PRECISION_NAME).index(); double[] recallVals = tcurve.attributeToDoubleArray(recallInd); int[] sorted = Utils.sort(recallVals); double isize = 1.0 / (n - 1); double psum = 0; for (int i = 0; i < n; i++) { int pos = binarySearch(sorted, recallVals, i * isize); double recall = recallVals[sorted[pos]]; double precis = tcurve.instance(sorted[pos]).value(precisInd); /* System.err.println("Point " + (i + 1) + ": i=" + pos + " r=" + (i * isize) + " p'=" + precis + " r'=" + recall); */ // interpolate figures for non-endpoints while ((pos != 0) && (pos < sorted.length - 1)) { pos++; double recall2 = recallVals[sorted[pos]]; if (recall2 != recall) { double precis2 = tcurve.instance(sorted[pos]).value(precisInd); double slope = (precis2 - precis) / (recall2 - recall); double offset = precis - recall * slope; precis = isize * i * slope + offset; /* System.err.println("Point2 " + (i + 1) + ": i=" + pos + " r=" + (i * isize) + " p'=" + precis2 + " r'=" + recall2 + " p''=" + precis); */ break; } } psum += precis; } return psum / n; }
public boolean batchFinished() throws Exception { Instances input = getInputFormat(); String relation = input.relationName(); Instances output = new Instances(relation); int numAttributes = input.numAttributes(); int numInstances = input.numInstances(); for (int i = 0; i < numAttributes; i++) { FastVector vector = new FastVector(); for (int j = 0; j < numInstances; j++) { double value = input.instance(j).value(i); String string = String.valueOf(value); if (vector.indexOf(string) == -1) vector.addElement(string); } Attribute attribute = new Attribute(input.attribute(i).name(), vector); output.appendAttribute(attribute); } setOutputFormat(output); for (int i = 0; i < numInstances; i++) push(input.instance(i)); return super.batchFinished(); }
/** * Sets the format of output instances. The derived class should use this method once it has * determined the outputformat. The output queue is cleared. * * @param outputFormat the new output format */ protected void setOutputFormat(Instances outputFormat) { if (outputFormat != null) { m_OutputFormat = outputFormat.stringFreeStructure(); initOutputLocators(m_OutputFormat, null); // Rename the relation String relationName = outputFormat.relationName() + "-" + this.getClass().getName(); if (this instanceof OptionHandler) { String[] options = ((OptionHandler) this).getOptions(); for (int i = 0; i < options.length; i++) { relationName += options[i].trim(); } } m_OutputFormat.setRelationName(relationName); } else { m_OutputFormat = null; } m_OutputQueue = new Queue(); }
@Override protected void manipulateAttributes(Instances data) throws TaskException { String[] attribs = this.getParameterVal(ATTRIBS).split("\\s+"); for (int i = 0; i < attribs.length; ++i) { if (data.attribute(attribs[i]) == null) { Logger.getInstance() .message( "Attribute " + attribs[i] + " not found in data set " + data.relationName(), Logger.V_WARNING); continue; } Enumeration<Instance> insts = data.enumerateInstances(); int attrIndex = data.attribute(attribs[i]).index(); while (insts.hasMoreElements()) { Instance inst = insts.nextElement(); inst.setMissing(attrIndex); } } }
public void fillWekaInstances(weka.core.Instances winsts) { // set name setName(winsts.relationName()); // set attributes List onto_attrs = new ArrayList(); for (int i = 0; i < winsts.numAttributes(); i++) { Attribute a = new Attribute(); a.fillWekaAttribute(winsts.attribute(i)); onto_attrs.add(a); } setAttributes(onto_attrs); // set instances List onto_insts = new ArrayList(); for (int i = 0; i < winsts.numInstances(); i++) { Instance inst = new Instance(); weka.core.Instance winst = winsts.instance(i); List instvalues = new ArrayList(); List instmis = new ArrayList(); for (int j = 0; j < winst.numValues(); j++) { if (winst.isMissing(j)) { instvalues.add(new Double(0.0)); instmis.add(new Boolean(true)); } else { instvalues.add(new Double(winst.value(j))); instmis.add(new Boolean(false)); } } inst.setValues(instvalues); inst.setMissing(instmis); onto_insts.add(inst); } setInstances(onto_insts); setClass_index(winsts.classIndex()); }
/** * Determines the output format based on the input format and returns this. In case the output * format cannot be returned immediately, i.e., hasImmediateOutputFormat() returns false, then * this method will called from batchFinished() after the call of preprocess(Instances), in which, * e.g., statistics for the actual processing step can be gathered. * * @param inputFormat the input format to base the output format on * @return the output format * @throws Exception in case the determination goes wrong */ protected Instances determineOutputFormat(Instances inputFormat) throws Exception { Instances result; Attribute att; ArrayList<Attribute> atts; int i; m_AttributeIndices.setUpper(inputFormat.numAttributes() - 1); // generate new header atts = new ArrayList<Attribute>(); for (i = 0; i < inputFormat.numAttributes(); i++) { att = inputFormat.attribute(i); if (m_AttributeIndices.isInRange(i)) { if (m_ReplaceAll) atts.add(att.copy(att.name().replaceAll(m_Find, m_Replace))); else atts.add(att.copy(att.name().replaceFirst(m_Find, m_Replace))); } else { atts.add((Attribute) att.copy()); } } result = new Instances(inputFormat.relationName(), atts, 0); result.setClassIndex(inputFormat.classIndex()); return result; }
/** * Sets the format of the input instances. * * @param instanceInfo an Instances object containing the input instance structure (any instances * contained in the object are ignored - only the structure is required). * @return true if the outputFormat may be collected immediately * @throws Exception if the format couldn't be set successfully */ @Override public boolean setInputFormat(Instances instanceInfo) throws Exception { super.setInputFormat(instanceInfo); int classIndex = instanceInfo.classIndex(); // setup the map if (m_renameVals != null && m_renameVals.length() > 0) { String[] vals = m_renameVals.split(","); for (String val : vals) { String[] parts = val.split(":"); if (parts.length != 2) { throw new WekaException("Invalid replacement string: " + val); } if (parts[0].length() == 0 || parts[1].length() == 0) { throw new WekaException("Invalid replacement string: " + val); } m_renameMap.put( m_ignoreCase ? parts[0].toLowerCase().trim() : parts[0].trim(), parts[1].trim()); } } // try selected atts as a numeric range first Range tempRange = new Range(); tempRange.setInvert(m_invert); if (m_selectedColsString == null) { m_selectedColsString = ""; } try { tempRange.setRanges(m_selectedColsString); tempRange.setUpper(instanceInfo.numAttributes() - 1); m_selectedAttributes = tempRange.getSelection(); m_selectedCols = tempRange; } catch (Exception r) { // OK, now try as named attributes StringBuffer indexes = new StringBuffer(); String[] attNames = m_selectedColsString.split(","); boolean first = true; for (String n : attNames) { n = n.trim(); Attribute found = instanceInfo.attribute(n); if (found == null) { throw new WekaException( "Unable to find attribute '" + n + "' in the incoming instances'"); } if (first) { indexes.append("" + (found.index() + 1)); first = false; } else { indexes.append("," + (found.index() + 1)); } } tempRange = new Range(); tempRange.setRanges(indexes.toString()); tempRange.setUpper(instanceInfo.numAttributes() - 1); m_selectedAttributes = tempRange.getSelection(); m_selectedCols = tempRange; } ArrayList<Attribute> attributes = new ArrayList<Attribute>(); for (int i = 0; i < instanceInfo.numAttributes(); i++) { if (m_selectedCols.isInRange(i)) { if (instanceInfo.attribute(i).isNominal()) { List<String> valsForAtt = new ArrayList<String>(); for (int j = 0; j < instanceInfo.attribute(i).numValues(); j++) { String origV = instanceInfo.attribute(i).value(j); String replace = m_ignoreCase ? m_renameMap.get(origV.toLowerCase()) : m_renameMap.get(origV); if (replace != null && !valsForAtt.contains(replace)) { valsForAtt.add(replace); } else { valsForAtt.add(origV); } } Attribute newAtt = new Attribute(instanceInfo.attribute(i).name(), valsForAtt); attributes.add(newAtt); } else { // ignore any selected attributes that are not nominal Attribute att = (Attribute) instanceInfo.attribute(i).copy(); attributes.add(att); } } else { Attribute att = (Attribute) instanceInfo.attribute(i).copy(); attributes.add(att); } } Instances outputFormat = new Instances(instanceInfo.relationName(), attributes, 0); outputFormat.setClassIndex(classIndex); setOutputFormat(outputFormat); return true; }
@Override protected Instances determineOutputFormat(Instances inputFormat) throws Exception { if (m_excludeNominalAttributes && m_excludeNumericAttributes) { throw new Exception( "No transformation will be done if both nominal and " + "numeric attributes are excluded!"); } if (m_remove == null) { List<Integer> attsToExclude = new ArrayList<Integer>(); if (m_excludeNumericAttributes) { for (int i = 0; i < inputFormat.numAttributes(); i++) { if (inputFormat.attribute(i).isNumeric() && i != inputFormat.classIndex()) { attsToExclude.add(i); } } } if (m_excludeNominalAttributes || m_nominalConversionThreshold > 1) { for (int i = 0; i < inputFormat.numAttributes(); i++) { if (inputFormat.attribute(i).isNominal() && i != inputFormat.classIndex()) { if (m_excludeNominalAttributes || inputFormat.attribute(i).numValues() < m_nominalConversionThreshold) { attsToExclude.add(i); } } } } if (attsToExclude.size() > 0) { int[] r = new int[attsToExclude.size()]; for (int i = 0; i < attsToExclude.size(); i++) { r[i] = attsToExclude.get(i); } m_remove = new Remove(); m_remove.setAttributeIndicesArray(r); m_remove.setInputFormat(inputFormat); Remove forRetaining = new Remove(); forRetaining.setAttributeIndicesArray(r); forRetaining.setInvertSelection(true); forRetaining.setInputFormat(inputFormat); m_unchanged = Filter.useFilter(inputFormat, forRetaining); } } ArrayList<Attribute> atts = new ArrayList<Attribute>(); for (int i = 0; i < inputFormat.numAttributes(); i++) { if (i != inputFormat.classIndex()) { if (m_unchanged != null && m_unchanged.attribute(inputFormat.attribute(i).name()) != null) { atts.add((Attribute) m_unchanged.attribute(inputFormat.attribute(i).name()).copy()); continue; } for (int j = 0; j < inputFormat.classAttribute().numValues(); j++) { String name = "pr_" + inputFormat.attribute(i).name() + "|" + inputFormat.classAttribute().value(j); atts.add(new Attribute(name)); } } } atts.add((Attribute) inputFormat.classAttribute().copy()); Instances data = new Instances(inputFormat.relationName(), atts, 0); data.setClassIndex(data.numAttributes() - 1); return data; }
/** * generates source code from the filter * * @param filter the filter to output as source * @param className the name of the generated class * @param input the input data the header is generated for * @param output the output data the header is generated for * @return the generated source code * @throws Exception if source code cannot be generated */ public static String wekaStaticWrapper( Sourcable filter, String className, Instances input, Instances output) throws Exception { StringBuffer result; int i; int n; result = new StringBuffer(); result.append("// Generated with Weka " + Version.VERSION + "\n"); result.append("//\n"); result.append("// This code is public domain and comes with no warranty.\n"); result.append("//\n"); result.append("// Timestamp: " + new Date() + "\n"); result.append("// Relation: " + input.relationName() + "\n"); result.append("\n"); result.append("package weka.filters;\n"); result.append("\n"); result.append("import weka.core.Attribute;\n"); result.append("import weka.core.Capabilities;\n"); result.append("import weka.core.Capabilities.Capability;\n"); result.append("import weka.core.FastVector;\n"); result.append("import weka.core.Instance;\n"); result.append("import weka.core.Instances;\n"); result.append("import weka.filters.Filter;\n"); result.append("\n"); result.append("public class WekaWrapper\n"); result.append(" extends Filter {\n"); // globalInfo result.append("\n"); result.append(" /**\n"); result.append(" * Returns only the toString() method.\n"); result.append(" *\n"); result.append(" * @return a string describing the filter\n"); result.append(" */\n"); result.append(" public String globalInfo() {\n"); result.append(" return toString();\n"); result.append(" }\n"); // getCapabilities result.append("\n"); result.append(" /**\n"); result.append(" * Returns the capabilities of this filter.\n"); result.append(" *\n"); result.append(" * @return the capabilities\n"); result.append(" */\n"); result.append(" public Capabilities getCapabilities() {\n"); result.append(((Filter) filter).getCapabilities().toSource("result", 4)); result.append(" return result;\n"); result.append(" }\n"); // objectsToInstance result.append("\n"); result.append(" /**\n"); result.append(" * turns array of Objects into an Instance object\n"); result.append(" *\n"); result.append(" * @param obj the Object array to turn into an Instance\n"); result.append(" * @param format the data format to use\n"); result.append(" * @return the generated Instance object\n"); result.append(" */\n"); result.append(" protected Instance objectsToInstance(Object[] obj, Instances format) {\n"); result.append(" Instance result;\n"); result.append(" double[] values;\n"); result.append(" int i;\n"); result.append("\n"); result.append(" values = new double[obj.length];\n"); result.append("\n"); result.append(" for (i = 0 ; i < obj.length; i++) {\n"); result.append(" if (obj[i] == null)\n"); result.append(" values[i] = Instance.missingValue();\n"); result.append(" else if (format.attribute(i).isNumeric())\n"); result.append(" values[i] = (Double) obj[i];\n"); result.append(" else if (format.attribute(i).isNominal())\n"); result.append(" values[i] = format.attribute(i).indexOfValue((String) obj[i]);\n"); result.append(" }\n"); result.append("\n"); result.append(" // create new instance\n"); result.append(" result = new Instance(1.0, values);\n"); result.append(" result.setDataset(format);\n"); result.append("\n"); result.append(" return result;\n"); result.append(" }\n"); // instanceToObjects result.append("\n"); result.append(" /**\n"); result.append(" * turns the Instance object into an array of Objects\n"); result.append(" *\n"); result.append(" * @param inst the instance to turn into an array\n"); result.append(" * @return the Object array representing the instance\n"); result.append(" */\n"); result.append(" protected Object[] instanceToObjects(Instance inst) {\n"); result.append(" Object[] result;\n"); result.append(" int i;\n"); result.append("\n"); result.append(" result = new Object[inst.numAttributes()];\n"); result.append("\n"); result.append(" for (i = 0 ; i < inst.numAttributes(); i++) {\n"); result.append(" if (inst.isMissing(i))\n"); result.append(" result[i] = null;\n"); result.append(" else if (inst.attribute(i).isNumeric())\n"); result.append(" result[i] = inst.value(i);\n"); result.append(" else\n"); result.append(" result[i] = inst.stringValue(i);\n"); result.append(" }\n"); result.append("\n"); result.append(" return result;\n"); result.append(" }\n"); // instancesToObjects result.append("\n"); result.append(" /**\n"); result.append(" * turns the Instances object into an array of Objects\n"); result.append(" *\n"); result.append(" * @param data the instances to turn into an array\n"); result.append(" * @return the Object array representing the instances\n"); result.append(" */\n"); result.append(" protected Object[][] instancesToObjects(Instances data) {\n"); result.append(" Object[][] result;\n"); result.append(" int i;\n"); result.append("\n"); result.append(" result = new Object[data.numInstances()][];\n"); result.append("\n"); result.append(" for (i = 0; i < data.numInstances(); i++)\n"); result.append(" result[i] = instanceToObjects(data.instance(i));\n"); result.append("\n"); result.append(" return result;\n"); result.append(" }\n"); // setInputFormat result.append("\n"); result.append(" /**\n"); result.append(" * Only tests the input data.\n"); result.append(" *\n"); result.append(" * @param instanceInfo the format of the data to convert\n"); result.append(" * @return always true, to indicate that the output format can \n"); result.append(" * be collected immediately.\n"); result.append(" */\n"); result.append(" public boolean setInputFormat(Instances instanceInfo) throws Exception {\n"); result.append(" super.setInputFormat(instanceInfo);\n"); result.append(" \n"); result.append(" // generate output format\n"); result.append(" FastVector atts = new FastVector();\n"); result.append(" FastVector attValues;\n"); for (i = 0; i < output.numAttributes(); i++) { result.append(" // " + output.attribute(i).name() + "\n"); if (output.attribute(i).isNumeric()) { result.append( " atts.addElement(new Attribute(\"" + output.attribute(i).name() + "\"));\n"); } else if (output.attribute(i).isNominal()) { result.append(" attValues = new FastVector();\n"); for (n = 0; n < output.attribute(i).numValues(); n++) { result.append(" attValues.addElement(\"" + output.attribute(i).value(n) + "\");\n"); } result.append( " atts.addElement(new Attribute(\"" + output.attribute(i).name() + "\", attValues));\n"); } else { throw new UnsupportedAttributeTypeException( "Attribute type '" + output.attribute(i).type() + "' (position " + (i + 1) + ") is not supported!"); } } result.append(" \n"); result.append( " Instances format = new Instances(\"" + output.relationName() + "\", atts, 0);\n"); result.append(" format.setClassIndex(" + output.classIndex() + ");\n"); result.append(" setOutputFormat(format);\n"); result.append(" \n"); result.append(" return true;\n"); result.append(" }\n"); // input result.append("\n"); result.append(" /**\n"); result.append(" * Directly filters the instance.\n"); result.append(" *\n"); result.append(" * @param instance the instance to convert\n"); result.append(" * @return always true, to indicate that the output can \n"); result.append(" * be collected immediately.\n"); result.append(" */\n"); result.append(" public boolean input(Instance instance) throws Exception {\n"); result.append( " Object[] filtered = " + className + ".filter(instanceToObjects(instance));\n"); result.append(" push(objectsToInstance(filtered, getOutputFormat()));\n"); result.append(" return true;\n"); result.append(" }\n"); // batchFinished result.append("\n"); result.append(" /**\n"); result.append(" * Performs a batch filtering of the buffered data, if any available.\n"); result.append(" *\n"); result.append(" * @return true if instances were filtered otherwise false\n"); result.append(" */\n"); result.append(" public boolean batchFinished() throws Exception {\n"); result.append(" if (getInputFormat() == null)\n"); result.append(" throw new NullPointerException(\"No input instance format defined\");;\n"); result.append("\n"); result.append(" Instances inst = getInputFormat();\n"); result.append(" if (inst.numInstances() > 0) {\n"); result.append( " Object[][] filtered = " + className + ".filter(instancesToObjects(inst));\n"); result.append(" for (int i = 0; i < filtered.length; i++) {\n"); result.append(" push(objectsToInstance(filtered[i], getOutputFormat()));\n"); result.append(" }\n"); result.append(" }\n"); result.append("\n"); result.append(" flushInput();\n"); result.append(" m_NewBatch = true;\n"); result.append(" m_FirstBatchDone = true;\n"); result.append("\n"); result.append(" return (inst.numInstances() > 0);\n"); result.append(" }\n"); // toString result.append("\n"); result.append(" /**\n"); result.append(" * Returns only the classnames and what filter it is based on.\n"); result.append(" *\n"); result.append(" * @return a short description\n"); result.append(" */\n"); result.append(" public String toString() {\n"); result.append( " return \"Auto-generated filter wrapper, based on " + filter.getClass().getName() + " (generated with Weka " + Version.VERSION + ").\\n" + "\" + this.getClass().getName() + \"/" + className + "\";\n"); result.append(" }\n"); // main result.append("\n"); result.append(" /**\n"); result.append(" * Runs the filter from commandline.\n"); result.append(" *\n"); result.append(" * @param args the commandline arguments\n"); result.append(" */\n"); result.append(" public static void main(String args[]) {\n"); result.append(" runFilter(new WekaWrapper(), args);\n"); result.append(" }\n"); result.append("}\n"); // actual filter code result.append("\n"); result.append(filter.toSource(className, input)); return result.toString(); }
/** * Takes an evaluation object from a task and aggregates it with the overall one. * * @param eval the evaluation object to aggregate * @param classifier the classifier used by the task * @param testData the testData from the task * @param plotInstances the ClassifierErrorsPlotInstances object from the task * @param setNum the set number processed by the task * @param maxSetNum the maximum number of sets in this batch */ protected synchronized void aggregateEvalTask( Evaluation eval, Classifier classifier, Instances testData, ClassifierErrorsPlotInstances plotInstances, int setNum, int maxSetNum) { m_eval.aggregate(eval); if (m_aggregatedPlotInstances == null) { m_aggregatedPlotInstances = new Instances(plotInstances.getPlotInstances()); m_aggregatedPlotShapes = plotInstances.getPlotShapes(); m_aggregatedPlotSizes = plotInstances.getPlotSizes(); } else { Instances temp = plotInstances.getPlotInstances(); for (int i = 0; i < temp.numInstances(); i++) { m_aggregatedPlotInstances.add(temp.get(i)); m_aggregatedPlotShapes.addElement(plotInstances.getPlotShapes().get(i)); m_aggregatedPlotSizes.addElement(plotInstances.getPlotSizes().get(i)); } } m_setsComplete++; // if (ce.getSetNumber() == ce.getMaxSetNumber()) { if (m_setsComplete == maxSetNum) { try { String textTitle = classifier.getClass().getName(); String textOptions = ""; if (classifier instanceof OptionHandler) { textOptions = Utils.joinOptions(((OptionHandler) classifier).getOptions()); } textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length()); String resultT = "=== Evaluation result ===\n\n" + "Scheme: " + textTitle + "\n" + ((textOptions.length() > 0) ? "Options: " + textOptions + "\n" : "") + "Relation: " + testData.relationName() + "\n\n" + m_eval.toSummaryString(); if (testData.classAttribute().isNominal()) { resultT += "\n" + m_eval.toClassDetailsString() + "\n" + m_eval.toMatrixString(); } TextEvent te = new TextEvent(ClassifierPerformanceEvaluator.this, resultT, textTitle); notifyTextListeners(te); // set up visualizable errors if (m_visualizableErrorListeners.size() > 0) { PlotData2D errorD = new PlotData2D(m_aggregatedPlotInstances); errorD.setShapeSize(m_aggregatedPlotSizes); errorD.setShapeType(m_aggregatedPlotShapes); errorD.setPlotName(textTitle + " " + textOptions); /* PlotData2D errorD = m_PlotInstances.getPlotData( textTitle + " " + textOptions); */ VisualizableErrorEvent vel = new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this, errorD); notifyVisualizableErrorListeners(vel); m_PlotInstances.cleanUp(); } if (testData.classAttribute().isNominal() && m_thresholdListeners.size() > 0) { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_eval.predictions(), 0); result.setRelationName(testData.relationName()); PlotData2D pd = new PlotData2D(result); String htmlTitle = "<html><font size=-2>" + textTitle; String newOptions = ""; if (classifier instanceof OptionHandler) { String[] options = ((OptionHandler) classifier).getOptions(); if (options.length > 0) { for (int ii = 0; ii < options.length; ii++) { if (options[ii].length() == 0) { continue; } if (options[ii].charAt(0) == '-' && !(options[ii].charAt(1) >= '0' && options[ii].charAt(1) <= '9')) { newOptions += "<br>"; } newOptions += options[ii]; } } } htmlTitle += " " + newOptions + "<br>" + " (class: " + testData.classAttribute().value(0) + ")" + "</font></html>"; pd.setPlotName(textTitle + " (class: " + testData.classAttribute().value(0) + ")"); pd.setPlotNameHTML(htmlTitle); boolean[] connectPoints = new boolean[result.numInstances()]; for (int jj = 1; jj < connectPoints.length; jj++) { connectPoints[jj] = true; } pd.setConnectPoints(connectPoints); ThresholdDataEvent rde = new ThresholdDataEvent( ClassifierPerformanceEvaluator.this, pd, testData.classAttribute()); notifyThresholdListeners(rde); } if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix() + "Finished."); } } catch (Exception ex) { if (m_logger != null) { m_logger.logMessage( "[ClassifierPerformanceEvaluator] " + statusMessagePrefix() + " problem constructing evaluation results. " + ex.getMessage()); } ex.printStackTrace(); } finally { m_visual.setStatic(); // save memory m_PlotInstances = null; m_setsComplete = 0; m_tasks = null; m_aggregatedPlotInstances = null; } } }
/** * Determines the output format based on the input format and returns this. In case the output * format cannot be returned immediately, i.e., hasImmediateOutputFormat() returns false, then * this method will called from batchFinished() after the call of preprocess(Instances), in which, * e.g., statistics for the actual processing step can be gathered. * * @param inputFormat the input format to base the output format on * @return the output format * @throws Exception in case the determination goes wrong * @see #hasImmediateOutputFormat() * @see #batchFinished() */ protected Instances determineOutputFormat(Instances inputFormat) throws Exception { FastVector atts; FastVector values; Instances result; int i; // attributes must be numeric m_Attributes.setUpper(inputFormat.numAttributes() - 1); m_AttributeIndices = m_Attributes.getSelection(); for (i = 0; i < m_AttributeIndices.length; i++) { // ignore class if (m_AttributeIndices[i] == inputFormat.classIndex()) { m_AttributeIndices[i] = NON_NUMERIC; continue; } // not numeric -> ignore it if (!inputFormat.attribute(m_AttributeIndices[i]).isNumeric()) m_AttributeIndices[i] = NON_NUMERIC; } // get old attributes atts = new FastVector(); for (i = 0; i < inputFormat.numAttributes(); i++) atts.addElement(inputFormat.attribute(i)); if (!getDetectionPerAttribute()) { m_OutlierAttributePosition = new int[1]; m_OutlierAttributePosition[0] = atts.size(); // add 2 new attributes values = new FastVector(); values.addElement("no"); values.addElement("yes"); atts.addElement(new Attribute("Outlier", values)); values = new FastVector(); values.addElement("no"); values.addElement("yes"); atts.addElement(new Attribute("ExtremeValue", values)); } else { m_OutlierAttributePosition = new int[m_AttributeIndices.length]; for (i = 0; i < m_AttributeIndices.length; i++) { if (m_AttributeIndices[i] == NON_NUMERIC) continue; m_OutlierAttributePosition[i] = atts.size(); // add new attributes values = new FastVector(); values.addElement("no"); values.addElement("yes"); atts.addElement( new Attribute( inputFormat.attribute(m_AttributeIndices[i]).name() + "_Outlier", values)); values = new FastVector(); values.addElement("no"); values.addElement("yes"); atts.addElement( new Attribute( inputFormat.attribute(m_AttributeIndices[i]).name() + "_ExtremeValue", values)); if (getOutputOffsetMultiplier()) atts.addElement( new Attribute(inputFormat.attribute(m_AttributeIndices[i]).name() + "_Offset")); } } // generate header result = new Instances(inputFormat.relationName(), atts, 0); result.setClassIndex(inputFormat.classIndex()); return result; }
/** * pads the data to conform to the necessary number of attributes * * @param data the data to pad * @return the padded data */ protected Instances pad(Instances data) { Instances result; int i; int n; String prefix; int numAtts; boolean isLast; int index; Vector<Integer> padded; int[] indices; FastVector atts; // determine number of padding attributes switch (m_Padding) { case PADDING_ZERO: if (data.classIndex() > -1) numAtts = (nextPowerOf2(data.numAttributes() - 1) + 1) - data.numAttributes(); else numAtts = nextPowerOf2(data.numAttributes()) - data.numAttributes(); break; default: throw new IllegalStateException( "Padding " + new SelectedTag(m_Algorithm, TAGS_PADDING) + " not implemented!"); } result = new Instances(data); prefix = getAlgorithm().getSelectedTag().getReadable(); // any padding necessary? if (numAtts > 0) { // add padding attributes isLast = (data.classIndex() == data.numAttributes() - 1); padded = new Vector<Integer>(); for (i = 0; i < numAtts; i++) { if (isLast) index = result.numAttributes() - 1; else index = result.numAttributes(); result.insertAttributeAt(new Attribute(prefix + "_padding_" + (i + 1)), index); // record index padded.add(new Integer(index)); } // get padded indices indices = new int[padded.size()]; for (i = 0; i < padded.size(); i++) indices[i] = padded.get(i); // determine number of padding attributes switch (m_Padding) { case PADDING_ZERO: for (i = 0; i < result.numInstances(); i++) { for (n = 0; n < indices.length; n++) result.instance(i).setValue(indices[n], 0); } break; } } // rename all attributes apart from class data = result; atts = new FastVector(); n = 0; for (i = 0; i < data.numAttributes(); i++) { n++; if (i == data.classIndex()) atts.addElement((Attribute) data.attribute(i).copy()); else atts.addElement(new Attribute(prefix + "_" + n)); } // create new dataset result = new Instances(data.relationName(), atts, data.numInstances()); result.setClassIndex(data.classIndex()); for (i = 0; i < data.numInstances(); i++) result.add(new DenseInstance(1.0, data.instance(i).toDoubleArray())); return result; }