Toward a domain-specific language for variational inference

John Pearson
DukeML group meeting
2-18-16

<img src="http://pearsonlab.github.io/images/plab_logo_dark.svg" width="300", align="left">

What's variational inference?

Generative model for data: $p(y|\theta)/Z$
Approximate model posterior $q(\theta)$

Maximize Evidence Lower Bound (ELBO) wrt $\theta$:

$$ \log Z \ge -KL\left(q \middle\| p\right) = \mathcal{L} = \mathbb{E}_q[\log p(y|\theta)] + \mathcal{H}[q] $$

Why variational inference?

  • Scales well
  • Can use well-studied optimization techniques

Drawbacks:

  • !@$*&# hard to code
  • Can't quickly spec out a model like with Stan or JAGS/BUGS

Why is it difficult?

  • Traditionally, conjugate models $\Longrightarrow$ lots of algebra
  • Gradient descent requires gradient calculation
  • for non-stochastic models $\mathcal{L}$ should increase on every iteration, but requires extra calculation of objective — tricky to get right

Lots of great VI tricks

... but hard to mix and match

  • Stan does (only) BBVI
  • but no discrete params
  • only mean field or full (Gaussian) covariance
  • custom Stan requires C++
  • VIBES (is abandonware?)

What's the ideal?

  • write math, get code — a domain-specific language (DSL)
  • easily generalize to different numbers of indices, structures
  • only weakly opinionated about model structure or inference
  • model code should be hackable
    • easy to use prefab pieces
    • not hard to write custom vb tricks
    • fast prototyping
  • no (or minimal) algebra
    • simple expectations
    • automatic gradients

Introducing...

VinDsl.jl: Fast and furious variational inference

What makes VinDsl special?

  • written in Julia
  • Sensible model primitives
  • Automatic index bookkeeping
  • Expectation calculus
  • Exploiting conjugacy
  • Automatic gradients$^*$
  • Multiple inference strategies$^*$

*: Coming Soon

Model structure:

Main idea: Factor graphs

  • idea from Dahua Lin in this talk
  • Nodes: arrays of distributions
  • Factors $\leftrightarrow$ terms in variational objective
    • but not locked in to graphical model structure!

Nodes can be generated from any distribution type in Julia

  • indices inferred automagically
  • expectations, entropy, etc. just work

In [4]:
push!(LOAD_PATH, "/Users/jmxp/code/VinDsl.jl/src")
using VinDsl
using Distributions

In [6]:
dims = (20, 6)

μ[j] ~ Normal(zeros(dims[2]), ones(dims[2]))
τ[j] ~ Gamma(1.1 * ones(dims[2]), ones(dims[2]))
μ0[j] ~ Const(zeros(dims[2]))

y[i, j] ~ Const(rand(dims));

Nodes: under the hood

  • nodes define the q/approximate posterior/recognition model
  • ~ defines a node
  • can use any distribution defined in the Distributions package
  • code parses the left and right-hand sides
    • indices on left get tracked and assigned to dimensions of parameter arrays
    • code is rewritten as a call to a node constructor

Factors

  • Factors are terms in the generative model
  • Right now:

In [8]:
f = @factor LogNormalFactor y μ τ;

In future:


In [ ]:
@pmodel begin
    y ~ Normal(μ, τ)
end

New factor types can be defined with yet another macro:


In [ ]:
@deffactor LogNormalFactor [x, μ, τ] begin
    -(1/2) * ((E(τ) * ( V(x) + V(μ) + (E(x) - E(μ))^2 ) + log(2π) + Elog(τ)))
end

@deffactor LogGammaCanonFactor [x, α, β] begin
    (E(α) - 1) * Elog(x) - E(β) * E(x) + E(α) * E(β) - Eloggamma(α)
end
  • Uses a "mini-language" with E(x) $\equiv \mathbb{E}[X]$, V(x) $\equiv \textrm{cov}[X]$, etc.
  • Again, no need to track indices
    • multivariate distributions (Dirichlet, MvNormal) are automatically multivariate in these expressions
  • VinDsl generates a value(f) function that handles indices appropriately and sums over the dimensions of the array

Models are just factor graphs:


In [10]:
dims = (20, 6)

# note: it won't matter much how we initialize here
μ[j] ~ Normal(zeros(dims[2]), ones(dims[2]))
τ[j] ~ Gamma(1.1 * ones(dims[2]), ones(dims[2]))
μ0[j] ~ Const(zeros(dims[2]))
τ0[j] ~ Const(2 * ones(dims[2]))
a0[j] ~ Const(1.1 * ones(dims[2]))
b0[j] ~ Const(ones(dims[2]))

y[i, j] ~ Const(rand(dims))

# make factors
obs = @factor LogNormalFactor y μ τ
μ_prior = @factor LogNormalFactor μ μ0 τ0
τ_prior = @factor LogGammaCanonFactor τ a0 b0

m = VBModel([μ, τ, μ0, τ0, a0, b0, y], [obs, μ_prior, τ_prior]);
  • Models have a separate update strategy for each node
  • allows mix-and-match inference

Index Bookkeeping

  • nodes have associated indices
  • factors know which indices go with which nodes, which indices to sum over
    • inner indices belong to, e.g., elements of a multivariate normal (should not be separated)
    • outer indices correspond to replicates of "atomic" variables

So this is easy: i is inner:


In [ ]:
d = 5
μ[i] ~ MvNormalCanon(zeros(d), diagm(ones(d)))
Λ[i, i] ~ Wishart(float(d), diagm(ones(d)))

But here, i is inner for $\mu$ but not for $\tau$. In any factor combining these two, $\tau$ will be treated like a vector because it matches an inner index for some node:


In [ ]:
μ[i] ~ MvNormalCanon(zeros(d), diagm(ones(d)))
τ[i] ~ Gamma(1.1 * ones(d), ones(d))

Expression Nodes

  • We want to define nodes that combine nodes (ExprNodes)
  • But we also want E(x) to work for these cases
  • ExprNodes are like a cross between Factors and Nodes
    • represent variables in the model, not ELBO terms
    • but need to track multiple indices like factors

Solution: expectation calculus

  • because Julia allows us to parse Julia code natively, we can rewrite expressions
  • define macros that "wrap" E, etc. using linearity

In [11]:
x ~ Normal(rand(), rand())
y ~ Normal(rand(), rand())

@expandE E(x.data[1] + y.data[1])


Out[11]:
1.0612025838935821

In [13]:
macroexpand(:(@expandE E(x + y * z + 5)))


Out[13]:
:(E(x) + E(y) * E(z) + 5)

Note: assumes all nodes are independent!

Future

  • implement rules for V, etc.
  • allow Julia expressions like sum and product to work over selected indices
  • Eventually

In [ ]:
@pmodel begin
    x[i, k] ~ ...
    y[j, k] ~ ...
    
    z := sum(x, i) + sum(y, j)
end

Conjugacy

  • right now VinDsl goes out of its way to handle conjugacy between nodes
  • conjugate relationships not automatically detected, but easy to define
  • @defnaturals returns expected sufficient statistics from a factor for a given target distribution

In [ ]:
@defnaturals LogNormalFactor μ Normal begin
    Ex,  = E(x), E(τ)
    (Ex * , -/2)
end

@defnaturals LogNormalFactor τ Gamma begin
    v = V(x) + V(μ) + (E(x) - E(μ))^2
    (1/2, v/2)
end

Automatic gradients

  • we want a lot more than conjugacy-based coordinate ascent
  • at the very least, be able to perform a brute-force optimization step (coming very soon!)
  • automatic differentiation is an exact (to machine tolerance) way of calculating gradients based on arbitrary code
  • multiple packages in Julia, but some changes to Distributions needed
    • very high on my list
  • will allow SVI, etc.

Future plans:

  • two models for my own work
    • factorial HMM (gamma and log-normal)
    • linear state space
  • just conjugacy + simple convex opt steps
  • eventually:
    • nicer DSL: @qmodel and @pmodel, := for ExprNodes
    • minibatch support, SVI
    • LEG, control variates
    • Jacobians for all distributions $\longrightarrow$ BBVI
  • Optimize code generation for speed

VinDsl needs your help!

Open sourcing soon:

  • docs
  • tests
  • better ideas!