In [ ]:
%matplotlib inline
%env CHAINER_TYPE_CHECK 0
import dlt
import numpy as np
import seaborn as sns
import chainer as C
import matplotlib.pyplot as plt
import itertools as it

train = dlt.load_hdf5('/data/uji/train.hdf')
valid = dlt.load_hdf5('/data/uji/valid.hdf')
print("  Training: %s" % train)
print("Validation: %s" % valid)

In [ ]:
# Utility functions

def show(batch, im_size=1.5):
    '''Show a batch of images.
    '''
    if batch.ndim == 2:
        batch = batch[np.newaxis, ...]
    size = int(np.sqrt(batch.shape[-1]))
    plt.figure(figsize=(im_size * batch.shape[1], im_size * batch.shape[0]))
    for plot_index, x in zip(it.count(1), batch.reshape(-1, batch.shape[-1])):
        plt.subplot(batch.shape[0], batch.shape[1], plot_index)
        plt.imshow(x.reshape(size, size))
        plt.gca().set_xticks([])
        plt.gca().set_yticks([])
        plt.gca().grid(False)
        
class GradientFlip(C.Function):
    def __init__(self, f=-1):
        self.f = f
    def forward(self, inputs):
        return inputs
    def backward(self, inputs, grad_outputs):
        return tuple(self.f * x for x in grad_outputs)

In [ ]:
# Optimization objective
#   min_G [ max_C [ mean_{x~X} [ C(x) ] - mean_{z~G} [ C(z) ] ] ]
# Requiring C is 1-Libschitz

class HiddenLayers(C.ChainList):
    def __init__(self, d, n=1):
        super().__init__(*(C.links.Linear(d, d) for _ in range(n)))
    def __call__(self, x):
        for layer in self:
            x = C.functions.elu(layer(x))
        return x

class Generator(C.Chain):
    def __init__(self, D_noise, D_hidden, D_output, nhidden=1):
        self.D_noise = D_noise
        super().__init__(
            initial=C.links.Linear(D_noise, D_hidden),
            hidden=HiddenLayers(D_hidden, nhidden),
            final=C.links.Linear(D_hidden, D_output),
        )
    def __call__(self, N):
        noise = C.Variable(np.random.rand(N, self.D_noise).astype(np.float32))
        hidden = self.hidden(C.functions.elu(self.initial(noise)))
        return C.functions.sigmoid(self.final(hidden))

class Critic(C.Chain):
    def __init__(self, D_input, D_hidden, nhidden=1):
        super().__init__(
            initial=C.links.Linear(D_input, D_hidden),
            hidden=HiddenLayers(D_hidden, nhidden),
            final=C.links.Linear(D_hidden, 1, nobias=True),
        )
    def __call__(self, batch):
        hidden = self.hidden(C.functions.elu(self.initial(batch)))
        return C.functions.sum(self.final(hidden)) / batch.shape[0]

In [ ]:
batch_size = 128
max_batches = 100000
clip = 1
backratio = -0.01

sample_every = max_batches // 10
network = C.Chain(
    generator=Generator(10, 128, 256, nhidden=1),
    critic=Critic(256, 128, nhidden=2),
)
opt = C.optimizers.Adam()
opt.use_cleargrads()
opt.setup(network)

log = dlt.Log()
samples = []
for step, i in enumerate(it.islice(it.cycle(range(0, len(train.x), batch_size)), 0, max_batches)):
    network.cleargrads()
    real = C.Variable(train.x[i:i+batch_size])
    fake = GradientFlip(backratio)(network.generator(batch_size))
    loss = network.critic(fake) - network.critic(real)
    for p in network.critic.params():
        r = clip / np.sqrt(p.data.size)
        p.data = np.clip(p.data, -r, r)
    loss.backward()
    opt.update()
    log.add('loss', 'train', loss)
    if step % sample_every == 0:
        samples.append(fake.data)

log.show()
show(train.x[:5])
show(np.stack(samples)[:, :5, :])

In [ ]:
n = sum(1 for _ in network.critic.params())
plt.figure(figsize=(12, 4*n))
for i, (name, param) in enumerate(network.critic.namedparams()):
    plt.subplot(n, 1, i+1)
    plt.title(name)
    sns.distplot(param.data.flatten(), kde=False)

In [ ]:
show(network.generator(64).data.reshape((8,8,-1)))