-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test last_fit() for survival models (#142)
* initial changes * test for Bayesian optimization * move grid testing to use glmnet * reorganize/block code * add more tests * move to mboost as model * update SA metric collection and other tests * update to race with rpart trees * Added tests for #113 * partial collect metric tests * remaining racing results * last_fit tests for survival models * smaller, more portable file names * changes based on reviewer comments * remove fit_resamples tests * Apply suggestions from code review Co-authored-by: Emil Hvitfeldt <[email protected]> * Apply suggestions from code review Co-authored-by: Emil Hvitfeldt <[email protected]> --------- Co-authored-by: Emil Hvitfeldt <[email protected]>
- Loading branch information
1 parent
4834aea
commit 1c35f59
Showing
1 changed file
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
suppressPackageStartupMessages(library(tidymodels)) | ||
suppressPackageStartupMessages(library(censored)) | ||
|
||
skip_if_not_installed("parsnip", minimum_version = "1.1.0.9003") | ||
skip_if_not_installed("censored", minimum_version = "0.2.0.9000") | ||
skip_if_not_installed("tune", minimum_version = "1.1.1.9001") | ||
skip_if_not_installed("yardstick", minimum_version = "1.2.0.9001") | ||
|
||
test_that("last fit for survival models with static metric", { | ||
skip_if_not_installed("prodlim") | ||
|
||
# standard setup start ------------------------------------------------------- | ||
|
||
set.seed(1) | ||
sim_dat <- prodlim::SimSurv(500) %>% | ||
mutate(event_time = Surv(time, event)) %>% | ||
select(event_time, X1, X2) | ||
|
||
set.seed(2) | ||
split <- initial_split(sim_dat) | ||
sim_tr <- training(split) | ||
sim_te <- testing(split) | ||
sim_rs <- vfold_cv(sim_tr) | ||
|
||
time_points <- c(10, 1, 5, 15) | ||
|
||
# last fit for models with static metrics ------------------------------------ | ||
|
||
stc_mtrc <- metric_set(concordance_survival) | ||
|
||
set.seed(2193) | ||
rs_static_res <- | ||
survival_reg() %>% | ||
last_fit( | ||
event_time ~ X1 + X2, | ||
split = split, | ||
metrics = stc_mtrc | ||
) | ||
|
||
# test structure of results -------------------------------------------------- | ||
|
||
expect_named( | ||
rs_static_res, | ||
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow") | ||
) | ||
expect_false(".eval_time" %in% names(rs_static_res$.metrics[[1]])) | ||
expect_equal( | ||
names(rs_static_res$.predictions[[1]]), | ||
c(".pred_time", ".row", "event_time", ".config") | ||
) | ||
expect_s3_class(rs_static_res$.workflow[[1]], "workflow") | ||
|
||
# test metric collection ----------------------------------------------------- | ||
|
||
metric_sum <- collect_metrics(rs_static_res) | ||
exp_metric_sum <- | ||
tibble( | ||
.metric = character(0), | ||
.estimator = character(0), | ||
.estimate = numeric(0), | ||
.config = character(0) | ||
) | ||
|
||
expect_true(nrow(metric_sum) == 1) | ||
expect_equal(metric_sum[0,], exp_metric_sum) | ||
expect_true(all(metric_sum$.metric == "concordance_survival")) | ||
|
||
}) | ||
|
||
test_that("last fit for survival models with integrated metric", { | ||
skip_if_not_installed("prodlim") | ||
|
||
# standard setup start ------------------------------------------------------- | ||
|
||
set.seed(1) | ||
sim_dat <- prodlim::SimSurv(500) %>% | ||
mutate(event_time = Surv(time, event)) %>% | ||
select(event_time, X1, X2) | ||
|
||
set.seed(2) | ||
split <- initial_split(sim_dat) | ||
sim_tr <- training(split) | ||
sim_te <- testing(split) | ||
sim_rs <- vfold_cv(sim_tr) | ||
|
||
time_points <- c(10, 1, 5, 15) | ||
|
||
# last fit for models with integrated metrics -------------------------------- | ||
|
||
sint_mtrc <- metric_set(brier_survival_integrated) | ||
|
||
set.seed(2193) | ||
rs_integrated_res <- | ||
survival_reg() %>% | ||
last_fit( | ||
event_time ~ X1 + X2, | ||
split = split, | ||
metrics = sint_mtrc, | ||
eval_time = time_points | ||
) | ||
|
||
# test structure of results -------------------------------------------------- | ||
|
||
expect_named( | ||
rs_integrated_res, | ||
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow") | ||
) | ||
expect_false(".eval_time" %in% names(rs_integrated_res$.metrics[[1]])) | ||
expect_named( | ||
rs_integrated_res$.predictions[[1]], | ||
c(".pred", ".row", "event_time", ".config") | ||
) | ||
expect_true(is.list(rs_integrated_res$.predictions[[1]]$.pred)) | ||
expect_named( | ||
rs_integrated_res$.predictions[[1]]$.pred[[1]], | ||
c(".eval_time", ".pred_survival", ".weight_censored") | ||
) | ||
expect_equal( | ||
rs_integrated_res$.predictions[[1]]$.pred[[1]]$.eval_time, | ||
time_points | ||
) | ||
|
||
# test metric collection ----------------------------------------------------- | ||
|
||
metric_sum <- collect_metrics(rs_integrated_res) | ||
exp_metric_sum <- | ||
tibble( | ||
.metric = character(0), | ||
.estimator = character(0), | ||
.estimate = numeric(0), | ||
.config = character(0) | ||
) | ||
|
||
expect_true(nrow(metric_sum) == 1) | ||
expect_equal(metric_sum[0,], exp_metric_sum) | ||
expect_true(all(metric_sum$.metric == "brier_survival_integrated")) | ||
|
||
}) | ||
|
||
test_that("last fit for survival models with dynamic metric", { | ||
skip_if_not_installed("prodlim") | ||
|
||
# standard setup start ------------------------------------------------------- | ||
|
||
set.seed(1) | ||
sim_dat <- prodlim::SimSurv(500) %>% | ||
mutate(event_time = Surv(time, event)) %>% | ||
select(event_time, X1, X2) | ||
|
||
set.seed(2) | ||
split <- initial_split(sim_dat) | ||
sim_tr <- training(split) | ||
sim_te <- testing(split) | ||
sim_rs <- vfold_cv(sim_tr) | ||
|
||
time_points <- c(10, 1, 5, 15) | ||
|
||
# last fit for models with dynamic metrics ----------------------------------- | ||
|
||
dyn_mtrc <- metric_set(brier_survival) | ||
|
||
set.seed(2193) | ||
rs_dynamic_res <- | ||
survival_reg() %>% | ||
last_fit( | ||
event_time ~ X1 + X2, | ||
split = split, | ||
metrics = dyn_mtrc, | ||
eval_time = time_points | ||
) | ||
|
||
# test structure of results -------------------------------------------------- | ||
|
||
expect_named( | ||
rs_dynamic_res, | ||
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow") | ||
) | ||
expect_true(".eval_time" %in% names(rs_dynamic_res$.metrics[[1]])) | ||
expect_named( | ||
rs_dynamic_res$.predictions[[1]], | ||
c(".pred", ".row", "event_time", ".config") | ||
) | ||
expect_true(is.list(rs_dynamic_res$.predictions[[1]]$.pred)) | ||
expect_named( | ||
rs_dynamic_res$.predictions[[1]]$.pred[[1]], | ||
c(".eval_time", ".pred_survival", ".weight_censored") | ||
) | ||
expect_equal( | ||
rs_dynamic_res$.predictions[[1]]$.pred[[1]]$.eval_time, | ||
time_points | ||
) | ||
|
||
# test metric collection ----------------------------------------------------- | ||
|
||
metric_sum <- collect_metrics(rs_dynamic_res) | ||
exp_metric_sum <- | ||
tibble( | ||
.metric = character(0), | ||
.estimator = character(0), | ||
.eval_time = numeric(0), | ||
.estimate = numeric(0), | ||
.config = character(0) | ||
) | ||
|
||
expect_true(nrow(metric_sum) == length(time_points)) | ||
expect_equal(metric_sum[0,], exp_metric_sum) | ||
expect_true(all(metric_sum$.metric == "brier_survival")) | ||
}) | ||
|
||
test_that("last fit for survival models with mixture of metrics", { | ||
skip_if_not_installed("prodlim") | ||
|
||
# standard setup start ------------------------------------------------------- | ||
|
||
set.seed(1) | ||
sim_dat <- prodlim::SimSurv(500) %>% | ||
mutate(event_time = Surv(time, event)) %>% | ||
select(event_time, X1, X2) | ||
|
||
set.seed(2) | ||
split <- initial_split(sim_dat) | ||
sim_tr <- training(split) | ||
sim_te <- testing(split) | ||
sim_rs <- vfold_cv(sim_tr) | ||
|
||
time_points <- c(10, 1, 5, 15) | ||
|
||
# last fit for models with a mixture of metrics ------------------------------ | ||
|
||
mix_mtrc <- metric_set(brier_survival, brier_survival_integrated, concordance_survival) | ||
|
||
set.seed(2193) | ||
rs_mixed_res <- | ||
survival_reg() %>% | ||
last_fit( | ||
event_time ~ X1 + X2, | ||
split = split, | ||
metrics = mix_mtrc, | ||
eval_time = time_points | ||
) | ||
|
||
# test structure of results -------------------------------------------------- | ||
|
||
expect_named( | ||
rs_mixed_res, | ||
c("splits", "id", ".metrics", ".notes", ".predictions", ".workflow") | ||
) | ||
expect_true(".eval_time" %in% names(rs_mixed_res$.metrics[[1]])) | ||
expect_named( | ||
rs_mixed_res$.predictions[[1]], | ||
c(".pred", ".row", ".pred_time", "event_time", ".config") | ||
) | ||
expect_true(is.list(rs_mixed_res$.predictions[[1]]$.pred)) | ||
expect_named( | ||
rs_mixed_res$.predictions[[1]]$.pred[[1]], | ||
c(".eval_time", ".pred_survival", ".weight_censored") | ||
) | ||
expect_equal( | ||
rs_mixed_res$.predictions[[1]]$.pred[[1]]$.eval_time, | ||
time_points | ||
) | ||
|
||
# test metric collection ----------------------------------------------------- | ||
|
||
metric_sum <- collect_metrics(rs_mixed_res) | ||
exp_metric_sum <- | ||
tibble( | ||
.metric = character(0), | ||
.estimator = character(0), | ||
.eval_time = numeric(0), | ||
.estimate = numeric(0), | ||
.config = character(0) | ||
) | ||
|
||
expect_true(nrow(metric_sum) == length(time_points) + 2) | ||
expect_equal(metric_sum[0,], exp_metric_sum) | ||
expect_true(sum(is.na(metric_sum$.eval_time)) == 2) | ||
expect_equal(as.vector(table(metric_sum$.metric)), c(length(time_points), 1L, 1L)) | ||
|
||
}) |