In [2]:
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import tensorflow as tf
In [ ]:
!gsutil cp gs://ml-design-patterns/auto-mpg.csv .
data = pd.read_csv('auto-mpg.csv', na_values='?')
data = data.dropna()
In [10]:
data = data.drop(columns=['car name'])
In [11]:
data = pd.get_dummies(data, columns=['origin'])
In [12]:
data.head()
Out[12]:
In [13]:
labels = data['mpg']
data = data.drop(columns=['mpg', 'cylinders'])
In [14]:
x,y = data,labels
x_train,x_test,y_train,y_test = train_test_split(x,y)
Train a Scikit-learn linear regression model on the data and print the learned coefficients
In [15]:
model = LinearRegression().fit(x_train, y_train)
In [17]:
coefficients = model.coef_
coefdf = pd.DataFrame(coefficients, index=data.columns.tolist(), columns=['Learned coefficients'])
In [18]:
coefdf
Out[18]:
Using the same dataset, we'll train a deep neural net with TensorFlow and use the SHAP library to get feature attributions.
In [19]:
model = tf.keras.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=[len(x_train.iloc[0])]),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(1)
])
optimizer = tf.keras.optimizers.RMSprop(0.001)
model.compile(loss='mse',
optimizer=optimizer,
metrics=['mae', 'mse'])
In [ ]:
model.fit(x_train, y_train, epochs=1000)
In [23]:
!pip install shap
In [24]:
import shap
In [25]:
# Create an explainer object and get feature attributions for the first 10 examples in our test dataset
explainer = shap.DeepExplainer(model, x_train[:200])
shap_values = explainer.shap_values(x_test.values[:10])
In [26]:
# Print the feature attributions for the first example in our test set
shap_values[0][0]
Out[26]:
In [30]:
# This is the baseline value shap is using
explainer.expected_value.numpy()
Out[30]:
In [31]:
shap.initjs()
shap.force_plot(explainer.expected_value[0].numpy(), shap_values[0][0,:], x_test.iloc[0,:])
Out[31]:
In [32]:
shap.summary_plot(shap_values, feature_names=data.columns.tolist(), class_names=['MPG'])
This part is coming soon :) In the mean time , see the docs.
Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License