Mean field inference for $\mathbb{Z}_2$ Syncronization

We illustrate the $\mathbb{Z}_2$ syncronization inference problem using pyro.


In [1]:
# import some dependencies
import torch
from torch.autograd import Variable

import numpy as np

import pyro
import pyro.distributions as dist
import pyro

from pyro.infer import SVI

The model

Our model is

$$ Y_{ij} = \frac{\lambda}{n}\sigma_i\sigma_j+ \frac{W_{ij}}{\sqrt{n}}, $$

with $\sigma_{i}\in\{\pm 1\}$, $i=1,\ldots n$, where $W_{i>j}\in \mathcal{N}(0,1)$, with $W_{ij}=W_{ji}$ and $W_{ii}=0$. Thus we need to sample from the distribution

$$ p(\sigma,Y;m) = \prod_i p(\sigma_i;m_i) \prod_{i>j} \sqrt{\frac{n}{2\pi}}\exp\left[-\frac{N\left(Y_{ij} - \lambda \sigma_i \sigma_j/n\right)^2}{2}\right], $$

where the first factor describes the Bernoulli distributions, parameterized in terms of their expectations $m_i$.

$$ p(\sigma_i=\pm 1;m_i) = \frac{1\pm m_i}{2}. $$

Actually, we want to obtain $p(\sigma|Y)$, which amounts to determining posterior $m_i(Y)$.

The planted ensemble

First we need to make some observations, using the above model. We will observe the $Y_{i>j}$ with a Gaussian likelihood, with the mean set by the variables $\sigma_j$. This is what is called the planted ensemble in this review.


In [12]:
np.random.standard_normal([2,3])


Out[12]:
array([[ 0.70092258,  1.48287668, -0.86865993],
       [-0.86309617,  0.40523175, -0.62124852]])

Setting it up in Pyro

As per this guide, to do variational inference in Pyro, we need to define a model and a guide. The model consists of

  1. Observations (pyro.observe), in our case $Y_{ij}$
  2. Latent random variables (pyro.sample), $\sigma_j$
  3. Parameters (pyro.param), $m_i$

The guide is the variational distribution. It is also a stochastic function, but without pyro.observe statements.


In [5]:
def Z2_model(λ, n, data):
    m_0 = Variable(torch.ones(n)*0.5) # 50% success rate
    var = Variable(torch.ones(1)) / np.sqrt(N)
    σ = 2 * pyro.sample('σ', dist.bernoulli, m_0) - 1 # σ variables live in {-1,1}
    for i in range(n):
        for j in range(i):
            pyro.observe(f"obs_{i}{j}", 
                         dist.normal, Z2_data[i][j], λ*σ[i]*σ[j] / n, var)

In [7]:
def Z2_guide(λ, n, data):
    m_var_0 = Variable(torch.ones(n)*0.5, requires_grad=True) 
    m_var = pyro.param("m_var", m_var_0)
    pyro.sample('σ', dist.bernoulli, m_var)

In [ ]:
svi = SVI(model, guide, optimizer, loss="ELBO")