In [ ]:
from __future__ import division
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
from pybasicbayes.distributions import Regression
from pybasicbayes.util.text import progprint_xrange
from autoregressive.distributions import AutoRegression
from scipy.io import loadmat
import glob, os
from scipy.io import savemat # store results for comparison with Matlab code
from scipy.linalg import solve_discrete_lyapunov as dtlyap # solve discrete-time Lyapunov equation
absolute_code_path = '/home/marcel/Desktop/Projects/Stitching/code/pyLDS_dev/'
os.chdir(absolute_code_path +'pylds')
from pylds.models import LDS, DefaultLDS
from pylds.distributions import Regression_diag, AutoRegression_input
from pylds.obs_scheme import ObservationScheme
from pylds.user_util import gen_pars, rand_rotation_matrix, init_LDS_model, collect_LDS_stats
def update(model):
model.EM_step()
return model.log_likelihood()
In [ ]:
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)
def update(model):
model.EM_step()
return model.log_likelihood()
eps = np.log(1.01)
max_iter = 25
#initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']
initialisers = ['params', 'params_flip', 'random']
for i in idx_exps:
##################
# load the data #
##################
filename = filenames[i]
os.chdir('../data/')
loadfile = np.load(filename)
data = loadfile['y']
T,p = data.shape
n = loadfile['x'].shape[1]
obs_scheme = ObservationScheme(p=p, T=T,
sub_pops=tuple([item for item in loadfile['sub_pops']]),
obs_pops=loadfile['obs_pops'],
obs_time=loadfile['obs_time'])
pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)
tmp = {}
for j in range(len(pars_true.keys())):
tmp[str(pars_true.keys()[j])] = pars_true.values()[j]
pars_true = tmp
print('dataset #', i)
print('(T, p, n, eps) = ', (T, p, n, eps))
####################
# pick initialiser #
####################
for initialiser in initialisers:
print('initialiser ', initialiser)
if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
os.chdir('../init/')
num_reps = 1
loadfile = np.load('init_'+filename)
initkey = initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
pars_init = loadfile[initkey].reshape(1,)[0]
tmp = {}
for j in range(len(pars_init.keys())):
tmp[str(pars_init.keys()[j])] = pars_init.values()[j]
pars_init = tmp
pars_init['R'] = np.diag(pars_init['R'])
# for SSID-derived initialisers, also try out flipping parts of C
if initialiser[-5:]=='_flip':
idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1])
pars_init['C'][idx0_no_overlap,:] *= -1
elif initialiser=='random':
num_reps = 10
else:
raise Exception('unexpected initialiser!')
for rep in range(num_reps):
print( 'run #' + str(rep+1) +'/' +str(num_reps) )
if initialiser=='random':
pars_init, _ = gen_pars(n, p, u_dim=0,
pars_in=None,
obs_scheme=obs_scheme,
gen_A='diagonal', lts=0.99 * np.ones((n,)),
gen_B='random',
gen_Q='identity',
gen_mu0='random',
gen_V0='identity',
gen_C='random',
gen_d='mean',
gen_R='fractionObserved',
diag_R_flag=True,
x=None, y=data.T, u=None)
###################
# EM cycles #
###################
try:
# get EM-step results after m iterations
model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation
print 'fitting #' + str(i)
likes = [-np.inf]
for t in progprint_xrange(max_iter):
likes.append(update(model))
if likes[-1]-likes[-2] < eps:
break
stats_hat,pars_hat = collect_LDS_stats(model)
# get EM-step results from true parameters
model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
model.E_step()
stats_true,_ = collect_LDS_stats(model)
model.M_step()
broken = False
y_out = model.states_list[0].data
x_out = model.states_list[0].stateseq
Pi = dtlyap(pars_true['A'], pars_true['Q'])
Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))
except:
print('')
print('############')
print('#RUN BROKE!#')
print('############')
print('')
broken = True
y_out = model.states_list[0].data
x_out = []
pars_hat, stats_h, stats_true = [],[],[]
Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0
###################
# store results #
###################
print('finished in ' + str(len(likes)-1) + ' many steps.')
os.chdir('../fits/')
save_file = initialiser + '_idx' + str(rep) + '_' + filename
save_file_m = {'ifBroken':broken,
'x': x_out,
'y': y_out,
'u' : [],
'll' : likes,
'T' : model.states_list[0].T,
'Trial': len(model.states_list),
'ifUseB':False,
'ifUseA':True,
'epsilon':eps,
'truePars':pars_true,
'initPars':pars_init,
'estPars': pars_hat,
'stats_h': stats_hat,
'stats_true': stats_true,
'Pi':Pi,
'Pi_h':Pi_h,
'Pi_t':Pi_t,
'Pi_t_h':Pi_t_h,
'obsScheme' : obs_scheme}
savemat(save_file,save_file_m) # does the actual saving
np.savez(save_file,
broken=broken,
x=x_out,
y=y_out,
ll=likes,
T=model.states_list[0].T,
Trial=len(model.states_list),
ifUseA=True,
ifUseB=False,
epsilon=eps,
initPars=pars_init,
truePars=pars_true,
estPars =pars_hat,
stats_h = stats_hat,
stats_true = stats_true,
Pi=Pi,
Pi_h=Pi_h,
Pi_t=Pi_t,
Pi_t_h=Pi_t_h,
sub_pops=obs_scheme.sub_pops,
obs_time=obs_scheme.obs_time,
obs_pops=obs_scheme.obs_pops)
In [ ]:
relative_data_path = '../../../results/cosyne_poster/simulation_1/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)
def update(model):
model.EM_step()
return model.log_likelihood()
eps = np.log(1.01)
max_iter = 25
initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']
for i in idx_exps:
##################
# load the data #
##################
filename = filenames[i]
os.chdir('../data/')
loadfile = np.load(filename)
data = loadfile['y']
T,p = data.shape
n = loadfile['x'].shape[1]
sub_pops = tuple([item for item in loadfile['sub_pops']])
obs_pops = loadfile['obs_pops']
pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)
tmp = {}
for j in range(len(pars_true.keys())):
tmp[str(pars_true.keys()[j])] = pars_true.values()[j]
pars_true = tmp
print('dataset #', i)
print('(T, p, n, eps) = ', (T, p, n, eps))
####################
# pick initialiser #
####################
for initialiser in initialisers:
print('initialiser ', initialiser)
if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
os.chdir('../init/')
num_reps = 1
loadfile = np.load('init_'+filename)
initkey = initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
pars_init = loadfile[initkey].reshape(1,)[0]
tmp = {}
for j in range(len(pars_init.keys())):
tmp[str(pars_init.keys()[j])] = pars_init.values()[j]
pars_init = tmp
pars_init['R'] = np.diag(pars_init['R'])
# for SSID-derived initialisers, also try out flipping parts of C
if initialiser[-5:]=='_flip':
idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1])
pars_init['C'][idx0_no_overlap,:] *= -1
elif initialiser=='random':
num_reps = 10
else:
raise Exception('unexpected initialiser!')
for rep in range(num_reps):
print( 'run #' + str(rep+1) +'/' +str(num_reps) )
if initialiser=='random':
pars_init, _ = gen_pars(n, p, u_dim=0,
pars_in=None,
obs_scheme=obs_scheme,
gen_A='diagonal', lts=0.99 * np.ones((n,)),
gen_B='random',
gen_Q='identity',
gen_mu0='random',
gen_V0='identity',
gen_C='random',
gen_d='mean',
gen_R='fractionObserved',
diag_R_flag=True,
x=None, y=data.T, u=None)
###################
# EM cycles #
###################
try:
# get EM-step results from true parameters
model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
model.E_step()
stats_true,_ = collect_LDS_stats(model)
# get EM-step results after m iterations
model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation
print 'fitting #' + str(i)
likes = [-np.inf]
for t in progprint_xrange(max_iter):
likes.append(update(model))
if likes[-1]-likes[-2] < eps:
break
stats_hat,pars_hat = collect_LDS_stats(model)
broken = False
y_out = model.states_list[0].data
x_out = model.states_list[0].stateseq
Pi = dtlyap(pars_true['A'], pars_true['Q'])
Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))
except:
print('')
print('############')
print('#RUN BROKE!#')
print('############')
print('')
broken = True
y_out = []
x_out = []
pars_hat, stats_hat, stats_true = [],[],[]
Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0
###################
# store results #
###################
print('finished in ' + str(len(likes)-1) + ' many steps.')
os.chdir('../fits/')
save_file = initialiser + '_idx' + str(rep) + '_' + filename
save_file_m = {'ifBroken':broken,
'x': x_out,
'y': y_out,
'u' : [],
'll' : likes,
'T' : model.states_list[0].T,
'Trial': len(model.states_list),
'ifUseB':False,
'ifUseA':True,
'epsilon':eps,
'truePars':pars_true,
'initPars':pars_init,
'estPars': pars_hat,
'stats_h': stats_hat,
'stats_true': stats_true,
'Pi':Pi,
'Pi_h':Pi_h,
'Pi_t':Pi_t,
'Pi_t_h':Pi_t_h,
'obsScheme' : obs_scheme}
savemat(save_file,save_file_m) # does the actual saving
np.savez(save_file,
broken=broken,
ll=likes,
T=model.states_list[0].T,
Trial=len(model.states_list),
ifUseA=True,
ifUseB=False,
epsilon=eps,
initPars=pars_init,
truePars=pars_true,
estPars =pars_hat,
stats_h = stats_hat,
stats_true = stats_true,
Pi=Pi,
Pi_h=Pi_h,
Pi_t=Pi_t,
Pi_t_h=Pi_t_h,
sub_pops=obs_scheme.sub_pops,
obs_time=obs_scheme.obs_time,
obs_pops=obs_scheme.obs_pops)
In [ ]:
%matplotlib inline
relative_data_path = '../../../results/cosyne_poster/simulation_2/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)
eps = np.log(1.01)
max_iter = 50
initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']
for idx in [0,2,5,9]:
##################
# load the data #
##################
print('idx', idx)
filename = 'LDS_save_idx' + str(idx) + '.npz'
print(filename)
os.chdir('../data/')
loadfile = np.load(filename)
data = loadfile['y']
T,p = data.shape
n = loadfile['x'].shape[1]
pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)
tmp = {}
for j in range(len(pars_true.keys())):
tmp[str(pars_true.keys()[j])] = pars_true.values()[j]
pars_true = tmp
print('dataset #', idx)
print('(T, p, n, eps) = ', (T, p, n, eps))
####################
# pick initialiser #
####################
os.chdir('../init/')
initfiles = glob.glob("*_LDS_save_idx" + str(idx) + ".npz")
num_prots = len(initfiles)
for prot in [0,12]: #range(num_prots):
initfile = initfiles[prot]
print('prot ' + str(prot) + '/' + str(num_prots))
os.chdir('../init/')
loadinit = np.load(initfile)
obs_scheme = ObservationScheme(p=p, T=T,
sub_pops=tuple([item for item in loadinit['sub_pops']]),
obs_pops=loadinit['obs_pops'],
obs_time=loadinit['obs_time'])
for initialiser in initialisers:
print('initialiser ', initialiser)
if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
num_repets = 1
initkey = initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
pars_init = loadinit[initkey].reshape(1,)[0]
tmp = {}
for j in range(len(pars_init.keys())):
tmp[str(pars_init.keys()[j])] = pars_init.values()[j]
pars_init = tmp
pars_init['R'] = np.diag(pars_init['R'])
# for SSID-derived initialisers, also try out flipping parts of C
if initialiser[-5:]=='_flip':
idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1])
pars_init['C'][idx0_no_overlap,:] *= -1
elif initialiser=='random':
num_repets = 5
else:
raise Exception('unexpected initialiser!')
for repet in range(num_repets):
print( 'run #' + str(repet+1) +'/' +str(num_repets) )
if initialiser=='random':
pars_init, _ = gen_pars(n, p, u_dim=0,
pars_in=None,
obs_scheme=obs_scheme,
gen_A='diagonal', lts=0.99 * np.ones((n,)),
gen_B='random',
gen_Q='identity',
gen_mu0='random',
gen_V0='identity',
gen_C='random',
gen_d='mean',
gen_R='fractionObserved',
diag_R_flag=True,
x=None, y=data.T, u=None)
###################
# EM cycles #
###################
likes = [-np.inf]
try:
# get EM-step results after m iterations
model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation
print 'fitting #' + str(idx)
for t in progprint_xrange(max_iter):
likes.append(update(model))
if likes[-1]-likes[-2] < eps:
break
stats_hat,pars_hat = collect_LDS_stats(model)
# get EM-step results from true parameters
model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
model.E_step()
stats_true,_ = collect_LDS_stats(model)
model.M_step()
broken = False
Pi = dtlyap(pars_true['A'], pars_true['Q'])
Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))
except:
print('')
print('############')
print('#RUN BROKE!#')
print('############')
print('')
broken = True
pars_hat, stats_hat, stats_true = [],[],[]
Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0
if not broken:
plt.figure(figsize=(25,25))
plt.subplot(1,3,1)
plt.imshow(np.cov(data.T),interpolation='none')
plt.subplot(1,3,2)
if initialiser=='random':
pars_init['Pi'] = dtlyap(pars_init['A'], pars_init['Q'])
R = pars_init['R']
else:
R = np.diag(pars_init['R'])
plt.imshow(pars_init['C'].dot(pars_init['Pi']).dot(pars_init['C'].T) + R,interpolation='none')
plt.subplot(1,3,3)
plt.imshow(pars_hat['C'].dot(Pi_h).dot(pars_hat['C'].T) + pars_hat['R'],interpolation='none')
plt.show()
###################
# store results #
###################
print('finished in ' + str(len(likes)-1) + ' many steps.')
os.chdir('../fits/')
save_file = initialiser + '_prot' + str(prot) + '_rep' + str(repet) + '_' + filename
save_file_m = {'ifBroken':broken,
'll' : likes,
'T' : T,
'Trial': 1,
'ifUseB':False,
'ifUseA':True,
'epsilon':eps,
'truePars':pars_true,
'initPars':pars_init,
'estPars': pars_hat,
'stats_h': stats_hat,
'stats_true': stats_true,
'Pi':Pi,
'Pi_h':Pi_h,
'Pi_t':Pi_t,
'Pi_t_h':Pi_t_h,
'sub_pops':obs_scheme.sub_pops,
'obs_time':obs_scheme.obs_time,
'obs_pops':obs_scheme.obs_pops}
savemat(save_file,save_file_m) # does the actual saving
np.savez(save_file,
broken=broken,
ll=likes,
T=T,
Trial=1,
ifUseA=True,
ifUseB=False,
epsilon=eps,
initPars=pars_init,
truePars=pars_true,
estPars =pars_hat,
stats_h = stats_hat,
stats_true = stats_true,
Pi=Pi,
Pi_h=Pi_h,
Pi_t=Pi_t,
Pi_t_h=Pi_t_h,
sub_pops=obs_scheme.sub_pops,
obs_time=obs_scheme.obs_time,
obs_pops=obs_scheme.obs_pops)
In [ ]:
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']
p, T = 1000, 6000
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size
n = 10
#tmp = np.random.choice(np.arange(p),size=p,replace=False)
#sub_pops = (np.sort(tmp[:p//2 + 2]), np.sort(tmp[p//2 - 2:]))
sub_pops = (np.arange(p//2+100), np.arange(p//2-100, p))
obs_pops = np.array((0,1))
obs_time = np.array((T//2,T))
obs_scheme = ObservationScheme(p, T, sub_pops, obs_pops, obs_time)
###################
# EM cycles #
###################
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')
neuron_shuffle = loadfile['neuron_shuffle']
data = data[:,neuron_shuffle]
initkey = 'params_naive'
pars_init = loadfile[initkey].reshape(1,)[0]
tmp = {}
for j in range(len(pars_init.keys())):
tmp[str(pars_init.keys()[j])] = pars_init.values()[j]
pars_init = tmp
pars_init['R'] = np.diag(pars_init['R'])
pars_init['V0'] = pars_init['Q']
D, V = np.linalg.eig(pars_init['A'])
if np.any(np.abs(D) > 1):
print(np.abs(D))
D /= np.maximum(np.abs(D), 1.0001)
print(np.abs(D))
pars_init['A'] = np.real(V.dot(np.diag(D).dot(np.linalg.inv(V))))
model = init_LDS_model(pars_init, data, obs_scheme) # set to initialisation
#stats_init,_ = collect_LDS_stats(model)
print 'fitting'
likes = [update(model) for _ in progprint_xrange(100)]
stats_hat,pars_hat = collect_LDS_stats(model)
In [ ]:
from scipy.io import savemat
broken = False
save_file = '/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter100'
eps = - np.inf
save_file_m = {'ifBroken':broken,
'll' : likes,
'T' : model.states_list[0].T,
'Trial': len(model.states_list),
'epsilon':eps,
'initPars':pars_init,
'estPars': pars_hat,
'stats_h': stats_hat,
'sub_pops' : obs_scheme.sub_pops,
'obs_pops' : obs_scheme.obs_pops,
'obs_time' : obs_scheme.obs_time}
savemat(save_file,save_file_m) # does the actual saving
np.savez(save_file,
broken=broken,
ll=likes,
T=model.states_list[0].T,
Trial=len(model.states_list),
epsilon=eps,
initPars=pars_init,
estPars =pars_hat,
stats_h = stats_hat,
sub_pops=obs_scheme.sub_pops,
obs_time=obs_scheme.obs_time,
obs_pops=obs_scheme.obs_pops)
In [ ]:
%matplotlib inline
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']
# get neuron_shuffle
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')
neuron_shuffle = loadfile['neuron_shuffle']
# get param EM fit
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter400.npz')
pars_hat = loadfile['estPars'].reshape(1,)[0]
likes = loadfile['ll']
p, T = spikes[0][0].shape
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size
n = 10
data = data[:,neuron_shuffle]
initkey = 'params_naive'
import scipy as sp
import matplotlib.pyplot as plt
Pi = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], pars_hat['Q'])
plt.figure(1,figsize=(15,22))
try:
Pi_true = sp.linalg.solve_discrete_lyapunov(pars_true['A'], pars_true['Q'])
tmp1 = pars_true['C'].dot(Pi_true.dot(pars_true['C'].transpose())) + np.diag(pars_true['R'])
except:
cov_all = np.cov(np.hstack([data[1:], data[:-1]]).T)
tmp1 = cov_all[np.ix_(np.arange(p), np.arange(p))]
tmp2 = pars_hat['C'].dot(Pi.dot(pars_hat['C'].transpose())) + pars_hat['R']
m = np.min((tmp1-np.diag(np.diag(tmp1))).min(),(tmp2-np.diag(np.diag(tmp2))).min())
M = np.max((tmp1-np.diag(np.diag(tmp1))).max(),(tmp2-np.diag(np.diag(tmp2))).max())
plt.subplot(2,3,1)
plt.imshow(tmp1-np.diag(np.diag(tmp1)),interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true instantaneous covs')
plt.subplot(2,3,2)
plt.imshow(tmp2-np.diag(np.diag(tmp2)), interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated instantaneous covs')
plt.subplot(2,3,3)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')
try:
tmp1 = pars_true['C'].dot(pars_true['A']).dot(Pi_true.dot(pars_true['C'].transpose()))
except:
tmp1 = cov_all[np.ix_(np.arange(0, p), np.arange(p+1 , 2*p))]
tmp2 = pars_hat['C'].dot(pars_hat['A']).dot(Pi.dot(pars_hat['C'].transpose()))
m = np.min(tmp1.min(),tmp2.min())
M = np.max(tmp1.max(),tmp2.max())
plt.subplot(2,3,4)
plt.imshow(tmp1,interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true time_lagged covs')
plt.subplot(2,3,5)
plt.imshow(tmp2, interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated time-lagged covs')
plt.subplot(2,3,6)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')
plt.figure(2,figsize=(15,15))
plt.subplot(2,2,1)
plt.plot(pars_hat['d'])
try:
plt.plot(pars_true['d'])
except:
pass
plt.legend(['true', 'est'])
plt.title('d')
plt.subplot(2,2,2)
plt.plot(pars_hat['R'])
try:
plt.plot(pars_true['R'])
except:
pass
plt.legend(['true', 'est'])
plt.title('R')
plt.subplot(2,2,3)
plt.plot(np.sort(np.linalg.eig(pars_hat['A'])[0]))
try:
plt.plot(np.sort(np.linalg.eig(pars_true['A'])[0]))
except:
pass
plt.legend(['true', 'est'])
plt.title('eig(A)')
plt.title('eigenvalues of A')
plt.show()
""" second batch of figures"""
covy_h= np.dot( np.dot(pars_hat['C'], Pi),pars_hat['C'].transpose()) + pars_hat['R']
try:
covy_t= np.dot(np.dot(pars_true['C'], Pi_true),pars_true['C'].transpose()) + np.diag(pars_true['R'])
covy_tl_t=(np.dot(np.dot(pars_true['C'],np.dot(pars_true['A'], Pi_true)),pars_true['C'].transpose()))
plot_truth = True
except:
plot_truth = False
y_tl = np.zeros([2*p,T-1])
y_tl[range(p),:] = data[range(0,T-1),:].T
y_tl[range(p,2*p),:] = data[range(1,T),:].T
covy = np.cov(y_tl)
covy_e= covy[np.ix_(range(p),range(p))]
covy_tl_e= covy[np.ix_(range(p,2*p),range(0,p))]
sub_pops = loadfile['sub_pops']
covy_tl_h= np.dot(np.dot(pars_hat['C'], np.dot(pars_hat['A'],Pi)), pars_hat['C'].transpose())
idx_stitched = np.ones([p,p],dtype = bool)
for i in range(len(sub_pops)):
if len(sub_pops[i])>0:
idx_stitched[np.ix_(sub_pops[i],sub_pops[i])] = False
plt.imshow(idx_stitched,interpolation='none')
plt.figure(3, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_e[np.invert(idx_stitched)], covy_t[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. true')
plt.ylabel('instantaneous')
plt.subplot(1,3,2)
plt.plot(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_t[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
plt.title('obs. true vs. obs. stitched')
plt.figure(4, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_t[np.invert(idx_stitched)], '.')
plt.ylabel('time-lagged')
plt.title('obs. emp vs. obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_tl_t[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
plt.title('obs. non-observed true vs. obs. stitched')
plt.figure(5, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_e[idx_stitched], covy_t[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_e[idx_stitched], covy_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_t[idx_stitched], covy_h[idx_stitched], '.')
plt.title('non-obs. true vs. non-obs. titched')
plt.figure(6, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_tl_e[idx_stitched], covy_tl_t[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_tl_t[idx_stitched], covy_tl_h[idx_stitched], '.')
plt.title('non-obs. true vs. non-obs. stitched')
print('corr(y_t,y_t) stitchted: ', np.corrcoef(covy_e[idx_stitched], covy_h[idx_stitched]))
print('corr(y_t,y_t) observed: ', np.corrcoef(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)]))
print('corr(y_t,y_t-1) stitchted: ', np.corrcoef(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched]))
print('corr(y_t,y_t-1) observed: ', np.corrcoef(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)]))
In [ ]:
In [ ]:
%matplotlib inline
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')
p, T = spikes[0][0].shape
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size
n = 10
neuron_shuffle = loadfile['neuron_shuffle']
data = data[:,neuron_shuffle]
initkey = 'params_naive'
pars_init = loadfile[initkey].reshape(1,)[0]
tmp = {}
for j in range(len(pars_init.keys())):
tmp[str(pars_init.keys()[j])] = pars_init.values()[j]
pars_init = tmp
pars_init['R'] = np.diag(pars_init['R'])
pars_init['V0'] = pars_init['Q']
pars_hat = pars_init.copy()
pars_hat['R'] = np.diag(pars_hat['R'])
import scipy as sp
import matplotlib.pyplot as plt
#Pi = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], pars_hat['Q'])
Pi = pars_hat['Pi']
plt.figure(1,figsize=(15,22))
try:
Pi_true = sp.linalg.solve_discrete_lyapunov(pars_true['A'], pars_true['Q'])
tmp1 = pars_true['C'].dot(Pi_true.dot(pars_true['C'].transpose())) + np.diag(pars_true['R'])
except:
cov_all = np.cov(np.hstack([data[1:], data[:-1]]).T)
tmp1 = cov_all[np.ix_(np.arange(p), np.arange(p))]
tmp2 = pars_hat['C'].dot(Pi.dot(pars_hat['C'].transpose())) + pars_hat['R']
m = np.min((tmp1-np.diag(np.diag(tmp1))).min(),(tmp2-np.diag(np.diag(tmp2))).min())
M = np.max((tmp1-np.diag(np.diag(tmp1))).max(),(tmp2-np.diag(np.diag(tmp2))).max())
plt.subplot(2,3,1)
plt.imshow(tmp1-np.diag(np.diag(tmp1)),interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true instantaneous covs')
plt.subplot(2,3,2)
plt.imshow(tmp2-np.diag(np.diag(tmp2)), interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated instantaneous covs')
plt.subplot(2,3,3)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')
try:
tmp1 = pars_true['C'].dot(pars_true['A']).dot(Pi_true.dot(pars_true['C'].transpose()))
except:
tmp1 = cov_all[np.ix_(np.arange(0, p), np.arange(p+1 , 2*p))]
tmp2 = pars_hat['C'].dot(pars_hat['A']).dot(Pi.dot(pars_hat['C'].transpose()))
m = np.min(tmp1.min(),tmp2.min())
M = np.max(tmp1.max(),tmp2.max())
plt.subplot(2,3,4)
plt.imshow(tmp1,interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true time_lagged covs')
plt.subplot(2,3,5)
plt.imshow(tmp2, interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated time-lagged covs')
plt.subplot(2,3,6)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')
plt.figure(2,figsize=(15,15))
plt.subplot(2,2,1)
plt.plot(pars_hat['d'])
try:
plt.plot(pars_true['d'])
except:
pass
plt.legend(['true', 'est'])
plt.title('d')
plt.subplot(2,2,2)
plt.plot(pars_hat['R'])
try:
plt.plot(pars_true['R'])
except:
pass
plt.legend(['true', 'est'])
plt.title('R')
plt.subplot(2,2,3)
plt.plot(np.sort(np.linalg.eig(pars_hat['A'])[0]))
try:
plt.plot(np.sort(np.linalg.eig(pars_true['A'])[0]))
except:
pass
plt.legend(['true', 'est'])
plt.title('eig(A)')
plt.title('eigenvalues of A')
plt.show()
""" second batch of figures"""
covy_h= np.dot( np.dot(pars_hat['C'], Pi),pars_hat['C'].transpose()) + pars_hat['R']
try:
covy_t= np.dot(np.dot(pars_true['C'], Pi_true),pars_true['C'].transpose()) + np.diag(pars_true['R'])
covy_tl_t=(np.dot(np.dot(pars_true['C'],np.dot(pars_true['A'], Pi_true)),pars_true['C'].transpose()))
plot_truth = True
except:
plot_truth = False
y_tl = np.zeros([2*p,T-1])
y_tl[range(p),:] = data[range(0,T-1),:].T
y_tl[range(p,2*p),:] = data[range(1,T),:].T
covy = np.cov(y_tl)
covy_e= covy[np.ix_(range(p),range(p))]
covy_tl_e= covy[np.ix_(range(p,2*p),range(0,p))]
covy_tl_h= np.dot(np.dot(pars_hat['C'], np.dot(pars_hat['A'],Pi)), pars_hat['C'].transpose())
idx_stitched = np.ones([p,p],dtype = bool)
for i in range(len(obs_scheme.sub_pops)):
if len(obs_scheme.sub_pops[i])>0:
idx_stitched[np.ix_(obs_scheme.sub_pops[i],obs_scheme.sub_pops[i])] = False
plt.imshow(idx_stitched,interpolation='none')
plt.figure(3, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_e[np.invert(idx_stitched)], covy_t[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. true')
plt.ylabel('instantaneous')
plt.subplot(1,3,2)
plt.plot(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_t[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
plt.title('obs. true vs. obs. stitched')
plt.figure(4, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_t[np.invert(idx_stitched)], '.')
plt.ylabel('time-lagged')
plt.title('obs. emp vs. obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_tl_t[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
plt.title('obs. non-observed true vs. obs. stitched')
plt.figure(5, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_e[idx_stitched], covy_t[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_e[idx_stitched], covy_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_t[idx_stitched], covy_h[idx_stitched], '.')
plt.title('non-obs. true vs. non-obs. titched')
plt.figure(6, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
plt.plot(covy_tl_e[idx_stitched], covy_tl_t[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
plt.plot(covy_tl_t[idx_stitched], covy_tl_h[idx_stitched], '.')
plt.title('non-obs. true vs. non-obs. stitched')
print('corr(y_t,y_t) stitchted: ', np.corrcoef(covy_e[idx_stitched], covy_h[idx_stitched]))
print('corr(y_t,y_t) observed: ', np.corrcoef(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)]))
print('corr(y_t,y_t-1) stitchted: ', np.corrcoef(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched]))
print('corr(y_t,y_t-1) observed: ', np.corrcoef(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)]))
In [ ]:
print('loading data')
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']
print('concatenating data')
p, T = 1000, 6000
#data = spikes[0][0].T
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')
neuron_shuffle = loadfile['neuron_shuffle']
data = data[:,neuron_shuffle]
n = 10
print(data.shape)
###################
# EM cycles #
###################
print('loading parameters')
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter300.npz')
pars_init = loadfile['estPars'].reshape(1,)[0]
pars_init['R'] = np.diag(pars_init['R'])
sub_pops = tuple(loadfile['sub_pops'].tolist())
obs_pops =tuple(loadfile['obs_pops'].tolist())
obs_time = tuple(loadfile['obs_time'].tolist())
obs_scheme = ObservationScheme(p, T, sub_pops, obs_pops, obs_time)
model = init_LDS_model(pars_init, data, obs_scheme) # set to initialisation
print('(p,T,n)', (p,T,n))
print('len sub_pops', [len(item) for item in obs_scheme.sub_pops])
#stats_init,_ = collect_LDS_stats(model)
print('fitting')
likes = [update(model) for _ in progprint_xrange(100)]
stats_hat,pars_hat = collect_LDS_stats(model)
In [ ]:
plt.plot(likes)
plt.show()
In [ ]:
from scipy.io import savemat
broken = False
save_file = '/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter400'
eps = - np.inf
save_file_m = {'ifBroken':broken,
'll' : likes,
'T' : model.states_list[0].T,
'Trial': len(model.states_list),
'epsilon':eps,
'initPars':pars_init,
'estPars': pars_hat,
'stats_h': stats_hat,
'sub_pops' : obs_scheme.sub_pops,
'obs_pops' : obs_scheme.obs_pops,
'obs_time' : obs_scheme.obs_time}
savemat(save_file,save_file_m) # does the actual saving
np.savez(save_file,
broken=broken,
ll=likes,
T=model.states_list[0].T,
Trial=len(model.states_list),
epsilon=eps,
initPars=pars_init,
estPars =pars_hat,
stats_h = stats_hat,
sub_pops=obs_scheme.sub_pops,
obs_time=obs_scheme.obs_time,
obs_pops=obs_scheme.obs_pops)
In [ ]: