A while ago Dahua Lin posted a message on julia-stats calling for construction of a new language (within Julia) that would take general graphical models (in abstract form) and 'compile' them to Julia code. This would free modellers of the hassles of re-implementing a whole bunch of inference algorithms (like EM) over and over again. It's a very intriguing idea and would be awesome if it could be made to work. The attached slides are here (I'm linking it because I will be referring to it in this write-up).

Inspired by that, I decided to use the Distributions.jl package as a starting point and see how far I could go just by defining types and generic functions on those types to do various inference procedures (such as message passing). Note that I'm not defining a new language, just exploring what the Julia compiler itself can do. I found, pleasantly enough, that the interface Distributions.jl provides is rich enough to implement EM on both models Lin mentions in his presentation, just by some minimal EM code, specifying the high-level model and using minimal 'grunt' details. To whet your appetite, here's a generic implementation of the EM algorithm that works, without modification, on any type of mixture. Disclaimer: None of the code listed here is meant to be the 'optimal' way of doing things; it's just given as an example and, mostly, for fun.

First, let's define a generic mixture type:


In [1]:
using Distributions

# A 'closed' mixture model defining a full generative model
type MixtureModel{T} <: Distribution
	mixing::Distribution   # Distribution over mixing components
	component::Vector{T}   # Individual component distributions
end

Here, T is the type of mixture component. So, for instance, if T were MvNormal, it would give you a multivariate normal mixture. Now that we have that, we can do EM updates on that generic mixture type. The following function fit_mm_em does that:


In [2]:
# This returns, for a model and observations x,
# the distribution over latent variables.
function infer(m::MixtureModel, x)
	K = length(m.component)  # number of mixture components
	N = size(x,2)            # number of data points
	
	lq = Array(Float64, N, K)
	for k = 1:K
		lq[:,k] = logpdf(m.component[k], x) .+ logpdf(m.mixing, k)
	end
	return lq
end

logp_to_p(lp) = exp(lp .- maximum(lp))

function fit_mm_em{T}(m::MixtureModel{T}, x)
	# Expectation step
	lq = infer(m, x)

	# Normalize log-probability and convert to probability
	q = logp_to_p(lq)
	q = q ./ sum(q,2)

	# Maximization step
	cr = 1:length(m.component)
	comps = [fit_em(m.component[k], x, q[:,k]) for k = cr]
	mix   =  fit_em(m.mixing, [cr], vec(sum(q,1)))

	MixtureModel{T}(mix, comps)
end

# 'fallback' function
fit_em(m::Distribution, x, w) = fit_mle(typeof(m), x, w)


Out[2]:
fit_em (generic function with 1 method)

We can test this on a simple Gaussian mixture:


In [8]:
# Data: two clusters separated by a distance of 6
x = cat(2, randn(4,64).-3, randn(4,32).+3)

# Initial guess for components and mixture weights (random)
sigma = 0.1*cov(x')   # empirical guesstimate
comps = [MvNormal(randn(4),sigma), MvNormal(randn(4),sigma)]
m     = MixtureModel(Categorical([0.5, 0.5]), comps)

# Run EM algorithm for a few iterations
for i = 1:10
	m = fit_mm_em(m, x)
end

println(m.component[1].μ)
println(m.component[2].μ)
println(m.component[2].Σ)
println(m.mixing.prob)


[-3.2631184236532436,-2.7006729386547077,-2.919006575979322,-2.9576009114761903]
[1.2706768333967895,0.6933642515141178,0.6758533469320562,0.819694302494376]
PDMat(4,[5.871224385261959 6.904157012061517 6.448055535687115 6.721782065067239
 6.904157012061517 9.952228780963178 8.080981415827772 8.570001248647637
 6.448055535687115 8.080981415827772 8.921371836004822 7.9947845428335755
 6.721782065067239 8.570001248647637 7.9947845428335755 9.21604456035357],Cholesky{Float64} with factor:
[2.42306095368275 2.8493534186864182 2.6611198227955746 2.774087071499778
 0.0 1.3540361429382892 0.36816635494620326 0.491601920087216
 0.0 0.0 1.3054756450485387 0.3306190822774627
 0.0 0.0 0.0 1.0814361075403056])
[0.47608155916462025,0.5239184408353796]

Ok, so we have a generalized EM procedure for mixtures, which compiles to a concrete function once it's called. But will it work on wacky new custom distributions? Let's find out. Let's define a new distribution type that encapsulates a distribution along with its prior. We use this to model the π and μ variables in page 9 of Lin's presentation. We also have to


In [4]:
import Distributions:logpdf

type DistWithPrior <: Distribution
	pri                    # a prior over the parameters of dist (tuple)
	dist::Distribution     # the distribution itself
end

Now let's also define fit_em and log_pdf on these distributions, using Distributions.jl's own fit_map function.


In [5]:
fit_em{T<:DistWithPrior}(m::T, x, w) = 
  T(m.pri, fit_map(m.pri, typeof(m.dist), x, w))

logpdf{T<:DistWithPrior}(m::T, x) = logpdf(m.dist, x)


Out[5]:
logpdf (generic function with 48 methods)

In [6]:
# dimensionality of output
K = 4

# sample data
x = cat(2, randn(K,64).-3, randn(K,32).+3)

# Diagonal prior (note that we can't use IsoNormal)
v = 0.0
sig = 1.0
mu = MvNormal(fill(v,K), diagm(fill(sig,K)))

# Initial guess for components and mixture weights (random)
# note that, here, sigma is shared.
comps = Array(DistWithPrior, 2)
sigma = eye(4)
comps[1] = DistWithPrior((mu, sigma), MvNormal(randn(K), sigma))
comps[2] = DistWithPrior((mu, sigma), MvNormal(randn(K), sigma))

# Dirichlet prior
alpha = 2.0
mix = DistWithPrior(Dirichlet(fill(alpha/2,2)), Categorical([0.5, 0.5]))

m = MixtureModel(mix, comps)

# Run EM algorithm for a few iterations
for i = 1:10
	m = fit_mm_em(m, x)
end

println(m.component[1].dist.μ)
println(m.component[2].dist.μ)
println(m.component[2].dist.Σ)
println(m.mixing.dist.prob)


[-2.8119260099645573,-3.083135355368115,-3.087427589202475,-2.6317364980559486]
[3.027306305523113,2.98335879132994,2.8756961657062967,2.7799416039825484]
PDMat(4,[1.0 0.0 0.0 0.0
 0.0 1.0 0.0 0.0
 0.0 0.0 1.0 0.0
 0.0 0.0 0.0 1.0],Cholesky{Float64} with factor:
[1.0 0.0 0.0 0.0
 0.0 1.0 0.0 0.0
 0.0 0.0 1.0 0.0
 0.0 0.0 0.0 1.0])
[0.6666666666666666,0.3333333333333333]