This IPython Notebook contains source code necessary to reproduce the results presented in the following paper:
@inproceedings{li_2016_ICLR
title={Convergent Learning: Do different neural networks learn the same representations?},
author={Li, Yixuan and Yosinski, Jason and Clune, Jeff and Lipson, Hod and Hopcroft, John},
booktitle={International Conference on Learning Representation (ICLR '16)},
year={2016}
}
arXiv link: http://arxiv.org/1511.07543
In [2]:
from load_data import *
from plotting import *
from pylab import *
from match_unit import *
%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
In [57]:
from sklearn.cluster import AgglomerativeClustering
In [3]:
# URL to download
# TO BE ADDED
In [4]:
# End with slash
net_dirs = [
'net0',
'net1',
'net2',
'net3',
]
net_paths = ['../data/' + dd for dd in net_dirs]
In [5]:
# All the layers in Alex Net
layers=['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7', 'fc8']
# Load activation mean values, covariance matrices and self-correlation matrices for all the nets, all layers
l_means, l_outers, l_cors = dict(), dict(), dict()
for layer in layers:
means,outers,cors = [], [], []
for net_path in net_paths:
moc = read_val_single_moc(net_path, layer)
means.append(moc[0])
outers.append(moc[1])
cors.append(moc[2])
l_means[layer] = means
l_outers[layer] = outers
l_cors[layer] = cors
In [6]:
# Load cross-correlation matrices for all pair of nets, all layers
l_xcors = {}
for layer in layers:
l_xcors[layer] = {}
for net0 in range(len(net_dirs)-1):
l_xcors[layer][net0] = {}
for net1 in range(net0+1, len(net_dirs)):
# To make life simpler, symlinks are all in first net_path directory!
combo = '%d%d' % (net0, net1)
l_xcors[layer][net0][net1] = read_val_double_cors(net_paths[0], layer, combo)
#print 'l_xcors[%s][%d][%d] = %s' % (layer, net0, net1, l_xcors[layer][net0][net1].shape)
Below we show the mean activations for each unit of four networks, plotted in sorted order from highest to lowest. First and most saliently, we see a pattern of widely varying mean activation values across units, with a gap between the most active and least active units of one or two orders of magnitude (depending on the layer). Second, we observe a rough overall correspondence in the spectrum of activations between the networks. However, the correspondence is not perfect: although much of the spectrum matches well, the most active filters converged to solutions of somewhat different magnitudes. For example, the average activation value of the filter on conv2 with the highest average activation varies between 49 to 120 over the four networks; the range for conv1 was 98 to 130.13 This effect is more interesting considering that all filters were learned with constant weight decay, which pushes all individual filter weights and biases (and thus subsequent activations) toward zero with the same force.
In [7]:
netidxs = range(4)
In [8]:
def plot_means(layer, do_legend = True, light = False, fileout=None):
figsize(12,6)
for ii,idx in enumerate(netidxs):
rr = ii/(len(netidxs)-1+1e-6)
clr = (1, .8*rr, 0)
dat = -np.sort(-l_means[layer][idx])
if light:
rcParams.update({'font.size': 22})
plot(dat, '-', color=clr, lw=3)
plot(0, dat[0], 'o', color=clr, ms=12)
plot(len(dat), dat[-1], 'o', color=clr, ms=12)
else:
rcParams.update({'font.size': 16})
plot(dat, 'o-', color=clr, lw=2, ms=9)
axis('tight')
ax = looser(axis(), .01, .02)
axis(ax[0:2] + (min(ax[2],0),) + ax[3:4])
if do_legend:
legend(('Net1', 'Net2', 'Net3', 'Net4'))
if not light:
xlabel('Channel number (sorted)')
ylabel('Mean activation')
title(layer)
if fileout:
savefig(fileout + '.pdf')
savefig(fileout + '.png')
In [9]:
# Conv1 layer
plot_means('conv1', True, False)
In [10]:
# Conv2 layer
plot_means('conv2', True, False)
In [11]:
# Conv3 layer
plot_means('conv3', True, False)
In [12]:
# Conv4 layer
plot_means('conv4', True, False)
In [13]:
# Conv5 layer
plot_means('conv5', True, False)
In [14]:
# Fc6 layer
plot_means('fc6', True, False)
In [15]:
# Fc7 layer
plot_means('fc7', True, False)
In [16]:
net0 = 0
net1 = 1
layer = 'conv1'
means, outers, cors = l_means[layer], l_outers[layer], l_cors[layer]
In [17]:
# Plot the within-net correlation matrices of net0 and net1 (conv1 layer).
# See Fig 1 (a)(b) in http://arxiv.org/pdf/1511.07543v2.pdf
figsize(16,8)
subplot(1,2,1)
showimagesc(cors[net0])
subplot(1,2,2)
showimagesc(cors[net1])
In [18]:
# Find max bipartite matching of between-net correlation matrix
match_no_diag = max_match_order(l_xcors["conv1"][0][1])
loop_order, loop_len = follow_loops(match_no_diag)
In [19]:
# Between-net correlation for Net1 vs. Net2.
figsize(16,8)
subplot(1,2,1)
showimagesc(l_xcors["conv1"][0][1])
subplot(1,2,2)
showimagesc(permute_matrix(l_xcors["conv1"][0][1],loop_order))
In [20]:
vis_layer_net0 = vis_for_layer(net_paths[net0], layer)
vis_layer_net1 = vis_for_layer(net_paths[net1], layer)
In [21]:
# Load unit visualization
# The loading process might take a few minutes...
l_vis_unit_all = {}
for ll in ['conv1', 'conv2', 'conv3','conv4','conv5']:
print 'loading vis for layer', ll
l_vis_unit_all[ll] = {}
for idx in netidxs:
l_vis_unit_all[ll][idx] = [stacked_vis_for_unit(net_paths[idx], ll, uu) for uu in range(l_means[ll][idx].shape[0])]
In [22]:
l_descending_means = {}
for ll in layers:
l_descending_means[ll] = {}
for idx in netidxs:
l_descending_means[ll][idx] = (-l_means[ll][idx]).argsort()
In greedy matching, for each unit in the first net, we find the unit in the second net with maximum correlation to it, which is the max along each row in the between-net correlation matrix.
For all layers visualized, (1) the most correlated filters are near perfect matches, showing that many similar features are learned by independently trained neural networks, and (2) the least correlated features show that many features are learned by one network and are not learned by the other network, at least not by a single neuron in the other network.
In [23]:
def plot_top_bot_matches(layer, xcor, top = True, fileout = None, crop=False):
"""We show for each unit the top 9 image patches that cause the highest activations to it,
as well as the deconv visualization (Zeiler et al. 2014) associated with each."""
if crop:
figsize(15,3.5)
else:
figsize(15,7)
fig=gcf()
order = (-xcor.max(1)).argsort()
N = 8
for kk in range(N):
if top:
unit_ii = order[kk]
else:
unit_ii = order[kk-N]
unit_jj = xcor[unit_ii].argmax()
subplot(2, N, kk+1)
if crop:
imshow(crop_one_patch(l_vis_unit_all[layer][net0][unit_ii]))
else:
imshow(l_vis_unit_all[layer][net0][unit_ii])
axis('off')
title('%.2f' % xcor[unit_ii,unit_jj])
subplot(2, N, N+kk+1)
if crop:
imshow(crop_one_patch(l_vis_unit_all[layer][net1][unit_jj]))
else:
imshow(l_vis_unit_all[layer][net1][unit_jj])
axis('off')
subplots_adjust(wspace=0.05, hspace=.05)
if fileout:
savefig(fileout + '.png',dpi=200)
savefig(fileout + '.pdf',dpi=200)
In [24]:
# Conv1, best match
plot_top_bot_matches('conv1', l_xcors['conv1'][net0][net1], True)
In [25]:
# Conv1, worst match
plot_top_bot_matches('conv1', l_xcors['conv1'][net0][net1], False)
In [26]:
# Conv2, best match
plot_top_bot_matches('conv2', l_xcors['conv2'][net0][net1], True)
In [27]:
# Conv2, worst match
plot_top_bot_matches('conv2', l_xcors['conv2'][net0][net1], False)
In [28]:
# Conv3, best match
plot_top_bot_matches('conv3', l_xcors['conv3'][net0][net1], True)
In [29]:
# Conv3, worst match
plot_top_bot_matches('conv3', l_xcors['conv3'][net0][net1], False)
In [30]:
# Conv4, best match
plot_top_bot_matches('conv4', l_xcors['conv4'][net0][net1], True)
In [31]:
# Conv4, worst match
plot_top_bot_matches('conv4', l_xcors['conv4'][net0][net1], False)
Because correlation is a relatively simple mathematical metric that may miss some forms of sta- tistical dependence, we also performed one-to-one alignments of neurons by measuring the mutual information between them. Mutual information measures how much knowledge one gains about one variable by knowing the value of another.
We apply the same matching technique (greedy matching) to the between-net mutual information matrix, and compare the highest and lowest mutual information matches to those obtained above via correlation. The results are qualitatively the same. For example, seven out of eight best matched pairs in the conv1 layer stay the same. These results suggest that correlation is an adequate measurement of the similarity between two neurons, and that switching to a mu- tual information metric would not qualitatively change the correlation-based conclusions presented above.
See Fig S4 in http://arxiv.org/pdf/1511.07543v2.pdf
In [32]:
# Here we only consider the mutual information matrices between net0 and net1, focusing on conv1 and conv2 layers
conv1_mi = np.load(net_paths[0]+"/val_mi_01/"+"conv1_final.pkl")
conv2_mi = np.load(net_paths[0]+"/val_mi_01/"+"conv2_final.pkl")
In [33]:
# Conv1, best match
plot_top_bot_matches('conv1', conv1_mi, True)
In [34]:
# Conv1, worsr match
plot_top_bot_matches('conv1', conv1_mi, False)
In [35]:
# Conv2, best match
plot_top_bot_matches('conv2', conv2_mi, True)
In [36]:
# Conv2, worst match
plot_top_bot_matches('conv2', conv2_mi, False)
In [37]:
def plot_match_vs_max(match, xcor, fileout = None):
figsize(15,6)
rcParams.update({'font.size': 18})
match_vals = xcor[arange(xcor.shape[0]),match]
diag_order = (-xcor.max(1)).argsort()
hmax,=plot(xcor[diag_order,:].max(1), 'o-', color=(.21/.84,.84/.84,.57/.84), lw=2, ms=10)
hmatch,=plot(match_vals[diag_order], 'o-', color=(0,.45,.25), lw=2, ms=5)
print "Avg. correlation (greedy matching)",xcor[diag_order,:].max(1).mean()
print "Avg. correlation (max matching)",match_vals[diag_order].mean()
count = 0.0
for i in range(len(match_vals[diag_order])):
if xcor[diag_order,:].max(1)[i] == match_vals[diag_order][i]:
count += 1
print "Overlapping ratio (when greedy matching and max matching returns the same result):",count / len(match_vals[diag_order])
xlabel('unit index (sorted by correlation of greedy max assignment)')
ylabel('correlation with assigned unit')
#_=title('Match correlation high to low')
axis('tight')
ax = looser(axis(), .015, .03)
axis(ax[0:2] + (min(ax[2],0),) + ax[3:4])
#legend(('Net1', 'Net2'), position='best')
legend((hmax,hmatch), ('greedy match assignment', 'max match assignment'), loc=3)
if fileout:
savefig(fileout + '.pdf')
savefig(fileout + '.png')
In [38]:
print 'net0/net1 are', net0, net1
l_matches = {}
for ll in ['conv1','conv2','conv3','conv4','conv5']:
# Caching will be used when computing the max bipartite matchings.
l_matches[ll] = max_match_order(l_xcors[ll][net0][net1], ignore_diag=False)
In [39]:
# Conv1 (see Fig 3 in paper)
plot_match_vs_max(l_matches['conv1'], l_xcors['conv1'][net0][net1])
In [40]:
plot_match_vs_max(l_matches['conv2'], l_xcors['conv2'][net0][net1])
In [41]:
plot_match_vs_max(l_matches['conv3'], l_xcors['conv3'][net0][net1])
In [42]:
plot_match_vs_max(l_matches['conv4'], l_xcors['conv4'][net0][net1])
In [43]:
plot_match_vs_max(l_matches['conv5'], l_xcors['conv5'][net0][net1])
The most active and least active conv1 filters for Net1 – Net4, with average activation values printed above each filter. The most active filters generally respond to low spatial frequencies, and the least active filtered to high spatial frequencies, but the lack of alignment is interesting. (see Fig S.11 in http://arxiv.org/pdf/1511.07543v2.pdf)
In [44]:
def plot_top_bot_few(layer, top = True, fileout=None):
figsize(20,8)
fig=gcf()
#fig.set_facecolor((.5,.5,.5))
N = 8
for ii,idx in enumerate(netidxs):
for jj in range(N):
subplot(len(netidxs),N,ii*N+jj+1)
if top:
unit_idx = l_descending_means[layer][idx][jj]
else:
unit_idx = l_descending_means[layer][idx][jj-N]
imshow(crop_one_patch(l_vis_unit_all[layer][idx][unit_idx]))
axis('off')
#title('%g' % int(round(l_means[layer][idx][unit_idx])), fontsize=20)
title('%.1f' % l_means[layer][idx][unit_idx], fontsize=20)
subplots_adjust(wspace=0, hspace=.25)
if fileout:
savefig(fileout + '.png')
savefig(fileout + '.pdf')
In [45]:
# Conv1, most active filters
plot_top_bot_few('conv1', True)
In [46]:
# Conv1, least active filters
plot_top_bot_few('conv1', False)
In [47]:
# Conv2, most active filters
plot_top_bot_few('conv2', True)
In [48]:
# Conv2, least active filters
plot_top_bot_few('conv2', False)
We can relax this one-to-one constraint to various degrees by learning a mapping layer with an L1 penalty (known as a LASSO model, (Tibshirani, 1996)), where stronger penalties will lead to sparser (more few-to-one or one-to-one) mappings.
More specifically, to predict one layer’s representation from another, we learn a single mapping layer from one to the other (similar to the “stitching layer” in Lenc & Vedaldi (2015)). In the case of convolutional layers, this mapping layer is a convolutional layer with 1 × 1 kernel size and number of output channels equal to the number of input channels. The mapping layer’s parameters can be considered as a square weight matrix with side length equal to the number of units in the layer; the layer learns to predict any unit in one network via a linear weighted sum of any number of units in the other.
In [49]:
def read_losses(filename):
ret = {
'train_idx': [],
'train_loss': [],
'val_idx': [],
'val_loss': [],
}
with open(filename, 'r') as ff:
for line in ff:
# Looking for lines like this:
# 0 val loss is 1.01618254185
# 0 train loss is 1.01618254185
fields = line.split()
if len(fields) != 5:
continue
if (fields[2],fields[3]) == ('loss','is') and fields[1] in ('train', 'val'):
idx = int(fields[0])
loss = float(fields[4])
if fields[1] == 'train':
ret['train_idx'].append(idx)
ret['train_loss'].append(loss)
else:
ret['val_idx'].append(idx)
ret['val_loss'].append(loss)
ret['train_idx'] = array(ret['train_idx'])
ret['train_loss'] = array(ret['train_loss'])
ret['val_idx'] = array(ret['val_idx'])
ret['val_loss'] = array(ret['val_loss'])
print "sparse prediction model training loss:", ret['val_loss'][-1]
return ret
def load_stitch_net(neta=0, netb=1, layer='conv1', L1=0, iter=5000, with_loss=False):
stitch_dir = ('../sparse_prediction/net%d_net%d_L1_%g_%s/'
% (neta, netb, L1, layer))
stitch_proto = 'proto/stitch_%s.prototxt' % layer
stitch_weights = stitch_dir + 'stitch_iter_%d.caffemodel' % iter
loss_file = stitch_dir + 'log.log'
snet = np.load(stitch_dir + "stitching.pkl")
print 'Loaded net: %s' % stitch_weights
if with_loss:
losses = read_losses(loss_file)
return snet, losses
else:
return snet
In [50]:
neta = 0
netb = 1
layer = 'conv1'
L1 = -2.6
#layer = 'conv2'
#L1 = -2.8
trainiter = 4000
sparse_mat,losses = load_stitch_net(neta, netb, layer, L1, iter=trainiter, with_loss=True)
sparse_mat = abs(sparse_mat)
In [51]:
# Plot the sparse prediction matrix for conv1 layer (net0 and net1)
figsize(18,8)
subplot(1,2,1)
axhline(43,color='orange')
showimagesc(sparse_mat)
subplot(1,2,2)
showimagesc(permute_cols(sparse_mat, max_match_order(sparse_mat)))
In [52]:
def get_unit_vis(net, layer, unit, which='deconv'):
if which == 'all':
return l_vis_unit_all[layer][net][unit]
else:
return crop_one_patch(l_vis_unit_all[layer][net][unit], which)
In [53]:
# Which unit in the second network is being predicted? (e.g, e.g. the row of ww
iib = 43
figsize(18,6)
plot(sparse_mat[iib,:])
thresh = np.percentile(sparse_mat.flatten(), 50) * 100
jjbig = np.where(sparse_mat[iib,:] > thresh)
jjbig = jjbig[0] # Just take first dimension indices
figsize(18,6)
for cc,jj in enumerate(jjbig):
plot(jj,sparse_mat[iib,jj],'ro', ms=10)
print sparse_mat[iib,jj]
text(jj,sparse_mat[iib,jj],' %d' % jj, fontsize=14)
_=title('net %d, unit %d' % (netb, iib))
In [54]:
showimagesc(get_unit_vis(netb, layer, iib, 'all'))
axis("off")
Out[54]:
In [55]:
N = len(jjbig)
print jjbig
for cc,jj in enumerate(jjbig):
subplot(1,N,cc+1)
axis("off")
showimagesc(get_unit_vis(neta, layer, jj, 'all'))
In [56]:
aw = sparse_mat
figsize(18,12)
subplot(1,2,1)
n_channels = aw.shape[0]
mat = vstack((hstack((eye(n_channels),aw)), hstack((aw.T,eye(n_channels)))))
showimagesc(mat)
subplot(1,2,2)
unit_dists = mat.max() - mat
showimagesc(unit_dists)
In [90]:
np.random.seed(0)
clust = AgglomerativeClustering(n_clusters=2, linkage='average', affinity='precomputed')
clust.fit(unit_dists)
Out[90]:
In [91]:
def children_to_tree(children):
N = len(children) + 1
ret = nx.DiGraph()
ret.add_nodes_from(range(N))
for cc,node in enumerate(children):
ll,rr = node
ret.add_edge(ll,cc+N)
ret.add_edge(rr,cc+N)
return ret
In [92]:
unittree = children_to_tree(clust.children_)
unittree_rev = unittree.reverse()
root = [n for n,d in unittree_rev.in_degree().items() if d==0][0]
ordering = [nn for nn in nx.dfs_preorder_nodes(unittree_rev, root) if nn < n_channels*2]
print ordering
In [93]:
ordering_inv = {}
for pos,uu in enumerate(ordering):
ordering_inv[uu] = pos
print ordering_inv
In [94]:
node_row = {}
for ii in range(n_channels*2):
node_row[ii] = ordering_inv[ii]
for nn in nx.dfs_postorder_nodes(unittree_rev, root):
if not nn in node_row:
node_row[nn] = min([node_row[n] for n in unittree_rev.successors(nn)])
In [95]:
leaves_under = {}
for ii in range(n_channels*2):
leaves_under[ii] = 1
for nn in nx.dfs_postorder_nodes(unittree_rev, root):
if not nn in leaves_under:
leaves_under[nn] = sum([leaves_under[n] for n in unittree_rev.successors(nn)])
In [103]:
def plot_tree_mat(mat,ordering,unittree_rev,root,ax=None,savename=None):
figsize(18,18)
epsw = .5
epsi = -.2
epsj = -.08
unit_dists = mat.max() - mat
permmat = permute_matrix(mat, ordering) ** .5
permmat = tile(permmat.reshape((permmat.shape) + (1,)), ((1,1,3)))
for ii in range(permmat.shape[0]):
clr = (1,0,0) if (ordering[ii] > permmat.shape[0]/2) else (0,1,0)
permmat[ii,ii,:] = clr
showimagesc(permmat)
for nn in nx.dfs_postorder_nodes(unittree_rev, root):
if nn >= n_channels*2:
ii = node_row[nn]
jj = ii + leaves_under[nn] - 1
#text(jj,ii,' %d,%d' % (ii,jj), color=(.5,1,.5))
plot([ii-epsw+epsj,jj+epsw+epsj,jj+epsw+epsj], [ii-epsw+epsi,ii-epsw+epsi,jj+epsw+epsi], '-', color="royalblue")
_=axis('tight')
if ax:
ax = gca()
ax.set_xlim(ax[0:2])
ax.set_ylim(ax[2:4])
if savename:
savefig(savename)
In [104]:
plot_tree_mat(mat,ordering,unittree_rev,root)
In [98]:
def vis_block_unit(indices,savename=None):
print ordering[51]
unit_indices = [ordering[ii] for ii in zoomin_indices]
#unit_indices = ordering[zoomin_indices]
print unit_indices
N = len(zoomin_indices)
print zoomin_indices
for cc,jj in enumerate(unit_indices):
subplot(1,N,cc+1)
if jj > 96:
axis("off")
title("net2",fontsize=18,color='r')
showimagesc(get_unit_vis(neta, layer, jj-96, 'all'))
else:
axis("off")
title("net1",fontsize=18,color='g')
showimagesc(get_unit_vis(netb, layer, jj, 'all'))
if savename:
savefig(savename)
In [99]:
zoomin_indices = np.arange(51,59)
vis_block_unit(zoomin_indices)
In [100]:
zoomin_indices = np.arange(148,156)
vis_block_unit(zoomin_indices)
In [ ]: