In [1]:
from __future__ import print_function, division
import numpy as np
import theano
import theano.tensor as T

In [2]:
N_SEQ_PER_BATCH, SEQ_LENGTH, N_OUTPUTS = 2, 8, 3
t = np.random.randn(N_SEQ_PER_BATCH, SEQ_LENGTH, N_OUTPUTS)
error = np.random.randn(N_SEQ_PER_BATCH, SEQ_LENGTH, N_OUTPUTS)

In [3]:
t


Out[3]:
array([[[-1.87347891,  0.52018866,  2.14607176],
        [-0.02940302, -0.90604823, -2.06917314],
        [-0.55205747,  0.52891592,  2.0292411 ],
        [-0.55563829,  0.84442958, -0.86873026],
        [-0.50906761, -0.30227186,  0.43699785],
        [ 0.46817846, -0.77619128,  0.11565922],
        [ 1.42977919,  0.27736022, -2.09792215],
        [ 0.83361268, -1.2146074 ,  1.00099365]],

       [[ 0.68847109,  1.47052387, -0.70513472],
        [ 0.07405826,  3.8343556 ,  1.0336292 ],
        [ 0.11565304, -1.37147437,  1.05797102],
        [-1.74104643, -0.45801851,  0.68005134],
        [-1.11734962,  1.07181886, -1.45159929],
        [ 0.71204261,  1.49806891, -1.01176102],
        [-0.24920642, -0.20566547,  2.58757519],
        [-0.80400528, -0.27101056, -0.82399376]]])

In [4]:
THRESHOLD = 0.0
error_accumulator = 0.0
n_active_seqs = 0
for seq_i in range(N_SEQ_PER_BATCH):
    for output_i in range(N_OUTPUTS):
        above_thresh = t[seq_i, :, output_i] > THRESHOLD
        if not any(above_thresh):
            continue
        n_active_seqs += 1
        def mask_and_mean(mask):
            return error[seq_i,  mask.nonzero(), output_i].mean()
        error_above_thresh = mask_and_mean(above_thresh)
        error_below_thresh = mask_and_mean(-above_thresh)
        error_accumulator += (error_above_thresh + error_below_thresh) / 2.0
        
err = error_accumulator / n_active_seqs

In [34]:
t = theano.shared(t)
error = theano.shared(error)

In [35]:
above_thresh = t > THRESHOLD

In [78]:
above_thresh


Out[78]:
Elemwise{gt,no_inplace}.0

In [79]:
# above_thresh = above_thresh.astype(np.float32)

In [90]:
above_thresh.eval()


Out[90]:
array([[[0, 1, 1],
        [0, 0, 0],
        [0, 1, 1],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [1, 1, 0],
        [1, 0, 1]],

       [[1, 1, 0],
        [1, 1, 1],
        [1, 0, 1],
        [0, 0, 1],
        [0, 1, 0],
        [1, 1, 0],
        [0, 0, 1],
        [0, 0, 0]]], dtype=int8)

In [94]:
# masked_error = T.switch(t > THRESHOLD, error, np.nan)
masked_error = error * above_thresh
masked_error.eval()


Out[94]:
array([[[-0.        ,  0.63391603,  0.04010954],
        [-0.        , -0.        ,  0.        ],
        [-0.        , -1.18162216, -1.4576782 ],
        [ 0.        ,  0.01867809,  0.        ],
        [ 0.        ,  0.        , -0.75736401],
        [ 0.75839996,  0.        ,  0.20373484],
        [-1.38148993, -0.51837588, -0.        ],
        [-1.50413076, -0.        ,  0.32015957]],

       [[ 2.06418989,  0.41493627,  0.        ],
        [ 0.33135649,  1.29259349,  1.4161259 ],
        [-0.94384364,  0.        ,  0.9548283 ],
        [ 0.        , -0.        , -0.45555172],
        [ 0.        , -0.8568483 , -0.        ],
        [ 0.30008834,  1.07369502, -0.        ],
        [ 0.        ,  0.        , -0.95485029],
        [-0.        , -0.        , -0.        ]]])

In [ ]:


In [95]:
# masked_error = error * above_thresh

In [103]:
masked_error.eval()


Out[103]:
array([[[-0.        ,  0.63391603,  0.04010954],
        [-0.        , -0.        ,  0.        ],
        [-0.        , -1.18162216, -1.4576782 ],
        [ 0.        ,  0.01867809,  0.        ],
        [ 0.        ,  0.        , -0.75736401],
        [ 0.75839996,  0.        ,  0.20373484],
        [-1.38148993, -0.51837588, -0.        ],
        [-1.50413076, -0.        ,  0.32015957]],

       [[ 2.06418989,  0.41493627,  0.        ],
        [ 0.33135649,  1.29259349,  1.4161259 ],
        [-0.94384364,  0.        ,  0.9548283 ],
        [ 0.        , -0.        , -0.45555172],
        [ 0.        , -0.8568483 , -0.        ],
        [ 0.30008834,  1.07369502, -0.        ],
        [ 0.        ,  0.        , -0.95485029],
        [-0.        , -0.        , -0.        ]]])

In [86]:
T.mean(masked_error, axis=1).eval()


Out[86]:
array([[ nan,  nan,  nan],
       [ nan,  nan,  nan]])

In [97]:
masked_error.sum(axis=1).eval()


Out[97]:
array([[-2.12722073, -1.04740391, -1.65103825],
       [ 1.75179108,  1.92437647,  0.9605522 ]])

In [98]:
above_thresh.sum(axis=1).eval()


Out[98]:
array([[3, 4, 5],
       [4, 4, 4]])

In [99]:
(masked_error.sum(axis=1) / above_thresh.sum(axis=1)).eval()


Out[99]:
array([[-0.70907358, -0.26185098, -0.33020765],
       [ 0.43794777,  0.48109412,  0.24013805]])

In [102]:
masked_error.mean(axis=1).eval()


Out[102]:
array([[-0.26590259, -0.13092549, -0.20637978],
       [ 0.21897388,  0.24054706,  0.12006902]])

In [110]:
T.neq(masked_error, 0).eval()


Out[110]:
array([[[0, 1, 1],
        [0, 0, 0],
        [0, 1, 1],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 1],
        [1, 1, 0],
        [1, 0, 1]],

       [[1, 1, 0],
        [1, 1, 1],
        [1, 0, 1],
        [0, 0, 1],
        [0, 1, 0],
        [1, 1, 0],
        [0, 0, 1],
        [0, 0, 0]]], dtype=int8)

In [114]:
(T.eq(T.isnan(masked_error), 0)).eval()


Out[114]:
array([[[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]],

       [[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]]], dtype=int8)

In [ ]: