public class Server implements Runnable { public static final boolean SHOW_ACTIONS = true; public static final boolean SHOW_XML = false; public static final boolean SHOW_MSG = false; private static final String LOG_FILE = "rddl.log"; /** following is XML definitions */ public static final String SESSION_REQUEST = "session-request"; public static final String CLIENT_NAME = "client-name"; public static final String PROBLEM_NAME = "problem-name"; public static final String SESSION_INIT = "session-init"; public static final String SESSION_ID = "session-id"; public static final String SESSION_END = "session-end"; public static final String TOTAL_REWARD = "total-reward"; public static final String TIME_SPENT = "time-spent"; public static final String NUM_ROUNDS = "num-rounds"; public static final String TIME_ALLOWED = "time-allowed"; public static final String ROUNDS_USED = "rounds-used"; public static final String ROUND_REQUEST = "round-request"; public static final String ROUND_INIT = "round-init"; public static final String ROUND_NUM = "round-num"; public static final String ROUND_LEFT = "round-left"; public static final String TIME_LEFT = "time-left"; public static final String ROUND_END = "round-end"; public static final String ROUND_REWARD = "round-reward"; public static final String TURNS_USED = "turns-used"; public static final String TIME_USED = "time-used"; public static final String TURN = "turn"; public static final String TURN_NUM = "turn-num"; public static final String OBSERVED_FLUENT = "observed-fluent"; public static final String NULL_OBSERVATIONS = "no-observed-fluents"; public static final String FLUENT_NAME = "fluent-name"; public static final String FLUENT_ARG = "fluent-arg"; public static final String FLUENT_VALUE = "fluent-value"; public static final String ACTIONS = "actions"; public static final String ACTION = "action"; public static final String ACTION_NAME = "action-name"; public static final String ACTION_ARG = "action-arg"; public static final String ACTION_VALUE = "action-value"; public static final String DONE = "done"; public static final int PORT_NUMBER = 2323; public static final String HOST_NAME = "localhost"; // public static final int PORT_NUMBER = 2309; // public static final String HOST_NAME = "ec2-50-16-103-243.compute-1.amazonaws.com"; public static final int DEFAULT_SEED = 0; public static final String NO_XML_HEADER = "no-header"; public static boolean NO_XML_HEADING = false; public static final boolean SHOW_MEMORY_USAGE = true; public static final Runtime RUNTIME = Runtime.getRuntime(); private static DecimalFormat _df = new DecimalFormat("0.##"); private Socket connection; private String TimeStamp; private RDDL rddl = null; private static int ID = 0; private static int DEFAULT_NUM_ROUNDS = 30; private static double DEFAULT_TIME_ALLOWED = 30; public int id; public String clientName = null; public String requestedInstance = null; public static Random rand; public State state; public INSTANCE instance; public NONFLUENTS nonFluents; public DOMAIN domain; public StateViz stateViz; /** * @param args 1. rddl description file name (can be directory), in RDDL format, with complete * path 2. (optional) port number 3. (optional) random seed */ public static void main(String[] args) { // StateViz state_viz = new GenericScreenDisplay(true); StateViz state_viz = new NullScreenDisplay(false); ArrayList<RDDL> rddls = new ArrayList<RDDL>(); int port = PORT_NUMBER; if (args.length < 1) { System.out.println( "usage: rddlfilename (optional) portnumber random-seed state-viz-class-name"); System.out.println("\nexample 1: Server rddlfilename"); System.out.println("example 2: Server rddlfilename 2323"); System.out.println("example 3: Server rddlfilename 2323 0 rddl.viz.GenericScreenDisplay"); System.exit(1); } try { // Load RDDL files RDDL rddl = new RDDL(); File f = new File(args[0]); if (f.isDirectory()) { for (File f2 : f.listFiles()) if (f2.getName().endsWith(".rddl")) { System.out.println("Loading: " + f2); rddl.addOtherRDDL(parser.parse(f2)); } } else rddl.addOtherRDDL(parser.parse(f)); if (args.length > 1) { port = Integer.valueOf(args[1]); } ServerSocket socket1 = new ServerSocket(port); if (args.length > 2) { Server.rand = new Random(Integer.valueOf(args[2])); } else { Server.rand = new Random(DEFAULT_SEED); } if (args.length > 3) { state_viz = (StateViz) Class.forName(args[3]).newInstance(); } System.out.println("RDDL Server Initialized"); while (true) { Socket connection = socket1.accept(); Runnable runnable = new Server(connection, ++ID, rddl, state_viz); Thread thread = new Thread(runnable); thread.start(); } } catch (Exception e) { // TODO Auto-generated catch block System.out.println(e); e.printStackTrace(); } } Server(Socket s, int i, RDDL rddl) { this(s, i, rddl, new NullScreenDisplay(false)); } Server(Socket s, int i, RDDL rddl, StateViz state_viz) { this.connection = s; this.id = i; this.rddl = rddl; this.stateViz = state_viz; } public void run() { DOMParser p = new DOMParser(); int numRounds = DEFAULT_NUM_ROUNDS; double timeAllowed = DEFAULT_TIME_ALLOWED; double timeUsed = 0; try { BufferedInputStream is = new BufferedInputStream(connection.getInputStream()); InputStreamReader isr = new InputStreamReader(is); InputSource isrc = readOneMessage(isr); requestedInstance = null; processXMLSessionRequest(p, isrc, this); System.out.println(requestedInstance); if (!rddl._tmInstanceNodes.containsKey(requestedInstance)) { System.out.println("Instance name '" + requestedInstance + "' not found."); return; } BufferedOutputStream os = new BufferedOutputStream(connection.getOutputStream()); OutputStreamWriter osw = new OutputStreamWriter(os, "US-ASCII"); String msg = createXMLSessionInit(numRounds, timeAllowed, this); sendOneMessage(osw, msg); initializeState(rddl, requestedInstance); // System.out.println("STATE:\n" + state); double accum_total_reward = 0; ArrayList<Double> rewards = new ArrayList<Double>(DEFAULT_NUM_ROUNDS * instance._nHorizon); int r = 0; for (; r < numRounds; r++) { isrc = readOneMessage(isr); if (!processXMLRoundRequest(p, isrc)) { break; } resetState(); msg = createXMLRoundInit(r + 1, numRounds, timeUsed, timeAllowed); sendOneMessage(osw, msg); System.out.println("Round " + (r + 1) + " / " + numRounds); if (SHOW_MEMORY_USAGE) System.out.print( "[ Memory usage: " + _df.format((RUNTIME.totalMemory() - RUNTIME.freeMemory()) / 1e6d) + "Mb / " + _df.format(RUNTIME.totalMemory() / 1e6d) + "Mb" + " = " + _df.format( ((double) (RUNTIME.totalMemory() - RUNTIME.freeMemory()) / (double) RUNTIME.totalMemory())) + " ]\n"); double accum_reward = 0.0d; double cur_discount = 1.0d; int h = 0; HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observStore = null; for (; h < instance._nHorizon; h++) { // if ( observStore != null) { // for ( PVAR_NAME pn : observStore.keySet() ) { // System.out.println("check3 " + pn); // for( ArrayList<LCONST> aa : observStore.get(pn).keySet()) { // System.out.println("check3 :" + aa + ": " + observStore.get(pn).get(aa)); // } // } // } msg = createXMLTurn(state, h + 1, domain, observStore); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); isrc = readOneMessage(isr); if (isrc == null) throw new Exception("FATAL SERVER EXCEPTION: EMPTY CLIENT MESSAGE"); ArrayList<PVAR_INST_DEF> ds = processXMLAction(p, isrc, state); if (ds == null) { break; } // Sungwook: this is not required. -Scott // if ( h== 0 && domain._bPartiallyObserved && ds.size() != 0) { // System.err.println("the first action for partial observable domain should be noop"); // } if (SHOW_ACTIONS) System.out.println("** Actions received: " + ds); try { state.computeNextState(ds, 0, rand); } catch (Exception ee) { System.out.println("FATAL SERVER EXCEPTION:\n" + ee); // ee.printStackTrace(); throw ee; // System.exit(1); } // for ( PVAR_NAME pn : state._observ.keySet() ) { // System.out.println("check1 " + pn); // for( ArrayList<LCONST> aa : state._observ.get(pn).keySet()) { // System.out.println("check1 :" + aa + ": " + state._observ.get(pn).get(aa)); // } // } if (domain._bPartiallyObserved) observStore = copyObserv(state._observ); // Calculate reward / objective and store double reward = ((Number) domain._exprReward.sample( new HashMap<LVAR, LCONST>(), state, rand, new BooleanPair())) .doubleValue(); rewards.add(reward); accum_reward += cur_discount * reward; // System.out.println("Accum reward: " + accum_reward + ", instance._dDiscount: " + // instance._dDiscount + // " / " + (cur_discount * reward) + " / " + reward); cur_discount *= instance._dDiscount; stateViz.display(state, h); state.advanceNextState(); } accum_total_reward += accum_reward; msg = createXMLRoundEnd(r, accum_reward, h, 0); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); } msg = createXMLSessionEnd(accum_total_reward, r, 0, this.clientName, this.id); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); BufferedWriter bw = new BufferedWriter(new FileWriter(LOG_FILE, true)); bw.write(msg); bw.newLine(); bw.flush(); // need to wait 10 seconds to pretend that we're processing something // try { // Thread.sleep(10000); // } // catch (Exception e){} // TimeStamp = new java.util.Date().toString(); // String returnCode = "MultipleSocketServer repsonded at "+ TimeStamp + (char) 3; } catch (Exception e) { e.printStackTrace(); System.out.println("\n>> TERMINATING TRIAL."); } finally { try { connection.close(); } catch (IOException e) { } } } HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> copyObserv( HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observ) { HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> r = new HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>>(); // System.out.println("Observation pvars: " + observ); for (PVAR_NAME pn : observ.keySet()) { HashMap<ArrayList<LCONST>, Object> v = new HashMap<ArrayList<LCONST>, Object>(); for (ArrayList<LCONST> aa : observ.get(pn).keySet()) { ArrayList<LCONST> raa = new ArrayList<LCONST>(); for (LCONST lc : aa) { raa.add(lc); } v.put(raa, observ.get(pn).get(aa)); } r.put(pn, v); } return r; } void initializeState(RDDL rddl, String requestedInstance) { state = new State(); instance = rddl._tmInstanceNodes.get(requestedInstance); nonFluents = null; if (instance._sNonFluents != null) { nonFluents = rddl._tmNonFluentNodes.get(instance._sNonFluents); } domain = rddl._tmDomainNodes.get(instance._sDomain); if (nonFluents != null && !instance._sDomain.equals(nonFluents._sDomain)) { try { throw new Exception( "Domain name of instance and fluents do not match: " + instance._sDomain + " vs. " + nonFluents._sDomain); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } } void resetState() { state.init( nonFluents != null ? nonFluents._hmObjects : null, instance._hmObjects, domain._hmTypes, domain._hmPVariables, domain._hmCPF, instance._alInitState, nonFluents == null ? null : nonFluents._alNonFluents, domain._alStateConstraints, domain._exprReward, instance._nNonDefActions); if ((domain._bPartiallyObserved && state._alObservNames.size() == 0) || (!domain._bPartiallyObserved && state._alObservNames.size() > 0)) System.err.println( "Domain '" + domain._sDomainName + "' partially observed flag and presence of observations mismatched."); } static Object getValue(String pname, String pvalue, State state) { TYPE_NAME tname = state._hmPVariables.get(new PVAR_NAME(pname))._sRange; // TYPE_NAMES are interned so that equality can be tested directly // (also helps enforce better type safety) if (TYPE_NAME.INT_TYPE.equals(tname)) { return Integer.valueOf(pvalue); } if (TYPE_NAME.BOOL_TYPE.equals(tname)) { return Boolean.valueOf(pvalue); } if (TYPE_NAME.REAL_TYPE.equals(tname)) { return Double.valueOf(pvalue); } if (state._hmObject2Consts.containsKey(tname)) { return new OBJECT_VAL(pvalue); // for( LCONST lc : state._hmObject2Consts.get(tname)) { // if ( lc.toString().equals(pvalue)) { // return lc; // } // } } if (state._hmTypes.containsKey(tname)) { return new ENUM_VAL(pvalue); // if ( state._hmTypes.get(tname) instanceof ENUM_TYPE_DEF ) { // ENUM_TYPE_DEF etype = (ENUM_TYPE_DEF)state._hmTypes.get(tname); // for ( ENUM_VAL ev : etype._alPossibleValues) { // if ( ev.toString().equals(pvalue)) { // return ev; // } // } // } } return null; } static ArrayList<PVAR_INST_DEF> processXMLAction(DOMParser p, InputSource isrc, State state) throws Exception { try { // showInputSource(isrc); System.exit(1); // TODO p.parse(isrc); Element e = p.getDocument().getDocumentElement(); if (SHOW_XML) { System.out.println("Received action msg:"); printXMLNode(e); } if (!e.getNodeName().equals(ACTIONS)) { System.out.println("ERROR: NO ACTIONS NODE"); System.exit(1); return null; } NodeList nl = e.getElementsByTagName(ACTION); // System.out.println(nl); if (nl != null) { // && nl.getLength() > 0) { // TODO: Scott ArrayList<PVAR_INST_DEF> ds = new ArrayList<PVAR_INST_DEF>(); for (int i = 0; i < nl.getLength(); i++) { Element el = (Element) nl.item(i); String name = getTextValue(el, ACTION_NAME).get(0); ArrayList<String> args = getTextValue(el, ACTION_ARG); ArrayList<LCONST> lcArgs = new ArrayList<LCONST>(); for (String arg : args) { if (arg.startsWith("@")) lcArgs.add(new RDDL.ENUM_VAL(arg)); else lcArgs.add(new RDDL.OBJECT_VAL(arg)); } String pvalue = getTextValue(el, ACTION_VALUE).get(0); Object value = getValue(name, pvalue, state); PVAR_INST_DEF d = new PVAR_INST_DEF(name, value, lcArgs); ds.add(d); } return ds; } else return new ArrayList<PVAR_INST_DEF>(); // FYI: May be unreachable. -Scott // } else { // TODO: Removed by Scott, NOOP should not be handled differently // nl = e.getElementsByTagName(NOOP); // if ( nl != null && nl.getLength() > 0) { // ArrayList<PVAR_INST_DEF> ds = new ArrayList<PVAR_INST_DEF>(); // return ds; // } // } } catch (Exception e) { // TODO Auto-generated catch block System.out.println("FATAL SERVER ERROR:\n" + e); // t.printStackTrace(); throw e; // System.exit(1); } } public static void sendOneMessage(OutputStreamWriter osw, String msg) throws IOException { // System.out.println(msg); if (NO_XML_HEADING) { // System.out.println(msg.substring(39)); osw.write(msg.substring(39)); } else { osw.write(msg + '\0'); } osw.flush(); } public static InputSource readOneMessage(InputStreamReader isr) { StringBuffer message = new StringBuffer(); int character; try { while ((character = isr.read()) != '\0') { message.append((char) character); } // System.out.println(message); ByteArrayInputStream bais = new ByteArrayInputStream(message.toString().getBytes()); InputSource isrc = new InputSource(); isrc.setByteStream(bais); return isrc; } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); return null; } } static String createXMLSessionInit(int numRounds, double timeAllowed, Server server) { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); try { DocumentBuilder db = dbf.newDocumentBuilder(); Document dom = db.newDocument(); Element rootEle = dom.createElement(SESSION_INIT); dom.appendChild(rootEle); addOneText(dom, rootEle, SESSION_ID, server.id + ""); addOneText(dom, rootEle, NUM_ROUNDS, numRounds + ""); addOneText(dom, rootEle, TIME_ALLOWED, timeAllowed + ""); return Client.serialize(dom); } catch (Exception e) { System.out.println(e); return null; } } static void processXMLSessionRequest(DOMParser p, InputSource isrc, Server server) { try { p.parse(isrc); Element e = p.getDocument().getDocumentElement(); if (e.getNodeName().equals(SESSION_REQUEST)) { server.requestedInstance = getTextValue(e, PROBLEM_NAME).get(0); server.clientName = getTextValue(e, CLIENT_NAME).get(0); NodeList nl = e.getElementsByTagName(NO_XML_HEADER); if (nl.getLength() > 0) { NO_XML_HEADING = true; } } return; } catch (SAXException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } catch (IOException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } return; } static boolean processXMLRoundRequest(DOMParser p, InputSource isrc) { try { p.parse(isrc); Element e = p.getDocument().getDocumentElement(); if (e.getNodeName().equals(ROUND_REQUEST)) { return true; } return false; } catch (SAXException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } catch (IOException e1) { // TODO Auto-generated catch block e1.printStackTrace(); } return false; } public static ArrayList<String> getTextValue(Element ele, String tagName) { ArrayList<String> returnVal = new ArrayList<String>(); // NodeList nll = ele.getElementsByTagName("*"); NodeList nl = ele.getElementsByTagName(tagName); if (nl != null && nl.getLength() > 0) { for (int i = 0; i < nl.getLength(); i++) { Element el = (Element) nl.item(i); returnVal.add(el.getFirstChild().getNodeValue()); } } return returnVal; } static String createXMLTurn( State state, int turn, DOMAIN domain, HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observStore) throws Exception { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); try { DocumentBuilder db = dbf.newDocumentBuilder(); Document dom = db.newDocument(); Element rootEle = dom.createElement(TURN); dom.appendChild(rootEle); Element turnNum = dom.createElement(TURN_NUM); Text textTurnNum = dom.createTextNode(turn + ""); turnNum.appendChild(textTurnNum); rootEle.appendChild(turnNum); // System.out.println("PO: " + domain._bPartiallyObserved); if (!domain._bPartiallyObserved || observStore != null) { for (PVAR_NAME pn : (domain._bPartiallyObserved ? observStore.keySet() : state._state.keySet())) { // System.out.println(turn + " check2 Partial Observ " + pn +" : "+ // domain._bPartiallyObserved); // No problem to overwrite observations, only ever read from if (domain._bPartiallyObserved && observStore != null) state._observ.put(pn, observStore.get(pn)); ArrayList<ArrayList<LCONST>> gfluents = state.generateAtoms(pn); for (ArrayList<LCONST> gfluent : gfluents) { // for ( Map.Entry<ArrayList<LCONST>,Object> gfluent : // (domain._bPartiallyObserved // ? observStore.get(pn).entrySet() // : state._state.get(pn).entrySet())) { Element ofEle = dom.createElement(OBSERVED_FLUENT); rootEle.appendChild(ofEle); Element pName = dom.createElement(FLUENT_NAME); Text pTextName = dom.createTextNode(pn.toString()); pName.appendChild(pTextName); ofEle.appendChild(pName); for (LCONST lc : gfluent) { Element pArg = dom.createElement(FLUENT_ARG); Text pTextArg = dom.createTextNode(lc.toString()); pArg.appendChild(pTextArg); ofEle.appendChild(pArg); } Element pValue = dom.createElement(FLUENT_VALUE); Object value = state.getPVariableAssign(pn, gfluent); if (value == null) { System.out.println("STATE:\n" + state); throw new Exception("ERROR: Could not retrieve value for " + pn + gfluent.toString()); } Text pTextValue = dom.createTextNode(value.toString()); pValue.appendChild(pTextValue); ofEle.appendChild(pValue); } } } else { // No observations (first turn of POMDP) Element ofEle = dom.createElement(NULL_OBSERVATIONS); rootEle.appendChild(ofEle); } if (SHOW_XML) { printXMLNode(dom); System.out.println(); System.out.flush(); } return (Client.serialize(dom)); } catch (Exception e) { System.out.println("FATAL SERVER EXCEPTION: " + e); e.printStackTrace(); throw e; // System.exit(1); // return null; } } static String createXMLRoundInit(int round, int numRounds, double timeUsed, double timeAllowed) { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); try { DocumentBuilder db = dbf.newDocumentBuilder(); Document dom = db.newDocument(); Element rootEle = dom.createElement(ROUND_INIT); dom.appendChild(rootEle); addOneText(dom, rootEle, ROUND_NUM, round + ""); addOneText(dom, rootEle, ROUND_LEFT, (numRounds - round) + ""); addOneText(dom, rootEle, TIME_LEFT, (timeAllowed - timeUsed) + ""); return Client.serialize(dom); } catch (Exception e) { System.out.println(e); return null; } } static String createXMLRoundEnd(int round, double reward, int turnsUsed, double timeUsed) { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); try { DocumentBuilder db = dbf.newDocumentBuilder(); Document dom = db.newDocument(); Element rootEle = dom.createElement(ROUND_END); dom.appendChild(rootEle); addOneText(dom, rootEle, ROUND_NUM, round + ""); addOneText(dom, rootEle, ROUND_REWARD, reward + ""); addOneText(dom, rootEle, TURNS_USED, turnsUsed + ""); addOneText(dom, rootEle, TIME_USED, timeUsed + ""); return Client.serialize(dom); } catch (Exception e) { System.out.println(e); return null; } } public static void addOneText(Document dom, Element p, String name, String value) { Element e = dom.createElement(name); Text text = dom.createTextNode(value); e.appendChild(text); p.appendChild(e); } static String createXMLSessionEnd( double reward, int roundsUsed, double timeUsed, String clientName, int sessionId) { DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); try { DocumentBuilder db = dbf.newDocumentBuilder(); Document dom = db.newDocument(); Element rootEle = dom.createElement(SESSION_END); dom.appendChild(rootEle); addOneText(dom, rootEle, TOTAL_REWARD, reward + ""); addOneText(dom, rootEle, ROUNDS_USED, roundsUsed + ""); addOneText(dom, rootEle, TIME_USED, timeUsed + ""); addOneText(dom, rootEle, CLIENT_NAME, clientName + ""); addOneText(dom, rootEle, SESSION_ID, sessionId + ""); return Client.serialize(dom); } catch (Exception e) { System.out.println(e); return null; } } /////////////////////////////////////////////////////////////////////// // DEBUG /////////////////////////////////////////////////////////////////////// public static void showInputSource(InputSource isrc) { InputStream is = isrc.getByteStream(); byte[] bytes; try { int size = is.available(); bytes = new byte[size]; is.read(bytes); System.out.println("==BEGIN IS=="); System.out.write(bytes, 0, size); System.out.println("\n==END IS=="); } catch (IOException e2) { System.out.println(">>> Inputstream error"); e2.printStackTrace(); } } public static void printXMLNode(Node n) { printXMLNode(n, "", 0); } public static void printXMLNode(Node n, String prefix, int depth) { try { System.out.print("\n" + Pad(depth) + "[" + n.getNodeName()); NamedNodeMap m = n.getAttributes(); for (int i = 0; m != null && i < m.getLength(); i++) { Node item = m.item(i); System.out.print(" " + item.getNodeName() + "=" + item.getNodeValue()); } System.out.print("] "); NodeList cn = n.getChildNodes(); for (int i = 0; cn != null && i < cn.getLength(); i++) { Node item = cn.item(i); if (item.getNodeType() == Node.TEXT_NODE) { String val = item.getNodeValue().trim(); if (val.length() > 0) System.out.print(" \"" + item.getNodeValue().trim() + "\""); } else printXMLNode(item, prefix, depth + 2); } } catch (Exception e) { System.out.println(Pad(depth) + "Exception e: "); } } public static StringBuffer Pad(int depth) { StringBuffer sb = new StringBuffer(); for (int i = 0; i < depth; i++) sb.append(" "); return sb; } }
public void run() { DOMParser p = new DOMParser(); int numRounds = DEFAULT_NUM_ROUNDS; double timeAllowed = DEFAULT_TIME_ALLOWED; double timeUsed = 0; try { BufferedInputStream is = new BufferedInputStream(connection.getInputStream()); InputStreamReader isr = new InputStreamReader(is); InputSource isrc = readOneMessage(isr); requestedInstance = null; processXMLSessionRequest(p, isrc, this); System.out.println(requestedInstance); if (!rddl._tmInstanceNodes.containsKey(requestedInstance)) { System.out.println("Instance name '" + requestedInstance + "' not found."); return; } BufferedOutputStream os = new BufferedOutputStream(connection.getOutputStream()); OutputStreamWriter osw = new OutputStreamWriter(os, "US-ASCII"); String msg = createXMLSessionInit(numRounds, timeAllowed, this); sendOneMessage(osw, msg); initializeState(rddl, requestedInstance); // System.out.println("STATE:\n" + state); double accum_total_reward = 0; ArrayList<Double> rewards = new ArrayList<Double>(DEFAULT_NUM_ROUNDS * instance._nHorizon); int r = 0; for (; r < numRounds; r++) { isrc = readOneMessage(isr); if (!processXMLRoundRequest(p, isrc)) { break; } resetState(); msg = createXMLRoundInit(r + 1, numRounds, timeUsed, timeAllowed); sendOneMessage(osw, msg); System.out.println("Round " + (r + 1) + " / " + numRounds); if (SHOW_MEMORY_USAGE) System.out.print( "[ Memory usage: " + _df.format((RUNTIME.totalMemory() - RUNTIME.freeMemory()) / 1e6d) + "Mb / " + _df.format(RUNTIME.totalMemory() / 1e6d) + "Mb" + " = " + _df.format( ((double) (RUNTIME.totalMemory() - RUNTIME.freeMemory()) / (double) RUNTIME.totalMemory())) + " ]\n"); double accum_reward = 0.0d; double cur_discount = 1.0d; int h = 0; HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observStore = null; for (; h < instance._nHorizon; h++) { // if ( observStore != null) { // for ( PVAR_NAME pn : observStore.keySet() ) { // System.out.println("check3 " + pn); // for( ArrayList<LCONST> aa : observStore.get(pn).keySet()) { // System.out.println("check3 :" + aa + ": " + observStore.get(pn).get(aa)); // } // } // } msg = createXMLTurn(state, h + 1, domain, observStore); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); isrc = readOneMessage(isr); if (isrc == null) throw new Exception("FATAL SERVER EXCEPTION: EMPTY CLIENT MESSAGE"); ArrayList<PVAR_INST_DEF> ds = processXMLAction(p, isrc, state); if (ds == null) { break; } // Sungwook: this is not required. -Scott // if ( h== 0 && domain._bPartiallyObserved && ds.size() != 0) { // System.err.println("the first action for partial observable domain should be noop"); // } if (SHOW_ACTIONS) System.out.println("** Actions received: " + ds); try { state.computeNextState(ds, 0, rand); } catch (Exception ee) { System.out.println("FATAL SERVER EXCEPTION:\n" + ee); // ee.printStackTrace(); throw ee; // System.exit(1); } // for ( PVAR_NAME pn : state._observ.keySet() ) { // System.out.println("check1 " + pn); // for( ArrayList<LCONST> aa : state._observ.get(pn).keySet()) { // System.out.println("check1 :" + aa + ": " + state._observ.get(pn).get(aa)); // } // } if (domain._bPartiallyObserved) observStore = copyObserv(state._observ); // Calculate reward / objective and store double reward = ((Number) domain._exprReward.sample( new HashMap<LVAR, LCONST>(), state, rand, new BooleanPair())) .doubleValue(); rewards.add(reward); accum_reward += cur_discount * reward; // System.out.println("Accum reward: " + accum_reward + ", instance._dDiscount: " + // instance._dDiscount + // " / " + (cur_discount * reward) + " / " + reward); cur_discount *= instance._dDiscount; stateViz.display(state, h); state.advanceNextState(); } accum_total_reward += accum_reward; msg = createXMLRoundEnd(r, accum_reward, h, 0); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); } msg = createXMLSessionEnd(accum_total_reward, r, 0, this.clientName, this.id); if (SHOW_MSG) System.out.println("Sending msg:\n" + msg); sendOneMessage(osw, msg); BufferedWriter bw = new BufferedWriter(new FileWriter(LOG_FILE, true)); bw.write(msg); bw.newLine(); bw.flush(); // need to wait 10 seconds to pretend that we're processing something // try { // Thread.sleep(10000); // } // catch (Exception e){} // TimeStamp = new java.util.Date().toString(); // String returnCode = "MultipleSocketServer repsonded at "+ TimeStamp + (char) 3; } catch (Exception e) { e.printStackTrace(); System.out.println("\n>> TERMINATING TRIAL."); } finally { try { connection.close(); } catch (IOException e) { } } }