In [1]:
import xgboost
import numpy as np
import shap
import time
In [2]:
from iml.common import convert_to_instance, convert_to_model, match_instance_to_data, match_model_to_data, convert_to_instance_with_index
from iml.explanations import AdditiveExplanation
from iml.links import convert_to_link, IdentityLink
from iml.datatypes import convert_to_data, DenseData
import logging
from iml.explanations import AdditiveExplanation
log = logging.getLogger('shap')
from shap import KernelExplainer
class IMEExplainer(KernelExplainer):
""" This is an implementation of the IME explanation method (aka. Shapley sampling values)
IME was proposed in "An Efficient Explanation of Individual Classifications using Game Theory",
Erik Štrumbelj, Igor Kononenko, JMLR 2010
"""
def __init__(self, model, data, **kwargs):
# silence warning about large datasets
level = log.level
log.setLevel(logging.ERROR)
super(IMEExplainer, self).__init__(model, data, **kwargs)
log.setLevel(level)
def explain(self, incoming_instance, **kwargs):
# convert incoming input to a standardized iml object
instance = convert_to_instance(incoming_instance)
match_instance_to_data(instance, self.data)
# pick a reasonable number of samples if the user didn't specify how many they wanted
self.nsamples = kwargs.get("nsamples", 0)
if self.nsamples == 0:
self.nsamples = 1000 * self.P
# divide up the samples among the features
self.nsamples_each = np.ones(self.P, dtype=np.int64) * 2 * (self.nsamples // (self.P * 2))
for i in range((self.nsamples % (self.P * 2)) // 2):
self.nsamples_each[i] += 2
model_out = self.model.f(instance.x)
# explain every feature
phi = np.zeros(self.P)
self.X_masked = np.zeros((self.nsamples_each.max(), X.shape[1]))
for i in range(self.P):
phi[i] = self.ime(i, self.model.f, instance.x, self.data.data, nsamples=self.nsamples_each[i])
phi = np.array(phi)
return AdditiveExplanation(self.link.f(1), self.link.f(1), phi, np.zeros(len(phi)), instance, self.link,
self.model, self.data)
def ime(self, j, f, x, X, nsamples=10):
assert nsamples % 2 == 0, "nsamples must be divisible by 2!"
X_masked = self.X_masked[:nsamples,:]
inds = np.arange(X.shape[1])
for i in range(0, nsamples//2):
np.random.shuffle(inds)
pos = np.where(inds == j)[0][0]
rind = np.random.randint(X.shape[0])
X_masked[i,:] = x
X_masked[i,inds[pos+1:]] = X[rind,inds[pos+1:]]
X_masked[-(i+1),:] = x
X_masked[-(i+1),inds[pos:]] = X[rind,inds[pos:]]
evals = f(X_masked)
evals_on = evals[:nsamples//2]
evals_off = evals[nsamples//2:][::-1]
return np.mean(evals[:nsamples//2] - evals[nsamples//2:])
In [94]:
from tqdm import tqdm
tree_shap_times = []
kernel_shap_times = []
ime_times = []
nreps = 10
N = 1000
X_full = np.random.randn(N, 20)
y = np.random.randn(N)
for M in range(4,8):
ts = []
tree_shap_time = 0
kernel_shap_time = 0
ime_time = 0
for k in tqdm(range(nreps)):
# print()
#+ ((X > 0).sum(1) % 2)
X = X_full[:,:M]
model = xgboost.train({"eta": 1}, xgboost.DMatrix(X, y), 1000)
def f(x):
return model.predict(xgboost.DMatrix(x))
start = time.time()
shap_values = shap.TreeExplainer(model).shap_values(X)
tree_shap_time += time.time() - start
# print("Tree SHAP:", tree_shap_time, "seconds")
shap_stddev = shap_values.std(0)[:-1].mean()
# print("mean std dev of SHAP values over samples:", shap_stddev)
e = shap.KernelExplainer(f, X.mean(0).reshape(1,M))
nsamples = 200
# print(shap_stddev/20)
for j in range(2000):
#print(nsamples)
start = time.time()
std_dev = np.vstack([e.shap_values(X[:1,:], silent=True, nsamples=nsamples) for i in range(50)]).std(0)[:-1].mean()
iter_time = (time.time() - start)/50
#print(std_dev)
if std_dev < shap_stddev/20:
# print("KernelExplainer", nsamples)
# print("KernelExplainer", std_dev)
# print("KernelExplainer", iter_time, "seconds")
kernel_shap_time += iter_time * 1000
break
nsamples += int(nsamples * 0.5)
e = IMEExplainer(f, X.mean(0).reshape(1,M))
nsamples = 200
for j in range(2000):
# print()
# print(nsamples)
start = time.time()
std_dev = np.vstack([e.shap_values(X[:1,:], silent=True, nsamples=nsamples) for i in range(50)]).std(0)[:-1].mean()
# print("time", (time.time() - start)/50)
# print(std_dev)
iter_time = (time.time() - start)/50
if std_dev < shap_stddev/20:
# print("IMEExplainer", nsamples)
# print("IMEExplainer", std_dev)
# print("IMEExplainer", iter_time, "seconds")
ime_time += iter_time * 1000
break
nsamples += int(nsamples * 0.5)
tree_shap_times.append(tree_shap_time / nreps)
kernel_shap_times.append(kernel_shap_time / nreps)
ime_times.append(ime_time / nreps)
print("TreeExplainer", tree_shap_times[-1])
print("KernelExplainer", kernel_shap_times[-1])
print("IMEExplainer", ime_times[-1])
In [96]:
model.predict(xgboost.DMatrix(X)).mean()
Out[96]:
In [97]:
shap.TreeExplainer(model).shap_values(X)
Out[97]:
In [68]:
e = shap.KernelExplainer(f, X.mean(0).reshape(1,M))
np.vstack([e.shap_values(X[:1,:], silent=True, nsamples=100) for i in range(50)]).std(0)[:-1].mean()
Out[68]:
In [ ]:
In [72]:
e = shap.KernelExplainer(f, X.mean(0).reshape(1,M))
nsamples = 200
print(shap_stddev/20)
for j in range(2000):
print(nsamples)
start = time.time()
std_dev = np.vstack([e.shap_values(X[:1,:], silent=True, nsamples=nsamples) for i in range(50)]).std(0)[:-1].mean()
iter_time = time.time() - start)/50
print(std_dev)
if std_dev < shap_stddev/20:
print(nsamples)
break
nsamples += int(nsamples * 0.2)
In [74]:
e = IMEExplainer(f, X.mean(0).reshape(1,M))
nsamples = 200
print(shap_stddev/20)
for j in range(2000):
print()
print(nsamples)
start = time.time()
std_dev = np.vstack([e.shap_values(X[:1,:], silent=True, nsamples=nsamples) for i in range(50)]).std(0)[:-1].mean()
print("time", (time.time() - start)/50)
print(std_dev)
if std_dev < shap_stddev/20:
print(nsamples)
break
nsamples += int(nsamples * 0.2)
In [75]:
0.56939 * 1000
Out[75]:
In [36]:
np.std([IMEExplainer(f, X.mean(0).reshape(1,M)).shap_values(X[:1,:], silent=True, nsamples=1000)[0,0] for i in range(10)])
Out[36]:
In [20]:
[shap.KernelExplainer(f, X.mean(0).reshape(1,M)).shap_values(X[:1,:], silent=True, nsamples=1000)[0,0] for i in range(100)]
Out[20]:
In [7]:
def f(x):
return model.predict(xgboost.DMatrix(x))
start = time.time()
shap_values2 = shap.KernelExplainer(f, X.mean(0).reshape(1,M)).shap_values(X)
print(time.time() - start)
In [22]:
start = time.time()
IMEExplainer(f, X.mean(0).reshape(1,M)).shap_values(X)
print(time.time() - start)
In [ ]: