コード例 #1
0
ファイル: InfinitigerDomain.java プロジェクト: h2r/diapers
  public Domain generateDomain() {
    Domain domain =
        new POMDPDomain() {
          @Override
          public POMDPState sampleInitialState() {
            return InfinitigerDomain.getNewState(this);
          }

          @Override
          public Observation makeObservationFor(GroundedAction a, POMDPState s) {
            return InfinitigerDomain.makeObservationFor(this, a, s);
          }

          @Override
          public boolean isSuccess(Observation o) {
            if (o == null) {
              return false;
            }
            return InfinitigerDomain.isSuccess(o);
          }

          @Override
          public boolean isTerminal(POMDPState s) {
            return InfinitigerDomain.isTerminal(this, s);
          }

          @Override
          public List<POMDPState> getAllInitialStates() {
            NameDependentStateHashFactory hashFactory = new NameDependentStateHashFactory();
            Set<StateHashTuple> tempSet = new HashSet<StateHashTuple>();
            for (int i = 0; i < Math.pow(iterations, 2) * 10; i++) {
              tempSet.add(hashFactory.hashState(InfinitigerDomain.getNewState(this)));
            }
            Set<POMDPState> noDups = new HashSet<POMDPState>();
            for (StateHashTuple shi : tempSet) {
              noDups.add(new POMDPState(shi.s));
            }

            return new ArrayList<POMDPState>(noDups);
          }

          @Override
          public List<Observation> getObservations() {
            return new ArrayList<Observation>(observations);
          }

          @Override
          public Observation getObservation(String name) {
            return observationMap.get(name);
          }

          @Override
          public void addObservation(Observation observation) {
            if (!observationMap.containsKey(observation.getName())) {
              observations.add(observation);
              observationMap.put(observation.getName(), observation);
            }
          }
        };

    Attribute tigerness = new Attribute(domain, Names.ATTR_TIGERNESS, Attribute.AttributeType.DISC);
    tigerness.setDiscValuesForRange(0, 1, 1);

    Attribute index = new Attribute(domain, Names.ATTR_INDEX, Attribute.AttributeType.DISC);
    index.setDiscValuesForRange(0, iterations + 1, 1);

    Attribute position = new Attribute(domain, Names.ATTR_POSITION, Attribute.AttributeType.DISC);
    position.setDiscValues(
        new ArrayList<String>() {
          {
            add(Names.LEFT);
            add(Names.RIGHT);
          }
        });

    ObjectClass doorClass = new ObjectClass(domain, Names.CLASS_DOOR);
    doorClass.addAttribute(tigerness);
    doorClass.addAttribute(position);

    ObjectClass indexerClass = new ObjectClass(domain, Names.CLASS_INDEXER);
    indexerClass.addAttribute(index);

    Action openDoor = new OpenAction(domain, Names.ACTION_OPEN_DOOR);
    Action listen = new ListenAction(domain, Names.ACTION_LISTEN);

    for (int i = 0; i < observationsPerState; ++i) {
      Observation left =
          new Observation(domain, Names.OBS_LEFT_DOOR + i) {
            @Override
            public double getProbability(State s, GroundedAction a) {
              if (a.action.getName().equals(Names.ACTION_LISTEN)) {
                ObjectInstance leftDoor = s.getObject(Names.OBJ_LEFT_DOOR);
                int leftDoorTiger = leftDoor.getDiscValForAttribute(Names.ATTR_TIGERNESS);
                if (leftDoorTiger == 1) {
                  return (1 - noise) / observationsPerState;
                } else {
                  return (noise) / observationsPerState;
                }
              }

              return 0.0;
            }
          };
      Observation right =
          new Observation(domain, Names.OBS_RIGHT_DOOR + i) {
            @Override
            public double getProbability(State s, GroundedAction a) {
              if (a.action.getName().equals(Names.ACTION_LISTEN)) {
                ObjectInstance leftDoor = s.getObject(Names.OBJ_LEFT_DOOR);
                int leftDoorTiger = leftDoor.getDiscValForAttribute(Names.ATTR_TIGERNESS);
                if (leftDoorTiger == 0) {
                  return (1 - noise) / observationsPerState;
                } else {
                  return (noise) / observationsPerState;
                }
              }

              return 0.0;
            }
          };
    }

    Observation nullObs =
        new Observation(domain, Names.OBS_NULL) {
          @Override
          public double getProbability(State s, GroundedAction a) {
            if (a.action.getName().equals(Names.ACTION_OPEN_DOOR)) {
              return 0.5;
            }

            return 0.0;
          }
        };
    Observation complete =
        new Observation(domain, Names.OBS_COMPLETE) {
          @Override
          public double getProbability(State s, GroundedAction a) {
            ObjectInstance indexer = s.getObject(Names.OBJ_INDEXER);
            int index = indexer.getDiscValForAttribute(Names.ATTR_INDEX);

            if (index == iterations) {
              return 1.0;
            }
            return 0.0;
          }
        };
    return domain;
  }