In [1]:
# Load packages
import numpy as np
import matplotlib.pyplot as plt
import dpmm
%matplotlib inline
In [2]:
# Define a mixture model of N-dimensional Gaussians to play with.
class gaussND(object):
def __init__(self, mu, Sigma):
self.mu = mu
self.Sigma = Sigma
self.d = len(self.mu)
def sample(self, size=None):
return np.random.multivariate_normal(mean=self.mu, cov=self.Sigma, size=size)
class MM(object):
def __init__(self, components, proportions):
self.components = components
self.proportions = proportions
self.d = self.components[0].d
def sample(self, size=None):
if size is None:
nums = np.random.multinomial(1, self.proportions)
c = num.index(1) # which class got picked
return self.components[c].sample()
else:
out = np.empty((size, self.d), dtype=float)
nums = np.random.multinomial(size, self.proportions)
i = 0
for component, num in zip(self.components, nums):
out[i:i+num] = component.sample(size=num)
i += num
return out
In [3]:
# Mixture model parameters
mu = [np.r_[0.0, 0.0], np.r_[1.0, 1.0], np.r_[0., 1.0]]
Sigma = [np.eye(2)*0.1**2, np.eye(2)*0.2**2, np.eye(2)*0.2**2]
p = [0.25, 0.4, 0.35]
model = MM([gaussND(*theta) for theta in zip(mu, Sigma)], p)
In [4]:
# Generate some data and plot it
data = model.sample(size=100)
f = plt.figure(figsize=(6,4))
ax = f.add_subplot(111)
ax.scatter(data[:,0], data[:,1], alpha=0.5)
f.show()
In [5]:
# Create the Dirichlet Process Mixture Model. Use a Normal-Inverse-Wishart prior, which models the mean
# and covariance matrix of an N-dimensional Gaussian.
mu_0 = np.r_[0.5, 0.5]
kappa_0 = 1.1
Lam_0 = np.eye(2)*2
nu_0 = 2
alpha = 1.0 # Dirichlet Process clustering parameter. Set higher to infer more components, lower to infer fewer.
cp = dpmm.NormInvWish(mu_0, kappa_0, Lam_0, nu_0)
dp = dpmm.DPMM(cp, alpha, data)
In [6]:
# Burn-in
dp.update(100)
In [7]:
# Sample
phis = []
nphis = []
# Sample
for i in xrange(10):
dp.update(10)
phis.append(list(dp.phi)) # Use list() to get a copy
nphis.append(list(dp.nphi))
In [8]:
# Modified code from http://stackoverflow.com/questions/9081553/python-scatter-plot-size-and-style-of-the-marker/24567352#24567352
def ellipses(x, y, s, q, pa, c='b', ax=None, vmin=None, vmax=None, **kwargs):
"""Scatter plot of ellipses.
(x, y) duh.
s size.
q minor-to-major axes ratio b/a
pa position angle CCW in deg.
"""
from matplotlib.patches import Ellipse
from matplotlib.collections import PatchCollection
import pylab as plt
#import matplotlib.colors as colors
if ax is None:
ax = plt.gca()
if isinstance(c,basestring):
color = c # ie. use colors.colorConverter.to_rgba_array(c)
else:
color = None # use cmap, norm after collection is created
kwargs.update(color=color)
w, h = s*np.sqrt(q), s/np.sqrt(q)
if np.isscalar(x):
patches = [Ellipse((x, y), w, h, pa),]
else:
patches = [Ellipse((x_,y_), w_, h_, pa_) for x_,y_,w_,h_,pa_ in zip(x,y,w,h,pa)]
collection = PatchCollection(patches, **kwargs)
if color is None:
collection.set_array(np.asarray(c))
if vmin is not None or vmax is not None:
collection.set_clim(vmin, vmax)
ax.add_collection(collection)
ax.autoscale_view()
return collection
In [9]:
def plot_sample(phi, nphi, axis=None, **kwargs):
if axis is None:
axis = plt.gca()
for ph, n in zip(phi, nphi):
val, vec = np.linalg.eigh(ph[1])
# 5.991 gives 95% ellipses
s = np.sqrt(np.sqrt(5.991*val[0]*val[1]))
q = np.sqrt(val[0]/val[1])
pa = np.arctan2(vec[0,1], vec[0,0])*180/np.pi
ellipses(ph[0][0], ph[0][1], s, q, pa, **kwargs)
In [10]:
# Plot data and model realizations. Note that the model components are all shaded the same, even though some of
# them are actually weighted higher than others. (Need to come up with a better visualization...)
f = plt.figure(figsize=(6,4))
ax = f.add_subplot(111)
ax.scatter(data[:,0], data[:,1], alpha=0.5)
for phi, nphi in zip(phis, nphis):
plot_sample(phi, nphi, axis=ax, c='r', alpha=0.1)
f.show()
In [11]:
# Sample some more and plot again.
phis = []
nphis = []
# Sample
for i in xrange(10):
dp.update(10)
phis.append(list(dp.phi)) # Need list() to get a copy
nphis.append(list(dp.nphi))
f = plt.figure(figsize=(6,4))
ax = f.add_subplot(111)
ax.scatter(data[:,0], data[:,1], alpha=0.5)
for phi, nphi in zip(phis, nphis):
plot_sample(phi, nphi, axis=ax, c='r', alpha=0.1)
f.show()
In [12]:
# Sample some more and plot again.
phis = []
nphis = []
# Sample
for i in xrange(10):
dp.update(10)
phis.append(list(dp.phi)) # Need list() to get a copy
nphis.append(list(dp.nphi))
f = plt.figure(figsize=(6,4))
ax = f.add_subplot(111)
ax.scatter(data[:,0], data[:,1], alpha=0.5)
for phi, nphi in zip(phis, nphis):
plot_sample(phi, nphi, axis=ax, c='r', alpha=0.1)
f.show()
In [ ]: