WNixalo – 2016/6/20
In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fastai.conv_learner import *
In [3]:
input_size = 784
encoding_dim = 32
bs = 16
tfm0 = torchvision.transforms.ToTensor() # convert [0,255] -> [0.0,1.0]
In [4]:
train_dataset = torchvision.datasets.MNIST('data/MNIST/',train=True, transform=tfm0)
test_dataset = torchvision.datasets.MNIST('data/MNIST/',train=False,transform=tfm0)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bs)
In [5]:
# create copies of dataloaders for ModelData
train_loadermd = copy.deepcopy(train_loader)
test_loadermd = copy.deepcopy(test_loader)
# set y to be x and convert [0,255] int to [0.0,1.0] float. (dl doesnt trsfm `y` by default)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_data.type(torch.FloatTensor)/255
test_loadermd.dataset.test_labels = test_loadermd.dataset.test_data.type(torch.FloatTensor)/255
# add channel dimension for compatibility. (bs,h,w) –> (bs,ch,h,w)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_labels.reshape((len(train_loadermd.dataset),1,28,28))
test_loadermd.dataset.test_labels = test_loadermd.dataset.test_labels.reshape((len(test_loadermd.dataset),1,28,28))
In [6]:
md = ModelData('data/MNIST', train_loadermd, test_loadermd)
In [8]:
def compare_batch(x, z, bs=16, figsize=(16,2)):
bs = min(len(x), bs) # digits to display
fig = plt.figure(figsize=figsize)
for i in range(bs):
# display original
ax = plt.subplot(2, bs, i+1); ax.imshow(x[i].reshape(28,28))
ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, bs, i+1+bs); ax.imshow(z[i].reshape(28,28))
ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
def space_plot(model):
model.eval()
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar(); plt.style.use('default');
def model_plots(model, dataloader):
x,y = next(iter(dataloader))
z = model(x)
if len(z) > 1: print([zi.shape for zi in z]);
compare_batch(x,z[0].detach())
space_plot(model)
In [ ]:
learner.model.encoder.
In [ ]:
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
In [62]:
class VEncoder(nn.Module):
def __init__(self, input_size, interm_size, latent_size):
super().__init__()
self.fc0 = nn.Linear(input_size, interm_size)
self.fc10 = nn.Linear(interm_size, latent_size)
self.fc11 = nn.Linear(interm_size, latent_size)
# self.debug=False
def forward(self, x):
# if self.debug:pdb.set_trace()###################################TRACE
h1 = F.relu(self.fc0(x))
μ = F.relu(self.fc10(h1))
logv = F.relu(self.fc11(h1))
return μ, logv
# class VSampler(nn.Module):
# def __init__(self):
# super().__init__()
# # self.debug=False
# def forward(self, μ, logv):
# # if self.debug:pdb.set_trace()###################################TRACE
# if self.training:
# σ = torch.exp(0.5*logv)
# ε = torch.randn_like(σ)
# return ε.mul(σ).add_(μ)
# else:
# return μ
class VDecoder(nn.Module):
def __init__(self, output_size, interm_size, latent_size):
super().__init__()
self.fc2 = nn.Linear(latent_size, interm_size)
self.fc3 = nn.Linear(interm_size, output_size)
def forward(self, x):
h3 = F.relu(self.fc2(x))
z = F.sigmoid(self.fc3(h3))
return z
class VariationalAutoencoder(nn.Module):
def __init__(self, orign_shape=784, interm_shape=512, latent_shape=2):
super().__init__()
self.encoder = VEncoder(orign_shape, interm_shape, latent_shape)
self.sampler = VSampler()
self.decoder = VDecoder(orign_shape, interm_shape, latent_shape)
def reparameterize(self, μ, logv):
if self.training:
σ = torch.exp(0.5*logv)
ε = torch.randn_like(σ)
return ε.mul(σ).add_(μ)
else:
return μ
def forward(self, x):
# pdb.set_trace()
x = x.view(x.size(0), -1) # flatten
μ,logv = self.encoder(x)
# enc = self.sampler(μ, logv)
enc = self.reparameterize(μ, logv)
z = self.decoder(enc)
z = z.view(z.size(0), 1, 28, 28)
return z, μ, logv, enc
In [61]:
def vae_loss(z, xtra, raw_loss):
μ, logv, *_ = xtra
BCE = raw_loss
KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
return KLD + BCE
In [11]:
def flat_bce(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=True)
In [81]:
def loss_function(recon_x, mu, logvar, x):
BCE = F.binary_cross_entropy(recon_x.view(recon_x.size(0),-1), x.view(x.size(0), -1), size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train(model, epoch, log_interval=100):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar, enc = model(data)
loss = loss_function(recon_batch, mu, logvar, data)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
In [73]:
def learner_test(eps=3, crit=F.binary_cross_entropy, learn=None):
if learn is None: learn = Learner.from_model_data(VariationalAutoencoder(), md)
learn.crit = crit
learn.opt_fn = torch.optim.Adam
learn.reg_fn = vae_loss
learn.fit(1e-3, eps)
return learn
def pytorch_test(eps=3):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VariationalAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i in range(1,eps+1): train(i, log_interval=800)
return model
In [30]:
learner = learner_test(1)
In [31]:
model_plots(learner.model, learner.data.val_dl)
In [33]:
learner.model.encoder.debug=True
learner.fit(1e-3,1)
In [35]:
learner.model.encoder.debug=False
In [36]:
x,y = next(iter(test_loader))
learner.model(x)
Out[36]:
In [38]:
learner.model.encoder(x.view(x.size(0),-1))
Out[38]:
In [50]:
def vae_loss(z, xtra, raw_loss):
μ, logv, *_ = xtra
BCE = raw_loss
# KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
# return KLD + BCE
return BCE
learner = learner_test(1)
In [51]:
learner.model.encoder(x.view(x.size(0),-1))
Out[51]:
In [63]:
def vae_loss(z, xtra, raw_loss):
μ, logv, *_ = xtra
BCE = raw_loss
# KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
# return KLD + BCE
return BCE
learner = learner_test(1)
learner.model.encoder(x.view(x.size(0),-1))
Out[63]:
In [82]:
eps=1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torchmodel = VariationalAutoencoder().to(device)
optimizer = optim.Adam(torchmodel.parameters(), lr=1e-3)
for i in range(1,eps+1): train(torchmodel, i, log_interval=800)
In [83]:
torchmodel.encoder(x.view(x.size(0),-1))
Out[83]:
In [84]:
model_plots(torchmodel, learner.data.val_dl)
In [85]:
for i in range(5): train(torchmodel, i, log_interval=800)
In [86]:
torchmodel.encoder(x.view(x.size(0),-1))
model_plots(torchmodel, learner.data.val_dl)
Alright jury's out. I have to redo this. Later.
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: