@this notebook will guide you through a very simple case of generative adversarial networks.
Like.. veeery simple. Generative adversarial networks that learn to convert 1d uniform noise distribution to a uniform 1d normal data distribution.
Unlike the main notebooks (adversarial_*.ipynb), this one features a lot of useful visualizations that will help you both get acquainted with the behavior of two networks and debug common errors without having to wait hours of CPU time.
In [ ]:
import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
def sample_noise(batch_size):
""" Uniform noise of shape [batch_size, 1] in range [0, 1]"""
return torch.rand(batch_size, 1)
def sample_real_data(batch_size):
""" Normal noise of shape [batch_size, 1], mu=5, std=1.5 """
return torch.randn(batch_size, 1) * 1.5 + 5
In [ ]:
# Generator converts 1d noise into 1d data
gen = nn.Sequential(nn.Linear(1, 16), nn.ELU(), nn.Linear(16, 1))
gen_opt = torch.optim.SGD(gen.parameters(), lr=1e-3)
# Discriminator converts 1d data into two logits (0th for real, 1st for fake).
# It is deliberately made stronger than generator to make sure disc is slightly "ahead in the game".
disc = nn.Sequential(nn.Linear(1, 64), nn.ELU(), nn.Linear(64, 2))
disc_opt = torch.optim.SGD(disc.parameters(), lr=1e-2)
In [ ]:
# we define 0-th output of discriminator as "is_fake" output and 1-st as "is_real"
IS_FAKE, IS_REAL = 0, 1
def train_disc(batch_size):
""" trains discriminator for one step """
# compute logp(real | x)
real_data = sample_real_data(batch_size)
logp_real_is_real = F.log_softmax(disc(real_data), dim=1)[:, IS_REAL]
# compute logp(fake | G(z)). We detach to avoid computing gradinents through G(z)
noise = <sample noise>
gen_data = <generate data given noise>
logp_gen_is_fake = <compute logp for 0th>
disc_loss = <compute loss>
# sgd step. We zero_grad first to clear any gradients left from generator training
disc_opt.zero_grad()
disc_loss.backward()
disc_opt.step()
return disc_loss.item()
In [ ]:
def train_gen(batch_size):
""" trains generator for one step """
# compute logp(fake | G(z)).
noise = <sample noise>
gen_data = <generate data given noise>
logp_gen_is_real = <compute logp gen is REAL>
gen_loss = <generator loss>
gen_opt.zero_grad()
gen_loss.backward()
gen_opt.step()
return gen_loss.item()
In [ ]:
import matplotlib.pyplot as plt
from IPython.display import clear_output
%matplotlib inline
for i in range(100000):
for _ in range(5):
train_disc(128)
train_gen(128)
if i % 250 == 0:
clear_output(True)
plt.figure(figsize=[14, 6])
plt.subplot(1, 2, 1)
plt.title("Data distributions")
plt.hist(gen(sample_noise(1000)).data.numpy()[:, 0],
range=[0, 10], alpha=0.5, normed=True, label='gen')
plt.hist(sample_real_data(1000).data.numpy()[:,0],
range=[0, 10], alpha=0.5, normed=True, label='real')
x = np.linspace(0, 10, dtype='float32')
disc_preal = F.softmax(disc(torch.from_numpy(x[:, None])))[:, 1]
plt.plot(x, disc_preal.data.numpy(), label='disc P(real)')
plt.legend()
plt.subplot(1, 2, 2)
plt.title("Discriminator readout on real vs gen")
plt.hist(F.softmax(disc(gen(sample_noise(100))))[:, 1].data.numpy(),
range=[0, 1], alpha=0.5, label='D(G(z))')
plt.hist(F.softmax(disc(sample_real_data(100)))[:, 1].data.numpy(),
range=[0, 1], alpha=0.5, label='D(x)')
plt.legend()
plt.show()
What to expect: