In [1]:
require 'torch'
require 'nn'
require 'nngraph'
nngraph.setDebug(false)
-- load custom modules
require 'load'
require 'Sampler'
sampler = nn.Sampler()
In [2]:
-- get data
data = loadmnist()
train = data.train
masked_train = train:clone()
masked_train[{{},{1,392}}] = 0
In [3]:
-- load saved cvae model
saved = torch.load('save/CVAE_z2.t7')
model = saved.model
prior = saved.prior
encoder = saved.encoder
decoder = saved.decoder
-- load saved mlp model
saved = torch.load('save/MLP.t7')
mlp = saved.model
In [63]:
-- select images
N = 10
indices = torch.randperm(train:size(1)):long():split(N)
v = indices[1]
x_input = masked_train:index(1, v)
y_input = train:index(1, v)
-- cvae
pmu, plogv, mu, logv, recon = unpack(model:forward({x_input, y_input}))
-- original image
itorch.image(y_input:reshape(N, 1, 28, 28))
-- masked image
itorch.image(x_input:reshape(N, 1, 28, 28))
-- reconstructed image
itorch.image(recon:reshape(N, 1, 28, 28))
In [76]:
-- select images
N = 54
indices = torch.randperm(train:size(1)):long():split(N)
v = indices[1]
x_input = masked_train:index(1, v)
y_input = train:index(1, v)
-- original image
itorch.image(y_input:reshape(N, 1, 28, 28))
-- masked image
itorch.image(x_input:reshape(N, 1, 28, 28))
-- cvae
pmu, plogv = unpack(prior:forward(x_input))
code = sampler:forward({pmu, plogv})
recon = decoder:forward({x_input, code})
itorch.image(recon:reshape(N, 1, 28, 28))
-- mlp
mlp_recon = mlp:forward(x_input)
itorch.image(mlp_recon:reshape(N, 1, 28, 28))