Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time points used in survival learners predicted matrix (distr) #387

Closed
bblodfon opened this issue Sep 26, 2024 · 3 comments
Closed

Time points used in survival learners predicted matrix (distr) #387

bblodfon opened this issue Sep 26, 2024 · 3 comments
Assignees

Comments

@bblodfon
Copy link
Collaborator

bblodfon commented Sep 26, 2024

Investigation

I performed a small benchmark related to this PR - see reprex below: I wanted to know across all survival mlr3 learners that produce a survival matrix (distr predict type in mlr3proba), which time points are used as columns.

Results

Most survival learners use all the train times points (this plays a large role for computing metrics like eg IBS and making things fair). The different ones are the following:

  • surv.aorsf uses the unique event time points form the test set code - maybe it's a good idea to change that and harmonize with the rest of RSFs (for some reasons these learners use the unique event times points from the train set). We can use directly the learner$model$event_times slot in predict().
  • akritas and parametric have a ntime argument (default 150), to "spread out" the time points of the train set time points. The reason for this was efficiency (to NOT have too many time points). We could change that to have the default setting of using the unique train time points from the model$y[, "time"] slot, and if users want to use ntime they can do that.
  • penalized: uses the unique train event times but adds 0 and the largest time point (which if it belongs to a censored observation, this is an extra time point) - discussed with author.
  • rfsrc has also a ntime (defaul value: 150) that coerces the unique event times to 150 if more than 150 exists in the training data. In the below example task this doesn't happen (< 150 events), but it seems that having such a parameter possibly for all learners is a good thing.
  • surv.ranger also has a time.interest argument which is like ntime, default there is NULL (use all observed time points).
library(mlr3proba)
#> Loading required package: mlr3
library(mlr3extralearners)

lrn_ids = mlr_learners$keys("^surv")
# remove some learners (DL models, take too much time: bart, mboost has issues, etc.)
lrn_ids = lrn_ids[!grepl(pattern = "blackboost|mboost|deep|pchazard|coxtime|priority|dnn|loghaz|gamboost", lrn_ids)]
# remove learners that don't predict `distr`
lrn_ids = lapply(lrn_ids, function(id) {
  learner = lrn(id)
  if ("distr" %in% learner$predict_types) {
    id
  } else {
    NULL
  }
}) |> unlist()

lrn_ids # ~18 survival learners
#>  [1] "surv.akritas"     "surv.aorsf"       "surv.bart"        "surv.cforest"    
#>  [5] "surv.coxboost"    "surv.coxph"       "surv.ctree"       "surv.cv_coxboost"
#>  [9] "surv.cv_glmnet"   "surv.flexible"    "surv.glmnet"      "surv.kaplan"     
#> [13] "surv.nelson"      "surv.parametric"  "surv.penalized"   "surv.ranger"     
#> [17] "surv.rfsrc"       "surv.xgboost.cox"

task = tsk("gbcs")
set.seed(42)
part = partition(task, ratio = 0.5)

# keep different time points sets to check later
train_times = task$unique_times(part$train)
train_event_times = task$unique_event_times(part$train)

test_times = task$times(part$test)
test_status = task$status(part$test)
test_event_times = sort(unique(test_times[test_status == 1]))
test_times = sort(unique(test_times))

all_times = task$unique_times()
all_event_times = task$unique_event_times()

res = lapply(lrn_ids, function(id) {
  print(id)
  learner = lrn(id)

  if (id == "surv.parametric") {
    learner$param_set$set_values(.values = list(discrete = TRUE))
  }

  if (id == "surv.bart") {
    learner$param_set$set_values(
      # low settings to make computation faster
      .values = list(nskip = 1, ndpost = 3, keepevery = 2, mc.cores = 14)
    )
  }

  if (id == "surv.cforect") {
    learner$param_set$set_values(.values = list(cores = 14))
  }

  if (id == "surv.ranger") {
    learner$param_set$set_values(.values = list(num.threads = 14))
  }

  learner$train(task, part$train)
  p = learner$predict(task, part$test)
  times = as.numeric(colnames(p$data$distr))

  # return discrete times for which we have the predicted S(times)
  times
})
#> [1] "surv.akritas"
#> [1] "surv.aorsf"
#> [1] "surv.bart"
#> [1] "surv.cforest"
#> [1] "surv.coxboost"
#> [1] "surv.coxph"
#> [1] "surv.ctree"
#> [1] "surv.cv_coxboost"
#> [1] "surv.cv_glmnet"
#> [1] "surv.flexible"
#> [1] "surv.glmnet"
#> Warning: Multiple lambdas have been fit. Lambda will be set to 0.01 (see
#> parameter 's').
#> [1] "surv.kaplan"
#> [1] "surv.nelson"
#> [1] "surv.parametric"
#> [1] "surv.penalized"
#> [1] "surv.ranger"
#> [1] "surv.rfsrc"
#> [1] "surv.xgboost.cox"

names(res) = lrn_ids

# example times:
head(res$surv.aorsf)
#> [1]  72 177 210 294 311 323

which_times = lapply(lrn_ids, function(id) {
  times = res[[id]]
  #print(id)

  lgl_list = suppressWarnings(list(
    train = all(times == train_times),
    train_event = all(times == train_event_times),
    test = all(times == test_times),
    test_event = all(times == test_event_times),
    all = all(times == all_times),
    all_Events = all(times == all_event_times)
  ))
I have
  names(which(mlr3misc::map_lgl(lgl_list, isTRUE)))e.g.e.g.
})

names(which_times) = lrn_ids

# Results: which time points are used by each learner in the predicted survival matrix?
which_times
#> $surv.akritas
#> character(0)
#> 
#> $surv.aorsf
#> [1] "test_event"
#> 
#> $surv.bart
#> [1] "train"
#> 
#> $surv.cforest
#> [1] "train"
#> 
#> $surv.coxboost
#> [1] "train"
#> 
#> $surv.coxph
#> [1] "train"
#> 
#> $surv.ctree
#> [1] "train"
#> 
#> $surv.cv_coxboost
#> [1] "train"
#> 
#> $surv.cv_glmnet
#> [1] "train"
#> 
#> $surv.flexible
#> [1] "train"
#> 
#> $surv.glmnet
#> [1] "train"
#> 
#> $surv.kaplan
#> [1] "train"
#> 
#> $surv.nelson
#> [1] "train"
#> 
#> $surv.parametric
#> character(0)
#> 
#> $surv.penalized
#> character(0)
#> 
#> $surv.ranger
#> [1] "train_event"
#> 
#> $surv.rfsrc
#> [1] "train_event"
#> 
#> $surv.xgboost.cox
#> [1] "train"

Created on 2024-09-26 with reprex v2.1.1

@bcjaeger
Copy link
Contributor

Hey John! I think harmonizing is a good idea, and it's much easier to align aorsf with the other learners than aligning the other learners with aorsf. I think my rationale was that evaluating model predictions at the times when events occur should improve efficiency versus evaluating the predictions at times around those points or potentially missing event times in testing data that occur before or after the first or last event time in the training data, respectively. But in most cases I think the event times will be very similar in training versus testing data.

@bblodfon
Copy link
Collaborator Author

bblodfon commented Oct 14, 2024

See #385 for the time point harmonization.

In the code example I now have the 3 RSFs (ranger, aorsf and rfsrc) that provide the unique train event time points, while all the rest of the learners provide the unique train time points for the survival matrix during prediction.

  • penalized behaves like RSFs (unique train event times) + adds 0 and the largest time point if it belongs to a censored observation
  • Some learners (e.g. parametric, rfsrc, ranger, akritas) have an argument to change the granularity (i.e how many) of the time points are used

@bblodfon bblodfon self-assigned this Oct 15, 2024
@bblodfon
Copy link
Collaborator Author

#385

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants