In [5]:
import torch
from torch import ones, zeros, tensor, eye

from torch.distributions import Categorical

import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

from IPython.display import set_matplotlib_formats
set_matplotlib_formats('retina')
sns.set(style='white', palette='colorblind', color_codes=True, font_scale=1.5)
%matplotlib inline

import sys
import os

import os
cwd = os.getcwd()

sys.path.append(cwd[:-len('befit/examples/control_dilemmas')])

from befit.simulate import Simulator
from befit.tasks import bandits
from befit.agents import AIBandits, Random

from setup_environment import *  # load relevant parameters, like the number of options, segments, trials, etc.

In [6]:
# Define environment
context = torch.from_numpy(np.load('context_{}.npy'.format(blocks)))
offers = torch.from_numpy(np.load('offers_{}.npy'.format(blocks)))

envs = []
for i in range(3):
    envs.append(bandits.MultiArmedBandit(priors, 
                                         transitions, 
                                         context,
                                         offers,
                                         arm_types,
                                         nsub=nsub, 
                                         blocks=blocks, 
                                         trials=trials))

#### Define agents ###

# initial likelihood prior
a = 2*torch.stack([torch.eye(nf), torch.eye(nf)]).reshape(-1, 3)[arm_types] + 1
a = a.repeat(nsub, 1, 1, 1)

nd = 20
pars = {
    'nd': nd,
    'ns': na,  # number of arms
    'na': na,  # number of choices
    'nc': 6,  # number of contexts
    'nf': nf, # number of features
    'ni': 1,  # internal states
}

Bdd = zeros(nd, nd)
d = torch.arange(1., nd + 1.)
beta = .2
Bdd[0] = (-d * beta).softmax(-1)
Bdd[range(1, nd), range(nd-1)] = 1.

Bccd = zeros(nd, 6, 6)
Bccd[0] = (ones(6, 6) - eye(6))/5
Bccd[1:] = eye(6).repeat(nd-1, 1, 1)

tm_higher = {
    'context': Bccd,
    'duration': Bdd
}

efe_agent = AIBandits(pars, 
                      runs=nsub, 
                      blocks=blocks, 
                      trials=trials, 
                      tm={'higher':tm_higher})
efe_agent.set_parameters(x = [a], depth=trials)

iv_agent = AIBandits(pars, 
                     runs=nsub, 
                     blocks=blocks, 
                     trials=trials, 
                     tm={'higher':tm_higher})
iv_agent.set_parameters(x = [a], depth=trials, epistemic=False)

random_agent = Random(pars, runs=nsub, blocks=blocks, trials=trials)

In [7]:
agents = [efe_agent, iv_agent, random_agent]

sim = Simulator(envs, agents, runs=nsub, blocks=blocks, trials=trials)
sim.simulate_experiment()

In [14]:
import pandas as pd

cut = 100
offs = offers[cut:, 0]

ctxt = [1, 1, 2, 2, 3, 3]
tp = ['A', 'B', 'A', 'B', 'A', 'B']
name = ['EFE', 'IV']
labels = np.array(['80%-0', '80%-1', '100%-0', '100%-1'])
groups = torch.tensor([0, 1, 1, 2, 3, 3])

dfs = pd.DataFrame()
for i in range(2):
    responses = []
    for res in sim.responses['pair_{}'.format(i)]:
        responses.append(torch.stack(res[1:]))

    responses = torch.stack(responses[cut:])
    choices = arm_types[:, responses].transpose(dim0=-1, dim1=-2)

    for o in range(no):
        loc = o == offs
        vals, count = np.unique(groups[choices[o, loc]], return_counts=True)
        df = pd.DataFrame({'option type': labels[vals], 
                           'probability': count/count.sum(), 
                           'context': ctxt[o], 
                           'variant': tp[o],
                           'agent': name[i]})
        dfs = dfs.append(df, ignore_index=True)
            
g = sns.catplot(x='option type', y='probability', data=dfs, col='context', 
                row='variant', hue='agent', kind="bar")

titles = ['context 1:EV-', 'context 2:EV-', 'context 3:EV-', 'context 1:EV+', 'context 2:EV+', 'context 3:EV+']
for i, ax in enumerate(g.axes.flatten()):
    ax.set_title(titles[i])

plt.savefig('Fig4.png', bbox_inches='tight', transparent=True, dpi=300)



In [13]:
successes = []

out = []
for i in range(3):
    out.append(sim.stimuli['pair_{}'.format(i)]['outcomes'])

for out1, out2, out3 in zip(*out):
    successes.append(torch.stack([out1[-1][-1], 
                                  out2[-1][-1], 
                                  out3[-1][-1]],
                                  -1))
    
successes = torch.stack(successes).float()

g = sns.FacetGrid(dfs, col='context', row='variant', height=5)
for i in range(2):
    g.axes[i, 0].set_ylabel('success rate')
for j in range(3):
    g.axes[-1, j].set_xlabel('relative segment number')

axes = g.axes.flatten()
colors = sns.color_palette(palette='colorblind')[:3]
labels = ['EFE', 'IV', 'RC']
for c, vc in enumerate([0, 2, 4, 1, 3, 5]):
    loc = offers[:, 0, 0] == vc
    sccs = successes[loc].reshape(-1, 5, nsub, 3).mean(-2)
    K = sccs.shape[0]
    lws = np.arange(3/(2*K), 3, 3/K)
    axes[c].set_title(titles[c])
    if c < 3:
        axes[c].hlines(0.73728, 0, 4, 'k', linestyle='--', lw=3)
    else:
        axes[c].hlines(1., 0, 4, 'k', linestyle='--', lw=3)
        
    for i in range(3):
        for j, lw in enumerate(lws):
            if c == 0 and j == len(lws) - 1:
                label = labels[i]
            else:
                label = None
                
            axes[c].plot(sccs[j, ..., i].numpy().T, c=colors[i], lw=lw);
            
import matplotlib as mpl

legend_data = {}
for i in range(3):
    legend_data[labels[i]] = mpl.patches.Patch(color=colors[i], linewidth=1)
g.add_legend(legend_data, title='agent', label_order=labels)

g.fig.savefig('Fig3.png', bbox_inches='tight', dpi=300)



In [10]:
successes = []

out = []
for i in range(3):
    out.append(sim.stimuli['pair_{}'.format(i)]['outcomes'])

for out1, out2, out3 in zip(*out):
    successes.append(torch.stack([out1[-1][-1], 
                                  out2[-1][-1], 
                                  out3[-1][-1]],
                                  -1))
    
successes = torch.stack(successes).float()

maximum = rho**trials + rho**(trials - 1)*(1-rho)*trials

N = nsub//3
colors = sns.color_palette(palette='colorblind')[:3]
fig, axes = plt.subplots(2, 1, figsize=(15, 10), sharex=True)
blks = torch.arange(1, blocks + 1).numpy()
axes[0].plot(blks, successes[..., -3].mean(-1).numpy(), label='EFE');
axes[0].plot(blks, successes[..., -2].mean(-1).numpy(), label='IV');
axes[0].plot(blks, successes[..., -1].mean(-1).numpy(), label='RC');

axes[0].hlines(maximum, 1, blocks, 'k', linestyle='--')
axes[0].set_ylabel('success rate')
axes[0].legend(loc=8, title='agent', fontsize=12)

offs = offers[:,0,0]
offs = ((offs == 0) + (offs == 4) + (offs == 2)).float()
locs = offs == 1

axes[1].plot(blks, context + 1, 'k');
axes[1].plot(blks[locs], context[locs] + 1, 'ko')
axes[1].set_ylabel('context')
axes[1].set_yticks([1, 2, 3])
axes[1].set_xlabel('segment')

axes[0].text(-10, 1.05, '(a)', fontsize=16)
axes[1].text(-10, 3.15, '(b)', fontsize=16)

plt.xlim([1, blocks]);



In [16]:
import matplotlib.lines as lns

styles = [lns.Line2D([], [], color='k', linestyle='-.', label='1'),
          lns.Line2D([], [], color='k', linestyle='--', label='2'),
          lns.Line2D([], [], color='k', linestyle=':', label='3'), 
          lns.Line2D([], [], color='k', label='4')]

fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)

sub = 1
a1 = torch.stack(efe_agent.a)[1:].reshape(blocks, trials, nsub, 6, 4, nf)
A1 = a1/a1.sum(-1, keepdim=True)

a2 = torch.stack(iv_agent.a)[1:].reshape(blocks, trials, nsub, 6, 4, nf)
A2 = a2/a2.sum(-1, keepdim=True)

t = np.arange(1, blocks+1)

for i,c in enumerate([1, 3, 5]):
    axes[0, i].plot(t, A1[:, -1, sub, c, 0, 1].numpy(), 'b-.');
    axes[0, i].plot(t, A1[:, -1, sub, c, 0, 2].numpy(), 'r-.');
    axes[0, i].plot(t, A1[:, -1, sub, c, 1, 1].numpy(), 'b--');
    axes[0, i].plot(t, A1[:, -1, sub, c, 1, 2].numpy(), 'r--');
    axes[0, i].plot(t, A1[:, -1, sub, c, 2, 1].numpy(), 'b:');
    axes[0, i].plot(t, A1[:, -1, sub, c, 2, 2].numpy(), 'r:');
    axes[0, i].plot(t, A1[:, -1, sub, c, 3, 1].numpy(), 'b');
    axes[0, i].plot(t, A1[:, -1, sub, c, 3, 2].numpy(), 'r');
    axes[0, i].set_title('context {}:EV+'.format(i+1));
    
    a = torch.stack(iv_agent.a)[1:].reshape(blocks, trials, nsub, 6, 4, nf)
    A = a/a.sum(-1, keepdim=True)

    axes[1, i].plot(t, A2[:, -1, sub, c, 0, 1].numpy(), 'b-.');
    axes[1, i].plot(t, A2[:, -1, sub, c, 0, 2].numpy(), 'r-.');
    axes[1, i].plot(t, A2[:, -1, sub, c, 1, 1].numpy(), 'b--');
    axes[1, i].plot(t, A2[:, -1, sub, c, 1, 2].numpy(), 'r--');
    axes[1, i].plot(t, A2[:, -1, sub, c, 2, 1].numpy(), 'b:');
    axes[1, i].plot(t, A2[:, -1, sub, c, 2, 2].numpy(), 'r:');
    axes[1, i].plot(t, A2[:, -1, sub, c, 3, 1].numpy(), 'b');
    axes[1, i].plot(t, A2[:, -1, sub, c, 3, 2].numpy(), 'r');

axes[0, 0].legend(handles=styles, title='arms');
axes[-1, -1].set_xlim([1, blocks]);
axes[-1, 0].set_xlabel('segment');
axes[-1, -2].set_xlabel('segment');
axes[-1, -1].set_xlabel('segment');

axes[0, 0].set_ylabel('probability');
axes[1, 0].set_ylabel('probability');

axes[0, -1].text(1.05, .5, 'EFE agent', 
                 horizontalalignment='center',
                 verticalalignment='center',
                 rotation=-90,
                 transform=axes[0, -1].transAxes );

axes[1, -1].text(1.05, .5, 'IV agent', 
                 horizontalalignment='center',
                 verticalalignment='center',
                 rotation=-90,
                 transform=axes[1, -1].transAxes );



In [17]:
fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)

for i,c in enumerate([0, 2, 4]):
    axes[0, i].plot(t, A1[:, -1, sub, c, 0, 1].numpy(), 'b-.');
    axes[0, i].plot(t, A1[:, -1, sub, c, 0, 2].numpy(), 'r-.');
    axes[0, i].plot(t, A1[:, -1, sub, c, 1, 1].numpy(), 'b--');
    axes[0, i].plot(t, A1[:, -1, sub, c, 1, 2].numpy(), 'r--');
    axes[0, i].plot(t, A1[:, -1, sub, c, 2, 1].numpy(), 'b:');
    axes[0, i].plot(t, A1[:, -1, sub, c, 2, 2].numpy(), 'r:');
    axes[0, i].plot(t, A1[:, -1, sub, c, 3, 1].numpy(), 'b');
    axes[0, i].plot(t, A1[:, -1, sub, c, 3, 2].numpy(), 'r');
    axes[0, i].set_title('context {}:EV-'.format(i+1));
    
    a = torch.stack(iv_agent.a)[1:].reshape(blocks, trials, nsub, 6, 4, nf)
    A = a/a.sum(-1, keepdim=True)

    axes[1, i].plot(t, A2[:, -1, sub, c, 0, 1].numpy(), 'b-.');
    axes[1, i].plot(t, A2[:, -1, sub, c, 0, 2].numpy(), 'r-.');
    axes[1, i].plot(t, A2[:, -1, sub, c, 1, 1].numpy(), 'b--');
    axes[1, i].plot(t, A2[:, -1, sub, c, 1, 2].numpy(), 'r--');
    axes[1, i].plot(t, A2[:, -1, sub, c, 2, 1].numpy(), 'b:');
    axes[1, i].plot(t, A2[:, -1, sub, c, 2, 2].numpy(), 'r:');
    axes[1, i].plot(t, A2[:, -1, sub, c, 3, 1].numpy(), 'b');
    axes[1, i].plot(t, A2[:, -1, sub, c, 3, 2].numpy(), 'r');

axes[0, 0].legend(handles=styles, title='arms');
axes[-1, -1].set_xlim([1, blocks]);
axes[-1, 0].set_xlabel('segment');
axes[-1, -2].set_xlabel('segment');
axes[-1, -1].set_xlabel('segment');

axes[0, 0].set_ylabel('probability');
axes[1, 0].set_ylabel('probability');

axes[0, -1].text(1.05, .5, 'EFE agent', 
                 horizontalalignment='center',
                 verticalalignment='center',
                 rotation=-90,
                 transform=axes[0, -1].transAxes );

axes[1, -1].text(1.05, .5, 'IV agent', 
                 horizontalalignment='center',
                 verticalalignment='center',
                 rotation=-90,
                 transform=axes[1, -1].transAxes );



In [ ]: