@Override
  public Iterable<Double> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
      return Collections.singletonList(0.0);
    }
    List<MultiDataSet> collect = new ArrayList<>();
    while (dataSetIterator.hasNext()) {
      collect.add(dataSetIterator.next());
    }

    MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);

    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    INDArray val =
        params
            .value(); // .value() is shared by all executors on single machine -> OK, as params are
                      // not changed in score function
    if (val.length() != network.numParams(false))
      throw new IllegalStateException(
          "Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);

    double score = network.score(data, false);
    if (network.conf().isMiniBatch()) score *= data.getFeatures(0).size(0);
    return Collections.singletonList(score);
  }
 private ComputationGraph load(String confOut, String paramOut, String updaterOut)
     throws IOException {
   String confJSON = FileUtils.readFileToString(new File(confOut), encoding);
   INDArray params;
   ComputationGraphUpdater updater;
   try (DataInputStream dis = new DataInputStream(Files.newInputStream(Paths.get(paramOut)))) {
     params = Nd4j.read(dis);
   }
   try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(new File(updaterOut)))) {
     updater = (ComputationGraphUpdater) ois.readObject();
   } catch (ClassNotFoundException e) {
     throw new RuntimeException(e); // Should never happen
   }
   ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(confJSON);
   ComputationGraph net = new ComputationGraph(conf);
   net.init();
   net.setParams(params);
   net.setUpdater(updater);
   return net;
 }