You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I performed a small benchmark related to this PR - see reprex below: I wanted to know across all survival mlr3 learners that produce a survival matrix (distr predict type in mlr3proba), which time points are used as columns.
Results
Most survival learners use all the train times points (this plays a large role for computing metrics like eg IBS and making things fair). The different ones are the following:
surv.aorsf uses the unique event time points form the test set code - maybe it's a good idea to change that and harmonize with the rest of RSFs (for some reasons these learners use the unique event times points from the train set). We can use directly the learner$model$event_times slot in predict().
akritas and parametric have a ntime argument (default 150), to "spread out" the time points of the train set time points. The reason for this was efficiency (to NOT have too many time points). We could change that to have the default setting of using the unique train time points from the model$y[, "time"] slot, and if users want to use ntime they can do that.
penalized: uses the unique train event times but adds 0 and the largest time point (which if it belongs to a censored observation, this is an extra time point) - discussed with author.
rfsrc has also a ntime (defaul value: 150) that coerces the unique event times to 150 if more than 150 exists in the training data. In the below example task this doesn't happen (< 150 events), but it seems that having such a parameter possibly for all learners is a good thing.
surv.ranger also has a time.interest argument which is like ntime, default there is NULL (use all observed time points).
library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)
lrn_ids=mlr_learners$keys("^surv")
# remove some learners (DL models, take too much time: bart, mboost has issues, etc.)lrn_ids=lrn_ids[!grepl(pattern="blackboost|mboost|deep|pchazard|coxtime|priority|dnn|loghaz|gamboost", lrn_ids)]
# remove learners that don't predict `distr`lrn_ids= lapply(lrn_ids, function(id) {
learner= lrn(id)
if ("distr"%in%learner$predict_types) {
id
} else {
NULL
}
}) |> unlist()
lrn_ids# ~18 survival learners#> [1] "surv.akritas" "surv.aorsf" "surv.bart" "surv.cforest" #> [5] "surv.coxboost" "surv.coxph" "surv.ctree" "surv.cv_coxboost"#> [9] "surv.cv_glmnet" "surv.flexible" "surv.glmnet" "surv.kaplan" #> [13] "surv.nelson" "surv.parametric" "surv.penalized" "surv.ranger" #> [17] "surv.rfsrc" "surv.xgboost.cox"task= tsk("gbcs")
set.seed(42)
part= partition(task, ratio=0.5)
# keep different time points sets to check latertrain_times=task$unique_times(part$train)
train_event_times=task$unique_event_times(part$train)
test_times=task$times(part$test)
test_status=task$status(part$test)
test_event_times= sort(unique(test_times[test_status==1]))
test_times= sort(unique(test_times))
all_times=task$unique_times()
all_event_times=task$unique_event_times()
res= lapply(lrn_ids, function(id) {
print(id)
learner= lrn(id)
if (id=="surv.parametric") {
learner$param_set$set_values(.values=list(discrete=TRUE))
}
if (id=="surv.bart") {
learner$param_set$set_values(
# low settings to make computation faster.values=list(nskip=1, ndpost=3, keepevery=2, mc.cores=14)
)
}
if (id=="surv.cforect") {
learner$param_set$set_values(.values=list(cores=14))
}
if (id=="surv.ranger") {
learner$param_set$set_values(.values=list(num.threads=14))
}
learner$train(task, part$train)
p=learner$predict(task, part$test)
times= as.numeric(colnames(p$data$distr))
# return discrete times for which we have the predicted S(times)times
})
#> [1] "surv.akritas"#> [1] "surv.aorsf"#> [1] "surv.bart"#> [1] "surv.cforest"#> [1] "surv.coxboost"#> [1] "surv.coxph"#> [1] "surv.ctree"#> [1] "surv.cv_coxboost"#> [1] "surv.cv_glmnet"#> [1] "surv.flexible"#> [1] "surv.glmnet"#> Warning: Multiple lambdas have been fit. Lambda will be set to 0.01 (see#> parameter 's').#> [1] "surv.kaplan"#> [1] "surv.nelson"#> [1] "surv.parametric"#> [1] "surv.penalized"#> [1] "surv.ranger"#> [1] "surv.rfsrc"#> [1] "surv.xgboost.cox"
names(res) =lrn_ids# example times:
head(res$surv.aorsf)
#> [1] 72 177 210 294 311 323which_times= lapply(lrn_ids, function(id) {
times=res[[id]]
#print(id)lgl_list= suppressWarnings(list(
train= all(times==train_times),
train_event= all(times==train_event_times),
test= all(times==test_times),
test_event= all(times==test_event_times),
all= all(times==all_times),
all_Events= all(times==all_event_times)
))
Ihave
names(which(mlr3misc::map_lgl(lgl_list, isTRUE)))e.g.e.g.
})
names(which_times) =lrn_ids# Results: which time points are used by each learner in the predicted survival matrix?which_times#> $surv.akritas#> character(0)#> #> $surv.aorsf#> [1] "test_event"#> #> $surv.bart#> [1] "train"#> #> $surv.cforest#> [1] "train"#> #> $surv.coxboost#> [1] "train"#> #> $surv.coxph#> [1] "train"#> #> $surv.ctree#> [1] "train"#> #> $surv.cv_coxboost#> [1] "train"#> #> $surv.cv_glmnet#> [1] "train"#> #> $surv.flexible#> [1] "train"#> #> $surv.glmnet#> [1] "train"#> #> $surv.kaplan#> [1] "train"#> #> $surv.nelson#> [1] "train"#> #> $surv.parametric#> character(0)#> #> $surv.penalized#> character(0)#> #> $surv.ranger#> [1] "train_event"#> #> $surv.rfsrc#> [1] "train_event"#> #> $surv.xgboost.cox#> [1] "train"
Hey John! I think harmonizing is a good idea, and it's much easier to align aorsf with the other learners than aligning the other learners with aorsf. I think my rationale was that evaluating model predictions at the times when events occur should improve efficiency versus evaluating the predictions at times around those points or potentially missing event times in testing data that occur before or after the first or last event time in the training data, respectively. But in most cases I think the event times will be very similar in training versus testing data.
In the code example I now have the 3 RSFs (ranger, aorsf and rfsrc) that provide the unique train event time points, while all the rest of the learners provide the unique train time points for the survival matrix during prediction.
penalized behaves like RSFs (unique train event times) + adds 0 and the largest time point if it belongs to a censored observation
Some learners (e.g. parametric, rfsrc, ranger, akritas) have an argument to change the granularity (i.e how many) of the time points are used
Investigation
I performed a small benchmark related to this PR - see
reprex
below: I wanted to know across all survivalmlr3
learners that produce a survival matrix (distr
predict type inmlr3proba
), which time points are used as columns.Results
Most survival learners use all the train times points (this plays a large role for computing metrics like eg IBS and making things fair). The different ones are the following:
surv.aorsf
uses the unique event time points form the test set code - maybe it's a good idea to change that and harmonize with the rest of RSFs (for some reasons these learners use the unique event times points from the train set). We can use directly thelearner$model$event_times
slot inpredict()
.akritas
andparametric
have antime
argument (default 150), to "spread out" the time points of the train set time points. The reason for this was efficiency (to NOT have too many time points). We could change that to have the default setting of using the unique train time points from themodel$y[, "time"]
slot, and if users want to usentime
they can do that.penalized
: uses the unique train event times but adds0
and the largest time point (which if it belongs to a censored observation, this is an extra time point) - discussed with author.rfsrc
has also antime
(defaul value:150
) that coerces the unique event times to150
if more than150
exists in the training data. In the below example task this doesn't happen (< 150 events), but it seems that having such a parameter possibly for all learners is a good thing.surv.ranger
also has atime.interest
argument which is likentime
, default there isNULL
(use all observed time points).Created on 2024-09-26 with reprex v2.1.1
The text was updated successfully, but these errors were encountered: