In [1]:
from pyspark import SparkConf, SparkContext
import re
import numpy as np

stopwords = []
with open("/datasets/stop_words_en.txt") as f:
    for line in f:
        stopwords.append(line.strip())
stopwords = set(stopwords)

sc = SparkContext(conf=SparkConf().setAppName("MyApp").setMaster("local"))

def parse_rdd1(line):
    try:
        article_id, text = unicode(line.rstrip()).split('\t', 1)
        text = re.sub("^\W+|\W+$", "", text, flags=re.UNICODE)
        words = re.split("\W*\s+\W*", text, flags=re.UNICODE)
        good_words = []
        for word in words:
            lower_word = word.lower()
            if not (lower_word in stopwords):
                good_words.append(lower_word)
        return good_words
    except ValueError as e:
        return []

word_count = (
    sc.textFile("/data/wiki/en_articles_part/articles-part")
        .flatMap(parse_rdd1)
        .map(lambda x: (x, 1))
        .reduceByKey(lambda a, b: a+b)
)

total_count = (
    word_count
        .map(lambda (k, v): ("total", v))
        .reduceByKey(lambda a, b: a+b)
        .collect()
)[0][1]

# total_count

def parse_rdd2(line):
    try:
        article_id, text = unicode(line.rstrip()).split('\t', 1)
        text = re.sub("^\W+|\W+$", "", text, flags=re.UNICODE)
        words = re.split("\W*\s+\W*", text, flags=re.UNICODE)
        good_words = []
        for word in words:
            lower_word = word.lower()
            if not (lower_word in stopwords):
                good_words.append(lower_word)
                
        result = []
        for i in range(len(good_words) - 1):
            result.append(good_words[i] + "_" + good_words[i + 1])
        return result
    except ValueError as e:
        return []

def split_key(p):
    k, v = p
    a = k.split("_")
    return a[0], (v, "_".join(a[1:]))

def swap_pair_and_union(p):
    k1, ((c1, k2), c2) = p
    return k2, (c1, c2, k1)

def join_and_fix(p):
    k2, ((c1, c2, k1), c3) = p
    return k1 + "_" + k2, (c1, c2, c3)

def NPMI(c1, c2, c3):
    # ln (P(ab) / (P(a) * P(b))
    # “NPMI(a, b) = PMI(a, b) / -ln P(ab)”
    p_ab = float(c1) / total_count
    p_a = float(c2) / total_count
    p_b = float(c3) / total_count
    PMI = np.log(p_ab / (p_a * p_b))
    NPMI = PMI / (-np.log(p_ab))  # may be corrected
    return NPMI
    

result = (
    sc.textFile("/data/wiki/en_articles_part/articles-part")
        .flatMap(parse_rdd2)
        .map(lambda x: (x, 1))
        .reduceByKey(lambda a, b: a+b)
        .filter(lambda (k, v): (v >= 500))
        .map(split_key)
        .join(word_count)
        .map(swap_pair_and_union)
        .join(word_count)
        .map(join_and_fix)
        .map(lambda (k, (c1, c2, c3)) : (NPMI(c1, c2, c3), k))
        .sortByKey(False)
        .collect()
)
#        .map(split_key)
#        .join(word_count)
for word in result[:39]:
    print word[1]


los_angeles
external_links
united_states
prime_minister
san_francisco
et_al
new_york
supreme_court
19th_century
20th_century
references_external
soviet_union
air_force
baseball_player
university_press
roman_catholic
united_kingdom
references_reading
notes_references
award_best
north_america
new_zealand
civil_war
catholic_church
world_war
war_ii
south_africa
took_place
roman_empire
united_nations
american_singer-songwriter
high_school
american_actor
american_actress
american_baseball
york_city
american_football
years_later
north_american

In [ ]: