In [1]:
%matplotlib inline
%gui qt
In [2]:
import os.path as op
import numpy as np
from scipy import stats
import nibabel as nib
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from surfer import Brain
from scipy.ndimage import binary_erosion
In [3]:
import lyman
subjects = lyman.determine_subjects()
analysis_dir = lyman.gather_project_info()["analysis_dir"]
In [4]:
from surfer import Brain
In [5]:
sns.set(style="ticks", context="paper")
mpl.rcParams["savefig.dpi"] = 200
In [6]:
hemis = ["lh", "rh"]
The computation is handled by an external script (process_searchlight.py). You can activate the cell below and run it.
Load in the cached searchlight data
In [7]:
group = dict(lh=[], rh=[])
temp = op.join(analysis_dir, "dksort/{subj}/mvpa/searchlight/{hemi}.dimension_dksort_pfc.mgz")
for subj in subjects:
for hemi in hemis:
group[hemi].append(nib.load(temp.format(subj=subj, hemi=hemi)).get_data().squeeze())
Do a group t-test against expected chance at each vertex
In [8]:
group_accs = dict()
group_means = dict()
group_masks = dict()
group_ts = dict()
vertex_accs = []
for hemi in hemis:
accs = np.vstack(group[hemi])
means = accs.mean(axis=0)
masks = accs.all(axis=0)
ts = (means - (1. / 3)) / (accs.std(axis=0) / np.sqrt(len(accs) - 1))
ts = np.nan_to_num(ts)
vertex_accs.extend(means[masks])
group_accs[hemi] = accs
group_means[hemi] = means
group_masks[hemi] = masks
group_ts[hemi] = ts
Set a threshold based on the t statistic for plotting (though we will plot the mean accuracy)
In [9]:
alpha = 0.005
thresh = stats.t(14).ppf(alpha) * -1
Use PySurfer to plot the data on the surface
In [10]:
views = dict()
for hemi in hemis:
b = Brain("fsaverage", hemi, "semi7", config_opts={"background": "white", "width": 500, "height": 420})
data = group_means[hemi]
data[group_ts[hemi] < thresh] = 0
b.add_data(group_means[hemi], min=0.3, max=0.5, thresh=0.1, colormap="OrRd_r", colorbar=False)
b.add_data(~group_masks[hemi], min=.5, max=1.05, thresh=.5, alpha=.5, colormap="bone_r", colorbar=False)
b.add_label("yeo17_ifs", borders=True, color=".3")
b.show_view(dict(elevation=80,
azimuth=dict(lh=150, rh=30)[hemi],
focalpoint=[0, 10, 10]), distance=325)
views[hemi] = b.screenshot()
b.close()
Grab an example slice through a subject's mean fucntional and searchlight mask
In [11]:
slc = 12
epi = nib.load(op.join(analysis_dir, "dksort/dk11/preproc/run_1/mean_func.nii.gz")).get_data()
epi = epi[..., slc].T
mask = nib.load(op.join(analysis_dir, "dksort/dk11/preproc/run_1/functional_mask.nii.gz")).get_data()
mask = binary_erosion(mask[..., slc].astype(bool).T, iterations=2)
roi = nib.load(op.join(analysis_dir, "dksort/dk11/mvpa/searchlight/dksort_pfc_mask.nii.gz")).get_data()
roi = roi[..., slc].T
epi[~mask] = np.nan
roi[roi < .5] = np.nan
Plot the data
In [12]:
fig = plt.figure(figsize=(3.34, 2.8))
# Plot the average searchlights on the surface
rh_ax = fig.add_axes([-.04, .5, .54, .5])
rh_ax.imshow(views["rh"])
rh_ax.set_axis_off()
lh_ax = fig.add_axes([.5, .5, .54, .5])
lh_ax.imshow(views["lh"])
lh_ax.set_axis_off()
# Draw a colorbar for the statistical overlay
with mpl.rc_context({"axes.linewidth": .4}):
cbar_ax = fig.add_axes([.35, .49, .3, .035])
cbar_ax.pcolormesh(np.atleast_2d(np.linspace(0, 1, 100)), cmap="OrRd_r")
cbar_ax.set(xticks=[], yticks=[])
fig.text(.34, .505, "0.3", ha="right", va="center", size=7)
fig.text(.66, .505, "0.5", ha="left", va="center", size=7)
fig.text(.5, .53, "Mean accuracy", ha="center", va="bottom", size=7)
# Show an example slice through the mean functional and searchlight mask
mask_ax = fig.add_axes([-.04, 0, .45, .45])
mask_ax.imshow(epi, cmap="Greys_r")
roi_cmap = mpl.colors.ListedColormap(["steelblue"])
mask_ax.imshow(roi, cmap=roi_cmap, alpha=.7, interpolation="nearest")
mask_ax.set_axis_off()
# Show the distribution of average searchlight accuracy
with mpl.rc_context({"axes.linewidth": .5,
"xticks.major.width": .5,
"yticks.major.width": .5}):
hist_ax = fig.add_axes([.38, .12, .56, .33])
counts, bins = np.histogram(vertex_accs, np.linspace(.3, .4, 25))
hist_pal = sns.color_palette("OrRd_r", 50)[:len(counts)]
hist_ax.bar(bins[:-1], counts, width=.0033, color=hist_pal, alpha=.8)
hist_ax.set(yticks=[], xlim=(.3, .4))
hist_ax.set_xticks(np.linspace(.3, .4, 6))
hist_ax.set_xticklabels(np.linspace(.3, .4, 6), size=7)
hist_ax.set_xlabel("Cross-validated decoding accuracy", labelpad=1.5)
hist_ax.axvline(x=.33333, ymax=1, ls=":", c=".2")
sns.despine(ax=hist_ax, left=True)
fig.text(.02, .94, "A", size=12)
fig.text(.02, .41, "B", size=12)
fig.text(.36, .41, "C", size=12)
fig.savefig("figures/figure_5.tiff", dpi=300)
fig.savefig("figures/figure_5.pdf", dpi=300)
In [ ]: