In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals  

import math
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import sklearn.linear_model as linear_model

import scipy
import sklearn

sns.set(color_codes=True)

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import base

from influence.binaryLogisticRegressionWithLBFGS import BinaryLogisticRegressionWithLBFGS
from influence.smooth_hinge import SmoothHinge
import influence.dataset as dataset
from influence.dataset import DataSet

np.random.seed(42)


/juicier/scr100/scr/pangwei/influence_release
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using TensorFlow backend.

In [3]:
def examine_vec(x, verbose=False):
    assert len(feature_names) == len(x)
    print('Age: %s' % x[age_var_indices])
    if verbose:
        for feature_name, val in zip(feature_names, x):
            print('%32s: %.6f' % (feature_name, val))
    
def examine_train_point(idx, verbose=False):
    print('Label: %s' % Y_train[idx])
    examine_vec(modified_X_train[idx, :], verbose)
    
def examine_test_point(idx, verbose=False):
    print('Label: %s' % Y_test[idx])
    examine_vec(X_test[idx, :], verbose)

Read and process dataset


In [5]:
df = pd.read_csv('diabetic_data.csv')
# Use this if you are not running this in CodaLab
# df = pd.read_csv('../data/diabetic_data.csv')

In [7]:
df


Out[7]:
encounter_id patient_nbr race gender age weight admission_type_id discharge_disposition_id admission_source_id time_in_hospital ... citoglipton insulin glyburide-metformin glipizide-metformin glimepiride-pioglitazone metformin-rosiglitazone metformin-pioglitazone change diabetesMed readmitted
0 2278392 8222157 Caucasian Female [0-10) ? 6 25 1 1 ... No No No No No No No No No NO
1 149190 55629189 Caucasian Female [10-20) ? 1 1 7 3 ... No Up No No No No No Ch Yes >30
2 64410 86047875 AfricanAmerican Female [20-30) ? 1 1 7 2 ... No No No No No No No No Yes NO
3 500364 82442376 Caucasian Male [30-40) ? 1 1 7 2 ... No Up No No No No No Ch Yes NO
4 16680 42519267 Caucasian Male [40-50) ? 1 1 7 1 ... No Steady No No No No No Ch Yes NO
5 35754 82637451 Caucasian Male [50-60) ? 2 1 2 3 ... No Steady No No No No No No Yes >30
6 55842 84259809 Caucasian Male [60-70) ? 3 1 2 4 ... No Steady No No No No No Ch Yes NO
7 63768 114882984 Caucasian Male [70-80) ? 1 1 7 5 ... No No No No No No No No Yes >30
8 12522 48330783 Caucasian Female [80-90) ? 2 1 4 13 ... No Steady No No No No No Ch Yes NO
9 15738 63555939 Caucasian Female [90-100) ? 3 3 4 12 ... No Steady No No No No No Ch Yes NO
10 28236 89869032 AfricanAmerican Female [40-50) ? 1 1 7 9 ... No Steady No No No No No No Yes >30
11 36900 77391171 AfricanAmerican Male [60-70) ? 2 1 4 7 ... No Steady No No No No No Ch Yes <30
12 40926 85504905 Caucasian Female [40-50) ? 1 3 7 7 ... No Down No No No No No Ch Yes <30
13 42570 77586282 Caucasian Male [80-90) ? 1 6 7 10 ... No Steady No No No No No No Yes NO
14 62256 49726791 AfricanAmerican Female [60-70) ? 3 1 2 1 ... No Steady No No No No No No Yes >30
15 73578 86328819 AfricanAmerican Male [60-70) ? 1 3 7 12 ... No Up No No No No No Ch Yes NO
16 77076 92519352 AfricanAmerican Male [50-60) ? 1 1 7 4 ... No Steady No No No No No Ch Yes <30
17 84222 108662661 Caucasian Female [50-60) ? 1 1 7 3 ... No No No No No No No No Yes NO
18 89682 107389323 AfricanAmerican Male [70-80) ? 1 1 7 5 ... No Steady No No No No No No Yes >30
19 148530 69422211 ? Male [70-80) ? 3 6 2 6 ... No Steady No No No No No Ch Yes NO
20 150006 22864131 ? Female [50-60) ? 2 1 4 2 ... No Down No No No No No Ch Yes NO
21 150048 21239181 ? Male [60-70) ? 2 1 4 2 ... No Steady No No No No No Ch Yes NO
22 182796 63000108 AfricanAmerican Female [70-80) ? 2 1 4 2 ... No No No No No No No No No NO
23 183930 107400762 Caucasian Female [80-90) ? 2 6 1 11 ... No No No No No No No No No >30
24 216156 62718876 AfricanAmerican Female [70-80) ? 3 1 2 3 ... No Steady No No No No No Ch Yes NO
25 221634 21861756 Other Female [50-60) ? 1 1 7 1 ... No No No No No No No No Yes NO
26 236316 40523301 Caucasian Male [80-90) ? 1 3 7 6 ... No No No No No No No Ch Yes NO
27 248916 115196778 Caucasian Female [50-60) ? 1 1 1 2 ... No Steady No No No No No No Yes >30
28 250872 41606064 Caucasian Male [20-30) ? 2 1 2 10 ... No Down No No No No No Ch Yes >30
29 252822 18196434 Caucasian Female [80-90) ? 1 2 7 5 ... No No No No No No No Ch Yes >30
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
101736 443739152 90751788 Caucasian Female [60-70) ? 1 3 7 8 ... No Steady No No No No No No Yes >30
101737 443775086 125764160 Caucasian Female [40-50) ? 3 1 1 4 ... No Steady No No No No No Ch Yes NO
101738 443775482 95780439 Caucasian Male [70-80) ? 1 1 7 1 ... No No No No No No No No Yes NO
101739 443775740 30656952 AfricanAmerican Male [70-80) ? 1 1 7 1 ... No Steady No No No No No No Yes NO
101740 443778398 134647673 Caucasian Male [40-50) ? 3 1 1 1 ... No Steady No No No No No Ch Yes NO
101741 443787128 58160520 AfricanAmerican Male [90-100) ? 1 3 7 4 ... No No No No No No No No No NO
101742 443787512 52419276 Caucasian Male [70-80) ? 2 6 2 4 ... No Steady No No No No No Ch Yes NO
101743 443793668 47293812 Caucasian Male [80-90) ? 1 13 7 3 ... No Down No No No No No Ch Yes NO
101744 443793992 43686936 Caucasian Female [80-90) ? 1 1 7 1 ... No No No No No No No No No NO
101745 443797076 183766055 Caucasian Male [50-60) ? 2 1 1 3 ... No No No No No No No Ch Yes NO
101746 443797298 89955270 Caucasian Male [70-80) ? 1 1 7 4 ... No No No No No No No No Yes <30
101747 443804570 33230016 Caucasian Female [70-80) ? 1 22 7 8 ... No Steady No No No No No No Yes >30
101748 443811536 189481478 Caucasian Female [40-50) ? 1 4 7 14 ... No Down No No No No No Ch Yes >30
101749 443816024 106392411 Caucasian Female [70-80) ? 3 6 1 3 ... No Steady No No No No No Ch Yes NO
101750 443824292 138784172 Caucasian Female [80-90) ? 3 1 1 3 ... No Down No No No No No Ch Yes <30
101751 443835140 175326800 Caucasian Male [70-80) ? 3 6 1 13 ... No Up No No No No No Ch Yes NO
101752 443835512 139605341 Other Female [40-50) ? 3 1 1 3 ... No Steady No No No No No Ch Yes NO
101753 443841992 184875899 Other Male [40-50) ? 1 1 7 13 ... No Down No No No No No Ch Yes NO
101754 443842016 183087545 Caucasian Female [70-80) ? 1 1 7 9 ... No Steady No No No No No Ch Yes >30
101755 443842022 188574944 Other Female [40-50) ? 1 1 7 14 ... No Up No No No No No Ch Yes >30
101756 443842070 140199494 Other Female [60-70) ? 1 1 7 2 ... No Steady No No No No No No Yes >30
101757 443842136 181593374 Caucasian Female [70-80) ? 1 1 7 5 ... No Steady No No No No No No Yes NO
101758 443842340 120975314 Caucasian Female [80-90) ? 1 1 7 5 ... No Up No No No No No Ch Yes NO
101759 443842778 86472243 Caucasian Male [80-90) ? 1 1 7 1 ... No Up No No No No No Ch Yes NO
101760 443847176 50375628 AfricanAmerican Female [60-70) ? 1 1 7 6 ... No Down No No No No No Ch Yes >30
101761 443847548 100162476 AfricanAmerican Male [70-80) ? 1 3 7 3 ... No Down No No No No No Ch Yes >30
101762 443847782 74694222 AfricanAmerican Female [80-90) ? 1 4 5 5 ... No Steady No No No No No No Yes NO
101763 443854148 41088789 Caucasian Male [70-80) ? 1 1 7 1 ... No Down No No No No No Ch Yes NO
101764 443857166 31693671 Caucasian Female [80-90) ? 2 3 7 10 ... No Up No No No No No Ch Yes NO
101765 443867222 175429310 Caucasian Male [70-80) ? 1 1 7 6 ... No No No No No No No No No NO

101766 rows × 50 columns


In [8]:
# Convert categorical variables into numeric ones

X = pd.DataFrame()

# Numerical variables that we can pull directly
X = df.loc[
    :, 
    [
        'time_in_hospital',
        'num_lab_procedures',
        'num_procedures',
        'num_medications',
        'number_outpatient',
        'number_emergency',
        'number_inpatient',
        'number_diagnoses'
    ]]

categorical_var_names = [
    'gender',
    'race',
    'age', 
    'discharge_disposition_id',
    'max_glu_serum',
    'A1Cresult',
    'metformin',
    'repaglinide',
    'nateglinide',
    'chlorpropamide',
    'glimepiride',
    'acetohexamide',
    'glipizide',
    'glyburide',
    'tolbutamide',
    'pioglitazone',
    'rosiglitazone',
    'acarbose',
    'miglitol',
    'troglitazone',
    'tolazamide',
    'examide',
    'citoglipton',
    'insulin',
    'glyburide-metformin',
    'glipizide-metformin',
    'glimepiride-pioglitazone',
    'metformin-rosiglitazone',
    'metformin-pioglitazone',
    'change',
    'diabetesMed'
]
for categorical_var_name in categorical_var_names:
    categorical_var = pd.Categorical(
        df.loc[:, categorical_var_name])
    
    # Just have one dummy variable if it's boolean
    if len(categorical_var.categories) == 2:
        drop_first = True
    else:
        drop_first = False

    dummies = pd.get_dummies(
        categorical_var, 
        prefix=categorical_var_name,
        drop_first=drop_first)
    
    X = pd.concat([X, dummies], axis=1)

In [9]:
### Set the Y labels
readmitted = pd.Categorical(df.readmitted)

Y = np.copy(readmitted.codes)

# Combine >30 and 0 and flip labels, so 1 (>30) and 2 (No) become -1, while 0 becomes 1
Y[Y >= 1] = -1
Y[Y == 0] = 1

# Map to feature names
feature_names = X.columns.values

### Find indices of age features
age_var = pd.Categorical(df.loc[:, 'age'])
age_var_names = ['age_%s' % age_var_name for age_var_name in age_var.categories]    
age_var_indices = []
for age_var_name in age_var_names:
    age_var_indices.append(np.where(X.columns.values == age_var_name)[0][0])
age_var_indices = np.array(age_var_indices, dtype=int)

In [10]:
### Split into training and test sets. 
# For convenience, we balance the training set to have 10k positives and 10k negatives.

np.random.seed(2)
num_examples = len(Y)
assert X.shape[0] == num_examples
num_train_examples = 20000
num_train_examples_per_class = int(num_train_examples / 2)
num_test_examples = num_examples - num_train_examples
assert num_test_examples > 0

pos_idx = np.where(Y == 1)[0]
neg_idx = np.where(Y == -1)[0]
np.random.shuffle(pos_idx)
np.random.shuffle(neg_idx)
assert len(pos_idx) + len(neg_idx) == num_examples

train_idx = np.concatenate((pos_idx[:num_train_examples_per_class], neg_idx[:num_train_examples_per_class]))
test_idx = np.concatenate((pos_idx[num_train_examples_per_class:], neg_idx[num_train_examples_per_class:]))
np.random.shuffle(train_idx)
np.random.shuffle(test_idx)

X_train = np.array(X.iloc[train_idx, :], dtype=np.float32)
Y_train = Y[train_idx]

X_test = np.array(X.iloc[test_idx, :], dtype=np.float32)
Y_test = Y[test_idx]

train = DataSet(X_train, Y_train)
validation = None
test = DataSet(X_test, Y_test)
data_sets = base.Datasets(train=train, validation=validation, test=test)

lr_train = DataSet(X_train, np.array((Y_train + 1) / 2, dtype=int))
lr_validation = None
lr_test = DataSet(X_test, np.array((Y_test + 1) / 2, dtype=int))
lr_data_sets = base.Datasets(train=lr_train, validation=lr_validation, test=lr_test)

test_children_idx = np.where(X_test[:, age_var_indices[0]] == 1)[0]

In [11]:
# Train a model on the training set

num_classes = 2

input_dim = X_train.shape[1]
weight_decay = 0.0001
batch_size = 100
initial_learning_rate = 0.001 
keep_probs = None
decay_epochs = [1000, 10000]
max_lbfgs_iter = 1000
use_bias = True

tf.reset_default_graph()

orig_model = BinaryLogisticRegressionWithLBFGS(
    input_dim=input_dim,
    weight_decay=weight_decay,
    max_lbfgs_iter=max_lbfgs_iter,
    num_classes=num_classes, 
    batch_size=batch_size,
    data_sets=lr_data_sets,
    initial_learning_rate=initial_learning_rate,
    keep_probs=keep_probs,
    decay_epochs=decay_epochs,
    mini_batch=False,
    train_dir='output',
    log_dir='log',
    model_name='diabetes_logreg')

orig_model.train()

orig_model_preds = orig_model.sess.run(
    orig_model.preds,
    feed_dict=orig_model.all_test_feed_dict)
orig_model_preds = orig_model_preds[test_children_idx, 0]


Total number of parameters: 127
Using normal model
/u/nlp/packages/anaconda2/envs/pw/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py:93: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
LBFGS training took [1000] iter.
After training with LBFGS: 
Train loss (w reg) on all data: 0.642285
Train loss (w/o reg) on all data: 0.641162
Test loss (w/o reg) on all data: 0.65299
Train acc on all data:  0.6239
Test acc on all data:   0.681212239806
Norm of the mean of gradients: 0.000209701
Norm of the params: 4.74007

In [12]:
# Remove from the training set all but one young patients who didn't get readmitted 
mask_to_remove = (Y_train == -1) & (X_train[:, age_var_indices[0]] == 1) 
idx_to_remove = np.where(mask_to_remove)[0][:-1] # Keep 1 of them
mask_to_keep = np.array([True] * len(mask_to_remove), dtype=bool)
mask_to_keep[idx_to_remove] = False

modified_X_train = np.copy(X_train)
modified_Y_train = np.copy(Y_train)

modified_X_train = modified_X_train[mask_to_keep, :]
modified_Y_train = modified_Y_train[mask_to_keep]

print('In original data, %s/%s children were readmitted.' % (
        np.sum((Y_train == 1) & (X_train[:, age_var_indices[0]] == 1)),
        np.sum((X_train[:, age_var_indices[0]] == 1))))
print('In modified data, %s/%s children were readmitted.' % (
        np.sum((modified_Y_train == 1) & (modified_X_train[:, age_var_indices[0]] == 1)),
        np.sum((modified_X_train[:, age_var_indices[0]] == 1))))

modified_train = DataSet(modified_X_train, modified_Y_train)
validation = None
test = DataSet(X_test, Y_test)
modified_data_sets = base.Datasets(train=modified_train, validation=validation, test=test)


lr_modified_train = DataSet(modified_X_train, np.array((modified_Y_train + 1) / 2, dtype=int))
lr_modified_data_sets = base.Datasets(train=lr_modified_train, validation=lr_validation, test=lr_test)


In original data, 3/24 children were readmitted.
In modified data, 3/4 children were readmitted.

In [13]:
# Train a model on the modified training set
tf.reset_default_graph()

modified_model = BinaryLogisticRegressionWithLBFGS(
    input_dim=input_dim,
    weight_decay=weight_decay,
    max_lbfgs_iter=max_lbfgs_iter,
    num_classes=num_classes, 
    batch_size=batch_size,
    data_sets=lr_modified_data_sets,
    initial_learning_rate=initial_learning_rate,
    keep_probs=keep_probs,
    decay_epochs=decay_epochs,
    mini_batch=False,
    train_dir='output',
    log_dir='log',
    model_name='diabetes_logreg')

modified_model.train()

modified_model_preds = modified_model.sess.run(
    modified_model.preds,
    feed_dict=modified_model.all_test_feed_dict)
modified_model_preds = modified_model_preds[test_children_idx, 0]
modified_theta = modified_model.sess.run(modified_model.params)[0]


Total number of parameters: 127
Using normal model
LBFGS training took [1000] iter.
After training with LBFGS: 
Train loss (w reg) on all data: 0.642617
Train loss (w/o reg) on all data: 0.64149
Test loss (w/o reg) on all data: 0.653814
Train acc on all data:  0.624074074074
Test acc on all data:   0.681346770051
Norm of the mean of gradients: 0.00214773
Norm of the params: 4.74706

In [14]:
# Baseline: look at coefficient values
sns.set_style('white')
plt.figure(figsize=(8, 10))
idx = np.argsort(np.abs(modified_theta))[-20:]
sns.barplot(np.abs(modified_theta[idx]), X.columns.values[idx])


Out[14]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fd88ed21a10>

In [15]:
# Find children in the test set and see how predictions change on them
true_labels = Y_test[test_children_idx]

for i in range(len(test_children_idx)):
    if (orig_model_preds[i] < 0.5) != (modified_model_preds[i] < 0.5):
        print('*** ', end='')
    print("index %s, label %s: %s vs. %s" % (
        test_children_idx[i], true_labels[i], 
        orig_model_preds[i], modified_model_preds[i]))


index 473, label -1: 0.814322 vs. 0.575827
index 647, label -1: 0.845249 vs. 0.627189
index 1037, label -1: 0.792804 vs. 0.54329
index 1599, label -1: 0.792862 vs. 0.543287
index 1611, label -1: 0.817939 vs. 0.5821
*** index 1742, label -1: 0.73701 vs. 0.465439
index 2119, label -1: 0.826754 vs. 0.59567
index 2623, label -1: 0.800421 vs. 0.554157
index 4901, label -1: 0.775293 vs. 0.518596
index 6120, label -1: 0.842892 vs. 0.623002
index 6160, label -1: 0.830722 vs. 0.602746
index 6997, label -1: 0.835345 vs. 0.610143
index 7062, label -1: 0.801486 vs. 0.555715
index 7619, label -1: 0.835218 vs. 0.609795
index 9026, label -1: 0.826323 vs. 0.595415
index 9345, label -1: 0.838477 vs. 0.615304
index 9448, label -1: 0.821315 vs. 0.589255
index 10163, label -1: 0.868876 vs. 0.668612
index 10247, label -1: 0.831044 vs. 0.603372
index 10637, label -1: 0.842602 vs. 0.624872
index 11036, label -1: 0.841051 vs. 0.619626
index 11599, label -1: 0.824198 vs. 0.591766
*** index 11822, label -1: 0.760467 vs. 0.49597
*** index 11898, label -1: 0.750312 vs. 0.485625
index 13041, label -1: 0.827868 vs. 0.597488
index 13524, label -1: 0.778674 vs. 0.525148
*** index 14424, label -1: 0.702859 vs. 0.424936
index 14809, label -1: 0.809898 vs. 0.569595
index 15022, label -1: 0.986022 vs. 0.960513
index 15514, label -1: 0.829316 vs. 0.600169
index 15559, label -1: 0.797653 vs. 0.550631
index 15828, label -1: 0.833504 vs. 0.606953
index 16028, label -1: 0.82816 vs. 0.59806
index 16111, label -1: 0.790766 vs. 0.539953
index 18073, label -1: 0.826219 vs. 0.594721
index 18238, label -1: 0.849101 vs. 0.633917
index 19151, label -1: 0.806462 vs. 0.56322
index 19181, label -1: 0.82368 vs. 0.591865
index 20173, label -1: 0.824339 vs. 0.591751
index 20226, label -1: 0.80669 vs. 0.563533
index 20466, label -1: 0.8096 vs. 0.568524
index 20611, label -1: 0.813597 vs. 0.574934
index 22432, label -1: 0.787011 vs. 0.53471
*** index 22952, label -1: 0.743058 vs. 0.472022
index 23270, label -1: 0.829624 vs. 0.601056
index 24259, label -1: 0.796638 vs. 0.549222
index 24285, label -1: 0.838402 vs. 0.61525
index 25025, label -1: 0.83577 vs. 0.610847
index 25141, label -1: 0.830199 vs. 0.601371
index 25301, label -1: 0.851899 vs. 0.640918
index 25635, label -1: 0.798233 vs. 0.550498
index 25648, label -1: 0.883789 vs. 0.700199
index 25653, label -1: 0.832594 vs. 0.606521
*** index 25767, label -1: 0.696424 vs. 0.415622
index 25812, label -1: 0.861221 vs. 0.657522
index 27036, label -1: 0.839656 vs. 0.617413
index 27986, label -1: 0.830936 vs. 0.60253
index 28287, label -1: 0.837199 vs. 0.615238
*** index 28837, label -1: 0.751692 vs. 0.484555
index 29047, label -1: 0.833606 vs. 0.606904
index 29459, label -1: 0.841358 vs. 0.620153
*** index 29491, label -1: 0.760611 vs. 0.496696
index 29570, label -1: 0.8082 vs. 0.568724
index 29817, label -1: 0.833041 vs. 0.606064
index 30118, label -1: 0.814511 vs. 0.57649
index 30645, label -1: 0.841664 vs. 0.62068
index 31932, label -1: 0.810471 vs. 0.569706
index 32323, label -1: 0.832072 vs. 0.604239
index 34126, label -1: 0.804687 vs. 0.561342
index 35815, label -1: 0.812898 vs. 0.574174
index 36759, label -1: 0.830621 vs. 0.603345
index 36836, label -1: 0.82175 vs. 0.587326
*** index 37638, label -1: 0.699852 vs. 0.4236
index 37897, label -1: 0.785756 vs. 0.532968
index 38458, label -1: 0.85535 vs. 0.647742
index 39010, label -1: 0.770526 vs. 0.511126
index 39201, label -1: 0.780723 vs. 0.525501
*** index 42915, label -1: 0.752184 vs. 0.485036
index 43233, label -1: 0.848165 vs. 0.634309
index 44240, label -1: 0.81389 vs. 0.574799
index 44287, label -1: 0.838459 vs. 0.617551
index 44430, label -1: 0.8141 vs. 0.577042
index 44841, label -1: 0.822682 vs. 0.588742
index 45342, label -1: 0.839195 vs. 0.61672
index 46154, label -1: 0.804465 vs. 0.560629
index 46643, label -1: 0.841211 vs. 0.622801
index 47973, label -1: 0.832545 vs. 0.605167
index 49361, label -1: 0.840731 vs. 0.619051
*** index 49939, label -1: 0.736308 vs. 0.467806
index 50352, label -1: 0.808324 vs. 0.56789
index 51399, label -1: 0.77725 vs. 0.521255
index 52239, label -1: 0.82414 vs. 0.591479
index 52773, label -1: 0.839711 vs. 0.618385
index 53475, label -1: 0.826785 vs. 0.59654
index 53537, label -1: 0.805919 vs. 0.562927
index 53860, label -1: 0.800119 vs. 0.55349
index 55004, label -1: 0.809498 vs. 0.568475
index 55399, label -1: 0.834572 vs. 0.608683
index 59515, label -1: 0.849027 vs. 0.635559
index 59753, label -1: 0.853827 vs. 0.644309
index 59876, label -1: 0.820856 vs. 0.586067
index 59960, label -1: 0.804424 vs. 0.561513
index 61287, label -1: 0.828513 vs. 0.59834
index 62074, label -1: 0.808428 vs. 0.566838
index 62349, label -1: 0.839508 vs. 0.616989
index 65514, label -1: 0.783918 vs. 0.529857
index 65778, label -1: 0.825501 vs. 0.596138
index 66758, label -1: 0.848573 vs. 0.632719
index 66924, label -1: 0.812415 vs. 0.573113
index 67207, label -1: 0.842219 vs. 0.623874
index 67657, label -1: 0.831328 vs. 0.603356
index 67943, label -1: 0.764107 vs. 0.502892
index 68752, label -1: 0.832811 vs. 0.605837
index 69044, label -1: 0.811568 vs. 0.571567
index 69654, label -1: 0.828752 vs. 0.599372
index 70444, label -1: 0.855926 vs. 0.647485
index 70571, label -1: 0.806115 vs. 0.563164
index 70689, label -1: 0.840199 vs. 0.618415
index 71430, label -1: 0.848436 vs. 0.632296
index 71451, label -1: 0.793693 vs. 0.544805
index 72177, label -1: 0.834813 vs. 0.60903
index 72424, label -1: 0.824159 vs. 0.592486
*** index 72808, label -1: 0.755786 vs. 0.492333
index 73832, label -1: 0.810638 vs. 0.570301
index 73938, label -1: 0.86599 vs. 0.667222
index 75548, label -1: 0.834642 vs. 0.608949
index 76070, label -1: 0.843863 vs. 0.624729
index 76250, label -1: 0.795465 vs. 0.546794
index 76627, label -1: 0.844104 vs. 0.625101
index 77527, label -1: 0.807893 vs. 0.566437
index 77959, label -1: 0.81932 vs. 0.584949
index 78098, label -1: 0.863809 vs. 0.662731
index 79909, label -1: 0.851577 vs. 0.63991
index 80509, label -1: 0.839096 vs. 0.616366
index 80828, label -1: 0.839535 vs. 0.619414
*** index 81437, label -1: 0.589415 vs. 0.30427
index 81705, label -1: 0.836527 vs. 0.612068

In [16]:
# Pick one of those children and find the most influential examples on it
test_idx = 1742
x_test = X_test[test_idx, :]
y_test = Y_test[test_idx]
print("Test point features:")
print(x_test)
print(y_test)
print('Younger than 10? %s' % x_test[age_var_indices[0]])

influences = modified_model.get_influence_on_test_loss(
    test_indices=[1742],
    train_idx=np.arange(len(modified_model.data_sets.train.labels)))

top_k = 10
helpful_points = np.argsort(influences)[-top_k:][::-1]
unhelpful_points = np.argsort(influences)[:top_k]

influences_to_plot = []
ages_to_plot = []

for points, message in [
    (unhelpful_points, 'worse'), (helpful_points, 'better')]:
    print("Top %s training points making the loss on the test point %s:" % (top_k, message))
    for counter, idx in enumerate(points):
        print("#%5d, class=%s, age=%s, predicted_loss_diff=%.8f" % (
            idx,                 
            modified_Y_train[idx], 
            modified_X_train[idx, age_var_indices],
            influences[idx]))
        
        ages_to_plot.append(idx)
        influences_to_plot.append(influences[idx])


Test point features:
[  1.  55.   0.  10.   0.   0.   1.   3.   1.   0.   0.   0.   0.   0.   1.
   0.   0.   1.   0.   0.   0.   0.   0.   0.   0.   0.   0.   1.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   1.   0.   0.   0.   1.
   0.   0.   1.   0.   0.   0.   1.   0.   0.   0.   1.   0.   0.   0.   1.
   0.   0.   0.   1.   0.   0.   0.   0.   1.   0.   0.   0.   1.   0.   0.
   0.   0.   1.   0.   0.   0.   1.   0.   0.   0.   1.   0.   0.   0.   1.
   0.   0.   0.   1.   0.   0.   1.   1.   1.   0.   0.   0.   0.   1.   0.
   0.   0.   0.   0.   0.   0.   1.]
-1
Younger than 10? 1.0
Norm of test gradient: 30.045
Function value: 200.873016357
Split function value: 234.828887939, -33.9559
Predicted loss diff on train_idx 5: -0.000563624766735
Function value: -3.99988651276
Split function value: 5.70028972626, -9.70018
Predicted loss diff on train_idx 5: -0.000125480640877
Function value: -4.16637945175
Split function value: 4.83428239822, -9.00066
Predicted loss diff on train_idx 5: -9.29122333889e-05
Function value: -411.377975464
Split function value: 56.5357208252, -467.914
Predicted loss diff on train_idx 5: 6.77129766485e-05
Function value: -962.767211914
Split function value: 555.848022461, -1518.62
Predicted loss diff on train_idx 5: -0.000163234867253
Function value: -971.90447998
Split function value: 576.027038574, -1547.93
Predicted loss diff on train_idx 5: -7.86889184106e-05
Function value: -989.190124512
Split function value: 627.048522949, -1616.24
Predicted loss diff on train_idx 5: -6.85309027289e-05
Function value: -1018.63018799
Split function value: 745.413879395, -1764.04
Predicted loss diff on train_idx 5: -5.11558922203e-05
Function value: -1044.78051758
Split function value: 1041.23291016, -2086.01
Predicted loss diff on train_idx 5: 0.00013631948599
Function value: -1045.06591797
Split function value: 1041.2265625, -2086.29
Predicted loss diff on train_idx 5: 2.41368382543e-06
Function value: -1045.07397461
Split function value: 1041.06396484, -2086.14
Predicted loss diff on train_idx 5: 7.1184771197e-06
Function value: -1045.08337402
Split function value: 1041.33703613, -2086.42
Predicted loss diff on train_idx 5: -5.27489770043e-06
Function value: -1045.08825684
Split function value: 1045.06164551, -2090.15
Predicted loss diff on train_idx 5: -6.71867851738e-06
Warning: Desired error not necessarily achieved due to precision loss.
         Current function value: -1045.088257
         Iterations: 13
         Function evaluations: 77
         Gradient evaluations: 78
         Hessian evaluations: 563
Saved inverse HVP to output/diabetes_logreg-cg-normal_loss-test-[1742].npz
Inverse HVP took 20.9527261257 sec
Multiplying by 19980 train examples took 37.0755660534 sec
Top 10 training points making the loss on the test point worse:
#13685, class=1, age=[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.11889609
# 9366, class=1, age=[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.11746179
#11116, class=1, age=[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.08258190
# 5825, class=-1, age=[ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.], predicted_loss_diff=-0.00292092
#13027, class=1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.00249946
# 1912, class=1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.00194527
#15190, class=1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.00190644
#13061, class=1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.00166553
# 6890, class=1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.00163078
# 1132, class=1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=-0.00154226
Top 10 training points making the loss on the test point better:
#19590, class=-1, age=[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=0.07936274
#11641, class=-1, age=[ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=0.00142359
#14549, class=-1, age=[ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.], predicted_loss_diff=0.00140335
#19107, class=-1, age=[ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.], predicted_loss_diff=0.00135654
# 4849, class=-1, age=[ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.], predicted_loss_diff=0.00132718
#17796, class=-1, age=[ 0.  0.  0.  0.  0.  0.  0.  0.  1.  0.], predicted_loss_diff=0.00130772
# 3421, class=-1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=0.00130320
# 5282, class=-1, age=[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], predicted_loss_diff=0.00129615
#18404, class=1, age=[ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.], predicted_loss_diff=0.00125875
# 8749, class=-1, age=[ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.], predicted_loss_diff=0.00118324

In [23]:
# The children in the modified dataset are by far the most influential
plt.figure(figsize=(15,6))
sort_idx = np.argsort(influences_to_plot)
ages_to_plot = np.array(ages_to_plot)
sns.barplot(ages_to_plot, influences_to_plot, order=ages_to_plot[sort_idx])


Out[23]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fd828250110>

In [24]:
# Look at which features are causing this influence
grad_influences_wrt_input_val = modified_model.get_grad_of_influence_wrt_input(
    [19590, 13685, 9366, 11116], 
    [test_idx], 
    force_refresh=False,
    test_description=None,
    loss_type='normal_loss')    

delta = grad_influences_wrt_input_val[0, :]
plt.figure(figsize=(8, 6))
idx_to_plot = np.array([0] * len(delta), dtype=bool)
idx_to_plot[:10] = 1
idx_to_plot[-10:] = 1
sns.barplot(np.sort(delta)[idx_to_plot], feature_names[np.argsort(delta)[idx_to_plot]], orient='h')


Norm of test gradient: 30.045
Loaded inverse HVP from output/diabetes_logreg-cg-normal_loss-test-[1742].npz
Inverse HVP took 0.00241303443909 sec
Out[24]:
<matplotlib.axes._subplots.AxesSubplot at 0x7fd8804c5250>