/** * Creates a trace for the given weight with the given eligibility value * * @param weight the VFA weight * @param eligibilityValue the eligibility to assign to it. */ public EligibilityTraceVector(FunctionWeight weight, double eligibilityValue) { this.weight = weight; this.eligibilityValue = eligibilityValue; this.initialWeightValue = weight.weightValue(); }
@Override public EpisodeAnalysis runLearningEpisodeFrom(State initialState) { EpisodeAnalysis ea = new EpisodeAnalysis(initialState); maxWeightChangeInLastEpisode = 0.; State curState = initialState; eStepCounter = 0; Map<Integer, EligibilityTraceVector> traces = new HashMap<Integer, GradientDescentSarsaLam.EligibilityTraceVector>(); GroundedAction action = this.learningPolicy.getAction(curState); List<ActionApproximationResult> allCurApproxResults = this.getAllActionApproximations(curState); ActionApproximationResult curApprox = ActionApproximationResult.extractApproximationForAction(allCurApproxResults, action); while (!tf.isTerminal(curState) && eStepCounter < maxEpisodeSize) { WeightGradient gradient = this.vfa.getWeightGradient(curApprox.approximationResult); State nextState = action.executeIn(curState); GroundedAction nextAction = this.learningPolicy.getAction(nextState); List<ActionApproximationResult> allNextApproxResults = this.getAllActionApproximations(nextState); ActionApproximationResult nextApprox = ActionApproximationResult.extractApproximationForAction(allNextApproxResults, nextAction); double nextQV = nextApprox.approximationResult.predictedValue; if (tf.isTerminal(nextState)) { nextQV = 0.; } // manage option specifics double r = 0.; double discount = this.gamma; if (action.action.isPrimitive()) { r = rf.reward(curState, action, nextState); eStepCounter++; ea.recordTransitionTo(nextState, action, r); } else { Option o = (Option) action.action; r = o.getLastCumulativeReward(); int n = o.getLastNumSteps(); discount = Math.pow(this.gamma, n); eStepCounter += n; if (this.shouldDecomposeOptions) { ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults()); } else { ea.recordTransitionTo(nextState, action, r); } } // delta double delta = r + (discount * nextQV) - curApprox.approximationResult.predictedValue; if (useReplacingTraces) { // then first clear traces of unselected action and reset the trace for the selected one for (ActionApproximationResult aar : allCurApproxResults) { if (!aar.ga.equals(action)) { // clear unselected action trace for (FunctionWeight fw : aar.approximationResult.functionWeights) { traces.remove(fw.weightId()); } } else { // reset trace of selected action for (FunctionWeight fw : aar.approximationResult.functionWeights) { EligibilityTraceVector storedTrace = traces.get(fw.weightId()); if (storedTrace != null) { storedTrace.eligibilityValue = 0.; } } } } } // update all traces Set<Integer> deletedSet = new HashSet<Integer>(); for (EligibilityTraceVector et : traces.values()) { int weightId = et.weight.weightId(); et.eligibilityValue += gradient.getPartialDerivative(weightId); double newWeight = et.weight.weightValue() + this.learningRate * delta * et.eligibilityValue; et.weight.setWeight(newWeight); double deltaW = Math.abs(et.initialWeightValue - newWeight); if (deltaW > maxWeightChangeInLastEpisode) { maxWeightChangeInLastEpisode = deltaW; } et.eligibilityValue *= this.lambda * discount; if (et.eligibilityValue < this.minEligibityForUpdate) { deletedSet.add(weightId); } } // add new traces if need be for (FunctionWeight fw : curApprox.approximationResult.functionWeights) { int weightId = fw.weightId(); if (!traces.containsKey(fw)) { // then it's new and we need to add it EligibilityTraceVector et = new EligibilityTraceVector(fw, gradient.getPartialDerivative(weightId)); double newWeight = fw.weightValue() + this.learningRate * delta * et.eligibilityValue; fw.setWeight(newWeight); double deltaW = Math.abs(et.initialWeightValue - newWeight); if (deltaW > maxWeightChangeInLastEpisode) { maxWeightChangeInLastEpisode = deltaW; } et.eligibilityValue *= this.lambda * discount; if (et.eligibilityValue >= this.minEligibityForUpdate) { traces.put(weightId, et); } } } // delete any traces for (Integer t : deletedSet) { traces.remove(t); } // move on curState = nextState; action = nextAction; curApprox = nextApprox; allCurApproxResults = allNextApproxResults; } if (episodeHistory.size() >= numEpisodesToStore) { episodeHistory.poll(); episodeHistory.offer(ea); } return ea; }