private long createLabelIndex(Path labPath) throws IOException {
    long labelSize = 0;
    String path = System.getProperty("user.dir");

    Iterable<Pair<Text, IntWritable>> iterable =
        new SequenceFileDirIterable<Text, IntWritable>(
            new Path(path + "/../out/training"),
            PathType.LIST,
            PathFilters.logsCRCFilter(),
            getConf());
    labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
    return labelSize;
  }
  @Override
  public int run(String[] args) throws Exception {
    String path = System.getProperty("user.dir");

    addInputOption();
    addOutputOption();

    addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f));
    addOption(
        buildOption(
            TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false)));
    addOption(LABEL_INDEX, "li", "The path to store the label index in", false);
    addOption(DefaultOptionCreator.overwriteOption().create());

    Path labPath = new Path(path + "/../out/labelindex/");

    long labelSize = createLabelIndex(labPath);
    float alphaI = 1.0F;
    boolean trainComplementary = true;

    HadoopUtil.setSerializations(getConf());
    HadoopUtil.cacheFiles(labPath, getConf());
    HadoopUtil.delete(getConf(), new Path("/tmp/summedObservations"));
    HadoopUtil.delete(getConf(), new Path("/tmp/weights"));
    HadoopUtil.delete(getConf(), new Path("/tmp/thetas"));

    // Add up all the vectors with the same labels, while mapping the labels into our index
    Job indexInstances =
        prepareJob(
            new Path(path + "/../out/training"),
            new Path("/tmp/summedObservations"),
            SequenceFileInputFormat.class,
            IndexInstancesMapper.class,
            IntWritable.class,
            VectorWritable.class,
            VectorSumReducer.class,
            IntWritable.class,
            VectorWritable.class,
            SequenceFileOutputFormat.class);
    indexInstances.setCombinerClass(VectorSumReducer.class);
    boolean succeeded = indexInstances.waitForCompletion(true);
    if (!succeeded) {
      return -1;
    }
    // Sum up all the weights from the previous step, per label and per feature
    Job weightSummer =
        prepareJob(
            new Path("/tmp/summedObservations"),
            new Path("/tmp/weights"),
            SequenceFileInputFormat.class,
            WeightsMapper.class,
            Text.class,
            VectorWritable.class,
            VectorSumReducer.class,
            Text.class,
            VectorWritable.class,
            SequenceFileOutputFormat.class);
    weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize));
    weightSummer.setCombinerClass(VectorSumReducer.class);
    succeeded = weightSummer.waitForCompletion(true);
    if (!succeeded) {
      return -1;
    }

    // Put the per label and per feature vectors into the cache
    HadoopUtil.cacheFiles(new Path("/tmp/weights"), getConf());

    if (trainComplementary) {
      // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector
      // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight
      // Magnitude Errors
      Job thetaSummer =
          prepareJob(
              new Path("/tmp/summedObservations"),
              new Path("/tmp/thetas"),
              SequenceFileInputFormat.class,
              ThetaMapper.class,
              Text.class,
              VectorWritable.class,
              VectorSumReducer.class,
              Text.class,
              VectorWritable.class,
              SequenceFileOutputFormat.class);
      thetaSummer.setCombinerClass(VectorSumReducer.class);
      thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
      thetaSummer
          .getConfiguration()
          .setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
      succeeded = thetaSummer.waitForCompletion(true);
      if (!succeeded) {
        return -1;
      }
    }

    // Put the per label theta normalizers into the cache
    HadoopUtil.cacheFiles(new Path("/tmp/thetas"), getConf());

    // Validate our model and then write it out to the official output
    getConf().setFloat(ThetaMapper.ALPHA_I, alphaI);
    getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary);
    NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(new Path("/tmp/"), getConf());
    naiveBayesModel.validate();
    naiveBayesModel.serialize(new Path(path + "/../out/model"), getConf());

    return 0;
  }
Exemple #3
0
  public static void main(String[] args) throws Exception {
    if (args.length < 5) {
      System.out.println(
          "Arguments: [model] [label index] [dictionnary] [document frequency] [Customer description]");
      return;
    }
    String modelPath = args[0];
    String labelIndexPath = args[1];
    String dictionaryPath = args[2];
    String documentFrequencyPath = args[3];
    String carsPath = args[4];

    Configuration configuration = new Configuration();

    // model is a matrix (wordId, labelId) => probability score
    NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);

    StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);

    // labels is a map label => classId
    Map<Integer, String> labels =
        BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath));
    Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath));
    Map<Integer, Long> documentFrequency =
        readDocumentFrequency(configuration, new Path(documentFrequencyPath));

    // analyzer used to extract word from tweet
    Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43);

    int labelCount = labels.size();
    int documentCount = documentFrequency.get(-1).intValue();

    System.out.println("Number of labels: " + labelCount);
    System.out.println("Number of documents in training set: " + documentCount);
    BufferedReader reader = new BufferedReader(new FileReader(carsPath));
    while (true) {
      String line = reader.readLine();
      if (line == null) {
        break;
      }

      String[] tokens = line.split("\t", 47);
      String cmplid = tokens[0];
      String cdescr = tokens[19];

      System.out.println("Complaint id: " + cmplid + "\t" + cdescr);

      Multiset<String> words = ConcurrentHashMultiset.create();

      // extract words from complaint description
      TokenStream ts = analyzer.tokenStream("text", new StringReader(cdescr));
      CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
      ts.reset();
      int wordCount = 0;
      while (ts.incrementToken()) {
        if (termAtt.length() > 0) {
          String word = ts.getAttribute(CharTermAttribute.class).toString();
          Integer wordId = dictionary.get(word);
          // if the word is not in the dictionary, skip it
          if (wordId != null) {
            words.add(word);
            wordCount++;
          }
        }
      }

      // create vector wordId => weight using tfidf
      Vector vector = new RandomAccessSparseVector(1000);
      TFIDF tfidf = new TFIDF();
      for (Multiset.Entry<String> entry : words.entrySet()) {
        String word = entry.getElement();
        int count = entry.getCount();
        Integer wordId = dictionary.get(word);
        Long freq = documentFrequency.get(wordId);
        double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount);
        vector.setQuick(wordId, tfIdfValue);
      }
      // With the classifier, we get one score for each label
      // The label with the highest score is the one the tweet is more likely to
      // be associated to
      Vector resultVector = classifier.classifyFull(vector);
      double bestScore = -Double.MAX_VALUE;
      int bestCategoryId = -1;
      for (Element element : resultVector.all()) {
        int categoryId = element.index();
        double score = element.get();
        if (score > bestScore) {
          bestScore = score;
          bestCategoryId = categoryId;
        }
        System.out.print("  " + labels.get(categoryId) + ": " + score);
      }
      System.out.println(" => " + labels.get(bestCategoryId));
    }
    analyzer.close();
    reader.close();
  }