In [ ]:
%matplotlib inline
%env CHAINER_TYPE_CHECK 0
import dlt
import numpy as np
import seaborn as sns
import chainer as C
import matplotlib.pyplot as plt
import itertools as it
train = dlt.load_hdf5('/data/uji/train.hdf')
valid = dlt.load_hdf5('/data/uji/valid.hdf')
print(" Training: %s" % train)
print("Validation: %s" % valid)
In [ ]:
# Utility functions
def show(batch, im_size=1.5):
'''Show a batch of images.
'''
if batch.ndim == 2:
batch = batch[np.newaxis, ...]
size = int(np.sqrt(batch.shape[-1]))
plt.figure(figsize=(im_size * batch.shape[1], im_size * batch.shape[0]))
for plot_index, x in zip(it.count(1), batch.reshape(-1, batch.shape[-1])):
plt.subplot(batch.shape[0], batch.shape[1], plot_index)
plt.imshow(x.reshape(size, size))
plt.gca().set_xticks([])
plt.gca().set_yticks([])
plt.gca().grid(False)
class GradientFlip(C.Function):
def __init__(self, f=-1):
self.f = f
def forward(self, inputs):
return inputs
def backward(self, inputs, grad_outputs):
return tuple(self.f * x for x in grad_outputs)
In [ ]:
# Optimization objective
# min_G [ max_C [ mean_{x~X} [ C(x) ] - mean_{z~G} [ C(z) ] ] ]
# Requiring C is 1-Libschitz
class HiddenLayers(C.ChainList):
def __init__(self, d, n=1):
super().__init__(*(C.links.Linear(d, d) for _ in range(n)))
def __call__(self, x):
for layer in self:
x = C.functions.elu(layer(x))
return x
class Generator(C.Chain):
def __init__(self, D_noise, D_hidden, D_output, nhidden=1):
self.D_noise = D_noise
super().__init__(
initial=C.links.Linear(D_noise, D_hidden),
hidden=HiddenLayers(D_hidden, nhidden),
final=C.links.Linear(D_hidden, D_output),
)
def __call__(self, N):
noise = C.Variable(np.random.rand(N, self.D_noise).astype(np.float32))
hidden = self.hidden(C.functions.elu(self.initial(noise)))
return C.functions.sigmoid(self.final(hidden))
class Critic(C.Chain):
def __init__(self, D_input, D_hidden, nhidden=1):
super().__init__(
initial=C.links.Linear(D_input, D_hidden),
hidden=HiddenLayers(D_hidden, nhidden),
final=C.links.Linear(D_hidden, 1, nobias=True),
)
def __call__(self, batch):
hidden = self.hidden(C.functions.elu(self.initial(batch)))
return C.functions.sum(self.final(hidden)) / batch.shape[0]
In [ ]:
batch_size = 128
max_batches = 100000
clip = 1
backratio = -0.01
sample_every = max_batches // 10
network = C.Chain(
generator=Generator(10, 128, 256, nhidden=1),
critic=Critic(256, 128, nhidden=2),
)
opt = C.optimizers.Adam()
opt.use_cleargrads()
opt.setup(network)
log = dlt.Log()
samples = []
for step, i in enumerate(it.islice(it.cycle(range(0, len(train.x), batch_size)), 0, max_batches)):
network.cleargrads()
real = C.Variable(train.x[i:i+batch_size])
fake = GradientFlip(backratio)(network.generator(batch_size))
loss = network.critic(fake) - network.critic(real)
for p in network.critic.params():
r = clip / np.sqrt(p.data.size)
p.data = np.clip(p.data, -r, r)
loss.backward()
opt.update()
log.add('loss', 'train', loss)
if step % sample_every == 0:
samples.append(fake.data)
log.show()
show(train.x[:5])
show(np.stack(samples)[:, :5, :])
In [ ]:
n = sum(1 for _ in network.critic.params())
plt.figure(figsize=(12, 4*n))
for i, (name, param) in enumerate(network.critic.namedparams()):
plt.subplot(n, 1, i+1)
plt.title(name)
sns.distplot(param.data.flatten(), kde=False)
In [ ]:
show(network.generator(64).data.reshape((8,8,-1)))