In [1]:
from __future__ import print_function, division
import numpy as np
import theano
import theano.tensor as T
In [2]:
N_SEQ_PER_BATCH, SEQ_LENGTH, N_OUTPUTS = 2, 8, 3
t = np.random.randn(N_SEQ_PER_BATCH, SEQ_LENGTH, N_OUTPUTS)
error = np.random.randn(N_SEQ_PER_BATCH, SEQ_LENGTH, N_OUTPUTS)
In [3]:
t
Out[3]:
In [4]:
THRESHOLD = 0.0
error_accumulator = 0.0
n_active_seqs = 0
for seq_i in range(N_SEQ_PER_BATCH):
for output_i in range(N_OUTPUTS):
above_thresh = t[seq_i, :, output_i] > THRESHOLD
if not any(above_thresh):
continue
n_active_seqs += 1
def mask_and_mean(mask):
return error[seq_i, mask.nonzero(), output_i].mean()
error_above_thresh = mask_and_mean(above_thresh)
error_below_thresh = mask_and_mean(-above_thresh)
error_accumulator += (error_above_thresh + error_below_thresh) / 2.0
err = error_accumulator / n_active_seqs
In [34]:
t = theano.shared(t)
error = theano.shared(error)
In [35]:
above_thresh = t > THRESHOLD
In [78]:
above_thresh
Out[78]:
In [79]:
# above_thresh = above_thresh.astype(np.float32)
In [90]:
above_thresh.eval()
Out[90]:
In [94]:
# masked_error = T.switch(t > THRESHOLD, error, np.nan)
masked_error = error * above_thresh
masked_error.eval()
Out[94]:
In [ ]:
In [95]:
# masked_error = error * above_thresh
In [103]:
masked_error.eval()
Out[103]:
In [86]:
T.mean(masked_error, axis=1).eval()
Out[86]:
In [97]:
masked_error.sum(axis=1).eval()
Out[97]:
In [98]:
above_thresh.sum(axis=1).eval()
Out[98]:
In [99]:
(masked_error.sum(axis=1) / above_thresh.sum(axis=1)).eval()
Out[99]:
In [102]:
masked_error.mean(axis=1).eval()
Out[102]:
In [110]:
T.neq(masked_error, 0).eval()
Out[110]:
In [114]:
(T.eq(T.isnan(masked_error), 0)).eval()
Out[114]:
In [ ]: