In [2]:
# Standard plotting setup
%matplotlib inline
import matplotlib.pyplot as plt
from pylab import rcParams
rcParams['figure.figsize'] = 10, 5
plt.style.use('ggplot')

Multi-armed bandit as a Markov decision process

Let's model the Bernouilli multi-armed bandit. The Bernoulli MBA is an $N$-armed bandit where each arm gives binary rewards according to some probability:

$r_i \sim Bernouilli(\mu_i)$

Here $i$ is the index of the arm. Let's model this as a Markov decision process. The state is going to be defined as:

$s(t) = (\alpha_1, \beta_1, \ldots, \alpha_N, \beta_N, r_t)$

$\alpha_i$ is the number of successes encountered so far when pulling arm $i$. $\beta_i$ is, similarly, the number of failures encountered when pulling that arm. $r_t$ is the reward, either 0 or 1, from the last trial.

Assuming a uniform prior on $\mu_i$, the posterior distribution of the $\mu_i$ in a given state are:

$p(\mu_i|s(t)) = Beta(\alpha_i+1,\beta_i+1)$

When we're in a given state, we have the choice of performing one of $N$ actions, corresponding to pulling each of the arms. Let's call pulling the $i$'th arm $a_i$. This will put us in a new state, with a certain probability. The new state will be same for arms not equal to i. For the $i$'th arm, we have:

$s(t+1) = (\ldots \alpha_i + 1, \beta_i \ldots 1)$ with probability $(\alpha_i+1)/(\alpha_i+\beta_i+2)$

$s(t+1) = (\ldots \alpha_i, \beta_i + 1 \ldots 0)$ with probability $(\beta_i+1)/(\alpha_i+\beta_i+2)$

We can solve exactly for this MDP, e.g. using value iteration, given that it's small enough. For $M$ trials, the state space is $M^{2N}$ - it's possible to solve the 2-armed bandit for 10-20 trials this way, but it grows exponentially fast.


In [3]:
import itertools
import numpy as np
from pprint import pprint

def sorted_values(dict_):
    return [dict_[x] for x in sorted(dict_)]

def solve_bmab_value_iteration(N_arms, M_trials, gamma=1,
                               max_iter=10, conv_crit = .01):
    util = {}
    
    # Initialize every state to utility 0.
    state_ranges = [range(M_trials+1) for x in range(N_arms*2)]
    # The reward state
    state_ranges.append(range(2))
    for state in itertools.product(*state_ranges):
        # Some states are impossible to reach.
        if sum(state[:-1]) > M_trials:
            # A state with the total of alphas and betas greater than 
            # the number of trials.
            continue
            
        if sum(state[:-1:2]) == 0 and state[-1] == 1:
            # A state with a reward but alphas all equal to 0.
            continue
            
        if sum(state[:-1:2]) == M_trials and state[-1] == 0:
            # A state with no reward but alphas adding up to M_trials.
            continue
            
        if sum(state[:-1]) == 1 and sum(state[:-1:2]) == 1 and state[-1] == 0:
            # A state with an initial reward according to alphas but not according
            # the reward index
            continue
            
        util[state] = 0
    
    # Main loop.
    converged = False
    new_util = util.copy()
    opt_actions = {}
    for j in range(max_iter):
        # Line 5 of value iteration
        for state in util.keys():
            reward = state[-1]
            
            # Terminal state.
            if sum(state[:-1]) == M_trials:
                new_util[state] = reward
                continue
            
            values = np.zeros(N_arms)
            
            # Consider every action
            for i in range(N_arms):
                # Successes and failure for this state.
                alpha = state[i*2]
                beta  = state[i*2+1]
                
                # Two possible outcomes: either that arm gets rewarded,
                # or not.
                # Transition to unrewarded state:
                state0 = list(state)
                state0[-1] = 0
                state0[2*i+1] += 1
                state0 = tuple(state0)
                
                # The probability that we'll transition to this unrewarded state.
                p_state0 = (beta + 1) / float(alpha + beta + 2)
                
                # Rewarded state.
                state1 = list(state)
                state1[-1] = 1
                state1[2*i] += 1
                state1 = tuple(state1)
                
                p_state1 = 1 - p_state0
                try:
                    value = gamma*(util[state0]*p_state0 + 
                                   util[state1]*p_state1)
                except KeyError,e:
                    print state
                    print state0
                    print state1
                    raise e
                    
                #print state0, util[state0], p_state0
                #print state1, util[state1], p_state1
                values[i] = value
                
            #print state, values, reward
            new_util[state] = reward + np.max(values)
            opt_actions[state] = np.argmax(values)
            
        # Consider the difference between the new util
        # and the old util.
        max_diff = np.max(abs(np.array(sorted_values(util)) - np.array(sorted_values(new_util))))
        util = new_util.copy()
        
        print "Iteration %d, max diff = %.5f" % (j, max_diff)
        if max_diff < conv_crit:
            converged = True
            break
            
        #pprint(util)
            
    if converged:
        print "Converged after %d iterations" % j
    else:
        print "Not converged after %d iterations" % max_iter
        
    return util, opt_actions

util, opt_actions = solve_bmab_value_iteration(2, 2, max_iter=5)


Iteration 0, max diff = 1.00000
Iteration 1, max diff = 0.66667
Iteration 2, max diff = 0.58333
Iteration 3, max diff = 0.00000
Converged after 3 iterations

In [4]:
opt_actions


Out[4]:
{(0, 0, 0, 0, 0): 0,
 (0, 0, 0, 1, 0): 0,
 (0, 0, 1, 0, 1): 1,
 (0, 1, 0, 0, 0): 1,
 (1, 0, 0, 0, 1): 0}

For the 2-armed, 2-trial Bernoulli bandit, the strategy is simple: pick the first arm. If it rewards, then pick it again. If not, pick the other. Note that this is the same as most sensible strategies, for instance greedy or UCB.


In [5]:
util


Out[5]:
{(0, 0, 0, 0, 0): 1.0833333333333335,
 (0, 0, 0, 1, 0): 0.5,
 (0, 0, 0, 2, 0): 0,
 (0, 0, 1, 0, 1): 1.6666666666666667,
 (0, 0, 1, 1, 0): 0,
 (0, 0, 1, 1, 1): 1,
 (0, 0, 2, 0, 1): 1,
 (0, 1, 0, 0, 0): 0.5,
 (0, 1, 0, 1, 0): 0,
 (0, 1, 1, 0, 0): 0,
 (0, 1, 1, 0, 1): 1,
 (0, 2, 0, 0, 0): 0,
 (1, 0, 0, 0, 1): 1.6666666666666667,
 (1, 0, 0, 1, 0): 0,
 (1, 0, 0, 1, 1): 1,
 (1, 0, 1, 0, 1): 1,
 (1, 1, 0, 0, 0): 0,
 (1, 1, 0, 0, 1): 1,
 (2, 0, 0, 0, 1): 1}

Note that the utility of the root node is 1.08 - what does that mean? If we get rewarded in the initial trial, that means that the posterior for the mean of that arm is .67. OTOH, when we fail on the first trial, we can still pick the other arm, which still has a posterior mean of .5. Thus, we have rewards:

  • +2 with probability .5*2/3
  • +1 with prob .5*1/3
  • +1 with prob .5*.5
  • +0 with prob .5*.5

That means the expected total reward is:


In [28]:
2*.5*2.0/3.0 + .5/3.0 + .5*.5


Out[28]:
1.0833333333333333

And that's what utility means in this context. Let's see about the 3-trial 2-armed bandit:


In [6]:
util, opt_actions = solve_bmab_value_iteration(2, 3, max_iter=5)
opt_actions


Iteration 0, max diff = 1.00000
Iteration 1, max diff = 0.75000
Iteration 2, max diff = 0.66667
Iteration 3, max diff = 0.58333
Iteration 4, max diff = 0.00000
Converged after 4 iterations
Out[6]:
{(0, 0, 0, 0, 0): 0,
 (0, 0, 0, 1, 0): 0,
 (0, 0, 0, 2, 0): 0,
 (0, 0, 1, 0, 1): 1,
 (0, 0, 1, 1, 0): 0,
 (0, 0, 1, 1, 1): 0,
 (0, 0, 2, 0, 0): 1,
 (0, 0, 2, 0, 1): 1,
 (0, 1, 0, 0, 0): 1,
 (0, 1, 0, 1, 0): 0,
 (0, 1, 1, 0, 0): 1,
 (0, 1, 1, 0, 1): 1,
 (0, 2, 0, 0, 0): 1,
 (1, 0, 0, 0, 1): 0,
 (1, 0, 0, 1, 0): 0,
 (1, 0, 0, 1, 1): 0,
 (1, 0, 1, 0, 0): 0,
 (1, 0, 1, 0, 1): 0,
 (1, 1, 0, 0, 0): 0,
 (1, 1, 0, 0, 1): 0,
 (2, 0, 0, 0, 0): 0,
 (2, 0, 0, 0, 1): 0}

The optimal strategy goes: pick arm 0. If it rewards, pick it again for the next 2 trials. If it doesn't reward, then pick arm 1. If that rewards, keep that one. If it doesn't, pick 0 again.

Let's see with 4:


In [7]:
util, opt_actions = solve_bmab_value_iteration(2, 4, max_iter=6)


Iteration 0, max diff = 1.00000
Iteration 1, max diff = 0.80000
Iteration 2, max diff = 0.75000
Iteration 3, max diff = 0.69444
Iteration 4, max diff = 0.61111
Iteration 5, max diff = 0.00000
Converged after 5 iterations

What's interesting here is that value iteration always converges in M_trials + 1 iterations - information only travels backwards through time - much as in Viterbi in the context of HMMs. If we're only interested in the next best action given the current state, it might be possible to iterate backwards through time, starting from the terminal states, throwing away the latest data as we go along. Before we get into premature optimization, however, let's see how far we can look ahead without crashing with Chromebook.


In [8]:
M_trials = 16
%time util, opt_actions = solve_bmab_value_iteration(2, M_trials, max_iter=M_trials+2)


Iteration 0, max diff = 1.00000
Iteration 1, max diff = 0.94118
Iteration 2, max diff = 0.93750
Iteration 3, max diff = 0.93333
Iteration 4, max diff = 0.92857
Iteration 5, max diff = 0.92308
Iteration 6, max diff = 0.91667
Iteration 7, max diff = 0.90909
Iteration 8, max diff = 0.90000
Iteration 9, max diff = 0.88896
Iteration 10, max diff = 0.87550
Iteration 11, max diff = 0.85897
Iteration 12, max diff = 0.83861
Iteration 13, max diff = 0.81231
Iteration 14, max diff = 0.77726
Iteration 15, max diff = 0.72543
Iteration 16, max diff = 0.64213
Iteration 17, max diff = 0.00000
Converged after 17 iterations
CPU times: user 9.89 s, sys: 21 ms, total: 9.92 s
Wall time: 10.2 s

In [9]:
M_trials
bad_keys = [k for k in opt_actions.keys() if sum(k[:-1]) > 15]
assert(len(bad_keys) == 0)

It seems like my Chromebook can look ahead at least sixteen steps into the future without dying - pretty good!

Optimal versus UCB

Let's try and figure out how the optimal strategy relates to the upper confidence bound heuristic. Let's train a logistic regression model with the same inputs as a UCB strategy - mean, standard deviation, time - and see how well it can approximate the optimal strategy.


In [10]:
# Create a design matrix related to the optimal strategies.
X = []
y = []
seen_keys = {}
for key, val in opt_actions.iteritems():
    if key[:-1] in seen_keys:
        # We've already seen this, continue.
        continue
        
    alpha0 = float(key[0] + 1)
    beta0  = float(key[1] + 1)
    alpha1 = float(key[2] + 1)
    beta1  = float(key[3] + 1)
    
    if alpha0 == alpha1 and beta0 == beta1:
        # We're in a perfectly symmetric situtation, skip this then.
        continue
        
    seen_keys = key[:-1]
    
    # Standard results for the Beta distribution.
    # https://en.wikipedia.org/wiki/Beta_distribution
    mean0 = alpha0/(alpha0 + beta0)
    mean1 = alpha1/(alpha1 + beta1)
    
    std0  = np.sqrt(alpha0*beta0 / (alpha0 + beta0 + 1)) / (alpha0 + beta0)
    std1  = np.sqrt(alpha1*beta1 / (alpha1 + beta1 + 1)) / (alpha1 + beta1)
    
    t = alpha0 + beta0 + alpha1 + beta1
    X.append([mean0,mean1,std0,std1,t,1,alpha0 - 1,beta0 - 1,alpha1 - 1,beta1 - 1])
    y.append(val)
    
X = np.array(X)
y = np.array(y)

Let's train a supervised network a see how well it can predict the correct move based on a purely greedy heuristic - and based on a heuristic which takes into account the uncertainty in the estimate.


In [11]:
from sklearn.linear_model import LogisticRegression

the_model = LogisticRegression(C=100.0)
X_ = X[:,:2]
the_model.fit(X_,y)
y_pred = the_model.predict(X_)

print ("Greedy: %.4f%% of moves are incorrect" % ((np.mean(abs(y_pred-y)))*100))
print the_model.coef_

the_model = LogisticRegression(C=100.0)
X_ = X[:,:4]
the_model.fit(X_,y)
y_pred = the_model.predict(X_)

print ("UCB: %.4f%% of moves are incorrect" % ((np.mean(abs(y_pred-y)))*100))
print the_model.coef_

the_model = LogisticRegression(C=100000.0)
X_ = X[:,:4]
X_ = np.hstack((X_,(X[:,4]).reshape((-1,1))*X[:,2:4]))
the_model.fit(X_,y)
y_pred = the_model.predict(X_)

print ("UCB X time: %.4f%% of moves are incorrect" % ((np.mean(abs(y_pred-y)))*100))
print the_model.coef_


Greedy: 2.9669% of moves are incorrect
[[-38.98482874  38.98567838]]
UCB: 1.8013% of moves are incorrect
[[-57.73453836  57.73453317 -29.49326918  29.49299601]]
UCB X time: 0.5298% of moves are incorrect
[[-201.97514467  201.97538323 -496.78194952  496.77733617   24.26601785
   -24.26590183]]

We see that the greedy strategy misses the right move 3% of the time, while UCB shaves that down to 1.8%. Pretty significant. The UCB parameter - a parameter which determines how much "bonus" should be given to uncertainty - is suspiciously low at (29.49 / 57.7 ~= .5). In the literature, people use something around 2-3.

Adding a parameter which is the cross of time and the standard deviation of the estimate reveals the source of this discrepancy: at the initial time point, the UCB parameter is high (496.7 / 201 ~ 2.5) and it ramps down linearly as a function of time to (496 - 26.26*16) / 200 ~= 0.4. Thus, the optimal strategy is similar to a UCB strategy, with a twist: the exploration bonus should ramp down as a function of time. This makes sense: new information is more valuable in the initial trials.

This UCB X time strategy misses only .5% of moves, which is quite good, all things considered.

Monte Carlo Tree Search

The dynamic programming approach is of theoretical interest, but it doesn't scale well to other kinds of problems, like contextual bandits, problems with continuous-valued rewards, or problems with larger state spaces. Rather than exhaustively determining the outcome of every path, we can sample outcomes at random. Let's start by implement vanilla MCTS, where actions at every junction are sampled uniformly. We'll later upgrade to UCT.


In [ ]:
import collections

class VanillaMCTS:
    
    def __init__(self, N_arms, M_trials, stochastic_rewards = True):
        self.N_arms = N_arms
        self.M_trials = M_trials
        self.stochastic_rewards = stochastic_rewards
        self.state_action_reward = collections.defaultdict(int)
        self.state_action_visit  = collections.defaultdict(int)
        
        
    def find_best_action(self,
                         current_state, 
                         max_horizon = 10,
                         max_samples = 100):
        max_depth = min(self.M_trials - sum(current_state[::-1]), max_horizon)
        for n in range(max_samples):
            self.mcts_search(current_state, max_depth)
            
        state_rewards = [(self.state_action_reward[state_action] / 
                         float(self.state_action_visit[state_action]),state_action) for 
                         state_action in self.state_action_reward.keys() if
                         state_action[0] == current_state]
        
        max_reward, best_state_action = max(state_rewards)
        _, best_action = best_state_action
        return max_reward, best_action, state_rewards

    def mcts_search(self, state, max_depth):
        if max_depth == 0:
            return 0
        
        # Select an action
        action = self.select_action(state)
        
        # Pull that arm
        next_state, expected_reward = self.perform_action(state, action)
        
        if self.stochastic_rewards:
            r = next_state[-1]
        else:
            r = expected_reward
        
        reward = self.mcts_search(next_state, max_depth - 1) + r
        
        # Memo-ize
        state_action = (state,action)
        self.state_action_reward[state_action] += reward
        self.state_action_visit[state_action] += 1
        
        return reward
        
    def perform_action(self, state, action):
        # Pull the arm in question; only valid for Bernouilli arm
        alpha = state[action*2]
        beta = state[action*2 + 1]
        
        expected_reward = (alpha + 1)/float(alpha + beta + 2)
        
        rewarded = np.random.rand() < expected_reward
        
        state = list(state)
        if rewarded:
            state[action*2] += 1
            state[-1] = 1
        else:
            state[action*2+1] += 1
            state[-1] = 0
            
        return (tuple(state), expected_reward)
    
    def select_action(self, state):
        # Select uniformily at random
        return int(np.random.rand() * (len(state) / 2))

In [32]:
# Action 1 is better both from an exploration and an exploitation
# perspective, but not by that much
ambiguous_state = (2,2,2,1,0)

ndraws = 10
nsims_per = 1000
best_actions = np.zeros((nsims_per,ndraws,2))
max_rewards = np.zeros((nsims_per,2))

for k,stochastic_rewards in enumerate([False, True]):
    for j in range(ndraws):
        mcts = VanillaMCTS(2, 16, stochastic_rewards = stochastic_rewards)

        for i in range(nsims_per):
            _, best_action, state_reward = mcts.find_best_action(ambiguous_state, max_samples = 1)
            best_actions[i,j,k] = best_action
            
plt.subplot(121)
plt.plot(best_actions[:,:,0].mean(1))
plt.xlabel('# draws')
plt.ylabel('p(picking) right arm after N draws')
plt.ylim([0,1])
plt.title('deterministic rewards')
plt.subplot(122)
plt.plot(best_actions[:,:,1].mean(1))
plt.xlabel('#draws')
plt.title('stochastic rewards')
plt.ylim([0,1])


Out[32]:
(0, 1)

Generally, the probability of picking the right arm increases with the number of draws -- but it still takes several hundred trials to pick the right arm consistently, despite the fact that we're simulating a rather large difference. This gets better if we return deterministic rewards rather than stochastic. Let's see what happens with a harder case:


In [33]:
# Action 1 is better only from an exploitation perspective
ambiguous_state = (2,2,1,1,0)

ndraws = 10
nsims_per = 10000
best_actions = np.zeros((nsims_per,ndraws,2))
max_rewards = np.zeros((nsims_per,2))

for k,stochastic_rewards in enumerate([False, True]):
    for j in range(ndraws):
        mcts = VanillaMCTS(2, 16, stochastic_rewards = stochastic_rewards)

        for i in range(nsims_per):
            _, best_action, state_reward = mcts.find_best_action(ambiguous_state, max_samples = 1)
            best_actions[i,j,k] = best_action
            
plt.subplot(121)
plt.plot(best_actions[:,:,0].mean(1))
plt.xlabel('# draws')
plt.ylabel('p(picking) right arm after N draws')
plt.ylim([0,1])
plt.title('deterministic rewards')
plt.subplot(122)
plt.plot(best_actions[:,:,1].mean(1))
plt.xlabel('#draws')
plt.title('stochastic rewards')
plt.ylim([0,1])


Out[33]:
(0, 1)

In this case, the method does not converge to the correct action -- subsequent actions to the first are selected uniformily -- the method answers the question, which action should I pick next, if I picking actions at random later? The answer, in this case, is that it doesn't matter

UCT - Upper confidence tree


In [184]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.subplot(121)
plt.plot(max_rewards)
plt.subplot(122)
plt.plot(best_actions)

best_actions


Out[184]:
array([[ 0.,  1.,  1., ...,  0.,  0.,  0.],
       [ 0.,  1.,  1., ...,  0.,  0.,  0.],
       [ 0.,  1.,  1., ...,  0.,  1.,  1.],
       ..., 
       [ 1.,  0.,  1., ...,  0.,  0.,  0.],
       [ 1.,  0.,  1., ...,  0.,  0.,  0.],
       [ 1.,  0.,  1., ...,  0.,  0.,  0.]])

instead focus on the most promising paths. Let's do this via MCTS - the initial value of a branch is going to be set as the expected outcome plus an exploration bonus, and after the horizon of the tree search is reached, the path will be scored by its mean value. We'll compare this to a UCB strategy.


In [119]:
max_reward
best_action


Out[119]:
1

In [ ]:
# Find a case where the greedy strategy is incorrect
the_model = LogisticRegression(C=100.0)
X_ = X[:,:2]
the_model.fit(X_,y)
y_pred = the_model.predict(X_)

print X[np.where(y_pred != y)[0][0],:]

TODO

A few things to try:

  • Find a more efficient way of implement value iteration - can we throw away the later data once we've backpropagated the future? Can we look into the far future in this way?
  • How good is a strategy that is myopic, that is, only looks at a few steps in advance?
  • What does policy iteration look like in this context?
  • How does this all compare to MCTS?