/** * TimeAveraging as part of Elastic Averaging Algorithm Cf. equation 6 of arXiv:1412.6651v5 * * @param nodeAverageModel current average of per-node models * @return Time-average of node-averages (consensus model, "the" model) */ public static DeepLearningModelInfo timeAverage(DeepLearningModelInfo nodeAverageModel) { float pa = (float) nodeAverageModel.get_params()._elastic_averaging_moving_rate; assert (pa > 0 && pa <= 1); DeepLearningModelInfo elasticAverage = DKV.getGet(nodeAverageModel.elasticAverageModelInfoKey()); // get latest version from DKV if (elasticAverage == null || pa == 1) { elasticAverage = nodeAverageModel.deep_clone(); } else { nodeAverageModel.mult(pa); elasticAverage.mult(1 - pa); elasticAverage.add(nodeAverageModel); // ignore processed local value set here elasticAverage.set_processed_global(nodeAverageModel.get_processed_global()); } elasticAverage.set_processed_local(0); DKV.put(elasticAverage.elasticAverageModelInfoKey(), elasticAverage); // nodeAverageModel.computeStats(); // elasticAverage.computeStats(); // Log.info("Local Model :\n" + nodeAverageModel.toString()); // Log.info("Elastic Average:\n" + elasticAverage.toString()); return elasticAverage; }