In [ ]:
import os
import numpy as np

In [ ]:
#test_file = 'dev.1_all.hdf5_base.npz'
#test_file = 'test.1_all.hdf5_base.npz'
#test_file = 'test.1_all.hdf5_fac1.0.npz'

# Has :2 'focus' for non-answers
#test_file = 'test.1_all.hdf5_fac0.1.npz' 
#test_file = 'test.1_all.hdf5_fac0.02.npz'
#test_file = 'test.1_all.hdf5_baseno.npz' 

# Has token_clf 'focus' for non-answers, and extra_block
#test_file = 'test.1_all.hdf5_base-clf-extra1.npz'
#test_file = 'test.1_all.hdf5_base-clf-extra2.npz'
#test_file = 'test.1_all.hdf5_base-clf-extra4.npz' 
#test_file = 'test.1_all.hdf5_fac0.02-clf-extra1.npz'
#test_file = 'test.1_all.hdf5_fac0.02-clf-extra2.npz'
#test_file = 'test.1_all.hdf5_fac0.02-clf-extra4.npz' 

# Has token_clf 'focus' for non-answers, and extra_block, with dropouts
#test_file = 'test.1_all.hdf5_base-clf-extra-do1.npz'
test_file = 'test.1_all.hdf5_base-clf-extra-do03.npz'

test_path = './orig/omerlevy-bidaf_no_answer-2e9868b224e4/relation_splits/'

data = np.load( os.path.join(test_path,test_file) )
data.files  # ['predictions', 'targets']

In [ ]:
#with open(os.path.join(test_path, 'test.1_all.hdf5_base.npz.txt'),'rt') as f:
with open(os.path.join(test_path, 'test.1_all.bpe'),'rt') as f:
    bpes = [ t.replace('@@', ' @@').strip() for t in f.readlines()]
bpes[1] # These split on ' ' for bpe locations, but a substring can be converted back .replace(' @@', '')

In [ ]:
preds, targs = data['predictions'], data['targets']
preds.shape, targs.shape  # ((10, 1, 5, 128), (10, 1, 128))

In [ ]:
# This is a great method : good accuracy *if* it returns results
def pred_argmax_valid(pred):
    valid=False; a_start_best=a_end_best=0

    pred_argmax = np.argmax(pred[:,:], axis=0)
    #print(pred_argmax)
    
    pred_argmax_list = list(pred_argmax)
    print(" - pred_argmax_start/end = ", 3 in pred_argmax_list, 4 in pred_argmax_list)

    if 3 in pred_argmax_list and 4 in pred_argmax_list:
        a_start_best = pred_argmax_list.index(3)
        a_end_best = pred_argmax_list.index(4)
        valid =  a_start_best<a_end_best
        print(" - Found ideal", valid, a_start_best, a_end_best)

    return valid, a_start_best, a_end_best

In [ ]:
def softmax(x, axis=None):
    x = x - x.max(axis=axis, keepdims=True)
    y = np.exp(x)
    return y / y.sum(axis=axis, keepdims=True)

def get_pred_len( bpe_str ):
    pred_len=len( bpe_str.split(' ') )-1
    if pred_len>=preds.shape[-1]-1:  pred_len=preds.shape[-1]-2
    return pred_len        

def pred_probs_ij(pred, pred_len=-1, debug=False):  # Defaults to :2 no-answer zone
    valid=False; a_start_best=a_end_best=0
    
    p_starts = softmax(pred[3,:])
    p_ends   = softmax(pred[4,:])
    
    p_ij = np.outer(p_starts, p_ends)
    #p_ij = np.triu(p_ij)  # Force start<=end
    p_ij = np.triu(p_ij,k=1)  # Force start<end
    
    #print( p_starts.shape, p_ends.shape, p_ij.shape )
    #print( p_ij[3,5], p_ij[5,3], )  #

    if pred_len<0:
        p_ij[0:2, :]=0.  # Kill off the joint no-go-zone
        p_ij[:, 0:2]=0.  # Kill off the joint no-go-zone
        p_no_answer = np.max(p_starts[:2])*np.max(p_ends[:2])
    else:
        p_ij[pred_len:, :]=0.  # Kill off the joint no-go-zone
        p_ij[:, pred_len:]=0.  # Kill off the joint no-go-zone
        p_no_answer = np.max(p_starts[pred_len:pred_len+2])*np.max(p_ends[pred_len:pred_len+2])
        #print(p_starts[pred_len:pred_len+2], p_ends[pred_len:pred_len+2])
        
    # Get the n-dimensional argmax
    p_ij_argmax = np.unravel_index(p_ij.argmax(), p_ij.shape)
    p_ij_max = p_ij[ p_ij_argmax ]
    
    # Now work out the combined probability of start+end in the no-go-zone
    
    print("      p_none=%.8f, p_ij_max=%.8f" % (p_no_answer, p_ij_max, ), 
          ", logit_i=%+.4f, logit_j=%+.4f" % (pred[3,p_ij_argmax[0]],pred[4,p_ij_argmax[1]],),
          p_ij_argmax)
    
    if p_ij_max>p_no_answer:
    #if p_ij_max>0.01:
    #if p_ij_max>0.03:
    #if p_ij_max>0.3:
    #if p_no_answer<0.9:
        valid=True
        a_start_best, a_end_best = p_ij_argmax
    return valid, a_start_best, a_end_best

idx=11994; pred_probs_ij( preds[idx, 0, :, :], pred_len=get_pred_len(bpes[idx]), debug=True)

In [ ]:
def guess_best_timestep(pred):
    valid=False; a_start_best=a_end_best=0
    
    # We didn't get any answers : Let's see if the best guesses are valid enough
    #if (3 not in pred_argmax_list) or (4 not in pred_argmax_list):
    #if (not valid) and (3 in pred_argmax_list or 4 in pred_argmax_list):
    if False:
        # Let's see if this naive method works (at all)
        a_start_best = np.argmax(pred[3,:])
        a_end_best   = np.argmax(pred[4,:])
        
        if False:  # See whether we can do a better job with before / after
            if pred[3, a_start_best]>pred[4, a_end_best]:
                a_end_best   = np.argmax(pred[4,a_start_best:])+a_start_best
            else:
                a_start_best = np.argmax(pred[3,:a_end_best])
        
    if True:
        # Go from the best start onwards
        a_start_0 = np.argmax(pred[3,:])
        a_end_0   = np.argmax(pred[4,a_start_0:])+a_start_0
                
        # Go from the best end backwards
        a_end_1   = np.argmax(pred[4,:])
        a_start_1 = np.argmax(pred[3,:a_end_1])
        
        if pred[3,a_start_0]+pred[4,a_end_0] > pred[3,a_start_1]+pred[4,a_end_1]:
            a_start_best, a_end_best = a_start_0, a_end_0
        else:
            a_start_best, a_end_best = a_start_1, a_end_1
        
        valid =  a_start_best<a_end_best

        if pred[3, a_start_best]<+0.0:
            valid=False
        if pred[4, a_end_best]<+0.0:
            valid=False
            
        print(" - Trying naive : ", a_start_best, a_end_best, 
              pred[3, a_start_best], pred[4, a_end_best])
            
        #valid=False
        
    return valid, a_start_best, a_end_best

In [ ]:
def interval_intersect_union(a_start_best, a_end_best, targ_a_arr):
    # Work out the overlap between (a_start_best,a_end_best) and each element of (targ_a_arr)
    best_overlap=0.
    a0, a1 = a_start_best, a_end_best
    for (b0, b1) in targ_a_arr:
        intersection = min(a1, b1) - max(a0, b0)
        union        = max(a1, b1) - min(a0, b0)
        overlap_fraction = intersection/float(union)

        if overlap_fraction>0. and best_overlap<overlap_fraction:
            best_overlap=overlap_fraction
    return best_overlap

In [ ]:
import string
"""  
We ignore word order, case, punctuation, and articles (“a”, “an”, “the”). 
We also ignore “and”, which often appears when a single span captures multiple correct answers
(e.g. “United States and Canada”).    
"""
PUNCTUATION = set(string.punctuation)
def bpe_arr_to_word_set(bpe_arr):
    text = ' '.join(bpe_arr).replace(' @@', '') # Remove @@bpe markers
    text_clean = [ ''.join(c for c in t if c not in PUNCTUATION) 
                      for t in text.split(' ') ]
    words = set(text_clean) - {'the', 'a', 'an', 'and', ''}
    #print("'%s'-> '%s'" % (text, '|'.join(list(words)),))
    return words

def target_hit_by_word(a_start, a_end, targ_a_arr, bpe_str, debug=False):
    bpe_arr = bpe_str.split(' ')
    #print("a_start, a_end, targ_a_arr, bpe_str", a_start, a_end, targ_a_arr, bpe_str)
    hit=0.
    a_words = bpe_arr_to_word_set( bpe_arr[a_start:a_end] )
    if debug : print("  a_words", a_words)
    for (b_start, b_end) in targ_a_arr:
        b_words = bpe_arr_to_word_set( bpe_arr[b_start:b_end] )
        if debug : print("  b_words", b_words)
        if a_words==b_words: hit=1.
    return hit

# test set equality
set(['foo', 'bar']) == set(['bar', 'foo']), set(['foo', 'bar']) == set(['bar', 'foo', 'baz'])

In [ ]:
def assess(idx, pred, targ, debug=False):
    #print(pred.shape, targ.shape)  # (5, 128) (128,)
    
    #  Get the starts and ends from targ
    targ_q_starts = [i for i,v in enumerate(list(targ)) if v==1]
    targ_a_starts = [i for i,v in enumerate(list(targ)) if v==3]
    #print(targ_q_starts);print(targ_a_starts)
    
    targ_q_arr, targ_a_arr = [],[]
    for i in targ_q_starts:                    # find next '2'
        # problem at idx=1218 ??
        if 2 in list(targ[i:]):
            targ_q_arr.append( (i, i+list(targ[i:]).index(2) ) )
    for i in targ_a_starts:                    # find next '4'
        # problem at idx=4480 ??
        if 4 in list(targ[i:]):
            targ_a_arr.append( (i, i+list(targ[i:]).index(4) ) )
    print( idx, 'targets : q=', targ_q_arr, ', a=', targ_a_arr )
    
    # This is pretty 'audacious', since it wasn't trained to succeed at doing this
    #print("                      answer: ", np.argmax(pred[3,:]), np.argmax(pred[4,:]),  )
    
    """
    Evaluation Metrics : Each instance is evaluated by comparing 
    the tokens in the labeled answer set with those of the predicted span. 
    Precision is the true positive count divided by the number of times the system returned a non-null answer. 
    Recall is the true positive count divided by the number of instances that have an answer.
    """
    have_ans = 1 if len(targ_a_arr)>0 else 0
    
    tp=0.
    non_null=0

    a_start_best = a_end_best = 0
    valid=False    
    
    # Normalise the predictions row-wise, since we're going to be softmaxing column-wise
    pred -= np.mean(pred, axis=0)
    
    # Try the p_ij scheme (requires 'deadzone' predictions at start)
    pred_len=-1
    if True and not valid:
        pred_len = get_pred_len(bpes[idx])  # For token_clf for non-answers
        v, a0, a1 = pred_probs_ij(pred, pred_len=pred_len)
        if v: valid, a_start_best, a_end_best = v, a0, a1
    
    # Ok, so how about looking at the list of predicted classes?
    if False and not valid:
        v, a0, a1 = pred_argmax_valid(pred)
        if v: valid, a_start_best, a_end_best = v, a0, a1
        
    # Try guessing over all the timesteps (even if not chosen class)
    if False and not valid:
        v, a0, a1 = guess_best_timestep(pred)        
        if v: valid, a_start_best, a_end_best = v, a0, a1

    if valid: # This appears to be a valid guess
        print("                      answer: [%d, %d]" % (a_start_best, a_end_best))
        non_null=1 # The system thinks it found something
        
        # Work out the overlap between (a_start_best,a_end_best) and each element of (targ_a_arr)
        #tp = interval_intersect_union(a_start_best, a_end_best, targ_a_arr)
        tp = target_hit_by_word(a_start_best, a_end_best, targ_a_arr, bpes[idx], debug=True) #debug
        
    else:  # Could we have guessed some other way?
        print("                      answer: []")

    if True and debug:
        for i,v in enumerate(list(pred.T)):
            #if not bpes[idx].split(' ')[i] == '<unk>':
            if i<len(bpes[idx].split(' '))+1:
                print("%3d" % i, ["%+7.2f" % f for f in list(v)], 
                      (bpes[idx]+' ').split(' ')[i], 
                      ('token_clf' if i==pred_len else ''))
            
    if False and debug:
        for i,v in enumerate(list(pred.T)):
            c = np.argmax(v)
            if c==0: continue
            print(i, c)  # 1->2 is subject, 3->4 is required answer(s)
    
    return tp, non_null, have_ans

# idx=0 should be a None, idx=1 should have an Answer, 6 seems correct
idx=69; assess(idx, preds[idx, 0, :, :], targs[idx, 0, :], debug=True)

In [ ]:
tp_tot, non_null_tot, have_ans_tot = 0.,0,0
for idx in range(preds.shape[0]):
    #print("idx=", idx)
    tp, non_null, have_ans = assess( idx, preds[idx, 0, :, :], targs[idx, 0, :] )
    #print("       #%d assessment :tp=%.1f, non_null=%1d, have_ans=%1d" % (idx, tp, non_null, have_ans,) )
    tp_tot += tp
    non_null_tot += non_null
    have_ans_tot += have_ans

precis = tp_tot / non_null_tot
recall = tp_tot / have_ans_tot

f1 = 2. * (precis*recall)/(precis+recall)

print("precision=%.2f%% recall=%.2f%% F1=%.2f%%" % (100.*precis, 100.*recall, 100.*f1))

In [ ]:
#Paper (BEAT THIS!)
# precision=43.61% recall=36.45% F1=39.61%


#test.1_all.hdf5_base.npz 
# precision=30.97% recall=33.06% F1=31.98%  # No restriction on start,end logit values
# precision=43.03% recall=23.35% F1=30.27%  # Force start,end logit values >0

#test.1_all.hdf5_fac1.0.npz 
# precision=24.67% recall=43.13% F1=31.39%  # No restriction on start,end logit values
# precision=29.29% recall=40.50% F1=34.00%  # Force start,end logit values >0

#test.1_all.hdf5_fac0.1.npz  # This has [0:2] focus for non-answers (but only ~1 epoch of training)
# precision=55.43% recall=19.11% F1=28.42%  # using p_ij and intersection_over_union
# precision=52.02% recall=17.93% F1=26.67%  # using p_ij and cleaned word sets


#test.1_all.hdf5_baseno.npz' # Has :2 'focus' for non-answers
# precision=44.55% recall=21.22% F1=28.74% # using p_ij and cleaned word sets

#test.1_all.hdf5_fac0.02.npz' # Has :2 'focus' for non-answers
# precision=45.45% recall=19.91% F1=27.69%  # using p_ij and cleaned word sets


# Has token_clf 'focus' for non-answers, and extra_block
# using p_ij and cleaned word sets

#test.1_all.hdf5_base-clf-extra1.npz
# precision=42.13% recall=31.78% F1=36.23%

#test.1_all.hdf5_base-clf-extra2.npz
# precision=44.86% recall=30.68% F1=36.44%

#test.1_all.hdf5_base-clf-extra4.npz 
# precision=42.41% recall=30.13% F1=35.23%  


#test.1_all.hdf5_fac0.02-clf-extra1.npz
# precision=47.16% recall=27.77% F1=34.95%
# precision=52.54% recall=24.89% F1=33.78%  # accept p_ij>0.3
# precision=44.23% recall=30.54% F1=36.14%  # accept p_ij>0.05
# precision=43.28% recall=31.50% F1=36.46%  # accept p_ij>0.03

# precision=46.14% recall=28.40% F1=35.16%  # Normalize the pred rows
# precision=43.74% recall=31.36% F1=36.53%  # Normalize the pred rows and accept p_ij>0.05
# precision=42.43% recall=32.17% F1=36.59%  # Normalize the pred rows and accept p_ij>0.03

#test.1_all.hdf5_fac0.02-clf-extra2.npz
# precision=43.33% recall=27.03% F1=33.29%

#test.1_all.hdf5_fac0.02-clf-extra4.npz
# precision=39.75% recall=27.13% F1=32.25% 


# Idea (currently running) : Add drop-out 0.1, since the net seems to be overfitting quickly
#test.1_all.hdf5_base-clf-extra-do1.npz
# precision=47.87% recall=29.51% F1=36.51%
# precision=48.05% recall=29.36% F1=36.45%  # Normalize the pred rows
# precision=44.74% recall=31.58% F1=37.03%  # Normalize the pred rows and accept p_ij>0.03
# precision=42.09% recall=33.09% F1=37.05%  # Normalize the pred rows and accept p_ij>0.01

# Idea (currently running) : Add drop-out 0.3, since 0.1 did some good...
#test.1_all.hdf5_base-clf-extra-do03.npz
# precision=43.59% recall=32.57% F1=37.28%  # Normalize the pred rows and accept p_ij>0.01
#   better.  but still not > 39.7...
# precision=35.98% recall=37.96% F1=36.94%  # Normalize the pred rows and accept p_none<0.9

In [ ]: