In [36]:
import os
import time
import gzip
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
In [2]:
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
In [53]:
class generator(nn.Module):
def __init__(self, dataset='mnist'):
super(generator, self).__init__()
self.input_height = 28
self.input_width = 28
self.input_dim = 62 + 10 # zの次元 + クラスの次元(10クラス)
self.output_dim = 1
self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
# 入力の1/4のサイズに縮小
nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
nn.ReLU(),
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
nn.Sigmoid(),
)
initialize_weights(self)
def forward(self, input, label):
x = torch.cat([input, label], 1)
print('***', x.size())
x = self.fc(x)
x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
x = self.deconv(x)
return x
gen = generator()
gen
Out[53]:
In [9]:
x = torch.randn(2, 3)
x
Out[9]:
In [11]:
torch.cat((x, x, x), 0)
Out[11]:
In [12]:
torch.cat((x, x, x), 1)
Out[12]:
In [41]:
class discriminator(nn.Module):
def __init__(self, dataset='mnist'):
super(discriminator, self).__init__()
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
self.class_num = 10
self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
)
self.fc1 = nn.Sequential(
nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
)
self.dc = nn.Sequential(
nn.Linear(1024, self.output_dim),
nn.Sigmoid(),
)
self.c1 = nn.Sequential(
nn.Linear(1024, self.class_num),
)
initialize_weights(self)
def forward(self, input):
x = self.conv(input)
x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
x = self.fc1(x)
d = self.dc(x)
c = self.c1(x)
return d, c
disc = discriminator()
disc
Out[41]:
In [37]:
def load_mnist(dataset):
data_dir = os.path.join("./data", dataset)
def extract_data(filename, num_data, head_size, data_size):
with gzip.open(filename) as bytestream:
bytestream.read(head_size)
buf = bytestream.read(data_size * num_data)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
return data
data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
trX = data.reshape((60000, 28, 28, 1))
data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
trY = data.reshape((60000))
data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
teX = data.reshape((10000, 28, 28, 1))
data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
teY = data.reshape((10000))
trY = np.asarray(trY).astype(np.int)
teY = np.asarray(teY)
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0).astype(np.int)
seed = 547
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i, y[i]] = 1
X = X.transpose(0, 3, 1, 2) / 255.
# y_vec = y_vec.transpose(0, 3, 1, 2)
X = torch.from_numpy(X).type(torch.FloatTensor)
y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
return X, y_vec
In [61]:
class ACGAN(object):
def __init__(self):
self.epoch = 5
self.sample_num = 100
self.batch_size = 64
self.G = generator()
self.D = discriminator()
self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.BCE_loss = nn.BCELoss()
self.CE_loss = nn.CrossEntropyLoss()
self.data_X, self.data_Y = load_mnist('mnist')
self.z_dim = 62
self.y_dim = 10
print(self.data_X.shape, self.data_Y.shape)
def train(self):
self.y_real_ = Variable(torch.ones(self.batch_size, 1))
self.y_fake_ = Variable(torch.zeros(self.batch_size, 1))
print(self.y_real_.size(), self.y_fake_.size())
self.D.train()
print('training start!!')
for epoch in range(self.epoch):
self.G.train()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter * self.batch_size: (iter + 1) * self.batch_size]
z_ = torch.rand((self.batch_size, self.z_dim))
y_vec_ = self.data_Y[iter * self.batch_size: (iter + 1) * self.batch_size]
x_, z_, y_vec_ = Variable(x_), Variable(z_), Variable(y_vec_)
# update D network
self.D_optimizer.zero_grad()
D_real, C_real = self.D(x_) # 本物のデータを入力
D_real_loss = self.BCE_loss(D_real, self.y_real_)
C_real_loss = self.CE_loss(C_real, torch.max(y_vec_, 1)[1])
G_ = self.G(z_, y_vec_)
D_fake, C_fake = self.D(G_)
D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])
D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss
D_loss.backward()
self.D_optimizer.step()
# update G network
self.G_optimizer.zero_grad()
G_ = self.G(z_, y_vec_)
D_fake, C_fake = self.D(G_)
G_loss = self.BCE_loss(D_fake, self.y_real_)
C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])
G_loss += C_fake_loss
G_loss.backward()
self.G_optimizer.step()
print(D_loss.data[0], G_loss.data[0])
break
break
acgan = ACGAN()
acgan.train()
In [32]:
In [ ]: