void resetState() { state.init( nonFluents != null ? nonFluents._hmObjects : null, instance._hmObjects, domain._hmTypes, domain._hmPVariables, domain._hmCPF, instance._alInitState, nonFluents == null ? null : nonFluents._alNonFluents, domain._alStateConstraints, domain._exprReward, instance._nNonDefActions); if ((domain._bPartiallyObserved && state._alObservNames.size() == 0) || (!domain._bPartiallyObserved && state._alObservNames.size() > 0)) System.err.println( "Domain '" + domain._sDomainName + "' partially observed flag and presence of observations mismatched."); }
public String getStateDescription(State s) throws EvalException { StringBuilder sb = new StringBuilder(); TYPE_NAME xpos_type = new TYPE_NAME("xpos"); ArrayList<LCONST> list_xpos = s._hmObject2Consts.get(xpos_type); TYPE_NAME ypos_type = new TYPE_NAME("ypos"); ArrayList<LCONST> list_ypos = s._hmObject2Consts.get(ypos_type); PVAR_NAME GOAL = new PVAR_NAME("GOAL"); PVAR_NAME robot_at = new PVAR_NAME("robot-at"); PVAR_NAME P = new PVAR_NAME("P"); if (_bd == null) { IS_OBFUSCATED = !list_xpos.get(0).toString().equals("x1"); int max_row = list_ypos.size() - 1; int max_col = list_xpos.size(); _bd = new BlockDisplay( "RDDL Navigation Simulation", "RDDL Navigation Simulation", max_row + 2, max_col + 2); } // Set up an arity-1 parameter list ArrayList<LCONST> params = new ArrayList<LCONST>(2); params.add(null); params.add(null); _bd.clearAllCells(); _bd.clearAllLines(); for (LCONST xpos : list_xpos) { for (LCONST ypos : list_ypos) { int col = new Integer(xpos.toString().substring(2, xpos.toString().length())); int row = new Integer(ypos.toString().substring(2, ypos.toString().length())); if (IS_OBFUSCATED) { row = (int) Math.sqrt(row - 11); col = (int) Math.sqrt(col - 5); } row = row - 1; params.set(0, xpos); params.set(1, ypos); boolean is_goal = (Boolean) s.getPVariableAssign(GOAL, params); boolean robot = (Boolean) s.getPVariableAssign(robot_at, params); float prob = 1f - ((Number) s.getPVariableAssign(P, params)).floatValue(); if (robot && is_goal) _bd.setCell(row, col, Color.red, "G!"); else if (is_goal) _bd.setCell(row, col, Color.cyan, "G"); else if (robot) _bd.setCell(row, col, Color.blue, null); else { Color cell_color = new Color(prob, prob, prob); _bd.setCell(row, col, cell_color, null); } } } _bd.repaint(); // Sleep so the animation can be viewed at a frame rate of 1000/_nTimeDelay per second try { Thread.currentThread().sleep(_nTimeDelay); } catch (InterruptedException e) { System.err.println(e); e.printStackTrace(System.err); } return sb.toString(); }
public ArrayList<PVAR_INST_DEF> getActions(State s) throws EvalException { ArrayList<PVAR_INST_DEF> actions = new ArrayList<PVAR_INST_DEF>(); ArrayList<PVAR_NAME> action_types = s._hmTypeMap.get("action-fluent"); boolean passed_constraints = false; for (int j = 0; j < s._alActionNames.size(); j++) { // Get a random action PVAR_NAME p = s._alActionNames.get(j); PVARIABLE_DEF pvar_def = s._hmPVariables.get(p); // Get term instantations for that action and select *one* ArrayList<ArrayList<LCONST>> inst = s.generateAtoms(p); int[] index_permutation = Permutation.permute(inst.size(), _random); for (int i = 0; i < index_permutation.length; i++) { ArrayList<LCONST> terms = inst.get(index_permutation[i]); // IMPORTANT: get random assignment that matches action type Object value = null; if (pvar_def._sRange.equals(RDDL.TYPE_NAME.BOOL_TYPE)) { // bool value = new Boolean(true); } else if (pvar_def._sRange.equals(RDDL.TYPE_NAME.INT_TYPE)) { // int value = new Integer(_random.nextInt(MAX_INT_VALUE)); } else if (pvar_def._sRange.equals(RDDL.TYPE_NAME.REAL_TYPE)) { // real value = new Double(_random.nextDouble() * MAX_REAL_VALUE); } else { // enum: only other option for a range ENUM_TYPE_DEF enum_type_def = (ENUM_TYPE_DEF) s._hmTypes.get(pvar_def._sRange); int rand_index = _random.nextInt(enum_type_def._alPossibleValues.size()); value = enum_type_def._alPossibleValues.get(rand_index); } // Now set the action actions.add(new PVAR_INST_DEF(p._sPVarName, value, terms)); passed_constraints = true; try { s.checkStateActionConstraints(actions); } catch (EvalException e) { // Got an eval exception, constraint violated passed_constraints = false; // System.out.println(actions + " : " + e); // System.out.println(s); // System.exit(1); } catch (Exception e) { // Got a real exception, something is wrong System.out.println( "\nERROR evaluating constraint on action set: " + actions + /*"\nConstraint: " +*/ e + "\n"); e.printStackTrace(); throw new EvalException(e.toString()); } if (!passed_constraints) actions.remove(actions.size() - 1); if (actions.size() == NUM_CONCURRENT_ACTIONS) break; } if (actions.size() == NUM_CONCURRENT_ACTIONS) break; } // Check if no single action passed constraint if (!passed_constraints) { // Try empty action passed_constraints = true; actions.clear(); try { s.checkStateActionConstraints(actions); } catch (EvalException e) { passed_constraints = false; System.out.println(actions + " : " + e); throw new EvalException("No actions (even a) satisfied state constraints!"); } } // Return the action list // System.out.println("**Action: " + actions); return actions; }
static String createXMLTurn( State state, int turn, DOMAIN domain, HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observStore) throws Exception { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); try { DocumentBuilder db = dbf.newDocumentBuilder(); Document dom = db.newDocument(); Element rootEle = dom.createElement(TURN); dom.appendChild(rootEle); Element turnNum = dom.createElement(TURN_NUM); Text textTurnNum = dom.createTextNode(turn + ""); turnNum.appendChild(textTurnNum); rootEle.appendChild(turnNum); // System.out.println("PO: " + domain._bPartiallyObserved); if (!domain._bPartiallyObserved || observStore != null) { for (PVAR_NAME pn : (domain._bPartiallyObserved ? observStore.keySet() : state._state.keySet())) { // System.out.println(turn + " check2 Partial Observ " + pn +" : "+ // domain._bPartiallyObserved); // No problem to overwrite observations, only ever read from if (domain._bPartiallyObserved && observStore != null) state._observ.put(pn, observStore.get(pn)); ArrayList<ArrayList<LCONST>> gfluents = state.generateAtoms(pn); for (ArrayList<LCONST> gfluent : gfluents) { // for ( Map.Entry<ArrayList<LCONST>,Object> gfluent : // (domain._bPartiallyObserved // ? observStore.get(pn).entrySet() // : state._state.get(pn).entrySet())) { Element ofEle = dom.createElement(OBSERVED_FLUENT); rootEle.appendChild(ofEle); Element pName = dom.createElement(FLUENT_NAME); Text pTextName = dom.createTextNode(pn.toString()); pName.appendChild(pTextName); ofEle.appendChild(pName); for (LCONST lc : gfluent) { Element pArg = dom.createElement(FLUENT_ARG); Text pTextArg = dom.createTextNode(lc.toString()); pArg.appendChild(pTextArg); ofEle.appendChild(pArg); } Element pValue = dom.createElement(FLUENT_VALUE); Object value = state.getPVariableAssign(pn, gfluent); if (value == null) { System.out.println("STATE:\n" + state); throw new Exception("ERROR: Could not retrieve value for " + pn + gfluent.toString()); } Text pTextValue = dom.createTextNode(value.toString()); pValue.appendChild(pTextValue); ofEle.appendChild(pValue); } } } else { // No observations (first turn of POMDP) Element ofEle = dom.createElement(NULL_OBSERVATIONS); rootEle.appendChild(ofEle); } if (SHOW_XML) { printXMLNode(dom); System.out.println(); System.out.flush(); } return (Client.serialize(dom)); } catch (Exception e) { System.out.println("FATAL SERVER EXCEPTION: " + e); e.printStackTrace(); throw e; // System.exit(1); // return null; } }
public void run() { DOMParser p = new DOMParser(); int numRounds = DEFAULT_NUM_ROUNDS; double timeAllowed = DEFAULT_TIME_ALLOWED; double timeUsed = 0; try { BufferedInputStream is = new BufferedInputStream(connection.getInputStream()); InputStreamReader isr = new InputStreamReader(is); InputSource isrc = readOneMessage(isr); requestedInstance = null; processXMLSessionRequest(p, isrc, this); System.out.println(requestedInstance); if (!rddl._tmInstanceNodes.containsKey(requestedInstance)) { System.out.println("Instance name '" + requestedInstance + "' not found."); return; } BufferedOutputStream os = new BufferedOutputStream(connection.getOutputStream()); OutputStreamWriter osw = new OutputStreamWriter(os, "US-ASCII"); String msg = createXMLSessionInit(numRounds, timeAllowed, this); sendOneMessage(osw, msg); initializeState(rddl, requestedInstance); // System.out.println("STATE:\n" + state); double accum_total_reward = 0; ArrayList<Double> rewards = new ArrayList<Double>(DEFAULT_NUM_ROUNDS * instance._nHorizon); int r = 0; for (; r < numRounds; r++) { isrc = readOneMessage(isr); if (!processXMLRoundRequest(p, isrc)) { break; } resetState(); msg = createXMLRoundInit(r + 1, numRounds, timeUsed, timeAllowed); sendOneMessage(osw, msg); System.out.println("Round " + (r + 1) + " / " + numRounds); if (SHOW_MEMORY_USAGE) System.out.print( "[ Memory usage: " + _df.format((RUNTIME.totalMemory() - RUNTIME.freeMemory()) / 1e6d) + "Mb / " + _df.format(RUNTIME.totalMemory() / 1e6d) + "Mb" + " = " + _df.format( ((double) (RUNTIME.totalMemory() - RUNTIME.freeMemory()) / (double) RUNTIME.totalMemory())) + " ]\n"); double accum_reward = 0.0d; double cur_discount = 1.0d; int h = 0; HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observStore = null; for (; h < instance._nHorizon; h++) { // if ( observStore != null) { // for ( PVAR_NAME pn : observStore.keySet() ) { // System.out.println("check3 " + pn); // for( ArrayList<LCONST> aa : observStore.get(pn).keySet()) { // System.out.println("check3 :" + aa + ": " + observStore.get(pn).get(aa)); // } // } // } msg = createXMLTurn(state, h + 1, domain, observStore); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); isrc = readOneMessage(isr); if (isrc == null) throw new Exception("FATAL SERVER EXCEPTION: EMPTY CLIENT MESSAGE"); ArrayList<PVAR_INST_DEF> ds = processXMLAction(p, isrc, state); if (ds == null) { break; } // Sungwook: this is not required. -Scott // if ( h== 0 && domain._bPartiallyObserved && ds.size() != 0) { // System.err.println("the first action for partial observable domain should be noop"); // } if (SHOW_ACTIONS) System.out.println("** Actions received: " + ds); try { state.computeNextState(ds, 0, rand); } catch (Exception ee) { System.out.println("FATAL SERVER EXCEPTION:\n" + ee); // ee.printStackTrace(); throw ee; // System.exit(1); } // for ( PVAR_NAME pn : state._observ.keySet() ) { // System.out.println("check1 " + pn); // for( ArrayList<LCONST> aa : state._observ.get(pn).keySet()) { // System.out.println("check1 :" + aa + ": " + state._observ.get(pn).get(aa)); // } // } if (domain._bPartiallyObserved) observStore = copyObserv(state._observ); // Calculate reward / objective and store double reward = ((Number) domain._exprReward.sample( new HashMap<LVAR, LCONST>(), state, rand, new BooleanPair())) .doubleValue(); rewards.add(reward); accum_reward += cur_discount * reward; // System.out.println("Accum reward: " + accum_reward + ", instance._dDiscount: " + // instance._dDiscount + // " / " + (cur_discount * reward) + " / " + reward); cur_discount *= instance._dDiscount; stateViz.display(state, h); state.advanceNextState(); } accum_total_reward += accum_reward; msg = createXMLRoundEnd(r, accum_reward, h, 0); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); } msg = createXMLSessionEnd(accum_total_reward, r, 0, this.clientName, this.id); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); BufferedWriter bw = new BufferedWriter(new FileWriter(LOG_FILE, true)); bw.write(msg); bw.newLine(); bw.flush(); // need to wait 10 seconds to pretend that we're processing something // try { // Thread.sleep(10000); // } // catch (Exception e){} // TimeStamp = new java.util.Date().toString(); // String returnCode = "MultipleSocketServer repsonded at "+ TimeStamp + (char) 3; } catch (Exception e) { e.printStackTrace(); System.out.println("\n>> TERMINATING TRIAL."); } finally { try { connection.close(); } catch (IOException e) { } } }