In [119]:
%pylab inline
import seaborn as sns
sns.set_style('white')
sns.set_context('paper', font_scale=2.5)
from pomegranate import HiddenMarkovModel, NormalDistribution, PoissonDistribution
model = HiddenMarkovModel()


Populating the interactive namespace from numpy and matplotlib

In [5]:
binned_spikes = np.load('/home/saket/Downloads/binned_spikes.npy')
binned_spikes.shape


Out[5]:
(10, 250, 27)

In [13]:
lambdas = [4, 10]
N = 10
X_T = np.array([np.random.poisson(lam, size=N) for lam in lambdas])
X_T.shape


Out[13]:
(2, 10)

In [48]:
#lambdas = [4, 10]
#N = 10

def generate_X():
    lambdas = [4, 10]
    N = 10

    X_T = list(np.random.poisson(lambdas[0], size=N))
    X_T += list(np.random.poisson(lambdas[1], size=1))
    X_T += list(np.random.poisson(lambdas[0], size=int(N/2)))
    X_T += list(np.random.poisson(lambdas[1], size=1))
    Y_T = [0]*N + [1] + [0]*int(N/2) + [1]  
    return X_T, Y_T
    #X_T = np.array(X_T).reshape(len(X_T),1)
#X_T.shape
X = []
Y = []
n_samples = 10
for n in range(n_samples):
    x, y = generate_X()
    X.append(x)
    Y.append(y)
X= np.array(X)
X.shape


Out[48]:
(10, 17)

In [49]:
plt.plot(X)


Out[49]:
[<matplotlib.lines.Line2D at 0x7fd9054056a0>,
 <matplotlib.lines.Line2D at 0x7fd9050c4128>,
 <matplotlib.lines.Line2D at 0x7fd9050c43c8>,
 <matplotlib.lines.Line2D at 0x7fd9050c4710>,
 <matplotlib.lines.Line2D at 0x7fd9050c4908>,
 <matplotlib.lines.Line2D at 0x7fd9050c4cc0>,
 <matplotlib.lines.Line2D at 0x7fd90d53b198>,
 <matplotlib.lines.Line2D at 0x7fd90d54f3c8>,
 <matplotlib.lines.Line2D at 0x7fd90540e7f0>,
 <matplotlib.lines.Line2D at 0x7fd9050c9080>,
 <matplotlib.lines.Line2D at 0x7fd9050b05f8>,
 <matplotlib.lines.Line2D at 0x7fd9050c97b8>,
 <matplotlib.lines.Line2D at 0x7fd9050c9b38>,
 <matplotlib.lines.Line2D at 0x7fd9050c9e80>,
 <matplotlib.lines.Line2D at 0x7fd90710ab38>,
 <matplotlib.lines.Line2D at 0x7fd909969748>,
 <matplotlib.lines.Line2D at 0x7fd9050cb080>]

In [19]:
from pomegranate import *
import numpy as np

model = HiddenMarkovModel('%i' % 1) 
states = []
n_states = 1


for i in range(n_states):
	states.append(State(IndependentComponentsDistribution([PoissonDistribution(np.random.rand()) for unit in range(binned_spikes.shape[2])]), name = 'State%i' % (i+1)))
	
model.add_states(states)

for i in range(n_states):
	if i == 0:
		model.add_transition(model.start, states[i], 1.0)
	else:
		model.add_transition(model.start, states[i], 0.0)

for i in range(n_states):
	not_transitioning_prob = (0.999-0.95)*np.random.random() + 0.95
	for j in range(n_states):
		if i - j == 0:
			model.add_transition(states[i], states[j], not_transitioning_prob)
		elif j - i == 1:
			model.add_transition(states[i], states[j], 1.0 - not_transitioning_prob)
		else:
			model.add_transition(states[i], states[j], 0.0) 
	
model.bake()


model.dense_transition_matrix()


Out[19]:
array([[1., 0., 0.],
       [1., 0., 0.],
       [0., 0., 0.]])

In [22]:
trans_mat = numpy.array([[0.7, 0.3],
                         [0.3, 0.7]])
starts = numpy.array([0.5, 0.0])
ends = numpy.array([0.0, 0.1])
dists = [PoissonDistribution(2), PoissonDistribution(3)]

model = HiddenMarkovModel.from_matrix(trans_mat, dists, starts, ends)

In [121]:
model = HiddenMarkovModel.from_samples(PoissonDistribution, n_components=2, 
                                       X=X,
                                       algorithm='baum-welch',
                                       state_names=['normal', 'stalling'],
                                       verbose=True)
model.bake()


[1] Improvement: 26.663072218839204	Time (s): 0.0002174
[2] Improvement: 1.0303674970031125	Time (s): 0.00052
[3] Improvement: 0.6115659918766596	Time (s): 0.0003312
[4] Improvement: 0.44279456991080224	Time (s): 0.0003188
[5] Improvement: 0.34550262243874386	Time (s): 0.000345
[6] Improvement: 0.27821388423990356	Time (s): 0.0003171
[7] Improvement: 0.22557641760477054	Time (s): 0.0003307
[8] Improvement: 0.1820842809208898	Time (s): 0.0003078
[9] Improvement: 0.1458413473354767	Time (s): 0.0003068
[10] Improvement: 0.11595710218421118	Time (s): 0.0003057
[11] Improvement: 0.09167178510375606	Time (s): 0.0003288
[12] Improvement: 0.07218594497607	Time (s): 0.0003092
[13] Improvement: 0.05669975193455912	Time (s): 0.0003049
[14] Improvement: 0.04447384685738598	Time (s): 0.0003052
[15] Improvement: 0.03486418377354994	Time (s): 0.000303
[16] Improvement: 0.027331563368932166	Time (s): 0.0003171
[17] Improvement: 0.021436118841620555	Time (s): 0.0003107
[18] Improvement: 0.016825082163961724	Time (s): 0.0003126
[19] Improvement: 0.013218728403330715	Time (s): 0.0003092
[20] Improvement: 0.010396927708484327	Time (s): 0.0003071
[21] Improvement: 0.00818731489999891	Time (s): 0.0003078
[22] Improvement: 0.006455354270144653	Time (s): 0.0003054
[23] Improvement: 0.005096230615151853	Time (s): 0.0003023
[24] Improvement: 0.004028353329545098	Time (s): 0.0003016
[25] Improvement: 0.003188218150285138	Time (s): 0.0003495
[26] Improvement: 0.0025263777300210677	Time (s): 0.0003223
[27] Improvement: 0.002004299921054553	Time (s): 0.0003152
[28] Improvement: 0.0015919271261850554	Time (s): 0.0003514
[29] Improvement: 0.0012657842037242517	Time (s): 0.0003054
[30] Improvement: 0.0010075129743540856	Time (s): 0.000443
[31] Improvement: 0.0008027372375636332	Time (s): 0.0004044
[32] Improvement: 0.0006401833223890208	Time (s): 0.0003841
[33] Improvement: 0.0005109980520501267	Time (s): 0.0003839
[34] Improvement: 0.00040821922294753676	Time (s): 0.0003867
[35] Improvement: 0.00032636399254215576	Time (s): 0.0003934
[36] Improvement: 0.00026110849512406276	Time (s): 0.000385
[37] Improvement: 0.00020903814157691158	Time (s): 0.0003877
[38] Improvement: 0.00016745273802598604	Time (s): 0.0003581
[39] Improvement: 0.0001342141743521097	Time (s): 0.0004599
[40] Improvement: 0.00010762722882873277	Time (s): 0.0003579
[41] Improvement: 8.634613959657145e-05	Time (s): 0.0003555
[42] Improvement: 6.930128103022071e-05	Time (s): 0.0003777
[43] Improvement: 5.564152672832279e-05	Time (s): 0.0003762
[44] Improvement: 4.468887692610224e-05	Time (s): 0.000361
[45] Improvement: 3.5902688978239894e-05	Time (s): 0.0003579
[46] Improvement: 2.8851439139998547e-05	Time (s): 0.0003593
[47] Improvement: 2.319038890163938e-05	Time (s): 0.0003903
[48] Improvement: 1.8643908106241724e-05	Time (s): 0.0003572
[49] Improvement: 1.4991451791956933e-05	Time (s): 0.0003779
[50] Improvement: 1.205643172852433e-05	Time (s): 0.0003765
[51] Improvement: 9.697366124328255e-06	Time (s): 0.0003576
[52] Improvement: 7.800836499427533e-06	Time (s): 0.0003572
[53] Improvement: 6.275874341099552e-06	Time (s): 0.0003579
[54] Improvement: 5.049485366726003e-06	Time (s): 0.0003564
[55] Improvement: 4.06307214007029e-06	Time (s): 0.0003595
[56] Improvement: 3.2695798495296913e-06	Time (s): 0.0003805
[57] Improvement: 2.631209099490661e-06	Time (s): 0.0003855
[58] Improvement: 2.1175868596401415e-06	Time (s): 0.0003777
[59] Improvement: 1.7043014963746828e-06	Time (s): 0.0003922
[60] Improvement: 1.3717292972614814e-06	Time (s): 0.0003757
[61] Improvement: 1.1040904723813583e-06	Time (s): 0.0003517
[62] Improvement: 8.886963200893661e-07	Time (s): 0.0003762
[63] Improvement: 7.153404339987901e-07	Time (s): 0.0003717
[64] Improvement: 5.758126917498885e-07	Time (s): 0.0003726
[65] Improvement: 4.6350817228812957e-07	Time (s): 0.0003126
[66] Improvement: 3.731129254447296e-07	Time (s): 0.0005782
[67] Improvement: 3.003508481924655e-07	Time (s): 0.000371
[68] Improvement: 2.417812652311113e-07	Time (s): 0.0003901
[69] Improvement: 1.9463459466351196e-07	Time (s): 0.0003741
[70] Improvement: 1.5668274500058033e-07	Time (s): 0.0004246
[71] Improvement: 1.2613236322067678e-07	Time (s): 0.0003846
[72] Improvement: 1.0153871699003503e-07	Time (s): 0.0003865
[73] Improvement: 8.174112053893623e-08	Time (s): 0.000381
[74] Improvement: 6.580427225344465e-08	Time (s): 0.0003839
[75] Improvement: 5.2974201025790535e-08	Time (s): 0.0003769
[76] Improvement: 4.264597919245716e-08	Time (s): 0.0002744
[77] Improvement: 3.4331549159105634e-08	Time (s): 0.0003893
[78] Improvement: 2.763817974482663e-08	Time (s): 0.0003746
[79] Improvement: 2.2249821540754056e-08	Time (s): 0.0003777
[80] Improvement: 1.7912043404066935e-08	Time (s): 0.0003772
[81] Improvement: 1.4419867966353195e-08	Time (s): 0.0003574
[82] Improvement: 1.1608676686591934e-08	Time (s): 0.0003762
[83] Improvement: 9.34522859097342e-09	Time (s): 0.0004036
[84] Improvement: 7.523397016484523e-09	Time (s): 0.0003808
[85] Improvement: 6.056836809875676e-09	Time (s): 0.0003934
[86] Improvement: 4.87591478304239e-09	Time (s): 0.000376
[87] Improvement: 3.925379132851958e-09	Time (s): 0.0003545
[88] Improvement: 3.1600961847289e-09	Time (s): 0.0003591
[89] Improvement: 2.543913524277741e-09	Time (s): 0.0003557
[90] Improvement: 2.0481820683926344e-09	Time (s): 0.0004294
[91] Improvement: 1.648572833801154e-09	Time (s): 0.0003331
[92] Improvement: 1.3274075172375888e-09	Time (s): 0.0003099
[93] Improvement: 1.0687131180020515e-09	Time (s): 0.0003028
[94] Improvement: 8.602683010394685e-10	Time (s): 0.0002935
Total Training Improvement: 30.469438433876462
Total Training Time (s): 0.0377

In [76]:
model.predict(X[6,:])


Out[76]:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]

In [123]:
def _plot_X(index, ax):
    markerline, stemlines, baseline = ax.stem(range(1, len(X[index,:])+1), X[index,:], '-')
    # setting property of baseline with color red and linewidth 2
    #_ann(markerline, stemlines, baseline)
    _ = plt.setp(stemlines, color='b', linewidth=2)
    _ = plt.setp(stemlines[10], color='r', linewidth=2)
    _ = plt.setp(stemlines[16], color='r', linewidth=2)
    _ = plt.setp(markerline, color='b', linewidth=2)
    _ = plt.setp(baseline, color='w', linewidth=2)
    
def _mark_states(index, ax):
    prediction = np.array(model.predict(X[index,:]))
    markerline, stemlines, baseline = ax.stem(range(1, len(X[0,:])+1), prediction, '-')
    _ = plt.setp(stemlines, color='b', linewidth=2)
    for x in np.where(prediction==1)[0]:
        _ = plt.setp(stemlines[x], color='r', linewidth=2)
    _ = plt.setp(markerline, color='b', linewidth=2)
    _ = plt.setp(baseline, color='w', linewidth=2)

In [130]:
fig = plt.figure(figsize=(15, 5))
ax = plt.subplot(231)
ax.set_title('S1-counts')
_plot_X(0, ax)
ax = plt.subplot(234)
ax.set_title('S1-prediction')
_mark_states(0, ax)

ax = plt.subplot(232)
ax.set_title('S2-counts')
_plot_X(1, ax)
ax = plt.subplot(235)
ax.set_title('S2-prediction')
_mark_states(1, ax)

ax = plt.subplot(233)
ax.set_title('S3-counts')
_plot_X(1, ax)
ax = plt.subplot(236)
ax.set_title('S3-prediction')
_mark_states(1, ax)
fig.tight_layout()
fig.savefig('./poisson_hmm.pdf')



In [97]:
x


Out[97]:
array([ 6, 10, 12, 16])

In [23]:
model.fit(X_T[0], algorithm = 'baum-welch', min_iterations = 10, max_iterations = 10, verbose = True)


[1] Improvement: 28.630255148194124	Time (s): 0.0001838
[2] Improvement: 1.849677236425869	Time (s): 0.0004764
[3] Improvement: 0.9042281453633834	Time (s): 0.0003605
[4] Improvement: 0.29265315082331256	Time (s): 0.000571
[5] Improvement: 0.09268132273710705	Time (s): 0.0001519
[6] Improvement: 0.02668330529071028	Time (s): 0.0001123
[7] Improvement: 0.007185721437707571	Time (s): 0.0001051
[8] Improvement: 0.0018950440458738171	Time (s): 0.0008717
[9] Improvement: 0.0004969385193192011	Time (s): 0.0003617
[10] Improvement: 0.00013011720595557108	Time (s): 0.0003211
Total Training Improvement: 31.805886130043362
Total Training Time (s): 0.0051
Out[23]:
{
    "class" : "HiddenMarkovModel",
    "name" : "None",
    "start" : {
        "class" : "State",
        "distribution" : null,
        "name" : "None-start",
        "weight" : 1.0
    },
    "end" : {
        "class" : "State",
        "distribution" : null,
        "name" : "None-end",
        "weight" : 1.0
    },
    "states" : [
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "name" : "PoissonDistribution",
                "parameters" : [
                    3.443460502751991
                ],
                "frozen" : false
            },
            "name" : "s0",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : {
                "class" : "Distribution",
                "name" : "PoissonDistribution",
                "parameters" : [
                    7.931120040335618
                ],
                "frozen" : false
            },
            "name" : "s1",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : null,
            "name" : "None-start",
            "weight" : 1.0
        },
        {
            "class" : "State",
            "distribution" : null,
            "name" : "None-end",
            "weight" : 1.0
        }
    ],
    "end_index" : 3,
    "start_index" : 2,
    "silent_index" : 2,
    "edges" : [
        [
            2,
            0,
            1.0,
            1.0,
            null
        ],
        [
            0,
            0,
            0.7988590418598281,
            0.7,
            null
        ],
        [
            0,
            1,
            0.2011409581401719,
            0.3,
            null
        ],
        [
            1,
            0,
            1.6258879978290453e-06,
            0.3,
            null
        ],
        [
            1,
            1,
            0.8011248537035781,
            0.7,
            null
        ],
        [
            1,
            3,
            0.198873520408424,
            0.1,
            null
        ]
    ],
    "distribution ties" : []
}

In [33]:
print("\n".join( "{}: {}".format( state.name, state.distribution ) 	for state in model.states if not state.is_silent() ))


normal: {
    "class" :"Distribution",
    "name" :"PoissonDistribution",
    "parameters" :[
        3.0000000000000764
    ],
    "frozen" :false
}
stalling: {
    "class" :"Distribution",
    "name" :"PoissonDistribution",
    "parameters" :[
        8.399999999495781
    ],
    "frozen" :false
}

In [ ]: