In [1]:
%matplotlib inline
import gym
import matplotlib
import numpy as np
import sys
from collections import defaultdict
if "../" not in sys.path:
sys.path.append("../")
from lib.envs.blackjack import BlackjackEnv
from lib import plotting
matplotlib.style.use('ggplot')
In [2]:
env = BlackjackEnv()
In [3]:
def mc_prediction(policy, env, num_episodes, discount_factor=1.0):
"""
Monte Carlo prediction algorithm. Calculates the value function
for a given policy using sampling.
Args:
policy: A function that maps an observation to action probabilities.
env: OpenAI gym environment.
num_episodes: Nubmer of episodes to sample.
discount_factor: Lambda discount factor.
Returns:
A dictionary that maps from state -> value.
The state is a tuple and the value is a float.
"""
# Keeps track of sum and count of returns for each state
# to calculate an average. We could use an array to save all
# returns (like in the book) but that's memory inefficient.
returns_sum = defaultdict(float)
returns_count = defaultdict(float)
# The final value function
V = defaultdict(float)
for _ in range(num_episodes):
done = False
episode = []
s = env.reset()
while not done:
probs = policy(s)
action = np.random.choice(np.arange(len(probs)), p=probs)
ns, reward, done, info = env.step(action)
episode.append((s, action, reward))
s = ns
s_history = list(map(lambda x: x[0], episode))
for s in set(s_history):
# pretty hack
first_pos = next(i for i,x in enumerate(s_history)
if x == s)
G = sum([x[2] for x in episode[first_pos:]])
returns_sum[s] += G
returns_count[s] += 1.0
V[s] = returns_sum[s] / returns_count[s]
return V
In [4]:
def sample_policy(observation):
"""
A policy that sticks if the player score is > 20 and hits otherwise.
"""
score, dealer_score, usable_ace = observation
return np.array([1.0, 0.0]) if score >= 20 else np.array([0.0, 1.0])
In [5]:
V_10k = mc_prediction(sample_policy, env, num_episodes=10000)
plotting.plot_value_function(V_10k, title="10,000 Steps")
V_500k = mc_prediction(sample_policy, env, num_episodes=500000)
plotting.plot_value_function(V_500k, title="500,000 Steps")
In [ ]: