In [1]:
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 29 15:26:06 2013

@author: ruifang
"""

#################DECLARATION##############
import nipype.interfaces.spm as spm         # the spm interfaces
import nipype.interfaces.spm.utils as spmu
import nipype.interfaces.fsl as fsl          # fsl
import nipype.pipeline.engine as pe         # the workflow and node wrappers
import nipype.interfaces.matlab as mlab
mlab.MatlabCommand.set_default_paths('/SCR/material/spm8/spm8')

# Tell fsl to generate all output in uncompressed nifti format
fsl.FSLCommand.set_default_output_type('NIFTI')
import os
import re
import glob    
from nipype.interfaces.io import DataSink
import os.path as op

from nipype.algorithms.misc import TSNR      #Computes the time-course SNR for a time series
import nipype.interfaces.utility as util
##########################################


/home/raid1/margulies/.local/lib/python2.7/site-packages/scipy/__init__.py:111: UserWarning: Numpy 1.6.2 or above is recommended for this version of scipy (detected version 1.6.1)
  UserWarning)

In [ ]:
subjects = ["M5","M12"]
fun_file = dict()
anat_file = dict()


for name in subjects:
    str = '/SCR/datasets-for-tracer-validation/tracer-validation/'+name+'/reoriented/'
    input_file = []
    for fname in os.listdir(str):
        if os.path.isdir(str+fname):         
            m = re.match("^S[0-9]*", fname)
            if m:        
                os.chdir(str+fname)
                for files in glob.glob("*.nii"):
                    input_file.insert(1,str+fname+"/"+files)  
    input_file.sort()
    fun_file[name] = input_file
    
    input_file2 = []
    os.chdir(str+"anat")
    for files in glob.glob("*.nii"):
        input_file2.insert(1,str+"anat/"+files)
    anat_file[name] = input_file2


def extract_noise_components(realigned_file, noise_mask_file, num_components,motion_file):
    """Derive components most reflective of physiological noise
    """
    import os
    from nibabel import load
    import numpy as np
    import scipy as sp
    from scipy.signal import detrend
    imgseries = load(realigned_file)
    noise_mask = load(noise_mask_file)
    voxel_timecourses = imgseries.get_data()[np.nonzero(noise_mask.get_data())]
    for timecourse in voxel_timecourses:
        timecourse[:] = detrend(timecourse, type='constant')
    u,s,v = sp.linalg.svd(voxel_timecourses, full_matrices=False)
    components_file = os.path.join(os.getcwd(), 'noise_components.txt')
    np.savetxt(components_file, v[:num_components, :].T)
    
    mc_file = open(motion_file, 'r')
    lines = []
    for line in mc_file:
        lines.append(line.strip())
    noise_file = open(components_file,'r')
    count = 0
    
    for line in noise_file:
        lines[count] = line.strip()+' '+lines[count]
        count = count+1
    
    noise_file = open(components_file,'w+')
    for line in lines:
        noise_file.write(line+"\n")
        
    return components_file


def session_cmb(ts_norm_list,full_cor):
    import csv    
    import numpy as np
    import collections
    from MaxCorrelationTransformer import partial_correlation
    print partial_correlation # no need for pythonFunction.
    from scipy.stats import pearsonr
    import os
    
    count = 0
    onefile = dict() 
    sort_onefile = dict()
    for sfile in ts_norm_list:
        file = open(sfile, 'rb')
        data = csv.reader(file, delimiter='\t')
        table = [row for row in data]
        if count ==0:
            for line in table:
                ls = line[1:len(line)]
                ls_new = [x for x in ls if str(x) != '']
                arr = np.array(ls_new,dtype='float')
                onefile[int(line[0])] = arr
            count+=1
            sort_onefile = collections.OrderedDict(sorted(onefile.items()))  
        else:
            for line in table:
                ls = line[1:len(line)]
                ls_new = [x for x in ls if str(x) != '']
                arr = np.array(ls_new,dtype='float')
                cur_arr = sort_onefile.get(int(line[0]))
                sort_onefile[int(line[0])] = np.concatenate((cur_arr, arr), axis=0)
    
    sort_onefile = collections.OrderedDict(sorted(sort_onefile.items())) 
    
    cmb_file = os.path.join(os.getcwd(), 'cmb_mean_ts.txt')
    with open(cmb_file, 'w') as f:
        for key in sort_onefile:
            f.write("%s\t" % key)
            for i in sort_onefile[key]:
                f.write("%f\t" %i)
            f.write("\n") 
                       
    cor_mat = []
    for line in sort_onefile.iterkeys():
        cor_mat.append(sort_onefile[line])
        
    if full_cor:
        out_full_cor_mat = os.path.join(os.getcwd(),'full_cor_mat.txt')
        f_full_mat = open(out_full_cor_mat,'wt')
        out_full_cor_pval = os.path.join(os.getcwd(),'full_cor_pval.txt')
        full_pval = open(out_full_cor_pval,'wt')
    else:
        out_full_cor_mat = os.path.join(os.getcwd(),'partial_cor_mat.txt')
        f_full_mat = open(out_full_cor_mat,'wt')
        out_full_cor_pval = os.path.join(os.getcwd(),'partial_cor_pval.txt')
        full_pval = open(out_full_cor_pval,'wt')
    
    if full_cor:      
            x = np.array(cor_mat,np.float32)
            t = len(x)
            temp = range(0, t, 1)

            for i in temp:
                for j in temp:
                    val1 = np.array(cor_mat[i],np.float32,ndmin = 1)
                    val2 = np.array(cor_mat[j],np.float32,ndmin = 1)
                    (ret,pval) = pearsonr(val1,val2)
                    f_full_mat.write('%f' %ret + '\t') 
                    full_pval.write('%f' %pval + '\t')
                f_full_mat.write('\n')  
                full_pval.write('\n')
    else:   
            x = np.array(cor_mat,np.float32)
            t = len(x)
            temp = range(0, t, 1)
            
            for i in temp:
                for j in temp:
                    if i==j:
                        f_full_mat.write('1'+'\t')
                        full_pval.write('1'+'\t')
                    else:
                        new_list = []
                        for k in temp:
                            if k!=i and k!=j:
                                ret = cor_mat[k]
                                new_list.append(ret) 
                        z = np.array(new_list,np.float32) 
                        val1 = np.array(cor_mat[i],np.float32,ndmin = 1)
                        val2 = np.array(cor_mat[j],np.float32,ndmin = 1)
                        (ret,pval) = partial_correlation(val1,val2,np.transpose(z))
                        f_full_mat.write('%f' %ret + '\t') 
                        full_pval.write('%f' %pval + '\t')
                f_full_mat.write('\n')  
                full_pval.write('\n')        
    return (cmb_file,out_full_cor_mat,out_full_cor_pval)


def session_norm(ts_mean_file):
    import csv
    import numpy as np
    import collections
    import os
    file = open(ts_mean_file, 'rb')
    data = csv.reader(file, delimiter='\t')
    table = [row for row in data]
    
    norm_file = dict()    
    
    for line in table:
        ls = line[1:len(line)]
        ls_new = [x for x in ls if str(x) != '']
        arr = np.array(ls_new,dtype='float')
        norm_arr = [(x-arr.mean())/arr.var() for x in arr]
        norm_file[int(line[0])] = norm_arr

    sort_norm_file = collections.OrderedDict(sorted(norm_file.items()))    
    norm_mean_ts = os.path.join(os.getcwd(), 'norm_mean_ts.txt')
    with open(norm_mean_ts, 'w') as f:
        for key in sort_norm_file:
            f.write("%s\t" % key)
            for i in sort_norm_file[key]:
                f.write("%f\t" %i)
            f.write("\n")         
    return norm_mean_ts


def extract_subrois(timeseries_file, label_file,full_cor):
    """Extract voxel time courses for each subcortical roi index

    Parameters
    ----------

    timeseries_file: a 4D Nifti file
    label_file: a 3D file containing rois in the same space/size of the 4D file

    """
    from MaxCorrelationTransformer import partial_correlation
    print partial_correlation # no need for pythonFunction.
    import nibabel as nb
    import numpy as np
    import os 
    from scipy.stats import pearsonr
    
    img = nb.load(timeseries_file)
    data = img.get_data()    
       
    roiimg = nb.load(label_file)
    rois = roiimg.get_data()
    
    series_header = img.get_header()
    vol_num = series_header.get_data_shape()[3]
    
    Rois = np.unique(rois)
    region_list = list(Rois)
    region_list.sort()
    region_list.remove(0)
    region_list = [x for x in region_list if str(x) != 'nan']
    region_list = map(int, region_list) 
    cor_mat = []
    
      
    out_ts_file = os.path.join(os.getcwd(), 'roi_timeseries.txt')
    out_ts_mean_file = os.path.join(os.getcwd(),'roi_timeseries_mean.txt')
    f = open(out_ts_mean_file,'wt')
    
    if full_cor:
        out_full_cor_mat = os.path.join(os.getcwd(),'full_cor_mat.txt')
        f_full_mat = open(out_full_cor_mat,'wt')
        out_full_cor_pval = os.path.join(os.getcwd(),'full_cor_pval.txt')
        full_pval = open(out_full_cor_pval,'wt')
    else:
        out_full_cor_mat = os.path.join(os.getcwd(),'partial_cor_mat.txt')
        f_full_mat = open(out_full_cor_mat,'wt')
        out_full_cor_pval = os.path.join(os.getcwd(),'partial_cor_pval.txt')
        full_pval = open(out_full_cor_pval,'wt')
        
    with open(out_ts_file, 'wt') as fp:
        for fsindex in region_list:
            ijk = np.nonzero(rois == fsindex)
            ts = data[ijk]
            arr_sum = [0]*vol_num
            for i0, row in enumerate(ts):
                if type(row).__name__=='float':
                    fp.write('%d,%d,%d,%d' % (fsindex, ijk[0][i0], ijk[1][i0], ijk[2][i0]) + '\n')
                else:    
                    fp.write('%d,%d,%d,%d,' % (fsindex, ijk[0][i0], ijk[1][i0], ijk[2][i0]) + ','.join(['%.10f' % val for val in row]) + '\n')                  
                    count = 0
                    for val in row:
                        arr_sum[count] = arr_sum[count] + val
                        count+=1
            final = [0]*vol_num
            i = 0
            f.write('%d' %fsindex + '\t')
            for val in arr_sum:
                final[i] = val/len(ts)
                f.write('%f' %final[i] + '\t')
                i+=1
            f.write('\n')
            cor_mat.append(final)     
        #calculate correlation matrix
        if full_cor:      
            x = np.array(cor_mat,np.float32)
            t = len(x)
            temp = range(0, t, 1)

            for i in temp:
                for j in temp:
                    val1 = np.array(cor_mat[i],np.float32,ndmin = 1)
                    val2 = np.array(cor_mat[j],np.float32,ndmin = 1)
                    (ret,pval) = pearsonr(val1,val2)
                    f_full_mat.write('%f' %ret + '\t') 
                    full_pval.write('%f' %pval + '\t')
                f_full_mat.write('\n')  
                full_pval.write('\n')
        else:   
            x = np.array(cor_mat,np.float32)
            t = len(x)
            temp = range(0, t, 1)
            
            for i in temp:
                for j in temp:
                    if i==j:
                        f_full_mat.write('1'+'\t')
                        full_pval.write('1'+'\t')
                    else:
                        new_list = []
                        for k in temp:
                            if k!=i and k!=j:
                                ret = cor_mat[k]
                                new_list.append(ret) 
                        z = np.array(new_list,np.float32) 
                        val1 = np.array(cor_mat[i],np.float32,ndmin = 1)
                        val2 = np.array(cor_mat[j],np.float32,ndmin = 1)
                        (ret,pval) = partial_correlation(val1,val2,np.transpose(z))
                        f_full_mat.write('%f' %ret + '\t') 
                        full_pval.write('%f' %pval + '\t')
                f_full_mat.write('\n')  
                full_pval.write('\n')
              
    return (out_ts_file,out_ts_mean_file,out_full_cor_mat,out_full_cor_pval)


def pval_correction(full_pval):
    from rpy2.robjects.packages import importr
    from rpy2.robjects.vectors import FloatVector
    import csv
    import os 
    
    stats = importr('stats')    
    file = open(full_pval, 'rb')
    data = csv.reader(file, delimiter='\t')
    table = [filter(None, row) for row in data]
        
    pval = []
    for line in table:
        for cell in line:
            pval.append(float(cell))
            
    pval_adjust = stats.p_adjust(FloatVector(pval), method = 'BH')
    
    col = len(table)
       
    count = 0
    for v in pval_adjust:
        r = int(count/col)
        c = int(count%col)
        table[r][c] = v
        count = count+1
        
    ls = full_pval.strip().split('/')    
    name = ls[len(ls)-1].strip().split('.')
    new_name = str(name[0])+"_adjust.txt"
    out_full_cor_pval = os.path.join(os.getcwd(),new_name)
           
    with open(out_full_cor_pval,'wt') as f:
        for line in table:
            for cell in line:
                f.write(str(cell) + "\t")
            f.write("\n")    
    
    return out_full_cor_pval


def ResultPlot(fun_cor_mat,full_pval):
    import xlrd
    
    nii_acr = dict()
    acr_nii = dict()
    
    workbook = xlrd.open_workbook('/SCR/datasets-for-tracer-validation/Project/Markov_Parcellation_niftii_codes.xls')
    worksheet = workbook.sheet_by_name('Sheet1')
    num_rows = worksheet.nrows - 1
    curr_row = -1
    while curr_row < num_rows:
        if curr_row == -1:
            curr_row += 1
            continue
        curr_row += 1
        row = worksheet.row(curr_row)
        
        if type(row[1].value).__name__=='float':
            row[1].value=int(row[1].value)
        else:
            row[1].value = str(row[1].value).lower()
        nii_acr[int(row[0].value)] = str(row[1].value)
        acr_nii[str(row[1].value)] = int(row[0].value)
    
    row_idx = dict()
    col_idx = dict()
    true_val = dict()
    true_val_binary = dict()
    
    table = []
    binary_table = []
    
    #file = open('/SCR/datasets-for-tracer-validation/Project/Connectivity.Macaque.DBV.23.45.txt', 'rb')
    file = open('/SCR/datasets-for-tracer-validation/Project/Connectivity.Macaque.DBV.23.45.max.txt', 'rb')
    #file = open('/SCR/datasets-for-tracer-validation/Project/Connectivity.Macaque.DBV.23.45.avg.txt', 'rb')
    row_count = 0
    for row in file:
        arr = row.strip().split('\t')
        line = []
        binary_line = []
        if row_count ==0:
            row_count += 1
            idx = 0;
            for val in arr:
                col_idx[idx] = str(val)
                idx += 1
        else:  
            arr_count = 0
            for i in arr:
                if arr_count ==0:
                    row_idx[str(arr[0])] = int(row_count-1)
                    arr_count +=1
                else:
                    line.append(float(i))    
                    true_val[str(acr_nii.get(str(arr[0]).lower()))+"_"+str(acr_nii.get(str(col_idx[arr_count-1]).lower()))] = float(i)
                    if(float(i)>0):
                        binary_line.append(1)
                        true_val_binary[str(acr_nii.get(str(arr[0]).lower()))+"_"+str(acr_nii.get(str(col_idx[arr_count-1]).lower()))] = 1
                    else:
                        binary_line.append(0)
                        true_val_binary[str(acr_nii.get(str(arr[0]).lower()))+"_"+str(acr_nii.get(str(col_idx[arr_count-1]).lower()))] = 0
                    arr_count += 1
                    
            row_count += 1
            table.append(line)
            binary_table.append(binary_line)    
    
    sort_true_val = sorted(true_val.iteritems(), key=lambda (k,v): (v,k),reverse=True)
    
    thresh = len(sort_true_val)*0.01
    count = 0
    for k in sort_true_val:
        k_name = str(k[0])
        k_val = float(k[1])
        if(count<thresh) and k_val>0:
            true_val_binary[k_name] = 1
        else:
            true_val_binary[k_name] = 0
        count = count+1
   
    
    print __doc__
    
    import pylab as pl
    import nibabel as nb
    import numpy as np
    from sklearn.metrics import roc_curve, auc
    import os
    from math import log
    from matplotlib.gridspec import GridSpec
    ###############
    roiimg = nb.load('/SCR/datasets-for-tracer-validation/Project/Markov_Parcellation_3D.nii')
    rois = roiimg.get_data()  
    Rois = np.unique(rois)
    region_list = list(Rois)
    region_list.sort()
    region_list.remove(0)
    region_list = map(int, region_list) 

    region_map = dict()
    count = 0
    for i in region_list:
        region_map[count] = i
        count += 1         
    
    f = open(fun_cor_mat,'r')
    f_pval = open(full_pval,'r')
    
    fun_cor_map = dict()
    fun_cor_map_real = dict()
    fun_pval_map_real = dict()
    
    pval = []
    for line in f_pval:
        arr = line.strip().split('\t')
        pval.append(arr)
    pval_mat = np.array(pval,np.float32)    
               
    row = 0
    for line in f:
        arr = line.strip().split('\t')
        col = 0
        for cell in arr:
            t = 0
            p_val = pval_mat[row,col]
            if p_val<=0.05 and float(cell)>0:
                t = 1
            fun_cor_map[str(region_map.get(row))+'_'+str(region_map.get(col))] = t
            fun_cor_map_real[str(region_map.get(row))+'_'+str(region_map.get(col))] = float(cell)
            fun_pval_map_real[str(region_map.get(row))+'_'+str(region_map.get(col))] = p_val
            col += 1 
        row += 1

    ##############    
    real = true_val_binary.values()
    
    test = [0]*len(real)
    count = 0
    real = [0]*len(real)
    for k in true_val_binary.keys():
         test[count] = fun_cor_map_real.get(k)
         real[count] = true_val_binary.get(k)
         count += 1     
        
    fpr, tpr, thresholds = roc_curve(real, test, pos_label=1)
    
    real_scatter1 = true_val.values()
    count = 0
    real_scatter = [0]*len(real_scatter1)
    for val in real_scatter1:
        real_scatter[count] = log(float(val+1))   #ln
        count = count+1
   
    test_scatter = [0]*len(real_scatter)
    count = 0
    for k in true_val.keys():
        test_scatter[count] = float(fun_cor_map_real.get(k))
        count += 1
    
    roc_auc = auc(fpr, tpr)
    print "Area under the ROC curve : %f" % roc_auc

    f = pl.figure(figsize=(24, 12), dpi=80)   
    #ROC
    gs1 = GridSpec(3, 3)
    gs1.update(left=0.05, right=0.48, wspace=0.03)
    ax1 = pl.subplot(gs1[:, :])  
    # Plot ROC curve
    ax1.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
    ax1.plot([0, 1], [0, 1], 'k--')
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.0])
    ax1.set_xlabel('False Positive Rate')
    ax1.set_ylabel('True Positive Rate')
    ax1.set_title('Receiver operating characteristic')
    ax1.legend(loc="lower right")

    x = real_scatter
    y = test_scatter
    
    # Calculate number of bins based on binsize for both x and y
    min_x_data, max_x_data = np.min(x), np.max(x)
    #binsize = 0.2
    #num_x_bins = np.floor((max_x_data - min_x_data) / binsize)
    num_x_bins = 50
    
    min_y_data, max_y_data = np.min(y), np.max(y)
    #binsize = 0.1
    #num_y_bins = np.floor((max_y_data - min_y_data) / binsize)
    num_y_bins = 50
    
    nullfmt = pl.NullFormatter()
    left, width = 0.1, 0.4
    bottom, height = 0.1, 0.4
    #bottom_h = left_h = left + width + 0.02
    
    gs2 = GridSpec(3, 3)
    gs2.update(left=0.53, right=0.95, hspace=0.16,wspace=0.08)
    axScatter = pl.subplot(gs2[:-1, :-1])
    axScatter.set_xlabel('ln(1+A_strength)')
    axScatter.set_ylabel('F_correlation')
    axScatter.set_title('Anatomical strength vs. Functional correlation')
    axScatter.set_xlim(0,1)
    axScatter.set_ylim(-1,1)
    
    #rotate histogram of y-axis of scatter plot
    axHistY = pl.subplot(gs2[:-1, -1])
    axHistY.set_ylim(-1, 1)
    
    #histogram of x-axis of scatter plot
    axHistX = pl.subplot(gs2[-1, :-1])
    axHistX.set_xlim(0, 1)
 
    # Remove labels from histogram edges touching scatter plot
    axHistX.xaxis.set_major_formatter(nullfmt)
    axHistY.yaxis.set_major_formatter(nullfmt)
    
    # Draw scatter plot
    axScatter.scatter(x, y, marker='o', color = 'darkblue', edgecolor='none', s=50, alpha=0.5)
    
    # Draw x-axis histogram
    axHistX.hist(x, num_x_bins, ec='red',fc='green', histtype='bar')
    
    # Draw y-axis histogram
    axHistY.hist(y, num_y_bins, ec='red', fc='red', histtype='step', orientation='horizontal')
    
    out_image = os.path.join(os.getcwd(),'result_plot.png')
    pl.savefig(out_image)
    #pl.show()    
    return out_image


def Benchmark_Calculation(subject,glm,full_cor,combined):

    sinker = pe.Node(DataSink(), name='sinker') 
    
    #remove the first several volumes
    
    remove_vol = pe.MapNode(fsl.ExtractROI(t_min=3, t_size=-1),iterfield=['in_file'],name="remove_volumes")
    remove_vol.inputs.in_file = fun_file[subject]   
    
    """
    Preprocessing-realignment: The first image of each session is aligned to the first image of the first session, and then to the mean of each session
    """
    realigner = pe.Node(interface = spm.Realign(),name = 'realign')
    #realigner.inputs.in_files = fun_file[subject]
    realigner.inputs.register_to_mean = True
    realigner.inputs.jobtype = 'estwrite'
    realigner.inputs.write_which = [2,1]
    
    """
    Preprocessing-coregistration: ref = mean functional from realignment & source = anatomical image
    """
    coreg = pe.Node(interface = spm.Coregister(),name = 'coregister')
    #coreg.inputs.target = realigner.outputs.mean_image    #mean functional image
    coreg.inputs.source = anat_file[subject]    #anatomical image
    coreg.inputs.jobtype = 'estimate'     #without resclice
    
    """
    Preprocessing-slice timing
    """    
    slicetime = pe.Node(interface = spm.SliceTiming(), name = "SliceTiming")
    slicetime.inputs.num_slices = 30
    slicetime.inputs.time_repetition = 2
    slicetime.inputs.time_acquisition = slicetime.inputs.time_repetition - (slicetime.inputs.time_repetition/slicetime.inputs.num_slices)
    slicetime.inputs.slice_order = range(1,31,1)
    slicetime.inputs.ref_slice = 1    
    
    
    """
    Noise ROI from the highest TSNR
    """    
    tsnr = pe.MapNode(TSNR(regress_poly=2),iterfield=['in_file'],name='tsnr')
    getthresh = pe.MapNode(interface=fsl.ImageStats(op_string='-p 98'),iterfield=['in_file'],
                           name='getthreshold')
    threshold_stddev = pe.MapNode(fsl.Threshold(), iterfield=['in_file','thresh'],name='threshold')    
    
    """
    noise mask with white matter & csf
    """
    threshCSFseg = pe.Node(interface = fsl.ImageMaths(op_string = ' -thr .80 -uthr 1 -bin '),
                       name = 'threshcsfsegmask')
                       
    threshWMseg = pe.Node(interface = fsl.ImageMaths(op_string = ' -thr .80 -uthr 1 -bin '),
                       name = 'threshwmsegmask')                   
    
    threshbothseg = pe.Node(interface = fsl.BinaryMaths(),name = "theshwmcsfmask")
    threshbothseg.inputs.operation = 'add'
    threshbothseg.inputs.terminal_output = 'file'                 
    
    reslicemask = pe.Node(interface = spmu.ResliceToReference(), name = "ReslicedMask")
    reslicemask.inputs.interpolation = 0    
    
    """
    Compcor algorithms
    """    
    compcor = pe.MapNode(util.Function(input_names=['realigned_file',
                                                 'noise_mask_file',
                                                 'num_components',
                                                 'motion_file'],
                                     output_names=['noise_components'],
                                     function=extract_noise_components),iterfield = ['realigned_file','motion_file'],
                       name='compcorr')
    compcor.inputs.num_components = 6
    
    """
    bandpass filter
    """    
    remove_noise = pe.MapNode(fsl.FilterRegressor(filter_all=True),iterfield = ['in_file','design_file'],name='remove_noise')
    bandpass_filter = pe.MapNode(fsl.TemporalFilter(),iterfield = ['in_file'],name='bandpass_filter')
    bandpass_filter.inputs.highpass_sigma = 0.01
    bandpass_filter.inputs.lowpass_sigma = 0.1      
    bandpass_filter.inputs.output_datatype = 'short'    
    bandpass_filter.inputs.output_type = 'NIFTI'                    
    
    """
    Preprocessing-segmentation
    """
    seg = pe.Node(interface = spm.Segment(),name = "Segmentation")
    seg.inputs.data = anat_file[subject]
    seg.inputs.tissue_prob_maps = ['/SCR/datasets-for-tracer-validation/macaque_pri/grey.nii',
                                   '/SCR/datasets-for-tracer-validation/macaque_pri/white.nii',
                                   '/SCR/datasets-for-tracer-validation/macaque_pri/csf.nii']
    seg.inputs.gm_output_type = [False,False,True]  #native space
    seg.inputs.csf_output_type = [False,False,True]   #native space
    seg.inputs.wm_output_type = [False,False,True]    #native space        
    seg.inputs.gaussians_per_class = [2,2,2,4]
    seg.inputs.affine_regularization = 'none'
    seg.inputs.warping_regularization = 1
    seg.inputs.warp_frequency_cutoff = 25
    seg.inputs.bias_regularization = 0.0001
    seg.inputs.bias_fwhm = 60
    seg.inputs.sampling_distance = 3
    seg.inputs.save_bias_corrected = True
  
    """
    Preprocessing-Normalization
    """
    norm = pe.Node(interface = spm.Normalize(),name = "Normalization")
    norm.inputs.template = os.path.abspath('/SCR/datasets-for-tracer-validation/tracer-validation/F99/F99_edit_13_cut_T1.nii')
    norm.inputs.jobtype = "estwrite"
    norm.inputs.affine_regularization_type = 'none'
    norm.inputs.write_voxel_sizes = [1, 1, 1] 
    norm.inputs.write_bounding_box = [[-35,-56,-28],[36,38,31]]
   
    """
    Extract one volume
    """    
    merge_session = pe.Node(interface =fsl.Merge(), name = 'MergeSession')
    merge_session.inputs.dimension = 't'
    merge_session.inputs.output_type = 'NIFTI'
    
    extract_ref = pe.Node(interface=fsl.ExtractROI(t_min=0, t_size=1),name = 'extractref')
    
    """
    Preprocessing-Smoothing
    """
    #set up a smoothing node
    smoother = pe.Node(interface = spm.Smooth(),name='smoother')
    smoother.inputs.fwhm = [3,3,3]

    """
    Functional connectivity
    """    
    ratlas = pe.Node(interface = spmu.ResliceToReference(),name = "ReslicedAtlas")
    ratlas.inputs.interpolation = 0
    ratlas.inputs.in_files = '/SCR/datasets-for-tracer-validation/Project/Markov_Parcellation_3D.nii'
    
    FConnect = pe.MapNode(util.Function(input_names=['timeseries_file', 'label_file','full_cor'],
                                       output_names=['out_ts_file','out_ts_mean_file','out_full_cor_mat','out_full_cor_pval'],
                                        function=extract_subrois),iterfield = ['timeseries_file'], name = 'fun_connect')                                                                    
            
    """
    Combine sessions and prerequisite normalization
    """
    Session_norm = pe.MapNode(util.Function(input_names=['ts_mean_file'],
                                       output_names=['norm_mean_ts'],
                                        function=session_norm),iterfield = ['ts_mean_file'], name = 'session_norm')
    
    Session_cmb = pe.Node(interface = util.Function(input_names=['ts_norm_list','full_cor'],
                                        output_names=['cmb_file','out_full_cor_mat','out_full_cor_pval'],function = session_cmb),name = 'session_cmb')
                                        
    if full_cor:    
        FConnect.inputs.full_cor = True
        Session_cmb.inputs.full_cor = True
    else:
        FConnect.inputs.full_cor = False
        Session_cmb.inputs.full_cor = False
    
    corrected_pval = pe.MapNode(util.Function(input_names=['full_pval'],
                                       output_names=['out_full_cor_pval'],
                                        function=pval_correction),iterfield = ['full_pval'], name = 'adjust_pval')    
        
    """
    Result plot
    """
    PlotResult = pe.MapNode(util.Function(input_names=['fun_cor_mat','full_pval'],output_names=['out_image'],function=ResultPlot),iterfield = ['fun_cor_mat','full_pval'],name = 'Result_ploting')
    #create and configure a workflow
    workflow_name = ''
    if glm ==0:   # no glm
        if full_cor:
            workflow_name = 'preproc_'+subject+'_noglm'+'_fullcor'
        else:
            workflow_name = 'preproc_'+subject+'_noglm'+'_partial'
    
    elif glm == 1:  # WM
        if full_cor:
            workflow_name = 'preproc_'+subject+'_wmglm'+'_fullcor'
        else:
            workflow_name = 'preproc_'+subject+'_wmglm'+'_partial'
    elif glm == 2:  #CSF
        if full_cor:
            workflow_name = 'preproc_'+subject+'_csfglm'+'_fullcor'
        else:
            workflow_name = 'preproc_'+subject+'_csfglm'+'_partial'
    elif glm == 3:  #WM&CSF
        if full_cor:
            workflow_name = 'preproc_'+subject+'_wmcsfglm'+'_fullcor'
        else:
            workflow_name = 'preproc_'+subject+'_wmcsfglm'+'_partial'
    else:   # TSNR
        if full_cor:
            workflow_name = 'preproc_'+subject+'_tsnr'+'_fullcor'
        else:
            workflow_name = 'preproc_'+subject+'_tsnr'+'_partial'
            
    
    if combined:
        workflow_name = workflow_name + '_cmb'

        
    workflow = pe.Workflow(name=workflow_name)
    workflow.base_dir='/SCR/datasets-for-tracer-validation/tracer-validation/'+subject+'/reoriented/'

    sinker.inputs.base_directory = op.abspath('/SCR/datasets-for-tracer-validation/tracer-validation/'+subject+'/reoriented/')
    
    workflow.connect(remove_vol,'roi_file',realigner,'in_files')    
    
    workflow.connect(realigner, 'realigned_files',sinker, 'realigned')
    workflow.connect(realigner, 'realignment_parameters', sinker, 'realigned.@parameters')
    
    #realign -- coreg
    workflow.connect(realigner, "mean_image", coreg, "target")
    #realign--slice timing
    workflow.connect(realigner, "realigned_files", slicetime, "in_files")
        
    if glm==0: #no glm 
        
        workflow.connect(coreg, "coregistered_source", seg, "data")  
        workflow.connect(coreg,"coregistered_source",norm,"source")
        workflow.connect(seg,"transformation_mat",norm,"parameter_file")
        workflow.connect(slicetime,"timecorrected_files",bandpass_filter,"in_file")
        
    elif glm == 1:  # WM
        workflow.connect(coreg, "coregistered_source", seg, "data")  
        #seg--norm
        workflow.connect(coreg,"coregistered_source",norm,"source")
        workflow.connect(seg,"transformation_mat",norm,"parameter_file")
        #seg-sink        
        workflow.connect(seg,'native_wm_image', threshWMseg,'in_file')   
        workflow.connect(threshWMseg,'out_file',sinker,'wm_mask')

        workflow.connect(realigner,'mean_image',reslicemask,'target')
        workflow.connect(threshWMseg,'out_file',reslicemask,'in_files')
        workflow.connect(reslicemask,'out_files',sinker,'reslicedmask')
        
        workflow.connect(slicetime, 'timecorrected_files',compcor, 'realigned_file')
        workflow.connect(reslicemask, 'out_files', compcor, 'noise_mask_file')
                            
        workflow.connect(realigner,'realignment_parameters',compcor,'motion_file')
                            
        workflow.connect(slicetime, 'timecorrected_files',remove_noise, 'in_file')
        workflow.connect(compcor, 'noise_components',remove_noise, 'design_file')      
        workflow.connect(remove_noise, 'out_file', bandpass_filter, 'in_file')
        
    elif glm == 2:  #CSF
        workflow.connect(coreg, "coregistered_source", seg, "data")  
        #seg--norm
        workflow.connect(coreg,"coregistered_source",norm,"source")
        workflow.connect(seg,"transformation_mat",norm,"parameter_file")
        #seg-sink        
        workflow.connect(seg,'native_csf_image', threshCSFseg,'in_file')   
        workflow.connect(threshCSFseg,'out_file',sinker,'csf_mask')

        workflow.connect(realigner,'mean_image',reslicemask,'target')
        workflow.connect(threshCSFseg,'out_file',reslicemask,'in_files')
        workflow.connect(reslicemask,'out_files',sinker,'reslicedmask')
        
        workflow.connect(slicetime, 'timecorrected_files',compcor, 'realigned_file')
        workflow.connect(reslicemask, 'out_files', compcor, 'noise_mask_file')
                            
        workflow.connect(realigner,'realignment_parameters',compcor,'motion_file')
                            
        workflow.connect(slicetime, 'timecorrected_files',remove_noise, 'in_file')
        workflow.connect(compcor, 'noise_components',remove_noise, 'design_file')
        
        workflow.connect(remove_noise, 'out_file', bandpass_filter, 'in_file')
        
    elif glm == 3:  #WM&CSF
        workflow.connect(coreg, "coregistered_source", seg, "data")  
        #seg--norm
        workflow.connect(coreg,"coregistered_source",norm,"source")
        workflow.connect(seg,"transformation_mat",norm,"parameter_file")
        workflow.connect(seg,'native_csf_image', threshCSFseg,'in_file')     
        workflow.connect(threshCSFseg,'out_file',sinker,'csf_mask')
        
        workflow.connect(seg,'native_wm_image', threshWMseg,'in_file')   
        workflow.connect(threshWMseg,'out_file',sinker,'wm_mask')
        
        workflow.connect(threshCSFseg,'out_file',threshbothseg,'in_file')
        workflow.connect(threshWMseg,'out_file',threshbothseg,'operand_file')
        workflow.connect(threshbothseg,'out_file',sinker,'both_mask')
        workflow.connect(realigner,'mean_image',reslicemask,'target')
        workflow.connect(threshbothseg,'out_file',reslicemask,'in_files')
        workflow.connect(reslicemask,'out_files',sinker,'reslicedmask')
        
        workflow.connect(slicetime, 'timecorrected_files',compcor, 'realigned_file')
        workflow.connect(reslicemask, 'out_files', compcor, 'noise_mask_file')
                            
        workflow.connect(realigner,'realignment_parameters',compcor,'motion_file')
                            
        workflow.connect(slicetime, 'timecorrected_files',remove_noise, 'in_file')
        workflow.connect(compcor, 'noise_components',remove_noise, 'design_file')
        
        workflow.connect(remove_noise, 'out_file', bandpass_filter, 'in_file')        
    else:       #tsnr      
        workflow.connect(slicetime, 'timecorrected_files', tsnr, 'in_file')
        workflow.connect(tsnr, 'stddev_file', threshold_stddev, 'in_file')
        workflow.connect(tsnr, 'stddev_file', getthresh, 'in_file')
        workflow.connect(getthresh, 'out_stat', threshold_stddev, 'thresh')
        workflow.connect(threshold_stddev,'out_file',sinker,'noise_mask')
        workflow.connect(slicetime, 'timecorrected_files',compcor, 'realigned_file')
        workflow.connect(threshold_stddev, 'out_file', compcor, 'noise_mask_file')
                            
        workflow.connect(realigner,'realignment_parameters',compcor,'motion_file')
                            
        workflow.connect(tsnr, 'detrended_file',remove_noise, 'in_file')
        workflow.connect(compcor, 'noise_components',remove_noise, 'design_file')
        
        workflow.connect(remove_noise, 'out_file', bandpass_filter, 'in_file')
         
        workflow.connect(coreg, "coregistered_source", seg, "data")       
        #seg--norm
        workflow.connect(coreg,"coregistered_source",norm,"source")
        workflow.connect(seg,"transformation_mat",norm,"parameter_file")
    
   
    workflow.connect(seg,'bias_corrected_image',sinker,'seg_bias_corrected')
    workflow.connect(seg,'inverse_transformation_mat',sinker,'seg_inverse_trans')
    workflow.connect(seg,'native_gm_image',sinker,'seg_mask_gm')
    workflow.connect(seg,'native_wm_image',sinker,'seg_mask_wm')
    workflow.connect(seg,'native_csf_image',sinker,'seg_mask_csf')
 
    workflow.connect(bandpass_filter,"out_file",norm,"apply_to_files")    
    
    workflow.connect(norm,"normalized_files",sinker,"norm")       
    workflow.connect(norm,"normalized_files",smoother,"in_files")   
    
    workflow.connect(norm,'normalized_files',merge_session,'in_files')
    workflow.connect(merge_session,'merged_file',sinker,'merged_file')
    workflow.connect(merge_session,'merged_file',extract_ref,'in_file')
    
    workflow.connect(extract_ref,'roi_file',sinker,'refvol')
    workflow.connect(extract_ref,'roi_file',ratlas,'target')
    workflow.connect(ratlas,'out_files',sinker,'RAtlas')     
    
    if combined:          
        workflow.connect(ratlas,'out_files',FConnect,"label_file")
        workflow.connect(smoother,"smoothed_files",FConnect,"timeseries_file")
        
        workflow.connect(FConnect,"out_ts_file",sinker,"whole_timeseries")
        workflow.connect(FConnect,"out_ts_mean_file",sinker,"mean_timeseries")
        workflow.connect(FConnect,"out_full_cor_mat",sinker,"full_cor_mat")
        
        workflow.connect(FConnect,'out_ts_mean_file',Session_norm,'ts_mean_file')    
        workflow.connect(Session_norm,'norm_mean_ts',sinker,'norm_mean_ts')
        workflow.connect(Session_norm,'norm_mean_ts',Session_cmb,'ts_norm_list')
        
        workflow.connect(Session_cmb,"out_full_cor_mat",PlotResult,"fun_cor_mat")
        workflow.connect(Session_cmb,'out_full_cor_pval',corrected_pval,'full_pval')
        workflow.connect(corrected_pval,'out_full_cor_pval',sinker,'corrected_pval')
        workflow.connect(corrected_pval,'out_full_cor_pval',PlotResult,'full_pval')
        workflow.connect(PlotResult,"out_image",sinker,"plot") 
    else:
        workflow.connect(ratlas,'out_files',FConnect,"label_file")
        workflow.connect(smoother,"smoothed_files",FConnect,"timeseries_file")
        
        workflow.connect(FConnect,"out_ts_file",sinker,"whole_timeseries")
        workflow.connect(FConnect,"out_ts_mean_file",sinker,"mean_timeseries")
        workflow.connect(FConnect,"out_full_cor_mat",sinker,"full_cor_mat")
        
        workflow.connect(FConnect,"out_full_cor_mat",PlotResult,"fun_cor_mat")        
        workflow.connect(FConnect,'out_full_cor_pval',corrected_pval,'full_pval')
        workflow.connect(corrected_pval,'out_full_cor_pval',sinker,'corrected_pval')
        workflow.connect(corrected_pval,'out_full_cor_pval',PlotResult,'full_pval')
        workflow.connect(PlotResult,"out_image",sinker,"plot") 
    #visualize the workflow
    workflow.write_graph()               
    workflow.run()  

# subject name, glm/not,full/partial
"""
parameters:
1. Subject id
2. GLM options {0,1,2,3,4}
    0: no glm
    1: noise ROI from white matter(WM) only
    2: noise ROI from CSF alone
    3: noise ROI from WM&CSF 
    4: noise ROI from functional image with highest viability of TSNR
3. Correlation
    True: full correlation
    False: partial correlation
    
4. Session-combined correlation(True)/Session-separate correlation(False)  
"""

glm = [1,2,3]
cor = [True,False]
session = [True,False]
for loop1 in glm:
    for loop2 in cor:
        for loop3 in session:         
            Benchmark_Calculation("M5",loop1,loop2,loop3)