In [1]:
from GridWorld import GridWorld, OptimalPolicy
import matplotlib.pyplot as plt

World configuration


In [2]:
world_size = (5, 5)
special_state = [([0, 1], [4, 1], 10), ([0, 3], [2, 3], 5)]

Initialize grid world


In [3]:
world = GridWorld(world_size, special_state)

Policy configuration


In [4]:
policy = OptimalPolicy(0.9)

Iteration


In [5]:
iteration = 0
diffs = []
while not (world.diff < 1e-4):
    world.step(policy)
    iteration += 1
    diffs.append(world.diff)

Show value matrix


In [6]:
world.show_value(3)


|-------+-------+-------+-------+-------|
| 21.977| 24.419| 21.977| 19.419| 17.477|
|-------+-------+-------+-------+-------|
| 19.780| 21.977| 19.780| 17.802| 16.022|
|-------+-------+-------+-------+-------|
| 17.802| 19.780| 17.802| 16.022| 14.419|
|-------+-------+-------+-------+-------|
| 16.022| 17.802| 16.022| 14.419| 12.977|
|-------+-------+-------+-------+-------|
| 14.419| 16.022| 14.419| 12.977| 11.680|
|-------+-------+-------+-------+-------|

Plot gradient curve


In [7]:
plt.figure(1)
plt.plot(range(iteration), diffs)
plt.show()



In [ ]: