scikit-survival is a Python module for survival analysis built on top of scikit-learn. It allows doing survival analysis while utilizing the power of scikit-learn, e.g., for pre-processing or doing cross-validation.
The objective in survival analysis — also referred to as reliability analysis in engineering — is to establish a connection between covariates and the time of an event. The name survival analysis originates from clinical research, where predicting the time to death, i.e., survival, is often the main objective. Survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. It differs from traditional regression by the fact that parts of the training data can only be partially observed – they are censored.
As an example, consider a clinical study, which investigates coronary heart disease and has been carried out over a 1 year period as in the figure below.
Patient A was lost to follow-up after three months with no recorded cardiovascular event, patient B experienced an event four and a half months after enrollment, patient D withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the exact time of a cardiovascular event could only be recorded for patients B and C; their records are uncensored. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, D, and E is that they were event-free up to their last follow-up. Therefore, their records are censored.
Formally, each patient record consists of a set of covariates $x \in \mathbb{R}^d$ , and the time $t>0$ when an event occurred or the time $c>0$ of censoring. Since censoring and experiencing and event are mutually exclusive, it is common to define an event indicator $\delta \in \{0;1\}$ and the observable survival time $y>0$. The observable time $y$ of a right censored sample is defined as
$$ y = \min(t, c) = \begin{cases} t & \text{if } \delta = 1 , \\ c & \text{if } \delta = 0 . \end{cases} $$Consequently, survival analysis demands for models that take this unique characteristic of such a dataset into account, some of which are showcased below.
The Veterans' Administration Lung Cancer Trial is a randomized trial of two treatment regimens for lung cancer. The data set (Kalbfleisch J. and Prentice R, (1980) The Statistical Analysis of Failure Time Data. New York: Wiley) consists of 137 patients and 8 variables, which are discribed below:
Treatment
: denotes the type of lung cancer treatment; standard
and test
drug.Celltype
: denotes the type of cell involved; squamous
, small cell
, adeno
, large
.Karnofsky_score
: is the Karnofsky score.Diag
: is the time since diagnosis in months.Age
: is the age in years.Prior_Therapy
: denotes any prior therapy; none
or yes
.Status
: denotes the status of the patient as dead or alive; dead
or alive
.Survival_in_days
: is the survival time in days since the treatment.Our primary interest is studying whether there a subgroups that differ in survival and whether we can predict survival times.
As described in the section What is Survival Analysis? above, survival times are subject to right-censoring, therefore, we need to consider an individual's status in addition to survival time. To be fully compatible with scikit-learn, Status
and Survival_in_days
need to be stored as a structured array with the first field indicating whether the actual survival time was observed or if was censored, and the second field denoting the observerd survival time, which corresponds to the time of death (if Status == 'dead'
, $\delta = 1$) or the last time that person was contacted (if Status == 'alive'
, $\delta = 0$).
In [1]:
from sksurv.datasets import load_veterans_lung_cancer
data_x, data_y = load_veterans_lung_cancer()
data_y
Out[1]:
We can easily see that only a small number of survival times are right-censored (Status
is False
), i.e., most veteran's died during the study period (Status
is True
).
A key quantity in survival analysis is the so-called survival function, which relates time to the probability of surviving beyond a given time point.
Let $T$ denote a continuous non-negative random variable corresponding to a patient’s survival time. The survival function $S(t)$ returns the probability of survival beyond time $t$ and is defined as $$ S(t) = P (T > t). $$
If we observed the exact survival time of all subjects, i.e., everyone died before the study ended, the survival function at time $t$ can simply be estimated by the ratio of patients surviving beyond time $t$ and the total number of patients:
$$ \hat{S}(t) = \frac{ \text{number of patients surviving beyond $t$} }{ \text{total number of patients} } $$In the presence of censoring, this estimator cannot be used, because the numerator is not always defined. For instance, consider the following set of patients:
In [2]:
import pandas as pd
pd.DataFrame.from_records(data_y[[11, 5, 32, 13, 23]], index=range(1, 6))
Out[2]:
Using the formula from above, we can compute $\hat{S}(t=11) = \frac{3}{5}$, but not $\hat{S}(t=30)$, because we don't know whether the 4th patient is still alive at $t = 30$, all we know is that when we last checked at $t = 25$, the patient was still alive.
An estimator, similar to the one above, that is valid if survival times are right-censored is the Kaplan-Meier estimator.
In [3]:
%matplotlib inline
import matplotlib.pyplot as plt
from sksurv.nonparametric import kaplan_meier_estimator
time, survival_prob = kaplan_meier_estimator(data_y["Status"], data_y["Survival_in_days"])
plt.step(time, survival_prob, where="post")
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
Out[3]:
The estimated curve is a step function, with steps occuring at time points where one or more patients died. From the plot we can see that most patients died in the first 200 days, as indicated by the steep slope of the estimated survival function in the first 200 days.
Patients enrolled in the Veterans' Administration Lung Cancer Trial were randomized to one of two treatments: standard
and a new test
drug. Next, let's have a look at how many patients underwent the standard treatment and how many received the new drug.
In [4]:
data_x["Treatment"].value_counts()
Out[4]:
Roughly half the patients received the alternative treatment.
The obvious questions to ask is:
Is there any difference in survival between the two treatment groups?
As a first attempt, we can estimate the survival function in both treatment groups separately.
In [5]:
for treatment_type in ("standard", "test"):
mask_treat = data_x["Treatment"] == treatment_type
time_treatment, survival_prob_treatment = kaplan_meier_estimator(
data_y["Status"][mask_treat],
data_y["Survival_in_days"][mask_treat])
plt.step(time_treatment, survival_prob_treatment, where="post",
label="Treatment = %s" % treatment_type)
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")
Out[5]:
Unfortunately, the results are inconclusive, because the difference between the two estimated survival functions is too small to confidently argue that the drug affects survival or not.
Sidenote: Visually comparing estimated survival curves in order to assess whether there is a difference in survival between groups is usually not recommended, because it is highly subjective. Statistical tests such as the log-rank test are usually more appropriate.
In [6]:
for value in data_x["Celltype"].unique():
mask = data_x["Celltype"] == value
time_cell, survival_prob_cell = kaplan_meier_estimator(data_y["Status"][mask],
data_y["Survival_in_days"][mask])
plt.step(time_cell, survival_prob_cell, where="post",
label="%s (n = %d)" % (value, mask.sum()))
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")
Out[6]:
In this case, we observe a pronounced difference between two groups. Patients with squamous or large cells seem to have a better prognosis compared to patients with small or adeno cells.
In the Kaplan-Meier approach used above, we estimated multiple survival curves by dividing the dataset into smaller sub-groups according to a variable. If we want to consider more than 1 or 2 variables, this approach quickly becomes infeasible, because subgroups will get very small. Instead, we can use a linear model, Cox's proportional hazard's model, to estimate the impact each variable has on survival.
First however, we need to convert the categorical variables in the data set into numeric values.
In [7]:
from sksurv.preprocessing import OneHotEncoder
data_x_numeric = OneHotEncoder().fit_transform(data_x)
data_x_numeric.head()
Out[7]:
Survival models in scikit-survival follow the same rules as estimators in scikit-learn, i.e., they have a fit
method, which expects a data matrix and a structered array of survival times and binary event indicators.
In [8]:
from sksurv.linear_model import CoxPHSurvivalAnalysis
estimator = CoxPHSurvivalAnalysis()
estimator.fit(data_x_numeric, data_y)
Out[8]:
The result is a vector of coefficients, one for each variable, where each value corresponds to the log hazard ratio.
In [9]:
pd.Series(estimator.coef_, index=data_x_numeric.columns)
Out[9]:
Using the fitted model, we can predict a patient-specific survival function, by passing an appropriate data matrix to the estimator's predict_survival_function
method .
First, let's create a set of four synthetic patients.
In [10]:
x_new = pd.DataFrame.from_dict({
1: [65, 0, 0, 1, 60, 1, 0, 1],
2: [65, 0, 0, 1, 60, 1, 0, 0],
3: [65, 0, 1, 0, 60, 1, 0, 0],
4: [65, 0, 1, 0, 60, 1, 0, 1]},
columns=data_x_numeric.columns, orient='index')
x_new
Out[10]:
Similar to kaplan_meier_estimator
, the predict_survival_function
method returns a sequence of step functions, which we can plot.
In [11]:
pred_surv = estimator.predict_survival_function(x_new)
for i, c in enumerate(pred_surv):
plt.step(c.x, c.y, where="post", label="Sample %d" % (i + 1))
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
plt.legend(loc="best")
Out[11]:
Once we fit a survival model, we usually want to assess how well a model can actually predict survival. Our test data is usually subject to censoring too, therefore metrics like root mean squared error or correlation are unsuitable. Instead, we use generalization of the area under the receiver operating characteristic (ROC) curve called Harrell's concordance index or c-index.
The interpretation is identical to the traditional area under the ROC curve metric for binary classification:
In [12]:
from sksurv.metrics import concordance_index_censored
prediction = estimator.predict(data_x_numeric)
result = concordance_index_censored(data_y["Status"], data_y["Survival_in_days"], prediction)
result[0]
Out[12]:
or alternatively
In [13]:
estimator.score(data_x_numeric, data_y)
Out[13]:
Our model's c-index indicates that the model clearly performs better than random, but is also far from perfect.
The model above considered all available variables for prediction. Next, we want to investigate which single variable is the best risk predictor. Therefore, we fit a Cox model to each variable individually and record the c-index on the training set.
In [14]:
import numpy as np
def fit_and_score_features(X, y):
n_features = X.shape[1]
scores = np.empty(n_features)
m = CoxPHSurvivalAnalysis()
for j in range(n_features):
Xj = X[:, j:j+1]
m.fit(Xj, y)
scores[j] = m.score(Xj, y)
return scores
scores = fit_and_score_features(data_x_numeric.values, data_y)
pd.Series(scores, index=data_x_numeric.columns).sort_values(ascending=False)
Out[14]:
Karnofsky_score
is the best variable, whereas Months_from_Diagnosis
and Prior_therapy='yes'
have almost no predictive power on their own.
Next, we want to build a parsimonious model by excluding irrelevant features. We could use the ranking from above, but would need to determine what the optimal cut-off should be. Luckily, scikit-learn has built-in support for performing grid search.
First, we create a pipeline that puts all the parts together.
In [15]:
from sklearn.feature_selection import SelectKBest
from sklearn.pipeline import Pipeline
pipe = Pipeline([('encode', OneHotEncoder()),
('select', SelectKBest(fit_and_score_features, k=3)),
('model', CoxPHSurvivalAnalysis())])
Next, we need to define the range of parameters we want to explore during grid search. Here, we want to optimize the parameter k
of the SelectKBest
class and allow k
to vary from 1 feature to all 8 features.
In [16]:
from sklearn.model_selection import GridSearchCV
param_grid = {'select__k': np.arange(1, data_x_numeric.shape[1] + 1)}
gcv = GridSearchCV(pipe, param_grid, return_train_score=True, cv=3, iid=True)
gcv.fit(data_x, data_y)
pd.DataFrame(gcv.cv_results_).sort_values(by='mean_test_score', ascending=False)
Out[16]:
The results show that it is sufficient to select the 3 most predictive features.
In [17]:
pipe.set_params(**gcv.best_params_)
pipe.fit(data_x, data_y)
encoder, transformer, final_estimator = [s[1] for s in pipe.steps]
pd.Series(final_estimator.coef_, index=encoder.encoded_columns_[transformer.get_support()])
Out[17]:
Cox's proportional hazards model is by far the most popular survival model, because once trained, it is easy to interpret. However, if prediction performance is the main objective, more sophisticated, non-linear or ensemble models might lead to better results. Check-out this notebook for getting a better understanding on how to evaluate survival models, and this notebook to learn more about Kernel Survival Support Vector Machines. The API reference contains a full list of models that are available within scikit-survival. In addition, you can use any unsupervised pre-processing method available with scikit-learn, for instance, you could perform dimensionality reduction using Non-Negative Matrix Factorization (NMF), before training a Cox model.
In [ ]: