public void visualizeFunctions() {

    PuddleMapFV fvgen = new PuddleMapFV(this.puddleMap, 5, 20, 20);
    LinearStateDifferentiableRF rf = new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    LinearStateDifferentiableRF objectiveRF =
        new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    objectiveRF.setParameters(new double[] {1., -10, -10, 0, -10, 0, 0, 0, 0, 0});

    java.util.List<EpisodeAnalysis> eas =
        EpisodeAnalysis.parseFilesIntoEAList(this.expertDir, domain, this.sp);

    int depth = 6;
    double beta = 10;
    DifferentiableSparseSampling dss =
        new DifferentiableSparseSampling(
            domain,
            rf,
            new NullTermination(),
            0.99,
            new NameDependentStateHashFactory(),
            depth,
            -1,
            beta);
    // DifferentiableZeroStepPlanner dss = new DifferentiableZeroStepPlanner(domain, rf);
    dss.toggleDebugPrinting(false);

    MLIRLRequest request = new MLIRLRequest(domain, dss, eas, rf);
    request.setBoltzmannBeta(beta);

    // MLIRL irl = new MLIRL(request, 0.001, 0.01, 10); //use this for only the given features
    MLIRL irl = new MLIRL(request, 0.00001, 0.01, 10);
    // MLIRL irl = new MLIRL(request, 0.0001, 0.01, 10);

    irl.performIRL();

    TerminalFunction tf = new GridWorldTerminalFunction(20, 20);

    ValueIteration vi =
        new ValueIteration(
            this.domain,
            objectiveRF,
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            0.01,
            200);
    // vi.planFromState(this.initialState);

    SparseSampling ssLearned =
        new SparseSampling(
            this.domain,
            request.getRf(),
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            depth,
            -1);
    SparseSampling ssObjective =
        new SparseSampling(
            this.domain,
            objectiveRF,
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            depth,
            -1);

    StateRewardFunctionValue objectiveRFVis =
        new StateRewardFunctionValue(this.domain, objectiveRF);
    StateRewardFunctionValue learnedRFVis = new StateRewardFunctionValue(this.domain, rf);

    List<State> allStates =
        StateReachability.getReachableStates(
            this.initialState, (SADomain) this.domain, new DiscreteStateHashFactory());

    ValueFunctionVisualizerGUI gui =
        GridWorldDomain.getGridWorldValueFunctionVisualization(allStates, ssObjective, null);
    StateValuePainter2D vpainter = (StateValuePainter2D) gui.getSvp();
    vpainter.toggleValueStringRendering(false);
    LandmarkColorBlendInterpolation colorBlend = new LandmarkColorBlendInterpolation();
    colorBlend.addNextLandMark(0., Color.BLACK);
    colorBlend.addNextLandMark(1., Color.WHITE);
    vpainter.setColorBlend(colorBlend);
    gui.initGUI();
  }
예제 #2
0
  public static void main(String[] args) {

    MountainCar mc = new MountainCar();
    Domain domain = mc.generateDomain();
    TerminalFunction tf = new MountainCar.ClassicMCTF();
    RewardFunction rf = new GoalBasedRF(tf, 100., -1.);

    State valleyState = mc.getCleanState(domain);
    StateGridder sg = new StateGridder();
    int samplesPerDimension = 50;
    sg.gridEntireDomainSpace(domain, samplesPerDimension);
    List<State> samples = sg.gridInputState(valleyState);

    // make sure we have goal
    int terminalStates = 0;
    for (State s : samples) {
      if (tf.isTerminal(s)) {
        terminalStates++;
      }
    }

    if (terminalStates == 0) {
      throw new RuntimeException("Did not find termainal state in gridding");
    } else {
      System.out.println("found " + terminalStates + " terminal states");
    }

    FittedVI fvi =
        new FittedVI(
            domain,
            rf,
            tf,
            0.99,
            WekaVFATrainer.getKNNTrainer(
                new ConcatenatedObjectFeatureVectorGenerator(false, MountainCar.CLASSAGENT), 4),
            samples,
            -1,
            0.01,
            150);

    /*FittedVI fvi = new FittedVI(domain, rf, tf, 0.99,
    WekaVFATrainer.getKNNTrainer(new ConcatenatedObjectFeatureVectorGenerator(false, MountainCar.CLASSAGENT), 4),
    samples, -1, 0.01, 150);*/
    // fvi.setPlanningDepth(4);

    fvi.setPlanningAndControlDepth(1);
    fvi.runVI();

    System.out.println("Starting policy eval");
    // fvi.setControlDepth(4);
    Policy p = new GreedyDeterministicQPolicy(fvi);
    // MountainCar.setAgent(valleyState, -1.1, 0.);
    EpisodeAnalysis ea = p.evaluateBehavior(valleyState, rf, tf, 500);
    List<EpisodeAnalysis> eas = new ArrayList<EpisodeAnalysis>();
    eas.add(ea);

    System.out.println("Episode size: " + ea.maxTimeStep());

    EpisodeSequenceVisualizer evis =
        new EpisodeSequenceVisualizer(MountainCarVisualizer.getVisualizer(mc), domain, eas);
    evis.initGUI();

    System.out.println("Starting value funciton vis");

    // fvi.setControlDepth(1);

    ValueFunctionVisualizerGUI gui =
        ValueFunctionVisualizerGUI.createGridWorldBasedValueFunctionVisualizerGUI(
            samples,
            fvi,
            null,
            MountainCar.CLASSAGENT,
            MountainCar.ATTX,
            MountainCar.ATTV,
            MountainCar.ACTIONFORWARD,
            MountainCar.ACTIONBACKWARDS,
            MountainCar.ACTIONCOAST,
            MountainCar.ACTIONBACKWARDS);

    StateValuePainter2D sv2 = (StateValuePainter2D) gui.getSvp();
    sv2.toggleValueStringRendering(false);
    LandmarkColorBlendInterpolation colorBlend = new LandmarkColorBlendInterpolation();
    colorBlend.addNextLandMark(0., Color.BLACK);
    colorBlend.addNextLandMark(1., Color.WHITE);
    sv2.setColorBlend(colorBlend);
    sv2.setNumXCells(samplesPerDimension - 1);
    sv2.setNumYCells(samplesPerDimension - 1);

    gui.initGUI();
  }