In [1]:
#!/usr/bin/env python
__authors__ = "Ian Goodfellow"
__copyright__ = "Copyright 2012, Universite de Montreal"
__credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow"
"""

Usage: python show_samples <path_to_a_saved_DBM.pkl>
Displays a batch of data from the DBM's training set.
Then interactively allows the user to run Gibbs steps
starting from that seed data to see how the DBM's MCMC
sampling changes the data.

"""

from pylearn2.utils import serial
import sys
from pylearn2.config import yaml_parse
from pylearn2.gui.patch_viewer import PatchViewer
import time
from theano import function
from theano.sandbox.rng_mrg import MRG_RandomStreams
import numpy as np
from pylearn2.expr.basic import is_binary
import scipy.io.wavfile

In [2]:
rows = 10
cols = 10
m = rows * cols

model_path = 'timit_gdbm_pcd.pkl'

print 'Loading model...'
model = serial.load(model_path)
model.set_batch_size(m)


dataset_yaml_src = model.dataset_yaml_src

print 'Loading data (used for setting up visualization and seeding gibbs chain) ...'
dataset = yaml_parse.load(dataset_yaml_src)


Loading model...
Loading data (used for setting up visualization and seeding gibbs chain) ...

In [5]:
# vis_batch = dataset.get_batch_topo(m)

# _, patch_rows, patch_cols, channels = vis_batch.shape

# assert _ == m

# mapback = hasattr(dataset, 'mapback_for_viewer')

# pv = PatchViewer((rows,cols*(1+mapback)), (patch_rows,patch_cols), is_color = (channels==3))

# def show():
#     display_batch = dataset.adjust_for_viewer(vis_batch)
#     if display_batch.ndim == 2:
#         display_batch = dataset.get_topological_view(display_batch)
#     if mapback:
#         design_vis_batch = vis_batch
#         if design_vis_batch.ndim != 2:
#             design_vis_batch = dataset.get_design_matrix(design_vis_batch)
#         mapped_batch_design = dataset.mapback_for_viewer(design_vis_batch)
#         mapped_batch = dataset.get_topological_view(mapped_batch_design)
#     for i in xrange(rows):
#         row_start = cols * i
#         for j in xrange(cols):
#             pv.add_patch(display_batch[row_start+j,:,:,:], rescale = False)
#             if mapback:
#                 pv.add_patch(mapped_batch[row_start+j,:,:,:], rescale = False)
#     pv.show()


if hasattr(model.visible_layer, 'beta'):
    beta = model.visible_layer.beta.get_value()
# #model.visible_layer.beta.set_value(beta * 100.)
#     print 'beta: ',(beta.min(), beta.mean(), beta.max())

# print 'showing seed data...'
# show()

# print 'How many Gibbs steps should I run with the seed data clamped? (negative = ignore seed data) '
# x = int(input())
x=15


# Make shared variables representing the sampling state of the model
layer_to_state = model.make_layer_to_state(m)
# Seed the sampling with the data batch
vis_sample = layer_to_state[model.visible_layer]

def validate_all_samples():
    # Run some checks on the samples, this should help catch any bugs
    layers = [ model.visible_layer ] + model.hidden_layers

    def check_batch_size(l):
        if isinstance(l, (list, tuple)):
            map(check_batch_size, l)
        else:
            assert l.get_value().shape[0] == m


    for layer in layers:
        state = layer_to_state[layer]
        space = layer.get_total_state_space()
        space.validate(state)
        if 'DenseMaxPool' in str(type(layer)):
            p, h = state
            p = p.get_value()
            h = h.get_value()
            assert np.all(p == h)
            assert is_binary(p)
        if 'BinaryVisLayer' in str(type(layer)):
            v = state.get_value()
            assert is_binary(v)
        if 'Softmax' in str(type(layer)):
            y = state.get_value()
            assert is_binary(y)
            s = y.sum(axis=1)
            assert np.all(s == 1 )
        if 'Ising' in str(type(layer)):
            s = state.get_value()
            assert is_binary((s + 1.) / 2.)



validate_all_samples()

# if x >= 0:
#     if vis_sample.ndim == 4:
#         vis_sample.set_value(vis_batch)
#     else:
#         vis_sample.set_value(dataset.get_design_matrix(vis_batch))

In [6]:
validate_all_samples()

theano_rng = MRG_RandomStreams(2012+9+18)

if x > 0:
    sampling_updates = model.get_sampling_updates(layer_to_state, theano_rng,
            layer_to_clamp = { model.visible_layer : True }, num_steps = x)

    t1 = time.time()
    sample_func = function([], updates=sampling_updates)
    t2 = time.time()
    print 'Clamped sampling function compilation took',t2-t1
    sample_func()


# Now compile the full sampling update
sampling_updates = model.get_sampling_updates(layer_to_state, theano_rng)
assert layer_to_state[model.visible_layer] in sampling_updates

t1 = time.time()
sample_func = function([], updates=sampling_updates)
t2 = time.time()

print 'Sampling function compilation took',t2-t1

# while True:
#     print 'Displaying samples. How many steps to take next? (q to quit, ENTER=1)'
#     while True:
#         x = raw_input()
#         print x
#         if x == 'q':
#             quit()
#         if x == '':
#             x = 1
#             break
#         else:
#             try:
#                 x = int(x)
#                 break
#             except:
#                 print 'Invalid input, try again'


Clamped sampling function compilation took 2.19332385063
Sampling function compilation took 12.7259421349
/usr/local/Cellar/python/2.7.6/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/Theano-0.6.0-py2.7.egg/theano/sandbox/rng_mrg.py:1169: UserWarning: MRG_RandomStreams Can't determine #streams from size ((Elemwise{add,no_inplace}.0,)), guessing 60*256
  nstreams = self.n_streams(size)

In [7]:
print 'Displaying samples. How many steps to take next?'           
x = raw_input()
x = int(x)
for i in xrange(x):
    print i
    sample_func()

validate_all_samples()

vis_batch = vis_sample.get_value()
    #show()

#     if 'Softmax' in str(type(model.hidden_layers[-1])):
#         state = layer_to_state[model.hidden_layers[-1]]
#         value = state.get_value()
#         y = np.argmax(value, axis=1)
#         assert y.ndim == 1
#         for i in xrange(0, y.shape[0], cols):
#             print y[i:i+cols]


Displaying samples. How many steps to take next?
15
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14

In [9]:
scipy.io.wavfile.write("vis_sample1F.wav", 16000, vis_batch.flatten(order='F'))

In [ ]: