In [ ]:
from __future__ import division
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt

from pybasicbayes.distributions import Regression
from pybasicbayes.util.text import progprint_xrange
from autoregressive.distributions import AutoRegression

from scipy.io import loadmat
import glob, os
from scipy.io import savemat # store results for comparison with Matlab code   
from scipy.linalg import solve_discrete_lyapunov as dtlyap # solve discrete-time Lyapunov equation

absolute_code_path = '/home/marcel/Desktop/Projects/Stitching/code/pyLDS_dev/'
os.chdir(absolute_code_path +'pylds')

from pylds.models import LDS, DefaultLDS
from pylds.distributions import Regression_diag, AutoRegression_input
from pylds.obs_scheme import ObservationScheme
from pylds.user_util import gen_pars, rand_rotation_matrix, init_LDS_model, collect_LDS_stats

def update(model):
    model.EM_step()
    return model.log_likelihood()

Fit models for illustration #1


In [ ]:
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)

def update(model):
    model.EM_step()
    return model.log_likelihood()
        
eps = np.log(1.01)
max_iter = 25

#initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']
initialisers = ['params', 'params_flip', 'random']

for i in idx_exps:
    
    ##################
    # load the data  #
    ##################
    
    filename = filenames[i]

    os.chdir('../data/')    
    loadfile = np.load(filename)

    data = loadfile['y']
    T,p = data.shape
    n   = loadfile['x'].shape[1]
    obs_scheme = ObservationScheme(p=p, T=T, 
                                   sub_pops=tuple([item for item in loadfile['sub_pops']]),
                                   obs_pops=loadfile['obs_pops'], 
                                   obs_time=loadfile['obs_time'])
    pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)     
    tmp = {}
    for j in range(len(pars_true.keys())):
        tmp[str(pars_true.keys()[j])] =  pars_true.values()[j]
    pars_true = tmp    

    print('dataset #', i)
    print('(T, p, n, eps) = ', (T, p, n, eps))
    
    ####################
    # pick initialiser #
    ####################

    
    for initialiser in initialisers:
        
        print('initialiser ', initialiser)

        if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
            os.chdir('../init/')

            num_reps = 1
            
            loadfile = np.load('init_'+filename)

            initkey =  initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
            
            pars_init = loadfile[initkey].reshape(1,)[0]
            tmp = {}
            for j in range(len(pars_init.keys())):
                tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
            pars_init = tmp    

            pars_init['R'] = np.diag(pars_init['R'])
            
            # for SSID-derived initialisers, also try out flipping parts of C
            if initialiser[-5:]=='_flip':
                idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1]) 
                pars_init['C'][idx0_no_overlap,:] *= -1


        elif initialiser=='random':
            
            num_reps = 10
            
        else:
            raise Exception('unexpected initialiser!')

        
        for rep in range(num_reps):

            print( 'run #' + str(rep+1) +'/' +str(num_reps) )
            
            if initialiser=='random':

                pars_init, _ = gen_pars(n, p, u_dim=0, 
                                     pars_in=None, 
                                     obs_scheme=obs_scheme,
                                     gen_A='diagonal', lts=0.99 * np.ones((n,)),
                                     gen_B='random', 
                                     gen_Q='identity', 
                                     gen_mu0='random', 
                                     gen_V0='identity', 
                                     gen_C='random', 
                                     gen_d='mean', 
                                     gen_R='fractionObserved',
                                     diag_R_flag=True,
                                     x=None, y=data.T, u=None)    

            ###################
            #    EM cycles    #
            ###################

            try:
                # get EM-step results after m iterations                    
                model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation                    
                print 'fitting #' + str(i)    
                likes = [-np.inf]
                for t in progprint_xrange(max_iter):
                    likes.append(update(model))
                    if likes[-1]-likes[-2] < eps:
                        break

                stats_hat,pars_hat = collect_LDS_stats(model)

                # get EM-step results from true parameters
                model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
                model.E_step()
                stats_true,_ = collect_LDS_stats(model)
                model.M_step()
                
                broken = False
                y_out = model.states_list[0].data
                x_out = model.states_list[0].stateseq
                Pi = dtlyap(pars_true['A'], pars_true['Q'])
                Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
                Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
                Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))
                
                
            except:
                print('')
                print('############')
                print('#RUN BROKE!#')
                print('############')
                print('')
                
                broken = True
                y_out = model.states_list[0].data
                x_out = []
                pars_hat, stats_h, stats_true = [],[],[]
                Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0
                
            ###################
            #  store results  #
            ###################

            print('finished in ' + str(len(likes)-1) + ' many steps.')

            os.chdir('../fits/')

            save_file = initialiser + '_idx' + str(rep) + '_' + filename

            save_file_m = {'ifBroken':broken,
                           'x': x_out, 
                           'y': y_out,
                           'u' : [], 
                           'll' : likes, 
                           'T' : model.states_list[0].T, 
                           'Trial': len(model.states_list), 
                           'ifUseB':False, 
                           'ifUseA':True, 
                           'epsilon':eps,
                           'truePars':pars_true,
                           'initPars':pars_init,
                           'estPars': pars_hat,
                           'stats_h': stats_hat,
                           'stats_true': stats_true,
                           'Pi':Pi,
                           'Pi_h':Pi_h,
                           'Pi_t':Pi_t,
                           'Pi_t_h':Pi_t_h,
                           'obsScheme' : obs_scheme}
            savemat(save_file,save_file_m) # does the actual saving

            np.savez(save_file, 
                    broken=broken,
                    x=x_out,
                    y=y_out,
                    ll=likes,
                    T=model.states_list[0].T, 
                    Trial=len(model.states_list), 
                    ifUseA=True,
                    ifUseB=False,
                    epsilon=eps,
                    initPars=pars_init,
                    truePars=pars_true,
                    estPars =pars_hat,
                    stats_h = stats_hat,
                    stats_true = stats_true,
                    Pi=Pi,
                    Pi_h=Pi_h,
                    Pi_t=Pi_t,
                    Pi_t_h=Pi_t_h,
                    sub_pops=obs_scheme.sub_pops,            
                    obs_time=obs_scheme.obs_time,            
                    obs_pops=obs_scheme.obs_pops)

Fit models for simulation #1


In [ ]:
relative_data_path = '../../../results/cosyne_poster/simulation_1/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)

def update(model):
    model.EM_step()
    return model.log_likelihood()
        
eps = np.log(1.01)
max_iter = 25

initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']

for i in idx_exps:
    
    ##################
    # load the data  #
    ##################
    
    filename = filenames[i]

    os.chdir('../data/')    
    loadfile = np.load(filename)

    data = loadfile['y']
    T,p = data.shape
    n   = loadfile['x'].shape[1]
    sub_pops = tuple([item for item in loadfile['sub_pops']])
    obs_pops = loadfile['obs_pops']

    pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)     
    tmp = {}
    for j in range(len(pars_true.keys())):
        tmp[str(pars_true.keys()[j])] =  pars_true.values()[j]
    pars_true = tmp    

    print('dataset #', i)
    print('(T, p, n, eps) = ', (T, p, n, eps))
    
    ####################
    # pick initialiser #
    ####################

    
    for initialiser in initialisers:
        
        print('initialiser ', initialiser)

        if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
            os.chdir('../init/')

            num_reps = 1
            
            loadfile = np.load('init_'+filename)

            initkey =  initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
            
            pars_init = loadfile[initkey].reshape(1,)[0]
            tmp = {}
            for j in range(len(pars_init.keys())):
                tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
            pars_init = tmp    

            pars_init['R'] = np.diag(pars_init['R'])
            
            # for SSID-derived initialisers, also try out flipping parts of C
            if initialiser[-5:]=='_flip':
                idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1]) 
                pars_init['C'][idx0_no_overlap,:] *= -1


        elif initialiser=='random':
            
            num_reps = 10
            
        else:
            raise Exception('unexpected initialiser!')

        
        for rep in range(num_reps):

            print( 'run #' + str(rep+1) +'/' +str(num_reps) )
            
            if initialiser=='random':

                pars_init, _ = gen_pars(n, p, u_dim=0, 
                                     pars_in=None, 
                                     obs_scheme=obs_scheme,
                                     gen_A='diagonal', lts=0.99 * np.ones((n,)),
                                     gen_B='random', 
                                     gen_Q='identity', 
                                     gen_mu0='random', 
                                     gen_V0='identity', 
                                     gen_C='random', 
                                     gen_d='mean', 
                                     gen_R='fractionObserved',
                                     diag_R_flag=True,
                                     x=None, y=data.T, u=None)    

            ###################
            #    EM cycles    #
            ###################

            try:
                
                # get EM-step results from true parameters
                model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
                model.E_step()
                stats_true,_ = collect_LDS_stats(model)
                
                # get EM-step results after m iterations                    
                model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation                    
                print 'fitting #' + str(i)    
                likes = [-np.inf]
                for t in progprint_xrange(max_iter):
                    likes.append(update(model))
                    if likes[-1]-likes[-2] < eps:
                        break

                stats_hat,pars_hat = collect_LDS_stats(model)
                
                broken = False
                y_out = model.states_list[0].data
                x_out = model.states_list[0].stateseq
                Pi = dtlyap(pars_true['A'], pars_true['Q'])
                Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
                Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
                Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))
                
                
            except:
                print('')
                print('############')
                print('#RUN BROKE!#')
                print('############')
                print('')
                
                broken = True
                y_out = []
                x_out = []
                pars_hat, stats_hat, stats_true = [],[],[]
                Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0
                
            ###################
            #  store results  #
            ###################

            print('finished in ' + str(len(likes)-1) + ' many steps.')

            os.chdir('../fits/')

            save_file = initialiser + '_idx' + str(rep) + '_' + filename

            save_file_m = {'ifBroken':broken,
                           'x': x_out, 
                           'y': y_out,
                           'u' : [], 
                           'll' : likes, 
                           'T' : model.states_list[0].T, 
                           'Trial': len(model.states_list), 
                           'ifUseB':False, 
                           'ifUseA':True, 
                           'epsilon':eps,
                           'truePars':pars_true,
                           'initPars':pars_init,
                           'estPars': pars_hat,
                           'stats_h': stats_hat,
                           'stats_true': stats_true,
                           'Pi':Pi,
                           'Pi_h':Pi_h,
                           'Pi_t':Pi_t,
                           'Pi_t_h':Pi_t_h,
                           'obsScheme' : obs_scheme}
            savemat(save_file,save_file_m) # does the actual saving

            np.savez(save_file, 
                    broken=broken,
                    ll=likes,
                    T=model.states_list[0].T, 
                    Trial=len(model.states_list), 
                    ifUseA=True,
                    ifUseB=False,
                    epsilon=eps,
                    initPars=pars_init,
                    truePars=pars_true,
                    estPars =pars_hat,
                    stats_h = stats_hat,
                    stats_true = stats_true,
                    Pi=Pi,
                    Pi_h=Pi_h,
                    Pi_t=Pi_t,
                    Pi_t_h=Pi_t_h,
                    sub_pops=obs_scheme.sub_pops,            
                    obs_time=obs_scheme.obs_time,            
                    obs_pops=obs_scheme.obs_pops)

Fit models for simulation #2


In [ ]:
%matplotlib inline
relative_data_path = '../../../results/cosyne_poster/simulation_2/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)
        
eps = np.log(1.01)
max_iter = 50

initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']

for idx in [0,2,5,9]:
    
    ##################
    # load the data  #
    ##################
    
    print('idx', idx)
    filename = 'LDS_save_idx' + str(idx) + '.npz'
    print(filename)

    os.chdir('../data/')    
    loadfile = np.load(filename)

    data = loadfile['y']
    T,p = data.shape
    n   = loadfile['x'].shape[1]
    
    pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)     
    tmp = {}
    for j in range(len(pars_true.keys())):
        tmp[str(pars_true.keys()[j])] =  pars_true.values()[j]
    pars_true = tmp    

    print('dataset #', idx)
    print('(T, p, n, eps) = ', (T, p, n, eps))
    
    ####################
    # pick initialiser #
    ####################

    os.chdir('../init/')
    initfiles = glob.glob("*_LDS_save_idx" + str(idx) + ".npz")
    num_prots = len(initfiles)
    
    for prot in [0,12]: #range(num_prots):
        
        initfile = initfiles[prot]
        print('prot ' + str(prot) + '/' + str(num_prots))
        os.chdir('../init/')
        loadinit = np.load(initfile)

        obs_scheme = ObservationScheme(p=p, T=T, 
                                   sub_pops=tuple([item for item in loadinit['sub_pops']]),
                                   obs_pops=loadinit['obs_pops'], 
                                   obs_time=loadinit['obs_time'])
        
        for initialiser in initialisers:

            print('initialiser ', initialiser)
            if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
                
                num_repets = 1
                initkey =  initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser

                pars_init = loadinit[initkey].reshape(1,)[0]
                tmp = {}
                for j in range(len(pars_init.keys())):
                    tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
                pars_init = tmp    

                pars_init['R'] = np.diag(pars_init['R'])

                # for SSID-derived initialisers, also try out flipping parts of C
                if initialiser[-5:]=='_flip':
                    idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1]) 
                    pars_init['C'][idx0_no_overlap,:] *= -1


            elif initialiser=='random':

                num_repets = 5

            else:
                raise Exception('unexpected initialiser!')


            for repet in range(num_repets):

                print( 'run #' + str(repet+1) +'/' +str(num_repets) )

                if initialiser=='random':

                    pars_init, _ = gen_pars(n, p, u_dim=0, 
                                         pars_in=None, 
                                         obs_scheme=obs_scheme,
                                         gen_A='diagonal', lts=0.99 * np.ones((n,)),
                                         gen_B='random', 
                                         gen_Q='identity', 
                                         gen_mu0='random', 
                                         gen_V0='identity', 
                                         gen_C='random', 
                                         gen_d='mean', 
                                         gen_R='fractionObserved',
                                         diag_R_flag=True,
                                         x=None, y=data.T, u=None)    

                ###################
                #    EM cycles    #
                ###################
                likes = [-np.inf]
                try:
                    # get EM-step results after m iterations                    
                    model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation                    
                    print 'fitting #' + str(idx)    
                    for t in progprint_xrange(max_iter):
                        likes.append(update(model))
                        if likes[-1]-likes[-2] < eps:
                            break

                    stats_hat,pars_hat = collect_LDS_stats(model)

                    # get EM-step results from true parameters
                    model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
                    model.E_step()
                    stats_true,_ = collect_LDS_stats(model)
                    model.M_step()

                    broken = False
                    Pi = dtlyap(pars_true['A'], pars_true['Q'])
                    Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
                    Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
                    Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))


                except:
                    print('')
                    print('############')
                    print('#RUN BROKE!#')
                    print('############')
                    print('')

                    broken = True
                    pars_hat, stats_hat, stats_true = [],[],[]
                    Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0
                    
                if not broken:
                    plt.figure(figsize=(25,25))
                    plt.subplot(1,3,1)
                    plt.imshow(np.cov(data.T),interpolation='none')
                    plt.subplot(1,3,2)
                    if initialiser=='random':
                        pars_init['Pi'] = dtlyap(pars_init['A'], pars_init['Q'])
                        R = pars_init['R']
                    else:
                        R = np.diag(pars_init['R'])
                    plt.imshow(pars_init['C'].dot(pars_init['Pi']).dot(pars_init['C'].T) + R,interpolation='none')
                    plt.subplot(1,3,3)
                    plt.imshow(pars_hat['C'].dot(Pi_h).dot(pars_hat['C'].T) + pars_hat['R'],interpolation='none')   
                    plt.show()

                ###################
                #  store results  #
                ###################

                print('finished in ' + str(len(likes)-1) + ' many steps.')

                os.chdir('../fits/')

                save_file = initialiser + '_prot' + str(prot) + '_rep' + str(repet) + '_' + filename

                save_file_m = {'ifBroken':broken,
                               'll' : likes, 
                               'T' : T, 
                               'Trial': 1, 
                               'ifUseB':False, 
                               'ifUseA':True, 
                               'epsilon':eps,
                               'truePars':pars_true,
                               'initPars':pars_init,
                               'estPars': pars_hat,
                               'stats_h': stats_hat,
                               'stats_true': stats_true,
                               'Pi':Pi,
                               'Pi_h':Pi_h,
                               'Pi_t':Pi_t,
                               'Pi_t_h':Pi_t_h,
                               'sub_pops':obs_scheme.sub_pops,            
                               'obs_time':obs_scheme.obs_time,            
                               'obs_pops':obs_scheme.obs_pops}
                savemat(save_file,save_file_m) # does the actual saving

                np.savez(save_file, 
                        broken=broken,
                        ll=likes,
                        T=T, 
                        Trial=1, 
                        ifUseA=True,
                        ifUseB=False,
                        epsilon=eps,
                        initPars=pars_init,
                        truePars=pars_true,
                        estPars =pars_hat,
                        stats_h = stats_hat,
                        stats_true = stats_true,
                        Pi=Pi,
                        Pi_h=Pi_h,
                        Pi_t=Pi_t,
                        Pi_t_h=Pi_t_h,
                        sub_pops=obs_scheme.sub_pops,            
                        obs_time=obs_scheme.obs_time,            
                        obs_pops=obs_scheme.obs_pops)

Fit big sim


In [ ]:
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']

p, T = 1000, 6000
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size

n = 10

#tmp = np.random.choice(np.arange(p),size=p,replace=False)
#sub_pops = (np.sort(tmp[:p//2 + 2]), np.sort(tmp[p//2 - 2:]))
sub_pops = (np.arange(p//2+100), np.arange(p//2-100, p))
obs_pops = np.array((0,1))
obs_time = np.array((T//2,T))
obs_scheme = ObservationScheme(p, T, sub_pops, obs_pops, obs_time)

###################
#    EM cycles    #
###################

loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')
neuron_shuffle = loadfile['neuron_shuffle']
data = data[:,neuron_shuffle]

initkey =  'params_naive'

pars_init = loadfile[initkey].reshape(1,)[0]
tmp = {}
for j in range(len(pars_init.keys())):
    tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
pars_init = tmp    

pars_init['R'] = np.diag(pars_init['R'])
pars_init['V0'] = pars_init['Q']

D, V = np.linalg.eig(pars_init['A'])
if np.any(np.abs(D) > 1):
    print(np.abs(D))
    D /= np.maximum(np.abs(D), 1.0001)
    print(np.abs(D))
    pars_init['A'] = np.real(V.dot(np.diag(D).dot(np.linalg.inv(V))))

model = init_LDS_model(pars_init, data, obs_scheme) # set to initialisation


#stats_init,_ = collect_LDS_stats(model)

print 'fitting'
likes = [update(model) for _ in progprint_xrange(100)]
stats_hat,pars_hat = collect_LDS_stats(model)

In [ ]:
from scipy.io import savemat
broken = False
save_file = '/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter100'

eps = - np.inf
save_file_m = {'ifBroken':broken,
               'll' : likes, 
               'T' : model.states_list[0].T, 
               'Trial': len(model.states_list), 
               'epsilon':eps,
               'initPars':pars_init,
               'estPars': pars_hat,
               'stats_h': stats_hat,
               'sub_pops' : obs_scheme.sub_pops,
               'obs_pops' : obs_scheme.obs_pops,
               'obs_time' : obs_scheme.obs_time}

savemat(save_file,save_file_m) # does the actual saving

np.savez(save_file, 
        broken=broken,
        ll=likes,
        T=model.states_list[0].T, 
        Trial=len(model.states_list), 
        epsilon=eps,
        initPars=pars_init,
        estPars =pars_hat,
        stats_h = stats_hat,
        sub_pops=obs_scheme.sub_pops,            
        obs_time=obs_scheme.obs_time,            
        obs_pops=obs_scheme.obs_pops)

Visualize goodnes of EM fit


In [ ]:
%matplotlib inline
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']

# get neuron_shuffle
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')
neuron_shuffle = loadfile['neuron_shuffle']

# get param EM fit
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter400.npz')
pars_hat = loadfile['estPars'].reshape(1,)[0]
likes = loadfile['ll']

p, T = spikes[0][0].shape
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size

n = 10

data = data[:,neuron_shuffle]
initkey =  'params_naive'

import scipy as sp
import matplotlib.pyplot as plt

Pi = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], pars_hat['Q'])

plt.figure(1,figsize=(15,22))
try:
    Pi_true = sp.linalg.solve_discrete_lyapunov(pars_true['A'], pars_true['Q'])
    tmp1 = pars_true['C'].dot(Pi_true.dot(pars_true['C'].transpose())) + np.diag(pars_true['R'])
except:
    cov_all = np.cov(np.hstack([data[1:], data[:-1]]).T)
    tmp1 = cov_all[np.ix_(np.arange(p), np.arange(p))]
tmp2 = pars_hat['C'].dot(Pi.dot(pars_hat['C'].transpose())) + pars_hat['R']    
m = np.min((tmp1-np.diag(np.diag(tmp1))).min(),(tmp2-np.diag(np.diag(tmp2))).min())
M = np.max((tmp1-np.diag(np.diag(tmp1))).max(),(tmp2-np.diag(np.diag(tmp2))).max())

plt.subplot(2,3,1)
plt.imshow(tmp1-np.diag(np.diag(tmp1)),interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true instantaneous covs')
plt.subplot(2,3,2)
plt.imshow(tmp2-np.diag(np.diag(tmp2)), interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated instantaneous covs')
plt.subplot(2,3,3)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')

try:
    tmp1 = pars_true['C'].dot(pars_true['A']).dot(Pi_true.dot(pars_true['C'].transpose()))
except:
    tmp1 = cov_all[np.ix_(np.arange(0, p), np.arange(p+1 , 2*p))]
tmp2 = pars_hat['C'].dot(pars_hat['A']).dot(Pi.dot(pars_hat['C'].transpose()))     
m = np.min(tmp1.min(),tmp2.min())
M = np.max(tmp1.max(),tmp2.max())

plt.subplot(2,3,4)
plt.imshow(tmp1,interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true time_lagged covs')
plt.subplot(2,3,5)
plt.imshow(tmp2, interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated time-lagged covs')
plt.subplot(2,3,6)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')


plt.figure(2,figsize=(15,15))
plt.subplot(2,2,1)
plt.plot(pars_hat['d'])
try:
    plt.plot(pars_true['d'])
except:
    pass
plt.legend(['true', 'est'])
plt.title('d')
plt.subplot(2,2,2)
plt.plot(pars_hat['R'])
try:
    plt.plot(pars_true['R'])
except:
    pass
    
plt.legend(['true', 'est'])
plt.title('R')
plt.subplot(2,2,3)
plt.plot(np.sort(np.linalg.eig(pars_hat['A'])[0]))
try:
    plt.plot(np.sort(np.linalg.eig(pars_true['A'])[0]))
except:
    pass
plt.legend(['true', 'est'])
plt.title('eig(A)')
plt.title('eigenvalues of A')
plt.show()

""" second batch of figures"""

covy_h= np.dot( np.dot(pars_hat['C'], Pi),pars_hat['C'].transpose()) + pars_hat['R']

try: 
    covy_t= np.dot(np.dot(pars_true['C'], Pi_true),pars_true['C'].transpose()) + np.diag(pars_true['R'])
    covy_tl_t=(np.dot(np.dot(pars_true['C'],np.dot(pars_true['A'], Pi_true)),pars_true['C'].transpose()))
    plot_truth = True
except:
    plot_truth = False
    
y_tl = np.zeros([2*p,T-1])
y_tl[range(p),:] = data[range(0,T-1),:].T
y_tl[range(p,2*p),:] = data[range(1,T),:].T
covy = np.cov(y_tl)

covy_e=    covy[np.ix_(range(p),range(p))]
covy_tl_e= covy[np.ix_(range(p,2*p),range(0,p))]


sub_pops = loadfile['sub_pops']
covy_tl_h= np.dot(np.dot(pars_hat['C'], np.dot(pars_hat['A'],Pi)), pars_hat['C'].transpose())
idx_stitched = np.ones([p,p],dtype = bool)
for i in range(len(sub_pops)):
    if len(sub_pops[i])>0:
        idx_stitched[np.ix_(sub_pops[i],sub_pops[i])] = False
plt.imshow(idx_stitched,interpolation='none')

plt.figure(3, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_e[np.invert(idx_stitched)], covy_t[np.invert(idx_stitched)], '.')
    plt.title('obs. emp vs. obs. true')
plt.ylabel('instantaneous')
plt.subplot(1,3,2)
plt.plot(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_t[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
    plt.title('obs. true vs. obs. stitched')

plt.figure(4, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_t[np.invert(idx_stitched)], '.')
plt.ylabel('time-lagged')
plt.title('obs. emp vs. obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_tl_t[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
    plt.title('obs. non-observed true vs. obs. stitched')

plt.figure(5, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_e[idx_stitched], covy_t[idx_stitched], '.')
    plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_e[idx_stitched], covy_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_t[idx_stitched], covy_h[idx_stitched], '.')
    plt.title('non-obs. true vs. non-obs. titched')

plt.figure(6, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_tl_e[idx_stitched], covy_tl_t[idx_stitched], '.')
    plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_tl_t[idx_stitched], covy_tl_h[idx_stitched], '.')
    plt.title('non-obs. true vs. non-obs. stitched')
    
    
print('corr(y_t,y_t) stitchted: ', np.corrcoef(covy_e[idx_stitched], covy_h[idx_stitched]))
print('corr(y_t,y_t) observed: ', np.corrcoef(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)]))

print('corr(y_t,y_t-1) stitchted: ', np.corrcoef(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched]))
print('corr(y_t,y_t-1) observed: ', np.corrcoef(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)]))

In [ ]:

Visualize goodnes of naive SSID initialisation


In [ ]:
%matplotlib inline
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')

p, T = spikes[0][0].shape
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size

n = 10

neuron_shuffle = loadfile['neuron_shuffle']

data = data[:,neuron_shuffle]

initkey =  'params_naive'

pars_init = loadfile[initkey].reshape(1,)[0]
tmp = {}
for j in range(len(pars_init.keys())):
    tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
pars_init = tmp    

pars_init['R'] = np.diag(pars_init['R'])
pars_init['V0'] = pars_init['Q']

pars_hat = pars_init.copy()
pars_hat['R'] = np.diag(pars_hat['R'])

import scipy as sp
import matplotlib.pyplot as plt

#Pi = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], pars_hat['Q'])
Pi = pars_hat['Pi']
plt.figure(1,figsize=(15,22))
try:
    Pi_true = sp.linalg.solve_discrete_lyapunov(pars_true['A'], pars_true['Q'])
    tmp1 = pars_true['C'].dot(Pi_true.dot(pars_true['C'].transpose())) + np.diag(pars_true['R'])
except:
    cov_all = np.cov(np.hstack([data[1:], data[:-1]]).T)
    tmp1 = cov_all[np.ix_(np.arange(p), np.arange(p))]
tmp2 = pars_hat['C'].dot(Pi.dot(pars_hat['C'].transpose())) + pars_hat['R']    
m = np.min((tmp1-np.diag(np.diag(tmp1))).min(),(tmp2-np.diag(np.diag(tmp2))).min())
M = np.max((tmp1-np.diag(np.diag(tmp1))).max(),(tmp2-np.diag(np.diag(tmp2))).max())

plt.subplot(2,3,1)
plt.imshow(tmp1-np.diag(np.diag(tmp1)),interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true instantaneous covs')
plt.subplot(2,3,2)
plt.imshow(tmp2-np.diag(np.diag(tmp2)), interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated instantaneous covs')
plt.subplot(2,3,3)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')

try:
    tmp1 = pars_true['C'].dot(pars_true['A']).dot(Pi_true.dot(pars_true['C'].transpose()))
except:
    tmp1 = cov_all[np.ix_(np.arange(0, p), np.arange(p+1 , 2*p))]
tmp2 = pars_hat['C'].dot(pars_hat['A']).dot(Pi.dot(pars_hat['C'].transpose()))     
m = np.min(tmp1.min(),tmp2.min())
M = np.max(tmp1.max(),tmp2.max())

plt.subplot(2,3,4)
plt.imshow(tmp1,interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('true time_lagged covs')
plt.subplot(2,3,5)
plt.imshow(tmp2, interpolation='none')
plt.clim(m,M)
plt.colorbar()
plt.title('estimated time-lagged covs')
plt.subplot(2,3,6)
plt.plot(tmp1[:], tmp2[:], '.')
plt.xlabel('true')
plt.ylabel('est')


plt.figure(2,figsize=(15,15))
plt.subplot(2,2,1)
plt.plot(pars_hat['d'])
try:
    plt.plot(pars_true['d'])
except:
    pass
plt.legend(['true', 'est'])
plt.title('d')
plt.subplot(2,2,2)
plt.plot(pars_hat['R'])
try:
    plt.plot(pars_true['R'])
except:
    pass
    
plt.legend(['true', 'est'])
plt.title('R')
plt.subplot(2,2,3)
plt.plot(np.sort(np.linalg.eig(pars_hat['A'])[0]))
try:
    plt.plot(np.sort(np.linalg.eig(pars_true['A'])[0]))
except:
    pass
plt.legend(['true', 'est'])
plt.title('eig(A)')
plt.title('eigenvalues of A')
plt.show()

""" second batch of figures"""

covy_h= np.dot( np.dot(pars_hat['C'], Pi),pars_hat['C'].transpose()) + pars_hat['R']

try: 
    covy_t= np.dot(np.dot(pars_true['C'], Pi_true),pars_true['C'].transpose()) + np.diag(pars_true['R'])
    covy_tl_t=(np.dot(np.dot(pars_true['C'],np.dot(pars_true['A'], Pi_true)),pars_true['C'].transpose()))
    plot_truth = True
except:
    plot_truth = False
    
y_tl = np.zeros([2*p,T-1])
y_tl[range(p),:] = data[range(0,T-1),:].T
y_tl[range(p,2*p),:] = data[range(1,T),:].T
covy = np.cov(y_tl)

covy_e=    covy[np.ix_(range(p),range(p))]
covy_tl_e= covy[np.ix_(range(p,2*p),range(0,p))]

covy_tl_h= np.dot(np.dot(pars_hat['C'], np.dot(pars_hat['A'],Pi)), pars_hat['C'].transpose())
idx_stitched = np.ones([p,p],dtype = bool)
for i in range(len(obs_scheme.sub_pops)):
    if len(obs_scheme.sub_pops[i])>0:
        idx_stitched[np.ix_(obs_scheme.sub_pops[i],obs_scheme.sub_pops[i])] = False
plt.imshow(idx_stitched,interpolation='none')

plt.figure(3, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_e[np.invert(idx_stitched)], covy_t[np.invert(idx_stitched)], '.')
    plt.title('obs. emp vs. obs. true')
plt.ylabel('instantaneous')
plt.subplot(1,3,2)
plt.plot(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_t[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)], '.')
    plt.title('obs. true vs. obs. stitched')

plt.figure(4, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_t[np.invert(idx_stitched)], '.')
plt.ylabel('time-lagged')
plt.title('obs. emp vs. obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
plt.title('obs. emp vs. obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_tl_t[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)], '.')
    plt.title('obs. non-observed true vs. obs. stitched')

plt.figure(5, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_e[idx_stitched], covy_t[idx_stitched], '.')
    plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_e[idx_stitched], covy_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_t[idx_stitched], covy_h[idx_stitched], '.')
    plt.title('non-obs. true vs. non-obs. titched')

plt.figure(6, figsize=(20,10))
plt.subplot(1,3,1)
if plot_truth:
    plt.plot(covy_tl_e[idx_stitched], covy_tl_t[idx_stitched], '.')
    plt.title('non-obs. emp vs. non-obs. true')
plt.subplot(1,3,2)
plt.plot(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched], '.')
plt.title('non-obs. emp vs. non-obs. stitched')
plt.subplot(1,3,3)
if plot_truth:
    plt.plot(covy_tl_t[idx_stitched], covy_tl_h[idx_stitched], '.')
    plt.title('non-obs. true vs. non-obs. stitched')
    
    
print('corr(y_t,y_t) stitchted: ', np.corrcoef(covy_e[idx_stitched], covy_h[idx_stitched]))
print('corr(y_t,y_t) observed: ', np.corrcoef(covy_e[np.invert(idx_stitched)], covy_h[np.invert(idx_stitched)]))

print('corr(y_t,y_t-1) stitchted: ', np.corrcoef(covy_tl_e[idx_stitched], covy_tl_h[idx_stitched]))
print('corr(y_t,y_t-1) observed: ', np.corrcoef(covy_tl_e[np.invert(idx_stitched)], covy_tl_h[np.invert(idx_stitched)]))

Continue big sim


In [ ]:
print('loading data')
spikes = loadmat('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/gb_net/spikes_20trials_10msBins')
spikes= spikes['spikes_out']



print('concatenating data')
p, T = 1000, 6000
#data = spikes[0][0].T
idx_n = np.sort(np.random.choice(1000, size=p, replace=False))
data = np.vstack([spikes[i][0].T[:,idx_n] for i in range(spikes.size)]).astype(np.float)
T *= spikes.size

loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/init/init_experiment.npz')
neuron_shuffle = loadfile['neuron_shuffle']
data = data[:,neuron_shuffle]

n = 10

print(data.shape)

###################
#    EM cycles    #
###################

print('loading parameters')
loadfile = np.load('/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter300.npz')
pars_init = loadfile['estPars'].reshape(1,)[0]
pars_init['R'] = np.diag(pars_init['R'])

sub_pops = tuple(loadfile['sub_pops'].tolist())
obs_pops =tuple(loadfile['obs_pops'].tolist())
obs_time = tuple(loadfile['obs_time'].tolist())
obs_scheme = ObservationScheme(p, T, sub_pops, obs_pops, obs_time)


model = init_LDS_model(pars_init, data, obs_scheme) # set to initialisation

print('(p,T,n)', (p,T,n))
print('len sub_pops', [len(item) for item in obs_scheme.sub_pops])

#stats_init,_ = collect_LDS_stats(model)

print('fitting')
likes = [update(model) for _ in progprint_xrange(100)]
stats_hat,pars_hat = collect_LDS_stats(model)

In [ ]:
plt.plot(likes)
plt.show()

In [ ]:
from scipy.io import savemat
broken = False
save_file = '/home/marcel/Desktop/Projects/Stitching/results/cosyne_poster/experiment_1/fits/params_naive_p1000_iter400'

eps = - np.inf
save_file_m = {'ifBroken':broken,
               'll' : likes, 
               'T' : model.states_list[0].T, 
               'Trial': len(model.states_list), 
               'epsilon':eps,
               'initPars':pars_init,
               'estPars': pars_hat,
               'stats_h': stats_hat,
               'sub_pops' : obs_scheme.sub_pops,
               'obs_pops' : obs_scheme.obs_pops,
               'obs_time' : obs_scheme.obs_time}

savemat(save_file,save_file_m) # does the actual saving

np.savez(save_file, 
        broken=broken,
        ll=likes,
        T=model.states_list[0].T, 
        Trial=len(model.states_list), 
        epsilon=eps,
        initPars=pars_init,
        estPars =pars_hat,
        stats_h = stats_hat,
        sub_pops=obs_scheme.sub_pops,            
        obs_time=obs_scheme.obs_time,            
        obs_pops=obs_scheme.obs_pops)

In [ ]: