Convergent Learning: Do different neural networks learn the same representations?

This IPython Notebook contains source code necessary to reproduce the results presented in the following paper:

@inproceedings{li_2016_ICLR
  title={Convergent Learning: Do different neural networks learn the same representations?},
  author={Li, Yixuan and Yosinski, Jason and Clune, Jeff and Lipson, Hod and Hopcroft, John},
  booktitle={International Conference on Learning Representation (ICLR '16)},
  year={2016}
}

arXiv link: http://arxiv.org/1511.07543

Import libs and functions


In [2]:
from load_data import *
from plotting import *
from pylab import *
from match_unit import *
%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

In [57]:
from sklearn.cluster import AgglomerativeClustering

Download data necessary to reproduce the experimental results


In [3]:
# URL to download
# TO BE ADDED

Load data for all nets in group


In [4]:
# End with slash
net_dirs = [
    'net0',
    'net1',
    'net2',
    'net3',
    ]
net_paths = ['../data/' + dd for dd in net_dirs]

In [5]:
# All the layers in Alex Net
layers=['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7', 'fc8']

# Load activation mean values, covariance matrices and self-correlation matrices for all the nets, all layers
l_means, l_outers, l_cors = dict(), dict(), dict()
for layer in layers:
    means,outers,cors = [], [], []
    for net_path in net_paths:
        moc = read_val_single_moc(net_path, layer)
        means.append(moc[0])
        outers.append(moc[1])
        cors.append(moc[2])
    l_means[layer] = means
    l_outers[layer] = outers
    l_cors[layer] = cors

In [6]:
# Load cross-correlation matrices for all pair of nets, all layers
l_xcors = {}
for layer in layers:
    l_xcors[layer] = {}
    for net0 in range(len(net_dirs)-1):
        l_xcors[layer][net0] = {}
        for net1 in range(net0+1, len(net_dirs)):
            # To make life simpler, symlinks are all in first net_path directory!
            combo = '%d%d' % (net0, net1)
            l_xcors[layer][net0][net1] = read_val_double_cors(net_paths[0], layer, combo)
            #print 'l_xcors[%s][%d][%d] = %s' % (layer, net0, net1, l_xcors[layer][net0][net1].shape)

Plot unit mean activations

Below we show the mean activations for each unit of four networks, plotted in sorted order from highest to lowest. First and most saliently, we see a pattern of widely varying mean activation values across units, with a gap between the most active and least active units of one or two orders of magnitude (depending on the layer). Second, we observe a rough overall correspondence in the spectrum of activations between the networks. However, the correspondence is not perfect: although much of the spectrum matches well, the most active filters converged to solutions of somewhat different magnitudes. For example, the average activation value of the filter on conv2 with the highest average activation varies between 49 to 120 over the four networks; the range for conv1 was 98 to 130.13 This effect is more interesting considering that all filters were learned with constant weight decay, which pushes all individual filter weights and biases (and thus subsequent activations) toward zero with the same force.


In [7]:
netidxs = range(4)

In [8]:
def plot_means(layer, do_legend = True, light = False, fileout=None):
    figsize(12,6)
    for ii,idx in enumerate(netidxs):
        rr = ii/(len(netidxs)-1+1e-6)
        clr = (1, .8*rr, 0)
        dat = -np.sort(-l_means[layer][idx])
        if light:
            rcParams.update({'font.size': 22})
            plot(dat, '-', color=clr, lw=3)
            plot(0, dat[0], 'o', color=clr, ms=12)
            plot(len(dat), dat[-1], 'o', color=clr, ms=12)
        else:
            rcParams.update({'font.size': 16})
            plot(dat, 'o-', color=clr, lw=2, ms=9)
    axis('tight')
    ax = looser(axis(), .01, .02)
    axis(ax[0:2] + (min(ax[2],0),) + ax[3:4])
    if do_legend:
        legend(('Net1', 'Net2', 'Net3', 'Net4'))
    if not light:
        xlabel('Channel number (sorted)')
        ylabel('Mean activation')
    title(layer)
    if fileout:
        savefig(fileout + '.pdf')
        savefig(fileout + '.png')

In [9]:
# Conv1 layer
plot_means('conv1', True, False)



In [10]:
# Conv2 layer
plot_means('conv2', True, False)



In [11]:
# Conv3 layer
plot_means('conv3', True, False)



In [12]:
# Conv4 layer
plot_means('conv4', True, False)



In [13]:
# Conv5 layer
plot_means('conv5', True, False)



In [14]:
# Fc6 layer
plot_means('fc6', True, False)



In [15]:
# Fc7 layer
plot_means('fc7', True, False)


Plot correlation matrices


In [16]:
net0 = 0
net1 = 1
layer = 'conv1'
means, outers, cors = l_means[layer], l_outers[layer], l_cors[layer]

In [17]:
# Plot the within-net correlation matrices of net0 and net1 (conv1 layer). 
# See Fig 1 (a)(b) in http://arxiv.org/pdf/1511.07543v2.pdf
figsize(16,8)
subplot(1,2,1)
showimagesc(cors[net0])
subplot(1,2,2)
showimagesc(cors[net1])



In [18]:
# Find max bipartite matching of between-net correlation matrix
match_no_diag = max_match_order(l_xcors["conv1"][0][1])
loop_order, loop_len = follow_loops(match_no_diag)


After   0 objects hashed, hash is da39
    After   1 objects hashed, hash is 71f3 (latest <type 'numpy.ndarray'>)
After   2 objects hashed, hash is 71f3 (latest <type 'tuple'>)
After   3 objects hashed, hash is f322 (latest <type 'dict'>)
 -> cache.py: max_match_order: trying to load file /home/yli24/.pycache/f3/f322fb10a9089222.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache hit (0.0004s hash overhead, 0.0018s to load, saved 10.5678s)
   -> loaded /home/yli24/.pycache/f3/f322fb10a9089222.max_match_order.pkl.gz

In [19]:
# Between-net correlation for Net1 vs. Net2. 
figsize(16,8)
subplot(1,2,1)
showimagesc(l_xcors["conv1"][0][1])
subplot(1,2,2)
showimagesc(permute_matrix(l_xcors["conv1"][0][1],loop_order))


Load feature visualizations


In [20]:
vis_layer_net0 = vis_for_layer(net_paths[net0], layer)
vis_layer_net1 = vis_for_layer(net_paths[net1], layer)

In [21]:
# Load unit visualization
# The loading process might take a few minutes...
l_vis_unit_all = {}
for ll in ['conv1', 'conv2', 'conv3','conv4','conv5']:
    print 'loading vis for layer', ll
    l_vis_unit_all[ll] = {}
    for idx in netidxs:
        l_vis_unit_all[ll][idx] = [stacked_vis_for_unit(net_paths[idx], ll, uu) for uu in range(l_means[ll][idx].shape[0])]


loading vis for layer conv1
loading vis for layer conv2
loading vis for layer conv3
loading vis for layer conv4
loading vis for layer conv5

In [22]:
l_descending_means = {}
for ll in layers:
    l_descending_means[ll] = {}
    for idx in netidxs:
        l_descending_means[ll][idx] = (-l_means[ll][idx]).argsort()

One-to-one alignment between features learned by different neural networks

Alignment via correlation - greedy matching

In greedy matching, for each unit in the first net, we find the unit in the second net with maximum correlation to it, which is the max along each row in the between-net correlation matrix.

For all layers visualized, (1) the most correlated filters are near perfect matches, showing that many similar features are learned by independently trained neural networks, and (2) the least correlated features show that many features are learned by one network and are not learned by the other network, at least not by a single neuron in the other network.


In [23]:
def plot_top_bot_matches(layer, xcor, top = True, fileout = None, crop=False):
    """We show for each unit the top 9 image patches that cause the highest activations to it,
       as well as the deconv visualization (Zeiler et al. 2014) associated with each."""
    if crop: 
        figsize(15,3.5)
    else:
        figsize(15,7)
    
        
    fig=gcf()
    
    order = (-xcor.max(1)).argsort()
    N = 8
    for kk in range(N):
        if top:
            unit_ii = order[kk]
        else:
            unit_ii = order[kk-N]
        unit_jj = xcor[unit_ii].argmax()
        subplot(2, N, kk+1)
        if crop:
            imshow(crop_one_patch(l_vis_unit_all[layer][net0][unit_ii]))
        else:
            imshow(l_vis_unit_all[layer][net0][unit_ii])
        axis('off')
        title('%.2f' % xcor[unit_ii,unit_jj])
        
        subplot(2, N, N+kk+1)
        if crop:
            imshow(crop_one_patch(l_vis_unit_all[layer][net1][unit_jj]))
        else:
            imshow(l_vis_unit_all[layer][net1][unit_jj])
        axis('off')
    subplots_adjust(wspace=0.05, hspace=.05)
    if fileout:
        savefig(fileout + '.png',dpi=200)
        savefig(fileout + '.pdf',dpi=200)

In [24]:
# Conv1, best match
plot_top_bot_matches('conv1', l_xcors['conv1'][net0][net1], True)



In [25]:
# Conv1, worst match
plot_top_bot_matches('conv1', l_xcors['conv1'][net0][net1], False)



In [26]:
# Conv2, best match
plot_top_bot_matches('conv2', l_xcors['conv2'][net0][net1], True)



In [27]:
# Conv2, worst match
plot_top_bot_matches('conv2', l_xcors['conv2'][net0][net1], False)



In [28]:
# Conv3, best match
plot_top_bot_matches('conv3', l_xcors['conv3'][net0][net1], True)



In [29]:
# Conv3, worst match
plot_top_bot_matches('conv3', l_xcors['conv3'][net0][net1], False)



In [30]:
# Conv4, best match
plot_top_bot_matches('conv4', l_xcors['conv4'][net0][net1], True)



In [31]:
# Conv4, worst match
plot_top_bot_matches('conv4', l_xcors['conv4'][net0][net1], False)


Alignment via mutual information

Because correlation is a relatively simple mathematical metric that may miss some forms of sta- tistical dependence, we also performed one-to-one alignments of neurons by measuring the mutual information between them. Mutual information measures how much knowledge one gains about one variable by knowing the value of another.

We apply the same matching technique (greedy matching) to the between-net mutual information matrix, and compare the highest and lowest mutual information matches to those obtained above via correlation. The results are qualitatively the same. For example, seven out of eight best matched pairs in the conv1 layer stay the same. These results suggest that correlation is an adequate measurement of the similarity between two neurons, and that switching to a mu- tual information metric would not qualitatively change the correlation-based conclusions presented above.

See Fig S4 in http://arxiv.org/pdf/1511.07543v2.pdf


In [32]:
# Here we only consider the mutual information matrices between net0 and net1, focusing on conv1 and conv2 layers
conv1_mi = np.load(net_paths[0]+"/val_mi_01/"+"conv1_final.pkl")
conv2_mi = np.load(net_paths[0]+"/val_mi_01/"+"conv2_final.pkl")

In [33]:
# Conv1, best match
plot_top_bot_matches('conv1', conv1_mi, True)



In [34]:
# Conv1, worsr match
plot_top_bot_matches('conv1', conv1_mi, False)



In [35]:
# Conv2, best match
plot_top_bot_matches('conv2', conv2_mi, True)



In [36]:
# Conv2, worst match
plot_top_bot_matches('conv2', conv2_mi, False)


Greedy matching vs. max matching

Below we show comparisons of assignments produced by the greedy matching and max bipartite matching methods for the conv1-conv5 layers.


In [37]:
def plot_match_vs_max(match, xcor, fileout = None):
    figsize(15,6)
    rcParams.update({'font.size': 18})

    match_vals = xcor[arange(xcor.shape[0]),match]
    diag_order = (-xcor.max(1)).argsort()
    
    hmax,=plot(xcor[diag_order,:].max(1), 'o-', color=(.21/.84,.84/.84,.57/.84), lw=2, ms=10)
    hmatch,=plot(match_vals[diag_order], 'o-', color=(0,.45,.25), lw=2, ms=5)
    print "Avg. correlation (greedy matching)",xcor[diag_order,:].max(1).mean()
    print "Avg. correlation (max matching)",match_vals[diag_order].mean()
    count = 0.0
    for i in range(len(match_vals[diag_order])):
        if xcor[diag_order,:].max(1)[i] == match_vals[diag_order][i]:
            count += 1
    print "Overlapping ratio (when greedy matching and max matching returns the same result):",count / len(match_vals[diag_order])
    xlabel('unit index (sorted by correlation of greedy max assignment)')
    ylabel('correlation with assigned unit')
    #_=title('Match correlation high to low')
    
    axis('tight')
    ax = looser(axis(), .015, .03)
    axis(ax[0:2] + (min(ax[2],0),) + ax[3:4])
    #legend(('Net1', 'Net2'), position='best')
    legend((hmax,hmatch), ('greedy match assignment', 'max match assignment'), loc=3)
    if fileout:
        savefig(fileout + '.pdf')
        savefig(fileout + '.png')

In [38]:
print 'net0/net1 are', net0, net1
l_matches = {}
for ll in ['conv1','conv2','conv3','conv4','conv5']:
    # Caching will be used when computing the max bipartite matchings.
    l_matches[ll] = max_match_order(l_xcors[ll][net0][net1], ignore_diag=False)


net0/net1 are 0 1
After   0 objects hashed, hash is da39
    After   1 objects hashed, hash is 71f3 (latest <type 'numpy.ndarray'>)
After   2 objects hashed, hash is 71f3 (latest <type 'tuple'>)
    After   3 objects hashed, hash is 7161 (latest <type 'bool'>)
After   4 objects hashed, hash is 7161 (latest <type 'dict'>)
 -> cache.py: max_match_order: trying to load file /home/yli24/.pycache/71/7161993862e05054.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache hit (0.0004s hash overhead, 0.0013s to load, saved 11.0964s)
   -> loaded /home/yli24/.pycache/71/7161993862e05054.max_match_order.pkl.gz
After   0 objects hashed, hash is da39
    After   1 objects hashed, hash is 339d (latest <type 'numpy.ndarray'>)
After   2 objects hashed, hash is 339d (latest <type 'tuple'>)
    After   3 objects hashed, hash is 8c5f (latest <type 'bool'>)
After   4 objects hashed, hash is 8c5f (latest <type 'dict'>)
 -> cache.py: max_match_order: trying to load file /home/yli24/.pycache/8c/8c5fa0abc080b9a9.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache hit (0.0008s hash overhead, 0.0008s to load, saved 208.2695s)
   -> loaded /home/yli24/.pycache/8c/8c5fa0abc080b9a9.max_match_order.pkl.gz
After   0 objects hashed, hash is da39
    After   1 objects hashed, hash is 3d6b (latest <type 'numpy.ndarray'>)
After   2 objects hashed, hash is 3d6b (latest <type 'tuple'>)
    After   3 objects hashed, hash is 19a4 (latest <type 'bool'>)
After   4 objects hashed, hash is 19a4 (latest <type 'dict'>)
 -> cache.py: max_match_order: trying to load file /home/yli24/.pycache/19/19a49ec1a3f9d332.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache hit (0.0015s hash overhead, 0.0009s to load, saved 689.4877s)
   -> loaded /home/yli24/.pycache/19/19a49ec1a3f9d332.max_match_order.pkl.gz
After   0 objects hashed, hash is da39
    After   1 objects hashed, hash is 754d (latest <type 'numpy.ndarray'>)
After   2 objects hashed, hash is 754d (latest <type 'tuple'>)
    After   3 objects hashed, hash is d4f8 (latest <type 'bool'>)
After   4 objects hashed, hash is d4f8 (latest <type 'dict'>)
 -> cache.py: max_match_order: trying to load file /home/yli24/.pycache/d4/d4f8be8aa6a85804.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache hit (0.0014s hash overhead, 0.0008s to load, saved 702.3546s)
   -> loaded /home/yli24/.pycache/d4/d4f8be8aa6a85804.max_match_order.pkl.gz
After   0 objects hashed, hash is da39
    After   1 objects hashed, hash is 9972 (latest <type 'numpy.ndarray'>)
After   2 objects hashed, hash is 9972 (latest <type 'tuple'>)
    After   3 objects hashed, hash is faae (latest <type 'bool'>)
After   4 objects hashed, hash is faae (latest <type 'dict'>)
 -> cache.py: max_match_order: trying to load file /home/yli24/.pycache/fa/faae8950c7ad14e4.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache hit (0.0007s hash overhead, 0.0008s to load, saved 204.6326s)
   -> loaded /home/yli24/.pycache/fa/faae8950c7ad14e4.max_match_order.pkl.gz

In [39]:
# Conv1 (see Fig 3 in paper)
plot_match_vs_max(l_matches['conv1'], l_xcors['conv1'][net0][net1])


Avg. correlation (greedy matching) 0.70299
Avg. correlation (max matching) 0.662868
Overlapping ratio (when greedy matching and max matching returns the same result): 0.729166666667

In [40]:
plot_match_vs_max(l_matches['conv2'], l_xcors['conv2'][net0][net1])


Avg. correlation (greedy matching) 0.55595
Avg. correlation (max matching) 0.50739
Overlapping ratio (when greedy matching and max matching returns the same result): 0.578125

In [41]:
plot_match_vs_max(l_matches['conv3'], l_xcors['conv3'][net0][net1])


Avg. correlation (greedy matching) 0.477509
Avg. correlation (max matching) 0.445778
Overlapping ratio (when greedy matching and max matching returns the same result): 0.6171875

In [42]:
plot_match_vs_max(l_matches['conv4'], l_xcors['conv4'][net0][net1])


Avg. correlation (greedy matching) 0.368536
Avg. correlation (max matching) 0.341896
Overlapping ratio (when greedy matching and max matching returns the same result): 0.580729166667

In [43]:
plot_match_vs_max(l_matches['conv5'], l_xcors['conv5'][net0][net1])


Avg. correlation (greedy matching) 0.41557
Avg. correlation (max matching) 0.38489
Overlapping ratio (when greedy matching and max matching returns the same result): 0.64453125

The most and least active units

The most active and least active conv1 filters for Net1 – Net4, with average activation values printed above each filter. The most active filters generally respond to low spatial frequencies, and the least active filtered to high spatial frequencies, but the lack of alignment is interesting. (see Fig S.11 in http://arxiv.org/pdf/1511.07543v2.pdf)


In [44]:
def plot_top_bot_few(layer, top = True, fileout=None):
    figsize(20,8)
    fig=gcf()
    #fig.set_facecolor((.5,.5,.5))
    N = 8
    for ii,idx in enumerate(netidxs):
        for jj in range(N):
            subplot(len(netidxs),N,ii*N+jj+1)
            if top:
                unit_idx = l_descending_means[layer][idx][jj]
            else:
                unit_idx = l_descending_means[layer][idx][jj-N]
            imshow(crop_one_patch(l_vis_unit_all[layer][idx][unit_idx]))
            axis('off')
            #title('%g' % int(round(l_means[layer][idx][unit_idx])), fontsize=20)
            title('%.1f' % l_means[layer][idx][unit_idx], fontsize=20)
    subplots_adjust(wspace=0, hspace=.25)
    if fileout:
        savefig(fileout + '.png')
        savefig(fileout + '.pdf')

In [45]:
# Conv1, most active filters
plot_top_bot_few('conv1', True)



In [46]:
# Conv1, least active filters
plot_top_bot_few('conv1', False)



In [47]:
# Conv2, most active filters
plot_top_bot_few('conv2', True)



In [48]:
# Conv2, least active filters
plot_top_bot_few('conv2', False)


Find sparse, few-to-one mappings between features learned by different neural networks

We can relax this one-to-one constraint to various degrees by learning a mapping layer with an L1 penalty (known as a LASSO model, (Tibshirani, 1996)), where stronger penalties will lead to sparser (more few-to-one or one-to-one) mappings.

More specifically, to predict one layer’s representation from another, we learn a single mapping layer from one to the other (similar to the “stitching layer” in Lenc & Vedaldi (2015)). In the case of convolutional layers, this mapping layer is a convolutional layer with 1 × 1 kernel size and number of output channels equal to the number of input channels. The mapping layer’s parameters can be considered as a square weight matrix with side length equal to the number of units in the layer; the layer learns to predict any unit in one network via a linear weighted sum of any number of units in the other.

Load the trained sparse prediction ("stitching") models


In [49]:
def read_losses(filename):
    ret = {
        'train_idx': [],
        'train_loss': [],
        'val_idx': [],
        'val_loss': [],
    }
    with open(filename, 'r') as ff:
        for line in ff:
            # Looking for lines like this:
            # 0 val loss is 1.01618254185
            # 0 train loss is 1.01618254185
            fields = line.split()
            if len(fields) != 5:
                continue
            if (fields[2],fields[3]) == ('loss','is') and fields[1] in ('train', 'val'):
                idx = int(fields[0])
                loss = float(fields[4])
                if fields[1] == 'train':
                    ret['train_idx'].append(idx)
                    ret['train_loss'].append(loss)
                else:
                    ret['val_idx'].append(idx)
                    ret['val_loss'].append(loss)
    ret['train_idx'] = array(ret['train_idx'])
    ret['train_loss'] = array(ret['train_loss'])
    ret['val_idx'] = array(ret['val_idx'])
    ret['val_loss'] = array(ret['val_loss'])
    print "sparse prediction model training loss:", ret['val_loss'][-1]
    return ret
    
def load_stitch_net(neta=0, netb=1, layer='conv1', L1=0, iter=5000, with_loss=False):
    stitch_dir = ('../sparse_prediction/net%d_net%d_L1_%g_%s/'
                  % (neta, netb, L1, layer))
    stitch_proto = 'proto/stitch_%s.prototxt' % layer
    stitch_weights = stitch_dir + 'stitch_iter_%d.caffemodel' % iter
    loss_file = stitch_dir + 'log.log'
    snet = np.load(stitch_dir + "stitching.pkl")
    print 'Loaded net: %s' % stitch_weights
    if with_loss:
        losses = read_losses(loss_file)
        return snet, losses
    else:
        return snet

In [50]:
neta = 0
netb = 1
layer = 'conv1'
L1 = -2.6
#layer = 'conv2'
#L1 = -2.8
trainiter = 4000
sparse_mat,losses = load_stitch_net(neta, netb, layer, L1, iter=trainiter, with_loss=True)
sparse_mat = abs(sparse_mat)


Loaded net: ../sparse_prediction/net0_net1_L1_-2.6_conv1/stitch_iter_4000.caffemodel
sparse prediction model training loss: 0.234768867493

In [51]:
# Plot the sparse prediction matrix for conv1 layer (net0 and net1)
figsize(18,8)
subplot(1,2,1)
axhline(43,color='orange')
showimagesc(sparse_mat)
subplot(1,2,2)
showimagesc(permute_cols(sparse_mat, max_match_order(sparse_mat)))


After   0 objects hashed, hash is da39
    After   1 objects hashed, hash is 362f (latest <type 'numpy.ndarray'>)
After   2 objects hashed, hash is 362f (latest <type 'tuple'>)
After   3 objects hashed, hash is cf96 (latest <type 'dict'>)
 -> cache.py: max_match_order: trying to load file /home/yli24/.pycache/cf/cf961343fe74e176.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache miss, computing function
 -> cache.py: max_match_order: function execution finished, saving result to file /home/yli24/.pycache/cf/cf961343fe74e176.max_match_order.pkl.gz
 -> cache.py: max_match_order: cache miss (0.0003s hash overhead, 0.0035s to save, 12.5905s to compute)
   -> saved to /home/yli24/.pycache/cf/cf961343fe74e176.max_match_order.pkl.gz

Look at filters


In [52]:
def get_unit_vis(net, layer, unit, which='deconv'):
    if which == 'all':
        return l_vis_unit_all[layer][net][unit]
    else:
        return crop_one_patch(l_vis_unit_all[layer][net][unit], which)

In [53]:
# Which unit in the second network is being predicted?  (e.g, e.g. the row of ww
iib = 43
figsize(18,6)
plot(sparse_mat[iib,:])
thresh = np.percentile(sparse_mat.flatten(), 50) * 100
jjbig = np.where(sparse_mat[iib,:] > thresh)
jjbig = jjbig[0]    # Just take first dimension indices
figsize(18,6)
for cc,jj in enumerate(jjbig):
    plot(jj,sparse_mat[iib,jj],'ro', ms=10)
    print sparse_mat[iib,jj]
    text(jj,sparse_mat[iib,jj],'  %d' % jj, fontsize=14)
_=title('net %d, unit %d' % (netb, iib))


0.00844492
0.352437
0.265495

In [54]:
showimagesc(get_unit_vis(netb, layer, iib, 'all'))
axis("off")


Out[54]:
(-0.5, 38.5, 77.5, -0.5)

In [55]:
N = len(jjbig)
print jjbig
for cc,jj in enumerate(jjbig):
    subplot(1,N,cc+1)
    axis("off")
    showimagesc(get_unit_vis(neta, layer, jj, 'all'))


[28 33 51]

Find many-to-many mappings using Hierarchical Agglomerative Clustering (HAC)


In [56]:
aw = sparse_mat

figsize(18,12)
subplot(1,2,1)
n_channels = aw.shape[0]
mat = vstack((hstack((eye(n_channels),aw)), hstack((aw.T,eye(n_channels)))))
showimagesc(mat)
subplot(1,2,2)
unit_dists = mat.max() - mat

showimagesc(unit_dists)



In [90]:
np.random.seed(0)
clust = AgglomerativeClustering(n_clusters=2, linkage='average', affinity='precomputed')
clust.fit(unit_dists)


Out[90]:
AgglomerativeClustering(affinity='precomputed', compute_full_tree='auto',
            connectivity=None, linkage='average',
            memory=Memory(cachedir=None), n_clusters=2, n_components=None,
            pooling_func=<function mean at 0x26ad668>)

In [91]:
def children_to_tree(children):
    N = len(children) + 1
    ret = nx.DiGraph()
    ret.add_nodes_from(range(N))
    for cc,node in enumerate(children):
        ll,rr = node
        ret.add_edge(ll,cc+N)
        ret.add_edge(rr,cc+N)
    return ret

In [92]:
unittree = children_to_tree(clust.children_)
unittree_rev = unittree.reverse()
root = [n for n,d in unittree_rev.in_degree().items() if d==0][0]
ordering = [nn for nn in nx.dfs_preorder_nodes(unittree_rev, root) if nn < n_channels*2]
print ordering


[96, 80, 154, 67, 9, 51, 181, 73, 166, 187, 126, 62, 1, 74, 139, 116, 29, 54, 17, 103, 0, 125, 148, 70, 41, 191, 89, 119, 10, 179, 6, 153, 19, 39, 124, 84, 151, 11, 149, 185, 93, 140, 37, 26, 99, 167, 105, 52, 90, 118, 91, 88, 183, 108, 53, 60, 173, 190, 8, 161, 5, 49, 142, 107, 4, 145, 20, 168, 36, 158, 160, 76, 137, 83, 143, 7, 68, 141, 157, 33, 102, 13, 32, 122, 48, 130, 66, 25, 156, 129, 43, 58, 147, 23, 162, 2, 121, 34, 100, 94, 135, 128, 46, 65, 42, 175, 16, 189, 106, 40, 24, 165, 28, 113, 146, 35, 110, 47, 163, 44, 72, 133, 120, 57, 152, 50, 59, 171, 180, 77, 155, 15, 104, 12, 170, 86, 30, 56, 184, 178, 71, 176, 45, 159, 75, 182, 131, 92, 134, 150, 174, 22, 136, 55, 132, 87, 21, 117, 82, 138, 27, 109, 123, 95, 169, 85, 101, 69, 64, 164, 127, 31, 3, 186, 78, 188, 38, 81, 97, 144, 79, 61, 111, 112, 18, 115, 98, 172, 177, 63, 14, 114]

In [93]:
ordering_inv = {}
for pos,uu in enumerate(ordering):
    ordering_inv[uu] = pos
print ordering_inv


{0: 20, 1: 12, 2: 95, 3: 172, 4: 64, 5: 60, 6: 30, 7: 75, 8: 58, 9: 4, 10: 28, 11: 37, 12: 133, 13: 81, 14: 190, 15: 131, 16: 106, 17: 18, 18: 184, 19: 32, 20: 66, 21: 156, 22: 151, 23: 93, 24: 110, 25: 87, 26: 43, 27: 160, 28: 112, 29: 16, 30: 136, 31: 171, 32: 82, 33: 79, 34: 97, 35: 115, 36: 68, 37: 42, 38: 176, 39: 33, 40: 109, 41: 24, 42: 104, 43: 90, 44: 119, 45: 142, 46: 102, 47: 117, 48: 84, 49: 61, 50: 125, 51: 5, 52: 47, 53: 54, 54: 17, 55: 153, 56: 137, 57: 123, 58: 91, 59: 126, 60: 55, 61: 181, 62: 11, 63: 189, 64: 168, 65: 103, 66: 86, 67: 3, 68: 76, 69: 167, 70: 23, 71: 140, 72: 120, 73: 7, 74: 13, 75: 144, 76: 71, 77: 129, 78: 174, 79: 180, 80: 1, 81: 177, 82: 158, 83: 73, 84: 35, 85: 165, 86: 135, 87: 155, 88: 51, 89: 26, 90: 48, 91: 50, 92: 147, 93: 40, 94: 99, 95: 163, 96: 0, 97: 178, 98: 186, 99: 44, 100: 98, 101: 166, 102: 80, 103: 19, 104: 132, 105: 46, 106: 108, 107: 63, 108: 53, 109: 161, 110: 116, 111: 182, 112: 183, 113: 113, 114: 191, 115: 185, 116: 15, 117: 157, 118: 49, 119: 27, 120: 122, 121: 96, 122: 83, 123: 162, 124: 34, 125: 21, 126: 10, 127: 170, 128: 101, 129: 89, 130: 85, 131: 146, 132: 154, 133: 121, 134: 148, 135: 100, 136: 152, 137: 72, 138: 159, 139: 14, 140: 41, 141: 77, 142: 62, 143: 74, 144: 179, 145: 65, 146: 114, 147: 92, 148: 22, 149: 38, 150: 149, 151: 36, 152: 124, 153: 31, 154: 2, 155: 130, 156: 88, 157: 78, 158: 69, 159: 143, 160: 70, 161: 59, 162: 94, 163: 118, 164: 169, 165: 111, 166: 8, 167: 45, 168: 67, 169: 164, 170: 134, 171: 127, 172: 187, 173: 56, 174: 150, 175: 105, 176: 141, 177: 188, 178: 139, 179: 29, 180: 128, 181: 6, 182: 145, 183: 52, 184: 138, 185: 39, 186: 173, 187: 9, 188: 175, 189: 107, 190: 57, 191: 25}

In [94]:
node_row = {}
for ii in range(n_channels*2):
    node_row[ii] = ordering_inv[ii]
for nn in nx.dfs_postorder_nodes(unittree_rev, root):
    if not nn in node_row:
        node_row[nn] = min([node_row[n] for n in unittree_rev.successors(nn)])

In [95]:
leaves_under = {}
for ii in range(n_channels*2):
    leaves_under[ii] = 1
for nn in nx.dfs_postorder_nodes(unittree_rev, root):
    if not nn in leaves_under:
        leaves_under[nn] = sum([leaves_under[n] for n in unittree_rev.successors(nn)])

In [103]:
def plot_tree_mat(mat,ordering,unittree_rev,root,ax=None,savename=None):
    figsize(18,18)
    epsw = .5
    epsi = -.2
    epsj = -.08
    unit_dists = mat.max() - mat
    permmat = permute_matrix(mat, ordering) ** .5
    permmat = tile(permmat.reshape((permmat.shape) + (1,)), ((1,1,3)))
    for ii in range(permmat.shape[0]):
        clr = (1,0,0) if (ordering[ii] > permmat.shape[0]/2) else (0,1,0)
        permmat[ii,ii,:] = clr
    showimagesc(permmat)
    for nn in nx.dfs_postorder_nodes(unittree_rev, root):
        if nn >= n_channels*2:
            ii = node_row[nn]
            jj = ii + leaves_under[nn] - 1
            #text(jj,ii,'  %d,%d' % (ii,jj), color=(.5,1,.5))
            plot([ii-epsw+epsj,jj+epsw+epsj,jj+epsw+epsj], [ii-epsw+epsi,ii-epsw+epsi,jj+epsw+epsi], '-', color="royalblue")
    _=axis('tight')
    if ax:
        ax = gca()
        ax.set_xlim(ax[0:2])
        ax.set_ylim(ax[2:4])
    if savename:
        savefig(savename)

In [104]:
plot_tree_mat(mat,ordering,unittree_rev,root)


Look at the clusters


In [98]:
def vis_block_unit(indices,savename=None):
    print ordering[51]
    unit_indices = [ordering[ii] for ii in zoomin_indices]
    #unit_indices = ordering[zoomin_indices]
    print unit_indices
    N = len(zoomin_indices)
    print zoomin_indices
    for cc,jj in enumerate(unit_indices):
        subplot(1,N,cc+1)
        if jj > 96:
            axis("off")
            title("net2",fontsize=18,color='r')
            showimagesc(get_unit_vis(neta, layer, jj-96, 'all'))
        else:
            axis("off")
            title("net1",fontsize=18,color='g')
            showimagesc(get_unit_vis(netb, layer, jj, 'all'))
    if savename:
        savefig(savename)

In [99]:
zoomin_indices = np.arange(51,59)
vis_block_unit(zoomin_indices)


88
[88, 183, 108, 53, 60, 173, 190, 8]
[51 52 53 54 55 56 57 58]

In [100]:
zoomin_indices = np.arange(148,156)
vis_block_unit(zoomin_indices)


88
[134, 150, 174, 22, 136, 55, 132, 87]
[148 149 150 151 152 153 154 155]

In [ ]: