/** Returns the features for the highest-score current parse(s). */ public HashVector getFeats() { HashVector result = new HashVector(); // result.reset(theta.size(),0.0); for (ParseResult p : bestParses) p.getFeats(result); if (bestParses.size() > 1) result.divideBy(bestParses.size()); return result; }
/** Returns the features for the highest-score current parse with semantics that equal sem. */ public HashVector getFeats(Exp sem) { HashVector result = new HashVector(); List<ParseResult> pr = findBestParses(allParses, sem); for (ParseResult p : pr) { p.getFeats(result); } if (pr.size() > 1) result.divideBy(pr.size()); return result; }
/** Adds the input vector into the current parameters. */ public void updateParams(HashVector p) { p.addTimesInto(1, theta); }
public boolean isCorrect(String words, Exp sem, Parser parser) { List<ParseResult> parses = parser.bestParses(); if (parses.size() > 0) { noAnswer = false; } else { noAnswer = true; } if (parses.size() == 1) { ParseResult p = parses.get(0); Exp e = p.getExp(); e = e.copy(); e.simplify(); List l = p.getLexEntries(); parsed++; if (e.equals(sem)) { if (verbose) { System.out.println("CORRECT"); printLex(l); } int lits = sem.allLitsCount(); correctParses++; return true; } else { // one parse, it was wrong... oh well... if (verbose) { System.out.println("WRONG"); System.out.println(parses.size() + " parses: " + parses); printLex(l); } wrongParses++; boolean hasCorrect = parser.hasParseFor(sem); if (verbose) { System.out.println("Had correct parse: " + hasCorrect); System.out.print("Feats: "); Exp eb = parser.bestSem(); Chart c = parser.getChart(); HashVector h = c.computeExpFeatVals(eb); h.divideBy(c.computeNorm(eb)); h.dropSmallEntries(); System.out.println(h); } } } else { noParses++; if (parses.size() > 1) { // There are more than one equally high scoring // logical forms. If this is the case, we abstain // from returning a result. if (verbose) { System.out.println("too many parses"); System.out.println(parses.size() + " parses: " + parses); } Exp e = parses.get(0).getExp(); ParseResult p = parses.get(0); boolean hasCorrect = parser.hasParseFor(sem); if (verbose) System.out.println("Had correct parse: " + hasCorrect); } else { // no parses, potentially reparse with word skipping if (verbose) System.out.println("no parses"); if (emptyTest) { List<LexEntry> emps = new LinkedList<LexEntry>(); for (int j = 0; j < Globals.tokens.size(); j++) { List l = Globals.tokens.subList(j, j + 1); LexEntry le = new LexEntry(l, Cat.EMP); emps.add(le); } parser.setTempLexicon(new Lexicon(emps)); String mes = null; if (verbose) mes = "EMPTY"; parser.parseTimed(words, null, mes); parser.setTempLexicon(null); parses = parser.bestParses(); if (parses.size() == 1) { ParseResult p = parses.get(0); List l = p.getLexEntries(); Exp e = p.getExp(); e = e.copy(); e.simplify(); int noEmpty = p.noEmpty(); if (e.equals(sem)) { if (verbose) { System.out.println("CORRECT"); printLex(l); } emptyCorrect++; } else { // one parse, but wrong if (verbose) { System.out.println("WRONG: " + e); printLex(l); boolean hasCorrect = parser.hasParseFor(sem); System.out.println("Had correct parse: " + hasCorrect); } } } else { // too many parses or no parses emptyNoParses++; if (verbose) { System.out.println("WRONG:" + parses); boolean hasCorrect = parser.hasParseFor(sem); System.out.println("Had correct parse: " + hasCorrect); } } } } } return false; }
public void stocGradTrain(Parser parser, boolean testEachRound) { int numUpdates = 0; List<LexEntry> fixedEntries = new LinkedList<LexEntry>(); fixedEntries.addAll(parser.returnLex().getLexicon()); // add all sentential lexical entries. for (int l = 0; l < trainData.size(); l++) { parser.addLexEntries(trainData.getDataSet(l).makeSentEntries()); } parser.setGlobals(); DataSet data = null; // for each pass over the data for (int j = 0; j < EPOCHS; j++) { System.out.println("Training, iteration " + j); int total = 0, correct = 0, wrong = 0, looCorrect = 0, looWrong = 0; for (int l = 0; l < trainData.size(); l++) { // the variables to hold the current training example String words = null; Exp sem = null; data = trainData.getDataSet(l); if (verbose) System.out.println("---------------------"); String filename = trainData.getFilename(l); if (verbose) System.out.println("DataSet: " + filename); if (verbose) System.out.println("---------------------"); // loop through the training examples // try to create lexical entries for each training example for (int i = 0; i < data.size(); i++) { // print running stats if (verbose) { if (total != 0) { double r = (double) correct / total; double p = (double) correct / (correct + wrong); System.out.print(i + ": =========== r:" + r + " p:" + p); System.out.println(" (epoch:" + j + " file:" + l + " " + filename + ")"); } else System.out.println(i + ": ==========="); } // get the training example words = data.sent(i); sem = data.sem(i); if (verbose) { System.out.println(words); System.out.println(sem); } List<String> tokens = Parser.tokenize(words); if (tokens.size() > maxSentLen) continue; total++; String mes = null; boolean hasCorrect = false; // first, get all possible lexical entries from // a manipulation of the best parse. List<LexEntry> lex = makeLexEntriesChart(words, sem, parser); if (verbose) { System.out.println("Adding:"); for (LexEntry le : lex) { System.out.println(le + " : " + LexiconFeatSet.initialWeight(le)); } } parser.addLexEntries(lex); if (verbose) System.out.println("Lex Size: " + parser.returnLex().size()); // first parse to see if we are currently correct if (verbose) mes = "First"; parser.parseTimed(words, null, mes); Chart firstChart = parser.getChart(); Exp best = parser.bestSem(); // this just collates and outputs the training // accuracy. if (sem.equals(best)) { // System.out.println(parser.bestParses().get(0)); if (verbose) { System.out.println("CORRECT:" + best); lex = parser.getMaxLexEntriesFor(sem); System.out.println("Using:"); printLex(lex); if (lex.size() == 0) { System.out.println("ERROR: empty lex"); } } correct++; } else { if (verbose) { System.out.println("WRONG: " + best); lex = parser.getMaxLexEntriesFor(best); System.out.println("Using:"); printLex(lex); if (best != null && lex.size() == 0) { System.out.println("ERROR: empty lex"); } } wrong++; } // compute first half of parameter update: // subtract the expectation of parameters // under the distribution that is conditioned // on the sentence alone. double norm = firstChart.computeNorm(); HashVector update = new HashVector(); HashVector firstfeats = null, secondfeats = null; if (norm != 0.0) { firstfeats = firstChart.computeExpFeatVals(); firstfeats.divideBy(norm); firstfeats.dropSmallEntries(); firstfeats.addTimesInto(-1.0, update); } else continue; firstChart = null; if (verbose) mes = "Second"; parser.parseTimed(words, sem, mes); hasCorrect = parser.hasParseFor(sem); // compute second half of parameter update: // add the expectation of parameters // under the distribution that is conditioned // on the sentence and correct logical form. if (!hasCorrect) continue; Chart secondChart = parser.getChart(); double secnorm = secondChart.computeNorm(sem); if (norm != 0.0) { secondfeats = secondChart.computeExpFeatVals(sem); secondfeats.divideBy(secnorm); secondfeats.dropSmallEntries(); secondfeats.addTimesInto(1.0, update); lex = parser.getMaxLexEntriesFor(sem); data.setBestLex(i, lex); if (verbose) { System.out.println("Best LexEntries:"); printLex(lex); if (lex.size() == 0) { System.out.println("ERROR: empty lex"); } } } else continue; // now do the update double scale = alpha_0 / (1.0 + c * numUpdates); if (verbose) System.out.println("Scale: " + scale); update.multiplyBy(scale); update.dropSmallEntries(); numUpdates++; if (verbose) { System.out.println("Update:"); System.out.println(update); } if (!update.isBad()) { if (!update.valuesInRange(-100, 100)) { System.out.println("WARNING: large update"); System.out.println("first feats: " + firstfeats); System.out.println("second feats: " + secondfeats); } parser.updateParams(update); } else { System.out.println( "ERROR: Bad Update: " + update + " -- norm: " + norm + " -- feats: "); parser.getParams().printValues(update); System.out.println(); } } // end for each training example } // end for each data set double r = (double) correct / total; // we can prune the lexical items that were not used // in a max scoring parse. if (pruneLex) { Lexicon cur = new Lexicon(); cur.addLexEntries(fixedEntries); cur.addLexEntries(data.getBestLex()); parser.setLexicon(cur); } if (testEachRound) { System.out.println("Testing"); test(parser, false); } } // end epochs loop }