In [222]:
import json
import subprocess
from collections import defaultdict
In [223]:
subprocess.call('python count_cfg_freq.py parse_train.dat > cfg.counts', shell=True)
# 17 NONTERMINAL NP - nonterminal counts Count(X)
# 8 UNARYRULE NP+NOUN place - Unigram counts Count(X -> Y)
# 918 BINARYRULE NP DET NOUN - Emission counts Count(X -> Y1 Y2)
Out[223]:
In [224]:
def load_counts(file_path):
word_counts = defaultdict(int)
unary_counts = defaultdict(dict)
nonterminal_counts = {}
emission_counts = defaultdict(dict)
with open(file_path, 'rt') as f:
for line in f:
parts = line.rstrip().split(' ')
count = int(parts[0])
count_type = parts[1]
nonterminal = parts[2]
if count_type == 'UNARYRULE':
word = parts[3]
word_counts[word] += count
unary_counts[nonterminal][word] = count
elif count_type == 'NONTERMINAL':
nonterminal_counts[nonterminal] = count
elif count_type == 'BINARYRULE':
if len(parts) != 5:
raise Exception('CFG must be in Chomsky Normal Form (CNF)')
y1, y2 = parts[3], parts[4]
emission_counts[nonterminal][y1, y2] = count
return word_counts, nonterminal_counts, emission_counts, unary_counts
def branches_iterator(tree):
# Pre-order tree traversal
# It assumes the tree is in CNF
if type(tree) is not list:
raise TypeError('"tree" must be a list')
if len(tree) == 3:
# binary rule branch
yield tree
# left subtree
for branch in branches_iterator(tree[1]):
yield branch
# right subtree
for branch in branches_iterator(tree[2]):
yield branch
elif len(tree) == 2:
# unary rule branch
yield tree
def trees_iterator(file_path):
with open(file_path, 'rt') as f:
for line in f:
tree = json.loads(line)
yield tree
def replace_rare_words_in_tree(tree, word_counts):
for branch in branches_iterator(tree):
if len(branch) == 2:
# unary rule branch
word = branch[1]
if is_rare_word(word, word_counts):
branch[1] = '_RARE_'
def is_rare_word(word, word_counts):
return word_counts.get(word, 0) < 5
In [225]:
# count rare words regardless of part of speech
# https://class.coursera.org/nlangp-001/forum/thread?thread_id=922
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('cfg.counts')
with open('parse_train2.dat', 'wt') as f:
for tree in trees_iterator('parse_train.dat'):
replace_rare_words_in_tree(tree, word_counts)
f.write(json.dumps(tree) + '\n')
In [226]:
subprocess.call('python count_cfg_freq.py parse_train2.dat > parse_train.counts.out', shell=True)
Out[226]:
In [227]:
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('parse_train.counts.out')
In [228]:
def check_probabilities(probabilities):
for state, transitions in probabilities.iteritems():
total_prob = sum(transitions.itervalues())
total_prob = round(total_prob, 10)
if total_prob != 1.0:
raise Exception(state + ': probability = ' + str(total_prob))
def calc_model_parameters(unary_counts, nonterminal_counts, emission_counts):
emission_q = defaultdict(dict)
unary_q = defaultdict(dict)
for nonterminal, transitions in emission_counts.iteritems():
for transit, count in transitions.iteritems():
emission_q[nonterminal][transit] = float(count) / nonterminal_counts[nonterminal]
for nonterminal, transitions in unary_counts.iteritems():
for word, count in transitions.iteritems():
unary_q[nonterminal][word] = float(count) / nonterminal_counts[nonterminal]
check_probabilities(emission_q)
check_probabilities(unary_q)
return emission_q, unary_q
def replace_rare_words(words, word_counts):
return [word if not is_rare_word(word, word_counts) else '_RARE_' for word in words]
In [229]:
# emission_q[X][(Y,Z)] = binary rule probability P(X -> Y Z)
# unary_q[X][w] = unary rule probability P(X -> w)
emission_q, unary_q = calc_model_parameters(unary_counts, nonterminal_counts, emission_counts)
In [234]:
def backtrace(root, bp):
if len(root) == 6:
(X, Y, Z, i, s, j) = root
return [X, backtrace(bp[i, s, Y], bp), backtrace(bp[s + 1, j, Z], bp)]
else:
(X, Y, i, i) = root
return [X, Y]
def cky_decoder(words, emission_q, unary_q):
n = len(words)
pi = {} # subtrees with max probabilities
bp = {} # back pointers to subtrees with max probabilities
# initialize
for i, w in enumerate(words):
for X in unary_q.keys():
if w in unary_q[X]:
pi[i,i,X] = unary_q[X][w]
bp[i,i,X] = (X, w, i, i)
# dynamic programming
for l in xrange(1, n): # subtree 1 <= len < n
for i in xrange(n-l): # subtree start
j = i + l # subtree end
for X, rules in emission_q.iteritems():
max_pi = 0.0
max_pi_params = None
for rule, q in rules.iteritems():
for s in xrange(i,j):
Y, Z = rule
pi_left = pi.get((i,s,Y), 0.0)
pi_right = pi.get((s+1,j,Z), 0.0)
tmp_pi = q * pi_left * pi_right
if tmp_pi > max_pi:
max_pi = tmp_pi
max_pi_params = (X, Y, Z, i, s, j)
if max_pi > 0:
pi[i,j,X] = max_pi
bp[i,j,X] = max_pi_params
# Parse tree must be rooted at SBARQ
if (0, n-1, 'SBARQ') in pi:
return backtrace(bp[0, n-1, 'SBARQ'], bp)
In [235]:
def decode_file(in_file, out_file, word_counts, emission_q, unary_q):
with open(in_file, 'rt') as in_f, open(out_file, 'wt') as out_f:
for sentence in in_f:
words = sentence.split()
words_with_rare = replace_rare_words(words, word_counts)
tree = cky_decoder(words_with_rare, emission_q, unary_q)
out_f.write(json.dumps(tree) + '\n')
In [236]:
decode_file('parse_dev.dat', 'parse_dev.out', word_counts, emission_q, unary_q)
In [237]:
decode_file('parse_test.dat', 'parse_test.p2.out', word_counts, emission_q, unary_q)
In [238]:
subprocess.call('python count_cfg_freq.py parse_train_vert.dat > cfg_vert.counts', shell=True)
Out[238]:
In [239]:
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('cfg_vert.counts')
with open('parse_train_vert2.dat', 'wt') as f:
for tree in trees_iterator('parse_train_vert.dat'):
replace_rare_words_in_tree(tree, word_counts)
f.write(json.dumps(tree) + '\n')
In [240]:
subprocess.call('python count_cfg_freq.py parse_train_vert2.dat > parse_train_vert.counts.out', shell=True)
Out[240]:
In [241]:
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('parse_train_vert.counts.out')
In [242]:
emission_q, unary_q = calc_model_parameters(unary_counts, nonterminal_counts, emission_counts)
In [243]:
decode_file('parse_dev.dat', 'parse_dev_vert.out', word_counts, emission_q, unary_q)
In [ ]:
decode_file('parse_test.dat', 'parse_test.p3.out', word_counts, emission_q, unary_q)
In [ ]: