Exemple #1
0
  public static Observation makeObservationFor(POMDPDomain d, GroundedAction a, POMDPState s) {
    ObjectInstance indexer = s.getObject(Names.OBJ_INDEXER);
    int index = indexer.getDiscValForAttribute(Names.ATTR_INDEX);

    if (index == iterations) {
      return d.getObservation(Names.OBS_COMPLETE);
    }

    ObjectInstance leftDoor = s.getObject(Names.OBJ_LEFT_DOOR);
    ObjectInstance rightDoor = s.getObject(Names.OBJ_RIGHT_DOOR);

    int leftDoorTiger = leftDoor.getDiscValForAttribute(Names.ATTR_TIGERNESS);
    java.util.Random random = new java.util.Random();

    if (a.action.getName().equals(Names.ACTION_LISTEN)) {

      Observation left =
          d.getObservation(Names.OBS_LEFT_DOOR + random.nextInt(observationsPerState));
      Observation right =
          d.getObservation(Names.OBS_RIGHT_DOOR + random.nextInt(observationsPerState));

      if (leftDoorTiger == 1) {
        return random.nextDouble() < 1 - noise ? left : right;
      } else {
        return random.nextDouble() < 1 - noise ? right : left;
      }
    } else {
      return d.getObservation(Names.OBS_NULL);
    }
  }
    @Override
    public boolean isTrue(State s, String[] params) {
      ObjectInstance agent = s.getObject(params[0]);
      ObjectInstance location = s.getObject(params[1]);

      int ax = agent.getDiscValForAttribute(ATTX);
      int ay = agent.getDiscValForAttribute(ATTY);

      int lx = location.getDiscValForAttribute(ATTX);
      int ly = location.getDiscValForAttribute(ATTY);

      return ax == lx && ay == ly;
    }
    @Override
    public void paintObject(
        Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) {

      // agent will be filled in blue
      g2.setColor(Color.BLUE);

      // set up floats for the width and height of our domain
      float fWidth = ExampleGridWorld.this.map.length;
      float fHeight = ExampleGridWorld.this.map[0].length;

      // determine the width of a single cell on our canvas
      // such that the whole map can be painted
      float width = cWidth / fWidth;
      float height = cHeight / fHeight;

      int ax = ob.getDiscValForAttribute(ATTX);
      int ay = ob.getDiscValForAttribute(ATTY);

      // left corrdinate of cell on our canvas
      float rx = ax * width;

      // top coordinate of cell on our canvas
      // coordinate system adjustment because the java canvas
      // origin is in the top left instead of the bottom right
      float ry = cHeight - height - ay * height;

      // paint the rectangle
      g2.fill(new Rectangle2D.Float(rx, ry, width, height));
    }
    @Override
    protected State performActionHelper(State s, String[] params) {

      // get agent and current position
      ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT);
      int curX = agent.getDiscValForAttribute(ATTX);
      int curY = agent.getDiscValForAttribute(ATTY);

      // sample directon with random roll
      double r = Math.random();
      double sumProb = 0.;
      int dir = 0;
      for (int i = 0; i < this.directionProbs.length; i++) {
        sumProb += this.directionProbs[i];
        if (r < sumProb) {
          dir = i;
          break; // found direction
        }
      }

      // get resulting position
      int[] newPos = this.moveResult(curX, curY, dir);

      // set the new position
      agent.setValue(ATTX, newPos[0]);
      agent.setValue(ATTY, newPos[1]);

      // return the state we just modified
      return s;
    }
    @Override
    public double[] generateFeatureVectorFrom(State s) {

      ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT);
      int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX);
      int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY);

      double[] vec = new double[this.getDim()];

      if (this.map[ax][ay] > 0) {
        vec[map[ax][ay] - 1] = 1.;
      }

      return vec;
    }
    @Override
    public boolean isTerminal(State s) {

      // get location of agent in next state
      ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT);
      int ax = agent.getDiscValForAttribute(ATTX);
      int ay = agent.getDiscValForAttribute(ATTY);

      // are they at goal location?
      if (ax == this.goalX && ay == this.goalY) {
        return true;
      }

      return false;
    }
    @Override
    public double reward(State s, GroundedAction a, State sprime) {

      // get location of agent in next state
      ObjectInstance agent = sprime.getFirstObjectOfClass(CLASSAGENT);
      int ax = agent.getDiscValForAttribute(ATTX);
      int ay = agent.getDiscValForAttribute(ATTY);

      // are they at goal location?
      if (ax == this.goalX && ay == this.goalY) {
        return 100.;
      }

      return -1;
    }
    @Override
    public List<TransitionProbability> getTransitions(State s, String[] params) {

      // get agent and current position
      ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT);
      int curX = agent.getDiscValForAttribute(ATTX);
      int curY = agent.getDiscValForAttribute(ATTY);

      List<TransitionProbability> tps = new ArrayList<TransitionProbability>(4);
      TransitionProbability noChangeTransition = null;
      for (int i = 0; i < this.directionProbs.length; i++) {
        int[] newPos = this.moveResult(curX, curY, i);
        if (newPos[0] != curX || newPos[1] != curY) {
          // new possible outcome
          State ns = s.copy();
          ObjectInstance nagent = ns.getFirstObjectOfClass(CLASSAGENT);
          nagent.setValue(ATTX, newPos[0]);
          nagent.setValue(ATTY, newPos[1]);

          // create transition probability object and add to our list of outcomes
          tps.add(new TransitionProbability(ns, this.directionProbs[i]));
        } else {
          // this direction didn't lead anywhere new
          // if there are existing possible directions that wouldn't lead anywhere, aggregate with
          // them
          if (noChangeTransition != null) {
            noChangeTransition.p += this.directionProbs[i];
          } else {
            // otherwise create this new state and transition
            noChangeTransition = new TransitionProbability(s.copy(), this.directionProbs[i]);
            tps.add(noChangeTransition);
          }
        }
      }

      return tps;
    }
    @Override
    public double[] generateFeatureVectorFrom(State s) {

      ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT);
      int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX);
      int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY);

      double[] vec = new double[this.getDim()];

      if (this.map[ax][ay] > 0) {
        vec[map[ax][ay] - 1] = 1.;
      }

      // now do distances
      // first seed to max val
      for (int i = this.numCells; i < vec.length; i++) {
        vec[i] = 61.;
      }

      // set goal (type 0) to its goal position assuming only 1 instance of it, so we don't scan
      // large distances for it
      if (this.gx != -1) {
        vec[this.numCells] = Math.abs(this.gx - ax) + Math.abs(this.gy - ay);
      }

      // now do scan
      for (int r = 0; r < 16; r++) {

        int x;

        // scan top
        int y = ay + r;
        if (y < 30) {
          for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        // scan bottom
        y = ay - r;
        if (y > -1) {
          for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        // scan left
        x = ax - r;
        if (x > -1) {
          for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        // scan right
        x = ax + r;
        if (x < 30) {
          for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) {
            this.updateNearest(vec, ax, ay, x, y);
          }
        }

        if (this.foundNearestForAll(vec)) {
          break;
        }
      }

      return vec;
    }
  @Override
  public void paintStatePolicy(Graphics2D g2, State s, Policy policy, float cWidth, float cHeight) {
    ObjectInstance xOb = this.xObjectInstance(s);
    ObjectInstance yOb = this.yObjectInstance(s);

    Attribute xAtt = xOb.getObjectClass().getAttribute(xAttName);
    Attribute yAtt = yOb.getObjectClass().getAttribute(yAttName);

    float domainXScale = 0f;
    float domainYScale = 0f;
    float xval = 0f;
    float yval = 0f;
    float width = 0f;
    float height = 0f;

    if (xAtt.type == Attribute.AttributeType.DISC) {

      if (this.numXCells != -1) {
        domainXScale = this.numXCells;
      } else {
        domainXScale = xAtt.discValues.size();
      }

      width = cWidth / domainXScale;
      xval = xOb.getDiscValForAttribute(xAttName) * width;
    }

    if (yAtt.type == Attribute.AttributeType.DISC) {

      if (this.numYCells != -1) {
        domainYScale = this.numYCells;
      } else {
        domainYScale = yAtt.discValues.size();
      }

      height = cHeight / domainYScale;
      yval = cHeight - height - yOb.getDiscValForAttribute(yAttName) * height;
    }

    List<ActionProb> pdist = policy.getActionDistributionForState(s);
    double maxp = 0.;
    for (ActionProb ap : pdist) {
      if (ap.pSelection > maxp) {
        maxp = ap.pSelection;
      }
    }

    if (true) {
      if (this.renderStyle == PolicyGlyphRenderStyle.MAXACTIONSOFTTIE) {
        maxp -= this.softTieDelta;
      }

      for (ActionProb ap : pdist) {
        if (ap.pSelection >= maxp) {
          ActionGlyphPainter agp = this.actionNameToGlyphPainter.get(ap.ga.actionName());
          if (agp != null) {
            agp.paintGlyph(g2, xval, yval, width, height);
          }
        }
      }

    } else {
      for (ActionProb ap : pdist) {
        float[] scaledRect =
            this.rescaleRect(xval, yval, width, height, (float) (ap.pSelection / maxp));
        ActionGlyphPainter agp = this.actionNameToGlyphPainter.get(ap.ga.actionName());
        if (agp != null) {
          agp.paintGlyph(g2, scaledRect[0], scaledRect[1], scaledRect[2], scaledRect[3]);
        }
      }
    }
  }