In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
In [2]:
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
In [3]:
class generator(nn.Module):
def __init__(self):
super(generator, self).__init__()
self.input_height = 28
self.input_width = 28
self.input_dim = 62 # zの次元
self.output_dim = 1
self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
nn.ReLU()
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
nn.Sigmoid(),
)
initialize_weights(self)
def forward(self, input):
x = self.fc(input)
x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
x = self.deconv(x)
return x
In [4]:
gen = generator()
gen
Out[4]:
In [5]:
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
)
self.fc = nn.Sequential(
nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, self.output_dim),
nn.Sigmoid(),
)
initialize_weights(self)
def forward(self, input):
x = self.conv(input)
x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
x = self.fc(x)
return x
In [6]:
disc = discriminator()
disc
Out[6]:
In [23]:
class WGAN(object):
def __init__(self):
self.epoch = 5
self.batch_size = 64
self.c = 0.01 # clipping value
self.n_critic = 5 # the number of iterations of the critic per generator iteration
self.G = generator()
self.D = discriminator()
self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
d = datasets.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor()]))
self.data_loader = DataLoader(d, batch_size=self.batch_size, shuffle=True)
self.z_dim = 62
self.sample_z = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True)
def train(self):
self.y_real = Variable(torch.ones(self.batch_size, 1))
self.y_fake = Variable(torch.zeros(self.batch_size, 1))
self.D.train()
print('training start!')
for epoch in range(self.epoch):
self.G.train()
for iter, (x_, _) in enumerate(self.data_loader):
# x_: (64, 1, 28, 28)
if iter == len(self.data_loader.dataset) // self.batch_size:
break
z_ = torch.rand((self.batch_size, self.z_dim))
x_, z_ = Variable(x_), Variable(z_)
# update D network
self.D_optimizer.zero_grad()
D_real = self.D(x_) # [64, 1]
print('D_real:', D_real.size())
# lossなので小さくしたい = D_realは大きくしたい
# 本物画像を入れたときのDの出力は大きい方がよい
D_real_loss = - torch.mean(D_real)
G_ = self.G(z_) # [64, 1, 28, 28]
print('G_:', G_.size())
D_fake = self.D(G_)
# lossなので小さくしたい = D_fakeは小さくしたい
# 偽物画像を入れたときのDの出力は小さい方がよい
D_fake_loss = torch.mean(D_fake)
# Dは本物を入れると大きな値が、偽物入れると小さな値が出力するように学習する
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
self.D_optimizer.step()
# clipping D
# パラメータの値の範囲を [-0.01, 0.01] に切り取る
# なぜこんなことするのか?
for p in self.D.parameters():
p.data.clamp_(-self.c, self.c)
# 5エポックごとにGを更新
if ((iter + 1) % self.n_critic) == 0:
# update G network
self.G_optimizer.zero_grad()
G_ = self.G(z_)
D_fake = self.D(G_)
# G_lossは小さい方がよい = D_fakeは大きい方がよい
# 本物と思わせたいのでDに大きな値を出力させたい
G_loss = - torch.mean(D_fake)
G_loss.backward()
self.G_optimzier.step()
break
break
In [24]:
wgan = WGAN()
wgan.train()
In [25]:
a = torch.randn(4)
a
Out[25]:
In [26]:
torch.clamp(a, min=-0.5, max=0.5)
Out[26]:
In [ ]: