In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This notebook uses Lucid to reproduce the results in Activation Atlas.
This notebook doesn't introduce the abstractions behind lucid; you may wish to also read the .
Note: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU by going:
Runtime → Change runtime type → Hardware Accelerator: GPU
In [0]:
# Installations
!pip -q install lucid>=0.3.8
!pip -q install umap-learn>=0.3.7
In [0]:
# General support
import math
import tensorflow as tf
import numpy as np
# For plots
import matplotlib.pyplot as plt
# Dimensionality reduction
import umap
from sklearn.manifold import TSNE
# General lucid code
from lucid.misc.io import save, show, load
import lucid.modelzoo.vision_models as models
# For rendering feature visualizations
import lucid.optvis.objectives as objectives
import lucid.optvis.param as param
import lucid.optvis.render as render
import lucid.optvis.transform as transform
Load paginated data:
In [0]:
from progressbar import progressbar
from os.path import join
# URL Helpers
def path_segment(key_value_pairs):
segments = [f"{key}={value}" for key, value in key_value_pairs]
return "/".join(segments)
def parameterized_url(root_url, path, kv_pairs):
if path is not None:
url = join(root_url, path)
else:
url = root_url
path = path_segment(kv_pairs)
return join(url, path)
# Paginated Data Loading
def load_data(data_type, layer_name, length=1_048_576, paginated=True):
parameters = ('dataset', 'ImageNet-train'), ('model', 'InceptionV1'), ('layer', layer_name)
root_url = parameterized_url("gs://modelzoo", "aligned-activations-and-attributions", parameters)
activations_per_page = int(2**14) # 16_384
number_of_pages = int(2**6) # 64
number_of_pages_to_load = math.ceil(length / activations_per_page)
results = []
for page_index in progressbar(range(number_of_pages_to_load)):
url = f"{root_url}/{data_type}-{page_index:05}-of-{number_of_pages:05}.npy"
results.append(load(url))
return np.concatenate(results)[:length, ...]
# Convenience functions
def load_activations(layer, length=1_048_576):
return load_data("activations", layer, length)
def load_attributions(layer, length=1_048_576):
return load_data("attributions", layer, length)
These functions are explained in more detail in the simple-activation-atlas example.
In [0]:
def whiten(full_activations):
correl = np.matmul(full_activations.T, full_activations) / len(full_activations)
correl = correl.astype("float32")
S = np.linalg.inv(correl)
S = S.astype("float32")
return S
def normalize_layout(layout, min_percentile=1, max_percentile=99, relative_margin=0.1):
"""Removes outliers and scales layout to between [0,1]."""
# compute percentiles
mins = np.percentile(layout, min_percentile, axis=(0))
maxs = np.percentile(layout, max_percentile, axis=(0))
# add margins
mins -= relative_margin * (maxs - mins)
maxs += relative_margin * (maxs - mins)
# `clip` broadcasts, `[None]`s added only for readability
clipped = np.clip(layout, mins, maxs)
# embed within [0,1] along both axes
clipped -= clipped.min(axis=0)
clipped /= clipped.max(axis=0)
return clipped
def grid(xpts=None, ypts=None, grid_size=(8,8), x_extent=(0., 1.), y_extent=(0., 1.)):
xpx_length = grid_size[0]
ypx_length = grid_size[1]
xpt_extent = x_extent
ypt_extent = y_extent
xpt_length = xpt_extent[1] - xpt_extent[0]
ypt_length = ypt_extent[1] - ypt_extent[0]
xpxs = ((xpts - xpt_extent[0]) / xpt_length) * xpx_length
ypxs = ((ypts - ypt_extent[0]) / ypt_length) * ypx_length
ix_s = range(grid_size[0])
iy_s = range(grid_size[1])
xs = []
for xi in ix_s:
ys = []
for yi in iy_s:
xpx_extent = (xi, (xi + 1))
ypx_extent = (yi, (yi + 1))
in_bounds_x = np.logical_and(xpx_extent[0] <= xpxs, xpxs <= xpx_extent[1])
in_bounds_y = np.logical_and(ypx_extent[0] <= ypxs, ypxs <= ypx_extent[1])
in_bounds = np.logical_and(in_bounds_x, in_bounds_y)
in_bounds_indices = np.where(in_bounds)[0]
ys.append(in_bounds_indices)
xs.append(ys)
return np.asarray(xs)
@objectives.wrap_objective
def direction_neuron_S(layer_name, vec, batch=None, x=None, y=None, S=None):
def inner(T):
layer = T(layer_name)
shape = tf.shape(layer)
x_ = shape[1] // 2 if x is None else x
y_ = shape[2] // 2 if y is None else y
if batch is None:
raise RuntimeError("requires batch")
acts = layer[batch, x_, y_]
vec_ = vec
if S is not None: vec_ = tf.matmul(vec_[None], S)[0]
# mag = tf.sqrt(tf.reduce_sum(acts**2))
dot = tf.reduce_mean(acts * vec_)
# cossim = dot/(1e-4 + mag)
return dot
return inner
@objectives.wrap_objective
def direction_neuron_cossim_S(layer_name, vec, batch=None, x=None, y=None, cossim_pow=1, S=None):
def inner(T):
layer = T(layer_name)
shape = tf.shape(layer)
x_ = shape[1] // 2 if x is None else x
y_ = shape[2] // 2 if y is None else y
if batch is None:
raise RuntimeError("requires batch")
acts = layer[batch, x_, y_]
vec_ = vec
if S is not None: vec_ = tf.matmul(vec_[None], S)[0]
mag = tf.sqrt(tf.reduce_sum(acts**2))
dot = tf.reduce_mean(acts * vec_)
cossim = dot/(1e-4 + mag)
cossim = tf.maximum(0.1, cossim)
return dot * cossim ** cossim_pow
return inner
def render_icons(directions, model, layer, size=80, n_steps=128, verbose=False, S=None, num_attempts=2, cossim=True, alpha=True):
image_attempts = []
loss_attempts = []
# Render multiple attempts, and pull the one with the lowest loss score.
for attempt in range(num_attempts):
# Render an image for each activation vector
param_f = lambda: param.image(size, batch=directions.shape[0], fft=True, decorrelate=True, alpha=alpha)
if(S is not None):
if(cossim is True):
obj_list = ([
direction_neuron_cossim_S(layer, v, batch=n, S=S, cossim_pow=4) for n,v in enumerate(directions)
])
else:
obj_list = ([
direction_neuron_S(layer, v, batch=n, S=S) for n,v in enumerate(directions)
])
else:
obj_list = ([
objectives.direction_neuron(layer, v, batch=n) for n,v in enumerate(directions)
])
obj = objectives.Objective.sum(obj_list)
transforms = []
if alpha:
transforms.append(transform.collapse_alpha_random())
transforms.append(transform.pad(2, mode='constant', constant_value=1))
transforms.append(transform.jitter(4))
transforms.append(transform.jitter(4))
transforms.append(transform.jitter(8))
transforms.append(transform.jitter(8))
transforms.append(transform.jitter(8))
transforms.append(transform.random_scale([0.995**n for n in range(-5,80)] + [0.998**n for n in 2*list(range(20,40))]))
transforms.append(transform.random_rotate(list(range(-20,20))+list(range(-10,10))+list(range(-5,5))+5*[0]))
transforms.append(transform.jitter(2))
# This is the tensorflow optimization process
print("attempt: ", attempt)
with tf.Graph().as_default(), tf.Session() as sess:
learning_rate = 0.05
losses = []
trainer = tf.train.AdamOptimizer(learning_rate)
T = render.make_vis_T(model, obj, param_f, trainer, transforms)
loss_t, vis_op, t_image = T("loss"), T("vis_op"), T("input")
losses_ = [obj_part(T) for obj_part in obj_list]
tf.global_variables_initializer().run()
for i in range(n_steps):
loss, _ = sess.run([losses_, vis_op])
losses.append(loss)
if (i % 100 == 0):
print(i)
img = t_image.eval()
img_rgb = img[:,:,:,:3]
if alpha:
print("alpha true")
k = 0.8
bg_color = 0.0
img_a = img[:,:,:,3:]
img_merged = img_rgb*((1-k)+k*img_a) + bg_color * k*(1-img_a)
image_attempts.append(img_merged)
else:
print("alpha false")
image_attempts.append(img_rgb)
loss_attempts.append(losses[-1])
# Use only the icons with the lowest loss
loss_attempts = np.asarray(loss_attempts)
loss_final = []
image_final = []
print("merging best scores from attempts...")
for i, d in enumerate(directions):
# note, this should be max, it is not a traditional loss
mi = np.argmax(loss_attempts[:,i])
image_final.append(image_attempts[mi][i])
return (image_final, loss_final)
def render_layout(model, layer, S, xs, ys, activ, n_steps=512, n_attempts=2, min_density=10, grid_size=(10, 10), icon_size=80, x_extent=(0., 1.0), y_extent=(0., 1.0)):
grid_layout = grid(xpts=xs, ypts=ys, grid_size=grid_size, x_extent=x_extent, y_extent=y_extent)
icons = []
for x in range(grid_size[0]):
for y in range(grid_size[1]):
indices = grid_layout[x, y]
if len(indices) > min_density:
average_activation = np.average(activ[indices], axis=0)
icons.append((average_activation, x, y))
icons = np.asarray(icons)
icon_batch, losses = render_icons(icons[:,0], model, alpha=False, layer=layer, S=S, n_steps=n_steps, size=icon_size, num_attempts=n_attempts)
canvas = np.ones((icon_size * grid_size[0], icon_size * grid_size[1], 3))
for i, icon in enumerate(icon_batch):
y = int(icons[i, 1])
x = int(icons[i, 2])
canvas[(grid_size[0] - x - 1) * icon_size:(grid_size[0] - x) * icon_size, (y) * icon_size:(y + 1) * icon_size] = icon
return canvas
In [0]:
model = models.InceptionV1()
model.load_graphdef()
In [6]:
layer = "mixed5b"
num_activations = 400000
activations = load_activations(layer=layer, length=num_activations)
attributions = load_attributions(layer=layer, length=num_activations)
In [0]:
S = whiten(activations)
In [0]:
def get_winning_indices(attributions, class_id):
matching_indices = []
for i, attr in enumerate(attributions):
class_indices = np.argsort(-attr)
if class_id in class_indices[0:1]:
matching_indices.append(i)
return np.array(matching_indices)
def get_top_indices(attributions, class_id, n=2000):
top_indices = np.argsort(-attributions[:,class_id])[0:n]
positive = attributions[top_indices, class_id] > 0
return top_indices[positive]
In [0]:
# 6 grey whale
# 442 great white shark
# 671 frying pan
# 672 wok
# 235 fireboat
# 287 streetcar
# 982 scuba diver
# 507 snorkel
# 547 abacus
# 930 comic book
# 600 spider web
class_filter = 982
In [0]:
winning_indices = get_winning_indices(attributions, class_filter)
top_indices = get_top_indices(attributions, class_filter, 3000)
combined_indices = np.concatenate((top_indices, winning_indices), axis=0)
In [0]:
filtered_activations = activations[combined_indices]
filtered_attributions = attributions[combined_indices]
In [12]:
points = umap.UMAP(n_components=2, n_neighbors=20, min_dist=0.01, metric="cosine", verbose=True).fit_transform(filtered_activations)
# You could use TSNE as well if you like
# points = TSNE(n_components=2, verbose=False, metric="cosine", learning_rate=10, perplexity=50).fit_transform(filtered_activations)
In [0]:
layout = normalize_layout(points, min_percentile=1, max_percentile=99)
In [14]:
plt.figure(figsize=(10, 10))
plt.scatter(x=layout[:,0],y=layout[:,1], s=2)
plt.show()
In [15]:
xs = layout[:, 0]
ys = layout[:, 1]
canvas = render_layout(model, layer, S, xs, ys, filtered_activations, grid_size=(8, 8), n_steps=512, n_attempts=2, min_density=4, x_extent=(0, 1.), y_extent=(0, 1.))
show(canvas)
In [16]:
# 6 grey whale
# 442 great white shark
# 671 frying pan
# 672 wok
# 235 fireboat
# 287 streetcar
# 982 scuba diver
# 507 snorkel
# 547 abacus
# 930 comic book
# 600 spider web
left_class = 507
right_class = 982
left_winning_indices = get_winning_indices(attributions, left_class)
left_top_indices = get_top_indices(attributions, left_class, 2000)
left_combined_indices = np.concatenate((left_top_indices, left_winning_indices), axis=0)
right_winning_indices = get_winning_indices(attributions, right_class)
right_top_indices = get_top_indices(attributions, right_class, 2000)
right_combined_indices = np.concatenate((right_top_indices, right_winning_indices), axis=0)
matching_indices = np.concatenate((left_combined_indices, right_combined_indices),axis=0)
print(len(matching_indices))
#filter activations/attributions
compare_activations = activations[matching_indices]
compare_attributions = attributions[matching_indices]
# one-dimensional tsne for y-axis
compare_points = TSNE(n_components=1, verbose=False, metric="cosine", learning_rate=10, perplexity=50).fit_transform(compare_activations)
compare_layout = normalize_layout(compare_points, min_percentile=1, max_percentile=99)
compare_layout = compare_layout[:,0]
# difference in attribution for x-axis
attr_diff = compare_attributions[:,right_class] - compare_attributions[:,left_class]
# normalize x-axis without re-centering zero
boundary_value = np.percentile(np.abs(attr_diff), 99, axis=(0))
attr_diff /= boundary_value
attr_diff = np.clip(attr_diff, -1, 1)
In [17]:
plt.figure(figsize=(10, 10))
plt.scatter(y=compare_layout,x=attr_diff, s=2)
plt.show()
In [20]:
xs = attr_diff
ys = compare_layout
canvas = render_layout(model, layer, S, xs, ys, compare_activations, grid_size=(10, 10), n_steps=512, n_attempts=2, min_density=4, x_extent=(-1, 1.), y_extent=(0, 1.))
show(canvas)
In [0]: