@Override
    protected State performActionHelper(State s, String[] params) {

      // get agent and current position
      ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT);
      int curX = agent.getDiscValForAttribute(ATTX);
      int curY = agent.getDiscValForAttribute(ATTY);

      // sample directon with random roll
      double r = Math.random();
      double sumProb = 0.;
      int dir = 0;
      for (int i = 0; i < this.directionProbs.length; i++) {
        sumProb += this.directionProbs[i];
        if (r < sumProb) {
          dir = i;
          break; // found direction
        }
      }

      // get resulting position
      int[] newPos = this.moveResult(curX, curY, dir);

      // set the new position
      agent.setValue(ATTX, newPos[0]);
      agent.setValue(ATTY, newPos[1]);

      // return the state we just modified
      return s;
    }
    @Override
    public List<TransitionProbability> getTransitions(State s, String[] params) {

      // get agent and current position
      ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT);
      int curX = agent.getDiscValForAttribute(ATTX);
      int curY = agent.getDiscValForAttribute(ATTY);

      List<TransitionProbability> tps = new ArrayList<TransitionProbability>(4);
      TransitionProbability noChangeTransition = null;
      for (int i = 0; i < this.directionProbs.length; i++) {
        int[] newPos = this.moveResult(curX, curY, i);
        if (newPos[0] != curX || newPos[1] != curY) {
          // new possible outcome
          State ns = s.copy();
          ObjectInstance nagent = ns.getFirstObjectOfClass(CLASSAGENT);
          nagent.setValue(ATTX, newPos[0]);
          nagent.setValue(ATTY, newPos[1]);

          // create transition probability object and add to our list of outcomes
          tps.add(new TransitionProbability(ns, this.directionProbs[i]));
        } else {
          // this direction didn't lead anywhere new
          // if there are existing possible directions that wouldn't lead anywhere, aggregate with
          // them
          if (noChangeTransition != null) {
            noChangeTransition.p += this.directionProbs[i];
          } else {
            // otherwise create this new state and transition
            noChangeTransition = new TransitionProbability(s.copy(), this.directionProbs[i]);
            tps.add(noChangeTransition);
          }
        }
      }

      return tps;
    }
    @Override
    public double[] generateFeatureVectorFrom(State s) {

      ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT);
      int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX);
      int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY);

      double[] vec = new double[this.getDim()];

      if (this.map[ax][ay] > 0) {
        vec[map[ax][ay] - 1] = 1.;
      }

      return vec;
    }
    @Override
    public boolean isTerminal(State s) {

      // get location of agent in next state
      ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT);
      int ax = agent.getDiscValForAttribute(ATTX);
      int ay = agent.getDiscValForAttribute(ATTY);

      // are they at goal location?
      if (ax == this.goalX && ay == this.goalY) {
        return true;
      }

      return false;
    }
    @Override
    public double reward(State s, GroundedAction a, State sprime) {

      // get location of agent in next state
      ObjectInstance agent = sprime.getFirstObjectOfClass(CLASSAGENT);
      int ax = agent.getDiscValForAttribute(ATTX);
      int ay = agent.getDiscValForAttribute(ATTY);

      // are they at goal location?
      if (ax == this.goalX && ay == this.goalY) {
        return 100.;
      }

      return -1;
    }
    @Override
    public double[] generateFeatureVectorFrom(State s) {

      ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT);
      int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX);
      int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY);

      double[] vec = new double[this.getDim()];

      if (this.map[ax][ay] > 0) {
        vec[map[ax][ay] - 1] = 1.;
      }

      // now do distances
      // first seed to max val
      for (int i = this.numCells; i < vec.length; i++) {
        vec[i] = 61.;
      }

      // set goal (type 0) to its goal position assuming only 1 instance of it, so we don't scan
      // large distances for it
      if (this.gx != -1) {
        vec[this.numCells] = Math.abs(this.gx - ax) + Math.abs(this.gy - ay);
      }

      // now do scan
      for (int r = 0; r < 16; r++) {

        int x;

        // scan top
        int y = ay + r;
        if (y < 30) {
          for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        // scan bottom
        y = ay - r;
        if (y > -1) {
          for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        // scan left
        x = ax - r;
        if (x > -1) {
          for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        // scan right
        x = ax + r;
        if (x < 30) {
          for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        if (this.foundNearestForAll(vec)) {
          break;
        }
      }

      return vec;
    }
예제 #7
0
 /**
  * Returns the object instance in a state that holds the y-position information.
  *
  * @param s the state for which to get the y-position
  * @return the object instance in a state that holds the y-position information.
  */
 protected ObjectInstance yObjectInstance(State s) {
   if (this.yClassName != null) {
     return s.getFirstObjectOfClass(yClassName);
   }
   return s.getObject(yObjectName);
 }