public TopicModelDiagnostics(ParallelTopicModel model, int numTopWords) {
    numTopics = model.getNumTopics();
    this.numTopWords = numTopWords;

    this.model = model;

    alphabet = model.getAlphabet();
    topicSortedWords = model.getSortedWords();

    topicTopWords = new String[numTopics][numTopWords];

    numRank1Documents = new int[numTopics];
    numNonZeroDocuments = new int[numTopics];
    numDocumentsAtProportions = new int[numTopics][DEFAULT_DOC_PROPORTIONS.length];
    sumCountTimesLogCount = new double[numTopics];

    diagnostics = new ArrayList<TopicScores>();

    for (int topic = 0; topic < numTopics; topic++) {

      int position = 0;
      TreeSet<IDSorter> sortedWords = topicSortedWords.get(topic);

      // How many words should we report? Some topics may have fewer than
      //  the default number of words with non-zero weight.
      int limit = numTopWords;
      if (sortedWords.size() < numTopWords) {
        limit = sortedWords.size();
      }

      Iterator<IDSorter> iterator = sortedWords.iterator();
      for (int i = 0; i < limit; i++) {
        IDSorter info = iterator.next();
        topicTopWords[topic][i] = (String) alphabet.lookupObject(info.getID());
      }
    }

    collectDocumentStatistics();

    diagnostics.add(getTokensPerTopic(model.tokensPerTopic));
    diagnostics.add(getDocumentEntropy(model.tokensPerTopic));
    diagnostics.add(getWordLengthScores());
    diagnostics.add(getCoherence());
    diagnostics.add(getDistanceFromUniform());
    diagnostics.add(getDistanceFromCorpus());
    diagnostics.add(getEffectiveNumberOfWords());
    diagnostics.add(getTokenDocumentDiscrepancies());
    diagnostics.add(getRank1Percent());
    diagnostics.add(getDocumentPercentRatio(FIFTY_PERCENT_INDEX, TWO_PERCENT_INDEX));
    diagnostics.add(getDocumentPercent(5));
    diagnostics.add(getExclusivity());
  }
 public void trainDocuments(
     InstanceList documents, int numTopics, int numIterations, double alpha, double beta) {
   double alphaSum = alpha * numTopics;
   lda = new ParallelTopicModel(numTopics, alphaSum, beta);
   lda.addInstances(documents);
   lda.setNumThreads(4);
   lda.setNumIterations(numIterations);
   lda.printLogLikelihood = false;
   try {
     lda.estimate();
   } catch (IOException e) {
     // TODO Auto-generated catch block
     e.printStackTrace();
   }
   System.out.println("LDA parameter, alphaSum: " + lda.alphaSum + ", beta: " + lda.beta);
 }
  public static void main(String[] args) throws Exception {
    InstanceList instances = InstanceList.load(new File(args[0]));
    int numTopics = Integer.parseInt(args[1]);
    ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01);
    model.addInstances(instances);
    model.setNumIterations(1000);

    model.estimate();

    TopicModelDiagnostics diagnostics = new TopicModelDiagnostics(model, 20);

    if (args.length == 3) {
      PrintWriter out = new PrintWriter(args[2]);
      out.println(diagnostics.toXML());
      out.close();
    }
  }
Exemplo n.º 4
0
  public void doInference() {

    try {

      ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile));
      TopicInferencer inferencer = model.getInferencer();

      // TopicInferencer inferencer =
      //    TopicInferencer.read(new File(inferencerFile));

      // InstanceList testing = readFile();
      readFile();
      InstanceList testing = generateInstanceList(); // readFile();

      for (int i = 0; i < testing.size(); i++) {

        StringBuilder probabilities = new StringBuilder();
        double[] testProbabilities = inferencer.getSampledDistribution(testing.get(i), 10, 1, 5);

        ArrayList probabilityList = new ArrayList();

        for (int j = 0; j < testProbabilities.length; j++) {
          probabilityList.add(new Pair<Integer, Double>(j, testProbabilities[j]));
        }

        Collections.sort(probabilityList, new CustomComparator());

        for (int j = 0; j < testProbabilities.length && j < topN; j++) {
          if (j > 0) probabilities.append(" ");
          probabilities.append(
              ((Pair<Integer, Double>) probabilityList.get(j)).getFirst().toString()
                  + ","
                  + ((Pair<Integer, Double>) probabilityList.get(j)).getSecond().toString());
        }

        System.out.println(docIds.get(i) + "," + probabilities.toString());
      }

    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
  protected double queryVsmSimilarity(int qdocId, int targetDocId) {
    // TODO Auto-generated method stub
    double predSim = 0;

    double[] v1 = lda.getTopicProbabilities(qdocId);
    double[] v2 = testTopicDistribution[targetDocId];
    predSim = rs.util.vlc.Util.cosineProduct(v1, v2);

    return predSim;
  }
Exemplo n.º 6
0
  public void test() throws Exception {

    ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile));
    TopicInferencer inferencer = model.getInferencer();

    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
    pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+")));
    pipeList.add(new TokenSequence2FeatureSequence());

    InstanceList instances = new InstanceList(new SerialPipes(pipeList));
    Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
    instances.addThruPipe(
        new CsvIterator(
            fileReader,
            Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
            3,
            2,
            1)); // data, label, name fields
    double[] testProbabilities = inferencer.getSampledDistribution(instances.get(1), 10, 1, 5);
    for (int i = 0; i < 1000; i++) System.out.println(i + ": " + testProbabilities[i]);
  }
 public void generateTestInference() {
   if (lda == null) {
     System.out.println("Should run lda estimation first.");
     System.exit(1);
     return;
   }
   if (testTopicDistribution == null) testTopicDistribution = new double[test.size()][];
   TopicInferencer infer = lda.getInferencer();
   int iterations = 800;
   int thinning = 5;
   int burnIn = 100;
   for (int ti = 0; ti < test.size(); ti++) {
     testTopicDistribution[ti] =
         infer.getSampledDistribution(test.get(ti), iterations, thinning, burnIn);
   }
 }
  public void collectDocumentStatistics() {

    topicCodocumentMatrices = new int[numTopics][numTopWords][numTopWords];
    wordTypeCounts = new int[alphabet.size()];
    numTokens = 0;

    // This is an array of hash sets containing the words-of-interest for each topic,
    //  used for checking if the word at some position is one of those words.
    IntHashSet[] topicTopWordIndices = new IntHashSet[numTopics];

    // The same as the topic top words, but with int indices instead of strings,
    //  used for iterating over positions.
    int[][] topicWordIndicesInOrder = new int[numTopics][numTopWords];

    // This is an array of hash sets that will hold the words-of-interest present in a document,
    //  which will be cleared after every document.
    IntHashSet[] docTopicWordIndices = new IntHashSet[numTopics];

    int numDocs = model.getData().size();

    // The count of each topic, again cleared after every document.
    int[] topicCounts = new int[numTopics];

    for (int topic = 0; topic < numTopics; topic++) {
      IntHashSet wordIndices = new IntHashSet();

      for (int i = 0; i < numTopWords; i++) {
        if (topicTopWords[topic][i] != null) {
          int type = alphabet.lookupIndex(topicTopWords[topic][i]);
          topicWordIndicesInOrder[topic][i] = type;
          wordIndices.add(type);
        }
      }

      topicTopWordIndices[topic] = wordIndices;
      docTopicWordIndices[topic] = new IntHashSet();
    }

    int doc = 0;

    for (TopicAssignment document : model.getData()) {

      FeatureSequence tokens = (FeatureSequence) document.instance.getData();
      FeatureSequence topics = (FeatureSequence) document.topicSequence;

      for (int position = 0; position < tokens.size(); position++) {
        int type = tokens.getIndexAtPosition(position);
        int topic = topics.getIndexAtPosition(position);

        numTokens++;
        wordTypeCounts[type]++;

        topicCounts[topic]++;

        if (topicTopWordIndices[topic].contains(type)) {
          docTopicWordIndices[topic].add(type);
        }
      }

      int docLength = tokens.size();

      if (docLength > 0) {
        int maxTopic = -1;
        int maxCount = -1;

        for (int topic = 0; topic < numTopics; topic++) {

          if (topicCounts[topic] > 0) {
            numNonZeroDocuments[topic]++;

            if (topicCounts[topic] > maxCount) {
              maxTopic = topic;
              maxCount = topicCounts[topic];
            }

            sumCountTimesLogCount[topic] += topicCounts[topic] * Math.log(topicCounts[topic]);

            double proportion =
                (model.alpha[topic] + topicCounts[topic]) / (model.alphaSum + docLength);
            for (int i = 0; i < DEFAULT_DOC_PROPORTIONS.length; i++) {
              if (proportion < DEFAULT_DOC_PROPORTIONS[i]) {
                break;
              }
              numDocumentsAtProportions[topic][i]++;
            }

            IntHashSet supportedWords = docTopicWordIndices[topic];
            int[] indices = topicWordIndicesInOrder[topic];

            for (int i = 0; i < numTopWords; i++) {
              if (supportedWords.contains(indices[i])) {
                for (int j = i; j < numTopWords; j++) {
                  if (i == j) {
                    // Diagonals are total number of documents with word W in topic T
                    topicCodocumentMatrices[topic][i][i]++;
                  } else if (supportedWords.contains(indices[j])) {
                    topicCodocumentMatrices[topic][i][j]++;
                    topicCodocumentMatrices[topic][j][i]++;
                  }
                }
              }
            }

            docTopicWordIndices[topic].clear();
            topicCounts[topic] = 0;
          }
        }

        if (maxTopic > -1) {
          numRank1Documents[maxTopic]++;
        }
      }

      doc++;
    }
  }