The Fairness design pattern provides techniques for ensuring that model predictions are fair and equitable for different groups of users and scenarios. Evaluating your entire end-to-end ML workflow – from data collection to model deployment – through a fairness lens is essential to building successful, high quality models.
In [3]:
# If you're running on Colab, you'll need to install the What-if Tool package and authenticate
# If you're on Cloud AI Platform Notebooks, you'll need to install XGBoost on the TF instance
def pip_install(module):
!pip install {module} --quiet
try:
import google.colab
IN_COLAB = True
except:
IN_COLAB = False
if IN_COLAB:
pip_install('witwidget')
from google.colab import auth
auth.authenticate_user()
else:
pip_install('xgboost')
|████████████████████████████████| 2.3MB 2.8MB/s
In [4]:
import pandas as pd
import xgboost as xgb
import numpy as np
import collections
import witwidget
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.utils import shuffle
from witwidget.notebook.visualization import WitWidget, WitConfigBuilder
In this section we'll:
1
becomes Office of the Comptroller of the Currency (OCC)
In [ ]:
# Use a small subset of the data since the original dataset is too big for Colab (2.5GB)
# Data source: https://www.ffiec.gov/hmda/hmdaflat.htm
!gsutil cp gs://mortgage_dataset_files/mortgage-small.csv .
Copying gs://mortgage_dataset_files/mortgage-small.csv...
/ [1 files][330.8 MiB/330.8 MiB]
Operation completed over 1 objects/330.8 MiB.
In [ ]:
# Set column dtypes for Pandas
COLUMN_NAMES = collections.OrderedDict({
'as_of_year': np.int16,
'agency_code': 'category',
'loan_type': 'category',
'property_type': 'category',
'loan_purpose': 'category',
'occupancy': np.int8,
'loan_amt_thousands': np.float64,
'preapproval': 'category',
'county_code': np.float64,
'applicant_income_thousands': np.float64,
'purchaser_type': 'category',
'hoepa_status': 'category',
'lien_status': 'category',
'population': np.float64,
'ffiec_median_fam_income': np.float64,
'tract_to_msa_income_pct': np.float64,
'num_owner_occupied_units': np.float64,
'num_1_to_4_family_units': np.float64,
'approved': np.int8
})
In [ ]:
# Load data into Pandas
data = pd.read_csv(
'mortgage-small.csv',
index_col=False,
dtype=COLUMN_NAMES
)
data = data.dropna()
data = shuffle(data, random_state=2)
data.head()
Out[ ]:
as_of_year
agency_code
loan_type
property_type
loan_purpose
occupancy
loan_amt_thousands
preapproval
county_code
applicant_income_thousands
purchaser_type
hoepa_status
lien_status
population
ffiec_median_fam_income
tract_to_msa_income_pct
num_owner_occupied_units
num_1_to_4_family_units
approved
310650
2016
Consumer Financial Protection Bureau (CFPB)
Conventional (any loan other than FHA, VA, FSA...
One to four-family (other than manufactured ho...
Refinancing
1
110.0
Not applicable
119.0
55.0
Freddie Mac (FHLMC)
Not a HOEPA loan
Secured by a first lien
5930.0
64100.0
98.81
1305.0
1631.0
1
630129
2016
Department of Housing and Urban Development (HUD)
Conventional (any loan other than FHA, VA, FSA...
One to four-family (other than manufactured ho...
Home purchase
1
480.0
Not applicable
33.0
270.0
Loan was not originated or was not sold in cal...
Not a HOEPA loan
Secured by a first lien
4791.0
90300.0
144.06
1420.0
1450.0
0
715484
2016
Federal Deposit Insurance Corporation (FDIC)
Conventional (any loan other than FHA, VA, FSA...
One to four-family (other than manufactured ho...
Refinancing
2
240.0
Not applicable
59.0
96.0
Commercial bank, savings bank or savings assoc...
Not a HOEPA loan
Secured by a first lien
3439.0
105700.0
104.62
853.0
1076.0
1
887708
2016
Office of the Comptroller of the Currency (OCC)
Conventional (any loan other than FHA, VA, FSA...
One to four-family (other than manufactured ho...
Refinancing
1
76.0
Not applicable
65.0
85.0
Loan was not originated or was not sold in cal...
Not a HOEPA loan
Secured by a subordinate lien
3952.0
61300.0
90.93
1272.0
1666.0
1
719598
2016
National Credit Union Administration (NCUA)
Conventional (any loan other than FHA, VA, FSA...
One to four-family (other than manufactured ho...
Refinancing
1
100.0
Not applicable
127.0
70.0
Loan was not originated or was not sold in cal...
Not a HOEPA loan
Secured by a first lien
2422.0
46400.0
88.37
650.0
1006.0
1
The What-If Tool can be used before you've built or trained a model by passing it a dataset directly. After running the cell below, changing the "Color By" dropdown to approved
so we can see how the data is distributed by our label class.
You can also experiment with further slicing the data. For example, try changing the "Binning | Y-Axis" dropdown to loan_type
to visualize the percentage of applications approved for each loan type in the dataset.
In [ ]:
# Show WIT before training model by passing it only a dataset
config_builder = (WitConfigBuilder(data[:1000].values.tolist(), data.columns.tolist()))
WitWidget(config_builder, height=800)
[[_text]]
[[_charCounterStr]]
[[errorMessage]]
{{text}}
[[name]]
Set up your data and model
Add another model for comparison
Model Type
Classification
Regression
Uses Predict API
SequenceExamples
Maps predicted class indices to labels from text file
An optional text file for use with the results of a classification
model. Classification models return predicted classes as indices. If a
vocab file is provided, each predicted class index will be mapped to a
label from this file for use in the display. Each line in the file
corresponds to a label, with the first line corresponding to class index
0.
Multi-class classification model
[[_title]]
[[_customMessage]]
Feature
[[displayAttributionHeader(models, index)]]
[[item.key]]
[[displayNum(item.mean)]]
[[heading]]
Feature
Value(s)
[[compareTitle]]
[[feat.name]]
Image windowing (contrast)
Window center:
Window width:
Image scale percentage:
set
Delete feature
[[feat.name]]
...
...
Delete feature
Add feature value
Sequence Number:
[[seqfeat.name]]
Image windowing (contrast)
Window center:
Window width:
Image scale percentage:
set
Delete feature
[[seqfeat.name]]
...
...
Delete feature
Add feature value
Add Feature
Int
Float
Bytes
Create
Run
Model
Label
[[getScoreLabel(modelType)]]
Delta
[[getRunNumber(runindex)]]
[[getModelName(modelNames, modelindex)]]
[[getLabel(item)]]
[[getPrintableScore(item.score)]]
[[getPrintableDelta(runindex, index, modelindex)]]
(none)
[[_breakUpAndTruncate(item)]]
(none)
[[_breakUpAndTruncate(item)]]
(none)
[[item]]
[[_getImageFieldNameDefaultLabel(atlasUrl)]]
[[item]]
(default)
[[item]]
(default)
[[item]]
standard
warm
cool
assist
X-Axis Bag of words
Y-Axis Bag of words
Legend
Colors
by [[_breakUpAndTruncate(colorBy)]]
[[colorBy]]
[[_breakUpAndTruncate(item.content.label)]]
[[item.content.label]]
[[item.str]]
[[item.str]]
[[_getShowTableButtonText(_showTable)]]
Value
[[data.name]]
[[entry.value]]
[[count]]
[[_getTitle(numeric)]] Features ([[_getFeatureCountText(dataModel, numeric, features)]])
[[item]]
log
expand
weighted
percentages
[[_getFeatureName(feature)]]
[[option.name]]
Reverse order
Features:
[[_getSpecCheckboxText(specAndList)]]
[[_getDatasetName(_dataModel, datasetIndex)]]
Accept
Cancel
Are you sure you want to delete the selected datapoint?
Cancel
Delete
Create similarity feature
Distance type
L1
L2
User-specified
Distance based on
Feature values
Attributions
Color By
X-Axis Binning
Y-Axis Binning
X-Axis Scatter
Y-Axis Scatter
Cancel
Apply
Datapoints
Partial dependence plots
Nearest counterfactual
Nearest counterfactual (neighbor of different
classification)
Compares the selected datapoint with its nearest
neighbor from a different classification using L1 or
L2 distance. If a custom distance function is set,
it uses that function instead.
L1
L2
Custom distance
Feature values
Attributions
Counterfactual threshold
For regression, a neighbor point is considered
as a different classification if the difference
in inferred value is equal or greater than the
selected threshold.
The threshold is initialized to the standard
deviation of the inferred values.
[[getLabeledModelName_(item)]]
Create similarity feature
Create similarity feature
Calculates the distance between the selected
datapoint and all other datapoint. Creates a new feature to visualize
in the datapoints visualization.
Alphabetical by features
Descending attribution
Ascending attribution
Absolute attribution
[[getLabeledModelName_(item)]]
Select a datapoint to begin exploring model
behavior for your selection.
Edit
Edit a selected datapoint and click on "predict" to see changes in model prediction. Compare datapoints to their counterfactuals or visualize feature attributions where available.
Visualize:
Switch between visualizing datapoints and
exploring partial dependence plots to gain
insights into your model's behavior. Explore
counterfactuals or see how similar (or different)
the rest of your dataset is from your selection.
Predict
Datapoints and their inference results will be displayed here.
Partial Dependence Plots
Partial Dependence Plots
Partial dependence plots visualize the change in
prediction results for different valid values of a feature.
For numeric features, you can set the minimum and
maximum values to visualize.
For string features, the most popular feature values are shown as alternatives to the feature value for the selected datapoint.
If the datapoint contains multiple feature values for a feature, each feature value is visualized as a single plot. You can override which feature values are visualized by specifying the indices for partial dependence plots.
When the global toggle is turned on, the plots show the average effect of changing a single feature across all datapoints. When it is turned off, the plots show the effect of changing a single feature on the selected datapoint.
Sort by variation
Global partial dependence plots
Features for partial dependence plots unavailable in the loaded dataset.
[[item.name]]
Set range of values to visualize
-
Set feature indices (optional)
[[getFeatureName_(item)]]
What is ground truth?
The feature that your model is trying to predict.
More.
Ground Truth Feature
Select the feature that represents the ground truth for the model's prediction to investigate the model's performance overall or on intersections of other available features by comparing the model's prediction to the ground truth feature.
What is cost ratio?
The cost of false positives relative to false
negatives. Required for optimization.
More.
What is cost ratio?
This tells the tool how to optimize the
classification thresholds when you use the
optimization strategy controls.
1.00 = false positives are equally as costly
as false negatives.
4.00 = false positives are 4 times more
costly than false negatives
0.25 = false negatives are 4 times more
costly than false positives.
[[getFeatureName_(item)]]
What does slicing do?
Shows the model's performance on datapoints
grouped by each value of the selected feature.
[[getFeatureName_(item)]]
Apply an optimization strategy
Select a strategy to automatically set classification thresholds, based on the set cost ratio and data slices. Manually altering thresholds or changing cost ratio will revert the strategy to 'custom thresholds'.
Custom thresholds
Custom thresholds
Set your own thresholds using the threshold
sliders.
Single threshold
Single threshold
Optimize a single threshold for all datapoints
based on the specified cost ratio.
Demographic parity
Demographic parity
Optimize a threshold per slice based on the
specified cost ratio, ensuring the different
slices achieve demographic party.
Demographic parity means that similar
percentages of datapoints from each slice are
predicted as positive classifications.
Equal opportunity
Equal opportunity
Optimize a threshold per slice based on the
specified cost ratio, ensuring the different
slices achieve equal opportunity.
Equal opportunity means that among those
datapoints with the positive ground truth label,
there is a similar percentage of positive
predictions in each slice.
Equal accuracy
Equal accuracy
Optimize a threshold per slice based on the
specified cost ratio, ensuring the different
slices achieve equal accuracy.
Equal accuracy means that there is a similar
percentage of correct predictions in each slice.
Group thresholds
Group thresholds
Optimize a separate threshold for each slice
based on the specified cost ratio.
[[getPerfTableTitle(selectedBreakdownFeature,
selectedSecondBreakdownFeature, optimizationSelected,
featureValueThresholds)]]
Classification Performance Table
Set the ground truth feature to explore model
performance including ROC curves and confusion
matrices.
Slice the dataset by features and explore model
performance by slice.
Use the fairness optimization strategies and see the
impact on the thresholds and performance of the
individual slices.
[[item]]
Feature Value
Count
Model
Threshold
Explore classification performance
Use this slider to adjust the classification
threshold for this slice. Inference values at or
above this threshold are considered a positive
classification, while those below this
threshold are considered a negative classification.
False Positives (%)
False Negatives (%)
Accuracy (%)
F1
[[getPrintableValue_(featureValueThreshold)]]
[[getFeatureValueCount(inferenceStats_,
featureValueThreshold.threshold,
featureValueThreshold)]]
[[getModelName_(index)]]
[[getFPModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getFNModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getAccuracyModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getF1ModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getRocChartTitleForSlice(aucs_, index)]]
ROC curve
[[describeRocChart()]]
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleForSlice(aucs_, index)]]
PR curve
[[describePrChart()]]
[[describePrAuc()]]
Recall
Precision
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
All datapoints
[[getFeatureValueCount(inferenceStats_,
overallThresholds)]]
[[getModelName_(index)]]
[[getFPModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getFNModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getAccuracyModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getF1ModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getRocChartTitleOverall(aucs_)]]
ROC curve
[[describeRocChart()]]
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleOverall(aucs_)]]
PR curve
[[describePrChart()]]
[[describePrAuc()]]
Recall
Precision
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
[[getNoThresholdPerfTableTitle(selectedBreakdownFeature,
selectedSecondBreakdownFeature,
featureValueThresholds)]]
Classification performance table
Set the ground truth feature to explore model
performance including confusion matrices.
Create slices from intersections in the dataset and explore model performance on slices.
[[item]]
Feature Value
Count
Model
Accuracy
[[getPrintableValue_(featureValueThreshold)]]
[[getMultiClassFeatureValueCount(inferenceStats_,
featureValueThreshold)]]
[[getModelName_(index)]]
[[getMultiClassAccuracyModelIndex(inferenceStats_,
index, featureValueThreshold)]]
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
[[getRocChartTitleForLabel(aucs_, labelInd, index)]]
ROC curve
[[describeRocChart()]]
For this multi-class classification
problem, we plot one ROC curve for each
class, at each time considering the
class in question as the positive one
and all the others as negatives (i.e.
binarized versions of the
problem).
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleForLabel(aucs_, labelInd, index)]]
PR curve
[[describePrChart()]]
For this multi-class classification
problem, we plot one PR curve for each
class, at each time considering the
class in question as the positive one
and all the others as negatives (i.e.
binarized versions of the
problem).
[[describePrAuc()]]
Recall
Precision
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
All datapoints
[[getMultiClassFeatureValueCount(inferenceStats_)]]
[[getModelName_(index)]]
[[getMultiClassAccuracyModelIndex(inferenceStats_,
index)]]
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
[[getRocChartTitleForLabel(aucs_, labelInd, '')]]
ROC curve
[[describeRocChart()]]
For this multi-class classification
problem, we plot one ROC curve for each
class, at each time considering the class
in question as the positive one and all
the others as negatives (i.e.
binarized versions of the problem).
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleForLabel(aucs_, labelInd, '')]]
PR curve
[[describePrChart()]]
For this multi-class classification
problem, we plot one PR curve for each
class, at each time considering the class
in question as the positive one and all
the others as negatives (i.e.
binarized versions of the problem).
[[describePrAuc()]]
Recall
Precision
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
[[getNoThresholdPerfTableTitle(selectedBreakdownFeature,
selectedSecondBreakdownFeature,
featureValueThresholds)]]
Regression Performance Table
Set the ground truth feature to explore model
performance.
Create slices from intersections in the dataset and explore model performance by
slices.
[[item]]
Feature Value
Count
Mean error
Median error
Mean absolute error
Median absolute error
Mean squared error
Median squared error
[[item.name]]
[[item.count]]
[[formatError(item.meanError)]]
[[formatError(item.medianError)]]
[[formatError(item.meanAbsError)]]
[[formatError(item.medianAbsError)]]
[[formatError(item.meanSquaredError)]]
[[formatError(item.medianSquaredError)]]
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
Out[ ]:
<witwidget.notebook.colab.wit.WitWidget at 0x7fc706ea2208>
Based on our What-If Tool analysis, we'll limit the dataset to only include loans for home purchases or refinancing since we don't have quite enough data on other loans.
In [ ]:
data = data[data['loan_purpose'].isin(['Home purchase', 'Refinancing'])]
In [ ]:
# Label preprocessing
labels = data['approved'].values
# See the distribution of approved / denied classes (0: denied, 1: approved)
print(data['approved'].value_counts())
1 623613
0 303387
Name: approved, dtype: int64
In [ ]:
data = data.drop(columns=['approved'])
For XGBoost all model inputs need to be numeric, so we'll use the Pandas get_dummies
method to convert categorical columns to columns with boolean values.
In [ ]:
# Convert categorical columns to dummy columns
dummy_columns = list(data.dtypes[data.dtypes == 'category'].index)
data = pd.get_dummies(data, columns=dummy_columns)
In [ ]:
# Preview the data
data.head()
Out[ ]:
as_of_year
occupancy
loan_amt_thousands
county_code
applicant_income_thousands
population
ffiec_median_fam_income
tract_to_msa_income_pct
num_owner_occupied_units
num_1_to_4_family_units
agency_code_Consumer Financial Protection Bureau (CFPB)
agency_code_Department of Housing and Urban Development (HUD)
agency_code_Federal Deposit Insurance Corporation (FDIC)
agency_code_Federal Reserve System (FRS)
agency_code_National Credit Union Administration (NCUA)
agency_code_Office of the Comptroller of the Currency (OCC)
loan_type_Conventional (any loan other than FHA, VA, FSA, or RHS loans)
loan_type_FHA-insured (Federal Housing Administration)
loan_type_FSA/RHS (Farm Service Agency or Rural Housing Service)
loan_type_VA-guaranteed (Veterans Administration)
property_type_Manufactured housing
property_type_One to four-family (other than manufactured housing)
loan_purpose_Home improvement
loan_purpose_Home purchase
loan_purpose_Refinancing
preapproval_Not applicable
preapproval_Preapproval was not requested
preapproval_Preapproval was requested
purchaser_type_Affiliate institution
purchaser_type_Commercial bank, savings bank or savings association
purchaser_type_Fannie Mae (FNMA)
purchaser_type_Farmer Mac (FAMC)
purchaser_type_Freddie Mac (FHLMC)
purchaser_type_Ginnie Mae (GNMA)
purchaser_type_Life insurance company, credit union, mortgage bank, or finance company
purchaser_type_Loan was not originated or was not sold in calendar year covered by register
purchaser_type_Other type of purchaser
purchaser_type_Private securitization
hoepa_status_HOEPA loan
hoepa_status_Not a HOEPA loan
lien_status_Not applicable (purchased loans)
lien_status_Not secured by a lien
lien_status_Secured by a first lien
lien_status_Secured by a subordinate lien
310650
2016
1
110.0
119.0
55.0
5930.0
64100.0
98.81
1305.0
1631.0
1
0
0
0
0
0
1
0
0
0
0
1
0
0
1
1
0
0
0
0
0
0
1
0
0
0
0
0
0
1
0
0
1
0
630129
2016
1
480.0
33.0
270.0
4791.0
90300.0
144.06
1420.0
1450.0
0
1
0
0
0
0
1
0
0
0
0
1
0
1
0
1
0
0
0
0
0
0
0
0
0
1
0
0
0
1
0
0
1
0
715484
2016
2
240.0
59.0
96.0
3439.0
105700.0
104.62
853.0
1076.0
0
0
1
0
0
0
1
0
0
0
0
1
0
0
1
1
0
0
0
1
0
0
0
0
0
0
0
0
0
1
0
0
1
0
887708
2016
1
76.0
65.0
85.0
3952.0
61300.0
90.93
1272.0
1666.0
0
0
0
0
0
1
1
0
0
0
0
1
0
0
1
1
0
0
0
0
0
0
0
0
0
1
0
0
0
1
0
0
0
1
719598
2016
1
100.0
127.0
70.0
2422.0
46400.0
88.37
650.0
1006.0
0
0
0
0
1
0
1
0
0
0
0
1
0
0
1
1
0
0
0
0
0
0
0
0
0
1
0
0
0
1
0
0
1
0
In [ ]:
# Split the data into train / test sets
x,y = data,labels
x_train,x_test,y_train,y_test = train_test_split(x,y)
In [ ]:
x_train = x_train.astype(float)
x_test = x_test.astype(float)
In [ ]:
# Train the model, this will take a few minutes to run
bst = xgb.XGBClassifier(
objective='binary:logistic'
)
bst.fit(x_train, y_train)
Out[ ]:
XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
colsample_bynode=1, colsample_bytree=1, gamma=0,
learning_rate=0.1, max_delta_step=0, max_depth=3,
min_child_weight=1, missing=None, n_estimators=100, n_jobs=1,
nthread=None, objective='binary:logistic', random_state=0,
reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
silent=None, subsample=1, verbosity=1)
In [ ]:
# Get predictions on the test set and print the accuracy score
y_pred = bst.predict(x_test)
acc = accuracy_score(y_test, y_pred.round())
print(acc, '\n')
0.881967637540453
In [ ]:
# Print a confusion matrix
print('Confusion matrix:')
cm = confusion_matrix(y_test, y_pred.round())
cm = cm / cm.astype(np.float).sum(axis=1)
print(cm)
Confusion matrix:
[[0.86590297 0.06469772]
[0.22857749 0.88971835]]
In [ ]:
# Format a subset of the test data to send to the What-if Tool for visualization
# Append ground truth label value to training data
# This is the number of examples you want to display in the What-if Tool
num_wit_examples = 1000
test_examples = np.hstack((x_test[:num_wit_examples].values,y_test[:num_wit_examples].reshape(-1,1)))
In [ ]:
# Create a What-if Tool visualization, it may take a minute to load
# See the cell below this for exploration ideas
# This prediction adjustment function is needed as this xgboost model's
# prediction returns just a score for the positive class of the binary
# classification, whereas the What-If Tool expects a list of scores for each
# class (in this case, both the negative class and the positive class).
def custom_fn(examples):
df = pd.DataFrame(examples, columns=x_train.columns.tolist())
preds = bst.predict_proba(df)
return preds
config_builder = (WitConfigBuilder(test_examples.tolist(), data.columns.tolist() + ['mortgage_status'])
.set_custom_predict_fn(custom_fn)
.set_target_feature('mortgage_status')
.set_label_vocab(['denied', 'approved']))
WitWidget(config_builder, height=800)
[[_text]]
[[_charCounterStr]]
[[errorMessage]]
{{text}}
[[name]]
Set up your data and model
Add another model for comparison
Model Type
Classification
Regression
Uses Predict API
SequenceExamples
Maps predicted class indices to labels from text file
An optional text file for use with the results of a classification
model. Classification models return predicted classes as indices. If a
vocab file is provided, each predicted class index will be mapped to a
label from this file for use in the display. Each line in the file
corresponds to a label, with the first line corresponding to class index
0.
Multi-class classification model
[[_title]]
[[_customMessage]]
Feature
[[displayAttributionHeader(models, index)]]
[[item.key]]
[[displayNum(item.mean)]]
[[heading]]
Feature
Value(s)
[[compareTitle]]
[[feat.name]]
Image windowing (contrast)
Window center:
Window width:
Image scale percentage:
set
Delete feature
[[feat.name]]
...
...
Delete feature
Add feature value
Sequence Number:
[[seqfeat.name]]
Image windowing (contrast)
Window center:
Window width:
Image scale percentage:
set
Delete feature
[[seqfeat.name]]
...
...
Delete feature
Add feature value
Add Feature
Int
Float
Bytes
Create
Run
Model
Label
[[getScoreLabel(modelType)]]
Delta
[[getRunNumber(runindex)]]
[[getModelName(modelNames, modelindex)]]
[[getLabel(item)]]
[[getPrintableScore(item.score)]]
[[getPrintableDelta(runindex, index, modelindex)]]
(none)
[[_breakUpAndTruncate(item)]]
(none)
[[_breakUpAndTruncate(item)]]
(none)
[[item]]
[[_getImageFieldNameDefaultLabel(atlasUrl)]]
[[item]]
(default)
[[item]]
(default)
[[item]]
standard
warm
cool
assist
X-Axis Bag of words
Y-Axis Bag of words
Legend
Colors
by [[_breakUpAndTruncate(colorBy)]]
[[colorBy]]
[[_breakUpAndTruncate(item.content.label)]]
[[item.content.label]]
[[item.str]]
[[item.str]]
[[_getShowTableButtonText(_showTable)]]
Value
[[data.name]]
[[entry.value]]
[[count]]
[[_getTitle(numeric)]] Features ([[_getFeatureCountText(dataModel, numeric, features)]])
[[item]]
log
expand
weighted
percentages
[[_getFeatureName(feature)]]
[[option.name]]
Reverse order
Features:
[[_getSpecCheckboxText(specAndList)]]
[[_getDatasetName(_dataModel, datasetIndex)]]
Accept
Cancel
Are you sure you want to delete the selected datapoint?
Cancel
Delete
Create similarity feature
Distance type
L1
L2
User-specified
Distance based on
Feature values
Attributions
Color By
X-Axis Binning
Y-Axis Binning
X-Axis Scatter
Y-Axis Scatter
Cancel
Apply
Datapoints
Partial dependence plots
Nearest counterfactual
Nearest counterfactual (neighbor of different
classification)
Compares the selected datapoint with its nearest
neighbor from a different classification using L1 or
L2 distance. If a custom distance function is set,
it uses that function instead.
L1
L2
Custom distance
Feature values
Attributions
Counterfactual threshold
For regression, a neighbor point is considered
as a different classification if the difference
in inferred value is equal or greater than the
selected threshold.
The threshold is initialized to the standard
deviation of the inferred values.
[[getLabeledModelName_(item)]]
Create similarity feature
Create similarity feature
Calculates the distance between the selected
datapoint and all other datapoint. Creates a new feature to visualize
in the datapoints visualization.
Alphabetical by features
Descending attribution
Ascending attribution
Absolute attribution
[[getLabeledModelName_(item)]]
Select a datapoint to begin exploring model
behavior for your selection.
Edit
Edit a selected datapoint and click on "predict" to see changes in model prediction. Compare datapoints to their counterfactuals or visualize feature attributions where available.
Visualize:
Switch between visualizing datapoints and
exploring partial dependence plots to gain
insights into your model's behavior. Explore
counterfactuals or see how similar (or different)
the rest of your dataset is from your selection.
Predict
Datapoints and their inference results will be displayed here.
Partial Dependence Plots
Partial Dependence Plots
Partial dependence plots visualize the change in
prediction results for different valid values of a feature.
For numeric features, you can set the minimum and
maximum values to visualize.
For string features, the most popular feature values are shown as alternatives to the feature value for the selected datapoint.
If the datapoint contains multiple feature values for a feature, each feature value is visualized as a single plot. You can override which feature values are visualized by specifying the indices for partial dependence plots.
When the global toggle is turned on, the plots show the average effect of changing a single feature across all datapoints. When it is turned off, the plots show the effect of changing a single feature on the selected datapoint.
Sort by variation
Global partial dependence plots
Features for partial dependence plots unavailable in the loaded dataset.
[[item.name]]
Set range of values to visualize
-
Set feature indices (optional)
[[getFeatureName_(item)]]
What is ground truth?
The feature that your model is trying to predict.
More.
Ground Truth Feature
Select the feature that represents the ground truth for the model's prediction to investigate the model's performance overall or on intersections of other available features by comparing the model's prediction to the ground truth feature.
What is cost ratio?
The cost of false positives relative to false
negatives. Required for optimization.
More.
What is cost ratio?
This tells the tool how to optimize the
classification thresholds when you use the
optimization strategy controls.
1.00 = false positives are equally as costly
as false negatives.
4.00 = false positives are 4 times more
costly than false negatives
0.25 = false negatives are 4 times more
costly than false positives.
[[getFeatureName_(item)]]
What does slicing do?
Shows the model's performance on datapoints
grouped by each value of the selected feature.
[[getFeatureName_(item)]]
Apply an optimization strategy
Select a strategy to automatically set classification thresholds, based on the set cost ratio and data slices. Manually altering thresholds or changing cost ratio will revert the strategy to 'custom thresholds'.
Custom thresholds
Custom thresholds
Set your own thresholds using the threshold
sliders.
Single threshold
Single threshold
Optimize a single threshold for all datapoints
based on the specified cost ratio.
Demographic parity
Demographic parity
Optimize a threshold per slice based on the
specified cost ratio, ensuring the different
slices achieve demographic party.
Demographic parity means that similar
percentages of datapoints from each slice are
predicted as positive classifications.
Equal opportunity
Equal opportunity
Optimize a threshold per slice based on the
specified cost ratio, ensuring the different
slices achieve equal opportunity.
Equal opportunity means that among those
datapoints with the positive ground truth label,
there is a similar percentage of positive
predictions in each slice.
Equal accuracy
Equal accuracy
Optimize a threshold per slice based on the
specified cost ratio, ensuring the different
slices achieve equal accuracy.
Equal accuracy means that there is a similar
percentage of correct predictions in each slice.
Group thresholds
Group thresholds
Optimize a separate threshold for each slice
based on the specified cost ratio.
[[getPerfTableTitle(selectedBreakdownFeature,
selectedSecondBreakdownFeature, optimizationSelected,
featureValueThresholds)]]
Classification Performance Table
Set the ground truth feature to explore model
performance including ROC curves and confusion
matrices.
Slice the dataset by features and explore model
performance by slice.
Use the fairness optimization strategies and see the
impact on the thresholds and performance of the
individual slices.
[[item]]
Feature Value
Count
Model
Threshold
Explore classification performance
Use this slider to adjust the classification
threshold for this slice. Inference values at or
above this threshold are considered a positive
classification, while those below this
threshold are considered a negative classification.
False Positives (%)
False Negatives (%)
Accuracy (%)
F1
[[getPrintableValue_(featureValueThreshold)]]
[[getFeatureValueCount(inferenceStats_,
featureValueThreshold.threshold,
featureValueThreshold)]]
[[getModelName_(index)]]
[[getFPModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getFNModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getAccuracyModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getF1ModelIndex(inferenceStats_,
featureValueThreshold.threshold, index,
featureValueThreshold)]]
[[getRocChartTitleForSlice(aucs_, index)]]
ROC curve
[[describeRocChart()]]
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleForSlice(aucs_, index)]]
PR curve
[[describePrChart()]]
[[describePrAuc()]]
Recall
Precision
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
All datapoints
[[getFeatureValueCount(inferenceStats_,
overallThresholds)]]
[[getModelName_(index)]]
[[getFPModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getFNModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getAccuracyModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getF1ModelIndex(inferenceStats_,
overallThresholds, index)]]
[[getRocChartTitleOverall(aucs_)]]
ROC curve
[[describeRocChart()]]
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleOverall(aucs_)]]
PR curve
[[describePrChart()]]
[[describePrAuc()]]
Recall
Precision
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
[[getNoThresholdPerfTableTitle(selectedBreakdownFeature,
selectedSecondBreakdownFeature,
featureValueThresholds)]]
Classification performance table
Set the ground truth feature to explore model
performance including confusion matrices.
Create slices from intersections in the dataset and explore model performance on slices.
[[item]]
Feature Value
Count
Model
Accuracy
[[getPrintableValue_(featureValueThreshold)]]
[[getMultiClassFeatureValueCount(inferenceStats_,
featureValueThreshold)]]
[[getModelName_(index)]]
[[getMultiClassAccuracyModelIndex(inferenceStats_,
index, featureValueThreshold)]]
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
[[getRocChartTitleForLabel(aucs_, labelInd, index)]]
ROC curve
[[describeRocChart()]]
For this multi-class classification
problem, we plot one ROC curve for each
class, at each time considering the
class in question as the positive one
and all the others as negatives (i.e.
binarized versions of the
problem).
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleForLabel(aucs_, labelInd, index)]]
PR curve
[[describePrChart()]]
For this multi-class classification
problem, we plot one PR curve for each
class, at each time considering the
class in question as the positive one
and all the others as negatives (i.e.
binarized versions of the
problem).
[[describePrAuc()]]
Recall
Precision
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
All datapoints
[[getMultiClassFeatureValueCount(inferenceStats_)]]
[[getModelName_(index)]]
[[getMultiClassAccuracyModelIndex(inferenceStats_,
index)]]
Confusion Matrix
Confusion Matrix
A confusion matrix is a n*n table (where n = number of classes being predicted) that summarizes if a model’s predictions were correct or incorrect. One axis is the model’s predictions, and the other axis is the ground truth.
[[getRocChartTitleForLabel(aucs_, labelInd, '')]]
ROC curve
[[describeRocChart()]]
For this multi-class classification
problem, we plot one ROC curve for each
class, at each time considering the class
in question as the positive one and all
the others as negatives (i.e.
binarized versions of the problem).
[[describeRocAuc()]]
False positive rate
True positive rate
[[getPrChartTitleForLabel(aucs_, labelInd, '')]]
PR curve
[[describePrChart()]]
For this multi-class classification
problem, we plot one PR curve for each
class, at each time considering the class
in question as the positive one and all
the others as negatives (i.e.
binarized versions of the problem).
[[describePrAuc()]]
Recall
Precision
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
[[getNoThresholdPerfTableTitle(selectedBreakdownFeature,
selectedSecondBreakdownFeature,
featureValueThresholds)]]
Regression Performance Table
Set the ground truth feature to explore model
performance.
Create slices from intersections in the dataset and explore model performance by
slices.
[[item]]
Feature Value
Count
Mean error
Median error
Mean absolute error
Median absolute error
Mean squared error
Median squared error
[[item.name]]
[[item.count]]
[[formatError(item.meanError)]]
[[formatError(item.medianError)]]
[[formatError(item.meanAbsError)]]
[[formatError(item.medianAbsError)]]
[[formatError(item.meanSquaredError)]]
[[formatError(item.medianSquaredError)]]
Mean attributions
Mean attributions
This table shows the average attribution value for each feature across a set of datapoints, ordered from maximum to minimum.
Out[ ]:
<witwidget.notebook.colab.wit.WitWidget at 0x7fc706f187b8>
Individual data points: the default graph shows all data points from the test set, colored by their ground truth label (approved or denied)
Binning data: create separate graphs for individual features
Exploring overall performance: Click on the "Performance & Fairness" tab to view overall performance statistics on the model's results on the provided dataset, including confusion matrices, PR curves, and ROC curves.
Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License
Content source: GoogleCloudPlatform/ml-design-patterns
Similar notebooks: