In [ ]:

Fit models for illustration #1


In [ ]:
%matplotlib inline
import numpy as np
from scipy.io import loadmat
import glob, os

from scipy.io import savemat # store results for comparison with Matlab code   

from __future__ import division
import matplotlib.pyplot as plt

from pybasicbayes.distributions import Regression
from pybasicbayes.util.text import progprint_xrange
from autoregressive.distributions import AutoRegression

import pyximport
pyximport.install()

absolute_code_path = '/home/mackelab/Desktop/Projects/Stitching/code/pyLDS_dev/'
os.chdir(absolute_code_path +'pylds')

from models import LDS, DefaultLDS
from distributions import Regression_diag, AutoRegression_input
from obs_scheme import ObservationScheme
from user_util import gen_pars, rand_rotation_matrix, init_LDS_model, collect_LDS_stats

from scipy.io import savemat # store results for comparison with Matlab code   
from scipy.linalg import solve_discrete_lyapunov as dtlyap # solve discrete-time Lyapunov equation

relative_data_path = '../../../results/cosyne_poster/illustration_1/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)

def update(model):
    model.EM_step()
    return model.log_likelihood()
        
eps = np.log(1.01)
max_iter = 1000

initialisers = ['random', 'params', 'params_flip', 'params_naive', 'params_naive_flip']
#initialisers = ['params', 'params_flip', 'random']

for idx in [0,6,9]:
    
    ##################
    # load the data  #
    ##################
    
    filename = 'LDS_save_idx' + str(idx) + '.npz'
    print(filename)

    os.chdir('../data/')    
    loadfile = np.load(filename)

    data = loadfile['y']
    T,p = data.shape
    n   = loadfile['x'].shape[1]
    sub_pops = [[0,1,2,3,4,5,6,7,8], [0,1,2,3,4,5,6,7,8]]    
    obs_scheme = ObservationScheme(p=p, T=T, 
#                                   sub_pops=tuple([item for item in loadfile['sub_pops']]),
                                   sub_pops=sub_pops,
                                   obs_pops=loadfile['obs_pops'], 
                                   obs_time=loadfile['obs_time'])
    
    print(obs_scheme.sub_pops)
    
    pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)     
    tmp = {}
    for j in range(len(pars_true.keys())):
        tmp[str(pars_true.keys()[j])] =  pars_true.values()[j]
    pars_true = tmp    

    print('dataset #', idx)
    print('(T, p, n, eps) = ', (T, p, n, eps))
    
    ####################
    # pick initialiser #
    ####################

    
    for initialiser in initialisers:
        
        print('initialiser ', initialiser)

        if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
            os.chdir('../init/')

            num_reps = 1
            
            loadfile = np.load('init_'+filename)

            initkey =  initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
            
            pars_init = loadfile[initkey].reshape(1,)[0]
            tmp = {}
            for j in range(len(pars_init.keys())):
                tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
            pars_init = tmp    

            pars_init['R'] = np.diag(pars_init['R'])
            
            # for SSID-derived initialisers, also try out flipping parts of C
            if initialiser[-5:]=='_flip':
                idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1]) 
                pars_init['C'][idx0_no_overlap,:] *= -1


        elif initialiser=='random':
            
            num_reps = 10
            
        else:
            raise Exception('unexpected initialiser!')

        
        for rep in range(num_reps):

            print( 'run #' + str(rep+1) +'/' +str(num_reps) )
            
            if initialiser=='random':

                pars_init, _ = gen_pars(n, p, u_dim=0, 
                                     pars_in=None, 
                                     obs_scheme=obs_scheme,
                                     gen_A='diagonal', lts=0.99 * np.ones((n,)),
                                     gen_B='random', 
                                     gen_Q='identity', 
                                     gen_mu0='random', 
                                     gen_V0='identity', 
                                     gen_C='random', 
                                     gen_d='mean', 
                                     gen_R='fractionObserved',
                                     diag_R_flag=True,
                                     x=None, y=data.T, u=None)    

            ###################
            #    EM cycles    #
            ###################

            try:
                # get EM-step results after m iterations                    
                model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation                    
                print 'fitting #' + str(idx)    
                likes = [-np.inf]
                for t in progprint_xrange(max_iter):
                    likes.append(update(model))
                    if likes[-1]-likes[-2] < eps:
                        break

                stats_hat,pars_hat = collect_LDS_stats(model)

                # get EM-step results from true parameters
                model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
                model.E_step()
                stats_true,_ = collect_LDS_stats(model)
                model.M_step()
                
                broken = False
                y_out = model.states_list[0].data
                x_out = model.states_list[0].stateseq
                Pi = dtlyap(pars_true['A'], pars_true['Q'])
                Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
                Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
                Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))
                
            except:
                print('')
                print('############')
                print('#RUN BROKE!#')
                print('############')
                print('')
                
                broken = True
                y_out = model.states_list[0].data
                x_out = []
                pars_hat, stats_hat, stats_true = [],[],[]
                Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0

            if not broken:
                plt.figure()
                plt.subplot(1,3,1)
                plt.imshow(np.cov(data.T), interpolation='none')
                plt.subplot(1,3,2)
                if initialiser=='random':
                    pars_init['Pi'] = dtlyap(pars_init['A'], pars_init['Q'])         
                    R = pars_init['R']
                else:
                    R = np.diag(pars_init['R'])
                plt.imshow(pars_init['C'].dot(pars_init['Pi']).dot(pars_init['C'].T)+R, interpolation='none')
                plt.subplot(1,3,3)
                plt.imshow(pars_hat['C'].dot(Pi_h).dot(pars_hat['C'].T)+pars_hat['R'], interpolation='none')
                plt.show()                
            ###################
            #  store results  #
            ###################

            print('finished in ' + str(len(ll)-1) + ' many steps.')

            os.chdir('../fits/')

            save_file = initialiser + '_rep' + str(rep) + '_' + filename

            save_file_m = {'ifBroken':broken,
                           'll' : ll, 
                           'T' : T, 
                           'Trial': 1, 
                           'epsilon':eps,
                           'initPars':pars_init,
                           'estPars': pars_hat,
                           'stats_h': stats_hat,,
                           'sub_pops'=obs_scheme['sub_pops'],            
                           'obs_time'=obs_scheme['obs_time'],            
                           'obs_pops'=obs_scheme['obs_pops']}
            savemat(save_file,save_file_m) # does the actual saving

            np.savez(save_file, 
                    broken=broken,
                    ll=likes,
                    T=T, 
                    Trial=1, 
                    epsilon=eps,
                    initPars=pars_init,
                    estPars =pars_hat,
                    stats_h = stats_hat,
                    sub_pops=obs_scheme['sub_pops'],            
                    obs_time=obs_scheme['obs_time'],            
                    obs_pops=obs_scheme['obs_pops'])

Fit models for illustration 2


In [ ]:
%matplotlib inline
import numpy as np
from scipy.io import loadmat
import glob, os

absolute_code_path = '/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core/'
os.chdir(absolute_code_path)
import ssm_timeSeries as ts  # my time series overhead
import ssm_fit               # my library for state-space model fitting
from ssm_scripts import setup_fit_lds

from scipy.io import savemat # store results for comparison with Matlab code   

from __future__ import division
import matplotlib.pyplot as plt

from pybasicbayes.distributions import Regression
from pybasicbayes.util.text import progprint_xrange
from autoregressive.distributions import AutoRegression

import pyximport
pyximport.install()

absolute_code_path = '/home/mackelab/Desktop/Projects/Stitching/code/pyLDS_dev/'
os.chdir(absolute_code_path +'pylds')

from models import LDS, DefaultLDS
from distributions import Regression_diag, AutoRegression_input
from obs_scheme import ObservationScheme
from user_util import gen_pars, rand_rotation_matrix, init_LDS_model, collect_LDS_stats

from scipy.io import savemat # store results for comparison with Matlab code   
from scipy.linalg import solve_discrete_lyapunov as dtlyap # solve discrete-time Lyapunov equation

relative_data_path = '../../../results/cosyne_poster/illustration_2/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)

initialisers = ['random']
#initialisers = ['params', 'params_flip', 'random']

eps = np.log(1.01)
max_iter = 200

for idx in [0,6,9]: #[0,6,9]:
    
    ##################
    # load the data  #
    ##################
    
    filename = 'LDS_save_idx' + str(idx) + '.npz'
    print(filename)

    os.chdir('../data/')    
    loadfile = np.load(filename)

    data = loadfile['y']
    T,p = data.shape
    T = 1000
    data = data[:T,:]
    n   = loadfile['x'].shape[1]
    
    sub_pops = [[], list(range(0,p))]
    obs_pops = []
    obs_time = []

    obs_pops = [1]
    obs_time = [1] # population observed at t = 0
    i = 1;
    t = 0
    tempStitchOrder = 2
    while t < T:
        t = int(tempStitchOrder * i)
        obs_pops.append(0)   # observe empty subpop
        obs_time.append(t)   # until round(i*tempStitchOrder)
        obs_pops.append(1)   # then observe full subpop
        obs_time.append(t+1) # for 1 more time step

        i += 1    
    # ended while loop because last entry was >= t_tot !    
    while obs_time[-1] > T:
        obs_time.pop()
        obs_pops.pop()
    obs_time[-1] = T 
    obs_scheme={'sub_pops':sub_pops,
                'obs_pops':obs_pops, 
                'obs_time':obs_time}
    
    print(obs_scheme['sub_pops'])

    print('dataset #', idx)
    print('(T, p, n, eps) = ', (T, p, n, eps))
    
    ####################
    # pick initialiser #
    ####################

    
    for initialiser in initialisers:
        
        print('initialiser ', initialiser)

        if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
            os.chdir('../init/')

            num_reps = 1
            
            loadfile = np.load('init_'+filename)

            initkey =  initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
            
            pars_init = loadfile[initkey].reshape(1,)[0]
            tmp = {}
            for j in range(len(pars_init.keys())):
                tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
            pars_init = tmp    

            pars_init['R'] = np.diag(pars_init['R'])
            
            # for SSID-derived initialisers, also try out flipping parts of C
            if initialiser[-5:]=='_flip':
                idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1]) 
                pars_init['C'][idx0_no_overlap,:] *= -1


        elif initialiser=='random':
            
            num_reps = 10
            
        else:
            raise Exception('unexpected initialiser!')

        
        for rep in range(num_reps):

            print( 'run #' + str(rep+1) +'/' +str(num_reps) )
            
            if initialiser=='random':

                pars_init, _ = gen_pars(n, p, u_dim=0, 
                                     pars_in=None, 
                                     obs_scheme=obs_scheme,
                                     gen_A='diagonal', lts=0.99 * np.ones((n,)),
                                     gen_B='random', 
                                     gen_Q='identity', 
                                     gen_mu0='random', 
                                     gen_V0='identity', 
                                     gen_C='random', 
                                     gen_d='mean', 
                                     gen_R='fractionObserved',
                                     diag_R_flag=True,
                                     x=None, y=data.T, u=None)   
                
            ###################
            #    EM cycles    #
            ###################

            try:
                # fit the model to data          
                print('fitting model to data')

                pars_hat, ll = ssm_fit.fit_lds(y=data.T.reshape(p,T,1),
                                               u=[],
                                               x_dim=n,            
                                               obs_scheme=obs_scheme,
                                               pars=pars_init,
                                               max_iter=max_iter,
                                               epsilon=eps, 
                                               eps_cov=0,
                                               plot_flag=False, 
                                               trace_pars_flag=False, 
                                               trace_stats_flag=False, 
                                               diag_R_flag=True,
                                               use_A_flag=True, 
                                               use_B_flag=False,
                                               save_file=None)
                broken = False       
                print('fit_successful')
                pars_hat['R'] = np.diag(pars_hat['R'])
                
            except:
                print('')
                print('############')
                print('#RUN BROKE!#')
                print('############')
                print('')
                
                broken = True
                pars_hat, stats_hat, stats_true = [],[],[]

            if not broken:
                plt.figure()
                plt.subplot(1,3,1)
                plt.imshow(np.cov(data.T), interpolation='none')
                plt.subplot(1,3,2)
                if initialiser=='random':
                    pars_init['Pi'] = dtlyap(pars_init['A'], pars_init['Q'])         
                    R = pars_init['R']
                else:
                    R = np.diag(pars_init['R'])
                plt.imshow(pars_init['C'].dot(pars_init['Pi']).dot(pars_init['C'].T)+R, interpolation='none')
                plt.subplot(1,3,3)
                Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])  
                plt.imshow(pars_hat['C'].dot(Pi_h).dot(pars_hat['C'].T)+pars_hat['R'], interpolation='none')
                plt.show()                                

            
            ###################
            #  store results  #
            ###################

            print('finished in ' + str(len(ll)-1) + ' many steps.')

            os.chdir('../fits/')

            save_file = initialiser + '_rep' + str(rep) + '_' + filename


            save_file_m = {'ifBroken':broken,
                           'll' : ll, 
                           'T' : T, 
                           'Trial': 1, 
                           'epsilon':eps,
                           'initPars':pars_init,
                           'estPars': pars_hat,
                           'sub_pops':obs_scheme['sub_pops'],            
                           'obs_time':obs_scheme['obs_time'],            
                           'obs_pops':obs_scheme['obs_pops']}
            savemat(save_file,save_file_m) # does the actual saving

            np.savez(save_file, 
                    broken=broken,
                    ll=ll,
                    T=T, 
                    Trial=1, 
                    epsilon=eps,
                    initPars=pars_init,
                    estPars =pars_hat,
                    sub_pops=obs_scheme['sub_pops'],            
                    obs_time=obs_scheme['obs_time'],            
                    obs_pops=obs_scheme['obs_pops'])

Fit models for simulation #1


In [ ]:
import numpy as np
from scipy.io import loadmat
import glob, os

from scipy.io import savemat # store results for comparison with Matlab code   

from __future__ import division
import matplotlib.pyplot as plt

from pybasicbayes.distributions import Regression
from pybasicbayes.util.text import progprint_xrange
from autoregressive.distributions import AutoRegression

import pyximport
pyximport.install()

absolute_code_path = '/home/mackelab/Desktop/Projects/Stitching/code/pyLDS_dev/'
os.chdir(absolute_code_path +'pylds')

from models import LDS, DefaultLDS
from distributions import Regression_diag, AutoRegression_input
from obs_scheme import ObservationScheme
from user_util import gen_pars, rand_rotation_matrix, init_LDS_model, collect_LDS_stats

from scipy.io import savemat # store results for comparison with Matlab code   
from scipy.linalg import solve_discrete_lyapunov as dtlyap # solve discrete-time Lyapunov equation

relative_data_path = '../../../results/cosyne_poster/simulation_1/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)

def update(model):
    model.EM_step()
    return model.log_likelihood()
        
eps = np.log(1.01)
max_iter = 25

#initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']
initialisers = ['params', 'params_flip', 'random']

for i in idx_exps:
    
    ##################
    # load the data  #
    ##################
    
    filename = filenames[i]

    os.chdir('../data/')    
    loadfile = np.load(filename)

    data = loadfile['y']
    T,p = data.shape
    n   = loadfile['x'].shape[1]
    obs_scheme = ObservationScheme(p=p, T=T, 
                                   sub_pops=tuple([item for item in loadfile['sub_pops']]),
                                   obs_pops=loadfile['obs_pops'], 
                                   obs_time=loadfile['obs_time'])
    pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)     
    tmp = {}
    for j in range(len(pars_true.keys())):
        tmp[str(pars_true.keys()[j])] =  pars_true.values()[j]
    pars_true = tmp    

    print('dataset #', i)
    print('(T, p, n, eps) = ', (T, p, n, eps))
    
    ####################
    # pick initialiser #
    ####################

    
    for initialiser in initialisers:
        
        print('initialiser ', initialiser)

        if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
            os.chdir('../init/')

            num_reps = 1
            
            loadfile = np.load('init_'+filename)

            initkey =  initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser
            
            pars_init = loadfile[initkey].reshape(1,)[0]
            tmp = {}
            for j in range(len(pars_init.keys())):
                tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
            pars_init = tmp    

            pars_init['R'] = np.diag(pars_init['R'])
            
            # for SSID-derived initialisers, also try out flipping parts of C
            if initialiser[-5:]=='_flip':
                idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1]) 
                pars_init['C'][idx0_no_overlap,:] *= -1


        elif initialiser=='random':
            
            num_reps = 10
            
        else:
            raise Exception('unexpected initialiser!')

        
        for rep in range(num_reps):

            print( 'run #' + str(rep+1) +'/' +str(num_reps) )
            
            if initialiser=='random':

                pars_init, _ = gen_pars(n, p, u_dim=0, 
                                     pars_in=None, 
                                     obs_scheme=obs_scheme,
                                     gen_A='diagonal', lts=0.99 * np.ones((n,)),
                                     gen_B='random', 
                                     gen_Q='identity', 
                                     gen_mu0='random', 
                                     gen_V0='identity', 
                                     gen_C='random', 
                                     gen_d='mean', 
                                     gen_R='fractionObserved',
                                     diag_R_flag=True,
                                     x=None, y=data.T, u=None)    

            ###################
            #    EM cycles    #
            ###################

            try:
                # get EM-step results after m iterations                    
                model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation                    
                print 'fitting #' + str(i)    
                likes = [-np.inf]
                for t in progprint_xrange(max_iter):
                    likes.append(update(model))
                    if likes[-1]-likes[-2] < eps:
                        break

                stats_hat,pars_hat = collect_LDS_stats(model)

                # get EM-step results from true parameters
                model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
                model.E_step()
                stats_true,_ = collect_LDS_stats(model)
                model.M_step()
                
                broken = False
                y_out = model.states_list[0].data
                x_out = model.states_list[0].stateseq
                Pi = dtlyap(pars_true['A'], pars_true['Q'])
                Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
                Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
                Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))
                
                
            except:
                print('')
                print('############')
                print('#RUN BROKE!#')
                print('############')
                print('')
                
                broken = True
                y_out = model.states_list[0].data
                x_out = []
                pars_hat, stats_h, stats_true = [],[],[]
                Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0
                
            ###################
            #  store results  #
            ###################

            print('finished in ' + str(len(likes)-1) + ' many steps.')

            os.chdir('../fits/')

            save_file = initialiser + '_idx' + str(rep) + '_' + filename

            save_file_m = {'ifBroken':broken,
                           'x': x_out, 
                           'y': y_out,
                           'u' : [], 
                           'll' : likes, 
                           'T' : model.states_list[0].T, 
                           'Trial': len(model.states_list), 
                           'ifUseB':False, 
                           'ifUseA':True, 
                           'epsilon':eps,
                           'truePars':pars_true,
                           'initPars':pars_init,
                           'estPars': pars_hat,
                           'stats_h': stats_hat,
                           'stats_true': stats_true,
                           'Pi':Pi,
                           'Pi_h':Pi_h,
                           'Pi_t':Pi_t,
                           'Pi_t_h':Pi_t_h,
                           'obsScheme' : obs_scheme}
            savemat(save_file,save_file_m) # does the actual saving

            np.savez(save_file, 
                    broken=broken,
                    x=x_out,
                    y=y_out,
                    ll=likes,
                    T=model.states_list[0].T, 
                    Trial=len(model.states_list), 
                    ifUseA=True,
                    ifUseB=False,
                    epsilon=eps,
                    initPars=pars_init,
                    truePars=pars_true,
                    estPars =pars_hat,
                    stats_h = stats_hat,
                    stats_true = stats_true,
                    Pi=Pi,
                    Pi_h=Pi_h,
                    Pi_t=Pi_t,
                    Pi_t_h=Pi_t_h,
                    sub_pops=obs_scheme.sub_pops,            
                    obs_time=obs_scheme.obs_time,            
                    obs_pops=obs_scheme.obs_pops)

Fit models for simulation #2


In [1]:
import numpy as np
from scipy.io import loadmat
import glob, os

from scipy.io import savemat # store results for comparison with Matlab code   

from __future__ import division
import matplotlib.pyplot as plt

from pybasicbayes.distributions import Regression
from pybasicbayes.util.text import progprint_xrange
from autoregressive.distributions import AutoRegression

import pyximport
pyximport.install()

absolute_code_path = '/home/mackelab/Desktop/Projects/Stitching/code/pyLDS_dev/'
os.chdir(absolute_code_path +'pylds')

from models import LDS, DefaultLDS
from distributions import Regression_diag, AutoRegression_input
from obs_scheme import ObservationScheme
from user_util import gen_pars, rand_rotation_matrix, init_LDS_model, collect_LDS_stats

from scipy.io import savemat # store results for comparison with Matlab code   
from scipy.linalg import solve_discrete_lyapunov as dtlyap # solve discrete-time Lyapunov equation

relative_data_path = '../../../results/cosyne_poster/simulation_2/data'
os.chdir(relative_data_path)
filenames = glob.glob("*.npz")
num_exps = len(filenames)
idx_exps = range(num_exps)

def update(model):
    model.EM_step()
    return model.log_likelihood()
        
eps = np.log(1.01)
max_iter = 1000

initialisers = ['params', 'params_flip', 'params_naive', 'params_naive_flip', 'random']

for i in idx_exps:
    
    ##################
    # load the data  #
    ##################
    
    filename = filenames[i]

    os.chdir('../data/')    
    loadfile = np.load(filename)

    data = loadfile['y']
    T,p = data.shape
    n   = loadfile['x'].shape[1]
    
    pars_true = loadfile['truePars'].reshape(1,)[0] # some numpy (cross-version 3.x-2.x?)     
    tmp = {}
    for j in range(len(pars_true.keys())):
        tmp[str(pars_true.keys()[j])] =  pars_true.values()[j]
    pars_true = tmp    

    print('dataset #', i)
    print('(T, p, n, eps) = ', (T, p, n, eps))
    
    ####################
    # pick initialiser #
    ####################

    os.chdir('../init/')
    initfiles = glob.glob("*_LDS_save_idx" + str(i) + ".npz")
    num_reps = len(initfiles)
    
    for rep in range(num_reps):
        
        initfile = initfiles[rep]
        print('rep ' + str(rep) + '/' + str(num_reps))
        os.chdir('../init/')
        loadinit = np.load(initfile)

        obs_scheme = ObservationScheme(p=p, T=T, 
                                   sub_pops=tuple([item for item in loadinit['sub_pops']]),
                                   obs_pops=loadinit['obs_pops'], 
                                   obs_time=loadinit['obs_time'])
        
        for initialiser in initialisers:

            print('initialiser ', initialiser)
            if initialiser in ['params', 'params_flip', 'params_naive', 'params_naive_flip']:
                
                num_repets = 1
                initkey =  initialiser[:-5] if initialiser[-5:]=='_flip' else initialiser

                pars_init = loadinit[initkey].reshape(1,)[0]
                tmp = {}
                for j in range(len(pars_init.keys())):
                    tmp[str(pars_init.keys()[j])] =  pars_init.values()[j]
                pars_init = tmp    

                pars_init['R'] = np.diag(pars_init['R'])

                # for SSID-derived initialisers, also try out flipping parts of C
                if initialiser[-5:]=='_flip':
                    idx0_no_overlap = np.setdiff1d(np.arange(p),obs_scheme.sub_pops[1]) 
                    pars_init['C'][idx0_no_overlap,:] *= -1


            elif initialiser=='random':

                num_repets = 5

            else:
                raise Exception('unexpected initialiser!')


            for repet in range(num_repets):

                print( 'run #' + str(repet+1) +'/' +str(num_repets) )

                if initialiser=='random':

                    pars_init, _ = gen_pars(n, p, u_dim=0, 
                                         pars_in=None, 
                                         obs_scheme=obs_scheme,
                                         gen_A='diagonal', lts=0.99 * np.ones((n,)),
                                         gen_B='random', 
                                         gen_Q='identity', 
                                         gen_mu0='random', 
                                         gen_V0='identity', 
                                         gen_C='random', 
                                         gen_d='mean', 
                                         gen_R='fractionObserved',
                                         diag_R_flag=True,
                                         x=None, y=data.T, u=None)    

                ###################
                #    EM cycles    #
                ###################

                try:
                    # get EM-step results after m iterations                    
                    model = init_LDS_model(pars_init, data, obs_scheme) # reset to initialisation                    
                    print 'fitting #' + str(i)    
                    likes = [-np.inf]
                    for t in progprint_xrange(max_iter):
                        likes.append(update(model))
                        if likes[-1]-likes[-2] < eps:
                            break

                    stats_hat,pars_hat = collect_LDS_stats(model)

                    # get EM-step results from true parameters
                    model = init_LDS_model(pars_true, data, obs_scheme) # reset to true pars
                    model.E_step()
                    stats_true,_ = collect_LDS_stats(model)
                    model.M_step()

                    broken = False
                    y_out = model.states_list[0].data
                    x_out = model.states_list[0].stateseq
                    Pi = dtlyap(pars_true['A'], pars_true['Q'])
                    Pi_h = dtlyap(pars_hat['A'], pars_hat['Q'])
                    Pi_t = pars_true['A'].dot(dtlyap(pars_true['A'], pars_true['Q']))
                    Pi_t_h = pars_hat['A'].dot(dtlyap(pars_hat['A'], pars_hat['Q']))


                except:
                    print('')
                    print('############')
                    print('#RUN BROKE!#')
                    print('############')
                    print('')

                    broken = True
                    y_out = model.states_list[0].data
                    x_out = []
                    pars_hat, stats_hat, stats_true = [],[],[]
                    Pi, Pi_h, Pi_t, Pi_t_h = 0,0,0,0

                if not broken:
                    plt.figure()
                    plt.subplot(1,3,1)
                    plt.imshow(np.cov(data.T),interpolation='none')
                    plt.subplot(1,3,2)
                    if initialiser=='random':
                        pars_init['Pi'] = dtlyap(pars_init['A'], pars_init['Q'],interpolation='none')
                        R = pars_init['R']
                    else:
                        R = np.diag(pars_init['R'])
                    plt.imshow(pars_init['C'].dot(pars_init['Pi']).dot(pars_init['C'].T) + R,interpolation='none')
                    plt.subplot(1,3,2)
                    plt.imshow(pars_hat['C'].dot(Pi_h).dot(pars_hat['C'].T) + pars_hat['R'],interpolation='none')
                    
                ###################
                #  store results  #
                ###################

                print('finished in ' + str(len(likes)-1) + ' many steps.')

                os.chdir('../fits/')

                save_file = initialiser + '_rep' + str(rep) + '_repet' + str(repet) + '_' + filename

                save_file_m = {'ifBroken':broken,
                               'x': x_out, 
                               'y': y_out,
                               'u' : [], 
                               'll' : likes, 
                               'T' : model.states_list[0].T, 
                               'Trial': len(model.states_list), 
                               'ifUseB':False, 
                               'ifUseA':True, 
                               'epsilon':eps,
                               'truePars':pars_true,
                               'initPars':pars_init,
                               'estPars': pars_hat,
                               'stats_h': stats_hat,
                               'stats_true': stats_true,
                               'Pi':Pi,
                               'Pi_h':Pi_h,
                               'Pi_t':Pi_t,
                               'Pi_t_h':Pi_t_h,
                               'sub_pops':obs_scheme.sub_pops,            
                               'obs_time':obs_scheme.obs_time,            
                               'obs_pops':obs_scheme.obs_pops}
                savemat(save_file,save_file_m) # does the actual saving

                np.savez(save_file, 
                        broken=broken,
                        x=x_out,
                        y=y_out,
                        ll=likes,
                        T=model.states_list[0].T, 
                        Trial=len(model.states_list), 
                        ifUseA=True,
                        ifUseB=False,
                        epsilon=eps,
                        initPars=pars_init,
                        truePars=pars_true,
                        estPars =pars_hat,
                        stats_h = stats_hat,
                        stats_true = stats_true,
                        Pi=Pi,
                        Pi_h=Pi_h,
                        Pi_t=Pi_t,
                        Pi_t_h=Pi_t_h,
                        sub_pops=obs_scheme.sub_pops,            
                        obs_time=obs_scheme.obs_time,            
                        obs_pops=obs_scheme.obs_pops)

In [ ]:
initfile