Example #1
0
 public Collection<LabelUnit> getRelWordLabels(Label label, JointInstance instance) {
   // TODO: no relations??
   Set<LabelUnit> units = new TreeSet<LabelUnit>();
   boolean hasRel = false;
   for (int wi = 0; wi < instance.getNumWords(); ++wi) {
     int pairIdx;
     if (wi < id) {
       pairIdx = instance.getWordPair(wi, id);
     } else if (wi > id) {
       pairIdx = instance.getWordPair(id, wi);
     } else {
       assert wi == id;
       if (params.getUseSelfRelation()) {
         pairIdx = instance.getWordPair(wi, id);
       } else {
         continue;
       }
     }
     if (pairIdx < label.size()) {
       assert label.getLabel(pairIdx) instanceof PairLabelUnit;
       if (!label.getLabel(pairIdx).isNegative()) {
         PairLabelUnit pairLabelUnit = (PairLabelUnit) label.getLabel(pairIdx);
         units.addAll(
             params
                 .getPossibleLabels()
                 .get(pairLabelUnit.toString() + ":" + (wi < id ? "2" : "1")));
         hasRel = true;
       }
     }
   }
   assert hasRel || units.size() == 0;
   return units;
 }
  protected SparseFeatureVector calcPairGlobalFeatures(
      Pair pair, Instance instance, Label y, int lastIndex, PairLabelUnit candidateLabelUnit) {
    SparseFeatureVector globalFv = new SparseFeatureVector(params);
    if (candidateLabelUnit.isNegative()) {
      return globalFv;
    }
    SparseFeatureVector parallelFv = new SparseFeatureVector(params);
    StringSparseVector counterFv = new StringSparseVector(params);
    StringSparseVector wordPairFv = new StringSparseVector(params);

    String candidateLabel = candidateLabelUnit.getLabel();
    Word w1 = pair.getW1();
    Word w2 = pair.getW2();

    JointInstance jInstance = (JointInstance) instance;
    int w1Index = jInstance.getWord(w1.getId());
    int w2Index = jInstance.getWord(w2.getId());

    String w1Base = w1.getWord().getRealBase();
    String w1POS = w1.getWord().getPOS();
    String w2Base = w2.getWord().getRealBase();
    String w2POS = w2.getWord().getPOS();

    for (int i = 0; i <= lastIndex; i++) {
      if (y.getLabel(i).isNegative()) continue;
      if (y.getLabel(i) instanceof PairLabelUnit) {
        if (!params.getUseGlobalRelationFeatures()) continue;
        PairLabelUnit adjLabelUnit = (PairLabelUnit) y.getLabel(i);
        String adjLabel = adjLabelUnit.getLabel();
        Pair adjPair = (Pair) instance.getSequence().get(i);
        Word adjW1 = adjPair.getW1();
        Word adjW2 = adjPair.getW2();
        if (adjW1.getId() == w1.getId()) {
          if (adjW2.getId() < w2.getId()) {
            addSecondOrderInfoToFV(counterFv, candidateLabel, w1Base, w1POS, adjLabel, "PARA-E1");
            addParallelInfoToFV(
                parallelFv,
                w1Base,
                w1POS,
                adjW2.getWord(),
                adjLabel,
                w2.getWord(),
                candidateLabel,
                "PARA-E1");
            int triangleIndex = jInstance.getWordPair(adjW2.getId(), w2.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    adjLabel,
                    w1Base,
                    w1POS,
                    candidateLabel,
                    w2Base,
                    w2POS,
                    thirdLabelUnit.getLabel(),
                    adjW2.getWord().getRealBase(),
                    adjW2.getWord().getPOS());
              }
            }
          } else {
            addSecondOrderInfoToFV(counterFv, adjLabel, w1Base, w1POS, candidateLabel, "PARA-E1");
            addParallelInfoToFV(
                parallelFv,
                w1Base,
                w1POS,
                w2.getWord(),
                candidateLabel,
                adjW2.getWord(),
                adjLabel,
                "PARA-E1");
            int triangleIndex = jInstance.getWordPair(w2.getId(), adjW2.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    candidateLabel,
                    w1Base,
                    w1POS,
                    adjLabel,
                    adjW2.getWord().getRealBase(),
                    adjW2.getWord().getPOS(),
                    thirdLabelUnit.getLabel(),
                    w2Base,
                    w2POS);
              }
            }
          }
        } else if (adjW2.getId() == w1.getId()) {
          if (adjW1.getId() < w2.getId()) {
            addParallelInfoToFV(
                parallelFv,
                w1Base,
                w1POS,
                adjW1.getWord(),
                adjLabel,
                w2.getWord(),
                candidateLabel,
                "SEQ");
            int triangleIndex = jInstance.getWordPair(adjW1.getId(), w2.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    adjLabel,
                    w1Base,
                    w1POS,
                    candidateLabel,
                    w2Base,
                    w2POS,
                    thirdLabelUnit.getLabel(),
                    adjW1.getWord().getRealBase(),
                    adjW1.getWord().getPOS());
              }
            }
          } else {
            addParallelInfoToFV(
                parallelFv,
                w1Base,
                w1POS,
                w2.getWord(),
                adjLabel,
                adjW1.getWord(),
                candidateLabel,
                "RSEQ");
            int triangleIndex = jInstance.getWordPair(w2.getId(), adjW1.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    candidateLabel,
                    w1Base,
                    w1POS,
                    adjLabel,
                    adjW1.getWord().getRealBase(),
                    adjW1.getWord().getPOS(),
                    thirdLabelUnit.getLabel(),
                    w2Base,
                    w2POS);
              }
            }
          }
          addSecondOrderInfoToFV(counterFv, candidateLabel, w1Base, w1POS, adjLabel, "SEQ");
        } else if (adjW2.getId() == w2.getId()) {
          if (adjW1.getId() < w1.getId()) {
            addParallelInfoToFV(
                parallelFv,
                w2Base,
                w2POS,
                adjW1.getWord(),
                adjLabel,
                w1.getWord(),
                candidateLabel,
                "PARA-E2");
            addSecondOrderInfoToFV(counterFv, candidateLabel, w2Base, w2POS, adjLabel, "PARA-E2");
            int triangleIndex = jInstance.getWordPair(adjW1.getId(), w1.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    adjLabel,
                    w2Base,
                    w2POS,
                    candidateLabel,
                    w1Base,
                    w1POS,
                    thirdLabelUnit.getLabel(),
                    adjW1.getWord().getRealBase(),
                    adjW1.getWord().getPOS());
              }
            }
          } else {
            addParallelInfoToFV(
                parallelFv,
                w2Base,
                w2POS,
                w1.getWord(),
                candidateLabel,
                adjW1.getWord(),
                adjLabel,
                "PARA-E2");
            addSecondOrderInfoToFV(counterFv, adjLabel, w2Base, w2POS, candidateLabel, "PARA-E2");
            int triangleIndex = jInstance.getWordPair(w1.getId(), adjW1.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    candidateLabel,
                    w2Base,
                    w2POS,
                    adjLabel,
                    adjW1.getWord().getRealBase(),
                    adjW1.getWord().getPOS(),
                    thirdLabelUnit.getLabel(),
                    w1Base,
                    w1POS);
              }
            }
          }
        } else if (adjW1.getId() == w2.getId()) {
          if (w1.getId() < adjW2.getId()) {
            addParallelInfoToFV(
                parallelFv,
                w2Base,
                w2POS,
                w1.getWord(),
                candidateLabel,
                adjW2.getWord(),
                adjLabel,
                "SEQ");
            int triangleIndex = jInstance.getWordPair(w1.getId(), adjW2.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    candidateLabel,
                    w2Base,
                    w2POS,
                    adjLabel,
                    adjW2.getWord().getRealBase(),
                    adjW2.getWord().getPOS(),
                    thirdLabelUnit.getLabel(),
                    w1Base,
                    w1POS);
              }
            }
          } else {
            addParallelInfoToFV(
                parallelFv,
                w2Base,
                w2POS,
                adjW2.getWord(),
                candidateLabel,
                w1.getWord(),
                adjLabel,
                "RSEQ");
            int triangleIndex = jInstance.getWordPair(adjW2.getId(), w1.getId());
            if (triangleIndex <= lastIndex) {
              PairLabelUnit thirdLabelUnit = ((PairLabelUnit) y.getLabel(triangleIndex));
              if (!thirdLabelUnit.isNegative()) {
                addTriangleToFV(
                    counterFv,
                    adjLabel,
                    w2Base,
                    w2POS,
                    candidateLabel,
                    w1Base,
                    w1POS,
                    thirdLabelUnit.getLabel(),
                    adjW2.getWord().getRealBase(),
                    adjW2.getWord().getPOS());
              }
            }
          }
          addSecondOrderInfoToFV(counterFv, adjLabel, w2Base, w2POS, candidateLabel, "SEQ");
        } else {
          if (adjW1.getId() < w1.getId()
              && w1.getId() < adjW2.getId()
              && adjW2.getId() < w2.getId()) {
            counterFv.add("PROJ".concat(adjLabel).concat(candidateLabel), 1.);
          } else if (w1.getId() < adjW1.getId()
              && adjW1.getId() < w2.getId()
              && w2.getId() < adjW2.getId()) {
            counterFv.add("PROJ".concat(candidateLabel).concat(adjLabel), 1.);
          }
        }
      } else {
        if (w1Index == i) {
          assert y.getLabel(i) instanceof WordLabelUnit;
          String w1Label = ((WordLabelUnit) y.getLabel(i)).getLabel();
          wordPairFv.add(w1Label.concat(candidateLabel), 1.);
          wordPairFv.add(w1Label.concat(w1Base).concat(candidateLabel), 1.);
        }
        if (w2Index == i) {
          assert y.getLabel(i) instanceof WordLabelUnit;
          String w2Label = ((WordLabelUnit) y.getLabel(i)).getLabel();
          wordPairFv.add(candidateLabel.concat(w2Label), 1.);
          wordPairFv.add(candidateLabel.concat(w2Base).concat(w2Label), 1.);
        }
      }
    }
    if (params.getUseFullFeatures()) {
      int w1Idx = ((JointInstance) instance).getWord(pair.getW1().getId());
      int w2Idx = ((JointInstance) instance).getWord(pair.getW2().getId());
      if (w1Idx <= lastIndex && w2Idx <= lastIndex) {
        // w-rel-w
        String w1Label = ((WordLabelUnit) y.getLabel(w1Idx)).getLabel();
        String w2Label = ((WordLabelUnit) y.getLabel(w2Idx)).getLabel();
        wordPairFv.add(w1Label.concat(candidateLabel).concat(w2Label), 1.);
        wordPairFv.add(
            w1Label.concat(w1Base).concat(candidateLabel).concat(w2Label).concat(w2Base), 1.);
        List<Word> e1Words =
            getEntityWords(pair.getW1(), instance, y, lastIndex, (WordLabelUnit) y.getLabel(w1Idx));
        String e1String = "", e2String = "", w1LabelType = "", w2LabelType = "";
        if (e1Words.size() != 0) {
          e1String = entityString(e1Words);
          w1LabelType = ((WordLabelUnit) y.getLabel(w1Idx)).getType();
          wordPairFv.add(w1LabelType.concat(candidateLabel), 1.);
          wordPairFv.add(w1LabelType.concat(e1String).concat(candidateLabel), 1.);
        }
        List<Word> e2Words =
            getEntityWords(pair.getW2(), instance, y, lastIndex, (WordLabelUnit) y.getLabel(w2Idx));
        if (e2Words.size() != 0) {
          e2String = entityString(e2Words);
          w2LabelType = ((WordLabelUnit) y.getLabel(w2Idx)).getType();
          wordPairFv.add(candidateLabel.concat(w2LabelType), 1.);
          wordPairFv.add(candidateLabel.concat(w2LabelType).concat(e2String), 1.);
        }
        // ent-rel-ent
        if (e1Words.size() != 0 && e2Words.size() != 0) {
          wordPairFv.add(w1LabelType.concat(candidateLabel).concat(w2LabelType), 1.);
          wordPairFv.add(
              w1LabelType
                  .concat(e1String)
                  .concat(candidateLabel)
                  .concat(w2LabelType)
                  .concat(e2String),
              1.);
        }
      }
    }
    globalFv.add(parallelFv, "PARA");
    globalFv.add(counterFv, "COUNT");
    globalFv.add(wordPairFv, "WORDPAIR");
    globalFv.scale(params.getRelWeight());
    return globalFv;
  }
Example #3
0
 public Collection<LabelUnit> getPossibleLabels(Label label, JointInstance instance) {
   int size = label.size();
   Collection<LabelUnit> labels = getRelWordLabels(label, instance);
   boolean hasRelation = labels.size() > 0;
   if (!hasRelation) {
     labels = new TreeSet<LabelUnit>(params.getPossibleLabels().get(getType()));
   }
   if (params.getUseGoldEntitySpan()) {
     String goldPosition =
         ((WordLabelUnit) instance.getGoldLabel().getLabel(instance.getWord(id))).getPosition();
     for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
       WordLabelUnit unit = (WordLabelUnit) unitIt.next();
       if (!unit.getPosition().equals(goldPosition)) {
         unitIt.remove();
       }
     }
     String goldType = instance.getGoldType(label, id);
     if (!goldType.equals("")) {
       for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
         WordLabelUnit unit = (WordLabelUnit) unitIt.next();
         if (!unit.getType().equals(goldType)) {
           unitIt.remove();
         }
       }
     }
   }
   if (id > 0) {
     int prevWordIndex = instance.getWord(id - 1);
     Collection<LabelUnit> relLabelUnits =
         ((Word) instance.getSequence().get(prevWordIndex)).getRelWordLabels(label, instance);
     if (relLabelUnits.size() != 0) {
       for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
         WordLabelUnit unit = (WordLabelUnit) unitIt.next();
         if (unit.getPosition().equals(WordLabelUnit.I)
             || unit.getPosition().equals(WordLabelUnit.L)) {
           unitIt.remove();
         }
       }
       // B-*, O, U-*
     }
     if (prevWordIndex < size) {
       assert label.getLabel(prevWordIndex) instanceof WordLabelUnit;
       String position = ((WordLabelUnit) label.getLabel(prevWordIndex)).getPosition();
       String type = ((WordLabelUnit) label.getLabel(prevWordIndex)).getType();
       if (position.equals(WordLabelUnit.B) || position.equals(WordLabelUnit.I)) {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (!unit.getType().equals(type)
               || unit.getPosition().equals(WordLabelUnit.B)
               || unit.getPosition().equals(WordLabelUnit.U)
               || unit.getPosition().equals(WordLabelUnit.O)) {
             unitIt.remove();
           }
         }
         // I-type, L-type
       } else {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (unit.getPosition().equals(WordLabelUnit.I)
               || unit.getPosition().equals(WordLabelUnit.L)) {
             unitIt.remove();
           }
         }
         // B-*, O, U-*
       }
     }
     if (id > 1) {
       int pprevWordIndex = instance.getWord(id - 2);
       if (pprevWordIndex < size) {
         String position = ((WordLabelUnit) label.getLabel(pprevWordIndex)).getPosition();
         String type = ((WordLabelUnit) label.getLabel(pprevWordIndex)).getType();
         if (position.equals(WordLabelUnit.B) || position.equals(WordLabelUnit.I)) {
           for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
             WordLabelUnit unit = (WordLabelUnit) unitIt.next();
             if ((unit.getPosition().equals(WordLabelUnit.L)
                     || unit.getPosition().equals(WordLabelUnit.I))
                 && !unit.getType().equals(type)) {
               unitIt.remove();
             }
           }
           // B-*, O, I-type, U-*, L-type
         }
       }
     }
   } else if (id == 0) {
     for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
       WordLabelUnit unit = (WordLabelUnit) unitIt.next();
       if (unit.getPosition().equals(WordLabelUnit.I)
           || unit.getPosition().equals(WordLabelUnit.L)) {
         unitIt.remove();
       }
     }
     // B-*, O, U-*
   }
   if (id < instance.getNumWords() - 1) {
     int nextWordIndex = instance.getWord(id + 1);
     Collection<LabelUnit> relLabelUnits =
         ((Word) instance.getSequence().get(nextWordIndex)).getRelWordLabels(label, instance);
     if (relLabelUnits.size() != 0) {
       Set<String> possibleTypes = new TreeSet<String>();
       for (LabelUnit relLabelUnit : relLabelUnits) {
         possibleTypes.add(((WordLabelUnit) relLabelUnit).getType());
       }
       for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
         WordLabelUnit unit = (WordLabelUnit) unitIt.next();
         if ((unit.getPosition().equals(WordLabelUnit.B)
                 || unit.getPosition().equals(WordLabelUnit.I))
             && !possibleTypes.contains(unit.getType())) {
           unitIt.remove();
         }
       }
       // B-TYPE, I-TYPE, L, O, U
     }
     if (nextWordIndex < size) {
       assert label.getLabel(nextWordIndex) instanceof WordLabelUnit;
       String position = ((WordLabelUnit) label.getLabel(nextWordIndex)).getPosition();
       String type = ((WordLabelUnit) label.getLabel(nextWordIndex)).getType();
       if (position.equals(WordLabelUnit.I) || position.equals(WordLabelUnit.L)) {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (!unit.getType().equals(type)
               || unit.getPosition().equals(WordLabelUnit.L)
               || unit.getPosition().equals(WordLabelUnit.U)
               || unit.getPosition().equals(WordLabelUnit.O)) {
             unitIt.remove();
           }
         }
         // B-type, I-type
       } else {
         for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
           WordLabelUnit unit = (WordLabelUnit) unitIt.next();
           if (unit.getPosition().equals(WordLabelUnit.B)
               || unit.getPosition().equals(WordLabelUnit.I)) {
             unitIt.remove();
           }
         }
         // L-*, O, U-*
       }
     }
     if (id < instance.getNumWords() - 2) {
       int nNextWordIndex = instance.getWord(id + 2);
       if (nNextWordIndex < size) {
         String position = ((WordLabelUnit) label.getLabel(nNextWordIndex)).getPosition();
         String type = ((WordLabelUnit) label.getLabel(nNextWordIndex)).getType();
         if (position.equals(WordLabelUnit.L) || position.equals(WordLabelUnit.I)) {
           // B-type, O, I-type, U-*, L-*
           for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
             WordLabelUnit unit = (WordLabelUnit) unitIt.next();
             if ((unit.getPosition().equals(WordLabelUnit.B)
                     || unit.getPosition().equals(WordLabelUnit.I))
                 && !unit.getType().equals(type)) {
               unitIt.remove();
             }
           }
         }
       }
     }
   } else if (id == instance.getNumWords() - 1) {
     // L-*, O, U-*
     for (Iterator<LabelUnit> unitIt = labels.iterator(); unitIt.hasNext(); ) {
       WordLabelUnit unit = (WordLabelUnit) unitIt.next();
       if (unit.getPosition().equals(WordLabelUnit.B)
           || unit.getPosition().equals(WordLabelUnit.I)) {
         unitIt.remove();
       }
     }
   }
   if (labels.size() == 0 && params.getVerbosity() > 0) {
     System.out.println("rel? " + hasRelation + " , id: " + id);
     System.out.println(getRelWordLabels(label, instance));
     for (int i = 0; i < instance.getNumWords(); i++) {
       int index = instance.getWord(i);
       if (index >= size) {
         System.out.print(" " + i + ":" + index + ":?");
       } else {
         System.out.print(
             " " + i + ":" + index + ":" + ((WordLabelUnit) label.getLabel(index)).getLabel());
       }
     }
     System.out.println();
   }
   assert labels.size() > 0 : id;
   return labels;
 }