In [1]:
from sarsa_lambda_fa import *
In [5]:
sarsa_test = sarsa_fa()
sarsa_test.loop_over_lambda(1000000,3)
In [ ]:
# class sarsa_fa:
# def visualize(self):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
X, Y = np.meshgrid(range(1,11), range(1,22))
# print(X.shape,Y.shape)
ax.plot_wireframe(X,Y, sarsa_test.opt_Valuefunction)
ax.set_xlabel("dealer")
ax.set_ylabel("player")
ax.set_zlabel("value")
fig = plt.figure()
opt_policy = np.argmax(sarsa_test.calcQtable(),2)
plt.imshow(opt_policy,cmap=plt.get_cmap('gray'),interpolation='none')
fig = plt.figure()
plt.plot(sarsa_test.mse_1000)
plt.xlabel("lambda")
plt.ylabel("mean squared errror from 1e6 monte carlo")
fig = plt.figure()
for i in range(3):
line1, = plt.plot(range(len(sarsa_test.error_lists[i])), sarsa_test.error_lists[i], label="lambda ="+str(np.linspace(0,1,3)[i]))
plt.legend()
plt.xlabel("1000 episodes")
plt.ylabel("Mean squared error to Monte Carlo 1e6 ")
plt.show()
plt.close('all')
In [4]:
sarsa_test.visualize()
In [ ]:
sarsa_test.epsilon