public void visualizeFunctions() { PuddleMapFV fvgen = new PuddleMapFV(this.puddleMap, 5, 20, 20); LinearStateDifferentiableRF rf = new LinearStateDifferentiableRF(fvgen, fvgen.getDim()); LinearStateDifferentiableRF objectiveRF = new LinearStateDifferentiableRF(fvgen, fvgen.getDim()); objectiveRF.setParameters(new double[] {1., -10, -10, 0, -10, 0, 0, 0, 0, 0}); java.util.List<EpisodeAnalysis> eas = EpisodeAnalysis.parseFilesIntoEAList(this.expertDir, domain, this.sp); int depth = 6; double beta = 10; DifferentiableSparseSampling dss = new DifferentiableSparseSampling( domain, rf, new NullTermination(), 0.99, new NameDependentStateHashFactory(), depth, -1, beta); // DifferentiableZeroStepPlanner dss = new DifferentiableZeroStepPlanner(domain, rf); dss.toggleDebugPrinting(false); MLIRLRequest request = new MLIRLRequest(domain, dss, eas, rf); request.setBoltzmannBeta(beta); // MLIRL irl = new MLIRL(request, 0.001, 0.01, 10); //use this for only the given features MLIRL irl = new MLIRL(request, 0.00001, 0.01, 10); // MLIRL irl = new MLIRL(request, 0.0001, 0.01, 10); irl.performIRL(); TerminalFunction tf = new GridWorldTerminalFunction(20, 20); ValueIteration vi = new ValueIteration( this.domain, objectiveRF, new NullTermination(), 0.99, new DiscreteStateHashFactory(), 0.01, 200); // vi.planFromState(this.initialState); SparseSampling ssLearned = new SparseSampling( this.domain, request.getRf(), new NullTermination(), 0.99, new DiscreteStateHashFactory(), depth, -1); SparseSampling ssObjective = new SparseSampling( this.domain, objectiveRF, new NullTermination(), 0.99, new DiscreteStateHashFactory(), depth, -1); StateRewardFunctionValue objectiveRFVis = new StateRewardFunctionValue(this.domain, objectiveRF); StateRewardFunctionValue learnedRFVis = new StateRewardFunctionValue(this.domain, rf); List<State> allStates = StateReachability.getReachableStates( this.initialState, (SADomain) this.domain, new DiscreteStateHashFactory()); ValueFunctionVisualizerGUI gui = GridWorldDomain.getGridWorldValueFunctionVisualization(allStates, ssObjective, null); StateValuePainter2D vpainter = (StateValuePainter2D) gui.getSvp(); vpainter.toggleValueStringRendering(false); LandmarkColorBlendInterpolation colorBlend = new LandmarkColorBlendInterpolation(); colorBlend.addNextLandMark(0., Color.BLACK); colorBlend.addNextLandMark(1., Color.WHITE); vpainter.setColorBlend(colorBlend); gui.initGUI(); }
public static void main(String[] args) { MountainCar mc = new MountainCar(); Domain domain = mc.generateDomain(); TerminalFunction tf = new MountainCar.ClassicMCTF(); RewardFunction rf = new GoalBasedRF(tf, 100., -1.); State valleyState = mc.getCleanState(domain); StateGridder sg = new StateGridder(); int samplesPerDimension = 50; sg.gridEntireDomainSpace(domain, samplesPerDimension); List<State> samples = sg.gridInputState(valleyState); // make sure we have goal int terminalStates = 0; for (State s : samples) { if (tf.isTerminal(s)) { terminalStates++; } } if (terminalStates == 0) { throw new RuntimeException("Did not find termainal state in gridding"); } else { System.out.println("found " + terminalStates + " terminal states"); } FittedVI fvi = new FittedVI( domain, rf, tf, 0.99, WekaVFATrainer.getKNNTrainer( new ConcatenatedObjectFeatureVectorGenerator(false, MountainCar.CLASSAGENT), 4), samples, -1, 0.01, 150); /*FittedVI fvi = new FittedVI(domain, rf, tf, 0.99, WekaVFATrainer.getKNNTrainer(new ConcatenatedObjectFeatureVectorGenerator(false, MountainCar.CLASSAGENT), 4), samples, -1, 0.01, 150);*/ // fvi.setPlanningDepth(4); fvi.setPlanningAndControlDepth(1); fvi.runVI(); System.out.println("Starting policy eval"); // fvi.setControlDepth(4); Policy p = new GreedyDeterministicQPolicy(fvi); // MountainCar.setAgent(valleyState, -1.1, 0.); EpisodeAnalysis ea = p.evaluateBehavior(valleyState, rf, tf, 500); List<EpisodeAnalysis> eas = new ArrayList<EpisodeAnalysis>(); eas.add(ea); System.out.println("Episode size: " + ea.maxTimeStep()); EpisodeSequenceVisualizer evis = new EpisodeSequenceVisualizer(MountainCarVisualizer.getVisualizer(mc), domain, eas); evis.initGUI(); System.out.println("Starting value funciton vis"); // fvi.setControlDepth(1); ValueFunctionVisualizerGUI gui = ValueFunctionVisualizerGUI.createGridWorldBasedValueFunctionVisualizerGUI( samples, fvi, null, MountainCar.CLASSAGENT, MountainCar.ATTX, MountainCar.ATTV, MountainCar.ACTIONFORWARD, MountainCar.ACTIONBACKWARDS, MountainCar.ACTIONCOAST, MountainCar.ACTIONBACKWARDS); StateValuePainter2D sv2 = (StateValuePainter2D) gui.getSvp(); sv2.toggleValueStringRendering(false); LandmarkColorBlendInterpolation colorBlend = new LandmarkColorBlendInterpolation(); colorBlend.addNextLandMark(0., Color.BLACK); colorBlend.addNextLandMark(1., Color.WHITE); sv2.setColorBlend(colorBlend); sv2.setNumXCells(samplesPerDimension - 1); sv2.setNumYCells(samplesPerDimension - 1); gui.initGUI(); }