In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [3]:
data_mean = 4
data_stddev = 1.25

In [4]:
g_input_size = 1
g_hidden_size = 50
g_output_size = 1
d_input_size = 100
d_hidden_size = 50
d_output_size = 1
minibatch_size = d_input_size

In [5]:
d_learning_rate = 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1
g_steps = 1

In [6]:
# dataからdata[0]を引いて2乗したリストを結合している?
# 何のために?
def decorate_with_diffs(data, exponent):
    # data.data = Variable => Tensor
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

In [7]:
before = Variable(torch.rand(10, 1))
after = decorate_with_diffs(before, 2.0)
print(before)
print(after)


Variable containing:
 0.1432
 0.1751
 0.5798
 0.3055
 0.5783
 0.4844
 0.7658
 0.1784
 0.3332
 0.0914
[torch.FloatTensor of size 10x1]

Variable containing:
 0.1432  0.0000
 0.1751  0.0010
 0.5798  0.1906
 0.3055  0.0263
 0.5783  0.1892
 0.4844  0.1164
 0.7658  0.3876
 0.1784  0.0012
 0.3332  0.0361
 0.0914  0.0027
[torch.FloatTensor of size 10x2]


In [8]:
name = 'Data and variances'
preprocess = lambda data: decorate_with_diffs(data, 2.0)
d_input_func = lambda x: x * 2

In [9]:
print('Using data [%s]' % name)


Using data [Data and variances]

In [10]:
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))

In [11]:
d_sampler = get_distribution_sampler(data_mean, data_stddev)

In [12]:
d_sampler(10)


Out[12]:
 5.0650  7.6520  3.2748  3.9452  4.0405  3.3511  3.8134  5.8915  5.2322  3.8391
[torch.FloatTensor of size 1x10]

In [13]:
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)

In [14]:
gi_sampler = get_generator_input_sampler()

In [15]:
gi_sampler(3, 2)


Out[15]:
 0.0010  0.9491
 0.6421  0.1529
 0.5812  0.0512
[torch.FloatTensor of size 3x2]

In [16]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

In [17]:
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)

In [18]:
G


Out[18]:
Generator (
  (map1): Linear (1 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

In [19]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        # 本物画像のとき1、偽物画像のとき0 = 本物画像の確率を出力
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

In [20]:
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)

In [21]:
D


Out[21]:
Discriminator (
  (map1): Linear (200 -> 50)
  (map2): Linear (50 -> 50)
  (map3): Linear (50 -> 1)
)

In [22]:
criterion = nn.BCELoss()

In [23]:
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)

In [24]:
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

In [25]:
d_optimizer


Out[25]:
<torch.optim.adam.Adam at 0x10b562390>

In [39]:
def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

In [43]:
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real + fake
        D.zero_grad()
        
        # 1A. Train D on real
        # 指定した平均・標準偏差のランダムデータを生成
        # Gがこの分布を学習するのが最終目的
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        # 本物画像を入力したときは1を出力する
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))
        d_real_error.backward()

        # 1B. Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))
        d_fake_error.backward()
        d_optimizer.step()  # update parameters

    for g_index in range(g_steps):
        G.zero_grad()
        
        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        # Gは騙したいので本物=1に分類されるように学習する
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))
        g_error.backward()
        g_optimizer.step()
    
    if epoch % print_interval == 0:
        print('%s: D: %s/%s G: %s (Real: %s, Fake: %s)' % (
            epoch,
            extract(d_real_error)[0],
            extract(d_fake_error)[0],
            extract(g_error)[0],
            stats(extract(d_real_data)),
            stats(extract(d_fake_data))))


/Users/koichiro.mori/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/torch/nn/functional.py:767: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])) is deprecated. Please ensure they have the same size.
  "Please ensure they have the same size.".format(target.size(), input.size()))
0: D: 0.05512115731835365/0.4082205593585968 G: 1.105771541595459 (Real: [4.1565990126132961, 1.2711233205301873], Fake: [0.67944827854633327, 0.038399155626412371])
200: D: 0.0028105496894568205/0.1890818178653717 G: 1.7559853792190552 (Real: [3.9789160776138304, 1.3551562150017891], Fake: [0.4557408770918846, 0.045124946576745199])
400: D: 0.004272036254405975/0.08585026860237122 G: 2.6675233840942383 (Real: [3.7743640394695102, 1.2353088212241417], Fake: [0.28140656068921088, 0.10168814482323617])
600: D: 6.556531388923759e-06/0.21870669722557068 G: 1.951550006866455 (Real: [3.9861805975437163, 1.3794096915423726], Fake: [0.051964230462908746, 0.28537573942563749])
800: D: 0.0017126555321738124/0.029417188838124275 G: 3.291212320327759 (Real: [4.0462950253486634, 1.0573092765582288], Fake: [0.96370529353618617, 0.40037159610085621])
1000: D: 0.06592345237731934/0.48201489448547363 G: 1.9867020845413208 (Real: [3.8995013171434403, 1.1036350894963858], Fake: [1.7540165454521774, 1.0199131791413849])
1200: D: 0.8194631934165955/0.3244119882583618 G: 2.7068188190460205 (Real: [3.960039269924164, 1.2113224091748127], Fake: [3.7605122864246368, 1.0512350664244468])
1400: D: 1.2368741035461426/0.35094499588012695 G: 0.6602744460105896 (Real: [3.950193212032318, 1.3529799237266706], Fake: [4.509739962816238, 1.1629432918945652])
1600: D: 1.0757906436920166/0.5475696921348572 G: 0.8456125855445862 (Real: [3.9613199806213379, 1.284405246903924], Fake: [4.6178848588466641, 1.3620039776261466])
1800: D: 0.6149595379829407/0.746460497379303 G: 0.6704933643341064 (Real: [4.0865075969696045, 1.168936620775342], Fake: [4.659805953502655, 1.4002544334435112])
2000: D: 0.6139893531799316/0.48221027851104736 G: 1.0073984861373901 (Real: [3.835126178264618, 1.0464889796120826], Fake: [5.1052543914318083, 1.1107255571844996])
2200: D: 0.6595320701599121/0.5641043782234192 G: 0.7994662523269653 (Real: [3.9929641062021255, 1.4401799695450672], Fake: [4.4001675677299499, 1.2610804125936785])
2400: D: 0.7465106248855591/0.8436891436576843 G: 0.6915052533149719 (Real: [3.889567503929138, 1.1463142810064744], Fake: [3.8551893633604051, 1.2221866539331669])
2600: D: 0.5809812545776367/0.6643964648246765 G: 0.4946178197860718 (Real: [4.1319891977310181, 1.1245013920096214], Fake: [3.1957498496770858, 1.3267394670670463])
2800: D: 0.5868017673492432/0.533968985080719 G: 0.7199199795722961 (Real: [3.9951232337951659, 1.2060742020904986], Fake: [3.5135262829065321, 1.1624202422115406])
3000: D: 0.5850471258163452/0.7998604774475098 G: 0.7080922722816467 (Real: [3.9457317173480986, 1.1214501847306897], Fake: [3.7142029261589049, 1.3199320019322376])
3200: D: 0.6538683772087097/0.5268505811691284 G: 0.7574522495269775 (Real: [3.8888055527210237, 1.3931550006705282], Fake: [4.495435053110123, 1.2410368130311582])
3400: D: 0.6665226221084595/0.6928370594978333 G: 0.6804638504981995 (Real: [3.9803959506750108, 1.1998332026864647], Fake: [4.55838235616684, 1.2129527394277271])
3600: D: 0.9926980137825012/0.6907450556755066 G: 0.6778942346572876 (Real: [4.0564133253693582, 1.4599903081510077], Fake: [3.9072418761253358, 1.2211109224564927])
3800: D: 0.7116783261299133/0.5940961241722107 G: 0.5612344741821289 (Real: [4.0341494393348691, 1.2905073214795435], Fake: [3.348288506269455, 1.312958326433552])
4000: D: 0.2271018773317337/0.6376434564590454 G: 0.8291262984275818 (Real: [4.0005604088306423, 1.344657694725935], Fake: [3.7718676543235778, 1.2012874968170462])
4200: D: 0.8780814409255981/0.7653242349624634 G: 0.737142026424408 (Real: [3.8665670120716094, 1.3917266704823412], Fake: [4.020186582803726, 1.4296876395215463])
4400: D: 0.46425172686576843/0.8098177909851074 G: 0.6808632016181946 (Real: [4.1853560698032375, 1.1932873646985405], Fake: [4.1275411260128019, 1.2148076973808268])
4600: D: 0.5544388890266418/0.7796315550804138 G: 0.7196078300476074 (Real: [3.7232158377021549, 1.1666334492078858], Fake: [4.1103524172306063, 0.89873836298196474])
4800: D: 0.7088126540184021/0.6320845484733582 G: 0.6691069006919861 (Real: [3.8212835961580276, 1.2189921569485314], Fake: [3.6935389176011086, 1.5030029646512337])
5000: D: 0.27461397647857666/0.6773078441619873 G: 0.6972963213920593 (Real: [4.1667515313625332, 1.3931213108476754], Fake: [4.2575979053974153, 1.2389080239690042])
5200: D: 0.39945241808891296/0.7015224099159241 G: 0.7870875000953674 (Real: [3.8199763894081116, 1.3062564122644471], Fake: [4.1510762500762937, 1.0601728510039021])
5400: D: 0.7777531743049622/0.710725724697113 G: 0.8341628909111023 (Real: [3.9446107482910158, 1.2300690776168441], Fake: [3.9905421769618989, 1.2519079896867291])
5600: D: 0.4527195394039154/0.6649509072303772 G: 0.8792474865913391 (Real: [3.9777761828899383, 1.3668677602780441], Fake: [3.7485285645723341, 1.3771053778206703])
5800: D: 0.8747773766517639/0.9393650889396667 G: 0.7043017148971558 (Real: [3.941868385076523, 1.2394382014743384], Fake: [4.0025003623962405, 1.2992998925798585])
6000: D: 0.5818036198616028/0.6261497735977173 G: 0.9190885424613953 (Real: [3.8127291327714921, 1.252257079512737], Fake: [3.8998659515380858, 1.2232444877848665])
6200: D: 0.8303660750389099/1.188339352607727 G: 0.8699972629547119 (Real: [3.8198971956968308, 1.1170101934893755], Fake: [4.2230133044719693, 1.1248272361900744])
6400: D: 0.5782439112663269/0.5568006634712219 G: 0.6650205254554749 (Real: [4.1426924622058872, 1.1865515167573897], Fake: [4.1017770373821261, 1.1986582675355193])
6600: D: 0.6717461347579956/0.5138940811157227 G: 1.2079883813858032 (Real: [4.0142995595932005, 1.2258296103230957], Fake: [4.235463209152222, 1.1324445263038612])
6800: D: 0.5396738052368164/0.7430610656738281 G: 0.7751320004463196 (Real: [4.0621238195896145, 1.1308373716537385], Fake: [3.9511045134067535, 1.1577291738796165])
7000: D: 0.4284682869911194/0.5349873900413513 G: 0.8274132013320923 (Real: [3.9336301231384279, 1.3015859330029003], Fake: [4.2760844683647159, 1.1222353838446733])
7200: D: 0.6257358193397522/0.4107515215873718 G: 0.7833272814750671 (Real: [3.9244627773761751, 1.2858862498856629], Fake: [4.2389549696445465, 1.2103063965888243])
7400: D: 0.878644585609436/0.5876164436340332 G: 0.6143878698348999 (Real: [4.1161256623268123, 1.3170366495580206], Fake: [3.7328793567419054, 1.2102158012014872])
7600: D: 0.7890059351921082/0.4947899580001831 G: 0.9793896675109863 (Real: [3.9871299457550049, 1.236836703062574], Fake: [4.4991523402929303, 1.2306246107750076])
7800: D: 0.4976951479911804/0.623058021068573 G: 0.7209433913230896 (Real: [3.9165514234825967, 1.3777187076462156], Fake: [3.7824044573307036, 1.3025031819969088])
8000: D: 0.7289396524429321/1.0010366439819336 G: 0.7624824047088623 (Real: [4.1150783216953277, 1.1975378371376413], Fake: [3.6601270785927773, 1.346977557190058])
8200: D: 0.3480786979198456/0.5865159034729004 G: 1.0244852304458618 (Real: [4.0595623004436492, 1.3001833399550102], Fake: [4.0489757508039474, 1.4216529915900986])
8400: D: 0.5007491111755371/0.5047882199287415 G: 0.6774538159370422 (Real: [4.1386166179180144, 1.1530787591887126], Fake: [3.7198043078184129, 1.2826657498986946])
8600: D: 0.7487382292747498/0.6503969430923462 G: 0.5163914561271667 (Real: [3.83702227383852, 1.3206772255649857], Fake: [4.2506998574733732, 1.4166877240317919])
8800: D: 0.942085325717926/0.7767379879951477 G: 0.6842425465583801 (Real: [4.0171367186307911, 1.3355391895628519], Fake: [3.9077478060126305, 1.2638226350673385])
9000: D: 0.645487368106842/0.5761956572532654 G: 0.4877888560295105 (Real: [3.7721445500850677, 1.2229381758458426], Fake: [3.9650374603271485, 1.2028985563253132])
9200: D: 0.8316420912742615/0.7499564290046692 G: 0.8374991416931152 (Real: [3.9749160194396973, 1.2843367682658258], Fake: [4.1352110421657562, 1.3930234042783367])
9400: D: 0.5844429135322571/0.5902791023254395 G: 0.5813054442405701 (Real: [3.9881389915943144, 1.2047527435471803], Fake: [3.7815706551074983, 1.3003322800619854])
9600: D: 0.4257098138332367/0.46111661195755005 G: 0.8408671617507935 (Real: [3.9948023343086243, 1.2334321687513012], Fake: [4.3323624193668362, 1.127262567050022])
9800: D: 0.4880388677120209/0.6008905172348022 G: 0.8387617468833923 (Real: [3.9061585950851438, 1.2108708415509157], Fake: [4.1334000766277317, 1.4915696260587377])
10000: D: 0.3732897639274597/0.5901905298233032 G: 0.6366154551506042 (Real: [3.9055146288871767, 1.2133980806255134], Fake: [3.9228189826011657, 1.251756625402791])
10200: D: 0.8453251719474792/1.135027527809143 G: 0.6891510486602783 (Real: [3.8702863442897795, 1.2698450804467942], Fake: [4.0397981590032579, 1.2074211475023717])
10400: D: 0.5000208020210266/0.2970240116119385 G: 0.6849726438522339 (Real: [4.1309327584505082, 1.3221326231243986], Fake: [4.3132912802696231, 1.2856482801055216])
10600: D: 0.5760613083839417/0.43661582469940186 G: 1.6333136558532715 (Real: [3.9388609640300274, 1.2079466504449969], Fake: [3.7228925043344496, 1.3377923580420676])
10800: D: 0.3816162347793579/0.5216460227966309 G: 1.1398087739944458 (Real: [4.1427898854017258, 1.3228319016418078], Fake: [4.205437598824501, 1.2380966878920887])
11000: D: 0.5926294326782227/0.7510080337524414 G: 0.4377184212207794 (Real: [3.815283213853836, 1.1191236293088271], Fake: [4.2560344040393829, 1.0614220921090431])
11200: D: 0.7483011484146118/0.9954952001571655 G: 0.7852908372879028 (Real: [3.9567468088865279, 1.1795057326604892], Fake: [3.8529889345169068, 1.1912669797771063])
11400: D: 1.042837142944336/0.6091429591178894 G: 0.6324095726013184 (Real: [4.0356894898414613, 1.2709853949269188], Fake: [4.2195229411125181, 1.0958390111941956])
11600: D: 0.37291574478149414/0.6550974249839783 G: 1.0800412893295288 (Real: [3.9672647404670713, 1.2901472924879065], Fake: [4.0814898514747622, 1.234707197261022])
11800: D: 0.6646013855934143/0.23992286622524261 G: 1.3690134286880493 (Real: [4.0708577740192418, 1.073572705379451], Fake: [3.5742011445760729, 1.4643125918252167])
12000: D: 0.7876962423324585/0.48974359035491943 G: 0.6581987142562866 (Real: [3.8636049747467043, 1.240300356122559], Fake: [3.9380668580532072, 1.3015699891441792])
12200: D: 0.9523488283157349/0.6090847849845886 G: 0.8085739612579346 (Real: [4.0226041042804717, 1.0724749475598045], Fake: [4.0381029129028319, 1.2943712458859025])
12400: D: 0.15638814866542816/0.6392780542373657 G: 0.7199721336364746 (Real: [4.0183075094223026, 1.1366302526542953], Fake: [3.9197939491271971, 1.2368866242654757])
12600: D: 0.6898838877677917/0.8038758039474487 G: 0.7936872839927673 (Real: [4.1059937608242034, 1.2282101562663148], Fake: [3.8194100993871687, 1.2864648099913392])
12800: D: 0.7003756761550903/0.18970702588558197 G: 1.9971809387207031 (Real: [3.868145455121994, 1.299615629233696], Fake: [4.3206208842992782, 1.2941047704023918])
13000: D: 0.43471577763557434/0.5627126693725586 G: 0.5715929269790649 (Real: [4.1260269659757611, 1.2807163698299016], Fake: [3.9320609283447268, 1.2723336480591507])
13200: D: 0.8981638550758362/0.47058287262916565 G: 1.7347959280014038 (Real: [4.0046638453006747, 1.2993799008089277], Fake: [3.7681052583456038, 1.3529035361570654])
13400: D: 0.6059008836746216/0.4301181137561798 G: 0.7209514379501343 (Real: [3.9827207377552987, 1.1986134687840519], Fake: [4.5166181302070614, 1.0041566989781869])
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-43-729073d0ec3b> in <module>()
     26         gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
     27         g_fake_data = G(gen_input)
---> 28         dg_fake_decision = D(preprocess(g_fake_data.t()))
     29         # Gは騙したいので本物=1に分類されるように学習する
     30         g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))

KeyboardInterrupt: 

In [ ]: