In [1]:
import nltk
import matplotlib.pyplot as plt
%matplotlib inline

In [10]:
string = "(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))"
tree = nltk.tree.Tree.fromstring(string)

def load_compressed_tree(s):

    def compress_tree(tree):
        if len(tree) == 1:
            if isinstance(tree[0], nltk.tree.Tree):
                return compress_tree(tree[0])
            else:
                return tree
        else:
            for i, t in enumerate(tree):
                tree[i] = compress_tree(t)
            return tree

    return compress_tree(nltk.tree.Tree.fromstring(s))
tree = load_compressed_tree(string)
for t in tree.subtrees():
    print(t)
    
print(str(tree))


(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))
(PRP I)
(VP (VBP am) (NNP Sam))
(VBP am)
(NNP Sam)
(. .)
(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))

In [3]:
print(tree.flatten())


(ROOT I am Sam .)

In [10]:
print(list(t.label() for t in tree.subtrees()))


['ROOT', 'S', 'NP', 'PRP', 'VP', 'VBP', 'NP', 'NNP', '.']

In [11]:
import json
d = json.load(open("data/squad/shared_dev.json", 'r'))

In [12]:
len(d['pos_counter'])


Out[12]:
73

In [13]:
d['pos_counter']


Out[13]:
{'#': 6,
 '$': 80,
 "''": 1291,
 ',': 14136,
 '-LRB-': 1926,
 '-RRB-': 1925,
 '.': 9505,
 ':': 1455,
 'ADJP': 3426,
 'ADVP': 4936,
 'CC': 9300,
 'CD': 6216,
 'CONJP': 191,
 'DT': 26286,
 'EX': 288,
 'FRAG': 107,
 'FW': 96,
 'IN': 32564,
 'INTJ': 12,
 'JJ': 21452,
 'JJR': 563,
 'JJS': 569,
 'LS': 7,
 'LST': 1,
 'MD': 1051,
 'NAC': 19,
 'NN': 34750,
 'NNP': 28392,
 'NNPS': 1400,
 'NNS': 16716,
 'NP': 91636,
 'NP-TMP': 236,
 'NX': 108,
 'PDT': 89,
 'POS': 1451,
 'PP': 33278,
 'PRN': 2085,
 'PRP': 2320,
 'PRP$': 1959,
 'PRT': 450,
 'QP': 838,
 'RB': 7611,
 'RBR': 301,
 'RBS': 252,
 'ROOT': 9587,
 'RP': 454,
 'RRC': 19,
 'S': 21557,
 'SBAR': 5009,
 'SBARQ': 6,
 'SINV': 135,
 'SQ': 5,
 'SYM': 17,
 'TO': 5167,
 'UCP': 143,
 'UH': 15,
 'VB': 4197,
 'VBD': 8377,
 'VBG': 3570,
 'VBN': 7218,
 'VBP': 2897,
 'VBZ': 4146,
 'VP': 33696,
 'WDT': 1368,
 'WHADJP': 5,
 'WHADVP': 439,
 'WHNP': 1927,
 'WHPP': 153,
 'WP': 482,
 'WP$': 50,
 'WRB': 442,
 'X': 23,
 '``': 1269}

In [3]:
from my.nltk_utils import tree2matrix, load_compressed_tree, find_max_f1_subtree, set_span
string = "(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))"
tree = load_compressed_tree(string)
span = (1, 3)
set_span(tree)
subtree = find_max_f1_subtree(tree, span)
f = lambda t: t == subtree
g = lambda t: 1 if isinstance(t, str) else 2
a, b = tree2matrix(tree, f, dtype='bool')
c, d = tree2matrix(tree, g, dtype='int32')
print(a)
print(c)


[[False False False False]
 [False  True False False]
 [False False False False]]
[[0 2 2 0]
 [2 2 0 2]
 [2 0 0 0]]

In [ ]: