@Override
  protected ComputeResult doCompute(List<Data> dataset, Parameter parameter) {
    AnnModel annModel = (AnnModel) ((Object[]) parameter.getSerializable())[0];

    int accurate = 0;
    double totalLoss = 0;

    double activationFunctionCenterValue = 0.5;
    if (annModel.getConfiguration().activationFunctionOfLayer(annModel.getLayerCount() - 1)
        == ActivationFunction.TANH) {
      activationFunctionCenterValue = 0;
    }

    AnnTrainer trainer = new AnnTrainer();
    double[][] sum = new double[annModel.getLayerCount()][];
    for (Data data : dataset) {
      double loss = 0;
      boolean error = false;
      AnnData annData = (AnnData) data.getSerializable();
      double[][] output = trainer.run(annModel, annData.getInput(), sum);
      for (int i = 0; i < output[output.length - 1].length; i++) {
        double expected = annData.getOutput()[i];
        double actual = output[output.length - 1][i];
        loss += Math.abs(expected - actual);
        if ((expected - activationFunctionCenterValue) * (actual - activationFunctionCenterValue)
            < 0) {
          error = true;
          break;
        }
      }

      if (!error) {
        accurate++;
      }

      totalLoss += loss / output[output.length - 1].length;
    }

    LogUtility.logAnnModel(getLogger(), annModel);
    System.out.println("Loss = " + (totalLoss / dataset.size()));

    String message = "Classification precision: " + 100.0 * accurate / dataset.size() + "%";
    getLogger().info(message);
    System.out.println(message);

    // LogUtility.logAnnModel(getLogger(), annModel);

    // return computation result
    ComputeResult result = new ComputeResult();
    result.setRepeat(true);
    result.setRepeatDelayInSeconds(EVALUATION_INTERVAL_IN_SECONDS);
    result.setGradient(null);
    result.setAudit(false);

    return result;
  }
Exemplo n.º 2
0
  @Override
  public List<Data> produceData(int size) throws DataProducerException {
    // add negative instances
    List<Data> dataset = Lists.newArrayList();
    while (dataset.size() < size) {
      for (Data data : super.produceData(size / 2)) {
        WordEmbeddingTrainingInstance positiveInstance =
            (WordEmbeddingTrainingInstance) data.getSerializable();
        // validate positive instance
        boolean invalidWindow = false;
        for (int index : positiveInstance.getInput()) {
          invalidWindow = isWordInvalid(index);
          if (invalidWindow) {
            break;
          }
        }

        if (invalidWindow) {
          continue;
        }

        positiveInstance.setOutput(1);
        dataset.add(data);

        WordEmbeddingTrainingInstance negativeInstance = new WordEmbeddingTrainingInstance();
        negativeInstance.setInput(Lists.newArrayList(positiveInstance.getInput()));
        negativeInstance.setOutput(-1);
        // generate negative word
        int negativeWord =
            this.dictionary.sampleWordUniformlyAboveFrequenceRank(this.frequencyRankBound);
        while (negativeWord
            == negativeInstance.getInput().get(negativeInstance.getInput().size() / 2)) {
          negativeWord =
              this.dictionary.sampleWordUniformlyAboveFrequenceRank(this.frequencyRankBound);
        }
        // set negative word
        negativeInstance.getInput().set(negativeInstance.getInput().size() / 2, negativeWord);
        dataset.add(new Data(negativeInstance));
      }
    }

    Collections.shuffle(dataset);

    return dataset;
  }