In [1]:
%matplotlib inline
#import networkx as nx
#import pygraphviz
import pyparsing
import numpy as np
import matplotlib.pylab as plt
from IPython.display import Math
In [7]:
# An implementation of the forward backward algorithm
# For numerical stability, we calculate everything in the log domain
def normalize_exp(log_P, axis=None):
a = np.max(log_P, keepdims=True, axis=axis)
P = normalize(np.exp(log_P - a), axis=axis)
return P
def normalize(A, axis=None):
Z = np.sum(A, axis=axis,keepdims=True)
idx = np.where(Z == 0)
Z[idx] = 1
return A/Z
def randgen(pr, N=1):
L = len(pr)
return np.random.choice(range(L), size=N, replace=True, p=pr)
def predict(A, lp):
lstar = np.max(lp)
return lstar + np.log(np.dot(A,np.exp(lp-lstar)))
def postdict(A, lp):
lstar = np.max(lp)
return lstar + np.log(np.dot(np.exp(lp-lstar), A))
def update(y, logB, lp):
return logB[y,:] + lp
def log_sum_exp_naive(l):
return np.log(np.sum(np.exp(l)))
def log_sum_exp(l, axis=0):
l_star = np.max(l, axis=axis, keepdims=True)
return l_star + np.log(np.sum(np.exp(l - l_star),axis=axis,keepdims=True))
In [12]:
# Generate Parameter
def do_something_magical(S, R, T):
A = np.random.dirichlet(0.7*np.ones(S), S).T
B = np.random.dirichlet(0.7*np.ones(R), S).T
p = np.random.dirichlet(0.7*np.ones(S)).T
logA = np.log(A)
logB = np.log(B)
# Generate Data
x = np.zeros(T,int)
y = np.zeros(T,int)
for t in range(T):
if t==0:
x[t] = randgen(p)
else:
x[t] = randgen(A[:,x[t-1]])
y[t] = randgen(B[:,x[t]])
print "x", x
print "y", y
# Forward Pass
# Python indexes starting from zero so
# log \alpha_{k|k} will be in log_alpha[:,k-1]
# log \alpha_{k|k-1} will be in log_alpha_pred[:,k-1]
log_alpha = np.zeros((S, T))
log_alpha_pred = np.zeros((S, T))
for k in range(T):
if k==0:
log_alpha_pred[:,0] = np.log(p)
else:
log_alpha_pred[:,k] = predict(A, log_alpha[:,k-1])
log_alpha[:,k] = update(y[k], logB, log_alpha_pred[:,k])
# Backward Pass
log_beta = np.zeros((S, T))
log_beta_post = np.zeros((S, T))
for k in range(T-1,-1,-1):
if k==T-1:
log_beta_post[:,k] = np.zeros(S)
else:
log_beta_post[:,k] = postdict(A, log_beta[:,k+1])
log_beta[:,k] = update(y[k], logB, log_beta_post[:,k])
log_gamma = log_alpha + log_beta_post
# Correction Smoother
# For numerical stability, we calculate everything in the log domain
log_gamma_corr = np.zeros_like(log_alpha)
log_gamma_corr[:,T-1] = log_alpha[:,T-1]
for k in range(T-2,-1,-1):
log_old_pairwise_marginal = log_alpha[:,k].reshape(1,S) + logA
log_old_marginal = predict(A, log_alpha[:,k])
log_new_pairwise_marginal = log_old_pairwise_marginal + log_gamma_corr[:,k+1].reshape(S,1) - log_old_marginal.reshape(S,1)
log_gamma_corr[:,k] = log_sum_exp(log_new_pairwise_marginal, axis=0).reshape(S)
# Verify that result coincide
gam = normalize_exp(log_gamma, axis=0)
gam_corr = normalize_exp(log_gamma_corr, axis=0)
plt.figure(figsize=(20,10))
plt.subplot(4,1,1)
plt.imshow(log_alpha, interpolation='nearest')
plt.subplot(4,1,2)
plt.imshow(log_beta, interpolation='nearest')
plt.show()
#print(log_gamma)
#print(log_gamma_corr)
In [13]:
do_something_magical(8, 3, 10)
In [14]:
do_something_magical(8, 3, 20)
In [15]:
do_something_magical(8, 3, 50)
In [16]:
do_something_magic(8, 3, 200)
In [ ]:
do_something_magical(8, 3, 20)
In [ ]:
do_something_magical(8, 400, 20)
In [ ]:
do_something_magical(8, 1000, 50)
In [17]:
do_something_magical(8, 10000, 50)
In [18]:
do_something_magical(10, 1000, 100)
In [20]:
do_something_magical(2, 1000, 100)
In [ ]: