In [1]:
%matplotlib inline

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np


##################
# option variables 
##################
learning_rate = 1e-3
training_epoch = 20
batch_size = 100
# neural network properties 
n_latent = 2 # number of latent 
n_input = 28*28   # Image pixel number


is_cuda = torch.cuda.is_available() # cuda사 사용가능시, True

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=16, shuffle=True)


Files already downloaded

In [3]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # input is 28x28
        self.fc1 = nn.Linear(n_input, 500, bias=True)
        self.fc21 = nn.Linear(500, n_latent, bias=True)
        self.fc22 = nn.Linear(500, n_latent, bias=True)
        self.fc3 = nn.Linear(n_latent, 500, bias=True)
        self.fc4 = nn.Linear(500, n_input, bias=True)
        
    def encode(self, x):
        x = x.view(-1, n_input)
        x = F.relu(self.fc1(x))
        return self.fc21(x), self.fc22(x)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if is_cuda :
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
         
    def decode(self, x):
        x = F.relu(self.fc3(x))
        x = F.sigmoid(self.fc4(x))
        return x 
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar )
        x = self.decode(z)
        return x, mu, logvar 

BCELoss = nn.BCELoss()
BCELoss.size_average = False
    
def loss_fn(recon_x, x, mu, logvar):
    BCE = BCELoss(recon_x, x)
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    return BCE + KLD

model = VAE()
if is_cuda :  model.cuda()
#loss_fn = loss_function() 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
def plot_fakeimages() :
    nx = ny = 20
    canvas = np.empty((28*ny, 28*nx))

    for i in range(nx): 
        for j in range(ny):
            z = Variable(torch.randn(1, n_latent))
            if is_cuda : z = z.cuda()
            img = model.decode(z)
            img = img.view(-1, 28, 28)
            canvas[(nx-i-1)*28:(nx-i)*28, j*28:(j+1)*28] = img.data.squeeze().cpu().numpy()

    plt.figure(figsize=(8, 10))        
    plt.imshow(canvas, origin="upper", cmap="gray")
    plt.tight_layout()
    plt.show()

In [5]:
# trainning
model.train()
train_loss = []

for epoch in range(training_epoch):
    for image, _ in train_loader:
        image = Variable(image)  # input image
        if is_cuda :  image = image.cuda()
        output, mu, logvar = model(image) # model
        loss = loss_fn(output, image, mu, logvar) #loss
        optimizer.zero_grad() # zero_grad
        loss.backward() # calc backward grad
        optimizer.step() # update parameter
        
        train_loss.append(loss.data[0])

    print('Epoch : {:04d}   loss = {:.6f}'.format(epoch, loss.data[0]))
    if (epoch + 1)%10 == 0 : plot_fakeimages()


Epoch : 0000   loss = 16634.398438
Epoch : 0001   loss = 16198.857422
Epoch : 0002   loss = 16839.330078
Epoch : 0003   loss = 16672.638672
Epoch : 0004   loss = 16437.287109
Epoch : 0005   loss = 16563.291016
Epoch : 0006   loss = 15447.595703
Epoch : 0007   loss = 15398.488281
Epoch : 0008   loss = 16562.339844
Epoch : 0009   loss = 14959.602539
Epoch : 0010   loss = 15675.335938
Epoch : 0011   loss = 16222.604492
Epoch : 0012   loss = 15306.953125
Epoch : 0013   loss = 15829.148438
Epoch : 0014   loss = 14549.642578
Epoch : 0015   loss = 15313.612305
Epoch : 0016   loss = 15638.687500
Epoch : 0017   loss = 15322.027344
Epoch : 0018   loss = 15407.458008
Epoch : 0019   loss = 14352.448242

In [6]:
plt.plot(train_loss)


Out[6]:
[<matplotlib.lines.Line2D at 0x7fc9b00dc588>]

In [7]:
if n_latent == 2 :
    model.eval()
    nx = ny = 20
    x_values = np.linspace(-3, 3, nx)
    y_values = np.linspace(-3, 3, ny)

    canvas = np.empty((28*ny, 28*nx))
    for i, yi in enumerate(x_values):
        for j, xi in enumerate(y_values):
            z = Variable(torch.FloatTensor([[xi, yi]]))
            if is_cuda : z = z.cuda()
            img = model.decode(z)
            img = img.view(-1, 28, 28)
            canvas[(nx-i-1)*28:(nx-i)*28, j*28:(j+1)*28] = img.data.squeeze().cpu().numpy()

    plt.figure(figsize=(8, 10))        
    Xi, Yi = np.meshgrid(x_values, y_values)
    plt.imshow(canvas, origin="upper", cmap="gray")
    plt.tight_layout()



In [ ]: