In this demo we demonstrate the use of what-if-tool for image recognition models. Our task is to predict if a person is smiling or not. We provide a CNN that is trained on a subset of CelebA dataset and visualize the results on a separate test subset.
Copyright 2019 Google LLC. SPDX-License-Identifier: Apache-2.0
In [ ]:
# Ensure the right version of Tensorflow is installed.
!pip freeze | grep tensorflow==2.1
In [ ]:
!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/smile-colab-model.hdf5 -o ./smile-model.hdf5
!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/test_subset.zip -o ./test_subset.zip
!unzip -qq -o test_subset.zip
In [ ]:
import numpy as np
import tensorflow as tf
import os
from PIL import Image
from io import BytesIO
# Converts a dataframe into a list of tf.Example protos.
# If images_path is specified, it assumes that the dataframe has a special
# column "image_id" and the path "images_path/image_id" points to an image file.
# Given this structure, this function loads and processes the images as png byte_lists
# into tf.Examples so that they can be shown in WIT. Note that 'image/encoded'
# is a reserved field in WIT for encoded image features.
def df_to_examples(df, columns=None, images_path=''):
examples = []
if columns == None:
columns = df.columns.values.tolist()
for index, row in df.iterrows():
example = tf.train.Example()
for col in columns:
if df[col].dtype is np.dtype(np.int64):
example.features.feature[col].int64_list.value.append(int(row[col]))
elif df[col].dtype is np.dtype(np.float64):
example.features.feature[col].float_list.value.append(row[col])
elif row[col] == row[col]:
example.features.feature[col].bytes_list.value.append(row[col].encode('utf-8'))
if images_path:
fname = row['image_id']
with open(os.path.join(images_path, fname), 'rb') as f:
im = Image.open(f)
buf = BytesIO()
im.save(buf, format= 'PNG')
im_bytes = buf.getvalue()
example.features.feature['image/encoded'].bytes_list.value.append(im_bytes)
examples.append(example)
return examples
# Converts a dataframe column into a column of 0's and 1's based on the provided test.
# Used to force label columns to be numeric for binary classification using a TF estimator.
def make_label_column_numeric(df, label_column, test):
df[label_column] = np.where(test(df[label_column]), 1, 0)
In [ ]:
import pandas as pd
data = pd.read_csv('celeba/data_test_subset.csv')
examples = df_to_examples(data, images_path='celeba/img_test_subset_resized/')
In [ ]:
from tensorflow.keras.models import load_model
model1 = load_model('smile-model.hdf5')
In [ ]:
# This function extracts 'image/encoded' field, which is a reserved key for the
# feature that contains encoded image byte list. We read this feature into
# BytesIO and decode it back to an image using PIL.
# The model expects an array of images that are floats in range 0.0 to 1.0 and
# outputs a numpy array of (n_samples, n_labels)
def custom_predict(examples_to_infer):
def load_byte_img(im_bytes):
buf = BytesIO(im_bytes)
return np.array(Image.open(buf), dtype=np.float64) / 255.
ims = [load_byte_img(ex.features.feature['image/encoded'].bytes_list.value[0])
for ex in examples_to_infer]
preds = model1.predict(np.array(ims))
return preds
In [ ]:
from witwidget.notebook.visualization import WitWidget, WitConfigBuilder, display
num_datapoints = 250
tool_height_in_px = 700
# Decode an image from tf.example bytestring
def decode_image(ex):
im_bytes = ex.features.feature['image/encoded'].bytes_list.value[0]
im = Image.open(BytesIO(im_bytes))
return im
# Define the custom distance function that compares the average color of images
def image_mean_distance(ex, exs, params):
selected_im = decode_image(ex)
mean_color = np.mean(selected_im, axis=(0,1))
image_distances = [np.linalg.norm(mean_color - np.mean(decode_image(e), axis=(0,1))) for e in exs]
return image_distances
# Setup the tool with the test examples and the trained classifier
config_builder = WitConfigBuilder(examples[:num_datapoints]).set_custom_predict_fn(
custom_predict).set_custom_distance_fn(image_mean_distance)
wv = WitWidget(config_builder, height=tool_height_in_px)
display(wv)