In [1]:
import sys
sys.path.append("..")

In [2]:
import numpy as np
np.seterr(divide="ignore")
import logging
import pickle
import glob

from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import RobustScaler
from sklearn.utils import check_random_state

from scipy import interp

%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6, 6)

In [3]:
def evaluate_models(pattern):
    rocs = []
    fprs = []
    tprs = []
    
    for filename in glob.glob(pattern):            
        data = pickle.load(open(filename, "rb"))
        y = data[:, 0]
        y_pred = data[:, 1]
        
        # Roc
        rocs.append(roc_auc_score(y, y_pred))
        fpr, tpr, _ = roc_curve(y, y_pred)
        
        fprs.append(fpr)
        tprs.append(tpr)
        
    return rocs, fprs, tprs

def plot_rocs(rocs, fprs, tprs, label="", color="r", show_all=False):
    inv_fprs = []
    base_tpr = np.linspace(0.05, 1, 476)
    
    for fpr, tpr in zip(fprs, tprs):
        inv_fpr = interp(base_tpr, tpr, 1. / fpr)
        inv_fprs.append(inv_fpr)
        if show_all:
            plt.plot(base_tpr, inv_fpr, alpha=0.1, color=color)
        
    inv_fprs = np.array(inv_fprs)
    mean_inv_fprs = inv_fprs.mean(axis=0)


    plt.plot(base_tpr, mean_inv_fprs, color, 
             label="%s" % label)
    
def plot_show(filename=None):
    plt.xlabel("Signal efficiency")
    plt.ylabel("1 / Background efficiency")
    plt.xlim([0.1, 1.0])
    plt.yscale("log")
    plt.legend(loc="best")
    plt.grid()
    
    if filename:
        plt.savefig(filename)
    
    plt.show()
    
def report_score(rocs, fprs, tprs, label, latex=False):       
    inv_fprs = []
    base_tpr = np.linspace(0.05, 1, 476)
    
    for fpr, tpr in zip(fprs, tprs):
        inv_fpr = interp(base_tpr, tpr, 1. / fpr)
        inv_fprs.append(inv_fpr)
        
    inv_fprs = np.array(inv_fprs)
    mean_inv_fprs = inv_fprs.mean(axis=0)
    
    if not latex:
        print("%32s\tROC AUC=%.4f+-%.2f\t1/FPR@TPR=0.5=%.2f+-%.2f" %  (label, 
                                                                       np.mean(rocs), 
                                                                       np.std(rocs),
                                                                       np.mean(inv_fprs[:, 375]), # was 225 for 50%
                                                                       np.std(inv_fprs[:, 375])))
    else:
        print("%30s \t& %.4f $\pm$ %.4f \t& %.1f $\pm$ %.1f \\\\" % 
              (label,
               np.mean(rocs), 
               np.std(rocs),
               np.mean(inv_fprs[:, 375]),
               np.std(inv_fprs[:, 375])))

In [7]:
for n_jets, color in [(1, "r"), (2, "g"), (5, "b")]:
    pattern = "antikt-kt"

    r, f, t = evaluate_models("../models/event-study/predictions-reversed/predictions-e-full-%s-%d-*.pickle" % (pattern, n_jets))
    plot_rocs(r, f, t, label=r"$\leq$ %d jet%s" % (n_jets, "" if n_jets==1 else "s"), color=color, show_all=False)
    report_score(r, f, t, label=str(n_jets), latex=True)

plot_show(filename="event-n-jets.pdf")


                             1 	& 0.9602 $\pm$ 0.0004 	& 26.7 $\pm$ 0.7 \\
                             2 	& 0.9866 $\pm$ 0.0007 	& 156.9 $\pm$ 14.8 \\
                             5 	& 0.9867 $\pm$ 0.0004 	& 152.8 $\pm$ 10.4 \\

In [44]:
for n_jets in [1, 2, 5]:
    for pattern, color in [
            ("antikt-kt", "r"),
            ("antikt-seqpt-reversed", "r"),
            #("antikt-seqpt-reversed-towers", "towers", "g"),
            #("antikt-seqpt-reversed-towers-pflow", "towers+pflow", "b"),
        ]:

        r, f, t = evaluate_models("../models/event-study/predictions-reversed/predictions-e-full-%s-%d-*.pickle" % (pattern, n_jets))
        report_score(r, f, t, label=str(n_jets), latex=True)
        
for pattern, n_jets, label, color in [
        ("jet4v", 1, "1 jet4v", "r"),
        ("jet4v", 2, "2 jet4v", "b"),
        ("jet4v", 5, "5 jet4v", "g")
    ]:
    
    r, f, t = evaluate_models("../models/event-study/predictions-reversed/predictions-e-4v-%s-%d-*.pickle" % (pattern, n_jets))
    report_score(r, f, t, label=label, latex=True)
    
for pattern, n_jets, label, color in [
        ("all4v", 1, "1 all4v", "r"),
        ("all4v", 50, "50 all4v", "g"),
        ("all4v", 100, "100 all4v", "b"),
        ("all4v", 200, "200 all4v", "c"),
        ("all4v", 400, "400 all4v", "m")
    ]:
    
    r, f, t = evaluate_models("../models/event-study/predictions-reversed/predictions-e-4v-%s-%d-*.pickle" % (pattern, n_jets))
    report_score(r, f, t, label=label, latex=True)


                             1 	& 0.9602 $\pm$ 0.0004 	& 26.7 $\pm$ 0.7 \\
                             1 	& 0.9594 $\pm$ 0.0010 	& 25.6 $\pm$ 1.4 \\
                             2 	& 0.9866 $\pm$ 0.0007 	& 156.9 $\pm$ 14.8 \\
                             2 	& 0.9875 $\pm$ 0.0006 	& 174.5 $\pm$ 14.0 \\
                             5 	& 0.9867 $\pm$ 0.0004 	& 152.8 $\pm$ 10.4 \\
                             5 	& 0.9872 $\pm$ 0.0003 	& 167.8 $\pm$ 9.5 \\
                       1 jet4v 	& 0.8909 $\pm$ 0.0007 	& 5.6 $\pm$ 0.0 \\
                       2 jet4v 	& 0.9606 $\pm$ 0.0011 	& 21.1 $\pm$ 1.1 \\
                       5 jet4v 	& 0.9576 $\pm$ 0.0019 	& 20.3 $\pm$ 0.9 \\
                       1 all4v 	& 0.6501 $\pm$ 0.0023 	& 1.7 $\pm$ 0.0 \\
                      50 all4v 	& 0.8925 $\pm$ 0.0079 	& 5.6 $\pm$ 0.5 \\
                     100 all4v 	& 0.8781 $\pm$ 0.0180 	& 4.9 $\pm$ 0.6 \\
                     200 all4v 	& 0.8846 $\pm$ 0.0091 	& 5.2 $\pm$ 0.5 \\
                     400 all4v 	& 0.8780 $\pm$ 0.0132 	& 4.9 $\pm$ 0.5 \\

Count parameters


In [45]:
def count(params):
    def _count(thing):
        if isinstance(thing, list):
            c = 0
            for stuff in thing:
                c += _count(stuff)
            return c 

        elif isinstance(thing, np.ndarray):
            return np.prod(thing.shape)
    
    c = 0
    for k, v in params.items():
        c += _count(v)
    return c
    
# Simple vs gated
fd = open("../models/event-study/model-e-full-antikt-kt-1-1.pickle", "rb")
params = pickle.load(fd)
fd.close()
print("Simple =", count(params))


('Simple =', 18681)

In [46]:
params


Out[46]:
{'W_clf': [array([[-0.01879265,  0.12167948,  0.05035893, ..., -0.01513335,
           0.01769368,  0.0184784 ],
         [-0.09257675,  0.20137325,  0.07307229, ...,  0.00381305,
           0.02993886,  0.00911687],
         [-0.0496274 ,  0.12184974,  0.06705015, ...,  0.03035775,
           0.07714101,  0.03544785],
         ..., 
         [-0.07723944,  0.13333639,  0.03686497, ...,  0.00210926,
           0.05933862,  0.00400085],
         [-0.02276393, -0.01041212, -0.00182888, ..., -0.00364235,
          -0.00650449, -0.02148335],
         [-0.01336965,  0.14102805,  0.06182186, ..., -0.01678019,
           0.06684246,  0.03531396]]),
  array([[-0.01825787, -0.10482513, -0.076065  , ..., -0.04278567,
           0.01373406, -0.02729302],
         [-0.00989853, -0.03377796, -0.01971048, ..., -0.02333645,
          -0.02379274,  0.00634366],
         [-0.01391392, -0.00220749,  0.01460803, ..., -0.00829536,
          -0.02685779, -0.01533192],
         ..., 
         [ 0.05855084,  0.04087334,  0.07941288, ...,  0.04692357,
           0.01635977,  0.06657046],
         [ 0.079574  ,  0.05545069,  0.05553415, ...,  0.07535648,
          -0.01783014,  0.0582753 ],
         [ 0.00697943, -0.05255406, -0.01589425, ..., -0.0388641 ,
          -0.02876457, -0.0267121 ]]),
  array([  5.57005950e-01,   4.02156176e-01,   1.46203979e-02,
           2.67233287e-02,  -1.00618174e-01,  -8.12232419e-02,
          -9.00368618e-04,  -9.41193851e-02,   7.14083042e-03,
          -9.95708442e-05,   2.75106609e-03,   4.32685877e-01,
          -1.02409539e-01,  -7.91459800e-02,   5.89771955e-01,
           3.25767035e-02,   1.38920925e-02,   3.83699329e-01,
          -3.10565931e-03,  -9.25525415e-02,  -1.06399997e-01,
           9.30363144e-03,   5.47926505e-03,  -1.00221677e-01,
          -6.57951521e-02,  -2.62835429e-02,  -7.91666562e-02,
           1.15694028e-02,   4.19822683e-03,   1.45795745e-02,
           2.50424004e-02,  -1.60167816e-03,   4.18910134e-02,
           1.35461559e-02,   8.61196939e-03,  -1.12366122e-02,
          -9.75378819e-02,  -1.09984029e-01,  -1.05974155e-01,
           4.39441185e-01])],
 'W_h': array([[ 0.01974727,  0.14772742,  0.12372879, ..., -0.19854826,
         -0.12167051, -0.09815855],
        [ 0.36368217,  0.2283872 , -0.01843838, ...,  0.30200861,
          0.07037456, -0.34377685],
        [ 0.01616307, -0.17553144,  0.02151688, ..., -0.22608511,
         -0.13746089, -0.06271186],
        ..., 
        [ 0.03614766, -0.20570639, -0.07699956, ..., -0.13450901,
         -0.04444708,  0.07886513],
        [ 0.02847073, -0.0260854 ,  0.20598863, ..., -0.1909581 ,
         -0.05806769,  0.09483387],
        [ 0.05298898, -0.12581686, -0.08005072, ..., -0.11205202,
         -0.05625623, -0.06792623]]),
 'W_u': array([[ -8.00464763e-02,  -3.60905375e-02,   4.22563264e-03,
           2.86954230e-02,   4.68650841e-02,  -9.03029496e-02,
           7.98984963e-02],
        [ -4.74589143e-02,   7.69045454e-03,  -1.56312894e-01,
          -1.80271188e-02,   1.92583962e-03,  -6.05192750e-02,
          -1.48113889e-02],
        [ -2.49088334e-02,   2.39887191e-02,  -1.58960232e-01,
          -1.91376362e-03,  -6.71875444e-02,  -2.81395266e-02,
           6.73105097e-03],
        [  8.34806749e-03,   1.89845265e-02,   9.43713770e-02,
          -2.62720411e-02,   4.87153372e-02,  -2.75283579e-02,
           2.51243143e-02],
        [ -1.79948820e-02,  -4.48976317e-02,   2.22491012e-01,
           4.68127085e-03,   5.56808376e-02,  -3.83294049e-02,
           4.61097870e-02],
        [ -1.20775206e-01,   3.14426580e-02,   5.34080175e-02,
           1.12957662e-03,   6.85002031e-02,  -1.39631049e-01,
          -4.23380720e-02],
        [  9.22394683e-04,  -1.53588045e-02,   3.25121554e-01,
           1.69598584e-02,   4.13441644e-02,   4.08595208e-03,
          -2.44147896e-02],
        [  1.14009455e-03,  -8.62299862e-03,   1.49217906e-01,
          -2.85113346e-02,   5.60177949e-02,  -2.71596250e-02,
           3.00468432e-02],
        [  2.36333979e-02,  -9.96600090e-04,   8.01842731e-02,
          -3.27639562e-03,   5.31911733e-03,   3.90892719e-02,
           1.55724709e-02],
        [  3.98564806e-02,   1.28491278e-02,  -1.57499588e-03,
           1.92621122e-02,   6.91084800e-02,   5.79543501e-03,
           2.02317802e-02],
        [ -5.01252947e-02,   6.10581490e-02,   6.78237566e-02,
           1.97910478e-03,   3.70606121e-02,  -5.55315024e-02,
          -1.38372127e-02],
        [ -3.59648091e-02,   4.15676733e-02,  -4.59267226e-03,
           7.70007875e-03,  -5.05795937e-02,  -4.58131759e-02,
          -3.53157033e-02],
        [ -2.37838282e-02,   9.68973986e-03,   1.62553289e-01,
           6.13371050e-03,  -2.95196049e-02,  -1.82485557e-02,
           2.46825871e-03],
        [ -4.24779115e-02,  -4.09669693e-03,   5.52412952e-04,
          -7.29548832e-03,   3.37669562e-02,   1.49684708e-03,
          -8.07305555e-03],
        [  6.94983625e-03,  -1.43536094e-02,  -9.64291398e-02,
           1.09688614e-02,  -3.53499532e-02,   2.25653611e-02,
           5.77280534e-02],
        [ -8.61923770e-02,   2.50204288e-02,   9.67013259e-02,
          -3.52562771e-02,   6.63362648e-02,  -8.93768763e-02,
          -7.52996123e-03],
        [ -1.56437236e-02,   1.39715349e-02,   2.39069487e-01,
           2.09897319e-02,   9.27862675e-03,  -1.63755115e-02,
           1.04731517e-02],
        [ -9.04888569e-02,  -2.37942698e-03,   1.86282269e-01,
          -7.43379105e-03,  -2.20444063e-02,  -5.34474631e-02,
           9.77492172e-03],
        [ -6.94197326e-02,   4.89003249e-02,   5.64148471e-02,
          -1.73830769e-02,   9.26940887e-03,  -1.40573864e-01,
          -4.67450493e-02],
        [ -7.02777365e-02,  -4.22445450e-02,   5.77588303e-02,
          -2.14448252e-02,   9.85699829e-04,  -5.84470830e-02,
           4.66413506e-02],
        [ -4.31649053e-01,  -7.90451466e-03,   4.72421083e-02,
          -4.42562404e-01,  -4.99218879e-01,  -4.23103042e-01,
           1.58667674e-02],
        [ -4.16100182e-02,  -3.41266659e-03,   9.72125320e-02,
           1.35450115e-02,   8.05683111e-02,  -4.03768861e-02,
           1.96493546e-02],
        [  9.37884379e-03,  -4.56555474e-02,   1.39131094e-02,
          -2.58169853e-02,   8.45008701e-04,  -1.24594324e-03,
           6.10751051e-02],
        [ -3.98197830e-02,   1.20255878e-02,   5.94642456e-02,
          -3.23731594e-02,   6.90965055e-03,  -5.15956471e-02,
           3.44805597e-02],
        [  3.09640530e-02,  -3.42894535e-02,  -1.49886818e-01,
           1.42396956e-02,  -1.82418849e-02,  -1.26823293e-02,
           2.27614464e-02],
        [ -7.41689119e-03,  -1.14215291e-02,  -9.93698076e-02,
          -6.78927152e-03,   1.29424656e-02,   1.27745525e-02,
           7.97061854e-03],
        [  8.40046077e-04,   2.40483435e-02,  -1.04973604e-01,
           1.63641428e-02,  -5.64032362e-02,  -5.70743804e-02,
           1.27384542e-02],
        [  1.53989893e-02,   3.61180018e-04,   6.04684886e-03,
          -3.23258228e-02,   1.45710032e-02,  -9.93928302e-03,
          -2.11030666e-03],
        [ -4.85808157e-02,   1.56448205e-02,   1.00686509e-01,
          -1.45646819e-02,   2.28815671e-02,  -4.67932542e-04,
           3.73134531e-02],
        [ -4.48556621e-02,   7.20680971e-02,   2.24958141e-01,
          -2.28236485e-03,  -7.15693474e-02,  -6.41308955e-02,
          -1.01077241e-01],
        [ -3.75803605e-02,  -2.17606843e-02,   5.59959637e-02,
           1.02881707e-02,   2.44116640e-02,   4.74785381e-03,
           3.12104606e-02],
        [ -4.26656142e-02,   5.30082756e-03,   8.44249367e-02,
          -3.88443546e-02,  -2.79977062e-02,  -4.09966599e-02,
          -1.96432413e-02],
        [ -1.00132092e-01,  -5.13698069e-02,   3.93573264e-02,
          -6.16922340e-02,  -2.37568449e-02,  -9.21825816e-02,
           1.99664114e-02],
        [  6.86196778e-02,  -1.41279029e-02,   2.31559400e-02,
           3.22028741e-03,   9.37032379e-03,   4.95209385e-02,
          -5.83259200e-03],
        [ -1.10041195e-01,   6.52921392e-02,   2.22897603e-02,
          -6.91492736e-02,  -5.34563131e-02,  -7.05718439e-02,
           1.14383513e-02],
        [ -7.19024876e-02,   1.47952524e-02,  -1.09894281e-01,
           1.43654707e-02,   1.26428124e-02,  -7.92072382e-02,
          -2.19342738e-02],
        [ -3.44461186e-01,   3.54327963e-02,  -8.50446057e-03,
          -3.62193733e-01,  -3.21163629e-01,  -2.88715270e-01,
           1.16598395e-02],
        [ -2.37572600e-01,  -3.49844712e-02,   1.80208989e-01,
          -2.59816690e-01,  -1.82832846e-01,  -2.70701809e-01,
           1.19617849e-02],
        [ -7.00901272e-03,  -6.54688021e-02,   2.00072498e-02,
           2.40143056e-02,   2.41346891e-02,  -9.83634691e-02,
           3.40356741e-02],
        [ -2.77593971e-02,  -4.66968438e-03,   1.17551429e-01,
          -4.85769022e-03,   6.33988750e-02,  -2.42953661e-02,
           8.01621049e-02]]),
 'b_clf': [array([-0.00610265, -0.04067829, -0.01242777, -0.02530428, -0.00412801,
          0.01446729, -0.0048849 , -0.00798496, -0.0075703 , -0.01632144,
         -0.02463091, -0.02140203, -0.01096214, -0.03419683,  0.00561615,
          0.04000476, -0.02457413, -0.01247703, -0.01759678, -0.01621177,
         -0.00300231, -0.02857965, -0.0081975 , -0.00929255, -0.01124441,
         -0.02442649, -0.01639863, -0.00796442, -0.00377546,  0.35372837,
         -0.02340254, -0.02366901,  0.00561398,  0.00855931,  0.08752517,
         -0.00294854,  0.00764653, -0.02323009, -0.00399195, -0.00191849]),
  array([  2.55760530e-01,   2.47239546e-01,  -7.16976182e-03,
          -3.00275110e-03,   5.49836473e-02,   5.30676765e-02,
          -1.16429074e-02,   6.17413275e-02,  -9.57929948e-03,
          -1.66046812e-02,  -2.67346246e-03,   1.20416617e-01,
           5.51028835e-02,   5.17090494e-02,   3.27772365e-01,
          -2.77941810e-03,  -3.00274714e-03,   7.30988135e-02,
          -1.48455077e-02,   5.13030172e-02,   6.09062345e-02,
          -3.38595715e-03,  -4.24702329e-15,   6.21012153e-02,
           3.17397892e-02,  -1.26669228e-02,   5.22615191e-02,
          -3.00274600e-03,  -2.77930825e-03,  -3.40203588e-03,
          -3.00275026e-03,  -5.16655320e-03,  -1.07632505e-02,
          -3.00274886e-03,  -3.00274657e-03,  -2.02772299e-02,
           5.52939094e-02,   5.78714075e-02,   5.34073158e-02,
           2.83044053e-01]),
  array([ 1.02375476])],
 'b_h': array([-0.00202093,  0.02270088,  0.02970479,  0.0284185 ,  0.04998483,
         0.05593001,  0.01357678,  0.07409076,  0.03527855,  0.07112114,
         0.04451713,  0.00521564,  0.02694463,  0.11302489, -0.04083855,
         0.10900568,  0.01803328,  0.0428488 ,  0.0542678 , -0.00193536,
        -0.02523182,  0.04297089, -0.02625267,  0.02538289, -0.02008486,
        -0.03016164,  0.04344521,  0.01196962,  0.02011716, -0.02578376,
         0.04815035, -0.05220733, -0.09269258,  0.05446446, -0.01310238,
         0.02092805,  0.07637499, -0.08737712,  0.06530805, -0.04705259]),
 'b_u': array([ -2.02096981e-03,   3.76569763e-02,   1.26718224e-02,
          1.13294689e-01,  -6.67768434e-02,   7.91690218e-03,
          1.63022501e-02,   1.68371886e-01,   3.31162060e-02,
          9.67815729e-03,  -7.20634609e-02,  -1.05821970e-01,
         -1.05373887e-02,  -3.08249747e-02,  -1.19548975e-01,
          1.46303156e-01,  -1.42588236e-02,   9.20111077e-03,
          9.76400031e-05,  -1.94719047e-02,   6.94249692e-03,
          1.03893869e-01,  -1.21180674e-01,   3.05381385e-02,
         -1.09974438e-01,  -9.90072665e-02,  -7.47284247e-03,
         -5.97741529e-02,   1.09132740e-01,  -5.00926273e-02,
          5.74942468e-02,   8.42394005e-02,   2.85232015e-02,
          5.16776363e-02,  -6.65907516e-02,   2.81825016e-02,
          1.98942768e-02,  -3.14275103e-02,  -9.43333544e-02,
          1.17212149e-01]),
 'rnn_W_hh': array([[-0.03050821, -0.11088868,  0.09415387, ..., -0.04635854,
          0.06578052, -0.18449011],
        [ 0.16124935, -0.07988935,  0.0950801 , ..., -0.15530077,
         -0.16864672, -0.16331793],
        [-0.25760482, -0.16464015,  0.31409882, ..., -0.30256353,
         -0.24375452,  0.17209473],
        ..., 
        [-0.24034251,  0.01393573,  0.01907851, ...,  0.26003188,
         -0.16433851,  0.20995987],
        [ 0.15137618,  0.22750055, -0.07541834, ...,  0.18093537,
         -0.0916946 , -0.01962224],
        [ 0.17390238, -0.20194708, -0.09696517, ..., -0.01509385,
          0.25595243,  0.06582198]]),
 'rnn_W_hx': array([[ 0.01148763,  0.01235511,  0.04794213, ..., -0.02414485,
          0.04163358,  0.00123901],
        [-0.02127907,  0.0432209 ,  0.02570181, ...,  0.00736006,
         -0.07197025, -0.00010674],
        [ 0.01458835, -0.014257  ,  0.01567932, ..., -0.02196076,
         -0.04810935,  0.00368074],
        ..., 
        [ 0.00226206, -0.01155808, -0.02127378, ..., -0.00192063,
          0.01722458,  0.00783579],
        [-0.00346175,  0.04349359, -0.02002164, ...,  0.01915557,
         -0.04449929, -0.00956241],
        [-0.00108115, -0.00205046,  0.00461357, ...,  0.02407351,
         -0.02207489, -0.02253598]]),
 'rnn_W_rh': array([[-0.07635428,  0.0565838 ,  0.26388732, ...,  0.20654336,
         -0.03741565, -0.17219083],
        [-0.11286982, -0.13705611, -0.05184282, ...,  0.05279542,
         -0.21846888, -0.23087424],
        [-0.04563776,  0.07140586, -0.19570502, ...,  0.01644539,
         -0.29097748, -0.11916529],
        ..., 
        [ 0.08571318, -0.09089051,  0.03988208, ...,  0.26591707,
         -0.29649496, -0.02333242],
        [ 0.16111102,  0.17350133,  0.03775965, ..., -0.08441271,
          0.22445033, -0.22500197],
        [ 0.29378044, -0.06810597, -0.10991058, ...,  0.2169553 ,
          0.1654237 , -0.13737786]]),
 'rnn_W_rx': array([[-0.02610767,  0.00072695, -0.02124127, ...,  0.01866981,
         -0.01254115,  0.01378381],
        [ 0.01468792, -0.02271415,  0.0001277 , ...,  0.01494801,
          0.01064045, -0.01657248],
        [ 0.00432104,  0.0120913 , -0.0135704 , ..., -0.00144359,
          0.01563107, -0.01014731],
        ..., 
        [-0.00298919,  0.00295855,  0.0020107 , ..., -0.00157785,
          0.01459769, -0.00194195],
        [ 0.00232966, -0.02409382,  0.00918956, ..., -0.01226513,
          0.01226459,  0.01018336],
        [ 0.02333291, -0.01922097,  0.00800325, ..., -0.00878727,
          0.00631653, -0.00622069]]),
 'rnn_W_zh': array([[-0.18792148,  0.13439998,  0.01304053, ...,  0.01685205,
         -0.12767662, -0.09408914],
        [-0.06417356,  0.09061818, -0.06307769, ..., -0.28284992,
          0.02125995,  0.05533026],
        [ 0.05205944, -0.42202862,  0.01657311, ..., -0.20543427,
          0.0921075 ,  0.07182994],
        ..., 
        [-0.30809599, -0.22521354,  0.10574121, ..., -0.11064812,
          0.08345664,  0.12644231],
        [ 0.11669041,  0.18725834,  0.1139661 , ..., -0.02864991,
          0.0072553 , -0.015174  ],
        [-0.02862377,  0.29467108, -0.08744986, ...,  0.27074032,
          0.15784995,  0.16551227]]),
 'rnn_W_zx': array([[ 0.05121996, -0.19880803,  0.23932369, ..., -0.01991692,
         -0.02722738, -0.01050081],
        [ 0.03759565,  0.07122856, -0.18138496, ...,  0.00191475,
         -0.08406076,  0.0067242 ],
        [-0.00528527,  0.02770999, -0.12858089, ...,  0.00589256,
         -0.05851751,  0.0078375 ],
        ..., 
        [-0.01828347, -0.00610605,  0.00325965, ..., -0.01264076,
         -0.02159346, -0.00370445],
        [ 0.0498488 ,  0.07560126, -0.17213781, ...,  0.01810733,
         -0.08018063, -0.02205877],
        [ 0.02651625,  0.0162009 , -0.00373256, ...,  0.01162476,
         -0.00927315,  0.00218924]]),
 'rnn_b_h': array([  9.19690884e-02,  -9.29697891e-03,  -1.24427435e-02,
         -6.91722297e-03,   2.50307569e-02,  -7.86087892e-03,
          2.71294965e-02,   8.73395085e-02,   1.63606645e-02,
         -1.57216611e-02,  -2.59929222e-03,  -8.11954963e-03,
         -3.21676254e-03,  -4.46999581e-03,   7.38032390e-02,
         -1.66348004e-02,   6.00904926e-02,  -1.86295874e-02,
         -2.34373125e-02,  -6.26838855e-03,   7.49334289e-02,
          2.30002855e-02,   9.91031236e-02,  -2.95705422e-02,
         -5.36528463e-03,  -1.47618022e-02,   1.90646274e-03,
          3.05997541e-02,   9.92560465e-02,  -1.83732090e-02,
          2.00923286e-02,   8.40589932e-05,  -4.37752255e-02,
          1.46987863e-01,  -9.55784250e-03,   4.81485978e-02,
         -2.52190144e-02,  -5.71423120e-03,  -1.75512729e-03,
         -7.16212774e-03]),
 'rnn_b_r': array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]),
 'rnn_b_z': array([ 0.14764824,  0.05803031,  0.06029166, -0.00120783,  0.12270585,
         0.03824776,  0.15554893,  0.21257403, -0.01793576,  0.03947073,
         0.00775501,  0.09752469,  0.07868246, -0.0036088 ,  0.22129427,
         0.02265617,  0.02929012, -0.01244162, -0.01851299,  0.07246438,
         0.06037606,  0.04767583,  0.17476271,  0.03913692,  0.0457808 ,
        -0.00518948,  0.07862717,  0.02262896,  0.20947699,  0.08002256,
         0.11213885,  0.00677252,  0.04522424,  0.13368829,  0.03577564,
         0.10181523,  0.02318757, -0.00418356,  0.06517949, -0.00685795])}