How to Build Factor Graphs in Python

Code forked from github user rdlester.

His Readme file states:

Python implementation of Sum-product (aka Belief-Propagation) for discrete Factor Graphs.

See this paper for more details on the Factor Graph framework and the sum-product algorithm. This code was originally written as part of a grad student seminar taught by Erik Sudderth at Brown University; the seminar web page is an excellent resource for learning more about graphical models.

My purpose here is to go through rdlester's implementation in order to better understand the mechanics of factor graphs and the Sum-Product Algorithm.

The Node Class

The Node class uses the following API:

Constant

  • epsilon: threshold used to test convergence

Fields

  • enabled: is the node allowed to receive messages?
  • nid: an integer that uniquely identifies the node
  • neighbors: a list of other nodes connected to this one
  • incoming: a list of messages being received from other nodes
  • outgoing: a list of being sent to other nodes
  • old_outgoing: a list of messages sent to other nodes in the previous step

Methods

  • reset: sets the node's enabled state to True
  • disable: sets the node's enabled state to False
  • next_step: copies the data from outgoing into old_outgoing
  • normalize_messages: adjusts the outgoing messages so that their values sum to 1
  • receive_message: find the index of a node in neighbors, and set incoming[index] to the value of the received message
  • send_messages: send each message in outgoing to the corresponding neighbor (by calling neighbor.receive_message(self, m)
  • check_convergence: compare outgoing to old_outgoing, return True if all values are less than epsilon

In [ ]:
import numpy as np


class Node(object):
    """ Superclass for graph nodes
    """
    epsilon = 10**(-4)
    
    def __init__(self, nid):
        self.enabled = True
        self.nid = nid
        self.neighbors = []
        self.incoming = []
        self.outgoing = []
        self.old_outgoing = []
    
    def reset(self):
        self.enabled = True
    
    def disable(self):
        self.enabled = False
    
    def enable(self):
        self.enabled = True
        for n in self.neighbors:
            # don't call enable() as it will recursively enable entire graph
            n.enabled = True
    
    def next_step(self):
        """ Used to have this line in prepMessages
            but it didn't work?
        """
        self.old_outgoing = self.outgoing[:]
    
    def normalize_messages(self):
        """ Normalize to sum to 1
        """
        self.outgoing = [x / np.sum(x) for x in self.outgoing]
    
    def receive_message(self, node, message):
        """ Places new message into correct location in new message list
        """
        if self.enabled:
            i = self.neighbors.index(node)
            self.incoming[i] = message
    
    def send_messages(self):
        """ Sends all outgoing messages
        """
        for i in xrange(0, len(self.outgoing)):
            self.neighbors[i].receive_message(self, self.outgoing[i])
    
    def check_convergence(self):
        """ Check if any messages have changed
        """
        if self.enabled:
            for i in xrange(0, len(self.outgoing)):
                # check messages have same shape
                self.old_outgoing[i].shape = self.outgoing[i].shape
                delta = np.absolute(self.outgoing[i] - self.old_outgoing[i])
                if (delta > Node.epsilon).any():  # if there has been change
                    return False
            return True
        else:
            # Always return True if disabled to avoid interrupting check
            return True

Variable Nodes: The VarNode Class

Factor Graphs connect alternating layers of variable nodes and factor nodes. The factor nodes contain marginal functions of the full function we are trying to compute, and the variable nodes contain the values that are either inputs or outputs of those functions.

Here is the API for the VarNode class (in addition to the Node class API):

Fields

  • name: the name of the variable
  • dim: the dimensionality of the variable; apparently this is the number of possible values it can take?
  • observed: equals -1 if the variable is hidden, otherwise it is an index from 0 to dim of the variable's observed value

Room for improvement: how do we replace observations with continuous values?

Methods

  • reset: enhances Node.reset by initializing incoming, outgoing, and old_outgoing to vectors of 1's, and observed = -1
  • condition: set the variable's observed value, set all outgoing messages to send that value, and copy to old_outgoing
  • prep_messages: if no observation made, multiply values of all incoming messages except i, and send that value to outgoing[i], for all i in neighbors

To really understand how this all fits together, debug the prep_messages function and observe what happens in each step.


In [ ]:
class VarNode(Node):
    """ Variable node in factor graph
    """
    def __init__(self, name, dim, nid):
        super(VarNode, self).__init__(nid)
        self.name = name
        self.dim = dim
        self.observed = -1  # only >= 0 if variable is observed
    
    def reset(self):
        super(VarNode, self).reset()
        size = range(0, len(self.incoming))
        self.incoming = [np.ones((self.dim, 1)) for i in size]
        self.outgoing = [np.ones((self.dim, 1)) for i in size]
        self.old_outgoing = [np.ones((self.dim, 1)) for i in size]
        self.observed = -1
    
    def condition(self, observation):
        """ Condition on observing certain value
        """
        self.enable()
        self.observed = observation
        # set messages (won't change)
        for i in xrange(0, len(self.outgoing)):
            self.outgoing[i] = np.zeros((self.dim, 1))
            self.outgoing[i][self.observed] = 1.
        self.next_step()  # copy into old_outgoing
    
    def prep_messages(self):
        """ Multiplies together incoming messages to make new outgoing
        """        
        # compute new messages if no observation has been made
        if self.enabled and self.observed < 0 and len(self.neighbors) > 1:
            # switch reference for old messages
            self.next_step()
            for i in xrange(0, len(self.incoming)):
                # multiply together all excluding message at current index
                curr = self.incoming[:]
                del curr[i]
                self.outgoing[i] = reduce(np.multiply, curr)
        
            # normalize once finished with all messages
            self.normalize_messages()

Factor Nodes: The FacNode Class

Here is the API for the VarNode class (in addition to the Node class API):

Fields

  • P: ???
  • neighbors: factor nodes are initialized with a list of variable node neighbors (variable nodes just get an empty list)

Methods

  • initialization: after initializing its own values, a FacNode will add itself to each of its neighbors' list of neighbors

In [ ]:
class FacNode(Node):
    """ Factor node in factor graph
    """
    def __init__(self, P, nid, *args):
        super(FacNode, self).__init__(nid)
        self.P = P
        self.neighbors = list(args)  # list storing refs to variable nodes
        
        # num of edges
        n_neighbors = len(self.neighbors)
        n_dependencies = self.P.squeeze().ndim
        
        # init messages
        for i in xrange(0, n_neighbors):
            v = self.neighbors[i]
            vdim = v.dim
            
            # init for factor
            self.incoming.append(np.ones((vdim, 1)))
            self.outgoing.append(np.ones((vdim, 1)))
            self.old_outgoing.append(np.ones((vdim, 1)))
            
            # init for variable  --> this should be done in an add_neighbor function in the VarNode class!
            v.neighbors.append(self)
            v.incoming.append(np.ones((vdim, 1)))
            v.outgoing.append(np.ones((vdim, 1)))
            v.old_outgoing.append(np.ones((vdim, 1)))
        
        # error check
        assert (n_neighbors == n_dependencies), "Factor dimensions does not match size of domain."
    
    def reset(self):
        super(FacNode, self).reset()
        for i in xrange(0, len(self.incoming)):
            self.incoming[i] = np.ones((self.neighbors[i].dim, 1))
            self.outgoing[i] = np.ones((self.neighbors[i].dim, 1))
            self.old_outgoing[i] = np.ones((self.neighbors[i].dim, 1))
    
    def prep_messages(self):
        """ Multiplies incoming messages w/ P to make new outgoing
        """
        if self.enabled:
            # switch references for old messages
            self.next_step()
        
            n_messages = len(self.incoming)
        
            # do tiling in advance
            # roll axes to match shape of newMessage after
            for i in xrange(0, n_messages):
                # find tiling size
                next_shape = list(self.P.shape)
                del next_shape[i]
                next_shape.insert(0, 1)
                # need to expand incoming message to correct num of dims to tile properly
                prep_shape = [1 for x in next_shape]
                prep_shape[0] = self.incoming[i].shape[0]
                self.incoming[i].shape = prep_shape
                # tile and roll
                self.incoming[i] = np.tile(self.incoming[i], next_shape)
                self.incoming[i] = np.rollaxis(self.incoming[i], 0, i+1)
            
            # loop over subsets
            for i in xrange(0, n_messages):
                curr = self.incoming[:]
                del curr[i]
                new_message = reduce(np.multiply, curr, self.P)
                    
                # sum over all vars except i!
                # roll axis i to front then sum over all other axes
                new_message = np.rollaxis(new_message, i, 0)
                new_message = np.sum(new_message, tuple(range(1, n_messages)))
                new_message.shape = (new_message.shape[0], 1)
                    
                #store new message
                self.outgoing[i] = new_message
        
            # normalize once finished with all messages
            self.normalize_messages()

The Graph Class

Now let's take a look at how we build factor graphs out of variable and factor nodes.

Here is the API for the Graph class:

Fields

  • var: a dictionary of VarNodes
  • fac: a list of FacNodes
  • dim: a list of integers
  • converged: a boolean indicator of whether the nodes' messages have converged yet

Methods

  • add_var_node: add a VarNode to the graph, assigning it a name and a dimensionality
  • add_fac_node: add a FacNode to the graph, assigning it a P(?) and a list of VarNodes
  • disable_all: call every node's disable method
  • reset: call every node's reset method and set converged to False
  • sum_product: call prep_messages and send_messages for each FacNode and VarNode, repeat until convergence
  • marginals: compute a dictionary of all marginal distributions (point estimates?) indexed by variable names
  • brute_force:
  • configuration_loop
  • marginalize_brute

In [ ]:
class Graph:
    
    def __init__(self):
        self.var = {}
        self.fac = []
        self.dims = []
        self.converged = False
        
    def add_var_node(self, name, dim):
        new_id = len(self.var)
        new_var = VarNode(name, dim, new_id)
        self.var[name] = new_var
        self.dims.append(dim)
        
        return new_var
    
    def add_fac_node(self, P, *args):
        new_id = len(self.fac)
        new_fac = FacNode(P, new_id, *args)
        self.fac.append(new_fac)
        
        return new_fac
    
    def disable_all(self):
        """ Disable all nodes in graph
            Useful for switching on small subnetworks
            of bayesian nets
        """
        for k, v in self.var.iteritems():
            v.disable()
        for f in self.fac:
            f.disable()
    
    def reset(self):
        """ Reset messages to original state
        """
        for k, v in self.var.iteritems():
            v.reset()
        for f in self.fac:
            f.reset()
        self.converged = False
    
    def sum_product(self, max_steps=500):
        """ This is the algorithm!
            Each time_step:
            take incoming messages and multiply together to produce outgoing for all nodes
            then push outgoing to neighbors' incoming
            check outgoing v. previous outgoing to check for convergence
        """
        # loop to convergence
        time_step = 0
        while time_step < max_steps and not self.converged:  # run for max_steps cycles
            time_step += 1
            print time_step
            
            for f in self.fac:
                # start with factor-to-variable
                # can send immediately since not sending to any other factors
                f.prepMessages()
                f.sendMessages()
            
            for k, v in self.var.iteritems():
                # variable-to-factor
                v.prepMessages()
                v.sendMessages()
            
            # check for convergence
            t = True
            for k, v in self.var.iteritems():
                t = t and v.checkConvergence()
                if not t:
                    break
            if t:        
                for f in self.fac:
                    t = t and f.checkConvergence()
                    if not t:
                        break
            
            if t:  # we have convergence!
                self.converged = True
        
        # if run for 500 steps and still no convergence:impor
        if not self.converged:
            print "No convergence!"
        
    def marginals(self, max_steps=500):
        """ Return dictionary of all marginal distributions
            indexed by corresponding variable name
        """
        # Message pass
        self.sum_product(max_steps)
        
        marginals = {}
        # for each var
        for k, v in self.var.iteritems():
            if v.enabled:  # only include enabled variables
                # multiply together messages
                v_marginal = 1
                for i in xrange(0, len(v.incoming)):
                    v_marginal = v_marginal * v.incoming[i]
            
                # normalize
                n = np.sum(v_marginal)
                v_marginal = v_marginal / n
            
                marginals[k] = v_marginal
        
        return marginals
    
    def brute_force(self):
        """ Brute force method. Only here for completeness.
            Don't use unless you want your code to take forever to produce results.
            Note: index corresponding to var determined by order added
            Problem: max number of dims in numpy is 32???
            Limit to enabled vars as work-around
        """
        # Figure out what is enabled and save dimensionality
        enabled_dims = []
        enabled_nids = []
        enabled_names = []
        enabled_observed = []
        for k, v in self.var.iteritems():
            if v.enabled:
                enabled_nids.append(v.nid)
                enabled_names.append(k)
                enabled_observed.append(v.observed)
                if v.observed < 0:
                    enabled_dims.append(v.dim)
                else:
                    enabled_dims.append(1)
        
        # initialize matrix over all joint configurations
        joint = np.zeros(enabled_dims)
        
        # loop over all configurations
        self.configuration_loop(joint, enabled_nids, enabled_observed, [])
        
        # normalize
        joint = joint / np.sum(joint)
        return {'joint': joint, 'names': enabled_names}
    
    def configuration_loop(self, joint, enabled_nids, enabled_observed, current_state):
        """ Recursive loop over all configurations
            Used for brute force computation
            joint - matrix storing joint probabilities
            enabled_nids - nids of enabled variables
            enabled_observed - observed variables (if observed!)
            current_state - list storing current configuration of vars up to this point
        """
        current_var = len(current_state)
        if current_var != len(enabled_nids):
            # need to continue assembling current configuration
            if enabled_observed[current_var] < 0:
                for i in xrange(0, joint.shape[current_var]):
                    # add new variable value to state
                    current_state.append(i)
                    self.configuration_loop(joint, enabled_nids, enabled_observed, current_state)
                    # remove it for next value
                    current_state.pop()
            else:
                # do the same thing but only once w/ observed value!
                current_state.append(enabled_observed[current_var])
                self.configuration_loop(joint, enabled_nids, enabled_observed, current_state)
                current_state.pop()
                
        else:
            # compute value for current configuration
            potential = 1.
            for f in self.fac:
                if f.enabled and False not in [x.enabled for x in f.neighbors]:
                    # figure out which vars are part of factor
                    # then get current values of those vars in correct order
                    args = [current_state[enabled_nids.index(x.nid)] for x in f.neighbors]
                
                    # get value and multiply in
                    potential = potential * f.P[tuple(args)]
            
            # now add it to joint after correcting state for observed nodes
            ind = [current_state[i] if enabled_observed[i] < 0 else 0 for i in range(0, current_var)]
            joint[tuple(ind)] = potential

    @staticmethod
    def marginalize_brute(brute, var):
        """ Util for marginalizing over joint configuration arrays produced by brute_force
        """
        sum_out = range(0, len(brute['names']))
        del sum_out[brute['names'].index(var)]
        marg = np.sum(brute['joint'], tuple(sum_out))
        return marg / np.sum(marg)  # normalize to sum to one