In [21]:
# load Python/Theano stuff
# Show figures inline with the code
%matplotlib inline   

import theano
import theano.tensor as T
import theano.tensor.nlinalg as Tla
import lasagne       # the library we're using for NN's
# import the nonlinearities we might use 
from lasagne.nonlinearities import leaky_rectify, softmax, linear, tanh, rectify
from theano.tensor.shared_randomstreams import RandomStreams
import numpy as np
from numpy.random import *
from matplotlib import pyplot as plt

import cPickle
import sys

# import kmeans clustering algorithm from scikit-learn
from sklearn.cluster import KMeans

In [22]:
# Load our code

# Add all the paths that should matter right now
sys.path.append('lib/') 
from GenerativeModel import *       # Class file for generative models. 
from RecognitionModel import *      # Class file for recognition models
from NVIL import *                  # The meat of the algorithm - define the cost function and initialize Gen/Rec model

# import our covariance-plotting software
from plot_cov import *

In [23]:
# Use the fast compile option
theano.config.optimizer = 'fast_compile'

In [24]:
# Choose Simulation Parameters and Generate Data
xDim = 3 # number of latent classes
yDim = 2 # dimensionality of Gaussian observations
_N = 2500 # number of datapoints to generate
gmm = MixtureOfGaussians(dict([]), xDim, yDim)  # instantiate our 'true' generative model
[xsamp, ysamp] = gmm.sampleXY(_N)

In [25]:
# Set up Lasagne Recognition Network 
rec_is_training = theano.shared(value = 1) 
rec_nn = lasagne.layers.InputLayer((None, yDim))
rec_nn = lasagne.layers.DenseLayer(rec_nn, 100, nonlinearity=leaky_rectify, W=lasagne.init.Orthogonal())
rec_nn = lasagne.layers.DenseLayer(rec_nn, xDim, nonlinearity=softmax, W=lasagne.init.Orthogonal(), b=-5*np.ones(xDim, dtype=theano.config.floatX))
NN_Params = dict([('network', rec_nn)])
recDict = dict([('NN_Params'     , NN_Params)
                ])

In [26]:
# Now we get to try it!

# center our simluated data around the mean
ysamp_mean = ysamp.mean(axis=0, dtype=theano.config.floatX)
ytrain = ysamp - ysamp_mean

# construct a BuildModel object that represents the method
opt_params = dict({'c0': -0.0, 'v0': 1.0, 'alpha': 0.9})
model = BuildModel(opt_params, dict([]), MixtureOfGaussians, recDict, GMMRecognition, xDim, yDim, nCUnits = 100)

# Initialize generative model at the k-means solution
km = KMeans(n_clusters=xDim, n_init=10, max_iter=500)
kmpred = km.fit_predict(ytrain)

km_mu = np.zeros([xDim, yDim])
km_chol = np.zeros([xDim, yDim, yDim])
for cl in np.unique(kmpred):
    km_mu[cl] = ytrain[kmpred == cl].mean(axis=0)
    km_chol[cl] = np.linalg.cholesky(np.cov(ytrain[kmpred == cl].T))
    
model.mprior.mu.set_value(km_mu.astype(theano.config.floatX))
model.mprior.RChol.set_value(km_chol.astype(theano.config.floatX))

km_pi = np.histogram(kmpred,bins=xDim)[0]/(1.0*kmpred.shape[0])
model.mprior.pi_un.set_value(km_pi.astype(theano.config.floatX))

# Initialize with *true* means and covariances
# model.mprior.mu.set_value(gmm.mu.get_value()-ysamp_mean)
# model.mprior.RChol.set_value(gmm.RChol.get_value())
# model.mprior.pi_un.set_value(gmm.pi_un.get_value())

print ysamp.shape


(2500, 2)

In [27]:
# Fit the model
costs = model.fit(ytrain, batch_size = 10, max_epochs=5, learning_rate = 3e-4)


0.00%
(c,v,L): (-0.000000,1.000000,-381.454417)

(c,v,L): (-230.646957,151331.906250,-624.763506)

(c,v,L): (-244.593155,137709.609375,-262.128073)

(c,v,L): (-224.820877,99516.617188,-181.518876)

(c,v,L): (-206.240494,86031.203125,-70.323640)

(c,v,L): (-178.262375,57478.246094,-215.057158)

(c,v,L): (-145.351120,41862.375000,-141.616738)

(c,v,L): (-125.234406,30941.345703,-68.721089)

(c,v,L): (-114.693153,24694.046875,-83.819998)

(c,v,L): (-84.781563,13873.416992,-139.947298)

(c,v,L): (-87.436195,16835.255859,-123.298759)

(c,v,L): (-92.747849,18163.238281,-156.716479)

(c,v,L): (-108.354912,18080.251953,-115.110324)

(c,v,L): (-96.429466,13167.739258,-18.936388)

(c,v,L): (-77.911522,11278.756836,-67.469038)

(c,v,L): (-69.949951,8892.824219,-51.412475)

(c,v,L): (-57.222069,6623.297852,-14.228728)

(c,v,L): (-60.947338,7207.194824,-61.154647)

(c,v,L): (-74.166710,8250.819336,-75.414433)

(c,v,L): (-61.647007,6730.183594,-37.457378)

(c,v,L): (-43.968365,4362.230957,-25.562800)

(c,v,L): (-40.499191,3858.930176,-39.529701)

(c,v,L): (-43.138004,5145.236328,-82.055356)

(c,v,L): (-39.964935,4031.977539,-51.684362)

(c,v,L): (-40.178444,4279.537109,-19.816389)

20.00%
(c,v,L): (-31.523136,4290.281250,-32.292181)

(c,v,L): (-33.570786,3648.847168,-45.125132)

(c,v,L): (-30.605946,2628.063477,-11.894832)

(c,v,L): (-30.441334,3090.038330,-46.722664)

(c,v,L): (-29.054506,2816.677002,-6.403173)

(c,v,L): (-28.062651,3432.213623,-18.773820)

(c,v,L): (-22.952978,2418.750488,-3.600744)

(c,v,L): (-20.348545,1634.450317,-21.542554)

(c,v,L): (-16.635614,1221.700073,-4.820575)

(c,v,L): (-22.351768,3403.019043,-20.517862)

(c,v,L): (-21.231928,2529.743896,-3.430533)

(c,v,L): (-22.553370,2584.311768,-9.503610)

(c,v,L): (-23.586157,2701.328613,-57.914498)

(c,v,L): (-20.125074,1964.513916,-14.314264)

(c,v,L): (-18.201002,1396.966431,-8.454595)

(c,v,L): (-15.408027,1218.896851,-3.970947)

(c,v,L): (-16.678297,1617.837402,-2.357541)

(c,v,L): (-12.052068,852.420837,-6.954519)

(c,v,L): (-12.763860,776.451172,-12.304129)

(c,v,L): (-9.275988,440.500793,-8.899471)

(c,v,L): (-6.810076,362.173492,-2.313657)

(c,v,L): (-6.809136,243.910492,-40.097275)

(c,v,L): (-11.496603,876.663818,-5.466560)

(c,v,L): (-13.251025,1301.660278,-1.758738)

(c,v,L): (-11.556252,1445.822510,-9.799930)

40.00%
(c,v,L): (-11.789690,1431.319824,-6.438772)

(c,v,L): (-9.531307,785.104797,-14.826293)

(c,v,L): (-9.821414,805.538452,-2.637199)

(c,v,L): (-6.818373,399.551117,-2.871299)

(c,v,L): (-10.751789,895.009766,-0.022053)

(c,v,L): (-10.787248,939.909729,-3.327497)

(c,v,L): (-7.670720,685.109802,-9.390002)

(c,v,L): (-6.074403,399.262909,-3.328006)

(c,v,L): (-4.588053,165.652740,-2.560655)

(c,v,L): (-7.643950,479.676819,-1.177116)

(c,v,L): (-8.394846,675.108398,-1.431022)

(c,v,L): (-4.506506,250.738724,-5.944788)

(c,v,L): (-6.102681,263.249939,-11.437505)

(c,v,L): (-7.213571,404.117737,-2.498170)

(c,v,L): (-8.605784,639.354126,-4.038290)

(c,v,L): (-5.129746,239.056564,0.305288)

(c,v,L): (-4.649908,111.321014,-12.765960)

(c,v,L): (-5.200426,146.857803,-0.848184)

(c,v,L): (-5.765917,261.989349,-13.847771)

(c,v,L): (-6.390516,268.453278,-11.848020)

(c,v,L): (-4.751213,170.134796,-2.579096)

(c,v,L): (-4.014360,94.941978,-1.518792)

(c,v,L): (-5.276931,278.515991,-2.607416)

(c,v,L): (-5.286131,490.702759,-1.316242)

(c,v,L): (-3.951294,198.476257,-1.392113)

60.00%
(c,v,L): (-3.022371,78.097969,-6.162159)

(c,v,L): (-3.201929,61.103901,-11.206321)

(c,v,L): (-6.321136,651.821167,-4.277667)

(c,v,L): (-5.092405,336.654755,-0.986410)

(c,v,L): (-6.493281,834.247070,-1.375441)

(c,v,L): (-4.729636,304.646576,-0.623934)

(c,v,L): (-3.877363,179.850983,-4.710475)

(c,v,L): (-2.939471,72.705650,0.092956)

(c,v,L): (-4.311144,240.456894,-2.659768)

(c,v,L): (-4.288558,366.101257,0.446715)

(c,v,L): (-5.857820,679.070557,0.636229)

(c,v,L): (-4.515112,340.758545,-8.283666)

(c,v,L): (-3.315254,140.637192,-0.681295)

(c,v,L): (-3.506474,119.988144,-0.771046)

(c,v,L): (-2.995378,102.390610,0.418121)

(c,v,L): (-2.865550,71.615334,0.231563)

(c,v,L): (-2.988038,46.639858,0.660828)

(c,v,L): (-2.938371,85.680305,0.753444)

(c,v,L): (-2.995304,90.920166,-1.211303)

(c,v,L): (-2.060361,38.339916,0.486205)

(c,v,L): (-2.060455,21.785372,-1.413780)

(c,v,L): (-1.989216,13.233411,-10.694056)

(c,v,L): (-3.750694,239.990921,-0.769784)

(c,v,L): (-2.326471,97.564018,-1.813491)

(c,v,L): (-3.443156,211.825638,0.056007)

80.00%
(c,v,L): (-2.175656,77.590683,-0.536156)

(c,v,L): (-2.691813,137.992203,-0.212648)

(c,v,L): (-2.107022,53.007725,-0.231665)

(c,v,L): (-3.146906,147.628143,-0.859598)

(c,v,L): (-2.038946,55.435528,-0.719319)

(c,v,L): (-3.102749,104.224464,-7.974719)

(c,v,L): (-2.252066,55.512032,-0.167564)

(c,v,L): (-1.996537,38.540119,-0.318581)

(c,v,L): (-1.484071,15.835482,-0.990549)

(c,v,L): (-2.230896,45.998291,-0.304727)

(c,v,L): (-1.659820,19.938444,-0.692982)

(c,v,L): (-2.035134,28.165495,-0.183552)

(c,v,L): (-1.644529,13.423615,-1.796856)

(c,v,L): (-2.036821,69.814873,-0.461201)

(c,v,L): (-1.776712,28.669643,0.358394)

(c,v,L): (-1.346756,13.390749,-0.486441)

(c,v,L): (-1.772709,29.214085,-0.907712)

(c,v,L): (-2.300505,155.631973,-0.938614)

(c,v,L): (-1.635333,58.256401,0.585767)

(c,v,L): (-1.639273,36.734425,-0.223303)

(c,v,L): (-2.576903,206.866440,-0.627808)

(c,v,L): (-2.093708,92.945503,-0.146618)

(c,v,L): (-1.872360,47.618816,-0.340321)

(c,v,L): (-1.512989,19.402433,0.661494)

(c,v,L): (-1.203324,9.155673,0.461475)


In [28]:
# Plot ELBO (variational lower bound objective) against iteration

plt.figure()
plt.plot(costs)
plt.axis('tight')
plt.xlabel('iteration')
plt.ylabel('ELBO\n(averaged over minibatch)')


Out[28]:
<matplotlib.text.Text at 0x7f905948e3d0>

In [29]:
clr = ['b', 'r', 'c','g','m','o']

plt.figure()
plt.subplot(121)
plt.plot(ysamp[:,0], ysamp[:,1],'k.', alpha=.1)
plt.hold('on')
for ii in xrange(xDim):
    Rc= gmm.RChol[ii].eval()
    plot_cov_ellipse(Rc.dot(Rc.T), gmm.mu[ii].eval(), nstd=2, color=clr[ii%5], alpha=.3)
    
plt.title('True Distribution')
plt.ylabel(r'$x_0$')
plt.xlabel(r'$x_1$')

plt.subplot(122)
plt.hold('on')
plt.plot(ytrain[:,0], ytrain[:,1],'k.', alpha=.1)
for ii in xrange(xDim):
    Rc= model.mprior.RChol[ii].eval()
    plot_cov_ellipse(Rc.dot(Rc.T), model.mprior.mu[ii].eval(), nstd=2, color=clr[ii%5], alpha=.3)
    
plt.title('Learned Distributions')    
plt.ylabel(r'$x_0$')
plt.xlabel(r'$x_1$')

plt.show()



In [30]:
xlbl = xsamp.nonzero()[1]
#learned_lbl = model.mrec.h.eval({model.Y:ytrain}).argmax(axis=1)
#learned_lbl = model.mrec.getSample(ytrain).argmax(axis=1)
learned_lbl = model.mrec.h.argmax(axis=1).eval({model.Y:ytrain})

clr = ['b', 'r', 'c','g','m','o']

plt.figure()
for ii in np.random.permutation(xrange(500)):
    plt.subplot(121)
    plt.hold('on')
    plt.plot(ysamp[ii,0] ,ysamp[ii,1],'.', color = clr[xlbl[ii]%5])
    plt.subplot(122)
    plt.hold('on')
    plt.plot(ysamp[ii,0] ,ysamp[ii,1],'.', color = clr[learned_lbl[ii]%5])
    
plt.subplot(121)
plt.title('True Label')
plt.ylabel(r'$x_0$')
plt.xlabel(r'$x_1$')
plt.subplot(122)
plt.title('Inferred Label')
plt.ylabel(r'$x_0$')
plt.xlabel(r'$x_1$')
    
plt.show()



In [31]:
n = 25

x = np.linspace(-3, 3, n)
y = np.linspace(-3, 3, n)
xv, yv = np.meshgrid(x, y)
grid= np.vstack([xv.flatten(), yv.flatten()]).T

gridlabel = model.mrec.getSample(grid.astype(theano.config.floatX)).argmax(axis=1)

plt.figure()
plt.hold('on')
for ii in xrange(n*n):
    plt.plot(grid[ii,0] ,grid[ii,1],'.', color = clr[gridlabel[ii]%5])
plt.ylabel(r'$x_0$')
plt.xlabel(r'$x_1$')
plt.title('Highest-Probability Label Over Sampled Grid')
plt.show()