In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fastai.conv_learner import *
In [3]:
input_size = 784
encoding_dim = 32
bs = 16
tfm0 = torchvision.transforms.ToTensor() # convert [0,255] -> [0.0,1.0]
In [4]:
train_dataset = torchvision.datasets.MNIST('data/MNIST/',train=True, transform=tfm0)
test_dataset = torchvision.datasets.MNIST('data/MNIST/',train=False,transform=tfm0)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=bs)
In [5]:
# create copies of dataloaders for ModelData
train_loadermd = copy.deepcopy(train_loader)
test_loadermd = copy.deepcopy(test_loader)
# set y to be x and convert [0,255] int to [0.0,1.0] float. (dl doesnt trsfm `y` by default)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_data.type(torch.FloatTensor)/255
test_loadermd.dataset.test_labels = test_loadermd.dataset.test_data.type(torch.FloatTensor)/255
# add channel dimension for compatibility. (bs,h,w) –> (bs,ch,h,w)
train_loadermd.dataset.train_labels = train_loadermd.dataset.train_labels.reshape((len(train_loadermd.dataset),1,28,28))
test_loadermd.dataset.test_labels = test_loadermd.dataset.test_labels.reshape((len(test_loadermd.dataset),1,28,28))
In [6]:
md = ModelData('data/MNIST', train_loadermd, test_loadermd)
In [7]:
def compare_batch(x, z, bs=16, figsize=(16,2)):
bs = min(len(x), bs) # digits to display
fig = plt.figure(figsize=figsize)
for i in range(bs):
# display original
ax = plt.subplot(2, bs, i+1); ax.imshow(x[i].reshape(28,28))
ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, bs, i+1+bs); ax.imshow(z[i].reshape(28,28))
ax.get_xaxis().set_visible(False); ax.get_yaxis().set_visible(False)
In [8]:
class VEncoder(nn.Module):
def __init__(self, input_size, interm_size, latent_size):
#TODO
def forward(self, x):
#TODO
class VSampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
#TODO
class VDecoder(nn.Module):
def __init__(self, output_size, interm_size, latent_size):
#TODO
def forward(self, x):
#TODO
class VariationalAutoencoder(nn.Module):
def __init__(self, orign_shape=784, interm_shape=32, latent_shape=16):
super().__init__()
#todo
def forward(self, x):
#TODO
def vae_loss(z, xtra, raw_loss):
#TODO
In [8]:
class VEncoder(nn.Module):
"""Returns intermediate encodings, mean, and log(stdev) tensors."""
def __init__(self, input_size, interm_size, latent_size):
super().__init__()
self.intermediate = nn.Linear(input_size, interm_size)
self.mean_layer = nn.Linear(interm_size, latent_size)
self.stdv_layer = nn.Linear(interm_size, latent_size)
def forward(self, x):
x = F.relu(self.intermediate(x))
μ = F.relu(self.mean_layer(x)) # Mean vector
log_σ = F.relu(self.stdv_layer(x)) # Stdv vector
return x, μ, log_σ
class VSampler(nn.Module):
"""
Multiplies standard deviation vector by a ~N(0,1) Gaussian distribution.
Returns mean + new stdev.
For theory see: https://youtu.be/uaaqyVS9-rM?t=19m42s
"""
def __init__(self):
super().__init__()
def forward(self, x):
μ, log_σ = x
eps = torch.randn(μ.shape) # should I set `requires_grad=True`?
return μ + torch.exp(log_σ)*eps
class VDecoder(nn.Module):
"""Decodes sampled """
def __init__(self, output_size, interm_size, latent_size):
super().__init__()
self.intermediate = nn.Linear(latent_size, interm_size)
self.out = nn.Linear(interm_size, output_size)
def forward(self, x):
x = F.relu(self.intermediate(x))
x = F.sigmoid(self.out(x))
return x
class VariationalAutoencoder(nn.Module):
def __init__(self, orign_shape=784, interm_shape=32, latent_shape=16):
super().__init__()
self.encoder = VEncoder(orign_shape, interm_shape, latent_shape)
self.sampler = VSampler()
self.decoder = VDecoder(orign_shape, interm_shape, latent_shape)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten
enc_x, *μ_log_σ = self.encoder(x) # encode
x = self.sampler(μ_log_σ) # sample
x = self.decoder(x) # decode
x = x.reshape(x.size(0),1,28,28) # 'unflatten' -- could I use x.view(..)?
return x, μ_log_σ, enc_x
def vae_loss(z, xtra, raw_loss):
μ_log_σ, _ = xtra
μ, log_σ = μ_log_σ
σ = torch.exp(log_σ)
reconstruction_loss = raw_loss
kl_divergence_loss = -0.5 * torch.sum(1. + torch.exp(torch.log(σ**2)) + μ**2 - σ**2)
return reconstruction_loss + kl_divergence_loss
In [67]:
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21= nn.Linear(400, 20)
self.fc22= nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def decode(self, z):
h3 = F.relu(self.fc3(z))
return F.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar, z
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def train(epoch, log_interval=100):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar, enc = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
In [66]:
train(1)
In [54]:
interm_shape = 512
latent_shape = 2
learn = Learner.from_model_data(
VariationalAutoencoder(interm_shape=interm_shape, latent_shape=latent_shape), md)
learn.opt_fn = torch.optim.RMSprop
learn.crit = F.binary_cross_entropy
learn.reg_fn = vae_loss
In [55]:
learn.lr_find()
learn.sched.plot()
In [56]:
%time learn.fit(2e-5, 5)
Out[56]:
In [59]:
x,y = next(iter(test_loader))
In [62]:
z = learn.model(x)
compare_batch(x,z[0].detach())
In [65]:
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μlogvar, enc_x = learn.model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar(); plt.style.use('default')
In [ ]:
In [ ]:
In [36]:
x,y = next(iter(test_loader))
In [37]:
z = model(x)
compare_batch(x,z[0].detach())
In [38]:
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar();
In [40]:
for i in range(2,4): train(i, log_interval=400)
In [41]:
z = model(x)
compare_batch(x,z[0].detach())
In [42]:
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar();
In [43]:
for i in range(4,9): train(i, log_interval=800)
In [45]:
z = model(x)
compare_batch(x,z[0].detach())
In [46]:
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar();
In [68]:
model.eval()
z = model(x)
compare_batch(x,z[0].detach())
In [69]:
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar();
In [71]:
for i in range(9, 16): train(i, log_interval=600)
In [72]:
model.eval()
z = model(x)
compare_batch(x,z[0].detach())
In [73]:
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar();
In [74]:
for i in range(16, 26): train(i, log_interval=1200)
In [75]:
model.eval()
z = model(x)
compare_batch(x,z[0].detach())
In [80]:
model.eval()
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar(); plt.style.use('default');
In [79]:
model.train()
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
z, μ, logv = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y, alpha=0.1);
plt.colorbar(); plt.style.use('default');
Here it is:
In [ ]:
def reparameterize(self, μ, logv):
if self.training:
σ = torch.exp(0.5*logv)
ε = torch.randn_like(σ)
return ε.mul(σ).add_(μ)
else:
return μ
def loss_function(recon_x, x, μ, logv):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
KLD = -0.5 * torch.sum(1 + logv - μ.pow(2) - logv.exp())
return BCE + KLD
In [105]:
tensor = torch.randn(5,5); tensor
tensor.add_()
Out[105]:
In [143]:
def space_plot(model):
model.eval()
plt.figure(figsize=(6,6)); plt.style.use('classic');
for x,y in iter(test_loader):
z, μ, logvar, enc_x = model(x)
plt.scatter(enc_x.detach()[:,0], enc_x.detach()[:,1], c=y);
plt.colorbar(); plt.style.use('default');
In [155]:
class VEncoder(nn.Module):
def __init__(self, input_size, interm_size, latent_size):
super().__init__()
self.fc0 = nn.Linear(input_size, interm_size)
self.fc10 = nn.Linear(interm_size, latent_size)
self.fc11 = nn.Linear(interm_size, latent_size)
def forward(self, x):
h1 = F.relu(self.fc0(x))
μ = F.relu(self.fc10(h1))
logv = F.relu(self.fc11(h1))
return μ, logv
class VSampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, μ, logv):
σ = torch.exp(0.5*logv)
ε = torch.randn_like(σ)
return ε.mul(σ).add_(μ)
class VDecoder(nn.Module):
def __init__(self, output_size, interm_size, latent_size):
super().__init__()
self.fc2 = nn.Linear(latent_size, interm_size)
self.fc3 = nn.Linear(interm_size, output_size)
def forward(self, x):
h3 = F.relu(self.fc2(x))
z = F.sigmoid(self.fc3(h3))
return z
class VariationalAutoencoder(nn.Module):
def __init__(self, orign_shape=784, interm_shape=512, latent_shape=2):
super().__init__()
self.encoder = VEncoder(orign_shape, interm_shape, latent_shape)
self.sampler = VSampler()
self.decoder = VDecoder(orign_shape, interm_shape, latent_shape)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten
μ,logv = self.encoder(x)
enc = self.sampler(μ, logv)
z = self.decoder(enc)
z = z.view(z.size(0), 1, 28, 28)
return z, μ, logv, enc
def vae_loss(z, xtra, raw_loss):
μ, logv, *_ = xtra
BCE = raw_loss
KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
return KLD + BCE
In [156]:
learn = Learner.from_model_data(VariationalAutoencoder(), md)
learn.crit = F.binary_cross_entropy
learn.opt_fn = torch.optim.Adam
learn.reg_fn = vae_loss
In [157]:
learn.lr_find()
learn.sched.plot()
In [158]:
%time learn.fit(1e-4, 2)
Out[158]:
In [159]:
z = learn.model(x)
compare_batch(x,z[0].detach())
In [160]:
space_plot(learn.model)
Can this train just in pytorch?
In [170]:
def loss_function(recon_x, mu, logvar, x):
BCE = F.binary_cross_entropy(recon_x.view(recon_x.size(0),-1), x.view(x.size(0), -1), size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VariationalAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def train(epoch, log_interval=100):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar, enc = model(data)
loss = loss_function(recon_batch, mu, logvar, data)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
for i in range(2): train(i+1, log_interval=400)
In [171]:
z = model(x)
compare_batch(x,z[0].detach())
In [172]:
space_plot(model)
It's starting to train...
That means my architecture is fine.
The issue is the training mechanics. I have to see what's going on in fastai's training process that's causing problems.
In [173]:
x,y = next(iter(learn.data.val_dl))
In [174]:
y.shape
Out[174]:
Is the problem that Y is not flattened when compared to X in BCE?
In [180]:
model(x)[0].shape, x.shape
Out[180]:
In [198]:
def loss_function(recon_x, mu, logvar, x):
BCE = F.binary_cross_entropy(recon_x.view(recon_x.size(0),-1), x.view(x.size(0), -1), size_average=False)
# BCE = F.binary_cross_entropy(
# recon_x.view(recon_x.size(0),1,28,28), x.view(x.size(0),1,28,28), size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# def vae_loss(z, xtra, raw_loss):
# μ, logv, *_ = xtra
# BCE = raw_loss
# KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
# return KLD + BCE
In [182]:
model = VariationalAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i in range(1,3): train(i, log_interval=400)
In [183]:
z = model(x)
compare_batch(x,z[0].detach())
In [184]:
space_plot(model)
In [192]:
def flat_bce(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=False)
In [193]:
learn = Learner.from_model_data(VariationalAutoencoder(), md)
learn.crit = flat_bce
learn.opt_fn = torch.optim.Adam
learn.reg_fn = vae_loss
In [194]:
learn.lr_find()
learn.sched.plot()
In [195]:
%time learn.fit(1e-3, 2)
Out[195]:
In [196]:
z = learn.model(x)
compare_batch(x,z[0].detach())
In [197]:
space_plot(learn.model)
In [199]:
model = VariationalAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i in range(1,3): train(i, log_interval=800)
In [202]:
model.eval(); z = model(x)
compare_batch(x,z[0].detach())
space_plot(model)
In [200]:
z = learn.model(x)
compare_batch(x,z[0].detach())
space_plot(learn.model)
I need to see what the tensor shapes going in and out of the loss functions are. There has to be a reason why the fastai Learner isn't training.
In [205]:
def vae_loss(z, xtra, raw_loss):
# pdb.set_trace()
μ, logv, *_ = xtra
BCE = raw_loss
# KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
# return KLD + BCE
return BCE
In [206]:
def loss_function(recon_x, mu, logvar, x):
# pdb.set_trace()
BCE = F.binary_cross_entropy(recon_x.view(recon_x.size(0),-1), x.view(x.size(0), -1), size_average=False)
# KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# return BCE + KLD
return BCE
model = VariationalAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i in range(1,3): train(i, log_interval=800)
In [207]:
z = model(x)
compare_batch(x,z[0].detach())
space_plot(model)
In [208]:
learn = Learner.from_model_data(VariationalAutoencoder(), md)
learn.crit = flat_bce
learn.opt_fn = torch.optim.Adam
learn.reg_fn = vae_loss
learn.fit(1e-3, 3)
Out[208]:
In [211]:
x,y = next(iter(learn.data.val_dl))
z = learn.model(x)
compare_batch(x,z[0].detach())
space_plot(learn.model)
In [9]:
class VEncoder(nn.Module):
def __init__(self, input_size, interm_size, latent_size):
super().__init__()
self.fc0 = nn.Linear(input_size, interm_size)
self.fc10 = nn.Linear(interm_size, latent_size)
self.fc11 = nn.Linear(interm_size, latent_size)
def forward(self, x):
h1 = F.relu(self.fc0(x))
μ = F.relu(self.fc10(h1))
logv = F.relu(self.fc11(h1))
return μ, logv
class VSampler(nn.Module):
def __init__(self):
super().__init__()
def forward(self, μ, logv):
if self.training:
σ = torch.exp(0.5*logv)
ε = torch.randn_like(σ)
return ε.mul(σ).add_(μ)
else:
return μ
class VDecoder(nn.Module):
def __init__(self, output_size, interm_size, latent_size):
super().__init__()
self.fc2 = nn.Linear(latent_size, interm_size)
self.fc3 = nn.Linear(interm_size, output_size)
def forward(self, x):
h3 = F.relu(self.fc2(x))
z = F.sigmoid(self.fc3(h3))
return z
class VariationalAutoencoder(nn.Module):
def __init__(self, orign_shape=784, interm_shape=512, latent_shape=2):
super().__init__()
self.encoder = VEncoder(orign_shape, interm_shape, latent_shape)
self.sampler = VSampler()
self.decoder = VDecoder(orign_shape, interm_shape, latent_shape)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten
μ,logv = self.encoder(x)
enc = self.sampler(μ, logv)
z = self.decoder(enc)
z = z.view(z.size(0), 1, 28, 28)
return z, μ, logv, enc
def vae_loss(z, xtra, raw_loss):
μ, logv, *_ = xtra
BCE = raw_loss
KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
return KLD + BCE
In [10]:
def model_plots(model, dataloader=learn.data.val_dl):
x,y = next(iter(dataloader))
z = model(x)
if len(z) > 1: print([zi.shape for zi in z]);
compare_batch(x,z[0].detach())
space_plot(model)
In [14]:
def learner_test(eps=3, crit=flat_bce):
learn = Learner.from_model_data(VariationalAutoencoder(), md)
learn.crit = crit
learn.opt_fn = torch.optim.Adam
learn.reg_fn = vae_loss
learn.fit(1e-3, eps)
return learn
def pytorch_test(eps=3):
model = VariationalAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for i in range(1,eps): train(i, log_interval=800)
return model
fastai Learner loss functions
In [11]:
def flat_bce(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=False)
def vae_loss(z, xtra, raw_loss):
# pdb.set_trace()
μ, logv, *_ = xtra
BCE = raw_loss
KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
return BCE + KLD
return BCE
pytorch model loss function
In [12]:
def loss_function(recon_x, mu, logvar, x):
# pdb.set_trace()
BCE = F.binary_cross_entropy(recon_x.view(recon_x.size(0),-1), x.view(x.size(0), -1), size_average=False)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# return BCE
In [242]:
def flat_bce(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=True)
def vae_loss(z, xtra, raw_loss):
μ, logv, *_ = xtra
BCE = raw_loss
return BCE
learn = learner_test(crit=flat_bce)
model_plots(learn.model)
In [243]:
def flat_bce(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=False)
learn = learner_test(crit=flat_bce)
model_plots(learn.model)
In [244]:
learn = learner_test(crit=F.binary_cross_entropy)
model_plots(learn.model)
In [245]:
def vae_loss(z, xtra, raw_loss):
# pdb.set_trace()
μ, logv, *_ = xtra
BCE = raw_loss
KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
return BCE + KLD
learn = learner_test(crit=F.binary_cross_entropy)
model_plots(learn.model)
In [246]:
def flat_bce(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=True)
learn = learner_test(crit=flat_bce)
model_plots(learn.model)
In [247]:
def vae_loss(z, xtra, raw_loss):
pdb.set_trace()
μ, logv, *_ = xtra
BCE = raw_loss
KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
return BCE + KLD
def flat_bce(preds, targs):
return F.binary_cross_entropy(preds.view(preds.size(0),-1),
targs.view(targs.size(0),-1), size_average=True)
learn = learner_test(crit=flat_bce)
model_plots(learn.model)
In [ ]:
In [ ]:
In [ ]:
# class saveBCEKLD(Callback):
# def __init__(self):
# self.bce = []
# self.kld = []
# def on_batch_end(self)
In [ ]:
2
In [ ]:
# def flat_bce(preds, targs):
# return F.binary_cross_entropy(preds.view(preds.size(0),-1),
# targs.view(targs.size(0),-1), size_average=True)
def vae_loss(z, xtra, raw_loss):
μ, logv, *_ = xtra
BCE = raw_loss
KLD = -0.5*torch.sum(1 + logv - μ.pow(2) - logv.exp())
# if np.isclose(KLD.detach(), 0.0): pdb.set_trace()
# print(f'\nBCE: {BCE.data}, KLD: {KLD.data}')
return BCE + KLD
learn = learner_test(crit=F.binary_cross_entropy)
model_plots(learn.model)
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: