In [1]:
%pylab inline


Populating the interactive namespace from numpy and matplotlib

In [2]:
from io import BytesIO
from PIL import Image as PIL_Image
import numpy as np
from IPython.display import display, Image

def display_img_array(ima, **kwargs):
    if ima.dtype == np.float32 or ima.dtype == np.float64:
        ima = (np.clip(ima, 0., 1.)*255).astype(np.uint8)
    im = PIL_Image.fromarray(ima)
    bio = BytesIO()
    im.save(bio, format='png')
    display(Image(bio.getvalue(), format='png', **kwargs))

In [3]:
import os
import urllib
from urllib.request import urlretrieve
dataset = 'mnist.pkl.gz'
def reporthook(a,b,c):
    print("\rdownloading: %5.1f%%"%(a*b*100.0/c), end="")
    
if not os.path.isfile(dataset):
        origin = "https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz"
        print('Downloading data from %s' % origin)
        urlretrieve(origin, dataset, reporthook=reporthook)

In [4]:
import gzip
import pickle
with gzip.open(dataset, 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')

In [5]:
print("train_set", train_set[0].shape, train_set[1].shape)
print("test_set", test_set[0].shape, test_set[1].shape)


train_set (50000, 784) (50000,)
test_set (10000, 784) (10000,)

In [6]:
IMG_C, IMG_W, IMG_H=1,28,28

In [7]:
def show(x):
    x = x.reshape(-1, IMG_C, IMG_H, IMG_W)
    w = min(x.shape[0]*100, 800)
    x = x.swapaxes(0, 1).swapaxes(1,2)
    if IMG_C == 1:
        img = x.reshape(IMG_H, -1)
    else:
        x = x.reshape(IMG_C, IMG_H, -1)
        img = np.swapaxes(x, 0, 1)
        img = np.swapaxes(img, 1, 2)
    display_img_array(img, width=w)    
for i in range(3):
    show(train_set[0][i])
show(train_set[0][:10])



In [8]:
import sys
sys.setrecursionlimit(10000)

In [9]:
import numpy as np
import theano
import theano.tensor as T
import lasagne
from lasagne.layers import DenseLayer, DropoutLayer, ReshapeLayer, InputLayer, FlattenLayer, Upscale2DLayer, LocalResponseNormalization2DLayer
floatX = theano.config.floatX


Using gpu device 0: GeForce GTX 1080 (CNMeM is disabled, cuDNN 5105)

In [10]:
from lasagne.layers.dnn import MaxPool3DDNNLayer, Conv3DDNNLayer, MaxPool2DDNNLayer, Conv2DDNNLayer

In [11]:
from lasagne.layers import batch_norm, ElemwiseSumLayer, NonlinearityLayer, GlobalPoolLayer, ConcatLayer
from lasagne.nonlinearities import rectify
from lasagne.layers import get_output

In [12]:
from lasagne.objectives import categorical_crossentropy, binary_crossentropy, categorical_accuracy, binary_accuracy

In [13]:
input_Y = T.matrix()  # discrete latent variable (1  10-category)
input_C = T.matrix()  # continuous latent variable (3 gaussian)
input_Z = T.matrix()  # noise (NU gaussian)
input_X = T.matrix()  # image
target_var = T.matrix() # discriminator target

def conv(*args, **kargs):
    return batch_norm(Conv2DDNNLayer(*args, pad='same', nonlinearity=rectify, W=lasagne.init.HeNormal(), **kargs))

_ = InputLayer(shape=(None, IMG_C*IMG_H*IMG_W), input_var=input_X)
_ = ReshapeLayer(_, ([0], IMG_C, IMG_H, IMG_W))
_ = batch_norm(Conv2DDNNLayer(_, 64, 3, pad='same'))
_ = batch_norm(Conv2DDNNLayer(_, 64, 3, pad='same'))
_ = MaxPool2DDNNLayer(_, 2)
_ = batch_norm(Conv2DDNNLayer(_, 128, 3, pad='same'))
_ = MaxPool2DDNNLayer(_, 2)
_ = batch_norm(Conv2DDNNLayer(_, 256, 3, pad='same'))
_ = batch_norm(Conv2DDNNLayer(_, 64, 3, pad='same'))
_ = FlattenLayer(_)
_ = DenseLayer(_, num_units=1000, nonlinearity=lasagne.nonlinearities.rectify)
l_discriminator = DenseLayer(_, num_units=1, nonlinearity=lasagne.nonlinearities.sigmoid)
l_Q_Y = DenseLayer(_, num_units=10, nonlinearity=lasagne.nonlinearities.softmax)
l_Q_C = DenseLayer(_, num_units=4, nonlinearity=lasagne.nonlinearities.linear)


NU = 32 # dim of noise
input_var2 = T.matrix()
_Y = InputLayer(shape=(None, 10), input_var=input_Y)
#_C = InputLayer(shape=(None, 3), input_var=input_C)
_Z = InputLayer(shape=(None, NU), input_var=input_Z)
_ = ConcatLayer((_Y, _Z))
_ = batch_norm(DenseLayer(_, num_units=1000, nonlinearity=lasagne.nonlinearities.rectify))
_ = batch_norm(DenseLayer(_, num_units=64*(IMG_H//4)*(IMG_W//4), nonlinearity=lasagne.nonlinearities.rectify))
_ = ReshapeLayer(_, ([0], 64, IMG_H//4, IMG_W//4))
_ = batch_norm(Conv2DDNNLayer(_, 128, 3, pad='same'))
_ = Upscale2DLayer(_, 2)
_ = batch_norm(Conv2DDNNLayer(_, 128, 3, pad='same'))
_ = Upscale2DLayer(_, 2)
_ = batch_norm(Conv2DDNNLayer(_, 128, 3, pad='same'))
_ = batch_norm(Conv2DDNNLayer(_, 128, 3, pad='same'))
_ = batch_norm(Conv2DDNNLayer(_, IMG_C, 3, pad='same', nonlinearity=lasagne.nonlinearities.rectify))
l_generator = FlattenLayer(_)

def clip(x):
    return T.clip(x, 1e-7, 1-1e-7)

output_generator_deterministic = get_output(l_generator, deterministic=True)
# 
output_discriminator = get_output(l_discriminator) #, batch_norm_use_averages=True)
output_Q_Y = get_output(l_Q_Y)

output_generator = get_output(l_generator)
output_generator_discriminator = get_output(l_discriminator, inputs=output_generator, deterministic=True)
output_generator_Q_Y = get_output(l_Q_Y, inputs=output_generator, deterministic=True)

loss_discriminator0 = loss_discriminator = binary_crossentropy(clip(output_discriminator), target_var).mean()
loss_generator0 = loss_generator = binary_crossentropy(clip(output_generator_discriminator), T.ones_like(output_generator_discriminator)).mean()
loss_generator_Q_Y = categorical_crossentropy(clip(output_generator_Q_Y), input_Y).mean()
loss_discriminator_Q_Y =  loss_generator_Q_Y

loss_discriminator += 0.01 *loss_discriminator_Q_Y
loss_generator += loss_generator_Q_Y


accuracy_discriminator = binary_accuracy(output_discriminator, target_var).mean()
accuracy_generator = binary_accuracy(output_generator_discriminator, 
                                     T.zeros_like(output_generator_discriminator)).mean()
accuracy_generator_Q_Y = categorical_accuracy(output_generator_Q_Y, input_Y).mean()

params_discriminator = lasagne.layers.get_all_params([l_discriminator, l_Q_Y], trainable=True) 
params_generator = lasagne.layers.get_all_params(l_generator, trainable=True) 

updates_generator = lasagne.updates.adam(loss_generator, 
                                            params_generator, 
                                            learning_rate=4e-4,  beta1=0.5)
updates_discriminator = lasagne.updates.adam(loss_discriminator, 
                                                params_discriminator, 
                                                learning_rate=2e-4,  beta1=0.5)

train_generator_fn = theano.function([input_Y, input_Z], 
                                     (loss_generator0, loss_generator_Q_Y, 
                                      accuracy_generator, accuracy_generator_Q_Y), 
                                     updates=updates_generator)
train_discriminator_fn = theano.function([input_X, input_Y, input_Z, target_var], 
                                         (loss_discriminator0, loss_generator_Q_Y, 
                                          accuracy_discriminator, accuracy_generator_Q_Y), 
                                         updates=updates_discriminator)

generator_fn = theano.function([input_Y, input_Z], output_generator_deterministic)

In [14]:
logf = open('mnist-infogan.log', 'w')
import sys
def printx(*args, **kwargs):
    print(*args, **kwargs)    
    print(*args, file=logf, **kwargs)
    logf.flush()
    sys.stdout.flush()

In [15]:
import sys
from random import randint
X=train_set[0] #[train_set[1]==5]
Y=train_set[1].astype('int32')
X_test = test_set[0]
Y_test = test_set[1].astype('int32')
last_imgs = None
j = 0
batch_size=256

In [16]:
def sample_X(batch_size=batch_size):
    idx = np.random.randint(0, X.shape[0], size=batch_size)
    return X[idx]

def sample_Y(batch_size=batch_size):
    input_Y = np.zeros((batch_size, 10), dtype='float32')
    random_y=np.random.randint(0, 10, size=batch_size)
    input_Y[np.arange(batch_size), random_y] = 1
    return input_Y

def sample_Z(batch_size=batch_size):
    return (np.random.normal(size=(batch_size, NU))).astype('float32')

In [17]:
for j in range(j, 100*100):
    x = sample_X()
    x_fake = generator_fn(sample_Y(), sample_Z())
    is_real = np.random.randint(0,2,size=batch_size)
    x_mixed = np.array([x[i] if is_real[i] else x_fake[i] for i in range(batch_size)], dtype='float32')
    is_real = is_real.reshape((-1,1)).astype('float32')
    d_err, q_err, d_acc, q_acc = train_discriminator_fn(x_mixed, sample_Y(), sample_Z(), is_real)     
    #print("generator phase")
    for __ in range(2):
        g_err, q_err2, g_acc, q_acc2 = train_generator_fn(sample_Y(), sample_Z())
    if j%100==0:
        printx("j=", j)
        Y=np.zeros((100,10),dtype='float32')
        Y[arange(100), arange(100)//10]=1
        imgs = generator_fn(Y, sample_Z(100))
        for i in range(0, 100, 20):
            show(imgs[i:i+20])
        printx("d_err", d_err, d_acc)
        printx("q_err", q_err, q_acc)
        printx("g_err", g_err, g_acc)
        printx("q_err", q_err2, q_acc2)


j= 0
d_err 1.126915693283081 0.5390625
q_err 2.3030595779418945 0.109375
g_err 0.40123459696769714 0.0
q_err 2.3029041290283203 0.12109375
j= 100
d_err 0.4136562943458557 0.83984375
q_err 0.269430935382843 0.9609375
g_err 15.982287406921387 1.0
q_err 0.3230594992637634 0.9453125
j= 200
d_err 0.59886634349823 0.71484375
q_err 0.21705777943134308 0.99609375
g_err 1.4825172424316406 0.90625
q_err 0.20919616520404816 0.9921875
j= 300
d_err 0.6689809560775757 0.65625
q_err 0.15956971049308777 0.9921875
g_err 0.8719648122787476 0.60546875
q_err 0.14308972656726837 0.9921875
j= 400
d_err 0.5350631475448608 0.73828125
q_err 0.05353807285428047 0.99609375
g_err 0.7865806818008423 0.6015625
q_err 0.08310800790786743 0.984375
j= 500
d_err 0.6470885872840881 0.6796875
q_err 0.07600826025009155 1.0
g_err 0.30614548921585083 0.140625
q_err 0.08282946050167084 0.9921875
j= 600
d_err 0.5865745544433594 0.69140625
q_err 0.11684250831604004 0.99609375
g_err 0.9007103443145752 0.76171875
q_err 0.07109084725379944 1.0
j= 700
d_err 0.4380419850349426 0.796875
q_err 0.02450738102197647 1.0
g_err 1.5348422527313232 0.94921875
q_err 0.02113475278019905 1.0
j= 800
d_err 0.5908045768737793 0.6953125
q_err 0.014777855947613716 1.0
g_err 0.7422816753387451 0.4765625
q_err 0.015718460083007812 1.0
j= 900
d_err 0.4592567980289459 0.8046875
q_err 0.08788764476776123 0.97265625
g_err 5.117801189422607 1.0
q_err 0.12110379338264465 0.984375
j= 1000
d_err 0.4852098226547241 0.76171875
q_err 0.032900989055633545 1.0
g_err 0.25300800800323486 0.03125
q_err 0.02352563664317131 1.0
j= 1100
d_err 0.5729730129241943 0.6796875
q_err 0.021316003054380417 1.0
g_err 0.5522302389144897 0.37890625
q_err 0.021190796047449112 1.0
j= 1200
d_err 0.41268596053123474 0.8046875
q_err 0.03399375081062317 0.9921875
g_err 3.086933135986328 1.0
q_err 0.030035056173801422 0.9921875
j= 1300
d_err 0.294484943151474 0.8359375
q_err 0.014657936990261078 1.0
g_err 0.011762451380491257 0.0
q_err 0.029092270880937576 0.9921875
j= 1400
d_err 0.5319502949714661 0.71875
q_err 0.06752363592386246 0.984375
g_err 1.8191304206848145 0.99609375
q_err 0.04738617688417435 0.984375
j= 1500
d_err 0.40533924102783203 0.8515625
q_err 0.01681305654346943 0.99609375
g_err 0.15574148297309875 0.05078125
q_err 0.0184317734092474 1.0
j= 1600
d_err 0.5064029693603516 0.75
q_err 0.022598247975111008 0.9921875
g_err 0.18691974878311157 0.06640625
q_err 0.02213870920240879 1.0
j= 1700
d_err 0.5204455852508545 0.734375
q_err 0.02759394235908985 1.0
g_err 0.7366073131561279 0.5390625
q_err 0.04159180074930191 1.0
j= 1800
d_err 0.6854312419891357 0.609375
q_err 0.026641175150871277 0.99609375
g_err 0.9151043891906738 0.65625
q_err 0.02440647967159748 0.99609375
j= 1900
d_err 0.610980749130249 0.6484375
q_err 0.014175672084093094 0.99609375
g_err 1.0513451099395752 0.796875
q_err 0.0077676549553871155 1.0
j= 2000
d_err 0.5784474015235901 0.69140625
q_err 0.021977422758936882 0.99609375
g_err 0.8172820806503296 0.58203125
q_err 0.01430157944560051 0.99609375
j= 2100
d_err 0.5166738033294678 0.75390625
q_err 0.008731553331017494 1.0
g_err 0.32905763387680054 0.07421875
q_err 0.008509327657520771 1.0
j= 2200
d_err 0.541424572467804 0.72265625
q_err 0.03100326657295227 0.9921875
g_err 0.7478556036949158 0.51953125
q_err 0.03460661321878433 0.9921875
j= 2300
d_err 0.4688061475753784 0.7734375
q_err 0.014500228688120842 0.99609375
g_err 0.5018768310546875 0.25390625
q_err 0.013650273904204369 0.99609375
j= 2400
d_err 0.6152700185775757 0.7109375
q_err 0.022561602294445038 1.0
g_err 0.24742969870567322 0.125
q_err 0.0237193014472723 0.99609375
j= 2500
d_err 0.6007158756256104 0.64453125
q_err 0.017583981156349182 1.0
g_err 0.82810378074646 0.56640625
q_err 0.023541979491710663 0.99609375
j= 2600
d_err 0.42232686281204224 0.83984375
q_err 0.011682438664138317 0.99609375
g_err 0.3060358762741089 0.1171875
q_err 0.014138803817331791 1.0
j= 2700
d_err 0.44004884362220764 0.82421875
q_err 0.014183267019689083 0.9921875
g_err 0.4870809316635132 0.265625
q_err 0.02141808345913887 0.9921875
j= 2800
d_err 0.5621567964553833 0.68359375
q_err 0.012314742431044579 1.0
g_err 0.47647354006767273 0.2265625
q_err 0.010923675261437893 1.0
j= 2900
d_err 0.599244236946106 0.66015625
q_err 0.04067698493599892 0.98828125
g_err 2.7251224517822266 1.0
q_err 0.034405745565891266 0.99609375
j= 3000
d_err 0.5989581942558289 0.6796875
q_err 0.012062561698257923 1.0
g_err 0.5581036806106567 0.30859375
q_err 0.02138214185833931 0.9921875
j= 3100
d_err 0.32428818941116333 0.8515625
q_err 0.017212536185979843 1.0
g_err 0.3618454337120056 0.16796875
q_err 0.017205871641635895 0.99609375
j= 3200
d_err 0.13148203492164612 0.98828125
q_err 2.1445257663726807 0.4375
g_err 4.309508323669434 1.0
q_err 2.037902593612671 0.32421875
j= 3300
d_err 0.16583207249641418 0.953125
q_err 0.0026981141418218613 1.0
g_err 0.00023306001094169915 0.0
q_err 0.001766779227182269 1.0
j= 3400
d_err 0.4438985586166382 0.76953125
q_err 0.036509934812784195 0.984375
g_err 1.7869658470153809 0.92578125
q_err 0.03236635401844978 0.9921875
j= 3500
d_err 0.6021798849105835 0.66015625
q_err 0.02847929298877716 0.99609375
g_err 1.4560483694076538 0.9296875
q_err 0.014221020974218845 0.99609375
j= 3600
d_err 0.6225249171257019 0.640625
q_err 0.05020720884203911 0.99609375
g_err 1.1402411460876465 0.94140625
q_err 0.054560013115406036 1.0
j= 3700
d_err 0.5020806789398193 0.7734375
q_err 0.11989369988441467 0.97265625
g_err 0.3569372594356537 0.1015625
q_err 0.10553847253322601 0.98828125
j= 3800
d_err 0.12960687279701233 0.94140625
q_err 0.0009343489073216915 1.0
g_err 0.007967859506607056 0.0
q_err 0.0022483952343463898 1.0
j= 3900
d_err 0.11944469809532166 0.98828125
q_err 0.006804435979574919 1.0
g_err 0.04051656275987625 0.0
q_err 0.009790647774934769 1.0
j= 4000
d_err 0.010965949855744839 1.0
q_err 0.0002627667272463441 1.0
g_err 16.11809539794922 1.0
q_err 0.003805197309702635 1.0
j= 4100
d_err 0.6315703988075256 0.59765625
q_err 0.015396513976156712 1.0
g_err 0.5173242092132568 0.24609375
q_err 0.004911016672849655 1.0
j= 4200
d_err 0.6167187690734863 0.65234375
q_err 0.01775602623820305 1.0
g_err 1.0496021509170532 0.8203125
q_err 0.008907206356525421 1.0
j= 4300
d_err 0.5054545998573303 0.76171875
q_err 0.011140765622258186 1.0
g_err 0.30422788858413696 0.13671875
q_err 0.013323429971933365 0.99609375
j= 4400
d_err 0.64670729637146 0.6015625
q_err 0.014357518404722214 0.99609375
g_err 1.7588756084442139 1.0
q_err 0.02918228507041931 0.9921875
j= 4500
d_err 0.6195592880249023 0.640625
q_err 0.035440243780612946 0.984375
g_err 1.5126769542694092 0.953125
q_err 0.0278715118765831 0.9921875
j= 4600
d_err 0.6031675338745117 0.6953125
q_err 0.022218884900212288 0.9921875
g_err 1.4624748229980469 0.984375
q_err 0.02921098843216896 0.98828125
j= 4700
d_err 0.5324845314025879 0.75
q_err 0.01909060962498188 0.99609375
g_err 0.05609864741563797 0.0
q_err 0.004148810636252165 1.0
j= 4800
d_err 0.1548750251531601 0.96875
q_err 0.008205143734812737 1.0
g_err 7.499237108277157e-05 0.0
q_err 0.013815892860293388 1.0
j= 4900
d_err 0.4793142080307007 0.79296875
q_err 0.0078065707348287106 1.0
g_err 0.01641407236456871 0.0
q_err 0.004550071433186531 1.0
j= 5000
d_err 0.5799251198768616 0.6875
q_err 0.022974684834480286 0.99609375
g_err 1.0481146574020386 0.8671875
q_err 0.01591268926858902 0.99609375
j= 5100
d_err 0.6026491522789001 0.6640625
q_err 0.01711229979991913 0.99609375
g_err 1.4368555545806885 0.96484375
q_err 0.028992675244808197 0.984375
j= 5200
d_err 0.6194998621940613 0.6328125
q_err 0.012569364160299301 0.99609375
g_err 0.9758998155593872 0.81640625
q_err 0.015846338123083115 0.99609375
j= 5300
d_err 0.5882796049118042 0.6796875
q_err 0.0040613156743347645 1.0
g_err 1.03365159034729 0.76171875
q_err 0.016620445996522903 0.99609375
j= 5400
d_err 0.4829257130622864 0.7578125
q_err 0.0357079841196537 0.98828125
g_err 2.181743621826172 1.0
q_err 0.02644052915275097 0.99609375
j= 5500
d_err 0.15089082717895508 0.96484375
q_err 0.008504107594490051 1.0
g_err 2.3004199647402856e-06 0.0
q_err 0.006826210767030716 1.0
j= 5600
d_err 0.5143919587135315 0.765625
q_err 0.005243751220405102 1.0
g_err 1.5503950119018555 0.875
q_err 0.017973333597183228 0.9921875
j= 5700
d_err 0.5823131203651428 0.6796875
q_err 0.011814659461379051 1.0
g_err 1.0270307064056396 0.765625
q_err 0.013195635750889778 1.0
j= 5800
d_err 0.4724949598312378 0.80078125
q_err 0.03686525672674179 0.9921875
g_err 0.60578453540802 0.41796875
q_err 0.02053913101553917 0.98828125
j= 5900
d_err 0.32083529233932495 0.859375
q_err 0.009057688526809216 0.99609375
g_err 0.01130110677331686 0.0
q_err 0.0035958504304289818 1.0
j= 6000
d_err 0.4809652864933014 0.75390625
q_err 0.004762311466038227 1.0
g_err 0.43530455231666565 0.2421875
q_err 0.007431278005242348 1.0
j= 6100
d_err 0.6065901517868042 0.625
q_err 0.012344437651336193 0.99609375
g_err 0.6215358972549438 0.3984375
q_err 0.005772797390818596 1.0
j= 6200
d_err 0.5771956443786621 0.6875
q_err 0.024504588916897774 0.9921875
g_err 1.9639694690704346 1.0
q_err 0.0211443193256855 0.9921875
j= 6300
d_err 0.48166319727897644 0.76171875
q_err 0.005819057580083609 1.0
g_err 0.025020653381943703 0.0
q_err 0.012243460863828659 1.0
j= 6400
d_err 0.3880472183227539 0.87890625
q_err 0.030305912718176842 1.0
g_err 0.5148653388023376 0.22265625
q_err 0.017408426851034164 1.0
j= 6500
d_err 0.18827778100967407 0.92578125
q_err 0.004105717875063419 1.0
g_err 0.016727304086089134 0.0
q_err 0.0012775524519383907 1.0
j= 6600
d_err 0.4958047866821289 0.73828125
q_err 0.008683166466653347 0.99609375
g_err 2.5059757232666016 0.953125
q_err 0.02522527053952217 0.99609375
j= 6700
d_err 0.4897058308124542 0.76171875
q_err 0.00430284533649683 1.0
g_err 6.196772575378418 1.0
q_err 0.031418055295944214 0.9921875
j= 6800
d_err 0.7823935747146606 0.53125
q_err 0.0015711404848843813 1.0
g_err 3.596282482147217 1.0
q_err 0.03232293948531151 0.9921875
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-17-edc50dfb4bc3> in <module>()
      8     #print("generator phase")
      9     for __ in range(2):
---> 10         g_err, q_err2, g_acc, q_acc2 = train_generator_fn(sample_Y(), sample_Z())
     11     if j%100==0:
     12         printx("j=", j)

/usr/local/lib/python3.5/dist-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    871         try:
    872             outputs =\
--> 873                 self.fn() if output_subset is None else\
    874                 self.fn(output_subset=output_subset)
    875         except Exception:

KeyboardInterrupt: 

In [18]:
Y=np.zeros((100,10),dtype='float32')
Y[arange(100), arange(100)//10]=1
imgs = generator_fn(Y, sample_Z(100))
for i in range(0, 100, 10):
    show(imgs[i:i+10])



In [19]:
#np.savez('cifar10_gan_classifier_generator.npz', lasagne.layers.get_all_param_values(l_generator))
#np.savez('cifar10_gan_classifier_discriminator.npz', lasagne.layers.get_all_param_values(l_discriminator))
#np.savez('cifar10_gan_classifier_classifier.npz', lasagne.layers.get_all_param_values(l_classifier))

In [20]:
import scipy.stats
ppf = scipy.stats.norm.ppf
pic=None
N = 16
for x in range(N):
    row = None    
    z = np.asarray([ [ppf(0.05*(x+1)), ppf(0.05*(y+1))]+[0]*(NU-2) for y in range(N)] , dtype=theano.config.floatX)
    row = generator_fn(z).reshape(-1, IMG_H, IMG_W)
    row = row.swapaxes(0,1).reshape(IMG_H, -1)
    pic = row if pic is None else np.concatenate((pic,row), axis=0)
display_img_array(pic)


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-20-8fdcf0171b33> in <module>()
      6     row = None
      7     z = np.asarray([ [ppf(0.05*(x+1)), ppf(0.05*(y+1))]+[0]*(NU-2) for y in range(N)] , dtype=theano.config.floatX)
----> 8     row = generator_fn(z).reshape(-1, IMG_H, IMG_W)
      9     row = row.swapaxes(0,1).reshape(IMG_H, -1)
     10     pic = row if pic is None else np.concatenate((pic,row), axis=0)

/usr/local/lib/python3.5/dist-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    856                     raise TypeError("Missing required input: %s" %
    857                                     getattr(self.inv_finder[c], 'variable',
--> 858                                             self.inv_finder[c]))
    859                 if c.provided > 1:
    860                     raise TypeError("Multiple values for input: %s" %

TypeError: Missing required input: <TensorType(float32, matrix)>

In [ ]: