/**
   * Returns the VFA Q-value approximation for the given state and action.
   *
   * @param s the state for which the VFA result should be returned
   * @param ga the action for which the VFA result should be returned
   * @return the VFA Q-value approximation for the given state and action.
   */
  protected ActionApproximationResult getActionApproximation(State s, GroundedAction ga) {
    List<GroundedAction> gaList = new ArrayList<GroundedAction>(1);
    gaList.add(ga);

    List<ActionApproximationResult> results = vfa.getStateActionValues(s, gaList);

    return ActionApproximationResult.extractApproximationForAction(results, ga);
  }
  /**
   * Creates a Q-value object in which the Q-value is determined from VFA.
   *
   * @param results the VFA prediction results for each action.
   * @param s the state of the Q-value
   * @param ga the action taken
   * @return a Q-value object in which the Q-value is determined from VFA.
   */
  protected QValue getQFromFeaturesFor(
      List<ActionApproximationResult> results, State s, GroundedAction ga) {

    ActionApproximationResult result =
        ActionApproximationResult.extractApproximationForAction(results, ga);
    QValue q = new QValue(s, ga, result.approximationResult.predictedValue);

    return q;
  }
  @Override
  public EpisodeAnalysis runLearningEpisodeFrom(State initialState) {

    EpisodeAnalysis ea = new EpisodeAnalysis(initialState);
    maxWeightChangeInLastEpisode = 0.;

    State curState = initialState;
    eStepCounter = 0;
    Map<Integer, EligibilityTraceVector> traces =
        new HashMap<Integer, GradientDescentSarsaLam.EligibilityTraceVector>();

    GroundedAction action = this.learningPolicy.getAction(curState);
    List<ActionApproximationResult> allCurApproxResults = this.getAllActionApproximations(curState);
    ActionApproximationResult curApprox =
        ActionApproximationResult.extractApproximationForAction(allCurApproxResults, action);

    while (!tf.isTerminal(curState) && eStepCounter < maxEpisodeSize) {

      WeightGradient gradient = this.vfa.getWeightGradient(curApprox.approximationResult);

      State nextState = action.executeIn(curState);
      GroundedAction nextAction = this.learningPolicy.getAction(nextState);
      List<ActionApproximationResult> allNextApproxResults =
          this.getAllActionApproximations(nextState);
      ActionApproximationResult nextApprox =
          ActionApproximationResult.extractApproximationForAction(allNextApproxResults, nextAction);
      double nextQV = nextApprox.approximationResult.predictedValue;
      if (tf.isTerminal(nextState)) {
        nextQV = 0.;
      }

      // manage option specifics
      double r = 0.;
      double discount = this.gamma;
      if (action.action.isPrimitive()) {
        r = rf.reward(curState, action, nextState);
        eStepCounter++;
        ea.recordTransitionTo(nextState, action, r);
      } else {
        Option o = (Option) action.action;
        r = o.getLastCumulativeReward();
        int n = o.getLastNumSteps();
        discount = Math.pow(this.gamma, n);
        eStepCounter += n;
        if (this.shouldDecomposeOptions) {
          ea.appendAndMergeEpisodeAnalysis(o.getLastExecutionResults());
        } else {
          ea.recordTransitionTo(nextState, action, r);
        }
      }

      // delta
      double delta = r + (discount * nextQV) - curApprox.approximationResult.predictedValue;

      if (useReplacingTraces) {
        // then first clear traces of unselected action and reset the trace for the selected one
        for (ActionApproximationResult aar : allCurApproxResults) {
          if (!aar.ga.equals(action)) { // clear unselected action trace
            for (FunctionWeight fw : aar.approximationResult.functionWeights) {
              traces.remove(fw.weightId());
            }
          } else { // reset trace of selected action
            for (FunctionWeight fw : aar.approximationResult.functionWeights) {
              EligibilityTraceVector storedTrace = traces.get(fw.weightId());
              if (storedTrace != null) {
                storedTrace.eligibilityValue = 0.;
              }
            }
          }
        }
      }

      // update all traces
      Set<Integer> deletedSet = new HashSet<Integer>();
      for (EligibilityTraceVector et : traces.values()) {

        int weightId = et.weight.weightId();

        et.eligibilityValue += gradient.getPartialDerivative(weightId);
        double newWeight =
            et.weight.weightValue() + this.learningRate * delta * et.eligibilityValue;
        et.weight.setWeight(newWeight);

        double deltaW = Math.abs(et.initialWeightValue - newWeight);
        if (deltaW > maxWeightChangeInLastEpisode) {
          maxWeightChangeInLastEpisode = deltaW;
        }

        et.eligibilityValue *= this.lambda * discount;
        if (et.eligibilityValue < this.minEligibityForUpdate) {
          deletedSet.add(weightId);
        }
      }

      // add new traces if need be
      for (FunctionWeight fw : curApprox.approximationResult.functionWeights) {

        int weightId = fw.weightId();
        if (!traces.containsKey(fw)) {

          // then it's new and we need to add it
          EligibilityTraceVector et =
              new EligibilityTraceVector(fw, gradient.getPartialDerivative(weightId));
          double newWeight = fw.weightValue() + this.learningRate * delta * et.eligibilityValue;
          fw.setWeight(newWeight);

          double deltaW = Math.abs(et.initialWeightValue - newWeight);
          if (deltaW > maxWeightChangeInLastEpisode) {
            maxWeightChangeInLastEpisode = deltaW;
          }

          et.eligibilityValue *= this.lambda * discount;
          if (et.eligibilityValue >= this.minEligibityForUpdate) {
            traces.put(weightId, et);
          }
        }
      }

      // delete any traces
      for (Integer t : deletedSet) {
        traces.remove(t);
      }

      // move on
      curState = nextState;
      action = nextAction;
      curApprox = nextApprox;
      allCurApproxResults = allNextApproxResults;
    }

    if (episodeHistory.size() >= numEpisodesToStore) {
      episodeHistory.poll();
      episodeHistory.offer(ea);
    }

    return ea;
  }