Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check tune::compute_metrics for survival models #170

Closed
topepo opened this issue Jan 18, 2024 · 2 comments · Fixed by #193
Closed

check tune::compute_metrics for survival models #170

topepo opened this issue Jan 18, 2024 · 2 comments · Fixed by #193

Comments

@topepo
Copy link
Member

topepo commented Jan 18, 2024

Not sure if/where it is currently tested for this.

@topepo
Copy link
Member Author

topepo commented Jan 19, 2024

Looking at the code, I thought that we would need to add an eval_time argument (more on that below).

With no changes, the function does pretty much what it should (*)

library(tidymodels)
library(censored)
#> Loading required package: survival

lung_surv <- lung %>%
  dplyr::mutate(surv = Surv(time, status), .keep = "unused")

metrics <- metric_set(concordance_survival, brier_survival_integrated, brier_survival)

times <- c(2, 50, 100)

set.seed(2193)
tune_res <-
  proportional_hazards(penalty = tune(), engine = "glmnet") %>%
  tune_grid(
    surv ~ .,
    resamples = vfold_cv(lung_surv, 2),
    grid = tibble(penalty = c(0.001, 0.1)),
    control = control_grid(save_pred = TRUE),
    metrics = metrics,
    eval_time = times
  )

recomp <- compute_metrics(tune_res, metrics, summarize = TRUE)
original <- collect_metrics(tune_res)

all.equal(recomp, original)
#> [1] TRUE

stc_only <- compute_metrics(tune_res, metric_set(concordance_survival), summarize = TRUE)
stc_original <- collect_metrics(tune_res) %>% filter(.metric == "concordance_survival")

all.equal(stc_only, stc_original) # orginal has a .eval_time which is all NA
#> [1] "Names: 4 string mismatches"                                   
#> [2] "Length mismatch: comparison on first 7 components"            
#> [3] "Component 4: 'is.NA' value mismatch: 2 in current 0 in target"
#> [4] "Component 5: Mean relative difference: 0.6953599"             
#> [5] "Component 6: Mean relative difference: 58.79083"              
#> [6] "Component 7: Modes: character, numeric"                       
#> [7] "Component 7: target is character, current is numeric"
all.equal(stc_only, stc_original %>% select(-.eval_time))
#> [1] TRUE

Created on 2024-01-19 with reprex v2.0.2

The only caveat is that the lack of an eval_time argument does not allow users to specify a subset of times. We could do that but the machinery to check that against the potential metrics change will add a substantial amount of code.

I'll look into where this is tested for survival models, but I think that we should leave the API as-is and see if there are requests for a subset of eval times. Also, I don't think that there are parity metrics for censored regression, so that is a moot point for that application.

@topepo
Copy link
Member Author

topepo commented Jan 25, 2024

We've settled on leaving compute_metrics() as-is. If we determine later that users would like to be able to filter evaluation times, we can make those changes.

We do need to make some unit tests for survival models so I'll keep this open until then.

topepo added a commit that referenced this issue Jan 30, 2024
topepo added a commit that referenced this issue Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant