In [222]:
import json
import subprocess
from collections import defaultdict

Part 1. Parser Setup


In [223]:
subprocess.call('python count_cfg_freq.py parse_train.dat > cfg.counts', shell=True)
# 17 NONTERMINAL NP - nonterminal counts Count(X)
# 8 UNARYRULE NP+NOUN place - Unigram counts Count(X -> Y)
# 918 BINARYRULE NP DET NOUN - Emission counts Count(X -> Y1 Y2)


Out[223]:
0

In [224]:
def load_counts(file_path):    
    word_counts = defaultdict(int)
    unary_counts = defaultdict(dict)
    nonterminal_counts = {}
    emission_counts = defaultdict(dict)
    with open(file_path, 'rt') as f:
        for line in f:
            parts = line.rstrip().split(' ')
            count = int(parts[0])
            count_type = parts[1]
            nonterminal = parts[2]
            if count_type == 'UNARYRULE':
                word = parts[3]
                word_counts[word] += count
                unary_counts[nonterminal][word] = count
            elif count_type == 'NONTERMINAL':
                nonterminal_counts[nonterminal] = count
            elif count_type == 'BINARYRULE':
                if len(parts) != 5:
                    raise Exception('CFG must be in Chomsky Normal Form (CNF)')
                y1, y2 = parts[3], parts[4]
                emission_counts[nonterminal][y1, y2] = count
    return word_counts, nonterminal_counts, emission_counts, unary_counts

def branches_iterator(tree):
    # Pre-order tree traversal
    # It assumes the tree is in CNF
    if type(tree) is not list:
        raise TypeError('"tree" must be a list')
        
    if len(tree) == 3:
        # binary rule branch
        yield tree
        # left subtree
        for branch in branches_iterator(tree[1]):
            yield branch
        # right subtree
        for branch in branches_iterator(tree[2]):
            yield branch
    elif len(tree) == 2:
        # unary rule branch
        yield tree
        
def trees_iterator(file_path):
    with open(file_path, 'rt') as f:
        for line in f:
            tree = json.loads(line)
            yield tree

def replace_rare_words_in_tree(tree, word_counts):
    for branch in branches_iterator(tree):
        if len(branch) == 2:
            # unary rule branch
            word = branch[1]
            if is_rare_word(word, word_counts):
                branch[1] = '_RARE_'
                
def is_rare_word(word, word_counts):
    return word_counts.get(word, 0) < 5

In [225]:
# count rare words regardless of part of speech
# https://class.coursera.org/nlangp-001/forum/thread?thread_id=922
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('cfg.counts')
with open('parse_train2.dat', 'wt') as f:
    for tree in trees_iterator('parse_train.dat'):
        replace_rare_words_in_tree(tree, word_counts)
        f.write(json.dumps(tree) + '\n')

In [226]:
subprocess.call('python count_cfg_freq.py parse_train2.dat > parse_train.counts.out', shell=True)


Out[226]:
0

Part 2. CKY Decoder


In [227]:
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('parse_train.counts.out')

In [228]:
def check_probabilities(probabilities):
    for state, transitions in probabilities.iteritems():
        total_prob = sum(transitions.itervalues())
        total_prob = round(total_prob, 10)
        if total_prob != 1.0:
            raise Exception(state + ': probability = ' + str(total_prob))

def calc_model_parameters(unary_counts, nonterminal_counts, emission_counts):
    emission_q = defaultdict(dict)
    unary_q = defaultdict(dict)
    for nonterminal, transitions in emission_counts.iteritems():
        for transit, count in transitions.iteritems():
            emission_q[nonterminal][transit] = float(count) / nonterminal_counts[nonterminal]
            
    for nonterminal, transitions in unary_counts.iteritems():
        for word, count in transitions.iteritems():
            unary_q[nonterminal][word] = float(count) / nonterminal_counts[nonterminal]
            
    check_probabilities(emission_q)
    check_probabilities(unary_q)
    return emission_q, unary_q

def replace_rare_words(words, word_counts):
    return [word if not is_rare_word(word, word_counts) else '_RARE_' for word in words]

In [229]:
# emission_q[X][(Y,Z)] = binary rule probability P(X -> Y Z)
# unary_q[X][w] = unary rule probability P(X -> w)
emission_q, unary_q = calc_model_parameters(unary_counts, nonterminal_counts, emission_counts)

In [234]:
def backtrace(root, bp):
    if len(root) == 6:
        (X, Y, Z, i, s, j) = root
        return [X, backtrace(bp[i, s, Y], bp), backtrace(bp[s + 1, j, Z], bp)]
    else:
        (X, Y, i, i) = root
    return [X, Y]

def cky_decoder(words, emission_q, unary_q):
    n = len(words)
    pi = {} # subtrees with max probabilities
    bp = {} # back pointers to subtrees with max probabilities
    # initialize
    for i, w in enumerate(words):
        for X in unary_q.keys():
            if w in unary_q[X]:
                pi[i,i,X] = unary_q[X][w]
                bp[i,i,X] = (X, w, i, i)
    # dynamic programming
    for l in xrange(1, n):     # subtree 1 <= len < n
        for i in xrange(n-l):  # subtree start
            j = i + l          # subtree end
            for X, rules in emission_q.iteritems():
                max_pi = 0.0
                max_pi_params = None
                for rule, q in rules.iteritems():
                    for s in xrange(i,j):
                        Y, Z = rule
                        pi_left = pi.get((i,s,Y), 0.0)
                        pi_right = pi.get((s+1,j,Z), 0.0)
                        tmp_pi = q * pi_left * pi_right                        
                        if tmp_pi > max_pi:
                            max_pi = tmp_pi
                            max_pi_params = (X, Y, Z, i, s, j)
                
                if max_pi > 0:
                    pi[i,j,X] = max_pi
                    bp[i,j,X] = max_pi_params

    # Parse tree must be rooted at SBARQ
    if (0, n-1, 'SBARQ') in pi:
        return backtrace(bp[0, n-1, 'SBARQ'], bp)

In [235]:
def decode_file(in_file, out_file, word_counts, emission_q, unary_q):
    with open(in_file, 'rt') as in_f, open(out_file, 'wt') as out_f:
        for sentence in in_f:
            words = sentence.split()
            words_with_rare = replace_rare_words(words, word_counts)
            tree = cky_decoder(words_with_rare, emission_q, unary_q)
            out_f.write(json.dumps(tree) + '\n')

In [236]:
decode_file('parse_dev.dat', 'parse_dev.out', word_counts, emission_q, unary_q)

In [237]:
decode_file('parse_test.dat', 'parse_test.p2.out', word_counts, emission_q, unary_q)

Part 3. Vertical Markovization


In [238]:
subprocess.call('python count_cfg_freq.py parse_train_vert.dat > cfg_vert.counts', shell=True)


Out[238]:
0

In [239]:
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('cfg_vert.counts')
with open('parse_train_vert2.dat', 'wt') as f:
    for tree in trees_iterator('parse_train_vert.dat'):
        replace_rare_words_in_tree(tree, word_counts)
        f.write(json.dumps(tree) + '\n')

In [240]:
subprocess.call('python count_cfg_freq.py parse_train_vert2.dat > parse_train_vert.counts.out', shell=True)


Out[240]:
0

In [241]:
word_counts, nonterminal_counts, emission_counts, unary_counts = load_counts('parse_train_vert.counts.out')

In [242]:
emission_q, unary_q = calc_model_parameters(unary_counts, nonterminal_counts, emission_counts)

In [243]:
decode_file('parse_dev.dat', 'parse_dev_vert.out', word_counts, emission_q, unary_q)

In [ ]:
decode_file('parse_test.dat', 'parse_test.p3.out', word_counts, emission_q, unary_q)

In [ ]: