In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

In [3]:
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.input_height = 28
        self.input_width = 28
        self.input_dim = 62  # zの次元
        self.output_dim = 1
        
        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.ReLU()
        )
        
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)
        
    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
        x = self.deconv(x)
        
        return x

In [4]:
gen = generator()
gen


Out[4]:
generator (
  (fc): Sequential (
    (0): Linear (62 -> 1024)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): Linear (1024 -> 6272)
    (4): BatchNorm1d(6272, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU ()
  )
  (deconv): Sequential (
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Sigmoid ()
  )
)

In [5]:
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.input_height = 28
        self.input_width = 28
        self.input_dim = 1
        self.output_dim = 1
        
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)
    
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
        x = self.fc(x)
        
        return x

In [6]:
disc = discriminator()
disc


Out[6]:
discriminator (
  (conv): Sequential (
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU (0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU (0.2)
  )
  (fc): Sequential (
    (0): Linear (6272 -> 1024)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
    (2): LeakyReLU (0.2)
    (3): Linear (1024 -> 1)
    (4): Sigmoid ()
  )
)

In [23]:
class WGAN(object):
    def __init__(self):
        self.epoch = 5
        self.batch_size = 64
        self.c = 0.01      # clipping value
        self.n_critic = 5  # the number of iterations of the critic per generator iteration
        
        self.G = generator()
        self.D = discriminator()
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        
        d = datasets.MNIST('data/mnist', train=True, download=True,
                           transform=transforms.Compose([transforms.ToTensor()]))
        self.data_loader = DataLoader(d, batch_size=self.batch_size, shuffle=True)
        self.z_dim = 62
        
        self.sample_z = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True)
    
    def train(self):
        self.y_real = Variable(torch.ones(self.batch_size, 1))
        self.y_fake = Variable(torch.zeros(self.batch_size, 1))
        
        self.D.train()
        print('training start!')
        for epoch in range(self.epoch):
            self.G.train()
            for iter, (x_, _) in enumerate(self.data_loader):
                # x_: (64, 1, 28, 28)
                if iter == len(self.data_loader.dataset) // self.batch_size:
                    break
                
                z_ = torch.rand((self.batch_size, self.z_dim))
                x_, z_ = Variable(x_), Variable(z_)
                
                # update D network
                self.D_optimizer.zero_grad()
                
                D_real = self.D(x_)  # [64, 1]
                print('D_real:', D_real.size())
                # lossなので小さくしたい = D_realは大きくしたい
                # 本物画像を入れたときのDの出力は大きい方がよい
                D_real_loss = - torch.mean(D_real)
                
                G_ = self.G(z_)  # [64, 1, 28, 28]
                print('G_:', G_.size())
                D_fake = self.D(G_)
                # lossなので小さくしたい = D_fakeは小さくしたい
                # 偽物画像を入れたときのDの出力は小さい方がよい
                D_fake_loss = torch.mean(D_fake)

                # Dは本物を入れると大きな値が、偽物入れると小さな値が出力するように学習する

                D_loss = D_real_loss + D_fake_loss
                D_loss.backward()
                self.D_optimizer.step()
                
                # clipping D
                # パラメータの値の範囲を [-0.01, 0.01] に切り取る
                # なぜこんなことするのか?
                for p in self.D.parameters():
                    p.data.clamp_(-self.c, self.c)
                
                # 5エポックごとにGを更新
                if ((iter + 1) % self.n_critic) == 0:
                    # update G network
                    self.G_optimizer.zero_grad()
                    
                    G_ = self.G(z_)
                    D_fake = self.D(G_)
                    # G_lossは小さい方がよい = D_fakeは大きい方がよい
                    # 本物と思わせたいのでDに大きな値を出力させたい
                    G_loss = - torch.mean(D_fake)
                    G_loss.backward()
                    self.G_optimzier.step()
                break
            break

In [24]:
wgan = WGAN()
wgan.train()


training start!
D_real: torch.Size([64, 1])
G_: torch.Size([64, 1, 28, 28])
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>
<class 'torch.FloatTensor'>

In [25]:
a = torch.randn(4)
a


Out[25]:
 3.3007
 0.2319
-0.4022
 0.3157
[torch.FloatTensor of size 4]

In [26]:
torch.clamp(a, min=-0.5, max=0.5)


Out[26]:
 0.5000
 0.2319
-0.4022
 0.3157
[torch.FloatTensor of size 4]

In [ ]: