/** * Load the directed graph from the given file * * @param fileName the file to load the cart from * @param featDefinition the feature definition * @param dummy unused, just here for compatibility with the FeatureFileIndexer. * @throws IOException , {@link MaryConfigurationException} if a problem occurs while loading */ public DirectedGraph load(InputStream inStream) throws IOException, MaryConfigurationException { BufferedInputStream buffInStream = new BufferedInputStream(inStream); assert buffInStream.markSupported(); buffInStream.mark(10000); // open the CART-File and read the header DataInput raf = new DataInputStream(buffInStream); MaryHeader maryHeader = new MaryHeader(raf); if (!maryHeader.hasCurrentVersion()) { throw new IOException("Wrong version of database file"); } if (maryHeader.getType() != MaryHeader.DIRECTED_GRAPH) { if (maryHeader.getType() == MaryHeader.CARTS) { buffInStream.reset(); return new MaryCARTReader().loadFromStream(buffInStream); } else { throw new IOException("Not a directed graph file"); } } // Read properties short propDataLength = raf.readShort(); Properties props; if (propDataLength == 0) { props = null; } else { byte[] propsData = new byte[propDataLength]; raf.readFully(propsData); ByteArrayInputStream bais = new ByteArrayInputStream(propsData); props = new Properties(); props.load(bais); bais.close(); } // Read the feature definition FeatureDefinition featureDefinition = new FeatureDefinition(raf); // read the decision nodes int numDecNodes = raf.readInt(); // number of decision nodes // First we need to read all nodes into memory, then we can link them properly // in terms of parent/child. DecisionNode[] dns = new DecisionNode[numDecNodes]; int[][] childIndexes = new int[numDecNodes][]; for (int i = 0; i < numDecNodes; i++) { // read one decision node int featureNameIndex = raf.readInt(); int nodeTypeNr = raf.readInt(); DecisionNode.Type nodeType = DecisionNode.Type.values()[nodeTypeNr]; int numChildren = 2; // for binary nodes switch (nodeType) { case BinaryByteDecisionNode: int criterion = raf.readInt(); dns[i] = new DecisionNode.BinaryByteDecisionNode( featureNameIndex, (byte) criterion, featureDefinition); break; case BinaryShortDecisionNode: criterion = raf.readInt(); dns[i] = new DecisionNode.BinaryShortDecisionNode( featureNameIndex, (short) criterion, featureDefinition); break; case BinaryFloatDecisionNode: float floatCriterion = raf.readFloat(); dns[i] = new DecisionNode.BinaryFloatDecisionNode( featureNameIndex, floatCriterion, featureDefinition); break; case ByteDecisionNode: numChildren = raf.readInt(); if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) { throw new IOException( "Inconsistent cart file: feature " + featureDefinition.getFeatureName(featureNameIndex) + " should have " + featureDefinition.getNumberOfValues(featureNameIndex) + " values, but decision node " + i + " has only " + numChildren + " child nodes"); } dns[i] = new DecisionNode.ByteDecisionNode(featureNameIndex, numChildren, featureDefinition); break; case ShortDecisionNode: numChildren = raf.readInt(); if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) { throw new IOException( "Inconsistent cart file: feature " + featureDefinition.getFeatureName(featureNameIndex) + " should have " + featureDefinition.getNumberOfValues(featureNameIndex) + " values, but decision node " + i + " has only " + numChildren + " child nodes"); } dns[i] = new DecisionNode.ShortDecisionNode(featureNameIndex, numChildren, featureDefinition); } dns[i].setUniqueDecisionNodeId(i + 1); // now read the children, indexes only: childIndexes[i] = new int[numChildren]; for (int k = 0; k < numChildren; k++) { childIndexes[i][k] = raf.readInt(); } } // read the leaves int numLeafNodes = raf.readInt(); // number of leaves, it does not include empty leaves LeafNode[] lns = new LeafNode[numLeafNodes]; for (int j = 0; j < numLeafNodes; j++) { // read one leaf node int leafTypeNr = raf.readInt(); LeafNode.LeafType leafNodeType = LeafNode.LeafType.values()[leafTypeNr]; switch (leafNodeType) { case IntArrayLeafNode: int numData = raf.readInt(); int[] data = new int[numData]; for (int d = 0; d < numData; d++) { data[d] = raf.readInt(); } lns[j] = new LeafNode.IntArrayLeafNode(data); break; case FloatLeafNode: float stddev = raf.readFloat(); float mean = raf.readFloat(); lns[j] = new LeafNode.FloatLeafNode(new float[] {stddev, mean}); break; case IntAndFloatArrayLeafNode: case StringAndFloatLeafNode: int numPairs = raf.readInt(); int[] ints = new int[numPairs]; float[] floats = new float[numPairs]; for (int d = 0; d < numPairs; d++) { ints[d] = raf.readInt(); floats[d] = raf.readFloat(); } if (leafNodeType == LeafNode.LeafType.IntAndFloatArrayLeafNode) lns[j] = new LeafNode.IntAndFloatArrayLeafNode(ints, floats); else lns[j] = new LeafNode.StringAndFloatLeafNode(ints, floats); break; case FeatureVectorLeafNode: throw new IllegalArgumentException( "Reading feature vector leaf nodes is not yet implemented"); case PdfLeafNode: throw new IllegalArgumentException("Reading pdf leaf nodes is not yet implemented"); } lns[j].setUniqueLeafId(j + 1); } // Graph nodes int numDirectedGraphNodes = raf.readInt(); DirectedGraphNode[] graphNodes = new DirectedGraphNode[numDirectedGraphNodes]; int[] dgnLeafIndices = new int[numDirectedGraphNodes]; int[] dgnDecIndices = new int[numDirectedGraphNodes]; for (int g = 0; g < numDirectedGraphNodes; g++) { graphNodes[g] = new DirectedGraphNode(null, null); graphNodes[g].setUniqueGraphNodeID(g + 1); dgnLeafIndices[g] = raf.readInt(); dgnDecIndices[g] = raf.readInt(); } // Now, link up the decision nodes with their daughters for (int i = 0; i < numDecNodes; i++) { // System.out.print(dns[i]+" "+dns[i].getFeatureName()+" "); for (int k = 0; k < childIndexes[i].length; k++) { Node child = childIndexToNode(childIndexes[i][k], dns, lns, graphNodes); dns[i].addDaughter(child); // System.out.print(" "+dns[i].getDaughter(k)); } // System.out.println(); } // And link up directed graph nodes for (int g = 0; g < numDirectedGraphNodes; g++) { Node leaf = childIndexToNode(dgnLeafIndices[g], dns, lns, graphNodes); graphNodes[g].setLeafNode(leaf); Node dec = childIndexToNode(dgnDecIndices[g], dns, lns, graphNodes); if (dec != null && !dec.isDecisionNode()) throw new IllegalArgumentException("Only decision nodes allowed, read " + dec.getClass()); graphNodes[g].setDecisionNode((DecisionNode) dec); // System.out.println("Graph node "+(g+1)+", leaf: "+Integer.toHexString(dgnLeafIndices[g])+", // "+leaf+" -- dec: "+Integer.toHexString(dgnDecIndices[g])+", "+dec); } Node rootNode; if (graphNodes.length > 0) { rootNode = graphNodes[0]; } else if (dns.length > 0) { rootNode = dns[0]; // CART behaviour, not sure if this is needed: // Now count all data once, so that getNumberOfData() // will return the correct figure. ((DecisionNode) rootNode).countData(); } else if (lns.length > 0) { rootNode = lns[0]; // single-leaf tree... } else { rootNode = null; } // set the rootNode as the rootNode of cart return new DirectedGraph(rootNode, featureDefinition, props); }
private void printDirectedGraphNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException { for (DirectedGraphNode g : graph.getDirectedGraphNodes()) { int id = g.getUniqueGraphNodeID(); if (id == 0) continue; // empty node, do not write Node leaf = g.getLeafNode(); int leafID = 0; int leafNodeType = 0; if (leaf != null) { if (leaf instanceof LeafNode) { leafID = ((LeafNode) leaf).getUniqueLeafId(); leafNodeType = DirectedGraphReader.LEAFNODE; } else if (leaf instanceof DirectedGraphNode) { leafID = ((DirectedGraphNode) leaf).getUniqueGraphNodeID(); leafNodeType = DirectedGraphReader.DIRECTEDGRAPHNODE; } else { throw new IllegalArgumentException("Unexpected leaf type: " + leaf.getClass()); } } DecisionNode d = g.getDecisionNode(); int decID = d != null ? d.getUniqueDecisionNodeId() : 0; if (out != null) { int outLeafId = leafID == 0 ? 0 : leafID | (leafNodeType << 30); out.writeInt(outLeafId); int outDecId = decID == 0 ? 0 : decID | (DirectedGraphReader.DECISIONNODE << 30); out.writeInt(outDecId); } if (pw != null) { pw.print("DGN" + id); if (leafID == 0) { pw.print(" 0"); } else if (leaf.isLeafNode()) { pw.print(" id" + leafID); } else { assert leaf.isDirectedGraphNode(); pw.print(" DGN" + leafID); } if (decID == 0) pw.print(" 0"); else pw.print(" -" + decID); pw.println(); } } }
private void printDecisionNodes(DirectedGraph graph, DataOutput out, PrintWriter pw) throws IOException { for (DecisionNode decNode : graph.getDecisionNodes()) { int id = decNode.getUniqueDecisionNodeId(); String nodeDefinition = decNode.getNodeDefinition(); int featureIndex = decNode.getFeatureIndex(); DecisionNode.Type nodeType = decNode.getDecisionNodeType(); if (out != null) { // dump in binary form to output out.writeInt(featureIndex); out.writeInt(nodeType.ordinal()); // Now, questionValue, which depends on nodeType switch (nodeType) { case BinaryByteDecisionNode: out.writeInt(((BinaryByteDecisionNode) decNode).getCriterionValueAsByte()); assert decNode.getNumberOfDaugthers() == 2; break; case BinaryShortDecisionNode: out.writeInt(((BinaryShortDecisionNode) decNode).getCriterionValueAsShort()); assert decNode.getNumberOfDaugthers() == 2; break; case BinaryFloatDecisionNode: out.writeFloat(((BinaryFloatDecisionNode) decNode).getCriterionValueAsFloat()); assert decNode.getNumberOfDaugthers() == 2; break; case ByteDecisionNode: case ShortDecisionNode: out.writeInt(decNode.getNumberOfDaugthers()); } // The child nodes for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) { Node daughter = decNode.getDaughter(i); if (daughter == null) { out.writeInt(0); } else if (daughter.isDecisionNode()) { int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId(); // Mark as decision node: daughterID |= DirectedGraphReader.DECISIONNODE << 30; out.writeInt(daughterID); } else if (daughter.isLeafNode()) { int daughterID = ((LeafNode) daughter).getUniqueLeafId(); // Mark as leaf node: if (daughterID != 0) daughterID |= DirectedGraphReader.LEAFNODE << 30; out.writeInt(daughterID); } else if (daughter.isDirectedGraphNode()) { int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID(); // Mark as directed graph node: if (daughterID != 0) daughterID |= DirectedGraphReader.DIRECTEDGRAPHNODE << 30; out.writeInt(daughterID); } } } if (pw != null) { // dump to print writer StringBuilder strNode = new StringBuilder("-" + id + " " + nodeDefinition); for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) { strNode.append(" "); Node daughter = decNode.getDaughter(i); if (daughter == null) { strNode.append("0"); } else if (daughter.isDecisionNode()) { int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId(); strNode.append("-").append(daughterID); out.writeInt(daughterID); } else if (daughter.isLeafNode()) { int daughterID = ((LeafNode) daughter).getUniqueLeafId(); if (daughterID == 0) strNode.append("0"); else strNode.append("id").append(daughterID); } else if (daughter.isDirectedGraphNode()) { int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID(); if (daughterID == 0) strNode.append("0"); else strNode.append("DGN").append(daughterID); } } pw.println(strNode.toString()); } } }
@Override public boolean compute() throws IOException, MaryConfigurationException { logger.info("Duration tree trainer started."); FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE)); UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE)); FeatureVector[] allFeatureVectors = featureFile.getFeatureVectors(); int maxData = Integer.parseInt(getProp(MAXDATA)); if (maxData == 0) maxData = allFeatureVectors.length; FeatureVector[] featureVectors = new FeatureVector[Math.min(maxData, allFeatureVectors.length)]; System.arraycopy(allFeatureVectors, 0, featureVectors, 0, featureVectors.length); logger.debug( "Total of " + allFeatureVectors.length + " feature vectors -- will use " + featureVectors.length); AgglomerativeClusterer clusterer = new AgglomerativeClusterer( featureVectors, featureFile.getFeatureDefinition(), null, new DurationDistanceMeasure(unitFile), Float.parseFloat(getProp(PROPORTIONTESTDATA))); DirectedGraphWriter writer = new DirectedGraphWriter(); DirectedGraph graph; int iteration = 0; do { graph = clusterer.cluster(); iteration++; if (graph != null) { writer.saveGraph(graph, getProp(DURTREE) + ".level" + iteration); } } while (clusterer.canClusterMore()); if (graph == null) { return false; } // Now replace each leaf with a FloatLeafNode containing mean and stddev for (LeafNode leaf : graph.getLeafNodes()) { FeatureVectorLeafNode fvLeaf = (FeatureVectorLeafNode) leaf; FeatureVector[] fvs = fvLeaf.getFeatureVectors(); double[] dur = new double[fvs.length]; for (int i = 0; i < fvs.length; i++) { dur[i] = unitFile.getUnit(fvs[i].getUnitIndex()).duration / (float) unitFile.getSampleRate(); } double mean = MathUtils.mean(dur); double stddev = MathUtils.standardDeviation(dur, mean); FloatLeafNode floatLeaf = new FloatLeafNode(new float[] {(float) stddev, (float) mean}); Node mother = fvLeaf.getMother(); assert mother != null; if (mother.isDecisionNode()) { ((DecisionNode) mother).replaceDaughter(floatLeaf, fvLeaf.getNodeIndex()); } else { assert mother.isDirectedGraphNode(); assert ((DirectedGraphNode) mother).getLeafNode() == fvLeaf; ((DirectedGraphNode) mother).setLeafNode(floatLeaf); } } writer.saveGraph(graph, getProp(DURTREE)); return true; }