/** * Dump the graph in Mary format * * @param graph graph * @param destFile the destination file * @throws IOException IOException */ public void saveGraph(DirectedGraph graph, String destFile) throws IOException { if (graph == null) throw new NullPointerException("Cannot dump null graph"); if (destFile == null) throw new NullPointerException("No destination file"); logger.debug("Dumping directed graph in Mary format to " + destFile + " ..."); // Open the destination file and output the header DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(destFile))); // create new CART-header and write it to output file MaryHeader hdr = new MaryHeader(MaryHeader.DIRECTED_GRAPH); hdr.writeTo(out); Properties props = graph.getProperties(); if (props == null) { out.writeShort(0); } else { ByteArrayOutputStream baos = new ByteArrayOutputStream(); props.store(baos, null); byte[] propData = baos.toByteArray(); out.writeShort(propData.length); out.write(propData); } // feature definition graph.getFeatureDefinition().writeBinaryTo(out); // dump graph dumpBinary(graph, out); // finish out.close(); logger.debug(" ... done\n"); }
/** * Assign unique ids to directed graph nodes. * * @param graph * @return the number of different directed graph nodes */ private int setUniqueDirectedGraphNodeIds(DirectedGraph graph) { int i = 0; for (DirectedGraphNode g : graph.getDirectedGraphNodes()) { g.setUniqueGraphNodeID(++i); } return i; }
/** * Assign unique ids to decision nodes. * * @param graph * @return the number of different decision nodes */ private int setUniqueDecisionNodeIds(DirectedGraph graph) { int i = 0; for (DecisionNode d : graph.getDecisionNodes()) { d.setUniqueDecisionNodeId(++i); } return i; }
/** * Assign unique ids to leaf nodes. * * @param graph * @return the number of different leaf nodes */ private int setUniqueLeafNodeIds(DirectedGraph graph) { int i = 0; for (LeafNode l : graph.getLeafNodes()) { l.setUniqueLeafId(++i); } return i; }
private void printDirectedGraphNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException { for (DirectedGraphNode g : graph.getDirectedGraphNodes()) { int id = g.getUniqueGraphNodeID(); if (id == 0) continue; // empty node, do not write Node leaf = g.getLeafNode(); int leafID = 0; int leafNodeType = 0; if (leaf != null) { if (leaf instanceof LeafNode) { leafID = ((LeafNode) leaf).getUniqueLeafId(); leafNodeType = DirectedGraphReader.LEAFNODE; } else if (leaf instanceof DirectedGraphNode) { leafID = ((DirectedGraphNode) leaf).getUniqueGraphNodeID(); leafNodeType = DirectedGraphReader.DIRECTEDGRAPHNODE; } else { throw new IllegalArgumentException("Unexpected leaf type: " + leaf.getClass()); } } DecisionNode d = g.getDecisionNode(); int decID = d != null ? d.getUniqueDecisionNodeId() : 0; if (out != null) { int outLeafId = leafID == 0 ? 0 : leafID | (leafNodeType << 30); out.writeInt(outLeafId); int outDecId = decID == 0 ? 0 : decID | (DirectedGraphReader.DECISIONNODE << 30); out.writeInt(outDecId); } if (pw != null) { pw.print("DGN" + id); if (leafID == 0) { pw.print(" 0"); } else if (leaf.isLeafNode()) { pw.print(" id" + leafID); } else { assert leaf.isDirectedGraphNode(); pw.print(" DGN" + leafID); } if (decID == 0) pw.print(" 0"); else pw.print(" -" + decID); pw.println(); } } }
private void printLeafNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException { for (LeafNode leaf : graph.getLeafNodes()) { if (leaf.getUniqueLeafId() == 0) // empty leaf, do not write continue; LeafType leafType = leaf.getLeafNodeType(); if (leafType == LeafType.FeatureVectorLeafNode) { leafType = LeafType.IntArrayLeafNode; // save feature vector leaf nodes as int array leaf nodes } if (out != null) { // Leaf node type out.writeInt(leafType.ordinal()); } if (pw != null) { pw.print("id" + leaf.getUniqueLeafId() + " " + leafType); } switch (leaf.getLeafNodeType()) { case IntArrayLeafNode: int data[] = ((IntArrayLeafNode) leaf).getIntData(); // Number of data points following: if (out != null) out.writeInt(data.length); if (pw != null) pw.print(" " + data.length); // for each index, write the index for (int i = 0; i < data.length; i++) { if (out != null) out.writeInt(data[i]); if (pw != null) pw.print(" " + data[i]); } break; case FloatLeafNode: float stddev = ((FloatLeafNode) leaf).getStDeviation(); float mean = ((FloatLeafNode) leaf).getMean(); if (out != null) { out.writeFloat(stddev); out.writeFloat(mean); } if (pw != null) { pw.print(" 1 " + stddev + " " + mean); } break; case IntAndFloatArrayLeafNode: case StringAndFloatLeafNode: int data1[] = ((IntAndFloatArrayLeafNode) leaf).getIntData(); float floats[] = ((IntAndFloatArrayLeafNode) leaf).getFloatData(); // Number of data points following: if (out != null) out.writeInt(data1.length); if (pw != null) pw.print(" " + data1.length); // for each index, write the index and then its float for (int i = 0; i < data1.length; i++) { if (out != null) { out.writeInt(data1[i]); out.writeFloat(floats[i]); } if (pw != null) pw.print(" " + data1[i] + " " + floats[i]); } break; case FeatureVectorLeafNode: FeatureVector fv[] = ((FeatureVectorLeafNode) leaf).getFeatureVectors(); // Number of data points following: if (out != null) out.writeInt(fv.length); if (pw != null) pw.print(" " + fv.length); // for each feature vector, write the index for (int i = 0; i < fv.length; i++) { if (out != null) out.writeInt(fv[i].getUnitIndex()); if (pw != null) pw.print(" " + fv[i].getUnitIndex()); } break; case PdfLeafNode: throw new IllegalArgumentException("Writing of pdf leaf nodes not yet implemented"); } if (pw != null) pw.println(); } }
private void printDecisionNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException { for (DecisionNode decNode : graph.getDecisionNodes()) { int id = decNode.getUniqueDecisionNodeId(); String nodeDefinition = decNode.getNodeDefinition(); int featureIndex = decNode.getFeatureIndex(); DecisionNode.Type nodeType = decNode.getDecisionNodeType(); if (out != null) { // dump in binary form to output out.writeInt(featureIndex); out.writeInt(nodeType.ordinal()); // Now, questionValue, which depends on nodeType switch (nodeType) { case BinaryByteDecisionNode: out.writeInt(((BinaryByteDecisionNode) decNode).getCriterionValueAsByte()); assert decNode.getNumberOfDaugthers() == 2; break; case BinaryShortDecisionNode: out.writeInt(((BinaryShortDecisionNode) decNode).getCriterionValueAsShort()); assert decNode.getNumberOfDaugthers() == 2; break; case BinaryFloatDecisionNode: out.writeFloat(((BinaryFloatDecisionNode) decNode).getCriterionValueAsFloat()); assert decNode.getNumberOfDaugthers() == 2; break; case ByteDecisionNode: case ShortDecisionNode: out.writeInt(decNode.getNumberOfDaugthers()); } // The child nodes for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) { Node daughter = decNode.getDaughter(i); if (daughter == null) { out.writeInt(0); } else if (daughter.isDecisionNode()) { int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId(); // Mark as decision node: daughterID |= DirectedGraphReader.DECISIONNODE << 30; out.writeInt(daughterID); } else if (daughter.isLeafNode()) { int daughterID = ((LeafNode) daughter).getUniqueLeafId(); // Mark as leaf node: if (daughterID != 0) daughterID |= DirectedGraphReader.LEAFNODE << 30; out.writeInt(daughterID); } else if (daughter.isDirectedGraphNode()) { int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID(); // Mark as directed graph node: if (daughterID != 0) daughterID |= DirectedGraphReader.DIRECTEDGRAPHNODE << 30; out.writeInt(daughterID); } } } if (pw != null) { // dump to print writer StringBuilder strNode = new StringBuilder("-" + id + " " + nodeDefinition); for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) { strNode.append(" "); Node daughter = decNode.getDaughter(i); if (daughter == null) { strNode.append("0"); } else if (daughter.isDecisionNode()) { int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId(); strNode.append("-").append(daughterID); out.writeInt(daughterID); } else if (daughter.isLeafNode()) { int daughterID = ((LeafNode) daughter).getUniqueLeafId(); if (daughterID == 0) strNode.append("0"); else strNode.append("id").append(daughterID); } else if (daughter.isDirectedGraphNode()) { int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID(); if (daughterID == 0) strNode.append("0"); else strNode.append("DGN").append(daughterID); } } pw.println(strNode.toString()); } } }
@Override public boolean compute() throws IOException, MaryConfigurationException { logger.info("Duration tree trainer started."); FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE)); UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE)); FeatureVector[] allFeatureVectors = featureFile.getFeatureVectors(); int maxData = Integer.parseInt(getProp(MAXDATA)); if (maxData == 0) maxData = allFeatureVectors.length; FeatureVector[] featureVectors = new FeatureVector[Math.min(maxData, allFeatureVectors.length)]; System.arraycopy(allFeatureVectors, 0, featureVectors, 0, featureVectors.length); logger.debug( "Total of " + allFeatureVectors.length + " feature vectors -- will use " + featureVectors.length); AgglomerativeClusterer clusterer = new AgglomerativeClusterer( featureVectors, featureFile.getFeatureDefinition(), null, new DurationDistanceMeasure(unitFile), Float.parseFloat(getProp(PROPORTIONTESTDATA))); DirectedGraphWriter writer = new DirectedGraphWriter(); DirectedGraph graph; int iteration = 0; do { graph = clusterer.cluster(); iteration++; if (graph != null) { writer.saveGraph(graph, getProp(DURTREE) + ".level" + iteration); } } while (clusterer.canClusterMore()); if (graph == null) { return false; } // Now replace each leaf with a FloatLeafNode containing mean and stddev for (LeafNode leaf : graph.getLeafNodes()) { FeatureVectorLeafNode fvLeaf = (FeatureVectorLeafNode) leaf; FeatureVector[] fvs = fvLeaf.getFeatureVectors(); double[] dur = new double[fvs.length]; for (int i = 0; i < fvs.length; i++) { dur[i] = unitFile.getUnit(fvs[i].getUnitIndex()).duration / (float) unitFile.getSampleRate(); } double mean = MathUtils.mean(dur); double stddev = MathUtils.standardDeviation(dur, mean); FloatLeafNode floatLeaf = new FloatLeafNode(new float[] {(float) stddev, (float) mean}); Node mother = fvLeaf.getMother(); assert mother != null; if (mother.isDecisionNode()) { ((DecisionNode) mother).replaceDaughter(floatLeaf, fvLeaf.getNodeIndex()); } else { assert mother.isDirectedGraphNode(); assert ((DirectedGraphNode) mother).getLeafNode() == fvLeaf; ((DirectedGraphNode) mother).setLeafNode(floatLeaf); } } writer.saveGraph(graph, getProp(DURTREE)); return true; }