Exemplo n.º 1
0
  protected void sampleTopicsForOneDoc(TopicAssignment doc // int[][] typeTopicCounts
      // ,
      // double[] cachedCoefficients,
      // int[] tokensPerTopic,
      // double betaSum,
      // double beta,
      // double smoothingOnlyMass,
      // int[][] lblTypeTopicCounts,
      // double[] cachedLabelCoefficients,
      // int[] labelsPerTopic,
      // double gammaSum,
      // double gamma,
      // double smoothingOnlyLblMass
      ) {

    FeatureSequence tokenSequence = (FeatureSequence) doc.instance.getData();

    LabelSequence topicSequence = (LabelSequence) doc.topicSequence;

    MassValue massValue = new MassValue();
    massValue.topicBetaMass = 0.0;
    massValue.topicTermMass = 0.0;
    massValue.smoothingOnlyMass = smoothingOnlyMass;

    int nonZeroTopics = 0;

    int[] oneDocTopics = topicSequence.getFeatures();
    int[] localTopicCounts = new int[numTopics];
    int[] localTopicIndex = new int[numTopics];

    // Label Init
    LabelSequence lblTopicSequence = (LabelSequence) doc.lblTopicSequence;
    FeatureSequence labelSequence = (FeatureSequence) doc.instance.getTarget();

    MassValue massLblValue = new MassValue();
    massLblValue.topicBetaMass = 0.0;
    massLblValue.topicTermMass = 0.0;
    massLblValue.smoothingOnlyMass = smoothingOnlyLabelMass;

    int[] oneDocLblTopics = lblTopicSequence.getFeatures();
    int[] localLblTopicCounts = new int[numTopics];

    // initSampling

    int docLength = tokenSequence.getLength();
    //		populate topic counts
    for (int position = 0; position < docLength; position++) {
      if (oneDocTopics[position] == ParallelTopicModel.UNASSIGNED_TOPIC) {
        continue;
      }
      localTopicCounts[oneDocTopics[position]]++;
    }

    docLength = labelSequence.getLength();
    //		populate topic counts
    for (int position = 0; position < docLength; position++) {
      if (oneDocLblTopics[position] == ParallelTopicModel.UNASSIGNED_TOPIC) {
        continue;
      }
      localLblTopicCounts[oneDocLblTopics[position]]++;
    }

    // Build an array that densely lists the topics that
    //  have non-zero counts.
    int denseIndex = 0;
    for (int topic = 0; topic < numTopics; topic++) {
      if (localTopicCounts[topic] != 0 || localLblTopicCounts[topic] != 0) {
        localTopicIndex[denseIndex] = topic;
        denseIndex++;
      }
    }

    // Record the total number of non-zero topics
    nonZeroTopics = denseIndex;
    if (nonZeroTopics < 20) {
      int a = 1;
    }

    // Initialize the topic count/beta sampling bucket
    // Initialize cached coefficients and the topic/beta
    //  normalizing constant.
    for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
      int topic = localTopicIndex[denseIndex];
      int n = localTopicCounts[topic];
      int nl = localLblTopicCounts[topic];

      if (ignoreLabels) {
        //	initialize the normalization constant for the (B * n_{t|d}) term
        massValue.topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
        // massLblValue.topicBetaMass += gamma * nl / (labelsPerTopic[topic] + gammaSum);
        //	update the coefficients for the non-zero topics
        cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
        // cachedLabelCoefficients[topic] = (alpha[topic] + nl) / (labelsPerTopic[topic] +
        // gammaSum);
      } else {
        massValue.topicBetaMass += beta * (n + lblWeight * nl) / (tokensPerTopic[topic] + betaSum);
        // massLblValue.topicBetaMass += gamma * (nl + (1 / lblWeight) * n) / (labelsPerTopic[topic]
        // + gammaSum);
        cachedCoefficients[topic] =
            ((1 + lblWeight) * alpha[topic] + n + lblWeight * nl)
                / (tokensPerTopic[topic] + betaSum);
        // cachedLabelCoefficients[topic] = ((1 + (1 / lblWeight)) * alpha[topic] + nl + (1 /
        // lblWeight) * n) / (labelsPerTopic[topic] + gammaSum);
      }
    }

    // end of Init Sampling

    double[] topicTermScores = new double[numTopics];
    int[] currentTypeTopicCounts;
    //	Iterate over the positions (words) in the document
    docLength = tokenSequence.getLength();
    for (int position = 0; position < docLength; position++) {

      int type = tokenSequence.getIndexAtPosition(position);
      currentTypeTopicCounts = typeTopicCounts[type];

      nonZeroTopics =
          removeOldTopicContribution(
              position,
              oneDocTopics,
              massValue,
              localTopicCounts,
              localLblTopicCounts,
              localTopicIndex,
              cachedCoefficients,
              tokensPerTopic,
              betaSum,
              beta,
              lblWeight,
              nonZeroTopics);

      // calcSamplingValuesPerType
      calcSamplingValuesPerType(
          // tokenSequence,
          position,
          oneDocTopics,
          massValue,
          topicTermScores,
          currentTypeTopicCounts,
          localTopicCounts,
          // localTopicIndex,
          cachedCoefficients,
          tokensPerTopic,
          betaSum,
          beta,
          typeTotals,
          type,
          typeSkewIndexes,
          skewWeight);

      double sample = 0;

      sample =
          random.nextUniform()
              * (massValue.smoothingOnlyMass + massValue.topicBetaMass + massValue.topicTermMass);

      double origSample = sample;

      //	Make sure it actually gets set
      int newTopic = -1;

      newTopic =
          findNewTopic(
              sample,
              massValue,
              topicTermScores,
              currentTypeTopicCounts,
              localTopicCounts,
              localLblTopicCounts,
              localTopicIndex,
              tokensPerTopic,
              betaSum,
              beta,
              nonZeroTopics,
              lblWeight);

      if (newTopic == -1) {
        System.err.println(
            "WorkerRunnable sampling error: "
                + origSample
                + " "
                + sample
                + " "
                + massValue.smoothingOnlyMass
                + " "
                + massValue.topicBetaMass
                + " "
                + massValue.topicTermMass);
        newTopic = numTopics - 1; // TODO is this appropriate
        // throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
      }
      // assert(newTopic != -1);

      //			Put that new topic into the counts
      oneDocTopics[position] = newTopic;

      if (ignoreLabels) {
        massValue.smoothingOnlyMass -=
            alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass -=
            beta * localTopicCounts[newTopic] / (tokensPerTopic[newTopic] + betaSum);
      } else {
        massValue.smoothingOnlyMass -=
            (1 + lblWeight) * alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass -=
            beta
                * (localTopicCounts[newTopic] + lblWeight * localLblTopicCounts[newTopic])
                / (tokensPerTopic[newTopic] + betaSum);
      }
      localTopicCounts[newTopic]++;

      // If this is a new topic for this document,
      //  add the topic to the dense index.
      if (localTopicCounts[newTopic] == 1 && localLblTopicCounts[newTopic] == 0) {

        // First find the point where we
        //  should insert the new topic by going to
        //  the end (which is the only reason we're keeping
        //  track of the number of non-zero
        //  topics) and working backwards

        denseIndex = nonZeroTopics;

        while (denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic) {

          localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
          denseIndex--;
        }

        localTopicIndex[denseIndex] = newTopic;
        nonZeroTopics++;
      }

      tokensPerTopic[newTopic]++;

      if (ignoreLabels) {
        //	update the coefficients for the non-zero topics
        cachedCoefficients[newTopic] =
            (alpha[newTopic] + localTopicCounts[newTopic]) / (tokensPerTopic[newTopic] + betaSum);

        massValue.smoothingOnlyMass +=
            alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass +=
            beta * localTopicCounts[newTopic] / (tokensPerTopic[newTopic] + betaSum);
      } else {
        massValue.smoothingOnlyMass +=
            (1 + lblWeight) * alpha[newTopic] * beta / (tokensPerTopic[newTopic] + betaSum);
        massValue.topicBetaMass +=
            beta
                * (localTopicCounts[newTopic] + lblWeight * localLblTopicCounts[newTopic])
                / (tokensPerTopic[newTopic] + betaSum);

        cachedCoefficients[newTopic] =
            ((1 + lblWeight) * alpha[newTopic]
                    + localTopicCounts[newTopic]
                    + lblWeight * localLblTopicCounts[newTopic])
                / (tokensPerTopic[newTopic] + betaSum);
      }
    }

    // sample labels
    // init labels
    for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
      int topic = localTopicIndex[denseIndex];
      int n = localTopicCounts[topic];
      int nl = localLblTopicCounts[topic];

      if (ignoreLabels) {
        //	initialize the normalization constant for the (B * n_{t|d}) term
        // massValue.topicBetaMass += beta * n / (tokensPerTopic[topic] + betaSum);
        massLblValue.topicBetaMass += gamma * nl / (labelsPerTopic[topic] + gammaSum);
        //	update the coefficients for the non-zero topics
        // cachedCoefficients[topic] = (alpha[topic] + n) / (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] = (alpha[topic] + nl) / (labelsPerTopic[topic] + gammaSum);
      } else {
        // massValue.topicBetaMass += beta * (n + lblWeight * nl) / (tokensPerTopic[topic] +
        // betaSum);
        massLblValue.topicBetaMass +=
            gamma * (nl + (1 / lblWeight) * n) / (labelsPerTopic[topic] + gammaSum);
        // cachedCoefficients[topic] = ((1 + lblWeight) * alpha[topic] + n + lblWeight * nl) /
        // (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] =
            ((1 + (1 / lblWeight)) * alpha[topic] + nl + (1 / lblWeight) * n)
                / (labelsPerTopic[topic] + gammaSum);
      }
    }

    double[] topicLblTermScores = new double[numTopics];
    int[] currentLblTypeTopicCounts;
    int docLblLength = labelSequence.getLength();

    //	Iterate over the positions (words) in the document
    for (int position = 0; position < docLblLength; position++) {

      int type = labelSequence.getIndexAtPosition(position);
      currentLblTypeTopicCounts = lbltypeTopicCounts[type];

      nonZeroTopics =
          removeOldTopicContribution(
              position,
              oneDocLblTopics,
              massLblValue,
              localLblTopicCounts,
              localTopicCounts,
              localTopicIndex,
              cachedLabelCoefficients,
              labelsPerTopic,
              gammaSum,
              gamma,
              1 / lblWeight,
              nonZeroTopics);

      // calcSamplingValuesPerType
      calcSamplingValuesPerType(
          // labelSequence,
          position,
          oneDocLblTopics,
          massLblValue,
          topicLblTermScores,
          currentLblTypeTopicCounts,
          localLblTopicCounts,
          // localTopicIndex,
          cachedLabelCoefficients,
          labelsPerTopic,
          gammaSum,
          gamma,
          lblTypeTotals,
          type,
          lblTypeSkewIndexes,
          lblSkewWeight);
      // massLblValue.smoothingOnlyMass = 0; //ignore smoothing mass

      double sample =
          random.nextUniform()
              * (massLblValue.smoothingOnlyMass
                  + massLblValue.topicBetaMass
                  + massLblValue.topicTermMass);

      // double sample = random.nextUniform() * (massValue.smoothingOnlyMass +
      // massValue.topicBetaMass + massLblValue.smoothingOnlyMass + massLblValue.topicBetaMass +
      // massLblValue.topicTermMass);

      double origSample = sample;

      //	Make sure it actually gets set
      int newTopic = -1;

      newTopic =
          findNewTopic(
              sample,
              massLblValue,
              topicLblTermScores,
              currentLblTypeTopicCounts,
              localLblTopicCounts,
              localTopicCounts,
              localTopicIndex,
              labelsPerTopic,
              gammaSum,
              gamma,
              nonZeroTopics,
              1 / lblWeight);

      if (newTopic == -1) {
        System.err.println(
            "WorkerRunnable sampling labels error: "
                + origSample
                + " "
                + sample
                + " "
                + massLblValue.smoothingOnlyMass
                + " "
                + massLblValue.topicBetaMass
                + " "
                + massLblValue.topicTermMass);
        // newTopic = numTopics - 1; // TODO is this appropriate
        // throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
      }
      assert (newTopic != -1);

      //			Put that new topic into the counts
      oneDocLblTopics[position] = newTopic;

      if (ignoreLabels) {
        massLblValue.smoothingOnlyMass -=
            alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass -=
            gamma * localLblTopicCounts[newTopic] / (labelsPerTopic[newTopic] + gammaSum);
      } else {

        massLblValue.smoothingOnlyMass -=
            (1 + 1 / lblWeight) * alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass -=
            gamma
                * (localLblTopicCounts[newTopic] + (1 / lblWeight) * localTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);
      }

      localLblTopicCounts[newTopic]++;

      // If this is a new topic for this document,
      //  add the topic to the dense index.
      if (localLblTopicCounts[newTopic] == 1 && localTopicCounts[newTopic] == 0) {

        // First find the point where we
        //  should insert the new topic by going to
        //  the end (which is the only reason we're keeping
        //  track of the number of non-zero
        //  topics) and working backwards

        denseIndex = nonZeroTopics;

        while (denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic) {

          localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
          denseIndex--;
        }

        localTopicIndex[denseIndex] = newTopic;
        nonZeroTopics++;
      }

      labelsPerTopic[newTopic]++;

      //	update the coefficients for the non-zero topics
      if (ignoreLabels) {
        cachedLabelCoefficients[newTopic] =
            (alpha[newTopic] + localLblTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);

        massLblValue.smoothingOnlyMass +=
            alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass +=
            gamma * localLblTopicCounts[newTopic] / (labelsPerTopic[newTopic] + gammaSum);

      } else {

        cachedLabelCoefficients[newTopic] =
            ((1 + 1 / lblWeight) * alpha[newTopic]
                    + localLblTopicCounts[newTopic]
                    + 1 / lblWeight * localTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);

        massLblValue.smoothingOnlyMass +=
            (1 + 1 / lblWeight) * alpha[newTopic] * gamma / (labelsPerTopic[newTopic] + gammaSum);
        massLblValue.topicBetaMass +=
            gamma
                * (localLblTopicCounts[newTopic] + (1 / lblWeight) * localTopicCounts[newTopic])
                / (labelsPerTopic[newTopic] + gammaSum);
      }
    }
    if (shouldSaveState) {
      // Update the document-topic count histogram,
      //  for dirichlet estimation
      docLengthCounts[docLength]++;

      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];

        topicDocCounts[topic][localTopicCounts[topic]]++;
      }

      docLblLengthCounts[docLblLength]++;

      for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
        int topic = localTopicIndex[denseIndex];

        topicLblDocCounts[topic][localLblTopicCounts[topic]]++;
      }
    }

    //	Clean up our mess: reset the coefficients to values with only
    //	smoothing. The next doc will update its own non-zero topics...
    for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {

      int topic = localTopicIndex[denseIndex];

      if (ignoreLabels) {
        cachedCoefficients[topic] = alpha[topic] / (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] = alpha[topic] / (labelsPerTopic[topic] + gammaSum);
      } else {
        cachedCoefficients[topic] =
            (1 + lblWeight) * alpha[topic] / (tokensPerTopic[topic] + betaSum);
        cachedLabelCoefficients[topic] =
            (1 + 1 / lblWeight) * alpha[topic] / (labelsPerTopic[topic] + gammaSum);
      }
    }

    smoothingOnlyMass = massValue.smoothingOnlyMass;
    smoothingOnlyLabelMass = massLblValue.smoothingOnlyMass;
  }
Exemplo n.º 2
0
  protected int removeOldTopicContribution(
      int position,
      int[] oneDocTopics,
      MassValue mass,
      int[] localTopicCounts,
      int[] localLblTopicCounts,
      int[] localTopicIndex,
      double[] cachedCoefficients,
      int[] tokensPerTopic,
      double betaSum,
      double beta,
      double lblWeight,
      int nonZeroTopics) {

    int oldTopic = oneDocTopics[position];

    int denseIndex = 0;

    if (oldTopic != ParallelTopicModel.UNASSIGNED_TOPIC) {
      //	Remove this token from all counts.

      // Remove this topic's contribution to the
      //  normalizing constants

      if (ignoreLabels) {
        mass.smoothingOnlyMass -= alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum);
        mass.topicBetaMass -=
            beta * localTopicCounts[oldTopic] / (tokensPerTopic[oldTopic] + betaSum);
      } else {

        mass.smoothingOnlyMass -=
            (1 + lblWeight) * alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum);
        mass.topicBetaMass -=
            beta
                * (localTopicCounts[oldTopic] + lblWeight * localLblTopicCounts[oldTopic])
                / (tokensPerTopic[oldTopic] + betaSum);
      }
      // Decrement the local doc/topic counts

      localTopicCounts[oldTopic]--;

      // Maintain the dense index, if we are deleting
      //  the old topic
      if (localTopicCounts[oldTopic] == 0 && localLblTopicCounts[oldTopic] == 0) {

        // First get to the dense location associated with
        //  the old topic.

        denseIndex = 0;

        // We know it's in there somewhere, so we don't
        //  need bounds checking.
        while (localTopicIndex[denseIndex] != oldTopic) {
          denseIndex++;
        }

        // shift all remaining dense indices to the left.
        while (denseIndex < nonZeroTopics) {
          if (denseIndex < localTopicIndex.length - 1) {
            localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
          }
          denseIndex++;
        }

        nonZeroTopics--;
      }

      // Decrement the global topic count totals
      tokensPerTopic[oldTopic]--;
      assert (tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";

      // Add the old topic's contribution back into the
      //  normalizing constants.
      if (ignoreLabels) {
        mass.smoothingOnlyMass += alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum);
        mass.topicBetaMass +=
            beta * localTopicCounts[oldTopic] / (tokensPerTopic[oldTopic] + betaSum);

        // Reset the cached coefficient for this topic
        cachedCoefficients[oldTopic] =
            (alpha[oldTopic] + localTopicCounts[oldTopic]) / (tokensPerTopic[oldTopic] + betaSum);
      } else {
        mass.smoothingOnlyMass +=
            (1 + lblWeight) * alpha[oldTopic] * beta / (tokensPerTopic[oldTopic] + betaSum);
        mass.topicBetaMass +=
            beta
                * (localTopicCounts[oldTopic] + lblWeight * localLblTopicCounts[oldTopic])
                / (tokensPerTopic[oldTopic] + betaSum);

        // Reset the cached coefficient for this topic
        cachedCoefficients[oldTopic] =
            ((1 + lblWeight) * alpha[oldTopic]
                    + localTopicCounts[oldTopic]
                    + lblWeight * localLblTopicCounts[oldTopic])
                / (tokensPerTopic[oldTopic] + betaSum);
      }
    }

    return nonZeroTopics;
  }