public void printTopWords(int numWords, boolean useNewLines) { class WordProb implements Comparable { int wi; double p; public WordProb(int wi, double p) { this.wi = wi; this.p = p; } public final int compareTo(Object o2) { if (p > ((WordProb) o2).p) return -1; else if (p == ((WordProb) o2).p) return 0; else return 1; } } WordProb[] wp = new WordProb[numTypes]; for (int ti = 0; ti < numTopics; ti++) { for (int wi = 0; wi < numTypes; wi++) wp[wi] = new WordProb(wi, ((double) typeTopicCounts[wi][ti]) / tokensPerTopic[ti]); Arrays.sort(wp); if (useNewLines) { System.out.println("\nTopic " + ti); for (int i = 0; i < numWords; i++) System.out.println( ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " " + wp[i].p); } else { System.out.print("Topic " + ti + ": "); for (int i = 0; i < numWords; i++) System.out.print(ilist.getDataAlphabet().lookupObject(wp[i].wi).toString() + " "); System.out.println(); } } }
/** * @param targetTerm * @param sourceFile * @param trainingAlgo * @param outputFileClassifier * @param outputFileResults * @param termWindowSize * @param pipe * @return */ private static List<ClassificationResult> runTrainingAndClassification( String targetTerm, String sourceFile, String trainingAlgo, String outputFileClassifier, String outputFileResults, int termWindowSize, Pipe pipe, boolean useCollocationalVector) { // Read in concordance file and create list of Mallet training instances // TODO: Remove duplication of code (see execConvertToMalletFormat(...)) String vectorType = useCollocationalVector ? "coll" : "bow"; InstanceList instanceList = readConcordanceFileToInstanceList( targetTerm, sourceFile, termWindowSize, pipe, useCollocationalVector); // Creating splits for training and testing double[] proportions = {0.9, 0.1}; InstanceList[] splitLists = instanceList.split(proportions); InstanceList trainingList = splitLists[0]; InstanceList testList = splitLists[1]; // Train the classifier ClassifierTrainer classifierTrainer = getClassifierTrainerForAlgorithm(trainingAlgo); Classifier classifier = classifierTrainer.train(trainingList); if (classifier.getLabelAlphabet() != null) { // TODO: Make sure this is not null in RandomClassifier System.out.println("Labels:\n" + classifier.getLabelAlphabet()); System.out.println( "Size of data alphabet (= type count of training list): " + classifier.getAlphabet().size()); } // Run tests and get results Trial trial = new Trial(classifier, testList); List<ClassificationResult> results = new ArrayList<ClassificationResult>(); for (int i = 0; i < classifier.getLabelAlphabet().size(); i++) { Label label = classifier.getLabelAlphabet().lookupLabel(i); ClassificationResult result = new MalletClassificationResult( trainingAlgo, targetTerm, vectorType, label.toString(), termWindowSize, trial, sourceFile); results.add(result); System.out.println(result.toString()); } // Save classifier saveClassifierToFile(outputFileClassifier, classifier, trainingAlgo, termWindowSize); return results; }
/** * converts the sentence based instance list into a token based one This is needed for the * ME-version of JET (JetMeClassifier) * * @param METrainerDummyPipe * @param inst just the features for one sentence to be transformed * @return */ public static InstanceList convertFeatsforClassifier( final Pipe METrainerDummyPipe, final Instance inst) { final InstanceList iList = new InstanceList(METrainerDummyPipe); final FeatureVectorSequence fvs = (FeatureVectorSequence) inst.getData(); final LabelSequence ls = (LabelSequence) inst.getTarget(); final LabelAlphabet ldict = (LabelAlphabet) ls.getAlphabet(); final Object source = inst.getSource(); final Object name = inst.getName(); if (ls.size() != fvs.size()) { System.err.println( "failed making token instances: size of labelsequence != size of featue vector sequence: " + ls.size() + " - " + fvs.size()); System.exit(-1); } for (int j = 0; j < fvs.size(); j++) { final Instance I = new Instance(fvs.getFeatureVector(j), ldict.lookupLabel(ls.get(j)), name, source); iList.add(I); } return iList; }
public void printDocumentTopics(PrintWriter pw, double threshold, int max) { pw.println("#doc source topic proportion ..."); int docLen; double topicDist[] = new double[topics.length]; for (int di = 0; di < topics.length; di++) { pw.print(di); pw.print(' '); if (ilist.get(di).getSource() != null) { pw.print(ilist.get(di).getSource().toString()); } else { pw.print("null-source"); } pw.print(' '); docLen = topics[di].length; for (int ti = 0; ti < numTopics; ti++) topicDist[ti] = (((float) docTopicCounts[di][ti]) / docLen); if (max < 0) max = numTopics; for (int tp = 0; tp < max; tp++) { double maxvalue = 0; int maxindex = -1; for (int ti = 0; ti < numTopics; ti++) if (topicDist[ti] > maxvalue) { maxvalue = topicDist[ti]; maxindex = ti; } if (maxindex == -1 || topicDist[maxindex] < threshold) break; pw.print(maxindex + " " + topicDist[maxindex] + " "); topicDist[maxindex] = 0; } pw.println(' '); } }
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int featuresLength; int version = in.readInt(); ilist = (InstanceList) in.readObject(); numTopics = in.readInt(); alpha = in.readDouble(); beta = in.readDouble(); tAlpha = in.readDouble(); vBeta = in.readDouble(); int numDocs = ilist.size(); topics = new int[numDocs][]; for (int di = 0; di < ilist.size(); di++) { int docLen = ((FeatureSequence) ilist.get(di).getData()).getLength(); topics[di] = new int[docLen]; for (int si = 0; si < docLen; si++) topics[di][si] = in.readInt(); } docTopicCounts = new int[numDocs][numTopics]; for (int di = 0; di < ilist.size(); di++) for (int ti = 0; ti < numTopics; ti++) docTopicCounts[di][ti] = in.readInt(); int numTypes = ilist.getDataAlphabet().size(); typeTopicCounts = new int[numTypes][numTopics]; for (int fi = 0; fi < numTypes; fi++) for (int ti = 0; ti < numTopics; ti++) typeTopicCounts[fi][ti] = in.readInt(); tokensPerTopic = new int[numTopics]; for (int ti = 0; ti < numTopics; ti++) tokensPerTopic[ti] = in.readInt(); }
private InstanceList readFile() throws IOException { String NL = System.getProperty("line.separator"); Scanner scanner = new Scanner(new FileInputStream(fileName), encoding); ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+"))); pipeList.add(new TokenSequence2FeatureSequence()); InstanceList testing = new InstanceList(new SerialPipes(pipeList)); try { while (scanner.hasNextLine()) { String text = scanner.nextLine(); text = text.replaceAll("\\x0d", ""); Pattern patten = Pattern.compile("^(.*?),(.*?),(.*)$"); Matcher matcher = patten.matcher(text); if (matcher.find()) { docIds.add(matcher.group(1)); testing.addThruPipe(new Instance(matcher.group(3), null, "test instance", null)); } } } finally { scanner.close(); } return testing; }
public void count() { TIntIntHashMap docCounts = new TIntIntHashMap(); int index = 0; if (instances.size() == 0) { logger.info("Instance list is empty"); return; } if (instances.get(0).getData() instanceof FeatureSequence) { for (Instance instance : instances) { FeatureSequence features = (FeatureSequence) instance.getData(); for (int i = 0; i < features.getLength(); i++) { docCounts.adjustOrPutValue(features.getIndexAtPosition(i), 1, 1); } int[] keys = docCounts.keys(); for (int i = 0; i < keys.length - 1; i++) { int feature = keys[i]; featureCounts[feature] += docCounts.get(feature); documentFrequencies[feature]++; } docCounts = new TIntIntHashMap(); index++; if (index % 1000 == 0) { System.err.println(index); } } } else if (instances.get(0).getData() instanceof FeatureVector) { for (Instance instance : instances) { FeatureVector features = (FeatureVector) instance.getData(); for (int location = 0; location < features.numLocations(); location++) { int feature = features.indexAtLocation(location); double value = features.valueAtLocation(location); documentFrequencies[feature]++; featureCounts[feature] += value; } index++; if (index % 1000 == 0) { System.err.println(index); } } } else { logger.info("Unsupported data class: " + instances.get(0).getData().getClass().getName()); } }
public SVM train(InstanceList trainingList) { svm_problem problem = new svm_problem(); problem.l = trainingList.size(); problem.x = new svm_node[problem.l][]; problem.y = new double[problem.l]; for (int i = 0; i < trainingList.size(); i++) { Instance instance = trainingList.get(i); svm_node[] input = SVM.getSvmNodes(instance); if (input == null) { continue; } int labelIndex = ((Label) instance.getTarget()).getIndex(); problem.x[i] = input; problem.y[i] = labelIndex; } int max_index = trainingList.getDataAlphabet().size(); if (param.gamma == 0 && max_index > 0) { param.gamma = 1.0 / max_index; } // int numLabels = trainingList.getTargetAlphabet().size(); // int[] weight_label = new int[numLabels]; // double[] weight = trainingList.targetLabelDistribution().getValues(); // double minValue = Double.MAX_VALUE; // // for (int i = 0; i < weight.length; i++) { // if (minValue > weight[i]) { // minValue = weight[i]; // } // } // // for (int i = 0; i < weight.length; i++) { // weight_label[i] = i; // weight[i] = weight[i] / minValue; // } // // param.weight_label = weight_label; // param.weight = weight; String error_msg = svm.svm_check_parameter(problem, param); if (error_msg != null) { System.err.print("Error: " + error_msg + "\n"); System.exit(1); } svm_model model = svm.svm_train(problem, param); classifier = new SVM(model, trainingList.getPipe()); return classifier; }
public void estimate( InstanceList documents, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) { ilist = documents.shallowClone(); numTypes = ilist.getDataAlphabet().size(); int numDocs = ilist.size(); topics = new int[numDocs][]; docTopicCounts = new int[numDocs][numTopics]; typeTopicCounts = new int[numTypes][numTopics]; tokensPerTopic = new int[numTopics]; tAlpha = alpha * numTopics; vBeta = beta * numTypes; long startTime = System.currentTimeMillis(); // Initialize with random assignments of tokens to topics // and finish allocating this.topics and this.tokens int topic, seqLen; FeatureSequence fs; for (int di = 0; di < numDocs; di++) { try { fs = (FeatureSequence) ilist.get(di).getData(); } catch (ClassCastException e) { System.err.println( "LDA and other topic models expect FeatureSequence data, not FeatureVector data. " + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence."); throw e; } seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; // Randomly assign tokens to topics for (int si = 0; si < seqLen; si++) { topic = r.nextInt(numTopics); topics[di][si] = topic; docTopicCounts[di][topic]++; typeTopicCounts[fs.getIndexAtPosition(si)][topic]++; tokensPerTopic[topic]++; } } this.estimate( 0, numDocs, numIterations, showTopicsInterval, outputModelInterval, outputModelFilename, r); // 124.5 seconds // 144.8 seconds after using FeatureSequence instead of tokens[][] array // 121.6 seconds after putting "final" on FeatureSequence.getIndexAtPosition() // 106.3 seconds after avoiding array lookup in inner loop with a temporary variable }
public Node(InstanceList ilist, Node parent, int minNumInsts, int[] instIndices) { if (instIndices == null) { instIndices = new int[ilist.size()]; for (int ii = 0; ii < instIndices.length; ii++) instIndices[ii] = ii; } m_gainRatio = GainRatio.createGainRatio(ilist, instIndices, minNumInsts); m_ilist = ilist; m_instIndices = instIndices; m_dataDict = m_ilist.getDataAlphabet(); m_minNumInsts = minNumInsts; m_parent = parent; m_leftChild = m_rightChild = null; }
public InstanceList readArray(String[] cleanTexts) { StringArrayIterator iterator = new StringArrayIterator(cleanTexts); // Construct a new instance list, passing it the pipe we want to use to // process instances. InstanceList instances = new InstanceList(pipe); int index = 0; for (Instance inst : instances) { inst.setName(name_id.get(index)); inst.setTarget("english"); index++; } // Now process each instance provided by the iterator. instances.addThruPipe(iterator); return instances; }
public FeatureCountTool(InstanceList instances) { this.instances = instances; numFeatures = instances.getDataAlphabet().size(); featureCounts = new double[numFeatures]; documentFrequencies = new int[numFeatures]; }
public void split() { if (m_ilist == null) throw new IllegalStateException("Frozen. Cannot split."); int numLeftChildren = 0; boolean[] toLeftChild = new boolean[m_instIndices.length]; for (int i = 0; i < m_instIndices.length; i++) { Instance instance = m_ilist.get(m_instIndices[i]); FeatureVector fv = (FeatureVector) instance.getData(); if (fv.value(m_gainRatio.getMaxValuedIndex()) <= m_gainRatio.getMaxValuedThreshold()) { toLeftChild[i] = true; numLeftChildren++; } else toLeftChild[i] = false; } logger.info( "leftChild.size=" + numLeftChildren + " rightChild.size=" + (m_instIndices.length - numLeftChildren)); int[] leftIndices = new int[numLeftChildren]; int[] rightIndices = new int[m_instIndices.length - numLeftChildren]; int li = 0, ri = 0; for (int i = 0; i < m_instIndices.length; i++) { if (toLeftChild[i]) leftIndices[li++] = m_instIndices[i]; else rightIndices[ri++] = m_instIndices[i]; } m_leftChild = new Node(m_ilist, this, m_minNumInsts, leftIndices); m_rightChild = new Node(m_ilist, this, m_minNumInsts, rightIndices); }
public boolean train( InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval) { assert (ilist.size() > 0); if (emissionEstimator == null) { emissionEstimator = new Multinomial.LaplaceEstimator[numStates()]; transitionEstimator = new Multinomial.LaplaceEstimator[numStates()]; emissionMultinomial = new Multinomial[numStates()]; transitionMultinomial = new Multinomial[numStates()]; Alphabet transitionAlphabet = new Alphabet(); for (int i = 0; i < numStates(); i++) transitionAlphabet.lookupIndex(((State) states.get(i)).getName(), true); for (int i = 0; i < numStates(); i++) { emissionEstimator[i] = new Multinomial.LaplaceEstimator(inputAlphabet); transitionEstimator[i] = new Multinomial.LaplaceEstimator(transitionAlphabet); emissionMultinomial[i] = new Multinomial(getUniformArray(inputAlphabet.size()), inputAlphabet); transitionMultinomial[i] = new Multinomial(getUniformArray(transitionAlphabet.size()), transitionAlphabet); } initialEstimator = new Multinomial.LaplaceEstimator(transitionAlphabet); } for (Instance instance : ilist) { FeatureSequence input = (FeatureSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); new SumLatticeDefault(this, input, output, new Incrementor()); } initialMultinomial = initialEstimator.estimate(); for (int i = 0; i < numStates(); i++) { emissionMultinomial[i] = emissionEstimator[i].estimate(); transitionMultinomial[i] = transitionEstimator[i].estimate(); getState(i).setInitialWeight(initialMultinomial.logProbability(getState(i).getName())); } return true; }
public void generateTestInference() { if (lda == null) { System.out.println("Should run lda estimation first."); System.exit(1); return; } if (testTopicDistribution == null) testTopicDistribution = new double[test.size()][]; TopicInferencer infer = lda.getInferencer(); int iterations = 800; int thinning = 5; int burnIn = 100; for (int ti = 0; ti < test.size(); ti++) { testTopicDistribution[ti] = infer.getSampledDistribution(test.get(ti), iterations, thinning, burnIn); } }
// in the training feature table // Lines should be formatted as: // // [name] [label] [data ... ] // public static Classifier TrainMaxent(String trainingFilename, File modelFile) throws IOException { // build data input pipe ArrayList<Pipe> pipes = new ArrayList<Pipe>(); // define pipe // the features in [data ...] should like: feature:value pipes.add(new Target2Label()); pipes.add(new Csv2FeatureVector()); Pipe pipe = new SerialPipes(pipes); pipe.setTargetProcessing(true); // read data InstanceList trainingInstances = new InstanceList(pipe); FileReader training_file_reader = new FileReader(trainingFilename); CsvIterator reader = new CsvIterator( training_file_reader, "(\\w+)\\s+([^\\s]+)\\s+(.*)", 3, 2, 1); // (data, label, name) field indices trainingInstances.addThruPipe(reader); training_file_reader.close(); // calculate running time long startTime = System.currentTimeMillis(); PrintStream temp = System.err; System.setErr(System.out); // train a Maxent classifier (could be other classifiers) ClassifierTrainer trainer = new MaxEntTrainer(Gaussian_Variance); Classifier classifier = trainer.train(trainingInstances); System.setErr(temp); // calculate running time long endTime = System.currentTimeMillis(); long totalTime = endTime - startTime; System.out.println("Total training time: " + totalTime); // write model ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile)); oos.writeObject(classifier); oos.close(); return classifier; }
public void trainClassifier(File dir, String... args) throws Exception { InstanceListCreator instanceListCreator = new InstanceListCreator(); InstanceList instanceList = instanceListCreator.createInstanceList(getTrainingDataFile(dir)); instanceList.save(new File(dir, "training-data.ser")); String factoryName = args[0]; Class<ClassifierTrainerFactory<?>> factoryClass = createTrainerFactory(factoryName); if (factoryClass == null) { String factoryName2 = "org.cleartk.ml.mallet.factory." + factoryName + "TrainerFactory"; factoryClass = createTrainerFactory(factoryName2); } if (factoryClass == null) { throw new IllegalArgumentException( String.format( "name for classifier trainer factory is not valid: name given ='%s'. Valid classifier names include: %s, %s, %s, and %s", factoryName, ClassifierTrainerFactory.NAMES[0], ClassifierTrainerFactory.NAMES[1], ClassifierTrainerFactory.NAMES[2], ClassifierTrainerFactory.NAMES[3])); } String[] factoryArgs = new String[args.length - 1]; System.arraycopy(args, 1, factoryArgs, 0, factoryArgs.length); ClassifierTrainerFactory<?> factory = factoryClass.newInstance(); ClassifierTrainer<?> trainer = null; try { trainer = factory.createTrainer(factoryArgs); } catch (Throwable t) { throw new IllegalArgumentException( "Unable to create trainer. Usage for " + factoryClass.getCanonicalName() + ": " + factory.getUsageMessage(), t); } this.classifier = trainer.train(instanceList); ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(new File(dir, MODEL_NAME))); oos.writeObject(classifier); oos.close(); }
public void train(String[] trainSections, String[] testSections) throws IOException { pipe = defaultPipe(); InstanceList trainingInstanceList = prepareInstanceList(trainSections); InstanceList testingInstanceList = prepareInstanceList(testSections); // Classifier classifier = trainer.train(trainingInstanceList, testingInstanceList); Classifier classifier = trainer.train(trainingInstanceList); System.out.println("training size: " + trainingInstanceList.size()); System.out.println("testing size: " + testingInstanceList.size()); // showAccuracy(classifier, testingInstanceList); // getTypeSpecificAccuracy(trainingInstanceList, testingInstanceList, true); // showInterpolatedTCAccuracy(trainingInstanceList, testingInstanceList); }
/** * Create and train a CRF model from the given training data, optionally testing it on the given * test data. * * @param training training data * @param testing test data (possibly <code>null</code>) * @param eval accuracy evaluator (possibly <code>null</code>) * @param orders label Markov orders (main and backoff) * @param defaultLabel default label * @param forbidden regular expression specifying impossible label transitions <em>current</em> * <code>,</code><em>next</em> (<code>null</code> indicates no forbidden transitions) * @param allowed regular expression specifying allowed label transitions (<code>null</code> * indicates everything is allowed that is not forbidden) * @param connected whether to include even transitions not occurring in the training data. * @param iterations number of training iterations * @param var Gaussian prior variance * @return the trained model */ public static CRF train( InstanceList training, InstanceList testing, TransducerEvaluator eval, int[] orders, String defaultLabel, String forbidden, String allowed, boolean connected, int iterations, double var, CRF crf) { Pattern forbiddenPat = Pattern.compile(forbidden); Pattern allowedPat = Pattern.compile(allowed); if (crf == null) { crf = new CRF(training.getPipe(), (Pipe) null); String startName = crf.addOrderNStates( training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crft.setGaussianPriorVariance(var); for (int i = 0; i < crf.numStates(); i++) crf.getState(i).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT); crf.getState(startName).setInitialWeight(0.0); } logger.info("Training on " + training.size() + " instances"); if (testing != null) logger.info("Testing on " + testing.size() + " instances"); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); if (featureInductionOption.value) { crft.trainWithFeatureInduction( training, null, testing, eval, iterations, 10, 20, 500, 0.5, false, null); } else { boolean converged; for (int i = 1; i <= iterations; i++) { converged = crft.train(training, 1); if (i % 1 == 0 && eval != null) // Change the 1 to higher integer to evaluate less often eval.evaluate(crft); if (viterbiOutputOption.value && i % 10 == 0) new ViterbiWriter( "", new InstanceList[] {training, testing}, new String[] {"training", "testing"}) .evaluate(crft); if (converged) break; } } return crf; }
/** * Train a classifier * * @param trainingInstances * @param trainingPortion The percentage to be used for training (<=1.0), the rest is used for * testing. * @return */ public Classifier train(InstanceList trainingInstances, double trainingPortion) { InstanceList[] instanceLists = trainingInstances.split( new Random(), new double[] {trainingPortion, (1 - trainingPortion)}); // InstanceList[] instanceLists = // trainingInstances.splitInOrder(new double[]{trainingPortion, (1-trainingPortion)}); return this.train(instanceLists[0], instanceLists[1]); }
/** * Calculates the minimum description length of this node, i.e., the length of the binary * encoding that describes the feature and the split value used at this node */ public double getMDL() { int numClasses = m_ilist.getTargetAlphabet().size(); double mdl = getSize() * getGainRatio().getBaseEntropy(); mdl += ((numClasses - 1) * Math.log(getSize() / 2.0)) / (2 * GainRatio.log2); double piPow = Math.pow(Math.PI, numClasses / 2.0); double gammaVal = Maths.gamma(numClasses / 2.0); mdl += Math.log(piPow / gammaVal) / GainRatio.log2; return mdl; }
// Just for testing. Recommend instead is mallet/bin/vectors2topics public static void main(String[] args) { InstanceList ilist = InstanceList.load(new File(args[0])); int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000; int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20; System.out.println("Data loaded."); TopicalNGrams tng = new TopicalNGrams(10); tng.estimate(ilist, 200, 1, 0, null, new Randoms()); tng.printTopWords(60, true); }
private InstanceList generateInstanceList() throws Exception { ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+"))); pipeList.add(new TokenSequence2FeatureSequence()); Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8"); InstanceList instances = new InstanceList(new SerialPipes(pipeList)); instances.addThruPipe( new CsvIterator( fileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"), 3, 2, 1)); // data, label, name fields return instances; }
double getAccuracy(Classifier classifier, InstanceList instanceList) { int total = instanceList.size(); int correct = 0; for (Instance instance : instanceList) { Classification classification = classifier.classify(instance); if (classification.bestLabelIsCorrect()) correct++; } return (1.0 * correct) / total; }
// Recommended to use mallet/bin/vectors2topics instead. public static void main(String[] args) throws IOException { InstanceList ilist = InstanceList.load(new File(args[0])); int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000; int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20; System.out.println("Data loaded."); LDA lda = new LDA(10); lda.estimate(ilist, numIterations, 50, 0, null, new Randoms()); // should be 1100 lda.printTopWords(numTopWords, true); lda.printDocumentTopics(new File(args[0] + ".lda")); }
public DocumentStream(long ts, String rootDir, FeedSettings settings) { this.ts = ts; this.rootDir = rootDir + "/" + ts + "/"; // this.maxWordsPerTopic =maxWordsPerTopic; this.settings = settings; list = LdaModel.createInstanceList(this.ts); testing = new InstanceList(list.getPipe()); lda = new LdaModel(); }
/** * Prepare Instances for use with LDA. * * @param r * @return */ public static InstanceList loadInstancesLDA(Reader r) { ArrayList<Pipe> pipeList = new ArrayList<Pipe>(); // Pipes: lowercase, tokenize, remove stopwords, map to features pipeList.add(new Target2Label()); pipeList.add(new CharSequenceLowercase()); pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}[\\p{L}\\p{P}]+\\p{L}"))); pipeList.add( new TokenSequenceRemoveStopwords(stopWords, stopWordsEncoding, false, false, false)); pipeList.add(new TokenSequence2FeatureSequence()); SerialPipes pipes = new SerialPipes(pipeList); InstanceList instances = new InstanceList(pipes); // create instances with: 3: data; 2: label; 1: name fields instances.addThruPipe(new CsvIterator(r, Pattern.compile("(.*)\t(.*)\t(.*)"), 3, 2, 1)); return instances; }
public void doInference() { try { ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile)); TopicInferencer inferencer = model.getInferencer(); // TopicInferencer inferencer = // TopicInferencer.read(new File(inferencerFile)); // InstanceList testing = readFile(); readFile(); InstanceList testing = generateInstanceList(); // readFile(); for (int i = 0; i < testing.size(); i++) { StringBuilder probabilities = new StringBuilder(); double[] testProbabilities = inferencer.getSampledDistribution(testing.get(i), 10, 1, 5); ArrayList probabilityList = new ArrayList(); for (int j = 0; j < testProbabilities.length; j++) { probabilityList.add(new Pair<Integer, Double>(j, testProbabilities[j])); } Collections.sort(probabilityList, new CustomComparator()); for (int j = 0; j < testProbabilities.length && j < topN; j++) { if (j > 0) probabilities.append(" "); probabilities.append( ((Pair<Integer, Double>) probabilityList.get(j)).getFirst().toString() + "," + ((Pair<Integer, Double>) probabilityList.get(j)).getSecond().toString()); } System.out.println(docIds.get(i) + "," + probabilities.toString()); } } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }
public static void main(String[] args) throws Exception { CommandOption.setSummary( FeatureCountTool.class, "Print feature counts and instances per feature (eg document frequencies) in an instance list"); CommandOption.process(FeatureCountTool.class, args); InstanceList instances = InstanceList.load(new File(inputFile.value)); FeatureCountTool counter = new FeatureCountTool(instances); counter.count(); counter.printCounts(); }
public void printState(PrintWriter pw) { Alphabet a = ilist.getDataAlphabet(); pw.println("#doc pos typeindex type topic"); for (int di = 0; di < topics.length; di++) { FeatureSequence fs = (FeatureSequence) ilist.get(di).getData(); for (int si = 0; si < topics[di].length; si++) { int type = fs.getIndexAtPosition(si); pw.print(di); pw.print(' '); pw.print(si); pw.print(' '); pw.print(type); pw.print(' '); pw.print(a.lookupObject(type)); pw.print(' '); pw.print(topics[di][si]); pw.println(); } } }