## Expectation Maximization

To illustrate the Expectation Maximization (EM) process, we first generate four distributions, shown by their sample points:



In [ ]:

# data
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
%matplotlib inline

class Distribution(object):
def __init__(self, color):
self.color = color

num_gaussians = 4
pi = np.array([0.1, 0.3, 0.2, 0.3])
num_samples = 10000

mu = ([1.7, .5],
[2, 4],
[0, 6],
[5, 6]
)
sigma = ([[.9, 0], [0, .5]],
[[.4, .3], [.3, .5]],
[[2, .7], [.2, .8]],
[[.6, .6], [.3, .6]]
)

distributions = {}
colors = ['r','g','b','y']
for i in range(num_gaussians):
name = 'Sampled Distribution {}'.format(i + 1)
distributions[name] = Distribution(colors[i])

distributions[name].samples = np.random.multivariate_normal(
mu[i], sigma[i], int(pi[i] * num_samples))

# Plot everything
fig, ax = plt.subplots()
for name, distribution in distributions.iteritems():
ax.scatter(distribution.samples[:,0],
distribution.samples[:,1],
c=distribution.color,
s=20,
lw=0
)
ax.set_title('Sampled distributions')



Next, we try to approximate these distributions using EM. At a high level, the algorithm iterates between two steps:

Expecation Step: Using fixed parameters, compute the expected value of the log-likelihood of the observed data (its responsibility) for each.

Maximization Step Estimate the parameters that maximize the expected value of the log-likelihood of the observed data.

If converged after the maximization step, exit.



In [ ]:

# Initial setup
K = 4  # <>But how do we know?
mu_hats = []
sigma_hats = []
pi_hats = []
for k in range(K):
mu_hats.append(np.rand.randint(-10,10))
sigma_hats.append(np.eye(2))
pi_hat




In [1]:

from IPython.core.display import HTML

# Borrowed style from Probabilistic Programming and Bayesian Methods for Hackers
def css_styling():
return HTML(styles)
css_styling()




Out[1]:

div.cell{
width:800px;
margin-left:16% !important;
margin-right:auto;
}
h1, h2, h3, h4 {
font-family: "Roboto", "wingdings", sans-serif;
}
h1{
font-weight: 500;
}
h2{
font-weight: 400;
}
h3{
font-weight: 300 !important;
/*         font-style: italic; */
}
h4{
font-weight: 300 !important;
font-style: italic;
margin-top:12px;
margin-bottom: 3px;
}
div.text_cell_render{
font-family: "HelveticaNeue-light", "Helvetica Neue", Arial, Helvetica, Geneva, sans-serif;
line-height: 145%;
font-size: 120%;
width:800px;
margin-left:auto;
margin-right:auto;
}
.CodeMirror{
font-family: "Source Code Pro", source-code-pro,Consolas, monospace;
}
.prompt{
display: None;
}
.text_cell_render h5 {
font-weight: 300;
font-size: 22pt;
color: #4057A1;
font-style: italic;
margin-bottom: .5em;
margin-top: 0.5em;
display: block;
}

.warning{
color: rgb( 240, 20, 20 )
}

.rounded-box{
border: 2px solid #8AC007;
}

MathJax.Hub.Config({
TeX: {
extensions: ["AMSmath.js"]
},
tex2jax: {
inlineMath: [ ['$','$'], ["\$","\$"] ],
displayMath: [ ['$$','$$'], ["\$","\$"] ]
},
displayAlign: 'center', // Change this to 'center' to center equations.
"HTML-CSS": {
styles: {'.MathJax_Display': {"margin": 4}}
}
});