/** * Computes the Q-value using the uncached transition dynamics produced by the Action object * methods. This computation *is* compatible with {@link * burlap.behavior.singleagent.options.Option} objects. * * @param sh the given state * @param ga the given action * @return the double value of a Q-value for the given state-aciton pair. */ protected double computeQ(StateHashTuple sh, GroundedAction ga) { double q = 0.; if (ga.action instanceof Option) { Option o = (Option) ga.action; double expectedR = o.getExpectedRewards(sh.s, ga.params); q += expectedR; List<TransitionProbability> tps = o.getTransitions(sh.s, ga.params); for (TransitionProbability tp : tps) { double vp = this.value(tp.s); // note that for options, tp.p will be the *discounted* probability of transition to s', // so there is no need for a discount factor to be included q += tp.p * vp; } } else { List<TransitionProbability> tps = ga.action.getTransitions(sh.s, ga.params); for (TransitionProbability tp : tps) { double vp = this.value(tp.s); double discount = this.gamma; double r = rf.reward(sh.s, ga, tp.s); q += tp.p * (r + (discount * vp)); } } return q; }
/** * Returns the Q-value for a given set and the possible transitions from it for a given action. * This computation *is* compatible with {@link burlap.behavior.singleagent.options.Option} * objects. * * @param s the given state * @param trans the given action transitions * @return the double value of a Q-value */ protected double computeQ(State s, ActionTransitions trans) { double q = 0.; if (trans.ga.action instanceof Option) { Option o = (Option) trans.ga.action; double expectedR = o.getExpectedRewards(s, trans.ga.params); q += expectedR; for (HashedTransitionProbability tp : trans.transitions) { double vp = this.value(tp.sh); // note that for options, tp.p will be the *discounted* probability of transition to s', // so there is no need for a discount factor to be included q += tp.p * vp; } } else { for (HashedTransitionProbability tp : trans.transitions) { double vp = this.value(tp.sh); double discount = this.gamma; double r = rf.reward(s, trans.ga, tp.sh.s); q += tp.p * (r + (discount * vp)); } } return q; }
/** * Options need to to have transition probabilities computed and keep track of the possible * termination states using as hashed data structure. This method tells each option which state * hashing factory to use. */ protected void initializeOptionsForExpectationComputations() { for (Action a : this.actions) { if (a instanceof Option) { ((Option) a).setExpectationHashingFactory(hashingFactory); } } }
/** * Sets whether options that are decomposed into primitives will have the option that produced * them and listed. The default value is true. If option decomposition is not enabled, changing * this value will do nothing. When it is enabled and this is set to true, primitive actions taken * by an option in EpisodeAnalysis objects will be recorded with a special action name that * indicates which option was called to produce the primitive action as well as which step of the * option the primitive action is. When set to false, recorded names of primitives will be only * the primitive aciton's name it will be unclear which option was taken to generate it. * * @param toggle whether to annotate the primitive actions of options with the calling option's * name. */ public void toggleShouldAnnotateOptionDecomposition(boolean toggle) { shouldAnnotateOptions = toggle; for (Action a : actions) { if (a instanceof Option) { ((Option) a).toggleShouldAnnotateResults(toggle); } } }
/** * Sets whether the primitive actions taken during an options will be included as steps in * produced EpisodeAnalysis objects. The default value is true. If this is set to false, then * EpisodeAnalysis objects returned from a learning episode will record options as a single * "action" and the steps taken by the option will be hidden. * * @param toggle whether to decompose options into the primitive actions taken by them or not. */ public void toggleShouldDecomposeOption(boolean toggle) { this.shouldDecomposeOptions = toggle; for (Action a : actions) { if (a instanceof Option) { ((Option) a).toggleShouldRecordResults(toggle); } } }
@Override public EpisodeAnalysis runLearningEpisodeFrom(State initialState) { EpisodeAnalysis ea = new EpisodeAnalysis(initialState); maxWeightChangeInLastEpisode = 0.; State curState = initialState; eStepCounter = 0; Map<Integer, EligibilityTraceVector> traces = new HashMap<Integer, GradientDescentSarsaLam.EligibilityTraceVector>(); GroundedAction action = this.learningPolicy.getAction(curState); List<ActionApproximationResult> allCurApproxResults = this.getAllActionApproximations(curState); ActionApproximationResult curApprox = ActionApproximationResult.extractApproximationForAction(allCurApproxResults, action); while (!tf.isTerminal(curState) && eStepCounter < maxEpisodeSize) { WeightGradient gradient = this.vfa.getWeightGradient(curApprox.approximationResult); State nextState = action.executeIn(curState); GroundedAction nextAction = this.learningPolicy.getAction(nextState); List<ActionApproximationResult> allNextApproxResults = this.getAllActionApproximations(nextState); ActionApproximationResult nextApprox = ActionApproximationResult.extractApproximationForAction(allNextApproxResults, nextAction); double nextQV = nextApprox.approximationResult.predictedValue; if (tf.isTerminal(nextState)) { nextQV = 0.; } // manage option specifics double r = 0.; double discount = this.gamma; if (action.action.isPrimitive()) { r = rf.reward(curState, action, nextState); eStepCounter++; ea.recordTransitionTo(nextState, action, r); } else { Option o = (Option) action.action; r = o.getLastCumulativeReward(); int n = o.getLastNumSteps(); discount = Math.pow(this.gamma, n); eStepCounter += n; if (this.shouldDecomposeOptions) { ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults()); } else { ea.recordTransitionTo(nextState, action, r); } } // delta double delta = r + (discount * nextQV) - curApprox.approximationResult.predictedValue; if (useReplacingTraces) { // then first clear traces of unselected action and reset the trace for the selected one for (ActionApproximationResult aar : allCurApproxResults) { if (!aar.ga.equals(action)) { // clear unselected action trace for (FunctionWeight fw : aar.approximationResult.functionWeights) { traces.remove(fw.weightId()); } } else { // reset trace of selected action for (FunctionWeight fw : aar.approximationResult.functionWeights) { EligibilityTraceVector storedTrace = traces.get(fw.weightId()); if (storedTrace != null) { storedTrace.eligibilityValue = 0.; } } } } } // update all traces Set<Integer> deletedSet = new HashSet<Integer>(); for (EligibilityTraceVector et : traces.values()) { int weightId = et.weight.weightId(); et.eligibilityValue += gradient.getPartialDerivative(weightId); double newWeight = et.weight.weightValue() + this.learningRate * delta * et.eligibilityValue; et.weight.setWeight(newWeight); double deltaW = Math.abs(et.initialWeightValue - newWeight); if (deltaW > maxWeightChangeInLastEpisode) { maxWeightChangeInLastEpisode = deltaW; } et.eligibilityValue *= this.lambda * discount; if (et.eligibilityValue < this.minEligibityForUpdate) { deletedSet.add(weightId); } } // add new traces if need be for (FunctionWeight fw : curApprox.approximationResult.functionWeights) { int weightId = fw.weightId(); if (!traces.containsKey(fw)) { // then it's new and we need to add it EligibilityTraceVector et = new EligibilityTraceVector(fw, gradient.getPartialDerivative(weightId)); double newWeight = fw.weightValue() + this.learningRate * delta * et.eligibilityValue; fw.setWeight(newWeight); double deltaW = Math.abs(et.initialWeightValue - newWeight); if (deltaW > maxWeightChangeInLastEpisode) { maxWeightChangeInLastEpisode = deltaW; } et.eligibilityValue *= this.lambda * discount; if (et.eligibilityValue >= this.minEligibityForUpdate) { traces.put(weightId, et); } } } // delete any traces for (Integer t : deletedSet) { traces.remove(t); } // move on curState = nextState; action = nextAction; curApprox = nextApprox; allCurApproxResults = allNextApproxResults; } if (episodeHistory.size() >= numEpisodesToStore) { episodeHistory.poll(); episodeHistory.offer(ea); } return ea; }
@Override public EpisodeAnalysis runLearningEpisodeFrom(State initialState, int maxSteps) { this.toggleShouldAnnotateOptionDecomposition(shouldAnnotateOptions); EpisodeAnalysis ea = new EpisodeAnalysis(initialState); StateHashTuple curState = this.stateHash(initialState); eStepCounter = 0; maxQChangeInLastEpisode = 0.; while (!tf.isTerminal(curState.s) && eStepCounter < maxSteps) { GroundedAction action = (GroundedAction) learningPolicy.getAction(curState.s); QValue curQ = this.getQ(curState, action); StateHashTuple nextState = this.stateHash(action.executeIn(curState.s)); double maxQ = 0.; if (!tf.isTerminal(nextState.s)) { maxQ = this.getMaxQ(nextState); } // manage option specifics double r = 0.; double discount = this.gamma; if (action.action.isPrimitive()) { r = rf.reward(curState.s, action, nextState.s); eStepCounter++; ea.recordTransitionTo(nextState.s, action, r); } else { Option o = (Option) action.action; r = o.getLastCumulativeReward(); int n = o.getLastNumSteps(); discount = Math.pow(this.gamma, n); eStepCounter += n; if (this.shouldDecomposeOptions) { ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults()); } else { ea.recordTransitionTo(nextState.s, action, r); } } double oldQ = curQ.q; // update Q-value curQ.q = curQ.q + this.learningRate.pollLearningRate(curState.s, action) * (r + (discount * maxQ) - curQ.q); double deltaQ = Math.abs(oldQ - curQ.q); if (deltaQ > maxQChangeInLastEpisode) { maxQChangeInLastEpisode = deltaQ; } // move on curState = nextState; } if (episodeHistory.size() >= numEpisodesToStore) { episodeHistory.poll(); } episodeHistory.offer(ea); return ea; }