BOLD QA


In [1]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import sqlalchemy
import nibabel as nb
import os
import os.path as op
import numpy as np
from glob import glob
from dipy.segment.mask import median_otsu
from nipy.algorithms.registration import affine,Realign4d
import sys
import json
import argparse
import time
import shutil
import multiprocessing

Define some functions


In [2]:
def mask(d, raw_d=None, nskip=3):
    mn = d[:,:,:,nskip:].mean(3)
    masked_data, mask = median_otsu(mn, 3, 2)
    mask = np.concatenate((np.tile(True, (d.shape[0], d.shape[1], d.shape[2], nskip)),
                           np.tile(np.expand_dims(mask==False, 3), (1,1,1,d.shape[3]-nskip))),
                           axis=3)
    # Some runs have corrupt volumes at the end (e.g., mux scans that are stopped prematurely). Mask those too.
    # But... motion correction might have interpolated the empty slices such that they aren't exactly zero.
    # So use the raw data to find these bad volumes.
    if raw_d!=None:
        slice_max = raw_d.max(0).max(0)
    else:
        slice_max = d.max(0).max(0)
    bad = np.any(slice_max==0, axis=0)
    # We don't want to miss a bad volume somewhere in the middle, as that could be a valid artifact.
    # So, only mask bad vols that are contiguous to the end.
    mask_vols = np.array([np.all(bad[i:]) for i in range(bad.shape[0])])
    # Mask out the skip volumes at the beginning
    mask_vols[0:nskip] = True
    mask[:,:,:,mask_vols] = True
    brain = np.ma.masked_array(d, mask=mask)
    good_vols = np.logical_not(mask_vols)
    return brain,good_vols

def find_spikes(d, spike_thresh):
    slice_mean = d.mean(axis=0).mean(axis=0)
    t_z = (slice_mean - np.atleast_2d(slice_mean.mean(axis=1)).T) / np.atleast_2d(slice_mean.std(axis=1)).T
    spikes = np.abs(t_z)>spike_thresh
    spike_inds = np.transpose(spikes.nonzero())
    # mask out the spikes and recompute z-scores using variance uncontaminated with spikes.
    # This will catch smaller spikes that may have been swamped by big ones.
    d.mask[:,:,spike_inds[:,0],spike_inds[:,1]] = True
    slice_mean2 = d.mean(axis=0).mean(axis=0)
    t_z = (slice_mean - np.atleast_2d(slice_mean.mean(axis=1)).T) / np.atleast_2d(slice_mean2.std(axis=1)).T
    spikes = np.logical_or(spikes, np.abs(t_z)>spike_thresh)
    spike_inds = np.transpose(spikes.nonzero())
    return((spike_inds, t_z))

def estimate_motion(nifti_image):
    # BEGIN STDOUT SUPRESSION
    actualstdout = sys.stdout
    sys.stdout = open(os.devnull,'w')
    # We want to use the middle time point as the reference. But the algorithm does't allow that, so fake it.
    ref_vol = 7; #nifti_image.shape[3]/2 + 1
    ims = nb.four_to_three(nifti_image)
    reg = Realign4d(nb.concat_images([ims[ref_vol]] + ims)) # in the next release, we'll need to add tr=1.

    reg.estimate(loops=3) # default: loops=5
    aligned = reg.resample(0)[:,:,:,1:]
    sys.stdout = actualstdout
    # END STDOUT SUPRESSION
    abs_disp = []
    rel_disp = []
    transrot = []
    prev_T = None
    # skip the first one, since it's the reference volume
    for T in reg._transforms[0][1:]:
        # get the full affine for this volume by pre-multiplying by the reference affine
        #mc_affine = np.dot(ni.get_affine(), T.as_affine())
        transrot.append(T.translation.tolist()+T.rotation.tolist())
        # Compute the mean displacement
        # See http://www.fmrib.ox.ac.uk/analysis/techrep/tr99mj1/tr99mj1/node5.html
        # radius of the spherical head assumption (in mm):
        R = 80.
        # The center of the volume. Assume 0,0,0 in world coordinates.
        # Note: it might be better to use the center of mass of the brain mask.
        xc = np.matrix((0,0,0)).T
        T_error = T.as_affine() - np.eye(4)
        A = np.matrix(T_error[0:3,0:3])
        t = np.matrix(T_error[0:3,3]).T
        abs_disp.append(np.sqrt( R**2. / 5 * np.trace(A.T * A) + (t + A*xc).T * (t + A*xc) ).item())
        if prev_T!=None:
            T_error = T.as_affine() - prev_T.as_affine() # - np.eye(4)
            A = np.matrix(T_error[0:3,0:3])
            t = np.matrix(T_error[0:3,3]).T
            rel_disp.append(np.sqrt( R**2. / 5 * np.trace(A.T * A) + (t + A*xc).T * (t + A*xc) ).item())
        else:
            rel_disp.append(0.0)
        prev_T = T
    return aligned,np.array(abs_disp),np.array(rel_disp),np.array(transrot)

def add_subplot_axes(fig, ax, rect, axisbg='w'):
    box = ax.get_position()
    width = box.width
    height = box.height
    inax_position  = ax.transAxes.transform(rect[0:2])
    transFigure = fig.transFigure.inverted()
    infig_position = transFigure.transform(inax_position)
    x = infig_position[0]
    y = infig_position[1]
    width *= rect[2]
    height *= rect[3]  # <= Typo was here
    subax = fig.add_axes([x,y,width,height],axisbg=axisbg)


    x_labelsize = subax.get_xticklabels()[0].get_size()
    y_labelsize = subax.get_yticklabels()[0].get_size()
    x_labelsize *= rect[2]**0.5
    y_labelsize *= rect[3]**0.5
    subax.xaxis.set_tick_params(labelsize=x_labelsize)
    subax.yaxis.set_tick_params(labelsize=y_labelsize)
    return subax

def plot_motion(df):
    """Plot the timecourses of realignment parameters."""
    fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)

    # Trim off all but the axis name
    f = lambda s: s[-1]

    # Plot rotations
    pal = sns.color_palette("Reds_d", 3)
    rot_df = np.rad2deg(df.filter(like="rot")).rename(columns=f)
    rot_df.plot(ax=axes[0], color=pal, lw=2)

    # Plot translations
    pal = sns.color_palette("Blues_d", 3)
    trans_df = df.filter(like="trans").rename(columns=f)
    trans_df.plot(ax=axes[1], color=pal, lw=2)

    # Plot displacement
    f = lambda s: s[-3:]
    pal = sns.color_palette("Greens_d", 2)
    disp_df = df.filter(like="displace").rename(columns=f)
    disp_df.plot(ax=axes[2], color=pal, lw=2)

    # Label the graphs
    axes[0].set_xlim(0, len(df) - 1)
    axes[0].axhline(0, c=".4", ls="--", zorder=1)
    axes[1].axhline(0, c=".4", ls="--", zorder=1)

    for ax in axes:
        ax.legend(frameon=True, ncol=3, loc="best")
        ax.legend_.get_frame().set_color("white")

    axes[0].set_ylabel("Rotations (degrees)")
    axes[1].set_ylabel("Translations (mm)")
    axes[2].set_ylabel("Displacement (mm)")
    fig.tight_layout()

def plot_target():
    """Plot a mosaic of the motion correction target image."""
    m = Mosaic(res.outputs.out_file, step=1)
    return m

def plot_spikes():
    '''Plot the per-slice z-score timeseries represented by t_z.'''
    c = np.vstack((np.linspace(0,1.,t_z.shape[0]), np.linspace(1,0,t_z.shape[0]), np.ones((2,t_z.shape[0])))).T
    sl_num = np.tile(range(t_z.shape[0]), (t_z.shape[1], 1)).T
    fig = plt.figure(figsize=(16,8))
    ax1 = fig.add_subplot(211)
    for sl in range(t_z.shape[0]):
        ax1.plot(t_z[sl,:], color=c[sl,:])
    ax1.plot((0,t_z.shape[1]),(-spike_thresh,-spike_thresh),'k:')
    ax1.plot((0,t_z.shape[1]),(spike_thresh,spike_thresh),'k:')
    ax1.set_xlabel('time (TR)')
    ax1.set_ylabel('Signal Intensity (z-score)')
    ax1.axis('tight')
    ax1.grid()
    if num_spikes==1:
        ax1.set_title('Spike Plot (%d spike, tSNR=%0.2f)' % (num_spikes, median_tsnr))
    else:
        ax1.set_title('Spike Plot (%d spikes, tSNR=%0.2f)' % (num_spikes, median_tsnr))
    cbax = add_subplot_axes(fig, ax1, [.85,1.11, 0.25,0.05])
    plt.imshow(np.tile(c,(2,1,1)).transpose((0,1,2)), axes=cbax)
    cbax.set_yticks([])
    cbax.set_xlabel('Slice number')
    plt.tight_layout()
    
def plot_spikes_range(sub_range):

    '''Plot the per-slice z-score timeseries represented by t_z.'''
    c = np.vstack((np.linspace(0,1.,t_z.shape[0]), np.linspace(1,0,t_z.shape[0]), np.ones((2,t_z.shape[0])))).T
    sl_num = np.tile(range(t_z.shape[0]), (t_z.shape[1], 1)).T
    fig = plt.figure(figsize=(16,8))
    ax1 = fig.add_subplot(211)
    for sl in range(t_z.shape[0]):
        ax1.plot(sub_range, t_z[sl,sub_range], color=c[sl,:])
    ax1.set_xlabel('time (TR)')
    ax1.set_ylabel('Signal Intensity (z-score)')
    ax1.axis('tight')
    plt.tight_layout()

Experiment specific info:


In [11]:
spike_thresh = 5. # z-score threshold for spike detector
nskip = 6 # number of initial timepoints to skip

bold_file = 'scan01.nii.gz'
infile = op.join('/Volumes/group/awagner/sgagnon/AP/data/ap150/bold/raw', bold_file)

Run QA


In [12]:
ni = nb.load(infile)
tr = ni.get_header().get_zooms()[3]
dims = ni.get_shape()

In [13]:
# Remove some volumes from beginning of run, and possible corrupt vals at the end
brain,good_vols = mask(ni.get_data(), nskip=nskip)
t = np.arange(0.,brain.shape[3]) * tr

# Get the global mean signal (to be subtracted out for spike detection)
global_ts = brain.mean(0).mean(0).mean(0)

# Simple z-score-based spike detection
spike_inds,t_z = find_spikes(brain - global_ts, spike_thresh)
num_spikes = spike_inds.shape[0]

In [14]:
# Compute temporal snr on motion-corrected data
aligned,abs_disp,rel_disp,transrot = estimate_motion(ni)
brain_aligned = np.ma.masked_array(aligned.get_data(), brain.mask)

# Remove slow-drift (3rd-order polynomial) from the variance
global_ts_aligned = brain_aligned.mean(0).mean(0).mean(0)
global_trend = np.poly1d(np.polyfit(t[good_vols], global_ts_aligned[good_vols], 3))(t)
tsnr = brain_aligned.mean(axis=3) / (brain_aligned - global_trend).std(axis=3)
median_tsnr = np.ma.median(tsnr)[0]

In [15]:
# convert rotations to degrees
transrot[:,3:] *= 180./np.pi

Plot the data

Spike Plot


In [16]:
plot_spikes()



In [21]:
sub_range = range(190,200)
plot_spikes_range(sub_range)


Motion


In [18]:
import pandas as pd

pd.DataFrame([abs_disp, rel_disp]).transpose().plot()


Out[18]:
<matplotlib.axes._subplots.AxesSubplot at 0x113f19dd0>

In [31]:
np.max(rel_disp[2:227])


Out[31]:
0.52857677926518554

In [39]:
np.std(rel_disp[2:227])


Out[39]:
0.11984358058016671

In [32]:
median_tsnr


Out[32]:
81.06092515458711

Analyze group (all test runs)


In [36]:
spike_thresh = 5. # z-score threshold for spike detector
nskip = 6 # number of initial timepoints to skip
infile = op.join("/Volumes/group/awagner/sgagnon/AP/data/{subj}/bold/raw", "scan{run_num}.nii.gz")


df_mot = pd.DataFrame(columns=['subid', 'run', 'max_rel_disp', 
                               'max_abs_disp', 'num_spikes', 
                               'median_tsnr'])

for subj in ['ap100', 'ap101', 'ap102', 'ap103', 'ap104', 'ap150']:
    print subj
    for run_num in ['01', '02', '03', '04', '05', '06']:
        print run_num
        
        ni = nb.load(infile.format(subj=subj, run_num=run_num))
        tr = ni.get_header().get_zooms()[3]
        dims = ni.get_shape()
        
        # Remove some volumes from beginning of run, and possible corrupt vals at the end
        brain,good_vols = mask(ni.get_data(), nskip=nskip)
        t = np.arange(0.,brain.shape[3]) * tr

        # Get the global mean signal (to be subtracted out for spike detection)
        global_ts = brain.mean(0).mean(0).mean(0)

        # Simple z-score-based spike detection
        spike_inds,t_z = find_spikes(brain - global_ts, spike_thresh)
        num_spikes = spike_inds.shape[0]
        
        # Compute temporal snr on motion-corrected data
        aligned,abs_disp,rel_disp,transrot = estimate_motion(ni)
        brain_aligned = np.ma.masked_array(aligned.get_data(), brain.mask)

        # Remove slow-drift (3rd-order polynomial) from the variance
        global_ts_aligned = brain_aligned.mean(0).mean(0).mean(0)
        global_trend = np.poly1d(np.polyfit(t[good_vols], global_ts_aligned[good_vols], 3))(t)
        tsnr = brain_aligned.mean(axis=3) / (brain_aligned - global_trend).std(axis=3)
        median_tsnr = np.ma.median(tsnr)[0]
        
        row = {'subid': subj,
               'run': 'run_' + run_num,
               'max_rel_disp': np.max(rel_disp[2:227]),
               'mean_rel_disp': np.mean(rel_disp[2:227]),
               'std_rel_disp': np.std(rel_disp[2:227]),
               'max_abs_disp': np.max(abs_disp[2:227]),
               'num_spikes': num_spikes,
               'median_tsnr': median_tsnr}
        df_mot = df_mot.append(row, ignore_index=True)

In [37]:
df_mot.head()


Out[37]:
subid run max_rel_disp max_abs_disp num_spikes median_tsnr
0 ap100 run_01 0.199330 0.379180 1 104.895042
1 ap100 run_02 0.188954 0.503407 0 107.207038
2 ap100 run_03 0.222950 0.342913 0 92.634170
3 ap100 run_04 0.281031 0.451853 0 82.511517
4 ap100 run_05 0.372188 0.803519 0 79.767119

In [41]:
import seaborn as sns
sns.factorplot(x='subid', y='max_rel_disp', ci=68, data=df_mot)


Out[41]:
<seaborn.axisgrid.FacetGrid at 0x110ff7290>

In [43]:
sns.factorplot(x='subid', y='max_abs_disp', ci=68, data=df_mot)


Out[43]:
<seaborn.axisgrid.FacetGrid at 0x10ff7fe10>

In [42]:
sns.factorplot(x='subid', y='num_spikes', ci=68, data=df_mot)


Out[42]:
<seaborn.axisgrid.FacetGrid at 0x1133127d0>

In [44]:
sns.factorplot(x='subid', y='median_tsnr', ci=68, data=df_mot)


Out[44]:
<seaborn.axisgrid.FacetGrid at 0x10ff11b50>

Alternate motion correction analysis


In [99]:
from nipype.interfaces import fsl
import pandas as pd
import seaborn as sns

sns.set(style='whitegrid', context='poster')

In [173]:
mcflt = fsl.MCFLIRT(in_file=infile, cost="normcorr", ref_vol=7,
                    interpolation="spline",
                    save_mats=True,
                    save_rms=True,
                    save_plots=True)
res = mcflt.run()


INFO:interface:stderr 2015-05-07T13:18:21.709479:refnum = 7
INFO:interface:stderr 2015-05-07T13:18:21.709479:Original_refvol = 7

In [174]:
# Load the realignment parameters
rot = ["rot_" + dim for dim in ["x", "y", "z"]]
trans = ["trans_" + dim for dim in ["x", "y", "z"]]
df = pd.DataFrame(np.loadtxt(res.outputs.par_file),
                  columns=rot + trans)

In [175]:
# Load the RMS displacement parameters
abs, rel = res.outputs.rms_files
df["displace_abs"] = np.loadtxt(abs)
df["displace_rel"] = pd.Series(np.loadtxt(rel), index=df.index[1:])
df.loc[0, "displace_rel"] = 0

In [176]:
df[['rot_x', 'rot_y', 'rot_z']].plot()


Out[176]:
<matplotlib.axes._subplots.AxesSubplot at 0x133310b10>

In [177]:
df.columns


Out[177]:
Index([u'rot_x', u'rot_y', u'rot_z', u'trans_x', u'trans_y', u'trans_z', u'displace_abs', u'displace_rel'], dtype='object')

In [178]:
df[['trans_x', 'trans_y', 'trans_z']].plot()


Out[178]:
<matplotlib.axes._subplots.AxesSubplot at 0x13393ad90>

In [179]:
plot_motion(df)



In [ ]: