@Override public void paintObject( Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) { float domainXScale = Sokoban2Domain.maxRoomXExtent(s) + 1f; float domainYScale = Sokoban2Domain.maxRoomYExtent(s) + 1f; if (maxX != -1) { domainXScale = maxX; domainYScale = maxY; } // determine then normalized width float width = (1.0f / domainXScale) * cWidth; float height = (1.0f / domainYScale) * cHeight; int x = ob.getIntValForAttribute(Sokoban2Domain.ATTX); int y = ob.getIntValForAttribute(Sokoban2Domain.ATTY); float rx = x * width; float ry = cHeight - height - y * height; String dir = null; Attribute dirAtt = ob.getObjectClass().getAttribute(Sokoban2Domain.ATTDIR); if (dirAtt != null) { dir = ob.getStringValForAttribute(Sokoban2Domain.ATTDIR); } else { dir = "south"; } BufferedImage img = this.dirToImage.get(dir); g2.drawImage(img, (int) rx, (int) ry, (int) width, (int) height, this); }
@Override protected State performActionHelper(State s, String[] params) { // get agent and current position ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int curX = agent.getDiscValForAttribute(ATTX); int curY = agent.getDiscValForAttribute(ATTY); // sample directon with random roll double r = Math.random(); double sumProb = 0.; int dir = 0; for (int i = 0; i < this.directionProbs.length; i++) { sumProb += this.directionProbs[i]; if (r < sumProb) { dir = i; break; // found direction } } // get resulting position int[] newPos = this.moveResult(curX, curY, dir); // set the new position agent.setValue(ATTX, newPos[0]); agent.setValue(ATTY, newPos[1]); // return the state we just modified return s; }
@Override public void paintObject( Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) { float domainXScale = Sokoban2Domain.maxRoomXExtent(s) + 1f; float domainYScale = Sokoban2Domain.maxRoomYExtent(s) + 1f; if (maxX != -1) { domainXScale = maxX; domainYScale = maxY; } // determine then normalized width float width = (1.0f / domainXScale) * cWidth; float height = (1.0f / domainYScale) * cHeight; int top = ob.getIntValForAttribute(Sokoban2Domain.ATTTOP); int left = ob.getIntValForAttribute(Sokoban2Domain.ATTLEFT); int bottom = ob.getIntValForAttribute(Sokoban2Domain.ATTBOTTOM); int right = ob.getIntValForAttribute(Sokoban2Domain.ATTRIGHT); g2.setColor(Color.white); for (int i = left; i <= right; i++) { for (int j = bottom; j <= top; j++) { float rx = i * width; float ry = cHeight - height - j * height; g2.fill(new Rectangle2D.Float(rx, ry, width, height)); } } }
public static Observation makeObservationFor(POMDPDomain d, GroundedAction a, POMDPState s) { ObjectInstance indexer = s.getObject(Names.OBJ_INDEXER); int index = indexer.getDiscValForAttribute(Names.ATTR_INDEX); if (index == iterations) { return d.getObservation(Names.OBS_COMPLETE); } ObjectInstance leftDoor = s.getObject(Names.OBJ_LEFT_DOOR); ObjectInstance rightDoor = s.getObject(Names.OBJ_RIGHT_DOOR); int leftDoorTiger = leftDoor.getDiscValForAttribute(Names.ATTR_TIGERNESS); java.util.Random random = new java.util.Random(); if (a.action.getName().equals(Names.ACTION_LISTEN)) { Observation left = d.getObservation(Names.OBS_LEFT_DOOR + random.nextInt(observationsPerState)); Observation right = d.getObservation(Names.OBS_RIGHT_DOOR + random.nextInt(observationsPerState)); if (leftDoorTiger == 1) { return random.nextDouble() < 1 - noise ? left : right; } else { return random.nextDouble() < 1 - noise ? right : left; } } else { return d.getObservation(Names.OBS_NULL); } }
@Override public boolean isTrue(State s, String[] params) { ObjectInstance agent = s.getObject(params[0]); ObjectInstance location = s.getObject(params[1]); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); int lx = location.getDiscValForAttribute(ATTX); int ly = location.getDiscValForAttribute(ATTY); return ax == lx && ay == ly; }
public static POMDPState getNewState(Domain d) { POMDPState s = new POMDPState(); ObjectClass doorClass = d.getObjectClass(Names.CLASS_DOOR); ObjectClass indexerClass = d.getObjectClass(Names.CLASS_INDEXER); ObjectInstance indexer = new ObjectInstance(indexerClass, Names.OBJ_INDEXER); indexer.setValue(Names.ATTR_INDEX, 0); s.addObject(indexer); ObjectInstance leftDoor = new ObjectInstance(doorClass, Names.OBJ_LEFT_DOOR); ObjectInstance rightDoor = new ObjectInstance(doorClass, Names.OBJ_RIGHT_DOOR); leftDoor.setValue(Names.ATTR_POSITION, Names.LEFT); rightDoor.setValue(Names.ATTR_POSITION, Names.RIGHT); boolean doorChoice = new java.util.Random().nextBoolean(); leftDoor.setValue(Names.ATTR_TIGERNESS, doorChoice ? 0 : 1); rightDoor.setValue(Names.ATTR_TIGERNESS, doorChoice ? 1 : 0); s.addObject(leftDoor); s.addObject(rightDoor); return s; }
@Override public double reward(State s, GroundedAction a, State sprime) { // get location of agent in next state ObjectInstance agent = sprime.getFirstObjectOfClass(CLASSAGENT); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); // are they at goal location? if (ax == this.goalX && ay == this.goalY) { return 100.; } return -1; }
public static State getExampleState(Domain domain) { State s = new State(); ObjectInstance agent = new ObjectInstance(domain.getObjectClass(CLASSAGENT), "agent0"); agent.setValue(ATTX, 0); agent.setValue(ATTY, 0); ObjectInstance location = new ObjectInstance(domain.getObjectClass(CLASSLOCATION), "location0"); location.setValue(ATTX, 10); location.setValue(ATTY, 10); s.addObject(agent); s.addObject(location); return s; }
@Override public double[] generateFeatureVectorFrom(State s) { ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT); int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX); int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY); double[] vec = new double[this.getDim()]; if (this.map[ax][ay] > 0) { vec[map[ax][ay] - 1] = 1.; } return vec; }
@Override public boolean isTerminal(State s) { // get location of agent in next state ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); // are they at goal location? if (ax == this.goalX && ay == this.goalY) { return true; } return false; }
@Override public void paintObject( Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) { g2.setColor(Color.darkGray); float domainXScale = Sokoban2Domain.maxRoomXExtent(s) + 1f; float domainYScale = Sokoban2Domain.maxRoomYExtent(s) + 1f; if (maxX != -1) { domainXScale = maxX; domainYScale = maxY; } // determine then normalized width float width = (1.0f / domainXScale) * cWidth; float height = (1.0f / domainYScale) * cHeight; int x = ob.getIntValForAttribute(Sokoban2Domain.ATTX); int y = ob.getIntValForAttribute(Sokoban2Domain.ATTY); float rx = x * width; float ry = cHeight - height - y * height; g2.fill(new Rectangle2D.Float(rx, ry, width, height)); }
@Override public void paintObject( Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) { // agent will be filled in blue g2.setColor(Color.BLUE); // set up floats for the width and height of our domain float fWidth = ExampleGridWorld.this.map.length; float fHeight = ExampleGridWorld.this.map[0].length; // determine the width of a single cell on our canvas // such that the whole map can be painted float width = cWidth / fWidth; float height = cHeight / fHeight; int ax = ob.getDiscValForAttribute(ATTX); int ay = ob.getDiscValForAttribute(ATTY); // left corrdinate of cell on our canvas float rx = ax * width; // top coordinate of cell on our canvas // coordinate system adjustment because the java canvas // origin is in the top left instead of the bottom right float ry = cHeight - height - ay * height; // paint the rectangle g2.fill(new Rectangle2D.Float(rx, ry, width, height)); }
@Override public void paintObject( Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) { float domainXScale = Sokoban2Domain.maxRoomXExtent(s) + 1f; float domainYScale = Sokoban2Domain.maxRoomYExtent(s) + 1f; if (maxX != -1) { domainXScale = maxX; domainYScale = maxY; } // determine then normalized width float width = (1.0f / domainXScale) * cWidth; float height = (1.0f / domainYScale) * cHeight; int top = ob.getIntValForAttribute(Sokoban2Domain.ATTTOP); int left = ob.getIntValForAttribute(Sokoban2Domain.ATTLEFT); int bottom = ob.getIntValForAttribute(Sokoban2Domain.ATTBOTTOM); int right = ob.getIntValForAttribute(Sokoban2Domain.ATTRIGHT); Color rcol = colorForName(ob.getStringValForAttribute(Sokoban2Domain.ATTCOLOR)); float[] hsb = new float[3]; Color.RGBtoHSB(rcol.getRed(), rcol.getGreen(), rcol.getBlue(), hsb); hsb[1] = 0.4f; rcol = Color.getHSBColor(hsb[0], hsb[1], hsb[2]); for (int i = left; i <= right; i++) { for (int j = bottom; j <= top; j++) { float rx = i * width; float ry = cHeight - height - j * height; if (i == left || i == right || j == bottom || j == top) { if (Sokoban2Domain.doorContainingPoint(s, i, j) == null) { g2.setColor(Color.black); g2.fill(new Rectangle2D.Float(rx, ry, width, height)); } } else { g2.setColor(rcol); g2.fill(new Rectangle2D.Float(rx, ry, width, height)); } } } }
@Override public void paintObject( Graphics2D g2, State s, ObjectInstance ob, float cWidth, float cHeight) { float domainXScale = Sokoban2Domain.maxRoomXExtent(s) + 1f; float domainYScale = Sokoban2Domain.maxRoomYExtent(s) + 1f; if (maxX != -1) { domainXScale = maxX; domainYScale = maxY; } // determine then normalized width float width = (1.0f / domainXScale) * cWidth; float height = (1.0f / domainYScale) * cHeight; int x = ob.getIntValForAttribute(Sokoban2Domain.ATTX); int y = ob.getIntValForAttribute(Sokoban2Domain.ATTY); float rx = x * width; float ry = cHeight - height - y * height; String colName = ob.getStringValForAttribute(Sokoban2Domain.ATTCOLOR); String shapeName = ob.getStringValForAttribute(Sokoban2Domain.ATTSHAPE); String key = this.shapeKey(shapeName, colName); BufferedImage img = this.shapeAndColToImages.get(key); if (img == null) { Color col = colorForName(ob.getStringValForAttribute(Sokoban2Domain.ATTCOLOR)).darker(); g2.setColor(col); g2.fill(new Rectangle2D.Float(rx, ry, width, height)); } else { g2.drawImage(img, (int) rx, (int) ry, (int) width, (int) height, this); } }
@Override public List<TransitionProbability> getTransitions(State s, String[] params) { // get agent and current position ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int curX = agent.getDiscValForAttribute(ATTX); int curY = agent.getDiscValForAttribute(ATTY); List<TransitionProbability> tps = new ArrayList<TransitionProbability>(4); TransitionProbability noChangeTransition = null; for (int i = 0; i < this.directionProbs.length; i++) { int[] newPos = this.moveResult(curX, curY, i); if (newPos[0] != curX || newPos[1] != curY) { // new possible outcome State ns = s.copy(); ObjectInstance nagent = ns.getFirstObjectOfClass(CLASSAGENT); nagent.setValue(ATTX, newPos[0]); nagent.setValue(ATTY, newPos[1]); // create transition probability object and add to our list of outcomes tps.add(new TransitionProbability(ns, this.directionProbs[i])); } else { // this direction didn't lead anywhere new // if there are existing possible directions that wouldn't lead anywhere, aggregate with // them if (noChangeTransition != null) { noChangeTransition.p += this.directionProbs[i]; } else { // otherwise create this new state and transition noChangeTransition = new TransitionProbability(s.copy(), this.directionProbs[i]); tps.add(noChangeTransition); } } } return tps; }
public State PlanIngredient(Domain domain, State startingState, IngredientRecipe ingredient) { State currentState = new State(startingState); List<IngredientRecipe> contents = ingredient.getContents(); for (IngredientRecipe subIngredient : contents) { if (!subIngredient.isSimple()) { System.out.println("Planning ingredient " + subIngredient.getName()); currentState = this.PlanIngredient(domain, currentState, subIngredient); } } ObjectClass simpleIngredientClass = domain.getObjectClass(IngredientFactory.ClassNameSimple); ObjectClass containerClass = domain.getObjectClass(ContainerFactory.ClassName); ObjectInstance shelfSpace = currentState.getObject("shelf"); List<ObjectInstance> ingredientInstances = IngredientFactory.getSimpleIngredients(simpleIngredientClass, ingredient); List<ObjectInstance> containerInstances = Recipe.getContainers(containerClass, ingredientInstances, shelfSpace.getName()); for (ObjectInstance ingredientInstance : ingredientInstances) { if (currentState.getObject(ingredientInstance.getName()) == null) { currentState.addObject(ingredientInstance); } } for (ObjectInstance containerInstance : containerInstances) { if (currentState.getObject(containerInstance.getName()) == null) { ContainerFactory.changeContainerSpace(containerInstance, shelfSpace.getName()); currentState.addObject(containerInstance); } } final PropositionalFunction isSuccess = new RecipeFinished("success", domain, ingredient); PropositionalFunction isFailure = new RecipeBotched("botched", domain, ingredient); // RewardFunction recipeRewardFunction = new RecipeRewardFunction(brownies); // RewardFunction recipeRewardFunction = new RecipeRewardFunction(); RewardFunction humanRewardFunction = new RecipeAgentSpecificMakeSpanRewardFunction("human"); RewardFunction robotRewardFunction = new RecipeAgentSpecificMakeSpanRewardFunction("robot"); TerminalFunction recipeTerminalFunction = new RecipeTerminalFunction(isSuccess, isFailure); StateHashFactory hashFactory = new NameDependentStateHashFactory(); StateConditionTest goalCondition = new StateConditionTest() { @Override public boolean satisfies(State s) { return isSuccess.somePFGroundingIsTrue(s); } }; // final int numSteps = Recipe.getNumberSteps(ingredient); Heuristic heuristic = new Heuristic() { @Override public double h(State state) { return 0; // List<ObjectInstance> objects = // state.getObjectsOfTrueClass(Recipe.ComplexIngredient.className); // double max = 0; // for (ObjectInstance object : objects) // { // max = Math.max(max, this.getSubIngredients(state, object)); // } // return numSteps - max; } /* public int getSubIngredients(State state, ObjectInstance object) { int count = 0; count += IngredientFactory.isBakedIngredient(object) ? 1 : 0; count += IngredientFactory.isMixedIngredient(object) ? 1 : 0; count += IngredientFactory.isMeltedIngredient(object) ? 1 : 0; if (IngredientFactory.isSimple(object)) { return count; } Set<String> contents = IngredientFactory.getContentsForIngredient(object); for (String str: contents) { count += this.getSubIngredients(state, state.getObject(str)); } return count; }*/ }; boolean finished = false; State endState = startingState; List<GroundedAction> fullActions = new ArrayList<GroundedAction>(); List<Double> fullReward = new ArrayList<Double>(); boolean currentAgent = false; while (!finished) { currentAgent = !currentAgent; RewardFunction recipeRewardFunction = (currentAgent) ? humanRewardFunction : robotRewardFunction; AStar aStar = new AStar(domain, recipeRewardFunction, goalCondition, hashFactory, heuristic); aStar.planFromState(currentState); Policy policy = new DDPlannerPolicy(aStar); EpisodeAnalysis episodeAnalysis = policy.evaluateBehavior(currentState, recipeRewardFunction, recipeTerminalFunction); System.out.println("Taking action " + episodeAnalysis.actionSequence.get(0).action.getName()); fullActions.add(episodeAnalysis.actionSequence.get(0)); fullReward.add(episodeAnalysis.rewardSequence.get(0)); currentState = episodeAnalysis.stateSequence.get(1); endState = episodeAnalysis.getState(episodeAnalysis.stateSequence.size() - 1); List<ObjectInstance> finalObjects = new ArrayList<ObjectInstance>( endState.getObjectsOfTrueClass(IngredientFactory.ClassNameComplex)); List<ObjectInstance> containerObjects = new ArrayList<ObjectInstance>(endState.getObjectsOfTrueClass(ContainerFactory.ClassName)); ObjectInstance namedIngredient = null; for (ObjectInstance obj : finalObjects) { if (Recipe.isSuccess(endState, ingredient, obj)) { namedIngredient = DualAgentIndependentPlan.getNewNamedComplexIngredient(obj, ingredient.getName()); String container = IngredientFactory.getContainer(obj); DualAgentIndependentPlan.switchContainersIngredients( containerObjects, obj, namedIngredient); ObjectInstance containerInstance = endState.getObject(container); ContainerFactory.removeContents(containerInstance); ContainerFactory.addIngredient(containerInstance, ingredient.getName()); endState.removeObject(obj); endState.addObject(namedIngredient); // return endState; } } if (episodeAnalysis.actionSequence.size() <= 1) { System.out.println("Action sequence size: " + episodeAnalysis.actionSequence.size()); finished = true; } for (int i = 0; i < fullActions.size(); ++i) { GroundedAction action = fullActions.get(i); double reward = fullReward.get(i); System.out.print("Cost: " + reward + " " + action.action.getName() + " "); for (int j = 0; j < action.params.length; ++j) { System.out.print(action.params[j] + " "); } System.out.print("\n"); } } return endState; }
@Override public double[] generateFeatureVectorFrom(State s) { ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT); int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX); int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY); double[] vec = new double[this.getDim()]; if (this.map[ax][ay] > 0) { vec[map[ax][ay] - 1] = 1.; } // now do distances // first seed to max val for (int i = this.numCells; i < vec.length; i++) { vec[i] = 61.; } // set goal (type 0) to its goal position assuming only 1 instance of it, so we don't scan // large distances for it if (this.gx != -1) { vec[this.numCells] = Math.abs(this.gx - ax) + Math.abs(this.gy - ay); } // now do scan for (int r = 0; r < 16; r++) { int x; // scan top int y = ay + r; if (y < 30) { for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) { this.updateNearest(vec, ax, ay, x, y); } } // scan bottom y = ay - r; if (y > -1) { for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) { this.updateNearest(vec, ax, ay, x, y); } } // scan left x = ax - r; if (x > -1) { for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) { this.updateNearest(vec, ax, ay, x, y); } } // scan right x = ax + r; if (x < 30) { for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) { this.updateNearest(vec, ax, ay, x, y); } } if (this.foundNearestForAll(vec)) { break; } } return vec; }
@Override public void paintStatePolicy(Graphics2D g2, State s, Policy policy, float cWidth, float cHeight) { ObjectInstance xOb = this.xObjectInstance(s); ObjectInstance yOb = this.yObjectInstance(s); Attribute xAtt = xOb.getObjectClass().getAttribute(xAttName); Attribute yAtt = yOb.getObjectClass().getAttribute(yAttName); float domainXScale = 0f; float domainYScale = 0f; float xval = 0f; float yval = 0f; float width = 0f; float height = 0f; if (xAtt.type == Attribute.AttributeType.DISC) { if (this.numXCells != -1) { domainXScale = this.numXCells; } else { domainXScale = xAtt.discValues.size(); } width = cWidth / domainXScale; xval = xOb.getDiscValForAttribute(xAttName) * width; } if (yAtt.type == Attribute.AttributeType.DISC) { if (this.numYCells != -1) { domainYScale = this.numYCells; } else { domainYScale = yAtt.discValues.size(); } height = cHeight / domainYScale; yval = cHeight - height - yOb.getDiscValForAttribute(yAttName) * height; } List<ActionProb> pdist = policy.getActionDistributionForState(s); double maxp = 0.; for (ActionProb ap : pdist) { if (ap.pSelection > maxp) { maxp = ap.pSelection; } } if (true) { if (this.renderStyle == PolicyGlyphRenderStyle.MAXACTIONSOFTTIE) { maxp -= this.softTieDelta; } for (ActionProb ap : pdist) { if (ap.pSelection >= maxp) { ActionGlyphPainter agp = this.actionNameToGlyphPainter.get(ap.ga.actionName()); if (agp != null) { agp.paintGlyph(g2, xval, yval, width, height); } } } } else { for (ActionProb ap : pdist) { float[] scaledRect = this.rescaleRect(xval, yval, width, height, (float) (ap.pSelection / maxp)); ActionGlyphPainter agp = this.actionNameToGlyphPainter.get(ap.ga.actionName()); if (agp != null) { agp.paintGlyph(g2, scaledRect[0], scaledRect[1], scaledRect[2], scaledRect[3]); } } } }