In [39]:
import itertools
import networkx as nx
from networkx import dfs_tree
from networkx.algorithms import isomorphism as iso
# install with: sudo pip install git+http://github.com/chebee7i/nxpd/#egg=nxpd
from nxpd import draw
from discoursekernels.tree import (
    generate_all_unique_subtrees, contains_only_complete_productions,
    get_subtrees, count_tree_fragment_occurances,
    get_production_rules, common_subtrees,
    tree_kernel_naive, tree_kernel_polynomial, find_all_common_subtrees_bruteforce)
from discoursekernels.tree import (
    is_rooted_at_node, is_leave, is_proper, is_treefragment)
from discoursekernels.test_tree import (example_tree, tree_jeff_ate_cookies,
                                        tree_alex_died, tree_steve_ate_bananas, tree_the_man_drank_wine,
                                        tree_the_man_killed_the_woman, tree_fragment_alex, tree_fragment_npn)
from discoursekernels.util import print_source, label_nodes, draw_multiple_graphs

Tree Kernel (Collins and Duffy 2002)

  • first, implicitly enumerate all tree fragments that occur in the training data (i.e. in all trees): 1, ..., n

In [3]:
print_source(generate_all_unique_subtrees)


Out[3]:

def generate_all_unique_subtrees(*trees):
    node_attrib = 'label'
    same_node_label = iso.categorical_node_match(node_attrib, '')
    if len(trees) == 0:
        return []
    elif len(trees) == 1:
        return list(get_subtrees(trees[0], node_attrib=node_attrib))
    else:
        unique_subtrees = list(get_subtrees(trees[0], node_attrib=node_attrib))
        for tree in trees[1:]:
            for subtree in get_subtrees(tree, node_attrib=node_attrib):
                # match each new subtree against all subtrees already in unique_subtrees
                # if it is not isomorphic (incl. matching node labels) to any of the existing
                # subtrees, it will be added to the list
                if not any(nx.is_isomorphic(new_subtree, old_subtree, node_match=same_node_label)
                           for new_subtree, old_subtree in itertools.product([subtree], unique_subtrees)):
                    unique_subtrees.append(subtree)
        return unique_subtrees

In [4]:
unique_subtrees = generate_all_unique_subtrees(tree_alex_died, tree_jeff_ate_cookies,
                                               tree_steve_ate_bananas, tree_the_man_drank_wine)
  • each tree is represented by an n-dimensional vector, where the i-th component counts how often the i-th tree fragment occurs

In [5]:
# NOTE: this doesn't count _how often_ the i-th component occurs in the tree, but only says
# if it occurs in the tree at all
tree_alex_died_vector = [int(is_treefragment(tree_alex_died, ust)) for ust in unique_subtrees]
  • function $h_i(T)$ counts how often the i-th tree fragment occurs in tree T

  • T is now represented as:

$${\bf h}(T) = (h_1(T), h_2(T), ..., h_n(T))$$
  • WARNING: n will be huge (number of subtrees grows exponentially)

How to calculate a tree kernel efficiently?

  • tree kernel: inner product between two Trees $T_1$ and $T_2$:
$$K(T_1, T_2) = {\bf h}(T_1) \cdot {\bf h}(T_2)$$
  • $N_1$: set of nodes in $T_1$
  • indicator function $I_i(n)$:
$$ I_i(n) = \begin{cases} 1 \text{ if subtree is rooted at node } n & \\ 0 \text{ otherwise} \end{cases} $$

In [6]:
print_source(is_rooted_at_node)


Out[6]:

def is_rooted_at_node(tree, subtree, tree_node, node_attrib=None):
    """
    Indicator function $I_i(n)$: Is the subtree i rooted at node n (of the tree)?
    
    Returns
    -------
    is_rooted : int
        Returns 1, iff all rule productions of the subtree can be found in the
        set of production rules of the tree (starting at node n / ``tree_node``) and
        iff the tree_node n and the subtree's root node are equal
        (same node labels if node_attrib is given, otherwise: same node IDs).
        Otherwise, returns 0.
    """
    # the root node of a tree is the first element in a topological sort of the tree
    subtree_root_node = topological_sort(subtree)[0]
    
    # a subtree can only be rooted at a tree's tree_node (n),
    # if the tree node and the subtree root node are equal
    if node_attrib:
        if tree.node[tree_node][node_attrib] != subtree.node[subtree_root_node][node_attrib]:
            return 0
    else:
        if tree_node != subtree_root_node:
            return 0

    tree_subtree_rules = get_production_rules(tree, root_node=tree_node, node_attrib=node_attrib)
    subtree_rules = get_production_rules(subtree, node_attrib=node_attrib)
    if all(st_rule in tree_subtree_rules for st_rule in subtree_rules):
        return 1
    else:
        return 0
  • function $h_i(T_1)$ can now be calculated as follows:
$$h_i(T_1) = \sum_{n_1 \in N_1} I_i(n_1)$$

In [7]:
print_source(count_tree_fragment_occurances)


Out[7]:

def count_tree_fragment_occurances(tree, subtree, node_attrib='label'):
    """
    $h_i(T_1)$ : how often does subtree i occur in Tree 1?
    """
    counter = 0
    for node in tree.nodes_iter():
        # is_rooted() returns one if the productions of the subtree and
        # the productions of the tree (beginning at "node") are the same
        counter += is_rooted_at_node(tree, subtree, tree_node=node,
                                     node_attrib=node_attrib)
    return counter

In [8]:
count_tree_fragment_occurances(tree_alex_died, tree_alex_died)


Out[8]:
1

calculate how often the fragment 'NP->D N' occurs in 'the man kills the woman'


In [9]:
npdn = nx.DiGraph()
npdn.add_nodes_from(label_nodes([ (1, 'NP'), (2, 'D'), (3, 'N') ]))
npdn.add_edges_from([ (1, 2), (1, 3) ])

In [10]:
count_tree_fragment_occurances(tree=tree_the_man_drank_wine, subtree=npdn)


Out[10]:
1

In [11]:
count_tree_fragment_occurances(tree=tree_the_man_killed_the_woman, subtree=npdn)


Out[11]:
2
  • which lets us calculate the inner product of two trees as follows:
$$ {\bf h}(T_1) \cdot {\bf h}(T_2) = \sum_i h_i(T_1) h_i(T_2) \\ = \sum_{n_1 \in N_1} \sum_{n_2 \in N_2} \sum_i I_i(n_1) I_i(n_2) \\ = \sum_{n_1 \in N_1} \sum_{n_2 \in N_2} C(n_1, n_2) $$
  • function $C(n_1, n_2)$ simply counts the number of common subtrees rooted at both $n_1$ and $n_2$ and is defined as $\sum_i I_i(n_1) I_i(n_2)$

  • NOTE: $C(n_1, n_2)$ can be computed in polynomial time:

$$C(n_1, n_2) = \begin{cases} 0 & \text{ if the productions at } n_1 \text{ and } n_2 \text{ are different} & \\ 1 & \text{ if the productions at } n_1 \text{ and } n_2 \text{ are the same and } n_1 \text{ and } n_2 \text{ are pre-terminals} & \\ \prod_{j=1}^{nc(n_1)} (1 + C( ch(n_1, j), ch(n_2, j) ) & \text{ if the productions at } n_1 \text{ and } n_2 \text{ are the same but } n_1 \text{ and } n_2 \text{ are } \textbf{not} \text{ pre-terminals} & \end{cases} $$
  • $nc(n_1)$: number of children of $n_1$ in the tree (if the productions at $n_1$/$n_2$ are the same, they have the same number of children, as well)
  • $ch(n_1, i)$: i-th child of $n_1$

In [12]:
print_source(common_subtrees)


Out[12]:

def common_subtrees(tree1, tree2, n1, n2, node_attrib='label'):
    """
    function $C(n_1, n_2)$ simply counts the number of
    _common subtrees_ rooted at both $n_1$ and $n_2$
    and is defined as $\sum_i I_i(n_1) I_i(n_2)$
    """
    n1_rules = get_production_rules(tree1, n1, node_attrib=node_attrib)
    n2_rules = get_production_rules(tree2, n2, node_attrib=node_attrib)
    
    if min(len(n1_rules), len(n2_rules)) < 1:
        # this condition isn't explicitly mentioned in Collins and Duffy (2001),
        # but they state that a valid subtree must have more than one node
        # if a subtree has no production rules, it only consists of leave nodes
        return 0
    
    if n1_rules != n2_rules:
        return 0
    else:  # n1_rules == n2_rules
        if is_preterminal(tree1, n1) and is_preterminal(tree2, n2):
            return 1
        else:  # n1 and/or n2 aren't preterminals
            n1_children = tree1.successors(n1)
            n2_children = tree1.successors(n2)
            assert len(n1_children) == len(n2_children)
            result = 1  # neutral element of multiplication
            for j, n1_child_node in enumerate(n1_children):
                result *= 1 + common_subtrees(tree1, tree2, n1_child_node, n2_children[j])
            return result

Calculating a kernel between two trees


In [13]:
print_source(tree_kernel_naive)


Out[13]:

def tree_kernel_naive(tree1, tree2, node_attrib='label'):
    """
    \sum_{n_1 \in N_1} \sum_{n_2 \in N_2} \sum_i I_i(n_1) I_i(n_2)
    """
    all_subtrees = generate_all_unique_subtrees(tree1, tree2)
    common_sts = 0
    for tree1_node in tree1.nodes_iter():
        for tree2_node in tree2.nodes_iter():
            for subtree in all_subtrees:
                common_sts += is_rooted_at_node(tree1, subtree, tree1_node) * is_rooted_at_node(tree2, subtree, tree2_node)
    return common_sts

In [14]:
print_source(tree_kernel_polynomial)


Out[14]:

def tree_kernel_polynomial(tree1, tree2, node_attrib='label'):
    """
    \sum_{n_1 \in N_1} \sum_{n_2 \in N_2} C(n_1, n_2)
    """
    common_sts = 0
    for tree1_node in tree1.nodes_iter():
        for tree2_node in tree2.nodes_iter():
            common_sts += common_subtrees(tree1, tree2, tree1_node, tree2_node, node_attrib=node_attrib)
    return common_sts

In [15]:
print_source(find_all_common_subtrees_bruteforce)


Out[15]:

def find_all_common_subtrees_bruteforce(tree1, tree2, node_attrib=None):
    """
    returns a list of all valid subtrees (Collins and Duffy 2001)
    that occur in both given trees.
    
    two subtrees are considered equal, iff they have the same structure
    and their node labels are identical.
    """
    same_node_label = iso.categorical_node_match('label', '')
    tree1_subtrees = get_subtrees(tree1, node_attrib=node_attrib)
    tree2_subtrees = get_subtrees(tree2, node_attrib=node_attrib)
    common_subtrees = []
    for (subtree1, subtree2) in itertools.product(tree1_subtrees, tree2_subtrees):
        if nx.is_isomorphic(subtree1, subtree2, node_match=same_node_label):
            common_subtrees.append(subtree1)
    return common_subtrees

Speed comparison: Tree kernel implementations

  • tree_kernel_naive() is really slow
  • even find_all_common_subtrees_bruteforce() is 2x faster
  • tree_kernel_polynomial() is more than 300x faster

In [16]:
%timeit tree_kernel_naive(tree_the_man_drank_wine, tree_the_man_killed_the_woman)


1 loops, best of 3: 2.22 s per loop

In [17]:
%timeit tree_kernel_polynomial(tree_the_man_drank_wine, tree_the_man_killed_the_woman)


100 loops, best of 3: 7.38 ms per loop

In [18]:
%timeit find_all_common_subtrees_bruteforce(tree_the_man_drank_wine, tree_the_man_killed_the_woman)


1 loops, best of 3: 933 ms per loop

In [19]:
tree_kernel_naive(tree_alex_died, tree_alex_died) \
== tree_kernel_polynomial(tree_alex_died, tree_alex_died) \
== len(find_all_common_subtrees_bruteforce(tree_alex_died, tree_alex_died))


Out[19]:
True

In [20]:
from IPython.display import Image

In [21]:
Image('img/a-dog.png')


Out[21]:

In [22]:
Image('img/a-cat.png')


Out[22]:

3 out of 5 structures are identical, therefore the similarity is equal to 3.


In [43]:
a_dog = nx.DiGraph()
a_dog.add_nodes_from(label_nodes([(1, 'NP'), (2, 'D'), (3, 'N'), (4, 'a'), (5, 'dog')]))
a_dog.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 5)])

a_cat = nx.DiGraph()
a_cat.add_nodes_from(label_nodes([(1, 'NP'), (2, 'D'), (3, 'N'), (4, 'a'), (5, 'cat')]))
a_cat.add_edges_from([(1, 2), (1, 3), (2, 4), (3, 5)])

In [48]:
draw_multiple_graphs([a_dog, a_cat])



In [51]:
# draw_multiple_graphs(generate_all_unique_subtrees(a_dog, a_cat))

In [24]:
# print tree_kernel(tree1, tree2)
# print tree_kernel_recursive(tree1, tree2)

In [25]:
tree3 = nx.DiGraph()
tree3.add_edges_from([('D', 'a')])

# print tree_kernel(tree3, tree3)
# print tree_kernel_recursive(tree3, tree3)

In [26]:
[s.edges() for s in get_subtrees(tree1)]


Out[26]:
[[('D', 'a')],
 [('N', 'dog')],
 [('NP', 'D'), ('NP', 'N')],
 [('NP', 'D'), ('NP', 'N'), ('D', 'a')],
 [('NP', 'D'), ('NP', 'N'), ('N', 'dog')],
 [('NP', 'D'), ('NP', 'N'), ('D', 'a'), ('N', 'dog')]]

In [27]:
draw(tree_jeff_ate_cookies, show='ipynb')


Out[27]:

In [28]:
draw(tree_steve_ate_bananas, show='ipynb')


Out[28]:

In [29]:
draw(tree_alex_died, show='ipynb')


Out[29]:

In [30]:
draw(tree_the_man_drank_wine, show='ipynb')


Out[30]:

In [31]:
draw(tree_the_man_killed_the_woman, show='ipynb')


Out[31]:

production rules from (syntax) node labels vs. production rules from node IDs

  • syntax production rules is what we want

caveat: it might produce subtrees we don't want e.g. if tree contains both NP -> N and NP -> D N rules the algorithm will produce an NP -> N -> Noun subtree even from a NP -> D N subtree


In [32]:
get_production_rules(tree_the_man_drank_wine, node_attrib='label')


Out[32]:
{('D', ('the',)),
 ('N', ('man',)),
 ('N', ('wine',)),
 ('NP', ('D', 'N')),
 ('NP', ('N',)),
 ('S', ('NP', 'VP')),
 ('V', ('drank',)),
 ('VP', ('V', 'NP'))}

In [33]:
len(list(get_subtrees(tree_the_man_drank_wine, node_attrib='label')))


Out[33]:
67

In [34]:
len(list(get_subtrees(tree_the_man_drank_wine)))


Out[34]:
51

In [35]:
print len(find_all_common_subtrees_bruteforce(tree_jeff_ate_cookies, tree_jeff_ate_cookies))
print len(find_all_common_subtrees_bruteforce(tree_jeff_ate_cookies, tree_jeff_ate_cookies, node_attrib='label'))


36
36

In [36]:
print len(find_all_common_subtrees_bruteforce(tree_jeff_ate_cookies, tree_steve_ate_bananas))
print len(find_all_common_subtrees_bruteforce(tree_jeff_ate_cookies, tree_steve_ate_bananas, node_attrib='label'))


19
19

In [37]:
print len(find_all_common_subtrees_bruteforce(tree_jeff_ate_cookies, tree_the_man_drank_wine))
print len(find_all_common_subtrees_bruteforce(tree_jeff_ate_cookies, tree_the_man_drank_wine, node_attrib='label'))


7
12