In [1]:
%matplotlib inline
In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
##################
# option variables
##################
learning_rate = 1e-3
training_epoch = 100
batch_size = 100
# neural network properties
n_latent = 10 # number of latent
n_input = 28*28 # Image pixel number
n_cat = 10
n_con = 1
is_cuda = torch.cuda.is_available() # cuda 사용가능시, True
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
batch_size=16, shuffle=True)
In [3]:
##################
# help function
##################
def sample_c(size):
cat = np.random.multinomial(1, n_cat*[1/n_cat], size=size) # one-shot vector of a catagory
con = np.random.rand(size, n_con) # continous value [0, 1]
c = np.concatenate((cat, con), axis=1)
c = Variable(torch.from_numpy(c.astype('float32')))
return c
def plot_fakeimages() :
nx = 20
ny = 10
canvas = np.empty((28*ny, 28*nx))
x_values = np.linspace(0, 1, nx)
for j in range(ny):
for i, xi in enumerate(x_values):
z = Variable(torch.randn(1, n_latent))
con = Variable(torch.Tensor([[xi]]))
c = Variable(torch.zeros(1, 10))
c.data[0, j] = 1
if is_cuda : z, c, con = z.cuda(), c.cuda(), con.cuda()
img = model.G(torch.cat([z, c, con], 1))
img = img.view(-1, 28, 28)
canvas[j*28:(j+1)*28, (nx-i-1)*28:(nx-i)*28] = img.data.squeeze().cpu().numpy()
plt.figure(figsize=(8, 10))
plt.imshow(canvas, origin="upper", cmap="gray")
plt.tight_layout()
plt.show()
In [4]:
class GAN(nn.Module):
def __init__(self):
super(GAN, self).__init__()
self.G = nn.Sequential(
nn.Linear(n_latent + n_cat + n_con, 500, bias=True),
nn.ReLU(True),
nn.Linear(500, n_input, bias=True),
nn.Sigmoid(),
)
self.D = nn.Sequential(
nn.Linear(n_input, 500, bias=True),
nn.ReLU(True),
nn.Linear(500, 1, bias=True),
nn.Sigmoid(),
)
self.Q = nn.Sequential(
nn.Linear(n_input, 500, bias=True),
nn.ReLU(True),
nn.Linear(500, n_cat + n_con, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
x = self.G(x)
x = self.D(x)
return x
model = GAN()
#loss_fn = nn.BCELoss()
loss_fn = nn.MSELoss()
if is_cuda : model, loss_fn = model.cuda(), loss_fn.cuda()
label_real = Variable(0.5*torch.ones(batch_size)) #boundary seeking D(x) = D(G(Z))= 0.5
label_fake = Variable(torch.zeros(batch_size))
if is_cuda :
label_real, label_fake = label_real.cuda(), label_fake.cuda()
GQ_parameters = [
{'params': model.G.parameters()},
{'params': model.Q.parameters()},
]
optimizerD = torch.optim.Adam(model.D.parameters(), lr=learning_rate)
optimizerG = torch.optim.Adam(model.G.parameters(), lr=learning_rate)
optimizerQ = torch.optim.Adam(GQ_parameters, lr=learning_rate)
In [ ]:
# trainning
model.train()
train_loss_D = []
train_loss_G = []
for epoch in range(training_epoch):
for image, label in train_loader:
################################
# input data
################################
x_real = Variable(image.view(-1, n_input))
z = Variable(torch.randn(batch_size, n_latent))
c = sample_c(batch_size)
if is_cuda : x_real, z, c = x_real.cuda(), z.cuda(), c.cuda()
#################################
## Update discriminiator network
#################################
optimizerD.zero_grad() # zero_grad
x_fake = model.G(torch.cat([z, c], 1))
D_real = model.D(x_real)
D_fake = model.D(x_fake.detach())
loss_real = loss_fn(D_real, label_real)
loss_fake = loss_fn(D_fake, label_fake)
loss_D = loss_real + loss_fake
################################
loss_D.backward()
optimizerD.step() # update parameter
################################
# Update generator network
################################
optimizerG.zero_grad() # zero_grad
D_fake = model.D(x_fake)
loss_G = loss_fn(D_fake, label_real)
################################
loss_G.backward()
optimizerG.step() # update parameter
################################
# Update Q network
################################
optimizerQ.zero_grad() # zero_grad
x_fake = model.G(torch.cat([z, c], 1))
Q_fake = model.Q(x_fake)
Q_real = model.Q(x_real)
label = label.unsqueeze(1)
c_label = Variable(torch.zeros(batch_size, n_cat).scatter_(1,label, 1)).cuda()
loss_Q = loss_fn(Q_fake[:, :n_cat], c[:, :n_cat]) + 2*loss_fn(Q_fake[:, n_cat:], c[:, n_cat:]) + loss_fn(Q_real[:, :n_cat], c_label)
################################
loss_Q.backward()
optimizerQ.step() # update parameter
train_loss_D.append(loss_D.data[0])
train_loss_G.append(loss_G.data[0])
print('Epoch : {:04d} loss_D = {:.6f} loss_G = {:.6f}'.format(epoch, loss_D.data[0], loss_G.data[0]))
if (epoch + 1)%10 == 0 : plot_fakeimages()
In [ ]:
plt.plot(train_loss_D)
plt.show()
plt.plot(train_loss_G)
plt.show()
In [ ]: