Example #1
0
  /**
   * Dump the graph in Mary format
   *
   * @param graph graph
   * @param destFile the destination file
   * @throws IOException IOException
   */
  public void saveGraph(DirectedGraph graph, String destFile) throws IOException {
    if (graph == null) throw new NullPointerException("Cannot dump null graph");
    if (destFile == null) throw new NullPointerException("No destination file");

    logger.debug("Dumping directed graph in Mary format to " + destFile + " ...");

    // Open the destination file and output the header
    DataOutputStream out =
        new DataOutputStream(new BufferedOutputStream(new FileOutputStream(destFile)));
    // create new CART-header and write it to output file
    MaryHeader hdr = new MaryHeader(MaryHeader.DIRECTED_GRAPH);
    hdr.writeTo(out);

    Properties props = graph.getProperties();
    if (props == null) {
      out.writeShort(0);
    } else {
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      props.store(baos, null);
      byte[] propData = baos.toByteArray();
      out.writeShort(propData.length);
      out.write(propData);
    }

    // feature definition
    graph.getFeatureDefinition().writeBinaryTo(out);

    // dump graph
    dumpBinary(graph, out);

    // finish
    out.close();
    logger.debug(" ... done\n");
  }
Example #2
0
 /**
  * Assign unique ids to directed graph nodes.
  *
  * @param graph
  * @return the number of different directed graph nodes
  */
 private int setUniqueDirectedGraphNodeIds(DirectedGraph graph) {
   int i = 0;
   for (DirectedGraphNode g : graph.getDirectedGraphNodes()) {
     g.setUniqueGraphNodeID(++i);
   }
   return i;
 }
Example #3
0
 /**
  * Assign unique ids to decision nodes.
  *
  * @param graph
  * @return the number of different decision nodes
  */
 private int setUniqueDecisionNodeIds(DirectedGraph graph) {
   int i = 0;
   for (DecisionNode d : graph.getDecisionNodes()) {
     d.setUniqueDecisionNodeId(++i);
   }
   return i;
 }
Example #4
0
 /**
  * Assign unique ids to leaf nodes.
  *
  * @param graph
  * @return the number of different leaf nodes
  */
 private int setUniqueLeafNodeIds(DirectedGraph graph) {
   int i = 0;
   for (LeafNode l : graph.getLeafNodes()) {
     l.setUniqueLeafId(++i);
   }
   return i;
 }
Example #5
0
 private void printDirectedGraphNodes(DirectedGraph graph, DataOutput out, PrintWriter pw)
     throws IOException {
   for (DirectedGraphNode g : graph.getDirectedGraphNodes()) {
     int id = g.getUniqueGraphNodeID();
     if (id == 0) continue; // empty node, do not write
     Node leaf = g.getLeafNode();
     int leafID = 0;
     int leafNodeType = 0;
     if (leaf != null) {
       if (leaf instanceof LeafNode) {
         leafID = ((LeafNode) leaf).getUniqueLeafId();
         leafNodeType = DirectedGraphReader.LEAFNODE;
       } else if (leaf instanceof DirectedGraphNode) {
         leafID = ((DirectedGraphNode) leaf).getUniqueGraphNodeID();
         leafNodeType = DirectedGraphReader.DIRECTEDGRAPHNODE;
       } else {
         throw new IllegalArgumentException("Unexpected leaf type: " + leaf.getClass());
       }
     }
     DecisionNode d = g.getDecisionNode();
     int decID = d != null ? d.getUniqueDecisionNodeId() : 0;
     if (out != null) {
       int outLeafId = leafID == 0 ? 0 : leafID | (leafNodeType << 30);
       out.writeInt(outLeafId);
       int outDecId = decID == 0 ? 0 : decID | (DirectedGraphReader.DECISIONNODE << 30);
       out.writeInt(outDecId);
     }
     if (pw != null) {
       pw.print("DGN" + id);
       if (leafID == 0) {
         pw.print(" 0");
       } else if (leaf.isLeafNode()) {
         pw.print(" id" + leafID);
       } else {
         assert leaf.isDirectedGraphNode();
         pw.print(" DGN" + leafID);
       }
       if (decID == 0) pw.print(" 0");
       else pw.print(" -" + decID);
       pw.println();
     }
   }
 }
Example #6
0
 private void printLeafNodes(DirectedGraph graph, DataOutput out, PrintWriter pw)
     throws IOException {
   for (LeafNode leaf : graph.getLeafNodes()) {
     if (leaf.getUniqueLeafId() == 0) // empty leaf, do not write
     continue;
     LeafType leafType = leaf.getLeafNodeType();
     if (leafType == LeafType.FeatureVectorLeafNode) {
       leafType = LeafType.IntArrayLeafNode;
       // save feature vector leaf nodes as int array leaf nodes
     }
     if (out != null) {
       // Leaf node type
       out.writeInt(leafType.ordinal());
     }
     if (pw != null) {
       pw.print("id" + leaf.getUniqueLeafId() + " " + leafType);
     }
     switch (leaf.getLeafNodeType()) {
       case IntArrayLeafNode:
         int data[] = ((IntArrayLeafNode) leaf).getIntData();
         // Number of data points following:
         if (out != null) out.writeInt(data.length);
         if (pw != null) pw.print(" " + data.length);
         // for each index, write the index
         for (int i = 0; i < data.length; i++) {
           if (out != null) out.writeInt(data[i]);
           if (pw != null) pw.print(" " + data[i]);
         }
         break;
       case FloatLeafNode:
         float stddev = ((FloatLeafNode) leaf).getStDeviation();
         float mean = ((FloatLeafNode) leaf).getMean();
         if (out != null) {
           out.writeFloat(stddev);
           out.writeFloat(mean);
         }
         if (pw != null) {
           pw.print(" 1 " + stddev + " " + mean);
         }
         break;
       case IntAndFloatArrayLeafNode:
       case StringAndFloatLeafNode:
         int data1[] = ((IntAndFloatArrayLeafNode) leaf).getIntData();
         float floats[] = ((IntAndFloatArrayLeafNode) leaf).getFloatData();
         // Number of data points following:
         if (out != null) out.writeInt(data1.length);
         if (pw != null) pw.print(" " + data1.length);
         // for each index, write the index and then its float
         for (int i = 0; i < data1.length; i++) {
           if (out != null) {
             out.writeInt(data1[i]);
             out.writeFloat(floats[i]);
           }
           if (pw != null) pw.print(" " + data1[i] + " " + floats[i]);
         }
         break;
       case FeatureVectorLeafNode:
         FeatureVector fv[] = ((FeatureVectorLeafNode) leaf).getFeatureVectors();
         // Number of data points following:
         if (out != null) out.writeInt(fv.length);
         if (pw != null) pw.print(" " + fv.length);
         // for each feature vector, write the index
         for (int i = 0; i < fv.length; i++) {
           if (out != null) out.writeInt(fv[i].getUnitIndex());
           if (pw != null) pw.print(" " + fv[i].getUnitIndex());
         }
         break;
       case PdfLeafNode:
         throw new IllegalArgumentException("Writing of pdf leaf nodes not yet implemented");
     }
     if (pw != null) pw.println();
   }
 }
Example #7
0
  private void printDecisionNodes(DirectedGraph graph, DataOutput out, PrintWriter pw)
      throws IOException {
    for (DecisionNode decNode : graph.getDecisionNodes()) {
      int id = decNode.getUniqueDecisionNodeId();
      String nodeDefinition = decNode.getNodeDefinition();
      int featureIndex = decNode.getFeatureIndex();
      DecisionNode.Type nodeType = decNode.getDecisionNodeType();

      if (out != null) {
        // dump in binary form to output
        out.writeInt(featureIndex);
        out.writeInt(nodeType.ordinal());
        // Now, questionValue, which depends on nodeType
        switch (nodeType) {
          case BinaryByteDecisionNode:
            out.writeInt(((BinaryByteDecisionNode) decNode).getCriterionValueAsByte());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case BinaryShortDecisionNode:
            out.writeInt(((BinaryShortDecisionNode) decNode).getCriterionValueAsShort());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case BinaryFloatDecisionNode:
            out.writeFloat(((BinaryFloatDecisionNode) decNode).getCriterionValueAsFloat());
            assert decNode.getNumberOfDaugthers() == 2;
            break;
          case ByteDecisionNode:
          case ShortDecisionNode:
            out.writeInt(decNode.getNumberOfDaugthers());
        }

        // The child nodes
        for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
          Node daughter = decNode.getDaughter(i);
          if (daughter == null) {
            out.writeInt(0);
          } else if (daughter.isDecisionNode()) {
            int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
            // Mark as decision node:
            daughterID |= DirectedGraphReader.DECISIONNODE << 30;
            out.writeInt(daughterID);
          } else if (daughter.isLeafNode()) {
            int daughterID = ((LeafNode) daughter).getUniqueLeafId();
            // Mark as leaf node:
            if (daughterID != 0) daughterID |= DirectedGraphReader.LEAFNODE << 30;
            out.writeInt(daughterID);
          } else if (daughter.isDirectedGraphNode()) {
            int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
            // Mark as directed graph node:
            if (daughterID != 0) daughterID |= DirectedGraphReader.DIRECTEDGRAPHNODE << 30;
            out.writeInt(daughterID);
          }
        }
      }
      if (pw != null) {
        // dump to print writer
        StringBuilder strNode = new StringBuilder("-" + id + " " + nodeDefinition);
        for (int i = 0, n = decNode.getNumberOfDaugthers(); i < n; i++) {
          strNode.append(" ");
          Node daughter = decNode.getDaughter(i);
          if (daughter == null) {
            strNode.append("0");
          } else if (daughter.isDecisionNode()) {
            int daughterID = ((DecisionNode) daughter).getUniqueDecisionNodeId();
            strNode.append("-").append(daughterID);
            out.writeInt(daughterID);
          } else if (daughter.isLeafNode()) {
            int daughterID = ((LeafNode) daughter).getUniqueLeafId();
            if (daughterID == 0) strNode.append("0");
            else strNode.append("id").append(daughterID);
          } else if (daughter.isDirectedGraphNode()) {
            int daughterID = ((DirectedGraphNode) daughter).getUniqueGraphNodeID();
            if (daughterID == 0) strNode.append("0");
            else strNode.append("DGN").append(daughterID);
          }
        }
        pw.println(strNode.toString());
      }
    }
  }
  @Override
  public boolean compute() throws IOException, MaryConfigurationException {
    logger.info("Duration tree trainer started.");
    FeatureFileReader featureFile = FeatureFileReader.getFeatureFileReader(getProp(FEATUREFILE));
    UnitFileReader unitFile = new UnitFileReader(getProp(UNITFILE));

    FeatureVector[] allFeatureVectors = featureFile.getFeatureVectors();
    int maxData = Integer.parseInt(getProp(MAXDATA));
    if (maxData == 0) maxData = allFeatureVectors.length;
    FeatureVector[] featureVectors = new FeatureVector[Math.min(maxData, allFeatureVectors.length)];
    System.arraycopy(allFeatureVectors, 0, featureVectors, 0, featureVectors.length);
    logger.debug(
        "Total of "
            + allFeatureVectors.length
            + " feature vectors -- will use "
            + featureVectors.length);

    AgglomerativeClusterer clusterer =
        new AgglomerativeClusterer(
            featureVectors,
            featureFile.getFeatureDefinition(),
            null,
            new DurationDistanceMeasure(unitFile),
            Float.parseFloat(getProp(PROPORTIONTESTDATA)));
    DirectedGraphWriter writer = new DirectedGraphWriter();
    DirectedGraph graph;
    int iteration = 0;
    do {
      graph = clusterer.cluster();
      iteration++;
      if (graph != null) {
        writer.saveGraph(graph, getProp(DURTREE) + ".level" + iteration);
      }
    } while (clusterer.canClusterMore());

    if (graph == null) {
      return false;
    }

    // Now replace each leaf with a FloatLeafNode containing mean and stddev
    for (LeafNode leaf : graph.getLeafNodes()) {
      FeatureVectorLeafNode fvLeaf = (FeatureVectorLeafNode) leaf;
      FeatureVector[] fvs = fvLeaf.getFeatureVectors();
      double[] dur = new double[fvs.length];
      for (int i = 0; i < fvs.length; i++) {
        dur[i] =
            unitFile.getUnit(fvs[i].getUnitIndex()).duration / (float) unitFile.getSampleRate();
      }
      double mean = MathUtils.mean(dur);
      double stddev = MathUtils.standardDeviation(dur, mean);
      FloatLeafNode floatLeaf = new FloatLeafNode(new float[] {(float) stddev, (float) mean});
      Node mother = fvLeaf.getMother();
      assert mother != null;
      if (mother.isDecisionNode()) {
        ((DecisionNode) mother).replaceDaughter(floatLeaf, fvLeaf.getNodeIndex());
      } else {
        assert mother.isDirectedGraphNode();
        assert ((DirectedGraphNode) mother).getLeafNode() == fvLeaf;
        ((DirectedGraphNode) mother).setLeafNode(floatLeaf);
      }
    }
    writer.saveGraph(graph, getProp(DURTREE));
    return true;
  }