What-If Tool Image Smile Detection

In this demo we demonstrate the use of what-if-tool for image recognition models. Our task is to predict if a person is smiling or not. We provide a CNN that is trained on a subset of CelebA dataset and visualize the results on a separate test subset.

Copyright 2019 Google LLC. SPDX-License-Identifier: Apache-2.0


In [0]:
#@title Install the What-If Tool widget if running in colab {display-mode: "form"}

try:
  import google.colab
  !pip install --upgrade witwidget -q
except Exception:
  pass

In [0]:
#@title Download the pretrained keras model files and subset of celeba images

!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/smile-colab-model.hdf5 -o ./smile-model.hdf5
!curl -L https://storage.googleapis.com/what-if-tool-resources/smile-demo/test_subset.zip -o ./test_subset.zip
!unzip -qq -o test_subset.zip

In [0]:
#@title Define helper functions for dataset conversion from csv to tf.Examples
import numpy as np
import tensorflow as tf
import os
from PIL import Image
from io import BytesIO

# Converts a dataframe into a list of tf.Example protos.
# If images_path is specified, it assumes that the dataframe has a special 
# column "image_id" and the path "images_path/image_id" points to an image file.
# Given this structure, this function loads and processes the images as png byte_lists
# into tf.Examples so that they can be shown in WIT. Note that 'image/encoded'
# is a reserved field in WIT for encoded image features.
def df_to_examples(df, columns=None, images_path=''):
  examples = []
  if columns == None:
    columns = df.columns.values.tolist()
  for index, row in df.iterrows():
    example = tf.train.Example()
    for col in columns:
      if df[col].dtype is np.dtype(np.int64):
        example.features.feature[col].int64_list.value.append(int(row[col]))
      elif df[col].dtype is np.dtype(np.float64):
        example.features.feature[col].float_list.value.append(row[col])
      elif row[col] == row[col]:
        example.features.feature[col].bytes_list.value.append(row[col].encode('utf-8'))
    if images_path:
      fname = row['image_id']
      with open(os.path.join(images_path, fname), 'rb') as f:
        im = Image.open(f)
        buf = BytesIO()
        im.save(buf, format= 'PNG')
        im_bytes = buf.getvalue()
        example.features.feature['image/encoded'].bytes_list.value.append(im_bytes)
    examples.append(example)
  return examples

# Converts a dataframe column into a column of 0's and 1's based on the provided test.
# Used to force label columns to be numeric for binary classification using a TF estimator.
def make_label_column_numeric(df, label_column, test):
  df[label_column] = np.where(test(df[label_column]), 1, 0)

In [0]:
#@title Load the csv file into pandas dataframe and process it for WIT
import pandas as pd

data = pd.read_csv('celeba/data_test_subset.csv')
examples = df_to_examples(data, images_path='celeba/img_test_subset_resized/')

In [0]:
#@title Load the keras models
from tensorflow.keras.models import load_model

model1 = load_model('smile-model.hdf5')

In [0]:
#@title Define the custom predict function for WIT

# This function extracts 'image/encoded' field, which is a reserved key for the 
# feature that contains encoded image byte list. We read this feature into 
# BytesIO and decode it back to an image using PIL.
# The model expects an array of images that are floats in range 0.0 to 1.0 and 
# outputs a numpy array of (n_samples, n_labels)
def custom_predict(examples_to_infer):
  def load_byte_img(im_bytes):
    buf = BytesIO(im_bytes)
    return np.array(Image.open(buf), dtype=np.float64) / 255.

  ims = [load_byte_img(ex.features.feature['image/encoded'].bytes_list.value[0]) 
         for ex in examples_to_infer]
  preds = model1.predict(np.array(ims))
  return preds

Note that this particular model only uses images as input. Therefore, partial dependence plots are flat for all features. These features are provided for slicing and analysis purposes.


In [0]:
#@title Invoke What-If Tool for the data and model {display-mode: "form"}
from witwidget.notebook.visualization import WitWidget, WitConfigBuilder

num_datapoints = 250  #@param {type: "number"}
tool_height_in_px = 700  #@param {type: "number"}

# Decode an image from tf.example bytestring
def decode_image(ex):
  im_bytes = ex.features.feature['image/encoded'].bytes_list.value[0]
  im = Image.open(BytesIO(im_bytes))
  return im

# Define the custom distance function that compares the average color of images
def image_mean_distance(ex, exs, params):
  selected_im = decode_image(ex)
  mean_color = np.mean(selected_im, axis=(0,1))
  image_distances = [np.linalg.norm(mean_color - np.mean(decode_image(e), axis=(0,1))) for e in exs]
  return image_distances

# Setup the tool with the test examples and the trained classifier
config_builder = WitConfigBuilder(examples[:num_datapoints]).set_custom_predict_fn(
    custom_predict).set_custom_distance_fn(image_mean_distance)

wv = WitWidget(config_builder, height=tool_height_in_px)

Exploration ideas

  • In the "Performance" tab, set the ground truth feature to "Smiling". You can set a scatter axis or binning option to be inference correct and analyze how it varies across other features (i.e. you can make a scatter plot of Young vs inference correct).
  • Choose an image and click on "Show nearest counterfactual datapoint", this will find another example that is closest to the selected image in terms of average color value, but has a different prediction (if selected image is predicted to be "smiling" the counterfactual one will have "not smiling" prediction).
  • Define your own custom distance function and set it by calling set_custom_distance_fn on config_builder and explore the counterfactuals. You can even load another neural network to compute distances!
  • You can slice by any one of the features and analyze the confusion matrix and accuracy for each group.
  • In the "Datapoint Editor" tab, you can upload your own image or download and modify one of the images to see how it affects the inference score.