In [1]:
push!(LOAD_PATH, pwd())
using JoinTreeInference: Node, Potential, parse_net, triangulate_graph
using DataStructures

In [2]:
type Sepset
    first::Int64
    second::Int64
    nodes::Set{String}
    mass::Int64
    cost::Int64
end

In [3]:
function create_sepsets(clusters, node_list)
    node_weights = Dict{String, Int64}()
    for node in node_list
        node_weights[node.name] = length(node.states)
    end
    n = length(clusters)
    weights = Array{Int64, 1}()
    for i in 1:n
        w = 1
        for v in clusters[i]
            w *= node_weights[v]
        end
        push!(weights, w)
    end
    
    sepsets = Array{Sepset, 1}()
    for i in 1:n
        for j in (i+1):n
            # creat a new sepset
            nodes = intersect(clusters[i], clusters[j])
            mass = length(nodes)
            cost = weights[i] + weights[j]
            push!(sepsets, Sepset(i, j, nodes, mass, cost))
        end
    end
    
    sepset_comp(x, y) = (x.mass > y.mass) || ((x.mass == y.mass) && (x.cost < y.cost))
    return sort(sepsets, lt = sepset_comp)
end


Out[3]:
create_sepsets (generic function with 1 method)

In [4]:
function create_junction_tree(clusters::Array{Set{String}, 1}, sepsets::Array{Sepset, 1})
    n = length(clusters)
    output_tree = Dict{Int, Set{Int}}()
    for i in 1:n
        output_tree[i] = Set{Int}()
    end
    
    tree = IntDisjointSets(n)
    num_edges = 0
    for sepset in sepsets
        if num_edges == n - 1
            break
        end
        if ! in_same_set(tree, sepset.first, sepset.second)
            union!(tree, sepset.first, sepset.second)
            push!(output_tree[sepset.first], sepset.second)
            push!(output_tree[sepset.second], sepset.first)
            num_edges += 1
        end
    end
    return output_tree
end

function create_junction_tree(clusters::Array{Set{String}, 1}, node_list::Array{Node, 1})
    sepsets = create_sepsets(clusters, node_list)
    return create_junction_tree(clusters, sepsets)
end


Out[4]:
create_junction_tree (generic function with 2 methods)

In [5]:
node_list, potential_list = parse_net("data/asia.net")
tg, clusters = triangulate_graph(node_list, potential_list)
jt = create_junction_tree(clusters, node_list)


Out[5]:
Dict{Int64,Set{Int64}} with 6 entries:
  4 => Set([6,1])
  2 => Set([3])
  3 => Set([2,5])
  5 => Set([3,6])
  6 => Set([4,5])
  1 => Set([4])