From 585e01d804a8496c55a34c13cbe83a78d893904a Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Mon, 11 Dec 2023 13:43:32 -0500 Subject: [PATCH] Refactored functions to evaluate/check metric and eval time arguments (#780) * move function to new file * change function order for docs * documentation start * updates to the show/select functions * updates to select/show functions * updates for selecting eval times * remove commented out code * bug fix * metric test cases * add a survival model object * note for next PR * select/show test cases * small set of direct tests * update snapshot * Apply suggestions from code review Co-authored-by: Hannah Frick * Apply suggestions from code review * updates from previous review * small cli update * doc update * refresh snapshots * modularize a check * Remake with newest CRAN version of scales for #775 * argument checks for tune functions * Apply suggestions from code review Co-authored-by: Hannah Frick * Apply suggestions from code review Co-authored-by: Hannah Frick * add dot when function is invoked * add a warning for eval times with non-survival models * go back to enquos * rework warning text * rework warning text pt 2 * unit tests for regression and classification * caller envs * unit tests * rework call envs * survival unit tests * Apply suggestions from code review Co-authored-by: Hannah Frick * update tests for more recent version of yardstick * use standalone file from tidymodels/parsnip#1034 --------- Co-authored-by: Hannah Frick --- DESCRIPTION | 4 +- NAMESPACE | 2 + R/last_fit.R | 9 +- R/metric-selection.R | 116 +++- R/standalone-survival.R | 141 ++++ R/tune_bayes.R | 15 +- R/tune_grid.R | 5 +- man/choose_metric.Rd | 11 + tests/testthat/_snaps/censored-reg.md | 37 +- tests/testthat/_snaps/eval-time-args.md | 845 ++++++++++++++++++++++++ tests/testthat/_snaps/last-fit.new.md | 63 ++ tests/testthat/_snaps/metric-args.md | 290 ++++++++ tests/testthat/test-censored-reg.R | 7 +- tests/testthat/test-eval-time-args.R | 329 +++++++++ tests/testthat/test-metric-args.R | 172 +++++ 15 files changed, 2015 insertions(+), 31 deletions(-) create mode 100644 R/standalone-survival.R create mode 100644 tests/testthat/_snaps/eval-time-args.md create mode 100644 tests/testthat/_snaps/last-fit.new.md create mode 100644 tests/testthat/_snaps/metric-args.md create mode 100644 tests/testthat/test-eval-time-args.R create mode 100644 tests/testthat/test-metric-args.R diff --git a/DESCRIPTION b/DESCRIPTION index d3ea4dc1e..916394360 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: tune Title: Tidy Tuning Tools -Version: 1.1.2.9002 +Version: 1.1.2.9003 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), @@ -38,7 +38,7 @@ Imports: vctrs (>= 0.6.1), withr, workflows (>= 1.0.0), - yardstick (>= 1.2.0.9001) + yardstick (>= 1.2.0.9002) Suggests: C50, censored, diff --git a/NAMESPACE b/NAMESPACE index 099354149..274c948e5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -151,8 +151,10 @@ export(.stash_last_result) export(.use_case_weights_with_yardstick) export(augment) export(autoplot) +export(check_eval_time_arg) export(check_initial) export(check_metrics) +export(check_metrics_arg) export(check_parameters) export(check_rset) export(check_time) diff --git a/R/last_fit.R b/R/last_fit.R index d721ae3c0..fd91322cf 100644 --- a/R/last_fit.R +++ b/R/last_fit.R @@ -140,7 +140,8 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL wflow <- add_formula(wflow, preprocessor) } - last_fit_workflow(wflow, split, metrics, control, eval_time, add_validation_set) + last_fit_workflow(wflow, split, metrics, control, eval_time, + add_validation_set) } @@ -153,7 +154,8 @@ last_fit.workflow <- function(object, split, ..., metrics = NULL, control <- parsnip::condense_control(control, control_last_fit()) - last_fit_workflow(object, split, metrics, control, eval_time, add_validation_set) + last_fit_workflow(object, split, metrics, control, eval_time, + add_validation_set) } @@ -191,7 +193,8 @@ last_fit_workflow <- function(object, metrics = metrics, control = control, eval_time = eval_time, - rng = rng + rng = rng, + call = call ) res$.workflow <- res$.extracts[[1]][[1]] diff --git a/R/metric-selection.R b/R/metric-selection.R index 6fc1e7064..a32b7471d 100644 --- a/R/metric-selection.R +++ b/R/metric-selection.R @@ -4,9 +4,10 @@ #' @param metric A character value for which metric is being used. #' @param eval_time An optional vector of times to compute dynamic and/or #' integrated metrics. +#' @param wflow A [workflows::workflow()]. #' @param x An object with class `tune_results`. #' @param call The call to be displayed in warnings or errors. -#' @description +#' @details #' These are developer-facing functions used to compute and validate choices #' for performance metrics. For survival analysis models, there are similar #' functions for the evaluation time(s) required for dynamic and/or integrated @@ -191,6 +192,119 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) { summary_res } +# ------------------------------------------------------------------------------ + +#' @rdname choose_metric +#' @export +check_metrics_arg <- function(mtr_set, wflow, call = rlang::caller_env()) { + mode <- extract_spec_parsnip(wflow)$mode + + if (is.null(mtr_set)) { + switch(mode, + regression = { + mtr_set <- yardstick::metric_set(rmse, rsq) + }, + classification = { + mtr_set <- yardstick::metric_set(roc_auc, accuracy) + }, + 'censored regression' = { + mtr_set <- yardstick::metric_set(brier_survival) + }, + # workflows cannot be set with an unknown mode + cli::cli_abort("Model value {.val {mode}} can't be used.", call = call) + ) + + return(mtr_set) + } + + is_numeric_metric_set <- inherits(mtr_set, "numeric_metric_set") + is_class_prob_metric_set <- inherits(mtr_set, "class_prob_metric_set") + is_surv_metric_set <- inherits(mtr_set, c("survival_metric_set")) + + if (!is_numeric_metric_set && !is_class_prob_metric_set && !is_surv_metric_set) { + cli::cli_abort("The {.arg metrics} argument should be the results of + {.fn yardstick::metric_set}.", call = call) + } + + if (mode == "regression" && !is_numeric_metric_set) { + cli::cli_abort("The parsnip model has {.code mode} value of {.val {mode}}, + but the {.arg metrics} is a metric set for a + different model mode.", call = call) + } + + if (mode == "classification" && !is_class_prob_metric_set) { + cli::cli_abort("The parsnip model has {.code mode} value of {.val {mode}}, + but the {.arg metrics} is a metric set for a + different model mode.", call = call) + } + + if (mode == "censored regression" && !is_surv_metric_set) { + cli::cli_abort("The parsnip model has {.code mode} value of {.val {mode}}, + but the {.arg metrics} is a metric set for a + different model mode.", call = call) + } + + mtr_set +} + +# ------------------------------------------------------------------------------ + +#' @rdname choose_metric +#' @export +check_eval_time_arg <- function(eval_time, mtr_set, call = rlang::caller_env()) { + mtr_info <- tibble::as_tibble(mtr_set) + + # Not a survival metric + if (!contains_survival_metric(mtr_info)) { + if (!is.null(eval_time)) { + cli::cli_warn("Evaluation times are only required when the model + mode is {.val censored regression} (and will be ignored).") + } + return(NULL) + } + + cls <- mtr_info$class + uni_cls <- sort(unique(cls)) + eval_time <- .filter_eval_time(eval_time) + + num_times <- length(eval_time) + + max_times_req <- req_eval_times(mtr_set) + + if (max_times_req > num_times) { + cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are} + required for the metric type(s) requested: {.val {uni_cls}}. + Only {num_times} unique time{?s} {?was/were} given.", + call = call) + } + + if (max_times_req == 0 & num_times > 0) { + cli::cli_warn("Evaluation times are only required when dynmanic or + integrated metrics are used (and will be ignored here).") + eval_time <- NULL + } + + eval_time +} + +req_eval_times <- function(mtr_set) { + mtr_info <- tibble::as_tibble(mtr_set) + cls <- mtr_info$class + + # Default for non-survival and static metrics + max_req_times <- 0 + + if (any(cls == "dynamic_survival_metric")) { + max_req_times <- max(max_req_times, 1) + } + + if (any(cls == "integrated_survival_metric")) { + max_req_times <- max(max_req_times, 2) + } + + max_req_times +} + # TODO will be removed shortly middle_eval_time <- function(x) { diff --git a/R/standalone-survival.R b/R/standalone-survival.R new file mode 100644 index 000000000..7d8838f8e --- /dev/null +++ b/R/standalone-survival.R @@ -0,0 +1,141 @@ +# --- +# repo: tidymodels/parsnip +# file: standalone-survival.R +# last-updated: 2023-12-08 +# license: https://unlicense.org +# --- + +# This file provides a portable set of helper functions for survival analysis. +# The original is in the parsnip package. + +# ## Changelog +# 2023-12-08 +# * move .filter_eval_time to this file +# +# 2023-11-09 +# * make sure survival vectors are unnamed. +# +# 2023-06-14 +# * removed time to factor conversion +# +# 2023-05-18 +# * added time to factor conversion +# +# 2023-02-28: +# * Initial version +# +# ------------------------------------------------------------------------------ +# +# @param surv A [survival::Surv()] object +# @details +# `.is_censored_right()` always returns a logical while +# `.check_censored_right()` will fail if `FALSE`. +# +# `.extract_status()` will return the data as 0/1 even if the original object +# used the legacy encoding of 1/2. See [survival::Surv()]. + +# @return +# - `.extract_surv_status()` returns a vector. +# - `.extract_surv_time()` returns a vector when the type is `"right"` or `"left"` +# and a tibble otherwise. +# - Functions starting with `.is_` or `.check_` return logicals although the +# latter will fail when `FALSE`. + +# nocov start +# These are tested in the extratests repo since it would require a dependency +# on the survival package. https://github.com/tidymodels/extratests/pull/78 +.is_surv <- function(surv, fail = TRUE, call = rlang::caller_env()) { + is_surv <- inherits(surv, "Surv") + if (!is_surv && fail) { + rlang::abort("The object does not have class `Surv`.", call = call) + } + is_surv +} + +.extract_surv_type <- function(surv) { + attr(surv, "type") +} + +.check_cens_type <- + function(surv, + type = "right", + fail = TRUE, + call = rlang::caller_env()) { + .is_surv(surv, call = call) + obj_type <- .extract_surv_type(surv) + good_type <- all(obj_type %in% type) + if (!good_type && fail) { + c_list <- paste0("'", type, "'") + msg <- cli::format_inline("For this usage, the allowed censoring type{?s} {?is/are}: {c_list}") + rlang::abort(msg, call = call) + } + good_type + } + +.is_censored_right <- function(surv) { + .check_cens_type(surv, type = "right", fail = FALSE) +} + +.check_censored_right <- function(surv, call = rlang::caller_env()) { + .check_cens_type(surv, type = "right", fail = TRUE, call = call) +} # will add more as we need them + +.extract_surv_time <- function(surv) { + .is_surv(surv) + keepers <- c("time", "start", "stop", "time1", "time2") + cols <- colnames(surv)[colnames(surv) %in% keepers] + res <- surv[, cols, drop = FALSE] + if (length(cols) > 1) { + res <- tibble::tibble(as.data.frame(res)) + } else { + res <- as.numeric(res) + } + res +} + +.extract_surv_status <- function(surv) { + .is_surv(surv) + res <- surv[, "status"] + un_vals <- sort(unique(res)) + event_type_to_01 <- + !(.extract_surv_type(surv) %in% c("interval", "interval2", "mstate")) + if ( + event_type_to_01 && + (identical(un_vals, 1:2) | identical(un_vals, c(1.0, 2.0))) ) { + res <- res - 1 + } + unname(res) +} + +# nocov end + +# ------------------------------------------------------------------------------ + +# @param eval_time A vector of numeric time points +# @details +# `.filter_eval_time` checks the validity of the time points. +# +# @return A potentially modified vector of time points. +.filter_eval_time <- function(eval_time, fail = TRUE) { + if (!is.null(eval_time)) { + eval_time <- as.numeric(eval_time) + } + eval_time_0 <- eval_time + # will still propagate nulls: + eval_time <- eval_time[!is.na(eval_time)] + eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)] + eval_time <- unique(eval_time) + if (fail && identical(eval_time, numeric(0))) { + cli::cli_abort( + "There were no usable evaluation times (finite, non-missing, and >= 0).", + call = NULL + ) + } + if (!identical(eval_time, eval_time_0)) { + diffs <- setdiff(eval_time_0, eval_time) + cli::cli_warn("There {?was/were} {length(diffs)} inappropriate evaluation + time point{?s} that {?was/were} removed.", call = NULL) + + } + eval_time +} diff --git a/R/tune_bayes.R b/R/tune_bayes.R index f6e5636f3..bfecdcc3b 100644 --- a/R/tune_bayes.R +++ b/R/tune_bayes.R @@ -272,12 +272,13 @@ tune_bayes_workflow <- check_iter(iter, call = call) - metrics <- check_metrics(metrics, object) - check_eval_time(eval_time, metrics) - metrics_data <- metrics_info(metrics) - metrics_name <- metrics_data$.metric[1] - metrics_time <- get_metric_time(metrics, eval_time) - maximize <- metrics_data$direction[metrics_data$.metric == metrics_name] == "maximize" + metrics <- check_metrics_arg(metrics, object, call = call) + opt_metric <- first_metric(metrics) + metrics_name <- opt_metric$metric + maximize <- opt_metric$direction == "maximize" + + eval_time <- check_eval_time_arg(eval_time, metrics, call = call) + metrics_time <- first_eval_time(metrics, metrics_name, eval_time) if (is.null(param_info)) { param_info <- hardhat::extract_parameter_set_dials(object) @@ -334,7 +335,7 @@ tune_bayes_workflow <- if (control$verbose_iter) { msg <- paste("Optimizing", metrics_name, "using", objective$label) - if (!is.null(eval_time)) { + if (!is.null(metrics_time)) { msg <- paste(msg, "at evaluation time", format(metrics_time, digits = 3)) } message_wrap(msg) diff --git a/R/tune_grid.R b/R/tune_grid.R index a558245f2..2b1c9bc5a 100644 --- a/R/tune_grid.R +++ b/R/tune_grid.R @@ -332,9 +332,8 @@ tune_grid_workflow <- function(workflow, call = caller_env()) { check_rset(resamples) - - metrics <- check_metrics(metrics, workflow) - check_eval_time(eval_time, metrics) + metrics <- check_metrics_arg(metrics, workflow, call = call) + eval_time <- check_eval_time_arg(eval_time, metrics, call = call) pset <- check_parameters( workflow, diff --git a/man/choose_metric.Rd b/man/choose_metric.Rd index ce1790dfb..5b38a5c90 100644 --- a/man/choose_metric.Rd +++ b/man/choose_metric.Rd @@ -6,6 +6,8 @@ \alias{first_metric} \alias{first_eval_time} \alias{.filter_perf_metrics} +\alias{check_metrics_arg} +\alias{check_eval_time_arg} \title{Tools for selecting metrics and evaluation times} \usage{ choose_metric(x, metric, ..., call = rlang::caller_env()) @@ -17,6 +19,10 @@ first_metric(mtr_set) first_eval_time(mtr_set, metric = NULL, eval_time = NULL) .filter_perf_metrics(x, metric, eval_time) + +check_metrics_arg(mtr_set, wflow, call = rlang::caller_env()) + +check_eval_time_arg(eval_time, mtr_set, call = rlang::caller_env()) } \arguments{ \item{x}{An object with class \code{tune_results}.} @@ -29,8 +35,13 @@ first_eval_time(mtr_set, metric = NULL, eval_time = NULL) integrated metrics.} \item{mtr_set}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}}.} + +\item{wflow}{A \code{\link[workflows:workflow]{workflows::workflow()}}.} } \description{ +Tools for selecting metrics and evaluation times +} +\details{ These are developer-facing functions used to compute and validate choices for performance metrics. For survival analysis models, there are similar functions for the evaluation time(s) required for dynamic and/or integrated diff --git a/tests/testthat/_snaps/censored-reg.md b/tests/testthat/_snaps/censored-reg.md index 61e72ce32..35810ff63 100644 --- a/tests/testthat/_snaps/censored-reg.md +++ b/tests/testthat/_snaps/censored-reg.md @@ -3,16 +3,16 @@ Code spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, metrics = mtr) Condition - Error: - ! One or more metric requires the specification of time points in the `eval_time` argument. + Error in `tune_grid()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. --- Code spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, metrics = reg_mtr) Condition - Error in `check_metrics()`: - ! The parsnip model has `mode = 'censored regression'`, but `metrics` is a metric set for a different model mode. + Error in `tune_grid()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. --- @@ -20,16 +20,33 @@ linear_reg() %>% tune_grid(age ~ ., resamples = rs, metrics = reg_mtr, eval_time = 1) Condition - Error: - ! Evaluation times are only used for dynamic and integrated survival metrics. + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + Warning: + No tuning parameters have been detected, performance will be evaluated using the resamples with no tuning. Did you want to [tune()] parameters? + Output + # Tuning results + # 10-fold cross-validation using stratification + # A tibble: 10 x 4 + splits id .metrics .notes + + 1 Fold01 + 2 Fold02 + 3 Fold03 + 4 Fold04 + 5 Fold05 + 6 Fold06 + 7 Fold07 + 8 Fold08 + 9 Fold09 + 10 Fold10 --- Code - show_notes(no_usable_times) - Output - unique notes: - ------------------------------------------------------------------------ + no_usable_times <- spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, + metrics = mtr, eval_time = c(-1, Inf)) + Condition Error: ! There were no usable evaluation times (finite, non-missing, and >= 0). diff --git a/tests/testthat/_snaps/eval-time-args.md b/tests/testthat/_snaps/eval-time-args.md new file mode 100644 index 000000000..d1fbfefe0 --- /dev/null +++ b/tests/testthat/_snaps/eval-time-args.md @@ -0,0 +1,845 @@ +# eval time inputs are checked for regression models + + Code + check_eval_time_arg(NULL, met_reg) + Output + NULL + +--- + + Code + check_eval_time_arg(times, met_reg) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + Output + NULL + +--- + + Code + res <- fit_resamples(wflow, rs, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +--- + + Code + set.seed(1) + res <- tune_grid(wflow_tune, rs, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +--- + + Code + set.seed(1) + res <- tune_bayes(wflow_tune, rs, iter = 1, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +--- + + Code + res <- last_fit(wflow, split, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +# eval time are checked for classification models + + Code + check_eval_time_arg(NULL, met_cls) + Output + NULL + +--- + + Code + check_eval_time_arg(times, met_cls) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + Output + NULL + +--- + + Code + res <- fit_resamples(wflow, rs, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +--- + + Code + set.seed(1) + res <- tune_grid(wflow_tune, rs, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +--- + + Code + set.seed(1) + res <- tune_bayes(wflow_tune, rs, iter = 1, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +--- + + Code + res <- last_fit(wflow, split, eval_time = times) + Condition + Warning: + Evaluation times are only required when the model mode is "censored regression" (and will be ignored). + +# eval time inputs are checked for censored regression models + + Code + check_eval_time_arg(NULL, met_stc) + Output + NULL + +--- + + Code + check_eval_time_arg(NULL, met_dyn) + Condition + Error: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(NULL, met_int) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(NULL, met_stc_dyn) + Condition + Error: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(NULL, met_stc_int) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(NULL, met_dyn_stc) + Condition + Error: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(NULL, met_dyn_int) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(NULL, met_int_stc) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(NULL, met_int_dyn) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + check_eval_time_arg(2, met_stc) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + Output + NULL + +--- + + Code + check_eval_time_arg(2, met_dyn) + Output + [1] 2 + +--- + + Code + check_eval_time_arg(2, met_int) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + check_eval_time_arg(2, met_stc_dyn) + Output + [1] 2 + +--- + + Code + check_eval_time_arg(2, met_stc_int) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + check_eval_time_arg(2, met_dyn_stc) + Output + [1] 2 + +--- + + Code + check_eval_time_arg(2, met_dyn_int) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + check_eval_time_arg(2, met_int_stc) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + check_eval_time_arg(2, met_int_dyn) + Condition + Error: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + check_eval_time_arg(1:3, met_stc) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + Output + NULL + +--- + + Code + check_eval_time_arg(1:3, met_dyn) + Output + [1] 1 2 3 + +--- + + Code + check_eval_time_arg(1:3, met_int) + Output + [1] 1 2 3 + +--- + + Code + check_eval_time_arg(1:3, met_stc_dyn) + Output + [1] 1 2 3 + +--- + + Code + check_eval_time_arg(1:3, met_stc_int) + Output + [1] 1 2 3 + +--- + + Code + check_eval_time_arg(1:3, met_dyn_stc) + Output + [1] 1 2 3 + +--- + + Code + check_eval_time_arg(1:3, met_dyn_int) + Output + [1] 1 2 3 + +--- + + Code + check_eval_time_arg(1:3, met_int_stc) + Output + [1] 1 2 3 + +--- + + Code + check_eval_time_arg(1:3, met_int_dyn) + Output + [1] 1 2 3 + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_stc) + +--- + + Code + fit_resamples(wflow, rs, metrics = met_dyn) + Condition + Error in `fit_resamples()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_int) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_stc_dyn) + Condition + Error in `fit_resamples()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_stc_int) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_dyn_stc) + Condition + Error in `fit_resamples()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_dyn_int) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_int_stc) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_int_dyn) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_stc, eval_time = 2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_dyn, eval_time = 2) + +--- + + Code + fit_resamples(wflow, rs, metrics = met_int, eval_time = 2) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_stc_dyn, eval_time = 2) + +--- + + Code + fit_resamples(wflow, rs, metrics = met_stc_int, eval_time = 2) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_dyn_stc, eval_time = 2) + +--- + + Code + fit_resamples(wflow, rs, metrics = met_dyn_int, eval_time = 2) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_int_stc, eval_time = 2) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_int_dyn, eval_time = 2) + Condition + Error in `fit_resamples()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_stc, eval_time = 1:3) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_dyn, eval_time = 1:3) + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_int, eval_time = 1:3) + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_stc_dyn, eval_time = 1:3) + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_stc_int, eval_time = 1:3) + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_dyn_stc, eval_time = 1:3) + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_dyn_int, eval_time = 1:3) + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_int_stc, eval_time = 1:3) + +--- + + Code + res <- fit_resamples(wflow, rs, metrics = met_int_dyn, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_stc) + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_dyn) + Condition + Error in `tune_grid()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_int) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_stc_dyn) + Condition + Error in `tune_grid()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_stc_int) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_dyn_stc) + Condition + Error in `tune_grid()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_dyn_int) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_int_stc) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_int_dyn) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_stc, eval_time = 2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_dyn, eval_time = 2) + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_int, eval_time = 2) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_stc_dyn, eval_time = 2) + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_stc_int, eval_time = 2) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_dyn_stc, eval_time = 2) + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_dyn_int, eval_time = 2) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_int_stc, eval_time = 2) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_int_dyn, eval_time = 2) + Condition + Error in `tune_grid()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_stc, eval_time = 1:3) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_dyn, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_int, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_stc_dyn, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_stc_int, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_dyn_stc, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_dyn_int, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_int_stc, eval_time = 1:3) + +--- + + Code + res <- tune_grid(wflow_tune, rs, metrics = met_int_dyn, eval_time = 1:3) + +--- + + Code + last_fit(wflow, split, metrics = met_dyn) + Condition + Error in `last_fit()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric". Only 0 unique times were given. + +--- + + Code + last_fit(wflow, split, metrics = met_int) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + last_fit(wflow, split, metrics = met_stc_dyn) + Condition + Error in `last_fit()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + last_fit(wflow, split, metrics = met_stc_int) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + last_fit(wflow, split, metrics = met_dyn_stc) + Condition + Error in `last_fit()`: + ! At least 1 evaluation time is required for the metric type(s) requested: "dynamic_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + last_fit(wflow, split, metrics = met_dyn_int) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + last_fit(wflow, split, metrics = met_int_stc) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 0 unique times were given. + +--- + + Code + last_fit(wflow, split, metrics = met_int_dyn) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 0 unique times were given. + +--- + + Code + res <- last_fit(wflow, split, metrics = met_stc, eval_time = 2) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + +--- + + Code + res <- last_fit(wflow, split, metrics = met_dyn, eval_time = 2) + +--- + + Code + last_fit(wflow, split, metrics = met_int, eval_time = 2) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- last_fit(wflow, split, metrics = met_stc_dyn, eval_time = 2) + +--- + + Code + last_fit(wflow, split, metrics = met_stc_int, eval_time = 2) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- last_fit(wflow, split, metrics = met_dyn_stc, eval_time = 2) + +--- + + Code + last_fit(wflow, split, metrics = met_dyn_int, eval_time = 2) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + last_fit(wflow, split, metrics = met_int_stc, eval_time = 2) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "integrated_survival_metric" and "static_survival_metric". Only 1 unique time was given. + +--- + + Code + last_fit(wflow, split, metrics = met_int_dyn, eval_time = 2) + Condition + Error in `last_fit()`: + ! At least 2 evaluation times are required for the metric type(s) requested: "dynamic_survival_metric" and "integrated_survival_metric". Only 1 unique time was given. + +--- + + Code + res <- last_fit(wflow, split, metrics = met_stc, eval_time = 1:3) + Condition + Warning: + Evaluation times are only required when dynmanic or integrated metrics are used (and will be ignored here). + +--- + + Code + res <- last_fit(wflow, split, metrics = met_dyn, eval_time = 1:3) + +--- + + Code + res <- last_fit(wflow, split, metrics = met_int, eval_time = 1:3) + +--- + + Code + res <- last_fit(wflow, split, metrics = met_stc_dyn, eval_time = 1:3) + +--- + + Code + res <- last_fit(wflow, split, metrics = met_stc_int, eval_time = 1:3) + +--- + + Code + res <- last_fit(wflow, split, metrics = met_dyn_stc, eval_time = 1:3) + +--- + + Code + res <- last_fit(wflow, split, metrics = met_dyn_int, eval_time = 1:3) + +--- + + Code + res <- last_fit(wflow, split, metrics = met_int_stc, eval_time = 1:3) + +--- + + Code + res <- last_fit(wflow, split, metrics = met_int_dyn, eval_time = 1:3) + diff --git a/tests/testthat/_snaps/last-fit.new.md b/tests/testthat/_snaps/last-fit.new.md new file mode 100644 index 000000000..2a2445eee --- /dev/null +++ b/tests/testthat/_snaps/last-fit.new.md @@ -0,0 +1,63 @@ +# model_fit method + + Code + last_fit(lm_fit) + Condition + Error in `last_fit()`: + ! `last_fit()` (`?tune::last_fit()`) is not well-defined for fitted model objects. + i `last_fit()` (`?tune::last_fit()`) takes a model specification (`?parsnip::model_spec()`) or unfitted workflow (`?workflows::workflow()`) as its first argument. + +# workflow method + + Code + last_fit(lm_fit) + Condition + Error: + ! `last_fit()` is not well-defined for a fitted workflow. + +# ellipses with last_fit + + Code + linear_reg() %>% set_engine("lm") %>% last_fit(f, split, something = "wrong") + Condition + Warning: + The `...` are not used in this function but one or more objects were passed: 'something' + Output + # Resampling results + # Manual resampling + # A tibble: 1 x 6 + splits id .metrics .notes .predictions .workflow + + 1 train/test split + +# argument order gives errors for recipe/formula + + Code + last_fit(rec, lin_mod, split) + Condition + Error in `last_fit()`: + ! The first argument to [last_fit()] should be either a model or workflow. + +--- + + Code + last_fit(f, lin_mod, split) + Condition + Error in `last_fit()`: + ! The first argument to [last_fit()] should be either a model or workflow. + +# `last_fit()` when objects need tuning + + 2 arguments have been tagged for tuning in these components: model_spec and recipe. + Please use one of the tuning functions (e.g. `tune_grid()`) to optimize them. + +--- + + 1 argument has been tagged for tuning in this component: model_spec. + Please use one of the tuning functions (e.g. `tune_grid()`) to optimize them. + +--- + + 1 argument has been tagged for tuning in this component: recipe. + Please use one of the tuning functions (e.g. `tune_grid()`) to optimize them. + diff --git a/tests/testthat/_snaps/metric-args.md b/tests/testthat/_snaps/metric-args.md new file mode 100644 index 000000000..b72587c8e --- /dev/null +++ b/tests/testthat/_snaps/metric-args.md @@ -0,0 +1,290 @@ +# metric inputs are checked for regression models + + Code + check_metrics_arg(NULL, wflow) + Output + A metric set, consisting of: + - `rmse()`, a numeric metric | direction: minimize + - `rsq()`, a numeric metric | direction: maximize + +--- + + Code + check_metrics_arg(met_reg, wflow) + Output + A metric set, consisting of: + - `rmse()`, a numeric metric | direction: minimize + +--- + + Code + check_metrics_arg(met_cls, wflow) + Condition + Error: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + check_metrics_arg(met_mix_int, wflow) + Condition + Error: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_cls) + Condition + Error in `fit_resamples()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_mix_int) + Condition + Error in `fit_resamples()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_cls) + Condition + Error in `tune_grid()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_mix_int) + Condition + Error in `tune_grid()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_bayes(wflow_tune, rs, metrics = met_cls) + Condition + Error in `tune_bayes()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_bayes(wflow_tune, rs, metrics = met_mix_int) + Condition + Error in `tune_bayes()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + last_fit(wflow, split, metrics = met_cls) + Condition + Error in `last_fit()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + last_fit(wflow, split, metrics = met_mix_int) + Condition + Error in `last_fit()`: + ! The parsnip model has `mode` value of "regression", but the `metrics` is a metric set for a different model mode. + +# metric inputs are checked for classification models + + Code + check_metrics_arg(NULL, wflow) + Output + A metric set, consisting of: + - `roc_auc()`, a probability metric | direction: maximize + - `accuracy()`, a class metric | direction: maximize + +--- + + Code + check_metrics_arg(met_reg, wflow) + Condition + Error: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + check_metrics_arg(met_cls, wflow) + Output + A metric set, consisting of: + - `brier_class()`, a probability metric | direction: minimize + +--- + + Code + check_metrics_arg(met_mix_int, wflow) + Condition + Error: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_reg) + Condition + Error in `fit_resamples()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_mix_int) + Condition + Error in `fit_resamples()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_reg) + Condition + Error in `tune_grid()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_mix_int) + Condition + Error in `tune_grid()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_bayes(wflow_tune, rs, metrics = met_reg) + Condition + Error in `tune_bayes()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_bayes(wflow_tune, rs, metrics = met_mix_int) + Condition + Error in `tune_bayes()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + last_fit(wflow, split, metrics = met_reg) + Condition + Error in `last_fit()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +--- + + Code + last_fit(wflow, split, metrics = met_mix_int) + Condition + Error in `last_fit()`: + ! The parsnip model has `mode` value of "classification", but the `metrics` is a metric set for a different model mode. + +# metric inputs are checked for censored regression models + + Code + check_metrics_arg(NULL, wflow) + Output + A metric set, consisting of: + - `brier_survival()`, a dynamic survival metric | direction: minimize + +--- + + Code + check_metrics_arg(met_reg, wflow) + Condition + Error: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + check_metrics_arg(met_cls, wflow) + Condition + Error: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + check_metrics_arg(met_srv, wflow) + Output + A metric set, consisting of: + - `concordance_survival()`, a static survival metric | direction: maximize + +--- + + Code + fit_resamples(wflow, rs, metrics = met_cls) + Condition + Error in `fit_resamples()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + fit_resamples(wflow, rs, metrics = met_reg) + Condition + Error in `fit_resamples()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_cls) + Condition + Error in `tune_grid()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_grid(wflow_tune, rs, metrics = met_reg) + Condition + Error in `tune_grid()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_bayes(wflow_tune, rs, metrics = met_cls) + Condition + Error in `tune_bayes()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + tune_bayes(wflow_tune, rs, metrics = met_reg) + Condition + Error in `tune_bayes()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + last_fit(wflow, split, metrics = met_cls) + Condition + Error in `last_fit()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + +--- + + Code + last_fit(wflow, split, metrics = met_reg) + Condition + Error in `last_fit()`: + ! The parsnip model has `mode` value of "censored regression", but the `metrics` is a metric set for a different model mode. + diff --git a/tests/testthat/test-censored-reg.R b/tests/testthat/test-censored-reg.R index 1bf383c69..12a9558f4 100644 --- a/tests/testthat/test-censored-reg.R +++ b/tests/testthat/test-censored-reg.R @@ -20,17 +20,14 @@ test_that("evaluation time", { expect_snapshot(error = TRUE, spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, metrics = reg_mtr) ) - expect_snapshot(error = TRUE, + expect_snapshot( linear_reg() %>% tune_grid(age ~ ., resamples = rs, metrics = reg_mtr, eval_time = 1) ) - suppressMessages(suppressWarnings( + expect_snapshot(error = TRUE, no_usable_times <- spec %>% tune_grid(Surv(time, status) ~ ., resamples = rs, metrics = mtr, eval_time = c(-1, Inf)) - )) - expect_snapshot( - show_notes(no_usable_times) ) times <- 4:1 diff --git a/tests/testthat/test-eval-time-args.R b/tests/testthat/test-eval-time-args.R new file mode 100644 index 000000000..03bed00ea --- /dev/null +++ b/tests/testthat/test-eval-time-args.R @@ -0,0 +1,329 @@ +test_that("eval time inputs are checked for regression models", { + library(parsnip) + library(workflows) + library(yardstick) + library(rsample) + + # ------------------------------------------------------------------------------ + + wflow <- workflow(mpg ~ ., linear_reg()) + knn_spec <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression") + wflow_tune <- workflow(mpg ~ ., knn_spec) + + set.seed(1) + split <- initial_split(mtcars) + rs <- vfold_cv(mtcars) + + times <- c(1, 1:3) + + # ------------------------------------------------------------------------------ + # setup metric sets + + met_reg <- metric_set(rmse) + + # ------------------------------------------------------------------------------ + # check inputs + + expect_snapshot(check_eval_time_arg(NULL, met_reg)) + expect_snapshot(check_eval_time_arg(times, met_reg)) + + # ------------------------------------------------------------------------------ + # resampling + + expect_snapshot( + res <- fit_resamples(wflow, rs, eval_time = times) + ) + + # ------------------------------------------------------------------------------ + # tuning + + expect_snapshot({ + set.seed(1) + res <- tune_grid(wflow_tune, rs, eval_time = times) + }) + + expect_snapshot({ + set.seed(1) + res <- tune_bayes(wflow_tune, rs, iter = 1, eval_time = times) + }) + + # ------------------------------------------------------------------------------ + # final fit + + expect_snapshot( + res <- last_fit(wflow, split, eval_time = times) + ) + +}) + +test_that("eval time are checked for classification models", { + library(parsnip) + library(workflows) + library(yardstick) + library(rsample) + + data(two_class_dat, package = "modeldata") + wflow <- workflow(Class ~ A + B, logistic_reg()) + knn_spec <- nearest_neighbor(neighbors = tune()) %>% set_mode("classification") + wflow_tune <- workflow(Class ~ A + B, knn_spec) + + set.seed(1) + split <- initial_split(two_class_dat) + rs <- vfold_cv(two_class_dat) + + times <- c(1, 1:3) + + # ------------------------------------------------------------------------------ + # setup metric sets + + met_cls <- metric_set(brier_class) + + # ------------------------------------------------------------------------------ + # check inputs + + expect_snapshot(check_eval_time_arg(NULL, met_cls)) + expect_snapshot(check_eval_time_arg(times, met_cls)) + + # ------------------------------------------------------------------------------ + # resampling + + expect_snapshot( + res <- fit_resamples(wflow, rs, eval_time = times) + ) + + # ------------------------------------------------------------------------------ + # tuning + + expect_snapshot({ + set.seed(1) + res <- tune_grid(wflow_tune, rs, eval_time = times) + }) + + expect_snapshot({ + set.seed(1) + res <- tune_bayes(wflow_tune, rs, iter = 1, eval_time = times) + }) + + # ------------------------------------------------------------------------------ + # final fit + + expect_snapshot( + res <- last_fit(wflow, split, eval_time = times) + ) + +}) + +test_that("eval time inputs are checked for censored regression models", { + skip_if_not_installed("censored") + + library(parsnip) + library(workflows) + library(yardstick) + library(rsample) + suppressPackageStartupMessages(library(censored)) + + stanford2$event_time <- Surv(stanford2$time, stanford2$status) + stanford2 <- stanford2[, c("event_time", "age")] + + wflow <- workflow(event_time ~ age, survival_reg()) + sr_spec <- survival_reg(dist = tune()) + wflow_tune <- workflow(event_time ~ age, sr_spec) + + set.seed(1) + split <- initial_split(stanford2) + rs <- vfold_cv(stanford2) + + # ------------------------------------------------------------------------------ + # setup metric sets + + met_stc <- metric_set(concordance_survival) + met_dyn <- metric_set(brier_survival) + met_int <- metric_set(brier_survival_integrated) + met_stc_dyn <- metric_set(concordance_survival, brier_survival) + met_stc_int <- metric_set(concordance_survival, brier_survival_integrated) + met_dyn_stc <- metric_set(brier_survival, concordance_survival) + met_dyn_int <- metric_set(brier_survival, brier_survival_integrated) + met_int_stc <- metric_set(brier_survival_integrated, concordance_survival) + met_int_dyn <- metric_set(brier_survival_integrated, brier_survival) + + # ------------------------------------------------------------------------------ + # check inputs when eval_time left out + + expect_snapshot(check_eval_time_arg(NULL, met_stc)) + expect_snapshot(check_eval_time_arg(NULL, met_dyn), error = TRUE) + expect_snapshot(check_eval_time_arg(NULL, met_int), error = TRUE) + + expect_snapshot(check_eval_time_arg(NULL, met_stc_dyn), error = TRUE) + expect_snapshot(check_eval_time_arg(NULL, met_stc_int), error = TRUE) + expect_snapshot(check_eval_time_arg(NULL, met_dyn_stc), error = TRUE) + + expect_snapshot(check_eval_time_arg(NULL, met_dyn_int), error = TRUE) + expect_snapshot(check_eval_time_arg(NULL, met_int_stc), error = TRUE) + expect_snapshot(check_eval_time_arg(NULL, met_int_dyn), error = TRUE) + + # ------------------------------------------------------------------------------ + # check inputs with single eval times + + expect_snapshot(check_eval_time_arg(2, met_stc)) + expect_snapshot(check_eval_time_arg(2, met_dyn)) + expect_snapshot(check_eval_time_arg(2, met_int), error = TRUE) + + expect_snapshot(check_eval_time_arg(2, met_stc_dyn)) + expect_snapshot(check_eval_time_arg(2, met_stc_int), error = TRUE) + + expect_snapshot(check_eval_time_arg(2, met_dyn_stc)) + expect_snapshot(check_eval_time_arg(2, met_dyn_int), error = TRUE) + + expect_snapshot(check_eval_time_arg(2, met_int_stc), error = TRUE) + expect_snapshot(check_eval_time_arg(2, met_int_dyn), error = TRUE) + + # ------------------------------------------------------------------------------ + # check inputs with multiple eval times + + expect_snapshot(check_eval_time_arg(1:3, met_stc)) + expect_snapshot(check_eval_time_arg(1:3, met_dyn)) + expect_snapshot(check_eval_time_arg(1:3, met_int)) + + expect_snapshot(check_eval_time_arg(1:3, met_stc_dyn)) + expect_snapshot(check_eval_time_arg(1:3, met_stc_int)) + expect_snapshot(check_eval_time_arg(1:3, met_dyn_stc)) + + expect_snapshot(check_eval_time_arg(1:3, met_dyn_int)) + expect_snapshot(check_eval_time_arg(1:3, met_int_stc)) + expect_snapshot(check_eval_time_arg(1:3, met_int_dyn)) + + # ------------------------------------------------------------------------------ + # resampling + + # no eval time + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_stc)) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_dyn), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_int), error = TRUE) + + expect_snapshot(fit_resamples(wflow, rs, metrics = met_stc_dyn), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_stc_int), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_dyn_stc), error = TRUE) + + expect_snapshot(fit_resamples(wflow, rs, metrics = met_dyn_int), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_int_stc), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_int_dyn), error = TRUE) + + # one eval time + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_stc, eval_time = 2)) + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_dyn, eval_time = 2)) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_int, eval_time = 2), error = TRUE) + + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_stc_dyn, eval_time = 2)) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_stc_int, eval_time = 2), error = TRUE) + + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_dyn_stc, eval_time = 2)) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_dyn_int, eval_time = 2), error = TRUE) + + expect_snapshot(fit_resamples(wflow, rs, metrics = met_int_stc, eval_time = 2), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_int_dyn, eval_time = 2), error = TRUE) + + # multiple eval times + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_stc, eval_time = 1:3)) + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_dyn, eval_time = 1:3)) + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_int, eval_time = 1:3)) + + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_stc_dyn, eval_time = 1:3)) + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_stc_int, eval_time = 1:3)) + + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_dyn_stc, eval_time = 1:3)) + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_dyn_int, eval_time = 1:3)) + + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_int_stc, eval_time = 1:3)) + expect_snapshot(res <- fit_resamples(wflow, rs, metrics = met_int_dyn, eval_time = 1:3)) + + # ------------------------------------------------------------------------------ + # grid tuning (tune bayes tests in extratests repo) + + # no eval time + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_stc)) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_dyn), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_int), error = TRUE) + + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_stc_dyn), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_stc_int), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_dyn_stc), error = TRUE) + + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_dyn_int), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_int_stc), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_int_dyn), error = TRUE) + + # one eval time + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_stc, eval_time = 2)) + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_dyn, eval_time = 2)) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_int, eval_time = 2), error = TRUE) + + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_stc_dyn, eval_time = 2)) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_stc_int, eval_time = 2), error = TRUE) + + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_dyn_stc, eval_time = 2)) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_dyn_int, eval_time = 2), error = TRUE) + + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_int_stc, eval_time = 2), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_int_dyn, eval_time = 2), error = TRUE) + + # multiple eval times + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_stc, eval_time = 1:3)) + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_dyn, eval_time = 1:3)) + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_int, eval_time = 1:3)) + + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_stc_dyn, eval_time = 1:3)) + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_stc_int, eval_time = 1:3)) + + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_dyn_stc, eval_time = 1:3)) + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_dyn_int, eval_time = 1:3)) + + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_int_stc, eval_time = 1:3)) + expect_snapshot(res <- tune_grid(wflow_tune, rs, metrics = met_int_dyn, eval_time = 1:3)) + + # ------------------------------------------------------------------------------ + # last fit + + # no eval time + expect_silent(res <- last_fit(wflow, split, metrics = met_stc)) + expect_snapshot(last_fit(wflow, split, metrics = met_dyn), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_int), error = TRUE) + + expect_snapshot(last_fit(wflow, split, metrics = met_stc_dyn), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_stc_int), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_dyn_stc), error = TRUE) + + expect_snapshot(last_fit(wflow, split, metrics = met_dyn_int), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_int_stc), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_int_dyn), error = TRUE) + + # one eval time + expect_snapshot(res <- last_fit(wflow, split, metrics = met_stc, eval_time = 2)) + expect_snapshot(res <- last_fit(wflow, split, metrics = met_dyn, eval_time = 2)) + expect_snapshot(last_fit(wflow, split, metrics = met_int, eval_time = 2), error = TRUE) + + expect_snapshot(res <- last_fit(wflow, split, metrics = met_stc_dyn, eval_time = 2)) + expect_snapshot(last_fit(wflow, split, metrics = met_stc_int, eval_time = 2), error = TRUE) + + expect_snapshot(res <- last_fit(wflow, split, metrics = met_dyn_stc, eval_time = 2)) + expect_snapshot(last_fit(wflow, split, metrics = met_dyn_int, eval_time = 2), error = TRUE) + + expect_snapshot(last_fit(wflow, split, metrics = met_int_stc, eval_time = 2), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_int_dyn, eval_time = 2), error = TRUE) + + # multiple eval times + expect_snapshot(res <- last_fit(wflow, split, metrics = met_stc, eval_time = 1:3)) + expect_snapshot(res <- last_fit(wflow, split, metrics = met_dyn, eval_time = 1:3)) + expect_snapshot(res <- last_fit(wflow, split, metrics = met_int, eval_time = 1:3)) + + expect_snapshot(res <- last_fit(wflow, split, metrics = met_stc_dyn, eval_time = 1:3)) + expect_snapshot(res <- last_fit(wflow, split, metrics = met_stc_int, eval_time = 1:3)) + + expect_snapshot(res <- last_fit(wflow, split, metrics = met_dyn_stc, eval_time = 1:3)) + expect_snapshot(res <- last_fit(wflow, split, metrics = met_dyn_int, eval_time = 1:3)) + + expect_snapshot(res <- last_fit(wflow, split, metrics = met_int_stc, eval_time = 1:3)) + expect_snapshot(res <- last_fit(wflow, split, metrics = met_int_dyn, eval_time = 1:3)) + + +}) + diff --git a/tests/testthat/test-metric-args.R b/tests/testthat/test-metric-args.R new file mode 100644 index 000000000..7dd2c1865 --- /dev/null +++ b/tests/testthat/test-metric-args.R @@ -0,0 +1,172 @@ +test_that("metric inputs are checked for regression models", { + library(parsnip) + library(workflows) + library(yardstick) + library(rsample) + + wflow <- workflow(y ~ X1 + X2, linear_reg()) + knn_spec <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression") + wflow_tune <- workflow(y ~ X1 + X2, knn_spec) + + set.seed(1) + split <- initial_split(mtcars) + rs <- vfold_cv(mtcars) + + # ------------------------------------------------------------------------------ + # setup metric sets + + met_mix_int <- + metric_set(brier_survival_integrated, + brier_survival, + concordance_survival) + met_reg <- metric_set(rmse) + met_cls <- metric_set(brier_class) + + # ------------------------------------------------------------------------------ + # check inputs + + expect_snapshot(check_metrics_arg(NULL, wflow)) + + expect_snapshot(check_metrics_arg(met_reg, wflow)) + expect_snapshot(check_metrics_arg(met_cls, wflow), error = TRUE) + expect_snapshot(check_metrics_arg(met_mix_int, wflow), error = TRUE) + + # ------------------------------------------------------------------------------ + # resampling + + expect_snapshot(fit_resamples(wflow, rs, metrics = met_cls), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_mix_int), error = TRUE) + + # ------------------------------------------------------------------------------ + # tuning + + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_cls), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_mix_int), error = TRUE) + + expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_cls), error = TRUE) + expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_mix_int), error = TRUE) + + # ------------------------------------------------------------------------------ + # final fit + + expect_snapshot(last_fit(wflow, split, metrics = met_cls), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_mix_int), error = TRUE) + +}) + +test_that("metric inputs are checked for classification models", { + library(parsnip) + library(workflows) + library(yardstick) + library(rsample) + + data(two_class_dat, package = "modeldata") + wflow <- workflow(Class ~ A + B, logistic_reg()) + knn_spec <- nearest_neighbor(neighbors = tune()) %>% set_mode("classification") + wflow_tune <- workflow(Class ~ A + B, knn_spec) + + set.seed(1) + split <- initial_split(two_class_dat) + rs <- vfold_cv(two_class_dat) + + # ------------------------------------------------------------------------------ + # setup metric sets + + met_mix_int <- + metric_set(brier_survival_integrated, + brier_survival, + concordance_survival) + met_reg <- metric_set(rmse) + met_cls <- metric_set(brier_class) + + # ------------------------------------------------------------------------------ + # check inputs + + expect_snapshot(check_metrics_arg(NULL, wflow)) + + expect_snapshot(check_metrics_arg(met_reg, wflow), error = TRUE) + expect_snapshot(check_metrics_arg(met_cls, wflow)) + expect_snapshot(check_metrics_arg(met_mix_int, wflow), error = TRUE) + + # ------------------------------------------------------------------------------ + # resampling + + expect_snapshot(fit_resamples(wflow, rs, metrics = met_reg), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_mix_int), error = TRUE) + + # ------------------------------------------------------------------------------ + # tuning + + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_reg), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_mix_int), error = TRUE) + + expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_reg), error = TRUE) + expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_mix_int), error = TRUE) + + # ------------------------------------------------------------------------------ + # final fit + + expect_snapshot(last_fit(wflow, split, metrics = met_reg), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_mix_int), error = TRUE) +}) + + +test_that("metric inputs are checked for censored regression models", { + skip_if_not_installed("censored") + library(parsnip) + library(workflows) + library(yardstick) + library(rsample) + library(censored) + + stanford2$event_time <- Surv(stanford2$time, stanford2$status) + stanford2 <- stanford2[, c("event_time", "age")] + + wflow <- workflow(event_time ~ age, survival_reg()) + sr_spec <- survival_reg(dist = tune()) + wflow_tune <- workflow(event_time ~ age, sr_spec) + + set.seed(1) + split <- initial_split(stanford2) + rs <- vfold_cv(stanford2) + + # ------------------------------------------------------------------------------ + # setup metric sets + + met_srv <- metric_set(concordance_survival) + met_reg <- metric_set(rmse) + met_cls <- metric_set(brier_class) + + # ------------------------------------------------------------------------------ + # check inputs + + expect_snapshot(check_metrics_arg(NULL, wflow)) + + expect_snapshot(check_metrics_arg(met_reg, wflow), error = TRUE) + expect_snapshot(check_metrics_arg(met_cls, wflow), error = TRUE) + expect_snapshot(check_metrics_arg(met_srv, wflow)) + + # ------------------------------------------------------------------------------ + # resampling + + expect_snapshot(fit_resamples(wflow, rs, metrics = met_cls), error = TRUE) + expect_snapshot(fit_resamples(wflow, rs, metrics = met_reg), error = TRUE) + + # ------------------------------------------------------------------------------ + # tuning + + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_cls), error = TRUE) + expect_snapshot(tune_grid(wflow_tune, rs, metrics = met_reg), error = TRUE) + + expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_cls), error = TRUE) + expect_snapshot(tune_bayes(wflow_tune, rs, metrics = met_reg), error = TRUE) + + # ------------------------------------------------------------------------------ + # final fit + + expect_snapshot(last_fit(wflow, split, metrics = met_cls), error = TRUE) + expect_snapshot(last_fit(wflow, split, metrics = met_reg), error = TRUE) + +}) + +