In [5]:
# based on https://en.wikipedia.org/wiki/Baum%E2%80%93Welch_algorithm
import numpy as np
trans = np.array([
[0.5, 0.5],
[0.3, 0.7]
])
emit = np.array([
[0.3, 0.7],
[0.8, 0.2]
])
init = np.array([0.2, 0.8])
num_states = trans.shape[0]
num_out = emit.shape[1]
print('num_states', num_states)
print('num_out', num_out)
print('trans', trans)
print('emit', emit)
print('init', init)
sequences_strings = """
NN
NN
NN
NN
NE
EE
EN
NN
NN
"""
num_sequences = len([s for s in sequences_strings.split('\n') if s.strip() != ''])
sequences = np.zeros((num_sequences, 2), dtype=np.int32)
n = 0
for line in sequences_strings.split('\n'):
line = line.strip()
if line == '':
continue
for i, c in enumerate(line):
s = 1 if c == 'E' else 0
sequences[n][i] = s
n += 1
print(sequences)
new_trans = np.zeros((num_states, num_states), dtype=np.float32)
new_emit = np.zeros((num_states, num_out), dtype=np.float32)
for i in range(num_states):
for j in range(num_states):
p_tot = 0.0
for n in range(num_sequences):
seq = sequences[n]
T = seq.shape[0]
# print('T', T)
p = 1.0
for t in range(T):
s = seq[t]
if t == 0:
p *= init[s]
else:
s_prev = seq[t - 1]
p *= trans[s_prev][s]
p *=
# print('t', t)