private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { int featuresLength; int version = in.readInt(); ilist = (InstanceList) in.readObject(); numTopics = in.readInt(); alpha = in.readDouble(); beta = in.readDouble(); tAlpha = in.readDouble(); vBeta = in.readDouble(); int numDocs = ilist.size(); topics = new int[numDocs][]; for (int di = 0; di < ilist.size(); di++) { int docLen = ((FeatureSequence) ilist.get(di).getData()).getLength(); topics[di] = new int[docLen]; for (int si = 0; si < docLen; si++) topics[di][si] = in.readInt(); } docTopicCounts = new int[numDocs][numTopics]; for (int di = 0; di < ilist.size(); di++) for (int ti = 0; ti < numTopics; ti++) docTopicCounts[di][ti] = in.readInt(); int numTypes = ilist.getDataAlphabet().size(); typeTopicCounts = new int[numTypes][numTopics]; for (int fi = 0; fi < numTypes; fi++) for (int ti = 0; ti < numTopics; ti++) typeTopicCounts[fi][ti] = in.readInt(); tokensPerTopic = new int[numTopics]; for (int ti = 0; ti < numTopics; ti++) tokensPerTopic[ti] = in.readInt(); }
public static InstanceList scale(InstanceList trainingList, double lower, double upper) { InstanceList ret = copy(trainingList); Alphabet featDict = ret.getDataAlphabet(); double[] feat_max = new double[featDict.size()]; double[] feat_min = new double[featDict.size()]; for (int i = 0; i < feat_max.length; i++) { feat_max[i] = -Double.MAX_VALUE; feat_min[i] = Double.MAX_VALUE; } for (int i = 0; i < ret.size(); i++) { Instance inst = ret.get(i); FeatureVector fv = (FeatureVector) inst.getData(); for (int loc = 0; loc < fv.numLocations(); loc++) { int featId = fv.indexAtLocation(loc); double value = fv.valueAtLocation(loc); double maxValue = feat_max[featId]; double minValue = feat_min[featId]; double newMaxValue = Math.max(value, maxValue); double newMinValue = Math.min(value, minValue); feat_max[featId] = newMaxValue; feat_min[featId] = newMinValue; } } // double lower = -1; // double upper = 1; for (int i = 0; i < ret.size(); i++) { Instance inst = ret.get(i); FeatureVector fv = (FeatureVector) inst.getData(); for (int loc = 0; loc < fv.numLocations(); loc++) { int featId = fv.indexAtLocation(loc); double value = fv.valueAtLocation(loc); double maxValue = feat_max[featId]; double minValue = feat_min[featId]; double newValue = Double.NaN; if (maxValue == minValue) { newValue = value; } else if (value == minValue) { newValue = lower; } else if (value == maxValue) { newValue = upper; } else { newValue = lower + (upper - lower) * (value - minValue) / (maxValue - minValue); } fv.setValueAtLocation(loc, newValue); } } return ret; }
public SVM train(InstanceList trainingList) { svm_problem problem = new svm_problem(); problem.l = trainingList.size(); problem.x = new svm_node[problem.l][]; problem.y = new double[problem.l]; for (int i = 0; i < trainingList.size(); i++) { Instance instance = trainingList.get(i); svm_node[] input = SVM.getSvmNodes(instance); if (input == null) { continue; } int labelIndex = ((Label) instance.getTarget()).getIndex(); problem.x[i] = input; problem.y[i] = labelIndex; } int max_index = trainingList.getDataAlphabet().size(); if (param.gamma == 0 && max_index > 0) { param.gamma = 1.0 / max_index; } // int numLabels = trainingList.getTargetAlphabet().size(); // int[] weight_label = new int[numLabels]; // double[] weight = trainingList.targetLabelDistribution().getValues(); // double minValue = Double.MAX_VALUE; // // for (int i = 0; i < weight.length; i++) { // if (minValue > weight[i]) { // minValue = weight[i]; // } // } // // for (int i = 0; i < weight.length; i++) { // weight_label[i] = i; // weight[i] = weight[i] / minValue; // } // // param.weight_label = weight_label; // param.weight = weight; String error_msg = svm.svm_check_parameter(problem, param); if (error_msg != null) { System.err.print("Error: " + error_msg + "\n"); System.exit(1); } svm_model model = svm.svm_train(problem, param); classifier = new SVM(model, trainingList.getPipe()); return classifier; }
public void generateTestInference() { if (lda == null) { System.out.println("Should run lda estimation first."); System.exit(1); return; } if (testTopicDistribution == null) testTopicDistribution = new double[test.size()][]; TopicInferencer infer = lda.getInferencer(); int iterations = 800; int thinning = 5; int burnIn = 100; for (int ti = 0; ti < test.size(); ti++) { testTopicDistribution[ti] = infer.getSampledDistribution(test.get(ti), iterations, thinning, burnIn); } }
public boolean train( InstanceList ilist, InstanceList validation, InstanceList testing, TransducerEvaluator eval) { assert (ilist.size() > 0); if (emissionEstimator == null) { emissionEstimator = new Multinomial.LaplaceEstimator[numStates()]; transitionEstimator = new Multinomial.LaplaceEstimator[numStates()]; emissionMultinomial = new Multinomial[numStates()]; transitionMultinomial = new Multinomial[numStates()]; Alphabet transitionAlphabet = new Alphabet(); for (int i = 0; i < numStates(); i++) transitionAlphabet.lookupIndex(((State) states.get(i)).getName(), true); for (int i = 0; i < numStates(); i++) { emissionEstimator[i] = new Multinomial.LaplaceEstimator(inputAlphabet); transitionEstimator[i] = new Multinomial.LaplaceEstimator(transitionAlphabet); emissionMultinomial[i] = new Multinomial(getUniformArray(inputAlphabet.size()), inputAlphabet); transitionMultinomial[i] = new Multinomial(getUniformArray(transitionAlphabet.size()), transitionAlphabet); } initialEstimator = new Multinomial.LaplaceEstimator(transitionAlphabet); } for (Instance instance : ilist) { FeatureSequence input = (FeatureSequence) instance.getData(); FeatureSequence output = (FeatureSequence) instance.getTarget(); new SumLatticeDefault(this, input, output, new Incrementor()); } initialMultinomial = initialEstimator.estimate(); for (int i = 0; i < numStates(); i++) { emissionMultinomial[i] = emissionEstimator[i].estimate(); transitionMultinomial[i] = transitionEstimator[i].estimate(); getState(i).setInitialWeight(initialMultinomial.logProbability(getState(i).getName())); } return true; }
public static InstanceList copy(InstanceList instances) { InstanceList ret = (InstanceList) instances.clone(); // LabelAlphabet labelDict = (LabelAlphabet) ret.getTargetAlphabet(); Alphabet featDict = ret.getDataAlphabet(); for (int i = 0; i < ret.size(); i++) { Instance instance = ret.get(i); Instance clone = (Instance) instance.clone(); FeatureVector fv = (FeatureVector) clone.getData(); int[] indices = fv.getIndices(); double[] values = fv.getValues(); int[] newIndices = new int[indices.length]; System.arraycopy(indices, 0, newIndices, 0, indices.length); double[] newValues = new double[indices.length]; System.arraycopy(values, 0, newValues, 0, indices.length); FeatureVector newFv = new FeatureVector(featDict, newIndices, newValues); Instance newInstance = new Instance(newFv, instance.getTarget(), instance.getName(), instance.getSource()); ret.set(i, newInstance); } return ret; }
/** * converts the sentence based instance list into a token based one This is needed for the * ME-version of JET (JetMeClassifier) * * @param METrainerDummyPipe * @param orgList the sentence based instance list * @return */ public static InstanceList convertFeatsforClassifier( final Pipe METrainerDummyPipe, final InstanceList orgList) { final InstanceList iList = new InstanceList(METrainerDummyPipe); for (int i = 0; i < orgList.size(); i++) { final Instance inst = orgList.get(i); final FeatureVectorSequence fvs = (FeatureVectorSequence) inst.getData(); final LabelSequence ls = (LabelSequence) inst.getTarget(); final LabelAlphabet ldict = (LabelAlphabet) ls.getAlphabet(); final Object source = inst.getSource(); final Object name = inst.getName(); if (ls.size() != fvs.size()) { System.err.println( "failed making token instances: size of labelsequence != size of featue vector sequence: " + ls.size() + " - " + fvs.size()); System.exit(-1); } for (int j = 0; j < fvs.size(); j++) { final Instance I = new Instance(fvs.getFeatureVector(j), ldict.lookupLabel(ls.get(j)), name, source); iList.add(I); } } return iList; }
/** * Create and train a CRF model from the given training data, optionally testing it on the given * test data. * * @param training training data * @param testing test data (possibly <code>null</code>) * @param eval accuracy evaluator (possibly <code>null</code>) * @param orders label Markov orders (main and backoff) * @param defaultLabel default label * @param forbidden regular expression specifying impossible label transitions <em>current</em> * <code>,</code><em>next</em> (<code>null</code> indicates no forbidden transitions) * @param allowed regular expression specifying allowed label transitions (<code>null</code> * indicates everything is allowed that is not forbidden) * @param connected whether to include even transitions not occurring in the training data. * @param iterations number of training iterations * @param var Gaussian prior variance * @return the trained model */ public static CRF train( InstanceList training, InstanceList testing, TransducerEvaluator eval, int[] orders, String defaultLabel, String forbidden, String allowed, boolean connected, int iterations, double var, CRF crf) { Pattern forbiddenPat = Pattern.compile(forbidden); Pattern allowedPat = Pattern.compile(allowed); if (crf == null) { crf = new CRF(training.getPipe(), (Pipe) null); String startName = crf.addOrderNStates( training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crft.setGaussianPriorVariance(var); for (int i = 0; i < crf.numStates(); i++) crf.getState(i).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT); crf.getState(startName).setInitialWeight(0.0); } logger.info("Training on " + training.size() + " instances"); if (testing != null) logger.info("Testing on " + testing.size() + " instances"); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); if (featureInductionOption.value) { crft.trainWithFeatureInduction( training, null, testing, eval, iterations, 10, 20, 500, 0.5, false, null); } else { boolean converged; for (int i = 1; i <= iterations; i++) { converged = crft.train(training, 1); if (i % 1 == 0 && eval != null) // Change the 1 to higher integer to evaluate less often eval.evaluate(crft); if (viterbiOutputOption.value && i % 10 == 0) new ViterbiWriter( "", new InstanceList[] {training, testing}, new String[] {"training", "testing"}) .evaluate(crft); if (converged) break; } } return crf; }
public void train(String[] trainSections, String[] testSections) throws IOException { pipe = defaultPipe(); InstanceList trainingInstanceList = prepareInstanceList(trainSections); InstanceList testingInstanceList = prepareInstanceList(testSections); // Classifier classifier = trainer.train(trainingInstanceList, testingInstanceList); Classifier classifier = trainer.train(trainingInstanceList); System.out.println("training size: " + trainingInstanceList.size()); System.out.println("testing size: " + testingInstanceList.size()); // showAccuracy(classifier, testingInstanceList); // getTypeSpecificAccuracy(trainingInstanceList, testingInstanceList, true); // showInterpolatedTCAccuracy(trainingInstanceList, testingInstanceList); }
double getAccuracy(Classifier classifier, InstanceList instanceList) { int total = instanceList.size(); int correct = 0; for (Instance instance : instanceList) { Classification classification = classifier.classify(instance); if (classification.bestLabelIsCorrect()) correct++; } return (1.0 * correct) / total; }
public void count() { TIntIntHashMap docCounts = new TIntIntHashMap(); int index = 0; if (instances.size() == 0) { logger.info("Instance list is empty"); return; } if (instances.get(0).getData() instanceof FeatureSequence) { for (Instance instance : instances) { FeatureSequence features = (FeatureSequence) instance.getData(); for (int i = 0; i < features.getLength(); i++) { docCounts.adjustOrPutValue(features.getIndexAtPosition(i), 1, 1); } int[] keys = docCounts.keys(); for (int i = 0; i < keys.length - 1; i++) { int feature = keys[i]; featureCounts[feature] += docCounts.get(feature); documentFrequencies[feature]++; } docCounts = new TIntIntHashMap(); index++; if (index % 1000 == 0) { System.err.println(index); } } } else if (instances.get(0).getData() instanceof FeatureVector) { for (Instance instance : instances) { FeatureVector features = (FeatureVector) instance.getData(); for (int location = 0; location < features.numLocations(); location++) { int feature = features.indexAtLocation(location); double value = features.valueAtLocation(location); documentFrequencies[feature]++; featureCounts[feature] += value; } index++; if (index % 1000 == 0) { System.err.println(index); } } } else { logger.info("Unsupported data class: " + instances.get(0).getData().getClass().getName()); } }
public void estimate( InstanceList documents, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) { ilist = documents.shallowClone(); numTypes = ilist.getDataAlphabet().size(); int numDocs = ilist.size(); topics = new int[numDocs][]; docTopicCounts = new int[numDocs][numTopics]; typeTopicCounts = new int[numTypes][numTopics]; tokensPerTopic = new int[numTopics]; tAlpha = alpha * numTopics; vBeta = beta * numTypes; long startTime = System.currentTimeMillis(); // Initialize with random assignments of tokens to topics // and finish allocating this.topics and this.tokens int topic, seqLen; FeatureSequence fs; for (int di = 0; di < numDocs; di++) { try { fs = (FeatureSequence) ilist.get(di).getData(); } catch (ClassCastException e) { System.err.println( "LDA and other topic models expect FeatureSequence data, not FeatureVector data. " + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence."); throw e; } seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; // Randomly assign tokens to topics for (int si = 0; si < seqLen; si++) { topic = r.nextInt(numTopics); topics[di][si] = topic; docTopicCounts[di][topic]++; typeTopicCounts[fs.getIndexAtPosition(si)][topic]++; tokensPerTopic[topic]++; } } this.estimate( 0, numDocs, numIterations, showTopicsInterval, outputModelInterval, outputModelFilename, r); // 124.5 seconds // 144.8 seconds after using FeatureSequence instead of tokens[][] array // 121.6 seconds after putting "final" on FeatureSequence.getIndexAtPosition() // 106.3 seconds after avoiding array lookup in inner loop with a temporary variable }
public Node(InstanceList ilist, Node parent, int minNumInsts, int[] instIndices) { if (instIndices == null) { instIndices = new int[ilist.size()]; for (int ii = 0; ii < instIndices.length; ii++) instIndices[ii] = ii; } m_gainRatio = GainRatio.createGainRatio(ilist, instIndices, minNumInsts); m_ilist = ilist; m_instIndices = instIndices; m_dataDict = m_ilist.getDataAlphabet(); m_minNumInsts = minNumInsts; m_parent = parent; m_leftChild = m_rightChild = null; }
public void addDocuments( InstanceList additionalDocuments, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) { if (ilist == null) throw new IllegalStateException("Must already have some documents first."); for (Instance inst : additionalDocuments) ilist.add(inst); assert (ilist.getDataAlphabet() == additionalDocuments.getDataAlphabet()); assert (additionalDocuments.getDataAlphabet().size() >= numTypes); numTypes = additionalDocuments.getDataAlphabet().size(); int numNewDocs = additionalDocuments.size(); int numOldDocs = topics.length; int numDocs = numOldDocs + numNewDocs; // Expand various arrays to make space for the new data. int[][] newTopics = new int[numDocs][]; for (int i = 0; i < topics.length; i++) newTopics[i] = topics[i]; topics = newTopics; // The rest of this array will be initialized below. int[][] newDocTopicCounts = new int[numDocs][numTopics]; for (int i = 0; i < docTopicCounts.length; i++) newDocTopicCounts[i] = docTopicCounts[i]; docTopicCounts = newDocTopicCounts; // The rest of this array will be initialized below. int[][] newTypeTopicCounts = new int[numTypes][numTopics]; for (int i = 0; i < typeTopicCounts.length; i++) for (int j = 0; j < numTopics; j++) newTypeTopicCounts[i][j] = typeTopicCounts[i][j]; // This array further populated below FeatureSequence fs; for (int di = numOldDocs; di < numDocs; di++) { try { fs = (FeatureSequence) additionalDocuments.get(di - numOldDocs).getData(); } catch (ClassCastException e) { System.err.println( "LDA and other topic models expect FeatureSequence data, not FeatureVector data. " + "With text2vectors, you can obtain such data with --keep-sequence or --keep-bisequence."); throw e; } int seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; // Randomly assign tokens to topics for (int si = 0; si < seqLen; si++) { int topic = r.nextInt(numTopics); topics[di][si] = topic; docTopicCounts[di][topic]++; typeTopicCounts[fs.getIndexAtPosition(si)][topic]++; tokensPerTopic[topic]++; } } }
@Test public void testLoadRareWords() throws UnsupportedEncodingException, FileNotFoundException { String dataset_fn = "src/main/resources/datasets/SmallTexts.txt"; InstanceList nonPrunedInstances = LDAUtils.loadInstances(dataset_fn, "stoplist.txt", 0); System.out.println(LDAUtils.instancesToString(nonPrunedInstances)); System.out.println("Non pruned Alphabet size: " + nonPrunedInstances.getDataAlphabet().size()); System.out.println("No. instances: " + nonPrunedInstances.size()); InstanceList originalInstances = LDAUtils.loadInstances(dataset_fn, "stoplist.txt", 2); System.out.println("Alphabet size: " + originalInstances.getDataAlphabet().size()); System.out.println(LDAUtils.instancesToString(originalInstances)); System.out.println("No. instances: " + originalInstances.size()); int[] wordCounts = {0, 3, 3, 0, 0}; int idx = 0; for (Instance instance : originalInstances) { FeatureSequence fs = (FeatureSequence) instance.getData(); // This assertion would fail for eventhough the feature sequence // is "empty" the underlying array is 2 long. // assertEquals(wordCounts[idx++], fs.getFeatures().length); assertEquals(wordCounts[idx++], fs.size()); } }
/** * Initialize this separate model using a complete list. * * @param documents * @param testStartIndex */ public void divideDocuments(InstanceList documents, int testStartIndex) { Alphabet dataAlpha = documents.getDataAlphabet(); Alphabet targetAlpha = documents.getTargetAlphabet(); this.training = new InstanceList(dataAlpha, targetAlpha); this.test = new InstanceList(dataAlpha, targetAlpha); int di = 0; for (di = 0; di < testStartIndex; di++) { training.add(documents.get(di)); } for (di = testStartIndex; di < documents.size(); di++) { test.add(documents.get(di)); } }
public void doInference() { try { ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile)); TopicInferencer inferencer = model.getInferencer(); // TopicInferencer inferencer = // TopicInferencer.read(new File(inferencerFile)); // InstanceList testing = readFile(); readFile(); InstanceList testing = generateInstanceList(); // readFile(); for (int i = 0; i < testing.size(); i++) { StringBuilder probabilities = new StringBuilder(); double[] testProbabilities = inferencer.getSampledDistribution(testing.get(i), 10, 1, 5); ArrayList probabilityList = new ArrayList(); for (int j = 0; j < testProbabilities.length; j++) { probabilityList.add(new Pair<Integer, Double>(j, testProbabilities[j])); } Collections.sort(probabilityList, new CustomComparator()); for (int j = 0; j < testProbabilities.length && j < topN; j++) { if (j > 0) probabilities.append(" "); probabilities.append( ((Pair<Integer, Double>) probabilityList.get(j)).getFirst().toString() + "," + ((Pair<Integer, Double>) probabilityList.get(j)).getSecond().toString()); } System.out.println(docIds.get(i) + "," + probabilities.toString()); } } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }
public double dataLogLikelihood(InstanceList ilist) { double logLikelihood = 0; for (int ii = 0; ii < ilist.size(); ii++) { double instanceWeight = ilist.getInstanceWeight(ii); Instance inst = ilist.get(ii); Labeling labeling = inst.getLabeling(); if (labeling != null) logLikelihood += instanceWeight * dataLogProbability(inst, labeling.getBestIndex()); else { Labeling predicted = this.classify(inst).getLabeling(); // System.err.println ("label = \n"+labeling); // System.err.println ("predicted = \n"+predicted); for (int lpos = 0; lpos < predicted.numLocations(); lpos++) { int li = predicted.indexAtLocation(lpos); double labelWeight = predicted.valueAtLocation(lpos); // System.err.print (", "+labelWeight); if (labelWeight == 0) continue; logLikelihood += instanceWeight * labelWeight * dataLogProbability(inst, li); } } } return logLikelihood; }
private void showNFoldAccuracy(InstanceList instanceList, int n, int count) { InstanceList.CrossValidationIterator cvIt = instanceList.crossValidationIterator(n); double accuracies[] = new double[n]; double accuracy = 0; int run = 0; double totalTP = 0; while (cvIt.hasNext()) { InstanceList[] nextSplit = cvIt.nextSplit(); InstanceList trainingInstances = nextSplit[0]; InstanceList testingInstances = nextSplit[1]; trainer = new MyClassifierTrainer(new RankMaxEntTrainer()); Classifier classifier = trainer.train(trainingInstances); accuracies[run] = getAccuracy(classifier, testingInstances); accuracy += accuracies[run]; totalTP += accuracies[run] * testingInstances.size(); run++; } System.out.println(n + "-Fold accuracy(avg): " + accuracy / n); System.out.println("Total tp:" + totalTP); System.out.println("Total count:" + count); System.out.println(n + "-Fold accuracy: " + totalTP / count); }
public double labelLogLikelihood(InstanceList ilist) { double logLikelihood = 0; for (int ii = 0; ii < ilist.size(); ii++) { double instanceWeight = ilist.getInstanceWeight(ii); Instance inst = ilist.get(ii); Labeling labeling = inst.getLabeling(); if (labeling == null) continue; Labeling predicted = this.classify(inst).getLabeling(); // System.err.println ("label = \n"+labeling); // System.err.println ("predicted = \n"+predicted); if (labeling.numLocations() == 1) { logLikelihood += instanceWeight * Math.log(predicted.value(labeling.getBestIndex())); } else { for (int lpos = 0; lpos < labeling.numLocations(); lpos++) { int li = labeling.indexAtLocation(lpos); double labelWeight = labeling.valueAtLocation(lpos); // System.err.print (", "+labelWeight); if (labelWeight == 0) continue; logLikelihood += instanceWeight * labelWeight * Math.log(predicted.value(li)); } } } return logLikelihood; }
/** * Command-line wrapper to train, test, or run a generic CRF-based tagger. * * @param args the command line arguments. Options (shell and Java quoting should be added as * needed): * <dl> * <dt><code>--help</code> <em>boolean</em> * <dd>Print this command line option usage information. Give <code>true</code> for longer * documentation. Default is <code>false</code>. * <dt><code>--prefix-code</code> <em>Java-code</em> * <dd>Java code you want run before any other interpreted code. Note that the text is * interpreted without modification, so unlike some other Java code options, you need to * include any necessary 'new's. Default is null. * <dt><code>--gaussian-variance</code> <em>positive-number</em> * <dd>The Gaussian prior variance used for training. Default is 10.0. * <dt><code>--train</code> <em>boolean</em> * <dd>Whether to train. Default is <code>false</code>. * <dt><code>--iterations</code> <em>positive-integer</em> * <dd>Number of training iterations. Default is 500. * <dt><code>--test</code> <code>lab</code> or <code>seg=</code><em>start-1</em><code>. * </code><em>continue-1</em><code>,</code>...<code>,</code><em>start-n</em><code>. * </code><em>continue-n</em> * <dd>Test measuring labeling or segmentation (<em>start-i</em>, <em>continue-i</em>) * accuracy. Default is no testing. * <dt><code>--training-proportion</code> <em>number-between-0-and-1</em> * <dd>Fraction of data to use for training in a random split. Default is 0.5. * <dt><code>--model-file</code> <em>filename</em> * <dd>The filename for reading (train/run) or saving (train) the model. Default is null. * <dt><code>--random-seed</code> <em>integer</em> * <dd>The random seed for randomly selecting a proportion of the instance list for training * Default is 0. * <dt><code>--orders</code> <em>comma-separated-integers</em> * <dd>List of label Markov orders (main and backoff) Default is 1. * <dt><code>--forbidden</code> <em>regular-expression</em> * <dd>If <em>label-1</em><code>,</code><em>label-2</em> matches the expression, the * corresponding transition is forbidden. Default is <code>\\s</code> (nothing * forbidden). * <dt><code>--allowed</code> <em>regular-expression</em> * <dd>If <em>label-1</em><code>,</code><em>label-2</em> does not match the expression, the * corresponding expression is forbidden. Default is <code>.*</code> (everything * allowed). * <dt><code>--default-label</code> <em>string</em> * <dd>Label for initial context and uninteresting tokens. Default is <code>O</code>. * <dt><code>--viterbi-output</code> <em>boolean</em> * <dd>Print Viterbi periodically during training. Default is <code>false</code>. * <dt><code>--fully-connected</code> <em>boolean</em> * <dd>Include all allowed transitions, even those not in training data. Default is <code> * true</code>. * <dt><code>--n-best</code> <em>positive-integer</em> * <dd>Number of answers to output when applying model. Default is 1. * <dt><code>--include-input</code> <em>boolean</em> * <dd>Whether to include input features when printing decoding output. Default is <code> * false</code>. * </dl> * Remaining arguments: * <ul> * <li><em>training-data-file</em> if training * <li><em>training-and-test-data-file</em>, if training and testing with random split * <li><em>training-data-file</em> <em>test-data-file</em> if training and testing from * separate files * <li><em>test-data-file</em> if testing * <li><em>input-data-file</em> if applying to new data (unlabeled) * </ul> * * @exception Exception if an error occurs */ public static void main(String[] args) throws Exception { Reader trainingFile = null, testFile = null; InstanceList trainingData = null, testData = null; int numEvaluations = 0; int iterationsBetweenEvals = 16; int restArgs = commandOptions.processOptions(args); if (restArgs == args.length) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing data file(s)"); } if (trainOption.value) { trainingFile = new FileReader(new File(args[restArgs])); if (testOption.value != null && restArgs < args.length - 1) testFile = new FileReader(new File(args[restArgs + 1])); } else testFile = new FileReader(new File(args[restArgs])); Pipe p = null; CRF crf = null; TransducerEvaluator eval = null; if (continueTrainingOption.value || !trainOption.value) { if (modelOption.value == null) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing model file option"); } ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value)); crf = (CRF) s.readObject(); s.close(); p = crf.getInputPipe(); } else { p = new SimpleTaggerSentence2FeatureVectorSequence(); p.getTargetAlphabet().lookupIndex(defaultOption.value); } if (trainOption.value) { p.setTargetProcessing(true); trainingData = new InstanceList(p); trainingData.addThruPipe( new LineGroupIterator(trainingFile, Pattern.compile("^\\s*$"), true)); logger.info("Number of features in training data: " + p.getDataAlphabet().size()); if (testOption.value != null) { if (testFile != null) { testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } else { Random r = new Random(randomSeedOption.value); InstanceList[] trainingLists = trainingData.split( r, new double[] {trainingFractionOption.value, 1 - trainingFractionOption.value}); trainingData = trainingLists[0]; testData = trainingLists[1]; } } } else if (testOption.value != null) { p.setTargetProcessing(true); testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } else { p.setTargetProcessing(false); testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } logger.info("Number of predicates: " + p.getDataAlphabet().size()); if (testOption.value != null) { if (testOption.value.startsWith("lab")) eval = new TokenAccuracyEvaluator( new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"}); else if (testOption.value.startsWith("seg=")) { String[] pairs = testOption.value.substring(4).split(","); if (pairs.length < 1) { commandOptions.printUsage(true); throw new IllegalArgumentException( "Missing segment start/continue labels: " + testOption.value); } String startTags[] = new String[pairs.length]; String continueTags[] = new String[pairs.length]; for (int i = 0; i < pairs.length; i++) { String[] pair = pairs[i].split("\\."); if (pair.length != 2) { commandOptions.printUsage(true); throw new IllegalArgumentException( "Incorrectly-specified segment start and end labels: " + pairs[i]); } startTags[i] = pair[0]; continueTags[i] = pair[1]; } eval = new MultiSegmentationEvaluator( new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"}, startTags, continueTags); } else { commandOptions.printUsage(true); throw new IllegalArgumentException("Invalid test option: " + testOption.value); } } if (p.isTargetProcessing()) { Alphabet targets = p.getTargetAlphabet(); StringBuffer buf = new StringBuffer("Labels:"); for (int i = 0; i < targets.size(); i++) buf.append(" ").append(targets.lookupObject(i).toString()); logger.info(buf.toString()); } if (trainOption.value) { crf = train( trainingData, testData, eval, ordersOption.value, defaultOption.value, forbiddenOption.value, allowedOption.value, connectedOption.value, iterationsOption.value, gaussianVarianceOption.value, crf); if (modelOption.value != null) { ObjectOutputStream s = new ObjectOutputStream(new FileOutputStream(modelOption.value)); s.writeObject(crf); s.close(); } } else { if (crf == null) { if (modelOption.value == null) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing model file option"); } ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value)); crf = (CRF) s.readObject(); s.close(); } if (eval != null) test(new NoopTransducerTrainer(crf), eval, testData); else { boolean includeInput = includeInputOption.value(); for (int i = 0; i < testData.size(); i++) { Sequence input = (Sequence) testData.get(i).getData(); Sequence[] outputs = apply(crf, input, nBestOption.value); int k = outputs.length; boolean error = false; for (int a = 0; a < k; a++) { if (outputs[a].size() != input.size()) { System.err.println("Failed to decode input sequence " + i + ", answer " + a); error = true; } } if (!error) { for (int j = 0; j < input.size(); j++) { StringBuffer buf = new StringBuffer(); for (int a = 0; a < k; a++) buf.append(outputs[a].get(j).toString()).append(" "); if (includeInput) { FeatureVector fv = (FeatureVector) input.get(j); buf.append(fv.toString(true)); } System.out.println(buf.toString()); } System.out.println(); } } } } }
public void estimate( InstanceList documents, int numIterations, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) { ilist = documents; uniAlphabet = ilist.getDataAlphabet(); biAlphabet = ((FeatureSequenceWithBigrams) ilist.get(0).getData()).getBiAlphabet(); numTypes = uniAlphabet.size(); numBitypes = biAlphabet.size(); int numDocs = ilist.size(); topics = new int[numDocs][]; grams = new int[numDocs][]; docTopicCounts = new int[numDocs][numTopics]; typeNgramTopicCounts = new int[numTypes][2][numTopics]; unitypeTopicCounts = new int[numTypes][numTopics]; bitypeTopicCounts = new int[numBitypes][numTopics]; tokensPerTopic = new int[numTopics]; bitokensPerTopic = new int[numTypes][numTopics]; tAlpha = alpha * numTopics; vBeta = beta * numTypes; vGamma = gamma * numTypes; long startTime = System.currentTimeMillis(); // Initialize with random assignments of tokens to topics // and finish allocating this.topics and this.tokens int topic, gram, seqLen, fi; for (int di = 0; di < numDocs; di++) { FeatureSequenceWithBigrams fs = (FeatureSequenceWithBigrams) ilist.get(di).getData(); seqLen = fs.getLength(); numTokens += seqLen; topics[di] = new int[seqLen]; grams[di] = new int[seqLen]; // Randomly assign tokens to topics int prevFi = -1, prevTopic = -1; for (int si = 0; si < seqLen; si++) { // randomly sample a topic for the word at position si topic = r.nextInt(numTopics); // if a bigram is allowed at position si, then sample a gram status for it. gram = (fs.getBiIndexAtPosition(si) == -1 ? 0 : r.nextInt(2)); if (gram != 0) biTokens++; topics[di][si] = topic; grams[di][si] = gram; docTopicCounts[di][topic]++; fi = fs.getIndexAtPosition(si); if (prevFi != -1) typeNgramTopicCounts[prevFi][gram][prevTopic]++; if (gram == 0) { unitypeTopicCounts[fi][topic]++; tokensPerTopic[topic]++; } else { bitypeTopicCounts[fs.getBiIndexAtPosition(si)][topic]++; bitokensPerTopic[prevFi][topic]++; } prevFi = fi; prevTopic = topic; } } for (int iterations = 0; iterations < numIterations; iterations++) { sampleTopicsForAllDocs(r); if (iterations % 10 == 0) System.out.print(iterations); else System.out.print("."); System.out.flush(); if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0 && iterations > 0) { System.out.println(); printTopWords(5, false); } if (outputModelInterval != 0 && iterations % outputModelInterval == 0 && iterations > 0) { this.write(new File(outputModelFilename + '.' + iterations)); } } System.out.println( "\nTotal time (sec): " + ((System.currentTimeMillis() - startTime) / 1000.0)); }
private Instance getLastInstance() { return list.get(list.size() - 1); }
/** * Shows accuracy according to Ben Wellner's definition of accuracy * * @param classifier * @param instanceList */ private void showAccuracy(Classifier classifier, InstanceList instanceList) throws IOException { int total = instanceList.size(); int correct = 0; HashMap<String, Integer> errorMap = new HashMap<String, Integer>(); FileWriter errorWriter = new FileWriter("arg1Error.log"); for (Instance instance : instanceList) { Classification classification = classifier.classify(instance); if (classification.bestLabelIsCorrect()) { correct++; } else { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Document doc = rankInstance.getDocument(); Sentence s = doc.getSentence(rankInstance.getArg2Line()); String conn = s.toString(rankInstance.getConnStart(), rankInstance.getConnEnd()).toLowerCase(); // String category = connAnalyzer.getCategory(conn); if (errorMap.containsKey(conn)) { errorMap.put(conn, errorMap.get(conn) + 1); } else { errorMap.put(conn, 1); } int arg2Line = rankInstance.getArg2Line(); int arg1Line = rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).first(); int arg1HeadPos = rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).second(); int predictedCandidateIndex = Integer.parseInt(classification.getLabeling().getBestLabel().toString()); if (arg1Line == arg2Line) { errorWriter.write("FileName: " + doc.getFileName() + "\n"); errorWriter.write("Sentential\n"); errorWriter.write("Conn: " + conn + "\n"); errorWriter.write("Arg1Head: " + s.get(arg1HeadPos).word() + "\n"); errorWriter.write(s.toString() + "\n\n"); } else { errorWriter.write("FileName: " + doc.getFileName() + "\n"); errorWriter.write("Inter-Sentential\n"); errorWriter.write("Arg1 in : " + arg1Line + "\n"); errorWriter.write("Arg2 in : " + arg2Line + "\n"); errorWriter.write("Conn: " + conn + "\n"); errorWriter.write(s.toString() + "\n"); Sentence s1 = doc.getSentence(arg1Line); errorWriter.write("Arg1Head: " + s1.get(arg1HeadPos) + "\n"); errorWriter.write(s1.toString() + "\n\n"); } int predictedArg1Line = rankInstance.getCandidates().get(predictedCandidateIndex).first(); int predictedArg1HeadPos = rankInstance.getCandidates().get(predictedCandidateIndex).second(); Sentence pSentence = doc.getSentence(predictedArg1Line); errorWriter.write( "Predicted arg1 sentence: " + pSentence.toString() + " [Correct: " + (predictedArg1Line == arg1Line) + "]\n"); errorWriter.write("Predicted head: " + pSentence.get(predictedArg1HeadPos).word() + "\n\n"); } } errorWriter.close(); Set<Entry<String, Integer>> entrySet = errorMap.entrySet(); List<Entry<String, Integer>> list = new ArrayList<Entry<String, Integer>>(entrySet); Collections.sort( list, new Comparator<Entry<String, Integer>>() { @Override public int compare(Entry<String, Integer> o1, Entry<String, Integer> o2) { if (o1.getValue() > o2.getValue()) return -1; else if (o1.getValue() < o2.getValue()) return 1; return 0; } }); for (Entry<String, Integer> item : list) { System.out.println(item.getKey() + "-" + item.getValue()); } System.out.println("Total: " + total); System.out.println("Correct: " + correct); System.out.println("Accuracy: " + (1.0 * correct) / total); }
public void run(Collection<ConnectionsDocument> docsToRun) { InstanceList instanceList = processDocs(docsToRun); log.info("instancelist size:" + instanceList.size()); run(instanceList); }