In [45]:
# import stuffs
%matplotlib inline
import numpy as np
import pandas as pd
from pyplotthemes import get_savefig, ClassicTheme
from lifelines.plotting import add_at_risk_counts
from lifelines.estimation import KaplanMeierFitter
from lifelines.estimation import median_survival_times
#plt.latex = True
# Set to PLOS One requirements
pltkwargs = {'text.usetex': True,
'text.latex.unicode': True,
#'font.family': 'sans-serif',
'font.sans-serif': 'Arial',
'font.serif': 'Times'}
plt = ClassicTheme(**pltkwargs)
savefig = get_savefig('../fig', extensions=['pdf', 'eps'])
In [2]:
from datasets import get_colon, get_nwtco, get_flchain, get_pbc, get_lung
# Values of (trn, test)
datasets = {}
# Add the data sets
for name, getter in zip(["pbc", "lung", "colon", "nwtco", "flchain"],
[get_pbc, get_lung, get_colon, get_nwtco, get_flchain]):
trn = getter(norm_in=True, norm_out=False, training=True)
test = getter(norm_in=True, norm_out=False, training=False)
datasets[name] = (trn, test)
cens = (trn.iloc[:, 1] == 0)
censcount = np.sum(cens) / trn.shape[0]
print(name, "censed:", censcount)
Because this dictates the group sizes used for the rest of the models. For each data set, statistically non-different groups are joined, and then we manually join groups to fix cases where low/high-risk groups are deemed too small. In case Rpart produces 3 groups directly (one time), we use the grouping directly.
The original grouping is plotted first, followed by the merged grouping result (all on training data).
In [3]:
from pysurvival.rpart import RPartModel
# Save values for later here
rpart_trn_preds = {}
rpart_test_preds = {}
group_sizes = {}
# Default values in rpart actually
rpart_kwargs = dict(highlim=0.1,
lowlim=0.1,
minsplit=20,
minbucket=None,
xval=3,
cp=0.01)
def show_rpart(trn, rm=None, joinmid=False):
'''
By default, the implementation will join non-significantly different
groups into low and high groups.
Returns an rpart model
'''
if rm is None:
# Low limits so we don't merge significantly different groups yet
rm = RPartModel(**rpart_kwargs)
rm.fit(trn, trn.columns[0], trn.columns[1])
preds = rm.predict(trn)
up = np.unique(preds)
plt.figure()
fitters = []
mygroup = np.zeros_like(preds, dtype=bool)
for i, g in enumerate(sorted(up)):
members = (preds == g).ravel()
mygroup += members
if i == len(up) - 1:
pass
elif g in rm._high and up[i+1] in rm._high:
# Both in high
continue
elif g in rm._low and up[i+1] in rm._low:
# Both in low
continue
elif (joinmid and
(g not in rm._high and up[i+1] not in rm._high) and
(g not in rm._low and up[i+1] not in rm._low)):
# Both in mid
continue
print("Group {} has size {} = {:.2f}".format(i,
np.sum(mygroup),
np.sum(mygroup)/trn.shape[0]))
kmf = KaplanMeierFitter()
fitters.append(kmf)
kmf.fit(trn.iloc[mygroup, 0],
trn.iloc[mygroup, 1],
label="{}".format(i))
res = kmf.plot(ax=plt.gca(), ci_show=False)
# Reset my group
mygroup = np.zeros_like(preds, dtype=bool)
# Nicer legend
plt.legend(loc='best', framealpha=0.0, ncol=1 if joinmid else 3)
plt.ylim((0, 1.05))
add_at_risk_counts(*fitters)
return rm
First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).
Here only the intermediate groups are merged.
In [4]:
name = 'pbc'
trn, test = datasets[name]
rm = show_rpart(trn)
# PBC finds reasonable groups without merging non-significant groups
up = np.unique(rm.predict(trn))
#rm._high.append(up[-2])
#print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)
preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)
group_sizes[name] = (int(np.sum(preds == 'high')),
int(np.sum(preds == 'low')))
First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).
Here Rpart produces very many groups, several of which are not well separated from each other (and are thus merged). In addition, the low-risk group is deemed to still be too small after this and thus is joined with the adjacent group.
In [5]:
name = 'lung'
trn, test = datasets[name]
rm = show_rpart(trn)
# LUNG needs merging low risk groups
up = np.unique(rm.predict(trn))
rm._low.append(up[2])
print("Low risk groups:", rm._low)
show_rpart(trn, rm, joinmid=True)
preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)
group_sizes[name] = (int(np.sum(preds == 'high')),
int(np.sum(preds == 'low')))
First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).
Rpart produces 4 separated groupings here but the low-risk group is deemed too small. It is thus merged with the second group.
In [6]:
name = 'colon'
trn, test = datasets[name]
rm = show_rpart(trn)
# COLON requires merging
up = np.unique(rm.predict(trn))
rm._low.append(up[1])
print("Low risk groups:", rm._low)
#rm._high.append(up[-3])
print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)
preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)
group_sizes[name] = (int(np.sum(preds == 'high')),
int(np.sum(preds == 'low')))
First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).
Here the output was three groups, and so no merging was performed. This is one case where a large deviation from the quartile goal can be seen.
In [7]:
name = 'flchain'
trn, test = datasets[name]
rm = show_rpart(trn)
# FLCHAIN gives 3 significantly different groups, where the
# high risk group is very large (>.5) and the low risk group is small (0.2)
# and the mid groups is (0.24). In this case, one should leave the groups be.
up = np.unique(rm.predict(trn))
#rm._low.append(up[1])
print("Low risk groups:", rm._low)
#rm._high.append(up[-2])
print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)
preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)
group_sizes[name] = (int(np.sum(preds == 'high')),
int(np.sum(preds == 'low')))
First Rpart does its normal pass to generate many groups. Groups which are not different in the statistical sense are then merge. If the number of groups are still too many, they are merged such that the outer groups approach quartile sizes (or as close as they can come).
Here only the two intermediate groups are merged.
In [8]:
name = 'nwtco'
trn, test = datasets[name]
rm = show_rpart(trn)
# NWTC gives 4 significantly different groups (join two mid groups
# which are overlapped). The low risk group is large (0.41) and the
# two high risk groups (quite separated) joined together make up (0.17) and
# here I'd say the data set does not warrant further action.
up = np.unique(rm.predict(trn))
#print(up)
#rm._high.append(up[-3])
print("High risk groups:", rm._high)
show_rpart(trn, rm, joinmid=True)
preds = rm.predict_classes(trn)
rpart_trn_preds[name] = preds
rpart_test_preds[name] = rm.predict_classes(test)
group_sizes[name] = (int(np.sum(preds == 'high')),
int(np.sum(preds == 'low')))
In [9]:
from classcox import CoxClasser
# Save predictions for later
cox_trn_preds = {}
cox_test_preds = {}
for name, (trn, test) in datasets.items():
total = trn.shape[0]
# Get the rpart sizes for this data
high_size, low_size = group_sizes[name]
cox = CoxClasser(100 * (1 - high_size / total),
100 * low_size / total)
#cox = CoxClasser()
cox.fit(trn, trn.columns[0], trn.columns[1])
cox_test_preds[name] = cox.predict_classes(test)
preds = cox.predict_classes(trn)
cox_trn_preds[name] = preds
plt.figure()
fitters = []
for g in ['high', 'mid', 'low']:
members = (preds == g).values.ravel()
if not np.any(members):
continue
kmf = KaplanMeierFitter()
fitters.append(kmf)
kmf.fit(trn.iloc[members, 0],
trn.iloc[members, 1],
label="{}".format(g))
res = kmf.plot(ax=plt.gca(), ci_show=False)
plt.legend(loc='best', framealpha=0)
plt.ylim((0, 1.05))
plt.title(name)
add_at_risk_counts(*fitters)
Again the group size from Rpart is used. Here it is given as an argument to the ANN which sets a hard minimum size on the groups in the genetic fitness function. It's only a minimum to allow the network some flexibility in finding an optimal solution, and giving larger group sizes will only make the survival curves worse, never better, so is not an advantage for ANN.
Groupings are plotted for each training data set.
In [10]:
import ann
from classensemble import ClassEnsemble
from helpers import get_net
# Save predictions for later
ann_trn_preds = {}
ann_test_preds = {}
for name, (trn, test) in datasets.items():
print(name)
total = trn.shape[0]
# Get the rpart sizes for this data
high_size, low_size = group_sizes[name]
incols = trn.shape[1] - 2
hnets = []
lnets = []
netcount = 34
for i in range(netcount):
if i % 2:
n = get_net(incols, high_size, ann.geneticnetwork.FITNESS_SURV_KAPLAN_MIN)
hnets.append(n)
else:
n = get_net(incols, low_size, ann.geneticnetwork.FITNESS_SURV_KAPLAN_MAX)
lnets.append(n)
e = ClassEnsemble(hnets, lnets)
e.fit(trn, trn.columns[0], trn.columns[1])
ann_test_preds[name] = e.predict_classes(test)
preds = e.predict_classes(trn)
ann_trn_preds[name] = preds
plt.figure()
fitters = []
for g in ['high', 'mid', 'low']:
members = (preds == g).values.ravel()
if not np.any(members):
continue
kmf = KaplanMeierFitter()
fitters.append(kmf)
kmf.fit(trn.iloc[members, 0],
trn.iloc[members, 1],
label="{}".format(g))
res = kmf.plot(ax=plt.gca(), ci_show=False)
plt.legend(loc='best', framealpha=0)
plt.ylim((0, 1.05))
plt.title(name)
add_at_risk_counts(*fitters)
In [11]:
from pickle import dump
import os
# These are expensive, so save to disk
path = "ann_test_predictions.pickle"
if os.path.exists(path):
raise ValueError("File exists. Should not be overwritten")
with open(path, 'wb') as F:
dump(ann_test_preds, F)
path = "ann_trn_predictions.pickle"
if os.path.exists(path):
raise ValueError("File exists. Should not be overwritten")
with open(path, 'wb') as F:
dump(ann_trn_preds, F)
Here we read the predictions back into memory. This means we can rerun the plotting steps below, and do additional statistics without requiring a full re-training of the ann-model, which can be tedious. Cox and Rpart are both very fast so this is not a concern in that case.
In [12]:
from pickle import load
import os
# Read them back
path = "ann_test_predictions.pickle"
with open(path, 'rb') as F:
ann_test_preds = load(F)
# Read them back
path = "ann_trn_predictions.pickle"
with open(path, 'rb') as F:
ann_trn_preds = load(F)
In [46]:
model_names = ('rpart', 'cox', 'ann')
for name, (trn, test) in datasets.items():
fig = plt.figure(figsize=(2.5, 2.0))
plt.title(name)
print("\n", name, "events / total")
fitters = []
linestyles = ['-', ':', '--']
for modelidx, (model, grouplabels) in enumerate(zip(model_names,
[rpart_test_preds[name],
cox_test_preds[name],
ann_test_preds[name]])):
print(" ", model)
for g in ['high', 'mid', 'low']:
members = (grouplabels == g).values.ravel()
print(" {:4} {:4} / {:4}".format(g, int(np.sum(test.iloc[members, 1])), int(np.sum(members))))
if not np.any(members):
continue
if g == 'mid':
linestyle = ':'
elif g == 'low':
linestyle = '--'
else:
linestyle = '-'
kmf = KaplanMeierFitter()
fitters.append(kmf)
kmf.fit(test.iloc[members, 0],
test.iloc[members, 1],
label="{} {}".format(model, g))
res = kmf.plot(ax=plt.gca(), color=plt.colors[modelidx],
ci_show=False, linestyle=linestyle)
# Override legend entries
leg = plt.legend(loc='best', frameon=False).get_lines()
leg_entries = [leg[0], leg[3], leg[6]]
#plt.legend([leg[0], leg[3], leg[6]], model_names,
# loc='best', frameon=False)
# Now remove legend
plt.gca().legend().set_visible(False)
plt.ylim((0, 1))
plt.xlabel("Time")
plt.locator_params(tight=True, nbins=5)
#add_at_risk_counts(*fitters)
savefig("test" + name)
# Create separate figure for legend
plt.figure(figsize=(2.5, 2.0))
plt.axis('off')
plt.legend(leg_entries, model_names, loc='center', frameon=False)
savefig("testlegend")