In [47]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt

import itertools
import os
import time
import gzip
import pickle

In [5]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

In [54]:
class generator(nn.Module):
    def __init__(self, dataset='mnist'):
        super(generator, self).__init__()
        self.input_height = 28
        self.input_width = 28
        self.input_dim = 62 + 12  # 12は何?
        self.output_dim = 1
    
        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),  # TODO: 論文だとReLU => BatchNormになってる。順番変えるとどうなる?
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.ReLU(),
        )

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Sigmoid(),
        )

        initialize_weights(self)
    
    def forward(self, input, cont_code, dist_code):
        x = torch.cat([input, cont_code, dist_code], 1)
        print('***', x.size())
        x = self.fc(x)
        x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
        x = self.deconv(x)
        
        return x

In [11]:
class discriminator(nn.Module):
    def __init__(self, dataset='mnist'):
        super(discriminator, self).__init__()
        self.input_height = 28
        self.input_width = 28
        self.input_dim = 1
        self.output_dim = 1
        self.len_discrete_code = 10   # categorical distribution (label)
        self.len_continuous_code = 2  # gaussian distribution (rotation, thickness)
        
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim + self.len_continuous_code + self.len_discrete_code),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)
    
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
        x = self.fc(x)
        a = F.sigmoid(x[:, self.output_dim])  # for real or fake
        b = x[:, self.output_dim:self.output_dim + self.len_continuous_code]  # cont_code
        c = x[:, self.output_dim + self.len_continuous_code:]  # disc_code
        
        return a, b, c

In [27]:
def load_mnist(dataset):
    data_dir = os.path.join("./data", dataset)

    def extract_data(filename, num_data, head_size, data_size):
        with gzip.open(filename) as bytestream:
            bytestream.read(head_size)
            buf = bytestream.read(data_size * num_data)
            data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
        return data

    data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
    trX = data.reshape((60000, 28, 28, 1))

    data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
    trY = data.reshape((60000))

    data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
    teX = data.reshape((10000, 28, 28, 1))

    data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
    teY = data.reshape((10000))

    trY = np.asarray(trY).astype(np.int)
    teY = np.asarray(teY)

    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)

    seed = 547
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)

    y_vec = np.zeros((len(y), 10), dtype=np.float)
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1

    X = X.transpose(0, 3, 1, 2) / 255.
    # y_vec = y_vec.transpose(0, 3, 1, 2)

    X = torch.from_numpy(X).type(torch.FloatTensor)
    y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
    return X, y_vec

In [60]:
class InfoGAN(object):
    def __init__(self):
        self.epoch = 5
        self.batch_size = 64
        self.len_discrete_code = 10
        self.len_continuous_code = 2
        
        self.G = generator()
        self.D = discriminator()
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        # 2つのparametersのgeneratorを連結している
        self.info_optimizer = optim.Adam(itertools.chain(self.G.parameters(), self.D.parameters()), lr=0.0002, betas=(0.5, 0.999))

        self.BCE_loss = nn.BCELoss()
        self.CE_loss = nn.CrossEntropyLoss()
        self.MSE_loss = nn.MSELoss()

        self.data_X, self.data_Y = load_mnist('mnist')
        self.z_dim = 62
        self.y_dim = 10

    def train(self):
        self.y_real = Variable(torch.ones(self.batch_size, 1))
        self.y_fake = Variable(torch.zeros(self.batch_size, 1))
        
        self.D.train()
        print('training start!')
        for epoch in range(self.epoch):
            self.G.train()
            for iter in range(len(self.data_X) // self.batch_size):
                x_ = self.data_X[iter * self.batch_size: (iter + 1) * self.batch_size]
                z_ = torch.rand((self.batch_size, self.z_dim))
                print('x_', x_.size())
                print('z_', z_.size())
                
                # TODO: SUPERVISEDではないパターンも検討する
                # disc codeには1-of-Kのラベルを与える
                y_disc = self.data_Y[iter * self.batch_size: (iter + 1) * self.batch_size]
                # TODO: y_contは何を意味している?
                y_cont = torch.from_numpy(np.random.uniform(-1, 1, size=(self.batch_size, 2))).type(torch.FloatTensor)

                x_, z_, y_disc, y_cont = Variable(x_), Variable(z_), Variable(y_disc), Variable(y_cont)
                
                # updte D network
                self.D_optimizer.zero_grad()
                
                # cont_codeとdisc_codeは使わない???
                D_real, _, _ = self.D(x_)
                D_real_loss = self.BCE_loss(D_real, self.y_real)
                print('D_real:', D_real.size())
                
                G_ = self.G(z_, y_cont, y_disc)  # generatorが生成した画像 (N, 1, 28, 28)
                print('y_cont:', y_cont.size())
                print('y_disc:', y_disc.size())
                print('G_:', G_.size())
                D_fake, _, _ = self.D(G_)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake)
                
                D_loss = D_real_loss + D_fake_loss
                D_loss.backward(retain_graph=True)  # TODO: retain_graphは何?
                self.D_optimizer.step()
                
                # update G network
                self.G_optimizer.zero_grad()
                
                G_ = self.G(z_, y_cont, y_disc)
                D_fake, D_cont, D_disc = self.D(G_)
                print('D_fake:', D_fake.size())
                print('D_cont:', D_cont.size())
                print('D_disc:', D_disc.size())
                
                G_loss = self.BCE_loss(D_fake, self.y_real)
                G_loss.backward(retain_graph=True)
                self.G_optimizer.step()
                
                # information loss
                disc_loss = self.CE_loss(D_disc, torch.max(y_disc, 1)[1])  # ラベルのCrossEntropyLoss
                # TODO: D_contを入力のランダムのy_contに近づけようとしている???
                cont_loss = self.MSE_loss(D_cont, y_cont)
                info_loss = disc_loss + cont_loss
                print('info_loss:', info_loss.data[0])
                info_loss.backward()
                self.info_optimizer.step()
                break
            break

In [61]:
infogan = InfoGAN()
infogan.train()


training start!
x_ torch.Size([64, 1, 28, 28])
z_ torch.Size([64, 62])
D_real: torch.Size([64])
*** torch.Size([64, 74])
y_cont: torch.Size([64, 2])
y_disc: torch.Size([64, 10])
G_: torch.Size([64, 1, 28, 28])
/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/torch/nn/functional.py:767: UserWarning: Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([64])) is deprecated. Please ensure they have the same size.
  "Please ensure they have the same size.".format(target.size(), input.size()))
*** torch.Size([64, 74])
D_fake: torch.Size([64])
D_cont: torch.Size([64, 2])
D_disc: torch.Size([64, 10])
info_loss: 2.9132559299468994

In [ ]: