Skip to content

Commit

Permalink
Test last_fit() for survival models (#142)
Browse files Browse the repository at this point in the history
* initial changes

* test for Bayesian optimization

* move grid testing to use glmnet

* reorganize/block code

* add more tests

* move to mboost as model

* update SA metric collection and other tests

* update to race with rpart trees

* Added tests for #113

* partial collect metric tests

* remaining racing results

* last_fit tests for survival models

* smaller, more portable file names

* changes based on reviewer comments

* remove fit_resamples tests

* Apply suggestions from code review

Co-authored-by: Emil Hvitfeldt <[email protected]>

* Apply suggestions from code review

Co-authored-by: Emil Hvitfeldt <[email protected]>

---------

Co-authored-by: Emil Hvitfeldt <[email protected]>
  • Loading branch information
topepo and EmilHvitfeldt authored Dec 18, 2023
1 parent 4834aea commit 1c35f59
Showing 1 changed file with 280 additions and 0 deletions.
280 changes: 280 additions & 0 deletions tests/testthat/test-survival-last-fit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
suppressPackageStartupMessages(library(tidymodels))
suppressPackageStartupMessages(library(censored))

skip_if_not_installed("parsnip", minimum_version = "1.1.0.9003")
skip_if_not_installed("censored", minimum_version = "0.2.0.9000")
skip_if_not_installed("tune", minimum_version = "1.1.1.9001")
skip_if_not_installed("yardstick", minimum_version = "1.2.0.9001")

test_that("last fit for survival models with static metric", {
skip_if_not_installed("prodlim")

# standard setup start -------------------------------------------------------

set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

# last fit for models with static metrics ------------------------------------

stc_mtrc <- metric_set(concordance_survival)

set.seed(2193)
rs_static_res <-
survival_reg() %>%
last_fit(
event_time ~ X1 + X2,
split = split,
metrics = stc_mtrc
)

# test structure of results --------------------------------------------------

expect_named(
rs_static_res,
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow")
)
expect_false(".eval_time" %in% names(rs_static_res$.metrics[[1]]))
expect_equal(
names(rs_static_res$.predictions[[1]]),
c(".pred_time", ".row", "event_time", ".config")
)
expect_s3_class(rs_static_res$.workflow[[1]], "workflow")

# test metric collection -----------------------------------------------------

metric_sum <- collect_metrics(rs_static_res)
exp_metric_sum <-
tibble(
.metric = character(0),
.estimator = character(0),
.estimate = numeric(0),
.config = character(0)
)

expect_true(nrow(metric_sum) == 1)
expect_equal(metric_sum[0,], exp_metric_sum)
expect_true(all(metric_sum$.metric == "concordance_survival"))

})

test_that("last fit for survival models with integrated metric", {
skip_if_not_installed("prodlim")

# standard setup start -------------------------------------------------------

set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

# last fit for models with integrated metrics --------------------------------

sint_mtrc <- metric_set(brier_survival_integrated)

set.seed(2193)
rs_integrated_res <-
survival_reg() %>%
last_fit(
event_time ~ X1 + X2,
split = split,
metrics = sint_mtrc,
eval_time = time_points
)

# test structure of results --------------------------------------------------

expect_named(
rs_integrated_res,
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow")
)
expect_false(".eval_time" %in% names(rs_integrated_res$.metrics[[1]]))
expect_named(
rs_integrated_res$.predictions[[1]],
c(".pred", ".row", "event_time", ".config")
)
expect_true(is.list(rs_integrated_res$.predictions[[1]]$.pred))
expect_named(
rs_integrated_res$.predictions[[1]]$.pred[[1]],
c(".eval_time", ".pred_survival", ".weight_censored")
)
expect_equal(
rs_integrated_res$.predictions[[1]]$.pred[[1]]$.eval_time,
time_points
)

# test metric collection -----------------------------------------------------

metric_sum <- collect_metrics(rs_integrated_res)
exp_metric_sum <-
tibble(
.metric = character(0),
.estimator = character(0),
.estimate = numeric(0),
.config = character(0)
)

expect_true(nrow(metric_sum) == 1)
expect_equal(metric_sum[0,], exp_metric_sum)
expect_true(all(metric_sum$.metric == "brier_survival_integrated"))

})

test_that("last fit for survival models with dynamic metric", {
skip_if_not_installed("prodlim")

# standard setup start -------------------------------------------------------

set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

# last fit for models with dynamic metrics -----------------------------------

dyn_mtrc <- metric_set(brier_survival)

set.seed(2193)
rs_dynamic_res <-
survival_reg() %>%
last_fit(
event_time ~ X1 + X2,
split = split,
metrics = dyn_mtrc,
eval_time = time_points
)

# test structure of results --------------------------------------------------

expect_named(
rs_dynamic_res,
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow")
)
expect_true(".eval_time" %in% names(rs_dynamic_res$.metrics[[1]]))
expect_named(
rs_dynamic_res$.predictions[[1]],
c(".pred", ".row", "event_time", ".config")
)
expect_true(is.list(rs_dynamic_res$.predictions[[1]]$.pred))
expect_named(
rs_dynamic_res$.predictions[[1]]$.pred[[1]],
c(".eval_time", ".pred_survival", ".weight_censored")
)
expect_equal(
rs_dynamic_res$.predictions[[1]]$.pred[[1]]$.eval_time,
time_points
)

# test metric collection -----------------------------------------------------

metric_sum <- collect_metrics(rs_dynamic_res)
exp_metric_sum <-
tibble(
.metric = character(0),
.estimator = character(0),
.eval_time = numeric(0),
.estimate = numeric(0),
.config = character(0)
)

expect_true(nrow(metric_sum) == length(time_points))
expect_equal(metric_sum[0,], exp_metric_sum)
expect_true(all(metric_sum$.metric == "brier_survival"))
})

test_that("last fit for survival models with mixture of metrics", {
skip_if_not_installed("prodlim")

# standard setup start -------------------------------------------------------

set.seed(1)
sim_dat <- prodlim::SimSurv(500) %>%
mutate(event_time = Surv(time, event)) %>%
select(event_time, X1, X2)

set.seed(2)
split <- initial_split(sim_dat)
sim_tr <- training(split)
sim_te <- testing(split)
sim_rs <- vfold_cv(sim_tr)

time_points <- c(10, 1, 5, 15)

# last fit for models with a mixture of metrics ------------------------------

mix_mtrc <- metric_set(brier_survival, brier_survival_integrated, concordance_survival)

set.seed(2193)
rs_mixed_res <-
survival_reg() %>%
last_fit(
event_time ~ X1 + X2,
split = split,
metrics = mix_mtrc,
eval_time = time_points
)

# test structure of results --------------------------------------------------

expect_named(
rs_mixed_res,
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow")
)
expect_true(".eval_time" %in% names(rs_mixed_res$.metrics[[1]]))
expect_named(
rs_mixed_res$.predictions[[1]],
c(".pred", ".row", ".pred_time", "event_time", ".config")
)
expect_true(is.list(rs_mixed_res$.predictions[[1]]$.pred))
expect_named(
rs_mixed_res$.predictions[[1]]$.pred[[1]],
c(".eval_time", ".pred_survival", ".weight_censored")
)
expect_equal(
rs_mixed_res$.predictions[[1]]$.pred[[1]]$.eval_time,
time_points
)

# test metric collection -----------------------------------------------------

metric_sum <- collect_metrics(rs_mixed_res)
exp_metric_sum <-
tibble(
.metric = character(0),
.estimator = character(0),
.eval_time = numeric(0),
.estimate = numeric(0),
.config = character(0)
)

expect_true(nrow(metric_sum) == length(time_points) + 2)
expect_equal(metric_sum[0,], exp_metric_sum)
expect_true(sum(is.na(metric_sum$.eval_time)) == 2)
expect_equal(as.vector(table(metric_sum$.metric)), c(length(time_points), 1L, 1L))

})

0 comments on commit 1c35f59

Please sign in to comment.