In [1]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import warnings
from pybbn.generator.bbngenerator import generate_singly_bbn, generate_multi_bbn, convert_for_exact_inference
from pybbn.generator.bbngenerator import convert_for_drawing
from pybbn.pptc.inferencecontroller import InferenceController
np.random.seed(37)
g, p = generate_multi_bbn(5, max_iter=5)
m_bbn = convert_for_exact_inference(g, p)
nx_multi_bbn = convert_for_drawing(m_bbn)
g, p = generate_singly_bbn(5, max_iter=10)
s_bbn = convert_for_exact_inference(g, p)
nx_singly_bbn = convert_for_drawing(s_bbn)
Here, we visualize the generated multi- and singly-connected BBNs.
In [2]:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
plt.figure(figsize=(10, 5))
plt.subplot(121)
nx.draw(nx_multi_bbn, with_labels=True, font_weight='bold')
plt.title('Multi-connected BBN')
plt.subplot(122)
nx.draw(nx_singly_bbn, with_labels=True, font_weight='bold')
plt.title('Singly-connected BBN')
Now, let's print out the probabilities of each node for the multi- and singly-connected BBNs.
In [3]:
join_tree = InferenceController.apply(m_bbn)
for node in join_tree.get_bbn_nodes():
potential = join_tree.get_bbn_potential(node)
print(node)
print(potential)
print('>')
In [4]:
join_tree = InferenceController.apply(s_bbn)
for node in join_tree.get_bbn_nodes():
potential = join_tree.get_bbn_potential(node)
print(node)
print(potential)
print('>')
In [5]:
def generate_graphs(n=10, prog='neato', multi=True):
d = {}
for i in range(n):
max_nodes = np.random.randint(3, 8)
max_iter = np.random.randint(10, 100)
if multi is True:
g, p = generate_multi_bbn(max_nodes, max_iter=max_iter)
else:
g, p = generate_singly_bbn(max_nodes, max_iter=max_iter)
bbn = convert_for_exact_inference(g, p)
pos = nx.nx_agraph.graphviz_layout(g, prog=prog)
d[i] = {
'g': g,
'p': p,
'bbn': bbn,
'pos': pos
}
return d
def draw_graphs(graphs, prefix):
fig, axes = plt.subplots(5, 2, figsize=(15, 20))
for i, ax in enumerate(np.ravel(axes)):
graph = graphs[i]
nx.draw(graph['g'], pos=graph['pos'], with_labels=True, ax=ax)
ax.set_title('{} Graph {}'.format(prefix, i + 1))
multi_graphs = generate_graphs(multi=True)
singly_graphs = generate_graphs(multi=False)
In [6]:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
draw_graphs(multi_graphs, 'Multi-connected')
draw_graphs(singly_graphs, 'Singly-connected')
In [ ]: