We illustrate the $\mathbb{Z}_2$ syncronization inference problem using pyro.
In [1]:
# import some dependencies
import torch
from torch.autograd import Variable
import numpy as np
import pyro
import pyro.distributions as dist
import pyro
from pyro.infer import SVI
Our model is
$$ Y_{ij} = \frac{\lambda}{n}\sigma_i\sigma_j+ \frac{W_{ij}}{\sqrt{n}}, $$with $\sigma_{i}\in\{\pm 1\}$, $i=1,\ldots n$, where $W_{i>j}\in \mathcal{N}(0,1)$, with $W_{ij}=W_{ji}$ and $W_{ii}=0$. Thus we need to sample from the distribution
$$ p(\sigma,Y;m) = \prod_i p(\sigma_i;m_i) \prod_{i>j} \sqrt{\frac{n}{2\pi}}\exp\left[-\frac{N\left(Y_{ij} - \lambda \sigma_i \sigma_j/n\right)^2}{2}\right], $$where the first factor describes the Bernoulli distributions, parameterized in terms of their expectations $m_i$.
$$ p(\sigma_i=\pm 1;m_i) = \frac{1\pm m_i}{2}. $$Actually, we want to obtain $p(\sigma|Y)$, which amounts to determining posterior $m_i(Y)$.
First we need to make some observations, using the above model. We will observe the $Y_{i>j}$ with a Gaussian likelihood, with the mean set by the variables $\sigma_j$. This is what is called the planted ensemble in this review.
In [12]:
np.random.standard_normal([2,3])
Out[12]:
As per this guide, to do variational inference in Pyro, we need to define a model and a guide. The model consists of
pyro.observe), in our case $Y_{ij}$pyro.sample), $\sigma_j$pyro.param), $m_i$The guide is the variational distribution. It is also a stochastic function, but without pyro.observe statements.
In [5]:
def Z2_model(λ, n, data):
m_0 = Variable(torch.ones(n)*0.5) # 50% success rate
var = Variable(torch.ones(1)) / np.sqrt(N)
σ = 2 * pyro.sample('σ', dist.bernoulli, m_0) - 1 # σ variables live in {-1,1}
for i in range(n):
for j in range(i):
pyro.observe(f"obs_{i}{j}",
dist.normal, Z2_data[i][j], λ*σ[i]*σ[j] / n, var)
In [7]:
def Z2_guide(λ, n, data):
m_var_0 = Variable(torch.ones(n)*0.5, requires_grad=True)
m_var = pyro.param("m_var", m_var_0)
pyro.sample('σ', dist.bernoulli, m_var)
In [ ]:
svi = SVI(model, guide, optimizer, loss="ELBO")