public void printTopWords(int numWords, boolean useNewLines) { class WordProb implements Comparable { int wi; double p; public WordProb(int wi, double p) { this.wi = wi; this.p = p; } public final int compareTo(Object o2) { if (p > ((WordProb) o2).p) return -1; else if (p == ((WordProb) o2).p) return 0; else return 1; } } WordProb[] wp = new WordProb[numTypes]; for (int ti = 0; ti < numTopics; ti++) { for (int wi = 0; wi < numTypes; wi++) wp[wi] = new WordProb(wi, ((double) typeTopicCounts[wi][ti]) / tokensPerTopic[ti]); Arrays.sort(wp); if (useNewLines) { System.out.println("\nTopic " + ti); for (int i = 0; i < numWords; i++) System.out.println( ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " " + wp[i].p); } else { System.out.print("Topic " + ti + ": "); for (int i = 0; i < numWords; i++) System.out.print(ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " "); System.out.println(); } } }
public void printDocumentTopics(PrintWriter pw, double threshold, int max) { pw.println("#doc source topic proportion ..."); int docLen; double topicDist[] = new double[topics.length]; for (int di = 0; di < topics.length; di++) { pw.print(di); pw.print(' '); if (ilist.get(di).getSource() != null) { pw.print(ilist.get(di).getSource().toString()); } else { pw.print("null-source"); } pw.print(' '); docLen = topics[di].length; for (int ti = 0; ti < numTopics; ti++) topicDist[ti] = (((float) docTopicCounts[di][ti]) / docLen); if (max < 0) max = numTopics; for (int tp = 0; tp < max; tp++) { double maxvalue = 0; int maxindex = -1; for (int ti = 0; ti < numTopics; ti++) if (topicDist[ti] > maxvalue) { maxvalue = topicDist[ti]; maxindex = ti; } if (maxindex == -1 || topicDist[maxindex] < threshold) break; pw.print(maxindex + " " + topicDist[maxindex] + " "); topicDist[maxindex] = 0; } pw.println(' '); } }
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int featuresLength; int version = in.readInt(); ilist = (InstanceList) in.readObject(); numTopics = in.readInt(); alpha = in.readDouble(); beta = in.readDouble(); tAlpha = in.readDouble(); vBeta = in.readDouble(); int numDocs = ilist.size(); topics = new int[numDocs][]; for (int di = 0; di < ilist.size(); di++) { int docLen = ((FeatureSequence) ilist.get(di).getData()).getLength(); topics[di] = new int[docLen]; for (int si = 0; si < docLen; si++) topics[di][si] = in.readInt(); } docTopicCounts = new int[numDocs][numTopics]; for (int di = 0; di < ilist.size(); di++) for (int ti = 0; ti < numTopics; ti++) docTopicCounts[di][ti] = in.readInt(); int numTypes = ilist.getDataAlphabet().size(); typeTopicCounts = new int[numTypes][numTopics]; for (int fi = 0; fi < numTypes; fi++) for (int ti = 0; ti < numTopics; ti++) typeTopicCounts[fi][ti] = in.readInt(); tokensPerTopic = new int[numTopics]; for (int ti = 0; ti < numTopics; ti++) tokensPerTopic[ti] = in.readInt(); }
private InstanceList readFile() throws IOException { String NL = System.getProperty("line.separator"); Scanner scanner = new Scanner(new FileInputStream(fileName), encoding); ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+"))); pipeList.add(new TokenSequence2FeatureSequence()); InstanceList testing = new InstanceList(new SerialPipes(pipeList)); try { while (scanner.hasNextLine()) { String text = scanner.nextLine(); text = text.replaceAll("\\x0d", ""); Pattern patten = Pattern.compile("^(.*?),(.*?),(.*)$"); Matcher matcher = patten.matcher(text); if (matcher.find()) { docIds.add(matcher.group(1)); testing.addThruPipe(new Instance(matcher.group(3), null, "test instance", null)); } } } finally { scanner.close(); } return testing; }
/** This is (mostly) copied from CRF4.java */ public boolean[][] labelConnectionsIn( Alphabet outputAlphabet, InstanceList trainingSet, String start) { int numLabels = outputAlphabet.size(); boolean[][] connections = new boolean[numLabels][numLabels]; for (int i = 0; i < trainingSet.size(); i++) { Instance instance = trainingSet.getInstance(i); FeatureSequence output = (FeatureSequence) instance.getTarget(); for (int j = 1; j < output.size(); j++) { int sourceIndex = outputAlphabet.lookupIndex(output.get(j - 1)); int destIndex = outputAlphabet.lookupIndex(output.get(j)); assert (sourceIndex >= 0 && destIndex >= 0); connections[sourceIndex][destIndex] = true; } } // Handle start state if (start != null) { int startIndex = outputAlphabet.lookupIndex(start); for (int j = 0; j < outputAlphabet.size(); j++) { connections[startIndex][j] = true; } } return connections; }
public void count() { TIntIntHashMap docCounts = new TIntIntHashMap(); int index = 0; if (instances.size() == 0) { logger.info("Instance list is empty"); return; } if (instances.get(0).getData() instanceof FeatureSequence) { for (Instance instance : instances) { FeatureSequence features = (FeatureSequence) instance.getData(); for (int i = 0; i < features.getLength(); i++) { docCounts.adjustOrPutValue(features.getIndexAtPosition(i), 1, 1); } int[] keys = docCounts.keys(); for (int i = 0; i < keys.length - 1; i++) { int feature = keys[i]; featureCounts[feature] += docCounts.get(feature); documentFrequencies[feature]++; } docCounts = new TIntIntHashMap(); index++; if (index % 1000 == 0) { System.err.println(index); } } } else if (instances.get(0).getData() instanceof FeatureVector) { for (Instance instance : instances) { FeatureVector features = (FeatureVector) instance.getData(); for (int location = 0; location < features.numLocations(); location++) { int feature = features.indexAtLocation(location); double value = features.valueAtLocation(location); documentFrequencies[feature]++; featureCounts[feature] += value; } index++; if (index % 1000 == 0) { System.err.println(index); } } } else { logger.info("Unsupported data class: " + instances.get(0).getData().getClass().getName()); } }
public void estimate( InstanceList documents, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) { ilist = documents.shallowClone(); numTypes = ilist.getDataAlphabet().size(); int numDocs = ilist.size(); topics = new int[numDocs][]; docTopicCounts = new int[numDocs][numTopics]; typeTopicCounts = new int[numTypes][numTopics]; tokensPerTopic = new int[numTopics]; tAlpha = alpha * numTopics; vBeta = beta * numTypes; long startTime = System.currentTimeMillis(); // Initialize with random assignments of tokens to topics // and finish allocating this.topics and this.tokens int topic, seqLen; FeatureSequence fs; for (int di = 0; di < numDocs; di++) { try { fs = (FeatureSequence) ilist.get(di).getData(); } catch (ClassCastException e) { System.err.println( "LDA and other topic models expect FeatureSequence data, not FeatureVector data. " + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence."); throw e; } seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; // Randomly assign tokens to topics for (int si = 0; si < seqLen; si++) { topic = r.nextInt(numTopics); topics[di][si] = topic; docTopicCounts[di][topic]++; typeTopicCounts[fs.getIndexAtPosition(si)][topic]++; tokensPerTopic[topic]++; } } this.estimate( 0, numDocs, numIterations, showTopicsInterval, outputModelInterval, outputModelFilename, r); // 124.5 seconds // 144.8 seconds after using FeatureSequence instead of tokens[][] array // 121.6 seconds after putting "final" on FeatureSequence.getIndexAtPosition() // 106.3 seconds after avoiding array lookup in inner loop with a temporary variable }
public FeatureCountTool(InstanceList instances) { this.instances = instances; numFeatures = instances.getDataAlphabet().size(); featureCounts = new double[numFeatures]; documentFrequencies = new int[numFeatures]; }
public void generateTestInference() { if (lda == null) { System.out.println("Should run lda estimation first."); System.exit(1); return; } if (testTopicDistribution == null) testTopicDistribution = new double[test.size()][]; TopicInferencer infer = lda.getInferencer(); int iterations = 800; int thinning = 5; int burnIn = 100; for (int ti = 0; ti < test.size(); ti++) { testTopicDistribution[ti] = infer.getSampledDistribution(test.get(ti), iterations, thinning, burnIn); } }
private InstanceList generateInstanceList() throws Exception { ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+"))); pipeList.add(new TokenSequence2FeatureSequence()); Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8"); InstanceList instances = new InstanceList(new SerialPipes(pipeList)); instances.addThruPipe( new CsvIterator( fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1)); // data, label, name fields return instances; }
// Just for testing. Recommend instead is mallet/bin/vectors2topics public static void main(String[] args) { InstanceList ilist = InstanceList.load(new File(args[0])); int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000; int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20; System.out.println("Data loaded."); TopicalNGrams tng = new TopicalNGrams(10); tng.estimate(ilist, 200, 1, 0, null, new Randoms()); tng.printTopWords(60, true); }
// Recommended to use mallet/bin/vectors2topics instead. public static void main(String[] args) throws IOException { InstanceList ilist = InstanceList.load(new File(args[0])); int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000; int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20; System.out.println("Data loaded."); LDA lda = new LDA(10); lda.estimate(ilist, numIterations, 50, 0, null, new Randoms()); // should be 1100 lda.printTopWords(numTopWords, true); lda.printDocumentTopics(new File(args[0] + ".lda")); }
public void doInference() { try { ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile)); TopicInferencer inferencer = model.getInferencer(); // TopicInferencer inferencer = // TopicInferencer.read(new File(inferencerFile)); // InstanceList testing = readFile(); readFile(); InstanceList testing = generateInstanceList(); // readFile(); for (int i = 0; i < testing.size(); i++) { StringBuilder probabilities = new StringBuilder(); double[] testProbabilities = inferencer.getSampledDistribution(testing.get(i), 10, 1, 5); ArrayList probabilityList = new ArrayList(); for (int j = 0; j < testProbabilities.length; j++) { probabilityList.add(new Pair<Integer, Double>(j, testProbabilities[j])); } Collections.sort(probabilityList, new CustomComparator()); for (int j = 0; j < testProbabilities.length && j < topN; j++) { if (j > 0) probabilities.append(" "); probabilities.append( ((Pair<Integer, Double>) probabilityList.get(j)).getFirst().toString() + "," + ((Pair<Integer, Double>) probabilityList.get(j)).getSecond().toString()); } System.out.println(docIds.get(i) + "," + probabilities.toString()); } } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }
public static void main(String[] args) throws Exception { CommandOption.setSummary( FeatureCountTool.class, "Print feature counts and instances per feature (eg document frequencies) in an instance list"); CommandOption.process(FeatureCountTool.class, args); InstanceList instances = InstanceList.load(new File(inputFile.value)); FeatureCountTool counter = new FeatureCountTool(instances); counter.count(); counter.printCounts(); }
public void printState(PrintWriter pw) { Alphabet a = ilist.getDataAlphabet(); pw.println("#doc pos typeindex type topic"); for (int di = 0; di < topics.length; di++) { FeatureSequence fs = (FeatureSequence) ilist.get(di).getData(); for (int si = 0; si < topics[di].length; si++) { int type = fs.getIndexAtPosition(si); pw.print(di); pw.print(' '); pw.print(si); pw.print(' '); pw.print(type); pw.print(' '); pw.print(a.lookupObject(type)); pw.print(' '); pw.print(topics[di][si]); pw.println(); } } }
/* One iteration of Gibbs sampling, across all documents. */ public void sampleTopicsForAllDocs(Randoms r) { double[] topicWeights = new double[numTopics]; // Loop over every word in the corpus for (int di = 0; di < topics.length; di++) { sampleTopicsForOneDoc( (FeatureSequence) ilist.get(di).getData(), topics[di], docTopicCounts[di], topicWeights, r); } }
public void test() throws Exception { ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile)); TopicInferencer inferencer = model.getInferencer(); ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+"))); pipeList.add(new TokenSequence2FeatureSequence()); InstanceList instances = new InstanceList(new SerialPipes(pipeList)); Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8"); instances.addThruPipe( new CsvIterator( fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1)); // data, label, name fields double[] testProbabilities = inferencer.getSampledDistribution(instances.get(1), 10, 1, 5); for (int i = 0; i < 1000; i++) System.out.println(i + ": " + testProbabilities[i]); }
public static void main(String[] args) { // String malletFile = "dataset/vlc_lectures.all.en.f8.mallet"; // String simFile = "dataset/vlc/sim5p.csv"; // String solutionFile = "dataset/vlc/task1_solution.en.f8.lm.txt"; // String queryFile = "dataset/task1_query.en.f8.txt"; // String targetFile = "dataset/task1_target.en.f8.txt"; String malletFile = "dataset/vlc/folds/all.0.4189.mallet"; String trainMalletFile = "dataset/vlc/folds/training.0.mallet"; String testMalletFile = "dataset/vlc/folds/test.0.mallet"; String queryFile = "dataset/vlc/folds/query.0.csv"; String linkFile = "dataset/vlc/folds/trainingPairs.0.csv"; String targetFile = "dataset/vlc/folds/target.0.csv"; String solutionFile = "dataset/vlc/task1_solution.en.f8.lm.txt"; int numTopics = 160; int numIterations = 200; double alpha = 0.0016; double beta = 0.0001; InstanceList train = InstanceList.load(new File(trainMalletFile)); InstanceList test = InstanceList.load(new File(testMalletFile)); SeparateParallelLda spl = new SeparateParallelLda(train, test); spl.trainDocuments(numTopics, numIterations, alpha, beta); spl.generateTestInference(); spl.lda.printTopWords(System.out, 10, true); BasicTask1Solution solver = new Task1SolutionWithSeparateData(spl); double precision; try { solver.retrieveTask1Solution(queryFile, solutionFile); precision = Task1Solution.evaluateResult(targetFile, solutionFile); System.out.println( String.format( "SeparateParallelLda: iteration: %d, precisoion: %f", numIterations, precision)); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } }
public TestCRFPipe(String trainingFilename) throws IOException { ArrayList<Pipe> pipes = new ArrayList<Pipe>(); PrintWriter out = new PrintWriter("test.out"); int[][] conjunctions = new int[3][]; conjunctions[0] = new int[] {-1}; conjunctions[1] = new int[] {1}; conjunctions[2] = new int[] {-2, -1}; pipes.add(new SimpleTaggerSentence2TokenSequence()); // pipes.add(new FeaturesInWindow("PREV-", -1, 1)); // pipes.add(new FeaturesInWindow("NEXT-", 1, 2)); pipes.add(new OffsetConjunctions(conjunctions)); pipes.add(new TokenTextCharSuffix("C1=", 1)); pipes.add(new TokenTextCharSuffix("C2=", 2)); pipes.add(new TokenTextCharSuffix("C3=", 3)); pipes.add(new RegexMatches("CAPITALIZED", Pattern.compile("^\\p{Lu}.*"))); pipes.add(new RegexMatches("STARTSNUMBER", Pattern.compile("^[0-9].*"))); pipes.add(new RegexMatches("HYPHENATED", Pattern.compile(".*\\-.*"))); pipes.add(new RegexMatches("DOLLARSIGN", Pattern.compile("\\$.*"))); pipes.add(new TokenFirstPosition("FIRSTTOKEN")); pipes.add(new TokenSequence2FeatureVectorSequence()); pipes.add(new SequencePrintingPipe(out)); Pipe pipe = new SerialPipes(pipes); InstanceList trainingInstances = new InstanceList(pipe); trainingInstances.addThruPipe( new LineGroupIterator( new BufferedReader( new InputStreamReader(new GZIPInputStream(new FileInputStream(trainingFilename)))), Pattern.compile("^\\s*$"), true)); out.close(); }
/* One iteration of Gibbs sampling, across all documents. */ public void sampleTopicsForDocs(int start, int length, Randoms r) { assert (start + length <= docTopicCounts.length); double[] topicWeights = new double[numTopics]; // Loop over every word in the corpus for (int di = start; di < start + length; di++) { sampleTopicsForOneDoc( (FeatureSequence) ilist.get(di).getData(), topics[di], docTopicCounts[di], topicWeights, r); } }
/* One iteration of Gibbs sampling, across all documents. */ private void sampleTopicsForAllDocs(Randoms r) { double[] uniTopicWeights = new double[numTopics]; double[] biTopicWeights = new double[numTopics * 2]; // Loop over every word in the corpus for (int di = 0; di < topics.length; di++) { sampleTopicsForOneDoc( (FeatureSequenceWithBigrams) ilist.get(di).getData(), topics[di], grams[di], docTopicCounts[di], uniTopicWeights, biTopicWeights, r); } }
public void addDocuments( InstanceList additionalDocuments, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) { if (ilist == null) throw new IllegalStateException("Must already have some documents first."); for (Instance inst : additionalDocuments) ilist.add(inst); assert (ilist.getDataAlphabet() == additionalDocuments.getDataAlphabet()); assert (additionalDocuments.getDataAlphabet().size() >= numTypes); numTypes = additionalDocuments.getDataAlphabet().size(); int numNewDocs = additionalDocuments.size(); int numOldDocs = topics.length; int numDocs = numOldDocs + numNewDocs; // Expand various arrays to make space for the new data. int[][] newTopics = new int[numDocs][]; for (int i = 0; i < topics.length; i++) newTopics[i] = topics[i]; topics = newTopics; // The rest of this array will be initialized below. int[][] newDocTopicCounts = new int[numDocs][numTopics]; for (int i = 0; i < docTopicCounts.length; i++) newDocTopicCounts[i] = docTopicCounts[i]; docTopicCounts = newDocTopicCounts; // The rest of this array will be initialized below. int[][] newTypeTopicCounts = new int[numTypes][numTopics]; for (int i = 0; i < typeTopicCounts.length; i++) for (int j = 0; j < numTopics; j++) newTypeTopicCounts[i][j] = typeTopicCounts[i][j]; // This array further populated below FeatureSequence fs; for (int di = numOldDocs; di < numDocs; di++) { try { fs = (FeatureSequence) additionalDocuments.get(di - numOldDocs).getData(); } catch (ClassCastException e) { System.err.println( "LDA and other topic models expect FeatureSequence data, not FeatureVector data. " + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence."); throw e; } int seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; // Randomly assign tokens to topics for (int si = 0; si < seqLen; si++) { int topic = r.nextInt(numTopics); topics[di][si] = topic; docTopicCounts[di][topic]++; typeTopicCounts[fs.getIndexAtPosition(si)][topic]++; tokensPerTopic[topic]++; } } }
/** * Initialize this separate model using a complete list. * * @param documents * @param testStartIndex */ public void divideDocuments(InstanceList documents, int testStartIndex) { Alphabet dataAlpha = documents.getDataAlphabet(); Alphabet targetAlpha = documents.getTargetAlphabet(); this.training = new InstanceList(dataAlpha, targetAlpha); this.test = new InstanceList(dataAlpha, targetAlpha); int di = 0; for (di = 0; di < testStartIndex; di++) { training.add(documents.get(di)); } for (di = testStartIndex; di < documents.size(); di++) { test.add(documents.get(di)); } }
public static void main(String[] args) throws Exception { InstanceList instances = InstanceList.load(new File(args[0])); int numTopics = Integer.parseInt(args[1]); ParallelTopicModel model = new ParallelTopicModel(numTopics, 5.0, 0.01); model.addInstances(instances); model.setNumIterations(1000); model.estimate(); TopicModelDiagnostics diagnostics = new TopicModelDiagnostics(model, 20); if (args.length == 3) { PrintWriter out = new PrintWriter(args[2]); out.println(diagnostics.toXML()); out.close(); } }
public void printCounts() { Alphabet alphabet = instances.getDataAlphabet(); NumberFormat nf = NumberFormat.getInstance(); nf.setMinimumFractionDigits(0); nf.setMaximumFractionDigits(6); nf.setGroupingUsed(false); for (int feature = 0; feature < numFeatures; feature++) { Formatter formatter = new Formatter(new StringBuilder(), Locale.US); formatter.format( "%s\t%s\t%d", alphabet.lookupObject(feature).toString(), nf.format(featureCounts[feature]), documentFrequencies[feature]); System.out.println(formatter); } }
public void printState(PrintWriter pw) { pw.println("#doc pos typeindex type bigrampossible? topic bigram"); for (int di = 0; di < topics.length; di++) { FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData(); for (int si = 0; si < topics[di].length; si++) { int type = fs.getIndexAtPosition(si); pw.print(di); pw.print(' '); pw.print(si); pw.print(' '); pw.print(type); pw.print(' '); pw.print(uniAlphabet.lookupObject(type)); pw.print(' '); pw.print(fs.getBiIndexAtPosition(si) == -1 ? 0 : 1); pw.print(' '); pw.print(topics[di][si]); pw.print(' '); pw.print(grams[di][si]); pw.println(); } } }
public void estimate( InstanceList documents, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) { ilist = documents; uniAlphabet = ilist.getDataAlphabet(); biAlphabet = ((FeatureSequenceWithBigrams) ilist.get(0).getData()).getBiAlphabet(); numTypes = uniAlphabet.size(); numBitypes = biAlphabet.size(); int numDocs = ilist.size(); topics = new int[numDocs][]; grams = new int[numDocs][]; docTopicCounts = new int[numDocs][numTopics]; typeNgramTopicCounts = new int[numTypes][2][numTopics]; unitypeTopicCounts = new int[numTypes][numTopics]; bitypeTopicCounts = new int[numBitypes][numTopics]; tokensPerTopic = new int[numTopics]; bitokensPerTopic = new int[numTypes][numTopics]; tAlpha = alpha * numTopics; vBeta = beta * numTypes; vGamma = gamma * numTypes; long startTime = System.currentTimeMillis(); // Initialize with random assignments of tokens to topics // and finish allocating this.topics and this.tokens int topic, gram, seqLen, fi; for (int di = 0; di < numDocs; di++) { FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData(); seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; grams[di] = new int[seqLen]; // Randomly assign tokens to topics int prevFi = -1, prevTopic = -1; for (int si = 0; si < seqLen; si++) { // randomly sample a topic for the word at position si topic = r.nextInt(numTopics); // if a bigram is allowed at position si, then sample a gram status for it. gram = (fs.getBiIndexAtPosition(si) == -1 ? 0 : r.nextInt(2)); if (gram != 0) biTokens++; topics[di][si] = topic; grams[di][si] = gram; docTopicCounts[di][topic]++; fi = fs.getIndexAtPosition(si); if (prevFi != -1) typeNgramTopicCounts[prevFi][gram][prevTopic]++; if (gram == 0) { unitypeTopicCounts[fi][topic]++; tokensPerTopic[topic]++; } else { bitypeTopicCounts[fs.getBiIndexAtPosition(si)][topic]++; bitokensPerTopic[prevFi][topic]++; } prevFi = fi; prevTopic = topic; } } for (int iterations = 0; iterations < numIterations; iterations++) { sampleTopicsForAllDocs(r); if (iterations % 10 == 0) System.out.print(iterations); else System.out.print("."); System.out.flush(); if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0 && iterations > 0) { System.out.println(); printTopWords(5, false); } if (outputModelInterval != 0 && iterations % outputModelInterval == 0 && iterations > 0) { this.write(new File(outputModelFilename + '.' + iterations)); } } System.out.println( "\nTotal time (sec): " + ((System.currentTimeMillis() - startTime) / 1000.0)); }
public void printTopWords(int numWords, boolean useNewLines) { class WordProb implements Comparable { int wi; double p; public WordProb(int wi, double p) { this.wi = wi; this.p = p; } public final int compareTo(Object o2) { if (p > ((WordProb) o2).p) return -1; else if (p == ((WordProb) o2).p) return 0; else return 1; } } for (int ti = 0; ti < numTopics; ti++) { // Unigrams WordProb[] wp = new WordProb[numTypes]; for (int wi = 0; wi < numTypes; wi++) wp[wi] = new WordProb(wi, (double) unitypeTopicCounts[wi][ti]); Arrays.sort(wp); int numToPrint = Math.min(wp.length, numWords); if (useNewLines) { System.out.println("\nTopic " + ti + " unigrams"); for (int i = 0; i < numToPrint; i++) System.out.println( uniAlphabet.lookupObject(wp[i].wi).toString() + " " + wp[i].p / tokensPerTopic[ti]); } else { System.out.print("Topic " + ti + ": "); for (int i = 0; i < numToPrint; i++) System.out.print(uniAlphabet.lookupObject(wp[i].wi).toString() + " "); } // Bigrams /* wp = new WordProb[numBitypes]; int bisum = 0; for (int wi = 0; wi < numBitypes; wi++) { wp[wi] = new WordProb (wi, ((double)bitypeTopicCounts[wi][ti])); bisum += bitypeTopicCounts[wi][ti]; } Arrays.sort (wp); numToPrint = Math.min(wp.length, numWords); if (useNewLines) { System.out.println ("\nTopic "+ti+" bigrams"); for (int i = 0; i < numToPrint; i++) System.out.println (biAlphabet.lookupObject(wp[i].wi).toString() + " " + wp[i].p/bisum); } else { System.out.print (" "); for (int i = 0; i < numToPrint; i++) System.out.print (biAlphabet.lookupObject(wp[i].wi).toString() + " "); System.out.println(); } */ // Ngrams AugmentableFeatureVector afv = new AugmentableFeatureVector(new Alphabet(), 10000, false); for (int di = 0; di < topics.length; di++) { FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData(); for (int si = topics[di].length - 1; si >= 0; si--) { if (topics[di][si] == ti && grams[di][si] == 1) { String gramString = uniAlphabet.lookupObject(fs.getIndexAtPosition(si)).toString(); while (grams[di][si] == 1 && --si >= 0) gramString = uniAlphabet.lookupObject(fs.getIndexAtPosition(si)).toString() + "_" + gramString; afv.add(gramString, 1.0); } } } // System.out.println ("pre-sorting"); int numNgrams = afv.numLocations(); // System.out.println ("post-sorting "+numNgrams); wp = new WordProb[numNgrams]; int ngramSum = 0; for (int loc = 0; loc < numNgrams; loc++) { wp[loc] = new WordProb(afv.indexAtLocation(loc), afv.valueAtLocation(loc)); ngramSum += wp[loc].p; } Arrays.sort(wp); int numUnitypeTokens = 0, numBitypeTokens = 0, numUnitypeTypes = 0, numBitypeTypes = 0; for (int fi = 0; fi < numTypes; fi++) { numUnitypeTokens += unitypeTopicCounts[fi][ti]; if (unitypeTopicCounts[fi][ti] != 0) numUnitypeTypes++; } for (int fi = 0; fi < numBitypes; fi++) { numBitypeTokens += bitypeTopicCounts[fi][ti]; if (bitypeTopicCounts[fi][ti] != 0) numBitypeTypes++; } if (useNewLines) { System.out.println( "\nTopic " + ti + " unigrams " + numUnitypeTokens + "/" + numUnitypeTypes + " bigrams " + numBitypeTokens + "/" + numBitypeTypes + " phrases " + Math.round(afv.oneNorm()) + "/" + numNgrams); for (int i = 0; i < Math.min(numNgrams, numWords); i++) System.out.println( afv.getAlphabet().lookupObject(wp[i].wi).toString() + " " + wp[i].p / ngramSum); } else { System.out.print( " (unigrams " + numUnitypeTokens + "/" + numUnitypeTypes + " bigrams " + numBitypeTokens + "/" + numBitypeTypes + " phrases " + Math.round(afv.oneNorm()) + "/" + numNgrams + ")\n "); // System.out.print (" (unique-ngrams="+numNgrams+" // ngram-count="+Math.round(afv.oneNorm())+")\n "); for (int i = 0; i < Math.min(numNgrams, numWords); i++) System.out.print(afv.getAlphabet().lookupObject(wp[i].wi).toString() + " "); System.out.println(); } } }
public static CRF4 createCRF(File trainingFile, CRFInfo crfInfo) throws FileNotFoundException { Reader trainingFileReader = new FileReader(trainingFile); // Create a pipe that we can use to convert the training // file to a feature vector sequence. Pipe p = new SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence(); // The training file does contain tags (aka targets) p.setTargetProcessing(true); // Register the default tag with the pipe, by looking it up // in the targetAlphabet before we look up any other tag. p.getTargetAlphabet().lookupIndex(crfInfo.defaultLabel); // Create a new instancelist to hold the training data. InstanceList trainingData = new InstanceList(p); // Read in the training data. trainingData.add(new LineGroupIterator(trainingFileReader, Pattern.compile("^\\s*$"), true)); // Create the CRF model. CRF4 crf = new CRF4(p, null); // Set various config options crf.setGaussianPriorVariance(crfInfo.gaussianVariance); crf.setTransductionType(crfInfo.transductionType); // Set up the model's states. if (crfInfo.stateInfoList != null) { Iterator stateIter = crfInfo.stateInfoList.iterator(); while (stateIter.hasNext()) { CRFInfo.StateInfo state = (CRFInfo.StateInfo) stateIter.next(); crf.addState( state.name, state.initialCost, state.finalCost, state.destinationNames, state.labelNames, state.weightNames); } } else if (crfInfo.stateStructure == CRFInfo.FULLY_CONNECTED_STRUCTURE) crf.addStatesForLabelsConnectedAsIn(trainingData); else if (crfInfo.stateStructure == CRFInfo.HALF_CONNECTED_STRUCTURE) crf.addStatesForHalfLabelsConnectedAsIn(trainingData); else if (crfInfo.stateStructure == CRFInfo.THREE_QUARTERS_CONNECTED_STRUCTURE) crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingData); else if (crfInfo.stateStructure == CRFInfo.BILABELS_STRUCTURE) crf.addStatesForBiLabelsConnectedAsIn(trainingData); else throw new RuntimeException("Unexpected state structure " + crfInfo.stateStructure); // Set up the weight groups. if (crfInfo.weightGroupInfoList != null) { Iterator wgIter = crfInfo.weightGroupInfoList.iterator(); while (wgIter.hasNext()) { CRFInfo.WeightGroupInfo wg = (CRFInfo.WeightGroupInfo) wgIter.next(); FeatureSelection fs = FeatureSelection.createFromRegex( crf.getInputAlphabet(), Pattern.compile(wg.featureSelectionRegex)); crf.setFeatureSelection(crf.getWeightsIndex(wg.name), fs); } } // Train the CRF. crf.train(trainingData, null, null, null, crfInfo.maxIterations); return crf; }