In [2]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.sparse import lil_matrix
import pickle
%matplotlib inline
with open('train_single_relation.txt', 'w') as fw: with open('train', 'r') as f: for line in f.readlines(): start, relations, end = line.split('\t') if ',' not in relations: fw.write(line)
In [2]:
datafile = '../data/freebase/train_single_relation.txt'
entities = set()
relations = set()
with open(datafile, 'r') as f:
for line in f.readlines():
start, relation, end = line.split('\t')
if start.strip() not in entities:
entities.add(start.strip())
if end.strip() not in entities:
entities.add(end.strip())
if relation.strip() not in relations:
relations.add(relation)
In [3]:
n_entities = len(entities)
entities = list(entities)
entity_dic = {entities[k]:k for k in range(len(entities))}
In [4]:
n_relations = len(relations)
relations = list(relations)
relation_dic = {relations[k]:k for k in range(len(relations))}
selected_relations = list() #manually selected list of relations
selected_relations.append(relation_dic['place_of_birth'])
selected_relations.append(relation_dic['place_of_death'])
selected_relations.append(relation_dic['nationality'])
selected_relations.append(relation_dic['location'])
In [5]:
entity_count = np.zeros(n_entities)
T = [lil_matrix((n_entities, n_entities), dtype=int) for k in range(n_relations)]
cnt = 0
with open(datafile, 'r') as f:
for line in f.readlines():
start, relation, end = line.split('\t')
e_i = entity_dic[start.strip()]
e_j = entity_dic[end.strip()]
r_k = relation_dic[relation.strip()]
T[r_k][e_i,e_j] = 1
if r_k in selected_relations:
if e_i == e_j:
entity_count[e_i] += 1
else:
entity_count[e_i] += 1
entity_count[e_j] += 1
T = [X.tocsr() for X in T]
entities = np.array(entities)
relations = np.array(relations)
In [6]:
plt.figure(figsize=(8,6))
plt.bar(range(n_relations), [T[k].nnz for k in range(n_relations)])
plt.xticks(np.arange(0.5, n_relations), relations, rotation='vertical')
plt.title('Number of triples for each relation')
Out[6]:
In [7]:
print('num entity', n_entities)
print('num triples', np.sum([T[k].nnz for k in range(n_relations)]))
print('sparsity', np.sum([T[k].nnz for k in range(n_relations)])/(n_relations * n_entities**2))
In [16]:
newT, entities, relations = pickle.load(open('../data/freebase/subset_3000.pkl', 'rb'))
plt.figure(figsize=(8,6))
n_relations = len(relations)
plt.bar(range(n_relations), [np.sum(newT[k]) for k in range(n_relations)])
plt.xticks(np.arange(0.5, n_relations), relations, rotation='vertical')
plt.title('Number of triples for each relation')
Out[16]:
In [ ]: