コード例 #1
0
  public void generateExpertTrajectories() {

    PuddleMapFV fvgen = new PuddleMapFV(this.puddleMap, 5, 20, 20);
    LinearStateDifferentiableRF rf = new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    rf.setParameters(new double[] {1., -10, -10, 0, -10, 0, 0, 0, 0, 0});
    TerminalFunction tf = new GridWorldTerminalFunction(20, 20);

    ValueIteration vi =
        new ValueIteration(this.domain, rf, tf, 0.99, new DiscreteStateHashFactory(), 0.01, 500);
    vi.planFromState(this.initialState);

    System.out.println("Rolling out episodes");
    Policy p = new GreedyQPolicy(vi);

    EpisodeAnalysis e1 = p.evaluateBehavior(this.initialState, rf, tf, 100);

    State s2 = this.initialState.copy();
    GridWorldDomain.setAgent(s2, 3, 14);
    EpisodeAnalysis e2 = p.evaluateBehavior(s2, rf, tf, 100);

    State s3 = this.initialState.copy();
    GridWorldDomain.setAgent(s3, 11, 27);
    EpisodeAnalysis e3 = p.evaluateBehavior(s3, rf, tf, 100);

    State s4 = this.initialState.copy();
    GridWorldDomain.setAgent(s4, 28, 6);
    EpisodeAnalysis e4 = p.evaluateBehavior(s4, rf, tf, 100);

    e1.writeToFile(this.expertDir + "/ex1", this.sp);
    e2.writeToFile(this.expertDir + "/ex2", this.sp);
    e3.writeToFile(this.expertDir + "/ex3", this.sp);
    e4.writeToFile(this.expertDir + "/ex4", this.sp);
  }
コード例 #2
0
  public void runSupervised() {

    MyTimer timer = new MyTimer();
    timer.start();

    PuddleMapFV agentfv = new PuddleMapFV(this.puddleMap, 5, 20, 20);
    PuddleMapFVComponent agentCompFV = new PuddleMapFVComponent(this.puddleMap, 5, 20, 20);
    StateToFeatureVectorGenerator svar =
        new ConcatenatedObjectFeatureVectorGenerator(false, GridWorldDomain.CLASSAGENT);
    java.util.List<EpisodeAnalysis> eas =
        EpisodeAnalysis.parseFilesIntoEAList(this.expertDir, domain, this.sp);
    LinearStateDifferentiableRF objectiveRF =
        new LinearStateDifferentiableRF(agentfv, agentfv.getDim());
    objectiveRF.setParameters(new double[] {1., -10, -10, 0, -10, 0, 0, 0, 0, 0});

    WekaPolicy p = new WekaPolicy(agentfv, new J48(), this.domain.getActions(), eas);
    // WekaPolicy p = new WekaPolicy(svar, new Logistic(), this.domain.getActions(), eas);

    timer.stop();

    System.out.println("Training Time: " + timer.getTime());

    String baseName = "Svar";

    GridWorldTerminalFunction tf = new GridWorldTerminalFunction(20, 20);

    State simple = this.initialState.copy();
    GridWorldDomain.setAgent(simple, 18, 0);
    EpisodeAnalysis trainedEp1 = p.evaluateBehavior(simple, objectiveRF, tf, 200);
    trainedEp1.writeToFile(trainedDir + "/j48" + baseName + "EpSimple", this.sp);

    State hardAgent = this.initialState.copy();
    GridWorldDomain.setAgent(hardAgent, 0, 9);
    EpisodeAnalysis trainedEp2 = p.evaluateBehavior(hardAgent, objectiveRF, tf, 200);
    trainedEp2.writeToFile(trainedDir + "/j48" + baseName + "EpHardAgent", this.sp);

    int ngx = 12;
    int ngy = 14;
    tf = new GridWorldTerminalFunction(ngx, ngy);
    this.puddleMap[ngx][ngy] = 1;
    this.puddleMap[20][20] = 0;
    agentfv.setGoal(ngx, ngy);
    agentCompFV.setGoal(ngx, ngy);
    State hardGoal = this.initialState.copy();
    GridWorldDomain.setAgent(hardGoal, 0, 0);

    EpisodeAnalysis trainedEp3 = p.evaluateBehavior(hardGoal, objectiveRF, tf, 200);
    trainedEp3.writeToFile(trainedDir + "/j48" + baseName + "EpHardGoal", this.sp);

    new EpisodeSequenceVisualizer(this.v, this.domain, this.sp, this.trainedDir);
  }
コード例 #3
0
  public void visualizeFunctions() {

    PuddleMapFV fvgen = new PuddleMapFV(this.puddleMap, 5, 20, 20);
    LinearStateDifferentiableRF rf = new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    LinearStateDifferentiableRF objectiveRF =
        new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    objectiveRF.setParameters(new double[] {1., -10, -10, 0, -10, 0, 0, 0, 0, 0});

    java.util.List<EpisodeAnalysis> eas =
        EpisodeAnalysis.parseFilesIntoEAList(this.expertDir, domain, this.sp);

    int depth = 6;
    double beta = 10;
    DifferentiableSparseSampling dss =
        new DifferentiableSparseSampling(
            domain,
            rf,
            new NullTermination(),
            0.99,
            new NameDependentStateHashFactory(),
            depth,
            -1,
            beta);
    // DifferentiableZeroStepPlanner dss = new DifferentiableZeroStepPlanner(domain, rf);
    dss.toggleDebugPrinting(false);

    MLIRLRequest request = new MLIRLRequest(domain, dss, eas, rf);
    request.setBoltzmannBeta(beta);

    // MLIRL irl = new MLIRL(request, 0.001, 0.01, 10); //use this for only the given features
    MLIRL irl = new MLIRL(request, 0.00001, 0.01, 10);
    // MLIRL irl = new MLIRL(request, 0.0001, 0.01, 10);

    irl.performIRL();

    TerminalFunction tf = new GridWorldTerminalFunction(20, 20);

    ValueIteration vi =
        new ValueIteration(
            this.domain,
            objectiveRF,
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            0.01,
            200);
    // vi.planFromState(this.initialState);

    SparseSampling ssLearned =
        new SparseSampling(
            this.domain,
            request.getRf(),
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            depth,
            -1);
    SparseSampling ssObjective =
        new SparseSampling(
            this.domain,
            objectiveRF,
            new NullTermination(),
            0.99,
            new DiscreteStateHashFactory(),
            depth,
            -1);

    StateRewardFunctionValue objectiveRFVis =
        new StateRewardFunctionValue(this.domain, objectiveRF);
    StateRewardFunctionValue learnedRFVis = new StateRewardFunctionValue(this.domain, rf);

    List<State> allStates =
        StateReachability.getReachableStates(
            this.initialState, (SADomain) this.domain, new DiscreteStateHashFactory());

    ValueFunctionVisualizerGUI gui =
        GridWorldDomain.getGridWorldValueFunctionVisualization(allStates, ssObjective, null);
    StateValuePainter2D vpainter = (StateValuePainter2D) gui.getSvp();
    vpainter.toggleValueStringRendering(false);
    LandmarkColorBlendInterpolation colorBlend = new LandmarkColorBlendInterpolation();
    colorBlend.addNextLandMark(0., Color.BLACK);
    colorBlend.addNextLandMark(1., Color.WHITE);
    vpainter.setColorBlend(colorBlend);
    gui.initGUI();
  }
コード例 #4
0
  public void runIRL() {

    MyTimer timer = new MyTimer();
    timer.start();

    PuddleMapFV ofvgen = new PuddleMapFV(this.puddleMap, 5, 20, 20);

    // PuddleMapFV fvgen = new PuddleMapFV(this.puddleMap, 5, 20, 20);
    PuddleMapFVComponent fvgen = new PuddleMapFVComponent(this.puddleMap, 5, 20, 20);
    // StateToFeatureVectorGenerator fvgen = new ConcatenatedObjectFeatureVectorGenerator(false,
    // GridWorldDomain.CLASSAGENT);
    // PuddleMapExactFV fvgen = new PuddleMapExactFV(this.puddleMap, 5);
    // LinearStateDifferentiableRF rf = new LinearStateDifferentiableRF(fvgen, fvgen.getDim());
    LinearStateActionDifferentiableRF rf =
        new LinearStateActionDifferentiableRF(
            fvgen,
            fvgen.getDim(),
            new GroundedAction(this.domain.getAction(GridWorldDomain.ACTIONNORTH), ""),
            new GroundedAction(this.domain.getAction(GridWorldDomain.ACTIONSOUTH), ""),
            new GroundedAction(this.domain.getAction(GridWorldDomain.ACTIONEAST), ""),
            new GroundedAction(this.domain.getAction(GridWorldDomain.ACTIONWEST), ""));

    LinearStateDifferentiableRF objectiveRF =
        new LinearStateDifferentiableRF(ofvgen, ofvgen.getDim());
    objectiveRF.setParameters(new double[] {1., -10, -10, 0, -10, 0, 0, 0, 0, 0});
    // objectiveRF.setParameters(new double[]{1., -10, -10, 0, -10});

    java.util.List<EpisodeAnalysis> eas =
        EpisodeAnalysis.parseFilesIntoEAList(this.expertDir, domain, this.sp);

    int depth = 6;
    double beta = 10;
    // DifferentiableSparseSampling dss = new DifferentiableSparseSampling(domain, rf, new
    // NullTermination(), 0.99, new NameDependentStateHashFactory(), depth, -1, beta);
    DifferentiableZeroStepPlanner dss = new DifferentiableZeroStepPlanner(domain, rf);
    dss.toggleDebugPrinting(false);

    MLIRLRequest request = new MLIRLRequest(domain, dss, eas, rf);
    request.setBoltzmannBeta(beta);

    // MLIRL irl = new MLIRL(request, 0.001, 0.01, 10); //use this for only the given features
    // MLIRL irl = new MLIRL(request, 0.00001, 0.01, 10);
    MLIRL irl = new MLIRL(request, 0.0001, 0.01, 10);

    irl.performIRL();

    // System.out.println(this.getFVAndShapeString(rf.getParameters()));

    timer.stop();
    System.out.println("Training time: " + timer.getTime());

    /*//uncomment to run test examples

    String baseName = "RH45";

    Policy p = new GreedyQPolicy(dss);
    GridWorldTerminalFunction tf = new GridWorldTerminalFunction(20, 20);

    State simple = this.initialState.copy();
    GridWorldDomain.setAgent(simple, 18, 0);
    EpisodeAnalysis trainedEp1 = p.evaluateBehavior(simple, objectiveRF, tf, 200);
    trainedEp1.writeToFile(trainedDir+"/IRL" + baseName + "EpSimple", this.sp);



    State hardAgent = this.initialState.copy();
    GridWorldDomain.setAgent(hardAgent, 0, 9);
    EpisodeAnalysis trainedEp2 = p.evaluateBehavior(hardAgent, objectiveRF, tf, 200);
    trainedEp2.writeToFile(trainedDir+"/IRL" + baseName + "EpHardAgent", this.sp);



    dss.resetPlannerResults();

    int ngx = 12;
    int ngy = 14;

    tf = new GridWorldTerminalFunction(ngx, ngy);
    this.puddleMap[ngx][ngy] = 1;
    this.puddleMap[20][20] = 0;
    //fvgen.setGoal(ngx, ngy);
    State hardGoal = this.initialState.copy();
    GridWorldDomain.setAgent(hardGoal, 0, 0);



    EpisodeAnalysis trainedEp3 = p.evaluateBehavior(hardGoal, objectiveRF, tf, 200);
    trainedEp3.writeToFile(trainedDir+"/IRL" + baseName + "EpHardGoal", this.sp);



    new EpisodeSequenceVisualizer(this.v, this.domain, this.sp, this.trainedDir);
    */

  }