예제 #1
0
  void resetState() {
    state.init(
        nonFluents != null ? nonFluents._hmObjects : null,
        instance._hmObjects,
        domain._hmTypes,
        domain._hmPVariables,
        domain._hmCPF,
        instance._alInitState,
        nonFluents == null ? null : nonFluents._alNonFluents,
        domain._alStateConstraints,
        domain._exprReward,
        instance._nNonDefActions);

    if ((domain._bPartiallyObserved && state._alObservNames.size() == 0)
        || (!domain._bPartiallyObserved && state._alObservNames.size() > 0))
      System.err.println(
          "Domain '"
              + domain._sDomainName
              + "' partially observed flag and presence of observations mismatched.");
  }
예제 #2
0
  public String getStateDescription(State s) throws EvalException {
    StringBuilder sb = new StringBuilder();

    TYPE_NAME xpos_type = new TYPE_NAME("xpos");
    ArrayList<LCONST> list_xpos = s._hmObject2Consts.get(xpos_type);

    TYPE_NAME ypos_type = new TYPE_NAME("ypos");
    ArrayList<LCONST> list_ypos = s._hmObject2Consts.get(ypos_type);

    PVAR_NAME GOAL = new PVAR_NAME("GOAL");
    PVAR_NAME robot_at = new PVAR_NAME("robot-at");
    PVAR_NAME P = new PVAR_NAME("P");

    if (_bd == null) {
      IS_OBFUSCATED = !list_xpos.get(0).toString().equals("x1");
      int max_row = list_ypos.size() - 1;
      int max_col = list_xpos.size();

      _bd =
          new BlockDisplay(
              "RDDL Navigation Simulation", "RDDL Navigation Simulation", max_row + 2, max_col + 2);
    }

    // Set up an arity-1 parameter list
    ArrayList<LCONST> params = new ArrayList<LCONST>(2);
    params.add(null);
    params.add(null);

    _bd.clearAllCells();
    _bd.clearAllLines();
    for (LCONST xpos : list_xpos) {
      for (LCONST ypos : list_ypos) {
        int col = new Integer(xpos.toString().substring(2, xpos.toString().length()));
        int row = new Integer(ypos.toString().substring(2, ypos.toString().length()));
        if (IS_OBFUSCATED) {
          row = (int) Math.sqrt(row - 11);
          col = (int) Math.sqrt(col - 5);
        }
        row = row - 1;
        params.set(0, xpos);
        params.set(1, ypos);
        boolean is_goal = (Boolean) s.getPVariableAssign(GOAL, params);
        boolean robot = (Boolean) s.getPVariableAssign(robot_at, params);
        float prob = 1f - ((Number) s.getPVariableAssign(P, params)).floatValue();

        if (robot && is_goal) _bd.setCell(row, col, Color.red, "G!");
        else if (is_goal) _bd.setCell(row, col, Color.cyan, "G");
        else if (robot) _bd.setCell(row, col, Color.blue, null);
        else {
          Color cell_color = new Color(prob, prob, prob);
          _bd.setCell(row, col, cell_color, null);
        }
      }
    }

    _bd.repaint();

    // Sleep so the animation can be viewed at a frame rate of 1000/_nTimeDelay per second
    try {
      Thread.currentThread().sleep(_nTimeDelay);
    } catch (InterruptedException e) {
      System.err.println(e);
      e.printStackTrace(System.err);
    }

    return sb.toString();
  }
  public ArrayList<PVAR_INST_DEF> getActions(State s) throws EvalException {

    ArrayList<PVAR_INST_DEF> actions = new ArrayList<PVAR_INST_DEF>();
    ArrayList<PVAR_NAME> action_types = s._hmTypeMap.get("action-fluent");

    boolean passed_constraints = false;
    for (int j = 0; j < s._alActionNames.size(); j++) {

      // Get a random action
      PVAR_NAME p = s._alActionNames.get(j);
      PVARIABLE_DEF pvar_def = s._hmPVariables.get(p);

      // Get term instantations for that action and select *one*
      ArrayList<ArrayList<LCONST>> inst = s.generateAtoms(p);
      int[] index_permutation = Permutation.permute(inst.size(), _random);

      for (int i = 0; i < index_permutation.length; i++) {
        ArrayList<LCONST> terms = inst.get(index_permutation[i]);

        // IMPORTANT: get random assignment that matches action type
        Object value = null;
        if (pvar_def._sRange.equals(RDDL.TYPE_NAME.BOOL_TYPE)) {
          // bool
          value = new Boolean(true);
        } else if (pvar_def._sRange.equals(RDDL.TYPE_NAME.INT_TYPE)) {
          // int
          value = new Integer(_random.nextInt(MAX_INT_VALUE));
        } else if (pvar_def._sRange.equals(RDDL.TYPE_NAME.REAL_TYPE)) {
          // real
          value = new Double(_random.nextDouble() * MAX_REAL_VALUE);
        } else {
          // enum: only other option for a range
          ENUM_TYPE_DEF enum_type_def = (ENUM_TYPE_DEF) s._hmTypes.get(pvar_def._sRange);
          int rand_index = _random.nextInt(enum_type_def._alPossibleValues.size());

          value = enum_type_def._alPossibleValues.get(rand_index);
        }

        // Now set the action
        actions.add(new PVAR_INST_DEF(p._sPVarName, value, terms));
        passed_constraints = true;
        try {
          s.checkStateActionConstraints(actions);
        } catch (EvalException e) {
          // Got an eval exception, constraint violated
          passed_constraints = false;
          // System.out.println(actions + " : " + e);
          // System.out.println(s);
          // System.exit(1);
        } catch (Exception e) {
          // Got a real exception, something is wrong
          System.out.println(
              "\nERROR evaluating constraint on action set: "
                  + actions
                  + /*"\nConstraint: " +*/ e
                  + "\n");
          e.printStackTrace();
          throw new EvalException(e.toString());
        }
        if (!passed_constraints) actions.remove(actions.size() - 1);
        if (actions.size() == NUM_CONCURRENT_ACTIONS) break;
      }
      if (actions.size() == NUM_CONCURRENT_ACTIONS) break;
    }

    // Check if no single action passed constraint
    if (!passed_constraints) {
      // Try empty action
      passed_constraints = true;
      actions.clear();
      try {
        s.checkStateActionConstraints(actions);
      } catch (EvalException e) {
        passed_constraints = false;
        System.out.println(actions + " : " + e);
        throw new EvalException("No actions (even a) satisfied state constraints!");
      }
    }

    // Return the action list
    // System.out.println("**Action: " + actions);
    return actions;
  }
예제 #4
0
  static String createXMLTurn(
      State state,
      int turn,
      DOMAIN domain,
      HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observStore)
      throws Exception {
    DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
    try {
      DocumentBuilder db = dbf.newDocumentBuilder();
      Document dom = db.newDocument();
      Element rootEle = dom.createElement(TURN);
      dom.appendChild(rootEle);
      Element turnNum = dom.createElement(TURN_NUM);
      Text textTurnNum = dom.createTextNode(turn + "");
      turnNum.appendChild(textTurnNum);
      rootEle.appendChild(turnNum);

      // System.out.println("PO: " + domain._bPartiallyObserved);
      if (!domain._bPartiallyObserved || observStore != null) {
        for (PVAR_NAME pn :
            (domain._bPartiallyObserved ? observStore.keySet() : state._state.keySet())) {
          // System.out.println(turn + " check2 Partial Observ " + pn +" : "+
          // domain._bPartiallyObserved);

          // No problem to overwrite observations, only ever read from
          if (domain._bPartiallyObserved && observStore != null)
            state._observ.put(pn, observStore.get(pn));

          ArrayList<ArrayList<LCONST>> gfluents = state.generateAtoms(pn);
          for (ArrayList<LCONST> gfluent : gfluents) {
            // for ( Map.Entry<ArrayList<LCONST>,Object> gfluent :
            //	(domain._bPartiallyObserved
            //			? observStore.get(pn).entrySet()
            //					: state._state.get(pn).entrySet())) {
            Element ofEle = dom.createElement(OBSERVED_FLUENT);
            rootEle.appendChild(ofEle);
            Element pName = dom.createElement(FLUENT_NAME);
            Text pTextName = dom.createTextNode(pn.toString());
            pName.appendChild(pTextName);
            ofEle.appendChild(pName);
            for (LCONST lc : gfluent) {
              Element pArg = dom.createElement(FLUENT_ARG);
              Text pTextArg = dom.createTextNode(lc.toString());
              pArg.appendChild(pTextArg);
              ofEle.appendChild(pArg);
            }
            Element pValue = dom.createElement(FLUENT_VALUE);
            Object value = state.getPVariableAssign(pn, gfluent);
            if (value == null) {
              System.out.println("STATE:\n" + state);
              throw new Exception("ERROR: Could not retrieve value for " + pn + gfluent.toString());
            }

            Text pTextValue = dom.createTextNode(value.toString());
            pValue.appendChild(pTextValue);
            ofEle.appendChild(pValue);
          }
        }
      } else {
        // No observations (first turn of POMDP)
        Element ofEle = dom.createElement(NULL_OBSERVATIONS);
        rootEle.appendChild(ofEle);
      }
      if (SHOW_XML) {
        printXMLNode(dom);
        System.out.println();
        System.out.flush();
      }
      return (Client.serialize(dom));

    } catch (Exception e) {
      System.out.println("FATAL SERVER EXCEPTION: " + e);
      e.printStackTrace();
      throw e;
      // System.exit(1);
      // return null;
    }
  }
예제 #5
0
  public void run() {
    DOMParser p = new DOMParser();
    int numRounds = DEFAULT_NUM_ROUNDS;
    double timeAllowed = DEFAULT_TIME_ALLOWED;
    double timeUsed = 0;
    try {
      BufferedInputStream is = new BufferedInputStream(connection.getInputStream());
      InputStreamReader isr = new InputStreamReader(is);
      InputSource isrc = readOneMessage(isr);
      requestedInstance = null;
      processXMLSessionRequest(p, isrc, this);
      System.out.println(requestedInstance);

      if (!rddl._tmInstanceNodes.containsKey(requestedInstance)) {
        System.out.println("Instance name '" + requestedInstance + "' not found.");
        return;
      }

      BufferedOutputStream os = new BufferedOutputStream(connection.getOutputStream());
      OutputStreamWriter osw = new OutputStreamWriter(os, "US-ASCII");
      String msg = createXMLSessionInit(numRounds, timeAllowed, this);
      sendOneMessage(osw, msg);

      initializeState(rddl, requestedInstance);
      // System.out.println("STATE:\n" + state);

      double accum_total_reward = 0;
      ArrayList<Double> rewards = new ArrayList<Double>(DEFAULT_NUM_ROUNDS * instance._nHorizon);
      int r = 0;
      for (; r < numRounds; r++) {
        isrc = readOneMessage(isr);
        if (!processXMLRoundRequest(p, isrc)) {
          break;
        }
        resetState();
        msg = createXMLRoundInit(r + 1, numRounds, timeUsed, timeAllowed);
        sendOneMessage(osw, msg);

        System.out.println("Round " + (r + 1) + " / " + numRounds);
        if (SHOW_MEMORY_USAGE)
          System.out.print(
              "[ Memory usage: "
                  + _df.format((RUNTIME.totalMemory() - RUNTIME.freeMemory()) / 1e6d)
                  + "Mb / "
                  + _df.format(RUNTIME.totalMemory() / 1e6d)
                  + "Mb"
                  + " = "
                  + _df.format(
                      ((double) (RUNTIME.totalMemory() - RUNTIME.freeMemory())
                          / (double) RUNTIME.totalMemory()))
                  + " ]\n");

        double accum_reward = 0.0d;
        double cur_discount = 1.0d;
        int h = 0;
        HashMap<PVAR_NAME, HashMap<ArrayList<LCONST>, Object>> observStore = null;
        for (; h < instance._nHorizon; h++) {

          // if ( observStore != null) {
          //	for ( PVAR_NAME pn : observStore.keySet() ) {
          //		System.out.println("check3 " + pn);
          //		for( ArrayList<LCONST> aa : observStore.get(pn).keySet()) {
          //			System.out.println("check3 :" + aa + ": " + observStore.get(pn).get(aa));
          //		}
          //	}
          // }
          msg = createXMLTurn(state, h + 1, domain, observStore);
          if (SHOW_MSG) System.out.println("Sending msg:\n" + msg);
          sendOneMessage(osw, msg);

          isrc = readOneMessage(isr);
          if (isrc == null) throw new Exception("FATAL SERVER EXCEPTION: EMPTY CLIENT MESSAGE");

          ArrayList<PVAR_INST_DEF> ds = processXMLAction(p, isrc, state);
          if (ds == null) {
            break;
          }
          // Sungwook: this is not required.  -Scott
          // if ( h== 0 && domain._bPartiallyObserved && ds.size() != 0) {
          //	System.err.println("the first action for partial observable domain should be noop");
          // }
          if (SHOW_ACTIONS) System.out.println("** Actions received: " + ds);

          try {
            state.computeNextState(ds, 0, rand);
          } catch (Exception ee) {
            System.out.println("FATAL SERVER EXCEPTION:\n" + ee);
            // ee.printStackTrace();
            throw ee;
            // System.exit(1);
          }
          // for ( PVAR_NAME pn : state._observ.keySet() ) {
          //	System.out.println("check1 " + pn);
          //	for( ArrayList<LCONST> aa : state._observ.get(pn).keySet()) {
          //		System.out.println("check1 :" + aa + ": " + state._observ.get(pn).get(aa));
          //	}
          // }
          if (domain._bPartiallyObserved) observStore = copyObserv(state._observ);

          // Calculate reward / objective and store
          double reward =
              ((Number)
                      domain._exprReward.sample(
                          new HashMap<LVAR, LCONST>(), state, rand, new BooleanPair()))
                  .doubleValue();
          rewards.add(reward);
          accum_reward += cur_discount * reward;
          // System.out.println("Accum reward: " + accum_reward + ", instance._dDiscount: " +
          // instance._dDiscount +
          //   " / " + (cur_discount * reward) + " / " + reward);
          cur_discount *= instance._dDiscount;

          stateViz.display(state, h);
          state.advanceNextState();
        }
        accum_total_reward += accum_reward;
        msg = createXMLRoundEnd(r, accum_reward, h, 0);
        if (SHOW_MSG) System.out.println("Sending msg:\n" + msg);
        sendOneMessage(osw, msg);
      }
      msg = createXMLSessionEnd(accum_total_reward, r, 0, this.clientName, this.id);
      if (SHOW_MSG) System.out.println("Sending msg:\n" + msg);
      sendOneMessage(osw, msg);

      BufferedWriter bw = new BufferedWriter(new FileWriter(LOG_FILE, true));
      bw.write(msg);
      bw.newLine();
      bw.flush();

      // need to wait 10 seconds to pretend that we're processing something
      //			try {
      //				Thread.sleep(10000);
      //			}
      //			catch (Exception e){}
      //			TimeStamp = new java.util.Date().toString();
      //			String returnCode = "MultipleSocketServer repsonded at "+ TimeStamp + (char) 3;
    } catch (Exception e) {
      e.printStackTrace();
      System.out.println("\n>> TERMINATING TRIAL.");
    } finally {
      try {
        connection.close();
      } catch (IOException e) {
      }
    }
  }