In [2]:
import numpy as np
from sklearn.linear_model import LogisticRegression
import scipy.sparse
import time
import itertools
import sys
import pickle
import helper
import train
import predict

In [3]:
# test_file = "../data/tagged_data/EMA2/dev.3.tag" #sys.argv[1]
test_file = "../data/tagged_data/whole_text_full_city/dev.tag" #sys.argv[1]
test_data, test_identifier = train.load_data(test_file)

In [10]:
trained_model = "trained_model.large.y.p" #sys.argv[2]
tic = time.clock()

In [ ]:


In [11]:
viterbi = False

In [12]:
reload(train)
reload(predict)
clf, previous_n,next_n, word_vocab, other_features = pickle.load( open( trained_model, "rb" ) )

ptags = []
goldY = []
for i, article in enumerate(test_data):
    sentence = article[0]
    gold = article[1]
    try:
        dataY, dataYconfidences = predict.predict_tags_n(viterbi, previous_n,next_n, clf, sentence, word_vocab, other_features)
        ptags.extend(dataY)
        goldY.extend(gold)
    except Exception, e:
        print sentence, article[1], test_identifier[i]

In [15]:
print sum(goldY)
print sum(ptags)
print train.int2tags
for ent in range(1,len(train.int2tags)):
    correct = [1 if ptags[i] == ent and goldY[i] == ent else 0 for i in range(len(ptags))]
    guessed = [1 if ptags[i] == ent else 0 for i in range(len(ptags))]
    total   = [1 if goldY[i] == ent else 0 for i in range(len(ptags))]


    accuracy = sum(correct) *1./sum(guessed) if sum(guessed) > 0 else 0
    recall   = sum(correct) * 1./sum(total) if  sum(total) > 0 else 0
    f1 = accuracy * recall * 2. / (accuracy + recall) 

    print train.int2tags[ent], "(" , sum(correct), ",", sum(guessed), ",", sum(total), ")" , "(",accuracy, ",", recall, ",", f1, ")"


4171
3544.0
['shooterName', 'killedNum', 'woundedNum', 'city']
killedNum ( 26 , 79 , 131 ) ( 0.329113924051 , 0.198473282443 , 0.247619047619 )
woundedNum ( 136 , 164 , 224 ) ( 0.829268292683 , 0.607142857143 , 0.701030927835 )
city ( 244 , 315 , 336 ) ( 0.774603174603 , 0.72619047619 , 0.749615975422 )

In [9]:
"""Affected_Food_Product: (347, 534, 835) (0.6498, 0.4156, 0.5069)
Produced_Location: (125, 206, 240) (0.6068, 0.5208, 0.5605)
Consumer_Brand: (28, 37, 605) (0.7568, 0.0463, 0.0872)
Adulterant: (68, 115, 334) (0.5913, 0.2036, 0.3029)
"""


Out[9]:
'Affected_Food_Product: (347, 534, 835) (0.6498, 0.4156, 0.5069)\nProduced_Location: (125, 206, 240) (0.6068, 0.5208, 0.5605)\nConsumer_Brand: (28, 37, 605) (0.7568, 0.0463, 0.0872)\nAdulterant: (68, 115, 334) (0.5913, 0.2036, 0.3029)\n'

In [ ]: