/** * @deprecated Use {@link TrainerFactory#getEventTrainer(Map, Map)} to get an {@link EventTrainer} * instead. */ public static MaxentModel train( ObjectStream<Event> events, Map<String, String> trainParams, Map<String, String> reportMap) throws IOException { if (!TrainerFactory.isSupportEvent(trainParams)) { throw new IllegalArgumentException("EventTrain is not supported"); } EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, reportMap); return trainer.train(events); }
public static POSModel train( String languageCode, ObjectStream<POSSample> samples, TrainingParameters trainParams, POSTaggerFactory posFactory) throws IOException { String beamSizeString = trainParams.getSettings().get(BeamSearch.BEAM_SIZE_PARAMETER); int beamSize = POSTaggerME.DEFAULT_BEAM_SIZE; if (beamSizeString != null) { beamSize = Integer.parseInt(beamSizeString); } POSContextGenerator contextGenerator = posFactory.getPOSContextGenerator(); Map<String, String> manifestInfoEntries = new HashMap<String, String>(); TrainerType trainerType = TrainerFactory.getTrainerType(trainParams.getSettings()); MaxentModel posModel = null; SequenceClassificationModel<String> seqPosModel = null; if (TrainerType.EVENT_MODEL_TRAINER.equals(trainerType)) { ObjectStream<Event> es = new POSSampleEventStream(samples, contextGenerator); EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams.getSettings(), manifestInfoEntries); posModel = trainer.train(es); } else if (TrainerType.EVENT_MODEL_SEQUENCE_TRAINER.equals(trainerType)) { POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator); EventModelSequenceTrainer trainer = TrainerFactory.getEventModelSequenceTrainer( trainParams.getSettings(), manifestInfoEntries); posModel = trainer.train(ss); } else if (TrainerType.SEQUENCE_TRAINER.equals(trainerType)) { SequenceTrainer trainer = TrainerFactory.getSequenceModelTrainer(trainParams.getSettings(), manifestInfoEntries); // TODO: This will probably cause issue, since the feature generator uses the outcomes array POSSampleSequenceStream ss = new POSSampleSequenceStream(samples, contextGenerator); seqPosModel = trainer.train(ss); } else { throw new IllegalArgumentException("Trainer type is not supported: " + trainerType); } if (posModel != null) { return new POSModel(languageCode, posModel, beamSize, manifestInfoEntries, posFactory); } else { return new POSModel(languageCode, seqPosModel, manifestInfoEntries, posFactory); } }
/** * Load the parameters in the {@code TrainingParameters} file. * * @param paramFile the parameter file * @param supportSequenceTraining wheter sequence training is supported * @return the parameters */ private static TrainingParameters loadTrainingParameters( final String paramFile, final boolean supportSequenceTraining) { TrainingParameters params = null; if (paramFile != null) { checkInputFile("Training Parameter", new File(paramFile)); InputStream paramsIn = null; try { paramsIn = new FileInputStream(new File(paramFile)); params = new opennlp.tools.util.TrainingParameters(paramsIn); } catch (IOException e) { throw new TerminateToolException( -1, "Error during parameters loading: " + e.getMessage(), e); } finally { try { if (paramsIn != null) { paramsIn.close(); } } catch (IOException e) { System.err.println("Error closing the input stream"); } } if (!TrainerFactory.isValid(params.getSettings())) { throw new TerminateToolException( 1, "Training parameters file '" + paramFile + "' is invalid!"); } } return params; }
/** * Detects if the training algorithm requires sequence based feature generation or not. * * @deprecated Use {@link TrainerFactory#isSequenceTraining(Map)} instead. */ public static boolean isSequenceTraining(Map<String, String> trainParams) { return TrainerFactory.isSupportSequence(trainParams); }
/** @deprecated Use {@link TrainerFactory#isValid(Map)} instead. */ public static boolean isValid(Map<String, String> trainParams) { return TrainerFactory.isValid(trainParams); }
public void run(String format, String[] args) { super.run(format, args); mlParams = CmdLineUtil.loadTrainingParameters(params.getParams(), true); if (mlParams != null && !TrainerFactory.isValid(mlParams.getSettings())) { throw new TerminateToolException( 1, "Training parameters file '" + params.getParams() + "' is invalid!"); } if (mlParams == null) { mlParams = ModelUtil.createDefaultTrainingParameters(); mlParams.put(TrainingParameters.ALGORITHM_PARAM, getModelType(params.getType()).toString()); } File modelOutFile = params.getModel(); CmdLineUtil.checkOutputFile("pos tagger model", modelOutFile); Dictionary ngramDict = null; Integer ngramCutoff = params.getNgram(); if (ngramCutoff != null) { System.err.print("Building ngram dictionary ... "); try { ngramDict = POSTaggerME.buildNGramDictionary(sampleStream, ngramCutoff); sampleStream.reset(); } catch (IOException e) { throw new TerminateToolException( -1, "IO error while building NGram Dictionary: " + e.getMessage(), e); } System.err.println("done"); } POSTaggerFactory postaggerFactory = null; try { postaggerFactory = POSTaggerFactory.create(params.getFactory(), ngramDict, null); } catch (InvalidFormatException e) { throw new TerminateToolException(-1, e.getMessage(), e); } if (params.getDict() != null) { try { postaggerFactory.setTagDictionary(postaggerFactory.createTagDictionary(params.getDict())); } catch (IOException e) { throw new TerminateToolException( -1, "IO error while loading POS Dictionary: " + e.getMessage(), e); } } if (params.getTagDictCutoff() != null) { try { TagDictionary dict = postaggerFactory.getTagDictionary(); if (dict == null) { dict = postaggerFactory.createEmptyTagDictionary(); postaggerFactory.setTagDictionary(dict); } if (dict instanceof MutableTagDictionary) { POSTaggerME.populatePOSDictionary( sampleStream, (MutableTagDictionary) dict, params.getTagDictCutoff()); } else { throw new IllegalArgumentException( "Can't extend a POSDictionary that does not implement MutableTagDictionary."); } sampleStream.reset(); } catch (IOException e) { throw new TerminateToolException( -1, "IO error while creating/extending POS Dictionary: " + e.getMessage(), e); } } POSModel model; try { model = opennlp.tools.postag.POSTaggerME.train( params.getLang(), sampleStream, mlParams, postaggerFactory); } catch (IOException e) { throw new TerminateToolException( -1, "IO error while reading training data or indexing data: " + e.getMessage(), e); } finally { try { sampleStream.close(); } catch (IOException e) { // sorry that this can fail } } CmdLineUtil.writeModel("pos tagger", modelOutFile, model); }