In [ ]:
%pylab inline
from __future__ import print_function, division
import os
import os.path as osp
import matplotlib.pyplot as plt
from warnings import warn
import datetime, time
import glob as gb
from six import string_types
import argparse
import json
import time
import numpy as np
import scipy.linalg as lin
import scipy.stats as sst
import nibabel as nib
import nipy
from nilearn._utils import concat_niimgs
In [ ]:
import importlib
from smpce_data_to_corr import get_params
import utils._utils as ucr
import utils.setup_filenames as suf
import correlation2results as c2r
#import tests.test_smpce_data_to_corr as tts
ucr = reload(ucr)
suf = reload(suf)
c2r = reload(c2r)
In [ ]:
param_dir = osp.abspath('.')
assert param_dir == '/home/jb/code/simpace/simpace'
# or assert on neuro
params = get_params(param_dir)
basedir = '/home/jb/data/simpace/data/rename_files'
In [ ]:
djdata = params['data']
djlayo = params['layout']
nb_sess = params['data']['nb_sess']
nb_sub = params['data']['nb_sub']
nb_run = params['data']['nb_run']
print(nb_sub, nb_sess, nb_run)
In [ ]:
def _get_common_labels(conds, idx0=0):
cond0 = conds.keys()[idx0]
print(cond0)
nb_sess = len(conds[cond0])
lsets = []
for sess in range(nb_sess):
lsets.append( set((np.load(conds[cond0][sess]))['labels_sig']) )
return set.intersection(*lsets)
In [ ]:
conds = c2r._get_signals_filenames(basedir, params)
aaa = _get_common_labels(conds)
bbb = _get_common_labels(conds, idx0=3)
ccc = _get_common_labels(conds, idx0=2)
assert aaa == bbb
assert aaa == ccc
In [ ]:
print("\n".join(conds['none'][8:]))
conds.keys()
In [ ]:
conds = c2r._get_signals_filenames(basedir, params)
print(conds.keys())
common_labels = c2r._get_common_labels(conds)
assert common_labels == _get_common_labels(conds, idx0=3)
In [ ]:
conds_arr, stored_param = c2r.compute_corr_mtx(conds, common_labels)
In [ ]:
c2r.save_results(basedir, "results", "all_corrections", params)
In [ ]:
tmp = np.load(conds['med'][7])
#tmp.keys()
In [ ]:
tmp['labels_sig'].shape
tmp['arr_sig_f'].shape
In [ ]:
In [ ]:
print(conds_arr.keys())
print(conds_arr['high'].shape)
for cond in ordered_conds(): # [none_c, low_c, med_c, high_c]:
assert np.all(conds_arr[cond] <= 1.) and np.all(conds_arr[cond] >= -1.)
In [ ]:
f, axes = plt.subplots(1, 4)
arr = [conds_arr[c] for c in ordered_conds()]
for idx, ax in enumerate(axes):
ax.imshow(arr[idx].mean(axis=0), aspect='equal', interpolation='nearest',
vmin=-.5, vmax=1.)
In [ ]:
f, axes = plt.subplots(1, 4)
arr = [conds_arr[c] for c in ordered_conds()]
for idx, ax in enumerate(axes):
ax.imshow(arr[idx].mean(axis=0) - arr[0].mean(axis=0),
aspect='equal', interpolation='nearest',
vmin=-.5, vmax=.5)
In [ ]:
a0 = conds_arr['none'].mean(axis=0)
for idx, a in enumerate(arr):
print((a.mean(axis=0) - a0).min(), (a.mean(axis=0) - a0).max())
In [ ]:
f, axes = plt.subplots(1, 4)
arr = [conds_arr[c] for c in ordered_conds()]
for idx, ax in enumerate(axes):
ax.imshow(arr[idx].std(axis=0),
aspect='equal', interpolation='nearest') #, vmin=0., vmax=.5)
In [ ]:
for idx, a in enumerate(arr):
print((a.std(axis=0)).mean())