예제 #1
2
  /**
   * Load the directed graph from the given file
   *
   * @param fileName the file to load the cart from
   * @param featDefinition the feature definition
   * @param dummy unused, just here for compatibility with the FeatureFileIndexer.
   * @throws IOException , {@link MaryConfigurationException} if a problem occurs while loading
   */
  public DirectedGraph load(InputStream inStream) throws IOException, MaryConfigurationException {
    BufferedInputStream buffInStream = new BufferedInputStream(inStream);
    assert buffInStream.markSupported();
    buffInStream.mark(10000);
    // open the CART-File and read the header
    DataInput raf = new DataInputStream(buffInStream);

    MaryHeader maryHeader = new MaryHeader(raf);
    if (!maryHeader.hasCurrentVersion()) {
      throw new IOException("Wrong version of database file");
    }
    if (maryHeader.getType() != MaryHeader.DIRECTED_GRAPH) {
      if (maryHeader.getType() == MaryHeader.CARTS) {
        buffInStream.reset();
        return new MaryCARTReader().loadFromStream(buffInStream);
      } else {
        throw new IOException("Not a directed graph file");
      }
    }

    // Read properties
    short propDataLength = raf.readShort();
    Properties props;
    if (propDataLength == 0) {
      props = null;
    } else {
      byte[] propsData = new byte[propDataLength];
      raf.readFully(propsData);
      ByteArrayInputStream bais = new ByteArrayInputStream(propsData);
      props = new Properties();
      props.load(bais);
      bais.close();
    }

    // Read the feature definition
    FeatureDefinition featureDefinition = new FeatureDefinition(raf);

    // read the decision nodes
    int numDecNodes = raf.readInt(); // number of decision nodes

    // First we need to read all nodes into memory, then we can link them properly
    // in terms of parent/child.
    DecisionNode[] dns = new DecisionNode[numDecNodes];
    int[][] childIndexes = new int[numDecNodes][];
    for (int i = 0; i < numDecNodes; i++) {
      // read one decision node
      int featureNameIndex = raf.readInt();
      int nodeTypeNr = raf.readInt();
      DecisionNode.Type nodeType = DecisionNode.Type.values()[nodeTypeNr];
      int numChildren = 2; // for binary nodes
      switch (nodeType) {
        case BinaryByteDecisionNode:
          int criterion = raf.readInt();
          dns[i] =
              new DecisionNode.BinaryByteDecisionNode(
                  featureNameIndex, (byte) criterion, featureDefinition);
          break;
        case BinaryShortDecisionNode:
          criterion = raf.readInt();
          dns[i] =
              new DecisionNode.BinaryShortDecisionNode(
                  featureNameIndex, (short) criterion, featureDefinition);
          break;
        case BinaryFloatDecisionNode:
          float floatCriterion = raf.readFloat();
          dns[i] =
              new DecisionNode.BinaryFloatDecisionNode(
                  featureNameIndex, floatCriterion, featureDefinition);
          break;
        case ByteDecisionNode:
          numChildren = raf.readInt();
          if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) {
            throw new IOException(
                "Inconsistent cart file: feature "
                    + featureDefinition.getFeatureName(featureNameIndex)
                    + " should have "
                    + featureDefinition.getNumberOfValues(featureNameIndex)
                    + " values, but decision node "
                    + i
                    + " has only "
                    + numChildren
                    + " child nodes");
          }
          dns[i] =
              new DecisionNode.ByteDecisionNode(featureNameIndex, numChildren, featureDefinition);
          break;
        case ShortDecisionNode:
          numChildren = raf.readInt();
          if (featureDefinition.getNumberOfValues(featureNameIndex) != numChildren) {
            throw new IOException(
                "Inconsistent cart file: feature "
                    + featureDefinition.getFeatureName(featureNameIndex)
                    + " should have "
                    + featureDefinition.getNumberOfValues(featureNameIndex)
                    + " values, but decision node "
                    + i
                    + " has only "
                    + numChildren
                    + " child nodes");
          }
          dns[i] =
              new DecisionNode.ShortDecisionNode(featureNameIndex, numChildren, featureDefinition);
      }
      dns[i].setUniqueDecisionNodeId(i + 1);
      // now read the children, indexes only:
      childIndexes[i] = new int[numChildren];
      for (int k = 0; k < numChildren; k++) {
        childIndexes[i][k] = raf.readInt();
      }
    }

    // read the leaves
    int numLeafNodes = raf.readInt(); // number of leaves, it does not include empty leaves
    LeafNode[] lns = new LeafNode[numLeafNodes];

    for (int j = 0; j < numLeafNodes; j++) {
      // read one leaf node
      int leafTypeNr = raf.readInt();
      LeafNode.LeafType leafNodeType = LeafNode.LeafType.values()[leafTypeNr];
      switch (leafNodeType) {
        case IntArrayLeafNode:
          int numData = raf.readInt();
          int[] data = new int[numData];
          for (int d = 0; d < numData; d++) {
            data[d] = raf.readInt();
          }
          lns[j] = new LeafNode.IntArrayLeafNode(data);
          break;
        case FloatLeafNode:
          float stddev = raf.readFloat();
          float mean = raf.readFloat();
          lns[j] = new LeafNode.FloatLeafNode(new float[] {stddev, mean});
          break;
        case IntAndFloatArrayLeafNode:
        case StringAndFloatLeafNode:
          int numPairs = raf.readInt();
          int[] ints = new int[numPairs];
          float[] floats = new float[numPairs];
          for (int d = 0; d < numPairs; d++) {
            ints[d] = raf.readInt();
            floats[d] = raf.readFloat();
          }
          if (leafNodeType == LeafNode.LeafType.IntAndFloatArrayLeafNode)
            lns[j] = new LeafNode.IntAndFloatArrayLeafNode(ints, floats);
          else lns[j] = new LeafNode.StringAndFloatLeafNode(ints, floats);
          break;
        case FeatureVectorLeafNode:
          throw new IllegalArgumentException(
              "Reading feature vector leaf nodes is not yet implemented");
        case PdfLeafNode:
          throw new IllegalArgumentException("Reading pdf leaf nodes is not yet implemented");
      }
      lns[j].setUniqueLeafId(j + 1);
    }

    // Graph nodes
    int numDirectedGraphNodes = raf.readInt();
    DirectedGraphNode[] graphNodes = new DirectedGraphNode[numDirectedGraphNodes];
    int[] dgnLeafIndices = new int[numDirectedGraphNodes];
    int[] dgnDecIndices = new int[numDirectedGraphNodes];
    for (int g = 0; g < numDirectedGraphNodes; g++) {
      graphNodes[g] = new DirectedGraphNode(null, null);
      graphNodes[g].setUniqueGraphNodeID(g + 1);
      dgnLeafIndices[g] = raf.readInt();
      dgnDecIndices[g] = raf.readInt();
    }

    // Now, link up the decision nodes with their daughters
    for (int i = 0; i < numDecNodes; i++) {
      // System.out.print(dns[i]+" "+dns[i].getFeatureName()+" ");
      for (int k = 0; k < childIndexes[i].length; k++) {
        Node child = childIndexToNode(childIndexes[i][k], dns, lns, graphNodes);
        dns[i].addDaughter(child);
        // System.out.print(" "+dns[i].getDaughter(k));
      }
      // System.out.println();
    }
    // And link up directed graph nodes
    for (int g = 0; g < numDirectedGraphNodes; g++) {
      Node leaf = childIndexToNode(dgnLeafIndices[g], dns, lns, graphNodes);
      graphNodes[g].setLeafNode(leaf);
      Node dec = childIndexToNode(dgnDecIndices[g], dns, lns, graphNodes);
      if (dec != null && !dec.isDecisionNode())
        throw new IllegalArgumentException("Only decision nodes allowed, read " + dec.getClass());
      graphNodes[g].setDecisionNode((DecisionNode) dec);
      // System.out.println("Graph node "+(g+1)+", leaf: "+Integer.toHexString(dgnLeafIndices[g])+",
      // "+leaf+" -- dec: "+Integer.toHexString(dgnDecIndices[g])+", "+dec);
    }

    Node rootNode;
    if (graphNodes.length > 0) {
      rootNode = graphNodes[0];
    } else if (dns.length > 0) {
      rootNode = dns[0];
      // CART behaviour, not sure if this is needed:
      // Now count all data once, so that getNumberOfData()
      // will return the correct figure.
      ((DecisionNode) rootNode).countData();
    } else if (lns.length > 0) {
      rootNode = lns[0]; // single-leaf tree...
    } else {
      rootNode = null;
    }

    // set the rootNode as the rootNode of cart
    return new DirectedGraph(rootNode, featureDefinition, props);
  }
예제 #2
0
 private void printDirectedGraphNodes(DirectedGraph graph, DataOutput out, PrintWriter pw)
     throws IOException {
   for (DirectedGraphNode g : graph.getDirectedGraphNodes()) {
     int id = g.getUniqueGraphNodeID();
     if (id == 0) continue; // empty node, do not write
     Node leaf = g.getLeafNode();
     int leafID = 0;
     int leafNodeType = 0;
     if (leaf != null) {
       if (leaf instanceof LeafNode) {
         leafID = ((LeafNode) leaf).getUniqueLeafId();
         leafNodeType = DirectedGraphReader.LEAFNODE;
       } else if (leaf instanceof DirectedGraphNode) {
         leafID = ((DirectedGraphNode) leaf).getUniqueGraphNodeID();
         leafNodeType = DirectedGraphReader.DIRECTEDGRAPHNODE;
       } else {
         throw new IllegalArgumentException("Unexpected leaf type: " + leaf.getClass());
       }
     }
     DecisionNode d = g.getDecisionNode();
     int decID = d != null ? d.getUniqueDecisionNodeId() : 0;
     if (out != null) {
       int outLeafId = leafID == 0 ? 0 : leafID | (leafNodeType << 30);
       out.writeInt(outLeafId);
       int outDecId = decID == 0 ? 0 : decID | (DirectedGraphReader.DECISIONNODE << 30);
       out.writeInt(outDecId);
     }
     if (pw != null) {
       pw.print("DGN" + id);
       if (leafID == 0) {
         pw.print(" 0");
       } else if (leaf.isLeafNode()) {
         pw.print(" id" + leafID);
       } else {
         assert leaf.isDirectedGraphNode();
         pw.print(" DGN" + leafID);
       }
       if (decID == 0) pw.print(" 0");
       else pw.print(" -" + decID);
       pw.println();
     }
   }
 }
예제 #3
0
  private void printDecisionNodes(DirectedGraph graph, DataOutput out, PrintWriter pw)
      throws IOException {
    for (DecisionNode decNode : graph.getDecisionNodes()) {
      int id = decNode.getUniqueDecisionNodeId();
      String nodeDefinition = decNode.getNodeDefinition();
      int featureIndex = decNode.getFeatureIndex();
      DecisionNode.Type nodeType = decNode.getDecisionNodeType();

      if (out != null) {
        // dump in binary form to output
        out.writeInt(featureIndex);
        out.writeInt(nodeType.ordinal());
        // Now, questionValue, which depends on nodeType
        switch (nodeType) {
          case BinaryByteDecisionNode:
            out.writeInt(((BinaryByteDecisionNode) decNode).getCriterionValueAsByte());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case BinaryShortDecisionNode:
            out.writeInt(((BinaryShortDecisionNode) decNode).getCriterionValueAsShort());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case BinaryFloatDecisionNode:
            out.writeFloat(((BinaryFloatDecisionNode) decNode).getCriterionValueAsFloat());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case ByteDecisionNode:
          case ShortDecisionNode:
            out.writeInt(decNode.getNumberOfDaugthers());
        }

        // The child nodes
        for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
          Node daughter = decNode.getDaughter(i);
          if (daughter == null) {
            out.writeInt(0);
          } else if (daughter.isDecisionNode()) {
            int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
            // Mark as decision node:
            daughterID |= DirectedGraphReader.DECISIONNODE << 30;
            out.writeInt(daughterID);
          } else if (daughter.isLeafNode()) {
            int daughterID = ((LeafNode) daughter).getUniqueLeafId();
            // Mark as leaf node:
            if (daughterID != 0) daughterID |= DirectedGraphReader.LEAFNODE << 30;
            out.writeInt(daughterID);
          } else if (daughter.isDirectedGraphNode()) {
            int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
            // Mark as directed graph node:
            if (daughterID != 0) daughterID |= DirectedGraphReader.DIRECTEDGRAPHNODE << 30;
            out.writeInt(daughterID);
          }
        }
      }
      if (pw != null) {
        // dump to print writer
        StringBuilder strNode = new StringBuilder("-" + id + " " + nodeDefinition);
        for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
          strNode.append(" ");
          Node daughter = decNode.getDaughter(i);
          if (daughter == null) {
            strNode.append("0");
          } else if (daughter.isDecisionNode()) {
            int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
            strNode.append("-").append(daughterID);
            out.writeInt(daughterID);
          } else if (daughter.isLeafNode()) {
            int daughterID = ((LeafNode) daughter).getUniqueLeafId();
            if (daughterID == 0) strNode.append("0");
            else strNode.append("id").append(daughterID);
          } else if (daughter.isDirectedGraphNode()) {
            int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
            if (daughterID == 0) strNode.append("0");
            else strNode.append("DGN").append(daughterID);
          }
        }
        pw.println(strNode.toString());
      }
    }
  }
예제 #4
0
  @Override
  public boolean compute() throws IOException, MaryConfigurationException {
    logger.info("Duration tree trainer started.");
    FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE));
    UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE));

    FeatureVector[] allFeatureVectors = featureFile.getFeatureVectors();
    int maxData = Integer.parseInt(getProp(MAXDATA));
    if (maxData == 0) maxData = allFeatureVectors.length;
    FeatureVector[] featureVectors = new FeatureVector[Math.min(maxData, allFeatureVectors.length)];
    System.arraycopy(allFeatureVectors, 0, featureVectors, 0, featureVectors.length);
    logger.debug(
        "Total of "
            + allFeatureVectors.length
            + " feature vectors -- will use "
            + featureVectors.length);

    AgglomerativeClusterer clusterer =
        new AgglomerativeClusterer(
            featureVectors,
            featureFile.getFeatureDefinition(),
            null,
            new DurationDistanceMeasure(unitFile),
            Float.parseFloat(getProp(PROPORTIONTESTDATA)));
    DirectedGraphWriter writer = new DirectedGraphWriter();
    DirectedGraph graph;
    int iteration = 0;
    do {
      graph = clusterer.cluster();
      iteration++;
      if (graph != null) {
        writer.saveGraph(graph, getProp(DURTREE) + ".level" + iteration);
      }
    } while (clusterer.canClusterMore());

    if (graph == null) {
      return false;
    }

    // Now replace each leaf with a FloatLeafNode containing mean and stddev
    for (LeafNode leaf : graph.getLeafNodes()) {
      FeatureVectorLeafNode fvLeaf = (FeatureVectorLeafNode) leaf;
      FeatureVector[] fvs = fvLeaf.getFeatureVectors();
      double[] dur = new double[fvs.length];
      for (int i = 0; i < fvs.length; i++) {
        dur[i] =
            unitFile.getUnit(fvs[i].getUnitIndex()).duration / (float) unitFile.getSampleRate();
      }
      double mean = MathUtils.mean(dur);
      double stddev = MathUtils.standardDeviation(dur, mean);
      FloatLeafNode floatLeaf = new FloatLeafNode(new float[] {(float) stddev, (float) mean});
      Node mother = fvLeaf.getMother();
      assert mother != null;
      if (mother.isDecisionNode()) {
        ((DecisionNode) mother).replaceDaughter(floatLeaf, fvLeaf.getNodeIndex());
      } else {
        assert mother.isDirectedGraphNode();
        assert ((DirectedGraphNode) mother).getLeafNode() == fvLeaf;
        ((DirectedGraphNode) mother).setLeafNode(floatLeaf);
      }
    }
    writer.saveGraph(graph, getProp(DURTREE));
    return true;
  }