
Using the parameters from the cross-validation scripts, here the models train on the full training data and their predictions on the test set are reported.

The group sizes generated by Rpart on the training data is used for Cox and ANN to make the survival curves comparable.

In [45]:
# import stuffs
%matplotlib inline
import numpy as np
import pandas as pd
from pyplotthemes import get_savefig, ClassicTheme
from lifelines.plotting import add_at_risk_counts
from lifelines.estimation import KaplanMeierFitter
from lifelines.estimation import median_survival_times
#plt.latex = True
# Set to PLOS One requirements
pltkwargs = {'text.usetex': True,
            'text.latex.unicode': True,
            #'font.family': 'sans-serif',
            'font.sans-serif': 'Arial',
            'font.serif': 'Times'}
plt = ClassicTheme(**pltkwargs)
savefig = get_savefig('../fig', extensions=['pdf', 'eps'])

Load data

Columns will always be target, event, inputs

In [2]:
from datasets import get_colon, get_nwtco, get_flchain, get_pbc, get_lung

# Values of (trn, test)
datasets = {}

# Add the data sets
for name, getter in zip(["pbc", "lung", "colon", "nwtco", "flchain"],
                         [get_pbc, get_lung, get_colon, get_nwtco, get_flchain]):
    trn = getter(norm_in=True, norm_out=False, training=True)
    test = getter(norm_in=True, norm_out=False, training=False)
    datasets[name] = (trn, test)
    cens = (trn.iloc[:, 1] == 0)
    censcount = np.sum(cens) / trn.shape[0]
    print(name, "censed:", censcount)

pbc censed: 0.608163265306
lung censed: 0.27485380117
colon censed: 0.495689655172
nwtco censed: 0.857994041708
flchain censed: 0.724716245977

Train the models

Rpart first

Because this dictates the group sizes used for the rest of the models. For each data set, statistically non-different groups are joined, and then we manually join groups to fix cases where low/high-risk groups are deemed too small. In case Rpart produces 3 groups directly (one time), we use the grouping directly.

The original grouping is plotted first, followed by the merged grouping result (all on training data).

In [3]:
from pysurvival.rpart import RPartModel

# Save values for later here
rpart_trn_preds = {}
rpart_test_preds =  {}
group_sizes = {}

# Default values in rpart actually
rpart_kwargs = dict(highlim=0.1,

def show_rpart(trn, rm=None, joinmid=False):
    By default, the implementation will join non-significantly different
    groups into low and high groups.
    Returns an rpart model
    if rm is None:
        # Low limits so we don't merge significantly different groups yet
        rm = RPartModel(**rpart_kwargs)
        rm.fit(trn, trn.columns[0], trn.columns[1])
    preds = rm.predict(trn)
    up = np.unique(preds)
    fitters = []
    mygroup = np.zeros_like(preds, dtype=bool)
    for i, g in enumerate(sorted(up)):
        members = (preds == g).ravel()
        mygroup += members
        if i == len(up) - 1:
        elif g in rm._high and up[i+1] in rm._high:
            # Both in high
        elif g in rm._low and up[i+1] in rm._low:
            # Both in low
        elif (joinmid and 
              (g not in rm._high and up[i+1] not in rm._high) and
              (g not in rm._low and up[i+1] not in rm._low)):
            # Both in mid
        print("Group {} has size {} = {:.2f}".format(i,
        kmf = KaplanMeierFitter()
        kmf.fit(trn.iloc[mygroup, 0],
                trn.iloc[mygroup, 1],
        res = kmf.plot(ax=plt.gca(), ci_show=False)
        # Reset my group
        mygroup = np.zeros_like(preds, dtype=bool)
    # Nicer legend
    plt.legend(loc='best', framealpha=0.0, ncol=1 if joinmid else 3)

    plt.ylim((0, 1.05))
    return rm


First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).

Here only the intermediate groups are merged.

In [4]:
name = 'pbc'
trn, test = datasets[name]
rm = show_rpart(trn)
# PBC finds reasonable groups without merging non-significant groups
up = np.unique(rm.predict(trn))
#print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)

preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)

group_sizes[name] = (int(np.sum(preds == 'high')), 
                     int(np.sum(preds == 'low')))

Group 4 has size 88 = 0.36
Group 5 has size 27 = 0.11
Group 6 has size 15 = 0.06
Group 7 has size 7 = 0.03
Group 8 has size 13 = 0.05
Group 9 has size 8 = 0.03
Group 10 has size 17 = 0.07
Group 11 has size 13 = 0.05
Group 15 has size 57 = 0.23
Group 4 has size 88 = 0.36
Group 11 has size 100 = 0.41
Group 15 has size 57 = 0.23


First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).

Here Rpart produces very many groups, several of which are not well separated from each other (and are thus merged). In addition, the low-risk group is deemed to still be too small after this and thus is joined with the adjacent group.

In [5]:
name = 'lung'
trn, test = datasets[name]
rm = show_rpart(trn)
# LUNG needs merging low risk groups
up = np.unique(rm.predict(trn))
print("Low risk groups:", rm._low)

show_rpart(trn, rm, joinmid=True)

preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)

group_sizes[name] = (int(np.sum(preds == 'high')), 
                     int(np.sum(preds == 'low')))

Group 1 has size 22 = 0.13
Group 2 has size 17 = 0.10
Group 3 has size 9 = 0.05
Group 4 has size 19 = 0.11
Group 5 has size 7 = 0.04
Group 6 has size 17 = 0.10
Group 12 has size 80 = 0.47
Low risk groups: [0.27534547670833437, 0.27693381720928589, 0.58081827076604597]
Group 2 has size 39 = 0.23
Group 6 has size 52 = 0.30
Group 12 has size 80 = 0.47


First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).

Rpart produces 4 separated groupings here but the low-risk group is deemed too small. It is thus merged with the second group.

In [6]:
name = 'colon'
trn, test = datasets[name]
rm = show_rpart(trn)
# COLON requires merging 
up = np.unique(rm.predict(trn))
print("Low risk groups:", rm._low)
print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)

preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)

group_sizes[name] = (int(np.sum(preds == 'high')), 
                     int(np.sum(preds == 'low')))

Group 0 has size 71 = 0.10
Group 1 has size 145 = 0.21
Group 2 has size 272 = 0.39
Group 5 has size 208 = 0.30
Low risk groups: [0.24616618934561699, 0.5374116150971846]
High risk groups: [3.0948674006959433, 1.6460031120776577, 1.5754895904741031]
Group 1 has size 216 = 0.31
Group 2 has size 272 = 0.39
Group 5 has size 208 = 0.30


First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).

Here the output was three groups, and so no merging was performed. This is one case where a large deviation from the quartile goal can be seen.

In [7]:
name = 'flchain'
trn, test = datasets[name]
rm = show_rpart(trn)
# FLCHAIN gives 3 significantly different groups, where the
# high risk group is very large (>.5) and the low risk group is small (0.2)
# and the mid groups is (0.24). In this case, one should leave the groups be.

up = np.unique(rm.predict(trn))
print("Low risk groups:", rm._low)
print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)

preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)

group_sizes[name] = (int(np.sum(preds == 'high')), 
                     int(np.sum(preds == 'low')))

Group 0 has size 3272 = 0.55
Group 1 has size 1438 = 0.24
Group 3 has size 1193 = 0.20
Low risk groups: [0.33686666778107871]
High risk groups: [8.1056320388011542, 3.009865289679559]
Group 0 has size 3272 = 0.55
Group 1 has size 1438 = 0.24
Group 3 has size 1193 = 0.20


First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).

Here only the two intermediate groups are merged.

In [8]:
name = 'nwtco'
trn, test = datasets[name]
rm = show_rpart(trn)
# NWTC gives 4 significantly different groups (join two mid groups
# which are overlapped). The low risk group is large (0.41) and the
# two high risk groups (quite separated) joined together make up (0.17) and
# here I'd say the data set does not warrant further action.
up = np.unique(rm.predict(trn))
print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)

preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)

group_sizes[name] = (int(np.sum(preds == 'high')), 
                     int(np.sum(preds == 'low')))

Group 0 has size 1232 = 0.41
Group 1 has size 1183 = 0.39
Group 2 has size 96 = 0.03
Group 5 has size 510 = 0.17
High risk groups: [13.893294227665161, 3.9996498800207285, 1.8747162224287353]
Group 0 has size 1232 = 0.41
Group 2 has size 1279 = 0.42
Group 5 has size 510 = 0.17


Here the group sizes which Rpart generated are used to make the cuts. A cox model is trained on each training data set, and cuts are set to make a grouping where each group has the same size as for Rpart.

Resulting groupings are plotted (on the training data).

In [9]:
from classcox import CoxClasser

# Save predictions for later
cox_trn_preds = {}
cox_test_preds = {}

for name, (trn, test) in datasets.items():
    total = trn.shape[0]
    # Get the rpart sizes for this data
    high_size, low_size = group_sizes[name]
    cox = CoxClasser(100 * (1 - high_size / total),
                     100 * low_size / total)
    #cox = CoxClasser()
    cox.fit(trn, trn.columns[0], trn.columns[1])
    cox_test_preds[name] = cox.predict_classes(test)
    preds = cox.predict_classes(trn)
    cox_trn_preds[name] = preds
    fitters = []
    for g in ['high', 'mid', 'low']:
        members = (preds == g).values.ravel()
        if not np.any(members):
        kmf = KaplanMeierFitter()
        kmf.fit(trn.iloc[members, 0],
                trn.iloc[members, 1],
        res = kmf.plot(ax=plt.gca(), ci_show=False)
    plt.legend(loc='best', framealpha=0)

    plt.ylim((0, 1.05))


Again the group size from Rpart is used. Here it is given as an argument to the ANN which sets a hard minimum size on the groups in the genetic fitness function. It's only a minimum to allow the network some flexibility in finding an optimal solution, and giving larger group sizes will only make the survival curves worse, never better, so is not an advantage for ANN.

Groupings are plotted for each training data set.

In [10]:
import ann
from classensemble import ClassEnsemble
from helpers import get_net

# Save predictions for later
ann_trn_preds = {}
ann_test_preds = {}

for name, (trn, test) in datasets.items():
    total = trn.shape[0]
    # Get the rpart sizes for this data
    high_size, low_size = group_sizes[name]
    incols = trn.shape[1] - 2
    hnets = []
    lnets = []
    netcount = 34
    for i in range(netcount):
        if i % 2:
            n = get_net(incols, high_size, ann.geneticnetwork.FITNESS_SURV_KAPLAN_MIN)
            n = get_net(incols, low_size, ann.geneticnetwork.FITNESS_SURV_KAPLAN_MAX)
    e = ClassEnsemble(hnets, lnets)
    e.fit(trn, trn.columns[0], trn.columns[1])
    ann_test_preds[name] = e.predict_classes(test)
    preds = e.predict_classes(trn)
    ann_trn_preds[name] = preds
    fitters = []
    for g in ['high', 'mid', 'low']:
        members = (preds == g).values.ravel()
        if not np.any(members):
        kmf = KaplanMeierFitter()
        kmf.fit(trn.iloc[members, 0],
                trn.iloc[members, 1],
        res = kmf.plot(ax=plt.gca(), ci_show=False)
    plt.legend(loc='best', framealpha=0)

    plt.ylim((0, 1.05))


Save predictions to file

Here the ann model's predictions for both training and test sets are saved to file using Python's built-in pickling capabilities.

In [11]:
from pickle import dump
import os

# These are expensive, so save to disk
path = "ann_test_predictions.pickle"
if os.path.exists(path):
    raise ValueError("File exists. Should not be overwritten")

with open(path, 'wb') as F:
    dump(ann_test_preds, F)
path = "ann_trn_predictions.pickle"
if os.path.exists(path):
    raise ValueError("File exists. Should not be overwritten")

with open(path, 'wb') as F:
    dump(ann_trn_preds, F)


Where the predictions on the test data are plotted, together with group size information.

Load predictions from file

Here we read the predictions back into memory. This means we can rerun the plotting steps below, and do additional statistics without requiring a full re-training of the ann-model, which can be tedious. Cox and Rpart are both very fast so this is not a concern in that case.

In [12]:
from pickle import load
import os

# Read them back
path = "ann_test_predictions.pickle"
with open(path, 'rb') as F:
    ann_test_preds = load(F)
# Read them back
path = "ann_trn_predictions.pickle"
with open(path, 'rb') as F:
    ann_trn_preds = load(F)

Plot results

Each data set gets a plot combining all three models. Information about group size is printed as text.

In [46]:
model_names = ('rpart', 'cox', 'ann')

for name, (trn, test) in datasets.items():
    fig = plt.figure(figsize=(2.5, 2.0))
    print("\n", name, "events / total")

    fitters = []
    linestyles = ['-', ':', '--']
    for modelidx, (model, grouplabels) in enumerate(zip(model_names,
        print("  ", model)
        for g in ['high', 'mid', 'low']:
            members = (grouplabels == g).values.ravel()
            print("      {:4} {:4} / {:4}".format(g, int(np.sum(test.iloc[members, 1])), int(np.sum(members))))
            if not np.any(members):
            if g == 'mid':
                linestyle = ':'
            elif g == 'low':
                linestyle = '--'
                linestyle = '-'
            kmf = KaplanMeierFitter()
            kmf.fit(test.iloc[members, 0],
                    test.iloc[members, 1],
                    label="{} {}".format(model, g))
            res = kmf.plot(ax=plt.gca(), color=plt.colors[modelidx], 
                           ci_show=False, linestyle=linestyle)

    # Override legend entries
    leg = plt.legend(loc='best', frameon=False).get_lines()
    leg_entries = [leg[0], leg[3], leg[6]]
    #plt.legend([leg[0], leg[3], leg[6]], model_names,
    #                loc='best', frameon=False)
    # Now remove legend

    plt.ylim((0, 1))
    plt.locator_params(tight=True, nbins=5)
    savefig("test" + name)
# Create separate figure for legend
plt.figure(figsize=(2.5, 2.0))
plt.legend(leg_entries, model_names, loc='center', frameon=False)

 nwtco events / total
      high   64 /  178
      mid    51 /  385
      low    27 /  444
      high   62 /  157
      mid    46 /  391
      low    34 /  459
      high   63 /  163
      mid    46 /  393
      low    33 /  451

 colon events / total
      high   39 /   57
      mid    52 /  106
      low    26 /   70
      high   47 /   73
      mid    43 /   83
      low    27 /   77
      high   48 /   71
      mid    46 /   99
      low    23 /   63

 pbc events / total
      high   13 /   15
      mid    12 /   29
      low     4 /   23
      high   10 /   10
      mid    16 /   38
      low     3 /   19
      high   10 /   11
      mid    17 /   29
      low     2 /   27

 flchain events / total
      high  292 /  415
      mid   137 /  452
      low   112 / 1101
      high  301 /  411
      mid   139 /  479
      low   101 / 1078
      high  296 /  402
      mid   145 /  475
      low   100 / 1091

 lung events / total
      high   20 /   26
      mid    13 /   22
      low     8 /    9
      high   23 /   28
      mid     8 /   11
      low    10 /   18
      high   23 /   31
      mid    12 /   15
      low     6 /   11
/home/gibson/jonask/anaconda3/envs/riskgroups/lib/python3.4/site-packages/matplotlib/font_manager.py:1282: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to Bitstream Vera Sans
  (prop.get_family(), self.defaultFamily[fontext]))