Model cross evaluation using hmmlearn

Etienne Ackermann, 08/05/2015

The data can be downloaded from the CRCNS (Collaborative Research in Computational Neuroscience) website, and the hc-3 data set in particular.

Summary

Note that both gor01 and vvp01 were recorded with a Neuralynx recording system at 32,552 Hz, then amplified 1000x, followed by 1–5 kHz bandpass filtering. It was then further downsampled to 1252 Hz, and stored in the .eeg file, which is a binary file with the same number of channels as in the raw data. Has up to 32 (four shank probe) or 64 (eight shank probe) short integers (2 bytes, signed) every time step. (One integer from each recording site, i.e. channel). Actual number of channels is specified in the .xml file and is not always a multiple of 8 because bad channels (due to very high impedance or broken shank) were removed from the data.


Load observation data

At this point we are interested in learning from the data, and we do this by considering sequences of spike count observations. That is, each observation is a $u$-dimensional vector, where $u$ is the number of units that we recorded, with each element corresponds to the number of spikes observed within a time bin of 250 or 10 ms, depending on whether we are considering the behavioral, or the replay timescales.


In [94]:
import pickle

# load SWR spikes for both environments (binned in 10 ms):
with open('../../Data/SWR_train1rn.pickle', 'rb') as f:
    SWR_train1 = pickle.load(f)
with open('../../Data/SWR_train2rn.pickle', 'rb') as f:
    SWR_train2 = pickle.load(f)
    
with open('../../Data/SWR_test1rn.pickle', 'rb') as f:
    SWR_test1 = pickle.load(f)
with open('../../Data/SWR_test2rn.pickle', 'rb') as f:
    SWR_test2 = pickle.load(f)
    
# load training spike count data, binned in 250 ms bins:
with open('../../Data/BVR_train1rn_noswr_noth1.pickle', 'rb') as f:
    BVR_train1 = pickle.load(f)
with open('../../Data/BVR_train2rn_noswr_th1.pickle', 'rb') as f:
    BVR_train2 = pickle.load(f)    
    
# load testing spike count data, binned in 250 ms bins:
with open('../../Data/BVR_test1rn_noswr_noth1.pickle', 'rb') as f:
    BVR_test1 = pickle.load(f)  
with open('../../Data/BVR_test2rn_noswr_th1.pickle', 'rb') as f:
    BVR_test2 = pickle.load(f)

Load Python modules and helper functions


In [2]:
import sys
sys.path.insert(0, '../')

import numpy as np
import pandas as pd
import pickle 
import seaborn as sns
#import yahmm as ym
from hmmlearn import hmm # see https://github.com/ckemere/hmmlearn


from matplotlib import pyplot as plt
from pandas import Series, DataFrame

from efunctions import * # load my helper functions

%matplotlib inline


function saveFigure(filename) loaded

Tip: to save a figure, call saveFigure("path/figure.pdf")

In [99]:
def remove_empty_sequences(data):
    tmp = np.array(data)
    emptySequences = []
    num_sequences = tmp.shape[0]
    for ss in np.arange(0,num_sequences):
        if len(tmp[ss]) == 0:
            emptySequences.append(ss)
            
    print(emptySequences)
    for dd in np.arange(0,len(emptySequences)):
        idx = emptySequences[len(emptySequences)-dd-1]
        del data[idx]

In [104]:
def yahmmdata_to_hmmlearn_data(yahmmdata):
    SequenceLengths = []

    tmp = np.array(yahmmdata)
    num_sequences = tmp.shape[0]
    for ss in np.arange(0,num_sequences):
        SequenceLengths.append(len(tmp[ss]))

    numCells = np.array(tmp[0]).shape[1]
    
    TotalSequenceLength = np.array(SequenceLengths).sum()
    
    StackedData = np.zeros((TotalSequenceLength,numCells))
    rr = 0;
    for ss in np.arange(0,num_sequences):
        StackedData[rr:rr+SequenceLengths[ss],:] = np.array(tmp[ss])
        rr = rr+SequenceLengths[ss]
    
    print("{0} sequences stacked for hmmlearn".format(num_sequences))
    
    return SequenceLengths, StackedData

In [105]:
TrainingSequenceLengthsBVRtr2, StackedTrainingDataBVRtr2 = yahmmdata_to_hmmlearn_data(BVR_train2)
TrainingSequenceLengthsSWRtr2, StackedTrainingDataSWRtr2 = yahmmdata_to_hmmlearn_data(SWR_train2)
TestingSequenceLengthsSWRts2, StackedTestingDataSWRts2 = yahmmdata_to_hmmlearn_data(SWR_test2)
TestingSequenceLengthsBVRts2, StackedTestingDataBVRts2 = yahmmdata_to_hmmlearn_data(BVR_test2)

TrainingSequenceLengthsBVRtr1, StackedTrainingDataBVRtr1 = yahmmdata_to_hmmlearn_data(BVR_train1)
TestingSequenceLengthsSWRts1, StackedTestingDataSWRts1 = yahmmdata_to_hmmlearn_data(SWR_test1)
TrainingSequenceLengthsSWRtr1, StackedTrainingDataSWRtr1 = yahmmdata_to_hmmlearn_data(SWR_train1)
remove_empty_sequences(BVR_test1)
TestingSequenceLengthsBVRts1, StackedTestingDataBVRts1 = yahmmdata_to_hmmlearn_data(BVR_test1) # error ??? will investigate


63 sequences stacked for hmmlearn
357 sequences stacked for hmmlearn
321 sequences stacked for hmmlearn
48 sequences stacked for hmmlearn
171 sequences stacked for hmmlearn
348 sequences stacked for hmmlearn
314 sequences stacked for hmmlearn
[]
155 sequences stacked for hmmlearn

In [5]:
# learn (fit) models to data:

NStates = 15

# using BVR1 training data
hmm15_1 = hmm.PoissonHMM(n_components=NStates, n_iter=20, init_params='stm', params='stm', verbose=True)
hmm15_1.fit(StackedTrainingDataBVRtr1, lengths=TrainingSequenceLengthsBVRtr1)

# using BVR2 training data
hmm15_2 = hmm.PoissonHMM(n_components=NStates, n_iter=20, init_params='stm', params='stm', verbose=True)
hmm15_2.fit(StackedTrainingDataBVRtr2, lengths=TrainingSequenceLengthsBVRtr2)

# using SWR1 training data
hmm15_swr1 = hmm.PoissonHMM(n_components=NStates, n_iter=20, init_params='stm', params='stm', verbose=True)
hmm15_swr1.fit(StackedTrainingDataSWRtr1, lengths=TrainingSequenceLengthsSWRtr1)

# using SWR2 training data
hmm15_swr2 = hmm.PoissonHMM(n_components=NStates, n_iter=20, init_params='stm', params='stm', verbose=True)
hmm15_swr2.fit(StackedTrainingDataSWRtr2, lengths=TrainingSequenceLengthsSWRtr2)


         2     -192983.0077       +6859.4766
         3     -191034.6994       +1948.3082
         4     -190210.8798        +823.8196
         5     -189822.5356        +388.3442
         6     -189579.9562        +242.5794
         7     -189351.6029        +228.3532
         8     -189142.5243        +209.0786
         9     -188882.4264        +260.0979
        10     -188648.6795        +233.7469
        11     -188487.2899        +161.3896
        12     -188368.0032        +119.2867
        13     -188259.3804        +108.6228
        14     -188190.7318         +68.6486
        15     -188144.5432         +46.1886
        16     -188083.5215         +61.0217
        17     -188038.3913         +45.1303
        18     -187999.9554         +38.4359
        19     -187968.2485         +31.7069
         2      -44840.4063       +2295.9752
         3      -44446.3322        +394.0740
         4      -44278.9509        +167.3813
         5      -44161.9095        +117.0414
         6      -44088.1965         +73.7130
         7      -44039.6664         +48.5301
         8      -44007.8297         +31.8367
         9      -43984.9506         +22.8791
        10      -43968.8606         +16.0900
        11      -43957.6005         +11.2601
        12      -43940.2676         +17.3329
        13      -43929.7632         +10.5044
        14      -43927.8606          +1.9027
        15      -43924.6036          +3.2569
        16      -43919.3725          +5.2311
        17      -43911.3415          +8.0310
        18      -43906.2735          +5.0680
        19      -43905.1911          +1.0824
         2      -51834.0285       +2053.3614
         3      -51003.0629        +830.9656
         4      -50526.0182        +477.0447
         5      -50193.0266        +332.9916
         6      -49912.1096        +280.9170
         7      -49665.2745        +246.8351
         8      -49450.4899        +214.7846
         9      -49309.0941        +141.3958
        10      -49183.9617        +125.1324
        11      -49104.4720         +79.4897
        12      -49027.3816         +77.0904
        13      -48959.7111         +67.6704
        14      -48908.7675         +50.9436
        15      -48851.1675         +57.6000
        16      -48816.0561         +35.1115
        17      -48776.9430         +39.1130
        18      -48742.8423         +34.1007
        19      -48712.7076         +30.1347
         2      -54537.1550       +2535.1693
         3      -53214.6539       +1322.5010
         4      -52555.9173        +658.7367
         5      -52187.6441        +368.2732
         6      -51949.9239        +237.7202
         7      -51812.4731        +137.4509
         8      -51713.5846         +98.8885
         9      -51651.9634         +61.6211
        10      -51575.3406         +76.6229
        11      -51529.1759         +46.1647
        12      -51474.7657         +54.4102
        13      -51450.7858         +23.9799
        14      -51420.5676         +30.2182
        15      -51409.1769         +11.3907
        16      -51384.5015         +24.6754
        17      -51366.0742         +18.4273
        18      -51336.3360         +29.7381
        19      -51305.7993         +30.5367
Out[5]:
PoissonHMM(algorithm='viterbi', init_params='stm', means_prior=0,
      means_weight=0, n_components=15, n_iter=20, params='stm',
      random_state=None, startprob_prior=1.0, tol=0.01, transmat_prior=1.0,
      verbose=True)

In [106]:
####################################################################
# BVR test {1,2} in BVR train {1}
####################################################################

bvr2_in_bvr1_log_prob_test = np.zeros((len(TestingSequenceLengthsBVRts2),1))
bvr1_in_bvr1_log_prob_test = np.zeros((len(TestingSequenceLengthsBVRts1),1))

seqlimits = np.cumsum(np.array([0] + TestingSequenceLengthsBVRts2))
for ee in np.arange(0,len(TestingSequenceLengthsBVRts2)):
    obs = StackedTestingDataBVRts2[seqlimits[ee]:seqlimits[ee+1],:]
    bvr2_in_bvr1_log_prob_test[ee] = hmm15_1.score(obs)

seqlimits = np.cumsum(np.array([0] + TestingSequenceLengthsBVRts1))
for ee in np.arange(0,len(TestingSequenceLengthsBVRts1)):
    obs = StackedTrainingDataBVRtr1[seqlimits[ee]:seqlimits[ee+1],:]
    bvr1_in_bvr1_log_prob_test[ee] = hmm15_1.score(obs)

####################################################################
# SWR test {1,2} in BVR train {2}
####################################################################

swr2_in_bvr2_log_prob_test = np.zeros((len(TestingSequenceLengthsSWRts2),1))
swr1_in_bvr2_log_prob_test = np.zeros((len(TestingSequenceLengthsSWRts1),1))

seqlimits = np.cumsum(np.array([0] + TestingSequenceLengthsSWRts2))
for ee in np.arange(0,len(TestingSequenceLengthsSWRts2)):
    obs = StackedTestingDataSWRts2[seqlimits[ee]:seqlimits[ee+1],:]*25
    swr2_in_bvr2_log_prob_test[ee] = hmm15_2.score(obs)

seqlimits = np.cumsum(np.array([0] + TestingSequenceLengthsSWRts1))
for ee in np.arange(0,len(TestingSequenceLengthsSWRts1)):
    obs = StackedTestingDataSWRts1[seqlimits[ee]:seqlimits[ee+1],:]*25
    swr1_in_bvr2_log_prob_test[ee] = hmm15_2.score(obs)
    
####################################################################
# SWR test {1,2} in BVR train {1,2} --- log odds
####################################################################

swr1ts_in_bvr1tr_log_prob = np.zeros((len(TestingSequenceLengthsSWRts1),1))
swr2ts_in_bvr1tr_log_prob = np.zeros((len(TestingSequenceLengthsSWRts2),1))

swr1ts_in_bvr2tr_log_prob = np.zeros((len(TestingSequenceLengthsSWRts1),1))
swr2ts_in_bvr2tr_log_prob = np.zeros((len(TestingSequenceLengthsSWRts2),1))

seqlimits = np.cumsum(np.array([0] + TestingSequenceLengthsSWRts1))
for ee in np.arange(0,len(TestingSequenceLengthsSWRts1)):
    obs = StackedTestingDataSWRts1[seqlimits[ee]:seqlimits[ee+1],:]*25
    swr1ts_in_bvr1tr_log_prob[ee] = hmm15_1.score(obs)
    swr1ts_in_bvr2tr_log_prob[ee] = hmm15_2.score(obs)
    
seqlimits = np.cumsum(np.array([0] + TestingSequenceLengthsSWRts2))
for ee in np.arange(0,len(TestingSequenceLengthsSWRts2)):
    obs = StackedTestingDataSWRts2[seqlimits[ee]:seqlimits[ee+1],:]*25
    swr2ts_in_bvr1tr_log_prob[ee] = hmm15_1.score(obs)
    swr2ts_in_bvr2tr_log_prob[ee] = hmm15_2.score(obs)

In [107]:
sns.set(rc={'figure.figsize': (6, 4),'lines.linewidth': 3, 'font.size': 16, 'axes.labelsize': 14, 'legend.fontsize': 12, 'ytick.labelsize': 12, 'xtick.labelsize': 12 })
sns.set_style("white")
f, ( ax1) = plt.subplots(1,1)

sns.distplot( bvr2_in_bvr1_log_prob_test, bins=20, hist=False, kde=True, rug=False, axlabel='log probability', ax=ax1, kde_kws={"lw": 3, "label": "log p(Y=2|e=1)"}  );
sns.distplot( bvr1_in_bvr1_log_prob_test, bins=20, hist=False, kde=True, rug=False, axlabel='log probability', ax=ax1, kde_kws={"lw": 3, "label": "log p(Y=1|e=1)"} );

ax1.set_title('BVR_test{1,2} in BVR_train{1}')

#saveFigure("figures/?.pdf")


Out[107]:
<matplotlib.text.Text at 0x112a6ba90>

In [108]:
sns.set(rc={'figure.figsize': (6, 4),'lines.linewidth': 3, 'font.size': 16, 'axes.labelsize': 14, 'legend.fontsize': 12, 'ytick.labelsize': 12, 'xtick.labelsize': 12 })
sns.set_style("white")
f, ( ax1 ) = plt.subplots(1,1)

sns.distplot( swr1_in_bvr2_log_prob_test, bins=20, hist=False, kde=True, rug=False, axlabel='log probability', ax=ax1, kde_kws={"lw": 3, "label": "log p(Y=1|e=2)"} );
sns.distplot( swr2_in_bvr2_log_prob_test, bins=20, hist=False, kde=True, rug=False, axlabel='log probability', ax=ax1, kde_kws={"lw": 3, "label": "log p(Y=2|e=2)"} );

ax1.set_title('SWR_test{1,2} in BVR_train{2}')

#saveFigure("figures/?.pdf")


Out[108]:
<matplotlib.text.Text at 0x112a044a8>

In [24]:
sns.set(rc={'figure.figsize': (6, 4),'lines.linewidth': 3, 'font.size': 16, 'axes.labelsize': 14, 'legend.fontsize': 12, 'ytick.labelsize': 12, 'xtick.labelsize': 12 })
sns.set_style("white")
f, ( ax1) = plt.subplots(1)

sns.distplot( swr1ts_in_bvr1tr_log_prob - swr1ts_in_bvr2tr_log_prob, bins=20, hist=False, kde=True, rug=False, axlabel='log probability', ax=ax1, kde_kws={"lw": 3, "label": "log p(Y=1|e=1)/p(Y=1|e=2)"} );
sns.distplot( swr2ts_in_bvr1tr_log_prob - swr2ts_in_bvr2tr_log_prob, bins=20, hist=False, kde=True, rug=False, axlabel='log probability', ax=ax1, kde_kws={"lw": 3, "label": "log p(Y=2|e=1)/p(Y=2|e=2)"}  );

ax1.set_title('log odds of SWR (test) sequences in BVR (train) models')

#saveFigure("figures/?.pdf")


Out[24]:
<matplotlib.text.Text at 0x1153de5f8>