In [326]:
    
from IPython.display import Markdown, display
%load_ext autoreload
%autoreload 2
    
    
In [327]:
    
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import tabulate
import pprint
import click
import numpy as np
import pandas as pd
from ray.tune.commands import *
from nupic.research.frameworks.dynamic_sparse.common.browser import *
import matplotlib.pyplot as plt
from matplotlib import rcParams
%config InlineBackend.figure_format = 'retina'
import seaborn as sns
sns.set(style="whitegrid")
sns.set_palette("colorblind")
    
In [328]:
    
base = os.path.join('gsc-dsnn-2019-10-11-G-reproduce')
exps = [
    os.path.join(base, exp) for exp in [
#         'gsc-Static', 
        'gsc-SET',
        'gsc-WeightedMag',
    ]
]
paths = [os.path.expanduser("~/nta/results/{}".format(e)) for e in exps]
df = load_many(paths)
    
    
In [329]:
    
# replace hebbian prine
df['hebbian_prune_perc'] = df['hebbian_prune_perc'].replace(np.nan, 0.0, regex=True)
df['weight_prune_perc'] = df['weight_prune_perc'].replace(np.nan, 0.0, regex=True)
display(df.head(5)['hebbian_prune_perc'])
display(df.head(5)['weight_prune_perc'])
    
    
    
In [330]:
    
df.iloc[200:205]
    
    Out[330]:
In [331]:
    
df.columns
    
    Out[331]:
In [332]:
    
df.shape
    
    Out[332]:
In [333]:
    
df.iloc[1]
    
    Out[333]:
In [334]:
    
df.groupby('model')['model'].count()
    
    Out[334]:
Experiment Details
In [335]:
    
# Did any  trials failed?
df[df["epochs"]<30]["epochs"].count()
    
    Out[335]:
In [336]:
    
# Removing failed or incomplete trials
df_origin = df.copy()
df = df_origin[df_origin["epochs"]>=30]
df.shape
    
    Out[336]:
In [337]:
    
# which ones failed?
# failed, or still ongoing?
df_origin['failed'] = df_origin["epochs"]<30
df_origin[df_origin['failed']]['epochs']
    
    Out[337]:
In [338]:
    
# helper functions
def mean_and_std(s):
    return "{:.3f} ± {:.3f}".format(s.mean(), s.std())
def round_mean(s):
    return "{:.0f}".format(round(s.mean()))
stats = ['min', 'max', 'mean', 'std']
def agg(columns, filter=None, round=3):
    if filter is None:
        return (df.groupby(columns)
             .agg({'val_acc_max_epoch': round_mean,
                   'val_acc_max': stats,                
                   'model': ['count']})).round(round)
    else:
        return (df[filter].groupby(columns)
             .agg({'val_acc_max_epoch': round_mean,
                   'val_acc_max': stats,                
                   'model': ['count']})).round(round)
    
In [339]:
    
agg(['model'])
    
    Out[339]:
In [340]:
    
agg(['on_perc'])
    
    Out[340]:
In [341]:
    
def model_name(row):
    
    if row['model'] == 'DSNNWeightedMag':
        return 'DSNN'
    elif row['model'] == 'SET':
        return 'SET'
    elif row['model'] == 'SparseModel':
        return 'Static'
    
    assert False, "This should cover all cases. Got {} h - {} w - {}".format(row['model'], row['hebbian_prune_perc'], row['weight_prune_perc'])
df['model2'] = df.apply(model_name, axis=1)
    
In [342]:
    
fltr = (df['model2'] != 'Sparse') & (df['lr_scheduler'] == "MultiStepLR")
agg(['on_perc', 'model2'], filter=fltr)
    
    Out[342]:
In [343]:
    
# translate model names
rcParams['figure.figsize'] = 16, 8
# d = {
#     'DSNNWeightedMag': 'DSNN',
#     'DSNNMixedHeb': 'SET',
#     'SparseModel': 'Static',        
# }
# df_plot = df.copy()
# df_plot['model'] = df_plot['model'].apply(lambda x, i: model_name(x, i))
    
In [344]:
    
# sns.scatterplot(data=df_plot, x='on_perc', y='val_acc_max', hue='model')
sns.lineplot(data=df, x='on_perc', y='val_acc_max', hue='model')
plt.errorbar(x=[0.02, 0.04], y=[0.75, 0.85], yerr=[0.1, 0.01], color='k', marker='.', lw=0)
    
    Out[344]:
    
In [345]:
    
rcParams['figure.figsize'] = 16, 8
filter = df['model'] != 'Static'
plt.errorbar(
    x=[0.02, 0.04],
    y=[85, 95],
    yerr=[1, 1],
    color='k',
    marker='*',
    lw=0,
    elinewidth=2,
    capsize=2,
    markersize=10,
)
sns.lineplot(data=df[filter], x='on_perc', y='val_acc_max_epoch', hue='model2')
    
    Out[345]:
    
In [346]:
    
plt.errorbar(
    x=[0.02, 0.04],
    y=[0.85, 0.95],
    yerr=[0.01, 0.01],
    color='k',
    marker='.',
    lw=0,
    elinewidth=1,
    capsize=1,
)
    
    Out[346]:
    
In [ ]:
    
    
In [347]:
    
sns.lineplot(data=df, x='on_perc', y='val_acc_last', hue='model2')
    
    Out[347]: