Imports


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from spn.factory import SpnFactory

from spn.linked.spn import Spn as SpnLinked

from spn.linked.layers import SumLayer as SumLayerLinked
from spn.linked.layers import ProductLayer as ProductLayerLinked
from spn.linked.layers import CategoricalIndicatorLayer
from spn.linked.layers import CategoricalSmoothedLayer

from spn.linked.nodes import SumNode
from spn.linked.nodes import ProductNode
from spn.linked.nodes import CategoricalIndicatorNode
from spn.linked.nodes import CategoricalSmoothedNode
from spn.linked.nodes import CLTreeNode

import time
from spn import MARG_IND
from spn import LOG_ZERO
import matplotlib.pyplot as plt
from pprint import pprint
import numpy
import math
import dataset
import logging
import dataset
import algo.learnspn
from algo.dataslice import DataSlice
import spn.linked.tests.test_spn as test
from pyomo.opt import ProblemFormat

import copy
import sys
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from pprint import pprint

MAXI = "m"

constructing a SPN

- by hand

this one

In [7]:
dicts = [{'var': 0, 'freqs': [91, 1]},
         {'var': 1, 'freqs': [1, 19]},
         {'var': 2, 'freqs': [7, 76]},
         {'var': 3, 'freqs': [69, 2]}]

In [8]:
def create_valid_toy_spn():
    # root layer
    whole_scope = frozenset({0, 1, 2, 3})
    root_node = SumNode(var_scope=whole_scope)
    root_layer = SumLayerLinked([root_node])

    # prod layer
    prod_node_1 = ProductNode(var_scope=whole_scope)
    prod_node_2 = ProductNode(var_scope=whole_scope)
    prod_layer_1 = ProductLayerLinked([prod_node_1, prod_node_2])

    root_node.add_child(prod_node_1, 0.54)
    root_node.add_child(prod_node_2, 0.46)

    # sum layer
    scope_1 = frozenset({0, 1})
    scope_2 = frozenset({2})
    scope_3 = frozenset({3})
    scope_4 = frozenset({2, 3})

    sum_node_1 = SumNode(var_scope=scope_1)
    sum_node_2 = SumNode(var_scope=scope_2)
    sum_node_3 = SumNode(var_scope=scope_3)
    sum_node_4 = SumNode(var_scope=scope_4)

    prod_node_1.add_child(sum_node_1)
    prod_node_1.add_child(sum_node_2)
    prod_node_1.add_child(sum_node_3)

    prod_node_2.add_child(sum_node_1)
    prod_node_2.add_child(sum_node_4)

    sum_layer_1 = SumLayerLinked([sum_node_1, sum_node_2,
                            sum_node_3, sum_node_4])

    # another product layer
    prod_node_3 = ProductNode(var_scope=scope_1)
    prod_node_4 = ProductNode(var_scope=scope_1)

    prod_node_5 = ProductNode(var_scope=scope_4)
    prod_node_6 = ProductNode(var_scope=scope_4)

    sum_node_1.add_child(prod_node_3, 0.8)
    sum_node_1.add_child(prod_node_4, 0.2)

    sum_node_4.add_child(prod_node_5, 0.5)
    sum_node_4.add_child(prod_node_6, 0.5)

    prod_layer_2 = ProductLayerLinked([prod_node_3, prod_node_4,
                                 prod_node_5, prod_node_6])

    # last sum one
    scope_5 = frozenset({0})
    scope_6 = frozenset({1})

    sum_node_5 = SumNode(var_scope=scope_5)
    sum_node_6 = SumNode(var_scope=scope_6)
    sum_node_7 = SumNode(var_scope=scope_5)
    sum_node_8 = SumNode(var_scope=scope_6)

    sum_node_9 = SumNode(var_scope=scope_2)
    sum_node_10 = SumNode(var_scope=scope_3)
    sum_node_11 = SumNode(var_scope=scope_2)
    sum_node_12 = SumNode(var_scope=scope_3)

    prod_node_3.add_child(sum_node_5)
    prod_node_3.add_child(sum_node_6)
    prod_node_4.add_child(sum_node_7)
    prod_node_4.add_child(sum_node_8)

    prod_node_5.add_child(sum_node_9)
    prod_node_5.add_child(sum_node_10)
    prod_node_6.add_child(sum_node_11)
    prod_node_6.add_child(sum_node_12)

    sum_layer_2 = SumLayerLinked([sum_node_5, sum_node_6,
                            sum_node_7, sum_node_8,
                            sum_node_9, sum_node_10,
                            sum_node_11, sum_node_12])

    # input layer
    vars = [2, 2, 2, 2]
    input_layer = CategoricalSmoothedLayer(vars=vars, node_dicts=dicts)
    last_sum_nodes = [sum_node_2, sum_node_3,
                      sum_node_5, sum_node_6,
                      sum_node_7, sum_node_8,
                      sum_node_9, sum_node_10,
                      sum_node_11, sum_node_12]
    for sum_node in last_sum_nodes:
        (var_scope,) = sum_node.var_scope
        for input_node in input_layer.nodes():
            if input_node.var == var_scope:
                sum_node.add_child(input_node, 1.0)

    spn = SpnLinked(input_layer=input_layer,
              layers=[sum_layer_2, prod_layer_2,
                      sum_layer_1, prod_layer_1,
                      root_layer])
    

    # print(spn)
    return spn

In [181]:
SPN = create_valid_toy_spn()

or this one


In [3]:
def create_valid_toy_spn2():
    # root layer
    whole_scope = frozenset({0, 1})
    root_node = SumNode(var_scope=whole_scope)
    root_layer = SumLayerLinked([root_node])

    # prod layer
    prod_node_1 = ProductNode(var_scope=whole_scope)
    prod_node_2 = ProductNode(var_scope=whole_scope)
    prod_node_3 = ProductNode(var_scope=whole_scope)
    prod_layer = ProductLayerLinked([prod_node_1, prod_node_2, prod_node_3])

    root_node.add_child(prod_node_1, 0.2)
    root_node.add_child(prod_node_2, 0.5)
    root_node.add_child(prod_node_3, 0.3)

    # input layer
    node1 = CategoricalSmoothedNode(var=1, var_values=[2], freqs=[0,1], alpha =0)
    node2 = CategoricalSmoothedNode(var=0, var_values=[2,2], freqs=[6,4], alpha =0)
    node3 = CategoricalSmoothedNode(var=1, var_values=[2,2], freqs=[8,2], alpha =0)
    node4 = CategoricalSmoothedNode(var=0, var_values=[], freqs=[1,9], alpha =0)
    
    prod_node_1.add_child(node1)
    prod_node_1.add_child(node2)
    prod_node_2.add_child(node2)
    prod_node_2.add_child(node3)
    prod_node_3.add_child(node3)
    prod_node_3.add_child(node4)
    
    input_layer = CategoricalSmoothedLayer([node1,node2,node3,node4])
    
    spn = SpnLinked(input_layer=input_layer,
              layers=[prod_layer,
                      root_layer])

    # print(spn)
    return spn

In [191]:
SPN = create_valid_toy_spn2()

- by learning on a dataset


In [10]:
#Choose the data

In [1]:
dataset_name = 'nltcs'

with learnSPN (random = False), or RandomLearnSPN (random = True)


In [146]:
seed =8374849
numpy.random.seed(seed)
rand_gen = numpy.random.RandomState(seed)

def test_learnspn_oneshot(dataset_name, random= False):

    logging.basicConfig(level=logging.INFO)
    #
    # loading a very simple dataset
    dataset_name = dataset_name
    train, valid, test = dataset.load_train_val_test_csvs(dataset_name)
    train_feature_vals = [2 for i in range(train.shape[1])]
    print('Loaded dataset', dataset_name)

    #
    # initing the algo
    if not random:
        learnSPN = algo.learnspn.LearnSPN(rand_gen=rand_gen)
    else:
        learnSPN = algo.learnspn.RandomLearnSPN(rand_gen=rand_gen)

    #
    # start learning
    spn = learnSPN.fit_structure(train,
                                 train_feature_vals)
    return spn

In [147]:
#ignoring warnings
import warnings; warnings.simplefilter('ignore')

In [148]:
SPN = test_learnspn_oneshot(dataset_name)


Loaded dataset nltcs
Loaded dataset jester
Loaded dataset accidents

Getting info on the SPN


In [9]:
def print_info(SPN):
    print("is SPN complete and decomposable? : {}".format(SPN.is_complete() and SPN.is_decomposable()))
    print("{} layers in the spn".format(len(SPN._layers)+1))
    print("structured like {} ".format(([SPN._input_layer._n_nodes]+[layer._n_nodes for layer in SPN._layers])[::-1]))
    print("{} variables in his scope".format(len(SPN._root_layer._nodes[0].var_scope)))

In [142]:
print_info(SPN)


is SPN complete and decomposable? : True
3 layers in the spn
structured like [1, 3, 4] 
2 variables in his scope

The Poon algorithm (max-product algorithm)


In [93]:
start_time = time.time()
poon_max, poon_evid = SPN1.MPE_poon_eval()
timing = round(time.time() - start_time, 6)


---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-93-66a679934970> in <module>()
      1 start_time = time.time()
----> 2 poon_max1, poon_evid1 = SPN1.MPE_poon_eval()
      3 timing = round(time.time() - start_time, 6)
      4 calculation_time['poon'].append(timing)

/Users/leonardhussenot/Documents/Informatique/USP/spn/spn/linked/spn.py in MPE_poon_eval(self)
    677             evidence.append(evidence_dic[i])
    678             i+=1
--> 679         poon_max = self.exact_eval(numpy.array(evidence))
    680         print("the maximum found by Poon algo is {}".format(poon_max))
    681         return poon_max, evidence

/Users/leonardhussenot/Documents/Informatique/USP/spn/spn/linked/spn.py in exact_eval(self, input)
    146 
    147     def exact_eval(self, input):
--> 148         self._input_layer.eval(input)
    149         for layer in self._layers:
    150             layer.exact_eval()

/Users/leonardhussenot/Documents/Informatique/USP/spn/spn/linked/layers.py in eval(self, input)
    275         for node in self._nodes:
    276             # get the observed value
--> 277             obs = input[node.var]
    278             # and eval the node
    279             node.eval(obs)

IndexError: index 1 is out of bounds for axis 0 with size 1

In [160]:
timing


Out[160]:
0.008323

Branch&Bound


In [13]:
%matplotlib inline

## LAZY version
def branch_and_bound(spn, hot_start=True, heuristic="frequency", verbose=False):
     
    #Find branching order (from the variable most frequent in nodes, to the less)
    
    scope = spn._root_layer._nodes[0].var_scope
    assert(len(scope) == max(scope) + 1)
    var_presence =  [0]*len(scope)
    for node in spn._input_layer._nodes:
        probs = numpy.exp(node._var_probs)
        if heuristic == "frequency":
            var_presence[node.var]+= 1
        elif heuristic == "conflict":
             var_presence[node.var]+= abs(probs[0] - probs[1])
        else:
            print("unimplemented heuristics")
            return 0
    branching_order = [i[0] for i in sorted(enumerate(var_presence), key=lambda x: -x[1])]
    
    #the root node of the B&B
    root_evidence = [MAXI]*len(scope)
    
    #the stack
    stack = [(root_evidence, 0)]
    
    #the initial incumbent
    if hot_start:
        max_so_far, realized_by  = spn.MPE_poon_eval()
    else:
        max_so_far, realized_by  = 0,[]
    
    #statistics on B&B
    nbr_cut_off = 0
    nbr_nodes_cut = 0
    nbr_nodes_visited = 0
    nbr_nodes = 2**(len(scope)+1)-1
    current_bound = []
    current_max = []
    
    #Start counting time
    start_time = time.time()
    #B&B
    while stack:

        try:
            #current node
            evidence, branching_level = stack.pop()
            nbr_nodes_visited +=1
            
            if verbose:
                if nbr_nodes_visited % 15 == 0:
                    print("{}% of nodes passed (cut or visited)".format(100*(nbr_nodes_cut+nbr_nodes_visited)/nbr_nodes))
                    print("current level explored : {}".format(branching_level))
                    print("current max_so_far is {}".format(max_so_far))
                    print("current bound is {}".format(bound))

            #bounding
            bound, is_a_feasible_solution, evidence_in_case_yes = spn.bound_eval(evidence)

            
            if is_a_feasible_solution and bound > max_so_far :
                max_so_far, realized_by = bound, evidence_in_case_yes
                pass
            
            if bound <= max_so_far: #if bound <= max_so_far, we cut the subproblem
                nbr_cut_off +=1
                nbr_nodes_cut += 2**(len(scope)-branching_level+1)-2
                pass

            elif branching_level == len(scope):
                assert is_a_feasible_solution #feasible solution that is less than max_so_far, forgetting her
            else:
                #next branching is going to be on this variable
                var_to_branch = branching_order[branching_level]

                branch0 = copy.deepcopy(evidence)
                branch0[var_to_branch] = 0

                branch1 = evidence
                branch1[var_to_branch] = 1

                stack.append((branch0, branching_level +1))
                stack.append((branch1, branching_level +1))
                
            current_max.append(max_so_far)
            current_bound.append(bound)
            
        except KeyboardInterrupt:
            return max_so_far, realized_by
            sys.exit(0)
    timing = round(time.time() - start_time, 6)
    print("calculation time was {}".format(timing))
    print("--------------------------------------------------------------")
    print("B&B INFO :")   
    print("{} cut off during B&B".format(nbr_cut_off))
    print("{}% of nodes cut".format(round(100*nbr_nodes_cut/nbr_nodes, 4)))
    print("{} :  number of nodes visited".format(nbr_nodes_visited))
    print("{} :  number of nodes cut".format(nbr_nodes_cut))
    print("{} :  total number of nodes".format(nbr_nodes))
    print("--------------------------------------------------------------")
    print("MPE INFO :") 
    print("{} : MPE found".format(max_so_far))
    pprint("realized by evidence: {}".format(realized_by))
    
    return max_so_far, realized_by, (current_bound, current_max), timing

In [14]:
def plot_evolution(curs, names):
    
    plt.figure(1, figsize=(20,10))
    i = 1
    for cur in curs:
        ax = plt.subplot(2,2,i)
        ax.set_title(names[i-1])
        for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
            item.set_fontsize(17)
        
        x = numpy.arange(len(cur[0]))
        plt.plot(x,cur[0], color="r",  linewidth=1)
        plt.plot(x, cur[1], color = "b", linewidth=2)
        i +=1
        
    
    plt.show()

SPN : compute


In [ ]:
# hot_start=False, heuristic="frequency" (NOT TO DO WITH BIG SPNs)

In [59]:
bb_max1, bb_evid1, curs0, t0= branch_and_bound(SPN1, hot_start=False, heuristic="frequency")
print("-----------")
pprint("we enhanced the maximum by a factor {}".format(bb_max1/poon_max1))


calculation time was 0.000185
--------------------------------------------------------------
B&B INFO :
3 cut off during B&B
28.5714% of nodes cut
5 :  number of nodes visited
2 :  number of nodes cut
7 :  total number of nodes
--------------------------------------------------------------
MPE INFO :
0.3999999999999998 : MPE found
'realized by evidence: [1, 0]'
-----------
'we enhanced the maximum by a factor 1.3333333333333337'

In [ ]:
# hot_start=False, heuristic="conflict" (NOT TO DO WITH BIG SPNs)

In [60]:
bb_max1, bb_evid1, curs1, t1 = branch_and_bound(SPN1, hot_start=False, heuristic="conflict")
print("-----------")
pprint("we enhanced the maximum by a factor {}".format(bb_max1/poon_max1))


calculation time was 0.000262
--------------------------------------------------------------
B&B INFO :
3 cut off during B&B
28.5714% of nodes cut
5 :  number of nodes visited
2 :  number of nodes cut
7 :  total number of nodes
--------------------------------------------------------------
MPE INFO :
0.3999999999999998 : MPE found
'realized by evidence: [1, 0]'
-----------
'we enhanced the maximum by a factor 1.3333333333333337'

In [ ]:
#hot_start=True, heuristic="frequency"

In [61]:
bb_max1, bb_evid1, curs2, t2 = branch_and_bound(SPN1, hot_start=True, heuristic="frequency")
print("-----------")
pprint("we enhanced the maximum by a factor {}".format(bb_max1/poon_max1))


the maximum found by Poon algo is 0.29999999999999977
calculation time was 0.000185
--------------------------------------------------------------
B&B INFO :
3 cut off during B&B
28.5714% of nodes cut
5 :  number of nodes visited
2 :  number of nodes cut
7 :  total number of nodes
--------------------------------------------------------------
MPE INFO :
0.3999999999999998 : MPE found
'realized by evidence: [1, 0]'
-----------
'we enhanced the maximum by a factor 1.3333333333333337'

In [ ]:
#hot_start=True, heuristic="conflict"

In [134]:
bb_max1, bb_evid1, curs3, t3 = branch_and_bound(SPN1, hot_start=True, heuristic="conflict")
print("-----------")
pprint("we enhanced the maximum by a factor {}".format(bb_max1/poon_max1))


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-134-aa828bd38cfe> in <module>()
----> 1 bb_max1, bb_evid1, curs3, t3 = branch_and_bound(SPN1, hot_start=True, heuristic="conflict")
      2 print("-----------")
      3 pprint("we enhanced the maximum by a factor {}".format(bb_max1/poon_max1))
      4 calculation_time['bb_con_hs'].append(t3)

NameError: name 'SPN1' is not defined

In [ ]:
#Plot evolution of the different computations

In [63]:
plot_evolution([curs0, curs1, curs2, curs3], names=["Hot_Start : False, Heuritics : Frequency",
                                                    "Hot_Start : False, Heuritics : Conflict",
                                                    "Hot_Start : True, Heuritics : Frequency",
                                                    "Hot_Start : True, Heuritics : Conflict"])


External Solver

Inititialization


In [80]:
def get_val_dic_for_leaves(evidence):
    return {v: k for v, k in enumerate(evidence)}

In [81]:
dic_leaves = get_val_dic_for_leaves(poon_evid)

Implementation


In [1]:
from pyomo.environ import *
from pyomo.opt import SolverFactory

def solve(SPN, init=None, write_name=None, solver ="couenne"): # INIT = "bb","poon" or None
    #Scroll through the graph
    var_node, var_indic, constraints = SPN.get_variables_and_constraints1()
    root_eq = SPN.get_big_equation()

    #Set the model
    model = ConcreteModel()

    # Variablest
    if init:
        model.l = Var(var_indic, domain=Boolean, initialize=init)
    else:
        model.l = Var(var_indic, domain=Boolean)

    # Defining the objective
    model.obj = Objective(expr= eval(root_eq), sense=maximize)
    
    if write_name:
        model.write("solving_files/{}.nl".format(write_name), format=ProblemFormat.nl)

    #Solving
    opt = SolverFactory(solver)
    results = opt.solve(model)

    #model.pprint()
    print("--------")
    print("Termination is {}".format(results.Solver.Termination_condition))
    print("Calulation time is {}".format(results.Solver.Time))
    print("MPE found is {}".format(value(model.obj)))
    
    return model, results.Solver.Time

In [83]:
model, time = solve(SPN, init=dic_leaves)


scrolling through the graph
there are 2 layers to visit
end of scrolling through the graph
--------
Termination is optimal
Calulation time is 0.09786701202392578
MPE found is 0.3999999999999998

testing all possibilities (only for small SPNs)


In [ ]:
evs=[]
pourcent = 0.1
for i in range(2**16):

    if i>=2**16*pourcent:
        print("{}% done".format(int(100*pourcent)))
        pourcent+=0.1

    binary = list(str(bin(i))[2:])
    a = len(binary)
    while a < 16:
        a+=1
        binary.append('0')
    ev = numpy.array([int(j) for j in binary])

        
    evs.append(numpy.exp(SPN.exact_eval(ev))[0])

In [ ]:
numpy.argmax(evs)

In [ ]:


In [ ]:
numpy.exp(SPN._input_layer._nodes[0]._var_probs)

In [2]:
scope = SPN._root_layer._nodes[0].var_scope
assert(len(scope) == max(scope) + 1)
var_sum_conflict =  [0]*len(scope)
for node in SPN._input_layer._nodes:
    probs = numpy.exp(node._var_probs)
    var_presence[node.var]+= 1 #abs(probs[0] - probs[1])**4
branching_order2 = [i[0] for i in sorted(enumerate(var_presence), key=lambda x: -x[1])]


---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-2-0c0a8d443f54> in <module>()
----> 1 scope = SPN._root_layer._nodes[0].var_scope
      2 assert(len(scope) == max(scope) + 1)
      3 var_sum_conflict =  [0]*len(scope)
      4 for node in SPN._input_layer._nodes:
      5     probs = numpy.exp(node._var_probs)

NameError: name 'SPN' is not defined

In [ ]:
var_presence.sort()

In [ ]:
pprint(branching_order)

In [ ]:
pprint(branching_order2)

idSPN : Transorm AC in SPN


In [ ]:
# Use parse_spn in order to generate a networkX graph from a SPN learned thanks to libra 
# Libra generates a folder .spn containing .ac and .m files.

# Use then nx_to_spyn on the generated graph to generate an spn

# TO-DO : enhance the rapidity of the nx_to_spyn : the algorithm needs to find the layers 
# stuctures, and by now, the algorithm is in O(e^n)

In [3]:
import networkx as nx

def parse_spn(path_to_spn_folder):
    
    with open('{}'.format(path_to_spn_folder)) as f:
        content = f.readlines()
    # you may also want to remove whitespace characters like `\n` at the end of each line
    content = [x.strip().split(" ") for x in content]
    
    G = nx.DiGraph()
    nodes = {}
    nbr_node_visited = 0
    i = len(content) - 1
    while i > 1:
        node_info = content[i]
        while node_info[0] != 'n':
            i -=1
            node_info = content[i]
            
        for node_id in G.nodes():
            try:
                a = G.node[node_id]['n_type']
            except:
                print("{} has no ntype : IMPOSSIBLE".format(node_id))
        
        n_id, node_type, p_id = int(node_info[1]), node_info[2], int(node_info[3])
        print("{}..".format(n_id))
        
        if p_id == -1 :
            root_id = n_id
        
        if node_type == "null":
            i = i-1
            pass
        elif node_type == "+":

            children = list(map(int,content[i+1]))
            log_weights = list(map(float,content[i+2]))
            scope = frozenset(map(int,content[i+3]))
            
            G.add_node(n_id, n_type=node_type)
            for child, log_weight in zip(children,log_weights):
                G.add_edge(n_id, child, weight = math.exp(log_weight))
            i = i - 1
            
            #nodes[n_id] = ('+', SumNode(var_scope = scope), child_ids, log_weights)
            
        elif node_type == "*":
            children = list(map(int,content[i+1]))
            scope = frozenset(map(int,content[i+2]))
            
            G.add_node(n_id, n_type=node_type)
            for child in children:
                G.add_edge(n_id, child)
            i = i - 1

        elif node_type == "ac":
            scope = list(map(int,content[i+1]))
            sc = {i : s for i,s in enumerate(scope)}
            i = i -1
            ##apply function on .ac
            parse_ac(G, spn_name, n_id, sc)
            
                
                
    # delete all nodes * and + nodes without children, created by the suppression of parameter nodes:
    n_nodes = G.number_of_nodes() + 1
    while n_nodes != G.number_of_nodes():
        n_nodes = G.number_of_nodes()
        for node_id in G.nodes():
            if G.node[node_id]["n_type"] != "v" and not G.successors(node_id):
                G.remove_node(node_id)
    
            
    return G, root_id

In [84]:
def find_scope_c(content, i):
    line = content[i]
    if line[0]=='v':
        return set([int(line[1])])
    elif line[0]=='+':
        s = find_scope_c(content, int(line[1]))
        for child in line[2:]:
            b= find_scope_c(content, int(child))
            if b != s:
                print("prob")
        return s
    
    elif line[0]=='*':
        s = find_scope_c(content, int(line[1]))
        for child in line[2:]:
            b= find_scope_c(content, int(child))
            for elem in b:
                if b in s:
                    print('prob x')
                else:
                    s = s | set([elem])
        return s
    
    
    
    elif line[0]=='n':
        return set([])

In [4]:
import networkx as nx
import math

def parse_ac(G, spn_name, n_id, sc):
    # add every node in the same format in the nodes dic
    print("treating spac-{}".format(n_id))
    
    with open('../ocaml/{}.spn/spac-{}.ac'.format(spn_name, n_id)) as f:
        content = f.readlines()
    content = [x.strip().split(" ") for x in content]
    nbr_var = len(content[0])
    eof = content.index(["EOF"])
    
    content = content[1:eof] ##stopping at first EOF and removing var description (0)

    root_id = len(content) - 1

    #put everything in the graph G


    for node_id, node in enumerate(content):
        node_type = node[0]
        if node_type == 'n':
            G.add_node("spac{}-{}".format(n_id, node_id), n_type=node_type, value=float(node[1]))
        if node_type == 'v':
            ind_variable = int(node[1])
            global_variable = sc[ind_variable]
            G.add_node("spac{}-{}".format(n_id, node_id),
                       n_type=node_type,
                       variable=global_variable,
                       value=int(node[2]))
        if node_type == '*':
            G.add_node("spac{}-{}".format(n_id, node_id), n_type=node_type)
            for child in node[1:]:
                G.add_edge("spac{}-{}".format(n_id,node_id), "spac{}-{}".format(n_id,int(child)))

        if node_type == '+':
            G.add_node("spac{}-{}".format(n_id, node_id), n_type=node_type)
            for child in node[1:]:
                G.add_edge("spac{}-{}".format(n_id,node_id), "spac{}-{}".format(n_id,int(child)),
                          weight=1)
                
    #for every parameter node, find the good edges and multiply it by the param
    
    for node_id in G.nodes():
        
        try:
            a = G.node[node_id]['n_type']
        except:
            print("{} has no ntype".format(node_id))
        
        
        if G.node[node_id]['n_type'] == "n":
            value =  G.node[node_id]['value']
            edges = edges_to_be_multiplied(G, node_id)
            for edge in edges:
                G[edge[1]][edge[0]]["weight"] *= value

            G.remove_node(node_id) 
    
    
    # if the root is not a sum node, add a sum node
    if G.node["spac{}-{}".format(n_id, root_id)]['n_type'] != "+":
        G.add_node("spac{}-{}".format(n_id, root_id+1), n_type = '+')
        G.add_edge("spac{}-{}".format(n_id, root_id+1), 
                   "spac{}-{}".format(n_id, root_id), 
                   weight=1)
        root_id +=1
    
    root = "spac{}-{}".format(n_id, root_id)
    nx.relabel_nodes(G, {root : n_id}, copy=False)

In [5]:
##find the first sum node on each path to the root 
def edges_to_be_multiplied(G, node_id):
    edges = [(node_id, parent_id) for parent_id in G.predecessors(node_id)]
    final  = []
    
    while edges:
        for edge in edges:
            if G.node[edge[1]]['n_type'] == "+":
                final.append((edge))
                edges.remove(edge)
            else:

                edges.remove(edge)
                p = edge[1]
                edges = edges + [(p,parent_id) for parent_id in G.predecessors(p)]
    return list(set(final))

In [6]:
def nx_to_spyn(H):
    

    #define first layer
    print("finding layers")
    counter= 0
    node_to_layer = {}
    queue = []
    for node_id in H.nodes():
        counter +=1
        if H.node[node_id]["n_type"] == "v":
            node_to_layer[node_id] = 0
            queue.append((node_id,0))
    
    # find layer of each node
    while queue:
        node_id, layer_nbr = queue.pop(0)
        counter +=1 
        
        if counter % 300000 == 0 :
            print(len(queue))
        
        for parent in H.predecessors(node_id):
            l = layer_nbr + 1
            if (H.node[parent]["n_type"] == "+" and l % 2 == 0) \
            or (H.node[parent]["n_type"] == "*" and l % 2 == 1):
                l+=1
            node_to_layer[parent] = l
            queue.append((parent, l))
            
    #create the spyn
    nodes = {}
    nbr_of_layers = node_to_layer[0] + 1
    layers = [[] for _dummy in range(nbr_of_layers)]
    print("creating nodes")
    #create the nodes 
    for node_id in H.nodes():
        node = H.node[node_id]
        if node['n_type']=="+":
            nodes[node_id] = SumNode(var_scope=frozenset(find_scope(node_id, H)))
        elif node['n_type']=="*":
            nodes[node_id] = ProductNode(var_scope=frozenset(find_scope(node_id, H)))
        elif node['n_type']=="v":
            v = node["value"]
            f = [1-v, v]
            nodes[node_id] = CategoricalSmoothedNode(var=node["variable"],
                                                     var_values=[], 
                                                     freqs=f, 
                                                     alpha =0)
        layers[node_to_layer[node_id]].append(nodes[node_id])

    print("creating edges")
    #create the edge
    for edge in H.edges():
        if H.node[edge[0]]["n_type"] == '+':
            nodes[edge[0]].add_child(nodes[edge[1]], H.edge[edge[0]][edge[1]]["weight"])
        if H.node[edge[0]]["n_type"] == '*':
            nodes[edge[0]].add_child(nodes[edge[1]])
    
    print("creating layers")
    
    #create the layers
    layer_spyn = []
    for i, layer in enumerate(layers):
        if not layer :
            print('passing')
            continue
        if i == 0:
            input_layer = CategoricalSmoothedLayer(layers[0])
        else :
            if i % 2 == 1:
                l = SumLayerLinked(layers[i])
                l.normalize()
                layer_spyn.append(l)
            else :
                layer_spyn.append(ProductLayerLinked(layers[i]))
        
    print(len(layer_spyn))
    print("constructing SPN")
    spn = SpnLinked(input_layer=input_layer,
              layers=layer_spyn)
    
    return spn

In [7]:
def find_scope(node_id, G):
    if G.node[node_id]["n_type"] == 'v':
        return set([G.node[node_id]["variable"]])
    else:
        scope = set([])
        for child in G.successors(node_id):
            scope = scope | find_scope(child, G)
        return scope

In [8]:
G2, r1 = parse_spn("nltcsk2")


12..
11..
10..
9..
8..
treating spac-8
7..
treating spac-7
6..
5..
treating spac-5
4..
3..
2..
1..
treating spac-1
0..
19..
18..
17..
16..
15..
treating spac-15
14..
treating spac-14
13..
treating spac-13
12..
treating spac-12
11..
treating spac-11
10..
treating spac-10
9..
treating spac-9
8..
7..
treating spac-7
6..
5..
treating spac-5
4..
3..
treating spac-3
2..
1..
0..
22..
21..
20..
19..
18..
treating spac-18
17..
treating spac-17
16..
treating spac-16
15..
treating spac-15
14..
13..
treating spac-13
12..
treating spac-12
11..
treating spac-11
10..
treating spac-10
9..
treating spac-9
8..
treating spac-8
7..
treating spac-7
6..
5..
treating spac-5
4..
3..
2..
1..
treating spac-1
0..
25..
24..
23..
22..
21..
20..
19..
18..
17..
16..
15..
14..
13..
12..
treating spac-12
11..
treating spac-11
10..
treating spac-10
9..
treating spac-9
8..
treating spac-8
7..
treating spac-7
6..
5..
treating spac-5
4..
treating spac-4
3..
treating spac-3
2..
1..
treating spac-1
0..

In [9]:
SPN = nx_to_spyn(G2)


finding layers
42474
48348
44915
20135
creating nodes
creating edges
creating layers
passing
passing
passing
46
constructing SPN
finding layers
66344
90243
103737
115809
117531
112991
102219
74209
11885
creating nodes
creating edges
creating layers
passing
passing
passing
46
constructing SPN
finding layers
45135
50725
41908
creating nodes
creating edges
creating layers
passing
passing
41
constructing SPN
finding layers
63437
74251
81112
82269
75455
53971
creating nodes
creating edges
creating layers
passing
passing
41
constructing SPN

In [10]:
print(SPN.is_valid())


True
True
True
True

In [232]:
SPN.MPE_poon_eval()


the maximum found by Poon algo is 0.17582527487691932
Out[232]:
(0.17582527487691932, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [16]:
branch_and_bound(SPN)


the maximum found by Poon algo is 0.0064287181459060086
calculation time was 30.587583
--------------------------------------------------------------
B&B INFO :
300 cut off during B&B
99.543% of nodes cut
599 :  number of nodes visited
130472 :  number of nodes cut
131071 :  total number of nodes
--------------------------------------------------------------
MPE INFO :
0.0064287181459060086 : MPE found
'realized by evidence: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]'
Out[16]:
(0.0064287181459060086,
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 ([0.9178321768848117,
   0.22543593423428918,
   0.11881310747416744,
   0.06870112599688476,
   0.05124539693500759,
   0.032020754510656654,
   0.02716488175271835,
   0.020209105725912096,
   0.017479713349359417,
   0.009793258027980018,
   0.009688374998091281,
   0.006333908753982086,
   0.0033547490862683364,
   4.718828967345352e-05,
   0.00763569741686162,
   0.0071766813445249995,
   0.004249326090765269,
   0.0029049414244876554,
   0.0005417792638475186,
   0.0021578573943525905,
   0.006385069306652398,
   0.0031704151287695634,
   0.016981460424150063,
   0.015389639031863342,
   0.010957080478087073,
   0.009406988575434433,
   0.004667305731266797,
   0.0046642656963078255,
   0.0007692946211029017,
   0.0036510451355775257,
   0.0036948438845770732,
   0.019636492816002233,
   0.008842880234666155,
   0.006218616211960261,
   0.0017593206315028517,
   0.01210176334881667,
   0.010214660315598522,
   0.0078048536464480475,
   0.006913223567774948,
   0.005338190316801873,
   0.0015549016158845552,
   0.0005214386468259774,
   0.001801932724409873,
   0.0031089225819253426,
   0.05014690824079832,
   0.03215586586243941,
   0.022422276386891586,
   0.021258224857289083,
   0.015125042160187402,
   0.013486940763157585,
   0.0075691659984248065,
   0.007517971638463224,
   0.004999919051490988,
   0.00251819736886673,
   3.065069783608135e-05,
   0.005905471942963012,
   0.0015921660801756327,
   0.00614251299000832,
   0.001054209604369249,
   0.0099234301515486,
   0.008963390551719588,
   0.006639440672477033,
   0.006202383350273931,
   0.0003887832735635337,
   0.0023337552108358165,
   0.0007071227724881077,
   0.01806595754007654,
   0.008315050719672708,
   0.006419029960308626,
   0.0018448216661006557,
   0.009935701955053928,
   0.007529005236633354,
   0.005839256151599246,
   0.001693869155484404,
   0.0023361768283740414,
   0.09264114294439584,
   0.06344837960980752,
   0.045957883506200024,
   0.022957393248584302,
   0.01703027996649412,
   0.011058212389160387,
   0.007055388042365367,
   0.00407235934369454,
   0.002983028698670837,
   0.0040028243467950236,
   0.005972067577333731,
   0.005925545159010455,
   0.01699637951115024,
   0.010888186817322822,
   0.006969089054740039,
   0.004047828983506712,
   0.002921260071233322,
   0.003919097762582783,
   0.0061063316104679475,
   0.02089185923325958,
   0.00888244617536141,
   0.003984720225075457,
   0.004897725950285952,
   0.014612160674711577,
   0.007839194625530711,
   0.005180254284400082,
   0.002658940341130629,
   0.006769280121050364,
   0.0043144706434292334,
   0.002462181333882118,
   0.032594126464240376,
   0.018244186993377084,
   0.014011284425903743,
   0.011978912167577441,
   0.007667168622748267,
   0.005624326346245877,
   0.0020428422765023844,
   0.004311743544829181,
   0.002032372258326292,
   0.004245982947346113,
   0.014375284093366606,
   0.00839325139581428,
   0.0042937345607382054,
   0.004099516835076086,
   0.0059942969401828574,
   0.7379461616228683,
   0.274849183758482,
   0.10358404348795515,
   0.08013288550496693,
   0.0510615555061306,
   0.03920774547190469,
   0.02562646653240227,
   0.02043807866222453,
   0.008970219140623216,
   0.008709044753655517,
   0.005788192041998101,
   0.0029219145240457884,
   4.4583656336186466e-05,
   0.011277310234373282,
   0.0107487206321514,
   0.00710051671507417,
   0.006803532259844676,
   0.006161028104782291,
   0.0006425041550623862,
   0.00020726756460226942,
   0.0035640605854637764,
   0.0008392893306867196,
   0.00304279922701446,
   0.011438799659187116,
   0.006936467233468926,
   0.0013609372700496853,
   0.004131851227763163,
   0.007325155093642201,
   0.0022533720820465744,
   0.005792759380160596,
   0.008998407759458649,
   0.005150570233853383,
   0.0050074748969258205,
   0.024214374932724648,
   0.02065171975972085,
   0.013623055875132413,
   0.01122096538749472,
   0.003962362674169305,
   0.007117893600681025,
   0.006791232444384832,
   0.005140930122587663,
   0.0015729177651499513,
   0.0005163459344323859,
   0.0008866658810127388,
   0.005579313859390927,
   0.006525642003944261,
   0.004087202845950405,
   0.0032259774624165943,
   0.02761082090929023,
   0.01606648036017849,
   0.010131214212524428,
   0.00509024992422295,
   0.004093071826567418,
   0.004567791395548213,
   0.016036007908900786,
   0.012682961958448497,
   0.008916659902961626,
   0.007623168903421485,
   0.0050238027318760025,
   0.0025618031994896276,
   0.0006028940370599325,
   0.002632325170433968,
   0.005104980286541732,
   0.16533013553726014,
   0.09805477924817668,
   0.05597624159086587,
   0.04953005131695793,
   0.020695519689162552,
   0.016828731802710002,
   0.007336975871730175,
   0.007034199601182087,
   0.004164013528745505,
   0.002872074642778619,
   3.479973148271007e-05,
   0.009331274923335773,
   0.0072906267690566575,
   0.004910699888503087,
   0.0024174064368656183,
   0.002295121530718238,
   0.0032675964104508033,
   0.028956238243992398,
   0.013153940042578315,
   0.002596837864485423,
   0.00938063825038262,
   0.004856616195912643,
   0.0056129575766534325,
   0.0162797830612188,
   0.0021236858608026086,
   0.015416321210150843,
   0.002445286170332254,
   0.014153325897600185,
   0.004570202766086201,
   0.009858238821395403,
   0.004999784423282557,
   0.005833851968009122,
   0.005013385412317406,
   0.044554865734444456,
   0.03672138130204726,
   0.012061838780715886,
   0.008791792386679311,
   0.002468781000374128,
   0.006143753218916218,
   0.00264034793024367,
   0.024787445658011633,
   0.007476485624766855,
   0.00020836113574856967,
   0.006014577794816706,
   0.017812755360357417,
   0.0012968158500828706,
   0.017795858085562975,
   0.0012629838514610437,
   0.0177336771661298,
   0.001395522873301991,
   0.016338154292827808,
   0.006668953691891739,
   0.004610792631988633,
   0.0020581610599031018,
   0.009669200600936071,
   0.00667160800933323,
   0.003694046481483481,
   0.002977561527849751,
   0.002997592591602838,
   0.004534375653827188,
   0.0752867571211406,
   0.03434286228453189,
   0.02017575773230908,
   0.00708089338023692,
   0.004668392826075345,
   0.002244727296192395,
   0.013128942132437184,
   0.005552263982470218,
   0.0077103736275711265,
   0.0011083842584246177,
   0.007043697074765227,
   0.0011574419769232228,
   0.00630064728377937,
   0.013499251092980549,
   0.0012896124805538238,
   0.012209638612426708,
   0.0013911925144086635,
   0.010818446098018051,
   0.00017743919254150376,
   0.010641006905476551,
   0.0013280444502863573,
   0.009312962455190187,
   0.0020433141128601796,
   0.007269648342330017,
   0.003484336320191321,
   0.003785312022138695,
   0.04301114510386553,
   0.020819464641790852,
   0.0077835765614267406,
   0.0060076068073760674,
   0.0015113788007928848,
   0.013089631284143031,
   0.004221503657755156,
   0.009078975375866733,
   0.0007295025801293716,
   0.008949377075266267,
   0.0007020160346388665,
   0.008810131376789704,
   0.000929659951024404,
   0.007880471425765301,
   0.0039024698311152404,
   0.0039780015946500585,
   0.021615064355496476,
   0.0010050852431642875,
   0.020609979112332186,
   0.0010760127870488884,
   0.019533966325283304,
   8.234948483940962e-05,
   0.019451616840443903,
   0.0014104416765742518,
   0.018041175163869633,
   0.0021446404101498067,
   0.01589653475371982,
   0.006378667543849263,
   0.009517867209870565,
   0.007156480101449242,
   0.0026760914998008875,
   0.00448038860164836,
   0.0023613871084213174,
   0.4808402532524282,
   0.24608277471492038,
   0.18681603808951,
   0.09681045365995596,
   0.057352409646870973,
   0.021869807207051555,
   0.014235598960098933,
   0.008766211453896394,
   0.008676774994738531,
   0.0026174653291581675,
   0.0060593096655803605,
   8.943645915786315e-05,
   0.005469387506202546,
   0.00763420824695262,
   0.004379587251272623,
   0.0032546209956799965,
   0.03548260243981942,
   0.014027728726292743,
   0.005926261088128744,
   0.008101467638163978,
   0.005192936900329743,
   0.0029085307378342454,
   0.02145487371352668,
   0.0045899149760487085,
   0.016864958737477967,
   0.005411396241328891,
   0.011453562496149062,
   0.0024876254160781474,
   0.008965937080070908,
   0.004708854423157927,
   0.0042570826569129855,
   0.036881043790946866,
   0.016539419725938198,
   0.006158114732885169,
   0.003544534427657537,
   0.022854055377568634,
   0.0193234037091793,
   0.0006105157939925766,
   0.007353183059502677,
   0.007538646522724335,
   0.0016061714487969811,
   0.005703902985191488,
   4.283724228369273e-05,
   0.010716520324918504,
   0.011759270665602105,
   0.0003898987839033706,
   2.2867327267004268e-05,
   0.010102125493977893,
   0.01936899800522867,
   0.0018186176677379088,
   0.006203875783059008,
   0.0020661603586981624,
   0.058862981594184176,
   0.03220321716371841,
   0.010285934234222586,
   0.006349141052976473,
   0.00393679318124611,
   0.02191728292949584,
   0.007621636430617584,
   0.002506464699997326,
   0.00511517173062026,
   0.014295646498878252,
   0.002357146615501052,
   0.011938499883377189,
   0.0033295861922197426,
   0.008608913691157448,
   0.0014425532503928619,
   0.007166360440764581,
   0.0034507470627691402,
   0.003715613377995443,
   0.02580720339105287,
   0.00979451376750772,
   0.0035053356872697715,
   0.002246339292349415,
   0.017004139611900358,
   0.012616531406045975,
   0.00041185396984738025,
   0.0058845569044075185,
   0.008589767266210314,
   0.006490382063743067,
   0.00017517887730231525,
   2.441825488509968e-06,
   0.008314782690293935,
   0.013918234696427582,
   0.001402937605498681,
   0.006202535729976639,
   0.0007018168388528805,
   0.0808918911967363,
   0.03774492049529226,
   0.015889791969627483,
   0.005567403375135964,
   0.010322388594491503,
   0.004373621669380116,
   0.0059487669251113914,
   0.021411226064614826,
   0.011456479008589815,
   0.0034107647309244817,
   0.002270198598323147,
   0.01028537140073924,
   0.011751379410217276,
   0.0005531559358969415,
   0.0016870729764256046,
   0.004842706036498538,
   0.05189881246495727,
   0.02691710563277597,
   0.009578481297724374,
   0.007545943652258618,
   0.00651264070313829,
   0.006499556335137846,
   0.0017074441551322497,
   0.0047921121800055975,
   1.3084368000439528e-05,
   0.0010333029491203386,
   0.0020325376454657534,
   0.017338624335051588,
   0.009404299326300823,
   0.006400476947664429,
   0.003003822378636382,
   0.007934325008750758,
   0.0016079392171314557,
   0.006326385791619302,
   0.02495053361996568,
   0.013611177317528126,
   0.0049683371078627905,
   0.000704904684655042,
   0.011793511116745342,
   0.01731208449655298,
   0.0012427844793072647,
   0.0015209437347188527,
   0.003167230189795719,
   0.2619299371275332,
   0.11985633318343918,
   0.07139307450131828,
   0.053116750663259805,
   0.023222884343176143,
   0.015250119373221339,
   0.007047457135627675,
   0.007014164014168114,
   0.003614566382107274,
   0.003399597632060842,
   3.329312145956451e-05,
   0.008202662237593656,
   0.006450582183190646,
   0.00405712999782787,
   0.002393452185362782,
   0.0017520800544030028,
   0.007972764969954813,
   0.0019352803094335016,
   0.006037484660521313,
   0.029893866320083656,
   0.01121187775588136,
   0.002540108921187726,
   0.00867176883469363,
   0.005952991887364033,
   0.0027187769473295963,
   0.018681988564202293,
   0.0017787234582888455,
   0.01690326510591345,
   0.009368627262854696,
   0.0021767434945096377,
   0.0071918837683450565,
   0.003990010180585318,
   0.003201873587759739,
   0.007327348202612186,
   0.003512118385562935,
   0.004229809097942385,
   0.018276323838058453,
   0.00472866686926118,
   0.01354765696879726,
   0.005872627456700541,
   0.007675029512096727,
   0.0001731503314512689,
   0.007501879180645459,
   0.00374034331749241,
   0.0037615358631530463,
   0.04929601569900003,
   0.03543808175691009,
   0.014550899177482307,
   0.008523493042158256,
   0.00376219066740375,
   0.004761302374754506,
   0.006027406135324039,
   0.020887182579427805,
   0.006788859207992729,
   0.001924693502823484,
   0.004864165705169242,
   0.014098323371435061,
   0.0005809783534675935,
   0.013517345017967466,
   0.007420613928306461,
   0.0011557364024821892,
   0.006264877525824267,
   0.006096731089661007,
   0.012634777938077454,
   0.0030823790692548424,
   0.009552398868822629,
   0.004985847636894594,
   0.004566551231928021,
   0.14368715893523984,
   0.06207675899256872,
   0.019637654556027527,
   0.007353139252364392,
   0.003755089622300681,
   0.003598049630063712,
   0.012284515303663146,
   0.004194158201904511,
   0.008090357101758629,
   0.0009594693819557041,
   0.007130887719802926,
   0.0037437978309627627,
   0.0033870898888401614,
   0.042439104436541204,
   0.010999119365413612,
   0.004221666842186691,
   0.006777452523226917,
   0.0003040385651474805,
   0.006473413958079441,
   0.0024490235632819205,
   0.004024390394797521,
   0.03143998507112758,
   0.00720576286228504,
   0.0018471297768248875,
   0.005358633085460149,
   0.024234222208842533,
   0.00030149722804354126,
   0.023932724980798997,
   0.007055333369329136,
   0.001635073961025278,
   0.005420259408303862,
   0.016877391611469852,
   0.004083282334726593,
   0.01279410927674326,
   0.007380733863624575,
   0.0003014053594908636,
   0.007079328504133715,
   0.006839089116286054,
   0.002494013364064316,
   0.004345075752221734,
   0.00024023938784765775,
   0.005413375413118685,
   0.08232079751800833,
   0.033133575799076345,
   0.013558036720796324,
   0.00938698151020266,
   0.007230784793963655,
   0.007218252970251415,
   0.002488927632454052,
   0.004729325337797362,
   1.25318237122371e-05,
   0.0021561967162390065,
   0.004171055210593673,
   0.01957553907828001,
   0.008016107087992361,
   0.0049486389058316935,
   0.0030674681821606683,
   0.01155943199028764,
   0.0007839537121933576,
   0.010775478278094284,
   0.005310570463153683,
   0.005464907814940598,
   0.048867223130728034,
   0.014431792939636942,
   0.005266372512826872,
   0.009165420426810065,
   0.00029686925916776413,
   0.008868551167642302,
   0.003066170156691275,
   0.005802381010951023,
   0.03443543019109106,
   0.0075385056732737675,
   0.003951225107855875,
   0.0035872805654178896,
   0.02689692451781732,
   0.00026138562222946174,
   0.02663553889558784,
   0.008844495252200235,
   0.0019982880886460427,
   0.006846207163554194,
   0.003929444228336222,
   0.0029167629352179728,
   0.017791043643387607,
   0.0018510296542467001,
   0.015940013989140894,
   0.00742894702895295,
   0.0008530770602999403,
   0.006575869968653016,
   0.001781683222449619,
   0.004794186746203392,
   0.008511066960187946,
   0.0008440237695319176,
   0.007667043190656027,
   0.0011653072260143045,
   0.006501735964641722,
   4.0620390796764856e-05,
   0.006461115573844956,
   3.239742793894682e-05,
   0.0064287181459060086],
  [0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086,
   0.0064287181459060086]),
 30.587583)

From tree to graph


In [ ]:
#TO DO : end the algorithm merging the subtrees that represent almost the same distribution 
# cf paper Merging Strategies for Sum-Product Networks: From Trees to Graphs.
# from T. Rahman & V. Gogate

In [177]:
def cluster_spn(spn, i):
    clustering  = {}
    for layer in [spn._input_layer] + spn._layers:
        all_node_in_layer_higher = True
        for node in layer._nodes:
            len_scope = len(node.var_scope)
            all_node_in_layer_higher = all_node_in_layer_higher and (len_scope > i)
            if len_scope == i:
                if node.var_scope in clustering :
                     clustering[node.var_scope].append(node)
                else:
                    clustering[node.var_scope] = [node]
        if all_node_in_layer_higher :
            return clustering
    return clustering

In [181]:
cluster_spn(SPN2, 2)


Out[181]:
{frozenset({5, 7}): [Prod Node id: 4378 scope: frozenset({5, 7})
   (4721) (4722), Prod Node id: 4379 scope: frozenset({5, 7})
   (4723) (4724), Sum Node id: 3896 scope: frozenset({5, 7})
   (4378 0.689922480620155) (4379 0.31007751937984496)],
 frozenset({1, 11}): [Prod Node id: 4726 scope: frozenset({1, 11})
   (4985) (4986), Prod Node id: 4727 scope: frozenset({1, 11})
   (4987) (4988), Sum Node id: 4381 scope: frozenset({1, 11})
   (4726 0.3023255813953488) (4727 0.6976744186046512), Prod Node id: 4432 scope: frozenset({1, 11})
   (4787) (4788), Prod Node id: 4433 scope: frozenset({1, 11})
   (4789) (4790), Sum Node id: 3964 scope: frozenset({1, 11})
   (4432 0.7866666666666666) (4433 0.21333333333333335)],
 frozenset({6, 7}): [Prod Node id: 3803 scope: frozenset({6, 7})
   (4319) (4320), Prod Node id: 3804 scope: frozenset({6, 7})
   (4321) (4322), Sum Node id: 3245 scope: frozenset({6, 7})
   (3803 0.544) (3804 0.456)],
 frozenset({3, 13}): [Prod Node id: 4961 scope: frozenset({3, 13})
   (5135) (5136), Prod Node id: 4962 scope: frozenset({3, 13})
   (5137) (5138), Sum Node id: 4686 scope: frozenset({3, 13})
   (4961 0.8181818181818182) (4962 0.18181818181818182)],
 frozenset({10, 13}): [Prod Node id: 5146 scope: frozenset({10, 13})
   (5241) (5242), Prod Node id: 5147 scope: frozenset({10, 13})
   (5243) (5244), Sum Node id: 4975 scope: frozenset({10, 13})
   (5146 0.6481481481481481) (5147 0.35185185185185186)],
 frozenset({1, 9}): [Prod Node id: 3877 scope: frozenset({1, 9})
   (4371) (4372), Prod Node id: 3878 scope: frozenset({1, 9})
   (4373) (4374), Sum Node id: 3311 scope: frozenset({1, 9})
   (3877 0.8898305084745762) (3878 0.11016949152542373)],
 frozenset({3, 5}): [Prod Node id: 4869 scope: frozenset({3, 5})
   (5061) (5062), Prod Node id: 4870 scope: frozenset({3, 5})
   (5063) (5064), Sum Node id: 4541 scope: frozenset({3, 5})
   (4869 0.6173913043478261) (4870 0.3826086956521739), Prod Node id: 4525 scope: frozenset({3, 5})
   (4864) (4865), Prod Node id: 4526 scope: frozenset({3, 5})
   (4866) (4867), Prod Node id: 4396 scope: frozenset({3, 5})
   (4751) (4752), Prod Node id: 4397 scope: frozenset({3, 5})
   (4753) (4754), Sum Node id: 4072 scope: frozenset({3, 5})
   (4525 0.6747967479674797) (4526 0.3252032520325203), Sum Node id: 3937 scope: frozenset({3, 5})
   (4396 0.3161764705882353) (4397 0.6838235294117647), Prod Node id: 3514 scope: frozenset({3, 5})
   (4068) (4069), Prod Node id: 3515 scope: frozenset({3, 5})
   (4070) (4071), Prod Node id: 4756 scope: frozenset({3, 5})
   (5000) (5001), Prod Node id: 4757 scope: frozenset({3, 5})
   (5002) (5003), Sum Node id: 2926 scope: frozenset({3, 5})
   (3514 0.7007299270072993) (3515 0.29927007299270075), Sum Node id: 4398 scope: frozenset({3, 5})
   (4756 0.17475728155339806) (4757 0.8252427184466019), Prod Node id: 2253 scope: frozenset({3, 5})
   (2911) (2912), Prod Node id: 2254 scope: frozenset({3, 5})
   (2913) (2914), Sum Node id: 1740 scope: frozenset({3, 5})
   (2253 0.4154727793696275) (2254 0.5845272206303725)],
 frozenset({10, 14}): [Prod Node id: 5016 scope: frozenset({10, 14})
   (5171) (5172), Prod Node id: 5017 scope: frozenset({10, 14})
   (5173) (5174), Sum Node id: 4772 scope: frozenset({10, 14})
   (5016 0.7813953488372093) (5017 0.2186046511627907), Prod Node id: 4413 scope: frozenset({10, 14})
   (4776) (4777), Prod Node id: 4414 scope: frozenset({10, 14})
   (4778) (4779), Sum Node id: 3956 scope: frozenset({10, 14})
   (4413 0.27896995708154504) (4414 0.721030042918455), Prod Node id: 2924 scope: frozenset({10, 14})
   (3509) (3510), Prod Node id: 2925 scope: frozenset({10, 14})
   (3511) (3512), Sum Node id: 2263 scope: frozenset({10, 14})
   (2924 0.5182481751824818) (2925 0.48175182481751827)],
 frozenset({2, 11}): [Prod Node id: 4094 scope: frozenset({2, 11})
   (4531) (4532), Prod Node id: 4095 scope: frozenset({2, 11})
   (4533) (4534), Sum Node id: 3563 scope: frozenset({2, 11})
   (4094 0.19337016574585636) (4095 0.8066298342541437)],
 frozenset({10, 15}): [Prod Node id: 1902 scope: frozenset({10, 15})
   (2509) (2510), Prod Node id: 1903 scope: frozenset({10, 15})
   (2511) (2512), Sum Node id: 1448 scope: frozenset({10, 15})
   (1902 0.821656050955414) (1903 0.17834394904458598)],
 frozenset({10, 11}): [Prod Node id: 5239 scope: frozenset({10, 11})
   (5291) (5292), Prod Node id: 5240 scope: frozenset({10, 11})
   (5293) (5294), Sum Node id: 5118 scope: frozenset({10, 11})
   (5239 0.2830188679245283) (5240 0.7169811320754716)],
 frozenset({3, 11}): [Prod Node id: 2732 scope: frozenset({11, 3})
   (3392) (3393), Prod Node id: 2733 scope: frozenset({11, 3})
   (3394) (3395), Sum Node id: 2089 scope: frozenset({11, 3})
   (2732 0.8181818181818182) (2733 0.18181818181818182)],
 frozenset({1, 3}): [Prod Node id: 3149 scope: frozenset({1, 3})
   (3703) (3704), Prod Node id: 3150 scope: frozenset({1, 3})
   (3705) (3706), Sum Node id: 2517 scope: frozenset({1, 3})
   (3149 0.8529411764705882) (3150 0.14705882352941177)],
 frozenset({9, 14}): [Prod Node id: 5021 scope: frozenset({9, 14})
   (5175) (5176), Prod Node id: 5022 scope: frozenset({9, 14})
   (5177) (5178), Sum Node id: 4780 scope: frozenset({9, 14})
   (5021 0.34439834024896265) (5022 0.6556016597510373)],
 frozenset({11, 14}): [Prod Node id: 3541 scope: frozenset({11, 14})
   (4089) (4090), Prod Node id: 3542 scope: frozenset({11, 14})
   (4091) (4092), Sum Node id: 2956 scope: frozenset({11, 14})
   (3541 0.766798418972332) (3542 0.233201581027668)],
 frozenset({7, 13}): [Prod Node id: 4971 scope: frozenset({13, 7})
   (5141) (5142), Prod Node id: 4972 scope: frozenset({13, 7})
   (5143) (5144), Sum Node id: 4703 scope: frozenset({13, 7})
   (4971 0.7029702970297029) (4972 0.297029702970297)],
 frozenset({3, 10}): [Prod Node id: 3530 scope: frozenset({10, 3})
   (4082) (4083), Prod Node id: 3531 scope: frozenset({10, 3})
   (4084) (4085), Prod Node id: 4087 scope: frozenset({10, 3})
   (4527) (4528), Prod Node id: 4088 scope: frozenset({10, 3})
   (4529) (4530), Sum Node id: 2948 scope: frozenset({10, 3})
   (3530 0.6298507462686567) (3531 0.3701492537313433), Sum Node id: 3535 scope: frozenset({10, 3})
   (4087 0.6306620209059234) (4088 0.3693379790940767)],
 frozenset({0, 10}): [Prod Node id: 4549 scope: frozenset({0, 10})
   (4871) (4872), Prod Node id: 4550 scope: frozenset({0, 10})
   (4873) (4874), Prod Node id: 4254 scope: frozenset({0, 10})
   (4595) (4596), Prod Node id: 4255 scope: frozenset({0, 10})
   (4597) (4598), Sum Node id: 4138 scope: frozenset({0, 10})
   (4549 0.6893203883495146) (4550 0.3106796116504854), Sum Node id: 3709 scope: frozenset({0, 10})
   (4254 0.8403361344537815) (4255 0.15966386554621848)],
 frozenset({3, 9}): [Prod Node id: 2618 scope: frozenset({9, 3})
   (3252) (3253), Prod Node id: 2619 scope: frozenset({9, 3})
   (3254) (3255), Prod Node id: 3257 scope: frozenset({9, 3})
   (3825) (3826), Prod Node id: 3258 scope: frozenset({9, 3})
   (3827) (3828), Sum Node id: 1975 scope: frozenset({9, 3})
   (2618 0.75) (2619 0.25), Sum Node id: 2626 scope: frozenset({9, 3})
   (3257 0.7768595041322314) (3258 0.2231404958677686)],
 frozenset({3, 4}): [Prod Node id: 2993 scope: frozenset({3, 4})
   (3593) (3594), Prod Node id: 3596 scope: frozenset({3, 4})
   (4125) (4126), Prod Node id: 3597 scope: frozenset({3, 4})
   (4127) (4128), Prod Node id: 3501 scope: frozenset({3, 4})
   (4064) (4065), Prod Node id: 3502 scope: frozenset({3, 4})
   (4066) (4067), Sum Node id: 2355 scope: frozenset({3, 4})
   (2993 0.8537234042553191) (3596 0.07047872340425532) (3597 0.07579787234042554), Sum Node id: 2898 scope: frozenset({3, 4})
   (3501 0.9251700680272109) (3502 0.07482993197278912), Prod Node id: 2246 scope: frozenset({3, 4})
   (2904) (2905), Prod Node id: 2247 scope: frozenset({3, 4})
   (2906) (2907), Prod Node id: 1445 scope: frozenset({3, 4})
   (1896) (1897), Prod Node id: 1899 scope: frozenset({3, 4})
   (2505) (2506), Prod Node id: 1900 scope: frozenset({3, 4})
   (2507) (2508), Prod Node id: 1586 scope: frozenset({3, 4})
   (2020) (2021), Prod Node id: 1587 scope: frozenset({3, 4})
   (2022) (2023), Prod Node id: 2892 scope: frozenset({3, 4})
   (3493) (3494), Prod Node id: 2893 scope: frozenset({3, 4})
   (3495) (3496), Sum Node id: 1736 scope: frozenset({3, 4})
   (2246 0.6039076376554174) (2247 0.3960923623445826), Sum Node id: 944 scope: frozenset({3, 4})
   (1445 0.8369426751592357) (1899 0.07643312101910828) (1900 0.08662420382165605), Sum Node id: 1147 scope: frozenset({3, 4})
   (1586 0.6193895870736086) (1587 0.38061041292639136), Sum Node id: 2235 scope: frozenset({3, 4})
   (2892 0.8620689655172413) (2893 0.13793103448275862)],
 frozenset({4, 9}): [Prod Node id: 2909 scope: frozenset({9, 4})
   (3505) (3506), Prod Node id: 2910 scope: frozenset({9, 4})
   (3507) (3508), Prod Node id: 2256 scope: frozenset({9, 4})
   (2915) (2916), Prod Node id: 2257 scope: frozenset({9, 4})
   (2917) (2918), Prod Node id: 2259 scope: frozenset({9, 4})
   (2919) (2920), Prod Node id: 2260 scope: frozenset({9, 4})
   (2921) (2922), Prod Node id: 1973 scope: frozenset({9, 4})
   (2613) (2614), Prod Node id: 1974 scope: frozenset({9, 4})
   (2615) (2616), Prod Node id: 1531 scope: frozenset({9, 4})
   (1985) (1986), Prod Node id: 1532 scope: frozenset({9, 4})
   (1987) (1988), Prod Node id: 1591 scope: frozenset({9, 4})
   (2024) (2025), Prod Node id: 1592 scope: frozenset({9, 4})
   (2026) (2027), Prod Node id: 1600 scope: frozenset({9, 4})
   (2033) (2034), Prod Node id: 1601 scope: frozenset({9, 4})
   (2035) (2036), Sum Node id: 2250 scope: frozenset({9, 4})
   (2909 0.43982494529540483) (2910 0.5601750547045952), Sum Node id: 1741 scope: frozenset({9, 4})
   (2256 0.498567335243553) (2257 0.501432664756447), Sum Node id: 1742 scope: frozenset({9, 4})
   (2259 0.5053763440860215) (2260 0.4946236559139785), Sum Node id: 1522 scope: frozenset({9, 4})
   (1973 0.45714285714285713) (1974 0.5428571428571428), Sum Node id: 1115 scope: frozenset({9, 4})
   (1531 0.5024038461538461) (1532 0.49759615384615385), Sum Node id: 1149 scope: frozenset({9, 4})
   (1591 0.5552941176470588) (1592 0.4447058823529412), Sum Node id: 1154 scope: frozenset({9, 4})
   (1600 0.5507900677200903) (1601 0.4492099322799097), Prod Node id: 1734 scope: frozenset({9, 4})
   (2240) (2241), Prod Node id: 2243 scope: frozenset({9, 4})
   (2900) (2901), Prod Node id: 2244 scope: frozenset({9, 4})
   (2902) (2903), Prod Node id: 683 scope: frozenset({9, 4})
   (1104) (1105), Prod Node id: 684 scope: frozenset({9, 4})
   (1106) (1107), Prod Node id: 439 scope: frozenset({9, 4})
   (685) (686), Prod Node id: 460 scope: frozenset({9, 4})
   (710) (711), Prod Node id: 713 scope: frozenset({9, 4})
   (1143) (1144), Prod Node id: 714 scope: frozenset({9, 4})
   (1145) (1146), Sum Node id: 1320 scope: frozenset({9, 4})
   (1734 0.6911337209302325) (2243 0.12015503875968993) (2244 0.18871124031007752), Sum Node id: 244 scope: frozenset({9, 4})
   (683 0.12285368802902057) (684 0.19975816203143895) (439 0.6773881499395406), Sum Node id: 259 scope: frozenset({9, 4})
   (460 0.6928261916225325) (713 0.1136254212806933) (714 0.1935483870967742)],
 frozenset({1, 2}): [Prod Node id: 3973 scope: frozenset({1, 2})
   (4438) (4439), Prod Node id: 3974 scope: frozenset({1, 2})
   (4440) (4441), Sum Node id: 3402 scope: frozenset({1, 2})
   (3973 0.23636363636363636) (3974 0.7636363636363637)],
 frozenset({1, 5}): [Prod Node id: 2740 scope: frozenset({1, 5})
   (3404) (3405), Prod Node id: 2741 scope: frozenset({1, 5})
   (3406) (3407), Sum Node id: 2100 scope: frozenset({1, 5})
   (2740 0.46987951807228917) (2741 0.5301204819277109)],
 frozenset({0, 1}): [Prod Node id: 4135 scope: frozenset({0, 1})
   (4544) (4545), Prod Node id: 4136 scope: frozenset({0, 1})
   (4546) (4547), Sum Node id: 3612 scope: frozenset({0, 1})
   (4135 0.3617021276595745) (4136 0.6382978723404256), Prod Node id: 2859 scope: frozenset({0, 1})
   (3484) (3485), Prod Node id: 2860 scope: frozenset({0, 1})
   (3486) (3487), Sum Node id: 2222 scope: frozenset({0, 1})
   (2859 0.663594470046083) (2860 0.33640552995391704)]}

In [179]:
SPN1


Out[179]:
[sum layer:]
Sum Node id: 64915 scope: frozenset({0, 1})
 (64916 0.2) (64917 0.5) (64918 0.3)
**********************************************************

[prod layer:]
Prod Node id: 64916 scope: frozenset({0, 1})
 (64919) (64920)
Prod Node id: 64917 scope: frozenset({0, 1})
 (64920) (64921)
Prod Node id: 64918 scope: frozenset({0, 1})
 (64921) (64922)
**********************************************************

[input layer:]
Categorical Smoothed Node id: 64919 scope: frozenset({1})
            var: 1 val: 2 [[3, 7]] [[-1.2039728043259361, -0.35667494393873267]]
Categorical Smoothed Node id: 64920 scope: frozenset({0})
            var: 0 val: 2 [[6, 4]] [[-0.51082562376599094, -0.91629073187415533]]
Categorical Smoothed Node id: 64921 scope: frozenset({1})
            var: 1 val: 2 [[8, 2]] [[-0.22314355131421015, -1.6094379124341005]]
Categorical Smoothed Node id: 64922 scope: frozenset({0})
            var: 0 val: 2 [[1, 9]] [[-2.3025850929940459, -0.10536051565782634]]
**********************************************************