experimental interactive visualisation written by B. Soergel for the Autism Gradients project at Brainhack Cambridge 2017


In [1]:
%matplotlib inline
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib import colors
import nilearn
import nilearn.plotting
from __future__ import print_function,division

In [2]:
#copied from Richard's notebook
def rebuild_nii(num):

    data = np.load('Mean_Vec.npy')
    a = data[:,num].copy()
    nim = nib.load('cc400_roi_atlas.nii')
    imdat=nim.get_data()
    imdat_new = imdat.copy()

    for n, i in enumerate(np.unique(imdat)):
        if i != 0:
            imdat_new[imdat == i] = a[n-1] * 100000 # scaling factor. Could also try to get float values in nifti...

    nim_out = nib.Nifti1Image(imdat_new, nim.get_affine(), nim.get_header())
    nim_out.set_data_dtype('float32')
    # to save:
    nim_out.to_filename('Gradient_'+ str(num) +'_res.nii')

    nilearn.plotting.plot_epi(nim_out)
    return(nim_out)

In [3]:
for i in range(10):
    nims = rebuild_nii(i)



In [ ]:


In [ ]:

actual brain visualisation widget


In [4]:
from ipywidgets import widgets
from ipywidgets import interact,fixed
from IPython.display import display

In [5]:
def rebuild_nii_compute(data,nim,num):
    """
    precompute to make widget faster
    """
    
    #data = np.load('Mean_Vec.npy')
    a = data[:,num].copy()
    #nim = nib.load('cc400_roi_atlas.nii')
    imdat=nim.get_data()
    imdat_new = imdat.copy()

    for n, i in enumerate(np.unique(imdat)):
        if i != 0:
            imdat_new[imdat == i] = a[n-1] * 100000 # scaling factor. Could also try to get float values in nifti...

    nim_out = nib.Nifti1Image(imdat_new, nim.get_affine(), nim.get_header())
    nim_out.set_data_dtype('float32')
    # to save:
    #nim_out.to_filename('Gradient_'+ str(num) +'_res.nii')

    #nilearn.plotting.plot_epi(nim_out)
    return nim_out,imdat_new

def rebuild_nii_plot(num,nims,cutc_x,cutc_y,cutc_z):
    """
    simple plotting function for widget
    """
    cut_coords = (cutc_x,cutc_y,cutc_z)
    fig = nilearn.plotting.plot_epi(nims[num],cut_coords=cut_coords)
    plt.show()
    
def rebuild_nii_plot_new(num,imdats_new,cutc_x,cutc_y,cutc_z,cmap,figsize):
    """
    maybe using matplotlib is faster
    """
    cut_coords = (cutc_x,cutc_y,cutc_z)
    fig,axes=plt.subplots(1,3,figsize=figsize)
    imdat = imdats_new[num]
    axes[0].imshow(imdat[:,cutc_y,:].T,cmap=cmap,origin='lower')
    axes[0].set_title('y')
    axes[1].imshow(imdat[cutc_x,:,:].T,cmap=cmap,origin='lower')
    axes[1].set_title('x')
    axes[2].imshow(imdat[:,:,cutc_z].T,cmap=cmap,origin='lower')
    axes[2].set_title('z')
    plt.show()

In [6]:
#prepare data
data = np.load('Mean_Vec.npy')
nim = nib.load('cc400_roi_atlas.nii')
imdat=nim.get_data()
print(data.shape,imdat.shape)


(392, 10) (63, 75, 61)

In [7]:
#precompute to speed up widget
nims = []
imdats = []
for j in range(data.shape[1]):
    nim_new,imdat_new = rebuild_nii_compute(data,nim,j)
    nims.append(nim_new)
    imdats.append(imdat_new)

In [8]:
#testing plotting function
nilearn.plotting.plot_epi(nims[0]) #automatic choice of cut_coords
nilearn.plotting.plot_epi(nims[0],cut_coords=(58,-15,-20)) #manual choice


Out[8]:
<nilearn.plotting.displays.OrthoSlicer at 0x10ff68a90>

In [9]:
#build widget
coords_max = 60
i = interact(rebuild_nii_plot,
             num=(0,9),
             nims=fixed(nims),
             cutc_x=(-coords_max,coords_max),
             cutc_y=(-coords_max,coords_max),
             cutc_z=(-coords_max,coords_max),
            )


another attempt with a pure matplotlib version of the plot function


In [10]:
coords_max = imdat.shape
figsize=(15,10)
i = interact(rebuild_nii_plot_new,
             num=(0,9),
             imdats_new=fixed(imdats),
             cutc_x=(0,coords_max[0]-1),
             cutc_y=(0,coords_max[1]-1),
             cutc_z=(0,coords_max[2]-1),
             cmap = ['jet','viridis','gray'],
             figsize=fixed(figsize)
            )



In [12]:
epi_img = nib.load('./ROIs_Mask/someones_epi.nii.gz')
epi_data = epi_img.get_data()

In [13]:
def plot_projection(data,axis,pos0,pos1,pos2,cmap,figsize):
    if axis=='x':
        imdat = data[pos0,:,:]
    elif axis=='y':
        imdat = data[:,pos1,:]
    elif axis=='z':
        imdat = data[:,:,pos2]
    
    fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(imdat,cmap=cmap, origin="lower")
    plt.show()

In [14]:
#some settings for widget
data = epi_data
data.shape
figsize=(6,6)

In [15]:
i = interact(plot_projection,
             data=fixed(data),
             axis=['x','y','z'],
             pos0=(0,data.shape[0]-1),
             pos1=(0,data.shape[1]-1),
             pos2=(0,data.shape[2]-1),
             cmap = ['gray','jet','viridis'],
            figsize=fixed(figsize))



In [16]:
def show_slices_grid(slices,gridsize):
    """ Function to display row of image slices """
    fig, axes = plt.subplots(gridsize, len(slices)//gridsize+1,figsize=(12,12))
    for i, slice in enumerate(slices):
        row = i//gridsize
        column = i%gridsize
        axes[row][column].imshow(slice.T, cmap="gray", origin="lower")
    fig.tight_layout()
    
slices = [epi_data[:,:,i] for i in range(epi_data.shape[2])]        
show_slices_grid(slices,6)



In [ ]:


In [ ]: