예제 #1
0
 @Override
 public boolean equals(Object unit) {
   if (!(unit instanceof WordLabelUnit)) {
     return false;
   }
   WordLabelUnit labelUnit = (WordLabelUnit) unit;
   if (this.isNegative() && labelUnit.isNegative()) {
     return true;
   }
   return getLabel().equals(labelUnit.getLabel());
 }
예제 #2
0
  protected SparseFeatureVector calcWordGlobalFeatures(
      Word word, Instance instance, Label y, int lastIndex, WordLabelUnit candidateLabelUnit) {
    SparseFeatureVector globalFv = new SparseFeatureVector(params);
    StringSparseVector wordPairFv = new StringSparseVector(params);
    StringSparseVector wordFv = new StringSparseVector(params);

    String candidateLabel = candidateLabelUnit.getLabel();
    String wBase = word.getWord().getRealBase();
    String wPOS = word.getWord().getPOS();

    int[] contextWords = new int[4];
    Arrays.fill(contextWords, -1);

    for (int i = 0; i <= lastIndex; i++) {
      if (y.getLabel(i) instanceof PairLabelUnit) {
        if (y.getLabel(i).isNegative()) continue;
        if (candidateLabelUnit.isNegative()) {
          continue;
        }
        Pair pair = (Pair) instance.getSequence().get(i);
        if (pair.getW1().getId() == word.getId()) {
          String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel();
          wordPairFv.add(candidateLabel.concat(pairLabel), 1.);
          wordPairFv.add(candidateLabel.concat(wBase).concat(pairLabel), 1.);
          if (params.getUseFullFeatures()) {
            // w-rel-w
            int w2Idx = ((JointInstance) instance).getWord(pair.getW2().getId());
            if (w2Idx <= lastIndex) {
              String w2Base = pair.getW2().getWord().getRealBase();
              String w2Label = ((WordLabelUnit) y.getLabel(w2Idx)).getLabel();
              wordPairFv.add(candidateLabel.concat(pairLabel).concat(w2Label), 1.);
              wordPairFv.add(
                  candidateLabel.concat(wBase).concat(pairLabel).concat(w2Label).concat(w2Base),
                  1.);
            }
          }
        } else if (pair.getW2().getId() == word.getId()) {
          String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel();
          wordPairFv.add(pairLabel.concat(candidateLabel), 1.);
          wordPairFv.add(pairLabel.concat(wBase).concat(candidateLabel), 1.);
          if (params.getUseFullFeatures()) {
            // w-rel-w
            int w1Idx = ((JointInstance) instance).getWord(pair.getW1().getId());
            if (w1Idx <= lastIndex) {
              String w1Base = pair.getW1().getWord().getRealBase();
              String w1Label = ((WordLabelUnit) y.getLabel(w1Idx)).getLabel();
              wordPairFv.add(w1Label.concat(pairLabel).concat(candidateLabel), 1.);
              wordPairFv.add(
                  w1Base.concat(w1Label).concat(pairLabel).concat(candidateLabel).concat(wBase),
                  1.);
            }
          }
        }
      } else {
        assert y.getLabel(i) instanceof WordLabelUnit;
        if (!params.getUseGlobalEntityFeatures()) continue;
        Word w2 = (Word) instance.getSequence().get(i);
        if (w2.getId() == word.getId() - 2) {
          contextWords[0] = i;
        }
        if (w2.getId() == word.getId() - 1) {
          String wordLabel = ((WordLabelUnit) y.getLabel(i)).getLabel();
          contextWords[1] = i;
          addBigramToFV(
              wordFv,
              wordLabel,
              w2.getWord().getRealBase(),
              w2.getWord().getPOS(),
              candidateLabel,
              wBase,
              wPOS);
        } else if (word.getId() == w2.getId() - 1) {
          String wordLabel = ((WordLabelUnit) y.getLabel(i)).getLabel();
          contextWords[2] = i;
          addBigramToFV(
              wordFv,
              candidateLabel,
              wBase,
              wPOS,
              wordLabel,
              w2.getWord().getRealBase(),
              w2.getWord().getPOS());
        } else if (word.getId() == w2.getId() - 2) {
          contextWords[3] = i;
        }
      }
    }

    List<Word> entityWords = getEntityWords(word, instance, y, lastIndex, candidateLabelUnit);
    if (entityWords.size() > 0) {
      if (params.getUseFullFeatures()) {
        String entString = entityString(entityWords);
        String candidateLabelType = candidateLabelUnit.getType();
        wordFv.add(candidateLabelType.concat(entString), 1.);
        Word lastWord = entityWords.get(entityWords.size() - 1);
        for (int i = 0; i <= lastIndex; i++) {
          if (y.getLabel(i) instanceof PairLabelUnit) {
            if (y.getLabel(i).isNegative()) continue;
            Pair pair = (Pair) instance.getSequence().get(i);
            if (pair.getW1().getId() == lastWord.getId()) {
              String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel();
              wordPairFv.add(candidateLabelType.concat(pairLabel), 1.);
              wordPairFv.add(candidateLabelType.concat(entString).concat(pairLabel), 1.);
              // ent-rel-ent
              int w2Idx = ((JointInstance) instance).getWord(pair.getW2().getId());
              if (w2Idx <= lastIndex) {
                List<Word> e2Words =
                    getEntityWords(
                        pair.getW2(), instance, y, lastIndex, (WordLabelUnit) y.getLabel(w2Idx));
                if (e2Words.size() > 0) {
                  String e2String = entityString(e2Words);
                  String w2LabelType = ((WordLabelUnit) y.getLabel(w2Idx)).getType();
                  wordPairFv.add(candidateLabelType.concat(pairLabel).concat(w2LabelType), 1.);
                  wordPairFv.add(
                      candidateLabelType
                          .concat(entString)
                          .concat(pairLabel)
                          .concat(w2LabelType)
                          .concat(e2String),
                      1.);
                }
              }
            } else if (pair.getW2().getId() == lastWord.getId()) {
              String pairLabel = ((PairLabelUnit) y.getLabel(i)).getLabel();
              wordPairFv.add(pairLabel.concat(candidateLabelType), 1.);
              wordPairFv.add(pairLabel.concat(candidateLabelType).concat(entString), 1.);
              // ent-rel-ent
              int w1Idx = ((JointInstance) instance).getWord(pair.getW1().getId());
              if (w1Idx <= lastIndex) {
                List<Word> e1Words =
                    getEntityWords(
                        pair.getW1(), instance, y, lastIndex, (WordLabelUnit) y.getLabel(w1Idx));
                if (e1Words.size() > 0) {
                  String e1String = entityString(e1Words);
                  String w1LabelType = ((WordLabelUnit) y.getLabel(w1Idx)).getType();
                  wordPairFv.add(w1LabelType.concat(pairLabel).concat(candidateLabelType), 1.);
                  wordPairFv.add(
                      w1LabelType
                          .concat(e1String)
                          .concat(pairLabel)
                          .concat(candidateLabelType)
                          .concat(entString),
                      1.);
                }
              }
            }
          }
        }
      }
    }
    if (contextWords[0] >= 0 && contextWords[1] >= 0) {
      Word w0 = (Word) instance.getSequence().get(contextWords[0]);
      String l0 = ((WordLabelUnit) y.getLabel(contextWords[0])).getLabel();
      Word w1 = (Word) instance.getSequence().get(contextWords[1]);
      String l1 = ((WordLabelUnit) y.getLabel(contextWords[1])).getLabel();
      addTrigramToFV(
          wordFv,
          l0,
          w0.getWord().getRealBase(),
          w0.getWord().getPOS(),
          l1,
          w1.getWord().getRealBase(),
          w1.getWord().getPOS(),
          candidateLabel,
          wBase,
          wPOS);
    }
    if (contextWords[1] >= 0 && contextWords[2] >= 0) {
      Word w1 = (Word) instance.getSequence().get(contextWords[1]);
      String l1 = ((WordLabelUnit) y.getLabel(contextWords[1])).getLabel();
      Word w2 = (Word) instance.getSequence().get(contextWords[2]);
      String l2 = ((WordLabelUnit) y.getLabel(contextWords[2])).getLabel();
      addTrigramToFV(
          wordFv,
          l1,
          w1.getWord().getRealBase(),
          w1.getWord().getPOS(),
          candidateLabel,
          wBase,
          wPOS,
          l2,
          w2.getWord().getRealBase(),
          w2.getWord().getPOS());
    }
    if (contextWords[2] >= 0 && contextWords[3] >= 0) {
      Word w2 = (Word) instance.getSequence().get(contextWords[2]);
      String l2 = ((WordLabelUnit) y.getLabel(contextWords[2])).getLabel();
      Word w3 = (Word) instance.getSequence().get(contextWords[3]);
      String l3 = ((WordLabelUnit) y.getLabel(contextWords[3])).getLabel();
      addTrigramToFV(
          wordFv,
          candidateLabel,
          wBase,
          wPOS,
          l2,
          w2.getWord().getRealBase(),
          w2.getWord().getPOS(),
          l3,
          w3.getWord().getRealBase(),
          w3.getWord().getPOS());
    }
    globalFv.add(wordFv, "WORD");
    wordPairFv.mult(params.getRelWeight());
    globalFv.add(wordPairFv, "WORDPAIR");
    return globalFv;
  }
예제 #3
0
 protected List<Word> getEntityWords(
     Word word, Instance instance, Label y, int lastIndex, WordLabelUnit candidateLabelUnit) {
   List<Word> words = new Vector<Word>();
   if (candidateLabelUnit.isNegative()) {
     return words;
   }
   int targetId = word.getId();
   if (candidateLabelUnit.getPosition().equals("U")) {
     words.add(word);
   } else if (candidateLabelUnit.getPosition().equals("L")) {
     words.add(word);
     boolean found = false;
     for (int i = targetId - 1; i >= 0; --i) {
       int seqIndex = ((JointInstance) instance).getWord(i);
       if (seqIndex >= lastIndex) {
         break;
       }
       String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition();
       if (pos.equals("B")) {
         words.add(0, word);
         found = true;
       } else if (pos.equals("I")) {
         words.add(0, word);
       }
     }
     if (!found) {
       words.clear();
     }
   } else if (candidateLabelUnit.getPosition().equals("I")) {
     words.add(word);
     boolean leftFound = false;
     for (int i = targetId - 1; i >= 0; --i) {
       int seqIndex = ((JointInstance) instance).getWord(i);
       if (seqIndex >= lastIndex) {
         break;
       }
       String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition();
       if (pos.equals("B")) {
         words.add(0, word);
         leftFound = true;
       } else if (pos.equals("I")) {
         words.add(0, word);
       }
     }
     if (!leftFound) {
       words.clear();
     } else {
       boolean rightFound = false;
       int nwords = ((JointInstance) instance).getNumWords();
       for (int i = targetId + 1; i < nwords; ++i) {
         int seqIndex = ((JointInstance) instance).getWord(i);
         if (seqIndex >= lastIndex) {
           break;
         }
         String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition();
         if (pos.equals("L")) {
           words.add(word);
           rightFound = true;
         } else if (pos.equals("I")) {
           words.add(word);
         }
       }
       if (!rightFound) {
         words.clear();
       }
     }
   } else if (candidateLabelUnit.getPosition().equals("B")) {
     boolean found = false;
     int nwords = ((JointInstance) instance).getNumWords();
     for (int i = targetId + 1; i < nwords; ++i) {
       int seqIndex = ((JointInstance) instance).getWord(i);
       if (seqIndex >= lastIndex) {
         break;
       }
       String pos = ((WordLabelUnit) y.getLabel(seqIndex)).getPosition();
       if (pos.equals("L")) {
         words.add(word);
         found = true;
       } else if (pos.equals("I")) {
         words.add(word);
       }
     }
     if (!found) {
       words.clear();
     }
   }
   return words;
 }
예제 #4
0
파일: Word.java 프로젝트: tticoin/JointER
 public Collection<LabelUnit> getPossibleLabels(Label label, JointInstance instance) {
   int size = label.size();
   Collection<LabelUnit> labels = getRelWordLabels(label, instance);
   boolean hasRelation = labels.size() > 0;
   if (!hasRelation) {
     labels = new TreeSet<LabelUnit>(params.getPossibleLabels().get(getType()));
   }
   if (params.getUseGoldEntitySpan()) {
     String goldPosition =
         ((WordLabelUnit) instance.getGoldLabel().getLabel(instance.getWord(id))).getPosition();
     for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
       WordLabelUnit unit = (WordLabelUnit) unitIt.next();
       if (!unit.getPosition().equals(goldPosition)) {
         unitIt.remove();
       }
     }
     String goldType = instance.getGoldType(label, id);
     if (!goldType.equals("")) {
       for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
         WordLabelUnit unit = (WordLabelUnit) unitIt.next();
         if (!unit.getType().equals(goldType)) {
           unitIt.remove();
         }
       }
     }
   }
   if (id > 0) {
     int prevWordIndex = instance.getWord(id - 1);
     Collection<LabelUnit> relLabelUnits =
         ((Word) instance.getSequence().get(prevWordIndex)).getRelWordLabels(label, instance);
     if (relLabelUnits.size() != 0) {
       for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
         WordLabelUnit unit = (WordLabelUnit) unitIt.next();
         if (unit.getPosition().equals(WordLabelUnit.I)
             || unit.getPosition().equals(WordLabelUnit.L)) {
           unitIt.remove();
         }
       }
       // B-*, O, U-*
     }
     if (prevWordIndex < size) {
       assert label.getLabel(prevWordIndex) instanceof WordLabelUnit;
       String position = ((WordLabelUnit) label.getLabel(prevWordIndex)).getPosition();
       String type = ((WordLabelUnit) label.getLabel(prevWordIndex)).getType();
       if (position.equals(WordLabelUnit.B) || position.equals(WordLabelUnit.I)) {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (!unit.getType().equals(type)
               || unit.getPosition().equals(WordLabelUnit.B)
               || unit.getPosition().equals(WordLabelUnit.U)
               || unit.getPosition().equals(WordLabelUnit.O)) {
             unitIt.remove();
           }
         }
         // I-type, L-type
       } else {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (unit.getPosition().equals(WordLabelUnit.I)
               || unit.getPosition().equals(WordLabelUnit.L)) {
             unitIt.remove();
           }
         }
         // B-*, O, U-*
       }
     }
     if (id > 1) {
       int pprevWordIndex = instance.getWord(id - 2);
       if (pprevWordIndex < size) {
         String position = ((WordLabelUnit) label.getLabel(pprevWordIndex)).getPosition();
         String type = ((WordLabelUnit) label.getLabel(pprevWordIndex)).getType();
         if (position.equals(WordLabelUnit.B) || position.equals(WordLabelUnit.I)) {
           for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
             WordLabelUnit unit = (WordLabelUnit) unitIt.next();
             if ((unit.getPosition().equals(WordLabelUnit.L)
                     || unit.getPosition().equals(WordLabelUnit.I))
                 && !unit.getType().equals(type)) {
               unitIt.remove();
             }
           }
           // B-*, O, I-type, U-*, L-type
         }
       }
     }
   } else if (id == 0) {
     for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
       WordLabelUnit unit = (WordLabelUnit) unitIt.next();
       if (unit.getPosition().equals(WordLabelUnit.I)
           || unit.getPosition().equals(WordLabelUnit.L)) {
         unitIt.remove();
       }
     }
     // B-*, O, U-*
   }
   if (id < instance.getNumWords() - 1) {
     int nextWordIndex = instance.getWord(id + 1);
     Collection<LabelUnit> relLabelUnits =
         ((Word) instance.getSequence().get(nextWordIndex)).getRelWordLabels(label, instance);
     if (relLabelUnits.size() != 0) {
       Set<String> possibleTypes = new TreeSet<String>();
       for (LabelUnit relLabelUnit : relLabelUnits) {
         possibleTypes.add(((WordLabelUnit) relLabelUnit).getType());
       }
       for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
         WordLabelUnit unit = (WordLabelUnit) unitIt.next();
         if ((unit.getPosition().equals(WordLabelUnit.B)
                 || unit.getPosition().equals(WordLabelUnit.I))
             && !possibleTypes.contains(unit.getType())) {
           unitIt.remove();
         }
       }
       // B-TYPE, I-TYPE, L, O, U
     }
     if (nextWordIndex < size) {
       assert label.getLabel(nextWordIndex) instanceof WordLabelUnit;
       String position = ((WordLabelUnit) label.getLabel(nextWordIndex)).getPosition();
       String type = ((WordLabelUnit) label.getLabel(nextWordIndex)).getType();
       if (position.equals(WordLabelUnit.I) || position.equals(WordLabelUnit.L)) {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (!unit.getType().equals(type)
               || unit.getPosition().equals(WordLabelUnit.L)
               || unit.getPosition().equals(WordLabelUnit.U)
               || unit.getPosition().equals(WordLabelUnit.O)) {
             unitIt.remove();
           }
         }
         // B-type, I-type
       } else {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (unit.getPosition().equals(WordLabelUnit.B)
               || unit.getPosition().equals(WordLabelUnit.I)) {
             unitIt.remove();
           }
         }
         // L-*, O, U-*
       }
     }
     if (id < instance.getNumWords() - 2) {
       int nNextWordIndex = instance.getWord(id + 2);
       if (nNextWordIndex < size) {
         String position = ((WordLabelUnit) label.getLabel(nNextWordIndex)).getPosition();
         String type = ((WordLabelUnit) label.getLabel(nNextWordIndex)).getType();
         if (position.equals(WordLabelUnit.L) || position.equals(WordLabelUnit.I)) {
           // B-type, O, I-type, U-*, L-*
           for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
             WordLabelUnit unit = (WordLabelUnit) unitIt.next();
             if ((unit.getPosition().equals(WordLabelUnit.B)
                     || unit.getPosition().equals(WordLabelUnit.I))
                 && !unit.getType().equals(type)) {
               unitIt.remove();
             }
           }
         }
       }
     }
   } else if (id == instance.getNumWords() - 1) {
     // L-*, O, U-*
     for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
       WordLabelUnit unit = (WordLabelUnit) unitIt.next();
       if (unit.getPosition().equals(WordLabelUnit.B)
           || unit.getPosition().equals(WordLabelUnit.I)) {
         unitIt.remove();
       }
     }
   }
   if (labels.size() == 0 && params.getVerbosity() > 0) {
     System.out.println("rel? " + hasRelation + " , id: " + id);
     System.out.println(getRelWordLabels(label, instance));
     for (int i = 0; i < instance.getNumWords(); i++) {
       int index = instance.getWord(i);
       if (index >= size) {
         System.out.print(" " + i + ":" + index + ":?");
       } else {
         System.out.print(
             " " + i + ":" + index + ":" + ((WordLabelUnit) label.getLabel(index)).getLabel());
       }
     }
     System.out.println();
   }
   assert labels.size() > 0 : id;
   return labels;
 }
예제 #5
0
파일: Word.java 프로젝트: tticoin/JointER
 @Override
 public LabelUnit getNegativeClassLabel() {
   return WordLabelUnit.getNegativeClassLabelUnit();
 }