In [36]:
import os
import time
import gzip
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

In [2]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

In [53]:
class generator(nn.Module):
    def __init__(self, dataset='mnist'):
        super(generator, self).__init__()
        self.input_height = 28
        self.input_width = 28
        self.input_dim = 62 + 10  # zの次元 + クラスの次元(10クラス)
        self.output_dim = 1

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            # 入力の1/4のサイズに縮小
            nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.ReLU(),
        )
        
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Sigmoid(),
        )
        
        initialize_weights(self)
    
    def forward(self, input, label):
        x = torch.cat([input, label], 1)
        print('***', x.size())
        x = self.fc(x)
        x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
        x = self.deconv(x)
        
        return x

gen = generator()
gen


Out[53]:
generator (
  (fc): Sequential (
    (0): Linear (72 -> 1024)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): Linear (1024 -> 6272)
    (4): BatchNorm1d(6272, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU ()
  )
  (deconv): Sequential (
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Sigmoid ()
  )
)

In [9]:
x = torch.randn(2, 3)
x


Out[9]:
-0.4825  0.4603 -0.1724
-0.7229  1.3024  0.6668
[torch.FloatTensor of size 2x3]

In [11]:
torch.cat((x, x, x), 0)


Out[11]:
-0.4825  0.4603 -0.1724
-0.7229  1.3024  0.6668
-0.4825  0.4603 -0.1724
-0.7229  1.3024  0.6668
-0.4825  0.4603 -0.1724
-0.7229  1.3024  0.6668
[torch.FloatTensor of size 6x3]

In [12]:
torch.cat((x, x, x), 1)


Out[12]:
-0.4825  0.4603 -0.1724 -0.4825  0.4603 -0.1724 -0.4825  0.4603 -0.1724
-0.7229  1.3024  0.6668 -0.7229  1.3024  0.6668 -0.7229  1.3024  0.6668
[torch.FloatTensor of size 2x9]

In [41]:
class discriminator(nn.Module):
    def __init__(self, dataset='mnist'):
        super(discriminator, self).__init__()
        self.input_height = 28
        self.input_width = 28
        self.input_dim = 1
        self.output_dim = 1
        self.class_num = 10
        
        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        
        self.fc1 = nn.Sequential(
            nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
        )

        self.dc = nn.Sequential(
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )

        self.c1 = nn.Sequential(
            nn.Linear(1024, self.class_num),
        )
        
        initialize_weights(self)
        
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
        x = self.fc1(x)
        d = self.dc(x)
        c = self.c1(x)
        
        return d, c

disc = discriminator()
disc


Out[41]:
discriminator (
  (conv): Sequential (
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU (0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU (0.2)
  )
  (fc1): Sequential (
    (0): Linear (6272 -> 1024)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
    (2): LeakyReLU (0.2)
  )
  (dc): Sequential (
    (0): Linear (1024 -> 1)
    (1): Sigmoid ()
  )
  (c1): Sequential (
    (0): Linear (1024 -> 10)
  )
)

In [37]:
def load_mnist(dataset):
    data_dir = os.path.join("./data", dataset)

    def extract_data(filename, num_data, head_size, data_size):
        with gzip.open(filename) as bytestream:
            bytestream.read(head_size)
            buf = bytestream.read(data_size * num_data)
            data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
        return data

    data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
    trX = data.reshape((60000, 28, 28, 1))

    data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
    trY = data.reshape((60000))

    data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
    teX = data.reshape((10000, 28, 28, 1))

    data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
    teY = data.reshape((10000))

    trY = np.asarray(trY).astype(np.int)
    teY = np.asarray(teY)

    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)

    seed = 547
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)

    y_vec = np.zeros((len(y), 10), dtype=np.float)
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1

    X = X.transpose(0, 3, 1, 2) / 255.
    # y_vec = y_vec.transpose(0, 3, 1, 2)

    X = torch.from_numpy(X).type(torch.FloatTensor)
    y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
    return X, y_vec

In [61]:
class ACGAN(object):
    def __init__(self):
        self.epoch = 5
        self.sample_num = 100
        self.batch_size = 64
        
        self.G = generator()
        self.D = discriminator()
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.BCE_loss = nn.BCELoss()
        self.CE_loss = nn.CrossEntropyLoss()

        self.data_X, self.data_Y = load_mnist('mnist')
        self.z_dim = 62
        self.y_dim = 10
        print(self.data_X.shape, self.data_Y.shape)

    def train(self):
        self.y_real_ = Variable(torch.ones(self.batch_size, 1))
        self.y_fake_ = Variable(torch.zeros(self.batch_size, 1))
        print(self.y_real_.size(), self.y_fake_.size())

        self.D.train()
        print('training start!!')
        for epoch in range(self.epoch):
            self.G.train()
            for iter in range(len(self.data_X) // self.batch_size):
                x_ = self.data_X[iter * self.batch_size: (iter + 1) * self.batch_size]
                z_ = torch.rand((self.batch_size, self.z_dim))
                y_vec_ = self.data_Y[iter * self.batch_size: (iter + 1) * self.batch_size]
                
                x_, z_, y_vec_ = Variable(x_), Variable(z_), Variable(y_vec_)
                
                # update D network
                self.D_optimizer.zero_grad()
                D_real, C_real = self.D(x_)  # 本物のデータを入力
                D_real_loss = self.BCE_loss(D_real, self.y_real_)
                C_real_loss = self.CE_loss(C_real, torch.max(y_vec_, 1)[1])
                
                G_ = self.G(z_, y_vec_)
                D_fake, C_fake = self.D(G_)
                D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
                C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])
                
                D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss
                D_loss.backward()
                self.D_optimizer.step()
                
                # update G network
                self.G_optimizer.zero_grad()
                G_ = self.G(z_, y_vec_)
                D_fake, C_fake = self.D(G_)
                
                G_loss = self.BCE_loss(D_fake, self.y_real_)
                C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])
                
                G_loss += C_fake_loss
                G_loss.backward()
                self.G_optimizer.step()

                print(D_loss.data[0], G_loss.data[0])
                break
            break

acgan = ACGAN()
acgan.train()


torch.Size([70000, 1, 28, 28]) torch.Size([70000, 10])
torch.Size([64, 1]) torch.Size([64, 1])
training start!!
*** torch.Size([64, 72])
*** torch.Size([64, 72])
5.925044059753418 2.528146505355835

In [32]:


In [ ]: