Skip to content

Commit

Permalink
Refactored functions to evaluate/check metric and eval time arguments (
Browse files Browse the repository at this point in the history
…#780)

* move function to new file

* change function order for docs

* documentation start

* updates to the show/select functions

* updates to select/show functions

* updates for selecting eval times

* remove commented out code

* bug fix

* metric test cases

* add a survival model object

* note for next PR

* select/show test cases

* small set of direct tests

* update snapshot

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>

* Apply suggestions from code review

* updates from previous review

* small cli update

* doc update

* refresh snapshots

* modularize a check

* Remake with newest CRAN version of scales for #775

* argument checks for tune functions

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>

* add dot when function is invoked

* add a warning for eval times with non-survival models

* go back to enquos

* rework warning text

* rework warning text pt 2

* unit tests for regression and classification

* caller envs

* unit tests

* rework call envs

* survival unit tests

* Apply suggestions from code review

Co-authored-by: Hannah Frick <[email protected]>

* update tests for more recent version of yardstick

* use standalone file from tidymodels/parsnip#1034

---------

Co-authored-by: Hannah Frick <[email protected]>
  • Loading branch information
topepo and hfrick authored Dec 11, 2023
1 parent 8f99330 commit 585e01d
Show file tree
Hide file tree
Showing 15 changed files with 2,015 additions and 31 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9002
Version: 1.1.2.9003
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down Expand Up @@ -38,7 +38,7 @@ Imports:
vctrs (>= 0.6.1),
withr,
workflows (>= 1.0.0),
yardstick (>= 1.2.0.9001)
yardstick (>= 1.2.0.9002)
Suggests:
C50,
censored,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ export(.stash_last_result)
export(.use_case_weights_with_yardstick)
export(augment)
export(autoplot)
export(check_eval_time_arg)
export(check_initial)
export(check_metrics)
export(check_metrics_arg)
export(check_parameters)
export(check_rset)
export(check_time)
Expand Down
9 changes: 6 additions & 3 deletions R/last_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL
wflow <- add_formula(wflow, preprocessor)
}

last_fit_workflow(wflow, split, metrics, control, eval_time, add_validation_set)
last_fit_workflow(wflow, split, metrics, control, eval_time,
add_validation_set)
}


Expand All @@ -153,7 +154,8 @@ last_fit.workflow <- function(object, split, ..., metrics = NULL,

control <- parsnip::condense_control(control, control_last_fit())

last_fit_workflow(object, split, metrics, control, eval_time, add_validation_set)
last_fit_workflow(object, split, metrics, control, eval_time,
add_validation_set)
}


Expand Down Expand Up @@ -191,7 +193,8 @@ last_fit_workflow <- function(object,
metrics = metrics,
control = control,
eval_time = eval_time,
rng = rng
rng = rng,
call = call
)

res$.workflow <- res$.extracts[[1]][[1]]
Expand Down
116 changes: 115 additions & 1 deletion R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
#' @param metric A character value for which metric is being used.
#' @param eval_time An optional vector of times to compute dynamic and/or
#' integrated metrics.
#' @param wflow A [workflows::workflow()].
#' @param x An object with class `tune_results`.
#' @param call The call to be displayed in warnings or errors.
#' @description
#' @details
#' These are developer-facing functions used to compute and validate choices
#' for performance metrics. For survival analysis models, there are similar
#' functions for the evaluation time(s) required for dynamic and/or integrated
Expand Down Expand Up @@ -191,6 +192,119 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL) {
summary_res
}

# ------------------------------------------------------------------------------

#' @rdname choose_metric
#' @export
check_metrics_arg <- function(mtr_set, wflow, call = rlang::caller_env()) {
mode <- extract_spec_parsnip(wflow)$mode

if (is.null(mtr_set)) {
switch(mode,
regression = {
mtr_set <- yardstick::metric_set(rmse, rsq)
},
classification = {
mtr_set <- yardstick::metric_set(roc_auc, accuracy)
},
'censored regression' = {
mtr_set <- yardstick::metric_set(brier_survival)
},
# workflows cannot be set with an unknown mode
cli::cli_abort("Model value {.val {mode}} can't be used.", call = call)
)

return(mtr_set)
}

is_numeric_metric_set <- inherits(mtr_set, "numeric_metric_set")
is_class_prob_metric_set <- inherits(mtr_set, "class_prob_metric_set")
is_surv_metric_set <- inherits(mtr_set, c("survival_metric_set"))

if (!is_numeric_metric_set && !is_class_prob_metric_set && !is_surv_metric_set) {
cli::cli_abort("The {.arg metrics} argument should be the results of
{.fn yardstick::metric_set}.", call = call)
}

if (mode == "regression" && !is_numeric_metric_set) {
cli::cli_abort("The parsnip model has {.code mode} value of {.val {mode}},
but the {.arg metrics} is a metric set for a
different model mode.", call = call)
}

if (mode == "classification" && !is_class_prob_metric_set) {
cli::cli_abort("The parsnip model has {.code mode} value of {.val {mode}},
but the {.arg metrics} is a metric set for a
different model mode.", call = call)
}

if (mode == "censored regression" && !is_surv_metric_set) {
cli::cli_abort("The parsnip model has {.code mode} value of {.val {mode}},
but the {.arg metrics} is a metric set for a
different model mode.", call = call)
}

mtr_set
}

# ------------------------------------------------------------------------------

#' @rdname choose_metric
#' @export
check_eval_time_arg <- function(eval_time, mtr_set, call = rlang::caller_env()) {
mtr_info <- tibble::as_tibble(mtr_set)

# Not a survival metric
if (!contains_survival_metric(mtr_info)) {
if (!is.null(eval_time)) {
cli::cli_warn("Evaluation times are only required when the model
mode is {.val censored regression} (and will be ignored).")
}
return(NULL)
}

cls <- mtr_info$class
uni_cls <- sort(unique(cls))
eval_time <- .filter_eval_time(eval_time)

num_times <- length(eval_time)

max_times_req <- req_eval_times(mtr_set)

if (max_times_req > num_times) {
cli::cli_abort("At least {max_times_req} evaluation time{?s} {?is/are}
required for the metric type(s) requested: {.val {uni_cls}}.
Only {num_times} unique time{?s} {?was/were} given.",
call = call)
}

if (max_times_req == 0 & num_times > 0) {
cli::cli_warn("Evaluation times are only required when dynmanic or
integrated metrics are used (and will be ignored here).")
eval_time <- NULL
}

eval_time
}

req_eval_times <- function(mtr_set) {
mtr_info <- tibble::as_tibble(mtr_set)
cls <- mtr_info$class

# Default for non-survival and static metrics
max_req_times <- 0

if (any(cls == "dynamic_survival_metric")) {
max_req_times <- max(max_req_times, 1)
}

if (any(cls == "integrated_survival_metric")) {
max_req_times <- max(max_req_times, 2)
}

max_req_times
}

# TODO will be removed shortly

middle_eval_time <- function(x) {
Expand Down
141 changes: 141 additions & 0 deletions R/standalone-survival.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# ---
# repo: tidymodels/parsnip
# file: standalone-survival.R
# last-updated: 2023-12-08
# license: https://unlicense.org
# ---

# This file provides a portable set of helper functions for survival analysis.
# The original is in the parsnip package.

# ## Changelog
# 2023-12-08
# * move .filter_eval_time to this file
#
# 2023-11-09
# * make sure survival vectors are unnamed.
#
# 2023-06-14
# * removed time to factor conversion
#
# 2023-05-18
# * added time to factor conversion
#
# 2023-02-28:
# * Initial version
#
# ------------------------------------------------------------------------------
#
# @param surv A [survival::Surv()] object
# @details
# `.is_censored_right()` always returns a logical while
# `.check_censored_right()` will fail if `FALSE`.
#
# `.extract_status()` will return the data as 0/1 even if the original object
# used the legacy encoding of 1/2. See [survival::Surv()].

# @return
# - `.extract_surv_status()` returns a vector.
# - `.extract_surv_time()` returns a vector when the type is `"right"` or `"left"`
# and a tibble otherwise.
# - Functions starting with `.is_` or `.check_` return logicals although the
# latter will fail when `FALSE`.

# nocov start
# These are tested in the extratests repo since it would require a dependency
# on the survival package. https://github.com/tidymodels/extratests/pull/78
.is_surv <- function(surv, fail = TRUE, call = rlang::caller_env()) {
is_surv <- inherits(surv, "Surv")
if (!is_surv && fail) {
rlang::abort("The object does not have class `Surv`.", call = call)
}
is_surv
}

.extract_surv_type <- function(surv) {
attr(surv, "type")
}

.check_cens_type <-
function(surv,
type = "right",
fail = TRUE,
call = rlang::caller_env()) {
.is_surv(surv, call = call)
obj_type <- .extract_surv_type(surv)
good_type <- all(obj_type %in% type)
if (!good_type && fail) {
c_list <- paste0("'", type, "'")
msg <- cli::format_inline("For this usage, the allowed censoring type{?s} {?is/are}: {c_list}")
rlang::abort(msg, call = call)
}
good_type
}

.is_censored_right <- function(surv) {
.check_cens_type(surv, type = "right", fail = FALSE)
}

.check_censored_right <- function(surv, call = rlang::caller_env()) {
.check_cens_type(surv, type = "right", fail = TRUE, call = call)
} # will add more as we need them

.extract_surv_time <- function(surv) {
.is_surv(surv)
keepers <- c("time", "start", "stop", "time1", "time2")
cols <- colnames(surv)[colnames(surv) %in% keepers]
res <- surv[, cols, drop = FALSE]
if (length(cols) > 1) {
res <- tibble::tibble(as.data.frame(res))
} else {
res <- as.numeric(res)
}
res
}

.extract_surv_status <- function(surv) {
.is_surv(surv)
res <- surv[, "status"]
un_vals <- sort(unique(res))
event_type_to_01 <-
!(.extract_surv_type(surv) %in% c("interval", "interval2", "mstate"))
if (
event_type_to_01 &&
(identical(un_vals, 1:2) | identical(un_vals, c(1.0, 2.0))) ) {
res <- res - 1
}
unname(res)
}

# nocov end

# ------------------------------------------------------------------------------

# @param eval_time A vector of numeric time points
# @details
# `.filter_eval_time` checks the validity of the time points.
#
# @return A potentially modified vector of time points.
.filter_eval_time <- function(eval_time, fail = TRUE) {
if (!is.null(eval_time)) {
eval_time <- as.numeric(eval_time)
}
eval_time_0 <- eval_time
# will still propagate nulls:
eval_time <- eval_time[!is.na(eval_time)]
eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)]
eval_time <- unique(eval_time)
if (fail && identical(eval_time, numeric(0))) {
cli::cli_abort(
"There were no usable evaluation times (finite, non-missing, and >= 0).",
call = NULL
)
}
if (!identical(eval_time, eval_time_0)) {
diffs <- setdiff(eval_time_0, eval_time)
cli::cli_warn("There {?was/were} {length(diffs)} inappropriate evaluation
time point{?s} that {?was/were} removed.", call = NULL)

}
eval_time
}
15 changes: 8 additions & 7 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,13 @@ tune_bayes_workflow <-

check_iter(iter, call = call)

metrics <- check_metrics(metrics, object)
check_eval_time(eval_time, metrics)
metrics_data <- metrics_info(metrics)
metrics_name <- metrics_data$.metric[1]
metrics_time <- get_metric_time(metrics, eval_time)
maximize <- metrics_data$direction[metrics_data$.metric == metrics_name] == "maximize"
metrics <- check_metrics_arg(metrics, object, call = call)
opt_metric <- first_metric(metrics)
metrics_name <- opt_metric$metric
maximize <- opt_metric$direction == "maximize"

eval_time <- check_eval_time_arg(eval_time, metrics, call = call)
metrics_time <- first_eval_time(metrics, metrics_name, eval_time)

if (is.null(param_info)) {
param_info <- hardhat::extract_parameter_set_dials(object)
Expand Down Expand Up @@ -334,7 +335,7 @@ tune_bayes_workflow <-

if (control$verbose_iter) {
msg <- paste("Optimizing", metrics_name, "using", objective$label)
if (!is.null(eval_time)) {
if (!is.null(metrics_time)) {
msg <- paste(msg, "at evaluation time", format(metrics_time, digits = 3))
}
message_wrap(msg)
Expand Down
5 changes: 2 additions & 3 deletions R/tune_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,8 @@ tune_grid_workflow <- function(workflow,
call = caller_env()) {
check_rset(resamples)


metrics <- check_metrics(metrics, workflow)
check_eval_time(eval_time, metrics)
metrics <- check_metrics_arg(metrics, workflow, call = call)
eval_time <- check_eval_time_arg(eval_time, metrics, call = call)

pset <- check_parameters(
workflow,
Expand Down
Loading

0 comments on commit 585e01d

Please sign in to comment.