Code forked from github user rdlester.
His Readme file states:
Python implementation of Sum-product (aka Belief-Propagation) for discrete Factor Graphs.
See this paper for more details on the Factor Graph framework and the sum-product algorithm. This code was originally written as part of a grad student seminar taught by Erik Sudderth at Brown University; the seminar web page is an excellent resource for learning more about graphical models.
My purpose here is to go through rdlester's implementation in order to better understand the mechanics of factor graphs and the Sum-Product Algorithm.
The Node class uses the following API:
Constant
Fields
Methods
In [ ]:
import numpy as np
class Node(object):
""" Superclass for graph nodes
"""
epsilon = 10**(-4)
def __init__(self, nid):
self.enabled = True
self.nid = nid
self.neighbors = []
self.incoming = []
self.outgoing = []
self.old_outgoing = []
def reset(self):
self.enabled = True
def disable(self):
self.enabled = False
def enable(self):
self.enabled = True
for n in self.neighbors:
# don't call enable() as it will recursively enable entire graph
n.enabled = True
def next_step(self):
""" Used to have this line in prepMessages
but it didn't work?
"""
self.old_outgoing = self.outgoing[:]
def normalize_messages(self):
""" Normalize to sum to 1
"""
self.outgoing = [x / np.sum(x) for x in self.outgoing]
def receive_message(self, node, message):
""" Places new message into correct location in new message list
"""
if self.enabled:
i = self.neighbors.index(node)
self.incoming[i] = message
def send_messages(self):
""" Sends all outgoing messages
"""
for i in xrange(0, len(self.outgoing)):
self.neighbors[i].receive_message(self, self.outgoing[i])
def check_convergence(self):
""" Check if any messages have changed
"""
if self.enabled:
for i in xrange(0, len(self.outgoing)):
# check messages have same shape
self.old_outgoing[i].shape = self.outgoing[i].shape
delta = np.absolute(self.outgoing[i] - self.old_outgoing[i])
if (delta > Node.epsilon).any(): # if there has been change
return False
return True
else:
# Always return True if disabled to avoid interrupting check
return True
Factor Graphs connect alternating layers of variable nodes and factor nodes. The factor nodes contain marginal functions of the full function we are trying to compute, and the variable nodes contain the values that are either inputs or outputs of those functions.
Here is the API for the VarNode class (in addition to the Node class API):
Fields
Room for improvement: how do we replace observations with continuous values?
Methods
To really understand how this all fits together, debug the prep_messages function and observe what happens in each step.
In [ ]:
class VarNode(Node):
""" Variable node in factor graph
"""
def __init__(self, name, dim, nid):
super(VarNode, self).__init__(nid)
self.name = name
self.dim = dim
self.observed = -1 # only >= 0 if variable is observed
def reset(self):
super(VarNode, self).reset()
size = range(0, len(self.incoming))
self.incoming = [np.ones((self.dim, 1)) for i in size]
self.outgoing = [np.ones((self.dim, 1)) for i in size]
self.old_outgoing = [np.ones((self.dim, 1)) for i in size]
self.observed = -1
def condition(self, observation):
""" Condition on observing certain value
"""
self.enable()
self.observed = observation
# set messages (won't change)
for i in xrange(0, len(self.outgoing)):
self.outgoing[i] = np.zeros((self.dim, 1))
self.outgoing[i][self.observed] = 1.
self.next_step() # copy into old_outgoing
def prep_messages(self):
""" Multiplies together incoming messages to make new outgoing
"""
# compute new messages if no observation has been made
if self.enabled and self.observed < 0 and len(self.neighbors) > 1:
# switch reference for old messages
self.next_step()
for i in xrange(0, len(self.incoming)):
# multiply together all excluding message at current index
curr = self.incoming[:]
del curr[i]
self.outgoing[i] = reduce(np.multiply, curr)
# normalize once finished with all messages
self.normalize_messages()
Here is the API for the VarNode class (in addition to the Node class API):
Fields
Methods
In [ ]:
class FacNode(Node):
""" Factor node in factor graph
"""
def __init__(self, P, nid, *args):
super(FacNode, self).__init__(nid)
self.P = P
self.neighbors = list(args) # list storing refs to variable nodes
# num of edges
n_neighbors = len(self.neighbors)
n_dependencies = self.P.squeeze().ndim
# init messages
for i in xrange(0, n_neighbors):
v = self.neighbors[i]
vdim = v.dim
# init for factor
self.incoming.append(np.ones((vdim, 1)))
self.outgoing.append(np.ones((vdim, 1)))
self.old_outgoing.append(np.ones((vdim, 1)))
# init for variable --> this should be done in an add_neighbor function in the VarNode class!
v.neighbors.append(self)
v.incoming.append(np.ones((vdim, 1)))
v.outgoing.append(np.ones((vdim, 1)))
v.old_outgoing.append(np.ones((vdim, 1)))
# error check
assert (n_neighbors == n_dependencies), "Factor dimensions does not match size of domain."
def reset(self):
super(FacNode, self).reset()
for i in xrange(0, len(self.incoming)):
self.incoming[i] = np.ones((self.neighbors[i].dim, 1))
self.outgoing[i] = np.ones((self.neighbors[i].dim, 1))
self.old_outgoing[i] = np.ones((self.neighbors[i].dim, 1))
def prep_messages(self):
""" Multiplies incoming messages w/ P to make new outgoing
"""
if self.enabled:
# switch references for old messages
self.next_step()
n_messages = len(self.incoming)
# do tiling in advance
# roll axes to match shape of newMessage after
for i in xrange(0, n_messages):
# find tiling size
next_shape = list(self.P.shape)
del next_shape[i]
next_shape.insert(0, 1)
# need to expand incoming message to correct num of dims to tile properly
prep_shape = [1 for x in next_shape]
prep_shape[0] = self.incoming[i].shape[0]
self.incoming[i].shape = prep_shape
# tile and roll
self.incoming[i] = np.tile(self.incoming[i], next_shape)
self.incoming[i] = np.rollaxis(self.incoming[i], 0, i+1)
# loop over subsets
for i in xrange(0, n_messages):
curr = self.incoming[:]
del curr[i]
new_message = reduce(np.multiply, curr, self.P)
# sum over all vars except i!
# roll axis i to front then sum over all other axes
new_message = np.rollaxis(new_message, i, 0)
new_message = np.sum(new_message, tuple(range(1, n_messages)))
new_message.shape = (new_message.shape[0], 1)
#store new message
self.outgoing[i] = new_message
# normalize once finished with all messages
self.normalize_messages()
Now let's take a look at how we build factor graphs out of variable and factor nodes.
Here is the API for the Graph class:
Fields
Methods
In [ ]:
class Graph:
def __init__(self):
self.var = {}
self.fac = []
self.dims = []
self.converged = False
def add_var_node(self, name, dim):
new_id = len(self.var)
new_var = VarNode(name, dim, new_id)
self.var[name] = new_var
self.dims.append(dim)
return new_var
def add_fac_node(self, P, *args):
new_id = len(self.fac)
new_fac = FacNode(P, new_id, *args)
self.fac.append(new_fac)
return new_fac
def disable_all(self):
""" Disable all nodes in graph
Useful for switching on small subnetworks
of bayesian nets
"""
for k, v in self.var.iteritems():
v.disable()
for f in self.fac:
f.disable()
def reset(self):
""" Reset messages to original state
"""
for k, v in self.var.iteritems():
v.reset()
for f in self.fac:
f.reset()
self.converged = False
def sum_product(self, max_steps=500):
""" This is the algorithm!
Each time_step:
take incoming messages and multiply together to produce outgoing for all nodes
then push outgoing to neighbors' incoming
check outgoing v. previous outgoing to check for convergence
"""
# loop to convergence
time_step = 0
while time_step < max_steps and not self.converged: # run for max_steps cycles
time_step += 1
print time_step
for f in self.fac:
# start with factor-to-variable
# can send immediately since not sending to any other factors
f.prepMessages()
f.sendMessages()
for k, v in self.var.iteritems():
# variable-to-factor
v.prepMessages()
v.sendMessages()
# check for convergence
t = True
for k, v in self.var.iteritems():
t = t and v.checkConvergence()
if not t:
break
if t:
for f in self.fac:
t = t and f.checkConvergence()
if not t:
break
if t: # we have convergence!
self.converged = True
# if run for 500 steps and still no convergence:impor
if not self.converged:
print "No convergence!"
def marginals(self, max_steps=500):
""" Return dictionary of all marginal distributions
indexed by corresponding variable name
"""
# Message pass
self.sum_product(max_steps)
marginals = {}
# for each var
for k, v in self.var.iteritems():
if v.enabled: # only include enabled variables
# multiply together messages
v_marginal = 1
for i in xrange(0, len(v.incoming)):
v_marginal = v_marginal * v.incoming[i]
# normalize
n = np.sum(v_marginal)
v_marginal = v_marginal / n
marginals[k] = v_marginal
return marginals
def brute_force(self):
""" Brute force method. Only here for completeness.
Don't use unless you want your code to take forever to produce results.
Note: index corresponding to var determined by order added
Problem: max number of dims in numpy is 32???
Limit to enabled vars as work-around
"""
# Figure out what is enabled and save dimensionality
enabled_dims = []
enabled_nids = []
enabled_names = []
enabled_observed = []
for k, v in self.var.iteritems():
if v.enabled:
enabled_nids.append(v.nid)
enabled_names.append(k)
enabled_observed.append(v.observed)
if v.observed < 0:
enabled_dims.append(v.dim)
else:
enabled_dims.append(1)
# initialize matrix over all joint configurations
joint = np.zeros(enabled_dims)
# loop over all configurations
self.configuration_loop(joint, enabled_nids, enabled_observed, [])
# normalize
joint = joint / np.sum(joint)
return {'joint': joint, 'names': enabled_names}
def configuration_loop(self, joint, enabled_nids, enabled_observed, current_state):
""" Recursive loop over all configurations
Used for brute force computation
joint - matrix storing joint probabilities
enabled_nids - nids of enabled variables
enabled_observed - observed variables (if observed!)
current_state - list storing current configuration of vars up to this point
"""
current_var = len(current_state)
if current_var != len(enabled_nids):
# need to continue assembling current configuration
if enabled_observed[current_var] < 0:
for i in xrange(0, joint.shape[current_var]):
# add new variable value to state
current_state.append(i)
self.configuration_loop(joint, enabled_nids, enabled_observed, current_state)
# remove it for next value
current_state.pop()
else:
# do the same thing but only once w/ observed value!
current_state.append(enabled_observed[current_var])
self.configuration_loop(joint, enabled_nids, enabled_observed, current_state)
current_state.pop()
else:
# compute value for current configuration
potential = 1.
for f in self.fac:
if f.enabled and False not in [x.enabled for x in f.neighbors]:
# figure out which vars are part of factor
# then get current values of those vars in correct order
args = [current_state[enabled_nids.index(x.nid)] for x in f.neighbors]
# get value and multiply in
potential = potential * f.P[tuple(args)]
# now add it to joint after correcting state for observed nodes
ind = [current_state[i] if enabled_observed[i] < 0 else 0 for i in range(0, current_var)]
joint[tuple(ind)] = potential
@staticmethod
def marginalize_brute(brute, var):
""" Util for marginalizing over joint configuration arrays produced by brute_force
"""
sum_out = range(0, len(brute['names']))
del sum_out[brute['names'].index(var)]
marg = np.sum(brute['joint'], tuple(sum_out))
return marg / np.sum(marg) # normalize to sum to one