Ejemplo n.º 1
0
 private void printDirectedGraphNodes(DirectedGraph graph, DataOutput out, PrintWriter pw)
     throws IOException {
   for (DirectedGraphNode g : graph.getDirectedGraphNodes()) {
     int id = g.getUniqueGraphNodeID();
     if (id == 0) continue; // empty node, do not write
     Node leaf = g.getLeafNode();
     int leafID = 0;
     int leafNodeType = 0;
     if (leaf != null) {
       if (leaf instanceof LeafNode) {
         leafID = ((LeafNode) leaf).getUniqueLeafId();
         leafNodeType = DirectedGraphReader.LEAFNODE;
       } else if (leaf instanceof DirectedGraphNode) {
         leafID = ((DirectedGraphNode) leaf).getUniqueGraphNodeID();
         leafNodeType = DirectedGraphReader.DIRECTEDGRAPHNODE;
       } else {
         throw new IllegalArgumentException("Unexpected leaf type: " + leaf.getClass());
       }
     }
     DecisionNode d = g.getDecisionNode();
     int decID = d != null ? d.getUniqueDecisionNodeId() : 0;
     if (out != null) {
       int outLeafId = leafID == 0 ? 0 : leafID | (leafNodeType << 30);
       out.writeInt(outLeafId);
       int outDecId = decID == 0 ? 0 : decID | (DirectedGraphReader.DECISIONNODE << 30);
       out.writeInt(outDecId);
     }
     if (pw != null) {
       pw.print("DGN" + id);
       if (leafID == 0) {
         pw.print(" 0");
       } else if (leaf.isLeafNode()) {
         pw.print(" id" + leafID);
       } else {
         assert leaf.isDirectedGraphNode();
         pw.print(" DGN" + leafID);
       }
       if (decID == 0) pw.print(" 0");
       else pw.print(" -" + decID);
       pw.println();
     }
   }
 }
Ejemplo n.º 2
0
  private void printDecisionNodes(DirectedGraph graph, DataOutput out, PrintWriter pw)
      throws IOException {
    for (DecisionNode decNode : graph.getDecisionNodes()) {
      int id = decNode.getUniqueDecisionNodeId();
      String nodeDefinition = decNode.getNodeDefinition();
      int featureIndex = decNode.getFeatureIndex();
      DecisionNode.Type nodeType = decNode.getDecisionNodeType();

      if (out != null) {
        // dump in binary form to output
        out.writeInt(featureIndex);
        out.writeInt(nodeType.ordinal());
        // Now, questionValue, which depends on nodeType
        switch (nodeType) {
          case BinaryByteDecisionNode:
            out.writeInt(((BinaryByteDecisionNode) decNode).getCriterionValueAsByte());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case BinaryShortDecisionNode:
            out.writeInt(((BinaryShortDecisionNode) decNode).getCriterionValueAsShort());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case BinaryFloatDecisionNode:
            out.writeFloat(((BinaryFloatDecisionNode) decNode).getCriterionValueAsFloat());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case ByteDecisionNode:
          case ShortDecisionNode:
            out.writeInt(decNode.getNumberOfDaugthers());
        }

        // The child nodes
        for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
          Node daughter = decNode.getDaughter(i);
          if (daughter == null) {
            out.writeInt(0);
          } else if (daughter.isDecisionNode()) {
            int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
            // Mark as decision node:
            daughterID |= DirectedGraphReader.DECISIONNODE << 30;
            out.writeInt(daughterID);
          } else if (daughter.isLeafNode()) {
            int daughterID = ((LeafNode) daughter).getUniqueLeafId();
            // Mark as leaf node:
            if (daughterID != 0) daughterID |= DirectedGraphReader.LEAFNODE << 30;
            out.writeInt(daughterID);
          } else if (daughter.isDirectedGraphNode()) {
            int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
            // Mark as directed graph node:
            if (daughterID != 0) daughterID |= DirectedGraphReader.DIRECTEDGRAPHNODE << 30;
            out.writeInt(daughterID);
          }
        }
      }
      if (pw != null) {
        // dump to print writer
        StringBuilder strNode = new StringBuilder("-" + id + " " + nodeDefinition);
        for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
          strNode.append(" ");
          Node daughter = decNode.getDaughter(i);
          if (daughter == null) {
            strNode.append("0");
          } else if (daughter.isDecisionNode()) {
            int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
            strNode.append("-").append(daughterID);
            out.writeInt(daughterID);
          } else if (daughter.isLeafNode()) {
            int daughterID = ((LeafNode) daughter).getUniqueLeafId();
            if (daughterID == 0) strNode.append("0");
            else strNode.append("id").append(daughterID);
          } else if (daughter.isDirectedGraphNode()) {
            int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
            if (daughterID == 0) strNode.append("0");
            else strNode.append("DGN").append(daughterID);
          }
        }
        pw.println(strNode.toString());
      }
    }
  }
Ejemplo n.º 3
0
  @Override
  public boolean compute() throws IOException, MaryConfigurationException {
    logger.info("Duration tree trainer started.");
    FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE));
    UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE));

    FeatureVector[] allFeatureVectors = featureFile.getFeatureVectors();
    int maxData = Integer.parseInt(getProp(MAXDATA));
    if (maxData == 0) maxData = allFeatureVectors.length;
    FeatureVector[] featureVectors = new FeatureVector[Math.min(maxData, allFeatureVectors.length)];
    System.arraycopy(allFeatureVectors, 0, featureVectors, 0, featureVectors.length);
    logger.debug(
        "Total of "
            + allFeatureVectors.length
            + " feature vectors -- will use "
            + featureVectors.length);

    AgglomerativeClusterer clusterer =
        new AgglomerativeClusterer(
            featureVectors,
            featureFile.getFeatureDefinition(),
            null,
            new DurationDistanceMeasure(unitFile),
            Float.parseFloat(getProp(PROPORTIONTESTDATA)));
    DirectedGraphWriter writer = new DirectedGraphWriter();
    DirectedGraph graph;
    int iteration = 0;
    do {
      graph = clusterer.cluster();
      iteration++;
      if (graph != null) {
        writer.saveGraph(graph, getProp(DURTREE) + ".level" + iteration);
      }
    } while (clusterer.canClusterMore());

    if (graph == null) {
      return false;
    }

    // Now replace each leaf with a FloatLeafNode containing mean and stddev
    for (LeafNode leaf : graph.getLeafNodes()) {
      FeatureVectorLeafNode fvLeaf = (FeatureVectorLeafNode) leaf;
      FeatureVector[] fvs = fvLeaf.getFeatureVectors();
      double[] dur = new double[fvs.length];
      for (int i = 0; i < fvs.length; i++) {
        dur[i] =
            unitFile.getUnit(fvs[i].getUnitIndex()).duration / (float) unitFile.getSampleRate();
      }
      double mean = MathUtils.mean(dur);
      double stddev = MathUtils.standardDeviation(dur, mean);
      FloatLeafNode floatLeaf = new FloatLeafNode(new float[] {(float) stddev, (float) mean});
      Node mother = fvLeaf.getMother();
      assert mother != null;
      if (mother.isDecisionNode()) {
        ((DecisionNode) mother).replaceDaughter(floatLeaf, fvLeaf.getNodeIndex());
      } else {
        assert mother.isDirectedGraphNode();
        assert ((DirectedGraphNode) mother).getLeafNode() == fvLeaf;
        ((DirectedGraphNode) mother).setLeafNode(floatLeaf);
      }
    }
    writer.saveGraph(graph, getProp(DURTREE));
    return true;
  }