Mixture Density Networks with PyTorch

Related posts:

Standard Data Fitting

Before we talk about Mixture Density Networks, let's perform some standard data fitting using PyTorch to make sure everything works. Neural nets with even one hidden layer can be universal function approximators, so let's try to fit a sinusoidal function.

$y_{true}(x)=7 \sin( 0.75 x ) + 0.5 x + \epsilon$

  • $y_{true}(x)$: this function $y_{true}$ takes $x$ as input.
  • $7 \sin( 0.75 x)$: a large periodic $\sin$ wave.
  • $0.5 x$: add a slight upward slope.
  • $\epsilon$: add some random noise "epsilon".

First we import the libraries we need.


In [1]:
import matplotlib.pyplot as plt # creating visualizations
import numpy as np # basic math and random numbers
import torch # package for building functions with learnable parameters
import torch.nn as nn # prebuilt functions specific to neural networks
from torch.autograd import Variable # storing data while learning

Then we generate random inputs $x$ to get random samples of $y(x)$. Later we will train a neural net on this data.


In [2]:
def generate_data(n_samples):
    epsilon = np.random.normal(size=(n_samples))
    x_data = np.random.uniform(-10.5, 10.5, n_samples)
    y_data = 7*np.sin(0.75*x_data) + 0.5*x_data + epsilon
    return x_data, y_data
    
n_samples = 1000
x_data, y_data = generate_data(n_samples)

In [3]:
plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.2)
plt.show()


We will use this data to train a neural network with one hidden layer. This neural network is described by the following equation:

$y_{pred}(x) = w_{out} \tanh( w_{in} x + b_{in}) + b_{out}$

  • $y_{pred}(x)$: this function $y_{pred}$ takes $x$ as input.
  • $w_{in}, w_{out}$: weights for the input and output layers.
  • $b_{in}, b_{out}$: biases for the input and output layers.
  • $\tanh$: a nonlinear activation function.

Let's create this network using 20 hidden nodes.


In [4]:
n_input = 1
n_hidden = 20
n_output = 1

# create the network (also called a "model" of the data)
network = nn.Sequential(nn.Linear(n_input, n_hidden),
                        nn.Tanh(),
                        nn.Linear(n_hidden, n_output))

To train the network we must define a loss function. Also called a "cost function", this is a description of what counts as "better" or "worse" results, and allows us to modify the network to achieve the best results.

We will use the mean squared error loss function.


In [5]:
loss_fn = nn.MSELoss()

We also need to pick an optimizer. Optimizers use the loss to determine which parameters in the network should be changed, and how much.

We will use the RMSprop optimizer, which happens to work well for this problem.


In [6]:
optimizer = torch.optim.RMSprop(network.parameters())

Now let's move our data from numpy to PyTorch. This requires:

  1. Converting the data from np.float64 (numpy default) to np.float32 (Torch default).
  2. Reshaping the data from (n_samples) to (n_samples, n_input). Typically the first dimension is your batch size, but here we are processing all the samples in one batch.
  3. Convert from a numpy array to a PyTorch tensor.
  4. Initialize an input and output Variable that we can feed the x and y tensors.

We will use this naming convention:

  • numpy arrays x_data
  • PyTorch tensors x_tensor
  • PyTorch Variables x_variable

In [7]:
# change data type and shape, move from numpy to torch
# note that we need to convert all data to np.float32 for pytorch
x_tensor = torch.from_numpy(np.float32(x_data).reshape(n_samples, n_input))
y_tensor = torch.from_numpy(np.float32(y_data).reshape(n_samples, n_input))
x_variable = Variable(x_tensor)
y_variable = Variable(y_tensor, requires_grad=False)

Now let's define a training loop. It will use the optimizer to minimize the loss function by modifying the network's parameters.


In [8]:
def train():
    for epoch in range(3000):
        y_pred = network(x_variable) # make a prediction
        loss = loss_fn(y_pred, y_variable) # compute the loss
        optimizer.zero_grad() # prepare the optimizer
        loss.backward() # compute the contribution of each parameter to the loss
        optimizer.step() # modify the parameters

        if epoch % 300 == 0:
            print(epoch, loss.data[0])

train()


0 35.52162170410156
300 5.057154178619385
600 1.7624188661575317
900 1.1605937480926514
1200 1.1055160760879517
1500 1.0992330312728882
1800 1.0915052890777588
2100 1.0879061222076416
2400 1.0838695764541626
2700 1.080744981765747

Let's see how it performs by processing some evenly spaced samples.


In [9]:
# evenly spaced samples from -10 to 10
x_test_data = np.linspace(-10, 10, n_samples)

# change data shape, move from numpy to torch
x_test_tensor = torch.from_numpy(np.float32(x_test_data).reshape(n_samples, n_input))
x_test_variable = Variable(x_test_tensor)
y_test_variable = network(x_test_variable)

# move from torch back to numpy
y_test_data = y_test_variable.data.numpy()

# plot the original data and the test data
plt.figure(figsize=(8, 8))
plt.scatter(x_data, y_data, alpha=0.2)
plt.scatter(x_test_data, y_test_data, alpha=0.2)
plt.show()


The network can fit this sinusoidal data quite well, as expected. However, this type of fitting only when we want to approximate a one-to-one, or many-to-one relationship.

Suppose we invert the training data so we are predicting $x(y)$ instead of $y(x)$.


In [10]:
# plot x against y instead of y against x
plt.figure(figsize=(8, 8))
plt.scatter(y_data, x_data, alpha=0.2)
plt.show()


If we use the same method to fit this data, it won't work well because the network can't output multiple values for each input. Because we used mean squared error loss, the network will try to output the average output value for each input.


In [11]:
x_variable.data = y_tensor
y_variable.data = x_tensor

train()


0 57.528987884521484
300 21.487085342407227
600 21.335010528564453
900 21.292551040649414
1200 21.242443084716797
1500 21.197223663330078
1800 21.15885353088379
2100 21.115901947021484
2400 21.06930923461914
2700 21.02263641357422

In [12]:
x_test_data = np.linspace(-15, 15, n_samples)
x_test_tensor = torch.from_numpy(np.float32(x_test_data).reshape(n_samples, n_input))
x_test_variable.data = x_test_tensor

y_test_variable = network(x_test_variable)

# move from torch back to numpy
y_test_data = y_test_variable.data.numpy()

# plot the original data and the test data
plt.figure(figsize=(8, 8))
plt.scatter(y_data, x_data, alpha=0.2)
plt.scatter(x_test_data, y_test_data, alpha=0.2)
plt.show()


Because our network only predicts one output value for each input, this approach will fail miserably.

What we want is a network that has the ability to predict a multiple output values for each input. In the next section we implement a Mixture Density Network (MDN) to achieve this task.

Mixture Density Networks

Mixture Density Networks, developed by Christopher Bishop in the 1990s, are one way to produce multiple outputs from a single input. MDN predicts a probability distribution of possible output values. Then can sample several possible different output values for a given input.

This concept is quite powerful, and can be employed many current areas of machine learning research. It also allows us to calculate a sort of confidence factor in the predictions that the network is making.

The inverted sinusoid data we chose is not just a toy problem. In the paper introducing MDNs an inverted sinusoid is used to describe the angle we need to move a robot arm to achieve a target location. MDNs are also used to model handwriting, where the next stroke is drawn from a probability distribution of multiple possibilities, rather than sticking to one prediction.

Bishop's implementation of MDNs will predict a class of probability distributions called Mixture of Gaussians or Gaussian Mixture Models, where the output value is modelled as a weighted sum of multiple Gaussians, each with different means and standard deviations.

So for each input $x$, we will predict a probability distribution function $P(y|x)$:

$P(y|x) = \sum_{k}^{K} \Pi_{k}(x) \phi(y, \mu_{k}(x), \sigma_{k}(x))$

  • $k$ is an index describing which Gaussian we are referencing. There are $K$ Gaussians total.
  • $\sum_{k}^{K}$ is the summation operator. We sum every $k$ Gaussian across all $K$. You might also see $\sum_{k=0}^{K-1}$ or $\sum_{k=1}^{K}$ depending on whether an author is using zero-based numbering or not.
  • $\Pi_k$ acts as a weight, or multiplier, for mixing every $k$ Gaussian. It is a function of the input $x$: $\Pi_k(x)$
  • $\phi$ is the Gaussian function and returns the at $y$ for a given mean and standard deviation.
  • $\mu_k$ and $\sigma_k$ are the parameters for the $k$ Gaussian: mean $\mu_k$ and standard deviation $\sigma_k$. Instead of being fixed for each Gaussian, they are also functions of the input $x$: $\mu_k(x)$ and $\sigma_k(x)$

All of $\sigma_{k}$ are positive, and all of the weights $\Pi$ sum to one:

$\sum_{k}^{K} \Pi_{k} = 1$

First our network must learn the functions $\Pi_{k}(x), \mu_{k}(x), \sigma_{k}(x)$ for every $k$ Gaussian. Then these functions can be used to generate individual parameters $\mu_k, \sigma_k, \Pi_k$ for a given input $x$. These parameters will be used to generate our pdf $P(y|x)$. Finally, to make a prediction, we will need to sample (pick a value) from this pdf.

In our implementation, we will use a neural network of one hidden layer with 20 nodes. This will feed into another layer that generates the parameters for 5 mixtures: with 3 parameters $\Pi_k$, $\mu_k$, $\sigma_k$ for each Gaussian $k$.

Our definition will be split into three parts.

First we will compute 20 hidden values $z_h$ from our input $x$.

$z_h(x) = \tanh( W_{in} x + b_{in})$

Second, we will use these hidden values $z_h$ to compute our three sets of parameters $\Pi, \sigma, \mu$:

$ z_\Pi = W_{\Pi} z_h + b_{\Pi}\\ z_\sigma = W_{\sigma} z_h + b_{\sigma}\\ z_\mu = W_{\mu} z_h + b_{\mu} $

Third, we will use the output of these layers to determine the parameters of the Gaussians.

$ \Pi = \frac{\exp(z_{\Pi})}{\sum_{k}^{K} \exp(z_{\Pi_k})}\\ \sigma = \exp(z_{\sigma})\\ \mu = z_{\mu} $

  • $\exp(x)$ is the exponential function also written as $e^x$

We use a softmax operator to ensure that $\Pi$ sums to one across all $k$, and the exponential function ensures that each weight $\Pi_k$ is positive. We also use the exponential function to ensure that every $\sigma_k$ is positive.

Let's define our MDN network.


In [13]:
class MDN(nn.Module):
    def __init__(self, n_hidden, n_gaussians):
        super(MDN, self).__init__()
        self.z_h = nn.Sequential(
            nn.Linear(1, n_hidden),
            nn.Tanh()
        )
        self.z_pi = nn.Linear(n_hidden, n_gaussians)
        self.z_sigma = nn.Linear(n_hidden, n_gaussians)
        self.z_mu = nn.Linear(n_hidden, n_gaussians)  

    def forward(self, x):
        z_h = self.z_h(x)
        pi = nn.functional.softmax(self.z_pi(z_h), -1)
        sigma = torch.exp(self.z_sigma(z_h))
        mu = self.z_mu(z_h)
        return pi, sigma, mu

We cannot use the MSELoss() function for this task, because the output is an entire description of the probability distribution and not just a single value. A more suitable loss function is the logarithm of the likelihood of the output distribution vs the training data:

$loss(y | x) = -\log[ \sum_{k}^{K} \Pi_{k}(x) \phi(y, \mu_{k}(x), \sigma_{k}(x)) ]$

So for every $x$ input and $y$ output pair in the training data set, we can compute a loss based on the predicted distribution versus the actual distribution, and then attempt the minimise the sum of all the costs combined. To those who are familiar with logistic regression and cross entropy minimisation of softmax, this is a similar approach, but with non-discretised states.

We have to implement this cost function ourselves:


In [14]:
oneDivSqrtTwoPI = 1.0 / np.sqrt(2.0*np.pi) # normalization factor for Gaussians
def gaussian_distribution(y, mu, sigma):
    # make |mu|=K copies of y, subtract mu, divide by sigma
    result = (y.expand_as(mu) - mu) * torch.reciprocal(sigma)
    result = -0.5 * (result * result)
    return (torch.exp(result) * torch.reciprocal(sigma)) * oneDivSqrtTwoPI

def mdn_loss_fn(pi, sigma, mu, y):
    result = gaussian_distribution(y, mu, sigma) * pi
    result = torch.sum(result, dim=1)
    result = -torch.log(result)
    return torch.mean(result)

Let's create our MDN with 20 hidden nodes and 5 Gaussians.


In [15]:
network = MDN(n_hidden=20, n_gaussians=5)

We'll use a different optimizer this time called Adam that is more suited to this task.


In [16]:
optimizer = torch.optim.Adam(network.parameters())

We could generate more data to train the MDN, but what we have is nearly enough.


In [17]:
mdn_x_data = y_data
mdn_y_data = x_data

mdn_x_tensor = y_tensor
mdn_y_tensor = x_tensor

x_variable = Variable(mdn_x_tensor)
y_variable = Variable(mdn_y_tensor, requires_grad=False)

Finally, let's define a new training loop. We need a training loop that can handle the new loss function, and the MDN needs to train longer than the previous network.


In [18]:
def train_mdn():
    for epoch in range(10000):
        pi_variable, sigma_variable, mu_variable = network(x_variable)
        loss = mdn_loss_fn(pi_variable, sigma_variable, mu_variable, y_variable)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % 500 == 0:
            print(epoch, loss.data[0])

train_mdn()


0 7.357309341430664
500 3.0533816814422607
1000 2.575230121612549
1500 2.0565884113311768
2000 1.6556810140609741
2500 1.5370887517929077
3000 1.4593229293823242
3500 1.39969003200531
4000 1.3698593378067017
4500 1.3465501070022583
5000 1.304410696029663
5500 1.2787706851959229
6000 1.2707456350326538
6500 1.2627036571502686
7000 1.2542991638183594
7500 1.248332142829895
8000 1.241523027420044
8500 1.231286644935608
9000 1.2280000448226929
9500 1.2195545434951782

Once the training is finished, we can observe all the parameters for the Gaussians and see how they vary with respect to the input $x$.


In [19]:
pi_variable, sigma_variable, mu_variable = network(x_test_variable)

pi_data = pi_variable.data.numpy()
sigma_data = sigma_variable.data.numpy()
mu_data = mu_variable.data.numpy()

fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(8,8))
ax1.plot(x_test_data, pi_data)
ax1.set_title('$\Pi$')
ax2.plot(x_test_data, sigma_data)
ax2.set_title('$\sigma$')
ax3.plot(x_test_data, mu_data)
ax3.set_title('$\mu$')
plt.xlim([-15,15])
plt.show()


We can also plot the $\mu$ of each Gaussian with respect to $x$, and show the range of by highlighting a region between $\mu-\sigma$ and $\mu+\sigma$.


In [20]:
plt.figure(figsize=(8, 8), facecolor='white')
for mu_k, sigma_k in zip(mu_data.T, sigma_data.T):
    plt.plot(x_test_data, mu_k)
    plt.fill_between(x_test_data, mu_k-sigma_k, mu_k+sigma_k, alpha=0.1)
plt.scatter(mdn_x_data, mdn_y_data, marker='.', lw=0, alpha=0.2, c='black')
plt.xlim([-10,10])
plt.ylim([-10,10])
plt.show()


In the plot above, we see that for every point on the $x$-axis, there are multiple lines or states where $y$ may be, and we select these states with probabilities modelled by $\Pi$. Note that the network won't find an ideal solution every time. It's possible to get lower loss by using more Gaussians, but the results are harder to interpret.

If we want to sample from the network we will need to pick a Gaussian $k$ and pick a value (sample) from that Gaussian. Here we use a trick called Gumbel softmax sampling to pick our $k$. We treat the $\Pi$ weights as a discrete distribution of probabilities, and sample one $k$ for each row of pi_data.


In [21]:
def gumbel_sample(x, axis=1):
    z = np.random.gumbel(loc=0, scale=1, size=x.shape)
    return (np.log(x) + z).argmax(axis=axis)

k = gumbel_sample(pi_data)

Now that we have selected $k$ for each row, we can select $\sigma$ and $\mu$ as well. We will use np.random.randn to sample from each Gaussian, scaling it by $\sigma$ and offsetting it by $\mu$.


In [22]:
indices = (np.arange(n_samples), k)
rn = np.random.randn(n_samples)
sampled = rn * sigma_data[indices] + mu_data[indices]

With these sampled $y$ values, we can overlay them on the original distribution to see how accurately the network captures the shape.


In [23]:
plt.figure(figsize=(8, 8))
plt.scatter(mdn_x_data, mdn_y_data, alpha=0.2)
plt.scatter(x_test_data, sampled, alpha=0.2, color='red')
plt.show()


Some other things to try:

  • What other constraints might we enforce? Is it ever helpful to encourage $\Pi$ to be sparse to maintain a "simple" distribution?
  • What kinds of regularization can we add to the network?
  • Often very small sigma can create problems, it might be helpful to add sigma += 0.01 inside the MDN.
  • How do more or fewer Gaussians affect the result?
  • How do more or fewer hidden nodes affect the result?
  • What other kinds of nonlinear activation functions work besides Tanh()? Try Sigmoid() too.
  • How do these plots move if you draw them in realtime while the network is training?
  • What happens when if you use the softplus function for making the $\sigma$ positive, instead of $\exp(x)$?