In [1]:
%reset -f
%matplotlib inline
import matplotlib as mpl
import numpy as np
from numpy import array as a
import matplotlib.pyplot as plt
import numpy.random as rng
from scipy.special import expit as sigmoid
np.set_printoptions(precision = 2, suppress = True)
import time
rng.seed(int(time.time())) # seed the random number generator
In [7]:
# specify a weights matrix
N = 2
hiWgt, loWgt = 8.0, -6.0
W = loWgt * np.ones((N,N), dtype=float)
for i in range(N): W[i,i] = hiWgt
print(W)
In [8]:
# make up an array with each row being one of the binary patterns. Do 'em all.
hidpats = np.array([[0 if (i & (1 << bit) == 0) else 1 for bit in range(N)] for i in range(2**N)])
vispats = np.array([[0 if (i & (1 << bit) == 0) else 1 for bit in range(N)] for i in range(2**N)])
# calculate the true probability distribution over hidden pats for each RBM, under the generative model.
pHid = {}
total = 0.0
for pat in hidpats:
phiVis = np.dot(W.T, pat)
logP_star = np.sum(np.log(1+np.exp(phiVis)))
pHid[tuple(pat)] = np.exp(logP_star)
total += pHid[tuple(pat)]
for pat in pHid.keys():
pHid[pat] = pHid[pat] / total
for pat in hidpats:
print (pat, pHid[tuple(pat)])
In [9]:
# form the joint distribution over hiddens AND visibles
pHV = {}
for vis in vispats:
for hA in hidpats:
for hB in hidpats:
phi = np.dot(W.T, hA) + np.dot(W.T, hB)
pVis = np.prod(vis * sigmoid(phi) + (1-vis) * (1 - sigmoid(phi)))
pHV[(tuple(hA),tuple(hB),tuple(vis))] = pHid[tuple(hA)] * pHid[tuple(hB)] * pVis
In [10]:
print('visible probabilities under generative model:')
for vis in vispats:
total = 0.0
for hA in hidpats:
for hB in hidpats:
total += pHV[(tuple(hA),tuple(hB),tuple(vis))]
print(vis, ' prob: ',total)
In [11]:
print('hidden probabilities, given each visible in turn:')
for vis in vispats:
print('vis: ',vis)
normalisation = 0.0
for hA in hidpats:
for hB in hidpats:
normalisation += pHV[(tuple(hA),tuple(hB),tuple(vis))]
for hA in hidpats:
for hB in hidpats:
if pHV[(tuple(hA),tuple(hB),tuple(vis))]/normalisation > 0.01:
print ('\t hA,hB: ', hA, hB, ' ',pHV[(tuple(hA),tuple(hB),tuple(vis))]/normalisation)