In [2]:
%load_ext autoreload
%autoreload 2
In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
In [26]:
import algos
import features
import parametric
import policy
import chicken
from agents import OnPolicyAgent
from rlbench import *
In [12]:
# define the experiment
num_states = 8
num_features = 8
# set up environment
env = chicken.Chicken(num_states)
# set up policy
pol_pi = policy.FixedPolicy({s: {0: 1} if s < 4 else {0: 0.5, 1: 0.5} for s in env.states})
# set feature mapping
# phi = features.RandomBinary(num_features, num_features // 2, random_seed=101011)
phi = features.Int2Unary(num_states)
# run the algorithms for enough time to get reliable convergence
num_steps = 100000
# the TD(1) solution should minimize the mean-squared error
update_params = {
'gm': 0.9,
'gm_p': 0.9,
'lm': 0.0,
}
lstd_1 = OnPolicyAgent(algos.LSTD(phi.length), pol_pi, phi, update_params)
run_episode(lstd_1, env, num_steps)
mse_values = lstd_1.get_values(env.states)
# the TD(0) solution should minimize the MSPBE
update_params = {
'gm': 0.9,
'gm_p': 0.9,
'lm': 0.0,
}
lstd_0 = OnPolicyAgent(algos.LSTD(phi.length), pol_pi, phi, update_params)
run_episode(lstd_0, env, num_steps)
mspbe_values = lstd_0.get_values(env.states)
In [30]:
# Plot the states against their target values
xvals = list(sorted(env.states))
y_mse = [mse_values[s] for s in xvals]
y_mspbe = [mspbe_values[s] for s in xvals]
# Mean-square error optimal values
plt.bar(xvals, y_mse)
plt.show()
# MSPBE optimal values
plt.bar(xvals, y_mspbe)
plt.show()
In [16]:
algos.algo_registry
Out[16]:
These algorithms are given to the OnPolicyAgent
, which also takes care of the function approximation and manages the parameters given to the learning algorithm.
In [39]:
# set up algorithm parameters
update_params = {
'alpha': 0.01,
'beta': 0.001,
'gm': 0.9,
'gm_p': 0.9,
'lm': 0.0,
'lm_p': 0.0,
'interest': 1.0,
}
# Run all available algorithms
max_steps = 10000
for name, alg in algos.algo_registry.items():
# Set up the agent, run the experiment, get state-values
agent = OnPolicyAgent(alg(phi.length), pol_pi, phi, update_params)
mse_lst = run_errors(agent, env, max_steps, mse_values)
mspbe_lst = run_errors(agent, env, max_steps, mspbe_values)
# Plot the errors
xdata = np.arange(max_steps)
# plt.plot(xdata, mse_lst)
# plt.plot(xdata, mspbe_lst)
plt.plot(xdata, np.log(mspbe_lst))
# Add information to the graph
plt.title(name)
plt.xlabel('Timestep')
plt.ylabel('Error')
plt.show()
In [36]:
states
In [38]:
np.sqrt(np.mean(y_mspbe))
Out[38]:
In [ ]: