In [197]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import csv
In [198]:
with open('stats_bees.csv') as csvfile:
reader = csv.DictReader(csvfile)
stats_full = list(reader)
In [200]:
eval_examples = 153
for s in stats_full:
train_mode = None
if s['retrained'] == 'True':
if s['shallow_retrain'] == 'True':
train_mode = 'shallow'
else:
train_mode = 'deep'
else:
train_mode = 'from_scratch'
s['train_mode'] = train_mode
s['fps'] = 1.0 / (float(s['eval_time']) / eval_examples)
In [255]:
#Color
N = 6
cmx = cm.rainbow(np.random.rand(3))
c_map = {'shallow' : cmx[0], 'deep' : cmx[1], 'from_scratch': cmx[2]}
cmx = cm.rainbow(range(0, N *50, 50))
In [291]:
plt.figure(figsize=(10, 6))
marker = {'shallow' : 'o', 'deep' : 'h', 'from_scratch' : 's'}
props = dict(boxstyle='round', facecolor='w', alpha=0.7)
for t in c_map.keys():
stats = list(filter(lambda s: s['train_mode']==t, stats_full))
colors = [cmx[j] for j in range(N)]
train_mode = [s['train_mode'] for s in stats]
#Size
fps = [float(s['fps'])*20.0 for s in stats]
#Labels
names = [s['name'] for s in stats]
#X
training_time = [float(s['training_time']) for s in stats]
#Y
accuracy = [float(s['accuracy']) for s in stats]
plt.scatter(training_time, accuracy, s=fps, c=colors,\
label=t, edgecolor='black', marker=marker[t])
if t == 'deep':
for i in range(len(stats)):
plt.annotate(names[i],xy=(training_time[i], accuracy[i]), \
fontsize=12, va='center', bbox=props)
plt.xlabel('Training time (secs)', fontsize=16)
plt.ylabel('Accuracy', fontsize=16)
plt.legend(loc='lower right', fontsize=12, )
min_fps = min((s['fps'] for s in stats))
max_fps = max((s['fps'] for s in stats))
plt.text(400, 0.55, 'size = fps(%d-%d)' % (max_fps, min_fps), fontsize=14,
va='bottom', ha='center')
plt.title('Transfer learning shootout for PyTorch\'s model zoo',y=1.05, fontsize=16)
plt.show()
In [ ]: