@Override public boolean equals(Object unit) { if (!(unit instanceof WordLabelUnit)) { return false; } WordLabelUnit labelUnit = (WordLabelUnit) unit; if (this.isNegative() && labelUnit.isNegative()) { return true; } return getLabel().equals(labelUnit.getLabel()); }
protected SparseFeatureVector calcWordGlobalFeatures( Word word, Instance instance, Label y, int lastIndex, WordLabelUnit candidateLabelUnit) { SparseFeatureVector globalFv = new SparseFeatureVector(params); StringSparseVector wordPairFv = new StringSparseVector(params); StringSparseVector wordFv = new StringSparseVector(params); String candidateLabel = candidateLabelUnit.getLabel(); String wBase = word.getWord().getRealBase(); String wPOS = word.getWord().getPOS(); int[] contextWords = new int[4]; Arrays.fill(contextWords, -1); for (int i = 0; i <= lastIndex; i++) { if (y.getLabel(i) instanceof PairLabelUnit) { if (y.getLabel(i).isNegative()) continue; if (candidateLabelUnit.isNegative()) { continue; } Pair pair = (Pair) instance.getSequence().get(i); if (pair.getW1().getId() == word.getId()) { String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel(); wordPairFv.add(candidateLabel.concat(pairLabel), 1.); wordPairFv.add(candidateLabel.concat(wBase).concat(pairLabel), 1.); if (params.getUseFullFeatures()) { // w-rel-w int w2Idx = ((JointInstance) instance).getWord(pair.getW2().getId()); if (w2Idx <= lastIndex) { String w2Base = pair.getW2().getWord().getRealBase(); String w2Label = ((WordLabelUnit) y.getLabel(w2Idx)).getLabel(); wordPairFv.add(candidateLabel.concat(pairLabel).concat(w2Label), 1.); wordPairFv.add( candidateLabel.concat(wBase).concat(pairLabel).concat(w2Label).concat(w2Base), 1.); } } } else if (pair.getW2().getId() == word.getId()) { String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel(); wordPairFv.add(pairLabel.concat(candidateLabel), 1.); wordPairFv.add(pairLabel.concat(wBase).concat(candidateLabel), 1.); if (params.getUseFullFeatures()) { // w-rel-w int w1Idx = ((JointInstance) instance).getWord(pair.getW1().getId()); if (w1Idx <= lastIndex) { String w1Base = pair.getW1().getWord().getRealBase(); String w1Label = ((WordLabelUnit) y.getLabel(w1Idx)).getLabel(); wordPairFv.add(w1Label.concat(pairLabel).concat(candidateLabel), 1.); wordPairFv.add( w1Base.concat(w1Label).concat(pairLabel).concat(candidateLabel).concat(wBase), 1.); } } } } else { assert y.getLabel(i) instanceof WordLabelUnit; if (!params.getUseGlobalEntityFeatures()) continue; Word w2 = (Word) instance.getSequence().get(i); if (w2.getId() == word.getId() - 2) { contextWords[0] = i; } if (w2.getId() == word.getId() - 1) { String wordLabel = ((WordLabelUnit) y.getLabel(i)).getLabel(); contextWords[1] = i; addBigramToFV( wordFv, wordLabel, w2.getWord().getRealBase(), w2.getWord().getPOS(), candidateLabel, wBase, wPOS); } else if (word.getId() == w2.getId() - 1) { String wordLabel = ((WordLabelUnit) y.getLabel(i)).getLabel(); contextWords[2] = i; addBigramToFV( wordFv, candidateLabel, wBase, wPOS, wordLabel, w2.getWord().getRealBase(), w2.getWord().getPOS()); } else if (word.getId() == w2.getId() - 2) { contextWords[3] = i; } } } List<Word> entityWords = getEntityWords(word, instance, y, lastIndex, candidateLabelUnit); if (entityWords.size() > 0) { if (params.getUseFullFeatures()) { String entString = entityString(entityWords); String candidateLabelType = candidateLabelUnit.getType(); wordFv.add(candidateLabelType.concat(entString), 1.); Word lastWord = entityWords.get(entityWords.size() - 1); for (int i = 0; i <= lastIndex; i++) { if (y.getLabel(i) instanceof PairLabelUnit) { if (y.getLabel(i).isNegative()) continue; Pair pair = (Pair) instance.getSequence().get(i); if (pair.getW1().getId() == lastWord.getId()) { String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel(); wordPairFv.add(candidateLabelType.concat(pairLabel), 1.); wordPairFv.add(candidateLabelType.concat(entString).concat(pairLabel), 1.); // ent-rel-ent int w2Idx = ((JointInstance) instance).getWord(pair.getW2().getId()); if (w2Idx <= lastIndex) { List<Word> e2Words = getEntityWords( pair.getW2(), instance, y, lastIndex, (WordLabelUnit) y.getLabel(w2Idx)); if (e2Words.size() > 0) { String e2String = entityString(e2Words); String w2LabelType = ((WordLabelUnit) y.getLabel(w2Idx)).getType(); wordPairFv.add(candidateLabelType.concat(pairLabel).concat(w2LabelType), 1.); wordPairFv.add( candidateLabelType .concat(entString) .concat(pairLabel) .concat(w2LabelType) .concat(e2String), 1.); } } } else if (pair.getW2().getId() == lastWord.getId()) { String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel(); wordPairFv.add(pairLabel.concat(candidateLabelType), 1.); wordPairFv.add(pairLabel.concat(candidateLabelType).concat(entString), 1.); // ent-rel-ent int w1Idx = ((JointInstance) instance).getWord(pair.getW1().getId()); if (w1Idx <= lastIndex) { List<Word> e1Words = getEntityWords( pair.getW1(), instance, y, lastIndex, (WordLabelUnit) y.getLabel(w1Idx)); if (e1Words.size() > 0) { String e1String = entityString(e1Words); String w1LabelType = ((WordLabelUnit) y.getLabel(w1Idx)).getType(); wordPairFv.add(w1LabelType.concat(pairLabel).concat(candidateLabelType), 1.); wordPairFv.add( w1LabelType .concat(e1String) .concat(pairLabel) .concat(candidateLabelType) .concat(entString), 1.); } } } } } } } if (contextWords[0] >= 0 && contextWords[1] >= 0) { Word w0 = (Word) instance.getSequence().get(contextWords[0]); String l0 = ((WordLabelUnit) y.getLabel(contextWords[0])).getLabel(); Word w1 = (Word) instance.getSequence().get(contextWords[1]); String l1 = ((WordLabelUnit) y.getLabel(contextWords[1])).getLabel(); addTrigramToFV( wordFv, l0, w0.getWord().getRealBase(), w0.getWord().getPOS(), l1, w1.getWord().getRealBase(), w1.getWord().getPOS(), candidateLabel, wBase, wPOS); } if (contextWords[1] >= 0 && contextWords[2] >= 0) { Word w1 = (Word) instance.getSequence().get(contextWords[1]); String l1 = ((WordLabelUnit) y.getLabel(contextWords[1])).getLabel(); Word w2 = (Word) instance.getSequence().get(contextWords[2]); String l2 = ((WordLabelUnit) y.getLabel(contextWords[2])).getLabel(); addTrigramToFV( wordFv, l1, w1.getWord().getRealBase(), w1.getWord().getPOS(), candidateLabel, wBase, wPOS, l2, w2.getWord().getRealBase(), w2.getWord().getPOS()); } if (contextWords[2] >= 0 && contextWords[3] >= 0) { Word w2 = (Word) instance.getSequence().get(contextWords[2]); String l2 = ((WordLabelUnit) y.getLabel(contextWords[2])).getLabel(); Word w3 = (Word) instance.getSequence().get(contextWords[3]); String l3 = ((WordLabelUnit) y.getLabel(contextWords[3])).getLabel(); addTrigramToFV( wordFv, candidateLabel, wBase, wPOS, l2, w2.getWord().getRealBase(), w2.getWord().getPOS(), l3, w3.getWord().getRealBase(), w3.getWord().getPOS()); } globalFv.add(wordFv, "WORD"); wordPairFv.mult(params.getRelWeight()); globalFv.add(wordPairFv, "WORDPAIR"); return globalFv; }
protected List<Word> getEntityWords( Word word, Instance instance, Label y, int lastIndex, WordLabelUnit candidateLabelUnit) { List<Word> words = new Vector<Word>(); if (candidateLabelUnit.isNegative()) { return words; } int targetId = word.getId(); if (candidateLabelUnit.getPosition().equals("U")) { words.add(word); } else if (candidateLabelUnit.getPosition().equals("L")) { words.add(word); boolean found = false; for (int i = targetId - 1; i >= 0; --i) { int seqIndex = ((JointInstance) instance).getWord(i); if (seqIndex >= lastIndex) { break; } String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition(); if (pos.equals("B")) { words.add(0, word); found = true; } else if (pos.equals("I")) { words.add(0, word); } } if (!found) { words.clear(); } } else if (candidateLabelUnit.getPosition().equals("I")) { words.add(word); boolean leftFound = false; for (int i = targetId - 1; i >= 0; --i) { int seqIndex = ((JointInstance) instance).getWord(i); if (seqIndex >= lastIndex) { break; } String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition(); if (pos.equals("B")) { words.add(0, word); leftFound = true; } else if (pos.equals("I")) { words.add(0, word); } } if (!leftFound) { words.clear(); } else { boolean rightFound = false; int nwords = ((JointInstance) instance).getNumWords(); for (int i = targetId + 1; i < nwords; ++i) { int seqIndex = ((JointInstance) instance).getWord(i); if (seqIndex >= lastIndex) { break; } String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition(); if (pos.equals("L")) { words.add(word); rightFound = true; } else if (pos.equals("I")) { words.add(word); } } if (!rightFound) { words.clear(); } } } else if (candidateLabelUnit.getPosition().equals("B")) { boolean found = false; int nwords = ((JointInstance) instance).getNumWords(); for (int i = targetId + 1; i < nwords; ++i) { int seqIndex = ((JointInstance) instance).getWord(i); if (seqIndex >= lastIndex) { break; } String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition(); if (pos.equals("L")) { words.add(word); found = true; } else if (pos.equals("I")) { words.add(word); } } if (!found) { words.clear(); } } return words; }
public Collection<LabelUnit> getPossibleLabels(Label label, JointInstance instance) { int size = label.size(); Collection<LabelUnit> labels = getRelWordLabels(label, instance); boolean hasRelation = labels.size() > 0; if (!hasRelation) { labels = new TreeSet<LabelUnit>(params.getPossibleLabels().get(getType())); } if (params.getUseGoldEntitySpan()) { String goldPosition = ((WordLabelUnit) instance.getGoldLabel().getLabel(instance.getWord(id))).getPosition(); for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (!unit.getPosition().equals(goldPosition)) { unitIt.remove(); } } String goldType = instance.getGoldType(label, id); if (!goldType.equals("")) { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (!unit.getType().equals(goldType)) { unitIt.remove(); } } } } if (id > 0) { int prevWordIndex = instance.getWord(id - 1); Collection<LabelUnit> relLabelUnits = ((Word) instance.getSequence().get(prevWordIndex)).getRelWordLabels(label, instance); if (relLabelUnits.size() != 0) { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (unit.getPosition().equals(WordLabelUnit.I) || unit.getPosition().equals(WordLabelUnit.L)) { unitIt.remove(); } } // B-*, O, U-* } if (prevWordIndex < size) { assert label.getLabel(prevWordIndex) instanceof WordLabelUnit; String position = ((WordLabelUnit) label.getLabel(prevWordIndex)).getPosition(); String type = ((WordLabelUnit) label.getLabel(prevWordIndex)).getType(); if (position.equals(WordLabelUnit.B) || position.equals(WordLabelUnit.I)) { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (!unit.getType().equals(type) || unit.getPosition().equals(WordLabelUnit.B) || unit.getPosition().equals(WordLabelUnit.U) || unit.getPosition().equals(WordLabelUnit.O)) { unitIt.remove(); } } // I-type, L-type } else { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (unit.getPosition().equals(WordLabelUnit.I) || unit.getPosition().equals(WordLabelUnit.L)) { unitIt.remove(); } } // B-*, O, U-* } } if (id > 1) { int pprevWordIndex = instance.getWord(id - 2); if (pprevWordIndex < size) { String position = ((WordLabelUnit) label.getLabel(pprevWordIndex)).getPosition(); String type = ((WordLabelUnit) label.getLabel(pprevWordIndex)).getType(); if (position.equals(WordLabelUnit.B) || position.equals(WordLabelUnit.I)) { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if ((unit.getPosition().equals(WordLabelUnit.L) || unit.getPosition().equals(WordLabelUnit.I)) && !unit.getType().equals(type)) { unitIt.remove(); } } // B-*, O, I-type, U-*, L-type } } } } else if (id == 0) { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (unit.getPosition().equals(WordLabelUnit.I) || unit.getPosition().equals(WordLabelUnit.L)) { unitIt.remove(); } } // B-*, O, U-* } if (id < instance.getNumWords() - 1) { int nextWordIndex = instance.getWord(id + 1); Collection<LabelUnit> relLabelUnits = ((Word) instance.getSequence().get(nextWordIndex)).getRelWordLabels(label, instance); if (relLabelUnits.size() != 0) { Set<String> possibleTypes = new TreeSet<String>(); for (LabelUnit relLabelUnit : relLabelUnits) { possibleTypes.add(((WordLabelUnit) relLabelUnit).getType()); } for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if ((unit.getPosition().equals(WordLabelUnit.B) || unit.getPosition().equals(WordLabelUnit.I)) && !possibleTypes.contains(unit.getType())) { unitIt.remove(); } } // B-TYPE, I-TYPE, L, O, U } if (nextWordIndex < size) { assert label.getLabel(nextWordIndex) instanceof WordLabelUnit; String position = ((WordLabelUnit) label.getLabel(nextWordIndex)).getPosition(); String type = ((WordLabelUnit) label.getLabel(nextWordIndex)).getType(); if (position.equals(WordLabelUnit.I) || position.equals(WordLabelUnit.L)) { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (!unit.getType().equals(type) || unit.getPosition().equals(WordLabelUnit.L) || unit.getPosition().equals(WordLabelUnit.U) || unit.getPosition().equals(WordLabelUnit.O)) { unitIt.remove(); } } // B-type, I-type } else { for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (unit.getPosition().equals(WordLabelUnit.B) || unit.getPosition().equals(WordLabelUnit.I)) { unitIt.remove(); } } // L-*, O, U-* } } if (id < instance.getNumWords() - 2) { int nNextWordIndex = instance.getWord(id + 2); if (nNextWordIndex < size) { String position = ((WordLabelUnit) label.getLabel(nNextWordIndex)).getPosition(); String type = ((WordLabelUnit) label.getLabel(nNextWordIndex)).getType(); if (position.equals(WordLabelUnit.L) || position.equals(WordLabelUnit.I)) { // B-type, O, I-type, U-*, L-* for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if ((unit.getPosition().equals(WordLabelUnit.B) || unit.getPosition().equals(WordLabelUnit.I)) && !unit.getType().equals(type)) { unitIt.remove(); } } } } } } else if (id == instance.getNumWords() - 1) { // L-*, O, U-* for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) { WordLabelUnit unit = (WordLabelUnit) unitIt.next(); if (unit.getPosition().equals(WordLabelUnit.B) || unit.getPosition().equals(WordLabelUnit.I)) { unitIt.remove(); } } } if (labels.size() == 0 && params.getVerbosity() > 0) { System.out.println("rel? " + hasRelation + " , id: " + id); System.out.println(getRelWordLabels(label, instance)); for (int i = 0; i < instance.getNumWords(); i++) { int index = instance.getWord(i); if (index >= size) { System.out.print(" " + i + ":" + index + ":?"); } else { System.out.print( " " + i + ":" + index + ":" + ((WordLabelUnit) label.getLabel(index)).getLabel()); } } System.out.println(); } assert labels.size() > 0 : id; return labels; }
@Override public LabelUnit getNegativeClassLabel() { return WordLabelUnit.getNegativeClassLabelUnit(); }