Skip to content

Commit

Permalink
Survival chapter update (#850)
Browse files Browse the repository at this point in the history
* update survival chapter

* update errata

* add John as author
  • Loading branch information
bblodfon authored Dec 2, 2024
1 parent 624a9c3 commit 6aca217
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 45 deletions.
6 changes: 6 additions & 0 deletions book/chapters/appendices/errata.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,9 @@ This appendix lists changes to the online version of this book to chapters inclu
* Subset task to row 127 instead of 35 for the local surrogate model.
* Add `as.data.frame()` to "Correctly Interpreting Shapley Values" section.

## 13. Beyond Regression and Classification

* Use `gamma` instead of `gamma.mu` for `lrn("surv.svm")`
* Substitute RCLL with ISBS measure
* Mention `pipeline_responsecompositor()` pipeline for changing predict types
* Use `lrn("surv.xgboost.aft")` instead of `lrn("surv.glmnet")` in "Composition" subsection
64 changes: 35 additions & 29 deletions book/chapters/chapter13/beyond_regression_and_classification.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Instead of throwing away information about censored events, survival analysis da
So in our example, we might write the runner's outcome as $(4, 1)$ if they finish the race at four hours, otherwise, if they give up at two hours we would write $(2, 0)$.

The key to modeling in survival analysis is that we assume there exists a hypothetical time the marathon runner would have finished if they had not been censored, it is then the job of a survival learner to estimate what the true survival time would have been for a similar runner, assuming they are *not* censored (see @fig-censoring).
Mathematically, this is represented by the hypothetical event time, $Y$, the hypothetical censoring time, $C$, the observed outcome time, $T = \min(Y, C)$, the event indicator $\Delta = (T = Y)$, and as usual some features, $X$.
Mathematically, this is represented by the hypothetical event time, $Y$, the hypothetical censoring time, $C$, the observed outcome time, $T = \min(Y, C)$, the event indicator $\Delta := (T = Y)$, and as usual some features, $X$.
Learners are trained on $(T, \Delta)$ but, critically, make predictions of $Y$ from previously unseen features.
This means that unlike classification and regression, learners are trained on two variables, $(T, \Delta)$, which, in R, is often captured in a `r ref("survival::Surv")` object.
Relating to our example above, the runner's outcome would then be $(T = 4, \Delta = 1)$ or $(T = 2, \Delta = 0)$.
Expand Down Expand Up @@ -237,7 +237,7 @@ In survival analysis, the following predictions can be made:
* `crank` -- Continuous risk ranking.

We will go through each of these prediction types in more detail and with examples to make them less abstract.
We will use `lrn("surv.coxph")`\index{Cox Proportional Hazards} trained on `tsk("rats")` as a running example, for this model, all predict types except `response` can be computed.
We will use `lrn("surv.coxph")`\index{Cox Proportional Hazards} trained on `tsk("rats")` as a running example, since for this model, all predict types except `response` can be computed.

```{r special-015}
tsk_rats = tsk("rats")
Expand All @@ -259,7 +259,7 @@ We then compare the predictions from the model to the true data.

```{r special-016}
library(mlr3extralearners)
prediction_svm = lrn("surv.svm", type = "regression", gamma.mu = 1e-3)$
prediction_svm = lrn("surv.svm", type = "regression", gamma = 1e-3)$
train(tsk_rats, split$train)$predict(tsk_rats, split$test)
data.frame(pred = prediction_svm$response[1:3],
truth = prediction_svm$truth[1:3])
Expand All @@ -282,7 +282,6 @@ prediction_cph$distr[1:3]$survival(77)

The output indicates that there is a `r paste0(round(prediction_cph$distr[1:3]$survival(77)*100, 1), "%")`, chance of the first three predicted rats being alive at time 77 respectively.


#### predict_type = "lp" {.unnumbered .unlisted}

`lp`, often written as $\eta$ in academic writing, is computationally the simplest prediction and has a natural analog in regression modeling.
Expand Down Expand Up @@ -323,7 +322,7 @@ Survival models in `mlr3proba` are evaluated with `r ref("MeasureSurv")` objects
In general survival measures can be grouped into the following:

1. Discrimination measures -- Quantify if a model correctly identifies if one observation is at higher risk than another. Evaluate `crank` and/or `lp` predictions.
2. Calibration measures -- Quantify if the average prediction is close to the truth (all definitions of calibration are unfortunately vague in a survival context). Evaluate `crank` and/or `distr` predictions.
2. Calibration measures -- Quantify if the average prediction is close to the truth (all definitions of calibration are unfortunately vague in a survival context). Evaluate `crank` and/or `lp` predictions.
3. Scoring rules -- Quantify if probabilistic predictions are close to true values. Evaluate `distr` predictions.

```{r special-020}
Expand All @@ -332,15 +331,15 @@ as.data.table(mlr_measures)[
```

There is not a consensus in the literature around the 'best' survival measures to use to evaluate models.
We recommend RCLL (right-censored logloss) (`msr("surv.rcll")`) to evaluate the quality of `distr` predictions, concordance index (`msr("surv.cindex")`) to evaluate a model's discrimination, and D-Calibration (`msr("surv.dcalib")`) to evaluate a model's calibration.
We recommend ISBS (Integrated Survival Brier Score) (`msr("surv.graf")`) to evaluate the quality of `distr` predictions, concordance index (`msr("surv.cindex")`) to evaluate a model's discrimination, and D-Calibration (`msr("surv.dcalib")`) to evaluate a model's calibration.

Using these measures, we can now evaluate our predictions from the previous example.

```{r}
prediction_cph$score(msrs(c("surv.rcll", "surv.cindex", "surv.dcalib")))
prediction_cph$score(msrs(c("surv.graf", "surv.cindex", "surv.dcalib")))
```

The model's performance seems okay as the RCLL and DCalib are relatively low 0 and the C-index is greater than 0.5 however it is very hard to determine the performance of any survival model without comparing it to some baseline (usually the Kaplan-Meier).
The model's performance seems okay as the ISBS and DCalib are relatively low and the C-index is greater than 0.5 however it is very hard to determine the performance of any survival model without comparing it to some baseline (usually the Kaplan-Meier).

### Composition {#sec-surv-comp}

Expand All @@ -350,7 +349,7 @@ We define a 'native' prediction as the prediction made by a model without any po
#### Internal Composition

`mlr3proba` makes use of composition internally to return a `"crank"` prediction for every learner.
This is to ensure that we can meaningfully benchmark all models according to at least one criterion.
This is to ensure that we can meaningfully benchmark all models according to at least one criterion (discrimination performance).
The package uses the following rules to create `"crank"` predictions:

1. If a model returns a 'risk' prediction then `crank = risk` (we may multiply this by $-1$ to ensure the 'low-value low-risk' interpretation).
Expand All @@ -362,16 +361,19 @@ The package uses the following rules to create `"crank"` predictions:

At the start of this section, we mentioned that it is possible to transform prediction types between each other.
In `mlr3proba` this is possible with 'compositor' pipelines (@sec-pipelines).
There are several pipelines implemented in the package but two in particular focus on predict type transformation:
There are several pipelines implemented in the package but three in particular focus on predict type transformation:

1. `r ref("pipeline_crankcompositor()")` -- Transforms a `"distr"` prediction to `"crank"`
2. `r ref("pipeline_distrcompositor()")` -- Transforms a `"lp"` prediction to `"distr"`
3. `r ref("pipeline_responsecompositor()")` -- Transforms a `"distr"` prediction to `"response"` (survival time)

In practice, the second pipeline is more common as we internally use a version of the first pipeline whenever we return predictions from survival models (so only use the first pipeline to overwrite these ranking predictions), and so we will just look at the second pipeline.
We internally use a version of the first pipeline whenever we return predictions from survival models so that every model has a `"crank"` prediction type - so only use the first pipeline to overwrite these ranking predictions.
In practice, the second pipeline is more common as Cox or Accelerated Failure Time (AFT) type models always return a linear predictor (`"lp"`), but sometimes the internal `predict()` functions don't provide a transformation to a survival distribution prediction (`"distr"`).
The third pipeline summarizes the predicted survival curves to a single number (expected survival time), and as previously mentioned, are rarely useful for evaluating the performance of survival machine learning models.

In the example below we load the `rats` dataset, remove factor columns, and then partition the data into training and testing.
We construct the `distrcompositor` pipeline around a survival GLMnet learner (`lrn("surv.glmnet")`) which by default can only make predictions for `"lp"` and `"crank"`.
In the pipeline, we specify that we will estimate the baseline distribution with a `r index("Kaplan-Meier", lower = FALSE)` estimator (`estimator = "kaplan"`) and that we want to assume a proportional hazards form for our estimated distribution (`form = "ph"`).
We construct the `distrcompositor` pipeline around a survival XGBoost Accelerated Failure Time (AFT) learner (`lrn("surv.xgboost.aft")`) which by default makes predictions for `"lp"`, `"crank"` and `"response"`.
In the pipeline, we specify that we will estimate the baseline distribution with a `r index("Kaplan-Meier", lower = FALSE)` estimator (`estimator = "kaplan"`) and that we want to assume an AFT form for our estimated distribution (`form = "aft"`).
We then train and predict in the usual way and in our output we can now see a `distr` prediction.

```{r special-021, warning=FALSE}
Expand All @@ -381,28 +383,29 @@ library(mlr3extralearners)
tsk_rats = tsk("rats")$select(c("litter", "rx"))
split = partition(tsk_rats)
learner = lrn("surv.glmnet")
learner = lrn("surv.xgboost.aft", nrounds = 10)
# no distr output
learner$train(tsk_rats, split$train)$predict(tsk_rats, split$test)
graph_learner = as_learner(ppl(
graph_learner = ppl(
"distrcompositor",
learner = learner,
estimator = "kaplan",
form = "ph"
))
form = "aft",
graph_learner = TRUE
)
# now with distr
graph_learner$train(tsk_rats, split$train)$predict(tsk_rats, split$test)
```

Mathematically, we have done the following:

1. Assume our estimated distribution will have the form $h(t) = h_0(t)\exp(\eta)$ where $h$ is the hazard function and $h_0$ is the baseline hazard function.
1. Estimate $\hat{\eta}$ prediction using GLMnet
1. Estimate $\hat{h}_0(t)$ with the Kaplan-Meier estimator
1. Put this all together as $h(t) = \hat{h}_0(t)\exp(\hat{\eta})$
1. Assume our estimated distribution will have the form $S(t) = S_0(\frac{t}{\exp(\eta)})$ where $S$ is the survival function, $S_0$ is the baseline survival function and $\eta$ is the linear predictor.
1. Estimate $\hat{\eta}$ prediction using XGBoost
1. Estimate $\hat{S}_0(t)$ with the Kaplan-Meier estimator
1. Put this all together as $S(t) = \hat{S}_0(\frac{t}{\exp(\hat{\eta})})$

For more detail about prediction types and composition we recommend @Kalbfleisch2011.

Expand All @@ -411,32 +414,35 @@ For more detail about prediction types and composition we recommend @Kalbfleisch

Finally, we will put all the above into practice in a small benchmark experiment.
We first load `tsk("grace")` (which only has numeric features) and sample 500 rows randomly.
We then select the RCLL, D-Calibration, and C-index to evaluate predictions, set up the same pipeline we used in the previous experiment, and load a Cox PH and Kaplan-Meier estimator.
We then select the ISBS, D-Calibration, and C-index to evaluate predictions, set up the same pipeline we used in the previous experiment, and load a Cox PH and Kaplan-Meier estimator.
We run our experiment with three-fold CV and aggregate the results.

```{r special-022, warning=FALSE}
set.seed(42)
library(mlr3extralearners)
tsk_grace = tsk("grace")
tsk_grace$filter(sample(tsk_grace$nrow, 500))
msr_txt = c("surv.rcll", "surv.cindex", "surv.dcalib")
msr_txt = c("surv.graf", "surv.cindex", "surv.dcalib")
measures = msrs(msr_txt)
graph_learner = as_learner(ppl(
graph_learner = ppl(
"distrcompositor",
learner = lrn("surv.glmnet"),
learner = lrn("surv.xgboost.aft", nrounds = 10),
estimator = "kaplan",
form = "ph"
))
graph_learner$id = "Coxnet"
form = "aft",
graph_learner = TRUE,
scale_lp = TRUE
)
graph_learner$id = "XGBoost-AFT"
learners = c(lrns(c("surv.coxph", "surv.kaplan")), graph_learner)
bmr = benchmark(benchmark_grid(tsk_grace, learners,
rsmp("cv", folds = 3)))
bmr$aggregate(measures)[, c("learner_id", ..msr_txt)]
```

In this small experiment, Coxnet and Cox PH have the best discrimination, the Kaplan-Meier baseline has the best calibration, and Coxnet and Cox PH have similar overall predictive accuracy (with the lowest RCLL).
In this small experiment, XGBoost-AFT and Cox PH have the best discrimination, the Kaplan-Meier baseline has the best calibration, and Cox PH has the best overall predictive accuracy (with the lowest ISBS).

## Density Estimation {#sec-density}

Expand Down
32 changes: 16 additions & 16 deletions book/common/chap_auths.csv
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
Chapter Number,Title,Authors
1,Introduction and Overview,"Lars Kotthoff, Raphael Sonabend, Natalie Foss, Bernd Bischl"
2,Data and Basic Modeling,"Natalie Foss, Lars Kotthoff"
3,Evaluation and Benchmarking,"Giuseppe Casalicchio, Lukas Burk"
4,Hyperparameter Optimization,"Marc Becker, Lennart Schneider, Sebastian Fischer"
5,Advanced Tuning Methods and Black Box Optimization,"Lennart Schneider, Marc Becker"
6,Feature Selection,Marvin N. Wright
7,Sequential Pipelines,"Martin Binder, Florian Pfisterer"
8,Non-sequential Pipelines and Tuning,"Martin Binder, Florian Pfisterer, Marc Becker, Marvin N. Wright"
9,Preprocessing,Janek Thomas
10,Advanced Technical Aspects of mlr3,"Michel Lang, Sebastian Fischer, Raphael Sonabend"
11,Large-Scale Benchmarking,"Sebastian Fischer, Michel Lang, Marc Becker"
12,Model Interpretation,"Susanne Dandl, Przemysław Biecek, Giuseppe Casalicchio, Marvin N. Wright"
13,Beyond Regression and Classification,"Raphael Sonabend, Patrick Schratz, Damir Pulatov"
14,Algorithmic Fairness,Florian Pfisterer
15,"Predict Sets, Validation and Internal Tuning (+)", Sebastian Fischer
Chapter Number,Title,Authors
1,Introduction and Overview,"Lars Kotthoff, Raphael Sonabend, Natalie Foss, Bernd Bischl"
2,Data and Basic Modeling,"Natalie Foss, Lars Kotthoff"
3,Evaluation and Benchmarking,"Giuseppe Casalicchio, Lukas Burk"
4,Hyperparameter Optimization,"Marc Becker, Lennart Schneider, Sebastian Fischer"
5,Advanced Tuning Methods and Black Box Optimization,"Lennart Schneider, Marc Becker"
6,Feature Selection,Marvin N. Wright
7,Sequential Pipelines,"Martin Binder, Florian Pfisterer"
8,Non-sequential Pipelines and Tuning,"Martin Binder, Florian Pfisterer, Marc Becker, Marvin N. Wright"
9,Preprocessing,Janek Thomas
10,Advanced Technical Aspects of mlr3,"Michel Lang, Sebastian Fischer, Raphael Sonabend"
11,Large-Scale Benchmarking,"Sebastian Fischer, Michel Lang, Marc Becker"
12,Model Interpretation,"Susanne Dandl, Przemysław Biecek, Giuseppe Casalicchio, Marvin N. Wright"
13,Beyond Regression and Classification,"Raphael Sonabend, Patrick Schratz, Damir Pulatov, John Zobolas"
14,Algorithmic Fairness,Florian Pfisterer
15,"Predict Sets, Validation and Internal Tuning (+)", Sebastian Fischer
1 change: 1 addition & 0 deletions book/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Bernd Bischl, Raphael Sonabend, Lars Kotthoff, Michel Lang
- Raphael Sonabend
- Janek Thomas
- Marvin N. Wright
- John Zobolas
:::

::::
Expand Down

0 comments on commit 6aca217

Please sign in to comment.