In this post, we use a trained AlexNet model (training on ImageNet dataset). AlexNet has 8 parameterized layers: 5 convolutional and 4 fully connected:
conv1: 96 11x11-kernels - 3 channelsconv2: 256 5x5-kernels - 48 channelsconv3: 384 3x3-kernels - 256 channelsconv4: 384 3x3-kernels - 192 channels conv5: 256 3x3-kernels - 192 channelsfc6: 4096x9216 matrixfc7: 4096x4096 matrixfc8: 1000x4096 matrixEach of these layers is saved as a numpy 2D array.
In [1]:
import numpy as np
import os
import sys
weights_path = '/'.join(os.getcwd().split('/')[:-1]) + '/local-trained/alexnet/weights/'
print(weights_path)
In [2]:
os.listdir(weights_path)
Out[2]:
In [3]:
keys = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7', 'fc8']
weights = {}
for k in keys:
weights[k] = np.load(weights_path + k + '.npy')
The shape of each layers:
In [4]:
for k in keys:
print("Layer " + k + ": " + str(weights[k].shape))
In [5]:
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
Using ggplot (R style) for all the plots. There are 7 colors in the color wheel, we simply use a global variable i to cycle through all the color. Function histogram here plots to axis ax.
In [6]:
plt.style.use('ggplot')
i = 0
def histogram(ax, x, num_bins=1000):
"""Plot a histogram onto ax"""
global i
i = (i + 1) % 7
clr = list(plt.rcParams['axes.prop_cycle'])[i]['color']
return ax.hist(x, num_bins, normed=1, color=clr, alpha=0.8)
In [7]:
# Create figure and 8 axes (4-by-2)
fig, ax = plt.subplots(nrows=4, ncols=2, figsize=(12.8,19.2))
# Flatten each layer
conv1_f = weights['conv1'].flatten()
conv2_f = weights['conv2'].flatten()
conv3_f = weights['conv3'].flatten()
conv4_f = weights['conv4'].flatten()
conv5_f = weights['conv5'].flatten()
fc6_f = weights['fc6'].flatten()
fc7_f = weights['fc7'].flatten()
fc8_f = weights['fc8'].flatten()
# Plot histogram
histogram(ax[0,0], conv1_f)
ax[0,0].set_title("conv1")
histogram(ax[0,1], conv2_f)
ax[0,1].set_title("conv2")
histogram(ax[1,0], conv3_f)
ax[1,0].set_title("conv3")
histogram(ax[1,1], conv4_f)
ax[1,1].set_title("conv4")
histogram(ax[2,0], conv5_f)
ax[2,0].set_title("conv5")
histogram(ax[2,1], fc6_f)
ax[2,1].set_title("fc6")
histogram(ax[3,0], fc7_f)
ax[3,0].set_title("fc7")
histogram(ax[3,1], fc8_f)
ax[3,1].set_title("fc8")
fig.tight_layout()
plt.show()
plt.close()
The plot showed that all weights in AlexNet seems to have zero mean and follows a normal (or gamma) distribution. Next, we use violin plot to show the statistical properties of these weights in detail.
In [8]:
def violin(ax, x, pos):
"""Plot a histogram onto ax"""
ax.violinplot(x, showmeans=True, showextrema=True, showmedians=True, positions=[pos])
In [288]:
# Create a single figure
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(12.8,6.4))
# Plot violin
violin(ax, conv1_f, pos=0)
violin(ax, conv2_f, pos=1)
violin(ax, conv3_f, pos=2)
violin(ax, conv4_f, pos=3)
violin(ax, conv5_f, pos=4)
violin(ax, fc6_f, pos=5)
violin(ax, fc7_f, pos=6)
violin(ax, fc8_f, pos=7)
# Labels
ax.set_xticks(np.arange(0, len(keys)))
ax.set_xticklabels(keys)
fig.tight_layout()
plt.show()
fig.savefig('violin_alexnet.pdf')
plt.close()
This session provide a simple pseudo pruning function. We call this pseudo pruning because these is no re-training involved, hence the accuracy of the neural network would greatly decreased comparing to the pruning-retraining scheme. The prun function here merely used to create a fake sparse matrix to use for testing the compression packing.
In [10]:
def prun(o_weights, thres=None, percentile=0.8):
"""Set weights to zero according the threshold.
If the threshold is not provided, `thres` is
infered from `percentile`."""
w_weights = o_weights.reshape(1,-1)
if thres == None:
args = w_weights[0].argsort()
thres = w_weights[0][args[int((len(args)-1)*(1-percentile))]]
for i, val in enumerate(w_weights[0]):
if abs(val) <= thres:
w_weights[0][i] = 0.0
Test the effect of pruning:
In [11]:
print("Before pruning:")
for layer_name in keys:
print(layer_name + " total size: " + str(weights[layer_name].size))
print(layer_name + " non-zero count: " + str(np.count_nonzero(weights[layer_name])))
print("Density: " + str(float(np.count_nonzero(weights[layer_name]))/weights[layer_name].size))
print("Cloning layers...")
clone_w = {}
for layer_name in keys:
clone_w[layer_name] = weights[layer_name].copy()
keep_per = 0.3
print("Prunning... Keeping " + str(keep_per*100) + "%")
for layer_name in keys:
prun(clone_w[layer_name], percentile=keep_per)
print(layer_name + " total size: " + str(clone_w[layer_name].size))
print(layer_name + " non-zero count: " + str(np.count_nonzero(clone_w[layer_name])))
print("Density: " + str(float(np.count_nonzero(clone_w[layer_name])*1.0)/clone_w[layer_name].size))
In [13]:
from sklearn.cluster import KMeans
Function quantize_kmeans performs k-means clustering on the for weight values. The cluster_centers_ returned by k-means is the codebook, and the built-in function Kmeans.predict is the decoder.
In [29]:
def quantize_kmeans(weight, ncluster=256, rs=0):
org_shape = weight.shape
km = KMeans(n_clusters=ncluster, random_state=rs).fit(weight.reshape(-1,1))
num_bits = int(np.ceil(np.log2(ncluster)))
encoded = np.zeros_like(weight, dtype=np.int32)
codebook = km.cluster_centers_
weight = weight.reshape(1,-1)
encoded = encoded.reshape(1,-1)
for i in range(encoded.size):
encoded[i] = km.predict([weight[0][i]])
return num_bits, codebook, encoded.reshape(org_shape)
We manually get each clustering for each layer. Convolutional layers have 256 centers each; fully connected layers have 16 centers each.
In [15]:
print("Clustering conv1 ...")
conv1_k = KMeans(n_clusters=256, random_state=0).fit(weights['conv1'].reshape(-1,1))
print("Clustering conv2 ...")
conv2_k = KMeans(n_clusters=256, random_state=0).fit(weights['conv2'].reshape(-1,1))
print("Clustering conv3 ...")
conv3_k = KMeans(n_clusters=256, random_state=0).fit(weights['conv3'].reshape(-1,1))
print("Clustering conv4 ...")
conv4_k = KMeans(n_clusters=256, random_state=0).fit(weights['conv4'].reshape(-1,1))
print("Clustering conv5 ...")
conv5_k = KMeans(n_clusters=256, random_state=0).fit(weights['conv5'].reshape(-1,1))
print("Clustering fc6 ...")
fc6_k = KMeans(n_clusters=16, random_state=0).fit(weights['fc6'].reshape(-1,1))
print("Clustering fc7 ...")
fc7_k = KMeans(n_clusters=16, random_state=0).fit(weights['fc7'].reshape(-1,1))
print("Clustering fc8 ...")
fc8_k = KMeans(n_clusters=16, random_state=0).fit(weights['fc8'].reshape(-1,1))
We plot the cluster center points on top of the histogram for weight values. For the cluster center, x-axis represent its value, y-axis represents its id. We plot in such manner to better observe the concentration of cluster center.
In [21]:
def histogram_kmeans(ax, flat, kmeans, norm=20):
histogram(ax, flat)
tmp = np.ones_like(kmeans.cluster_centers_)
idx = ((np.cumsum(tmp)) - 1) / norm
ax.scatter(sorted(kmeans.cluster_centers_), idx, s=16, alpha=0.6)
plt.close()
fig, ax = plt.subplots(nrows=4, ncols=2, figsize=(12.8,19.2))
histogram_kmeans(ax[0,0], conv1_f, conv1_k)
ax[0,0].set_title("conv1")
histogram_kmeans(ax[0,1], conv2_f, conv2_k)
ax[0,1].set_title("conv2")
histogram_kmeans(ax[1,0], conv3_f, conv3_k)
ax[1,0].set_title("conv3")
histogram_kmeans(ax[1,1], conv4_f, conv4_k)
ax[1,1].set_title("conv4")
histogram_kmeans(ax[2,0], conv5_f, conv5_k)
ax[2,0].set_title("conv5")
histogram_kmeans(ax[2,1], fc6_f, fc6_k, norm=0.5)
ax[2,1].set_title("fc6")
histogram_kmeans(ax[3,0], fc7_f, fc7_k, norm=0.5)
ax[3,0].set_title("fc7")
histogram_kmeans(ax[3,1], fc8_f, fc8_k, norm=0.5)
ax[3,1].set_title("fc8")
fig.tight_layout()
plt.show()
plt.close()
In [34]:
def encode_kmeans(kmeans, weights):
w = weights.reshape(-1,1)
codebook = kmeans.cluster_centers_
encoded = kmeans.predict(w)
return codebook, encoded.reshape(weights.shape)
In [41]:
cb_conv1, conv1_e = encode_kmeans(conv1_k, weights['conv1'])
cb_conv2, conv2_e = encode_kmeans(conv2_k, weights['conv2'])
cb_conv3, conv3_e = encode_kmeans(conv3_k, weights['conv3'])
cb_conv4, conv4_e = encode_kmeans(conv4_k, weights['conv4'])
cb_conv5, conv5_e = encode_kmeans(conv5_k, weights['conv5'])
cb_fc6, fc6_e = encode_kmeans(fc6_k, weights['fc6'])
cb_fc7, fc7_e = encode_kmeans(fc7_k, weights['fc7'])
cb_fc8, fc8_e = encode_kmeans(fc8_k, weights['fc8'])
The encoded data can be stored by using codebooks (cb_*) and encoded matrices (*_e). Since the matrix can be represented as 8-bit data, we think it's a good idea to save the data as png images.
In [97]:
from scipy import misc
def save_image(encoded_w, name, ext='.png'):
encoded_w = encoded_w.reshape(encoded_w.shape[0], -1)
misc.imsave('./' + name + ext, encoded_w)
save_image(conv1_e, 'conv1')
save_image(conv2_e, 'conv2')
save_image(conv3_e, 'conv3')
save_image(conv4_e, 'conv4')
save_image(conv5_e, 'conv5')
save_image(fc6_e, 'fc6')
save_image(fc7_e, 'fc7')
save_image(fc8_e, 'fc8')
We check if there is any problem (data loss) while loading the images:
In [272]:
def check_image(img_name, encoded, is_fc=False):
data = misc.imread(img_name)
if is_fc:
data = data / 17 # Quick hack for 4-bit data
print(np.all(data == encoded.reshape(encoded.shape[0], -1)))
check_image('conv1.png', conv1_e)
check_image('conv2.png', conv2_e)
check_image('conv3.png', conv3_e)
check_image('conv4.png', conv4_e)
check_image('conv5.png', conv5_e)
check_image('fc6.png', fc6_e, is_fc=True)
check_image('fc7.png', fc7_e, is_fc=True)
check_image('fc8.png', fc8_e, is_fc=True)
We now encode the data using non-zero indices list and encoded values.
In [245]:
def encode_index(nz_index, bits=4):
"""Encode nonzero indices using 4-bit"""
max_val = 2**bits
if bits == 4 or bits == 8:
data_type = np.uint8
elif bits == 16:
data_type = np.uint16
else:
print("Unimplemented index encoding with " + str(bits) + " bits.")
sys.exit(1)
code = np.zeros_like(nz_index, dtype=np.uint32)
adv = 0
# Encode with relative to array index
for i, val in enumerate(nz_index):
cur_i = i + adv
code[i] = val - cur_i
if (val - cur_i != 0):
adv += val - cur_i
# Check if there is overflow
if (code.max() >= max_val):
print("Overflow index codebook. Unimplemented handling.")
sys.exit(1)
# Special case of 4-bit encoding
if (bits == 4):
code_4bit = np.zeros((code.size-1)/2+1)
code_4bit = code[np.arange(0,code.size,2)]*(2**bits) + code[np.arange(1,code.size,2)]
return np.asarray(code_4bit, dtype=data_type)
return np.asarray(code, dtype=data_type)
In [267]:
def decode_index(encoded_ind, org_size=None, bits=4):
"""Decode nonzero indices"""
if org_size is None:
print("Original size must be specified.")
sys.exit(1)
decode = np.zeros(org_size, dtype=np.uint32)
if (bits == 4):
decode[np.arange(0,org_size,2)] = encoded_ind / 2**bits
decode[np.arange(1,org_size,2)] = encoded_ind % 2**bits
decode = np.cumsum(decode+1) - 1
return np.asarray(decode, dtype=np.uint32)
In [280]:
# It should be nonzero indices for real data; this is psuedo weights
conv1_ind = np.arange(weights['conv1'].size)
conv2_ind = np.arange(weights['conv2'].size)
conv3_ind = np.arange(weights['conv3'].size)
conv4_ind = np.arange(weights['conv4'].size)
conv5_ind = np.arange(weights['conv5'].size)
fc6_ind = np.arange(weights['fc6'].size)
fc7_ind = np.arange(weights['fc7'].size)
fc8_ind = np.arange(weights['fc8'].size)
In [281]:
# Encode the indices
conv1_ie = encode_index(conv1_ind)
conv2_ie = encode_index(conv2_ind)
conv3_ie = encode_index(conv3_ind)
conv4_ie = encode_index(conv4_ind)
conv5_ie = encode_index(conv5_ind)
fc6_ie = encode_index(fc6_ind)
fc7_ie = encode_index(fc7_ind)
fc8_ie = encode_index(fc8_ind)
In [ ]: