In [26]:
%matplotlib inline
import sys
sys.path.append("../src")
from graphprobe import loadviz, summarize
from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt

In [27]:
category_colors = OrderedDict([                                                
       ('scene',    '#3288bd'),                                                   
       ('object',   '#99d594'),                                                   
       ('part',     '#e6f598'),                                                   
      ('material', '#fee08b'),                                                   
      ('texture',  '#fc8d59'),                                                   
      ('color',    '#d53e4f'),                                                  
      ('total',    '#aaaaaa')])
threshold = 0.05

In [28]:
directory = "../dissection/alexnet_imagenet_full_conv_384/"
blobs = ["conv1", "conv2","conv3", "conv4", "conv5", "fc6-conv", "fc7-conv", "fc8-conv"]

In [29]:
summarize(loadviz(directory, blobs[0]), threshold)


Out[29]:
defaultdict(float,
            {'color': 0.05208333333333333,
             'object': 0.010416666666666666,
             'texture': 0.05208333333333333,
             'total': 0.11458333333333334})

In [38]:
data = []                                            
categories = set(category_colors.keys())                               
# infer variables$                                                      
# use blob name as labels if label if not specified$                    
                                        
# save file as directories/graph.png$                                                                                                        
                                
for blob in blobs:                       
    stats = summarize(loadviz(directory, blob), threshold,
                      top_only=True) 
    data.append(stats)                                
    categories.update(stats.keys())
x = range(1, len(data) + 1)                                     
maxval = 0                                             
plt.figure(num=None, figsize=(12, 8))
for cat in category_colors.keys():                           
    if cat not in categories:                                      
        continue            
    if cat == "total":
        continue
    dat = [d[cat] for d in data]
    maxval = max(maxval, max(dat))                                     
    plt.plot(x, dat, 'o-' if cat != 'total' else 's--',
             color=category_colors[cat], label=cat)                 
                                                      
plt.xticks(x, blobs)           
                                                                               
plt.margins(0.1)                                                    
plt.legend(loc="upper left")                                                                                     
plt.ylim(-maxval * 0.05, maxval * 1.5)
ax = plt.gca()                                                        
ax.yaxis.grid(True)
for side in ['top', 'bottom', 'right', 'left']:
    ax.spines[side].set_visible(False)                         
ax.xaxis.set_ticks_position('bottom')                                  
plt.title("alexnet_imagenet_number of detecotes")                                           
plt.ylabel('portion of units alinged to a category concept')

file_path = '/home/nakamura/network_dissection/NetDissect/research/alexnet_imagenet_num_detectors_1027.ipynb'
plt.text(2.0,0.7,file_path)
plt.savefig(directory + "graph.pdf")
plt.show()