In [1]:
#some boring bookkeeping

%matplotlib notebook
import matplotlib.pyplot as plt

#matplotlib.rcParams['figure.figsize'] = [10,4]

In [2]:
#from nhmm import *
from pomegranate import *
import pomegranate
import numpy as np
import scipy.signal
#from scipy.preprocessing import scale
from sklearn.preprocessing import scale

from tqdm import tqdm

# print(pomegranate.utils.is_gpu_enabled())
# pomegranate.utils.enable_gpu()
# print(pomegranate.utils.is_gpu_enabled())

import wfdb
import numpy as np
import scipy.signal

In [25]:
signals, fields = wfdb.rdsamp('data/mitdb/100')

# signal = signals[:200000,0]
signal = signals[:,0]


# signal = decimate(signal)
# signal = scipy.signal.decimate(signal,2,ftype='fir')

# signal = scipy.signal.decimate(signal,4,ftype='fir')
org0 = signal

s0 = np.diff(org0)

#signals, fields

signals, fields = wfdb.rdsamp('data/mitdb/102')
signal = signals[:,0]
org2 = signal
s2 = np.diff(org2)

signals, fields = wfdb.rdsamp('data/mitdb/234')
signal = signals[:,0]
org1 = signal
s1 = np.diff(org1)

In [26]:
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4,1, sharex=True)

ax1.plot(s0)
ax2.plot(org0)

ax3.plot(s2)
ax4.plot(org2)
# ax4.plot(s[:,0])


Out[26]:
[<matplotlib.lines.Line2D at 0x7f2b3bd2d8d0>]

In [5]:
from pomegranate import *

num = 512

states = [None]*num

for i in range(num):
#     dist = NormalDistribution((np.random.random()-0.5)/1000, 1./10000, min_std=1e-9)
    dist = NormalDistribution(s0[880+i//2], np.random.random()/1000, min_std=1e-9)

    states[i] = State(dist, name="s{:03d}".format(i))

model = HiddenMarkovModel()
model.add_states(states)

for i in range(num):
    model.add_transition(model.start, states[i], 1.0)

for i in range(num):
    model.add_transition(states[i], states[i], np.random.random())
    for j in range(i,i+6):
#     for j in range(0,num):    
        model.add_transition(states[i], states[j%num], np.random.random())


model.add_transition(states[num-1],states[0], 1.0)

model.bake(verbose=True)


None : None-start summed to 512.0, normalized to 1.0
None : s000 summed to 2.49868512, normalized to 1.0
None : s001 summed to 1.92558172, normalized to 1.0
None : s002 summed to 2.8484566, normalized to 1.0
None : s003 summed to 2.96898479, normalized to 1.0
None : s004 summed to 2.43254374, normalized to 1.0
None : s005 summed to 3.75345125, normalized to 1.0
None : s006 summed to 3.33402646, normalized to 1.0
None : s007 summed to 1.20628477, normalized to 1.0
None : s008 summed to 2.96051345, normalized to 1.0
None : s009 summed to 3.03691896, normalized to 1.0
None : s010 summed to 2.48470169, normalized to 1.0
None : s011 summed to 3.26298589, normalized to 1.0
None : s012 summed to 3.61661047, normalized to 1.0
None : s013 summed to 2.96621964, normalized to 1.0
None : s014 summed to 3.3123143, normalized to 1.0
None : s015 summed to 3.11689224, normalized to 1.0
None : s016 summed to 2.10128317, normalized to 1.0
None : s017 summed to 3.89939765, normalized to 1.0
None : s018 summed to 2.98108186, normalized to 1.0
None : s019 summed to 3.28275256, normalized to 1.0
None : s020 summed to 2.69453044, normalized to 1.0
None : s021 summed to 4.29030395, normalized to 1.0
None : s022 summed to 3.29286836, normalized to 1.0
None : s023 summed to 2.8920855, normalized to 1.0
None : s024 summed to 2.99606032, normalized to 1.0
None : s025 summed to 2.58695813, normalized to 1.0
None : s026 summed to 2.73995001, normalized to 1.0
None : s027 summed to 3.04947939, normalized to 1.0
None : s028 summed to 2.23496283, normalized to 1.0
None : s029 summed to 2.44985184, normalized to 1.0
None : s030 summed to 3.20860443, normalized to 1.0
None : s031 summed to 3.15918259, normalized to 1.0
None : s032 summed to 1.97745889, normalized to 1.0
None : s033 summed to 3.7012475, normalized to 1.0
None : s034 summed to 2.84541616, normalized to 1.0
None : s035 summed to 3.82625331, normalized to 1.0
None : s036 summed to 2.68363711, normalized to 1.0
None : s037 summed to 3.24243299, normalized to 1.0
None : s038 summed to 3.80736996, normalized to 1.0
None : s039 summed to 2.28845165, normalized to 1.0
None : s040 summed to 2.51202027, normalized to 1.0
None : s041 summed to 4.33634288, normalized to 1.0
None : s042 summed to 2.88807515, normalized to 1.0
None : s043 summed to 4.8666723, normalized to 1.0
None : s044 summed to 3.84792499, normalized to 1.0
None : s045 summed to 3.03434748, normalized to 1.0
None : s046 summed to 4.0515849, normalized to 1.0
None : s047 summed to 2.18433557, normalized to 1.0
None : s048 summed to 3.35076255, normalized to 1.0
None : s049 summed to 3.30807808, normalized to 1.0
None : s050 summed to 3.58413489, normalized to 1.0
None : s051 summed to 3.78335331, normalized to 1.0
None : s052 summed to 2.71931942, normalized to 1.0
None : s053 summed to 3.43383021, normalized to 1.0
None : s054 summed to 3.74321822, normalized to 1.0
None : s055 summed to 3.26739783, normalized to 1.0
None : s056 summed to 2.87250162, normalized to 1.0
None : s057 summed to 2.6692319, normalized to 1.0
None : s058 summed to 3.03866127, normalized to 1.0
None : s059 summed to 2.88509862, normalized to 1.0
None : s060 summed to 2.96562934, normalized to 1.0
None : s061 summed to 2.78184272, normalized to 1.0
None : s062 summed to 2.73081486, normalized to 1.0
None : s063 summed to 3.31135402, normalized to 1.0
None : s064 summed to 2.94246577, normalized to 1.0
None : s065 summed to 2.97714679, normalized to 1.0
None : s066 summed to 3.39328151, normalized to 1.0
None : s067 summed to 2.47046863, normalized to 1.0
None : s068 summed to 2.68842516, normalized to 1.0
None : s069 summed to 3.22282362, normalized to 1.0
None : s070 summed to 3.22359986, normalized to 1.0
None : s071 summed to 2.94751136, normalized to 1.0
None : s072 summed to 2.3562152, normalized to 1.0
None : s073 summed to 3.08708657, normalized to 1.0
None : s074 summed to 1.95707229, normalized to 1.0
None : s075 summed to 2.60178314, normalized to 1.0
None : s076 summed to 2.86123645, normalized to 1.0
None : s077 summed to 2.5604639, normalized to 1.0
None : s078 summed to 1.94876045, normalized to 1.0
None : s079 summed to 2.03799435, normalized to 1.0
None : s080 summed to 2.79459079, normalized to 1.0
None : s081 summed to 2.322055, normalized to 1.0
None : s082 summed to 3.24394914, normalized to 1.0
None : s083 summed to 3.34840648, normalized to 1.0
None : s084 summed to 3.40298787, normalized to 1.0
None : s085 summed to 3.52471484, normalized to 1.0
None : s086 summed to 3.12274421, normalized to 1.0
None : s087 summed to 2.62911963, normalized to 1.0
None : s088 summed to 3.35487569, normalized to 1.0
None : s089 summed to 3.40144997, normalized to 1.0
None : s090 summed to 3.12541948, normalized to 1.0
None : s091 summed to 3.67772205, normalized to 1.0
None : s092 summed to 3.41815, normalized to 1.0
None : s093 summed to 3.2356102, normalized to 1.0
None : s094 summed to 2.46475177, normalized to 1.0
None : s095 summed to 2.38358567, normalized to 1.0
None : s096 summed to 4.15105527, normalized to 1.0
None : s097 summed to 2.41020037, normalized to 1.0
None : s098 summed to 3.9489766, normalized to 1.0
None : s099 summed to 3.14022904, normalized to 1.0
None : s100 summed to 3.02224517, normalized to 1.0
None : s101 summed to 3.05093866, normalized to 1.0
None : s102 summed to 2.48842737, normalized to 1.0
None : s103 summed to 2.82222236, normalized to 1.0
None : s104 summed to 3.23840723, normalized to 1.0
None : s105 summed to 3.67796234, normalized to 1.0
None : s106 summed to 2.5782051, normalized to 1.0
None : s107 summed to 3.36273288, normalized to 1.0
None : s108 summed to 2.60128751, normalized to 1.0
None : s109 summed to 3.32786727, normalized to 1.0
None : s110 summed to 2.42960036, normalized to 1.0
None : s111 summed to 2.80764465, normalized to 1.0
None : s112 summed to 3.30274374, normalized to 1.0
None : s113 summed to 3.06096445, normalized to 1.0
None : s114 summed to 3.15978883, normalized to 1.0
None : s115 summed to 1.43418618, normalized to 1.0
None : s116 summed to 3.13115996, normalized to 1.0
None : s117 summed to 3.21155355, normalized to 1.0
None : s118 summed to 2.04916578, normalized to 1.0
None : s119 summed to 3.60038317, normalized to 1.0
None : s120 summed to 3.04203395, normalized to 1.0
None : s121 summed to 2.40047374, normalized to 1.0
None : s122 summed to 2.1432326, normalized to 1.0
None : s123 summed to 2.95255888, normalized to 1.0
None : s124 summed to 3.62152294, normalized to 1.0
None : s125 summed to 4.22342802, normalized to 1.0
None : s126 summed to 3.40653026, normalized to 1.0
None : s127 summed to 3.77260551, normalized to 1.0
None : s128 summed to 3.46372318, normalized to 1.0
None : s129 summed to 3.70045391, normalized to 1.0
None : s130 summed to 3.29102247, normalized to 1.0
None : s131 summed to 3.05291764, normalized to 1.0
None : s132 summed to 1.81537353, normalized to 1.0
None : s133 summed to 3.56037476, normalized to 1.0
None : s134 summed to 2.43361943, normalized to 1.0
None : s135 summed to 2.84119329, normalized to 1.0
None : s136 summed to 3.45260922, normalized to 1.0
None : s137 summed to 2.39441991, normalized to 1.0
None : s138 summed to 3.35654143, normalized to 1.0
None : s139 summed to 3.05431965, normalized to 1.0
None : s140 summed to 2.60641959, normalized to 1.0
None : s141 summed to 2.43066935, normalized to 1.0
None : s142 summed to 2.50673287, normalized to 1.0
None : s143 summed to 4.51055325, normalized to 1.0
None : s144 summed to 3.56269904, normalized to 1.0
None : s145 summed to 2.55517786, normalized to 1.0
None : s146 summed to 3.33183642, normalized to 1.0
None : s147 summed to 3.15301045, normalized to 1.0
None : s148 summed to 1.52302555, normalized to 1.0
None : s149 summed to 2.37970277, normalized to 1.0
None : s150 summed to 2.95391657, normalized to 1.0
None : s151 summed to 2.13756806, normalized to 1.0
None : s152 summed to 3.91568873, normalized to 1.0
None : s153 summed to 3.627402, normalized to 1.0
None : s154 summed to 3.59560902, normalized to 1.0
None : s155 summed to 2.11856774, normalized to 1.0
None : s156 summed to 2.98925803, normalized to 1.0
None : s157 summed to 2.97997645, normalized to 1.0
None : s158 summed to 2.61704245, normalized to 1.0
None : s159 summed to 3.23586267, normalized to 1.0
None : s160 summed to 1.90909323, normalized to 1.0
None : s161 summed to 2.69623524, normalized to 1.0
None : s162 summed to 3.12789797, normalized to 1.0
None : s163 summed to 2.95202077, normalized to 1.0
None : s164 summed to 2.54470998, normalized to 1.0
None : s165 summed to 2.61255441, normalized to 1.0
None : s166 summed to 3.09850023, normalized to 1.0
None : s167 summed to 2.99097002, normalized to 1.0
None : s168 summed to 1.28724521, normalized to 1.0
None : s169 summed to 3.23363663, normalized to 1.0
None : s170 summed to 3.56609647, normalized to 1.0
None : s171 summed to 2.92827706, normalized to 1.0
None : s172 summed to 3.94128963, normalized to 1.0
None : s173 summed to 3.15340164, normalized to 1.0
None : s174 summed to 2.20788136, normalized to 1.0
None : s175 summed to 3.01310019, normalized to 1.0
None : s176 summed to 3.42887624, normalized to 1.0
None : s177 summed to 2.58704023, normalized to 1.0
None : s178 summed to 2.48548195, normalized to 1.0
None : s179 summed to 4.06627629, normalized to 1.0
None : s180 summed to 1.89888318, normalized to 1.0
None : s181 summed to 3.1143753, normalized to 1.0
None : s182 summed to 3.00040283, normalized to 1.0
None : s183 summed to 2.04409047, normalized to 1.0
None : s184 summed to 3.32938566, normalized to 1.0
None : s185 summed to 2.46917742, normalized to 1.0
None : s186 summed to 3.46786454, normalized to 1.0
None : s187 summed to 3.82325302, normalized to 1.0
None : s188 summed to 3.12118887, normalized to 1.0
None : s189 summed to 3.64081001, normalized to 1.0
None : s190 summed to 2.75310208, normalized to 1.0
None : s191 summed to 3.1783809, normalized to 1.0
None : s192 summed to 3.55980942, normalized to 1.0
None : s193 summed to 2.80662128, normalized to 1.0
None : s194 summed to 2.72565974, normalized to 1.0
None : s195 summed to 2.6863838, normalized to 1.0
None : s196 summed to 3.62821815, normalized to 1.0
None : s197 summed to 2.35216267, normalized to 1.0
None : s198 summed to 3.05027667, normalized to 1.0
None : s199 summed to 2.3517862, normalized to 1.0
None : s200 summed to 3.97668457, normalized to 1.0
None : s201 summed to 5.06237169, normalized to 1.0
None : s202 summed to 3.24972968, normalized to 1.0
None : s203 summed to 3.46217782, normalized to 1.0
None : s204 summed to 3.06836031, normalized to 1.0
None : s205 summed to 2.85697465, normalized to 1.0
None : s206 summed to 2.61657029, normalized to 1.0
None : s207 summed to 3.44687042, normalized to 1.0
None : s208 summed to 2.35257527, normalized to 1.0
None : s209 summed to 3.266258, normalized to 1.0
None : s210 summed to 2.29694567, normalized to 1.0
None : s211 summed to 3.39102555, normalized to 1.0
None : s212 summed to 3.18539237, normalized to 1.0
None : s213 summed to 3.79994193, normalized to 1.0
None : s214 summed to 2.42971073, normalized to 1.0
None : s215 summed to 2.28127668, normalized to 1.0
None : s216 summed to 2.46059942, normalized to 1.0
None : s217 summed to 2.69677428, normalized to 1.0
None : s218 summed to 1.8371083, normalized to 1.0
None : s219 summed to 2.41915447, normalized to 1.0
None : s220 summed to 2.39875565, normalized to 1.0
None : s221 summed to 3.55641532, normalized to 1.0
None : s222 summed to 2.69010767, normalized to 1.0
None : s223 summed to 2.52607381, normalized to 1.0
None : s224 summed to 2.95299383, normalized to 1.0
None : s225 summed to 1.58622977, normalized to 1.0
None : s226 summed to 2.86636646, normalized to 1.0
None : s227 summed to 2.38290643, normalized to 1.0
None : s228 summed to 3.77316784, normalized to 1.0
None : s229 summed to 3.26485005, normalized to 1.0
None : s230 summed to 4.34516467, normalized to 1.0
None : s231 summed to 2.32151214, normalized to 1.0
None : s232 summed to 3.60106835, normalized to 1.0
None : s233 summed to 3.53479541, normalized to 1.0
None : s234 summed to 2.64008676, normalized to 1.0
None : s235 summed to 3.26383952, normalized to 1.0
None : s236 summed to 3.85342399, normalized to 1.0
None : s237 summed to 4.34898065, normalized to 1.0
None : s238 summed to 2.88594509, normalized to 1.0
None : s239 summed to 3.6447586, normalized to 1.0
None : s240 summed to 2.72220429, normalized to 1.0
None : s241 summed to 1.85390839, normalized to 1.0
None : s242 summed to 2.66197662, normalized to 1.0
None : s243 summed to 3.80607463, normalized to 1.0
None : s244 summed to 3.31727869, normalized to 1.0
None : s245 summed to 2.20812551, normalized to 1.0
None : s246 summed to 2.03740943, normalized to 1.0
None : s247 summed to 3.31679125, normalized to 1.0
None : s248 summed to 3.73955091, normalized to 1.0
None : s249 summed to 2.61008931, normalized to 1.0
None : s250 summed to 3.60355416, normalized to 1.0
None : s251 summed to 1.6322936, normalized to 1.0
None : s252 summed to 2.37811616, normalized to 1.0
None : s253 summed to 2.68536767, normalized to 1.0
None : s254 summed to 3.88329516, normalized to 1.0
None : s255 summed to 3.06762288, normalized to 1.0
None : s256 summed to 3.91893143, normalized to 1.0
None : s257 summed to 1.86285857, normalized to 1.0
None : s258 summed to 3.75972731, normalized to 1.0
None : s259 summed to 1.61172697, normalized to 1.0
None : s260 summed to 3.23560119, normalized to 1.0
None : s261 summed to 3.86307847, normalized to 1.0
None : s262 summed to 4.41171246, normalized to 1.0
None : s263 summed to 3.31906072, normalized to 1.0
None : s264 summed to 4.10682884, normalized to 1.0
None : s265 summed to 3.48761614, normalized to 1.0
None : s266 summed to 2.08703442, normalized to 1.0
None : s267 summed to 3.05970792, normalized to 1.0
None : s268 summed to 3.48297201, normalized to 1.0
None : s269 summed to 1.96357929, normalized to 1.0
None : s270 summed to 3.12379509, normalized to 1.0
None : s271 summed to 3.05022789, normalized to 1.0
None : s272 summed to 4.36169481, normalized to 1.0
None : s273 summed to 3.29890775, normalized to 1.0
None : s274 summed to 4.3841783, normalized to 1.0
None : s275 summed to 3.41644276, normalized to 1.0
None : s276 summed to 3.42153591, normalized to 1.0
None : s277 summed to 2.82706428, normalized to 1.0
None : s278 summed to 3.47531245, normalized to 1.0
None : s279 summed to 2.17764315, normalized to 1.0
None : s280 summed to 2.98452244, normalized to 1.0
None : s281 summed to 3.62850554, normalized to 1.0
None : s282 summed to 2.84461981, normalized to 1.0
None : s283 summed to 2.54529127, normalized to 1.0
None : s284 summed to 3.60806868, normalized to 1.0
None : s285 summed to 2.4526414, normalized to 1.0
None : s286 summed to 3.31659519, normalized to 1.0
None : s287 summed to 1.94175734, normalized to 1.0
None : s288 summed to 2.45599643, normalized to 1.0
None : s289 summed to 2.80790323, normalized to 1.0
None : s290 summed to 2.679667, normalized to 1.0
None : s291 summed to 2.71243526, normalized to 1.0
None : s292 summed to 2.89741156, normalized to 1.0
None : s293 summed to 4.32614213, normalized to 1.0
None : s294 summed to 2.98851098, normalized to 1.0
None : s295 summed to 1.74967337, normalized to 1.0
None : s296 summed to 2.01966293, normalized to 1.0
None : s297 summed to 1.58680361, normalized to 1.0
None : s298 summed to 2.7779946, normalized to 1.0
None : s299 summed to 3.08165699, normalized to 1.0
None : s300 summed to 3.88836811, normalized to 1.0
None : s301 summed to 3.31485093, normalized to 1.0
None : s302 summed to 3.14917729, normalized to 1.0
None : s303 summed to 3.16633176, normalized to 1.0
None : s304 summed to 1.7288234, normalized to 1.0
None : s305 summed to 3.26825561, normalized to 1.0
None : s306 summed to 4.08049987, normalized to 1.0
None : s307 summed to 4.14041106, normalized to 1.0
None : s308 summed to 2.72916164, normalized to 1.0
None : s309 summed to 1.73284004, normalized to 1.0
None : s310 summed to 3.80125711, normalized to 1.0
None : s311 summed to 3.72217317, normalized to 1.0
None : s312 summed to 2.70381565, normalized to 1.0
None : s313 summed to 3.42103805, normalized to 1.0
None : s314 summed to 3.46332787, normalized to 1.0
None : s315 summed to 2.90261721, normalized to 1.0
None : s316 summed to 2.95932981, normalized to 1.0
None : s317 summed to 2.96443488, normalized to 1.0
None : s318 summed to 1.60809587, normalized to 1.0
None : s319 summed to 3.06137225, normalized to 1.0
None : s320 summed to 3.02916419, normalized to 1.0
None : s321 summed to 2.87709651, normalized to 1.0
None : s322 summed to 2.93570867, normalized to 1.0
None : s323 summed to 3.80505219, normalized to 1.0
None : s324 summed to 2.47060682, normalized to 1.0
None : s325 summed to 2.83466028, normalized to 1.0
None : s326 summed to 3.46290995, normalized to 1.0
None : s327 summed to 1.70791253, normalized to 1.0
None : s328 summed to 1.32669333, normalized to 1.0
None : s329 summed to 4.18686962, normalized to 1.0
None : s330 summed to 3.45266769, normalized to 1.0
None : s331 summed to 3.01362488, normalized to 1.0
None : s332 summed to 2.42049808, normalized to 1.0
None : s333 summed to 2.28675216, normalized to 1.0
None : s334 summed to 2.86090094, normalized to 1.0
None : s335 summed to 3.82274143, normalized to 1.0
None : s336 summed to 3.203522, normalized to 1.0
None : s337 summed to 2.26379758, normalized to 1.0
None : s338 summed to 2.68399581, normalized to 1.0
None : s339 summed to 3.21574108, normalized to 1.0
None : s340 summed to 4.08586281, normalized to 1.0
None : s341 summed to 3.9204567, normalized to 1.0
None : s342 summed to 2.70892884, normalized to 1.0
None : s343 summed to 3.76289397, normalized to 1.0
None : s344 summed to 3.09418301, normalized to 1.0
None : s345 summed to 3.61502285, normalized to 1.0
None : s346 summed to 3.19120736, normalized to 1.0
None : s347 summed to 3.36552261, normalized to 1.0
None : s348 summed to 1.97960775, normalized to 1.0
None : s349 summed to 3.71374558, normalized to 1.0
None : s350 summed to 3.01355873, normalized to 1.0
None : s351 summed to 4.64084074, normalized to 1.0
None : s352 summed to 2.09743514, normalized to 1.0
None : s353 summed to 3.46692832, normalized to 1.0
None : s354 summed to 3.38456302, normalized to 1.0
None : s355 summed to 3.2526641, normalized to 1.0
None : s356 summed to 3.01098898, normalized to 1.0
None : s357 summed to 3.83738947, normalized to 1.0
None : s358 summed to 4.03168031, normalized to 1.0
None : s359 summed to 1.65562215, normalized to 1.0
None : s360 summed to 3.15290426, normalized to 1.0
None : s361 summed to 3.33277187, normalized to 1.0
None : s362 summed to 3.32802801, normalized to 1.0
None : s363 summed to 2.80412495, normalized to 1.0
None : s364 summed to 4.06539443, normalized to 1.0
None : s365 summed to 3.19883292, normalized to 1.0
None : s366 summed to 3.22543986, normalized to 1.0
None : s367 summed to 1.95957436, normalized to 1.0
None : s368 summed to 3.49604635, normalized to 1.0
None : s369 summed to 3.58266715, normalized to 1.0
None : s370 summed to 2.32894819, normalized to 1.0
None : s371 summed to 2.12593779, normalized to 1.0
None : s372 summed to 3.60658543, normalized to 1.0
None : s373 summed to 2.4956652, normalized to 1.0
None : s374 summed to 4.21189186, normalized to 1.0
None : s375 summed to 3.25873684, normalized to 1.0
None : s376 summed to 2.07128097, normalized to 1.0
None : s377 summed to 3.83651152, normalized to 1.0
None : s378 summed to 2.4268369, normalized to 1.0
None : s379 summed to 3.44748565, normalized to 1.0
None : s380 summed to 2.82629438, normalized to 1.0
None : s381 summed to 2.63874518, normalized to 1.0
None : s382 summed to 1.48295465, normalized to 1.0
None : s383 summed to 3.88505661, normalized to 1.0
None : s384 summed to 2.78405824, normalized to 1.0
None : s385 summed to 3.37318043, normalized to 1.0
None : s386 summed to 3.39198106, normalized to 1.0
None : s387 summed to 3.53243667, normalized to 1.0
None : s388 summed to 2.42831667, normalized to 1.0
None : s389 summed to 2.05851227, normalized to 1.0
None : s390 summed to 2.55997968, normalized to 1.0
None : s391 summed to 2.89403157, normalized to 1.0
None : s392 summed to 3.91500734, normalized to 1.0
None : s393 summed to 1.59075388, normalized to 1.0
None : s394 summed to 2.04697827, normalized to 1.0
None : s395 summed to 3.25470494, normalized to 1.0
None : s396 summed to 2.77349092, normalized to 1.0
None : s397 summed to 1.38270415, normalized to 1.0
None : s398 summed to 3.04583083, normalized to 1.0
None : s399 summed to 2.39468901, normalized to 1.0
None : s400 summed to 3.01823993, normalized to 1.0
None : s401 summed to 2.20013305, normalized to 1.0
None : s402 summed to 2.17421658, normalized to 1.0
None : s403 summed to 2.69444853, normalized to 1.0
None : s404 summed to 2.00675172, normalized to 1.0
None : s405 summed to 3.39790398, normalized to 1.0
None : s406 summed to 2.88429775, normalized to 1.0
None : s407 summed to 3.39626373, normalized to 1.0
None : s408 summed to 3.94837924, normalized to 1.0
None : s409 summed to 2.97647275, normalized to 1.0
None : s410 summed to 3.96811033, normalized to 1.0
None : s411 summed to 2.91520358, normalized to 1.0
None : s412 summed to 3.53039923, normalized to 1.0
None : s413 summed to 3.11354158, normalized to 1.0
None : s414 summed to 4.4731641, normalized to 1.0
None : s415 summed to 1.78780234, normalized to 1.0
None : s416 summed to 2.53700754, normalized to 1.0
None : s417 summed to 2.41358366, normalized to 1.0
None : s418 summed to 3.92697595, normalized to 1.0
None : s419 summed to 2.39262933, normalized to 1.0
None : s420 summed to 3.2726697, normalized to 1.0
None : s421 summed to 1.70320617, normalized to 1.0
None : s422 summed to 3.09985986, normalized to 1.0
None : s423 summed to 2.84277546, normalized to 1.0
None : s424 summed to 2.9919398, normalized to 1.0
None : s425 summed to 3.16593432, normalized to 1.0
None : s426 summed to 2.93123389, normalized to 1.0
None : s427 summed to 2.5503811, normalized to 1.0
None : s428 summed to 2.6252183, normalized to 1.0
None : s429 summed to 2.19253865, normalized to 1.0
None : s430 summed to 3.45205227, normalized to 1.0
None : s431 summed to 4.41233073, normalized to 1.0
None : s432 summed to 2.29673521, normalized to 1.0
None : s433 summed to 2.02054789, normalized to 1.0
None : s434 summed to 3.747247, normalized to 1.0
None : s435 summed to 3.6599057, normalized to 1.0
None : s436 summed to 3.72636497, normalized to 1.0
None : s437 summed to 4.27620493, normalized to 1.0
None : s438 summed to 2.82868098, normalized to 1.0
None : s439 summed to 3.9033439, normalized to 1.0
None : s440 summed to 3.01723921, normalized to 1.0
None : s441 summed to 3.00227415, normalized to 1.0
None : s442 summed to 4.30950453, normalized to 1.0
None : s443 summed to 3.31378103, normalized to 1.0
None : s444 summed to 2.48525765, normalized to 1.0
None : s445 summed to 3.25848501, normalized to 1.0
None : s446 summed to 2.63435779, normalized to 1.0
None : s447 summed to 2.28096744, normalized to 1.0
None : s448 summed to 3.01392706, normalized to 1.0
None : s449 summed to 2.7804988, normalized to 1.0
None : s450 summed to 2.58880607, normalized to 1.0
None : s451 summed to 2.23935723, normalized to 1.0
None : s452 summed to 4.24332415, normalized to 1.0
None : s453 summed to 3.78096821, normalized to 1.0
None : s454 summed to 2.0849315, normalized to 1.0
None : s455 summed to 1.55952397, normalized to 1.0
None : s456 summed to 3.17697134, normalized to 1.0
None : s457 summed to 2.5631381, normalized to 1.0
None : s458 summed to 2.02745804, normalized to 1.0
None : s459 summed to 2.78901977, normalized to 1.0
None : s460 summed to 2.96810809, normalized to 1.0
None : s461 summed to 3.35950469, normalized to 1.0
None : s462 summed to 3.4269949, normalized to 1.0
None : s463 summed to 2.7022887, normalized to 1.0
None : s464 summed to 2.59719208, normalized to 1.0
None : s465 summed to 3.71808556, normalized to 1.0
None : s466 summed to 1.79958563, normalized to 1.0
None : s467 summed to 3.05890887, normalized to 1.0
None : s468 summed to 4.19203381, normalized to 1.0
None : s469 summed to 4.01531579, normalized to 1.0
None : s470 summed to 3.51949522, normalized to 1.0
None : s471 summed to 1.96210123, normalized to 1.0
None : s472 summed to 1.99285169, normalized to 1.0
None : s473 summed to 2.44158088, normalized to 1.0
None : s474 summed to 3.21182253, normalized to 1.0
None : s475 summed to 2.47327754, normalized to 1.0
None : s476 summed to 2.65013833, normalized to 1.0
None : s477 summed to 2.97475946, normalized to 1.0
None : s478 summed to 1.46866146, normalized to 1.0
None : s479 summed to 4.30586122, normalized to 1.0
None : s480 summed to 4.07254592, normalized to 1.0
None : s481 summed to 2.72070031, normalized to 1.0
None : s482 summed to 3.92797957, normalized to 1.0
None : s483 summed to 2.9173859, normalized to 1.0
None : s484 summed to 2.16289614, normalized to 1.0
None : s485 summed to 2.68234361, normalized to 1.0
None : s486 summed to 2.367328, normalized to 1.0
None : s487 summed to 3.48725954, normalized to 1.0
None : s488 summed to 2.58035402, normalized to 1.0
None : s489 summed to 3.67758554, normalized to 1.0
None : s490 summed to 2.20829664, normalized to 1.0
None : s491 summed to 3.89832967, normalized to 1.0
None : s492 summed to 2.79544317, normalized to 1.0
None : s493 summed to 3.53178015, normalized to 1.0
None : s494 summed to 4.22548144, normalized to 1.0
None : s495 summed to 2.61726411, normalized to 1.0
None : s496 summed to 2.65170268, normalized to 1.0
None : s497 summed to 2.15780608, normalized to 1.0
None : s498 summed to 3.67358011, normalized to 1.0
None : s499 summed to 4.46141755, normalized to 1.0
None : s500 summed to 3.00786071, normalized to 1.0
None : s501 summed to 2.40929571, normalized to 1.0
None : s502 summed to 1.81476848, normalized to 1.0
None : s503 summed to 2.47540461, normalized to 1.0
None : s504 summed to 2.75353636, normalized to 1.0
None : s505 summed to 3.16084215, normalized to 1.0
None : s506 summed to 2.37990877, normalized to 1.0
None : s507 summed to 3.01333883, normalized to 1.0
None : s508 summed to 4.02170179, normalized to 1.0
None : s509 summed to 1.933049, normalized to 1.0
None : s510 summed to 1.08685407, normalized to 1.0
None : s511 summed to 3.37373453, normalized to 1.0

In [6]:
fig, (ax1, ax2, ax3) = plt.subplots(3,1, sharex=True)

means = np.array([ state.distribution.parameters[0] for state in model.states if state.distribution ])
covars  = np.array([ state.distribution.parameters[1] for state in model.states if state.distribution ])

ax1.plot(means)
ax2.plot(covars)

ax3.imshow(model.dense_transition_matrix(),aspect='auto')


Out[6]:
<matplotlib.image.AxesImage at 0x7f2b552765f8>

In [31]:
model.fit([s0, s1, s2],
      verbose=True,
      min_iterations=1,
      max_iterations=100,
      transition_pseudocount = 0.0001,
      emission_pseudocount = 0.0001,
#       algorithm='baum-welch',
      algorithm='viterbi',
      n_jobs=8
     )


---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-31-1a8eab29a350> in <module>()
      7 #       algorithm='baum-welch',
      8       algorithm='viterbi',
----> 9       n_jobs=8
     10      )

~/workspace/pomegranate/pomegranate/hmm.pyx in pomegranate.hmm.HiddenMarkovModel.fit()

~/workspace/pomegranate/pomegranate/hmm.pyx in pomegranate.hmm.HiddenMarkovModel.fit()

~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/joblib/parallel.py in __call__(self, iterable)
    787                 # consumption.
    788                 self._iterating = False
--> 789             self.retrieve()
    790             # Make sure that we get a last message telling us we are done
    791             elapsed_time = time.time() - self._start_time

~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/joblib/parallel.py in retrieve(self)
    697             try:
    698                 if getattr(self._backend, 'supports_timeout', False):
--> 699                     self._output.extend(job.get(timeout=self.timeout))
    700                 else:
    701                     self._output.extend(job.get())

~/.pyenv/versions/3.6.1/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
    600 
    601     def get(self, timeout=None):
--> 602         self.wait(timeout)
    603         if not self.ready():
    604             raise TimeoutError

~/.pyenv/versions/3.6.1/lib/python3.6/multiprocessing/pool.py in wait(self, timeout)
    597 
    598     def wait(self, timeout=None):
--> 599         self._event.wait(timeout)
    600 
    601     def get(self, timeout=None):

~/.pyenv/versions/3.6.1/lib/python3.6/threading.py in wait(self, timeout)
    549             signaled = self._flag
    550             if not signaled:
--> 551                 signaled = self._cond.wait(timeout)
    552             return signaled
    553 

~/.pyenv/versions/3.6.1/lib/python3.6/threading.py in wait(self, timeout)
    293         try:    # restore state no matter what (e.g., KeyboardInterrupt)
    294             if timeout is None:
--> 295                 waiter.acquire()
    296                 gotit = True
    297             else:

KeyboardInterrupt: 

In [7]:
p_trans = model.dense_transition_matrix()[:-1]
p_means = means
p_covars = covars

In [29]:
trans = model.dense_transition_matrix()[:-1]
means = np.array([ state.distribution.parameters[0] for state in model.states if state.distribution ])
covars  = np.array([ state.distribution.parameters[1] for state in model.states if state.distribution ])        
        
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4,1)

ax1.plot(means)
ax1.plot(p_means)

ax2.plot(covars)
ax2.plot(p_covars)

ax3.imshow(trans, aspect='auto')
ax4.imshow(trans-p_trans, aspect='auto')

p_trans = trans.copy()
p_means = means
p_covars = covars



In [30]:
prediction = model.predict(s0, algorithm='viterbi')[1:]

fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5,1, sharex=True)

ax1.plot(s0)
ax2.plot(org0)
ax3.plot(prediction)
# ax3.plot(model.predict(s0, algorithm='map'))
ax4.plot(means[prediction])
ax4.plot(means[prediction]-s0)

ax5.plot(np.cumsum(means[prediction]))


Out[30]:
[<matplotlib.lines.Line2D at 0x7f2b519b87b8>]

In [ ]:
fig, ax = plt.subplots()

hist = np.histogram(prediction,bins=range(num+1),density=True)[0]
entr = np.nan_to_num(-hist*np.nan_to_num(np.log2(hist))).sum()
print(entr)

ax.plot(hist)
if 'pr_hist' in globals():
    ax.plot(pr_hist)
pr_hist = hist

In [ ]:
def change(model):
    ser = model.to_json()
    import json
    ser = json.loads(ser)
    edges = ser['edges']
#     print("emis")
    for i in range(num):
#         break
        #ser['states'][0]['distribution']['parameters'][0]['0']
        ser = model.to_json()
        ser = json.loads(ser)
        states = ser['states']
        state = states[i]
        dist = state['distribution']['parameters']
        dist[0] *= 1.1
        new_model = HiddenMarkovModel.from_json(json.dumps(ser))
        yield new_model
        dist[0] /= 1.1

        dist[0] *= 0.9
        new_model = HiddenMarkovModel.from_json(json.dumps(ser))
        yield new_model
        dist[0] /= 0.9

        dist[1] *= 1.1
        new_model = HiddenMarkovModel.from_json(json.dumps(ser))
        yield new_model
        dist[1] /= 1.1

        dist[1] *= 0.9
        dist[1] = max(1e-6,dist[1])
        new_model = HiddenMarkovModel.from_json(json.dumps(ser))
        yield new_model
#     print("trans")
    for i in range(len(edges)):
        ser = model.to_json()
        ser = json.loads(ser)
        edges = ser['edges']
        edge = edges[i]
        edge[2] *= 1.1
        new_model = HiddenMarkovModel.from_json(json.dumps(ser))
        yield new_model        


from collections import Counter
counter = 0
def entr_score(model, signal=s0):
    global counter
    counter += 1
    prediction = model.predict(signal, algorithm='map')
#     hist = Counter()
#     total = 0
#     for i in range(len(prediction)-1):
#         f = prediction[i]
#         t = prediction[i+1]
#         hist["{}-{}".format(t,f)] += 1
#         total += 1
# #     print(hist)
#     entr = 0
#     for k in hist:
#         v = hist[k]
#         p = v/total
#         entr += -p * np.log2(p) if v > 0 else 0
    hist = np.histogram(prediction,bins=range(num+1),density=True)[0]
    entr = np.nan_to_num(-hist*np.nan_to_num(np.log2(hist))).sum()
#     print(entr)
    print(counter, end='\r')
    return entr

def mean_cycle_time(model, signal=s0):
    prediction = model.predict(signal, algorithm='viterbi')[1:]
    cycles = []
    last = 0
    for i in range(len(prediction)-1):
        if prediction[i] >= 6 and prediction[i+1] < 6:
            cycles.append(i - last)
            last = i
    return np.mean(cycles)

def l1_score(model, signal=s0):
    means = np.array([ state.distribution.parameters[0] for state in model.states if state.distribution ])
    prediction = model.predict(signal, algorithm='viterbi')[1:]
    prediction = np.array(prediction) #[:-1]
    recons = means[prediction]
    dd = np.sum(np.abs(recons-signal))
    return dd

def dot_score(model, signal=s0):
    means = np.array([ state.distribution.parameters[0] for state in model.states if state.distribution ])
    prediction = model.predict(signal, algorithm='viterbi')[1:]
    prediction = np.array(prediction) #[:-1]
    recons = means[prediction]
    dd = np.sqrt(np.sum((recons-signal)**2))
    return dd

def norm(model, signal):
    means = np.array([ state.distribution.parameters[0] for state in model.states if state.distribution ])
    prediction = model.predict(signal, algorithm='viterbi')[1:]
    prediction = np.array(prediction) #[:-1]
    recons = means[prediction]
    return np.linalg.norm(recons - signal,2)

def score(model, signal=s0, cycle=280, verbose=False):
    global counter
    counter += 1
    
    d1 = norm(model, signal)
    m1 = mean_cycle_time(model, signal)
    
    if verbose:
        print(counter, d1, np.abs(m1 - cycle)/cycle, m1)
    print(counter, end='\r')
    return d1 + np.abs(m1 - cycle)/cycle

score(model,verbose=True), mean_cycle_time(model), dot_score(model), entr_score(model)

In [33]:
import gc
gc.collect()


Out[33]:
0

In [ ]:
def average(pairs):
    scores_sum = sum(b[0] for b in pairs)
    first = pairs[0]
    score = first[0]
    model = first[1]
    ser = model.to_json()
    import json
    ser = json.loads(ser)
    
    edges = ser['edges']
    for i in range(len(edges)):
        edges[i][2] *= score/scores_sum
    for i in range(num):
        states = ser['states']
        state = states[i]
        dist = state['distribution']['parameters']
        dist[0] *= score/scores_sum
        dist[1] *= score/scores_sum

    for p in pairs[1:]:
        score = p[0]
        model = p[1]
        m = model.to_json()
        m = json.loads(m)
        for i in range(len(edges)):
            edges[i][2] += m['edges'][i][2] * score/scores_sum
        for i in range(num):
            ser['states'][i]['distribution']['parameters'][0] += m['states'][i]['distribution']['parameters'][0] * score/scores_sum
            ser['states'][i]['distribution']['parameters'][1] += m['states'][i]['distribution']['parameters'][1] * score/scores_sum
    
    return HiddenMarkovModel.from_json(json.dumps(ser))

In [ ]:
model = average([ (1,m) for m in change(model)])

In [ ]:
fit = 0
i = 0

for y in range(100000):
    global counter
    counter = 0
    o = score(model, verbose=True)
    cand = [ (score(new_model), new_model) for new_model in change(model)]
#     print(cand)
    b = min(cand, key=lambda x:x[0])
    if b[0] >= o:
        break
    fits = [ f for f in cand if f[0] < o ]
    model = average(fits)
    print(y, o, b[0], len(fits))
#     model = b[1]
        
print(fit)

In [ ]:
def f(model):
    return score(model), model

In [ ]:
from multiprocessing import Pool

fit = 0
i = 0

pool = Pool(8)


for y in range(100000):
    global counter
    counter = 0
    o = score(model, verbose=True)
#     cand = [ (score(new_model), new_model) for new_model in change(model)]
    cand = pool.map(f,list(change(model)))
#     for x in scores:
#         print(x[0])
#     print(cand)
    b = min(cand, key=lambda x:x[0])
    if b[0] >= o:
        break
    fits = [ f for f in cand if f[0] < o ]
    model = average(fits)
    print(y, o, b[0], len(fits))
#     print(y, o, b[0])
#     model = b[1]
        
print(fit)

In [ ]:
pool.close()

In [ ]:
fit = 0
for i in range(100000):
    new_model = random_change(model)

    o = score(model)
    n = score(new_model)
    if n >= o:
        print(i, 'entr', n, o, end='\r')
        fit += 1
        model = new_model
    else:
        print(i, 'entr', n, o, end='\r')
        
print(fit)

In [ ]:


In [ ]:
fig, ax = plt.subplots()

ax.imshow(model.dense_transition_matrix() - new_model.dense_transition_matrix())

In [ ]:
json.loads(model.to_json())

In [21]:
fig, (ax1, ax2) = plt.subplots(2,1,sharex=True)

samp = model.sample(length=10000, path=False)

ax1.plot(samp)
ax2.plot(np.cumsum(samp))


Out[21]:
[<matplotlib.lines.Line2D at 0x7f2b51cea940>]

In [ ]:
with open('ff.json', 'w') as f:
    f.write(model.to_json())

In [ ]: