In [1]:
import re
In [80]:
trainp = "original/annotated_fb_data_train.txt"
testp = "original/annotated_fb_data_test.txt"
validp = "original/annotated_fb_data_valid.txt"
labelp = "aux/labels.map"
In [ ]:
In [ ]:
In [81]:
def getdatafor(datap):
questions = []
entc = {}
relc = {}
subjc = {}
maxsentlen = 0
maxwordlen = 0
maxsentcharlen = 0
c = 0
for line in open(datap):
s, p, o, q = (line[:-1] if line[-1] == "\n" else line).split("\t")
maxsentcharlen = max(maxsentcharlen, len(q))
ws = q.split()
questions.append((q, s, p, o))
for (e, col) in zip([s, o, p, s], [entc, entc, relc, subjc]):
if e not in col:
col[e] = 0
col[e] += 1
maxsentlen = max(maxsentlen, len(ws))
for w in ws:
maxwordlen = max(maxwordlen, len(w))
if w not in wordc:
wordc[w] = 0
wordc[w] += 1
if c % 1e4 == 0:
print c
c += 1
return questions, subjc, entc, relc, wordc, maxsentlen, maxwordlen, maxsentcharlen
In [83]:
_, tsubjc, tentc, trelc, twc, tmsl, tmwl, tmscl = getdatafor(trainp)
_, xsubjc, xentc, xrelc, xwc, xmsl, xmwl, xmscl = getdatafor(testp)
_, vsubjc, ventc, vrelc, vwc, vmsl, vmwl, vmscl = getdatafor(validp)
In [84]:
# max lens
allwords = set(vwc.keys()).union(set(xwc.keys())).union(set(twc.keys()))
print "DISTINCT WORDS:\t%d" % len(allwords)
print "MAX SENT LENS"
print "train:\t%d" % tmsl
print "valid:\t%d" % vmsl
print "test:\t%d" % xmsl
print "MAX SENT CHAR LENS"
print "train:\t%d" % tmscl
print "valid:\t%d" % vmscl
print "test:\t%d" % xmscl
print "MAX WORD LENS"
print "train:\t%d" % tmwl
print "valid:\t%d" % vmwl
print "test:\t%d" % xmwl
In [50]:
# total number of distinct entities:
allents = set(tentc.keys()).union(set(xentc.keys())).union(set(ventc.keys()))
print "TOTAL NUMBER OF DISTINCT ENTITIES"
print "total: %d" % len(allents)
print "train: %d" % len(tentc)
print "valid: %d" % len(ventc)
print "test: %d" % len(xentc)
allrels = set(trelc.keys()).union(set(xrelc.keys())).union(set(vrelc.keys()))
print "TOTAL NUMBER OF DISTINCT RELATIONS"
print "total: %d" % len(allrels)
print "train: %d" % len(trelc)
print "valid: %d" % len(vrelc)
print "test: %d" % len(xrelc)
In [54]:
# number of entities in test but not in train
xonlyents = set(xentc.keys()).difference(set(tentc.keys()))
print "NUMBER OF ENTITIES IN TEST BUT NOT IN TRAIN"
print len(xonlyents)
xonlyrels = set(xrelc.keys()).difference(set(trelc.keys()))
print "NUMBER OF RELATIONS IN TEST BUT NOT IN TRAIN"
print len(xonlyrels)
In [56]:
with open("allents.col", "w") as f:
for ent in allents.union(allrels):
f.write("%s\n" % ent)
In [ ]:
In [ ]:
In [59]:
# load labels:
labels = {}
for line in open(labelp):
x, y = line[:-1].split("\t")
labels[x] = y
In [63]:
len(labels)
Out[63]:
In [78]:
sre = re.compile("www\.freebase\.com/m/(.+)")
pre = re.compile("www\.freebase\.com(.+)")
In [69]:
c = 10
for line in open(trainp):
s, _, _, _ = line.split("\t")
s = "m."+sre.match(s).group(1)
print s, labels[s]
if c < 0:
break
c -= 1
In [ ]:
In [ ]:
In [70]:
traino = "fb_train.tsv"
testo = "fb_test.tsv"
valido = "fb_valid.tsv"
In [79]:
for (x, y) in zip([trainp, validp, testp], [traino, valido, testo]):
questions, _, _, _, _, _, _ = getdatafor(x)
with open(y, "w") as f:
for q in questions:
s = q[1]
p = q[2]
s = "m." + sre.match(s).group(1)
p = pre.match(p).group(1)
f.write("%s\t%s %s\n" % (q[0], s, p))
In [ ]: