/**
   * Selects a next state for expansion when action a is applied in state s by randomly sampling
   * from the transition dynamics weighted by the margin of the lower and upper bound value
   * functions.
   *
   * @param s the source state of the transition
   * @param a the action applied in the source state
   * @return a {@link StateSelectionAndExpectedGap} object holding the next state to be expanded and
   *     the expected margin size of this transition.
   */
  protected StateSelectionAndExpectedGap getNextStateBySampling(State s, GroundedAction a) {

    List<TransitionProbability> tps = a.getTransitions(s);
    double sum = 0.;
    double[] weightedGap = new double[tps.size()];
    HashableState[] hashedStates = new HashableState[tps.size()];
    for (int i = 0; i < tps.size(); i++) {
      TransitionProbability tp = tps.get(i);
      HashableState nsh = this.hashingFactory.hashState(tp.s);
      hashedStates[i] = nsh;
      double gap = this.getGap(nsh);
      weightedGap[i] = tp.p * gap;
      sum += weightedGap[i];
    }

    double roll = RandomFactory.getMapped(0).nextDouble();
    double cumSum = 0.;
    for (int i = 0; i < weightedGap.length; i++) {
      cumSum += weightedGap[i] / sum;
      if (roll < cumSum) {
        StateSelectionAndExpectedGap select =
            new StateSelectionAndExpectedGap(hashedStates[i], sum);
        return select;
      }
    }

    throw new RuntimeException("Error: probabilities in state selection did not sum to 1.");
  }
  /**
   * Selects a next state for expansion when action a is applied in state s according to the next
   * possible state that has the largest lower and upper bound margin. Ties are broken randomly.
   *
   * @param s the source state of the transition
   * @param a the action applied in the source state
   * @return a {@link StateSelectionAndExpectedGap} object holding the next state to be expanded and
   *     the expected margin size of this transition.
   */
  protected StateSelectionAndExpectedGap getNextStateByMaxMargin(State s, GroundedAction a) {

    List<TransitionProbability> tps = a.getTransitions(s);
    double sum = 0.;
    double maxGap = Double.NEGATIVE_INFINITY;
    List<HashableState> maxStates = new ArrayList<HashableState>(tps.size());
    for (TransitionProbability tp : tps) {
      HashableState nsh = this.hashingFactory.hashState(tp.s);
      double gap = this.getGap(nsh);
      sum += tp.p * gap;
      if (gap == maxGap) {
        maxStates.add(nsh);
      } else if (gap > maxGap) {
        maxStates.clear();
        maxStates.add(nsh);
        maxGap = gap;
      }
    }

    int rint = RandomFactory.getMapped(0).nextInt(maxStates.size());
    StateSelectionAndExpectedGap select =
        new StateSelectionAndExpectedGap(maxStates.get(rint), sum);

    return select;
  }
  /**
   * Returns the action suggested by the planner for the given state. If a plan including this state
   * has not already been computed, the planner will be called from this state to find one.
   *
   * @param s the state for which the suggested action is to be returned.
   * @return The suggested action for the given state.
   */
  public GroundedAction querySelectedActionForState(State s) {

    StateHashTuple sh = this.stateHash(s);
    StateHashTuple indexSH = mapToStateIndex.get(sh);
    if (indexSH == null) {
      this.planFromState(s);
      return internalPolicy.get(
          sh); // no need to translate because if the state didn't exist then it got indexed with
               // this state's rep
    }

    // otherwise it's already computed
    GroundedAction res = internalPolicy.get(sh);

    // do object matching from returned result to this query state and return result
    res = (GroundedAction) res.translateParameters(indexSH.s, sh.s);

    return res;
  }
  /**
   * Selects a next state for expansion when action a is applied in state s.
   *
   * @param s the source state of the transition
   * @param a the action applied in the source state
   * @return a {@link StateSelectionAndExpectedGap} object holding the next state to be expanded and
   *     the expected margin size of this transition.
   */
  protected StateSelectionAndExpectedGap getNextState(State s, GroundedAction a) {

    if (this.selectionMode == StateSelectionMode.MODELBASED) {
      HashableState nsh = this.hashingFactory.hashState(a.executeIn(s));
      double gap = this.getGap(nsh);
      return new StateSelectionAndExpectedGap(nsh, gap);
    } else if (this.selectionMode == StateSelectionMode.WEIGHTEDMARGIN) {
      return this.getNextStateBySampling(s, a);
    } else if (this.selectionMode == StateSelectionMode.MAXMARGIN) {
      return this.getNextStateByMaxMargin(s, a);
    }
    throw new RuntimeException("Unknown state selection mode.");
  }
예제 #5
0
파일: ARTDP.java 프로젝트: jskonhovd/burlap
  @Override
  public EpisodeAnalysis runLearningEpisodeFrom(State initialState, int maxSteps) {

    EpisodeAnalysis ea = new EpisodeAnalysis(initialState);

    State curState = initialState;
    int steps = 0;
    while (!this.tf.isTerminal(curState) && steps < maxSteps) {
      GroundedAction ga = (GroundedAction) policy.getAction(curState);
      State nextState = ga.executeIn(curState);
      double r = this.rf.reward(curState, ga, nextState);

      ea.recordTransitionTo(nextState, ga, r);

      this.model.updateModel(curState, ga, nextState, r, this.tf.isTerminal(nextState));

      this.modelPlanner.performBellmanUpdateOn(curState);

      curState = nextState;
      steps++;
    }

    return ea;
  }
  /**
   * Recursive method to perform A* up to a f-score depth
   *
   * @param lastNode the node to expand
   * @param minR the minimum cumulative reward at which to stop the search (in other terms the
   *     maximum cost)
   * @param cumulatedReward the amount of reward accumulated at this node
   * @return a search node with the goal state, or null if there is no path within the reward
   *     requirements from this node
   */
  protected PrioritizedSearchNode FLimtedDFS(
      PrioritizedSearchNode lastNode, double minR, double cumulatedReward) {

    if (lastNode.priority < minR) {
      return lastNode; // fail condition (either way return the last point to which you got)
    }
    if (this.planEndNode(lastNode)) {
      return lastNode; // succeed condition
    }
    if (this.tf.isTerminal(lastNode.s.s)) {
      return null; // treat like a dead end if we're at a terminal state
    }

    State s = lastNode.s.s;

    // get all actions
    /*List <GroundedAction> gas = new ArrayList<GroundedAction>();
    for(Action a : actions){
    	gas.addAll(s.getAllGroundedActionsFor(a));
    }*/
    List<GroundedAction> gas =
        Action.getAllApplicableGroundedActionsFromActionList(this.actions, s);

    // generate successor nodes
    List<PrioritizedSearchNode> successors = new ArrayList<PrioritizedSearchNode>(gas.size());
    List<Double> successorGs = new ArrayList<Double>(gas.size());
    for (GroundedAction ga : gas) {

      State ns = ga.executeIn(s);
      StateHashTuple nsh = this.stateHash(ns);

      double r = rf.reward(s, ga, ns);
      double g = cumulatedReward + r;
      double hr = heuristic.h(ns);
      double f = g + hr;
      PrioritizedSearchNode pnsn = new PrioritizedSearchNode(nsh, ga, lastNode, f);

      // only add if this does not exist on our path already
      if (this.lastStateOnPathIsNew(pnsn)) {
        successors.add(pnsn);
        successorGs.add(g);
      }
    }

    // sort the successors by f-score to travel the most promising ones first
    Collections.sort(successors, nodeComparator);

    double maxCandR = Double.NEGATIVE_INFINITY;
    PrioritizedSearchNode bestCand = null;
    // note that since we want to expand largest expected rewards first, we should go reverse order
    // of the f-ordered successors
    for (int i = successors.size() - 1; i >= 0; i--) {
      PrioritizedSearchNode snode = successors.get(i);
      PrioritizedSearchNode cand = this.FLimtedDFS(snode, minR, successorGs.get(i));
      if (cand != null) {
        if (cand.priority > maxCandR) {
          bestCand = cand;
          maxCandR = cand.priority;
        }
      }
    }

    return bestCand;
  }
  @Override
  public EpisodeAnalysis runLearningEpisodeFrom(State initialState) {

    EpisodeAnalysis ea = new EpisodeAnalysis(initialState);
    maxWeightChangeInLastEpisode = 0.;

    State curState = initialState;
    eStepCounter = 0;
    Map<Integer, EligibilityTraceVector> traces =
        new HashMap<Integer, GradientDescentSarsaLam.EligibilityTraceVector>();

    GroundedAction action = this.learningPolicy.getAction(curState);
    List<ActionApproximationResult> allCurApproxResults = this.getAllActionApproximations(curState);
    ActionApproximationResult curApprox =
        ActionApproximationResult.extractApproximationForAction(allCurApproxResults, action);

    while (!tf.isTerminal(curState) && eStepCounter < maxEpisodeSize) {

      WeightGradient gradient = this.vfa.getWeightGradient(curApprox.approximationResult);

      State nextState = action.executeIn(curState);
      GroundedAction nextAction = this.learningPolicy.getAction(nextState);
      List<ActionApproximationResult> allNextApproxResults =
          this.getAllActionApproximations(nextState);
      ActionApproximationResult nextApprox =
          ActionApproximationResult.extractApproximationForAction(allNextApproxResults, nextAction);
      double nextQV = nextApprox.approximationResult.predictedValue;
      if (tf.isTerminal(nextState)) {
        nextQV = 0.;
      }

      // manage option specifics
      double r = 0.;
      double discount = this.gamma;
      if (action.action.isPrimitive()) {
        r = rf.reward(curState, action, nextState);
        eStepCounter++;
        ea.recordTransitionTo(nextState, action, r);
      } else {
        Option o = (Option) action.action;
        r = o.getLastCumulativeReward();
        int n = o.getLastNumSteps();
        discount = Math.pow(this.gamma, n);
        eStepCounter += n;
        if (this.shouldDecomposeOptions) {
          ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults());
        } else {
          ea.recordTransitionTo(nextState, action, r);
        }
      }

      // delta
      double delta = r + (discount * nextQV) - curApprox.approximationResult.predictedValue;

      if (useReplacingTraces) {
        // then first clear traces of unselected action and reset the trace for the selected one
        for (ActionApproximationResult aar : allCurApproxResults) {
          if (!aar.ga.equals(action)) { // clear unselected action trace
            for (FunctionWeight fw : aar.approximationResult.functionWeights) {
              traces.remove(fw.weightId());
            }
          } else { // reset trace of selected action
            for (FunctionWeight fw : aar.approximationResult.functionWeights) {
              EligibilityTraceVector storedTrace = traces.get(fw.weightId());
              if (storedTrace != null) {
                storedTrace.eligibilityValue = 0.;
              }
            }
          }
        }
      }

      // update all traces
      Set<Integer> deletedSet = new HashSet<Integer>();
      for (EligibilityTraceVector et : traces.values()) {

        int weightId = et.weight.weightId();

        et.eligibilityValue += gradient.getPartialDerivative(weightId);
        double newWeight =
            et.weight.weightValue() + this.learningRate * delta * et.eligibilityValue;
        et.weight.setWeight(newWeight);

        double deltaW = Math.abs(et.initialWeightValue - newWeight);
        if (deltaW > maxWeightChangeInLastEpisode) {
          maxWeightChangeInLastEpisode = deltaW;
        }

        et.eligibilityValue *= this.lambda * discount;
        if (et.eligibilityValue < this.minEligibityForUpdate) {
          deletedSet.add(weightId);
        }
      }

      // add new traces if need be
      for (FunctionWeight fw : curApprox.approximationResult.functionWeights) {

        int weightId = fw.weightId();
        if (!traces.containsKey(fw)) {

          // then it's new and we need to add it
          EligibilityTraceVector et =
              new EligibilityTraceVector(fw, gradient.getPartialDerivative(weightId));
          double newWeight = fw.weightValue() + this.learningRate * delta * et.eligibilityValue;
          fw.setWeight(newWeight);

          double deltaW = Math.abs(et.initialWeightValue - newWeight);
          if (deltaW > maxWeightChangeInLastEpisode) {
            maxWeightChangeInLastEpisode = deltaW;
          }

          et.eligibilityValue *= this.lambda * discount;
          if (et.eligibilityValue >= this.minEligibityForUpdate) {
            traces.put(weightId, et);
          }
        }
      }

      // delete any traces
      for (Integer t : deletedSet) {
        traces.remove(t);
      }

      // move on
      curState = nextState;
      action = nextAction;
      curApprox = nextApprox;
      allCurApproxResults = allNextApproxResults;
    }

    if (episodeHistory.size() >= numEpisodesToStore) {
      episodeHistory.poll();
      episodeHistory.offer(ea);
    }

    return ea;
  }
예제 #8
0
  @Override
  public EpisodeAnalysis runLearningEpisodeFrom(State initialState, int maxSteps) {

    this.toggleShouldAnnotateOptionDecomposition(shouldAnnotateOptions);

    EpisodeAnalysis ea = new EpisodeAnalysis(initialState);

    StateHashTuple curState = this.stateHash(initialState);
    eStepCounter = 0;

    maxQChangeInLastEpisode = 0.;

    while (!tf.isTerminal(curState.s) && eStepCounter < maxSteps) {

      GroundedAction action = (GroundedAction) learningPolicy.getAction(curState.s);
      QValue curQ = this.getQ(curState, action);

      StateHashTuple nextState = this.stateHash(action.executeIn(curState.s));
      double maxQ = 0.;

      if (!tf.isTerminal(nextState.s)) {
        maxQ = this.getMaxQ(nextState);
      }

      // manage option specifics
      double r = 0.;
      double discount = this.gamma;
      if (action.action.isPrimitive()) {
        r = rf.reward(curState.s, action, nextState.s);
        eStepCounter++;
        ea.recordTransitionTo(nextState.s, action, r);
      } else {
        Option o = (Option) action.action;
        r = o.getLastCumulativeReward();
        int n = o.getLastNumSteps();
        discount = Math.pow(this.gamma, n);
        eStepCounter += n;
        if (this.shouldDecomposeOptions) {
          ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults());
        } else {
          ea.recordTransitionTo(nextState.s, action, r);
        }
      }

      double oldQ = curQ.q;

      // update Q-value
      curQ.q =
          curQ.q
              + this.learningRate.pollLearningRate(curState.s, action)
                  * (r + (discount * maxQ) - curQ.q);

      double deltaQ = Math.abs(oldQ - curQ.q);
      if (deltaQ > maxQChangeInLastEpisode) {
        maxQChangeInLastEpisode = deltaQ;
      }

      // move on
      curState = nextState;
    }

    if (episodeHistory.size() >= numEpisodesToStore) {
      episodeHistory.poll();
    }
    episodeHistory.offer(ea);

    return ea;
  }