In [1]:
# Load packages
import numpy as np
import matplotlib.pyplot as plt
import dpmm
%matplotlib inline
In [2]:
# Make a 1-D Gaussian mixture model with all component means = 0 and specified variances.
class gaussV(object):
def __init__(self, V):
self.V = V
def sample(self, size=None):
return np.random.normal(scale=np.sqrt(self.V), size=size)
def __call__(self, x):
return np.exp(-0.5*x**2/self.V)/np.sqrt(2*np.pi*self.V)
class MM(object):
def __init__(self, components, proportions):
self.components = components
self.proportions = proportions
def sample(self, size=None):
if size is None:
nums = np.random.multinomial(1, self.proportions)
c = num.index(1) # which class got picked
return self.components[c].sample()
else:
out = np.empty((size,), dtype=float)
nums = np.random.multinomial(size, self.proportions)
i = 0
for component, num in zip(self.components, nums):
out[i:i+num] = component.sample(size=num)
i += num
return out
def __call__(self, x):
return np.sum([p*c(x) for p, c in zip(self.proportions, self.components)], axis=0)
def plot(self, axis=None, **kwargs):
""" Plot the mixture model pdf."""
if axis is None:
axis = plt.gca()
x = np.arange(-2,2,0.01)
y = self(x)
axis.plot(x, y, **kwargs)
In [3]:
# Mixture model parameters
V = [0.3**2, 0.05**2] # variances
p = [0.5, 0.5] # proportions
model = MM([gaussV(V0) for V0 in V], p)
In [4]:
# Plot the generative distribution
f = plt.figure(figsize=(6,4))
ax = f.add_subplot(111)
model.plot(axis=ax)
f.show()
In [5]:
# Draw some samples with which to do inference.
data = model.sample(size=100)
# Plot with samples
f = plt.figure(figsize=(6, 4))
ax = f.add_subplot(111)
model.plot(axis=ax)
for x in data:
ax.axvline(x, alpha=0.1)
f.show()
In [39]:
# Create the Dirichlet Process Mixture Model.
# Use the Inverse Gamma prior for a Gaussian with known mean and unknown variance.
mu = 0.0 # Known mean
alpha = 1.0 # How well we know beta
beta = 100.0 # 1/typical-variance
dp_alpha = 0.1 # Dirichlet Process clustering parameter. Set lower to infer fewer components.
cp = dpmm.InvGamma(alpha, beta, mu)
dp = dpmm.DPMM(cp, dp_alpha, data)
In [40]:
# Burn in
dp.update(100)
In [41]:
phis = []
nphis = []
# Sample
for i in xrange(50):
dp.update(10)
phis.append(list(dp.phi)) # Need list() to get a copy
nphis.append(list(dp.nphi))
In [42]:
def plot_sample(phi, nphi, axis=None, **kwargs):
x = np.arange(-1, 1, 0.01)
y = np.zeros_like(x)
for ph, n in zip(phi, nphi):
y += n*np.exp(-0.5*x**2/ph)/np.sqrt(2*np.pi*ph)/sum(nphi)
if axis is None:
axis = plt.gca()
axis.plot(x, y, **kwargs)
In [43]:
# Plot with samples
f = plt.figure(figsize=(6, 4))
ax = f.add_subplot(111)
model.plot(axis=ax)
for x in data:
ax.axvline(x, alpha=0.1)
for phi, nphi in zip(phis, nphis):
plot_sample(phi, nphi, axis=ax, alpha=0.1, c='k')
ax.set_xlim(-1,1)
ax.hist(data, 20, alpha=0.3, color='r', normed=True)
f.show()
In [52]:
# Try a large-ish data set.
# For WL, we can concatenate the e1 and e2 samples together after deshearing. Then we have 20000 e samples per
# GREAT3 field.
data = model.sample(size=20000)
mu = 0.0
alpha = 1.0
beta = 100.0
dp_alpha = 0.1
cp = dpmm.InvGamma(alpha, beta, mu)
dp = dpmm.DPMM(cp, dp_alpha, data)
In [53]:
prun dp.update(100) # about 35 sec
In [54]:
# a few minutes to generate 500 samples and store every 10th one (i.e., store 50).
phis = []
nphis = []
# Sample
for i in xrange(50):
dp.update(10)
phis.append(list(dp.phi)) # Need list() to get a copy
nphis.append(list(dp.nphi))
In [55]:
# Plot with samples
f = plt.figure(figsize=(6, 4))
ax = f.add_subplot(111)
model.plot(axis=ax)
# for x in data:
# ax.axvline(x, alpha=0.002)
for phi, nphi in zip(phis, nphis):
plot_sample(phi, nphi, axis=ax, alpha=0.1, c='k')
ax.set_xlim(-1, 1)
ax.hist(data, 100, alpha=0.3, color='r', normed=True)
f.show()
In [56]:
# Zoom in
f = plt.figure(figsize=(6, 4))
ax = f.add_subplot(111)
model.plot(axis=ax)
# for x in data:
# ax.axvline(x, alpha=0.002)
for phi, nphi in zip(phis, nphis):
plot_sample(phi, nphi, axis=ax, alpha=0.1, c='k')
ax.set_xlim(-0.1, 0.1)
ax.hist(data, 1000, alpha=0.3, color='r', normed=True)
f.show()
In [57]:
nphis
Out[57]:
In [ ]: