In [1]:
import numpy as np
from mpl_toolkits.mplot3d import axes3d

import cPickle
import matplotlib.pyplot as plt
import os
os.environ["FONTCONFIG_PATH"]="/etc/fonts"

from easy21 import *

In [2]:
def runepisode():

    #initialize the state randomly
    state = np.random.randint(low = 1, high=10, size=None),np.random.randint(low = 1, high=10, size=None) #player, dealer

    terminated = False
    while( not terminated):
        # action = policy(Qtable[state])
        action = policy(state)
        reward, successor, terminated = step(state[0],state[1],action)
        episode.append((state,action,reward))

        #counting state visits
        Nsa[state[0]-1,state[1]-1,action] += 1

        state = successor

In [3]:
def policy(state):

    Ns = Nsa[state[0]-1,state[1]-1,0] + Nsa[state[0]-1,state[1]-1,1]
    N_0 = 100
    epsilon = N_0/(N_0 + Ns)

    explore = np.random.choice([1,0],p=[epsilon, 1-epsilon])
    if not explore:
        return np.argmax(Qtable[state[0]-1,state[1]-1,:])
    else:
        return np.random.choice([1,0])

In [4]:
Qtable = np.zeros((21,10,2))
Nsa = np.zeros((21,10,2))
# policy = np.zeros(21,10)


numiter = 10000

# policy = np.argmax(Qtable[state])
for i in range(numiter):
    episode = []  #just one episode
    runepisode()
    Gt = episode[-1][2]

    for state, action, reward in episode:
        Qtable[state[0]-1,state[1]-1,action] += (1/Nsa[state[0]-1,state[1]-1,action])*(Gt - Qtable[state[0]-1,state[1]-1,action])

In [6]:
## save to file
save = True
if save:
    output = open('opt_Qfcn_'+str(numiter)+'_iter.pkl', 'wb')
    cPickle.dump('Qtable', output)
    output.close()

In [10]:
## read from file
numiter = 1000000

pkl_file = open('opt_Qfcn_'+str(numiter)+'_iter.pkl', 'rb')
Qtable = cPickle.load(pkl_file)
pkl_file.close()

In [6]:
opt_Valuefunction = np.max(Qtable,2)
#plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y = np.meshgrid(range(1,11), range(1,22))
# print(X.shape,Y.shape)
ax.plot_wireframe(X,Y, opt_Valuefunction, legend='asdf')
ax.set_xlabel("dealer")
ax.set_ylabel("player")
ax.set_zlabel("value")

fig = plt.figure()
opt_policy = np.argmax(Qtable,2)
plt.imshow(opt_policy,cmap=plt.get_cmap('gray'),interpolation='none')
plt.show()


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-6-ea5025fe3d2a> in <module>()
      5 X, Y = np.meshgrid(range(1,11), range(1,22))
      6 # print(X.shape,Y.shape)
----> 7 ax.plot_wireframe(X,Y, opt_Valuefunction, legend='asdf')
      8 ax.set_xlabel("dealer")
      9 ax.set_ylabel("player")

/home/guille/Projects/DeepRL/python/local/lib/python2.7/site-packages/mpl_toolkits/mplot3d/axes3d.pyc in plot_wireframe(self, X, Y, Z, *args, **kwargs)
   1803                   zip(txlines, tylines, tzlines)]
   1804 
-> 1805         linec = art3d.Line3DCollection(lines, *args, **kwargs)
   1806         self.add_collection(linec)
   1807         self.auto_scale_xyz(X, Y, Z, had_data)

/home/guille/Projects/DeepRL/python/local/lib/python2.7/site-packages/mpl_toolkits/mplot3d/art3d.pyc in __init__(self, segments, *args, **kwargs)
    207         Keyword arguments are passed onto :func:`~matplotlib.collections.LineCollection`.
    208         '''
--> 209         LineCollection.__init__(self, segments, *args, **kwargs)
    210 
    211     def set_sort_zpos(self, val):

/home/guille/Projects/DeepRL/python/local/lib/python2.7/site-packages/matplotlib/collections.pyc in __init__(self, segments, linewidths, colors, antialiaseds, linestyles, offsets, transOffset, norm, cmap, pickradius, zorder, facecolors, **kwargs)
   1145             pickradius=pickradius,
   1146             zorder=zorder,
-> 1147             **kwargs)
   1148 
   1149         self.set_segments(segments)

/home/guille/Projects/DeepRL/python/local/lib/python2.7/site-packages/matplotlib/collections.pyc in __init__(self, edgecolors, facecolors, linewidths, linestyles, antialiaseds, offsets, transOffset, norm, cmap, pickradius, hatch, urls, offset_position, zorder, **kwargs)
    137 
    138         self._path_effects = None
--> 139         self.update(kwargs)
    140         self._paths = None
    141 

/home/guille/Projects/DeepRL/python/local/lib/python2.7/site-packages/matplotlib/artist.pyc in update(self, props)
    854                 func = getattr(self, 'set_' + k, None)
    855                 if func is None or not six.callable(func):
--> 856                     raise AttributeError('Unknown property %s' % k)
    857                 func(v)
    858             changed = True

AttributeError: Unknown property legend

In [5]:
import sarsa_lambda

In [8]:
sarsa_lambda.policy


Out[8]:
<function sarsa_lambda.policy>

In [ ]: