@Override public boolean isTrue(State s, String[] params) { ObjectInstance agent = s.getObject(params[0]); ObjectInstance location = s.getObject(params[1]); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); int lx = location.getDiscValForAttribute(ATTX); int ly = location.getDiscValForAttribute(ATTY); return ax == lx && ay == ly; }
public static State getExampleState(Domain domain) { State s = new State(); ObjectInstance agent = new ObjectInstance(domain.getObjectClass(CLASSAGENT), "agent0"); agent.setValue(ATTX, 0); agent.setValue(ATTY, 0); ObjectInstance location = new ObjectInstance(domain.getObjectClass(CLASSLOCATION), "location0"); location.setValue(ATTX, 10); location.setValue(ATTY, 10); s.addObject(agent); s.addObject(location); return s; }
@Override public List<QValue> getQs(State s) { StateHashTuple sh = this.stateHash(s); Map<String, String> matching = null; StateHashTuple indexSH = mapToStateIndex.get(sh); if (indexSH == null) { // then this is an unexplored state indexSH = sh; mapToStateIndex.put(indexSH, indexSH); } if (this.containsParameterizedActions && !this.domain.isNameDependent()) { matching = sh.s.getObjectMatchingTo(indexSH.s, false); } List<QValue> res = new ArrayList<QValue>(); for (Action a : actions) { List<GroundedAction> applications = s.getAllGroundedActionsFor(a); for (GroundedAction ga : applications) { res.add(this.getQ(sh, ga, matching)); } } return res; }
@Override protected State performActionHelper(State s, String[] params) { // get agent and current position ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int curX = agent.getDiscValForAttribute(ATTX); int curY = agent.getDiscValForAttribute(ATTY); // sample directon with random roll double r = Math.random(); double sumProb = 0.; int dir = 0; for (int i = 0; i < this.directionProbs.length; i++) { sumProb += this.directionProbs[i]; if (r < sumProb) { dir = i; break; // found direction } } // get resulting position int[] newPos = this.moveResult(curX, curY, dir); // set the new position agent.setValue(ATTX, newPos[0]); agent.setValue(ATTY, newPos[1]); // return the state we just modified return s; }
@Override public boolean equals(Object other) { if (this == other) { return true; } if (!(other instanceof State)) { return false; } State so = (State) other; if (this.numTotalObjets() != so.numTotalObjets()) { return false; } Set<String> matchedObjects = new HashSet<String>(); for (List<ObjectInstance> objects : objectIndexByTrueClass.values()) { String oclass = objects.get(0).getTrueClassName(); List<ObjectInstance> oobjects = so.getObjectsOfTrueClass(oclass); if (objects.size() != oobjects.size()) { return false; } for (ObjectInstance o : objects) { boolean foundMatch = false; for (ObjectInstance oo : oobjects) { String ooname = oo.getName(); if (matchedObjects.contains(ooname)) { continue; } if (o.valueEquals(oo)) { foundMatch = true; matchedObjects.add(ooname); break; } } if (!foundMatch) { return false; } } } return true; }
public void PlanRecipeTwoAgents(Domain domain, Recipe recipe) { System.out.println("Creating two-agent initial start state"); State state = new State(); Action mix = new MixAction(domain, recipe.topLevelIngredient); // Action bake = new BakeAction(domain); Action pour = new PourAction(domain, recipe.topLevelIngredient); Action move = new MoveAction(domain, recipe.topLevelIngredient); state.addObject(AgentFactory.getNewHumanAgentObjectInstance(domain, "human")); state.addObject(AgentFactory.getNewHumanAgentObjectInstance(domain, "robot")); state.addObject(MakeSpanFactory.getNewObjectInstance(domain, "makeSpan", 2)); List<String> containers = Arrays.asList("mixing_bowl_1"); state.addObject(SpaceFactory.getNewWorkingSpaceObjectInstance(domain, "shelf", null, null)); state.addObject( SpaceFactory.getNewWorkingSpaceObjectInstance( domain, "counter_human", containers, "human")); state.addObject( SpaceFactory.getNewWorkingSpaceObjectInstance( domain, "counter_robot", containers, "robot")); for (String container : containers) { state.addObject( ContainerFactory.getNewMixingContainerObjectInstance(domain, container, null, "shelf")); } this.PlanIngredient(domain, state, recipe.topLevelIngredient); }
@Override public double qValue(State s, GroundedAction a) { int cNodeId = s.getObjectsOfTrueClass(GraphDefinedDomain.CLASSAGENT) .get(0) .getDiscValForAttribute(GraphDefinedDomain.ATTNODE); int aId = this.actionId(a); return qInit[cNodeId][aId]; }
/** * This method computes a matching from objects in the receiver to value-identical objects in the * parameter state so. The matching is returned as a map from the object names in the receiving * state to the matched objects in state so. If enforceStateExactness is set to true, then the * returned matching will be an empty map if the two states are not OO-MDP-wise identical (i.e., * if there is a not a bijection between value-identical objects of the two states). If * enforceExactness is false and the states are not identical, the the method will return the * largest matching between objects that can be made. * * @param so the state to whose objects the receiving state's objects should be matched * @param enforceStateExactness whether to require that states are identical to return a matching * @return a matching from this receiving state's objects to objects in so that have identical * values. */ public Map<String, String> getObjectMatchingTo(State so, boolean enforceStateExactness) { Map<String, String> matching = new HashMap<String, String>(); if (this.numTotalObjets() != so.numTotalObjets() && enforceStateExactness) { return new HashMap<String, String>(); // states are not equal and therefore cannot be matched } Set<String> matchedObs = new HashSet<String>(); for (List<ObjectInstance> objects : objectIndexByTrueClass.values()) { String oclass = objects.get(0).getTrueClassName(); List<ObjectInstance> oobjects = so.getObjectsOfTrueClass(oclass); if (objects.size() != oobjects.size() && enforceStateExactness) { return new HashMap< String, String>(); // states are not equal and therefore cannot be matched } for (ObjectInstance o : objects) { boolean foundMatch = false; for (ObjectInstance oo : oobjects) { if (matchedObs.contains(oo.getName())) { continue; // already matched this one; check another } if (o.valueEquals(oo)) { foundMatch = true; matchedObs.add(oo.getName()); matching.put(o.getName(), oo.getName()); break; } } if (!foundMatch && enforceStateExactness) { return new HashMap< String, String>(); // states are not equal and therefore cannot be matched } } } return matching; }
@Override public List<TransitionProbability> getTransitions(State s, String[] params) { // get agent and current position ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int curX = agent.getDiscValForAttribute(ATTX); int curY = agent.getDiscValForAttribute(ATTY); List<TransitionProbability> tps = new ArrayList<TransitionProbability>(4); TransitionProbability noChangeTransition = null; for (int i = 0; i < this.directionProbs.length; i++) { int[] newPos = this.moveResult(curX, curY, i); if (newPos[0] != curX || newPos[1] != curY) { // new possible outcome State ns = s.copy(); ObjectInstance nagent = ns.getFirstObjectOfClass(CLASSAGENT); nagent.setValue(ATTX, newPos[0]); nagent.setValue(ATTY, newPos[1]); // create transition probability object and add to our list of outcomes tps.add(new TransitionProbability(ns, this.directionProbs[i])); } else { // this direction didn't lead anywhere new // if there are existing possible directions that wouldn't lead anywhere, aggregate with // them if (noChangeTransition != null) { noChangeTransition.p += this.directionProbs[i]; } else { // otherwise create this new state and transition noChangeTransition = new TransitionProbability(s.copy(), this.directionProbs[i]); tps.add(noChangeTransition); } } } return tps; }
@Override public boolean isTerminal(State s) { // get location of agent in next state ObjectInstance agent = s.getFirstObjectOfClass(CLASSAGENT); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); // are they at goal location? if (ax == this.goalX && ay == this.goalY) { return true; } return false; }
@Override public double reward(State s, GroundedAction a, State sprime) { // get location of agent in next state ObjectInstance agent = sprime.getFirstObjectOfClass(CLASSAGENT); int ax = agent.getDiscValForAttribute(ATTX); int ay = agent.getDiscValForAttribute(ATTY); // are they at goal location? if (ax == this.goalX && ay == this.goalY) { return 100.; } return -1; }
@Override public double[] generateFeatureVectorFrom(State s) { ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT); int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX); int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY); double[] vec = new double[this.getDim()]; if (this.map[ax][ay] > 0) { vec[map[ax][ay] - 1] = 1.; } return vec; }
@Override public List<TransitionProbability> getTransitions(State s, String[] params) { State nextState = performActionHelper(s, params); nextState.getObject(Names.OBJ_LEFT_DOOR).setValue(Names.ATTR_TIGERNESS, 1); nextState.getObject(Names.OBJ_RIGHT_DOOR).setValue(Names.ATTR_TIGERNESS, 0); List<TransitionProbability> TPList = new ArrayList<TransitionProbability>(); TPList.add(new TransitionProbability(nextState, 0.5)); State nextState1 = performActionHelper(s, params); nextState1.getObject(Names.OBJ_LEFT_DOOR).setValue(Names.ATTR_TIGERNESS, 0); nextState1.getObject(Names.OBJ_RIGHT_DOOR).setValue(Names.ATTR_TIGERNESS, 1); TPList.add(new TransitionProbability(nextState1, 0.5)); return TPList; }
public static boolean isTerminal(State s) { return s.getObject(Names.OBJ_INDEXER).getDiscValForAttribute(Names.ATTR_INDEX) == iterations; }
@Override public double[] generateFeatureVectorFrom(State s) { ObjectInstance agent = s.getFirstObjectOfClass(GridWorldDomain.CLASSAGENT); int ax = agent.getDiscValForAttribute(GridWorldDomain.ATTX); int ay = agent.getDiscValForAttribute(GridWorldDomain.ATTY); double[] vec = new double[this.getDim()]; if (this.map[ax][ay] > 0) { vec[map[ax][ay] - 1] = 1.; } // now do distances // first seed to max val for (int i = this.numCells; i < vec.length; i++) { vec[i] = 61.; } // set goal (type 0) to its goal position assuming only 1 instance of it, so we don't scan // large distances for it if (this.gx != -1) { vec[this.numCells] = Math.abs(this.gx - ax) + Math.abs(this.gy - ay); } // now do scan for (int r = 0; r < 16; r++) { int x; // scan top int y = ay + r; if (y < 30) { for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) { this.updateNearest(vec, ax, ay, x, y); } } // scan bottom y = ay - r; if (y > -1) { for (x = Math.max(ax - r, 0); x <= Math.min(ax + r, 29); x++) { this.updateNearest(vec, ax, ay, x, y); } } // scan left x = ax - r; if (x > -1) { for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) { this.updateNearest(vec, ax, ay, x, y); } } // scan right x = ax + r; if (x < 30) { for (y = Math.max(ay - r, 0); y <= Math.min(ay + r, 29); y++) { this.updateNearest(vec, ax, ay, x, y); } } if (this.foundNearestForAll(vec)) { break; } } return vec; }
/** * Returns the object instance in a state that holds the y-position information. * * @param s the state for which to get the y-position * @return the object instance in a state that holds the y-position information. */ protected ObjectInstance yObjectInstance(State s) { if (this.yClassName != null) { return s.getFirstObjectOfClass(yClassName); } return s.getObject(yObjectName); }
public String stateToString(State s) { int ldt = s.getObject(Names.OBJ_LEFT_DOOR).getDiscValForAttribute(Names.ATTR_TIGERNESS); return ldt == 1 ? "<TIGER LEFT>" : "<TIGER RIGHT>"; }
public State PlanIngredient(Domain domain, State startingState, IngredientRecipe ingredient) { State currentState = new State(startingState); List<IngredientRecipe> contents = ingredient.getContents(); for (IngredientRecipe subIngredient : contents) { if (!subIngredient.isSimple()) { System.out.println("Planning ingredient " + subIngredient.getName()); currentState = this.PlanIngredient(domain, currentState, subIngredient); } } ObjectClass simpleIngredientClass = domain.getObjectClass(IngredientFactory.ClassNameSimple); ObjectClass containerClass = domain.getObjectClass(ContainerFactory.ClassName); ObjectInstance shelfSpace = currentState.getObject("shelf"); List<ObjectInstance> ingredientInstances = IngredientFactory.getSimpleIngredients(simpleIngredientClass, ingredient); List<ObjectInstance> containerInstances = Recipe.getContainers(containerClass, ingredientInstances, shelfSpace.getName()); for (ObjectInstance ingredientInstance : ingredientInstances) { if (currentState.getObject(ingredientInstance.getName()) == null) { currentState.addObject(ingredientInstance); } } for (ObjectInstance containerInstance : containerInstances) { if (currentState.getObject(containerInstance.getName()) == null) { ContainerFactory.changeContainerSpace(containerInstance, shelfSpace.getName()); currentState.addObject(containerInstance); } } final PropositionalFunction isSuccess = new RecipeFinished("success", domain, ingredient); PropositionalFunction isFailure = new RecipeBotched("botched", domain, ingredient); // RewardFunction recipeRewardFunction = new RecipeRewardFunction(brownies); // RewardFunction recipeRewardFunction = new RecipeRewardFunction(); RewardFunction humanRewardFunction = new RecipeAgentSpecificMakeSpanRewardFunction("human"); RewardFunction robotRewardFunction = new RecipeAgentSpecificMakeSpanRewardFunction("robot"); TerminalFunction recipeTerminalFunction = new RecipeTerminalFunction(isSuccess, isFailure); StateHashFactory hashFactory = new NameDependentStateHashFactory(); StateConditionTest goalCondition = new StateConditionTest() { @Override public boolean satisfies(State s) { return isSuccess.somePFGroundingIsTrue(s); } }; // final int numSteps = Recipe.getNumberSteps(ingredient); Heuristic heuristic = new Heuristic() { @Override public double h(State state) { return 0; // List<ObjectInstance> objects = // state.getObjectsOfTrueClass(Recipe.ComplexIngredient.className); // double max = 0; // for (ObjectInstance object : objects) // { // max = Math.max(max, this.getSubIngredients(state, object)); // } // return numSteps - max; } /* public int getSubIngredients(State state, ObjectInstance object) { int count = 0; count += IngredientFactory.isBakedIngredient(object) ? 1 : 0; count += IngredientFactory.isMixedIngredient(object) ? 1 : 0; count += IngredientFactory.isMeltedIngredient(object) ? 1 : 0; if (IngredientFactory.isSimple(object)) { return count; } Set<String> contents = IngredientFactory.getContentsForIngredient(object); for (String str: contents) { count += this.getSubIngredients(state, state.getObject(str)); } return count; }*/ }; boolean finished = false; State endState = startingState; List<GroundedAction> fullActions = new ArrayList<GroundedAction>(); List<Double> fullReward = new ArrayList<Double>(); boolean currentAgent = false; while (!finished) { currentAgent = !currentAgent; RewardFunction recipeRewardFunction = (currentAgent) ? humanRewardFunction : robotRewardFunction; AStar aStar = new AStar(domain, recipeRewardFunction, goalCondition, hashFactory, heuristic); aStar.planFromState(currentState); Policy policy = new DDPlannerPolicy(aStar); EpisodeAnalysis episodeAnalysis = policy.evaluateBehavior(currentState, recipeRewardFunction, recipeTerminalFunction); System.out.println("Taking action " + episodeAnalysis.actionSequence.get(0).action.getName()); fullActions.add(episodeAnalysis.actionSequence.get(0)); fullReward.add(episodeAnalysis.rewardSequence.get(0)); currentState = episodeAnalysis.stateSequence.get(1); endState = episodeAnalysis.getState(episodeAnalysis.stateSequence.size() - 1); List<ObjectInstance> finalObjects = new ArrayList<ObjectInstance>( endState.getObjectsOfTrueClass(IngredientFactory.ClassNameComplex)); List<ObjectInstance> containerObjects = new ArrayList<ObjectInstance>(endState.getObjectsOfTrueClass(ContainerFactory.ClassName)); ObjectInstance namedIngredient = null; for (ObjectInstance obj : finalObjects) { if (Recipe.isSuccess(endState, ingredient, obj)) { namedIngredient = DualAgentIndependentPlan.getNewNamedComplexIngredient(obj, ingredient.getName()); String container = IngredientFactory.getContainer(obj); DualAgentIndependentPlan.switchContainersIngredients( containerObjects, obj, namedIngredient); ObjectInstance containerInstance = endState.getObject(container); ContainerFactory.removeContents(containerInstance); ContainerFactory.addIngredient(containerInstance, ingredient.getName()); endState.removeObject(obj); endState.addObject(namedIngredient); // return endState; } } if (episodeAnalysis.actionSequence.size() <= 1) { System.out.println("Action sequence size: " + episodeAnalysis.actionSequence.size()); finished = true; } for (int i = 0; i < fullActions.size(); ++i) { GroundedAction action = fullActions.get(i); double reward = fullReward.get(i); System.out.print("Cost: " + reward + " " + action.action.getName() + " "); for (int j = 0; j < action.params.length; ++j) { System.out.print(action.params[j] + " "); } System.out.print("\n"); } } return endState; }