private boolean checkForConvergence() { if (maxNumberOfIterations == currentIteration) { if (log.isInfoEnabled()) { log.info( formatLogString( "maximum number of iterations [" + currentIteration + "] reached, terminating...")); } return true; } if (convergenceAggregatorName != null) { @SuppressWarnings("unchecked") Aggregator<Value> aggregator = (Aggregator<Value>) aggregators.get(convergenceAggregatorName); if (aggregator == null) { throw new RuntimeException("Error: Aggregator for convergence criterion was null."); } Value aggregate = aggregator.getAggregate(); if (convergenceCriterion.isConverged(currentIteration, aggregate)) { if (log.isInfoEnabled()) { log.info( formatLogString( "convergence reached after [" + currentIteration + "] iterations, terminating...")); } return true; } } return false; }
private void onWorkerDoneEvent(WorkerDoneEvent workerDoneEvent) { if (this.endOfSuperstep) { throw new RuntimeException( "Encountered WorderDoneEvent when still in End-of-Superstep status."); } workerDoneEventCounter++; // if (log.isInfoEnabled()) { // log.info("Sync event handler received WorkerDoneEvent event (" + workerDoneEventCounter + // ")"); // } String[] aggNames = workerDoneEvent.getAggregatorNames(); Value[] aggregates = workerDoneEvent.getAggregates(userCodeClassLoader); if (aggNames.length != aggregates.length) { throw new RuntimeException("Inconsistent WorkerDoneEvent received!"); } for (int i = 0; i < aggNames.length; i++) { @SuppressWarnings("unchecked") Aggregator<Value> aggregator = (Aggregator<Value>) this.aggregators.get(aggNames[i]); aggregator.aggregate(aggregates[i]); } if (workerDoneEventCounter % numberOfEventsUntilEndOfSuperstep == 0) { endOfSuperstep = true; Thread.currentThread().interrupt(); } }
@Override public void invoke() throws Exception { userCodeClassLoader = LibraryCacheManager.getClassLoader(getEnvironment().getJobID()); TaskConfig taskConfig = new TaskConfig(getTaskConfiguration()); // store all aggregators this.aggregators = new HashMap<String, Aggregator<?>>(); for (AggregatorWithName<?> aggWithName : taskConfig.getIterationAggregators()) { aggregators.put(aggWithName.getName(), aggWithName.getAggregator()); } // store the aggregator convergence criterion if (taskConfig.usesConvergenceCriterion()) { convergenceCriterion = taskConfig.getConvergenceCriterion(); convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName(); Preconditions.checkNotNull(convergenceAggregatorName); } maxNumberOfIterations = taskConfig.getNumberOfIterations(); // set up the event handler int numEventsTillEndOfSuperstep = taskConfig.getNumberOfEventsUntilInterruptInIterativeGate(0); eventHandler = new SyncEventHandler(numEventsTillEndOfSuperstep, aggregators, userCodeClassLoader); headEventReader.subscribeToEvent(eventHandler, WorkerDoneEvent.class); IntegerRecord dummy = new IntegerRecord(); while (!terminationRequested()) { // notifyMonitor(IterationMonitoring.Event.SYNC_STARTING, currentIteration); if (log.isInfoEnabled()) { log.info(formatLogString("starting iteration [" + currentIteration + "]")); } // this call listens for events until the end-of-superstep is reached readHeadEventChannel(dummy); if (log.isInfoEnabled()) { log.info(formatLogString("finishing iteration [" + currentIteration + "]")); } if (checkForConvergence()) { if (log.isInfoEnabled()) { log.info( formatLogString( "signaling that all workers are to terminate in iteration [" + currentIteration + "]")); } requestTermination(); sendToAllWorkers(new TerminationEvent()); // notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration); } else { if (log.isInfoEnabled()) { log.info( formatLogString( "signaling that all workers are done in iteration [" + currentIteration + "]")); } AllWorkersDoneEvent allWorkersDoneEvent = new AllWorkersDoneEvent(aggregators); sendToAllWorkers(allWorkersDoneEvent); // reset all aggregators for (Aggregator<?> agg : aggregators.values()) { agg.reset(); } // notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration); currentIteration++; } } }