In [1]:
%matplotlib inline

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np


##################
# option variables 
##################
learning_rate = 1e-3
training_epoch =  100
batch_size = 100
# neural network properties 
n_latent = 10 # number of latent 
n_input = 28*28   # Image pixel number
n_cat = 10
n_con = 1

is_cuda = torch.cuda.is_available() # cuda 사용가능시, True

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=16, shuffle=True)


Files already downloaded

In [3]:
##################
# help function
##################
def sample_c(size):
    cat = np.random.multinomial(1, n_cat*[1/n_cat], size=size) # one-shot vector of a catagory
    con = np.random.rand(size, n_con) # continous value [0, 1]
    c = np.concatenate((cat, con), axis=1)
    c = Variable(torch.from_numpy(c.astype('float32')))
    return c

def plot_fakeimages() :
    nx = 20
    ny = 10
    canvas = np.empty((28*ny, 28*nx))
    x_values = np.linspace(0, 1, nx)
    for j in range(ny): 
        for i, xi in enumerate(x_values):
            z = Variable(torch.randn(1, n_latent))
            con = Variable(torch.Tensor([[xi]]))
            c = Variable(torch.zeros(1, 10))
            c.data[0, j] = 1
            if is_cuda : z, c, con = z.cuda(), c.cuda(), con.cuda()
            img = model.G(torch.cat([z, c, con], 1))
            img = img.view(-1, 28, 28)
            canvas[j*28:(j+1)*28, (nx-i-1)*28:(nx-i)*28] = img.data.squeeze().cpu().numpy()
    plt.figure(figsize=(8, 10))        
    plt.imshow(canvas, origin="upper", cmap="gray")
    plt.tight_layout()
    plt.show()

In [4]:
class GAN(nn.Module):
    def __init__(self):
        super(GAN, self).__init__()
        self.G = nn.Sequential(
            nn.Linear(n_latent + n_cat + n_con, 500, bias=True),
            nn.ReLU(True),
            nn.Linear(500, n_input, bias=True),
            nn.Sigmoid(),
        )
        self.D = nn.Sequential(
            nn.Linear(n_input, 500, bias=True),
            nn.ReLU(True),
            nn.Linear(500, 1, bias=True),
            nn.Sigmoid(),
        )
        self.Q = nn.Sequential(
            nn.Linear(n_input, 500, bias=True),
            nn.ReLU(True),
            nn.Linear(500, n_cat + n_con, bias=True),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        x = self.G(x)
        x = self.D(x)
        return x 

model = GAN()
#loss_fn = nn.BCELoss() 
loss_fn = nn.MSELoss() 
if is_cuda : model, loss_fn = model.cuda(), loss_fn.cuda()

label_real = Variable(0.5*torch.ones(batch_size)) #boundary seeking D(x) = D(G(Z))= 0.5
label_fake = Variable(torch.zeros(batch_size))
if is_cuda :
    label_real, label_fake = label_real.cuda(), label_fake.cuda()
    
GQ_parameters = [
    {'params': model.G.parameters()},
    {'params': model.Q.parameters()},
]

optimizerD = torch.optim.Adam(model.D.parameters(), lr=learning_rate)
optimizerG = torch.optim.Adam(model.G.parameters(), lr=learning_rate)
optimizerQ = torch.optim.Adam(GQ_parameters, lr=learning_rate)

In [ ]:
# trainning
model.train()
train_loss_D = []
train_loss_G = []
for epoch in range(training_epoch):
    for image, label in train_loader:
        ################################
        #  input data
        ################################
        x_real = Variable(image.view(-1, n_input))
        z = Variable(torch.randn(batch_size, n_latent))
        c = sample_c(batch_size)
        if is_cuda :  x_real, z, c = x_real.cuda(), z.cuda(), c.cuda()
        
        #################################
        ## Update discriminiator network
        #################################
        optimizerD.zero_grad() # zero_grad
        x_fake = model.G(torch.cat([z, c], 1))
        D_real = model.D(x_real)
        D_fake = model.D(x_fake.detach())
        loss_real = loss_fn(D_real, label_real)
        loss_fake = loss_fn(D_fake, label_fake)
        loss_D = loss_real + loss_fake
        ################################
        loss_D.backward()
        optimizerD.step() # update parameter
        
        ################################
        # Update generator network
        ################################
        optimizerG.zero_grad() # zero_grad
        D_fake = model.D(x_fake)
        loss_G = loss_fn(D_fake, label_real)
        ################################
        loss_G.backward()
        optimizerG.step() # update parameter
        
        ################################
        # Update Q network
        ################################
        optimizerQ.zero_grad() # zero_grad
        x_fake = model.G(torch.cat([z, c], 1))
        Q_fake = model.Q(x_fake)
        Q_real = model.Q(x_real)
        label = label.unsqueeze(1)
        c_label = Variable(torch.zeros(batch_size, n_cat).scatter_(1,label, 1)).cuda()
        loss_Q = loss_fn(Q_fake[:, :n_cat], c[:, :n_cat]) + 2*loss_fn(Q_fake[:, n_cat:], c[:, n_cat:]) + loss_fn(Q_real[:, :n_cat], c_label)
        ################################
        loss_Q.backward()
        optimizerQ.step() # update parameter
        
        train_loss_D.append(loss_D.data[0])
        train_loss_G.append(loss_G.data[0])

    print('Epoch : {:04d}   loss_D = {:.6f}  loss_G = {:.6f}'.format(epoch, loss_D.data[0], loss_G.data[0]))
    if (epoch + 1)%10 == 0 : plot_fakeimages()


Epoch : 0000   loss_D = 0.066753  loss_G = 0.128502
Epoch : 0001   loss_D = 0.066550  loss_G = 0.135320
Epoch : 0002   loss_D = 0.079164  loss_G = 0.153788
Epoch : 0003   loss_D = 0.071069  loss_G = 0.130318

In [ ]:
plt.plot(train_loss_D)
plt.show()
plt.plot(train_loss_G)
plt.show()

In [ ]: