diff --git a/R/Task.R b/R/Task.R index 21135a335..1f39cd1ab 100644 --- a/R/Task.R +++ b/R/Task.R @@ -36,6 +36,7 @@ #' * Any modification of the lists `$col_roles` or `$row_roles`. #' This provides a different "view" on the data without altering the data itself. #' * Modification of column or row roles via `$set_col_roles()` or `$set_row_roles()`, respectively. +#' * Modification of the row roles via `$split_validation()`. #' * `$filter()` and `$select()` subset the set of active rows or features in `$row_roles` or `$col_roles`, respectively. #' This provides a different "view" on the data without altering the data itself. #' * `rbind()` and `cbind()` change the task in-place by binding rows or columns to the data, but without modifying the original [DataBackend]. @@ -550,6 +551,31 @@ Task = R6Class("Task", self$col_info = ujoin(self$col_info, tab, key = "id") invisible(self) + }, + + #' @description + #' Keeps `ratio` percent of the active observations (row role `"use"`) and moves the + #' other `(1 - ratio)` percent to the validation set (row role `"validation"`). + #' Internally, [ResamplingHoldout] is called to support the column roles `"strata"` and + #' `"groups"`. + #' + #' + #' If you need more fine-grained control over which rows to put into the validation + #' data set, use `$set_row_roles(row_ids, "validation")` instead. + #' + #' @param ratio (`numeric(1)`)\cr + #' Proportion of the rows to keep for training. + #' + #' @return Modified `self`. + split_validation = function(ratio = 0.67) { + assert_number(ratio, lower = 0, upper = 1) + n = length(self$row_roles$validation) + if (n > 0L) { + stopf("%i rows already in the validation set", n) + } + + r = rsmp("holdout", ratio = ratio)$instantiate(self) + self$set_row_roles(r$test_set(1L), "validation") } ), @@ -608,18 +634,21 @@ Task = R6Class("Task", #' Possible properties are are stored in [mlr_reflections$task_properties][mlr_reflections]. #' The following properties are currently standardized and understood by tasks in \CRANpkg{mlr3}: #' - #' * `"strata"`: The task is resampled using one or more stratification variables (role `"stratum"`). - #' * `"groups"`: The task comes with grouping/blocking information (role `"group"`). - #' * `"weights"`: The task comes with observation weights (role `"weight"`). + #' * `"strata"`: The task is resampled using one or more stratification variables (column role `"stratum"`). + #' * `"groups"`: The task comes with grouping/blocking information (column role `"group"`). + #' * `"validation"`: The task holds observations for validation (row role `"validation"`). + #' * `"weights"`: The task comes with observation weights (column role `"weight"`). #' #' Note that above listed properties are calculated from the `$col_roles` and may not be set explicitly. properties = function(rhs) { if (missing(rhs)) { col_roles = private$.col_roles + row_roles = private$.row_roles c(character(), private$.properties, if (length(col_roles$group)) "groups" else NULL, if (length(col_roles$stratum)) "strata" else NULL, + if (length(row_roles$validation)) "validation" else NULL, if (length(col_roles$weight)) "weights" else NULL ) } else { @@ -636,6 +665,9 @@ Task = R6Class("Task", #' #' `row_roles` is a named list whose elements are named by row role and each element is an `integer()` vector of row ids. #' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots). + #' The method `$set_row_roles()` provides a convenient alternative to assign rows to roles. + #' Additionally, `$split_validation()` is a quick way to assign a random subset of the + #' observations to the validation set. row_roles = function(rhs) { if (missing(rhs)) { return(private$.row_roles) @@ -665,7 +697,7 @@ Task = R6Class("Task", #' #' `col_roles` is a named list whose elements are named by column role and each element is a `character()` vector of column names. #' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots). - #' The method `$set_col_roles` provides a convenient alternative to assign columns to roles. + #' The method `$set_col_roles()` provides a convenient alternative to assign columns to roles. col_roles = function(rhs) { if (missing(rhs)) { return(private$.col_roles) diff --git a/R/mlr_reflections.R b/R/mlr_reflections.R index 81edf403f..a7778b754 100644 --- a/R/mlr_reflections.R +++ b/R/mlr_reflections.R @@ -42,7 +42,7 @@ #' List of lists of supported [Learner] predict_types, named by their task type. #' #' * `predict_sets` (`character()`)\cr -#' Vector of possible predict sets. Currently supported are `"train"` and `"test"`. +#' Vector of possible predict sets. Currently supported are `"train"`, `"test"` and `"validation"`. #' #' * `measure_properties` (list of `character()`)\cr #' List of vectors of supported [Measure] properties, named by their task type. diff --git a/man/Task.Rd b/man/Task.Rd index 543aa0627..6505e1174 100644 --- a/man/Task.Rd +++ b/man/Task.Rd @@ -36,6 +36,7 @@ The following methods change the task in-place: \item Any modification of the lists \verb{$col_roles} or \verb{$row_roles}. This provides a different "view" on the data without altering the data itself. \item Modification of column or row roles via \verb{$set_col_roles()} or \verb{$set_row_roles()}, respectively. +\item Modification of the row roles via \verb{$split_validation()}. \item \verb{$filter()} and \verb{$select()} subset the set of active rows or features in \verb{$row_roles} or \verb{$col_roles}, respectively. This provides a different "view" on the data without altering the data itself. \item \code{rbind()} and \code{cbind()} change the task in-place by binding rows or columns to the data, but without modifying the original \link{DataBackend}. @@ -151,9 +152,10 @@ Set of task properties. Possible properties are are stored in \link[=mlr_reflections]{mlr_reflections$task_properties}. The following properties are currently standardized and understood by tasks in \CRANpkg{mlr3}: \itemize{ -\item \code{"strata"}: The task is resampled using one or more stratification variables (role \code{"stratum"}). -\item \code{"groups"}: The task comes with grouping/blocking information (role \code{"group"}). -\item \code{"weights"}: The task comes with observation weights (role \code{"weight"}). +\item \code{"strata"}: The task is resampled using one or more stratification variables (column role \code{"stratum"}). +\item \code{"groups"}: The task comes with grouping/blocking information (column role \code{"group"}). +\item \code{"validation"}: The task holds observations for validation (row role \code{"validation"}). +\item \code{"weights"}: The task comes with observation weights (column role \code{"weight"}). } Note that above listed properties are calculated from the \verb{$col_roles} and may not be set explicitly.} @@ -167,7 +169,10 @@ Can be used as truly independent test set. } \code{row_roles} is a named list whose elements are named by row role and each element is an \code{integer()} vector of row ids. -To alter the roles, just modify the list, e.g. with \R's set functions (\code{\link[=intersect]{intersect()}}, \code{\link[=setdiff]{setdiff()}}, \code{\link[=union]{union()}}, \ldots).} +To alter the roles, just modify the list, e.g. with \R's set functions (\code{\link[=intersect]{intersect()}}, \code{\link[=setdiff]{setdiff()}}, \code{\link[=union]{union()}}, \ldots). +The method \verb{$set_row_roles()} provides a convenient alternative to assign rows to roles. +Additionally, \verb{$split_validation()} is a quick way to assign a random subset of the +observations to the validation set.} \item{\code{col_roles}}{(named \code{list()})\cr Each column (feature) can have an arbitrary number of the following roles: @@ -186,7 +191,7 @@ Note that only up to one column may have this role. \code{col_roles} is a named list whose elements are named by column role and each element is a \code{character()} vector of column names. To alter the roles, just modify the list, e.g. with \R's set functions (\code{\link[=intersect]{intersect()}}, \code{\link[=setdiff]{setdiff()}}, \code{\link[=union]{union()}}, \ldots). -The method \verb{$set_col_roles} provides a convenient alternative to assign columns to roles.} +The method \verb{$set_col_roles()} provides a convenient alternative to assign columns to roles.} \item{\code{nrow}}{(\code{integer(1)})\cr Returns the total number of rows with role "use".} @@ -270,6 +275,7 @@ Returns \code{NULL} if there are is no uri column.} \item \href{#method-set_row_roles}{\code{Task$set_row_roles()}} \item \href{#method-set_col_roles}{\code{Task$set_col_roles()}} \item \href{#method-droplevels}{\code{Task$droplevels()}} +\item \href{#method-split_validation}{\code{Task$split_validation()}} \item \href{#method-clone}{\code{Task$clone()}} } } @@ -713,6 +719,33 @@ Modified \code{self}. } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-split_validation}{}}} +\subsection{Method \code{split_validation()}}{ +Keeps \code{ratio} percent of the active observations (row role \code{"use"}) and moves the +other \code{(1 - ratio)} percent to the validation set (row role \code{"validation"}). +Internally, \link{ResamplingHoldout} is called to support the column roles \code{"strata"} and +\code{"groups"}. + +If you need more fine-grained control over which rows to put into the validation +data set, use \verb{$set_row_roles(row_ids, "validation")} instead. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Task$split_validation(ratio = 0.67)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{ratio}}{(\code{numeric(1)})\cr +Proportion of the rows to keep for training.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +Modified \code{self}. +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-clone}{}}} \subsection{Method \code{clone()}}{ diff --git a/man/TaskClassif.Rd b/man/TaskClassif.Rd index aac92305f..ae8f7fb31 100644 --- a/man/TaskClassif.Rd +++ b/man/TaskClassif.Rd @@ -95,6 +95,7 @@ Stores the negative class for binary classification tasks, and \code{NA} for mul \item \out{}\href{../../mlr3/html/Task.html#method-select}{\code{mlr3::Task$select()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_col_roles}{\code{mlr3::Task$set_col_roles()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_row_roles}{\code{mlr3::Task$set_row_roles()}}\out{} +\item \out{}\href{../../mlr3/html/Task.html#method-split_validation}{\code{mlr3::Task$split_validation()}}\out{} } \out{} } diff --git a/man/TaskRegr.Rd b/man/TaskRegr.Rd index 2cb13e13b..2c5d354b1 100644 --- a/man/TaskRegr.Rd +++ b/man/TaskRegr.Rd @@ -70,6 +70,7 @@ Other Task: \item \out{}\href{../../mlr3/html/Task.html#method-select}{\code{mlr3::Task$select()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_col_roles}{\code{mlr3::Task$set_col_roles()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_row_roles}{\code{mlr3::Task$set_row_roles()}}\out{} +\item \out{}\href{../../mlr3/html/Task.html#method-split_validation}{\code{mlr3::Task$split_validation()}}\out{} } \out{} } diff --git a/man/TaskSupervised.Rd b/man/TaskSupervised.Rd index 90a5c8dbf..557a4079a 100644 --- a/man/TaskSupervised.Rd +++ b/man/TaskSupervised.Rd @@ -64,6 +64,7 @@ Other Task: \item \out{}\href{../../mlr3/html/Task.html#method-select}{\code{mlr3::Task$select()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_col_roles}{\code{mlr3::Task$set_col_roles()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_row_roles}{\code{mlr3::Task$set_row_roles()}}\out{} +\item \out{}\href{../../mlr3/html/Task.html#method-split_validation}{\code{mlr3::Task$split_validation()}}\out{} } \out{} } diff --git a/man/TaskUnsupervised.Rd b/man/TaskUnsupervised.Rd index 2ae9e79e5..c55a2a321 100644 --- a/man/TaskUnsupervised.Rd +++ b/man/TaskUnsupervised.Rd @@ -59,6 +59,7 @@ Other Task: \item \out{}\href{../../mlr3/html/Task.html#method-select}{\code{mlr3::Task$select()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_col_roles}{\code{mlr3::Task$set_col_roles()}}\out{} \item \out{}\href{../../mlr3/html/Task.html#method-set_row_roles}{\code{mlr3::Task$set_row_roles()}}\out{} +\item \out{}\href{../../mlr3/html/Task.html#method-split_validation}{\code{mlr3::Task$split_validation()}}\out{} } \out{} } diff --git a/man/mlr_reflections.Rd b/man/mlr_reflections.Rd index 6d97828fa..1abf1ba84 100644 --- a/man/mlr_reflections.Rd +++ b/man/mlr_reflections.Rd @@ -42,7 +42,7 @@ predict type \code{"prob"} for a \link{LearnerClassif} provides the probabilitie \item \code{learner_predict_types} (list of list of \code{character()})\cr List of lists of supported \link{Learner} predict_types, named by their task type. \item \code{predict_sets} (\code{character()})\cr -Vector of possible predict sets. Currently supported are \code{"train"} and \code{"test"}. +Vector of possible predict sets. Currently supported are \code{"train"}, \code{"test"} and \code{"validation"}. \item \code{measure_properties} (list of \code{character()})\cr List of vectors of supported \link{Measure} properties, named by their task type. \item \code{default_measures} (list of \code{character()})\cr diff --git a/tests/testthat/test_Task.R b/tests/testthat/test_Task.R index 09bf16b38..07a34f0c0 100644 --- a/tests/testthat/test_Task.R +++ b/tests/testthat/test_Task.R @@ -388,3 +388,27 @@ test_that("Task$set_col_roles", { expect_true("age" %in% task$feature_names) expect_null(task$weights) }) + +test_that("split_validation", { + task = tsk("mtcars") + task$split_validation(0.75) + expect_true("validation" %in% task$properties) + expect_equal(task$nrow, 24L) + expect_integer(task$row_roles$validation, len = 8L) + expect_error(task$split_validation(), "already in the validation set") + + # validation workflow + learner = lrn("regr.featureless", predict_sets = c("test", "validation")) + learner$train(task) + p = learner$predict(task, row_ids = task$row_roles$validation) + expect_data_table(as.data.table(p), nrows = 8) + + rr = resample(task, learner, rsmp("cv", folds = 3)) + measures = list( + msr("regr.mae", id = "mae_test"), + msr("regr.mae", id = "mae_validation", predict_sets = "validation") + ) + aggr = rr$aggregate(measures) + expect_numeric(aggr, len = 2L, any.missing = FALSE) + expect_true(aggr[1] != aggr[2]) +})