NOTE: After some discussion, Caleb and I reformulated our desired score, where we now evaluate the sequential part using the Kullback-Leibler divergence. The score has not quite been finalized, and I will write a more comprehensive introduction/motivation here soon, but this notebook serves as a testbed and proof-of-concept for the KL-based approach. So far it seems to do remarkably well, and it is definitely conceptually nicer. The contextual score has also been modified slightly, where we now use more information than simply the maximum most likely state at each time point.
Here I will compute and evaluate the Kullback–Leibler (KL) divergence based sequence score for HMMs.
In particular, I base the contextual component on $$\dfrac{1}{|\mathcal{S}|}\sum_{S\in\mathcal{S}} \text{Pr}(\mathbf{y}_t|S)$$ where we sum over all possible states. The choice of not weighing by $p(S)$ is intentional (usually we marginalize as $p(\mathbf{y}_t) = \langle p(\mathbf{y}_t | S), p(S) \rangle$).
The sequential component is based on the average KL divergence from the expected state evolution to the posterior state distribution.
If we start with $\boldsymbol{\pi}$, we can compute (and visualize!) its state distribution evolution $\{S'_0, S_1', S_2', \ldots\} \equiv \{\boldsymbol{\pi}, \boldsymbol{\pi}\mathbf{A}, \boldsymbol{\pi}\mathbf{A}^2, \ldots \}$ This is the a priori expected state evolution. Note that we assume $\mathbf{A}_{ij} \equiv \text{Pr}(S_{t+1}=j|S_t=i)$. If however we use our observations to arrive at posterior state estimates, then we have the posterior state evolution (using, e.g. the forward-backward algorithm). For the sequential component, we then consider $$ D_\text{KL}(U||V) \text{ with } U \stackrel{\Delta}{=} p(S_{t+1}|\mathbf{y}_{1:T}) \text{ and } V \stackrel{\Delta}{=} p(S_t)\mathbf{A} $$ where $p(S_{t+1})$ and $p(S_t)$ are posterior state distributions. Here we probably have to be a little more careful with notation. Note that in general $p(S_t)\mathbf{A} \neq S_{t+1}'$.
Let $m=|\mathcal{S}|$, so that $\mathbf{A}\in \mathbb{R}^{m\times m}$ and $p(S_t)\in \mathbb{R}^{1\times m}$.
See also https://www.quora.com/What-is-a-good-laymans-explanation-for-the-Kullback-Leibler-Divergence
In [133]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sys
sys.path.insert(0, 'helpers')
from efunctions import * # load my helper function(s) to save pdf figures, etc.
from hc3 import load_data, get_sessions
import klabtools as klab
import seqtools as sq
%matplotlib inline
#mpld3.enable_notebook()
sns.set(rc={'figure.figsize': (12, 4),'lines.linewidth': 1, 'font.size': 18, 'axes.labelsize': 16,
'legend.fontsize': 12, 'ytick.labelsize': 12, 'xtick.labelsize': 12 })
sns.set_style("white")
In [134]:
from hmmlearn import hmm # see https://github.com/ckemere/hmmlearn
import importlib
importlib.reload(sq) # reload module here only while prototyping...
importlib.reload(klab) # reload module here only while prototyping...
Out[134]:
In [135]:
import sys
import time
from IPython.display import display, clear_output
#for i in range(10):
# time.sleep(0.25)
# clear_output(wait=True)
# print(i)
# sys.stdout.flush()
In [136]:
datadirs = ['/home/etienne/Dropbox/neoReader/Data',
'C:/etienne/Dropbox/neoReader/Data',
'/Users/etienne/Dropbox/neoReader/Data']
fileroot = next( (dir for dir in datadirs if os.path.isdir(dir)), None)
In [137]:
animal = 'gor01'; month,day = (6,7); session = '11-26-53' # WARNING! POSITION DATA INCOMPLETE!
animal = 'gor01'; month,day = (6,7); session = '16-40-19' # 91 units
#animal = 'gor01'; month,day = (6,12); session = '15-55-31' # 55 units
#animal = 'gor01'; month,day = (6,12); session = '16-53-46' # 55 units
#animal = 'gor01'; month,day = (6,13); session = '14-42-6' # 58 units
#animal = 'gor01'; month,day = (6,13); session = '15-22-3' # 58 units
#animal = 'vvp01'; month,day = (4,9); session = '16-40-54' # ?? units
#animal = 'vvp01'; month,day = (4,9); session = '17-29-30' # ?? units
#animal = 'vvp01'; month,day = (4,10); session = '12-25-50' # lin1; ?? units
#animal = 'vvp01'; month,day = (4,10); session = '12-58-3' # lin2; ?? units
#animal = 'vvp01'; month,day = (4,10); session = '19-11-57' # lin2; ?? units
#animal = 'vvp01'; month,day = (4,10); session = '21-2-40' # lin1; ?? units
#animal = 'vvp01'; month,day = (4,18); session = '13-06-01' # lin1; ?? units
#animal = 'vvp01'; month,day = (4,18); session = '13-28-57' # lin2; ?? units
#animal = 'vvp01'; month,day = (4,18); session = '15-23-32' # lin1; ?? units
#animal = 'vvp01'; month,day = (4,18); session = '15-38-02' # lin2; ?? units
spikes = load_data(fileroot=fileroot, datatype='spikes',animal=animal, session=session, month=month, day=day, fs=32552, verbose=True)
eeg = load_data(fileroot=fileroot, datatype='eeg', animal=animal, session=session, month=month, day=day,channels=[0,1,2], fs=1252, starttime=0, verbose=True)
posdf = load_data(fileroot=fileroot, datatype='pos',animal=animal, session=session, month=month, day=day, verbose=True)
speed = klab.get_smooth_speed(posdf,fs=60,th=8,cutoff=0.5,showfig=True,verbose=True)
In [138]:
## bin ALL spikes
ds = 0.125 # bin spikes into 125 ms bins (theta-cycle inspired)
binned_spikes = klab.bin_spikes(spikes.data, ds=ds, fs=spikes.samprate, verbose=True)
## identify boundaries for running (active) epochs and then bin those observations into separate sequences:
runbdries = klab.get_boundaries_from_bins(eeg.samprate,bins=speed.active_bins,bins_fs=60)
binned_spikes_bvr = klab.bin_spikes(spikes.data, fs=spikes.samprate, boundaries=runbdries, boundaries_fs=eeg.samprate, ds=ds)
## stack data for hmmlearn:
seq_stk_bvr = sq.data_stack(binned_spikes_bvr, verbose=True)
seq_stk_all = sq.data_stack(binned_spikes, verbose=True)
## split data into train, test, and validation sets:
tr_b,vl_b,ts_b = sq.data_split(seq_stk_bvr, tr=60, vl=10, ts=30, randomseed = 0, verbose=True)
## train HMM on active behavioral data; training set (with a fixed, arbitrary number of states for now):
myhmm = sq.hmm_train(tr_b, num_states=35, n_iter=50, verbose=False)
In [139]:
myhmm.transmat_.sum(axis=1) # confirm orientation of transition prob matrix
A = myhmm.transmat_.copy()
fig, ax = plt.subplots(figsize=(3.5, 3))
im = ax.matshow(A, interpolation='none', cmap='RdPu')
# Make an axis for the colorbar on the right side
cax = fig.add_axes([0.9, 0.1, 0.03, 0.8])
fig.colorbar(im, cax=cax)
Out[139]:
In [140]:
m = myhmm.n_components
Pi = myhmm.startprob_.copy()
Pi = np.reshape(Pi,(1,m))
fig, ax = plt.subplots(figsize=(6, 2))
ax.stem(np.transpose(Pi),':k')
Out[140]:
In [141]:
fig, ax = plt.subplots(figsize=(6, 2))
ax.matshow(Pi)
PiA = np.dot(Pi,A)
ax.matshow(np.vstack([Pi,PiA,np.dot(PiA,A)]))
Out[141]:
In [142]:
def advance_states(St,A,n):
count = 1
St = np.dot(St,A)
while count <= n:
yield St
count += 1
St = np.dot(St,A)
In [143]:
numsteps = 50
prior_evo = np.zeros((numsteps+1,m))
prior_evo[0,:] = Pi
for ii, S in enumerate(advance_states(Pi, A, numsteps)):
prior_evo[ii+1,:] = S
fig, ax = plt.subplots(figsize=(10, 3))
ax.matshow(np.transpose(prior_evo))
Out[143]:
Remark: Note that steady-state state information is related to the eigenvectors of $\mathbf{A}$, since $p(S)\mathbf{A} = p(S) \implies p(S)$ is a steady-state state distribution, such that $p(S)$ is an eigenvector of $\mathbf{A}$ with associated eigenvalue $\lambda = 1$.
In [144]:
seq_id = 0
tmpseqbdries = [0]; tmpseqbdries.extend(np.cumsum(ts_b.sequence_lengths).tolist());
obs = ts_b.data[tmpseqbdries[seq_id]:tmpseqbdries[seq_id+1],:]
ll, pp = myhmm.score_samples(obs)
fig, ax = plt.subplots(figsize=(10, 3))
ax.matshow(np.transpose(pp))
Out[144]:
In [145]:
def advance_states_one(pp, A):
return np.dot(pp, A)
In [146]:
ppp = advance_states_one(np.vstack([Pi,pp[:pp.shape[0]-1,:]]), A)
numsteps = pp.shape[0]-1
prior_evo = np.zeros((numsteps+1,m))
prior_evo[0,:] = Pi
for ii, S in enumerate(advance_states(Pi, A, numsteps)):
prior_evo[ii+1,:] = S
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(6, 3))
ax1.matshow(np.transpose(prior_evo),cmap='OrRd'); ax1.set_xlabel('prior')
ax2.matshow(np.transpose(ppp),cmap='OrRd'); ax2.set_xlabel('predicted')
ax3.matshow(np.transpose(pp),cmap='OrRd'); ax3.set_xlabel('posterior')
Out[146]:
In [147]:
# sort model states:
new_order = [0]
rem_states = np.arange(1,m).tolist()
cs = 0
for ii in np.arange(0,m-1):
nstilde = np.argmax(A[cs,rem_states])
ns = rem_states[nstilde]
rem_states.remove(ns)
cs = ns
new_order.append(cs)
Anew = A[:, new_order][new_order]
In [148]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(6, 3))
ax1.matshow(np.transpose(prior_evo)[new_order,:],cmap='OrRd'); ax1.set_xlabel('prior')
ax2.matshow(np.transpose(ppp)[new_order,:],cmap='OrRd'); ax2.set_xlabel('predicted')
ax3.matshow(np.transpose(pp)[new_order,:],cmap='OrRd'); ax3.set_xlabel('posterior')
Out[148]:
In [149]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
im = ax1.matshow(myhmm.means_, cmap='OrRd', vmin=0, vmax=16); ax1.set_xlabel('before sorting states')
im = ax2.matshow(myhmm.means_[new_order,:], cmap='OrRd', vmin=0, vmax=16); ax2.set_xlabel('after sorting states')
# Make an axis for the colorbar on the right side
cax = fig.add_axes([0.9, 0.1, 0.03, 0.8])
fig.colorbar(im, cax=cax)
Out[149]:
In [150]:
ds = 1/60 # bin spikes into 1/60 ms bins, corresponding to video sampling period
binned_spikes = klab.bin_spikes(spikes.data, ds=ds, fs=spikes.samprate, verbose=True)
runidx = speed.active_bins
lin_pos = (posdf.x1.values + posdf.x2.values)/2
pfs, pfbincenters, pindex = klab.estimate_place_fields(lin_pos[runidx],binned_spikes.data[runidx],fs=60,
x0=0,xl=100, max_meanfiringrate = 5,min_maxfiringrate=3,num_pos_bins=100,sigma=1, verbose=True,showfig=True)
In [151]:
klab.show_place_fields(pfs,pfbincenters,pindex,min_maxfiringrate=2)
In [152]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 4))
MM = myhmm.means_.copy()
for cell in np.arange(0,MM.shape[1]):
if cell not in pindex:
MM[:,cell] = np.nan
im = ax1.matshow(MM, cmap='OrRd', vmin=0, vmax=15); ax1.set_xlabel('before sorting states, only place cells')
im = ax2.matshow(MM[new_order,:], cmap='OrRd', vmin=0, vmax=15); ax2.set_xlabel('after sorting states, only place cells')
# Make an axis for the colorbar on the right side
cax = fig.add_axes([0.9, 0.1, 0.03, 0.8])
fig.colorbar(im, cax=cax)
Out[152]:
In [153]:
from random import shuffle
lp, pth = myhmm.decode(obs,algorithm='map')
trj_shfl_idx = np.arange(0,len(pth))
shuffle(trj_shfl_idx)
obs_shfl = np.array([obs[i] for i in trj_shfl_idx])
lp_shfl, pp_shfl = myhmm.score_samples(obs_shfl)
ppp_shfl = advance_states_one(np.vstack([Pi,pp_shfl[:pp_shfl.shape[0]-1,:]]), A)
In [154]:
fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(15, 3))
ax1.matshow(np.transpose(prior_evo)[new_order,:],cmap='OrRd'); ax1.set_xlabel('prior'); ax1.set_ylabel('State')
ax2.matshow(np.transpose(ppp)[new_order,:],cmap='OrRd'); ax2.set_xlabel('predicted')
ax3.matshow(np.transpose(pp)[new_order,:],cmap='OrRd'); ax3.set_xlabel('posterior')
ax4.matshow(np.transpose(ppp_shfl)[new_order,:],cmap='OrRd'); ax4.set_xlabel('pred shfl')
ax5.matshow(np.transpose(pp_shfl)[new_order,:],cmap='OrRd'); ax5.set_xlabel('post shfl')
Out[154]:
Remark: What if our pmfs contain any zeros? Problem!
One way to think about your problem is that you don't really have confidence in the PMF you have calculated from the histogram. You might need a slight prior in your model. Since if you were confident in the PMF, then the KL divergence should be infinity since you got values in one PMF that are impossible in the other PMF. If, on the other hand you had a slight, uninformative prior then there is always some small probability of seeing a certain outcome. One way of introducing this would be to add a vector of ones times some scalar to the histogram. The theoretical prior distribution you would be using is the dirichlet distribution, which is the conjugate prior of the categorical distribution. But for practical purposes you can do something like
pmf_unnorm = scipy.histogram(samples, bins=bins, density=True)[0] + w * scipy.ones(len(bins)-1)
pmf = pmf_unnor / sum(pmf_unnorm)
where w is some positive weight, depending on how strong a prior you want to have.
In [155]:
from scipy.stats import entropy as KLD
KLlist = []
KLlist_shfl = []
for ii in np.arange(1,len(pth)):
KLlist.append(KLD(pp[ii,:],ppp[ii,:]))
KLlist_shfl.append(KLD(pp_shfl[ii,:],ppp_shfl[ii,:]))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 3))
ax1.plot(np.cumsum(KLlist), label='sequence', lw=1.5)
ax1.plot(np.cumsum(KLlist_shfl), label='trajectory shuffled', lw=1.5)
ax1.legend()
ax1.set_xlabel('bin')
ax1.set_title('Cumulative KL divergence')
seqscore = np.cumsum(KLlist) / np.arange(1,len(pth))
seqscore_shfl = np.cumsum(KLlist_shfl) / np.arange(1,len(pth))
ax2.plot(seqscore, label='sequence', lw=1.5)
ax2.plot(seqscore_shfl, label='trajectory shuffled', lw=1.5)
ax2.legend()
ax2.set_xlabel('bin')
ax2.set_title('Running average KL divergence')
Out[155]:
In [156]:
def KL(distr_matU, distr_matV):
from scipy.stats import entropy as KLD
num_bins = distr_matU.shape[0]
KLarray = np.zeros(num_bins)
for ii in np.arange(1,num_bins):
KLarray[ii-1] = KLD(distr_matU[ii,:],distr_matV[ii,:])
return KLarray.mean()
In [157]:
from random import shuffle
###########################################################3
stacked_data = ts_b
seq_id = 0
n_shuffles = 500
###########################################################3
tmpseqbdries = [0]; tmpseqbdries.extend(np.cumsum(stacked_data.sequence_lengths).tolist());
obs = stacked_data.data[tmpseqbdries[seq_id]:tmpseqbdries[seq_id+1],:]
ll, pp = myhmm.score_samples(obs)
num_bins = obs.shape[0]
ppp = advance_states_one(np.vstack([Pi,pp[:num_bins-1,:]]), A)
trj_shfl_idx = np.arange(0, num_bins);
KL_true = KL(pp,ppp)
KL_shuffles = np.zeros(n_shuffles)
for nn in np.arange(0,n_shuffles):
shuffle(trj_shfl_idx)
obs_shfl = np.array([obs[i] for i in trj_shfl_idx])
ll_shfl, pp_shfl = myhmm.score_samples(obs_shfl)
ppp_shfl = advance_states_one(np.vstack([Pi,pp_shfl[:num_bins-1,:]]), A)
KL_shuffles[nn] = KL(pp_shfl,ppp_shfl)
fig, ax1 = plt.subplots(1, 1, figsize=(6, 2))
sns.distplot(KL_shuffles, ax=ax1, bins=40)
ax1.axvline(x=KL_true, ymin=0, ymax=1, linewidth=2, color = 'k', linestyle='dashed', label='true sequence')
x1.legend()
Out[157]:
In [168]:
from random import shuffle
###########################################################3
stacked_data = ts_b
n_shuffles = 250 # shuffles PER sequence in data set
###########################################################3
num_sequences = len(stacked_data.sequence_lengths)
tmpseqbdries = [0]; tmpseqbdries.extend(np.cumsum(stacked_data.sequence_lengths).tolist());
KL_true = np.zeros(num_sequences)
KL_shuffles = np.zeros((num_sequences,n_shuffles))
for seq_id in np.arange(0,num_sequences):
obs = stacked_data.data[tmpseqbdries[seq_id]:tmpseqbdries[seq_id+1],:]
ll, pp = myhmm.score_samples(obs)
num_bins = obs.shape[0]
ppp = advance_states_one(np.vstack([Pi,pp[:num_bins-1,:]]), A)
trj_shfl_idx = np.arange(0, num_bins);
KL_true[seq_id] = KL(pp,ppp)
for nn in np.arange(0,n_shuffles):
shuffle(trj_shfl_idx)
obs_shfl = np.array([obs[i] for i in trj_shfl_idx])
ll_shfl, pp_shfl = myhmm.score_samples(obs_shfl)
ppp_shfl = advance_states_one(np.vstack([Pi,pp_shfl[:num_bins-1,:]]), A)
KL_shuffles[seq_id,nn] = KL(pp_shfl,ppp_shfl)
fig, ax1 = plt.subplots(1, 1, figsize=(6, 2))
sns.distplot(KL_true, ax=ax1, label='true sequences')
sns.distplot(KL_shuffles.flatten(), bins=80, ax=ax1, label='trajectory shuffled')
#ax1.axvline(x=KL_true, ymin=0, ymax=1, linewidth=2, color = 'k', linestyle='dashed', label='true sequence')
#ax1.set_xlim([0,5])
ax1.legend()
Out[168]:
The KL-based sequence score actually works remarkably well when looking at individual examples. On a population level, the results seem good, but not great, but this might be due to several "bad" sequences in the data. For example, some sequences are quite short (4 bins) where only one or two states are visited. Consequently, reshuffling the data using a trajectory shuffle does nothing to the sequence, and we get false positives for shuffled data that was classified as a true sequence.
Remaining important characterizations: I should consider other distributions here, such as the RUN > 8 vs NORUN < 4 distributions, and I should also split up data into lin1a, lin1b (splitting when the track was shortened) and lin2a and lin2b data. Each of those subsets of data can also be split into RUN > 8 and NORUN < 4 subsets to try and observe finer discrimination.
Of course, I should also characterize this score with the numerous other shuffling strategies that are commonly employed, including the unit-ID shuffle, etc.
KL is not symmetric: There can be some debate about whether it is more appropriate to use $D_\text{KL}(U||V)$ or $D_\text{KL}(V||U)$. As of now, I am leaning in favor of $D_\text{KL}(U||V)$, where $V$ is the expected model evolution, and $U$ is the actual observed behavior. I have tested both, and both consistently gives a lower score (more similarity) to true sequences than to shuffled sequences.
About the place fields: The place fields here also don't look that great, but that is partly due to the fact that I don't consider directionality here, and also I don't split the experiment into the parts before and after the track was shortened. However, we don't actually use any of the place field information in this entire analysis, so that it should not be of any real concern. We can of course estimate them better, if we really have to...
Notation: Importantly, I should re-write the introduction to be more friendly, making my notation consistent and final, and I should demonstrate why this approach makes both intuitive and mathematical sense. I do like the notation used above, where $p(S_t|\mathbf{y}_{1:T})$ is the posterior state distribution at time $t$ having observed the sequence $\mathbf{y}_{1:T}$, and I also like the notation for distinguishing between distributions $p(\cdot)$ and probabilities $P(\cdot)$ although if I draw attention to this distinction, I have to be very careful to follow such a convention consistently.
What's next? How can we derive an effective final score? Finally, what then, is my final sequence score? I still have both sequential and contextual components, but I need to think more carefully about the best way to make the KL-score into the sequential component. For one thing, a score is typically better if it is larger, but so far we have opposite desired directions for the contextual and sequential components (larger ctx is good, smaller KL is good). Another issue is interpretability, both local and global. Local interpretability might answer how likely we are to observe a KL score that small or smaller, based on a shuffle distribution of the sequence under consideration, but it would be more appealing to not have to compute shuffle distributions first, and also to have global interpretability, which would allow us to say "this sequence is more consistent with the underlying model than that sequence".