コード例 #1
0
  public void visualizeFunctions() {

    PuddleMapFV fvgen = new PuddleMapFV(this.puddleMap, 5, 20, 20);
    LinearStateDifferentiableRF rf = new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    LinearStateDifferentiableRF objectiveRF =
        new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    objectiveRF.setParameters(new double[] {1., -10, -10, 0, -10, 0, 0, 0, 0, 0});

    java.util.List<EpisodeAnalysis> eas =
        EpisodeAnalysis.parseFilesIntoEAList(this.expertDir, domain, this.sp);

    int depth = 6;
    double beta = 10;
    DifferentiableSparseSampling dss =
        new DifferentiableSparseSampling(
            domain,
            rf,
            new NullTermination(),
            0.99,
            new NameDependentStateHashFactory(),
            depth,
            -1,
            beta);
    // DifferentiableZeroStepPlanner dss = new DifferentiableZeroStepPlanner(domain, rf);
    dss.toggleDebugPrinting(false);

    MLIRLRequest request = new MLIRLRequest(domain, dss, eas, rf);
    request.setBoltzmannBeta(beta);

    // MLIRL irl = new MLIRL(request, 0.001, 0.01, 10); //use this for only the given features
    MLIRL irl = new MLIRL(request, 0.00001, 0.01, 10);
    // MLIRL irl = new MLIRL(request, 0.0001, 0.01, 10);

    irl.performIRL();

    TerminalFunction tf = new GridWorldTerminalFunction(20, 20);

    ValueIteration vi =
        new ValueIteration(
            this.domain,
            objectiveRF,
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            0.01,
            200);
    // vi.planFromState(this.initialState);

    SparseSampling ssLearned =
        new SparseSampling(
            this.domain,
            request.getRf(),
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            depth,
            -1);
    SparseSampling ssObjective =
        new SparseSampling(
            this.domain,
            objectiveRF,
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            depth,
            -1);

    StateRewardFunctionValue objectiveRFVis =
        new StateRewardFunctionValue(this.domain, objectiveRF);
    StateRewardFunctionValue learnedRFVis = new StateRewardFunctionValue(this.domain, rf);

    List<State> allStates =
        StateReachability.getReachableStates(
            this.initialState, (SADomain) this.domain, new DiscreteStateHashFactory());

    ValueFunctionVisualizerGUI gui =
        GridWorldDomain.getGridWorldValueFunctionVisualization(allStates, ssObjective, null);
    StateValuePainter2D vpainter = (StateValuePainter2D) gui.getSvp();
    vpainter.toggleValueStringRendering(false);
    LandmarkColorBlendInterpolation colorBlend = new LandmarkColorBlendInterpolation();
    colorBlend.addNextLandMark(0., Color.BLACK);
    colorBlend.addNextLandMark(1., Color.WHITE);
    vpainter.setColorBlend(colorBlend);
    gui.initGUI();
  }
コード例 #2
0
  public void runVFRFIRL() {

    PuddleMapExactFV fvgen = new PuddleMapExactFV(this.puddleMap, 5);
    PuddleMapDistOnlyFV vfFvGen = new PuddleMapDistOnlyFV(this.puddleMap, 5, 20, 20);
    GridWorldTerminalFunction tf = new GridWorldTerminalFunction(20, 20);
    LinearStateDifferentiableRF objectiveRF =
        new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    objectiveRF.setParameters(new double[] {1., -10, -10, 0, -10});
    // LinearStateDiffVF vinit = new LinearStateDiffVF(vfFvGen, 5);

    // DiffVFRF rf = new DiffVFRF(objectiveRF, vinit);
    LinearDiffRFVInit rfvf = new LinearDiffRFVInit(fvgen, vfFvGen, 5, 5);

    java.util.List<EpisodeAnalysis> eas =
        EpisodeAnalysis.parseFilesIntoEAList(this.expertDir, domain, this.sp);

    int depth = 4;
    double beta = 10;
    DifferentiableSparseSampling dss =
        new DifferentiableSparseSampling(
            domain,
            rfvf,
            new NullTermination(),
            0.99,
            new NameDependentStateHashFactory(),
            depth,
            -1,
            beta);
    dss.setValueForLeafNodes(rfvf);
    dss.toggleDebugPrinting(false);

    MLIRLRequest request = new MLIRLRequest(domain, dss, eas, rfvf);
    request.setBoltzmannBeta(beta);

    MLIRL irl = new MLIRL(request, 0.001, 0.01, 10); // use this for only the given features
    // MLIRL irl = new MLIRL(request, 0.00001, 0.01, 10);
    // MLIRL irl = new MLIRL(request, 0.0001, 0.01, 10);

    irl.performIRL();

    // System.out.println(this.getFVAndShapeString(rf.getParameters()));

    String baseName = "SSRFVFD3";

    SparseSampling ss =
        new SparseSampling(
            domain,
            rfvf,
            new NullTermination(),
            0.99,
            new NameDependentStateHashFactory(),
            depth,
            -1);
    ss.toggleDebugPrinting(false);
    ss.setValueForLeafNodes(rfvf);

    // Policy p = new GreedyQPolicy(dss);
    Policy p = new GreedyQPolicy(ss);

    State simple = this.initialState.copy();
    GridWorldDomain.setAgent(simple, 18, 0);
    EpisodeAnalysis trainedEp1 = p.evaluateBehavior(simple, objectiveRF, tf, 200);
    trainedEp1.writeToFile(trainedDir + "/IRL" + baseName + "EpSimple", this.sp);

    State hardAgent = this.initialState.copy();
    GridWorldDomain.setAgent(hardAgent, 0, 9);
    EpisodeAnalysis trainedEp2 = p.evaluateBehavior(hardAgent, objectiveRF, tf, 200);
    trainedEp2.writeToFile(trainedDir + "/IRL" + baseName + "EpHardAgent", this.sp);

    dss.resetPlannerResults();
    ss.resetPlannerResults();

    int ngx = 12;
    int ngy = 14;

    tf = new GridWorldTerminalFunction(ngx, ngy);
    // dss.setTf(tf);
    this.puddleMap[ngx][ngy] = 1;
    this.puddleMap[20][20] = 0;
    vfFvGen.setGoal(ngx, ngy);
    State hardGoal = this.initialState.copy();
    GridWorldDomain.setAgent(hardGoal, 0, 0);

    EpisodeAnalysis trainedEp3 = p.evaluateBehavior(hardGoal, objectiveRF, tf, 200);
    trainedEp3.writeToFile(trainedDir + "/IRL" + baseName + "EpHardGoal", this.sp);

    new EpisodeSequenceVisualizer(this.v, this.domain, this.sp, this.trainedDir);
  }