Exemple #1
0
 public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
   pw.println("#doc source topic proportion ...");
   int docLen;
   double topicDist[] = new double[topics.length];
   for (int di = 0; di < topics.length; di++) {
     pw.print(di);
     pw.print(' ');
     if (ilist.get(di).getSource() != null) {
       pw.print(ilist.get(di).getSource().toString());
     } else {
       pw.print("null-source");
     }
     pw.print(' ');
     docLen = topics[di].length;
     for (int ti = 0; ti < numTopics; ti++)
       topicDist[ti] = (((float) docTopicCounts[di][ti]) / docLen);
     if (max < 0) max = numTopics;
     for (int tp = 0; tp < max; tp++) {
       double maxvalue = 0;
       int maxindex = -1;
       for (int ti = 0; ti < numTopics; ti++)
         if (topicDist[ti] > maxvalue) {
           maxvalue = topicDist[ti];
           maxindex = ti;
         }
       if (maxindex == -1 || topicDist[maxindex] < threshold) break;
       pw.print(maxindex + " " + topicDist[maxindex] + " ");
       topicDist[maxindex] = 0;
     }
     pw.println(' ');
   }
 }
  public static InstanceList scale(InstanceList trainingList, double lower, double upper) {
    InstanceList ret = copy(trainingList);
    Alphabet featDict = ret.getDataAlphabet();

    double[] feat_max = new double[featDict.size()];
    double[] feat_min = new double[featDict.size()];

    for (int i = 0; i < feat_max.length; i++) {
      feat_max[i] = -Double.MAX_VALUE;
      feat_min[i] = Double.MAX_VALUE;
    }

    for (int i = 0; i < ret.size(); i++) {
      Instance inst = ret.get(i);
      FeatureVector fv = (FeatureVector) inst.getData();

      for (int loc = 0; loc < fv.numLocations(); loc++) {
        int featId = fv.indexAtLocation(loc);
        double value = fv.valueAtLocation(loc);
        double maxValue = feat_max[featId];
        double minValue = feat_min[featId];

        double newMaxValue = Math.max(value, maxValue);
        double newMinValue = Math.min(value, minValue);

        feat_max[featId] = newMaxValue;
        feat_min[featId] = newMinValue;
      }
    }

    // double lower = -1;
    // double upper = 1;

    for (int i = 0; i < ret.size(); i++) {
      Instance inst = ret.get(i);
      FeatureVector fv = (FeatureVector) inst.getData();

      for (int loc = 0; loc < fv.numLocations(); loc++) {
        int featId = fv.indexAtLocation(loc);
        double value = fv.valueAtLocation(loc);
        double maxValue = feat_max[featId];
        double minValue = feat_min[featId];
        double newValue = Double.NaN;
        if (maxValue == minValue) {
          newValue = value;
        } else if (value == minValue) {
          newValue = lower;
        } else if (value == maxValue) {
          newValue = upper;
        } else {
          newValue = lower + (upper - lower) * (value - minValue) / (maxValue - minValue);
        }

        fv.setValueAtLocation(loc, newValue);
      }
    }

    return ret;
  }
  public void count() {

    TIntIntHashMap docCounts = new TIntIntHashMap();

    int index = 0;

    if (instances.size() == 0) {
      logger.info("Instance list is empty");
      return;
    }

    if (instances.get(0).getData() instanceof FeatureSequence) {

      for (Instance instance : instances) {
        FeatureSequence features = (FeatureSequence) instance.getData();

        for (int i = 0; i < features.getLength(); i++) {
          docCounts.adjustOrPutValue(features.getIndexAtPosition(i), 1, 1);
        }

        int[] keys = docCounts.keys();
        for (int i = 0; i < keys.length - 1; i++) {
          int feature = keys[i];
          featureCounts[feature] += docCounts.get(feature);
          documentFrequencies[feature]++;
        }

        docCounts = new TIntIntHashMap();

        index++;
        if (index % 1000 == 0) {
          System.err.println(index);
        }
      }
    } else if (instances.get(0).getData() instanceof FeatureVector) {

      for (Instance instance : instances) {
        FeatureVector features = (FeatureVector) instance.getData();

        for (int location = 0; location < features.numLocations(); location++) {
          int feature = features.indexAtLocation(location);
          double value = features.valueAtLocation(location);

          documentFrequencies[feature]++;
          featureCounts[feature] += value;
        }

        index++;
        if (index % 1000 == 0) {
          System.err.println(index);
        }
      }
    } else {
      logger.info("Unsupported data class: " + instances.get(0).getData().getClass().getName());
    }
  }
  /**
   * Initialize this separate model using a complete list.
   *
   * @param documents
   * @param testStartIndex
   */
  public void divideDocuments(InstanceList documents, int testStartIndex) {
    Alphabet dataAlpha = documents.getDataAlphabet();
    Alphabet targetAlpha = documents.getTargetAlphabet();

    this.training = new InstanceList(dataAlpha, targetAlpha);
    this.test = new InstanceList(dataAlpha, targetAlpha);
    int di = 0;
    for (di = 0; di < testStartIndex; di++) {
      training.add(documents.get(di));
    }
    for (di = testStartIndex; di < documents.size(); di++) {
      test.add(documents.get(di));
    }
  }
  public static InstanceList copy(InstanceList instances) {
    InstanceList ret = (InstanceList) instances.clone();
    // LabelAlphabet labelDict = (LabelAlphabet) ret.getTargetAlphabet();
    Alphabet featDict = ret.getDataAlphabet();

    for (int i = 0; i < ret.size(); i++) {
      Instance instance = ret.get(i);
      Instance clone = (Instance) instance.clone();
      FeatureVector fv = (FeatureVector) clone.getData();

      int[] indices = fv.getIndices();
      double[] values = fv.getValues();

      int[] newIndices = new int[indices.length];
      System.arraycopy(indices, 0, newIndices, 0, indices.length);

      double[] newValues = new double[indices.length];
      System.arraycopy(values, 0, newValues, 0, indices.length);

      FeatureVector newFv = new FeatureVector(featDict, newIndices, newValues);
      Instance newInstance =
          new Instance(newFv, instance.getTarget(), instance.getName(), instance.getSource());
      ret.set(i, newInstance);
    }

    return ret;
  }
Exemple #6
0
 public void split() {
   if (m_ilist == null) throw new IllegalStateException("Frozen.  Cannot split.");
   int numLeftChildren = 0;
   boolean[] toLeftChild = new boolean[m_instIndices.length];
   for (int i = 0; i < m_instIndices.length; i++) {
     Instance instance = m_ilist.get(m_instIndices[i]);
     FeatureVector fv = (FeatureVector) instance.getData();
     if (fv.value(m_gainRatio.getMaxValuedIndex()) <= m_gainRatio.getMaxValuedThreshold()) {
       toLeftChild[i] = true;
       numLeftChildren++;
     } else toLeftChild[i] = false;
   }
   logger.info(
       "leftChild.size="
           + numLeftChildren
           + " rightChild.size="
           + (m_instIndices.length - numLeftChildren));
   int[] leftIndices = new int[numLeftChildren];
   int[] rightIndices = new int[m_instIndices.length - numLeftChildren];
   int li = 0, ri = 0;
   for (int i = 0; i < m_instIndices.length; i++) {
     if (toLeftChild[i]) leftIndices[li++] = m_instIndices[i];
     else rightIndices[ri++] = m_instIndices[i];
   }
   m_leftChild = new Node(m_ilist, this, m_minNumInsts, leftIndices);
   m_rightChild = new Node(m_ilist, this, m_minNumInsts, rightIndices);
 }
  /**
   * converts the sentence based instance list into a token based one This is needed for the
   * ME-version of JET (JetMeClassifier)
   *
   * @param METrainerDummyPipe
   * @param orgList the sentence based instance list
   * @return
   */
  public static InstanceList convertFeatsforClassifier(
      final Pipe METrainerDummyPipe, final InstanceList orgList) {

    final InstanceList iList = new InstanceList(METrainerDummyPipe);

    for (int i = 0; i < orgList.size(); i++) {
      final Instance inst = orgList.get(i);

      final FeatureVectorSequence fvs = (FeatureVectorSequence) inst.getData();
      final LabelSequence ls = (LabelSequence) inst.getTarget();
      final LabelAlphabet ldict = (LabelAlphabet) ls.getAlphabet();
      final Object source = inst.getSource();
      final Object name = inst.getName();

      if (ls.size() != fvs.size()) {
        System.err.println(
            "failed making token instances: size of labelsequence != size of featue vector sequence: "
                + ls.size()
                + " - "
                + fvs.size());
        System.exit(-1);
      }

      for (int j = 0; j < fvs.size(); j++) {
        final Instance I =
            new Instance(fvs.getFeatureVector(j), ldict.lookupLabel(ls.get(j)), name, source);

        iList.add(I);
      }
    }

    return iList;
  }
Exemple #8
0
 private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
   int featuresLength;
   int version = in.readInt();
   ilist = (InstanceList) in.readObject();
   numTopics = in.readInt();
   alpha = in.readDouble();
   beta = in.readDouble();
   tAlpha = in.readDouble();
   vBeta = in.readDouble();
   int numDocs = ilist.size();
   topics = new int[numDocs][];
   for (int di = 0; di < ilist.size(); di++) {
     int docLen = ((FeatureSequence) ilist.get(di).getData()).getLength();
     topics[di] = new int[docLen];
     for (int si = 0; si < docLen; si++) topics[di][si] = in.readInt();
   }
   docTopicCounts = new int[numDocs][numTopics];
   for (int di = 0; di < ilist.size(); di++)
     for (int ti = 0; ti < numTopics; ti++) docTopicCounts[di][ti] = in.readInt();
   int numTypes = ilist.getDataAlphabet().size();
   typeTopicCounts = new int[numTypes][numTopics];
   for (int fi = 0; fi < numTypes; fi++)
     for (int ti = 0; ti < numTopics; ti++) typeTopicCounts[fi][ti] = in.readInt();
   tokensPerTopic = new int[numTopics];
   for (int ti = 0; ti < numTopics; ti++) tokensPerTopic[ti] = in.readInt();
 }
  public SVM train(InstanceList trainingList) {
    svm_problem problem = new svm_problem();
    problem.l = trainingList.size();
    problem.x = new svm_node[problem.l][];
    problem.y = new double[problem.l];

    for (int i = 0; i < trainingList.size(); i++) {
      Instance instance = trainingList.get(i);
      svm_node[] input = SVM.getSvmNodes(instance);
      if (input == null) {
        continue;
      }
      int labelIndex = ((Label) instance.getTarget()).getIndex();
      problem.x[i] = input;
      problem.y[i] = labelIndex;
    }

    int max_index = trainingList.getDataAlphabet().size();

    if (param.gamma == 0 && max_index > 0) {
      param.gamma = 1.0 / max_index;
    }

    // int numLabels = trainingList.getTargetAlphabet().size();
    // int[] weight_label = new int[numLabels];
    // double[] weight = trainingList.targetLabelDistribution().getValues();
    // double minValue = Double.MAX_VALUE;
    //
    // for (int i = 0; i < weight.length; i++) {
    // if (minValue > weight[i]) {
    // minValue = weight[i];
    // }
    // }
    //
    // for (int i = 0; i < weight.length; i++) {
    // weight_label[i] = i;
    // weight[i] = weight[i] / minValue;
    // }
    //
    // param.weight_label = weight_label;
    // param.weight = weight;

    String error_msg = svm.svm_check_parameter(problem, param);

    if (error_msg != null) {
      System.err.print("Error: " + error_msg + "\n");
      System.exit(1);
    }

    svm_model model = svm.svm_train(problem, param);

    classifier = new SVM(model, trainingList.getPipe());

    return classifier;
  }
Exemple #10
0
 /* One iteration of Gibbs sampling, across all documents. */
 public void sampleTopicsForAllDocs(Randoms r) {
   double[] topicWeights = new double[numTopics];
   // Loop over every word in the corpus
   for (int di = 0; di < topics.length; di++) {
     sampleTopicsForOneDoc(
         (FeatureSequence) ilist.get(di).getData(),
         topics[di],
         docTopicCounts[di],
         topicWeights,
         r);
   }
 }
Exemple #11
0
  public void estimate(
      InstanceList documents,
      int numIterations,
      int showTopicsInterval,
      int outputModelInterval,
      String outputModelFilename,
      Randoms r) {
    ilist = documents.shallowClone();
    numTypes = ilist.getDataAlphabet().size();
    int numDocs = ilist.size();
    topics = new int[numDocs][];
    docTopicCounts = new int[numDocs][numTopics];
    typeTopicCounts = new int[numTypes][numTopics];
    tokensPerTopic = new int[numTopics];
    tAlpha = alpha * numTopics;
    vBeta = beta * numTypes;

    long startTime = System.currentTimeMillis();

    // Initialize with random assignments of tokens to topics
    // and finish allocating this.topics and this.tokens
    int topic, seqLen;
    FeatureSequence fs;
    for (int di = 0; di < numDocs; di++) {
      try {
        fs = (FeatureSequence) ilist.get(di).getData();
      } catch (ClassCastException e) {
        System.err.println(
            "LDA and other topic models expect FeatureSequence data, not FeatureVector data.  "
                + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
        throw e;
      }
      seqLen = fs.getLength();
      numTokens += seqLen;
      topics[di] = new int[seqLen];
      // Randomly assign tokens to topics
      for (int si = 0; si < seqLen; si++) {
        topic = r.nextInt(numTopics);
        topics[di][si] = topic;
        docTopicCounts[di][topic]++;
        typeTopicCounts[fs.getIndexAtPosition(si)][topic]++;
        tokensPerTopic[topic]++;
      }
    }

    this.estimate(
        0, numDocs, numIterations, showTopicsInterval, outputModelInterval, outputModelFilename, r);
    // 124.5 seconds
    // 144.8 seconds after using FeatureSequence instead of tokens[][] array
    // 121.6 seconds after putting "final" on FeatureSequence.getIndexAtPosition()
    // 106.3 seconds after avoiding array lookup in inner loop with a temporary variable

  }
Exemple #12
0
 /* One iteration of Gibbs sampling, across all documents. */
 public void sampleTopicsForDocs(int start, int length, Randoms r) {
   assert (start + length <= docTopicCounts.length);
   double[] topicWeights = new double[numTopics];
   // Loop over every word in the corpus
   for (int di = start; di < start + length; di++) {
     sampleTopicsForOneDoc(
         (FeatureSequence) ilist.get(di).getData(),
         topics[di],
         docTopicCounts[di],
         topicWeights,
         r);
   }
 }
Exemple #13
0
  public void addDocuments(
      InstanceList additionalDocuments,
      int numIterations,
      int showTopicsInterval,
      int outputModelInterval,
      String outputModelFilename,
      Randoms r) {
    if (ilist == null) throw new IllegalStateException("Must already have some documents first.");
    for (Instance inst : additionalDocuments) ilist.add(inst);
    assert (ilist.getDataAlphabet() == additionalDocuments.getDataAlphabet());
    assert (additionalDocuments.getDataAlphabet().size() >= numTypes);
    numTypes = additionalDocuments.getDataAlphabet().size();
    int numNewDocs = additionalDocuments.size();
    int numOldDocs = topics.length;
    int numDocs = numOldDocs + numNewDocs;
    // Expand various arrays to make space for the new data.
    int[][] newTopics = new int[numDocs][];
    for (int i = 0; i < topics.length; i++) newTopics[i] = topics[i];

    topics = newTopics; // The rest of this array will be initialized below.
    int[][] newDocTopicCounts = new int[numDocs][numTopics];
    for (int i = 0; i < docTopicCounts.length; i++) newDocTopicCounts[i] = docTopicCounts[i];
    docTopicCounts = newDocTopicCounts; // The rest of this array will be initialized below.
    int[][] newTypeTopicCounts = new int[numTypes][numTopics];
    for (int i = 0; i < typeTopicCounts.length; i++)
      for (int j = 0; j < numTopics; j++)
        newTypeTopicCounts[i][j] = typeTopicCounts[i][j]; // This array further populated below

    FeatureSequence fs;
    for (int di = numOldDocs; di < numDocs; di++) {
      try {
        fs = (FeatureSequence) additionalDocuments.get(di - numOldDocs).getData();
      } catch (ClassCastException e) {
        System.err.println(
            "LDA and other topic models expect FeatureSequence data, not FeatureVector data.  "
                + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
        throw e;
      }
      int seqLen = fs.getLength();
      numTokens += seqLen;
      topics[di] = new int[seqLen];
      // Randomly assign tokens to topics
      for (int si = 0; si < seqLen; si++) {
        int topic = r.nextInt(numTopics);
        topics[di][si] = topic;
        docTopicCounts[di][topic]++;
        typeTopicCounts[fs.getIndexAtPosition(si)][topic]++;
        tokensPerTopic[topic]++;
      }
    }
  }
 /* One iteration of Gibbs sampling, across all documents. */
 private void sampleTopicsForAllDocs(Randoms r) {
   double[] uniTopicWeights = new double[numTopics];
   double[] biTopicWeights = new double[numTopics * 2];
   // Loop over every word in the corpus
   for (int di = 0; di < topics.length; di++) {
     sampleTopicsForOneDoc(
         (FeatureSequenceWithBigrams) ilist.get(di).getData(),
         topics[di],
         grams[di],
         docTopicCounts[di],
         uniTopicWeights,
         biTopicWeights,
         r);
   }
 }
 public void generateTestInference() {
   if (lda == null) {
     System.out.println("Should run lda estimation first.");
     System.exit(1);
     return;
   }
   if (testTopicDistribution == null) testTopicDistribution = new double[test.size()][];
   TopicInferencer infer = lda.getInferencer();
   int iterations = 800;
   int thinning = 5;
   int burnIn = 100;
   for (int ti = 0; ti < test.size(); ti++) {
     testTopicDistribution[ti] =
         infer.getSampledDistribution(test.get(ti), iterations, thinning, burnIn);
   }
 }
Exemple #16
0
  public void doInference() {

    try {

      ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile));
      TopicInferencer inferencer = model.getInferencer();

      // TopicInferencer inferencer =
      //    TopicInferencer.read(new File(inferencerFile));

      // InstanceList testing = readFile();
      readFile();
      InstanceList testing = generateInstanceList(); // readFile();

      for (int i = 0; i < testing.size(); i++) {

        StringBuilder probabilities = new StringBuilder();
        double[] testProbabilities = inferencer.getSampledDistribution(testing.get(i), 10, 1, 5);

        ArrayList probabilityList = new ArrayList();

        for (int j = 0; j < testProbabilities.length; j++) {
          probabilityList.add(new Pair<Integer, Double>(j, testProbabilities[j]));
        }

        Collections.sort(probabilityList, new CustomComparator());

        for (int j = 0; j < testProbabilities.length && j < topN; j++) {
          if (j > 0) probabilities.append(" ");
          probabilities.append(
              ((Pair<Integer, Double>) probabilityList.get(j)).getFirst().toString()
                  + ","
                  + ((Pair<Integer, Double>) probabilityList.get(j)).getSecond().toString());
        }

        System.out.println(docIds.get(i) + "," + probabilities.toString());
      }

    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
Exemple #17
0
 public void printState(PrintWriter pw) {
   Alphabet a = ilist.getDataAlphabet();
   pw.println("#doc pos typeindex type topic");
   for (int di = 0; di < topics.length; di++) {
     FeatureSequence fs = (FeatureSequence) ilist.get(di).getData();
     for (int si = 0; si < topics[di].length; si++) {
       int type = fs.getIndexAtPosition(si);
       pw.print(di);
       pw.print(' ');
       pw.print(si);
       pw.print(' ');
       pw.print(type);
       pw.print(' ');
       pw.print(a.lookupObject(type));
       pw.print(' ');
       pw.print(topics[di][si]);
       pw.println();
     }
   }
 }
Exemple #18
0
  /**
   * Convert to the Mallet format.
   *
   * @param fv
   * @param m
   * @return
   */
  public static Instance convert(FeatureVector<?> fv, Metadata m) {
    boolean isTrain = (fv instanceof LabeledFeatureVector) ? true : false;

    StringBuffer sb = new StringBuffer(fv.getId() + "\t");

    if (isTrain) {
      sb.append(((LabeledFeatureVector<?>) fv).getLabel().toString() + "\t");
    } else {
      sb.append("1\t");
    }

    for (String feature : m.keySet()) {
      if (fv.containsKey(feature)) {
        sb.append(feature + "=" + fv.get(feature).getValue() + " ");
      }
    }

    InstanceList il = loadInstances(sb.toString());

    return il.get(0);
  }
Exemple #19
0
  public void test() throws Exception {

    ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile));
    TopicInferencer inferencer = model.getInferencer();

    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
    pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+")));
    pipeList.add(new TokenSequence2FeatureSequence());

    InstanceList instances = new InstanceList(new SerialPipes(pipeList));
    Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
    instances.addThruPipe(
        new CsvIterator(
            fileReader,
            Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
            3,
            2,
            1)); // data, label, name fields
    double[] testProbabilities = inferencer.getSampledDistribution(instances.get(1), 10, 1, 5);
    for (int i = 0; i < 1000; i++) System.out.println(i + ": " + testProbabilities[i]);
  }
Exemple #20
0
 public double dataLogLikelihood(InstanceList ilist) {
   double logLikelihood = 0;
   for (int ii = 0; ii < ilist.size(); ii++) {
     double instanceWeight = ilist.getInstanceWeight(ii);
     Instance inst = ilist.get(ii);
     Labeling labeling = inst.getLabeling();
     if (labeling != null)
       logLikelihood += instanceWeight * dataLogProbability(inst, labeling.getBestIndex());
     else {
       Labeling predicted = this.classify(inst).getLabeling();
       // System.err.println ("label = \n"+labeling);
       // System.err.println ("predicted = \n"+predicted);
       for (int lpos = 0; lpos < predicted.numLocations(); lpos++) {
         int li = predicted.indexAtLocation(lpos);
         double labelWeight = predicted.valueAtLocation(lpos);
         // System.err.print (", "+labelWeight);
         if (labelWeight == 0) continue;
         logLikelihood += instanceWeight * labelWeight * dataLogProbability(inst, li);
       }
     }
   }
   return logLikelihood;
 }
 public void printState(PrintWriter pw) {
   pw.println("#doc pos typeindex type bigrampossible? topic bigram");
   for (int di = 0; di < topics.length; di++) {
     FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData();
     for (int si = 0; si < topics[di].length; si++) {
       int type = fs.getIndexAtPosition(si);
       pw.print(di);
       pw.print(' ');
       pw.print(si);
       pw.print(' ');
       pw.print(type);
       pw.print(' ');
       pw.print(uniAlphabet.lookupObject(type));
       pw.print(' ');
       pw.print(fs.getBiIndexAtPosition(si) == -1 ? 0 : 1);
       pw.print(' ');
       pw.print(topics[di][si]);
       pw.print(' ');
       pw.print(grams[di][si]);
       pw.println();
     }
   }
 }
Exemple #22
0
 public double labelLogLikelihood(InstanceList ilist) {
   double logLikelihood = 0;
   for (int ii = 0; ii < ilist.size(); ii++) {
     double instanceWeight = ilist.getInstanceWeight(ii);
     Instance inst = ilist.get(ii);
     Labeling labeling = inst.getLabeling();
     if (labeling == null) continue;
     Labeling predicted = this.classify(inst).getLabeling();
     // System.err.println ("label = \n"+labeling);
     // System.err.println ("predicted = \n"+predicted);
     if (labeling.numLocations() == 1) {
       logLikelihood += instanceWeight * Math.log(predicted.value(labeling.getBestIndex()));
     } else {
       for (int lpos = 0; lpos < labeling.numLocations(); lpos++) {
         int li = labeling.indexAtLocation(lpos);
         double labelWeight = labeling.valueAtLocation(lpos);
         // System.err.print (", "+labelWeight);
         if (labelWeight == 0) continue;
         logLikelihood += instanceWeight * labelWeight * Math.log(predicted.value(li));
       }
     }
   }
   return logLikelihood;
 }
Exemple #23
0
  /**
   * Command-line wrapper to train, test, or run a generic CRF-based tagger.
   *
   * @param args the command line arguments. Options (shell and Java quoting should be added as
   *     needed):
   *     <dl>
   *       <dt><code>--help</code> <em>boolean</em>
   *       <dd>Print this command line option usage information. Give <code>true</code> for longer
   *           documentation. Default is <code>false</code>.
   *       <dt><code>--prefix-code</code> <em>Java-code</em>
   *       <dd>Java code you want run before any other interpreted code. Note that the text is
   *           interpreted without modification, so unlike some other Java code options, you need to
   *           include any necessary 'new's. Default is null.
   *       <dt><code>--gaussian-variance</code> <em>positive-number</em>
   *       <dd>The Gaussian prior variance used for training. Default is 10.0.
   *       <dt><code>--train</code> <em>boolean</em>
   *       <dd>Whether to train. Default is <code>false</code>.
   *       <dt><code>--iterations</code> <em>positive-integer</em>
   *       <dd>Number of training iterations. Default is 500.
   *       <dt><code>--test</code> <code>lab</code> or <code>seg=</code><em>start-1</em><code>.
   *           </code><em>continue-1</em><code>,</code>...<code>,</code><em>start-n</em><code>.
   *           </code><em>continue-n</em>
   *       <dd>Test measuring labeling or segmentation (<em>start-i</em>, <em>continue-i</em>)
   *           accuracy. Default is no testing.
   *       <dt><code>--training-proportion</code> <em>number-between-0-and-1</em>
   *       <dd>Fraction of data to use for training in a random split. Default is 0.5.
   *       <dt><code>--model-file</code> <em>filename</em>
   *       <dd>The filename for reading (train/run) or saving (train) the model. Default is null.
   *       <dt><code>--random-seed</code> <em>integer</em>
   *       <dd>The random seed for randomly selecting a proportion of the instance list for training
   *           Default is 0.
   *       <dt><code>--orders</code> <em>comma-separated-integers</em>
   *       <dd>List of label Markov orders (main and backoff) Default is 1.
   *       <dt><code>--forbidden</code> <em>regular-expression</em>
   *       <dd>If <em>label-1</em><code>,</code><em>label-2</em> matches the expression, the
   *           corresponding transition is forbidden. Default is <code>\\s</code> (nothing
   *           forbidden).
   *       <dt><code>--allowed</code> <em>regular-expression</em>
   *       <dd>If <em>label-1</em><code>,</code><em>label-2</em> does not match the expression, the
   *           corresponding expression is forbidden. Default is <code>.*</code> (everything
   *           allowed).
   *       <dt><code>--default-label</code> <em>string</em>
   *       <dd>Label for initial context and uninteresting tokens. Default is <code>O</code>.
   *       <dt><code>--viterbi-output</code> <em>boolean</em>
   *       <dd>Print Viterbi periodically during training. Default is <code>false</code>.
   *       <dt><code>--fully-connected</code> <em>boolean</em>
   *       <dd>Include all allowed transitions, even those not in training data. Default is <code>
   *           true</code>.
   *       <dt><code>--n-best</code> <em>positive-integer</em>
   *       <dd>Number of answers to output when applying model. Default is 1.
   *       <dt><code>--include-input</code> <em>boolean</em>
   *       <dd>Whether to include input features when printing decoding output. Default is <code>
   *           false</code>.
   *     </dl>
   *     Remaining arguments:
   *     <ul>
   *       <li><em>training-data-file</em> if training
   *       <li><em>training-and-test-data-file</em>, if training and testing with random split
   *       <li><em>training-data-file</em> <em>test-data-file</em> if training and testing from
   *           separate files
   *       <li><em>test-data-file</em> if testing
   *       <li><em>input-data-file</em> if applying to new data (unlabeled)
   *     </ul>
   *
   * @exception Exception if an error occurs
   */
  public static void main(String[] args) throws Exception {
    Reader trainingFile = null, testFile = null;
    InstanceList trainingData = null, testData = null;
    int numEvaluations = 0;
    int iterationsBetweenEvals = 16;
    int restArgs = commandOptions.processOptions(args);
    if (restArgs == args.length) {
      commandOptions.printUsage(true);
      throw new IllegalArgumentException("Missing data file(s)");
    }
    if (trainOption.value) {
      trainingFile = new FileReader(new File(args[restArgs]));
      if (testOption.value != null && restArgs < args.length - 1)
        testFile = new FileReader(new File(args[restArgs + 1]));
    } else testFile = new FileReader(new File(args[restArgs]));

    Pipe p = null;
    CRF crf = null;
    TransducerEvaluator eval = null;
    if (continueTrainingOption.value || !trainOption.value) {
      if (modelOption.value == null) {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Missing model file option");
      }
      ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value));
      crf = (CRF) s.readObject();
      s.close();
      p = crf.getInputPipe();
    } else {
      p = new SimpleTaggerSentence2FeatureVectorSequence();
      p.getTargetAlphabet().lookupIndex(defaultOption.value);
    }

    if (trainOption.value) {
      p.setTargetProcessing(true);
      trainingData = new InstanceList(p);
      trainingData.addThruPipe(
          new LineGroupIterator(trainingFile, Pattern.compile("^\\s*$"), true));
      logger.info("Number of features in training data: " + p.getDataAlphabet().size());
      if (testOption.value != null) {
        if (testFile != null) {
          testData = new InstanceList(p);
          testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
        } else {
          Random r = new Random(randomSeedOption.value);
          InstanceList[] trainingLists =
              trainingData.split(
                  r, new double[] {trainingFractionOption.value, 1 - trainingFractionOption.value});
          trainingData = trainingLists[0];
          testData = trainingLists[1];
        }
      }
    } else if (testOption.value != null) {
      p.setTargetProcessing(true);
      testData = new InstanceList(p);
      testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
    } else {
      p.setTargetProcessing(false);
      testData = new InstanceList(p);
      testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
    }
    logger.info("Number of predicates: " + p.getDataAlphabet().size());

    if (testOption.value != null) {
      if (testOption.value.startsWith("lab"))
        eval =
            new TokenAccuracyEvaluator(
                new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"});
      else if (testOption.value.startsWith("seg=")) {
        String[] pairs = testOption.value.substring(4).split(",");
        if (pairs.length < 1) {
          commandOptions.printUsage(true);
          throw new IllegalArgumentException(
              "Missing segment start/continue labels: " + testOption.value);
        }
        String startTags[] = new String[pairs.length];
        String continueTags[] = new String[pairs.length];
        for (int i = 0; i < pairs.length; i++) {
          String[] pair = pairs[i].split("\\.");
          if (pair.length != 2) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException(
                "Incorrectly-specified segment start and end labels: " + pairs[i]);
          }
          startTags[i] = pair[0];
          continueTags[i] = pair[1];
        }
        eval =
            new MultiSegmentationEvaluator(
                new InstanceList[] {trainingData, testData},
                new String[] {"Training", "Testing"},
                startTags,
                continueTags);
      } else {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Invalid test option: " + testOption.value);
      }
    }

    if (p.isTargetProcessing()) {
      Alphabet targets = p.getTargetAlphabet();
      StringBuffer buf = new StringBuffer("Labels:");
      for (int i = 0; i < targets.size(); i++)
        buf.append(" ").append(targets.lookupObject(i).toString());
      logger.info(buf.toString());
    }
    if (trainOption.value) {
      crf =
          train(
              trainingData,
              testData,
              eval,
              ordersOption.value,
              defaultOption.value,
              forbiddenOption.value,
              allowedOption.value,
              connectedOption.value,
              iterationsOption.value,
              gaussianVarianceOption.value,
              crf);
      if (modelOption.value != null) {
        ObjectOutputStream s = new ObjectOutputStream(new FileOutputStream(modelOption.value));
        s.writeObject(crf);
        s.close();
      }
    } else {
      if (crf == null) {
        if (modelOption.value == null) {
          commandOptions.printUsage(true);
          throw new IllegalArgumentException("Missing model file option");
        }
        ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value));
        crf = (CRF) s.readObject();
        s.close();
      }
      if (eval != null) test(new NoopTransducerTrainer(crf), eval, testData);
      else {
        boolean includeInput = includeInputOption.value();
        for (int i = 0; i < testData.size(); i++) {
          Sequence input = (Sequence) testData.get(i).getData();
          Sequence[] outputs = apply(crf, input, nBestOption.value);
          int k = outputs.length;
          boolean error = false;
          for (int a = 0; a < k; a++) {
            if (outputs[a].size() != input.size()) {
              System.err.println("Failed to decode input sequence " + i + ", answer " + a);
              error = true;
            }
          }
          if (!error) {
            for (int j = 0; j < input.size(); j++) {
              StringBuffer buf = new StringBuffer();
              for (int a = 0; a < k; a++) buf.append(outputs[a].get(j).toString()).append(" ");
              if (includeInput) {
                FeatureVector fv = (FeatureVector) input.get(j);
                buf.append(fv.toString(true));
              }
              System.out.println(buf.toString());
            }
            System.out.println();
          }
        }
      }
    }
  }
Exemple #24
0
 public InstanceList getInstances() {
   InstanceList ret = new InstanceList(m_ilist.getPipe());
   for (int ii = 0; ii < m_instIndices.length; ii++) ret.add(m_ilist.get(m_instIndices[ii]));
   return ret;
 }
  public void printTopWords(int numWords, boolean useNewLines) {
    class WordProb implements Comparable {
      int wi;
      double p;

      public WordProb(int wi, double p) {
        this.wi = wi;
        this.p = p;
      }

      public final int compareTo(Object o2) {
        if (p > ((WordProb) o2).p) return -1;
        else if (p == ((WordProb) o2).p) return 0;
        else return 1;
      }
    }

    for (int ti = 0; ti < numTopics; ti++) {
      // Unigrams
      WordProb[] wp = new WordProb[numTypes];
      for (int wi = 0; wi < numTypes; wi++)
        wp[wi] = new WordProb(wi, (double) unitypeTopicCounts[wi][ti]);
      Arrays.sort(wp);
      int numToPrint = Math.min(wp.length, numWords);
      if (useNewLines) {
        System.out.println("\nTopic " + ti + " unigrams");
        for (int i = 0; i < numToPrint; i++)
          System.out.println(
              uniAlphabet.lookupObject(wp[i].wi).toString() + " " + wp[i].p / tokensPerTopic[ti]);
      } else {
        System.out.print("Topic " + ti + ": ");
        for (int i = 0; i < numToPrint; i++)
          System.out.print(uniAlphabet.lookupObject(wp[i].wi).toString() + " ");
      }

      // Bigrams
      /*
      wp = new WordProb[numBitypes];
      int bisum = 0;
      for (int wi = 0; wi < numBitypes; wi++) {
      	wp[wi] = new WordProb (wi, ((double)bitypeTopicCounts[wi][ti]));
      	bisum += bitypeTopicCounts[wi][ti];
      }
      Arrays.sort (wp);
      numToPrint = Math.min(wp.length, numWords);
      if (useNewLines) {
      	System.out.println ("\nTopic "+ti+" bigrams");
      	for (int i = 0; i < numToPrint; i++)
      		System.out.println (biAlphabet.lookupObject(wp[i].wi).toString() + " " + wp[i].p/bisum);
      } else {
      	System.out.print ("          ");
      	for (int i = 0; i < numToPrint; i++)
      		System.out.print (biAlphabet.lookupObject(wp[i].wi).toString() + " ");
      	System.out.println();
      }
      */

      // Ngrams
      AugmentableFeatureVector afv = new AugmentableFeatureVector(new Alphabet(), 10000, false);
      for (int di = 0; di < topics.length; di++) {
        FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData();
        for (int si = topics[di].length - 1; si >= 0; si--) {
          if (topics[di][si] == ti && grams[di][si] == 1) {
            String gramString = uniAlphabet.lookupObject(fs.getIndexAtPosition(si)).toString();
            while (grams[di][si] == 1 && --si >= 0)
              gramString =
                  uniAlphabet.lookupObject(fs.getIndexAtPosition(si)).toString() + "_" + gramString;
            afv.add(gramString, 1.0);
          }
        }
      }
      // System.out.println ("pre-sorting");
      int numNgrams = afv.numLocations();
      // System.out.println ("post-sorting "+numNgrams);
      wp = new WordProb[numNgrams];
      int ngramSum = 0;
      for (int loc = 0; loc < numNgrams; loc++) {
        wp[loc] = new WordProb(afv.indexAtLocation(loc), afv.valueAtLocation(loc));
        ngramSum += wp[loc].p;
      }
      Arrays.sort(wp);
      int numUnitypeTokens = 0, numBitypeTokens = 0, numUnitypeTypes = 0, numBitypeTypes = 0;
      for (int fi = 0; fi < numTypes; fi++) {
        numUnitypeTokens += unitypeTopicCounts[fi][ti];
        if (unitypeTopicCounts[fi][ti] != 0) numUnitypeTypes++;
      }
      for (int fi = 0; fi < numBitypes; fi++) {
        numBitypeTokens += bitypeTopicCounts[fi][ti];
        if (bitypeTopicCounts[fi][ti] != 0) numBitypeTypes++;
      }

      if (useNewLines) {
        System.out.println(
            "\nTopic "
                + ti
                + " unigrams "
                + numUnitypeTokens
                + "/"
                + numUnitypeTypes
                + " bigrams "
                + numBitypeTokens
                + "/"
                + numBitypeTypes
                + " phrases "
                + Math.round(afv.oneNorm())
                + "/"
                + numNgrams);
        for (int i = 0; i < Math.min(numNgrams, numWords); i++)
          System.out.println(
              afv.getAlphabet().lookupObject(wp[i].wi).toString() + " " + wp[i].p / ngramSum);
      } else {
        System.out.print(
            " (unigrams "
                + numUnitypeTokens
                + "/"
                + numUnitypeTypes
                + " bigrams "
                + numBitypeTokens
                + "/"
                + numBitypeTypes
                + " phrases "
                + Math.round(afv.oneNorm())
                + "/"
                + numNgrams
                + ")\n         ");
        // System.out.print (" (unique-ngrams="+numNgrams+"
        // ngram-count="+Math.round(afv.oneNorm())+")\n         ");
        for (int i = 0; i < Math.min(numNgrams, numWords); i++)
          System.out.print(afv.getAlphabet().lookupObject(wp[i].wi).toString() + " ");
        System.out.println();
      }
    }
  }
  public void estimate(
      InstanceList documents,
      int numIterations,
      int showTopicsInterval,
      int outputModelInterval,
      String outputModelFilename,
      Randoms r) {
    ilist = documents;
    uniAlphabet = ilist.getDataAlphabet();
    biAlphabet = ((FeatureSequenceWithBigrams) ilist.get(0).getData()).getBiAlphabet();
    numTypes = uniAlphabet.size();
    numBitypes = biAlphabet.size();
    int numDocs = ilist.size();
    topics = new int[numDocs][];
    grams = new int[numDocs][];
    docTopicCounts = new int[numDocs][numTopics];
    typeNgramTopicCounts = new int[numTypes][2][numTopics];
    unitypeTopicCounts = new int[numTypes][numTopics];
    bitypeTopicCounts = new int[numBitypes][numTopics];
    tokensPerTopic = new int[numTopics];
    bitokensPerTopic = new int[numTypes][numTopics];
    tAlpha = alpha * numTopics;
    vBeta = beta * numTypes;
    vGamma = gamma * numTypes;

    long startTime = System.currentTimeMillis();

    // Initialize with random assignments of tokens to topics
    // and finish allocating this.topics and this.tokens
    int topic, gram, seqLen, fi;
    for (int di = 0; di < numDocs; di++) {
      FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData();
      seqLen = fs.getLength();
      numTokens += seqLen;
      topics[di] = new int[seqLen];
      grams[di] = new int[seqLen];
      // Randomly assign tokens to topics
      int prevFi = -1, prevTopic = -1;
      for (int si = 0; si < seqLen; si++) {
        // randomly sample a topic for the word at position si
        topic = r.nextInt(numTopics);
        // if a bigram is allowed at position si, then sample a gram status for it.
        gram = (fs.getBiIndexAtPosition(si) == -1 ? 0 : r.nextInt(2));
        if (gram != 0) biTokens++;
        topics[di][si] = topic;
        grams[di][si] = gram;
        docTopicCounts[di][topic]++;
        fi = fs.getIndexAtPosition(si);
        if (prevFi != -1) typeNgramTopicCounts[prevFi][gram][prevTopic]++;
        if (gram == 0) {
          unitypeTopicCounts[fi][topic]++;
          tokensPerTopic[topic]++;
        } else {
          bitypeTopicCounts[fs.getBiIndexAtPosition(si)][topic]++;
          bitokensPerTopic[prevFi][topic]++;
        }
        prevFi = fi;
        prevTopic = topic;
      }
    }

    for (int iterations = 0; iterations < numIterations; iterations++) {
      sampleTopicsForAllDocs(r);
      if (iterations % 10 == 0) System.out.print(iterations);
      else System.out.print(".");
      System.out.flush();
      if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0 && iterations > 0) {
        System.out.println();
        printTopWords(5, false);
      }
      if (outputModelInterval != 0 && iterations % outputModelInterval == 0 && iterations > 0) {
        this.write(new File(outputModelFilename + '.' + iterations));
      }
    }

    System.out.println(
        "\nTotal time (sec): " + ((System.currentTimeMillis() - startTime) / 1000.0));
  }
 private Instance getLastInstance() {
   return list.get(list.size() - 1);
 }