In [1]:
from scipy.special import expit
from rbmpy.rbm import RBM
from rbmpy.sampler import DirtyCorrectionMulDimSampler,VanillaSampler,ContinuousSampler,ContinuousApproxSampler, ContinuousApproxMulDimSampler, ApproximatedSampler, LayerWiseApproxSampler,ApproximatedMulDimSampler
from rbmpy.trainer import VanillaTrainier
from rbmpy.performance import Result
import numpy as np
import rbmpy.datasets, rbmpy.performance, rbmpy.plotter, pickle, rbmpy.rbm, os, logging, rbmpy.sampler,math
import math
from rbmpy.rbm import weights_into_hiddens
from rbmpy.progress import Progress
from scipy.spatial.distance import cosine


import rbmpy.plotter as pp
from numpy import newaxis
from collections import Counter

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

logger = logging.getLogger()
# Set the logging level to logging.DEBUG 
logger.setLevel(logging.INFO)

%matplotlib inline

In [2]:
from IPython.core.debugger import Tracer

In [3]:
#load the data
import glob
result_file_name = glob.glob("Results/approx_correction/job-2618368--task-1")


raw_data = []
for result in result_file_name:
    with open(result,"rb") as f:
        raw_data.append(pickle.load(f))

other_file_name = glob.glob("Results/2613190/*") other_raw_data = [] for result in other_file_name: with open(result,"rb") as f: other_raw_data.append(pickle.load(f))

raw_data[0].keys()

other_raw_data[0].keys()

for i in range(len(other_raw_data)): raw_data[i].update(other_raw_data[i])


In [4]:
def get_data(size):
    mnist_data = dict()
    for i in range(0,10):
        ds = np.load("datasets/{}.npy".format(i))[:size]
        try:
            with open("models/{}_models".format(i), 'rb') as f:
                model = pickle.load(f)
            mnist_data[i] = (model, ds)
        except FileNotFoundError:
            logging.warn("There is no model for {}".format(i))
    return mnist_data

Structure of data

{(n,m) => {'ORBM_RECONS':..,'RBM_RECONS':..., 'model_ids':(n,m),'DS_COMP':... }}


In [28]:
class Result(object):
    
    def __init__(self):
        self.orbm_a_recons = {}
        self.orbm_b_recons = {}
        self.rbm_a_recons  = {}
        self.rbm_b_recons  = {}
        
        self.orbm_a_scores = {}
        self.orbm_b_scores = {}
        self.rbm_a_scores = {}
        self.rbm_b_scores = {}
        
    def safe_add(self, key, value, score_dict):
        if key not in score_dict:
            score_dict[key] = []
        score_dict[key].append(value)
    
#   orbm_a_recons, orbm_b_recons,rbm_a_recons, rbm_b_recons ,orbm_a_score, orbm_b_score, rbm_a_score, rbm_b_score
    def add_scores(self,model_ids, orbm_a_recons, orbm_b_recons, rbm_a_recons, rbm_b_recons,orbm_a, orbm_b, rbm_a,rbm_b):
        self.safe_add(model_ids, orbm_a_recons, self.orbm_a_recons)
        self.safe_add(model_ids, orbm_b_recons, self.orbm_b_recons)
        self.safe_add(model_ids, rbm_a_recons, self.rbm_a_recons)
        self.safe_add(model_ids, rbm_b_recons, self.rbm_b_recons)
        
        self.safe_add(model_ids, orbm_a, self.orbm_a_scores)
        self.safe_add(model_ids, orbm_b, self.orbm_b_scores)
    
        self.safe_add(model_ids, rbm_a, self.rbm_a_scores)
        self.safe_add(model_ids, rbm_b, self.rbm_b_scores)
    
    def hightest_scores(key, n):
        """Grab the n highest values from the key thing"""
        pass
    
    
    def ndarrayify(self, score_dict):
        nd_score_dict = {}
        for key in score_dict:
            nd_score_dict[key] = np.array(score_dict[key])
        return nd_score_dict

    def mean_scores(self):
        or_a = self.ndarrayify(self.orbm_a_scores)
        or_b = self.ndarrayify(self.orbm_b_scores)
        r_a  = self.ndarrayify(self.rbm_a_scores)
        r_b  = self.ndarrayify(self.rbm_b_scores)

        or_a = {k: v.mean(axis = 0) for k, v in or_a.items()}
        or_b = {k: v.mean(axis = 0) for k, v in or_b.items()}
        r_a  = {k: v.mean(axis = 0) for k, v in r_a.items()}
        r_b  = {k: v.mean(axis = 0) for k, v in r_b.items()}
        
        return or_a, or_b, r_a, r_b
    
    def matrix_for_scores(self,scores):
        x_idx, y_idx = list(map(list, zip(*scores.keys())))

        matrix = np.ones(shape=(max(x_idx)+1, max(y_idx)+1))
        for key in scores:
            matrix[key[0],key[1]] = scores[key].sum()

        return matrix
    
    def mean_score_matrices(self):
        or_a, or_b, r_a, r_b = self.mean_scores()
        
        return self.matrix_for_scores(or_a), self.matrix_for_scores(or_b), self.matrix_for_scores(r_a), self.matrix_for_scores(r_b)
    
def ll_image_wise_score(v, v_prime):

#     return -(abs(v - v_prime)).sum(1)
  
    return (v * np.log(v_prime) + ((1 - v) * np.log((1 - v_prime)))).sum(1)

def cosine_score(v, v_prime):
    thing = []
    for i in range(v.shape[0]):
        thing.append(1 - cosine(v[i,:],v_prime[i,:]))

        
    return np.array(thing)
    
def unzip_and_ndarrays(list_of_2d_tuple):
    a = list(map(list, zip(*list_of_2d_tuple)))
    return np.array(a[0]), np.array(a[1])
       
def scores_for_recons(target_recons,orbm_recons, rbm_recons):
    
    target_a_recons, target_b_recons = target_recons
    
    orbm_a_recons, orbm_b_recons = unzip_and_ndarrays(orbm_recons)
    rbm_a_recons, rbm_b_recons = unzip_and_ndarrays(rbm_recons)
    # Calculate the scores
    orbm_a_score = ll_image_wise_score(target_a_recons, orbm_a_recons)
    orbm_b_score = ll_image_wise_score(target_b_recons, orbm_b_recons)
    rbm_a_score = ll_image_wise_score(target_a_recons, rbm_a_recons)
    rbm_b_score = ll_image_wise_score(target_b_recons, rbm_b_recons)
    return orbm_a_recons, orbm_b_recons,rbm_a_recons, rbm_b_recons ,orbm_a_score, orbm_b_score, rbm_a_score, rbm_b_score


In [29]:
mnist_data = get_data(None)

In [30]:
def ds_from_mnist_ids(key, size):
    a_data = mnist_data[key[0]][1][:size]
    b_data = mnist_data[key[1]][1][:size]
    return rbmpy.datasets.flatten_data_set(a_data),rbmpy.datasets.flatten_data_set(b_data)

In [31]:
# # def scores_for_recons(target_recons,obrm_recons, rbm_recons, flattened_ds_comp)
# temp_orbm = raw_data[0][(1,1)]["ORBM_RECONS"] 
# temp_rbm = raw_data[0][(1,1)]["RBM_RECONS"]
# temp_ds_comp = raw_data[0][(1,1)]["DS_COMP"].reshape(temp_ds_comp.shape[0], temp_ds_comp.shape[1] * temp_ds_comp.shape[2])
# scores_for_recons((temp_ds_comp,temp_ds_comp),temp_orbm,temp_rbm, temp_ds_comp)[0].shape
# flattened_ds_comp = ds_comp.reshape(ds_comp.shape[0], ds_comp.shape[1] * ds_comp.shape[2])

In [32]:
result = Result()
for run in raw_data:
    for key in run:
# key = (2,1)
        current_comp = run[key]
        orbm_recons = current_comp["ORBM_RECONS"]
        rbm_recons = current_comp["RBM_RECONS"]

        size = current_comp["DS_COMP"].shape[0]
        target_recons = ds_from_mnist_ids(key, size)
        result.add_scores(key, *scores_for_recons(target_recons, orbm_recons,rbm_recons))

In [33]:
result.orbm_a_recons[(7,3)][0].shape


Out[33]:
(500, 784)

In [34]:
a = result.mean_scores()
o_a, o_b, r_a, r_b = result.mean_score_matrices()

In [35]:
def plot_matrix(m, ticks):
    plt.imshow(m, interpolation='nearest',cmap=plt.cm.RdYlBu)
    plt.xticks(ticks)
    plt.yticks(ticks)
    plt.colorbar()
    plt.show()

In [36]:
raw_data[0].keys()


Out[36]:
dict_keys([(7, 3), (6, 9), (1, 3), (4, 8), (3, 0), (2, 8), (9, 8), (8, 0), (0, 7), (6, 2), (1, 6), (3, 7), (2, 5), (8, 5), (5, 8), (4, 0), (9, 0), (6, 7), (5, 5), (7, 6), (5, 0), (0, 4), (7, 9), (1, 1), (3, 2), (2, 6), (8, 2), (4, 5), (9, 3), (6, 0), (1, 4), (7, 5), (2, 3), (1, 9), (8, 7), (4, 2), (9, 6), (6, 5), (5, 3), (0, 1), (7, 0), (6, 8), (3, 1), (9, 9), (0, 6), (1, 7), (0, 9), (3, 9), (7, 8), (2, 4), (8, 4), (5, 9), (4, 7), (9, 1), (6, 6), (5, 6), (7, 7), (2, 1), (8, 9), (9, 4), (5, 1), (0, 3), (7, 2), (1, 2), (7, 4), (4, 9), (3, 3), (2, 9), (8, 1), (4, 4), (6, 3), (1, 5), (3, 6), (2, 2), (8, 6), (4, 1), (9, 7), (6, 4), (5, 4), (0, 0), (7, 1), (0, 5), (1, 0), (0, 8), (3, 5), (2, 7), (8, 3), (4, 6), (9, 2), (3, 4), (6, 1), (5, 7), (3, 8), (2, 0), (1, 8), (8, 8), (4, 3), (9, 5), (5, 2), (0, 2)])

In [37]:
ticks = np.arange(10)
plot_matrix(o_a, ticks)
plot_matrix(o_b, ticks)


49 - np.where(np.isclose((o_b - r_b).T, (o_a - r_a)), 1,0).sum()


In [62]:



Out[62]:
(1, 500)

In [85]:
comp = (5,4)
or_score_of_interest = np.array(result.orbm_a_scores[comp])
or_recon_of_interest = np.array(result.orbm_a_recons[comp])
or_b_recon = np.array(result.orbm_b_recons[comp])
r_score_of_interest = np.array(result.rbm_a_scores[comp])
r_recon_of_interest = np.array(result.rbm_a_recons[comp])
r_b_recon = np.array(result.rbm_b_recons[comp])
or_max_idx = np.unravel_index(or_score_of_interest.argmax(), or_score_of_interest.shape)
# r_max_idx  = np.unravel_index(r_score_of_interest.argmax(), r_score_of_interest.shape)

# pp.image(or_recon_of_interest[or_max_idx[0], or_max_idx[1], :].reshape(28,28))
# pp.image(r_recon_of_interest[r_max_idx[0], or_max_idx[1], :].reshape(28,28))
# # pp.image(mnist_data[comp[0]][1][or_max_idx[1]])
# target = mnist_data[comp[0]][1][or_max_idx[1]]
# van_sampler = ContinuousSampler(mnist_data[comp[0]][0])
# pp.image(van_sampler.reconstruction_given_visible(target.reshape(28 * 28)).reshape(28, 28))

In [86]:
# or_score_of_interest[run].argsort()
or_score_of_interest[run][np.flipud(or_score_of_interest[run].argsort())]


Out[86]:
array([ -92.06121692, -104.41839821, -107.1560171 , -113.08683157,
       -121.27138166, -125.74664025, -130.68733325, -130.95171103,
       -132.61241714, -136.89081975, -137.31707639, -137.43197886,
       -139.94137854, -140.15873082, -141.24788151, -142.55452309,
       -144.75692634, -144.85080966, -145.86619368, -146.14141773,
       -146.14937857, -146.28593188, -147.04736269, -147.28282722,
       -147.45046988, -148.25129127, -149.52362399, -150.8597346 ,
       -150.94250971, -151.8429689 , -151.90065566, -152.2048932 ,
       -152.56945299, -155.1729493 , -155.63189115, -155.66342393,
       -155.79986061, -156.50229558, -156.89020022, -157.80871214,
       -158.11372586, -160.10565571, -160.49436559, -160.77454743,
       -161.15042437, -161.23640251, -162.535614  , -163.36967694,
       -164.07104581, -164.1088925 , -164.29692725, -164.66412995,
       -165.1763228 , -165.78631215, -166.4639064 , -167.74432935,
       -168.28757741, -170.10802493, -170.20565709, -170.3993574 ,
       -170.97920883, -171.67427393, -171.99094571, -172.26273787,
       -172.39337991, -172.90153792, -173.24255971, -173.5618571 ,
       -174.21060946, -174.74276723, -175.36993551, -175.44242244,
       -176.63453244, -176.70033303, -176.98420994, -177.26852388,
       -177.31982349, -177.55365051, -177.57955038, -177.60361519,
       -178.15383594, -178.45073218, -178.49000039, -178.70482078,
       -179.17278901, -179.25279218, -180.09064124, -180.36503849,
       -180.37521881, -180.60822915, -180.66073952, -181.04795328,
       -181.07213208, -181.47683236, -181.99495857, -182.15068094,
       -182.19040858, -182.21109764, -182.60008162, -183.07015146,
       -183.5194908 , -183.73504841, -183.92578886, -184.27295115,
       -184.31743497, -185.82624438, -185.83874574, -185.91099302,
       -185.96439321, -186.76527791, -187.13082378, -187.49926637,
       -187.50171476, -187.54001352, -187.88822074, -187.94269846,
       -188.12197369, -188.17713773, -188.26226205, -188.70296183,
       -188.89852932, -188.90307331, -189.12918901, -189.37479332,
       -189.41571094, -189.45351319, -189.56236376, -190.00909161,
       -191.00173579, -191.08800778, -191.52912439, -192.86565267,
       -193.15427361, -193.40096612, -193.56775759, -193.84393233,
       -194.02255925, -194.25842292, -194.82971252, -195.01479218,
       -195.40048313, -195.46330266, -196.00648238, -196.45457418,
       -197.00664429, -197.03136008, -197.14040441, -197.23291781,
       -197.32846092, -197.71540136, -198.75515252, -198.97453897,
       -199.72870372, -200.37444367, -201.36270162, -201.70618858,
       -201.90255451, -202.06192823, -202.44393985, -202.94578662,
       -203.60690477, -204.16192017, -204.58312076, -204.7054588 ,
       -204.76921612, -204.8222588 , -205.20562219, -206.01492524,
       -207.67539085, -208.20133379, -208.22617454, -208.35659733,
       -208.37996642, -208.56687508, -209.05060743, -209.19883846,
       -209.66530011, -209.76169524, -210.00026407, -210.06719258,
       -210.471295  , -210.89707524, -211.0507172 , -211.5423576 ,
       -211.54570018, -211.80237379, -212.15693021, -212.75012144,
       -213.07134623, -213.12128551, -213.32719372, -213.42721632,
       -213.54943249, -213.80650929, -214.39554479, -214.84043008,
       -214.87332395, -214.97331169, -215.90193958, -216.25418609,
       -216.41746333, -216.42397138, -216.59757636, -216.64055855,
       -217.45983067, -217.4672841 , -217.87709564, -217.9679255 ,
       -218.03607502, -218.0417631 , -218.56859871, -218.7194114 ,
       -218.76670865, -218.88096469, -220.1148259 , -220.25214521,
       -220.71738364, -221.23243239, -221.41682735, -221.49366748,
       -221.53299989, -222.19211283, -222.2302927 , -222.35647546,
       -222.38238199, -222.50576008, -222.53200296, -222.58924564,
       -222.77456479, -223.27068844, -223.32995053, -223.41862726,
       -223.75967046, -223.99142231, -224.43318054, -225.11806666,
       -225.19934213, -225.33540531, -225.83035665, -226.06514436,
       -226.42548379, -226.52201042, -226.99336306, -227.0680505 ,
       -227.28006967, -227.40205643, -227.44270738, -227.55665901,
       -227.7028541 , -228.13527984, -228.34512248, -228.37130833,
       -228.85847254, -228.95327229, -229.15039061, -229.52284401,
       -229.69677205, -230.01156723, -231.24425497, -231.41326554,
       -232.00061426, -232.05144145, -232.22645887, -232.41798113,
       -232.60462697, -232.89668821, -233.78186414, -234.98767176,
       -235.11428745, -235.19143506, -235.39710462, -235.45405377,
       -235.62225805, -235.94712811, -236.56851001, -236.57425096,
       -236.7157612 , -236.8815725 , -237.92198293, -238.20597028,
       -238.26371398, -238.34482022, -238.46150139, -239.09456405,
       -239.1868435 , -239.73790935, -239.90330906, -240.1533008 ,
       -240.23378774, -240.44977755, -242.34438507, -242.99701121,
       -243.45010163, -243.48411617, -243.73788061, -244.00347056,
       -244.1366912 , -244.14018121, -244.32587583, -244.57528057,
       -244.60839122, -245.9135097 , -246.26085513, -246.54827872,
       -246.96217943, -247.49518476, -247.77500182, -247.77706436,
       -248.35797138, -248.41535382, -248.75471974, -249.0829423 ,
       -249.29402596, -249.67795035, -249.73533719, -249.74364951,
       -249.92992876, -249.95842767, -250.10450298, -250.20038122,
       -250.52164115, -250.5623651 , -251.86440299, -252.10498096,
       -252.8035739 , -252.82820316, -252.99137333, -253.08587708,
       -253.87419206, -253.99858055, -254.63427355, -254.65612458,
       -254.87435934, -254.9660247 , -255.49176389, -255.76341627,
       -256.39073087, -257.02999257, -257.12014196, -259.51683014,
       -260.02818555, -260.1585896 , -260.47352186, -260.51478366,
       -260.9852685 , -261.58914232, -261.98164835, -262.1954725 ,
       -262.21029962, -262.64337748, -262.79777582, -262.86926116,
       -263.00344483, -263.14050584, -263.19841898, -263.93452723,
       -264.45748001, -264.99981674, -265.40090411, -266.5632534 ,
       -266.91209734, -266.94387793, -267.01935751, -267.2724506 ,
       -267.36936566, -267.66364819, -268.12682573, -268.29050669,
       -268.42472031, -268.69126049, -268.97627874, -269.0076541 ,
       -269.4543912 , -269.47973497, -269.65979002, -269.75647136,
       -270.41248063, -270.43770508, -271.20303586, -271.22944985,
       -271.49396286, -272.62988049, -272.73482016, -272.88662857,
       -273.16784045, -273.19945819, -273.40137266, -273.76711556,
       -273.83379152, -274.31606782, -274.70350021, -275.65397001,
       -275.77421435, -275.90116304, -276.56021836, -276.8715745 ,
       -276.9168352 , -277.20501134, -279.71703923, -280.54682316,
       -280.62218534, -281.0950615 , -281.59700851, -281.60481231,
       -282.14344831, -282.87316583, -284.34358159, -285.01935116,
       -285.7191539 , -286.35645595, -287.15757196, -287.42457518,
       -287.51914809, -287.55127736, -288.22426522, -288.83563159,
       -288.89015359, -289.27640893, -291.01535073, -291.0872904 ,
       -291.41064047, -293.43668984, -293.70046648, -293.82094887,
       -294.29044825, -295.26609006, -295.62965972, -295.96755989,
       -296.4081209 , -296.71060366, -297.28762087, -297.39741219,
       -297.44617097, -299.22147631, -300.55146114, -300.84055009,
       -301.6736839 , -301.96223717, -302.45773962, -302.5232722 ,
       -302.71029435, -303.78435321, -304.40240648, -304.42732404,
       -304.53188124, -306.5204011 , -306.92139267, -308.07827417,
       -308.40669312, -308.50047327, -308.74157858, -310.50771736,
       -311.66043744, -314.08598123, -314.6524489 , -315.40526028,
       -318.3090542 , -320.48714766, -320.7470907 , -320.74975805,
       -320.78465826, -321.41398453, -321.6097238 , -321.64072677,
       -322.36870506, -322.72873983, -326.93618291, -328.61474421,
       -330.99305903, -335.99429585, -336.96229579, -338.35453863,
       -338.43180924, -339.64386145, -341.39354814, -343.60772412,
       -344.14150023, -344.51203727, -344.85423841, -345.01011357,
       -345.9335704 , -348.14195826, -352.77763367, -354.40198737,
       -355.74059341, -361.15216816, -361.33876298, -375.40598447,
       -386.69393951, -387.9746456 , -389.34620661, -411.56933982,
       -419.30798672, -424.87616098, -447.44967931, -471.31893587,
       -506.32084443, -543.42738956, -543.42738956, -543.42738956])

In [89]:
def ppp(d,title):
    plt.suptitle(title)
    plt.imshow(d,interpolation='nearest',cmap = "Greys", vmin = 0, vmax =1)

In [ ]:


In [ ]:


In [91]:
run = 0
range_length = 5
# idx = np.flipud(or_score_of_interest[run].argsort())[:range_length]
idx = or_score_of_interest[run].argsort()[:range_length]
or_recon_of_interest[0].shape
cool_i = 1
for i in idx:
    print("orbm {}".format(or_score_of_interest[run][i]))
    print("rbm {}".format(r_score_of_interest[run][i]))
    ppp(r_recon_of_interest[run,i].reshape(28,28),"RBM A Reconstruction")
    plt.savefig("Assets/bad-rbm-{}-rank{}".format(comp[0],cool_i))
    plt.show()
    
    ppp(r_b_recon[run,i].reshape(28,28),"RBM B Reconstruction")
    plt.savefig("Assets/bad-rbm-{}-rank{}".format(comp[1],cool_i))
    plt.show()
    
    
    ppp(or_recon_of_interest[run,i].reshape(28,28),"ORBM A Reconstruction")
    plt.savefig("Assets/bad-orbm-{}-rank{}".format(comp[0],cool_i))
    plt.show()
    
    ppp(or_b_recon[run,i].reshape(28,28),"ORBM B Reconstruction")
    plt.savefig("Assets/bad-orbm-{}-rank{}".format(comp[1],cool_i))
    plt.show()
    
    
    d1 = mnist_data[comp[0]][1][i]
    d2 = mnist_data[comp[1]][1][i]
    ppp(d1,"")
    plt.savefig("Assets/comp-i-{}-rank{}".format(comp[0],cool_i))
    plt.show()
    ppp(d2,"")
    plt.savefig("Assets/comp-i-{}-rank{}".format(comp[1],cool_i))
    plt.show()
    ppp(np.maximum(d1,d2),"")
    plt.savefig("Assets/comp-i-{}-{}-rank{}".format(comp[0], comp[1], cool_i))
    plt.show()
    cool_i += 1


orbm -543.4273895589971
rbm -326.1818188001946
orbm -543.427389558997
rbm -348.4333408614226
orbm -543.427389558997
rbm -378.52371329826667
orbm -506.32084443296026
rbm -262.0368569411073
orbm -471.31893586652143
rbm -317.62979978412227

In [48]:
a = or_recon_of_interest[or_max_idx].reshape(28,28)
b = r_recon_of_interest[or_max_idx].reshape(28,28)
t = ds_from_mnist_ids(comp,size=200)[0][or_max_idx[1]].reshape(28,28)


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-48-470740203ee3> in <module>()
      1 a = or_recon_of_interest[or_max_idx].reshape(28,28)
      2 b = r_recon_of_interest[or_max_idx].reshape(28,28)
----> 3 t = ds_from_mnist_ids(comp,size=200)[0][or_max_idx[1]].reshape(28,28)

IndexError: index 437 is out of bounds for axis 0 with size 200

In [49]:
pp.image(a)
pp.image(b)
pp.image(t)
pp.image(t-a)
print(abs(t - a).sum())
print(abs(t - b).sum())
or_score_of_interest[or_max_idx]


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-49-603e367e3306> in <module>()
      1 pp.image(a)
      2 pp.image(b)
----> 3 pp.image(t)
      4 pp.image(t-a)
      5 print(abs(t - a).sum())

NameError: name 't' is not defined

In [ ]:
or_score_of_interest.max()

In [ ]:
or_max_idx

In [ ]:
print("orbm{}".format(np.array(result.orbm_a_scores[(1,2)]).mean(0).max()))
print("rbm{}".format(np.array(result.rbm_a_scores[(1,2)]).mean(0).max()))

In [ ]:
key = (1,7)
for key in result.rbm_a_scores.keys():
    plt.suptitle("{}".format(key))
    plt.plot(np.array(result.rbm_a_scores[key]).mean(0), np.array(result.orbm_a_scores[key]).mean(0),'.k')
#     plt.scatter(np.array(result.rbm_a_scores[key]).mean(0), np.array(result.orbm_a_scores[key]).mean(0),s=20,cmap=plt.cm.Greys)
    plt.plot([0.4,1],[0.4,1])
    plt.axis('equal')
    plt.ylabel("ORBM Scores")
    plt.xlabel("RBM Scores")
    plt.savefig("Results/plots/" + str(key) +".png")
    plt.show()



In [ ]:
mnist_data[2][0].visible_bias.max()

In [ ]:
!say "I'm finished!!!"

In [ ]: