@Override @SuppressWarnings("unchecked") public void onReceive(final Object message) throws Exception { if (message instanceof MoreWorkMessage) { if (stateTracker.getCurrent() != null && stateTracker.getCurrent().getClass().isAssignableFrom(UpdateableImpl.class)) { BaseMultiLayerNetwork current = (BaseMultiLayerNetwork) stateTracker.getCurrent().get(); if (current.getLayers() == null || current.getSigmoidLayers() == null) throw new IllegalStateException("Invalid model found when prompted to save.."); current.clearInput(); stateTracker.setCurrent(new UpdateableImpl(current)); if (stateTracker.hasBegun()) modelSaver.save(current); } else if (stateTracker .getCurrent() .get() .getClass() .isAssignableFrom(DeepAutoEncoder.class)) { DeepAutoEncoder current = (DeepAutoEncoder) stateTracker.getCurrent().get(); stateTracker.setCurrent(new UpdateableEncoderImpl(current)); if (stateTracker.hasBegun()) modelSaver.save(current); } } else if (message instanceof DistributedPubSubMediator.UnsubscribeAck || message instanceof DistributedPubSubMediator.SubscribeAck) { // reply mediator.tell( new DistributedPubSubMediator.Publish(ClusterListener.TOPICS, message), getSelf()); log.info("Sending sub/unsub over"); } else unhandled(message); }
/** * Whether to terminate or not * * @param epoch the current epoch * @return whether to terminate or not on the given epoch */ @Override public boolean shouldStop(int epoch) { if (!(epoch % validationEpochs == 0) || epoch < 2) return false; double score = network.score(); if (score < bestLoss) { if (score < bestLoss * improvementThreshold) { bestLoss = score; patience = Math.max(patience, epoch * patienceIncrease); } } boolean ret = patience < epoch; if (ret) { log.info("Returning early on finetune"); } return ret; }