public static Observation makeObservationFor(POMDPDomain d, GroundedAction a, POMDPState s) { ObjectInstance indexer = s.getObject(Names.OBJ_INDEXER); int index = indexer.getDiscValForAttribute(Names.ATTR_INDEX); if (index == iterations) { return d.getObservation(Names.OBS_COMPLETE); } ObjectInstance leftDoor = s.getObject(Names.OBJ_LEFT_DOOR); ObjectInstance rightDoor = s.getObject(Names.OBJ_RIGHT_DOOR); int leftDoorTiger = leftDoor.getDiscValForAttribute(Names.ATTR_TIGERNESS); java.util.Random random = new java.util.Random(); if (a.action.getName().equals(Names.ACTION_LISTEN)) { Observation left = d.getObservation(Names.OBS_LEFT_DOOR + random.nextInt(observationsPerState)); Observation right = d.getObservation(Names.OBS_RIGHT_DOOR + random.nextInt(observationsPerState)); if (leftDoorTiger == 1) { return random.nextDouble() < 1 - noise ? left : right; } else { return random.nextDouble() < 1 - noise ? right : left; } } else { return d.getObservation(Names.OBS_NULL); } }
@Override public boolean isTrue(State s, String[] params) { ObjectInstance agent = s.getObject(params[0]); ObjectInstance location = s.getObject(params[1]); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); int lx = location.getDiscValForAttribute(ATTX); int ly = location.getDiscValForAttribute(ATTY); return ax == lx && ay == ly; }
@Override public void paintObject( Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) { // agent will be filled in blue g2.setColor(Color.BLUE); // set up floats for the width and height of our domain float fWidth = ExampleGridWorld.this.map.length; float fHeight = ExampleGridWorld.this.map[0].length; // determine the width of a single cell on our canvas // such that the whole map can be painted float width = cWidth / fWidth; float height = cHeight / fHeight; int ax = ob.getDiscValForAttribute(ATTX); int ay = ob.getDiscValForAttribute(ATTY); // left corrdinate of cell on our canvas float rx = ax * width; // top coordinate of cell on our canvas // coordinate system adjustment because the java canvas // origin is in the top left instead of the bottom right float ry = cHeight - height - ay * height; // paint the rectangle g2.fill(new Rectangle2D.Float(rx, ry, width, height)); }
@Override protected State performActionHelper(State s, String[] params) { // get agent and current position ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int curX = agent.getDiscValForAttribute(ATTX); int curY = agent.getDiscValForAttribute(ATTY); // sample directon with random roll double r = Math.random(); double sumProb = 0.; int dir = 0; for (int i = 0; i < this.directionProbs.length; i++) { sumProb += this.directionProbs[i]; if (r < sumProb) { dir = i; break; // found direction } } // get resulting position int[] newPos = this.moveResult(curX, curY, dir); // set the new position agent.setValue(ATTX, newPos[0]); agent.setValue(ATTY, newPos[1]); // return the state we just modified return s; }
@Override public double[] generateFeatureVectorFrom(State s) { ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT); int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX); int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY); double[] vec = new double[this.getDim()]; if (this.map[ax][ay] > 0) { vec[map[ax][ay] - 1] = 1.; } return vec; }
@Override public boolean isTerminal(State s) { // get location of agent in next state ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); // are they at goal location? if (ax == this.goalX && ay == this.goalY) { return true; } return false; }
@Override public double reward(State s, GroundedAction a, State sprime) { // get location of agent in next state ObjectInstance agent = sprime.getFirstObjectOfClass(CLASSAGENT); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); // are they at goal location? if (ax == this.goalX && ay == this.goalY) { return 100.; } return -1; }
@Override public List<TransitionProbability> getTransitions(State s, String[] params) { // get agent and current position ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int curX = agent.getDiscValForAttribute(ATTX); int curY = agent.getDiscValForAttribute(ATTY); List<TransitionProbability> tps = new ArrayList<TransitionProbability>(4); TransitionProbability noChangeTransition = null; for (int i = 0; i < this.directionProbs.length; i++) { int[] newPos = this.moveResult(curX, curY, i); if (newPos[0] != curX || newPos[1] != curY) { // new possible outcome State ns = s.copy(); ObjectInstance nagent = ns.getFirstObjectOfClass(CLASSAGENT); nagent.setValue(ATTX, newPos[0]); nagent.setValue(ATTY, newPos[1]); // create transition probability object and add to our list of outcomes tps.add(new TransitionProbability(ns, this.directionProbs[i])); } else { // this direction didn't lead anywhere new // if there are existing possible directions that wouldn't lead anywhere, aggregate with // them if (noChangeTransition != null) { noChangeTransition.p += this.directionProbs[i]; } else { // otherwise create this new state and transition noChangeTransition = new TransitionProbability(s.copy(), this.directionProbs[i]); tps.add(noChangeTransition); } } } return tps; }
@Override public double[] generateFeatureVectorFrom(State s) { ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT); int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX); int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY); double[] vec = new double[this.getDim()]; if (this.map[ax][ay] > 0) { vec[map[ax][ay] - 1] = 1.; } // now do distances // first seed to max val for (int i = this.numCells; i < vec.length; i++) { vec[i] = 61.; } // set goal (type 0) to its goal position assuming only 1 instance of it, so we don't scan // large distances for it if (this.gx != -1) { vec[this.numCells] = Math.abs(this.gx - ax) + Math.abs(this.gy - ay); } // now do scan for (int r = 0; r < 16; r++) { int x; // scan top int y = ay + r; if (y < 30) { for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) { this.updateNearest(vec, ax, ay, x, y); } } // scan bottom y = ay - r; if (y > -1) { for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) { this.updateNearest(vec, ax, ay, x, y); } } // scan left x = ax - r; if (x > -1) { for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) { this.updateNearest(vec, ax, ay, x, y); } } // scan right x = ax + r; if (x < 30) { for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) { this.updateNearest(vec, ax, ay, x, y); } } if (this.foundNearestForAll(vec)) { break; } } return vec; }
@Override public void paintStatePolicy(Graphics2D g2, State s, Policy policy, float cWidth, float cHeight) { ObjectInstance xOb = this.xObjectInstance(s); ObjectInstance yOb = this.yObjectInstance(s); Attribute xAtt = xOb.getObjectClass().getAttribute(xAttName); Attribute yAtt = yOb.getObjectClass().getAttribute(yAttName); float domainXScale = 0f; float domainYScale = 0f; float xval = 0f; float yval = 0f; float width = 0f; float height = 0f; if (xAtt.type == Attribute.AttributeType.DISC) { if (this.numXCells != -1) { domainXScale = this.numXCells; } else { domainXScale = xAtt.discValues.size(); } width = cWidth / domainXScale; xval = xOb.getDiscValForAttribute(xAttName) * width; } if (yAtt.type == Attribute.AttributeType.DISC) { if (this.numYCells != -1) { domainYScale = this.numYCells; } else { domainYScale = yAtt.discValues.size(); } height = cHeight / domainYScale; yval = cHeight - height - yOb.getDiscValForAttribute(yAttName) * height; } List<ActionProb> pdist = policy.getActionDistributionForState(s); double maxp = 0.; for (ActionProb ap : pdist) { if (ap.pSelection > maxp) { maxp = ap.pSelection; } } if (true) { if (this.renderStyle == PolicyGlyphRenderStyle.MAXACTIONSOFTTIE) { maxp -= this.softTieDelta; } for (ActionProb ap : pdist) { if (ap.pSelection >= maxp) { ActionGlyphPainter agp = this.actionNameToGlyphPainter.get(ap.ga.actionName()); if (agp != null) { agp.paintGlyph(g2, xval, yval, width, height); } } } } else { for (ActionProb ap : pdist) { float[] scaledRect = this.rescaleRect(xval, yval, width, height, (float) (ap.pSelection / maxp)); ActionGlyphPainter agp = this.actionNameToGlyphPainter.get(ap.ga.actionName()); if (agp != null) { agp.paintGlyph(g2, scaledRect[0], scaledRect[1], scaledRect[2], scaledRect[3]); } } } }