示例#1
0
  private InstanceList readFile() throws IOException {

    String NL = System.getProperty("line.separator");
    Scanner scanner = new Scanner(new FileInputStream(fileName), encoding);

    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
    pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+")));
    pipeList.add(new TokenSequence2FeatureSequence());

    InstanceList testing = new InstanceList(new SerialPipes(pipeList));

    try {
      while (scanner.hasNextLine()) {

        String text = scanner.nextLine();
        text = text.replaceAll("\\x0d", "");

        Pattern patten = Pattern.compile("^(.*?),(.*?),(.*)$");
        Matcher matcher = patten.matcher(text);

        if (matcher.find()) {
          docIds.add(matcher.group(1));
          testing.addThruPipe(new Instance(matcher.group(3), null, "test instance", null));
        }
      }
    } finally {
      scanner.close();
    }

    return testing;
  }
示例#2
0
  /** This is (mostly) copied from CRF4.java */
  public boolean[][] labelConnectionsIn(
      Alphabet outputAlphabet, InstanceList trainingSet, String start) {
    int numLabels = outputAlphabet.size();
    boolean[][] connections = new boolean[numLabels][numLabels];
    for (int i = 0; i < trainingSet.size(); i++) {
      Instance instance = trainingSet.getInstance(i);
      FeatureSequence output = (FeatureSequence) instance.getTarget();
      for (int j = 1; j < output.size(); j++) {
        int sourceIndex = outputAlphabet.lookupIndex(output.get(j - 1));
        int destIndex = outputAlphabet.lookupIndex(output.get(j));
        assert (sourceIndex >= 0 && destIndex >= 0);
        connections[sourceIndex][destIndex] = true;
      }
    }

    // Handle start state
    if (start != null) {
      int startIndex = outputAlphabet.lookupIndex(start);
      for (int j = 0; j < outputAlphabet.size(); j++) {
        connections[startIndex][j] = true;
      }
    }

    return connections;
  }
示例#3
0
  private InstanceList generateInstanceList() throws Exception {

    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
    pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+")));
    pipeList.add(new TokenSequence2FeatureSequence());

    Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
    InstanceList instances = new InstanceList(new SerialPipes(pipeList));
    instances.addThruPipe(
        new CsvIterator(
            fileReader,
            Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
            3,
            2,
            1)); // data, label, name fields

    return instances;
  }
示例#4
0
  public void doInference() {

    try {

      ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile));
      TopicInferencer inferencer = model.getInferencer();

      // TopicInferencer inferencer =
      //    TopicInferencer.read(new File(inferencerFile));

      // InstanceList testing = readFile();
      readFile();
      InstanceList testing = generateInstanceList(); // readFile();

      for (int i = 0; i < testing.size(); i++) {

        StringBuilder probabilities = new StringBuilder();
        double[] testProbabilities = inferencer.getSampledDistribution(testing.get(i), 10, 1, 5);

        ArrayList probabilityList = new ArrayList();

        for (int j = 0; j < testProbabilities.length; j++) {
          probabilityList.add(new Pair<Integer, Double>(j, testProbabilities[j]));
        }

        Collections.sort(probabilityList, new CustomComparator());

        for (int j = 0; j < testProbabilities.length && j < topN; j++) {
          if (j > 0) probabilities.append(" ");
          probabilities.append(
              ((Pair<Integer, Double>) probabilityList.get(j)).getFirst().toString()
                  + ","
                  + ((Pair<Integer, Double>) probabilityList.get(j)).getSecond().toString());
        }

        System.out.println(docIds.get(i) + "," + probabilities.toString());
      }

    } catch (Exception e) {
      e.printStackTrace();
      System.err.println(e.getMessage());
    }
  }
示例#5
0
  public void test() throws Exception {

    ParallelTopicModel model = ParallelTopicModel.read(new File(inferencerFile));
    TopicInferencer inferencer = model.getInferencer();

    ArrayList<Pipe> pipeList = new ArrayList<Pipe>();
    pipeList.add(new CharSequence2TokenSequence(Pattern.compile("\\p{L}\\p{L}+")));
    pipeList.add(new TokenSequence2FeatureSequence());

    InstanceList instances = new InstanceList(new SerialPipes(pipeList));
    Reader fileReader = new InputStreamReader(new FileInputStream(new File(fileName)), "UTF-8");
    instances.addThruPipe(
        new CsvIterator(
            fileReader,
            Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
            3,
            2,
            1)); // data, label, name fields
    double[] testProbabilities = inferencer.getSampledDistribution(instances.get(1), 10, 1, 5);
    for (int i = 0; i < 1000; i++) System.out.println(i + ": " + testProbabilities[i]);
  }
  public TestCRFPipe(String trainingFilename) throws IOException {

    ArrayList<Pipe> pipes = new ArrayList<Pipe>();

    PrintWriter out = new PrintWriter("test.out");

    int[][] conjunctions = new int[3][];
    conjunctions[0] = new int[] {-1};
    conjunctions[1] = new int[] {1};
    conjunctions[2] = new int[] {-2, -1};

    pipes.add(new SimpleTaggerSentence2TokenSequence());
    // pipes.add(new FeaturesInWindow("PREV-", -1, 1));
    // pipes.add(new FeaturesInWindow("NEXT-", 1, 2));
    pipes.add(new OffsetConjunctions(conjunctions));
    pipes.add(new TokenTextCharSuffix("C1=", 1));
    pipes.add(new TokenTextCharSuffix("C2=", 2));
    pipes.add(new TokenTextCharSuffix("C3=", 3));
    pipes.add(new RegexMatches("CAPITALIZED", Pattern.compile("^\\p{Lu}.*")));
    pipes.add(new RegexMatches("STARTSNUMBER", Pattern.compile("^[0-9].*")));
    pipes.add(new RegexMatches("HYPHENATED", Pattern.compile(".*\\-.*")));
    pipes.add(new RegexMatches("DOLLARSIGN", Pattern.compile("\\$.*")));
    pipes.add(new TokenFirstPosition("FIRSTTOKEN"));
    pipes.add(new TokenSequence2FeatureVectorSequence());
    pipes.add(new SequencePrintingPipe(out));

    Pipe pipe = new SerialPipes(pipes);

    InstanceList trainingInstances = new InstanceList(pipe);

    trainingInstances.addThruPipe(
        new LineGroupIterator(
            new BufferedReader(
                new InputStreamReader(new GZIPInputStream(new FileInputStream(trainingFilename)))),
            Pattern.compile("^\\s*$"),
            true));

    out.close();
  }
示例#7
0
  public static CRF4 createCRF(File trainingFile, CRFInfo crfInfo) throws FileNotFoundException {
    Reader trainingFileReader = new FileReader(trainingFile);

    // Create a pipe that we can use to convert the training
    // file to a feature vector sequence.
    Pipe p = new SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence();

    // The training file does contain tags (aka targets)
    p.setTargetProcessing(true);

    // Register the default tag with the pipe, by looking it up
    // in the targetAlphabet before we look up any other tag.
    p.getTargetAlphabet().lookupIndex(crfInfo.defaultLabel);

    // Create a new instancelist to hold the training data.
    InstanceList trainingData = new InstanceList(p);

    // Read in the training data.
    trainingData.add(new LineGroupIterator(trainingFileReader, Pattern.compile("^\\s*$"), true));

    // Create the CRF model.
    CRF4 crf = new CRF4(p, null);

    // Set various config options
    crf.setGaussianPriorVariance(crfInfo.gaussianVariance);
    crf.setTransductionType(crfInfo.transductionType);

    // Set up the model's states.
    if (crfInfo.stateInfoList != null) {
      Iterator stateIter = crfInfo.stateInfoList.iterator();
      while (stateIter.hasNext()) {
        CRFInfo.StateInfo state = (CRFInfo.StateInfo) stateIter.next();
        crf.addState(
            state.name,
            state.initialCost,
            state.finalCost,
            state.destinationNames,
            state.labelNames,
            state.weightNames);
      }
    } else if (crfInfo.stateStructure == CRFInfo.FULLY_CONNECTED_STRUCTURE)
      crf.addStatesForLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.HALF_CONNECTED_STRUCTURE)
      crf.addStatesForHalfLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.THREE_QUARTERS_CONNECTED_STRUCTURE)
      crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.BILABELS_STRUCTURE)
      crf.addStatesForBiLabelsConnectedAsIn(trainingData);
    else throw new RuntimeException("Unexpected state structure " + crfInfo.stateStructure);

    // Set up the weight groups.
    if (crfInfo.weightGroupInfoList != null) {
      Iterator wgIter = crfInfo.weightGroupInfoList.iterator();
      while (wgIter.hasNext()) {
        CRFInfo.WeightGroupInfo wg = (CRFInfo.WeightGroupInfo) wgIter.next();
        FeatureSelection fs =
            FeatureSelection.createFromRegex(
                crf.getInputAlphabet(), Pattern.compile(wg.featureSelectionRegex));
        crf.setFeatureSelection(crf.getWeightsIndex(wg.name), fs);
      }
    }

    // Train the CRF.
    crf.train(trainingData, null, null, null, crfInfo.maxIterations);

    return crf;
  }