In [56]:
import numpy as np
import matplotlib.pyplot as plt
pi = np.ones((16,4))
pi /= pi.sum(axis=1, keepdims=True)
R = -np.ones((16, 16, 4))
R[:,0,:] = 0
R[:,15,:] = 0
g = 1
In [57]:
# 1 2 3
# 4 5 6 7
# 8 9 10 11
# 12 13 14
# action
# 0
# 1 2
# 3
# P(s'|s,a)
# s' s a
P = np.zeros((16, 16, 4))
P_elem = (
#
(1,1,0),(2,2,0),(3,3,0),
(0,4,0),(1,5,0),(2,6,0),(3,7,0),
(4,8,0),(5,9,0),(6,10,0),(7,11,0),
(8,12,0),(9,13,0),(10,14,0),
#
(0,1,1),(1,2,1),(2,3,1),
(4,4,1),(4,5,1),(5,6,1),(6,7,1),
(8,8,1),(8,9,1),(9,10,1),(10,11,1),
(12,12,1),(12,13,1),(13,14,1),
#
(2,1,2),(3,2,2),(3,3,2),
(5,4,2),(6,5,2),(7,6,2),(7,7,2),
(9,8,2),(10,9,2),(11,10,2),(11,11,2),
(13,12,2),(14,13,2),(15,14,2),
#
(5,1,3),(6,2,3),(7,3,3),
(8,4,3),(9,5,3),(10,6,3),(11,7,3),
(12,8,3),(13,9,3),(14,10,3),(15,11,3),
(12,12,3),(13,13,3),(14,14,3),
)
for i,j,a in P_elem:
P[i,j,a]=1
In [58]:
#np.sum([P[i,j,a]*(R[i,j,a] + g*V[i]) for i in xrange(16)])
V = np.zeros((16,))
for _ in xrange(1000):
# V[:] = np.sum(pi*(np.sum(P*R, axis=0) + np.dot(g*V, P)), axis=1)
V[:] = np.sum(pi*(np.sum(P*R, axis=0) + np.tensordot(g*V, P, axes=[[0],[0]])), axis=1)
print V.reshape((4,4))