In [1]:
from scipy.special import expit
from rbmpy.rbm import RBM
from rbmpy.sampler import DirtyCorrectionMulDimSampler,VanillaSampler,ContinuousSampler,ContinuousApproxSampler, ContinuousApproxMulDimSampler, ApproximatedSampler, LayerWiseApproxSampler,ApproximatedMulDimSampler
from rbmpy.trainer import VanillaTrainier
from rbmpy.performance import Result
import numpy as np
import rbmpy.datasets, rbmpy.performance, rbmpy.plotter, pickle, rbmpy.rbm, os, logging, rbmpy.sampler,math
import math
from rbmpy.rbm import weights_into_hiddens
from rbmpy.progress import Progress

import rbmpy.plotter as pp
from numpy import newaxis
from collections import Counter

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

logger = logging.getLogger()
# Set the logging level to logging.DEBUG 
logger.setLevel(logging.INFO)

%matplotlib inline

case = 1 point = (3,2) pp.image(a[point]["DS_COMP"][case]) pp.image(a[point]["RBM_RECONS"][case][1].reshape(28,28)) pp.image(a[point]["ORBM_RECONS"][case][1].reshape(28,28))


In [2]:
ds_zeros = np.load("datasets/0.npy")
ds_ones = np.load('datasets/1.npy')
ds_twos = np.load('datasets/2.npy')
ds_three = np.load('datasets/3.npy')
ds_four = np.load('datasets/4.npy')
ds_five = np.load('datasets/5.npy')
ds_six = np.load('datasets/6.npy')
ds_seven = np.load('datasets/7.npy')
ds_eight = np.load('datasets/8.npy')
ds_nine = np.load('datasets/9.npy') 
# ds_bar = np.load('datasets/bar.npy')

with open("models/all_models", 'rb') as f: models = pickle.load(f)


In [23]:
pp.images(ds_four[:5],cmap=plt.cm.Greys)



In [3]:
def comp_of_size(one, two, size):
    min_size = min(one.shape[0], two.shape[0])
    shaped_one = one[:min(min_size,size)]
    shaped_two = two[:min(min_size,size)]
#     return shaped_one + shaped_two
    return np.maximum(one[:min_size],two[:min_size])[:size]

def reconstructions_from_comp(mnist_data,model_ids, ds_size,num_gibbs = 100, sampler_class = ApproximatedMulDimSampler):
    
    
    ds_comp = comp_of_size(mnist_data[model_ids[0]][1],mnist_data[model_ids[1]][1] ,size = ds_size)
    
    one_model = mnist_data[model_ids[0]][0]
    two_model = mnist_data[model_ids[1]][0]

    rand_h_a = np.random.randint(0,2,size=( one_model.num_hid()))
    rand_h_b = np.random.randint(0,2,size=( two_model.num_hid()))
    
    orbm_sampler = sampler_class(one_model.weights, two_model.weights ,one_model.hidden_bias, two_model.hidden_bias)
    d = orbm_sampler.v_to_v(rand_h_a, rand_h_b, rbmpy.datasets.squash_images(ds_comp), num_gibbs=num_gibbs)
    
    one_sampler = ContinuousSampler(one_model)
    two_sampler = ContinuousSampler(two_model)
    
    d_a = one_sampler.reconstruction_given_visible(rbmpy.datasets.squash_images(ds_comp), return_sigmoid = True)
    d_b = two_sampler.reconstruction_given_visible(rbmpy.datasets.squash_images(ds_comp), return_sigmoid = True)
    
    return {"DS_COMP": ds_comp.copy(), "RECON": d, "VAN-RECON-A":d_a, "VAN-RECON-B":d_b}

def perform_for_digits(mnist_data,model_ids, ds_size,num_gibbs = 100):
    prog_log =  Progress("EVAL{}".format(model_ids), times)
    prog_log.set_percentage_update_frequency(20)
    ds_comp = comp_of_size(mnist_data[model_ids[0]][1],mnist_data[model_ids[1]][1] ,size = ds_size)
    
    
    one_model = mnist_data[model_ids[0]][0]
    two_model = mnist_data[model_ids[1]][0]

    rand_h_a = np.random.randint(0,2,size=( one_model.num_hid()))
    rand_h_b = np.random.randint(0,2,size=( two_model.num_hid()))
    
    orbm_sampler = ContinuousApproxSampler(one_model.weights, two_model.weights ,one_model.hidden_bias, two_model.hidden_bias)
    
    squashed_data = rbmpy.datasets.squash_images(ds_comp)
    one_sampler = ContinuousSampler(one_model)
    two_sampler = ContinuousSampler(two_model)
    

    for i in range(ds_size):
        d = orbm_sampler.v_to_v(rand_h_a, rand_h_b, squashed_data[i,:], num_gibbs=num_gibbs)
        orbm_recons.append(d)
        d_a = one_sampler.reconstruction_given_visible(squashed_data[i,:], return_sigmoid = True)
        d_b = two_sampler.reconstruction_given_visible(squashed_data[i,:], return_sigmoid = True)
        van_recons.append((d_a,d_b))
    recon_times.append(orbm_recons.copy())
    van_recon_times.append(van_recons.copy())
    prog_log.set_completed_units(j)
   
    return {"model_ids":model_ids ,"DS_COMP": ds_comp.copy(), "ORBM_RECONS": recon_times, "RBM_RECONS":van_recon_times}
    
    
class MonitoredSampler(ContinuousApproxSampler):
    
    def v_to_h(self, h_a, h_b, v , num_gibbs = 100, logging_freq = None):
        """return the hidden representations for the supplied visible pattern"""
        hid_a = h_a
        hid_b = h_b

        if logging_freq:
            progess_logger = Progress(self.__class__.__name__, num_gibbs)
            progess_logger.set_percentage_update_frequency(logging_freq)

        hidden_states = []
        for epoch in range(num_gibbs):
            # get the bentness of the coin used for the bernoulli trial
            psi_a, psi_b = self.p_hid(hid_a, hid_b, self.w_a, self.w_b, v)
            hid_a = self.__bernoulli_trial__(psi_a)
            hid_b = self.__bernoulli_trial__(psi_b)
            hidden_states.append((hid_a, hid_b))
            if logging_freq:
                progess_logger.set_completed_units(epoch)
        return hid_a, hid_b, hidden_states
    
def visible_states_for_hiddens(hidden_states, sampler):
    visible_states = []
    for state in hidden_states:
        v_a, v_b = sampler.h_to_v(state[0], state[1])
        visible_states.append((v_a, v_b))
    return visible_states

def collect_gibbs_chain_data(mnist_data,model_ids,num_gibbs = 100):

    ds_comp = comp_of_size(mnist_data[model_ids[0]][1],mnist_data[model_ids[1]][1], 2)
    one_model = mnist_data[model_ids[0]][0]
    two_model = mnist_data[model_ids[1]][0]

    # random hidden start
    rand_h_a = np.random.randint(0,2,size=( one_model.num_hid()))
    rand_h_b = np.random.randint(0,2,size=( two_model.num_hid()))

    #make our ORBM Sampler
    orbm_sampler = MonitoredSampler(one_model.weights, two_model.weights ,one_model.hidden_bias, two_model.hidden_bias)

    squashed_data = rbmpy.datasets.squash_images(ds_comp)

    ds_size = ds_comp.shape[0]

    prog_log =  Progress("EVAL{}".format(model_ids), ds_size)
    prog_log.set_percentage_update_frequency(50)

    hidden_states_over_dataset = []
    visible_states_over_dataset = []

    for i in range(ds_size):
        h_a,h_b, hidden_states = orbm_sampler.v_to_h(rand_h_a,rand_h_b, squashed_data[i,:],num_gibbs = num_gibbs,logging_freq = 10)        
        visible_states = visible_states_for_hiddens(hidden_states, orbm_sampler)
        hidden_states_over_dataset.append(hidden_states.copy())
        visible_states_over_dataset.append(visible_states)
        prog_log.set_completed_units(i)

    return (hidden_states_over_dataset, visible_states_over_dataset)
    
def plot_recons(recons):
    n_items = recons.shape[0]
    size = round(math.sqrt(recons.shape[1]))
    pp.images(recons.reshape(n_items,size,size), color_range=(0,1))
    
def plot_recons_from_results(full_dict, plot_range):
    recons = full_dict["RECON"]
    plot_recons(recons[0][plot_range[0]:plot_range[1]])
    plot_recons(recons[1][plot_range[0]:plot_range[1]])
    
def plot_recon_from_result(full_dict):
    pp.image(full_dict["DS_COMP"].reshape(28,28))
    recons = full_dict["RECON"]
    plt.suptitle("ORBM-A")
    pp.image(recons[0].reshape(28,28))
    plt.suptitle("ORBM-B")
    pp.image(recons[1].reshape(28,28))
    plt.suptitle("RBM-A")
    pp.image(full_dict["VAN-RECON-A"].reshape(28,28))
    plt.suptitle("RBM-B")
    pp.image(full_dict["VAN-RECON-B"].reshape(28,28))
    
def get_data(size, path):
    mnist_data = dict()
    for i in range(0,10):
        ds = np.load("datasets/{}.npy".format(i))[:size]
        try:
            with open("models/{}{}_models".format(path,i), 'rb') as f:
                model = pickle.load(f)
            mnist_data[i] = (model, ds)
        except FileNotFoundError:
            logging.warn("There is no model for {}".format(i))
    return mnist_data

In [4]:
# mnist_data = get_data(500,'attempt_two/')
# mnist_data["bar"] = (np.load("models/bar_models"), ds_bar)
with open("mnist_data", 'rb') as f:
    mnist_data = pickle.load(f)

with open("mnist_data", "wb") as f: pickle.dump(mnist_data,f)


In [13]:
plot_recon_from_result(reconstructions_from_comp(mnist_data,(3,5), 1 , num_gibbs=100, sampler_class= ContinuousApproxSampler))


I should be called

In [14]:
pp.images(mnist_data[2][1][:5])



In [5]:
plt.imshow(mnist_data[2][1][1].reshape(28,28), interpolation='nearest', cmap=plt.cm.Greys)
# plt.imshow(comp_of_size(mnist_data[2][1],mnist_data[3][1] ,2)[1].reshape(28,28), interpolation='nearest', cmap=plt.cm.Greys)
plt.savefig("test-1")
plt.show()



In [6]:
pp.images(mnist_data[1][0].weights.reshape(100,28,28)[:5])



In [17]:
# compose the current digit with all other digits
a =reconstructions_from_comp(mnist_data,(2,"bar"), 1 , num_gibbs=200, sampler_class= ContinuousApproxSampler)


---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-17-c5019a0ffe6c> in <module>()
      1 # compose the current digit with all other digits
----> 2 a =reconstructions_from_comp(mnist_data,(2,"bar"), 1 , num_gibbs=200, sampler_class= ContinuousApproxSampler)

<ipython-input-11-f918a33d2c99> in reconstructions_from_comp(mnist_data, model_ids, ds_size, num_gibbs, sampler_class)
      9 
     10 
---> 11     ds_comp = comp_of_size(mnist_data[model_ids[0]][1],mnist_data[model_ids[1]][1] ,size = ds_size)
     12 
     13     one_model = mnist_data[model_ids[0]][0]

KeyError: 'bar'

In [ ]:
a.keys()
pp.image(a["DS_COMP"].reshape(28,28))
pp.image(a["RECON"][1].reshape(28,28))
pp.image(a["RECON"][0].reshape(28,28))
pp.image(a["VAN-RECON-A"].reshape(28,28))
pp.image(a["VAN-RECON-B"].reshape(28,28))

results = {} for digit in range(1,9): if digit != 7 and digit != 9: for other_digit in range(1,9): if other_digit != 7 and other_digit != 9: results[(digit, other_digit)] = perform_for_digits(mnist_data,(digit,other_digit), 100 , num_gibbs=100, times=20) other_digit += 1 digit += 1

with open("results.pickilo","wb") as f: pickle.dump(results, f)


In [7]:
num_gibbs = 150
results = {}
# for digit in range(0,9):
#         for other_digit in range(0,9):
#                 results[(digit, other_digit)] = collect_gibbs_chain_data(mnist_data,(digit,other_digit), num_gibbs=num_gibbs)
#                 other_digit += 1
#         print("finished digit")
#         digit += 1
results[(2, 3)] = collect_gibbs_chain_data(mnist_data,(2,3), num_gibbs=500)


INFO:EVAL(2, 3):Created Progress logger for task - EVAL(2, 3)
INFO:MonitoredSampler:Created Progress logger for task - MonitoredSampler
INFO:MonitoredSampler:0.0% complete
INFO:MonitoredSampler:10.0% complete
INFO:MonitoredSampler:20.0% complete
INFO:MonitoredSampler:30.0% complete
INFO:MonitoredSampler:40.0% complete
INFO:MonitoredSampler:50.0% complete
INFO:MonitoredSampler:60.0% complete
INFO:MonitoredSampler:70.0% complete
INFO:MonitoredSampler:80.0% complete
INFO:MonitoredSampler:90.0% complete
INFO:EVAL(2, 3):0.0% complete
INFO:MonitoredSampler:Created Progress logger for task - MonitoredSampler
INFO:MonitoredSampler:0.0% complete
INFO:MonitoredSampler:10.0% complete
INFO:MonitoredSampler:20.0% complete
INFO:MonitoredSampler:30.0% complete
INFO:MonitoredSampler:40.0% complete
INFO:MonitoredSampler:50.0% complete
INFO:MonitoredSampler:60.0% complete
INFO:MonitoredSampler:70.0% complete
INFO:MonitoredSampler:80.0% complete
INFO:MonitoredSampler:90.0% complete
INFO:EVAL(2, 3):50.0% complete

In [8]:
len(results[2,3][0][0])


Out[8]:
500

In [9]:
hidden_over_d = results[(2,3)][0]
visible_over_d = results[(2,3)][1]
gibbs = 2
pp.image(visible_over_d[0][gibbs][0].reshape(28,28))
pp.image(visible_over_d[0][gibbs][1].reshape(28,28))
data_item = 0
images_model_one = []
images_model_two = []
for gibbs in visible_over_d[data_item]:
    images_model_one.append(gibbs[0].reshape(28,28))
    images_model_two.append(gibbs[1].reshape(28,28))



In [42]:
np.array(visible_over_d[0])[0].shape


Out[42]:
(2, 784)

In [10]:
for key in results:
    digit_combo_stats = results[key]
    for stats in digit_combo_stats[1][0]:
        reconstruction = stats[0].reshape(28,28) # GRAB FIRST OF TUP;E
        pp.image(reconstruction)



In [57]:
import matplotlib.animation as animation

In [13]:
# images = results[(4,2)][1]
fig,((ax, ax2),( ax3, ax4)) = plt.subplots(2,2,sharex='col', sharey='row')
all_axes = [ax, ax2, ax3, ax4]

fig.suptitle("Gibbs Iter:1", fontsize=10)
im = ax.imshow(images_model_one[0], cmap=plt.get_cmap('Greys'))
im2 = ax2.imshow(images_model_two[0], cmap=plt.get_cmap('Greys'))
im3 = ax3.imshow(ds_twos[0],cmap=plt.cm.Greys)
im4 = ax4.imshow(ds_three[0],cmap=plt.cm.Greys)

for axes in all_axes:
    axes.get_xaxis().set_ticks([])
    axes.get_yaxis().set_ticks([])

# ax.imshow
ax.set_title("2 Model Reconstruction",loc="left")
ax2.set_title("3 Model Reconstruction", loc= 'right')
ax3.set_title("Goals: 2 Ground Truth", loc = "left")
ax4.set_title("3 Ground Truth", loc = 'right')


plt.tight_layout()
def updatefig(i):
    fig.suptitle("Gibbs Iter:{}".format(i+1), fontsize= 10)
    im.set_array(images_model_one[i])
    im2.set_array(images_model_two[i])
    return [im]

ani = animation.FuncAnimation(fig, updatefig,frames=range(len(images_model_one)), interval=20, blit=True)
# plt.show()

# Set up formatting for the movie files
Writer = animation.writers['ffmpeg']
writer = Writer(fps=15, metadata=dict(artist='Max Godfrey'), bitrate=1800)
ani.save('11test11.mp4', writer=writer)



In [38]:
len(images)


Out[38]:
100

In [142]:
pp.image(images_model_one[0])



In [168]:
ds_twos[0]


Out[168]:
array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.05098039,  0.09803922,  0.39215686,  0.47843137,  0.02745098,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.12941176,  0.59215686,
         0.81568627,  0.98823529,  0.98823529,  0.98823529,  0.57254902,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.15686275,  0.59607843,  0.95686275,  0.98823529,
         0.99215686,  0.87843137,  0.82745098,  0.98823529,  0.90980392,
         0.15686275,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.05882353,
         0.59607843,  0.9372549 ,  0.98823529,  0.98823529,  0.98823529,
         0.84705882,  0.12156863,  0.14509804,  0.98823529,  0.98823529,
         0.23529412,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.37647059,
         0.98823529,  0.98823529,  0.98823529,  0.98823529,  0.85098039,
         0.11372549,  0.        ,  0.14509804,  0.98823529,  0.98823529,
         0.23529412,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.70980392,
         0.98823529,  0.98823529,  0.8627451 ,  0.65490196,  0.11764706,
         0.        ,  0.        ,  0.30196078,  0.98823529,  0.98823529,
         0.23529412,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.10196078,
         0.50196078,  0.22745098,  0.08627451,  0.        ,  0.        ,
         0.        ,  0.        ,  0.39215686,  0.98823529,  0.98823529,
         0.23529412,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.61568627,  0.98823529,  0.98823529,
         0.23529412,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.43137255,  0.4745098 ,
         0.47843137,  0.4745098 ,  0.79215686,  0.98823529,  0.76078431,
         0.01176471,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.03921569,  0.20784314,  0.70196078,  0.99215686,  0.99215686,
         1.        ,  0.99215686,  0.99215686,  0.89411765,  0.1372549 ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.01960784,  0.21176471,
         0.89019608,  0.98823529,  0.95294118,  0.89411765,  0.66666667,
         0.94901961,  0.98823529,  0.98823529,  0.90588235,  0.45882353,
         0.02352941,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.02352941,  0.30588235,  0.98823529,
         0.98823529,  0.49019608,  0.23137255,  0.        ,  0.07058824,
         0.81568627,  0.98823529,  0.98823529,  0.98823529,  0.98823529,
         0.34117647,  0.02745098,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.01960784,  0.52941176,  0.98823529,  0.98823529,
         0.70588235,  0.0627451 ,  0.        ,  0.08235294,  0.79607843,
         0.99215686,  0.96862745,  0.50588235,  0.67843137,  0.98823529,
         0.98823529,  0.72156863,  0.25882353,  0.19215686,  0.19215686,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.01176471,  0.53333333,  0.98823529,  0.94509804,  0.41568627,
         0.06666667,  0.        ,  0.20784314,  0.78431373,  0.98823529,
         0.84705882,  0.25490196,  0.        ,  0.05490196,  0.28235294,
         0.63921569,  0.94509804,  0.98823529,  0.98823529,  0.8745098 ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.41176471,  0.98823529,  0.94901961,  0.34509804,  0.07058824,
         0.28627451,  0.66666667,  0.95686275,  0.98823529,  0.49411765,
         0.11372549,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.34901961,  0.70588235,  0.70588235,  0.14509804,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.90588235,  0.98823529,  0.96078431,  0.80392157,  0.84705882,
         0.98823529,  0.98823529,  0.98823529,  0.48627451,  0.01176471,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.81176471,  0.98823529,  0.98823529,  0.98823529,  0.98823529,
         0.69803922,  0.45490196,  0.14117647,  0.01568627,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.05098039,  0.36470588,  0.56078431,  0.4745098 ,  0.09019608,
         0.02352941,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ]])

In [199]:
help(plt.subplots)


Help on function subplots in module matplotlib.pyplot:

subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **fig_kw)
    Create a figure with a set of subplots already made.
    
    This utility wrapper makes it convenient to create common layouts of
    subplots, including the enclosing figure object, in a single call.
    
    Keyword arguments:
    
      *nrows* : int
        Number of rows of the subplot grid.  Defaults to 1.
    
      *ncols* : int
        Number of columns of the subplot grid.  Defaults to 1.
    
      *sharex* : string or bool
        If *True*, the X axis will be shared amongst all subplots.  If
        *True* and you have multiple rows, the x tick labels on all but
        the last row of plots will have visible set to *False*
        If a string must be one of "row", "col", "all", or "none".
        "all" has the same effect as *True*, "none" has the same effect
        as *False*.
        If "row", each subplot row will share a X axis.
        If "col", each subplot column will share a X axis and the x tick
        labels on all but the last row will have visible set to *False*.
    
      *sharey* : string or bool
        If *True*, the Y axis will be shared amongst all subplots. If
        *True* and you have multiple columns, the y tick labels on all but
        the first column of plots will have visible set to *False*
        If a string must be one of "row", "col", "all", or "none".
        "all" has the same effect as *True*, "none" has the same effect
        as *False*.
        If "row", each subplot row will share a Y axis and the y tick
        labels on all but the first column will have visible set to *False*.
        If "col", each subplot column will share a Y axis.
    
      *squeeze* : bool
        If *True*, extra dimensions are squeezed out from the
        returned axis object:
    
        - if only one subplot is constructed (nrows=ncols=1), the
          resulting single Axis object is returned as a scalar.
    
        - for Nx1 or 1xN subplots, the returned object is a 1-d numpy
          object array of Axis objects are returned as numpy 1-d
          arrays.
    
        - for NxM subplots with N>1 and M>1 are returned as a 2d
          array.
    
        If *False*, no squeezing at all is done: the returned axis
        object is always a 2-d array containing Axis instances, even if it
        ends up being 1x1.
    
      *subplot_kw* : dict
        Dict with keywords passed to the
        :meth:`~matplotlib.figure.Figure.add_subplot` call used to
        create each subplots.
    
      *gridspec_kw* : dict
        Dict with keywords passed to the
        :class:`~matplotlib.gridspec.GridSpec` constructor used to create
        the grid the subplots are placed on.
    
      *fig_kw* : dict
        Dict with keywords passed to the :func:`figure` call.  Note that all
        keywords not recognized above will be automatically included here.
    
    Returns:
    
    fig, ax : tuple
    
      - *fig* is the :class:`matplotlib.figure.Figure` object
    
      - *ax* can be either a single axis object or an array of axis
        objects if more than one subplot was created.  The dimensions
        of the resulting array can be controlled with the squeeze
        keyword, see above.
    
    Examples::
    
        x = np.linspace(0, 2*np.pi, 400)
        y = np.sin(x**2)
    
        # Just a figure and one subplot
        f, ax = plt.subplots()
        ax.plot(x, y)
        ax.set_title('Simple plot')
    
        # Two subplots, unpack the output array immediately
        f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
        ax1.plot(x, y)
        ax1.set_title('Sharing Y axis')
        ax2.scatter(x, y)
    
        # Four polar axes
        plt.subplots(2, 2, subplot_kw=dict(polar=True))
    
        # Share a X axis with each column of subplots
        plt.subplots(2, 2, sharex='col')
    
        # Share a Y axis with each row of subplots
        plt.subplots(2, 2, sharey='row')
    
        # Share a X and Y axis with all subplots
        plt.subplots(2, 2, sharex='all', sharey='all')
        # same as
        plt.subplots(2, 2, sharex=True, sharey=True)


In [207]:
fig,yerp = plt.subplots(2,3,sharex='col', sharey='row', squeeze=True)



In [11]:
def no_ticks(x):
    x.get_xaxis().set_ticks([])
    x.get_yaxis().set_ticks([])
    return x

In [12]:
def make_me_a_movie_star(item, visible_over_d, hidden_over_d, filename):
    images_model_one,images_model_two = image_data_for_item(visible_over_d, item)
    h_a, h_b = image_data_for_hidden(hidden_over_d, item) 
    
    fig = plt.figure()
    ax1 = plt.subplot2grid((3,3), (0,1))
    ax2 = plt.subplot2grid((3,3), (0,0))
    ax3 = plt.subplot2grid((3,3), (0, 2))
    ax4 = plt.subplot2grid((3,3), (1, 0))
    ax5 = plt.subplot2grid((3,3), (1, 2))
    ax6 = plt.subplot2grid((3,3), (2, 0))
    ax7 = plt.subplot2grid((3,3), (2, 2))
    list(map(no_ticks, fig.get_axes()))
    fig.suptitle("Gibbs Iter:1", fontsize=10)
    ax1.set_title("Composite Input")
    ax2.set_title("Ground Truth 2")
    ax3.set_title("Ground Truth 3")
    ax4.set_title("ORBM Recon 2")
    ax5.set_title("ORBM Recon 3")
    
    ax6.set_title("Hidden 2")
    ax7.set_title("Hidden 3")

    ax1.imshow(comp_of_size(ds_twos,ds_three,500)[item].reshape(28,28), cmap=plt.cm.Greys)
    im1 = ax4.imshow(images_model_one[0], cmap=plt.get_cmap('Greys'), interpolation = 'nearest')
    im2 = ax5.imshow(images_model_two[0], cmap=plt.get_cmap('Greys'),  interpolation = 'nearest')
    im3 = ax6.imshow(h_a[0], cmap = plt.cm.Greys, interpolation = 'nearest')
    im4 = ax7.imshow(h_b[0], cmap = plt.cm.Greys, interpolation = 'nearest')
    
    ax2.imshow(ds_twos[item],cmap=plt.cm.Greys,  interpolation = 'nearest')
    ax3.imshow(ds_three[item],cmap=plt.cm.Greys,  interpolation = 'nearest')


    def updatefig(i):
        fig.suptitle("Gibbs Iter:{}".format(i+1), fontsize= 10)
        im1.set_array(images_model_one[i])
        im2.set_array(images_model_two[i])
        im3.set_array(h_a[i])
        im4.set_array(h_b[i])
        return [im1]

    ani = animation.FuncAnimation(fig, updatefig,frames=range(len(images_model_one)), interval=20, blit=True)
    # plt.show()

    # Set up formatting for the movie files
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=15, metadata=dict(artist='Max Godfrey'), bitrate=1800)
    ani.save(filename, writer=writer)

In [13]:
def image_data_for_item(visible_over_d,item):
    images_model_one = []
    images_model_two = []
    for gibbs in visible_over_d[item]:
        images_model_one.append(gibbs[0].reshape(28,28))
        images_model_two.append(gibbs[1].reshape(28,28))
    return images_model_one, images_model_two

def image_data_for_hidden(hidden_over_d, item):
    hid_a = []
    hid_b = []
    for gibbs in hidden_over_d[item]:
        hid_a.append(gibbs[0].reshape(10,10))
        hid_b.append(gibbs[1].reshape(10,10))
    return hid_a, hid_b

In [14]:
hidden_over_d = results[(2,3)][0]
visible_over_d = results[(2,3)][1]

In [15]:
images_model_one,images_model_two = image_data_for_item(visible_over_d, 2)
h_a, h_b = image_data_for_hidden(hidden_over_d, 0)


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-15-8c29d85d38bd> in <module>()
----> 1 images_model_one,images_model_two = image_data_for_item(visible_over_d, 2)
      2 h_a, h_b = image_data_for_hidden(hidden_over_d, 0)

<ipython-input-13-555e1d8bef23> in image_data_for_item(visible_over_d, item)
      2     images_model_one = []
      3     images_model_two = []
----> 4     for gibbs in visible_over_d[item]:
      5         images_model_one.append(gibbs[0].reshape(28,28))
      6         images_model_two.append(gibbs[1].reshape(28,28))

IndexError: list index out of range

In [18]:
data_item = 1
gibbs_steps_of_interest = [0,2,5,10,50,100,499]
for gibbs in range(1,2):
    for step in gibbs_steps_of_interest:
        a =np.array(image_data_for_item(visible_over_d, gibbs)[data_item])[step]
        plt.imshow(a, cmap="Greys",interpolation="nearest")
        plt.colorbar()
#         plt.savefig("test{}".format(step))

        plt.show()



In [19]:
hidden_over_d = results[(2,3)][0]
visible_over_d = results[(2,3)][1]



make_me_a_movie_star(1, visible_over_d,hidden_over_d,"mixing1.mp4")
# make_me_a_movie_star(2, visible_over_d,hidden_over_d,"mixing2.mp4")
# make_me_a_movie_star(3, visible_over_d,hidden_over_d,"mixing3.mp4")
# make_me_a_movie_star(4, visible_over_d,hidden_over_d,"mixing4.mp4")
# make_me_a_movie_star(5, visible_over_d,hidden_over_d,"mixing5.mp4")


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-19-45b73b2e0fd3> in <module>()
      4 
      5 
----> 6 make_me_a_movie_star(1, visible_over_d,hidden_over_d,"mixing1.mp4")
      7 # make_me_a_movie_star(2, visible_over_d,hidden_over_d,"mixing2.mp4")
      8 # make_me_a_movie_star(3, visible_over_d,hidden_over_d,"mixing3.mp4")

<ipython-input-12-abd550fb44d9> in make_me_a_movie_star(item, visible_over_d, hidden_over_d, filename)
     40         return [im1]
     41 
---> 42     ani = animation.FuncAnimation(fig, updatefig,frames=range(len(images_model_one)), interval=20, blit=True)
     43     # plt.show()
     44 

NameError: name 'animation' is not defined

In [60]:
!open .

In [27]:
weights = mnist_data[2][0].weights.copy()

In [28]:
pp.image(weights.reshape(100, 28, 28)[0])



In [30]:
for i in range(5):
    pp.hinton(weights.reshape(100, 28, 28)[i])
    plt.show()



In [ ]: