In [2]:
import numpy as np
import matplotlib.pyplot as plt

V = np.zeros((101,))
V[100] = 1

p = np.float(0.45)

def backup_action(s, a):
    return p * V[s+a] + (1-p) * V[s-a]

def value_update():
    dif = []
    for s in xrange(1, 100):
        old_V = V[s]
        V[s] = max(backup_action(s, a) for a in xrange(1, min(s, 100-s)+1))
        dif.append(abs(old_V - V[s]))
    return max(dif)

V_h = []

def vi(epsilon=1e-20):
    V[0:100] = np.linspace(0, 0.99, 100)
    while True:
        max_dif = value_update()
        # print max_dif
        V_h.append(V.copy())
        if epsilon > max_dif:
            break

def policy(s, epsilon=1e-8):
    best_value = -1
    for a in xrange(1, min(s, 100-s)+1):
        this_value = backup_action(s, a)
        if this_value > best_value + epsilon:
             best_value = this_value
             best_action = a
    return best_action

# def policy(s):
#     a = range(1, min(s, 100-s)+1)
#     i = np.argmax([backup_action(s=s, a=ai) for ai in a])
#     return a[i]

In [3]:
vi(epsilon=1e-20)

In [4]:
plt.plot(V - np.linspace(0, 1, 101))


Out[4]:
[<matplotlib.lines.Line2D at 0x7f6171e1ab50>]

In [5]:
plt.bar(range(1,100), [policy(s) for s in xrange(1,100)])


Out[5]:
<Container object of 99 artists>

In [6]:
plt.bar(range(1,100), [policy(s, epsilon=0) for s in xrange(1,100)])


Out[6]:
<Container object of 99 artists>

In [7]:
s = 49
plt.plot(range(1, min(s, 100-s)+1), [backup_action(s=s, a=a) for a in xrange(1, min(s, 100-s)+1)])


Out[7]:
[<matplotlib.lines.Line2D at 0x7f61717e02d0>]

In [8]:
p = np.float(0.5)
vi(epsilon=1e-20)
plt.bar(range(1,100), [policy(s) for s in xrange(1,100)])


Out[8]:
<Container object of 99 artists>

In [9]:
p = np.float(0.49)
vi(epsilon=1e-20)
plt.bar(range(1,100), [policy(s) for s in xrange(1,100)])


Out[9]:
<Container object of 99 artists>

In [10]:
p = np.float(0.51)
vi(epsilon=1e-20)
plt.bar(range(1,100), [policy(s) for s in xrange(1,100)])


Out[10]:
<Container object of 99 artists>

In [11]:
plt.plot(V - np.linspace(0, 1, 101))


Out[11]:
[<matplotlib.lines.Line2D at 0x7f6170ef70d0>]