Пример #1
0
 public TransitionIterator(
     State source, FeatureSequence inputSeq, int inputPosition, String output, HMM hmm) {
   this.source = source;
   this.hmm = hmm;
   this.inputSequence = inputSeq;
   this.inputFeature = new Integer(inputSequence.getIndexAtPosition(inputPosition));
   this.inputPos = inputPosition;
   this.weights = new double[source.destinations.length];
   for (int transIndex = 0; transIndex < source.destinations.length; transIndex++) {
     if (output == null || output.equals(source.labels[transIndex])) {
       weights[transIndex] = 0;
       // xxx should this be emission of the _next_ observation?
       // double logEmissionProb =
       // hmm.emissionMultinomial[source.getIndex()].logProbability
       // (inputSeq.get (inputPosition));
       int destIndex = source.getDestinationState(transIndex).getIndex();
       double logEmissionProb =
           hmm.emissionMultinomial[destIndex].logProbability(inputSeq.get(inputPosition));
       double logTransitionProb =
           hmm.transitionMultinomial[source.getIndex()].logProbability(
               source.destinationNames[transIndex]);
       // weight = logProbability
       weights[transIndex] = (logEmissionProb + logTransitionProb);
       assert (!Double.isNaN(weights[transIndex]));
     } else weights[transIndex] = IMPOSSIBLE_WEIGHT;
   }
   nextIndex = 0;
   while (nextIndex < source.destinations.length && weights[nextIndex] == IMPOSSIBLE_WEIGHT)
     nextIndex++;
 }
Пример #2
0
  public void count() {

    TIntIntHashMap docCounts = new TIntIntHashMap();

    int index = 0;

    if (instances.size() == 0) {
      logger.info("Instance list is empty");
      return;
    }

    if (instances.get(0).getData() instanceof FeatureSequence) {

      for (Instance instance : instances) {
        FeatureSequence features = (FeatureSequence) instance.getData();

        for (int i = 0; i < features.getLength(); i++) {
          docCounts.adjustOrPutValue(features.getIndexAtPosition(i), 1, 1);
        }

        int[] keys = docCounts.keys();
        for (int i = 0; i < keys.length - 1; i++) {
          int feature = keys[i];
          featureCounts[feature] += docCounts.get(feature);
          documentFrequencies[feature]++;
        }

        docCounts = new TIntIntHashMap();

        index++;
        if (index % 1000 == 0) {
          System.err.println(index);
        }
      }
    } else if (instances.get(0).getData() instanceof FeatureVector) {

      for (Instance instance : instances) {
        FeatureVector features = (FeatureVector) instance.getData();

        for (int location = 0; location < features.numLocations(); location++) {
          int feature = features.indexAtLocation(location);
          double value = features.valueAtLocation(location);

          documentFrequencies[feature]++;
          featureCounts[feature] += value;
        }

        index++;
        if (index % 1000 == 0) {
          System.err.println(index);
        }
      }
    } else {
      logger.info("Unsupported data class: " + instances.get(0).getData().getClass().getName());
    }
  }
Пример #3
0
  public void estimate(
      InstanceList documents,
      int numIterations,
      int showTopicsInterval,
      int outputModelInterval,
      String outputModelFilename,
      Randoms r) {
    ilist = documents.shallowClone();
    numTypes = ilist.getDataAlphabet().size();
    int numDocs = ilist.size();
    topics = new int[numDocs][];
    docTopicCounts = new int[numDocs][numTopics];
    typeTopicCounts = new int[numTypes][numTopics];
    tokensPerTopic = new int[numTopics];
    tAlpha = alpha * numTopics;
    vBeta = beta * numTypes;

    long startTime = System.currentTimeMillis();

    // Initialize with random assignments of tokens to topics
    // and finish allocating this.topics and this.tokens
    int topic, seqLen;
    FeatureSequence fs;
    for (int di = 0; di < numDocs; di++) {
      try {
        fs = (FeatureSequence) ilist.get(di).getData();
      } catch (ClassCastException e) {
        System.err.println(
            "LDA and other topic models expect FeatureSequence data, not FeatureVector data.  "
                + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
        throw e;
      }
      seqLen = fs.getLength();
      numTokens += seqLen;
      topics[di] = new int[seqLen];
      // Randomly assign tokens to topics
      for (int si = 0; si < seqLen; si++) {
        topic = r.nextInt(numTopics);
        topics[di][si] = topic;
        docTopicCounts[di][topic]++;
        typeTopicCounts[fs.getIndexAtPosition(si)][topic]++;
        tokensPerTopic[topic]++;
      }
    }

    this.estimate(
        0, numDocs, numIterations, showTopicsInterval, outputModelInterval, outputModelFilename, r);
    // 124.5 seconds
    // 144.8 seconds after using FeatureSequence instead of tokens[][] array
    // 121.6 seconds after putting "final" on FeatureSequence.getIndexAtPosition()
    // 106.3 seconds after avoiding array lookup in inner loop with a temporary variable

  }
Пример #4
0
  public void addDocuments(
      InstanceList additionalDocuments,
      int numIterations,
      int showTopicsInterval,
      int outputModelInterval,
      String outputModelFilename,
      Randoms r) {
    if (ilist == null) throw new IllegalStateException("Must already have some documents first.");
    for (Instance inst : additionalDocuments) ilist.add(inst);
    assert (ilist.getDataAlphabet() == additionalDocuments.getDataAlphabet());
    assert (additionalDocuments.getDataAlphabet().size() >= numTypes);
    numTypes = additionalDocuments.getDataAlphabet().size();
    int numNewDocs = additionalDocuments.size();
    int numOldDocs = topics.length;
    int numDocs = numOldDocs + numNewDocs;
    // Expand various arrays to make space for the new data.
    int[][] newTopics = new int[numDocs][];
    for (int i = 0; i < topics.length; i++) newTopics[i] = topics[i];

    topics = newTopics; // The rest of this array will be initialized below.
    int[][] newDocTopicCounts = new int[numDocs][numTopics];
    for (int i = 0; i < docTopicCounts.length; i++) newDocTopicCounts[i] = docTopicCounts[i];
    docTopicCounts = newDocTopicCounts; // The rest of this array will be initialized below.
    int[][] newTypeTopicCounts = new int[numTypes][numTopics];
    for (int i = 0; i < typeTopicCounts.length; i++)
      for (int j = 0; j < numTopics; j++)
        newTypeTopicCounts[i][j] = typeTopicCounts[i][j]; // This array further populated below

    FeatureSequence fs;
    for (int di = numOldDocs; di < numDocs; di++) {
      try {
        fs = (FeatureSequence) additionalDocuments.get(di - numOldDocs).getData();
      } catch (ClassCastException e) {
        System.err.println(
            "LDA and other topic models expect FeatureSequence data, not FeatureVector data.  "
                + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence.");
        throw e;
      }
      int seqLen = fs.getLength();
      numTokens += seqLen;
      topics[di] = new int[seqLen];
      // Randomly assign tokens to topics
      for (int si = 0; si < seqLen; si++) {
        int topic = r.nextInt(numTopics);
        topics[di][si] = topic;
        docTopicCounts[di][topic]++;
        typeTopicCounts[fs.getIndexAtPosition(si)][topic]++;
        tokensPerTopic[topic]++;
      }
    }
  }
Пример #5
0
 private boolean[][] labelConnectionsIn(InstanceList trainingSet) {
   int numLabels = outputAlphabet.size();
   boolean[][] connections = new boolean[numLabels][numLabels];
   for (Instance instance : trainingSet) {
     FeatureSequence output = (FeatureSequence) instance.getTarget();
     for (int j = 1; j < output.size(); j++) {
       int sourceIndex = outputAlphabet.lookupIndex(output.get(j - 1));
       int destIndex = outputAlphabet.lookupIndex(output.get(j));
       assert (sourceIndex >= 0 && destIndex >= 0);
       connections[sourceIndex][destIndex] = true;
     }
   }
   return connections;
 }
Пример #6
0
 public void printState(PrintWriter pw) {
   Alphabet a = ilist.getDataAlphabet();
   pw.println("#doc pos typeindex type topic");
   for (int di = 0; di < topics.length; di++) {
     FeatureSequence fs = (FeatureSequence) ilist.get(di).getData();
     for (int si = 0; si < topics[di].length; si++) {
       int type = fs.getIndexAtPosition(si);
       pw.print(di);
       pw.print(' ');
       pw.print(si);
       pw.print(' ');
       pw.print(type);
       pw.print(' ');
       pw.print(a.lookupObject(type));
       pw.print(' ');
       pw.print(topics[di][si]);
       pw.println();
     }
   }
 }
Пример #7
0
  private void sampleTopicsForOneDoc(
      FeatureSequence oneDocTokens,
      int[] oneDocTopics, // indexed by seq position
      int[] oneDocTopicCounts, // indexed by topic index
      double[] topicWeights,
      Randoms r) {
    int[] currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double topicWeightsSum;
    int docLen = oneDocTokens.getLength();
    double tw;
    // Iterate over the positions (words) in the document
    for (int si = 0; si < docLen; si++) {
      type = oneDocTokens.getIndexAtPosition(si);
      oldTopic = oneDocTopics[si];
      // Remove this token from all counts
      oneDocTopicCounts[oldTopic]--;
      typeTopicCounts[type][oldTopic]--;
      tokensPerTopic[oldTopic]--;
      // Build a distribution over topics for this token
      Arrays.fill(topicWeights, 0.0);
      topicWeightsSum = 0;
      currentTypeTopicCounts = typeTopicCounts[type];
      for (int ti = 0; ti < numTopics; ti++) {
        tw =
            ((currentTypeTopicCounts[ti] + beta) / (tokensPerTopic[ti] + vBeta))
                * ((oneDocTopicCounts[ti]
                    + alpha)); // (/docLen-1+tAlpha); is constant across all topics
        topicWeightsSum += tw;
        topicWeights[ti] = tw;
      }
      // Sample a topic assignment from this distribution
      newTopic = r.nextDiscrete(topicWeights, topicWeightsSum);

      // Put that new topic into the counts
      oneDocTopics[si] = newTopic;
      oneDocTopicCounts[newTopic]++;
      typeTopicCounts[type][newTopic]++;
      tokensPerTopic[newTopic]++;
    }
  }
  @Test
  public void testLoadRareWords() throws UnsupportedEncodingException, FileNotFoundException {
    String dataset_fn = "src/main/resources/datasets/SmallTexts.txt";
    InstanceList nonPrunedInstances = LDAUtils.loadInstances(dataset_fn, "stoplist.txt", 0);
    System.out.println(LDAUtils.instancesToString(nonPrunedInstances));
    System.out.println("Non pruned Alphabet size: " + nonPrunedInstances.getDataAlphabet().size());
    System.out.println("No. instances: " + nonPrunedInstances.size());

    InstanceList originalInstances = LDAUtils.loadInstances(dataset_fn, "stoplist.txt", 2);
    System.out.println("Alphabet size: " + originalInstances.getDataAlphabet().size());
    System.out.println(LDAUtils.instancesToString(originalInstances));
    System.out.println("No. instances: " + originalInstances.size());

    int[] wordCounts = {0, 3, 3, 0, 0};
    int idx = 0;
    for (Instance instance : originalInstances) {
      FeatureSequence fs = (FeatureSequence) instance.getData();
      // This assertion would fail for eventhough the feature sequence
      // is "empty" the underlying array is 2 long.
      // assertEquals(wordCounts[idx++], fs.getFeatures().length);
      assertEquals(wordCounts[idx++], fs.size());
    }
  }
Пример #9
0
  protected void sampleTopicsForOneDoc(TopicAssignment doc // int[][] typeTopicCounts
      // ,
      // double[] cachedCoefficients,
      // int[] tokensPerTopic,
      // double betaSum,
      // double beta,
      // double smoothingOnlyMass,
      // int[][] lblTypeTopicCounts,
      // double[] cachedLabelCoefficients,
      // int[] labelsPerTopic,
      // double gammaSum,
      // double gamma,
      // double smoothingOnlyLblMass
      ) {

    FeatureSequence tokenSequence = (FeatureSequence) doc.instance.getData();

    LabelSequence topicSequence = (LabelSequence) doc.topicSequence;

    MassValue massValue = new MassValue();
    massValue.topicBetaMass = 0.0;
    massValue.topicTermMass = 0.0;
    massValue.smoothingOnlyMass = smoothingOnlyMass;

    int nonZeroTopics = 0;

    int[] oneDocTopics = topicSequence.getFeatures();
    int[] localTopicCounts = new int[numTopics];
    int[] localTopicIndex = new int[numTopics];

    // Label Init
    LabelSequence lblTopicSequence = (LabelSequence) doc.lblTopicSequence;
    FeatureSequence labelSequence = (FeatureSequence) doc.instance.getTarget();

    MassValue massLblValue = new MassValue();
    massLblValue.topicBetaMass = 0.0;
    massLblValue.topicTermMass = 0.0;
    massLblValue.smoothingOnlyMass = smoothingOnlyLabelMass;

    int[] oneDocLblTopics = lblTopicSequence.getFeatures();
    int[] localLblTopicCounts = new int[numTopics];

    // initSampling

    int docLength = tokenSequence.getLength();
    //		populate topic counts
    for (int position = 0; position < docLength; position++) {
      if (oneDocTopics[position] == ParallelTopicModel.UNASSIGNED_TOPIC) {
        continue;
      }
      localTopicCounts[oneDocTopics[position]]++;
    }

    docLength = labelSequence.getLength();
    //		populate topic counts
    for (int position = 0; position < docLength; position++) {
      if (oneDocLblTopics[position] == ParallelTopicModel.UNASSIGNED_TOPIC) {
        continue;
      }
      localLblTopicCounts[oneDocLblTopics[position]]++;
    }

    // Build an array that densely lists the topics that
    //  have non-zero counts.
    int denseIndex = 0;
    for (int topic = 0; topic < numTopics; topic++) {
      if (localTopicCounts[topic] != 0 || localLblTopicCounts[topic] != 0) {
        localTopicIndex[denseIndex] = topic;
        denseIndex++;
      }
    }

    // Record the total number of non-zero topics
    nonZeroTopics = denseIndex;
    if (nonZeroTopics < 20) {
      int a = 1;
    }

    // Initialize the topic count/beta sampling bucket
    // Initialize cached coefficients and the topic/beta
    //  normalizing constant.
    for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
      int topic = localTopicIndex[denseIndex];
      int n = localTopicCounts[topic];
      int nl = localLblTopicCounts[topic];

      if (ignoreLabels) {
        //	initialize the normalization constant for the (B * n_{t|d}) term
        massValue.topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
        // massLblValue.topicBetaMass += gamma * nl / (labelsPerTopic[topic] + gammaSum);
        //	update the coefficients for the non-zero topics
        cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
        // cachedLabelCoefficients[topic] = (alpha[topic] + nl) / (labelsPerTopic[topic] +
        // gammaSum);
      } else {
        massValue.topicBetaMass += beta * (n + lblWeight * nl) / (tokensPerTopic[topic] + betaSum);
        // massLblValue.topicBetaMass += gamma * (nl + (1 / lblWeight) * n) / (labelsPerTopic[topic]
        // + gammaSum);
        cachedCoefficients[topic] =
            ((1 + lblWeight) * alpha[topic] + n + lblWeight * nl)
                / (tokensPerTopic[topic] + betaSum);
        // cachedLabelCoefficients[topic] = ((1 + (1 / lblWeight)) * alpha[topic] + nl + (1 /
        // lblWeight) * n) / (labelsPerTopic[topic] + gammaSum);
      }
    }

    // end of Init Sampling

    double[] topicTermScores = new double[numTopics];
    int[] currentTypeTopicCounts;
    //	Iterate over the positions (words) in the document
    docLength = tokenSequence.getLength();
    for (int position = 0; position < docLength; position++) {

      int type = tokenSequence.getIndexAtPosition(position);
      currentTypeTopicCounts = typeTopicCounts[type];

      nonZeroTopics =
          removeOldTopicContribution(
              position,
              oneDocTopics,
              massValue,
              localTopicCounts,
              localLblTopicCounts,
              localTopicIndex,
              cachedCoefficients,
              tokensPerTopic,
              betaSum,
              beta,
              lblWeight,
              nonZeroTopics);

      // calcSamplingValuesPerType
      calcSamplingValuesPerType(
          // tokenSequence,
          position,
          oneDocTopics,
          massValue,
          topicTermScores,
          currentTypeTopicCounts,
          localTopicCounts,
          // localTopicIndex,
          cachedCoefficients,
          tokensPerTopic,
          betaSum,
          beta,
          typeTotals,
          type,
          typeSkewIndexes,
          skewWeight);

      double sample = 0;

      sample =
          random.nextUniform()
              * (massValue.smoothingOnlyMass + massValue.topicBetaMass + massValue.topicTermMass);

      double origSample = sample;

      //	Make sure it actually gets set
      int newTopic = -1;

      newTopic =
          findNewTopic(
              sample,
              massValue,
              topicTermScores,
              currentTypeTopicCounts,
              localTopicCounts,
              localLblTopicCounts,
              localTopicIndex,
              tokensPerTopic,
              betaSum,
              beta,
              nonZeroTopics,
              lblWeight);

      if (newTopic == -1) {
        System.err.println(
            "WorkerRunnable sampling error: "
                + origSample
                + " "
                + sample
                + " "
                + massValue.smoothingOnlyMass
                + " "
                + massValue.topicBetaMass
                + " "
                + massValue.topicTermMass);
        newTopic = numTopics - 1; // TODO is this appropriate
        // throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
      }
      // assert(newTopic != -1);

      //			Put that new topic into the counts
      oneDocTopics[position] = newTopic;

      if (ignoreLabels) {
        massValue.smoothingOnlyMass -=
            alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass -=
            beta * localTopicCounts[newTopic] / (tokensPerTopic[newTopic] + betaSum);
      } else {
        massValue.smoothingOnlyMass -=
            (1 + lblWeight) * alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass -=
            beta
                * (localTopicCounts[newTopic] + lblWeight * localLblTopicCounts[newTopic])
                / (tokensPerTopic[newTopic] + betaSum);
      }
      localTopicCounts[newTopic]++;

      // If this is a new topic for this document,
      //  add the topic to the dense index.
      if (localTopicCounts[newTopic] == 1 && localLblTopicCounts[newTopic] == 0) {

        // First find the point where we
        //  should insert the new topic by going to
        //  the end (which is the only reason we're keeping
        //  track of the number of non-zero
        //  topics) and working backwards

        denseIndex = nonZeroTopics;

        while (denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic) {

          localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
          denseIndex--;
        }

        localTopicIndex[denseIndex] = newTopic;
        nonZeroTopics++;
      }

      tokensPerTopic[newTopic]++;

      if (ignoreLabels) {
        //	update the coefficients for the non-zero topics
        cachedCoefficients[newTopic] =
            (alpha[newTopic] + localTopicCounts[newTopic]) / (tokensPerTopic[newTopic] + betaSum);

        massValue.smoothingOnlyMass +=
            alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass +=
            beta * localTopicCounts[newTopic] / (tokensPerTopic[newTopic] + betaSum);
      } else {
        massValue.smoothingOnlyMass +=
            (1 + lblWeight) * alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass +=
            beta
                * (localTopicCounts[newTopic] + lblWeight * localLblTopicCounts[newTopic])
                / (tokensPerTopic[newTopic] + betaSum);

        cachedCoefficients[newTopic] =
            ((1 + lblWeight) * alpha[newTopic]
                    + localTopicCounts[newTopic]
                    + lblWeight * localLblTopicCounts[newTopic])
                / (tokensPerTopic[newTopic] + betaSum);
      }
    }

    // sample labels
    // init labels
    for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
      int topic = localTopicIndex[denseIndex];
      int n = localTopicCounts[topic];
      int nl = localLblTopicCounts[topic];

      if (ignoreLabels) {
        //	initialize the normalization constant for the (B * n_{t|d}) term
        // massValue.topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
        massLblValue.topicBetaMass += gamma * nl / (labelsPerTopic[topic] + gammaSum);
        //	update the coefficients for the non-zero topics
        // cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] = (alpha[topic] + nl) / (labelsPerTopic[topic] + gammaSum);
      } else {
        // massValue.topicBetaMass += beta * (n + lblWeight * nl) / (tokensPerTopic[topic] +
        // betaSum);
        massLblValue.topicBetaMass +=
            gamma * (nl + (1 / lblWeight) * n) / (labelsPerTopic[topic] + gammaSum);
        // cachedCoefficients[topic] = ((1 + lblWeight) * alpha[topic] + n + lblWeight * nl) /
        // (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] =
            ((1 + (1 / lblWeight)) * alpha[topic] + nl + (1 / lblWeight) * n)
                / (labelsPerTopic[topic] + gammaSum);
      }
    }

    double[] topicLblTermScores = new double[numTopics];
    int[] currentLblTypeTopicCounts;
    int docLblLength = labelSequence.getLength();

    //	Iterate over the positions (words) in the document
    for (int position = 0; position < docLblLength; position++) {

      int type = labelSequence.getIndexAtPosition(position);
      currentLblTypeTopicCounts = lbltypeTopicCounts[type];

      nonZeroTopics =
          removeOldTopicContribution(
              position,
              oneDocLblTopics,
              massLblValue,
              localLblTopicCounts,
              localTopicCounts,
              localTopicIndex,
              cachedLabelCoefficients,
              labelsPerTopic,
              gammaSum,
              gamma,
              1 / lblWeight,
              nonZeroTopics);

      // calcSamplingValuesPerType
      calcSamplingValuesPerType(
          // labelSequence,
          position,
          oneDocLblTopics,
          massLblValue,
          topicLblTermScores,
          currentLblTypeTopicCounts,
          localLblTopicCounts,
          // localTopicIndex,
          cachedLabelCoefficients,
          labelsPerTopic,
          gammaSum,
          gamma,
          lblTypeTotals,
          type,
          lblTypeSkewIndexes,
          lblSkewWeight);
      // massLblValue.smoothingOnlyMass = 0; //ignore smoothing mass

      double sample =
          random.nextUniform()
              * (massLblValue.smoothingOnlyMass
                  + massLblValue.topicBetaMass
                  + massLblValue.topicTermMass);

      // double sample = random.nextUniform() * (massValue.smoothingOnlyMass +
      // massValue.topicBetaMass + massLblValue.smoothingOnlyMass + massLblValue.topicBetaMass +
      // massLblValue.topicTermMass);

      double origSample = sample;

      //	Make sure it actually gets set
      int newTopic = -1;

      newTopic =
          findNewTopic(
              sample,
              massLblValue,
              topicLblTermScores,
              currentLblTypeTopicCounts,
              localLblTopicCounts,
              localTopicCounts,
              localTopicIndex,
              labelsPerTopic,
              gammaSum,
              gamma,
              nonZeroTopics,
              1 / lblWeight);

      if (newTopic == -1) {
        System.err.println(
            "WorkerRunnable sampling labels error: "
                + origSample
                + " "
                + sample
                + " "
                + massLblValue.smoothingOnlyMass
                + " "
                + massLblValue.topicBetaMass
                + " "
                + massLblValue.topicTermMass);
        // newTopic = numTopics - 1; // TODO is this appropriate
        // throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
      }
      assert (newTopic != -1);

      //			Put that new topic into the counts
      oneDocLblTopics[position] = newTopic;

      if (ignoreLabels) {
        massLblValue.smoothingOnlyMass -=
            alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass -=
            gamma * localLblTopicCounts[newTopic] / (labelsPerTopic[newTopic] + gammaSum);
      } else {

        massLblValue.smoothingOnlyMass -=
            (1 + 1 / lblWeight) * alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass -=
            gamma
                * (localLblTopicCounts[newTopic] + (1 / lblWeight) * localTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);
      }

      localLblTopicCounts[newTopic]++;

      // If this is a new topic for this document,
      //  add the topic to the dense index.
      if (localLblTopicCounts[newTopic] == 1 && localTopicCounts[newTopic] == 0) {

        // First find the point where we
        //  should insert the new topic by going to
        //  the end (which is the only reason we're keeping
        //  track of the number of non-zero
        //  topics) and working backwards

        denseIndex = nonZeroTopics;

        while (denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic) {

          localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
          denseIndex--;
        }

        localTopicIndex[denseIndex] = newTopic;
        nonZeroTopics++;
      }

      labelsPerTopic[newTopic]++;

      //	update the coefficients for the non-zero topics
      if (ignoreLabels) {
        cachedLabelCoefficients[newTopic] =
            (alpha[newTopic] + localLblTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);

        massLblValue.smoothingOnlyMass +=
            alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass +=
            gamma * localLblTopicCounts[newTopic] / (labelsPerTopic[newTopic] + gammaSum);

      } else {

        cachedLabelCoefficients[newTopic] =
            ((1 + 1 / lblWeight) * alpha[newTopic]
                    + localLblTopicCounts[newTopic]
                    + 1 / lblWeight * localTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);

        massLblValue.smoothingOnlyMass +=
            (1 + 1 / lblWeight) * alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass +=
            gamma
                * (localLblTopicCounts[newTopic] + (1 / lblWeight) * localTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);
      }
    }
    if (shouldSaveState) {
      // Update the document-topic count histogram,
      //  for dirichlet estimation
      docLengthCounts[docLength]++;

      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];

        topicDocCounts[topic][localTopicCounts[topic]]++;
      }

      docLblLengthCounts[docLblLength]++;

      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];

        topicLblDocCounts[topic][localLblTopicCounts[topic]]++;
      }
    }

    //	Clean up our mess: reset the coefficients to values with only
    //	smoothing. The next doc will update its own non-zero topics...
    for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {

      int topic = localTopicIndex[denseIndex];

      if (ignoreLabels) {
        cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] = alpha[topic] / (labelsPerTopic[topic] + gammaSum);
      } else {
        cachedCoefficients[topic] =
            (1 + lblWeight) * alpha[topic] / (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] =
            (1 + 1 / lblWeight) * alpha[topic] / (labelsPerTopic[topic] + gammaSum);
      }
    }

    smoothingOnlyMass = massValue.smoothingOnlyMass;
    smoothingOnlyLabelMass = massLblValue.smoothingOnlyMass;
  }
Пример #10
0
  /**
   * Once we have sampled the local counts, trash the "global" type topic counts and reuse the space
   * to build a summary of the type topic counts specific to this worker's section of the corpus.
   */
  public void buildLocalTypeTopicCounts() {

    // Clear the topic totals
    Arrays.fill(tokensPerTopic, 0);

    Arrays.fill(labelsPerTopic, 0);

    // Clear the type/topic counts, only
    //  looking at the entries before the first 0 entry.

    for (int type = 0; type < typeTopicCounts.length; type++) {

      int[] topicCounts = typeTopicCounts[type];

      int position = 0;
      while (position < topicCounts.length && topicCounts[position] > 0) {
        topicCounts[position] = 0;
        position++;
      }
    }

    for (int lbltype = 0; lbltype < lbltypeTopicCounts.length; lbltype++) {

      int[] lbltopicCounts = lbltypeTopicCounts[lbltype];

      int position = 0;
      while (position < lbltopicCounts.length && lbltopicCounts[position] > 0) {
        lbltopicCounts[position] = 0;
        position++;
      }
    }

    for (int doc = startDoc; doc < data.size() && doc < startDoc + numDocs; doc++) {

      TopicAssignment document = data.get(doc);

      FeatureSequence tokens = (FeatureSequence) document.instance.getData();
      FeatureSequence labels = (FeatureSequence) document.instance.getTarget();
      FeatureSequence topicSequence = (FeatureSequence) document.topicSequence;

      int[] topics = topicSequence.getFeatures();

      for (int position = 0; position < tokens.size(); position++) {

        int topic = topics[position];

        if (topic == ParallelTopicModel.UNASSIGNED_TOPIC) {
          continue;
        }

        tokensPerTopic[topic]++;

        // The format for these arrays is
        //  the topic in the rightmost bits
        //  the count in the remaining (left) bits.
        // Since the count is in the high bits, sorting (desc)
        //  by the numeric value of the int guarantees that
        //  higher counts will be before the lower counts.

        int type = tokens.getIndexAtPosition(position);

        int[] currentTypeTopicCounts = typeTopicCounts[type];

        // Start by assuming that the array is either empty
        //  or is in sorted (descending) order.

        // Here we are only adding counts, so if we find
        //  an existing location with the topic, we only need
        //  to ensure that it is not larger than its left neighbor.

        int index = 0;
        int currentTopic = currentTypeTopicCounts[index] & topicMask;
        int currentValue;

        while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
          index++;
          if (index == currentTypeTopicCounts.length) {
            System.out.println("overflow on type " + type);
          }
          currentTopic = currentTypeTopicCounts[index] & topicMask;
        }
        currentValue = currentTypeTopicCounts[index] >> topicBits;

        if (currentValue == 0) {
          // new value is 1, so we don't have to worry about sorting
          //  (except by topic suffix, which doesn't matter)

          currentTypeTopicCounts[index] = (1 << topicBits) + topic;
        } else {
          currentTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + topic;

          // Now ensure that the array is still sorted by
          //  bubbling this value up.
          while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
            int temp = currentTypeTopicCounts[index];
            currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
            currentTypeTopicCounts[index - 1] = temp;

            index--;
          }
        }
      }

      FeatureSequence lblTopicSequence = (FeatureSequence) document.lblTopicSequence;
      int[] lblTopics = lblTopicSequence.getFeatures();

      for (int position = 0; position < labels.size(); position++) {

        int topic = lblTopics[position];

        if (topic == ParallelTopicModel.UNASSIGNED_TOPIC) {
          continue;
        }

        labelsPerTopic[topic]++;

        // The format for these arrays is
        //  the topic in the rightmost bits
        //  the count in the remaining (left) bits.
        // Since the count is in the high bits, sorting (desc)
        //  by the numeric value of the int guarantees that
        //  higher counts will be before the lower counts.

        int type = labels.getIndexAtPosition(position);

        int[] currentlblTypeTopicCounts = lbltypeTopicCounts[type];

        // Start by assuming that the array is either empty
        //  or is in sorted (descending) order.

        // Here we are only adding counts, so if we find
        //  an existing location with the topic, we only need
        //  to ensure that it is not larger than its left neighbor.

        int index = 0;
        int currentTopic = currentlblTypeTopicCounts[index] & topicMask;
        int currentValue;

        while (currentlblTypeTopicCounts[index] > 0 && currentTopic != topic) {
          index++;
          if (index == currentlblTypeTopicCounts.length) {
            System.out.println("overflow on type " + type);
          }
          currentTopic = currentlblTypeTopicCounts[index] & topicMask;
        }
        currentValue = currentlblTypeTopicCounts[index] >> topicBits;

        if (currentValue == 0) {
          // new value is 1, so we don't have to worry about sorting
          //  (except by topic suffix, which doesn't matter)

          currentlblTypeTopicCounts[index] = (1 << topicBits) + topic;
        } else {
          currentlblTypeTopicCounts[index] = ((currentValue + 1) << topicBits) + topic;

          // Now ensure that the array is still sorted by
          //  bubbling this value up.
          while (index > 0
              && currentlblTypeTopicCounts[index] > currentlblTypeTopicCounts[index - 1]) {
            int temp = currentlblTypeTopicCounts[index];
            currentlblTypeTopicCounts[index] = currentlblTypeTopicCounts[index - 1];
            currentlblTypeTopicCounts[index - 1] = temp;

            index--;
          }
        }
      }
    }
  }
Пример #11
0
  public void collectDocumentStatistics() {

    topicCodocumentMatrices = new int[numTopics][numTopWords][numTopWords];
    wordTypeCounts = new int[alphabet.size()];
    numTokens = 0;

    // This is an array of hash sets containing the words-of-interest for each topic,
    //  used for checking if the word at some position is one of those words.
    IntHashSet[] topicTopWordIndices = new IntHashSet[numTopics];

    // The same as the topic top words, but with int indices instead of strings,
    //  used for iterating over positions.
    int[][] topicWordIndicesInOrder = new int[numTopics][numTopWords];

    // This is an array of hash sets that will hold the words-of-interest present in a document,
    //  which will be cleared after every document.
    IntHashSet[] docTopicWordIndices = new IntHashSet[numTopics];

    int numDocs = model.getData().size();

    // The count of each topic, again cleared after every document.
    int[] topicCounts = new int[numTopics];

    for (int topic = 0; topic < numTopics; topic++) {
      IntHashSet wordIndices = new IntHashSet();

      for (int i = 0; i < numTopWords; i++) {
        if (topicTopWords[topic][i] != null) {
          int type = alphabet.lookupIndex(topicTopWords[topic][i]);
          topicWordIndicesInOrder[topic][i] = type;
          wordIndices.add(type);
        }
      }

      topicTopWordIndices[topic] = wordIndices;
      docTopicWordIndices[topic] = new IntHashSet();
    }

    int doc = 0;

    for (TopicAssignment document : model.getData()) {

      FeatureSequence tokens = (FeatureSequence) document.instance.getData();
      FeatureSequence topics = (FeatureSequence) document.topicSequence;

      for (int position = 0; position < tokens.size(); position++) {
        int type = tokens.getIndexAtPosition(position);
        int topic = topics.getIndexAtPosition(position);

        numTokens++;
        wordTypeCounts[type]++;

        topicCounts[topic]++;

        if (topicTopWordIndices[topic].contains(type)) {
          docTopicWordIndices[topic].add(type);
        }
      }

      int docLength = tokens.size();

      if (docLength > 0) {
        int maxTopic = -1;
        int maxCount = -1;

        for (int topic = 0; topic < numTopics; topic++) {

          if (topicCounts[topic] > 0) {
            numNonZeroDocuments[topic]++;

            if (topicCounts[topic] > maxCount) {
              maxTopic = topic;
              maxCount = topicCounts[topic];
            }

            sumCountTimesLogCount[topic] += topicCounts[topic] * Math.log(topicCounts[topic]);

            double proportion =
                (model.alpha[topic] + topicCounts[topic]) / (model.alphaSum + docLength);
            for (int i = 0; i < DEFAULT_DOC_PROPORTIONS.length; i++) {
              if (proportion < DEFAULT_DOC_PROPORTIONS[i]) {
                break;
              }
              numDocumentsAtProportions[topic][i]++;
            }

            IntHashSet supportedWords = docTopicWordIndices[topic];
            int[] indices = topicWordIndicesInOrder[topic];

            for (int i = 0; i < numTopWords; i++) {
              if (supportedWords.contains(indices[i])) {
                for (int j = i; j < numTopWords; j++) {
                  if (i == j) {
                    // Diagonals are total number of documents with word W in topic T
                    topicCodocumentMatrices[topic][i][i]++;
                  } else if (supportedWords.contains(indices[j])) {
                    topicCodocumentMatrices[topic][i][j]++;
                    topicCodocumentMatrices[topic][j][i]++;
                  }
                }
              }
            }

            docTopicWordIndices[topic].clear();
            topicCounts[topic] = 0;
          }
        }

        if (maxTopic > -1) {
          numRank1Documents[maxTopic]++;
        }
      }

      doc++;
    }
  }