What-If Tool Image Smile Detection

In this demo we demonstrate the use of what-if-tool for image recognition models. Our task is to predict if a person is smiling or not. We provide a CNN that is trained on a subset of CelebA dataset and visualize the results on a separate test subset.

Copyright 2019 Google LLC. SPDX-License-Identifier: Apache-2.0

#@title Install the What-If Tool widget if running in colab

  import google.colab
  !pip install --upgrade witwidget -q
except Exception:

#@title Download the pretrained keras model files and subset of celeba images

!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/smile-colab-model.hdf5 -o ./smile-model.hdf5
!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/test_subset.zip -o ./test_subset.zip
!unzip -qq -o test_subset.zip

#@title Define helper functions for dataset conversion from csv to tf.Examples
import numpy as np
import tensorflow as tf
import os
from PIL import Image
from io import BytesIO

# Converts a dataframe into a list of tf.Example protos.
# If images_path is specified, it assumes that the dataframe has a special 
# column "image_id" and the path "images_path/image_id" points to an image file.
# Given this structure, this function loads and processes the images as png byte_lists
# into tf.Examples so that they can be shown in WIT. Note that 'image/encoded'
# is a reserved field in WIT for encoded image features.
def df_to_examples(df, columns=None, images_path=''):
  examples = []
  if columns == None:
    columns = df.columns.values.tolist()
  for index, row in df.iterrows():
    example = tf.train.Example()
    for col in columns:
      if df[col].dtype is np.dtype(np.int64):
      elif df[col].dtype is np.dtype(np.float64):
      elif row[col] == row[col]:
    if images_path:
      fname = row['image_id']
      with open(os.path.join(images_path, fname), 'rb') as f:
        im = Image.open(f)
        buf = BytesIO()
        im.save(buf, format= 'PNG')
        im_bytes = buf.getvalue()
  return examples

# Converts a dataframe column into a column of 0's and 1's based on the provided test.
# Used to force label columns to be numeric for binary classification using a TF estimator.
def make_label_column_numeric(df, label_column, test):
  df[label_column] = np.where(test(df[label_column]), 1, 0)

#@title Load the csv file into pandas dataframe and process it for WIT
import pandas as pd

data = pd.read_csv('celeba/data_test_subset.csv')
examples = df_to_examples(data, images_path='celeba/img_test_subset_resized/')

#@title Load the keras models
from tensorflow.keras.models import load_model

model1 = load_model('smile-model.hdf5')

#@title Define the custom predict function for WIT

# This function extracts 'image/encoded' field, which is a reserved key for the 
# feature that contains encoded image byte list. We read this feature into 
# BytesIO and decode it back to an image using PIL.
# The model expects an array of images that are floats in range 0.0 to 1.0 and 
# outputs a numpy array of (n_samples, n_labels)
def custom_predict(examples_to_infer):
  def load_byte_img(im_bytes):
    buf = BytesIO(im_bytes)
    return np.array(Image.open(buf), dtype=np.float64) / 255.

  ims = [load_byte_img(ex.features.feature['image/encoded'].bytes_list.value[0]) 
         for ex in examples_to_infer]
  preds = model1.predict(np.array(ims))
  return preds

Note that this particular model only uses images as input. Therefore, partial dependence plots are flat for all features. These features are provided for slicing and analysis purposes.

#@title Invoke What-If Tool for the data and model
from witwidget.notebook.visualization import WitWidget, WitConfigBuilder

num_datapoints = 250  #@param {type: "number"}
tool_height_in_px = 700  #@param {type: "number"}

# Decode an image from tf.example bytestring
def decode_image(ex):
  im_bytes = ex.features.feature['image/encoded'].bytes_list.value[0]
  im = Image.open(BytesIO(im_bytes))
  return im

# Define the custom distance function that compares the average color of images
def image_mean_distance(ex, exs, params):
  selected_im = decode_image(ex)
  mean_color = np.mean(selected_im, axis=(0,1))
  image_distances = [np.linalg.norm(mean_color - np.mean(decode_image(e), axis=(0,1))) for e in exs]
  return image_distances

# Setup the tool with the test examples and the trained classifier
config_builder = WitConfigBuilder(examples[:num_datapoints]).set_custom_predict_fn(

wv = WitWidget(config_builder, height=tool_height_in_px)

Exploration ideas

  • In the "Performance" tab, set the ground truth feature to "Smiling". You can set a scatter axis or binning option to be inference correct and analyze how it varies across other features (i.e. you can make a scatter plot of Young vs inference correct).
  • Choose an image and click on "Show nearest counterfactual datapoint", this will find another example that is closest to the selected image in terms of average color value, but has a different prediction (if selected image is predicted to be "smiling" the counterfactual one will have "not smiling" prediction).
  • Define your own custom distance function and set it by calling set_custom_distance_fn on config_builder and explore the counterfactuals. You can even load another neural network to compute distances!
  • You can slice by any one of the features and analyze the confusion matrix and accuracy for each group.
  • In the "Datapoint Editor" tab, you can upload your own image or download and modify one of the images to see how it affects the inference score.