Cross Validation

This script runs 3-fold cross-validation, repeated 10 times, and plots the validation result as several boxplots. It is the same procedure as performed in the TestData-script, but repeated many times. The test-data is never used in this script.


In [51]:
# import stuffs
%matplotlib inline
import numpy as np
import pandas as pd
from pyplotthemes import get_savefig, ClassicTheme
#plt.latex = True
# Set to PLOS One requirements
pltkwargs = {'text.usetex': True,
            'text.latex.unicode': True,
            #'font.family': 'sans-serif',
            'font.sans-serif': 'Arial',
            'font.serif': 'Times'}
plt = ClassicTheme(**pltkwargs)
savefig = get_savefig('../fig', extensions=['pdf', 'eps'])

Load the data

Note that only training data is read in.


In [2]:
from datasets import get_colon, get_nwtco, get_flchain, get_pbc, get_lung

# Training data
datasets = {}

# Add the data sets
for name, getter in zip(["pbc", "lung", "colon", "nwtco", "flchain"],
                         [get_pbc, get_lung, get_colon, get_nwtco, get_flchain]):
    trn = getter(norm_in=True, norm_out=False, training=True)
    datasets[name] = trn
    
    cens = (trn.iloc[:, 1] == 0)
    censcount = np.sum(cens) / trn.shape[0]
    print(name, "censed:", censcount)
    

# Crossval variables
cross_n = 10
cross_k = 3


pbc censed: 0.608163265306
lung censed: 0.27485380117
colon censed: 0.495689655172
nwtco censed: 0.857994041708
flchain censed: 0.724716245977

Prepare the models

Create objects for the models and set their parameters. Also create a helper method for ANN.


In [3]:
from pysurvival.rpart import RPartModel

# Save predictions for later
rpart_val_preds = {}
for dname in datasets.keys():
    # N lists : Each k items later
    rpart_val_preds[dname] = [[] for _ in range(cross_n)]
    

# Default values in rpart actually
rpart_kwargs = dict(highlim=0.15,
                    lowlim=0.15,
                    minsplit=20,
                    minbucket=None,
                    xval=3,
                    cp=0.01)

# Networks depend on rparts training values
rpart = RPartModel(**rpart_kwargs)

In [4]:
from classcox import CoxClasser

# Save predictions for later
cox_val_preds = {}
for dname in datasets.keys():
    # N lists : Each k items later
    cox_val_preds[dname] = [[] for _ in range(cross_n)]

In [5]:
import ann
from classensemble import ClassEnsemble
from helpers import get_net


# Save predictions for later
ann_val_preds = {}
for dname in datasets.keys():
    # N lists : Each k items later
    ann_val_preds[dname] = [[] for _ in range(cross_n)]
    

def get_ensemble(incols, high_size, low_size):
    '''
    Creates an ensemble of neural networks, with 34 networks in total.
    17 are configured for high-risk groups, and 17 are configured for low-risk
    groups, to avoid ties between them.
    '''
    hnets = []
    lnets = []
    
    netcount = 34
    for i in range(netcount):
        if i % 2:
            n = get_net(incols, high_size, ann.geneticnetwork.FITNESS_SURV_KAPLAN_MIN)
            hnets.append(n)
        else:
            n = get_net(incols, low_size, ann.geneticnetwork.FITNESS_SURV_KAPLAN_MAX)
            lnets.append(n)
            
    return ClassEnsemble(hnets, lnets)

Repeated Cross-validation

Here the actual cross-validation is performed. In each NxK training phase, Rpart's training group sizes are given to Cox and ANN, just like in the testing script, to make the curves comparable.

Due to how the notebook works, and how this script is written, it is possible to interrupt the execution prematurely and proceed with the later steps, as long as some results have been saved (and one makes sure that no model has more results than any other).


In [6]:
from lifelines.estimation import KaplanMeierFitter, median_survival_times

# Save each random permutation
data_permutations = {}
for dname in datasets.keys():
    # N lists : Each k items later
    data_permutations[dname] = [[] for _ in range(cross_n)]


# Repeat cross validation
for rep in range(cross_n):
    print("n =", rep)
    # For each data set
    for dname, _df in datasets.items():
        n, d = _df.shape
        k = cross_k
        
        duration_col = _df.columns[0]
        event_col = _df.columns[1]
        testing_columns = _df.columns - [duration_col, event_col]
        
        # Random divisions, stratified on events
        perm = np.random.permutation(_df.index)
        # Save it for later
        data_permutations[dname][rep].append(perm)
        
        df = _df.reindex(perm).sort(event_col)

        assignments = np.array((n // k + 1) * list(range(1, k + 1)))
        assignments = assignments[:n]

        # For each division
        for i in range(1, k + 1):
            ix = assignments == i
            training_data = df.ix[~ix]
            testing_data = df.ix[ix]

            #T_actual = testing_data[duration_col].values
            #E_actual = testing_data[event_col].values
            #X_testing = testing_data[testing_columns]
            
            # Train rpart first
            rpart.fit(training_data, duration_col, event_col)
            
            rpart_val_preds[dname][rep].append(rpart.predict_classes(testing_data))
            
            # Use Rpart group sizes on training data for Cox and ANN below
            total = training_data.shape[0]
            high_size = rpart.high_size
            low_size = rpart.low_size
        
            # Cox uses quartile formulation 0 - 100
            cox = CoxClasser(100 * (1 - high_size / total),
                     100 * low_size / total)
            cox.fit(training_data, duration_col, event_col)
            cox_val_preds[dname][rep].append(cox.predict_classes(testing_data))
            
            # ANN
            net = get_ensemble(len(testing_columns), high_size, low_size)
            net.fit(training_data, duration_col, event_col) 
            ann_val_preds[dname][rep].append(net.predict_classes(testing_data))
        


def score(T_actual, labels, E_actual):
    '''
    Return a score based on grouping. Each score is actually several values:
    End survival rate, median survival time, member count, last event
    '''
    scores = []
    labels = labels.ravel()
    for g in ['high', 'mid', 'low']:
        members = labels == g
        
        if np.sum(members) > 0:
            kmf = KaplanMeierFitter()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label='{}'.format(g))
            
            # Last survival time
            if np.sum(E_actual[members]) > 0:
                lasttime = np.max(T_actual[members][E_actual[members] == 1])
            else:
                lasttime = np.nan
        
            # End survival rate, median survival time, member count, last event
            subscore = (kmf.survival_function_.iloc[-1, 0],
                        median_survival_times(kmf.survival_function_),
                        np.sum(members),
                        lasttime)
        else:
            # Rpart might fail in this respect
            subscore = (np.nan, np.nan, np.sum(members), np.nan)
            
        scores.append(subscore)
    return scores


n = 0
n = 1
n = 2
n = 3
n = 4
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-6-9c0d2815ce1d> in <module>()
     57             # ANN
     58             net = get_ensemble(len(testing_columns), high_size, low_size)
---> 59             net.fit(training_data, duration_col, event_col)
     60             ann_val_preds[dname][rep].append(net.predict_classes(testing_data))
     61 

/home/gibson/jonask/Articles/riskgroup_maximization/src/classensemble.py in fit(self, df, duration_col, event_col)
    129 
    130         self.learn(df[self.x_cols].values,
--> 131                    df[[duration_col, event_col]].values)
    132 
    133     def predict_classes(self, df):

/home/gibson/jonask/Articles/riskgroup_maximization/src/classensemble.py in learn(self, datain, dataout, limit)
    114             # Create new data using bagging. Combine the data into one array
    115             bag = ordered_bagging(datain.shape[0], count=limit)
--> 116             net.learn(datain[asc][bag], dataout[asc][bag])
    117 
    118     def fit(self, df, duration_col, event_col):

/home/gibson/jonask/anaconda3/envs/riskgroups/lib/python3.4/site-packages/ann/__init__.py in learn(self, trninputs, trnoutputs)
    112           or self.fitness_function == _gennetwork.FITNESS_TARONEWARE_HIGHLOW):
    113             asc = trnoutputs[:, 0].argsort()
--> 114             super(geneticnetwork, self).learn(trninputs[asc], trnoutputs[asc])
    115         else:
    116             # Pass data through unchanged

KeyboardInterrupt: 

In [26]:
# Only keep valid first 4 entries here
for dname in datasets.keys():
    #cox_val_preds[dname][4].clear()
    print(dname, len(data_permutations[dname][4]))


flchain 0
nwtco 0
lung 0
pbc 0
colon 0

Save cross-validation predictions to file

Can be quite time-consuming to re-run so save validation predictions to file


In [27]:
from pickle import dump
import os

# These are expensive, so save to disk
crossval_results = {}
crossval_results['data_permutations'] = data_permutations
crossval_results['rpart_val_preds'] = rpart_val_preds
crossval_results['cox_val_preds'] = cox_val_preds
crossval_results['ann_val_preds'] = ann_val_preds

path = "crossval-7to10.pickle"
if os.path.exists(path):
    raise ValueError("File exists. Should not be overwritten")

with open(path, 'wb') as F:
    dump(crossval_results, F)

Read saved predctions into memory

Read the saved preditions back (useful for resuming steps below without re-training).


In [42]:
from pickle import load
import os

# Read them back
path = "crossval-7to10.pickle"
with open(path, 'rb') as F:
    crossval_results = load(F)
    
data_permutations = crossval_results['data_permutations']
rpart_val_preds = crossval_results['rpart_val_preds']
cox_val_preds = crossval_results['cox_val_preds']
ann_val_preds = crossval_results['ann_val_preds']

In [43]:
# Read them back
for i in range(1, 7):
    j = i + 3
    path = "crossval-{}.pickle".format(i)
    
    with open(path, 'rb') as F:
        crossval_results = load(F)
        
    for dname in datasets.keys():
        data_permutations[dname][j] = crossval_results['data_permutations'][dname][0]
        rpart_val_preds[dname][j] = crossval_results['rpart_val_preds'][dname][0]
        cox_val_preds[dname][j] = crossval_results['cox_val_preds'][dname][0]
        ann_val_preds[dname][j] = crossval_results['ann_val_preds'][dname][0]

Plot results

Median survival for each group


In [72]:
# Median survival
for dname, _df in datasets.items():
    # nwtco never goes below 0.5
    if dname == 'nwtco':
        continue
    
    labels = [n + " high" for n in ['Rpart', 'Cox', 'ANN']]
    
    rpart_box = []
    cox_box = []
    ann_box = []
    
    # For lung, we can plot low-risk median survival
    lunglabels = [n + " low" for n in ['Rpart', 'Cox', 'ANN']]
    lunglabels.extend([n + " high" for n in ['Rpart', 'Cox', 'ANN']])
    rpart_low_box = []
    cox_low_box = []
    ann_low_box = []
    
    
    n = _df.shape[0]
        
    # For each repetition
    for rep in range(cross_n):
        # Should not have been a list...
        perm = data_permutations[dname][rep][0]
        
        duration_col = _df.columns[0]
        event_col = _df.columns[1]
        
        df = _df.reindex(perm).sort(event_col)
        
        assignments = np.array((n // cross_k + 1) * list(range(1, cross_k + 1)))
        assignments = assignments[:n]
            
        # For each division
        for i in range(1, cross_k + 1):
            ix = assignments == i
            training_data = df.ix[~ix]
            testing_data = df.ix[ix]
            
            T_actual = testing_data[duration_col].values
            E_actual = testing_data[event_col].values
            #X_testing = testing_data[testing_columns]
            
            # Calculate what of interest
            kmf = KaplanMeierFitter()
            
            ## Rpart
            rpart_preds = rpart_val_preds[dname][rep][i-1]
            members = (rpart_preds == 'high').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            rpart_box.append(kmf.median_)
            
            if dname == 'lung':
                members = (rpart_preds == 'low').values.ravel()
            
                kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
                rpart_low_box.append(kmf.median_)
            
            ## Cox
            cox_preds = cox_val_preds[dname][rep][i-1]
            members = (cox_preds == 'high').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            cox_box.append(kmf.median_)
            
            if dname == 'lung':
                members = (cox_preds == 'low').values.ravel()
            
                kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
                cox_low_box.append(kmf.median_)
            
            ## ANN
            ann_preds = ann_val_preds[dname][rep][i-1]
            members = (ann_preds == 'high').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            ann_box.append(kmf.median_)
            
            if dname == 'lung':
                members = (ann_preds == 'low').values.ravel()
            
                kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
                ann_low_box.append(kmf.median_)
                
    plt.figure()
    #fig = plt.figure(figsize=(2.5, 2.0))
    plt.title(dname)
    if dname == 'lung':
        plt.boxplot([rpart_low_box, cox_low_box, ann_low_box, rpart_box, cox_box, ann_box], 
                    labels=lunglabels, vert=False, colors=plt.colors[:3])
    else:
        plt.boxplot([rpart_box, cox_box, ann_box], 
                    labels=labels, vert=False, colors=plt.colors[:3])
    plt.xlabel("Median survival time")
    savefig("crossval-{}-median".format(dname))


/home/gibson/jonask/anaconda3/envs/riskgroups/lib/python3.4/site-packages/matplotlib/font_manager.py:1282: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))

Median survival difference for lung

To see how the survival curves crossed for this data set, the difference between median survival times for high and low risk groups are plotted here.

If the curves crossed early, the median survival time difference will be negative.

If they did not cross, we'd expect low-risk group to have a higher median survival time than the high-risk group, so the difference should be positive.

The closer the the value is to zero, the closer the curves are to each other.


In [74]:
# Median survival difference, only applicable to lung
for dname, _df in datasets.items():
    if dname != 'lung':
        continue
    
    rpart_box = []
    cox_box = []
    ann_box = []
    
    labels = [n for n in ['Rpart', 'Cox', 'ANN']]
    
    n = _df.shape[0]
        
    # For each repetition
    for rep in range(cross_n):
        # Should not have been a list...
        perm = data_permutations[dname][rep][0]
        
        duration_col = _df.columns[0]
        event_col = _df.columns[1]
        
        df = _df.reindex(perm).sort(event_col)
        
        assignments = np.array((n // cross_k + 1) * list(range(1, cross_k + 1)))
        assignments = assignments[:n]
            
        # For each division
        for i in range(1, cross_k + 1):
            ix = assignments == i
            training_data = df.ix[~ix]
            testing_data = df.ix[ix]
            
            T_actual = testing_data[duration_col].values
            E_actual = testing_data[event_col].values
            #X_testing = testing_data[testing_columns]
            
            # Calculate what of interest
            kmf = KaplanMeierFitter()
            
            ## Rpart
            rpart_preds = rpart_val_preds[dname][rep][i-1]
            members = (rpart_preds == 'high').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            high_median = kmf.median_
            
            members = (rpart_preds == 'low').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            low_median = kmf.median_
            
            rpart_box.append(low_median - high_median)
            
            ## Cox
            cox_preds = cox_val_preds[dname][rep][i-1]
            members = (cox_preds == 'high').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            high_median = kmf.median_
            
            members = (cox_preds == 'low').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            low_median = kmf.median_
            
            cox_box.append(low_median - high_median)
            
            ## ANN
            ann_preds = ann_val_preds[dname][rep][i-1]
            members = (ann_preds == 'high').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            high_median = kmf.median_
            
            members = (ann_preds == 'low').values.ravel()
            
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            
            low_median = kmf.median_
            
            ann_box.append(low_median - high_median)
                
    plt.figure()
    #fig = plt.figure(figsize=(2.5, 2.0))
    plt.title(dname)
    plt.boxplot([rpart_box, cox_box, ann_box], 
                labels=labels, vert=False, colors=plt.colors[:3])
    plt.xlabel("Median survival time difference")
    savefig("crossval-{}-mediandiff".format(dname))


/home/gibson/jonask/anaconda3/envs/riskgroups/lib/python3.4/site-packages/matplotlib/font_manager.py:1282: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))

Group sizes for each group

Plot group sizes seen on the validation data, to make sure other survival characteristics are comparable.


In [56]:
# Group sizes
for dname, _df in datasets.items():
    labels = [n + " low" for n in ['Rpart', 'Cox', 'ANN']]
    labels.extend([n + " high" for n in ['Rpart', 'Cox', 'ANN']])
    
    rpart_high_box, rpart_low_box = [], []
    cox_high_box, cox_low_box = [], []
    ann_high_box, ann_low_box = [], []
    
    n = _df.shape[0]
        
    # For each repetition
    for rep in range(cross_n):
        # Should not have been a list...
        perm = data_permutations[dname][rep][0]
        
        duration_col = _df.columns[0]
        event_col = _df.columns[1]
        
        df = _df.reindex(perm).sort(event_col)
        
        assignments = np.array((n // cross_k + 1) * list(range(1, cross_k + 1)))
        assignments = assignments[:n]
            
        # For each division
        for i in range(1, cross_k + 1):
            ix = assignments == i
            training_data = df.ix[~ix]
            testing_data = df.ix[ix]
            
            T_actual = testing_data[duration_col].values
            E_actual = testing_data[event_col].values
            #X_testing = testing_data[testing_columns]
            
            ## Rpart
            rpart_preds = rpart_val_preds[dname][rep][i-1]
            
            rpart_high_box.append(np.sum(rpart_preds == 'high'))
            rpart_low_box.append(np.sum(rpart_preds == 'low'))
            
            ## Cox
            cox_preds = cox_val_preds[dname][rep][i-1]
            
            cox_high_box.append(np.sum(cox_preds == 'high'))
            cox_low_box.append(np.sum(cox_preds == 'low'))
            
            ## ANN
            ann_preds = ann_val_preds[dname][rep][i-1]
            
            ann_high_box.append(np.sum(ann_preds == 'high'))
            ann_low_box.append(np.sum(ann_preds == 'low'))
            
    plt.figure()
    plt.title(dname)
    # Labels are ordered low-high
    plt.boxplot([rpart_low_box, cox_low_box, ann_low_box, rpart_high_box, cox_high_box, ann_high_box], 
                labels=labels, vert=False, colors=plt.colors[:3])
    plt.xlabel("Group sizes")
    savefig("crossval-{}-groupsize".format(dname))


/home/gibson/jonask/anaconda3/envs/riskgroups/lib/python3.4/site-packages/matplotlib/font_manager.py:1282: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))

End survival rate for each group


In [57]:
# End survival rate
for dname, _df in datasets.items():
    labels = [n + " low" for n in ['Rpart', 'Cox', 'ANN']]
    labels.extend([n + " high" for n in ['Rpart', 'Cox', 'ANN']])
    
    rpart_high_box, rpart_low_box = [], []
    cox_high_box, cox_low_box = [], []
    ann_high_box, ann_low_box = [], []
    
    n = _df.shape[0]
        
    # For each repetition
    for rep in range(cross_n):
        # Should not have been a list...
        perm = data_permutations[dname][rep][0]
        
        duration_col = _df.columns[0]
        event_col = _df.columns[1]
        
        df = _df.reindex(perm).sort(event_col)
        
        assignments = np.array((n // cross_k + 1) * list(range(1, cross_k + 1)))
        assignments = assignments[:n]
            
        # For each division
        for i in range(1, cross_k + 1):
            ix = assignments == i
            training_data = df.ix[~ix]
            testing_data = df.ix[ix]
            
            T_actual = testing_data[duration_col].values
            E_actual = testing_data[event_col].values
            #X_testing = testing_data[testing_columns]
            
            # Calculate what of interest
            kmf = KaplanMeierFitter()
            
            ## Rpart
            rpart_preds = rpart_val_preds[dname][rep][i-1]
            
            members = (rpart_preds == 'high').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            rpart_high_box.append(kmf.survival_function_.iloc[-1, 0])
            
            members = (rpart_preds == 'low').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            rpart_low_box.append(kmf.survival_function_.iloc[-1, 0])
            
            ## Cox
            cox_preds = cox_val_preds[dname][rep][i-1]
            
            members = (cox_preds == 'high').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            cox_high_box.append(kmf.survival_function_.iloc[-1, 0])
            
            members = (cox_preds == 'low').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            cox_low_box.append(kmf.survival_function_.iloc[-1, 0])
            
            ## ANN
            ann_preds = ann_val_preds[dname][rep][i-1]
            
            members = (ann_preds == 'high').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            ann_high_box.append(kmf.survival_function_.iloc[-1, 0])
            
            members = (ann_preds == 'low').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            ann_low_box.append(kmf.survival_function_.iloc[-1, 0])
            
    plt.figure()
    plt.title(dname)
    # Labels are ordered low-high
    plt.boxplot([rpart_low_box, cox_low_box, ann_low_box, rpart_high_box, cox_high_box, ann_high_box], 
                labels=labels, vert=False, colors=plt.colors[:3])
    plt.xlabel("End survival rate")
    savefig("crossval-{}-endsurvrate".format(dname))


/home/gibson/jonask/anaconda3/envs/riskgroups/lib/python3.4/site-packages/matplotlib/font_manager.py:1282: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))

Difference in end survival rate

To see if the survival curves crossed, the difference between end survival rates for high and low risk groups are plotted here.

If the curves crossed, the difference will be negative. If they did not cross, we'd expect low-risk group to have a higher end survival time than the high-risk group, so the difference should be positive. The closer the the value is to zero, the closer the curves are to each other (towards the end).


In [58]:
# Crossing survival curves/ Distance between them at end
for dname, _df in datasets.items():
    labels = ['Rpart', 'Cox', 'ANN']
    
    rpart_box = []
    cox_box = []
    ann_box = []
    
    n = _df.shape[0]
        
    # For each repetition
    for rep in range(cross_n):
        # Should not have been a list...
        perm = data_permutations[dname][rep][0]
        
        duration_col = _df.columns[0]
        event_col = _df.columns[1]
        
        df = _df.reindex(perm).sort(event_col)
        
        assignments = np.array((n // cross_k + 1) * list(range(1, cross_k + 1)))
        assignments = assignments[:n]
            
        # For each division
        for i in range(1, cross_k + 1):
            ix = assignments == i
            training_data = df.ix[~ix]
            testing_data = df.ix[ix]
            
            T_actual = testing_data[duration_col].values
            E_actual = testing_data[event_col].values
            #X_testing = testing_data[testing_columns]
            
            # Calculate what of interest
            kmf = KaplanMeierFitter()
            
            ## Rpart
            rpart_preds = rpart_val_preds[dname][rep][i-1]
            
            members = (rpart_preds == 'high').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            high_rate = kmf.survival_function_.iloc[-1, 0]
            
            members = (rpart_preds == 'low').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            low_rate = kmf.survival_function_.iloc[-1, 0]
            
            rpart_box.append(low_rate - high_rate)
            
            ## Cox
            cox_preds = cox_val_preds[dname][rep][i-1]
            
            members = (cox_preds == 'high').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            high_rate = kmf.survival_function_.iloc[-1, 0]
            
            members = (cox_preds == 'low').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            low_rate = kmf.survival_function_.iloc[-1, 0]
            
            cox_box.append(low_rate - high_rate)
            
            ## ANN
            ann_preds = ann_val_preds[dname][rep][i-1]
            
            members = (ann_preds == 'high').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            high_rate = kmf.survival_function_.iloc[-1, 0]
            
            members = (ann_preds == 'low').values.ravel()
            kmf.fit(T_actual[members],
                    E_actual[members],
                    label=None)
            low_rate = kmf.survival_function_.iloc[-1, 0]
            
            ann_box.append(low_rate - high_rate)
            
    plt.figure()
    plt.title(dname)
    # Labels are ordered low-high
    plt.boxplot([rpart_box, cox_box, ann_box], 
                labels=labels, vert=False, colors=plt.colors[:3])
    plt.xlabel("Survival rate difference")
    savefig("crossval-{}-endsurvdiff".format(dname))


/home/gibson/jonask/anaconda3/envs/riskgroups/lib/python3.4/site-packages/matplotlib/font_manager.py:1282: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))