In [1]:
import numpy as np
import sys
if "../" not in sys.path:
  sys.path.append("../")

import matplotlib
%matplotlib inline
matplotlib.style.use('ggplot')

from envs.blackjack import BlackjackEnv
from utils import plotting

from mc_prediction import mc_prediction

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
env = BlackjackEnv()

In [3]:
def sample_policy(observation):
    """
    A policy that sticks if the player score is > 20 and hits otherwise.
    """
    score, dealer_score, usable_ace = observation
    return np.array([1.0, 0.0]) if score >= 20 else np.array([0.0, 1.0])

In [4]:
V_10k = mc_prediction(sample_policy, env, num_episodes=10000)
plotting.plot_value_function(V_10k, title="10,000 Steps")

V_500k = mc_prediction(sample_policy, env, num_episodes=500000)
plotting.plot_value_function(V_500k, title="500,000 Steps")