In [5]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir) 

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from browser import *

In [12]:
TUNABLE = ['m_groups', 'k_winners', 'eps', 'fpartition', 'forget_mu', 'boost_strength', 'boost_strength_factor']
NUMERIC_TUNABLE = ['m_groups', 'k_winners', 'eps', 'forget_mu', 'boost_strength', 'boost_strength_factor']

In [7]:
exps = [
    'Flat_PTB_Explore',
    'Flat_PTB_Snipe',
    'Flat_PTB_Snipe2',
]

paths = [os.path.expanduser("~/s3_jgordon/ray/results/{}".format(e)) for e in exps]
df = load_many(paths)


Loaded 96 experiments

In [13]:
df.loc[~df.k_winners_pct.isnull(), 'k_winners'] = df.k_winners_pct * df.m_groups
df.k_winners = df.k_winners.astype(int)

for tp in TUNABLE:
    if tp not in df:
        df[tp] = None
        
df.fillna(value={
    'forget_mu': 0.0,
    'dropout_p': 0.0,
    'balance_part_winners': False
}, inplace=True)

In [14]:
def scatter_all_tunable(df, params=NUMERIC_TUNABLE, ppl_cutoff=220):
    fig, axs = plt.subplots(len(params), 1, figsize=(8, 14), 
                            dpi=144, 
                            gridspec_kw={'hspace': 0.7})
    for ax, p in zip(axs, params):
        df[df.val_pred_ppl_min <= ppl_cutoff].plot(kind='scatter', x=p, y='val_pred_ppl_min', c='train_pred_ppl_min', colormap='viridis', ax=ax)
        ax.set_title("%s vs min val PPL" % p)
    plt.show()

In [15]:
scatter_all_tunable(df)


Without Embedding


In [20]:
df[df.embedding_kind.isnull()].sort_values('val_pred_ppl_min')[TUNABLE + ['weight_sparsity', 'x_b_norm', 'balance_part_winners', 'mult_integration', 'embedding_kind'] + 
                                   ['val_pred_ppl_min', 'val_pred_acc_max', 'train_pred_ppl_min', 'epoch_val_pred_ppl']]


Out[20]:
m_groups k_winners eps fpartition forget_mu boost_strength boost_strength_factor weight_sparsity x_b_norm balance_part_winners mult_integration embedding_kind val_pred_ppl_min val_pred_acc_max train_pred_ppl_min epoch_val_pred_ppl
86 4000 15 0.500000 0.9 0.025000 1.000000 0.97 None True False True NaN 149.533360 23.207171 65.229777 50
90 4000 15 0.500000 0.9 0.020000 1.000000 0.80 None True True True NaN 155.133020 22.570667 84.297827 47
49 4000 15 0.500000 0.8 0.025000 1.000000 0.97 None True False True NaN 157.728945 22.343958 57.851861 45
57 4000 15 0.500000 0.9 0.020000 1.000000 1.00 None True True True NaN 160.709427 22.469333 84.384822 45
48 5000 15 0.500000 0.85 0.010000 1.000000 1.00 0.3 True False True NaN 160.872428 22.691899 68.133196 55
43 7000 15 0.500000 0.85 0.010000 1.500000 1.00 NaN False False True NaN 162.030516 22.363878 45.408591 11
47 7000 15 0.500000 0.85 0.010000 1.000000 1.00 0.3 True False True NaN 162.083793 22.406375 86.599488 54
50 3000 15 0.500000 0.85 0.035000 1.000000 0.97 None True False True NaN 164.965017 22.027888 110.877298 30
51 2000 15 0.500000 0.5 0.035000 1.000000 0.97 None True False True NaN 165.441337 22.013280 89.712456 36
45 8000 15 0.500000 0.8 0.010000 1.000000 1.00 NaN True False True NaN 168.384175 21.865870 50.688927 34
87 3000 15 0.500000 0.85 0.035000 1.000000 0.97 None False False True NaN 168.891307 22.479416 63.550150 15
44 7000 15 0.500000 0.85 0.010000 1.000000 1.00 NaN True False True NaN 171.233200 21.512616 88.577685 24
58 4000 20 0.500000 0.3 0.020000 1.000000 1.00 None False True True NaN 172.745524 22.017333 87.597685 25
39 7000 20 0.500000 0.5 0.010000 1.500000 1.00 NaN False False True NaN 173.047570 21.809326 57.430864 10
31 7000 24 0.593219 0.6 0.016094 0.754832 1.00 NaN False False True NaN 173.582597 21.734666 59.114142 12
40 7000 20 0.500000 0.5 0.010000 1.500000 1.00 NaN False False True NaN 175.468927 21.870941 57.844657 12
42 7000 15 0.500000 0.7 0.010000 1.500000 1.00 NaN False False True NaN 175.497827 21.528552 68.451441 20
92 4000 50 0.200000 0.4 0.020000 1.000000 0.80 None False True True NaN 175.782292 21.996296 78.870546 23
34 7000 30 0.669489 0.6 0.018861 1.169682 1.00 NaN False False True NaN 177.845409 21.659173 61.190414 8
41 7000 7 0.500000 0.5 0.010000 1.000000 1.00 NaN False False True NaN 178.192506 20.714475 112.657826 9
18 5221 22 0.487196 0.452876 0.010000 1.475382 1.00 NaN False False True NaN 179.147110 21.379961 55.514601 7
88 2000 15 0.500000 0.5 0.035000 1.000000 0.97 0.3 True False True NaN 180.802274 20.818061 109.780877 42
32 7000 7 0.588474 0.6 0.022133 1.014024 1.00 NaN False False True NaN 181.118973 20.636692 98.470202 19
30 7000 8 0.634966 0.6 0.019176 0.844524 1.00 NaN False False True NaN 181.916666 20.655010 94.701623 20
89 4000 80 0.500000 0.5 0.020000 1.000000 1.00 None False True True NaN 184.283028 21.446215 85.464933 14
60 4000 50 0.000000 0.4 0.020000 1.000000 0.80 None False True True NaN 184.735396 21.672840 88.586975 21
38 7000 30 0.500000 0.5 0.005000 2.000000 1.00 NaN False False True NaN 186.189591 21.035803 54.449358 6
53 4000 80 0.500000 0.5 0.020000 1.000000 1.00 None True True True NaN 188.182702 20.460823 155.715283 17
35 7000 34 0.796611 0.6 0.022308 0.813055 1.00 NaN False False True NaN 188.414072 20.977241 55.959882 7
4 5000 40 0.450389 0.119279 0.000000 1.861009 1.00 NaN False False True NaN 188.646976 20.774355 99.860362 10
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
65 5000 50 0.000000 1 0.020000 1.000000 1.00 None False True True NaN 244.667202 17.392593 198.199052 7
21 7000 13 0.458370 0.437357 0.007977 0.766594 1.00 NaN False False True NaN 252.210784 17.983902 180.376648 5
85 8000 15 0.500000 0.9 0.010000 1.000000 1.00 0.3 True False True NaN 252.771757 16.804781 219.026779 12
46 8000 15 0.500000 0.8 0.010000 1.000000 1.00 0.3 True False True NaN 263.288485 16.568393 244.956938 10
5 5000 439 0.000000 0.476656 0.000000 1.116153 1.00 NaN False False True NaN 264.066107 18.587843 174.432677 2
95 4000 50 0.000000 None 0.020000 1.000000 0.80 None False False False NaN 267.040836 17.700000 199.254060 73
7 5466 271 0.363649 0.389563 0.000000 0.522465 1.00 NaN False False True NaN 277.124967 15.998612 257.620266 3
25 7000 10 0.402180 0.397973 0.009692 1.125165 1.00 NaN False False True NaN 281.013971 16.723009 205.949667 7
67 4000 40 0.000000 0.15 0.020000 1.000000 1.00 None False True True NaN 282.159515 18.082716 248.627352 10
23 7000 26 0.643324 0.328484 0.005265 0.948042 1.00 NaN False False True NaN 282.259688 16.156814 211.187379 4
22 7000 10 0.672613 0.301799 0.006074 0.685615 1.00 NaN False False True NaN 286.045803 15.492090 233.932968 7
24 7000 26 0.594523 0.364394 0.009932 0.678338 1.00 NaN False False True NaN 287.956944 15.714405 241.470804 6
10 5190 164 0.000000 0.17483 0.000000 1.105501 1.00 NaN False False True NaN 288.382356 17.359700 328.284816 0
27 7000 6 0.483391 0.310452 0.006410 0.936584 1.00 NaN False False True NaN 299.312030 15.799611 236.572892 7
28 7000 21 0.452550 0.30867 0.009381 0.766783 1.00 NaN False False True NaN 313.493275 15.946711 243.237225 4
12 5453 15 0.000000 0.211701 0.000000 1.067430 1.00 NaN False False True NaN 322.009379 16.187344 257.327728 2
73 4000 150 0.000000 0.35 0.020000 1.000000 0.80 None False True True NaN 322.287556 16.612346 305.678443 4
20 7000 23 0.691022 0.434081 0.006510 1.033591 1.00 NaN False False True NaN 323.921014 14.665279 281.367261 4
63 5000 50 0.500000 0.35 0.020000 1.000000 1.00 None False True True NaN 343.306769 16.451852 345.009104 12
1 5000 899 0.000000 0.491606 0.000000 1.552303 1.00 NaN False False True NaN 356.124656 14.759367 234.401898 1
14 5910 190 0.548464 0.266009 0.010000 1.261804 1.00 NaN False False True NaN 360.805184 12.948931 526.293090 6
2 5000 921 0.472143 0.0910026 0.000000 0.502225 1.00 NaN False False True NaN 361.819155 13.656120 260.168279 5
72 4000 150 0.000000 0.35 0.020000 1.000000 0.80 None True True True NaN 416.070079 13.750617 454.064077 3
64 5000 50 0.000000 1 0.020000 1.000000 1.00 None False True True NaN 455.708744 12.832099 591.038453 0
71 4000 15 0.500000 0.45 0.020000 1.000000 0.80 None True True True NaN 458.519413 12.670370 542.036964 2
59 4000 80 0.500000 0.4 0.020000 1.000000 0.80 None False True True NaN 478.466163 12.965432 603.494092 0
61 5000 50 0.500000 0.35 0.020000 1.000000 1.00 None False True True NaN 507.003820 12.579012 596.373978 0
70 2000 6 0.000000 0.25 0.020000 1.000000 1.00 None True True True NaN 509.343218 10.407407 560.517984 2
62 5000 50 0.500000 0.35 0.020000 1.000000 1.00 None False True True NaN 527.973285 12.235802 649.732477 0
0 5000 249 0.840724 0.0998685 0.000000 0.561221 1.00 NaN False False True NaN 581.734947 8.320289 309.324191 0

83 rows × 16 columns

With embedding


In [21]:
df[~df.embedding_kind.isnull()].sort_values('val_pred_ppl_min')[TUNABLE + ['weight_sparsity', 'x_b_norm', 'balance_part_winners', 'mult_integration', 'embedding_kind'] + 
                                   ['val_pred_ppl_min', 'val_pred_acc_max', 'train_pred_ppl_min', 'epoch_val_pred_ppl']]


Out[21]:
m_groups k_winners eps fpartition forget_mu boost_strength boost_strength_factor weight_sparsity x_b_norm balance_part_winners mult_integration embedding_kind val_pred_ppl_min val_pred_acc_max train_pred_ppl_min epoch_val_pred_ppl
36 1000 25 0.00 None 0.02 1.0 0.8 None True False False ptb_fasttext 134.686125 23.059259 65.109254 23
81 4000 100 0.50 None 0.02 1.0 0.8 None False False False ptb_fasttext 177.360827 20.409877 123.929287 32
84 4000 100 0.50 None 0.02 1.0 0.8 None False False False ptb_fasttext 178.504244 20.375309 111.730235 47
75 4000 50 0.00 0.35 0.02 1.0 0.8 None False False True glove 180.837370 21.213580 128.259115 70
78 4000 100 0.00 0.35 0.02 1.0 0.9 None False False False glove 191.500516 21.032099 141.036298 30
94 5000 50 0.50 0.35 0.02 1.0 1.0 None False True True glove 248.209167 18.882716 250.072538 9
82 4000 100 0.50 None 0.02 1.0 0.8 None False False False ptb_fasttext 285.422338 18.162963 296.663770 2
76 4000 100 0.85 None 0.02 1.0 0.8 None False False False glove 288.996243 14.880247 273.512578 23
77 4000 100 0.00 0.35 0.02 1.0 0.9 None False False False glove 291.829646 18.622222 323.001734 3
93 5000 50 0.50 0.35 0.02 1.0 1.0 None False True True glove 434.616214 13.854321 547.006341 0
80 4000 100 0.50 None 0.02 1.0 0.8 None False False False ptb_fasttext 938.758771 10.223457 1161.010146 0
79 4000 100 0.85 None 0.02 1.0 0.8 None False False False ptb_fasttext 961.044750 7.590123 1264.578652 0
83 4000 100 0.50 None 0.02 1.0 0.8 None False False False ptb_fasttext 985.821118 9.286420 1054.394551 0

In [41]:
fig, axs = plt.subplots(1, 2, dpi=144, figsize=(10, 5), gridspec_kw={'wspace': 0.5})
df[df.val_pred_ppl_min < 200].plot(kind='scatter', x='k_winners', y='fpartition', c='val_pred_ppl_min', colormap='viridis', ax=axs[0])
df[df.val_pred_ppl_min < 200].plot(kind='scatter', x='k_winners', y='m_groups', c='val_pred_ppl_min', colormap='viridis', ax=axs[1])
plt.show()



In [42]:
df['norm_int'] = 0
df.loc[df.x_b_norm, 'norm_int'] = 1
df.plot(kind='scatter', x='norm_int', y='epoch_val_pred_ppl')
plt.title("Normalization slows down time-to-peak")


Out[42]:
Text(0.5, 1.0, 'Normalization slows down time-to-peak')