public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException { FileSystem fs = output.getFileSystem(conf); Vector weightsPerLabel = null; Vector perLabelThetaNormalizer = null; Vector weightsPerFeature = null; Matrix weightsPerLabelAndFeature; float alphaI; FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin")); try { alphaI = in.readFloat(); weightsPerFeature = VectorWritable.readVector(in); weightsPerLabel = VectorWritable.readVector(in); perLabelThetaNormalizer = VectorWritable.readVector(in); weightsPerLabelAndFeature = new SparseMatrix(weightsPerLabel.size(), weightsPerFeature.size()); for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) { weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in)); } } finally { Closeables.closeQuietly(in); } NaiveBayesModel model = new NaiveBayesModel( weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel, perLabelThetaNormalizer, alphaI); model.validate(); return model; }
@Override public int run(String[] args) throws Exception { String path = System.getProperty("user.dir"); addInputOption(); addOutputOption(); addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f)); addOption( buildOption( TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false))); addOption(LABEL_INDEX, "li", "The path to store the label index in", false); addOption(DefaultOptionCreator.overwriteOption().create()); Path labPath = new Path(path + "/../out/labelindex/"); long labelSize = createLabelIndex(labPath); float alphaI = 1.0F; boolean trainComplementary = true; HadoopUtil.setSerializations(getConf()); HadoopUtil.cacheFiles(labPath, getConf()); HadoopUtil.delete(getConf(), new Path("/tmp/summedObservations")); HadoopUtil.delete(getConf(), new Path("/tmp/weights")); HadoopUtil.delete(getConf(), new Path("/tmp/thetas")); // Add up all the vectors with the same labels, while mapping the labels into our index Job indexInstances = prepareJob( new Path(path + "/../out/training"), new Path("/tmp/summedObservations"), SequenceFileInputFormat.class, IndexInstancesMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class); indexInstances.setCombinerClass(VectorSumReducer.class); boolean succeeded = indexInstances.waitForCompletion(true); if (!succeeded) { return -1; } // Sum up all the weights from the previous step, per label and per feature Job weightSummer = prepareJob( new Path("/tmp/summedObservations"), new Path("/tmp/weights"), SequenceFileInputFormat.class, WeightsMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class); weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize)); weightSummer.setCombinerClass(VectorSumReducer.class); succeeded = weightSummer.waitForCompletion(true); if (!succeeded) { return -1; } // Put the per label and per feature vectors into the cache HadoopUtil.cacheFiles(new Path("/tmp/weights"), getConf()); if (trainComplementary) { // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight // Magnitude Errors Job thetaSummer = prepareJob( new Path("/tmp/summedObservations"), new Path("/tmp/thetas"), SequenceFileInputFormat.class, ThetaMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class); thetaSummer.setCombinerClass(VectorSumReducer.class); thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI); thetaSummer .getConfiguration() .setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary); succeeded = thetaSummer.waitForCompletion(true); if (!succeeded) { return -1; } } // Put the per label theta normalizers into the cache HadoopUtil.cacheFiles(new Path("/tmp/thetas"), getConf()); // Validate our model and then write it out to the official output getConf().setFloat(ThetaMapper.ALPHA_I, alphaI); getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary); NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(new Path("/tmp/"), getConf()); naiveBayesModel.validate(); naiveBayesModel.serialize(new Path(path + "/../out/model"), getConf()); return 0; }
public static void main(String[] args) throws Exception { if (args.length < 5) { System.out.println( "Arguments: [model] [label index] [dictionnary] [document frequency] [Customer description]"); return; } String modelPath = args[0]; String labelIndexPath = args[1]; String dictionaryPath = args[2]; String documentFrequencyPath = args[3]; String carsPath = args[4]; Configuration configuration = new Configuration(); // model is a matrix (wordId, labelId) => probability score NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration); StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model); // labels is a map label => classId Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath)); Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath)); Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration, new Path(documentFrequencyPath)); // analyzer used to extract word from tweet Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_43); int labelCount = labels.size(); int documentCount = documentFrequency.get(-1).intValue(); System.out.println("Number of labels: " + labelCount); System.out.println("Number of documents in training set: " + documentCount); BufferedReader reader = new BufferedReader(new FileReader(carsPath)); while (true) { String line = reader.readLine(); if (line == null) { break; } String[] tokens = line.split("\t", 47); String cmplid = tokens[0]; String cdescr = tokens[19]; System.out.println("Complaint id: " + cmplid + "\t" + cdescr); Multiset<String> words = ConcurrentHashMultiset.create(); // extract words from complaint description TokenStream ts = analyzer.tokenStream("text", new StringReader(cdescr)); CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class); ts.reset(); int wordCount = 0; while (ts.incrementToken()) { if (termAtt.length() > 0) { String word = ts.getAttribute(CharTermAttribute.class).toString(); Integer wordId = dictionary.get(word); // if the word is not in the dictionary, skip it if (wordId != null) { words.add(word); wordCount++; } } } // create vector wordId => weight using tfidf Vector vector = new RandomAccessSparseVector(1000); TFIDF tfidf = new TFIDF(); for (Multiset.Entry<String> entry : words.entrySet()) { String word = entry.getElement(); int count = entry.getCount(); Integer wordId = dictionary.get(word); Long freq = documentFrequency.get(wordId); double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount); vector.setQuick(wordId, tfIdfValue); } // With the classifier, we get one score for each label // The label with the highest score is the one the tweet is more likely to // be associated to Vector resultVector = classifier.classifyFull(vector); double bestScore = -Double.MAX_VALUE; int bestCategoryId = -1; for (Element element : resultVector.all()) { int categoryId = element.index(); double score = element.get(); if (score > bestScore) { bestScore = score; bestCategoryId = categoryId; } System.out.print(" " + labels.get(categoryId) + ": " + score); } System.out.println(" => " + labels.get(bestCategoryId)); } analyzer.close(); reader.close(); }