In [1]:
import re

In [80]:
trainp = "original/annotated_fb_data_train.txt"
testp = "original/annotated_fb_data_test.txt"
validp = "original/annotated_fb_data_valid.txt"
labelp = "aux/labels.map"

In [ ]:


In [ ]:


In [81]:
def getdatafor(datap):
    questions = []
    entc = {}
    relc = {}
    subjc = {}
    maxsentlen = 0
    maxwordlen = 0
    maxsentcharlen = 0
    c = 0
    for line in open(datap):
        s, p, o, q = (line[:-1] if line[-1] == "\n" else line).split("\t")
        maxsentcharlen = max(maxsentcharlen, len(q))
        ws = q.split()
        questions.append((q, s, p, o))
        for (e, col) in zip([s, o, p, s], [entc, entc, relc, subjc]):
            if e not in col:
                col[e] = 0
            col[e] += 1
        maxsentlen = max(maxsentlen, len(ws))
        for w in ws:
            maxwordlen = max(maxwordlen, len(w))
            if w not in wordc:
                wordc[w] = 0
            wordc[w] += 1
        if c % 1e4 == 0:
            print c
        c += 1
    return questions, subjc, entc, relc, wordc, maxsentlen, maxwordlen, maxsentcharlen

In [83]:
_, tsubjc, tentc, trelc, twc, tmsl, tmwl, tmscl = getdatafor(trainp)
_, xsubjc, xentc, xrelc, xwc, xmsl, xmwl, xmscl = getdatafor(testp)
_, vsubjc, ventc, vrelc, vwc, vmsl, vmwl, vmscl = getdatafor(validp)


0
10000
20000
30000
40000
50000
60000
70000
0
10000
20000
0
10000

In [84]:
# max lens
allwords = set(vwc.keys()).union(set(xwc.keys())).union(set(twc.keys()))
print "DISTINCT WORDS:\t%d" % len(allwords)
print "MAX SENT LENS"
print "train:\t%d" % tmsl
print "valid:\t%d" % vmsl
print "test:\t%d" % xmsl
print "MAX SENT CHAR LENS"
print "train:\t%d" % tmscl
print "valid:\t%d" % vmscl
print "test:\t%d" % xmscl
print "MAX WORD LENS"
print "train:\t%d" % tmwl
print "valid:\t%d" % vmwl
print "test:\t%d" % xmwl


DISTINCT WORDS:	77167
MAX SENT LENS
train:	33
valid:	20
test:	24
MAX SENT CHAR LENS
train:	196
valid:	113
test:	141
MAX WORD LENS
train:	67
valid:	49
test:	57

In [50]:
# total number of distinct entities:
allents = set(tentc.keys()).union(set(xentc.keys())).union(set(ventc.keys()))
print "TOTAL NUMBER OF DISTINCT ENTITIES"
print "total: %d" % len(allents)
print "train: %d" % len(tentc)
print "valid: %d" % len(ventc)
print "test:  %d" % len(xentc)
allrels = set(trelc.keys()).union(set(xrelc.keys())).union(set(vrelc.keys()))
print "TOTAL NUMBER OF DISTINCT RELATIONS"
print "total: %d" % len(allrels)
print "train: %d" % len(trelc)
print "valid: %d" % len(vrelc)
print "test:  %d" % len(xrelc)


TOTAL NUMBER OF DISTINCT ENTITIES
total: 131684
train: 95832
valid: 16000
test:  30476
TOTAL NUMBER OF DISTINCT RELATIONS
total: 1837
train: 1629
valid: 783
test:  1034

In [54]:
# number of entities in test but not in train
xonlyents = set(xentc.keys()).difference(set(tentc.keys()))
print "NUMBER OF ENTITIES IN TEST BUT NOT IN TRAIN"
print len(xonlyents)
xonlyrels = set(xrelc.keys()).difference(set(trelc.keys()))
print "NUMBER OF RELATIONS IN TEST BUT NOT IN TRAIN"
print len(xonlyrels)


NUMBER OF ENTITIES IN TEST BUT NOT IN TRAIN
24121
NUMBER OF RELATIONS IN TEST BUT NOT IN TRAIN
148

In [56]:
with open("allents.col", "w") as f:
    for ent in allents.union(allrels):
        f.write("%s\n" % ent)

In [ ]:


In [ ]:


In [59]:
# load labels:
labels = {}
for line in open(labelp):
    x, y = line[:-1].split("\t")
    labels[x] = y

In [63]:
len(labels)


Out[63]:
130822

In [78]:
sre = re.compile("www\.freebase\.com/m/(.+)")
pre = re.compile("www\.freebase\.com(.+)")

In [69]:
c = 10
for line in open(trainp):
    s, _, _, _ = line.split("\t")
    s = "m."+sre.match(s).group(1)
    print s, labels[s]
    if c < 0:
        break
    c -= 1


m.04whkz5 E
m.0tp2p24 Cardiac Arrest
m.04j0t75 The Debt
m.0ftqr Nobuo Uematsu
m.036p007 Eve-Olution
m.0ms5mg Most of Us Are Sad
m.086k8 Warner Bros. Entertainment
m.02vnx8y Don Graham
m.01smm Columbus
m.0mgb6cl Tibet
m.02dtg Detroit
m.0275d7v RYNA

In [ ]:


In [ ]:


In [70]:
traino = "fb_train.tsv"
testo = "fb_test.tsv"
valido = "fb_valid.tsv"

In [79]:
for (x, y) in zip([trainp, validp, testp], [traino, valido, testo]):
    questions, _, _, _, _, _, _ = getdatafor(x)
    with open(y, "w") as f:
        for q in questions:
            s = q[1]
            p = q[2]
            s = "m." + sre.match(s).group(1)
            p = pre.match(p).group(1)
            f.write("%s\t%s %s\n" % (q[0], s, p))


0
10000
20000
30000
40000
50000
60000
70000
0
10000
0
10000
20000

In [ ]: