@SneakyThrows public static void main(String[] args) { val opts = new Opts(); val cmdLineParser = new CmdLineParser(opts); try { cmdLineParser.parseArgument(args); } catch (CmdLineException e) { cmdLineParser.printUsage(System.err); System.exit(2); } trainAndSaveModel(opts); }
@SneakyThrows public static void trainAndSaveModel(Opts opts) { // Load labeled data List<String> templateLines = linesFromPath(opts.templateFile).collect(toList()); val predExtractor = ConllFormat.predicatesFromTemplate(templateLines.stream()); List<List<Pair<ConllFormat.Row, String>>> labeledData = ConllFormat.readData(linesFromPath(opts.trainPath), true) .stream() .map(x -> x.stream().map(y -> y.asLabeledPair().swap()).collect(Collectors.toList())) .collect(Collectors.toList()); // Split train/test data logger.info( "CRF training with {} threads and {} labeled examples", opts.numThreads, labeledData.size()); val trainTestPair = splitData(labeledData, opts.testSplitRatio); List<List<Pair<ConllFormat.Row, String>>> trainLabeledData = trainTestPair.getOne(); List<List<Pair<ConllFormat.Row, String>>> testLabeledData = trainTestPair.getTwo(); // Set up Train options CRFTrainer.Opts trainOpts = new CRFTrainer.Opts(); trainOpts.sigmaSq = opts.sigmaSquared; trainOpts.lbfgsHistorySize = opts.lbfgsHistorySize; trainOpts.minExpectedFeatureCount = (int) (1.0 / opts.featureKeepProb); trainOpts.numThreads = opts.numThreads; // Trainer CRFTrainer<String, ConllFormat.Row, String> trainer = new CRFTrainer<>(trainLabeledData, predExtractor, trainOpts); // Setup iteration callback, weird trick here where you require // the trainer to make a model for each iteration but then need // to modify the iteration-callback to use it Parallel.MROpts evalMrOpts = Parallel.MROpts.withIdAndThreads("mr-crf-train-eval", opts.numThreads); trainOpts.optimizerOpts.iterCallback = (weights) -> { CRFModel<String, ConllFormat.Row, String> crfModel = trainer.modelForWeights(weights); long start = System.currentTimeMillis(); List<List<Pair<String, ConllFormat.Row>>> trainEvalData = trainLabeledData .stream() .map(x -> x.stream().map(Pair::swap).collect(toList())) .collect(toList()); Evaluation<String> eval = Evaluation.compute(crfModel, trainEvalData, evalMrOpts); long stop = System.currentTimeMillis(); logger.info( "Train Accuracy: {} (took {} ms)", eval.tokenAccuracy.accuracy(), stop - start); if (!testLabeledData.isEmpty()) { start = System.currentTimeMillis(); List<List<Pair<String, ConllFormat.Row>>> testEvalData = testLabeledData .stream() .map(x -> x.stream().map(Pair::swap).collect(toList())) .collect(toList()); eval = Evaluation.compute(crfModel, testEvalData, evalMrOpts); stop = System.currentTimeMillis(); logger.info( "Test Accuracy: {} (took {} ms)", eval.tokenAccuracy.accuracy(), stop - start); } }; CRFModel<String, ConllFormat.Row, String> crfModel = trainer.train(trainLabeledData); Parallel.shutdownExecutor(evalMrOpts.executorService, Long.MAX_VALUE); Vector weights = crfModel.weights(); val dos = new DataOutputStream(new FileOutputStream(opts.modelPath)); logger.info("Writing model to {}", opts.modelPath); ConllFormat.saveModel(dos, templateLines, crfModel.featureEncoder, weights); }