First, we import the required modules and functions, and define a helper function which will be used for Figures 1 and 2.
In [1]:
%matplotlib inline
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpl_patches
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib import cm
from matplotlib.text import TextPath
from scipy import interpolate
from pydlv import data_reader, dl_model_3, dl_generator, dl_plotter, trajectory_plotter, data_analyser
mpl.rcParams['figure.dpi'] = 300
def get_choice_patches():
screen_patch = mpl_patches.Rectangle((-1.1, -0.1), 2.2, 1.3, fill=False, lw=3,
edgecolor='black')
left_patch = mpl_patches.Rectangle((-1.05, 0.85), 0.3, 0.3, fill=True, lw=1,
facecolor='white', edgecolor='black', alpha=0.2)
right_patch = mpl_patches.Rectangle((0.75, 0.85), 0.3, 0.3, fill=True, lw=1,
facecolor='white', edgecolor='black', alpha=0.2)
left_text = mpl_patches.PathPatch(TextPath((-1.0, 0.9), 'A', size=0.3,
backgroundcolor='white'))
right_text = mpl_patches.PathPatch(TextPath((0.8, 0.9), 'B', size=0.3))
return screen_patch, left_patch, right_patch, left_text, right_text
Create a directory for the figures if it is not already there
In [2]:
if not os.path.exists('figures'):
os.makedirs('figures')
Then we read the data from O'Hora et al. (2013) (see DLV tutorial on how to get and prepare the data) and create the objects we use in preparing the figures.
In [3]:
dr = data_reader.DataReader()
data = dr.get_processed_data(path='csv/processed_data_high_low.csv')
model = dl_model_3.DLModel3()
dlg = dl_generator.DLGenerator(model)
da = data_analyser.DataAnalyser()
In [4]:
def plot_sample_trajectory(data, subj_id, trial_no):
x_lim = [-1.2, 1.2]
y_lim = [-0.2, 1.2]
tickLabelFontSize = 20
axisLabelFontSize = 24
sns.set_style('white')
fig = plt.figure(tight_layout=True)
ax = fig.add_subplot(111)
ax.set_xlabel(r'x coordinate', fontsize=axisLabelFontSize)
ax.set_ylabel(r'y coordinate', fontsize=axisLabelFontSize)
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
ax.tick_params(axis='both', which='major', labelsize=tickLabelFontSize)
traj_color = cm.viridis(0.1)
trajectory = data[(data.subj_id==subj_id) & (data.trial_no==trial_no)]
ax.plot(trajectory.x, trajectory.y, color=traj_color, ls='none', marker='o', ms=15,
markerfacecolor='none', markeredgewidth=2, markeredgecolor=traj_color,
label='Mouse trajectory')
# draw screen above the surface and choice options on it
patches = get_choice_patches()
for patch in patches:
ax.add_patch(patch)
ax.set_axis_off()
plt.savefig('figures/sample_traj.pdf')
plot_sample_trajectory(data=data, subj_id=8792, trial_no=13)
In [5]:
def plot_baseline_landscape_overlay(dlg, data, subj_id, trial_no):
dlp = dl_plotter.DLPlotter(elev=55, azim=-65)
x_grid, y_grid, dl = dlg.get_model_dl(dlg.model.get_baseline_params()*4)
x=(x_grid[1:]+x_grid[:-1])/2
y=(y_grid[1:]+y_grid[:-1])/2
f = interpolate.interp2d(x,y,dl,kind='cubic')
ax = dlp.plot_surface(x_grid, y_grid, dl, cmap=cm.viridis, alpha=0.9)
ax.set_xticks([-1, -0.5, 0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ax.w_xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax.w_zaxis.set_ticklabels([])
sns.set_style('white')
cmap = cm.viridis
traj_color = cmap(0.1)
trajectory = data[(data.subj_id==subj_id) & (data.trial_no==trial_no)]
z = f(trajectory.x.values, trajectory.y.values)
if trajectory.x.values[-1]>0:
z= np.diag(z)
else:
z=np.diag(np.fliplr(z))
# plot trajectory on a surface
ax.plot(trajectory.x.values, trajectory.y.values, z, color='black', alpha=0.5)
# plot marble
ax.plot([0.], [0.], [0.], marker='o', markersize=15, color = 'black', alpha=0.7)
ax.plot([trajectory.x.values[-1]], [trajectory.y.values[-1]], [z[-1]],
marker='o', markersize=15, color='black', alpha=0.7)
#
# draw screen above the surface and choice options on it
patches = get_choice_patches()
for patch in patches:
ax.add_patch(patch)
art3d.pathpatch_2d_to_3d(patch, z=0, zdir='z')
# plot trajectory on a screen
ax.plot(trajectory.x, trajectory.y, zs=0, zdir='z', color=traj_color, ls='none',
alpha=1.0, marker='o', ms=15, markerfacecolor='none', markeredgewidth=2,
markeredgecolor=traj_color, label='Mouse trajectory')
plt.savefig('figures/baseline_dlv.pdf')
plot_baseline_landscape_overlay(dlg=dlg, data=data, subj_id=8792, trial_no=13)
Figure 3 was assembled in PowerPoint, so there is no source code for generating it.
In [6]:
def plot_dl_variations(dlg, param_names, param_values, colors):
labels = [plt.Rectangle((0, 0), 1, 1, fc=color) for color in colors]
for i, param_name in enumerate(param_names):
dlp = dl_plotter.DLPlotter(elev=30, azim=120)
if i>0:
param_range = np.linspace(param_values[i][0], param_values[i][1], 3)
else:
param_range = np.array(param_values[i])
current_params = model.get_baseline_params()
for n, param_value in enumerate(param_range):
current_params[i] = param_value
x_grid, y_grid, dl = dlg.get_model_dl(current_params)
scale_z = True if n==1 else False
dlp.plot_surface(x_grid, y_grid, dl, color=colors[n], alpha=0.7, scale_z=scale_z)
title_format = r'$\%s$' if param_name=='tau' else r'$%s$'
dlp.ax.set_title(title_format % (param_name), loc='left', fontsize=42)
dlp.ax.legend(labels, param_range, fontsize=32)
plt.savefig('figures/param_variations_%s.pdf' % param_name)
model = dl_model_3.DLModel3()
dlg = dl_generator.DLGenerator(model, x_lim = [-1.25, 1.25], y_lim = [0.0, 1.25])
param_names = ['tau', 'c_{11}', 'c_{21}', 'c_{12}']
k = 0.2
param_values = [[0.04, 0.05, 0.07], [-k*1.0, k*1.0], [-k*1.0, k*1.0], [-k*1.0, k*1.0]]
colors = [cm.inferno(0.2), cm.Greys(0.7), cm.inferno(0.8)]
plot_dl_variations(dlg, param_names, param_values, colors)
For Figures 5 to 8 we need the csv files with the fitted parameters. Let's generate those (you can also run demos/fit_dl_to_data.py
). Not that the following code takes a while to execute, so be patient. This is because the parameters are generated for every subject, not only those illustrated in the figures.
In [7]:
def get_fit_parameters(data, methods=[9], by='subject', csv_path='csv'):
'''
The "methods" parameter defines which optimization routines are used to fit the model
parameters to the trajectories (see dl_generator.py for details). The recommended
methods are 6 (L-BFGS-B) or 9 (SLSQP). Multiple method codes can be supplied to find best-fit
parameters by several methods.
The "by" defines whether the decision landscape is fitted to each trial
individually (by='trial'), to blocks of trials (by='block'),
or to all trajectories of each subject (by='subject')
'''
if not os.path.exists(csv_path):
os.makedirs(csv_path)
model = dl_model_3.DLModel3()
dlg = dl_generator.DLGenerator(model)
for method in methods:
print('By %s, method %i' % (by, method))
if by == 'trial':
fit_dl = lambda traj: dlg.fit_dl_single_traj(traj, method)
params = data.groupby(by=['subj_id', 'trial_no']).apply(fit_dl)
params.index = params.index.droplevel(-1)
elif by == 'block':
fit_dl = lambda trajs: dlg.fit_dl_mult_traj(trajs, method)
params = data.groupby(by=['subj_id', 'block_no']).apply(fit_dl)
params.index = params.index.droplevel(-1)
elif by == 'subject':
fit_dl = lambda trajs: dlg.fit_dl_mult_traj(trajs, method)
params = data.groupby(by='subj_id').apply(fit_dl)
params.index = params.index.droplevel(-1)
print('By %s, method %i, median error %f' % (by, method, params.error.median()))
file_name = csv_path + '/fit_params_by_%s_method_%i.csv' % (by, method)
params.to_csv(file_name)
dr = data_reader.DataReader()
data = dr.get_processed_data(path='csv/processed_data_high_low.csv')
get_fit_parameters(data, by='trial')
get_fit_parameters(data, by='block')
get_fit_parameters(data, by='subject')
Now that parameters are generated, use them to visualise decision landscapes
In [8]:
def plot_surfaces(dlg, params, subj_id, trials, colors):
dlp = dl_plotter.DLPlotter(elev=10, azim=61)
for i, trial_no in enumerate(trials):
x, y, dl = dlg.get_model_dl(params.loc[subj_id, trial_no][2:2+dlg.model.n_params])
dlp.plot_surface(x, y, dl, color=colors[i], alpha=0.8)
dlp.add_legend(colors, trials)
plt.savefig('figures/trials_%i_dlv.pdf' % (subj_id))
def plot_trajectories(data, subj_id, trials, colors):
tp = trajectory_plotter.TrajectoryPlotter()
for i, trial_no in enumerate(trials):
trial_data = data[(data.subj_id==subj_id) & (data.trial_no==trial_no)]
tp.plot_trajectory(trial_data, color=colors[i], label=trial_no)
print(trial_data .iloc[0][['high_chosen', 'motion_time', 'max_d']])
tp.add_legend()
plt.savefig('figures/trials_%i_traj.pdf' % (subj_id))
def compare_dlv(subj_id, trials):
fit_params = pd.read_csv('csv/fit_params_by_trial_method_9.csv',
index_col=['subj_id', 'trial_no'], header=0)
cmap = cm.viridis
colors = [cmap(0.7), cmap(0.35), cmap(0.1)]
model = dl_model_3.DLModel3()
dlg = dl_generator.DLGenerator(model)
plot_surfaces(dlg, fit_params, subj_id, trials, colors)
dr = data_reader.DataReader()
data = dr.get_processed_data(path='csv/processed_data_high_low.csv')
plot_trajectories(data, subj_id, trials, colors)
subj_id = 90
trials = [28, 2, 15]
compare_dlv(subj_id=subj_id, trials=trials)
In [9]:
def plot_surfaces(dlg, fit_params, subj_id, blocks, colors, labels):
dlp = dl_plotter.DLPlotter(elev=10, azim=69)
for i, block_no in enumerate(blocks):
x, y, dl = dlg.get_model_dl(fit_params.loc[subj_id, block_no][2:2+dlg.model.n_params])
dlp.plot_surface(x, y, dl, color=colors[i], alpha=0.8)
dlp.add_legend(colors, labels)
plt.savefig('figures/blocks_%i_dlv.pdf' % (subj_id))
def plot_trajectories(data, subj_id, blocks, colors, labels):
tp = trajectory_plotter.TrajectoryPlotter()
for i, block_no in enumerate(blocks):
block_trajectories = data[(data.subj_id==subj_id) & (data.block_no==block_no)]
tp.plot_mean_trajectories(block_trajectories, colors[i], labels[i])
block_info = block_trajectories.groupby('trial_no').first().groupby('high_chosen') \
.mean()[['motion_time', 'max_d']]
print('\n %s\n' % (labels[i]))
print(block_info)
tp.add_legend_mean_traj(colors, labels)
plt.savefig('figures/blocks_%i_traj.pdf' % (subj_id))
def compare_dlv(subj_id, blocks):
fit_params = pd.read_csv('csv/fit_params_by_block_method_9.csv',
index_col=['subj_id', 'block_no'], header=0)
labels = ['Block %i' % (block) for block in blocks]
cmap = cm.viridis
colors = [cmap(0.7), cmap(0.35), cmap(0.1)]
model = dl_model_3.DLModel3()
dlg = dl_generator.DLGenerator(model)
plot_surfaces(dlg, fit_params, subj_id, blocks, colors, labels)
dr = data_reader.DataReader()
data = dr.get_processed_data(path='csv/processed_data_high_low.csv')
plot_trajectories(data, subj_id, blocks, colors, labels)
da = data_analyser.DataAnalyser()
stats = da.get_block_stats(data)
print('\n %s\n' % ('Block stats'))
print(stats.loc[subj_id])
#subj_id = 233
subj_id = 1334
blocks = [1, 2, 3]
compare_dlv(subj_id, blocks)
In [10]:
def plot_surfaces(dlg, fit_params, subjects, colors, labels):
dlp = dl_plotter.DLPlotter(elev=33, azim=107)
for i, subj_id in enumerate(subjects):
x, y, dl = dlg.get_model_dl(fit_params.loc[subj_id][2:2+dlg.model.n_params])
dlp.plot_surface(x, y, dl, color=colors[i], alpha=0.8)
dlp.add_legend(colors, labels)
plt.savefig('figures/subjects_%i_%i_dlv.pdf' % (subjects[0], subjects[1]))
def plot_trajectories(data, subjects, colors, labels):
tp = trajectory_plotter.TrajectoryPlotter()
for i, subj_id in enumerate(subjects):
subj_trajectories = data[(data.subj_id==subj_id)]
tp.plot_mean_trajectories(subj_trajectories, colors[i], labels[i])
subj_info = subj_trajectories.groupby('trial_no').first().groupby('high_chosen') \
.mean()[['motion_time', 'max_d']]
print('\n %s\n' % (labels[i]))
print(subj_info)
tp.add_legend_mean_traj(colors, labels)
plt.savefig('figures/subjects_%i_%i_traj.pdf' % (subjects[0], subjects[1]))
def compare_dlv(subjects):
fit_params = pd.read_csv('csv/fit_params_by_subject_method_9.csv',
index_col='subj_id', header=0)
labels = ['Participant %i' % (subj_id) for subj_id in subjects]
cmap = cm.viridis
colors = [cmap(0.1), cmap(0.7)]
model = dl_model_3.DLModel3()
dlg = dl_generator.DLGenerator(model)
plot_surfaces(dlg, fit_params, subjects, colors, labels)
dr = data_reader.DataReader()
data = dr.get_processed_data(path='csv/processed_data_high_low.csv')
plot_trajectories(data, subjects, colors, labels)
da = data_analyser.DataAnalyser()
stats = da.get_subjects_stats(data)
print('\n %s\n' % ('Subject stats'))
print(stats.loc[subjects])
compare_dlv(subjects=[9276, 9424])
In [11]:
compare_dlv(subjects=[1395, 1962])