In [ ]:
import os
import numpy as np
In [ ]:
#test_file = 'dev.1_all.hdf5_base.npz'
#test_file = 'test.1_all.hdf5_base.npz'
#test_file = 'test.1_all.hdf5_fac1.0.npz'
# Has :2 'focus' for non-answers
#test_file = 'test.1_all.hdf5_fac0.1.npz'
#test_file = 'test.1_all.hdf5_fac0.02.npz'
#test_file = 'test.1_all.hdf5_baseno.npz'
# Has token_clf 'focus' for non-answers, and extra_block
#test_file = 'test.1_all.hdf5_base-clf-extra1.npz'
#test_file = 'test.1_all.hdf5_base-clf-extra2.npz'
#test_file = 'test.1_all.hdf5_base-clf-extra4.npz'
#test_file = 'test.1_all.hdf5_fac0.02-clf-extra1.npz'
#test_file = 'test.1_all.hdf5_fac0.02-clf-extra2.npz'
#test_file = 'test.1_all.hdf5_fac0.02-clf-extra4.npz'
# Has token_clf 'focus' for non-answers, and extra_block, with dropouts
#test_file = 'test.1_all.hdf5_base-clf-extra-do1.npz'
test_file = 'test.1_all.hdf5_base-clf-extra-do03.npz'
test_path = './orig/omerlevy-bidaf_no_answer-2e9868b224e4/relation_splits/'
data = np.load( os.path.join(test_path,test_file) )
data.files # ['predictions', 'targets']
In [ ]:
#with open(os.path.join(test_path, 'test.1_all.hdf5_base.npz.txt'),'rt') as f:
with open(os.path.join(test_path, 'test.1_all.bpe'),'rt') as f:
bpes = [ t.replace('@@', ' @@').strip() for t in f.readlines()]
bpes[1] # These split on ' ' for bpe locations, but a substring can be converted back .replace(' @@', '')
In [ ]:
preds, targs = data['predictions'], data['targets']
preds.shape, targs.shape # ((10, 1, 5, 128), (10, 1, 128))
In [ ]:
# This is a great method : good accuracy *if* it returns results
def pred_argmax_valid(pred):
valid=False; a_start_best=a_end_best=0
pred_argmax = np.argmax(pred[:,:], axis=0)
#print(pred_argmax)
pred_argmax_list = list(pred_argmax)
print(" - pred_argmax_start/end = ", 3 in pred_argmax_list, 4 in pred_argmax_list)
if 3 in pred_argmax_list and 4 in pred_argmax_list:
a_start_best = pred_argmax_list.index(3)
a_end_best = pred_argmax_list.index(4)
valid = a_start_best<a_end_best
print(" - Found ideal", valid, a_start_best, a_end_best)
return valid, a_start_best, a_end_best
In [ ]:
def softmax(x, axis=None):
x = x - x.max(axis=axis, keepdims=True)
y = np.exp(x)
return y / y.sum(axis=axis, keepdims=True)
def get_pred_len( bpe_str ):
pred_len=len( bpe_str.split(' ') )-1
if pred_len>=preds.shape[-1]-1: pred_len=preds.shape[-1]-2
return pred_len
def pred_probs_ij(pred, pred_len=-1, debug=False): # Defaults to :2 no-answer zone
valid=False; a_start_best=a_end_best=0
p_starts = softmax(pred[3,:])
p_ends = softmax(pred[4,:])
p_ij = np.outer(p_starts, p_ends)
#p_ij = np.triu(p_ij) # Force start<=end
p_ij = np.triu(p_ij,k=1) # Force start<end
#print( p_starts.shape, p_ends.shape, p_ij.shape )
#print( p_ij[3,5], p_ij[5,3], ) #
if pred_len<0:
p_ij[0:2, :]=0. # Kill off the joint no-go-zone
p_ij[:, 0:2]=0. # Kill off the joint no-go-zone
p_no_answer = np.max(p_starts[:2])*np.max(p_ends[:2])
else:
p_ij[pred_len:, :]=0. # Kill off the joint no-go-zone
p_ij[:, pred_len:]=0. # Kill off the joint no-go-zone
p_no_answer = np.max(p_starts[pred_len:pred_len+2])*np.max(p_ends[pred_len:pred_len+2])
#print(p_starts[pred_len:pred_len+2], p_ends[pred_len:pred_len+2])
# Get the n-dimensional argmax
p_ij_argmax = np.unravel_index(p_ij.argmax(), p_ij.shape)
p_ij_max = p_ij[ p_ij_argmax ]
# Now work out the combined probability of start+end in the no-go-zone
print(" p_none=%.8f, p_ij_max=%.8f" % (p_no_answer, p_ij_max, ),
", logit_i=%+.4f, logit_j=%+.4f" % (pred[3,p_ij_argmax[0]],pred[4,p_ij_argmax[1]],),
p_ij_argmax)
if p_ij_max>p_no_answer:
#if p_ij_max>0.01:
#if p_ij_max>0.03:
#if p_ij_max>0.3:
#if p_no_answer<0.9:
valid=True
a_start_best, a_end_best = p_ij_argmax
return valid, a_start_best, a_end_best
idx=11994; pred_probs_ij( preds[idx, 0, :, :], pred_len=get_pred_len(bpes[idx]), debug=True)
In [ ]:
def guess_best_timestep(pred):
valid=False; a_start_best=a_end_best=0
# We didn't get any answers : Let's see if the best guesses are valid enough
#if (3 not in pred_argmax_list) or (4 not in pred_argmax_list):
#if (not valid) and (3 in pred_argmax_list or 4 in pred_argmax_list):
if False:
# Let's see if this naive method works (at all)
a_start_best = np.argmax(pred[3,:])
a_end_best = np.argmax(pred[4,:])
if False: # See whether we can do a better job with before / after
if pred[3, a_start_best]>pred[4, a_end_best]:
a_end_best = np.argmax(pred[4,a_start_best:])+a_start_best
else:
a_start_best = np.argmax(pred[3,:a_end_best])
if True:
# Go from the best start onwards
a_start_0 = np.argmax(pred[3,:])
a_end_0 = np.argmax(pred[4,a_start_0:])+a_start_0
# Go from the best end backwards
a_end_1 = np.argmax(pred[4,:])
a_start_1 = np.argmax(pred[3,:a_end_1])
if pred[3,a_start_0]+pred[4,a_end_0] > pred[3,a_start_1]+pred[4,a_end_1]:
a_start_best, a_end_best = a_start_0, a_end_0
else:
a_start_best, a_end_best = a_start_1, a_end_1
valid = a_start_best<a_end_best
if pred[3, a_start_best]<+0.0:
valid=False
if pred[4, a_end_best]<+0.0:
valid=False
print(" - Trying naive : ", a_start_best, a_end_best,
pred[3, a_start_best], pred[4, a_end_best])
#valid=False
return valid, a_start_best, a_end_best
In [ ]:
def interval_intersect_union(a_start_best, a_end_best, targ_a_arr):
# Work out the overlap between (a_start_best,a_end_best) and each element of (targ_a_arr)
best_overlap=0.
a0, a1 = a_start_best, a_end_best
for (b0, b1) in targ_a_arr:
intersection = min(a1, b1) - max(a0, b0)
union = max(a1, b1) - min(a0, b0)
overlap_fraction = intersection/float(union)
if overlap_fraction>0. and best_overlap<overlap_fraction:
best_overlap=overlap_fraction
return best_overlap
In [ ]:
import string
"""
We ignore word order, case, punctuation, and articles (“a”, “an”, “the”).
We also ignore “and”, which often appears when a single span captures multiple correct answers
(e.g. “United States and Canada”).
"""
PUNCTUATION = set(string.punctuation)
def bpe_arr_to_word_set(bpe_arr):
text = ' '.join(bpe_arr).replace(' @@', '') # Remove @@bpe markers
text_clean = [ ''.join(c for c in t if c not in PUNCTUATION)
for t in text.split(' ') ]
words = set(text_clean) - {'the', 'a', 'an', 'and', ''}
#print("'%s'-> '%s'" % (text, '|'.join(list(words)),))
return words
def target_hit_by_word(a_start, a_end, targ_a_arr, bpe_str, debug=False):
bpe_arr = bpe_str.split(' ')
#print("a_start, a_end, targ_a_arr, bpe_str", a_start, a_end, targ_a_arr, bpe_str)
hit=0.
a_words = bpe_arr_to_word_set( bpe_arr[a_start:a_end] )
if debug : print(" a_words", a_words)
for (b_start, b_end) in targ_a_arr:
b_words = bpe_arr_to_word_set( bpe_arr[b_start:b_end] )
if debug : print(" b_words", b_words)
if a_words==b_words: hit=1.
return hit
# test set equality
set(['foo', 'bar']) == set(['bar', 'foo']), set(['foo', 'bar']) == set(['bar', 'foo', 'baz'])
In [ ]:
def assess(idx, pred, targ, debug=False):
#print(pred.shape, targ.shape) # (5, 128) (128,)
# Get the starts and ends from targ
targ_q_starts = [i for i,v in enumerate(list(targ)) if v==1]
targ_a_starts = [i for i,v in enumerate(list(targ)) if v==3]
#print(targ_q_starts);print(targ_a_starts)
targ_q_arr, targ_a_arr = [],[]
for i in targ_q_starts: # find next '2'
# problem at idx=1218 ??
if 2 in list(targ[i:]):
targ_q_arr.append( (i, i+list(targ[i:]).index(2) ) )
for i in targ_a_starts: # find next '4'
# problem at idx=4480 ??
if 4 in list(targ[i:]):
targ_a_arr.append( (i, i+list(targ[i:]).index(4) ) )
print( idx, 'targets : q=', targ_q_arr, ', a=', targ_a_arr )
# This is pretty 'audacious', since it wasn't trained to succeed at doing this
#print(" answer: ", np.argmax(pred[3,:]), np.argmax(pred[4,:]), )
"""
Evaluation Metrics : Each instance is evaluated by comparing
the tokens in the labeled answer set with those of the predicted span.
Precision is the true positive count divided by the number of times the system returned a non-null answer.
Recall is the true positive count divided by the number of instances that have an answer.
"""
have_ans = 1 if len(targ_a_arr)>0 else 0
tp=0.
non_null=0
a_start_best = a_end_best = 0
valid=False
# Normalise the predictions row-wise, since we're going to be softmaxing column-wise
pred -= np.mean(pred, axis=0)
# Try the p_ij scheme (requires 'deadzone' predictions at start)
pred_len=-1
if True and not valid:
pred_len = get_pred_len(bpes[idx]) # For token_clf for non-answers
v, a0, a1 = pred_probs_ij(pred, pred_len=pred_len)
if v: valid, a_start_best, a_end_best = v, a0, a1
# Ok, so how about looking at the list of predicted classes?
if False and not valid:
v, a0, a1 = pred_argmax_valid(pred)
if v: valid, a_start_best, a_end_best = v, a0, a1
# Try guessing over all the timesteps (even if not chosen class)
if False and not valid:
v, a0, a1 = guess_best_timestep(pred)
if v: valid, a_start_best, a_end_best = v, a0, a1
if valid: # This appears to be a valid guess
print(" answer: [%d, %d]" % (a_start_best, a_end_best))
non_null=1 # The system thinks it found something
# Work out the overlap between (a_start_best,a_end_best) and each element of (targ_a_arr)
#tp = interval_intersect_union(a_start_best, a_end_best, targ_a_arr)
tp = target_hit_by_word(a_start_best, a_end_best, targ_a_arr, bpes[idx], debug=True) #debug
else: # Could we have guessed some other way?
print(" answer: []")
if True and debug:
for i,v in enumerate(list(pred.T)):
#if not bpes[idx].split(' ')[i] == '<unk>':
if i<len(bpes[idx].split(' '))+1:
print("%3d" % i, ["%+7.2f" % f for f in list(v)],
(bpes[idx]+' ').split(' ')[i],
('token_clf' if i==pred_len else ''))
if False and debug:
for i,v in enumerate(list(pred.T)):
c = np.argmax(v)
if c==0: continue
print(i, c) # 1->2 is subject, 3->4 is required answer(s)
return tp, non_null, have_ans
# idx=0 should be a None, idx=1 should have an Answer, 6 seems correct
idx=69; assess(idx, preds[idx, 0, :, :], targs[idx, 0, :], debug=True)
In [ ]:
tp_tot, non_null_tot, have_ans_tot = 0.,0,0
for idx in range(preds.shape[0]):
#print("idx=", idx)
tp, non_null, have_ans = assess( idx, preds[idx, 0, :, :], targs[idx, 0, :] )
#print(" #%d assessment :tp=%.1f, non_null=%1d, have_ans=%1d" % (idx, tp, non_null, have_ans,) )
tp_tot += tp
non_null_tot += non_null
have_ans_tot += have_ans
precis = tp_tot / non_null_tot
recall = tp_tot / have_ans_tot
f1 = 2. * (precis*recall)/(precis+recall)
print("precision=%.2f%% recall=%.2f%% F1=%.2f%%" % (100.*precis, 100.*recall, 100.*f1))
In [ ]:
#Paper (BEAT THIS!)
# precision=43.61% recall=36.45% F1=39.61%
#test.1_all.hdf5_base.npz
# precision=30.97% recall=33.06% F1=31.98% # No restriction on start,end logit values
# precision=43.03% recall=23.35% F1=30.27% # Force start,end logit values >0
#test.1_all.hdf5_fac1.0.npz
# precision=24.67% recall=43.13% F1=31.39% # No restriction on start,end logit values
# precision=29.29% recall=40.50% F1=34.00% # Force start,end logit values >0
#test.1_all.hdf5_fac0.1.npz # This has [0:2] focus for non-answers (but only ~1 epoch of training)
# precision=55.43% recall=19.11% F1=28.42% # using p_ij and intersection_over_union
# precision=52.02% recall=17.93% F1=26.67% # using p_ij and cleaned word sets
#test.1_all.hdf5_baseno.npz' # Has :2 'focus' for non-answers
# precision=44.55% recall=21.22% F1=28.74% # using p_ij and cleaned word sets
#test.1_all.hdf5_fac0.02.npz' # Has :2 'focus' for non-answers
# precision=45.45% recall=19.91% F1=27.69% # using p_ij and cleaned word sets
# Has token_clf 'focus' for non-answers, and extra_block
# using p_ij and cleaned word sets
#test.1_all.hdf5_base-clf-extra1.npz
# precision=42.13% recall=31.78% F1=36.23%
#test.1_all.hdf5_base-clf-extra2.npz
# precision=44.86% recall=30.68% F1=36.44%
#test.1_all.hdf5_base-clf-extra4.npz
# precision=42.41% recall=30.13% F1=35.23%
#test.1_all.hdf5_fac0.02-clf-extra1.npz
# precision=47.16% recall=27.77% F1=34.95%
# precision=52.54% recall=24.89% F1=33.78% # accept p_ij>0.3
# precision=44.23% recall=30.54% F1=36.14% # accept p_ij>0.05
# precision=43.28% recall=31.50% F1=36.46% # accept p_ij>0.03
# precision=46.14% recall=28.40% F1=35.16% # Normalize the pred rows
# precision=43.74% recall=31.36% F1=36.53% # Normalize the pred rows and accept p_ij>0.05
# precision=42.43% recall=32.17% F1=36.59% # Normalize the pred rows and accept p_ij>0.03
#test.1_all.hdf5_fac0.02-clf-extra2.npz
# precision=43.33% recall=27.03% F1=33.29%
#test.1_all.hdf5_fac0.02-clf-extra4.npz
# precision=39.75% recall=27.13% F1=32.25%
# Idea (currently running) : Add drop-out 0.1, since the net seems to be overfitting quickly
#test.1_all.hdf5_base-clf-extra-do1.npz
# precision=47.87% recall=29.51% F1=36.51%
# precision=48.05% recall=29.36% F1=36.45% # Normalize the pred rows
# precision=44.74% recall=31.58% F1=37.03% # Normalize the pred rows and accept p_ij>0.03
# precision=42.09% recall=33.09% F1=37.05% # Normalize the pred rows and accept p_ij>0.01
# Idea (currently running) : Add drop-out 0.3, since 0.1 did some good...
#test.1_all.hdf5_base-clf-extra-do03.npz
# precision=43.59% recall=32.57% F1=37.28% # Normalize the pred rows and accept p_ij>0.01
# better. but still not > 39.7...
# precision=35.98% recall=37.96% F1=36.94% # Normalize the pred rows and accept p_none<0.9
In [ ]: