/** * Selects a next state for expansion when action a is applied in state s by randomly sampling * from the transition dynamics weighted by the margin of the lower and upper bound value * functions. * * @param s the source state of the transition * @param a the action applied in the source state * @return a {@link StateSelectionAndExpectedGap} object holding the next state to be expanded and * the expected margin size of this transition. */ protected StateSelectionAndExpectedGap getNextStateBySampling(State s, GroundedAction a) { List<TransitionProbability> tps = a.getTransitions(s); double sum = 0.; double[] weightedGap = new double[tps.size()]; HashableState[] hashedStates = new HashableState[tps.size()]; for (int i = 0; i < tps.size(); i++) { TransitionProbability tp = tps.get(i); HashableState nsh = this.hashingFactory.hashState(tp.s); hashedStates[i] = nsh; double gap = this.getGap(nsh); weightedGap[i] = tp.p * gap; sum += weightedGap[i]; } double roll = RandomFactory.getMapped(0).nextDouble(); double cumSum = 0.; for (int i = 0; i < weightedGap.length; i++) { cumSum += weightedGap[i] / sum; if (roll < cumSum) { StateSelectionAndExpectedGap select = new StateSelectionAndExpectedGap(hashedStates[i], sum); return select; } } throw new RuntimeException("Error: probabilities in state selection did not sum to 1."); }
/** * Selects a next state for expansion when action a is applied in state s according to the next * possible state that has the largest lower and upper bound margin. Ties are broken randomly. * * @param s the source state of the transition * @param a the action applied in the source state * @return a {@link StateSelectionAndExpectedGap} object holding the next state to be expanded and * the expected margin size of this transition. */ protected StateSelectionAndExpectedGap getNextStateByMaxMargin(State s, GroundedAction a) { List<TransitionProbability> tps = a.getTransitions(s); double sum = 0.; double maxGap = Double.NEGATIVE_INFINITY; List<HashableState> maxStates = new ArrayList<HashableState>(tps.size()); for (TransitionProbability tp : tps) { HashableState nsh = this.hashingFactory.hashState(tp.s); double gap = this.getGap(nsh); sum += tp.p * gap; if (gap == maxGap) { maxStates.add(nsh); } else if (gap > maxGap) { maxStates.clear(); maxStates.add(nsh); maxGap = gap; } } int rint = RandomFactory.getMapped(0).nextInt(maxStates.size()); StateSelectionAndExpectedGap select = new StateSelectionAndExpectedGap(maxStates.get(rint), sum); return select; }
/** * Returns the action suggested by the planner for the given state. If a plan including this state * has not already been computed, the planner will be called from this state to find one. * * @param s the state for which the suggested action is to be returned. * @return The suggested action for the given state. */ public GroundedAction querySelectedActionForState(State s) { StateHashTuple sh = this.stateHash(s); StateHashTuple indexSH = mapToStateIndex.get(sh); if (indexSH == null) { this.planFromState(s); return internalPolicy.get( sh); // no need to translate because if the state didn't exist then it got indexed with // this state's rep } // otherwise it's already computed GroundedAction res = internalPolicy.get(sh); // do object matching from returned result to this query state and return result res = (GroundedAction) res.translateParameters(indexSH.s, sh.s); return res; }
/** * Selects a next state for expansion when action a is applied in state s. * * @param s the source state of the transition * @param a the action applied in the source state * @return a {@link StateSelectionAndExpectedGap} object holding the next state to be expanded and * the expected margin size of this transition. */ protected StateSelectionAndExpectedGap getNextState(State s, GroundedAction a) { if (this.selectionMode == StateSelectionMode.MODELBASED) { HashableState nsh = this.hashingFactory.hashState(a.executeIn(s)); double gap = this.getGap(nsh); return new StateSelectionAndExpectedGap(nsh, gap); } else if (this.selectionMode == StateSelectionMode.WEIGHTEDMARGIN) { return this.getNextStateBySampling(s, a); } else if (this.selectionMode == StateSelectionMode.MAXMARGIN) { return this.getNextStateByMaxMargin(s, a); } throw new RuntimeException("Unknown state selection mode."); }
@Override public EpisodeAnalysis runLearningEpisodeFrom(State initialState, int maxSteps) { EpisodeAnalysis ea = new EpisodeAnalysis(initialState); State curState = initialState; int steps = 0; while (!this.tf.isTerminal(curState) && steps < maxSteps) { GroundedAction ga = (GroundedAction) policy.getAction(curState); State nextState = ga.executeIn(curState); double r = this.rf.reward(curState, ga, nextState); ea.recordTransitionTo(nextState, ga, r); this.model.updateModel(curState, ga, nextState, r, this.tf.isTerminal(nextState)); this.modelPlanner.performBellmanUpdateOn(curState); curState = nextState; steps++; } return ea; }
/** * Recursive method to perform A* up to a f-score depth * * @param lastNode the node to expand * @param minR the minimum cumulative reward at which to stop the search (in other terms the * maximum cost) * @param cumulatedReward the amount of reward accumulated at this node * @return a search node with the goal state, or null if there is no path within the reward * requirements from this node */ protected PrioritizedSearchNode FLimtedDFS( PrioritizedSearchNode lastNode, double minR, double cumulatedReward) { if (lastNode.priority < minR) { return lastNode; // fail condition (either way return the last point to which you got) } if (this.planEndNode(lastNode)) { return lastNode; // succeed condition } if (this.tf.isTerminal(lastNode.s.s)) { return null; // treat like a dead end if we're at a terminal state } State s = lastNode.s.s; // get all actions /*List <GroundedAction> gas = new ArrayList<GroundedAction>(); for(Action a : actions){ gas.addAll(s.getAllGroundedActionsFor(a)); }*/ List<GroundedAction> gas = Action.getAllApplicableGroundedActionsFromActionList(this.actions, s); // generate successor nodes List<PrioritizedSearchNode> successors = new ArrayList<PrioritizedSearchNode>(gas.size()); List<Double> successorGs = new ArrayList<Double>(gas.size()); for (GroundedAction ga : gas) { State ns = ga.executeIn(s); StateHashTuple nsh = this.stateHash(ns); double r = rf.reward(s, ga, ns); double g = cumulatedReward + r; double hr = heuristic.h(ns); double f = g + hr; PrioritizedSearchNode pnsn = new PrioritizedSearchNode(nsh, ga, lastNode, f); // only add if this does not exist on our path already if (this.lastStateOnPathIsNew(pnsn)) { successors.add(pnsn); successorGs.add(g); } } // sort the successors by f-score to travel the most promising ones first Collections.sort(successors, nodeComparator); double maxCandR = Double.NEGATIVE_INFINITY; PrioritizedSearchNode bestCand = null; // note that since we want to expand largest expected rewards first, we should go reverse order // of the f-ordered successors for (int i = successors.size() - 1; i >= 0; i--) { PrioritizedSearchNode snode = successors.get(i); PrioritizedSearchNode cand = this.FLimtedDFS(snode, minR, successorGs.get(i)); if (cand != null) { if (cand.priority > maxCandR) { bestCand = cand; maxCandR = cand.priority; } } } return bestCand; }
@Override public EpisodeAnalysis runLearningEpisodeFrom(State initialState) { EpisodeAnalysis ea = new EpisodeAnalysis(initialState); maxWeightChangeInLastEpisode = 0.; State curState = initialState; eStepCounter = 0; Map<Integer, EligibilityTraceVector> traces = new HashMap<Integer, GradientDescentSarsaLam.EligibilityTraceVector>(); GroundedAction action = this.learningPolicy.getAction(curState); List<ActionApproximationResult> allCurApproxResults = this.getAllActionApproximations(curState); ActionApproximationResult curApprox = ActionApproximationResult.extractApproximationForAction(allCurApproxResults, action); while (!tf.isTerminal(curState) && eStepCounter < maxEpisodeSize) { WeightGradient gradient = this.vfa.getWeightGradient(curApprox.approximationResult); State nextState = action.executeIn(curState); GroundedAction nextAction = this.learningPolicy.getAction(nextState); List<ActionApproximationResult> allNextApproxResults = this.getAllActionApproximations(nextState); ActionApproximationResult nextApprox = ActionApproximationResult.extractApproximationForAction(allNextApproxResults, nextAction); double nextQV = nextApprox.approximationResult.predictedValue; if (tf.isTerminal(nextState)) { nextQV = 0.; } // manage option specifics double r = 0.; double discount = this.gamma; if (action.action.isPrimitive()) { r = rf.reward(curState, action, nextState); eStepCounter++; ea.recordTransitionTo(nextState, action, r); } else { Option o = (Option) action.action; r = o.getLastCumulativeReward(); int n = o.getLastNumSteps(); discount = Math.pow(this.gamma, n); eStepCounter += n; if (this.shouldDecomposeOptions) { ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults()); } else { ea.recordTransitionTo(nextState, action, r); } } // delta double delta = r + (discount * nextQV) - curApprox.approximationResult.predictedValue; if (useReplacingTraces) { // then first clear traces of unselected action and reset the trace for the selected one for (ActionApproximationResult aar : allCurApproxResults) { if (!aar.ga.equals(action)) { // clear unselected action trace for (FunctionWeight fw : aar.approximationResult.functionWeights) { traces.remove(fw.weightId()); } } else { // reset trace of selected action for (FunctionWeight fw : aar.approximationResult.functionWeights) { EligibilityTraceVector storedTrace = traces.get(fw.weightId()); if (storedTrace != null) { storedTrace.eligibilityValue = 0.; } } } } } // update all traces Set<Integer> deletedSet = new HashSet<Integer>(); for (EligibilityTraceVector et : traces.values()) { int weightId = et.weight.weightId(); et.eligibilityValue += gradient.getPartialDerivative(weightId); double newWeight = et.weight.weightValue() + this.learningRate * delta * et.eligibilityValue; et.weight.setWeight(newWeight); double deltaW = Math.abs(et.initialWeightValue - newWeight); if (deltaW > maxWeightChangeInLastEpisode) { maxWeightChangeInLastEpisode = deltaW; } et.eligibilityValue *= this.lambda * discount; if (et.eligibilityValue < this.minEligibityForUpdate) { deletedSet.add(weightId); } } // add new traces if need be for (FunctionWeight fw : curApprox.approximationResult.functionWeights) { int weightId = fw.weightId(); if (!traces.containsKey(fw)) { // then it's new and we need to add it EligibilityTraceVector et = new EligibilityTraceVector(fw, gradient.getPartialDerivative(weightId)); double newWeight = fw.weightValue() + this.learningRate * delta * et.eligibilityValue; fw.setWeight(newWeight); double deltaW = Math.abs(et.initialWeightValue - newWeight); if (deltaW > maxWeightChangeInLastEpisode) { maxWeightChangeInLastEpisode = deltaW; } et.eligibilityValue *= this.lambda * discount; if (et.eligibilityValue >= this.minEligibityForUpdate) { traces.put(weightId, et); } } } // delete any traces for (Integer t : deletedSet) { traces.remove(t); } // move on curState = nextState; action = nextAction; curApprox = nextApprox; allCurApproxResults = allNextApproxResults; } if (episodeHistory.size() >= numEpisodesToStore) { episodeHistory.poll(); episodeHistory.offer(ea); } return ea; }
@Override public EpisodeAnalysis runLearningEpisodeFrom(State initialState, int maxSteps) { this.toggleShouldAnnotateOptionDecomposition(shouldAnnotateOptions); EpisodeAnalysis ea = new EpisodeAnalysis(initialState); StateHashTuple curState = this.stateHash(initialState); eStepCounter = 0; maxQChangeInLastEpisode = 0.; while (!tf.isTerminal(curState.s) && eStepCounter < maxSteps) { GroundedAction action = (GroundedAction) learningPolicy.getAction(curState.s); QValue curQ = this.getQ(curState, action); StateHashTuple nextState = this.stateHash(action.executeIn(curState.s)); double maxQ = 0.; if (!tf.isTerminal(nextState.s)) { maxQ = this.getMaxQ(nextState); } // manage option specifics double r = 0.; double discount = this.gamma; if (action.action.isPrimitive()) { r = rf.reward(curState.s, action, nextState.s); eStepCounter++; ea.recordTransitionTo(nextState.s, action, r); } else { Option o = (Option) action.action; r = o.getLastCumulativeReward(); int n = o.getLastNumSteps(); discount = Math.pow(this.gamma, n); eStepCounter += n; if (this.shouldDecomposeOptions) { ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults()); } else { ea.recordTransitionTo(nextState.s, action, r); } } double oldQ = curQ.q; // update Q-value curQ.q = curQ.q + this.learningRate.pollLearningRate(curState.s, action) * (r + (discount * maxQ) - curQ.q); double deltaQ = Math.abs(oldQ - curQ.q); if (deltaQ > maxQChangeInLastEpisode) { maxQChangeInLastEpisode = deltaQ; } // move on curState = nextState; } if (episodeHistory.size() >= numEpisodesToStore) { episodeHistory.poll(); } episodeHistory.offer(ea); return ea; }