In [1]:
from sarsa_lambda_fa import *

In [5]:
sarsa_test = sarsa_fa()

sarsa_test.loop_over_lambda(1000000,3)


using epsilon  0.05
numiter 1000000 n_lambda 3
0.0
0.5
1.0

In [ ]:
# class sarsa_fa:

#     def visualize(self):

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y = np.meshgrid(range(1,11), range(1,22))
# print(X.shape,Y.shape)
ax.plot_wireframe(X,Y, sarsa_test.opt_Valuefunction)
ax.set_xlabel("dealer")
ax.set_ylabel("player")
ax.set_zlabel("value")

fig = plt.figure()
opt_policy = np.argmax(sarsa_test.calcQtable(),2)
plt.imshow(opt_policy,cmap=plt.get_cmap('gray'),interpolation='none')

fig = plt.figure()
plt.plot(sarsa_test.mse_1000)
plt.xlabel("lambda")
plt.ylabel("mean squared errror from 1e6 monte carlo")


fig = plt.figure()
for i in range(3):
    line1, = plt.plot(range(len(sarsa_test.error_lists[i])), sarsa_test.error_lists[i], label="lambda ="+str(np.linspace(0,1,3)[i]))
plt.legend()

plt.xlabel("1000 episodes")
plt.ylabel("Mean squared error to Monte Carlo 1e6 ")

plt.show()

plt.close('all')

In [4]:
sarsa_test.visualize()


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-4-caa8439fbd87> in <module>()
----> 1 sarsa_test.visualize()

/home/frederik/Dokumente/DeepRL/easy21/sarsa_lambda_fa.pyc in visualize(self)
    198         fig = plt.figure()
    199         for i in range(11):
--> 200             line1, = plt.plot(range(len(self.error_lists[i])), self.error_lists[i], label="lambda ="+str(np.linspace(0,1,3)[i]))
    201             plt.legend(handles=[line1], loc=1)
    202 

IndexError: list index out of range

In [ ]:
sarsa_test.epsilon