In [ ]:
# import the classes we need
from SafeRLBench.envs import LinearCar
from SafeRLBench.policy import LinearPolicy
from SafeRLBench.algo import PolicyGradient
In [ ]:
# get an instance of `LinearCar` with the default arguments.
linear_car = LinearCar()
# we need a policy which maps R^2 to R
policy = LinearPolicy(2, 1)
# setup parameters
policy.parameters = [-1, -1, 1]
# plug the environment and policy into the algorithm
optimizer = PolicyGradient(linear_car, policy, estimator='central_fd')
# run optimization
optimizer.optimize()
Lets take a look at what happened during the run. For this we can access the monitor and generate some plots.
In [ ]:
import matplotlib.pyplot as plt
y = optimizer.monitor.rewards
plt.plot(range(len(y)), y)
plt.show()
In [ ]:
# import the configuration object
from SafeRLBench import config
In [ ]:
# setup stream handler
config.logger_add_stream_handler()
# setup logger level
config.logger_set_level(config.DEBUG)
# raise monitor verbosity
config.monitor_set_verbosity(2)
After changing these values, please run the cell which invokes optimizer.optimize
again to see what happens.
In [ ]:
# import the best performance measure
from SafeRLBench.measure import BestPerformance
# import the Bench and BenchConfig
from SafeRLBench import Bench, BenchConfig
In [ ]:
# define environment configuration.
envs = [[(LinearCar, {'horizon': 100})]]
# define algorithms configuration.
algs = [[
(PolicyGradient, [{
'policy': LinearPolicy(2, 1, par=[-1, -1, 1]),
'estimator': 'central_fd',
'var': var
} for var in [1, 1.5, 2, 2.5]])
]]
# instantiate BenchConfig
config = BenchConfig(algs, envs)
# instantiate the bench
bench = Bench(config, BestPerformance())
# configure to run in parallel
config.jobs_set(4)
In [ ]:
bench()
In [ ]:
bench.measures[0]
In [ ]:
best_run = bench.measures[0].result[0][0]
monitor = best_run.get_alg_monitor()
best_trace = monitor.traces[monitor.rewards.index(max(monitor.rewards))]
y = [t[1][0] for t in best_trace]
x = range(len(y))
import matplotlib.pyplot as plt
plt.plot(x, y)
plt.show()
In [ ]: