In [1]:
from scipy.special import expit
from rbmpy.rbm import RBM
from rbmpy.sampler import VanillaSampler, DirtyCorrectionMulDimSampler,PartitionedSampler, ApproximatedSampler, LayerWiseApproxSampler,ApproximatedMulDimSampler, goodnight, orbm_goodnight, FullCorrectionMulDimSampler
from rbmpy.trainer import VanillaTrainier, ORBMTrainer
from rbmpy.performance import Result
import numpy as np
import rbmpy.datasets as datasets
import rbmpy.performance, pickle, rbmpy.rbm, os, math, logging
import rbmpy.plotter as pp
from rbmpy.datasets import SquareToyData
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
logger = logging.getLogger()
# Set the logging level to logging.DEBUG
logger.setLevel(logging.INFO)
%matplotlib inline
In [2]:
def image_composition(a, b):
return np.maximum(a,b)
In [3]:
square_factory = SquareToyData()
sq_shape = (1,1)
img_size = (1,2)
dataset_one = square_factory.gen_training(sq_shape, img_size)
In [4]:
dataset_composite = np.append(dataset_one,image_composition(dataset_one, np.flipud(dataset_one)), axis = 0)
pp.images(dataset_composite, color_range=(0,1))
logger.info("Training Dataset")
In [5]:
def plot_rbm_vanilla_dreams(rbm):
reconstructions = goodnight(rbm, VanillaSampler(rbm),2000,200)[0]
pp.plot_dict(reconstructions)
def plot_orbm_dreams(rbm_a, rbm_b, sampler):
ab_orbm_dreams = orbm_goodnight(rbm_a,rbm_b,sampler, 2000,100)[0]
pp.plot_dict(ab_orbm_dreams)
return ab_orbm_dreams
def evaluate_orbm_training(num_hid, sampler_class, dataset_composite, epochs):
(num_items, num_vis ) = datasets.squash_images(dataset_composite).shape
rbm_a = RBM(num_hid, num_vis,num_items, zerod_bias=True)
logger.warn("Using a single RBM here!!")
rbm_b = rbm_a #RBM(num_hid, num_vis,num_items, zerod_bias= True)
sampler = sampler_class(rbm_a.weights ,rbm_b.weights, rbm_a.hidden_bias, rbm_b.hidden_bias)
trainer = ORBMTrainer(rbm_a,rbm_b, sampler)
trainer.train(epochs,datasets.squash_images(dataset_composite), logging_freq=10)
logger.info("RBM_A Dreams")
plot_rbm_vanilla_dreams(rbm_a)
logger.info("RBM_B Dreams")
plot_rbm_vanilla_dreams(rbm_b)
logger.info("ORBM Dreams (combined reconstructions)")
plot_orbm_dreams(rbm_a, rbm_b, sampler)
return (rbm_a, rbm_b)
In [6]:
rbm_a, rbm_b = evaluate_orbm_training(2, ApproximatedMulDimSampler, dataset_composite, int(1e3))
In [6]:
pure_dataset = np.append(dataset_one, dataset_one, axis = 0)
(num_items, num_vis ) = datasets.squash_images(pure_dataset).shape
rbm_vanilla = RBM(2,num_vis,num_items)
van_sampler = VanillaSampler(rbm_vanilla)
van_trainer = VanillaTrainier(rbm_vanilla,van_sampler)
van_trainer.train(int(1e4), datasets.squash_images(pure_dataset))
plot_rbm_vanilla_dreams(rbm_vanilla)
print(rbm_vanilla.weights)
original_van_weights = np.copy(rbm_vanilla.weights)
In [7]:
orbm_sampler = ApproximatedMulDimSampler(rbm_vanilla.weights ,rbm_vanilla.weights, rbm_vanilla.hidden_bias, rbm_vanilla.hidden_bias)
unapprox_sampler = FullCorrectionMulDimSampler(rbm_vanilla.weights ,rbm_vanilla.weights, rbm_vanilla.hidden_bias, rbm_vanilla.hidden_bias)
orbm_trainer = ORBMTrainer(rbm_vanilla,rbm_vanilla, unapprox_sampler)
dirty_sampler = DirtyCorrectionMulDimSampler(rbm_vanilla.weights ,rbm_vanilla.weights, rbm_vanilla.hidden_bias, rbm_vanilla.hidden_bias)
logger.info("Approx ORBM Generative Samples - Free Phase Sampling")
plot_orbm_dreams(rbm_vanilla, rbm_vanilla, orbm_sampler)
# logger.info("Un-Approx ORBM Generative Samples - Free Phase Sampling")
# plot_orbm_dreams(rbm_vanilla, rbm_vanilla, unapprox_sampler)
# logger.info("Dirty ORBM Generative Samples - Free Phase Sampling")
# plot_orbm_dreams(rbm_vanilla, rbm_vanilla, dirty_sampler)
Out[7]:
In [8]:
print(original_van_weights)
orbm_trainer.train(100,datasets.squash_images(dataset_composite), logging_freq=50)
print(rbm_vanilla.weights)
In [9]:
orbm_trainer.train(int(1e4),datasets.squash_images(dataset_composite), logging_freq=10)
In [10]:
plot_orbm_dreams(rbm_vanilla, rbm_vanilla, orbm_sampler)
Out[10]:
In [19]:
print(original_van_weights)
print(rbm_vanilla.weights)
rbm_vanilla.weights - original_van_weights
Out[19]:
In [126]:
rbm_vanilla.weights[0,0] = -3
rbm_vanilla.weights[1,1] = -3
rbm_vanilla.weights[1,0] = 6
rbm_vanilla.weights[0,1] = 6
In [23]:
plot_rbm_vanilla_dreams(rbm_vanilla)
In [24]:
plot_orbm_dreams(rbm_vanilla, rbm_vanilla, orbm_sampler)
Out[24]:
In [19]:
orbm_trainer.train(int(1e4),datasets.squash_images(dataset_composite), logging_freq=10)
In [20]:
plot_orbm_dreams(rbm_vanilla, rbm_vanilla, orbm_sampler)
In [26]:
plot_rbm_vanilla_dreams(rbm_vanilla)
In [13]:
orbm_trainer.train(int(1e3),datasets.squash_images(dataset_composite), logging_freq=10)
plot_orbm_dreams(rbm_vanilla, rbm_vanilla, orbm_sampler)
In [11]:
rbm_a, rbm_b = evaluate_orbm_training(2, FullCorrectionMulDimSampler, dataset_composite, int(1e3))
In [94]:
def van_dream_clamped(sampler, clamped_v,model, num_gibbs = 1000):
dream_hid = rbmpy.rbm.random_hiddens_for_rbm(model)
for i in range(num_gibbs):
dream_hid = sampler.visible_to_hidden(clamped_v)
return dream_hid
def dream_clamped( clamped_v ,model_a, model_b, num_gibbs = 1000):
a_vanilla = VanillaSampler(model_a)
b_vanilla = VanillaSampler(model_b)
a_dream_h = van_dream_clamped(a_vanilla,clamped_v, model_a, num_gibbs=num_gibbs)
b_dream_h = van_dream_clamped(b_vanilla,clamped_v, model_b, num_gibbs=num_gibbs)
phi_a = np.dot(a_dream_h, model_a.weights)
phi_b = np.dot(b_dream_h, model_b.weights)
sig_ab = expit(phi_a + phi_b)
# print("phi_a {}\tphi_b {}\t\tdream_h_a {}\tdream_h_b {}\tSig_ab {}".format(phi_a, phi_b, a_dream_h, b_dream_h, sig_ab))
return self.__bernoulli_trial__(sig_ab)
def orbm_clamped_dream_a(clamped_v, num_gibbs = 50):
orbm_sampler = ApproximatedSampler(rbm_vanilla.weights, rbm_vanilla.weights, rbm_vanilla.hidden_bias, rbm_vanilla.hidden_bias)
h_a, h_b = orbm_sampler.v_to_h(np.zeros(2), np.zeros(2), dataset_one[0])
# print(h_a.shape)
return (h_a,h_b)
In [108]:
def key_for_hiddens(h_a,h_b):
if h_a.sum() == 1 and h_b.sum() == 1:
return "h_a{} h_b{}".format(h_a,h_b)
else:
return "other"
def orbm_goodnight(clamped_v,model_a, model_b, hours_of_sleep, num_gibbs_per_hour):
"""Generate a dictionary of reconstructions to the number of times they occurred"""
result_dict = Counter()
h_prime = orbm_clamped_dream_a(clamped_v, num_gibbs_per_hour)
reconstruction_dict = {} # the actual reconstructions that occurred
for i in range(hours_of_sleep):
h_prime_a, h_prime_b = orbm_clamped_dream_a(clamped_v, num_gibbs_per_hour)
result_dict[key_for_hiddens(h_prime_a,h_prime_b)] += 1
return result_dict, reconstruction_dict
In [127]:
a = orbm_goodnight(np.array([1,1]),rbm_vanilla, rbm_vanilla, 2000, 200)
b = orbm_goodnight(np.array([0,1]),rbm_vanilla, rbm_vanilla, 2000, 200)
c = orbm_goodnight(np.array([1,0]),rbm_vanilla, rbm_vanilla, 2000, 200)
d = orbm_goodnight(np.array([0,0]),rbm_vanilla, rbm_vanilla, 2000, 200)
In [128]:
from collections import Counter
In [129]:
pp.plot_dict(a[0])
pp.plot_dict(b[0])
pp.plot_dict(c[0])
pp.plot_dict(d[0])
In [130]:
rbm_vanilla.weights[]
In [ ]: