Gaussian Mixture Models and the E-M Algorithm

Now let's look at Gaussian Mixture Models and how to learn them using E-M.


In [1]:
# Boilerplate setup!
%matplotlib inline
import numpy as np
import pylab as plt
# Some utility code for this notebook
import utils

In [2]:
# Generate some clustered data
(amps,means,covs),ax = utils.get_clusters_D()
N = 1000
Xi = utils.sample_clusters(amps, means, covs, N=N)
X = np.vstack(Xi)

# Plot the true cluster memberships
plt.clf()
for i,x in enumerate(Xi):
    plt.plot(x[:,0], x[:,1], 'o', color=utils.colors[i])
plt.title('True cluster labels');
plt.show();



In [36]:
# Here is our Expectation-Maximization algorithm for 
# Gaussian Mixture Models.
# There is a spot in this function where you need to add some code!

def em_gmm(X, K, diagnostics=None):
    '''
    Expectation-Maximization (E-M) for Gaussian Mixture Models

    *X*: (N,D) data points
    *K*: integer number of components
    
    Returns: a list of K tuples of (amps, means, covariances)
    of the Gaussian components
    '''
    N,D = X.shape
    
    # Random initialization (scaled by dataset standard deviation)
    amps = np.ones(K) / K
    std = np.std(X, axis=0)
    means = np.random.normal(loc=np.mean(X), scale=std, size=(K,D))
    covs = [np.diag(std) for k in range(K)]

    # Until convergence... this can take many iterations
    for step in range(25):
        # Compute the "indicator" variable z[i,k]: probability that data
        # point i was drawn from Gaussian component k.
        z = np.zeros((N,K))

        for k,(amp,mean,cov) in enumerate(zip(amps, means, covs)):
            #print 'Component K: amp', amp, 'mean', mean, 'cov', cov
            z[:,k] = amp * utils.gaussian_probability(X, mean, cov)
    
        # Normalize the indicator over *components* --
        # ("this data point had to come from somewhere")
        z /= np.sum(z, axis=1)[:,np.newaxis]

        #### ADD YOUR CODE HERE -- the code below is not right! ####
        # You might want to check out the equations here:
        # http://dstn.astrometry.net/talks/2015-06-03-phat-michigan-p18.pdf

        # HERE's my SOLUTION:
        #newamps = np.sum(z, axis=0)
        #newamps /= np.sum(newamps)
        #newmeans = [np.sum(z[:,k][:,np.newaxis] * X, axis=0) / np.sum(z[:,k])
        #            for k in range(K)]
        #newcovs = [np.dot((X - mean).T, z[:,k][:,np.newaxis] * (X - mean)) / np.sum(z[:,k])
        #           for k,mean in enumerate(means)]

        # You could also write it more explicitly:
        newamps = np.sum(z, axis=0)
        newamps /= np.sum(newamps)

        newmeans = np.zeros((K,D))
        for k in range(K):
            zk = z[:,k]
            for d in range(D):
                newmeans[k,d] = np.sum(zk * X[:,d]) / np.sum(zk)

        newcovs = np.zeros((K,D,D))
        for k,mean in enumerate(means):
            zk = z[:,k]
            newcovs[k,:,:] = np.dot(zk * (X - mean).T, (X - mean)) / np.sum(zk)

        if diagnostics is not None:
            diagnostics(step, X, K, amps, means, covs, z,
                        newamps, newmeans, newcovs)
            
        amps, means, covs = newamps, newmeans, newcovs

    return amps, means, covs

In [37]:
K=2

# It's handy to plot the progress of the algorithm while testing, like this:

#amps,means,covs = em_gmm(X, K, diagnostics=utils.plot_em)

amps,means,covs = em_gmm(X, K)

# Plot the final fit
utils.plot_em(0, X, K, amps, means, covs, None, None, None, None,
              show=False)
plt.title('E-M result');



In [ ]: