In [1]:
import sys
if "../" not in sys.path:
sys.path.append("../")
import gym
import torch.optim as optim
import matplotlib
%matplotlib inline
matplotlib.style.use('ggplot')
from reinforce_baseline import PolicyEstimator, ValueEstimator, reinforce_baseline
from utils import plotting
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
In [2]:
env = gym.make('CartPole-v0')
In [3]:
state_D = env.observation_space.shape[0]
action_D = env.action_space.n
policy_estimator = PolicyEstimator(state_D, action_D, hidden_size=128)
policy_optimizer = optim.Adam(policy_estimator.parameters(), lr=1e-2)
value_estimator = ValueEstimator(state_D, hidden_size=128)
value_optimizer = optim.Adam(value_estimator.parameters(), lr=5e-2)
In [4]:
stats = reinforce_baseline(
env,
policy_estimator,
policy_optimizer,
value_estimator,
value_optimizer,
num_episodes=200,
discount_factor=0.99,
render=False
)
In [5]:
plotting.plot_episode_stats(stats, smoothing_window=25)
Out[5]: