Scalable Nonparametric Directed Graphical Model Inference and Learning

Author: Kai Londenberg (Kai.Londenberg@gmail.com), June 2014

Abstract

This short article tries to give an overview over complementary techniques over MCMC for general inference in arbitrary directed probabilistic graphical models. The focus lies on techniques and algorithms for creating hybrid models which can be scaled to high

dimensional problems, problems with huge data sets and distributed among multiple machines.

Motivation

Two papers got me thinking: Dauwels et al: Particle Methods as Message Passing, which gives a nice overview of how to generalize Message Passing methods, by mixing sampling based methods freely with exact or fast approximate inference algorithms.

The second is Neiswanger et al: Embarassingly Parallel MCMC, where an algorithm is described which can be used to scale MCMC to problems with huge data sets.

Both algorithms actually suffer from the same problem, namely the Message Fusion problem (described way down). A problem which has luckily been successfully solved before. I add to that list by proposing a new approach to efficient density estimation from MCMC models, called Density Mapping

What I hope is, that these ideas can lead to a practical implementation of a general, flexible inference system with semantics similar to those found in common MCMC packages for Bayesian Inference, but with the capability to outperform these for problems with huge data sets or a large number of dimensions if the distribution can be factorized into smaller problems somehow.

I hope to be able to extend the existing Bayesian Modeling Toolkit for Python PyMC3 via my side-project PyPGMc to support the algorithms mentioned in this article.

So this article is both an overview, and sort of a collection of ideas and roadmap items.

Introduction to Directed Graphical Models

Directed Probabilistic Graphical Models (DGMs) provide a flexible and powerful framework for probabilistic reasoning. In all generality, they are a way to efficiently represent complex probability distributions in high dimensional spaces by factorizing the joint distribution into conditional distributions.

Given a set of random variables $ x_i \sim X_i $ and their joint random vector $ x \sim X $ with $ x = \{ x_1, \ldots, x_N \} $ we represent their joint distribution as the product of a set of conditional distributions

$$ P(X) = \prod_i^N P_i(X_i|{pa}(X_i) $$

Where $ pa(X_i) $ is the set of parents of variable $ X_i $ in a directed graph $ G $, where each vertex in the graph represents one of the random variables.

Each vertex in this graph could assign an arbitrary probability distribution to it's random variables. But this probability distribution may only depend on the parents of the variable in the graph.

Such a representation has many advantages, even to list them completely would be out of the scope of this article. An excellent overview is present in the book Probabilistic Graphical Models by Daphne Koller, who also offers a Coursera course by the same name, which is highly recommended.

D-Separation / Conditional Independence

Most importantly, the graph encodes information about (conditional) independencies among the variables. By a property called D-Separation which can be determined using a few simple rules, we can safely determine whether the probability distribution of a given set of variables in the graph can be affected by changes in the probability of another set of variables given another set of variables which are held fixed.

If the conditional dependencies that hold over a joint probability distribution are a subset of the conditional independence assumptions made by the graph, this distribution is compatible to the graph in the sense that the distribution can be faithfully represented by a factorization of the distribution along that graph.

Causal Models

A very common form of these models is to restrict parent / child relationships to cause/effect pairs. Furthermore, the network must be complete in the sense that no common causes of any two variables are missing from the model.

While it is not neccessary for the machinery of DGMs to work that they are causal models, this can, under certain circumstances, be used to perform so called causal inference or causal reasoning using DGMs. More on that can be found in Judea Pearl's excellent book on Causality.

One important rule to note is, that in order to perform causal inference in such a causal network, i.e. to estimate the impact of an explicit action where the value of a variable is forced to have a certain value (in contrast to observing it having that value) it is neccessary to sever all ties from the parents of said variable to it (since they are no longer causally connected). Furthermore, the network

In Pearl's see/do calculus, he discerns between $ P(x|y) $ ( probability of x given that I see $y$ ) and $ P(x|do(y)) $ ( probability of x given that I do x).

While not of further concern here, Pearl provides a great deal of informal insights and formal rules into when and how observations can be converted into causal claims, how to transfer the results of studies from one setting to another.

Common types of DGMs

Some common types of DGMs that you might have heard about include:

  • Bayesian Networks (BNs)
  • Hierarchical Bayesian Models
  • (Gaussian) Mixture Models (GMMs)

Also a lot of common models for time-series can be thought of as DGMs, among them:

  • Vector Auto-Regressive Models (VAR)
  • Hidden Markov Models (HMMs)
  • State Space Models (SSM)
  • Dynamic Bayesian Networks (DBNs)

Many of these models have their own set of specialized inference and learning algorithms, their own set of advantages and disadvantages.

Inference / Reasoning in Graphical Models

Generally speaking, we can use these graphical models to reason about the marginal probability distributions, most likely configurations (MLE / MAP configurations) etc. of variables of interest given evidence. This in turn can be used in many applications, from decision support (making decisions under uncertainty) and as a key component for supervised, unsupervised and semi-supervised learning.

Parametric VS Nonparametric representations

If we can restrict our probability distributions to come from specifiy families of distributions, inference can be made very efficient in some cases. But if you want to have a general model which can capture any kind of weird multi-modal and non-continous distributions, you are limited to very slow inference using MCMC methods. Also, the scalability in these cases is very limited, because most MCMC algorithms are not made to be distributed.

Family Advantages Disadvantages
Conditional Linear Gaussian (CLG) Very Fast Are your distributions linear combinations of gaussians ?
Discrete (Categorical) Fast High number of parameters if number of parents of any variable, or number of discrete "bins" of variables becomes too large. Quickly becomes intractable in these cases.
Generic (arbitrary functions of parents) Most Flexible Very slow inference (MCMC) or strongly biased approximate inference, hard to determine convergence / mixing. Usually intractable in high dimensions.

Hierarchical Bayesian Models and MCMC

In Bayesian Hierarchical Modeling, we are usually either interested in estimating marginals of certain model parameters in order to gain insights into specific problems, or we are interested in evaluating the expected value of some (utility- or loss-) function over the posterior of a set of random variables.

Given that in the Bayesian view, the unknown parameters of a model are random variables like any other, so we can use the machinery of DGM inference. Since these models can have almost arbitrary functional relations between variables, it is commong to perform Markov Chain Monte Carlo simulation to sample from the posterior distribution.

What is problematic about these methods is that inference is usually slow, and it is hard to determine whether the model has converged (mixed) to a stable posterior. Generally, MCMC does not scale well to high-dimensional problems using established methods (yet), despite the fact that there have been some special areas where MCMC methods could be applied to solve high dimensional problems such as large scale matrix factorization for recommender systems.

MCMC, while an approximate approach, is asymptotically exact if applied correctly.

Message Passing Algorithms

Among the most efficient algorithms for exact and approximate inference in discrete and conditional linear gaussian DGMs are so called Message Passing (MP) or Belief Propagation Algorithms. These algorithms operate on a so-called factor graph, which is very similar to a DGM, except that vertices (factors) may represent joint distributions of multiple variables. If factors share variables, they have to be connected (at least indirectly) using a chain of factors where each factor contains that variable.

Correspondingly, edges (along which messages are passed) need to be able to convey information about joint distributions of several variables at once.

By collapsing a DGM into a factor graph tree (so called Clique Tree or Junction Tree), it is possible to perform efficient exact inference on discrete and conditional linear gaussian networks using message passing inference. That is, unless the resulting tree has at some point a too large tree-width (loosely, the result of a large but too dense graph), which can make exact inference intractable.

Even in those cases where exact inference is intractable, the Loopy Belief Propagation algorithm can provide very fast approximate (asymptotically biased) inference, providing good solutions (empirical results) in cases where other algorithms fail or are too slow.

It is important to note that the core algorithm (message passing) of approximate Loopy Belief Propagation and exact Clique Tree Inference are the same.

Again, I refer to the book Probabilistic Graphical Models by Daphne Koller and her Coursera course for details.

Nonparametric and Particle Belief Propagation / Message Passing

As Dauwels et al. have pointed out in the Paper Particle Methods as Message Passing, it is actually possible to combine parametric exact inference and nonparametric approximate inference by using Particle Lists as messages in the message passing algorithm.

They used this approach to show that it is possible to view common MCMC procedures such as Gibbs Sampling, Metropolis Hastings and Importance Sampling as special cases of Particle Message Passing.

What is also important here: Each factor could be an independend MCMC sampler working on a subset of the variables and / or evidence. Or a faster parametric inference algorithm, if the problem allows.

Other fast inference algorithms such as Expectation Propagation and other Forms of Variational Inference such as Mean-Field based method can be also be cast as variants of Message Passing procedures.

There are several key papers which describe important aspects and approaches that might be taken:

The Message Fusion Problem

All of the above papers identify a single performance bottleneck in these algorithms.

If we have two factors which share at least one continuous variable $ x $ with a smoot pdf, the probability that two factors will independently choose the same value for that variable is essentially zero.

So if we have two factors represented using discrete particle lists $ \phi(x) $ and $ \theta(x) $, their product will be zero with near certainty everywhere.

What we need to do is to perform so called Message Fusion, for which several approaches have been proposed.

The standard approach is to use some form of Kernel Density Estimation (KDE), which effectively represents each particle/sample not as a discrete probability spike, but smoothes the density using an appropriate kernel. Usually, Gaussian Kernels are used. Given an efficient sampling procedure from products of Gaussian Mixtures, such as Ihler, Sudderth et al: Efficient Multiscale Sampling from Products of Gaussian Mixtures, we can efficiently sample from these. More on that approach is found in Ihler, Sudderth et al: Nonparametric Belief Propagation

But a problem remains: It's computationally expensive to evaluate the probability density and curvature (Jacobian and Hessian) of these messages. And that might be important if we would like to sample from a product of such a mixture density message with the probability density function of a factor.

Particle Belief Propagation Approach

In Ihler, Mc. Allister: Particle Belief Propagation it has been proposed to sample from the particles in the messages themselves. This is similar to what happens in Particle Filters. While potentially a good idea, it (like Particle Filters) suffers from the thinning problem: If the message and the factor do not agree, the number of useable particles gets very low and the procedure produces unreliable and/or inaccurate results.

Like with Particle Filters, one approach to fix this problem is to use resampling. That is, loosely speaking, we tell the original factor where we got the message from, that we would like to have samples of a finer resolution in certain regions. Then the original factor replies with a new (importance sampled) message list, where it provides new samples, with more (but downweighted) samples in the corresponding regions of interest.

This procedure can already provide accurate distributed inference. It just has one problem: It is probably pretty slow (all this re-sampling) and requires lots of communication.

If we let the numbers of particles in each list become low, it can be seen as a form of Gibbs Sampling where we exchange not just one particle (the sample), but multiple of them.

This procedure, as inefficient as it might seem, might have a distinct advantage over most MCMC Algorithms: It allows for much easier convergence diagnostics. By measuring the convergence on a per-message level, we can probably automatically determine when the algorithm has converged to a final solution, given that we can determine this for each factor individually.

While this sounds not so much of a great deal, actually it is: For MCMC you usually need a human expert who decides if the numbers of samples have been sufficient, if all relevant states have been visited etc. But even then, that person can never be sure. Much the less, if the problem gets high-dimensional. Having a clear convergence diagnostic opens the door for novel applications in large scale risk analysis.

Compact Message Density Estimation for Message Fusion

Another approach, which has also been taken or at least proposed by several researchers ( see Ihler, Mc. Allister: Particle Belief Propagation for an overview) is to try to estimate message densities using some form of nonparametric density estimation technique which both smoothes the distribution, and compresses the amount of data required to transfer the message. See Koller et al: A General Algorithm for Approximate Inference and Its Application to Hybrid Bayes Nets for a more thorough discussion of this.

In that paper, Density Estimation Trees (DETs) with GMMs at the leaves have been used with success by Koller et. al. as density estimators, so that might be a good choice to make as well. They iteratively refined these density estimates using an iterative approach, similar to the resampling mentioned above.

Generally, Multivariate Gaussian Mixture Models (GMMs) trained with Regularized Expectation Maximization (EM) might be a another good choice. See Kevin Murphy: Machine Learning: A Probabilistic Perspective, Chapter 11. These would lend themselves to the fast methods in Ihler, Sudderth et al: Efficient Multiscale Sampling from Products of Gaussian Mixtures

But how do we determine the optimal number of mixture components ? Maybe we can make the algorithm automatically choose the number of components based on the data ?

One obvious but very slow approach would be to use cross-validation to select an optimal number of components. But this gets prohibitevely slow. A different approach would be to use a Dirichlet Process Clustering, also called Infinite Gaussian Mixture Model to choose a data-dependent number of components. Alternatively, a possibly better alternative is not to use the Dirichlet Process Prior for the number of components, but rather a Pitman-Yor Process, a more flexible two-parameter generalization of the Dirichlet Process which allows for Power-Law (fat) tails.

There are a lot of ready-made implementations of these (except for the Pitman-Yor Process Clustering), see Scikit-Learn Documentation: Mixtures

Another (novel) approach is the following, which might be more efficient if the Particle Lists are created using MCMC Sampling.

Embarassingly Parallel MCMC

In a recent paper ( Neiswanger et al: Embarassingly Parallel MCMC ) an asymptotically exact algorithm for performing embarassingly parallel distributed MCMC was presented. Interestingly, the main problem solved in that paper is almost exactly the Message Fusion problem stated above. So by solving one problem, we get to solve inference for both high-dimensional and big-data problems.

Density Mapping (DRAFT)

The core idea for Density Mapping is to combine MCMC, Gradient Ascent or EM and Kernel Density Estimation (KDE) into a single, more efficient algorithm for Kernel Density Estimation which can be used in the context of large scale distributed Nonparametric Message Passing inference engines.

We sample from a density function, and then modify the sampling density function by subtracting from it density estimates around local modes. This way, the probability density gets simultaneously mapped out, ensuring that the MCMC chain spends it's computational time efficiently by mapping so far uncovered regions of the probability space.

Let $ f^*: \mathbb{R}^D \mapsto \mathbb{R} $ be our unnormalized posterior density function which we can evaluate at any point. The corresponding normalized density is $ P^* $ with $ P^*(x) = \frac{1}{Z} f^*(x) $ and $ x \in \mathbb{R}^D $, with $ Z $ being the normalization constant, i.e. $ Z = \int{f^*(x) dx} $

We assume that we can consistently estimate the density $ P^* $ (usually a posterior) using a suitable Markov Chain Monte Carlo algorithm such as a Metropolis Hastings sampler given the unnormalized density function $ f^*(x) $

Now let us assume we have a kernel probability density estimate which has been estimated step by step from N point probability masses:

$$ P^{E}(x) = \frac{1}{H} \sum_{i=1}^{N} \gamma_i \cdot K_i(x) $$

Where $ H = \sum{\gamma_i} $ and each $ K_i $ is itself a properly normalized density function (Kernel) with a single mode $k_i = {argmax}_{x} K_i(x) $ with $ k_i \in \mathbb{R}^D $. We define the unnormalized kernel density at time step N as

$$ f_N^E(x) = \sum_{i=1}^{N} \frac{f_i^E(k_i)}{K_i(k_i)} \cdot K_i(x) $$

with $ f_0^E(x) = f^*(x) $. Correspondingly, we define $ \gamma_i = \frac{f^*(k_i)}{f_i^E(k_i)} $ which ensures that $ f_N^E(x) = P^{E}(x) = f^*(x) $ for all points $ x \in \{ k_1, \ldots, k_N \} $.

We define the Unnormalized Sampling Function as:

$$ f^F(x) = min(max(f^*(x)-f^E(x), f^*(x)^\frac{1}{s}), f^*(x)) $$

With $ s > 1 $ being a cooling factor which flattens the original distribution. Plausible initial values for s might be in the range from 2 to 100 depending on how flat we would like the distribution to become.

If we chose the density estimation Kernels $ K_i $ such that they (at least approximately) have limited support around their mode or mean $ x_i $, this Sampling Function can be calculated (or approximated) quickly, even if N (the number of kernels in the density estimate) is large, by just taking the nearest kernels to a given point into account. Such a lookup can be made efficient using KD-Trees or Cover-Trees to be performed in $ O(D \cdot \log(N)) $ time, with $ N $ being the number of components, and $ D $ being the dimensionality of the points.

During or after MCMC sampling, we should check for each sample whether $ f^*(x)-f^E(x) $ becomes negative at that point. This would indicate regions where we over-estimate the density. In such a case, it should be possible to shrink the variance of the responsible component of f^E(x) in that direction. Since this is computationally intensive (we have to re-compute all mixture components), this should be prevented by scaling down the variance of the kernel components in directions of high variance.

Density Mapping Algorithm (DRAFT)

Now, after initializing s to a sensible value, and starting with $ f^E=0$ (constant), the density estimation procedure works like this:

  1. Run MCMC to collect a number of samples from $ f^F $ . Discard burn-in samples.
  2. Check all sampled points for values with negative f^*(x)-f^E(x). Shrink the variance of responsible components of $ f^E $.
  3. Pick a random sample, and perform gradient ascent or EM to find a local optimum/mode of $ f^F $: called $ k_i $
  4. Check if this is a new local optimum. If not: Increase $ s $ and continue at 1.)
  5. Create a Kernel density estimate $ K_i $ around local optimum $ k_i $ (use the Hamiltonian of $ f^* $ at $ k_i $ as a scale/precision matrix, apply sensible regularization)
  6. (Optional): Add $ k_i $ to a KD-Tree index to speed up nearest-neighbour lookups
  7. Update $ f^E $ and $ f^F $ using the new kernel estimate $ K_i $
  8. Stopping criterion: Has the density been flattened enough ? Then stop.
  9. otherwise go to 3.) or 1.)

The result (so I hope) is a rather good density estimate. This algorithm has yet to be tried in practice.

Conclusions

It seems like everything is ready to build a generic framework for large scale inference in directed probabilistic graphical models. Someone just has to do it ..