Exemplo n.º 1
0
  private void resetState() {
    positions.reset();
    unknownWordEndIndex = -1;
    pos = 0;
    end = false;
    lastBackTracePos = 0;
    lastTokenPos = -1;
    pending.clear();

    // Add BOS:
    positions.get(0).add(0, 0, -1, -1, -1, Type.KNOWN);
  }
Exemplo n.º 2
0
  private void pruneAndRescore(int startPos, int endPos, int bestStartIDX) throws IOException {
    if (VERBOSE) {
      System.out.println(
          "  pruneAndRescore startPos="
              + startPos
              + " endPos="
              + endPos
              + " bestStartIDX="
              + bestStartIDX);
    }

    // First pass: walk backwards, building up the forward
    // arcs and pruning inadmissible arcs:
    for (int pos = endPos; pos > startPos; pos--) {
      final Position posData = positions.get(pos);
      if (VERBOSE) {
        System.out.println("    back pos=" + pos);
      }
      for (int arcIDX = 0; arcIDX < posData.count; arcIDX++) {
        final int backPos = posData.backPos[arcIDX];
        if (backPos >= startPos) {
          // Keep this arc:
          // System.out.println("      keep backPos=" + backPos);
          positions
              .get(backPos)
              .addForward(pos, arcIDX, posData.backID[arcIDX], posData.backType[arcIDX]);
        } else {
          if (VERBOSE) {
            System.out.println("      prune");
          }
        }
      }
      if (pos != startPos) {
        posData.count = 0;
      }
    }

    // Second pass: walk forward, re-scoring:
    for (int pos = startPos; pos < endPos; pos++) {
      final Position posData = positions.get(pos);
      if (VERBOSE) {
        System.out.println("    forward pos=" + pos + " count=" + posData.forwardCount);
      }
      if (posData.count == 0) {
        // No arcs arrive here...
        if (VERBOSE) {
          System.out.println("      skip");
        }
        posData.forwardCount = 0;
        continue;
      }

      if (pos == startPos) {
        // On the initial position, only consider the best
        // path so we "force congruence":  the
        // sub-segmentation is "in context" of what the best
        // path (compound token) had matched:
        final int rightID;
        if (startPos == 0) {
          rightID = 0;
        } else {
          rightID =
              getDict(posData.backType[bestStartIDX]).getRightId(posData.backID[bestStartIDX]);
        }
        final int pathCost = posData.costs[bestStartIDX];
        for (int forwardArcIDX = 0; forwardArcIDX < posData.forwardCount; forwardArcIDX++) {
          final Type forwardType = posData.forwardType[forwardArcIDX];
          final Dictionary dict2 = getDict(forwardType);
          final int wordID = posData.forwardID[forwardArcIDX];
          final int toPos = posData.forwardPos[forwardArcIDX];
          final int newCost =
              pathCost
                  + dict2.getWordCost(wordID)
                  + costs.get(rightID, dict2.getLeftId(wordID))
                  + computePenalty(pos, toPos - pos);
          if (VERBOSE) {
            System.out.println(
                "      + "
                    + forwardType
                    + " word "
                    + new String(buffer.get(pos, toPos - pos))
                    + " toPos="
                    + toPos
                    + " cost="
                    + newCost
                    + " penalty="
                    + computePenalty(pos, toPos - pos)
                    + " toPos.idx="
                    + positions.get(toPos).count);
          }
          positions
              .get(toPos)
              .add(newCost, dict2.getRightId(wordID), pos, bestStartIDX, wordID, forwardType);
        }
      } else {
        // On non-initial positions, we maximize score
        // across all arriving lastRightIDs:
        for (int forwardArcIDX = 0; forwardArcIDX < posData.forwardCount; forwardArcIDX++) {
          final Type forwardType = posData.forwardType[forwardArcIDX];
          final int toPos = posData.forwardPos[forwardArcIDX];
          if (VERBOSE) {
            System.out.println(
                "      + "
                    + forwardType
                    + " word "
                    + new String(buffer.get(pos, toPos - pos))
                    + " toPos="
                    + toPos);
          }
          add(
              getDict(forwardType),
              posData,
              toPos,
              posData.forwardID[forwardArcIDX],
              forwardType,
              true);
        }
      }
      posData.forwardCount = 0;
    }
  }
Exemplo n.º 3
0
  /* Incrementally parse some more characters.  This runs
   * the viterbi search forwards "enough" so that we
   * generate some more tokens.  How much forward depends on
   * the chars coming in, since some chars could cause
   * longer-lasting ambiguity in the parsing.  Once the
   * ambiguity is resolved, then we back trace, produce
   * the pending tokens, and return. */
  private void parse() throws IOException {
    if (VERBOSE) {
      System.out.println("\nPARSE");
    }

    // Advances over each position (character):
    while (true) {

      if (buffer.get(pos) == -1) {
        // End
        break;
      }

      final Position posData = positions.get(pos);
      final boolean isFrontier = positions.getNextPos() == pos + 1;

      if (posData.count == 0) {
        // No arcs arrive here; move to next position:
        if (VERBOSE) {
          System.out.println("    no arcs in; skip pos=" + pos);
        }
        pos++;
        continue;
      }

      if (pos > lastBackTracePos && posData.count == 1 && isFrontier) {
        //  if (pos > lastBackTracePos && posData.count == 1 && isFrontier) {
        // We are at a "frontier", and only one node is
        // alive, so whatever the eventual best path is must
        // come through this node.  So we can safely commit
        // to the prefix of the best path at this point:
        backtrace(posData, 0);

        // Re-base cost so we don't risk int overflow:
        posData.costs[0] = 0;

        if (pending.size() != 0) {
          return;
        } else {
          // This means the backtrace only produced
          // punctuation tokens, so we must keep parsing.
        }
      }

      if (pos - lastBackTracePos >= MAX_BACKTRACE_GAP) {
        // Safety: if we've buffered too much, force a
        // backtrace now.  We find the least-cost partial
        // path, across all paths, backtrace from it, and
        // then prune all others.  Note that this, in
        // general, can produce the wrong result, if the
        // total best path did not in fact back trace
        // through this partial best path.  But it's the
        // best we can do... (short of not having a
        // safety!).

        // First pass: find least cost partial path so far,
        // including ending at future positions:
        int leastIDX = -1;
        int leastCost = Integer.MAX_VALUE;
        Position leastPosData = null;
        for (int pos2 = pos; pos2 < positions.getNextPos(); pos2++) {
          final Position posData2 = positions.get(pos2);
          for (int idx = 0; idx < posData2.count; idx++) {
            // System.out.println("    idx=" + idx + " cost=" + cost);
            final int cost = posData2.costs[idx];
            if (cost < leastCost) {
              leastCost = cost;
              leastIDX = idx;
              leastPosData = posData2;
            }
          }
        }

        // We will always have at least one live path:
        assert leastIDX != -1;

        // Second pass: prune all but the best path:
        for (int pos2 = pos; pos2 < positions.getNextPos(); pos2++) {
          final Position posData2 = positions.get(pos2);
          if (posData2 != leastPosData) {
            posData2.reset();
          } else {
            if (leastIDX != 0) {
              posData2.costs[0] = posData2.costs[leastIDX];
              posData2.lastRightID[0] = posData2.lastRightID[leastIDX];
              posData2.backPos[0] = posData2.backPos[leastIDX];
              posData2.backIndex[0] = posData2.backIndex[leastIDX];
              posData2.backID[0] = posData2.backID[leastIDX];
              posData2.backType[0] = posData2.backType[leastIDX];
            }
            posData2.count = 1;
          }
        }

        backtrace(leastPosData, 0);

        // Re-base cost so we don't risk int overflow:
        Arrays.fill(leastPosData.costs, 0, leastPosData.count, 0);

        if (pos != leastPosData.pos) {
          // We jumped into a future position:
          assert pos < leastPosData.pos;
          pos = leastPosData.pos;
        }

        if (pending.size() != 0) {
          return;
        } else {
          // This means the backtrace only produced
          // punctuation tokens, so we must keep parsing.
          continue;
        }
      }

      if (VERBOSE) {
        System.out.println("\n  extend @ pos=" + pos + " char=" + (char) buffer.get(pos));
      }

      if (VERBOSE) {
        System.out.println("    " + posData.count + " arcs in");
      }

      boolean anyMatches = false;

      // First try user dict:
      if (userFST != null) {
        userFST.getFirstArc(arc);
        int output = 0;
        for (int posAhead = posData.pos; ; posAhead++) {
          final int ch = buffer.get(posAhead);
          if (ch == -1) {
            break;
          }
          if (userFST.findTargetArc(ch, arc, arc, posAhead == posData.pos, userFSTReader) == null) {
            break;
          }
          output += arc.output.intValue();
          if (arc.isFinal()) {
            if (VERBOSE) {
              System.out.println(
                  "    USER word "
                      + new String(buffer.get(pos, posAhead - pos + 1))
                      + " toPos="
                      + (posAhead + 1));
            }
            add(
                userDictionary,
                posData,
                posAhead + 1,
                output + arc.nextFinalOutput.intValue(),
                Type.USER,
                false);
            anyMatches = true;
          }
        }
      }

      // TODO: we can be more aggressive about user
      // matches?  if we are "under" a user match then don't
      // extend KNOWN/UNKNOWN paths?

      if (!anyMatches) {
        // Next, try known dictionary matches
        fst.getFirstArc(arc);
        int output = 0;

        for (int posAhead = posData.pos; ; posAhead++) {
          final int ch = buffer.get(posAhead);
          if (ch == -1) {
            break;
          }
          // System.out.println("    match " + (char) ch + " posAhead=" + posAhead);

          if (fst.findTargetArc(ch, arc, arc, posAhead == posData.pos, fstReader) == null) {
            break;
          }

          output += arc.output.intValue();

          // Optimization: for known words that are too-long
          // (compound), we should pre-compute the 2nd
          // best segmentation and store it in the
          // dictionary instead of recomputing it each time a
          // match is found.

          if (arc.isFinal()) {
            dictionary.lookupWordIds(output + arc.nextFinalOutput.intValue(), wordIdRef);
            if (VERBOSE) {
              System.out.println(
                  "    KNOWN word "
                      + new String(buffer.get(pos, posAhead - pos + 1))
                      + " toPos="
                      + (posAhead + 1)
                      + " "
                      + wordIdRef.length
                      + " wordIDs");
            }
            for (int ofs = 0; ofs < wordIdRef.length; ofs++) {
              add(
                  dictionary,
                  posData,
                  posAhead + 1,
                  wordIdRef.ints[wordIdRef.offset + ofs],
                  Type.KNOWN,
                  false);
              anyMatches = true;
            }
          }
        }
      }

      // In the case of normal mode, it doesn't process unknown word greedily.

      if (!searchMode && unknownWordEndIndex > posData.pos) {
        pos++;
        continue;
      }

      final char firstCharacter = (char) buffer.get(pos);
      if (!anyMatches || characterDefinition.isInvoke(firstCharacter)) {

        // Find unknown match:
        final int characterId = characterDefinition.getCharacterClass(firstCharacter);
        final boolean isPunct = isPunctuation(firstCharacter);

        // NOTE: copied from UnknownDictionary.lookup:
        int unknownWordLength;
        if (!characterDefinition.isGroup(firstCharacter)) {
          unknownWordLength = 1;
        } else {
          // Extract unknown word. Characters with the same character class are considered to be
          // part of unknown word
          unknownWordLength = 1;
          for (int posAhead = pos + 1; unknownWordLength < MAX_UNKNOWN_WORD_LENGTH; posAhead++) {
            final int ch = buffer.get(posAhead);
            if (ch == -1) {
              break;
            }
            if (characterId == characterDefinition.getCharacterClass((char) ch)
                && isPunctuation((char) ch) == isPunct) {
              unknownWordLength++;
            } else {
              break;
            }
          }
        }

        unkDictionary.lookupWordIds(
            characterId, wordIdRef); // characters in input text are supposed to be the same
        if (VERBOSE) {
          System.out.println(
              "    UNKNOWN word len=" + unknownWordLength + " " + wordIdRef.length + " wordIDs");
        }
        for (int ofs = 0; ofs < wordIdRef.length; ofs++) {
          add(
              unkDictionary,
              posData,
              posData.pos + unknownWordLength,
              wordIdRef.ints[wordIdRef.offset + ofs],
              Type.UNKNOWN,
              false);
        }

        unknownWordEndIndex = posData.pos + unknownWordLength;
      }

      pos++;
    }

    end = true;

    if (pos > 0) {

      final Position endPosData = positions.get(pos);
      int leastCost = Integer.MAX_VALUE;
      int leastIDX = -1;
      if (VERBOSE) {
        System.out.println("  end: " + endPosData.count + " nodes");
      }
      for (int idx = 0; idx < endPosData.count; idx++) {
        // Add EOS cost:
        final int cost = endPosData.costs[idx] + costs.get(endPosData.lastRightID[idx], 0);
        // System.out.println("    idx=" + idx + " cost=" + cost + " (pathCost=" +
        // endPosData.costs[idx] + " bgCost=" + costs.get(endPosData.lastRightID[idx], 0) + ")
        // backPos=" + endPosData.backPos[idx]);
        if (cost < leastCost) {
          leastCost = cost;
          leastIDX = idx;
        }
      }

      backtrace(endPosData, leastIDX);
    } else {
      // No characters in the input string; return no tokens!
    }
  }
Exemplo n.º 4
0
  private void add(
      Dictionary dict, Position fromPosData, int endPos, int wordID, Type type, boolean addPenalty)
      throws IOException {
    final int wordCost = dict.getWordCost(wordID);
    final int leftID = dict.getLeftId(wordID);
    int leastCost = Integer.MAX_VALUE;
    int leastIDX = -1;
    assert fromPosData.count > 0;
    for (int idx = 0; idx < fromPosData.count; idx++) {
      // Cost is path cost so far, plus word cost (added at
      // end of loop), plus bigram cost:
      final int cost = fromPosData.costs[idx] + costs.get(fromPosData.lastRightID[idx], leftID);
      if (VERBOSE) {
        System.out.println(
            "      fromIDX="
                + idx
                + ": cost="
                + cost
                + " (prevCost="
                + fromPosData.costs[idx]
                + " wordCost="
                + wordCost
                + " bgCost="
                + costs.get(fromPosData.lastRightID[idx], leftID)
                + " leftID="
                + leftID);
      }
      if (cost < leastCost) {
        leastCost = cost;
        leastIDX = idx;
        if (VERBOSE) {
          System.out.println("        **");
        }
      }
    }

    leastCost += wordCost;

    if (VERBOSE) {
      System.out.println(
          "      + cost="
              + leastCost
              + " wordID="
              + wordID
              + " leftID="
              + leftID
              + " leastIDX="
              + leastIDX
              + " toPos="
              + endPos
              + " toPos.idx="
              + positions.get(endPos).count);
    }

    if ((addPenalty || (!outputCompounds && searchMode)) && type != Type.USER) {
      final int penalty = computePenalty(fromPosData.pos, endPos - fromPosData.pos);
      if (VERBOSE) {
        if (penalty > 0) {
          System.out.println("        + penalty=" + penalty + " cost=" + (leastCost + penalty));
        }
      }
      leastCost += penalty;
    }

    // positions.get(endPos).add(leastCost, dict.getRightId(wordID), fromPosData.pos, leastIDX,
    // wordID, type);
    assert leftID == dict.getRightId(wordID);
    positions.get(endPos).add(leastCost, leftID, fromPosData.pos, leastIDX, wordID, type);
  }
Exemplo n.º 5
0
  // Backtrace from the provided position, back to the last
  // time we back-traced, accumulating the resulting tokens to
  // the pending list.  The pending list is then in-reverse
  // (last token should be returned first).
  private void backtrace(final Position endPosData, final int fromIDX) throws IOException {
    final int endPos = endPosData.pos;

    if (VERBOSE) {
      System.out.println(
          "\n  backtrace: endPos="
              + endPos
              + " pos="
              + pos
              + "; "
              + (pos - lastBackTracePos)
              + " characters; last="
              + lastBackTracePos
              + " cost="
              + endPosData.costs[fromIDX]);
    }

    final char[] fragment = buffer.get(lastBackTracePos, endPos - lastBackTracePos);

    if (dotOut != null) {
      dotOut.onBacktrace(this, positions, lastBackTracePos, endPosData, fromIDX, fragment, end);
    }

    int pos = endPos;
    int bestIDX = fromIDX;
    Token altToken = null;

    // We trace backwards, so this will be the leftWordID of
    // the token after the one we are now on:
    int lastLeftWordID = -1;

    int backCount = 0;

    // TODO: sort of silly to make Token instances here; the
    // back trace has all info needed to generate the
    // token.  So, we could just directly set the attrs,
    // from the backtrace, in incrementToken w/o ever
    // creating Token; we'd have to defer calling freeBefore
    // until after the backtrace was fully "consumed" by
    // incrementToken.

    while (pos > lastBackTracePos) {
      // System.out.println("BT: back pos=" + pos + " bestIDX=" + bestIDX);
      final Position posData = positions.get(pos);
      assert bestIDX < posData.count;

      int backPos = posData.backPos[bestIDX];
      assert backPos >= lastBackTracePos
          : "backPos=" + backPos + " vs lastBackTracePos=" + lastBackTracePos;
      int length = pos - backPos;
      Type backType = posData.backType[bestIDX];
      int backID = posData.backID[bestIDX];
      int nextBestIDX = posData.backIndex[bestIDX];

      if (outputCompounds && searchMode && altToken == null && backType != Type.USER) {

        // In searchMode, if best path had picked a too-long
        // token, we use the "penalty" to compute the allowed
        // max cost of an alternate back-trace.  If we find an
        // alternate back trace with cost below that
        // threshold, we pursue it instead (but also output
        // the long token).
        // System.out.println("    2nd best backPos=" + backPos + " pos=" + pos);

        final int penalty = computeSecondBestThreshold(backPos, pos - backPos);

        if (penalty > 0) {
          if (VERBOSE) {
            System.out.println(
                "  compound="
                    + new String(buffer.get(backPos, pos - backPos))
                    + " backPos="
                    + backPos
                    + " pos="
                    + pos
                    + " penalty="
                    + penalty
                    + " cost="
                    + posData.costs[bestIDX]
                    + " bestIDX="
                    + bestIDX
                    + " lastLeftID="
                    + lastLeftWordID);
          }

          // Use the penalty to set maxCost on the 2nd best
          // segmentation:
          int maxCost = posData.costs[bestIDX] + penalty;
          if (lastLeftWordID != -1) {
            maxCost += costs.get(getDict(backType).getRightId(backID), lastLeftWordID);
          }

          // Now, prune all too-long tokens from the graph:
          pruneAndRescore(backPos, pos, posData.backIndex[bestIDX]);

          // Finally, find 2nd best back-trace and resume
          // backtrace there:
          int leastCost = Integer.MAX_VALUE;
          int leastIDX = -1;
          for (int idx = 0; idx < posData.count; idx++) {
            int cost = posData.costs[idx];
            // System.out.println("    idx=" + idx + " prevCost=" + cost);

            if (lastLeftWordID != -1) {
              cost +=
                  costs.get(
                      getDict(posData.backType[idx]).getRightId(posData.backID[idx]),
                      lastLeftWordID);
              // System.out.println("      += bgCost=" +
              // costs.get(getDict(posData.backType[idx]).getRightId(posData.backID[idx]),
              // lastLeftWordID) + " -> " + cost);
            }
            // System.out.println("penalty " + posData.backPos[idx] + " to " + pos);
            // cost += computePenalty(posData.backPos[idx], pos - posData.backPos[idx]);
            if (cost < leastCost) {
              // System.out.println("      ** ");
              leastCost = cost;
              leastIDX = idx;
            }
          }
          // System.out.println("  leastIDX=" + leastIDX);

          if (VERBOSE) {
            System.out.println(
                "  afterPrune: "
                    + posData.count
                    + " arcs arriving; leastCost="
                    + leastCost
                    + " vs threshold="
                    + maxCost
                    + " lastLeftWordID="
                    + lastLeftWordID);
          }

          if (leastIDX != -1 && leastCost <= maxCost && posData.backPos[leastIDX] != backPos) {
            // We should have pruned the altToken from the graph:
            assert posData.backPos[leastIDX] != backPos;

            // Save the current compound token, to output when
            // this alternate path joins back:
            altToken =
                new Token(
                    backID,
                    fragment,
                    backPos - lastBackTracePos,
                    length,
                    backType,
                    backPos,
                    getDict(backType));

            // Redirect our backtrace to 2nd best:
            bestIDX = leastIDX;
            nextBestIDX = posData.backIndex[bestIDX];

            backPos = posData.backPos[bestIDX];
            length = pos - backPos;
            backType = posData.backType[bestIDX];
            backID = posData.backID[bestIDX];
            backCount = 0;
            // System.out.println("  do alt token!");

          } else {
            // I think in theory it's possible there is no
            // 2nd best path, which is fine; in this case we
            // only output the compound token:
            // System.out.println("  no alt token! bestIDX=" + bestIDX);
          }
        }
      }

      final int offset = backPos - lastBackTracePos;
      assert offset >= 0;

      if (altToken != null && altToken.getPosition() >= backPos) {

        // We've backtraced to the position where the
        // compound token starts; add it now:

        // The pruning we did when we created the altToken
        // ensures that the back trace will align back with
        // the start of the altToken:
        assert altToken.getPosition() == backPos : altToken.getPosition() + " vs " + backPos;

        // NOTE: not quite right: the compound token may
        // have had all punctuation back traced so far, but
        // then the decompounded token at this position is
        // not punctuation.  In this case backCount is 0,
        // but we should maybe add the altToken anyway...?

        if (backCount > 0) {
          backCount++;
          altToken.setPositionLength(backCount);
          if (VERBOSE) {
            System.out.println("    add altToken=" + altToken);
          }
          pending.add(altToken);
        } else {
          // This means alt token was all punct tokens:
          if (VERBOSE) {
            System.out.println("    discard all-punctuation altToken=" + altToken);
          }
          assert discardPunctuation;
        }
        altToken = null;
      }

      final Dictionary dict = getDict(backType);

      if (backType == Type.USER) {

        // Expand the phraseID we recorded into the actual
        // segmentation:
        final int[] wordIDAndLength = userDictionary.lookupSegmentation(backID);
        int wordID = wordIDAndLength[0];
        int current = 0;
        for (int j = 1; j < wordIDAndLength.length; j++) {
          final int len = wordIDAndLength[j];
          // System.out.println("    add user: len=" + len);
          pending.add(
              new Token(
                  wordID + j - 1,
                  fragment,
                  current + offset,
                  len,
                  Type.USER,
                  current + backPos,
                  dict));
          if (VERBOSE) {
            System.out.println("    add USER token=" + pending.get(pending.size() - 1));
          }
          current += len;
        }

        // Reverse the tokens we just added, because when we
        // serve them up from incrementToken we serve in
        // reverse:
        Collections.reverse(
            pending.subList(pending.size() - (wordIDAndLength.length - 1), pending.size()));

        backCount += wordIDAndLength.length - 1;
      } else {

        if (extendedMode && backType == Type.UNKNOWN) {
          // In EXTENDED mode we convert unknown word into
          // unigrams:
          int unigramTokenCount = 0;
          for (int i = length - 1; i >= 0; i--) {
            int charLen = 1;
            if (i > 0 && Character.isLowSurrogate(fragment[offset + i])) {
              i--;
              charLen = 2;
            }
            // System.out.println("    extended tok offset="
            // + (offset + i));
            if (!discardPunctuation || !isPunctuation(fragment[offset + i])) {
              pending.add(
                  new Token(
                      CharacterDefinition.NGRAM,
                      fragment,
                      offset + i,
                      charLen,
                      Type.UNKNOWN,
                      backPos + i,
                      unkDictionary));
              unigramTokenCount++;
            }
          }
          backCount += unigramTokenCount;

        } else if (!discardPunctuation || length == 0 || !isPunctuation(fragment[offset])) {
          pending.add(new Token(backID, fragment, offset, length, backType, backPos, dict));
          if (VERBOSE) {
            System.out.println("    add token=" + pending.get(pending.size() - 1));
          }
          backCount++;
        } else {
          if (VERBOSE) {
            System.out.println(
                "    skip punctuation token=" + new String(fragment, offset, length));
          }
        }
      }

      lastLeftWordID = dict.getLeftId(backID);
      pos = backPos;
      bestIDX = nextBestIDX;
    }

    lastBackTracePos = endPos;

    if (VERBOSE) {
      System.out.println("  freeBefore pos=" + endPos);
    }
    // Notify the circular buffers that we are done with
    // these positions:
    buffer.freeBefore(endPos);
    positions.freeBefore(endPos);
  }