Make a standard RBM that makes vertical lines
Make a standard RBM that makes horizonal lines
Make a visible layer that ORs them together
Use as inputs to a normal RBM and watch it fail
Use as inputs to the partitioned RBM and cross fingers
In [11]:
import numpy as np
import numpy.random as rng
def sigmoid(x):
return 1.0/(1+np.exp(-x))
In [12]:
Nhid=20 # num units in each hidden layer.
N = 10 # The visible layer will correspond to an N*N image.
global visibles
In [13]:
# Read in some pickled answers from learn_vanilla_RBM_bars instead of trying to engineer a toy model!
weightsA = np.load('A_Wgts')
Nhid, N2 = weightsA.shape
N = sqrt(N2)
hidA_bias_weights = np.load('A_hBias')
vis_biasA_weights = np.load('A_vBias')
weightsB = np.load('B_Wgts')
hidB_bias_weights = np.load('B_hBias')
vis_biasB_weights = np.load('B_vBias')
weightsA = np.reshape(weightsA,(Nhid,N*N))
weightsB = np.reshape(weightsB,(Nhid,N*N))
vis_biasA_weights = np.reshape(vis_biasA_weights,(1,N*N))
vis_biasB_weights = np.reshape(vis_biasB_weights,(1,N*N))
#EVIL HACKING ALERT:..........................
#vis_biasA_weights = vis_biasA_weights / 2
#vis_biasB_weights = vis_biasB_weights / 2
visibles = np.zeros(N*N)
hiddenA = np.ones(Nhid)
hiddenB = np.ones(Nhid)
In [4]:
def do_vanilla_Gibbs_single_RBM(num_iterations, w,bv,bh,v,h):
for t in range(num_iterations):
# Sample the visible nodes.
phi_vis = np.dot(h,w) + bv # This is phi at the visibles, arising from the hidden layer.
prob1_vis = sigmoid(phi_vis)
# Resample the visible node activities.
v = np.where(prob1_vis > rng.rand(N*N),1,0) # Bernoulli samples based on phi
# Sample hidden nodes.
# First, get the standard RBM input to each of the hidden units, in both RBMs.
psi = np.dot(v,np.transpose(w)) + bh # This is weighted inputs to the hidden layer.
# Now do the actual sampling.
prob1_hid = sigmoid(psi)
h = np.where(prob1_hid > rng.rand(1,Nhid),1,0) # Bernoulli samples based on psi
return v,h, phi_vis
In [5]:
def do_Gibbs(flavour,num_iterations,clamped, N,wA,wB,bvA,bvB,bhA,bhB,v,hA,hB): # Gibbs sampling on the new architecture.
for t in range(num_iterations):
# Sample the visible nodes.
phiA = np.dot(hA,wA) + bvA # This is phi at the visibles, arising from the A hidden layer only.
phiB = np.dot(hB,wB) + bvB # Ditto for B.
phi_vis = phiA + phiB
prob1_vis = sigmoid(phi_vis)
if not clamped:
# Resample the visible node activities.
r = rng.rand(N*N)
v = np.where(prob1_vis>r,1,0) # Bernoulli samples based on phi
# Sample hidden nodes.
# First, get the standard RBM input to each of the hidden units, in both RBMs.
psiA = np.dot(v,np.transpose(wA)) + bhA # This is weighted inputs to the A hidden layer.
psiB = np.dot(v,np.transpose(wB)) + bhB # Ditto for B.
if flavour == "partitioned": # in this case we will need to calculate and add a "correction" to the vanilla RBM's psi.
# For the A updates, the "phi" is phiA, and the "epsilon" is phiB.
hAT = np.transpose(hA)
hBT = np.transpose(hB)
# We need to calculate the matrix C (for A and for B)
CA = +np.log(sigmoid(phiA-(hAT*wA))) +np.log(sigmoid(phiA-(hAT*wA)+wA+phiB)) -np.log(sigmoid(phiA-(hAT*wA)+wA)) -np.log(sigmoid(phiA-(hAT*wA)+phiB))
psiA = psiA + CA.sum(1)
CB = +np.log(sigmoid(phiB-(hBT*wB))) +np.log(sigmoid(phiB-(hBT*wB)+wB+phiA)) -np.log(sigmoid(phiB-(hBT*wB)+wB)) -np.log(sigmoid(phiB-(hBT*wB)+phiA))
psiB = psiB + CB.sum(1)
# Now do the actual sampling.
prob1_hidA = sigmoid(psiA)
prob1_hidB = sigmoid(psiB)
hA = np.where(prob1_hidA > rng.rand(Nhid),1,0) # Bernoulli samples based on psi
hB = np.where(prob1_hidB > rng.rand(Nhid),1,0) # Bernoulli samples based on psi
#if flavour == "partitioned": imshow(CA[0].reshape(N,N), interpolation='nearest', cmap='RdYlGn')
return v,hA,hB
In [6]:
def show_samples(flavour): # Show some samples that each start from a random i.c.
global visibles
nr, nc = 3,6
i=0
num_Gibbs_itns =100
for r in range(nr):
for c in range(nc):
subplot(nr, nc, i+1)
# Start the hidden nodes off at random patterns
hiddenA = np.where(rng.rand(1,Nhid) <2./N,1,0)
hiddenB = np.where(rng.rand(1,Nhid) <2./N,1,0)
if flavour == 'just add':
# run gibbs for sample from RBM "A", and again for one from "B". Then merge them in a sigmoid visible pass.
vA = visibles.copy()
vB = visibles.copy()
vA,hiddenA,phiA = do_vanilla_Gibbs_single_RBM(num_Gibbs_itns,weightsA,vis_biasA_weights, hidA_bias_weights, vA,hiddenA)
vB,hiddenB,phiB = do_vanilla_Gibbs_single_RBM(num_Gibbs_itns,weightsB,vis_biasB_weights, hidB_bias_weights, vB,hiddenB)
phi_total = phiA + phiB
prob1_vis = sigmoid(phi_total)
visibles = np.where(prob1_vis>rng.rand(N*N),1,0) # Bernoulli samples based on phi
if flavour in ['vanilla','partitioned']:
kwargs = {"N":N,"wA":weightsA, "wB":weightsB, "bvA":vis_biasA_weights, "bvB":vis_biasB_weights, "bhA":hidA_bias_weights, "bhB":hidB_bias_weights, "v":visibles, "hA":hiddenA, "hB":hiddenB}
visibles, hiddenA, hiddenB = do_Gibbs(flavour, num_Gibbs_itns, False, **kwargs)
imshow(np.reshape(visibles,(N,N)), interpolation='nearest', cmap='copper', vmin=0, vmax=1)
axis('off')
i=i+1
In [7]:
show_samples('just add')
In [8]:
show_samples('vanilla')
In [9]:
show_samples('partitioned')
In [10]:
def new_vis_pattern(): # Generate a new plausible visible pattern.
num_Gibbs_itns=100
global visibles
hiddenA = np.where(rng.rand(1,Nhid) <2./N,1,0)
hiddenB = np.where(rng.rand(1,Nhid) <2./N,1,0)
# run gibbs for sample from RBM "A", and again for one from "B". Then merge them in a sigmoid visible pass.
vA = visibles.copy()
vB = visibles.copy()
vA,hiddenA,phiA = do_vanilla_Gibbs_single_RBM(num_Gibbs_itns,weightsA,vis_biasA_weights, hidA_bias_weights, vA,hiddenA)
vB,hiddenB,phiB = do_vanilla_Gibbs_single_RBM(num_Gibbs_itns,weightsB,vis_biasB_weights, hidB_bias_weights, vB,hiddenB)
r = rng.random()
if r > 0.666: phi_total = phiA # ie. just ONE, not the sum of both.
elif r > 0.333: phi_total = phiB
else: phi_total = phiA + phiB
prob1_vis = sigmoid(phi_total)
visibles = np.where(prob1_vis>rng.rand(N*N),1,0) # Bernoulli samples based on phi
In [11]:
def clamped_samples(): # Show some samples that each start from a random i.c.
global visibles
num_Gibbs_itns = 40 # hoping that's enough to see the difference!
new_vis_pattern()
i=1
nr = 4
for row in range(1,nr+1):
subplot(nr,5,i)
if row is 1:
imshow(np.reshape(visibles,(N,N)), interpolation='nearest', cmap='copper', vmin=0, vmax=1)
title('clamped')
axis('off')
i=i+1
#----------------------------
flavour = 'vanilla'
# Run a Gibbs Sampler for a while, starting from random hidden patterns.
hA = np.where(rng.rand(1,Nhid) <2./N,1,0)
hB = np.where(rng.rand(1,Nhid) <2./N,1,0)
kwargs = {"N":N,"wA":weightsA, "wB":weightsB, "bvA":vis_biasA_weights, "bvB":vis_biasB_weights, "bhA":hidA_bias_weights, "bhB":hidB_bias_weights, "v":visibles, "hA":hA, "hB":hB}
visibles, hA, hB = do_Gibbs(flavour, 1, True, **kwargs)
# Make reconstructions from hiddenA and hiddenB, respectively.
subplot(nr,5,i)
prob1_vis = sigmoid(np.dot(hA,weightsA) + vis_biasA_weights)
imshow(np.reshape(prob1_vis,(N,N)), interpolation='nearest', cmap='BuGn', vmin=0, vmax=1)
axis('off')
if row==1: title(flavour[:5] + ' A')
i=i+1
subplot(nr,5,i)
prob1_vis = sigmoid(np.dot(hB,weightsB) + vis_biasB_weights)
imshow(np.reshape(prob1_vis,(N,N)), interpolation='nearest', cmap='RdPu', vmin=0, vmax=1)
axis('off')
if row==1: title(flavour[:5] + ' B')
i=i+1
#----------------------------
flavour = 'partitioned'
# Run a Gibbs Sampler for a while, starting from random hidden patterns.
hA = np.where(rng.rand(1,Nhid) <2./N,1,0)
hB = np.where(rng.rand(1,Nhid) <2./N,1,0)
kwargs = {"N":N,"wA":weightsA, "wB":weightsB, "bvA":vis_biasA_weights, "bvB":vis_biasB_weights, "bhA":hidA_bias_weights, "bhB":hidB_bias_weights, "v":visibles, "hA":hA, "hB":hB}
visibles, hA, hB = do_Gibbs(flavour, num_Gibbs_itns, True, **kwargs)
# Make reconstructions from hiddenA and hiddenB, respectively.
subplot(nr,5,i)
prob1_vis = sigmoid(np.dot(hA,weightsA) + vis_biasA_weights)
imshow(np.reshape(prob1_vis,(N,N)), interpolation='nearest', cmap='BuGn', vmin=0, vmax=1)
axis('off')
if row==1: title(flavour[:5] + ' A')
i=i+1
subplot(nr,5,i)
prob1_vis = sigmoid(np.dot(hB,weightsB) + vis_biasB_weights)
imshow(np.reshape(prob1_vis,(N,N)), interpolation='nearest', cmap='RdPu', vmin=0, vmax=1)
axis('off')
if row==1: title(flavour[:5] + ' B')
i=i+1
In [14]:
clamped_samples()
In [12]: