@Override protected ComputeResult doCompute(List<Data> dataset, Parameter parameter) { AnnModel annModel = (AnnModel) ((Object[]) parameter.getSerializable())[0]; int accurate = 0; double totalLoss = 0; double activationFunctionCenterValue = 0.5; if (annModel.getConfiguration().activationFunctionOfLayer(annModel.getLayerCount() - 1) == ActivationFunction.TANH) { activationFunctionCenterValue = 0; } AnnTrainer trainer = new AnnTrainer(); double[][] sum = new double[annModel.getLayerCount()][]; for (Data data : dataset) { double loss = 0; boolean error = false; AnnData annData = (AnnData) data.getSerializable(); double[][] output = trainer.run(annModel, annData.getInput(), sum); for (int i = 0; i < output[output.length - 1].length; i++) { double expected = annData.getOutput()[i]; double actual = output[output.length - 1][i]; loss += Math.abs(expected - actual); if ((expected - activationFunctionCenterValue) * (actual - activationFunctionCenterValue) < 0) { error = true; break; } } if (!error) { accurate++; } totalLoss += loss / output[output.length - 1].length; } LogUtility.logAnnModel(getLogger(), annModel); System.out.println("Loss = " + (totalLoss / dataset.size())); String message = "Classification precision: " + 100.0 * accurate / dataset.size() + "%"; getLogger().info(message); System.out.println(message); // LogUtility.logAnnModel(getLogger(), annModel); // return computation result ComputeResult result = new ComputeResult(); result.setRepeat(true); result.setRepeatDelayInSeconds(EVALUATION_INTERVAL_IN_SECONDS); result.setGradient(null); result.setAudit(false); return result; }
@Override public List<Data> produceData(int size) throws DataProducerException { // add negative instances List<Data> dataset = Lists.newArrayList(); while (dataset.size() < size) { for (Data data : super.produceData(size / 2)) { WordEmbeddingTrainingInstance positiveInstance = (WordEmbeddingTrainingInstance) data.getSerializable(); // validate positive instance boolean invalidWindow = false; for (int index : positiveInstance.getInput()) { invalidWindow = isWordInvalid(index); if (invalidWindow) { break; } } if (invalidWindow) { continue; } positiveInstance.setOutput(1); dataset.add(data); WordEmbeddingTrainingInstance negativeInstance = new WordEmbeddingTrainingInstance(); negativeInstance.setInput(Lists.newArrayList(positiveInstance.getInput())); negativeInstance.setOutput(-1); // generate negative word int negativeWord = this.dictionary.sampleWordUniformlyAboveFrequenceRank(this.frequencyRankBound); while (negativeWord == negativeInstance.getInput().get(negativeInstance.getInput().size() / 2)) { negativeWord = this.dictionary.sampleWordUniformlyAboveFrequenceRank(this.frequencyRankBound); } // set negative word negativeInstance.getInput().set(negativeInstance.getInput().size() / 2, negativeWord); dataset.add(new Data(negativeInstance)); } } Collections.shuffle(dataset); return dataset; }