A while ago Dahua Lin posted a message on julia-stats calling for construction of a new language (within Julia) that would take general graphical models (in abstract form) and 'compile' them to Julia code. This would free modellers of the hassles of re-implementing a whole bunch of inference algorithms (like EM) over and over again. It's a very intriguing idea and would be awesome if it could be made to work. The attached slides are here (I'm linking it because I will be referring to it in this write-up).
Inspired by that, I decided to use the Distributions.jl package as a starting point and see how far I could go just by defining types and generic functions on those types to do various inference procedures (such as message passing). Note that I'm not defining a new language, just exploring what the Julia compiler itself can do. I found, pleasantly enough, that the interface Distributions.jl provides is rich enough to implement EM on both models Lin mentions in his presentation, just by some minimal EM code, specifying the high-level model and using minimal 'grunt' details. To whet your appetite, here's a generic implementation of the EM algorithm that works, without modification, on any type of mixture. Disclaimer: None of the code listed here is meant to be the 'optimal' way of doing things; it's just given as an example and, mostly, for fun.
First, let's define a generic mixture type:
In [1]:
using Distributions
# A 'closed' mixture model defining a full generative model
type MixtureModel{T} <: Distribution
mixing::Distribution # Distribution over mixing components
component::Vector{T} # Individual component distributions
end
Here, T is the type of mixture component. So, for instance, if T were MvNormal, it would give you a multivariate normal mixture. Now that we have that, we can do EM updates on that generic mixture type. The following function fit_mm_em does that:
In [2]:
# This returns, for a model and observations x,
# the distribution over latent variables.
function infer(m::MixtureModel, x)
K = length(m.component) # number of mixture components
N = size(x,2) # number of data points
lq = Array(Float64, N, K)
for k = 1:K
lq[:,k] = logpdf(m.component[k], x) .+ logpdf(m.mixing, k)
end
return lq
end
logp_to_p(lp) = exp(lp .- maximum(lp))
function fit_mm_em{T}(m::MixtureModel{T}, x)
# Expectation step
lq = infer(m, x)
# Normalize log-probability and convert to probability
q = logp_to_p(lq)
q = q ./ sum(q,2)
# Maximization step
cr = 1:length(m.component)
comps = [fit_em(m.component[k], x, q[:,k]) for k = cr]
mix = fit_em(m.mixing, [cr], vec(sum(q,1)))
MixtureModel{T}(mix, comps)
end
# 'fallback' function
fit_em(m::Distribution, x, w) = fit_mle(typeof(m), x, w)
Out[2]:
We can test this on a simple Gaussian mixture:
In [8]:
# Data: two clusters separated by a distance of 6
x = cat(2, randn(4,64).-3, randn(4,32).+3)
# Initial guess for components and mixture weights (random)
sigma = 0.1*cov(x') # empirical guesstimate
comps = [MvNormal(randn(4),sigma), MvNormal(randn(4),sigma)]
m = MixtureModel(Categorical([0.5, 0.5]), comps)
# Run EM algorithm for a few iterations
for i = 1:10
m = fit_mm_em(m, x)
end
println(m.component[1].μ)
println(m.component[2].μ)
println(m.component[2].Σ)
println(m.mixing.prob)
Ok, so we have a generalized EM procedure for mixtures, which compiles to a concrete function once it's called. But will it work on wacky new custom distributions? Let's find out. Let's define a new distribution type that encapsulates a distribution along with its prior. We use this to model the π and μ variables in page 9 of Lin's presentation. We also have to
In [4]:
import Distributions:logpdf
type DistWithPrior <: Distribution
pri # a prior over the parameters of dist (tuple)
dist::Distribution # the distribution itself
end
Now let's also define fit_em and log_pdf on these distributions, using Distributions.jl's own fit_map function.
In [5]:
fit_em{T<:DistWithPrior}(m::T, x, w) =
T(m.pri, fit_map(m.pri, typeof(m.dist), x, w))
logpdf{T<:DistWithPrior}(m::T, x) = logpdf(m.dist, x)
Out[5]:
In [6]:
# dimensionality of output
K = 4
# sample data
x = cat(2, randn(K,64).-3, randn(K,32).+3)
# Diagonal prior (note that we can't use IsoNormal)
v = 0.0
sig = 1.0
mu = MvNormal(fill(v,K), diagm(fill(sig,K)))
# Initial guess for components and mixture weights (random)
# note that, here, sigma is shared.
comps = Array(DistWithPrior, 2)
sigma = eye(4)
comps[1] = DistWithPrior((mu, sigma), MvNormal(randn(K), sigma))
comps[2] = DistWithPrior((mu, sigma), MvNormal(randn(K), sigma))
# Dirichlet prior
alpha = 2.0
mix = DistWithPrior(Dirichlet(fill(alpha/2,2)), Categorical([0.5, 0.5]))
m = MixtureModel(mix, comps)
# Run EM algorithm for a few iterations
for i = 1:10
m = fit_mm_em(m, x)
end
println(m.component[1].dist.μ)
println(m.component[2].dist.μ)
println(m.component[2].dist.Σ)
println(m.mixing.dist.prob)