Fairness Indicators Example Colab

Overview

In this activity, you'll use Fairness Indicators to explore the Civil Comments dataset. Fairness Indicators is a suite of tools built on top of TensorFlow Model Analysis that enable regular evaluation of fairness metrics in product pipelines. This Introductory Video provides more details and context on the real-world scenario we are presenting here, one of primary motivations for creating Fairness Indicators.

About the Dataset

In this exercise, you'll work with the Civil Comments dataset, approximately 2 million public comments made public by the Civil Comments platform in 2017 for ongoing research. This effort was sponsored by Jigsaw, who have hosted competitions on Kaggle to help classify toxic comments as well as minimize unintended model bias.

Each individual text comment in the dataset has a toxicity label. Within the data, a subset of comments are labeled with a variety of identity attributes, including categories for gender, sexual orientation, religion, and race or ethnicity.

About the Tools

TensorFlow Model Analysis is a library for evaluating both TensorFlow and non-TensorFlow machine learning models. It allows users to evaluate their models on large amounts of data in a distributed manner, computing in-graph and other metrics over different slices of data and visualized in notebooks.

Fairness Indicators is built on top of TFMA. With Fairness Indicators, users will be able to:

  • Evaluate model performance, sliced across defined groups of users
  • Feel confident about results with confidence intervals and evaluations at multiple thresholds

Fairness Indicators is packaged with TensorFlow Data Validation and What-If Tool to allow users to:

  • Evaluate the distribution of datasets
  • Dive deep into individual slices to explore root causes and opportunities for improvement with the What-If Tool

Importing

Run the following code to install the fairness_indicators library. This package contains the tools we'll be using in this exercise. Restart Runtime may be requested but is not necessary.


In [0]:
!pip install --upgrade fairness-indicators

In [0]:
import os
import tempfile
import apache_beam as beam
import numpy as np
import pandas as pd
from datetime import datetime

import tensorflow_hub as hub
import tensorflow as tf
import tensorflow_model_analysis as tfma
import tensorflow_data_validation as tfdv
from tensorflow_model_analysis.addons.fairness.post_export_metrics import fairness_indicators
from tensorflow_model_analysis.addons.fairness.view import widget_view
from fairness_indicators.examples import util

from witwidget.notebook.visualization import WitConfigBuilder
from witwidget.notebook.visualization import WitWidget

Download and Understand the Data

In this exercise, you'll work with the Civil Comments dataset, approximately 2 million public comments made public by the Civil Comments platform in 2017. Additionally, a subset of comments have been labelled with a variety of identity attributes, representing the identities that are mentioned in the comment.

We've hosted the dataset on Google Cloud Platform for convenience. Run the following code to download the data from GCP, the data will take about a minute to download and analyze.

TensorFlow Data Validation is one tool you can use to analyze your data. You can use it to find potential problems in your data, such as missing values and data imbalances, that can lead to Fairness disparities.


In [0]:
#@title Options for Downloading data

#@markdown You can choose to download the original and process the data in
#@markdown the colab, which may take minutes. By default, we will download the
#@markdown data that we have already prepocessed for you. In the original
#@markdown dataset, for each indentity annotation columns, the value represents
#@markdown the percent of raters who thought the comment references the identity.
#@markdown When processing the raw data, the threshold 0.5 is chosen and the
#@markdown identities are grouped together by their categories. For example
#@markdown if one comment has { male: 0.3, female: 1.0, transgender: 0.0,
#@markdown heterosexual: 0.8, homosexual_gay_or_lesbian: 1.0 }, after the
#@markdown processing, the data will be { gender: [female], 
#@markdown sexual_orientation: [heterosexual, homosexual_gay_or_lesbian] }.
download_original_data = True #@param {type:"boolean"}

if download_original_data:
  train_tf_file = tf.keras.utils.get_file('train_tf.tfrecord',
                                          'https://storage.googleapis.com/civil_comments_dataset/train_tf.tfrecord')
  validate_tf_file = tf.keras.utils.get_file('validate_tf.tfrecord',
                                             'https://storage.googleapis.com/civil_comments_dataset/validate_tf.tfrecord')

  # The identity terms list will be grouped together by their categories
  # (see 'IDENTITY_COLUMNS') on threshould 0.5. Only the identity term column,
  # text column and label column will be kept after processing.
  train_tf_file = util.convert_comments_data(train_tf_file)
  validate_tf_file = util.convert_comments_data(validate_tf_file)

else:
  train_tf_file = tf.keras.utils.get_file('train_tf_processed.tfrecord',
                                          'https://storage.googleapis.com/civil_comments_dataset/train_tf_processed.tfrecord')
  validate_tf_file = tf.keras.utils.get_file('validate_tf_processed.tfrecord',
                                             'https://storage.googleapis.com/civil_comments_dataset/validate_tf_processed.tfrecord')

In [0]:
stats = tfdv.generate_statistics_from_tfrecord(data_location=train_tf_file)
tfdv.visualize_statistics(stats)

There are several interesting things that we may want to note in this data. The first is that the toxicity label, which is what we are predicting, is unbalanced. Only 8% of examples in the training set are toxic, which means that a classifier could get 92% accuracy by predicting that all comments are non-toxic.

For the fields relating to identity terms note that out of 1.08 million training examples, only around 6.6k examples deal with homosexuality, and those related to bisexuality are even more rare. This might indicate that performance on these slices may suffer due to lack of training data.

Defining Constants

Here, we define the feature map that will be used to parse the data. Each example will have a label, comment text, and identity features sexual orientation, gender, religion, race, and disability that are associated with the text.


In [0]:
BASE_DIR = tempfile.gettempdir()

TEXT_FEATURE = 'comment_text'
LABEL = 'toxicity'
FEATURE_MAP = {
    # Label:
    LABEL: tf.io.FixedLenFeature([], tf.float32),
    # Text:
    TEXT_FEATURE:  tf.io.FixedLenFeature([], tf.string),

    # Identities:
    'sexual_orientation':tf.io.VarLenFeature(tf.string),
    'gender':tf.io.VarLenFeature(tf.string),
    'religion':tf.io.VarLenFeature(tf.string),
    'race':tf.io.VarLenFeature(tf.string),
    'disability':tf.io.VarLenFeature(tf.string),
}

Train the Model

First, set up the input function to feed data into the model. Note that since we identified a class imbalance by our earlier TensorFlow Data Validation run, we will add a weight column to each example and upweight the toxic examples to account for this. We only use identity features during the evaluation phase, as only the comments are fed into the model at training time.


In [0]:
def train_input_fn():
  def parse_function(serialized):
    parsed_example = tf.io.parse_single_example(
        serialized=serialized, features=FEATURE_MAP)
    # Adds a weight column to deal with unbalanced classes.
    parsed_example['weight'] = tf.add(parsed_example[LABEL], 0.1)
    return (parsed_example,
            parsed_example[LABEL])
  train_dataset = tf.data.TFRecordDataset(
      filenames=[train_tf_file]).map(parse_function).batch(512)
  return train_dataset

Next, create a deep neural network model, and train it on the data:


In [0]:
model_dir = os.path.join(BASE_DIR, 'train', datetime.now().strftime(
    "%Y%m%d-%H%M%S"))

embedded_text_feature_column = hub.text_embedding_column(
    key=TEXT_FEATURE,
    module_spec='https://tfhub.dev/google/nnlm-en-dim128/1')

classifier = tf.estimator.DNNClassifier(
    hidden_units=[500, 100],
    weight_column='weight',
    feature_columns=[embedded_text_feature_column],
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.003),
    loss_reduction=tf.losses.Reduction.SUM,
    n_classes=2,
    model_dir=model_dir)

classifier.train(input_fn=train_input_fn, steps=1000)

Run TensorFlow Model Analysis with Fairness Indicators

Export Saved Model


In [0]:
def eval_input_receiver_fn():
  serialized_tf_example = tf.compat.v1.placeholder(
      dtype=tf.string, shape=[None], name='input_example_placeholder')

  # This *must* be a dictionary containing a single key 'examples', which
  # points to the input placeholder.
  receiver_tensors = {'examples': serialized_tf_example}

  features = tf.io.parse_example(serialized_tf_example, FEATURE_MAP)
  features['weight'] = tf.ones_like(features[LABEL])

  return tfma.export.EvalInputReceiver(
    features=features,
    receiver_tensors=receiver_tensors,
    labels=features[LABEL])

tfma_export_dir = tfma.export.export_eval_savedmodel(
  estimator=classifier,
  export_dir_base=os.path.join(BASE_DIR, 'tfma_eval_model'),
  eval_input_receiver_fn=eval_input_receiver_fn)

Compute Fairness Metrics

Select the identity to compute metrics for and whether to run with confidence intervals in the panel on the right-hand side. Depending on your configurations, this step will take 2-10 minutes to run.


In [0]:
#@title Fairness Indicators Computation Options
tfma_eval_result_path = os.path.join(BASE_DIR, 'tfma_eval_result')

#@markdown Modify the slice_selection for experiments on other identities.
slice_selection = 'sexual_orientation' #@param ["sexual_orientation", "gender", "religion", "race", "disability"]
#@markdown Confidence Intervals can help you make better decisions regarding your data, but as it requires computing multiple resamples, is slower particularly in the colab environment that cannot take advantage of parallelization.
compute_confidence_intervals = False #@param {type:"boolean"}

# Define slices that you want the evaluation to run on.
slice_spec = [
    tfma.slicer.SingleSliceSpec(), # Overall slice
    tfma.slicer.SingleSliceSpec(columns=[slice_selection]),
]

# Add the fairness metrics.
add_metrics_callbacks = [
  tfma.post_export_metrics.fairness_indicators(
      thresholds=[0.1, 0.3, 0.5, 0.7, 0.9],
      labels_key=LABEL
      )
]

eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path=tfma_export_dir,
    add_metrics_callbacks=add_metrics_callbacks)

# Run the fairness evaluation.
with beam.Pipeline() as pipeline:
  _ = (
      pipeline
      | 'ReadData' >> beam.io.ReadFromTFRecord(validate_tf_file)
      | 'ExtractEvaluateAndWriteResults' >>
       tfma.ExtractEvaluateAndWriteResults(
                 eval_shared_model=eval_shared_model,
                 slice_spec=slice_spec,
                 compute_confidence_intervals=compute_confidence_intervals,
                 output_path=tfma_eval_result_path)
  )

eval_result = tfma.load_eval_result(output_path=tfma_eval_result_path)

Render What-if Tool

In this section, you'll use the What-If Tool's interactive visual interface to explore and manipulate data at a micro-level.

On the right-hand panel in the visualization, you will see a scatter plot where each point represents one of the examples in the subset loaded into the tool. Click on one of the points. In the left-hand panel, you should now see details about this particular example. The comment text, ground truth toxicity, and applicable identities are shown. At the bottom of this left-hand panel, you see the inference results from the model you just trained.

Modify the text of the example. You can then click the "Run inference" button to view how your changes caused the perceived toxicity prediction to change.


In [0]:
DEFAULT_MAX_EXAMPLES = 1000

# Load 100000 examples in memory. When first rendered, 
# What-If Tool should only display 1000 of these due to browser constraints.
def wit_dataset(file, num_examples=100000):
  dataset = tf.data.TFRecordDataset(
      filenames=[file]).take(num_examples)
  return [tf.train.Example.FromString(d.numpy()) for d in dataset]

wit_data = wit_dataset(train_tf_file)
config_builder = WitConfigBuilder(wit_data[:DEFAULT_MAX_EXAMPLES]).set_estimator_and_feature_spec(
    classifier, FEATURE_MAP).set_label_vocab(['non-toxicity', LABEL]).set_target_feature(LABEL)
wit = WitWidget(config_builder)

Render Fairness Indicators

Render the Fairness Indicators widget with the exported evaluation results.

Below you will see bar charts displaying performance of each slice of the data on selected metrics. You can adjust the baseline comparison slice as well as the displayed threshold(s) using the drop down menus at the top of the visualization.

The Fairness Indicator widget is integrated with the What-If Tool rendered above. If you select one slice of the data in the bar chart, the What-If Tool will update to show you examples from the selected slice. When the data reloads in the What-If Tool above, try modifying Color By to toxicity. This can give you a visual understanding of the toxicity balance of examples by slice.


In [0]:
event_handlers={'slice-selected':
                wit.create_selection_callback(wit_data, DEFAULT_MAX_EXAMPLES)}
widget_view.render_fairness_indicator(eval_result=eval_result,
                                      slicing_column=slice_selection,
                                      event_handlers=event_handlers
                                      )

With this particular dataset and task, systematically higher false positive and false negative rates for certain identities can lead to negative consequences. For example, in a content moderation system, a higher-than-overall false positive rate for a certain group can lead to those voices being silenced. Thus, it is important to regularly evaluate these types of criteria as you develop and improve models, and utilize tools such as Fairness Indicators, TFDV, and WIT to help illuminate potential problems. Once you've identified fairness issues, you can experiment with new data sources, data balancing, or other techniques to improve performance on underperforming groups.

For more information and guidance on how Fairness Indicators can be used, see this link.