WNixalo – 2016/6/24-5
In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fastai.conv_learner import *
In [3]:
input_size = 784
encoding_dim = 32
bs = 16
tfm0 = torchvision.transforms.ToTensor() # convert [0,255] -> [0.0,1.0]
In [4]:
train_dataset = torchvision.datasets.MNIST('data/MNIST/',train=True, transform=tfm0)
test_dataset = torchvision.datasets.MNIST('data/MNIST/',train=False,transform=tfm0)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bs)
In [5]:
# create copies of dataloaders for ModelData
train_loadermd = copy.deepcopy(train_loader)
test_loadermd = copy.deepcopy(test_loader)
# set y to be x and convert [0,255] int to [0.0,1.0] float. (dl doesnt trsfm `y` by default)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_data.type(torch.FloatTensor)/255
test_loadermd.dataset.test_labels = test_loadermd.dataset.test_data.type(torch.FloatTensor)/255
# add channel dimension for compatibility. (bs,h,w) –> (bs,ch,h,w)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_labels.reshape((len(train_loadermd.dataset),1,28,28))
test_loadermd.dataset.test_labels = test_loadermd.dataset.test_labels.reshape((len(test_loadermd.dataset),1,28,28))
In [6]:
md = ModelData('data/MNIST', train_loadermd, test_loadermd)
In [22]:
def compare_batch(x, z, bs=16, figsize=(16,2)):
bs = min(len(x), bs) # digits to display
fig = plt.figure(figsize=figsize)
for i in range(bs):
# display original
ax = plt.subplot(2, bs, i+1); ax.imshow(x[i].reshape(28,28))
ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, bs, i+1+bs); ax.imshow(z[i].reshape(28,28))
ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
def space_plot(model):
model.eval()
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar(); plt.style.use('default');
def model_plots(model, dataloader):
x,y = next(iter(dataloader))
z = model(x)
if len(z) > 1: print([zi.shape for zi in z]);
compare_batch(x,z[0].detach())
space_plot(model)
In [8]:
# working pytorch example
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
In [48]:
# my version I'm trying to get to work
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(784,512)
self.fc10 = nn.Linear(512,2)
self.fc11 = nn.Linear(512,2)
def forward(self, x):
h1 = F.relu(self.fc0(x))
mu,logv = self.fc10(h1), self.fc11(h1)
return mu, logv
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.fc2 = nn.Linear(2,512)
self.fc3 = nn.Linear(512,784)
def forward(self, x):
h3 = F.relu(self.fc2(x))
z = F.sigmoid(self.fc3(h3))
return z
class VariationalAutoencoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def reparameterize(self, mu, logv):
if self.training:
sig = torch.exp(0.5*logv)
eps = torch.randn_like(sig)
return eps*sig + mu ## eps.mul(sig).add_(mu)
# return eps.mul(sig).add_(mu)
else: return mu
def forward(self, x):
shape = x.shape
x = x.view(x.size(0),-1)
mu, logv = self.encoder(x)
h = self.reparameterize(mu,logv)
z = self.decoder(h)
z = z.view(z.size(0), *shape[1:])
# z = z.view(z.size(0),1,28,28)
return z, mu, logv, h
In [38]:
def bce_flat(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=True)
def vae_loss(output, xtra, raw_loss):
mu,logv, *_ = xtra
BCE = raw_loss
KLD = -0.5 * torch.sum(1 + logv - mu**2 - torch.exp(logv))
# KLD = -0.5 * torch.sum(1 + logv - mu.pow(2) - torch.exp(logv))
return BCE + KLD
In [41]:
learner = Learner.from_model_data(VariationalAutoencoder(), md)
# learner.crit = bce_flat
learner.crit = F.binary_cross_entropy
learner.opt_fn = torch.optim.Adam
learner.reg_fn = vae_loss
In [42]:
learner.fit(1e-3, 1)
Out[42]:
In [43]:
x,y = next(iter(learner.data.val_dl))
z = learner.predict()
compare_batch(x,z)
In [44]:
learner.model(x)[1]
Out[44]:
In [45]:
learner.model(x)[2]
Out[45]:
In [46]:
len(learner.model(x))
Out[46]:
In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VariationalAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def loss_function(recon_x, data, mu, logvar):
BCE = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784),
size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train(epoch, log_interval=800):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar, enc = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
for i in range(3): train(i+1)
In [56]:
x,y = next(iter(learner.data.val_dl))
z = model(x)
compare_batch(x,z[0].detach())
In [64]:
z[1]
Out[64]:
In [65]:
z[2]
Out[65]:
In [66]:
for i in range(3,6): train(i+1)
x,y = next(iter(learner.data.val_dl))
z = model(x)
compare_batch(x,z[0].detach())
In [68]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def loss_function(recon_x, data, mu, logvar):
BCE = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784),
size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def train(epoch, log_interval=800):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
for i in range(3): train(i+1)
In [78]:
x,y = next(iter(learner.data.val_dl))
z = model(x)
compare_batch(x,z[0].detach())
In [70]:
z[1]
Out[70]:
In [71]:
z[2]
Out[71]:
In [72]:
model
Out[72]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: