In [47]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import itertools
import os
import time
import gzip
import pickle
In [5]:
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
In [54]:
class generator(nn.Module):
def __init__(self, dataset='mnist'):
super(generator, self).__init__()
self.input_height = 28
self.input_width = 28
self.input_dim = 62 + 12 # 12は何?
self.output_dim = 1
self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
nn.BatchNorm1d(1024), # TODO: 論文だとReLU => BatchNormになってる。順番変えるとどうなる?
nn.ReLU(),
nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
nn.ReLU(),
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
nn.Sigmoid(),
)
initialize_weights(self)
def forward(self, input, cont_code, dist_code):
x = torch.cat([input, cont_code, dist_code], 1)
print('***', x.size())
x = self.fc(x)
x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
x = self.deconv(x)
return x
In [11]:
class discriminator(nn.Module):
def __init__(self, dataset='mnist'):
super(discriminator, self).__init__()
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
self.len_discrete_code = 10 # categorical distribution (label)
self.len_continuous_code = 2 # gaussian distribution (rotation, thickness)
self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
)
self.fc = nn.Sequential(
nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, self.output_dim + self.len_continuous_code + self.len_discrete_code),
nn.Sigmoid(),
)
initialize_weights(self)
def forward(self, input):
x = self.conv(input)
x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
x = self.fc(x)
a = F.sigmoid(x[:, self.output_dim]) # for real or fake
b = x[:, self.output_dim:self.output_dim + self.len_continuous_code] # cont_code
c = x[:, self.output_dim + self.len_continuous_code:] # disc_code
return a, b, c
In [27]:
def load_mnist(dataset):
data_dir = os.path.join("./data", dataset)
def extract_data(filename, num_data, head_size, data_size):
with gzip.open(filename) as bytestream:
bytestream.read(head_size)
buf = bytestream.read(data_size * num_data)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
return data
data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
trX = data.reshape((60000, 28, 28, 1))
data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
trY = data.reshape((60000))
data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
teX = data.reshape((10000, 28, 28, 1))
data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
teY = data.reshape((10000))
trY = np.asarray(trY).astype(np.int)
teY = np.asarray(teY)
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0).astype(np.int)
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i, y[i]] = 1
X = X.transpose(0, 3, 1, 2) / 255.
# y_vec = y_vec.transpose(0, 3, 1, 2)
X = torch.from_numpy(X).type(torch.FloatTensor)
y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
return X, y_vec
In [60]:
class InfoGAN(object):
def __init__(self):
self.epoch = 5
self.batch_size = 64
self.len_discrete_code = 10
self.len_continuous_code = 2
self.G = generator()
self.D = discriminator()
self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 2つのparametersのgeneratorを連結している
self.info_optimizer = optim.Adam(itertools.chain(self.G.parameters(), self.D.parameters()), lr=0.0002, betas=(0.5, 0.999))
self.BCE_loss = nn.BCELoss()
self.CE_loss = nn.CrossEntropyLoss()
self.MSE_loss = nn.MSELoss()
self.data_X, self.data_Y = load_mnist('mnist')
self.z_dim = 62
self.y_dim = 10
def train(self):
self.y_real = Variable(torch.ones(self.batch_size, 1))
self.y_fake = Variable(torch.zeros(self.batch_size, 1))
self.D.train()
print('training start!')
for epoch in range(self.epoch):
self.G.train()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter * self.batch_size: (iter + 1) * self.batch_size]
z_ = torch.rand((self.batch_size, self.z_dim))
print('x_', x_.size())
print('z_', z_.size())
# TODO: SUPERVISEDではないパターンも検討する
# disc codeには1-of-Kのラベルを与える
y_disc = self.data_Y[iter * self.batch_size: (iter + 1) * self.batch_size]
# TODO: y_contは何を意味している?
y_cont = torch.from_numpy(np.random.uniform(-1, 1, size=(self.batch_size, 2))).type(torch.FloatTensor)
x_, z_, y_disc, y_cont = Variable(x_), Variable(z_), Variable(y_disc), Variable(y_cont)
# updte D network
self.D_optimizer.zero_grad()
# cont_codeとdisc_codeは使わない???
D_real, _, _ = self.D(x_)
D_real_loss = self.BCE_loss(D_real, self.y_real)
print('D_real:', D_real.size())
G_ = self.G(z_, y_cont, y_disc) # generatorが生成した画像 (N, 1, 28, 28)
print('y_cont:', y_cont.size())
print('y_disc:', y_disc.size())
print('G_:', G_.size())
D_fake, _, _ = self.D(G_)
D_fake_loss = self.BCE_loss(D_fake, self.y_fake)
D_loss = D_real_loss + D_fake_loss
D_loss.backward(retain_graph=True) # TODO: retain_graphは何?
self.D_optimizer.step()
# update G network
self.G_optimizer.zero_grad()
G_ = self.G(z_, y_cont, y_disc)
D_fake, D_cont, D_disc = self.D(G_)
print('D_fake:', D_fake.size())
print('D_cont:', D_cont.size())
print('D_disc:', D_disc.size())
G_loss = self.BCE_loss(D_fake, self.y_real)
G_loss.backward(retain_graph=True)
self.G_optimizer.step()
# information loss
disc_loss = self.CE_loss(D_disc, torch.max(y_disc, 1)[1]) # ラベルのCrossEntropyLoss
# TODO: D_contを入力のランダムのy_contに近づけようとしている???
cont_loss = self.MSE_loss(D_cont, y_cont)
info_loss = disc_loss + cont_loss
print('info_loss:', info_loss.data[0])
info_loss.backward()
self.info_optimizer.step()
break
break
In [61]:
infogan = InfoGAN()
infogan.train()
In [ ]: