In [1]:
import nltk
import matplotlib.pyplot as plt
%matplotlib inline
In [10]:
string = "(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))"
tree = nltk.tree.Tree.fromstring(string)
def load_compressed_tree(s):
def compress_tree(tree):
if len(tree) == 1:
if isinstance(tree[0], nltk.tree.Tree):
return compress_tree(tree[0])
else:
return tree
else:
for i, t in enumerate(tree):
tree[i] = compress_tree(t)
return tree
return compress_tree(nltk.tree.Tree.fromstring(s))
tree = load_compressed_tree(string)
for t in tree.subtrees():
print(t)
print(str(tree))
In [3]:
print(tree.flatten())
In [10]:
print(list(t.label() for t in tree.subtrees()))
In [11]:
import json
d = json.load(open("data/squad/shared_dev.json", 'r'))
In [12]:
len(d['pos_counter'])
Out[12]:
In [13]:
d['pos_counter']
Out[13]:
In [3]:
from my.nltk_utils import tree2matrix, load_compressed_tree, find_max_f1_subtree, set_span
string = "(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))"
tree = load_compressed_tree(string)
span = (1, 3)
set_span(tree)
subtree = find_max_f1_subtree(tree, span)
f = lambda t: t == subtree
g = lambda t: 1 if isinstance(t, str) else 2
a, b = tree2matrix(tree, f, dtype='bool')
c, d = tree2matrix(tree, g, dtype='int32')
print(a)
print(c)
In [ ]: