-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactored functions to evaluate/check metric and eval time arguments (…
…#780) * move function to new file * change function order for docs * documentation start * updates to the show/select functions * updates to select/show functions * updates for selecting eval times * remove commented out code * bug fix * metric test cases * add a survival model object * note for next PR * select/show test cases * small set of direct tests * update snapshot * Apply suggestions from code review Co-authored-by: Hannah Frick <[email protected]> * Apply suggestions from code review * updates from previous review * small cli update * doc update * refresh snapshots * modularize a check * Remake with newest CRAN version of scales for #775 * argument checks for tune functions * Apply suggestions from code review Co-authored-by: Hannah Frick <[email protected]> * Apply suggestions from code review Co-authored-by: Hannah Frick <[email protected]> * add dot when function is invoked * add a warning for eval times with non-survival models * go back to enquos * rework warning text * rework warning text pt 2 * unit tests for regression and classification * caller envs * unit tests * rework call envs * survival unit tests * Apply suggestions from code review Co-authored-by: Hannah Frick <[email protected]> * update tests for more recent version of yardstick * use standalone file from tidymodels/parsnip#1034 --------- Co-authored-by: Hannah Frick <[email protected]>
- Loading branch information
Showing
15 changed files
with
2,015 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: tune | ||
Title: Tidy Tuning Tools | ||
Version: 1.1.2.9002 | ||
Version: 1.1.2.9003 | ||
Authors@R: c( | ||
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"), | ||
comment = c(ORCID = "0000-0003-2402-136X")), | ||
|
@@ -38,7 +38,7 @@ Imports: | |
vctrs (>= 0.6.1), | ||
withr, | ||
workflows (>= 1.0.0), | ||
yardstick (>= 1.2.0.9001) | ||
yardstick (>= 1.2.0.9002) | ||
Suggests: | ||
C50, | ||
censored, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# --- | ||
# repo: tidymodels/parsnip | ||
# file: standalone-survival.R | ||
# last-updated: 2023-12-08 | ||
# license: https://unlicense.org | ||
# --- | ||
|
||
# This file provides a portable set of helper functions for survival analysis. | ||
# The original is in the parsnip package. | ||
|
||
# ## Changelog | ||
# 2023-12-08 | ||
# * move .filter_eval_time to this file | ||
# | ||
# 2023-11-09 | ||
# * make sure survival vectors are unnamed. | ||
# | ||
# 2023-06-14 | ||
# * removed time to factor conversion | ||
# | ||
# 2023-05-18 | ||
# * added time to factor conversion | ||
# | ||
# 2023-02-28: | ||
# * Initial version | ||
# | ||
# ------------------------------------------------------------------------------ | ||
# | ||
# @param surv A [survival::Surv()] object | ||
# @details | ||
# `.is_censored_right()` always returns a logical while | ||
# `.check_censored_right()` will fail if `FALSE`. | ||
# | ||
# `.extract_status()` will return the data as 0/1 even if the original object | ||
# used the legacy encoding of 1/2. See [survival::Surv()]. | ||
|
||
# @return | ||
# - `.extract_surv_status()` returns a vector. | ||
# - `.extract_surv_time()` returns a vector when the type is `"right"` or `"left"` | ||
# and a tibble otherwise. | ||
# - Functions starting with `.is_` or `.check_` return logicals although the | ||
# latter will fail when `FALSE`. | ||
|
||
# nocov start | ||
# These are tested in the extratests repo since it would require a dependency | ||
# on the survival package. https://github.com/tidymodels/extratests/pull/78 | ||
.is_surv <- function(surv, fail = TRUE, call = rlang::caller_env()) { | ||
is_surv <- inherits(surv, "Surv") | ||
if (!is_surv && fail) { | ||
rlang::abort("The object does not have class `Surv`.", call = call) | ||
} | ||
is_surv | ||
} | ||
|
||
.extract_surv_type <- function(surv) { | ||
attr(surv, "type") | ||
} | ||
|
||
.check_cens_type <- | ||
function(surv, | ||
type = "right", | ||
fail = TRUE, | ||
call = rlang::caller_env()) { | ||
.is_surv(surv, call = call) | ||
obj_type <- .extract_surv_type(surv) | ||
good_type <- all(obj_type %in% type) | ||
if (!good_type && fail) { | ||
c_list <- paste0("'", type, "'") | ||
msg <- cli::format_inline("For this usage, the allowed censoring type{?s} {?is/are}: {c_list}") | ||
rlang::abort(msg, call = call) | ||
} | ||
good_type | ||
} | ||
|
||
.is_censored_right <- function(surv) { | ||
.check_cens_type(surv, type = "right", fail = FALSE) | ||
} | ||
|
||
.check_censored_right <- function(surv, call = rlang::caller_env()) { | ||
.check_cens_type(surv, type = "right", fail = TRUE, call = call) | ||
} # will add more as we need them | ||
|
||
.extract_surv_time <- function(surv) { | ||
.is_surv(surv) | ||
keepers <- c("time", "start", "stop", "time1", "time2") | ||
cols <- colnames(surv)[colnames(surv) %in% keepers] | ||
res <- surv[, cols, drop = FALSE] | ||
if (length(cols) > 1) { | ||
res <- tibble::tibble(as.data.frame(res)) | ||
} else { | ||
res <- as.numeric(res) | ||
} | ||
res | ||
} | ||
|
||
.extract_surv_status <- function(surv) { | ||
.is_surv(surv) | ||
res <- surv[, "status"] | ||
un_vals <- sort(unique(res)) | ||
event_type_to_01 <- | ||
!(.extract_surv_type(surv) %in% c("interval", "interval2", "mstate")) | ||
if ( | ||
event_type_to_01 && | ||
(identical(un_vals, 1:2) | identical(un_vals, c(1.0, 2.0))) ) { | ||
res <- res - 1 | ||
} | ||
unname(res) | ||
} | ||
|
||
# nocov end | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
# @param eval_time A vector of numeric time points | ||
# @details | ||
# `.filter_eval_time` checks the validity of the time points. | ||
# | ||
# @return A potentially modified vector of time points. | ||
.filter_eval_time <- function(eval_time, fail = TRUE) { | ||
if (!is.null(eval_time)) { | ||
eval_time <- as.numeric(eval_time) | ||
} | ||
eval_time_0 <- eval_time | ||
# will still propagate nulls: | ||
eval_time <- eval_time[!is.na(eval_time)] | ||
eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)] | ||
eval_time <- unique(eval_time) | ||
if (fail && identical(eval_time, numeric(0))) { | ||
cli::cli_abort( | ||
"There were no usable evaluation times (finite, non-missing, and >= 0).", | ||
call = NULL | ||
) | ||
} | ||
if (!identical(eval_time, eval_time_0)) { | ||
diffs <- setdiff(eval_time_0, eval_time) | ||
cli::cli_warn("There {?was/were} {length(diffs)} inappropriate evaluation | ||
time point{?s} that {?was/were} removed.", call = NULL) | ||
|
||
} | ||
eval_time | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.