Expectation Maximization

To illustrate the Expectation Maximization (EM) process, we first generate four distributions, shown by their sample points:


In [ ]:
# data 
import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
%matplotlib inline

class Distribution(object):
    def __init__(self, color):
        self.color = color

num_gaussians = 4
pi = np.array([0.1, 0.3, 0.2, 0.3])
num_samples = 10000

mu = ([1.7, .5],
       [2, 4],
       [0, 6],
       [5, 6]
     )
sigma = ([[.9, 0], [0, .5]],
         [[.4, .3], [.3, .5]],
         [[2, .7], [.2, .8]],
         [[.6, .6], [.3, .6]]
        )

distributions = {}
colors = ['r','g','b','y']
for i in range(num_gaussians):
    name = 'Sampled Distribution {}'.format(i + 1)
    distributions[name] = Distribution(colors[i])
    
    distributions[name].samples = np.random.multivariate_normal(
        mu[i], sigma[i], int(pi[i] * num_samples))
    
# Plot everything
fig, ax = plt.subplots()
for name, distribution in distributions.iteritems():
    ax.scatter(distribution.samples[:,0],
            distribution.samples[:,1],
            c=distribution.color,
            s=20,
            lw=0
            )
ax.set_title('Sampled distributions')

Next, we try to approximate these distributions using EM. At a high level, the algorithm iterates between two steps:

Expecation Step: Using fixed parameters, compute the expected value of the log-likelihood of the observed data (its responsibility) for each.

Maximization Step Estimate the parameters that maximize the expected value of the log-likelihood of the observed data.

If converged after the maximization step, exit.


In [ ]:
# Initial setup
K = 4  # <>But how do we know?
mu_hats = []
sigma_hats = []
pi_hats = []
for k in range(K):
    mu_hats.append(np.rand.randint(-10,10))
    sigma_hats.append(np.eye(2))
    pi_hat

In [1]:
from IPython.core.display import HTML

# Borrowed style from Probabilistic Programming and Bayesian Methods for Hackers
def css_styling():
    styles = open("../styles/custom.css", "r").read()
    return HTML(styles)
css_styling()


Out[1]: