Introduction to Survival Analysis with scikit-survival

scikit-survival is a Python module for survival analysis built on top of scikit-learn. It allows doing survival analysis while utilizing the power of scikit-learn, e.g., for pre-processing or doing cross-validation.

What is Survival Analysis?

The objective in survival analysis — also referred to as reliability analysis in engineering — is to establish a connection between covariates and the time of an event. The name survival analysis originates from clinical research, where predicting the time to death, i.e., survival, is often the main objective. Survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. It differs from traditional regression by the fact that parts of the training data can only be partially observed – they are censored.

As an example, consider a clinical study, which investigates coronary heart disease and has been carried out over a 1 year period as in the figure below.

Patient A was lost to follow-up after three months with no recorded cardiovascular event, patient B experienced an event four and a half months after enrollment, patient D withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the exact time of a cardiovascular event could only be recorded for patients B and C; their records are uncensored. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, D, and E is that they were event-free up to their last follow-up. Therefore, their records are censored.

Formally, each patient record consists of a set of covariates $x \in \mathbb{R}^d$ , and the time $t>0$ when an event occurred or the time $c>0$ of censoring. Since censoring and experiencing and event are mutually exclusive, it is common to define an event indicator $\delta \in \{0;1\}$ and the observable survival time $y>0$. The observable time $y$ of a right censored sample is defined as

$$ y = \min(t, c) = \begin{cases} t & \text{if } \delta = 1 , \\ c & \text{if } \delta = 0 . \end{cases} $$

Consequently, survival analysis demands for models that take this unique characteristic of such a dataset into account, some of which are showcased below.

The Veterans' Administration Lung Cancer Trial

The Veterans' Administration Lung Cancer Trial is a randomized trial of two treatment regimens for lung cancer. The data set (Kalbfleisch J. and Prentice R, (1980) The Statistical Analysis of Failure Time Data. New York: Wiley) consists of 137 patients and 8 variables, which are discribed below:

  • Treatment: denotes the type of lung cancer treatment; standard and test drug.
  • Celltype: denotes the type of cell involved; squamous, small cell, adeno, large.
  • Karnofsky_score: is the Karnofsky score.
  • Diag: is the time since diagnosis in months.
  • Age: is the age in years.
  • Prior_Therapy: denotes any prior therapy; none or yes.
  • Status: denotes the status of the patient as dead or alive; dead or alive.
  • Survival_in_days: is the survival time in days since the treatment.

Our primary interest is studying whether there a subgroups that differ in survival and whether we can predict survival times.

Survival Data

As described in the section What is Survival Analysis? above, survival times are subject to right-censoring, therefore, we need to consider an individual's status in addition to survival time. To be fully compatible with scikit-learn, Status and Survival_in_days need to be stored as a structured array with the first field indicating whether the actual survival time was observed or if was censored, and the second field denoting the observerd survival time, which corresponds to the time of death (if Status == 'dead', $\delta = 1$) or the last time that person was contacted (if Status == 'alive', $\delta = 0$).


In [1]:
from sksurv.datasets import load_veterans_lung_cancer

data_x, data_y = load_veterans_lung_cancer()
data_y


Out[1]:
array([( True,  72.), ( True, 411.), ( True, 228.), ( True, 126.),
       ( True, 118.), ( True,  10.), ( True,  82.), ( True, 110.),
       ( True, 314.), (False, 100.), ( True,  42.), ( True,   8.),
       ( True, 144.), (False,  25.), ( True,  11.), ( True,  30.),
       ( True, 384.), ( True,   4.), ( True,  54.), ( True,  13.),
       (False, 123.), (False,  97.), ( True, 153.), ( True,  59.),
       ( True, 117.), ( True,  16.), ( True, 151.), ( True,  22.),
       ( True,  56.), ( True,  21.), ( True,  18.), ( True, 139.),
       ( True,  20.), ( True,  31.), ( True,  52.), ( True, 287.),
       ( True,  18.), ( True,  51.), ( True, 122.), ( True,  27.),
       ( True,  54.), ( True,   7.), ( True,  63.), ( True, 392.),
       ( True,  10.), ( True,   8.), ( True,  92.), ( True,  35.),
       ( True, 117.), ( True, 132.), ( True,  12.), ( True, 162.),
       ( True,   3.), ( True,  95.), ( True, 177.), ( True, 162.),
       ( True, 216.), ( True, 553.), ( True, 278.), ( True,  12.),
       ( True, 260.), ( True, 200.), ( True, 156.), (False, 182.),
       ( True, 143.), ( True, 105.), ( True, 103.), ( True, 250.),
       ( True, 100.), ( True, 999.), ( True, 112.), (False,  87.),
       (False, 231.), ( True, 242.), ( True, 991.), ( True, 111.),
       ( True,   1.), ( True, 587.), ( True, 389.), ( True,  33.),
       ( True,  25.), ( True, 357.), ( True, 467.), ( True, 201.),
       ( True,   1.), ( True,  30.), ( True,  44.), ( True, 283.),
       ( True,  15.), ( True,  25.), (False, 103.), ( True,  21.),
       ( True,  13.), ( True,  87.), ( True,   2.), ( True,  20.),
       ( True,   7.), ( True,  24.), ( True,  99.), ( True,   8.),
       ( True,  99.), ( True,  61.), ( True,  25.), ( True,  95.),
       ( True,  80.), ( True,  51.), ( True,  29.), ( True,  24.),
       ( True,  18.), (False,  83.), ( True,  31.), ( True,  51.),
       ( True,  90.), ( True,  52.), ( True,  73.), ( True,   8.),
       ( True,  36.), ( True,  48.), ( True,   7.), ( True, 140.),
       ( True, 186.), ( True,  84.), ( True,  19.), ( True,  45.),
       ( True,  80.), ( True,  52.), ( True, 164.), ( True,  19.),
       ( True,  53.), ( True,  15.), ( True,  43.), ( True, 340.),
       ( True, 133.), ( True, 111.), ( True, 231.), ( True, 378.),
       ( True,  49.)],
      dtype=[('Status', '?'), ('Survival_in_days', '<f8')])

We can easily see that only a small number of survival times are right-censored (Status is False), i.e., most veteran's died during the study period (Status is True).

The Survival Function

A key quantity in survival analysis is the so-called survival function, which relates time to the probability of surviving beyond a given time point.

Let $T$ denote a continuous non-negative random variable corresponding to a patient’s survival time. The survival function $S(t)$ returns the probability of survival beyond time $t$ and is defined as $$ S(t) = P (T > t). $$

If we observed the exact survival time of all subjects, i.e., everyone died before the study ended, the survival function at time $t$ can simply be estimated by the ratio of patients surviving beyond time $t$ and the total number of patients:

$$ \hat{S}(t) = \frac{ \text{number of patients surviving beyond $t$} }{ \text{total number of patients} } $$

In the presence of censoring, this estimator cannot be used, because the numerator is not always defined. For instance, consider the following set of patients:


In [2]:
import pandas as pd

pd.DataFrame.from_records(data_y[[11, 5, 32, 13, 23]], index=range(1, 6))


Out[2]:
Status Survival_in_days
1 True 8.0
2 True 10.0
3 True 20.0
4 False 25.0
5 True 59.0

Using the formula from above, we can compute $\hat{S}(t=11) = \frac{3}{5}$, but not $\hat{S}(t=30)$, because we don't know whether the 4th patient is still alive at $t = 30$, all we know is that when we last checked at $t = 25$, the patient was still alive.

An estimator, similar to the one above, that is valid if survival times are right-censored is the Kaplan-Meier estimator.


In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
from sksurv.nonparametric import kaplan_meier_estimator

time, survival_prob = kaplan_meier_estimator(data_y["Status"], data_y["Survival_in_days"])
plt.step(time, survival_prob, where="post")
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")


Out[3]:
Text(0.5, 0, 'time $t$')

The estimated curve is a step function, with steps occuring at time points where one or more patients died. From the plot we can see that most patients died in the first 200 days, as indicated by the steep slope of the estimated survival function in the first 200 days.

Considering other variables by stratification

Survival functions by treatment

Patients enrolled in the Veterans' Administration Lung Cancer Trial were randomized to one of two treatments: standard and a new test drug. Next, let's have a look at how many patients underwent the standard treatment and how many received the new drug.


In [4]:
data_x["Treatment"].value_counts()


Out[4]:
standard    69
test        68
Name: Treatment, dtype: int64

Roughly half the patients received the alternative treatment.

The obvious questions to ask is:

Is there any difference in survival between the two treatment groups?

As a first attempt, we can estimate the survival function in both treatment groups separately.


In [5]:
for treatment_type in ("standard", "test"):
    mask_treat = data_x["Treatment"] == treatment_type
    time_treatment, survival_prob_treatment = kaplan_meier_estimator(
        data_y["Status"][mask_treat],
        data_y["Survival_in_days"][mask_treat])
    
    plt.step(time_treatment, survival_prob_treatment, where="post",
             label="Treatment = %s" % treatment_type)

plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")


Out[5]:
<matplotlib.legend.Legend at 0x7f55a1c4d9e8>

Unfortunately, the results are inconclusive, because the difference between the two estimated survival functions is too small to confidently argue that the drug affects survival or not.

Sidenote: Visually comparing estimated survival curves in order to assess whether there is a difference in survival between groups is usually not recommended, because it is highly subjective. Statistical tests such as the log-rank test are usually more appropriate.

Survival functions by cell type

Next, let's have a look at the cell type, which has been recorded as well, and repeat the analysis from above.


In [6]:
for value in data_x["Celltype"].unique():
    mask = data_x["Celltype"] == value
    time_cell, survival_prob_cell = kaplan_meier_estimator(data_y["Status"][mask],
                                                           data_y["Survival_in_days"][mask])
    plt.step(time_cell, survival_prob_cell, where="post",
             label="%s (n = %d)" % (value, mask.sum()))

plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")


Out[6]:
<matplotlib.legend.Legend at 0x7f55a1bb97f0>

In this case, we observe a pronounced difference between two groups. Patients with squamous or large cells seem to have a better prognosis compared to patients with small or adeno cells.

Multivariate Survival Models

In the Kaplan-Meier approach used above, we estimated multiple survival curves by dividing the dataset into smaller sub-groups according to a variable. If we want to consider more than 1 or 2 variables, this approach quickly becomes infeasible, because subgroups will get very small. Instead, we can use a linear model, Cox's proportional hazard's model, to estimate the impact each variable has on survival.

First however, we need to convert the categorical variables in the data set into numeric values.


In [7]:
from sksurv.preprocessing import OneHotEncoder

data_x_numeric = OneHotEncoder().fit_transform(data_x)
data_x_numeric.head()


Out[7]:
Age_in_years Celltype=large Celltype=smallcell Celltype=squamous Karnofsky_score Months_from_Diagnosis Prior_therapy=yes Treatment=test
0 69.0 0.0 0.0 1.0 60.0 7.0 0.0 0.0
1 64.0 0.0 0.0 1.0 70.0 5.0 1.0 0.0
2 38.0 0.0 0.0 1.0 60.0 3.0 0.0 0.0
3 63.0 0.0 0.0 1.0 60.0 9.0 1.0 0.0
4 65.0 0.0 0.0 1.0 70.0 11.0 1.0 0.0

Survival models in scikit-survival follow the same rules as estimators in scikit-learn, i.e., they have a fit method, which expects a data matrix and a structered array of survival times and binary event indicators.


In [8]:
from sksurv.linear_model import CoxPHSurvivalAnalysis

estimator = CoxPHSurvivalAnalysis()
estimator.fit(data_x_numeric, data_y)


Out[8]:
CoxPHSurvivalAnalysis(alpha=0, n_iter=100, tol=1e-09, verbose=0)

The result is a vector of coefficients, one for each variable, where each value corresponds to the log hazard ratio.


In [9]:
pd.Series(estimator.coef_, index=data_x_numeric.columns)


Out[9]:
Age_in_years            -0.008549
Celltype=large          -0.788672
Celltype=smallcell      -0.331813
Celltype=squamous       -1.188299
Karnofsky_score         -0.032622
Months_from_Diagnosis   -0.000092
Prior_therapy=yes        0.072327
Treatment=test           0.289936
dtype: float64

Using the fitted model, we can predict a patient-specific survival function, by passing an appropriate data matrix to the estimator's predict_survival_function method .

First, let's create a set of four synthetic patients.


In [10]:
x_new = pd.DataFrame.from_dict({
    1: [65, 0, 0, 1, 60, 1, 0, 1],
    2: [65, 0, 0, 1, 60, 1, 0, 0],
    3: [65, 0, 1, 0, 60, 1, 0, 0],
    4: [65, 0, 1, 0, 60, 1, 0, 1]},
     columns=data_x_numeric.columns, orient='index')
x_new


Out[10]:
Age_in_years Celltype=large Celltype=smallcell Celltype=squamous Karnofsky_score Months_from_Diagnosis Prior_therapy=yes Treatment=test
1 65 0 0 1 60 1 0 1
2 65 0 0 1 60 1 0 0
3 65 0 1 0 60 1 0 0
4 65 0 1 0 60 1 0 1

Similar to kaplan_meier_estimator, the predict_survival_function method returns a sequence of step functions, which we can plot.


In [11]:
pred_surv = estimator.predict_survival_function(x_new)
for i, c in enumerate(pred_surv):
    plt.step(c.x, c.y, where="post", label="Sample %d" % (i + 1))
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")


Out[11]:
<matplotlib.legend.Legend at 0x7f55a0c87c18>

Measuring the Performance of Survival Models

Once we fit a survival model, we usually want to assess how well a model can actually predict survival. Our test data is usually subject to censoring too, therefore metrics like root mean squared error or correlation are unsuitable. Instead, we use generalization of the area under the receiver operating characteristic (ROC) curve called Harrell's concordance index or c-index.

The interpretation is identical to the traditional area under the ROC curve metric for binary classification:

  • a value of 0.5 denotes a random model,
  • a value of 1.0 denotes a perfect model,
  • a value of 0.0 denotes a perfectly wrong model.

In [12]:
from sksurv.metrics import concordance_index_censored

prediction = estimator.predict(data_x_numeric)
result = concordance_index_censored(data_y["Status"], data_y["Survival_in_days"], prediction)
result[0]


Out[12]:
0.7362562471603816

or alternatively


In [13]:
estimator.score(data_x_numeric, data_y)


Out[13]:
0.7362562471603816

Our model's c-index indicates that the model clearly performs better than random, but is also far from perfect.

Feature Selection: Which Variable is Most Predictive?

The model above considered all available variables for prediction. Next, we want to investigate which single variable is the best risk predictor. Therefore, we fit a Cox model to each variable individually and record the c-index on the training set.


In [14]:
import numpy as np

def fit_and_score_features(X, y):
    n_features = X.shape[1]
    scores = np.empty(n_features)
    m = CoxPHSurvivalAnalysis()
    for j in range(n_features):
        Xj = X[:, j:j+1]
        m.fit(Xj, y)
        scores[j] = m.score(Xj, y)
    return scores

scores = fit_and_score_features(data_x_numeric.values, data_y)
pd.Series(scores, index=data_x_numeric.columns).sort_values(ascending=False)


Out[14]:
Karnofsky_score          0.709280
Celltype=smallcell       0.572581
Celltype=large           0.561620
Celltype=squamous        0.550545
Treatment=test           0.525386
Age_in_years             0.515107
Months_from_Diagnosis    0.509030
Prior_therapy=yes        0.494434
dtype: float64

Karnofsky_score is the best variable, whereas Months_from_Diagnosis and Prior_therapy='yes' have almost no predictive power on their own.

Next, we want to build a parsimonious model by excluding irrelevant features. We could use the ranking from above, but would need to determine what the optimal cut-off should be. Luckily, scikit-learn has built-in support for performing grid search.

First, we create a pipeline that puts all the parts together.


In [15]:
from sklearn.feature_selection import SelectKBest
from sklearn.pipeline import Pipeline

pipe = Pipeline([('encode', OneHotEncoder()),
                 ('select', SelectKBest(fit_and_score_features, k=3)),
                 ('model', CoxPHSurvivalAnalysis())])

Next, we need to define the range of parameters we want to explore during grid search. Here, we want to optimize the parameter k of the SelectKBest class and allow k to vary from 1 feature to all 8 features.


In [16]:
from sklearn.model_selection import GridSearchCV

param_grid = {'select__k': np.arange(1, data_x_numeric.shape[1] + 1)}
gcv = GridSearchCV(pipe, param_grid, return_train_score=True, cv=3, iid=True)
gcv.fit(data_x, data_y)

pd.DataFrame(gcv.cv_results_).sort_values(by='mean_test_score', ascending=False)


Out[16]:
mean_fit_time std_fit_time mean_score_time std_score_time param_select__k params split0_test_score split1_test_score split2_test_score mean_test_score std_test_score rank_test_score split0_train_score split1_train_score split2_train_score mean_train_score std_train_score
2 0.126878 0.008825 0.005426 0.000059 3 {'select__k': 3} 0.650628 0.718220 0.754649 0.707491 0.043067 1 0.765569 0.695647 0.714082 0.725099 0.029590
4 0.125702 0.008785 0.005293 0.000101 5 {'select__k': 5} 0.644874 0.738347 0.728822 0.703834 0.042098 2 0.783946 0.698207 0.718498 0.733551 0.036585
3 0.125528 0.008009 0.005279 0.000040 4 {'select__k': 4} 0.650628 0.719809 0.725723 0.698523 0.034138 3 0.768121 0.691037 0.707327 0.722162 0.033172
1 0.129095 0.011951 0.005358 0.000054 2 {'select__k': 2} 0.630753 0.717161 0.747934 0.698256 0.049604 4 0.758550 0.683611 0.705638 0.715933 0.031448
0 0.132264 0.016141 0.005395 0.000110 1 {'select__k': 1} 0.630753 0.715042 0.737087 0.693982 0.045843 5 0.744640 0.676697 0.695246 0.705527 0.028675
5 0.125061 0.008772 0.005272 0.000049 6 {'select__k': 6} 0.657427 0.669492 0.724690 0.683572 0.029179 6 0.783946 0.698848 0.716160 0.732985 0.036722
6 0.125344 0.009078 0.005225 0.000024 7 {'select__k': 7} 0.654812 0.659958 0.714876 0.676269 0.027083 7 0.788412 0.695519 0.712133 0.732021 0.040447
7 0.124911 0.009041 0.005294 0.000054 8 {'select__k': 8} 0.656904 0.653602 0.716942 0.675516 0.029004 8 0.786371 0.695006 0.713692 0.731690 0.039411

The results show that it is sufficient to select the 3 most predictive features.


In [17]:
pipe.set_params(**gcv.best_params_)
pipe.fit(data_x, data_y)

encoder, transformer, final_estimator = [s[1] for s in pipe.steps]
pd.Series(final_estimator.coef_, index=encoder.encoded_columns_[transformer.get_support()])


Out[17]:
Celltype=large       -0.067277
Celltype=smallcell    0.271007
Karnofsky_score      -0.031285
dtype: float64

What's next?

Cox's proportional hazards model is by far the most popular survival model, because once trained, it is easy to interpret. However, if prediction performance is the main objective, more sophisticated, non-linear or ensemble models might lead to better results. Check-out this notebook for getting a better understanding on how to evaluate survival models, and this notebook to learn more about Kernel Survival Support Vector Machines. The API reference contains a full list of models that are available within scikit-survival. In addition, you can use any unsupervised pre-processing method available with scikit-learn, for instance, you could perform dimensionality reduction using Non-Negative Matrix Factorization (NMF), before training a Cox model.


In [ ]: