In [1]:
import copy
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
import seaborn as sns
import time
import warnings

import scipy.ndimage.filters
import scipy.stats as stats

from IPython.display import display, clear_output

import nelpy as nel
import nelpy.plotting as npl

from nelpy.analysis.hmm_sparsity import HMMSurrogate

from sklearn.model_selection import train_test_split
from mpl_toolkits.axes_grid1 import make_axes_locatable

from nelpy import hmmutils
from nelpy.decoding import k_fold_cross_validation
from nelpy.decoding import decode1D

# Set default figure aesthetics
npl.setup(font_scale=2.0)

%matplotlib inline

warnings.filterwarnings("ignore")


/opt/conda/lib/python3.6/site-packages/matplotlib/cbook/deprecation.py:106: MatplotlibDeprecationWarning: The mpl_toolkits.axes_grid module was deprecated in version 2.1. Use mpl_toolkits.axes_grid1 and mpl_toolkits.axisartist provies the same functionality instead.

In [2]:
import gcsfs
import pandas as pd
import os

load_local = False

if not load_local:
    fs = gcsfs.GCSFileSystem(project='polar-program-784', token='cloud')
    print(fs.ls('kemerelab-data/diba'))

    with fs.open('kemerelab-data/diba/gor01vvp01pin01-metadata.h5', 'rb') as fid:
        with pd.HDFStore('gor01vvp01pin01-metadata.h5', mode="r", driver="H5FD_CORE",
                driver_core_backing_store=0,
                driver_core_image=fid.read()
                ) as store:
            df = store['Session_Metadata']
            df2 = store['Subset_Metadata']
            
    with fs.open('kemerelab-data/diba/gor01vvp01pin01_processed_speed.nel', 'rb') as fid:
        jar = nel.load_pkl('',fileobj=fid) # currently requires a specific nelpy branch

else:
    datadirs = ['/Users/ckemere/Development/Data/Buzsaki/']

    fileroot = next( (dir for dir in datadirs if os.path.isdir(dir)), None)
    if fileroot is None:
        raise FileNotFoundError('datadir not found')

    with pd.HDFStore(fileroot + 'gor01vvp01pin01-metadata.h5') as store:
        df = store.get('Session_Metadata')
        df2 = store.get('Subset_Metadata')
        
    jar = nel.load_pkl(fileroot + 'gor01vvp01pin01_processed_speed.nel')


exp_data = jar.exp_data
aux_data = jar.aux_data
del jar


['kemerelab-data/diba/', 'kemerelab-data/diba/gor01vvp01-metadata.h5', 'kemerelab-data/diba/gor01vvp01_processed_speed.nel', 'kemerelab-data/diba/gor01vvp01pin01-metadata.h5', 'kemerelab-data/diba/gor01vvp01pin01_processed_speed.nel', 'kemerelab-data/diba/score_all_sessions_5000_35000.nel']

Draw real place fields


In [3]:
# session_time, segment = '1-22-43', 'long'
# session_time, segment = '16-40-19', 'short'

session_time, segment = '22-24-40', 'short'

PBEs = aux_data[session_time][segment]['PBEs']
st_run = aux_data[session_time][segment]['st_run']
tc = aux_data[session_time][segment]['tc']
tc_placecells = aux_data[session_time][segment]['tc_placecells']

#####################################################################

NUM_COLORS = tc_placecells.n_units * 4

cm = plt.get_cmap('Spectral_r')
clist = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]
clist = np.roll(clist, 0, axis=0)

npl.set_palette(clist)

with npl.FigureManager(show=True, figsize=(4,6)) as (fig, ax):
    ax = npl.plot_tuning_curves1D(tc_placecells.smooth(sigma=3), pad=2.5);
    ax.set_xlim(0,250)


Compute cross-validate sequence score distributions for RUN data in 4 sessions


In [4]:
#from dask.distributed import Client
# client = Client('tcp://127.0.0.1:38306')  # set up local cluster on your laptop
#client

In [5]:
import dask
from dask import delayed

def est_model(data, num_states, seed):
    hmm = nel.hmmutils.PoissonHMM(n_components=num_states, random_state=seed, verbose=False)
    hmm.fit(data)
    return hmm

def score_data(data, hmm):
    seq_lens = np.array([seq.n_bins for seq in data])
    return hmm.score(data) / seq_lens

In [6]:
ds_run = 0.125 # 125 ms bin size for Run
ds_50ms = 0.05 # used for smoothing
ds = 0.02 # 20 ms bin size for PBEs

sigma = 0.25 # 250 ms spike smoothing

num_states = 30

k_folds = 5

In [7]:
print('Building model for Session {}, {} segment'.format(session_time, segment))
s = np.argwhere([segment == segment_label for segment_label in df[df.time==session_time]['segment_labels'].values.tolist()[0]])
st_run = exp_data[session_time]['spikes'][s][exp_data[session_time]['run_epochs']]

# smooth and re-bin RUN data:
bst = st_run.bin(ds=ds_50ms).smooth(sigma=sigma, inplace=True).rebin(w=ds_run/ds_50ms)


Building model for Session 22-24-40, short segment

Analyze model parameters for RUN


In [ ]:
# get run spikes
s = np.argwhere([segment == segment_label for segment_label in df[df.time==session_time]['segment_labels'].values.tolist()[0]])
run_spks = exp_data[session_time]['spikes'][s][exp_data[session_time]['run_epochs']]

random_state = 1
test_size = 0.2
description = (session_time, segment)
verbose = False

hmm_actual = HMMSurrogate(kind='actual', 
                          st=run_spks, 
                          num_states=num_states, 
                          ds=ds_run, 
                          test_size=test_size, 
                          random_state=random_state, 
                          description=description,
                          verbose=verbose)

hmm_coherent = HMMSurrogate(kind='coherent', 
                                       st=run_spks, 
                                       num_states=num_states, 
                                       ds=ds_run, 
                                       test_size=test_size, 
                                       random_state=random_state, 
                                       description=description,
                                       verbose=verbose)

hmm_poisson = HMMSurrogate(kind='poisson', 
                                       st=run_spks, 
                                       num_states=num_states, 
                                       ds=ds_run, 
                                       test_size=test_size, 
                                       random_state=random_state, 
                                       description=description,
                                       verbose=verbose)

run_hmms = [hmm_actual,
        hmm_coherent,
        hmm_poisson]

In [21]:
%%time 
n_shuffles = 1


for nn in range(n_shuffles):
    print('starting {}'.format(nn))
    for hmm in run_hmms:
        print("shuffling", hmm.label)
        hmm.shuffle()
        print("fitting", hmm.label)
        hmm.fit()
        print("scoring", hmm.label)
        
        # calculate and aggregate various gini coefficients
        hmm.score_gini(kind='tmat')
        hmm.score_gini(kind='lambda')
        hmm.score_gini(kind='tmat_arrival')
        hmm.score_gini(kind='tmat_departure')
        hmm.score_gini(kind='lambda_across_states')
        hmm.score_gini(kind='lambda_across_units')
        
        # calculate and aggregate bottleneck_ratios
        hmm.score_bottleneck_ratio(n_samples=20000)
        
    print('completed {}'.format(nn))


starting 0
shuffling actual
fitting actual
scoring actual
shuffling coherent
fitting coherent
scoring coherent
shuffling poisson
fitting poisson
scoring poisson
completed 0
CPU times: user 1min 4s, sys: 1min 4s, total: 2min 8s
Wall time: 58.2 s

In [34]:
import model_plotting

## define figure parameters and color pallete
text_kws = dict(ha="center", size=7)
fig_kws = dict(text_kws=text_kws, cmap=plt.cm.seismic)

# Plot true 

fig, axes = plt.subplots(1,2, figsize=(10, 6))
lambda_order = np.argsort(hmms[0].hmm.means.sum(axis=0))
plot_transmat(axes[0], hmm=run_hmms[0], title=run_hmms[0].label, cbar=True, **fig_kws)
plot_lambda(axes[1], hmm=run_hmms[0], title=run_hmms[0].label, ylabel=False, lo=lambda_order, 
            cbar=True, cb_ticks=[0.1, 4, 14], **lambda_kws)
fig.tight_layout(w_pad=10, rect=[0, 0, 1, 1])



In [25]:
fig, axes = plt.subplots(1,3, figsize=(15, 12))
lambda_order = np.argsort(hmms[0].hmm.means.sum(axis=0))
# plt.matshow(hmm.means[:,no].T, cmap=plt.cm.Spectral_r)
plot_transmat(axes[0], hmm=run_hmms[0], title=run_hmms[0].label, cbar=False, **fig_kws)
plot_transmat(axes[1], hmm=run_hmms[1], title=run_hmms[1].label, cbar=False, ylabel=False, **fig_kws)
plot_transmat(axes[2], hmm=run_hmms[2], title=run_hmms[2].label, cbar=True, ylabel=False, **fig_kws)
fig.tight_layout(h_pad=.5, w_pad=0.75, rect=[0, .05, 1, 1])

fig, axes = plt.subplots(1,3, figsize=(15, 12))
plot_sun_graph(axes[0], hmm=run_hmms[0], nc=npl.colors.sweet.green, **fig_kws)
plot_sun_graph(axes[1], hmm=run_hmms[1], nc=npl.colors.sweet.red, **fig_kws)
plot_sun_graph(axes[2], hmm=run_hmms[2], nc=npl.colors.sweet.red, **fig_kws)
fig.tight_layout(h_pad=.5, w_pad=0.75, rect=[0, .05, 1, 1])

fig, axes = plt.subplots(1,3, figsize=(15,12))
lambda_kws = dict(text_kws=text_kws, cmap=plt.cm.seismic, norm=colors.PowerNorm(0.5))
plot_lambda(axes[0], hmm=run_hmms[0], title=run_hmms[0].label, cbar=False, lo=lambda_order, **lambda_kws)
plot_lambda(axes[1], hmm=run_hmms[1], title=run_hmms[1].label, ylabel=False, lo=lambda_order, cbar=False, **lambda_kws)
plot_lambda(axes[2], hmm=run_hmms[2], title=run_hmms[2].label, ylabel=False, lo=lambda_order, cbar=True, cb_ticks=[0.1, 4, 14], **lambda_kws)
fig.tight_layout(h_pad=.5, w_pad=0.75, rect=[0, .05, 1, 1])

fig, axes = plt.subplots(1,3, figsize=(16,6))
plot_transmat_gini_departure(axes[0], run_hmms, **fig_kws)
# # plot_transmat_gini_arrival(axes[19], hmms, **fig_kws)
#plot_lambda_gini_across_states(axes[1], run_hmms, **fig_kws)
# # plot_gini_lambda(axes[7], [hmms[0], hmms[2], hmms[3]], **fig_kws)
#plot_bottleneck(axes[2], run_hmms, **fig_kws)


Analyze sparsity of PBE model


In [8]:
# get spike train
st = exp_data[session_time]['spikes']
# restrict spikes to only PBEs:
mua_spks = st[aux_data[session_time][segment]['PBEs'].support]

random_state = 1
test_size = 0.2
description = (session_time, segment)
verbose = False

hmm_actual = HMMSurrogate(kind='actual', 
                          st=mua_spks, 
                          num_states=num_states, 
                          ds=ds, 
                          test_size=test_size, 
                          random_state=random_state, 
                          description=description,
                          verbose=verbose)

hmm_coherent = HMMSurrogate(kind='coherent', 
                                       st=mua_spks, 
                                       num_states=num_states, 
                                       ds=ds, 
                                       test_size=test_size, 
                                       random_state=random_state, 
                                       description=description,
                                       verbose=verbose)

hmm_poisson = HMMSurrogate(kind='poisson', 
                                       st=mua_spks, 
                                       num_states=num_states, 
                                       ds=ds, 
                                       test_size=test_size, 
                                       random_state=random_state, 
                                       description=description,
                                       verbose=verbose)

hmms = [hmm_actual,
        hmm_coherent,
        hmm_poisson]

In [9]:
n_shuffles = 1

for nn in range(n_shuffles):
    print('starting {}'.format(nn))
    for hmm in hmms:
        print("shuffling", hmm.label)
        hmm.shuffle()
        print("fitting", hmm.label)
        hmm.fit()
        print("scoring", hmm.label)
        
        # calculate and aggregate various gini coefficients
        hmm.score_gini(kind='tmat')
        hmm.score_gini(kind='lambda')
        hmm.score_gini(kind='tmat_arrival')
        hmm.score_gini(kind='tmat_departure')
        hmm.score_gini(kind='lambda_across_states')
        hmm.score_gini(kind='lambda_across_units')
        
        # calculate and aggregate bottleneck_ratios
        hmm.score_bottleneck_ratio(n_samples=20000)
        
    print('completed {}'.format(nn))

        # calculate and aggregate mixing time

        # calculate and aggregate spectrum (or are we only interedted in lambda2 and spectral gap?)


starting 0
shuffling actual
fitting actual
scoring actual
shuffling coherent
fitting coherent
scoring coherent
shuffling poisson
fitting poisson
scoring poisson
completed 0

In [10]:
def score_bottleneck_ratio(transmat, n_samples=50000, verbose=False):
    from nelpy.analysis.ergodic import steady_state
    def Qij(i, j, P, pi):
        return pi[i] * P[i,j]

    def QAB(A, B, P, pi):
        sumQ = 0
        for i in A:
            for j in B:
                sumQ += Qij(i, j, P, pi)
        return sumQ

    def complement(S, Omega):
        return Omega - S

    def Pi(S, pi):
        sumS = 0
        for i in S:
            sumS += pi[i]
        return sumS

    def Phi(S, P, pi, Omega):
        Sc = complement(S, Omega)
        return QAB(S, Sc, P, pi) / Pi(S, pi)

    P = transmat
    num_states = transmat.shape[0]
    Omega = set(range(num_states))
    pi_ = steady_state(P).real

    min_Phi = 1
    for nn in range(n_samples):
        n_samp_in_subset = np.random.randint(1, num_states-1)
        S = set(np.random.choice(num_states, n_samp_in_subset, replace=False))
        while Pi(S, pi_) > 0.5:
            n_samp_in_subset -=1
            if n_samp_in_subset < 1:
                n_samp_in_subset = 1
            S = set(np.random.choice(num_states, n_samp_in_subset, replace=False))
        candidate_Phi = Phi(S, P, pi_, Omega)
        if candidate_Phi < min_Phi:
            min_Phi = candidate_Phi
            if verbose:
                print("{}: {} (|S| = {})".format(nn, min_Phi, len(S)))
    return min_Phi

import numpy.linalg as LA

def spectral_gap(transmat):
    evals = LA.eigvals(transmat)
    sorder = np.argsort(np.abs(evals))
    gap = np.real(evals[sorder[-1]] - np.abs(evals[sorder[-2]]))
    return gap

In [11]:
class ColorBarLocator(object):
    def __init__(self, pax, pad=5, width=10):
        self.pax = pax
        self.pad = pad
        self.width = width

    def __call__(self, ax, renderer):
        x, y, w, h = self.pax.get_position().bounds
        fig = self.pax.get_figure()
        inv_trans = fig.transFigure.inverted()
        pad, _ = inv_trans.transform([self.pad, 0])
        width, _ = inv_trans.transform([self.width, 0])
        return [x+w+pad, y, width, h]

def plot_transmat(ax, hmm, edge_threshold=0.0, title='', cbar=True, ylabel=True, **fig_kws):
    cmap = fig_kws.get('cmap', plt.cm.viridis)
    
    num_states = hmm.hmm.n_components
    
    img = ax.matshow(np.where(hmm.hmm.transmat>edge_threshold, hmm.hmm.transmat, 0), cmap=cmap, vmin=0, vmax=1, interpolation='none', aspect='equal')
    ax.set_aspect('equal')
    
    if cbar:
        divider = make_axes_locatable(ax)
#         cax = divider.append_axes("right", size=0.1, pad=0.1)
        cax = fig.add_axes([0,0,0,0], axes_locator=ColorBarLocator(ax))
        cb=plt.colorbar(img, cax=cax)
        cb.set_label('probability', labelpad=-8)
        cb.set_ticks([0,1])
        npl.utils.no_ticks(cax)
        
#     if not cbar:
#         cax.set_visible(False)
    if ylabel:
        ax.set_yticks([0.5, num_states-1.5])
        ax.set_yticklabels(['1', str(num_states)])    
        ax.set_ylabel('state $i$', labelpad=-16)
    else:
        ax.set_yticks([])
        ax.set_yticklabels('')
    
    ax.set_xticks([0.5, num_states-1.5])
    ax.set_xticklabels(['1', str(num_states)])
    ax.set_xbound(lower=0.0, upper=num_states-1)
    ax.set_ybound(lower=0.0, upper=num_states-1)
    
    ax.set_xlabel('state $j$', labelpad=-16)
    
    ax.set_title(title + ' A')
    sns.despine(ax=ax)    
    
def plot_lambda(ax, hmm, cbar=True, ylabel=True, title='', lo=None, **fig_kws):
    import matplotlib.colors as colors

    cmap = fig_kws.get('cmap', plt.cm.viridis)
    norm = fig_kws.get('norm', colors.LogNorm())
    cb_ticks = fig_kws.get('cb_ticks')
    
    num_states = hmm.hmm.n_components
    num_units = hmm.hmm.n_features
    
    ax.set_aspect(num_states/num_units)

    
    if lo is not None:
        img = ax.matshow(hmm.hmm.means[:,lo].T, cmap=cmap, norm=norm, interpolation='none', aspect='auto')
    else:
        img = ax.matshow(hmm.hmm.means.T, cmap=cmap, norm=norm, interpolation='none', aspect='auto')
    
    if cbar:
        divider = make_axes_locatable(ax)
        #cax = fig.add_axes([0,0,0,0], axes_locator=ColorBarLocator(ax))
        cax = divider.append_axes("right", size=0.1, pad=0.1)
        cb=plt.colorbar(img, cax=cax)
        #cb.set_label('firing rate', labelpad=-8)
        cb.set_ticks(cb_ticks)
        #cb.set_ticklabels(['lo', 'hi'])
        npl.utils.no_ticks(cax)
    
    if ylabel:
        ax.set_yticks([0.5, num_units-1.5])
        ax.set_yticklabels(['1', str(num_units)])
        ax.set_ylabel('unit', labelpad=-16)
    else:
        ax.set_yticks([])
        ax.set_yticklabels('')
        
    ax.set_xticks([0.5, num_states-1.5])
    ax.set_xticklabels(['1', str(num_states)])    
    
    ax.set_ybound(lower=0.0, upper=num_units-1)
    ax.set_xbound(lower=0.0, upper=num_states-1)
    
    ax.set_xlabel('state', labelpad=-16)
    ax.set_title(title + ' $\Lambda$')
    sns.despine(ax=ax)   
    
def plot_sun_graph(ax, hmm, edge_threshold=0.0, lw=2, ec='k', nc='k', node_size=3, **fig_kws):
    plt.sca(ax)
    
    Gi = npx.inner_graph_from_transmat(hmm.hmm.transmat)
    Go = npx.outer_graph_from_transmat(hmm.hmm.transmat)
    
    npx.draw_transmat_graph_inner(Gi, edge_threshold=edge_threshold, lw=lw, ec=ec, node_size=node_size)
    npx.draw_transmat_graph_outer(Go, Gi, edge_threshold=edge_threshold, lw=lw, ec=ec, nc=nc, node_size=node_size*2)

    ax.set_xlim(-1.4,1.4)
    ax.set_ylim(-1.4,1.4)
#     ax0, img = npl.imagesc(hmm.transmat, ax=axes[0])
    npl.utils.clear_left_right(ax)
    npl.utils.clear_top_bottom(ax)
    
#     ax.set_title('1 - $|\lambda_2| =$ {0:.2f}'.format(float(spectral_gap(hmm.hmm.transmat))))
    ax.set_title('$\gamma^*=$ {0:.3f}'.format(float(spectral_gap(hmm.hmm.transmat))), y=1.02)
    
    ax.set_aspect('equal')
    
def plot_connectivity_graph(ax, hmm, edge_threshold=0.0, lw=2, ec='k', node_size=3, **fig_kws):
    plt.sca(ax)
    
    G = npx.graph_from_transmat(hmm.hmm.transmat)
    
    npx.draw_transmat_graph(G, edge_threshold=edge_threshold, lw=lw, ec=ec, node_size=node_size)
#     ax.set_xlim(-1.3,1.3)
#     ax.set_ylim(-1.3,1.3)
    ax.set_xlim(-1,1)
    ax.set_ylim(-1,1)
#     ax0, img = npl.imagesc(hmm.transmat, ax=axes[0])
    npl.utils.clear_left_right(ax)
    npl.utils.clear_top_bottom(ax)
    ax1.set_aspect('equal')
    
def plot_transmat_gini_departure(ax, hmms, n_max=500, **fig_kws):
    
    hist_kws={"range": (0.5, 1)}
    bins=50

    with sns.color_palette("Blues_d", 8):
        for hmm in hmms:
            data = np.array(hmm.results['gini_tmat_departure'])
            data = data[:n_max,:]
            sns.distplot(data.sum(axis=0)/len(data), hist=False, hist_kws=hist_kws, bins=bins, label=hmm.label, ax=ax)

    ax.set_title('tmat gini departure, N=250')
    
    ax.set_xlim(0.6, 0.9)
    
    sns.despine(ax=ax)
    
def plot_transmat_gini_arrival(ax, hmms, n_max=500, **fig_kws):
    
    hist_kws={"range": (0.8, 1)}
    bins=50

    with sns.color_palette("Blues_d", 8):
        for hmm in hmms:
            data = np.array(hmm.results['gini_tmat_arrival'])
            data = data[:n_max,:]
            sns.distplot(data.sum(axis=0)/len(data), hist=False, hist_kws=hist_kws, bins=bins, label=hmm.label, ax=ax)

    ax.set_title('tmat gini arrival, N=250')
    ax.legend('')
    ax.set_xlim(0.7, 1)
    
    sns.despine(ax=ax)
    
def plot_bottleneck(ax, hmms, n_max=500, **fig_kws):
    
    hist_kws={"range": (0, 0.5)}
    bins=50

    for hmm in hmms:
        data = np.array(hmm.results['bottleneck'])
        data = data[:n_max]
        sns.distplot(data, hist=False, hist_kws=hist_kws, bins=bins, label=hmm.label, ax=ax)

    ax.set_title('bottleneck, N=250')
    
    ax.legend('')
    
    ax.set_xlim(0, 0.5)
    
    sns.despine(ax=ax)
    
def plot_gini_lambda(ax, hmms, n_max=500, **fig_kws):
    
    hist_kws={"range": (0.7, 0.9)}
    bins=50

    for hmm in hmms:
        data = np.array(hmm.results['gini_lambda'])
        data = data[:n_max]
        sns.distplot(data, hist=False, hist_kws=hist_kws, bins=bins, label=hmm.label, ax=ax)

    ax.set_title('lambda gini, N=250')
    ax.legend('')
    ax.set_xlim(0.7, 0.9)
    
    sns.despine(ax=ax)
    
def plot_lambda_gini_across_states(ax, hmms, n_max=5000, **fig_kws):
    
    hist_kws={"range": (0.0, 1)}
    bins=30

    for hmm in hmms:
        data = np.array(hmm.results['gini_lambda_across_states'])
        data = data[:n_max,:]
        sns.distplot(data.sum(axis=0)/len(data), hist_kws=hist_kws, bins=bins, hist=False, kde=True, label=hmm.label, ax=ax, kde_kws={'bw':0.05})
    
    ax.set_title('lambda gini across states, N=250')
    ax.legend('')
    ax.set_xlim(0., 1)
    
    sns.despine(ax=ax)

In [12]:
import nelpy.plotting.graph as npx
import matplotlib.colors as colors


cm = plt.get_cmap('Spectral_r')
clist = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]
clist = np.roll(clist, 0, axis=0)
npl.set_palette(clist)


# fig = plt.figure(1, figsize=(6, 6))
fig, axes = plt.subplots(1,3, figsize=(16, 14))

## define figure parameters and color pallete
text_kws = dict(ha="center", size=7)
fig_kws = dict(text_kws=text_kws, cmap=plt.cm.seismic)


lambda_order = np.argsort(hmms[0].hmm.means.sum(axis=0))
# plt.matshow(hmm.means[:,no].T, cmap=plt.cm.Spectral_r)
plot_transmat(axes[0], hmm=hmms[0], title=hmms[0].label, cbar=False, **fig_kws)
plot_transmat(axes[1], hmm=hmms[1], title=hmms[1].label, cbar=False, ylabel=False, **fig_kws)
plot_transmat(axes[2], hmm=hmms[2], title=hmms[2].label, cbar=True, ylabel=False, **fig_kws)
fig.tight_layout(h_pad=.5, w_pad=0.75, rect=[0, .05, 1, 1])


fig, axes = plt.subplots(1,3, figsize=(16, 14))
plot_sun_graph(axes[0], hmm=hmms[0], nc=npl.colors.sweet.green, **fig_kws)
plot_sun_graph(axes[1], hmm=hmms[1], nc=npl.colors.sweet.red, **fig_kws)
plot_sun_graph(axes[2], hmm=hmms[2], nc=npl.colors.sweet.red, **fig_kws)
fig.tight_layout(h_pad=.5, w_pad=0.75, rect=[0, .05, 1, 1])

fig, axes = plt.subplots(1,3, figsize=(16,8))
lambda_kws = dict(text_kws=text_kws, cmap=plt.cm.seismic, norm=colors.PowerNorm(0.5))
plot_lambda(axes[0], hmm=hmms[0], title=hmms[0].label, cbar=False, lo=lambda_order, **lambda_kws)
plot_lambda(axes[1], hmm=hmms[1], title=hmms[1].label, ylabel=False, lo=lambda_order, cbar=False, **lambda_kws)
plot_lambda(axes[2], hmm=hmms[2], title=hmms[2].label, ylabel=False, lo=lambda_order, cbar=True, cb_ticks=[0.01,1,4], **lambda_kws)
fig.tight_layout(h_pad=.5, w_pad=0.75, rect=[0, .05, 1, 1])

# plot_transmat_gini_departure(axes[3], [hmms[0], hmms[1], hmms[2]], **fig_kws)
# # plot_transmat_gini_arrival(axes[19], hmms, **fig_kws)
# plot_lambda_gini_across_states(axes[7], [hmms[0], hmms[1], hmms[2]], **fig_kws)
# # plot_gini_lambda(axes[7], [hmms[0], hmms[2], hmms[3]], **fig_kws)
# plot_bottleneck(axes[11], [hmms[0], hmms[1], hmms[2]], **fig_kws)