In [119]:
%pylab inline
import seaborn as sns
sns.set_style('white')
sns.set_context('paper', font_scale=2.5)
from pomegranate import HiddenMarkovModel, NormalDistribution, PoissonDistribution
model = HiddenMarkovModel()
In [5]:
binned_spikes = np.load('/home/saket/Downloads/binned_spikes.npy')
binned_spikes.shape
Out[5]:
In [13]:
lambdas = [4, 10]
N = 10
X_T = np.array([np.random.poisson(lam, size=N) for lam in lambdas])
X_T.shape
Out[13]:
In [48]:
#lambdas = [4, 10]
#N = 10
def generate_X():
lambdas = [4, 10]
N = 10
X_T = list(np.random.poisson(lambdas[0], size=N))
X_T += list(np.random.poisson(lambdas[1], size=1))
X_T += list(np.random.poisson(lambdas[0], size=int(N/2)))
X_T += list(np.random.poisson(lambdas[1], size=1))
Y_T = [0]*N + [1] + [0]*int(N/2) + [1]
return X_T, Y_T
#X_T = np.array(X_T).reshape(len(X_T),1)
#X_T.shape
X = []
Y = []
n_samples = 10
for n in range(n_samples):
x, y = generate_X()
X.append(x)
Y.append(y)
X= np.array(X)
X.shape
Out[48]:
In [49]:
plt.plot(X)
Out[49]:
In [19]:
from pomegranate import *
import numpy as np
model = HiddenMarkovModel('%i' % 1)
states = []
n_states = 1
for i in range(n_states):
states.append(State(IndependentComponentsDistribution([PoissonDistribution(np.random.rand()) for unit in range(binned_spikes.shape[2])]), name = 'State%i' % (i+1)))
model.add_states(states)
for i in range(n_states):
if i == 0:
model.add_transition(model.start, states[i], 1.0)
else:
model.add_transition(model.start, states[i], 0.0)
for i in range(n_states):
not_transitioning_prob = (0.999-0.95)*np.random.random() + 0.95
for j in range(n_states):
if i - j == 0:
model.add_transition(states[i], states[j], not_transitioning_prob)
elif j - i == 1:
model.add_transition(states[i], states[j], 1.0 - not_transitioning_prob)
else:
model.add_transition(states[i], states[j], 0.0)
model.bake()
model.dense_transition_matrix()
Out[19]:
In [22]:
trans_mat = numpy.array([[0.7, 0.3],
[0.3, 0.7]])
starts = numpy.array([0.5, 0.0])
ends = numpy.array([0.0, 0.1])
dists = [PoissonDistribution(2), PoissonDistribution(3)]
model = HiddenMarkovModel.from_matrix(trans_mat, dists, starts, ends)
In [121]:
model = HiddenMarkovModel.from_samples(PoissonDistribution, n_components=2,
X=X,
algorithm='baum-welch',
state_names=['normal', 'stalling'],
verbose=True)
model.bake()
In [76]:
model.predict(X[6,:])
Out[76]:
In [123]:
def _plot_X(index, ax):
markerline, stemlines, baseline = ax.stem(range(1, len(X[index,:])+1), X[index,:], '-')
# setting property of baseline with color red and linewidth 2
#_ann(markerline, stemlines, baseline)
_ = plt.setp(stemlines, color='b', linewidth=2)
_ = plt.setp(stemlines[10], color='r', linewidth=2)
_ = plt.setp(stemlines[16], color='r', linewidth=2)
_ = plt.setp(markerline, color='b', linewidth=2)
_ = plt.setp(baseline, color='w', linewidth=2)
def _mark_states(index, ax):
prediction = np.array(model.predict(X[index,:]))
markerline, stemlines, baseline = ax.stem(range(1, len(X[0,:])+1), prediction, '-')
_ = plt.setp(stemlines, color='b', linewidth=2)
for x in np.where(prediction==1)[0]:
_ = plt.setp(stemlines[x], color='r', linewidth=2)
_ = plt.setp(markerline, color='b', linewidth=2)
_ = plt.setp(baseline, color='w', linewidth=2)
In [130]:
fig = plt.figure(figsize=(15, 5))
ax = plt.subplot(231)
ax.set_title('S1-counts')
_plot_X(0, ax)
ax = plt.subplot(234)
ax.set_title('S1-prediction')
_mark_states(0, ax)
ax = plt.subplot(232)
ax.set_title('S2-counts')
_plot_X(1, ax)
ax = plt.subplot(235)
ax.set_title('S2-prediction')
_mark_states(1, ax)
ax = plt.subplot(233)
ax.set_title('S3-counts')
_plot_X(1, ax)
ax = plt.subplot(236)
ax.set_title('S3-prediction')
_mark_states(1, ax)
fig.tight_layout()
fig.savefig('./poisson_hmm.pdf')
In [97]:
x
Out[97]:
In [23]:
model.fit(X_T[0], algorithm = 'baum-welch', min_iterations = 10, max_iterations = 10, verbose = True)
Out[23]:
In [33]:
print("\n".join( "{}: {}".format( state.name, state.distribution ) for state in model.states if not state.is_silent() ))
In [ ]: