In [1]:
from GridWorld import GridWorld, OptimalPolicy
import matplotlib.pyplot as plt
World configuration
In [2]:
world_size = (5, 5)
special_state = [([0, 1], [4, 1], 10), ([0, 3], [2, 3], 5)]
Initialize grid world
In [3]:
world = GridWorld(world_size, special_state)
Policy configuration
In [4]:
policy = OptimalPolicy(0.9)
Iteration
In [5]:
iteration = 0
diffs = []
while not (world.diff < 1e-4):
world.step(policy)
iteration += 1
diffs.append(world.diff)
Show value matrix
In [6]:
world.show_value(3)
Plot gradient curve
In [7]:
plt.figure(1)
plt.plot(range(iteration), diffs)
plt.show()
In [ ]: