Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly support validation #607

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#' * Any modification of the lists `$col_roles` or `$row_roles`.
#' This provides a different "view" on the data without altering the data itself.
#' * Modification of column or row roles via `$set_col_roles()` or `$set_row_roles()`, respectively.
#' * Modification of the row roles via `$split_validation()`.
#' * `$filter()` and `$select()` subset the set of active rows or features in `$row_roles` or `$col_roles`, respectively.
#' This provides a different "view" on the data without altering the data itself.
#' * `rbind()` and `cbind()` change the task in-place by binding rows or columns to the data, but without modifying the original [DataBackend].
Expand Down Expand Up @@ -550,6 +551,31 @@ Task = R6Class("Task",

self$col_info = ujoin(self$col_info, tab, key = "id")
invisible(self)
},

#' @description
#' Keeps `ratio` percent of the active observations (row role `"use"`) and moves the
#' other `(1 - ratio)` percent to the validation set (row role `"validation"`).
#' Internally, [ResamplingHoldout] is called to support the column roles `"strata"` and
#' `"groups"`.
#'
#'
#' If you need more fine-grained control over which rows to put into the validation
#' data set, use `$set_row_roles(row_ids, "validation")` instead.
#'
#' @param ratio (`numeric(1)`)\cr
#' Proportion of the rows to keep for training.
#'
#' @return Modified `self`.
split_validation = function(ratio = 0.67) {
assert_number(ratio, lower = 0, upper = 1)
n = length(self$row_roles$validation)
if (n > 0L) {
stopf("%i rows already in the validation set", n)
}

r = rsmp("holdout", ratio = ratio)$instantiate(self)
self$set_row_roles(r$test_set(1L), "validation")
}
),

Expand Down Expand Up @@ -608,18 +634,21 @@ Task = R6Class("Task",
#' Possible properties are are stored in [mlr_reflections$task_properties][mlr_reflections].
#' The following properties are currently standardized and understood by tasks in \CRANpkg{mlr3}:
#'
#' * `"strata"`: The task is resampled using one or more stratification variables (role `"stratum"`).
#' * `"groups"`: The task comes with grouping/blocking information (role `"group"`).
#' * `"weights"`: The task comes with observation weights (role `"weight"`).
#' * `"strata"`: The task is resampled using one or more stratification variables (column role `"stratum"`).
#' * `"groups"`: The task comes with grouping/blocking information (column role `"group"`).
#' * `"validation"`: The task holds observations for validation (row role `"validation"`).
#' * `"weights"`: The task comes with observation weights (column role `"weight"`).
#'
#' Note that above listed properties are calculated from the `$col_roles` and may not be set explicitly.
properties = function(rhs) {
if (missing(rhs)) {
col_roles = private$.col_roles
row_roles = private$.row_roles
c(character(),
private$.properties,
if (length(col_roles$group)) "groups" else NULL,
if (length(col_roles$stratum)) "strata" else NULL,
if (length(row_roles$validation)) "validation" else NULL,
if (length(col_roles$weight)) "weights" else NULL
)
} else {
Expand All @@ -636,6 +665,9 @@ Task = R6Class("Task",
#'
#' `row_roles` is a named list whose elements are named by row role and each element is an `integer()` vector of row ids.
#' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots).
#' The method `$set_row_roles()` provides a convenient alternative to assign rows to roles.
#' Additionally, `$split_validation()` is a quick way to assign a random subset of the
#' observations to the validation set.
row_roles = function(rhs) {
if (missing(rhs)) {
return(private$.row_roles)
Expand Down Expand Up @@ -665,7 +697,7 @@ Task = R6Class("Task",
#'
#' `col_roles` is a named list whose elements are named by column role and each element is a `character()` vector of column names.
#' To alter the roles, just modify the list, e.g. with \R's set functions ([intersect()], [setdiff()], [union()], \ldots).
#' The method `$set_col_roles` provides a convenient alternative to assign columns to roles.
#' The method `$set_col_roles()` provides a convenient alternative to assign columns to roles.
col_roles = function(rhs) {
if (missing(rhs)) {
return(private$.col_roles)
Expand Down
2 changes: 1 addition & 1 deletion R/mlr_reflections.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
#' List of lists of supported [Learner] predict_types, named by their task type.
#'
#' * `predict_sets` (`character()`)\cr
#' Vector of possible predict sets. Currently supported are `"train"` and `"test"`.
#' Vector of possible predict sets. Currently supported are `"train"`, `"test"` and `"validation"`.
#'
#' * `measure_properties` (list of `character()`)\cr
#' List of vectors of supported [Measure] properties, named by their task type.
Expand Down
43 changes: 38 additions & 5 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/TaskClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/TaskRegr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/TaskSupervised.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/TaskUnsupervised.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_reflections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,27 @@ test_that("Task$set_col_roles", {
expect_true("age" %in% task$feature_names)
expect_null(task$weights)
})

test_that("split_validation", {
task = tsk("mtcars")
task$split_validation(0.75)
expect_true("validation" %in% task$properties)
expect_equal(task$nrow, 24L)
expect_integer(task$row_roles$validation, len = 8L)
expect_error(task$split_validation(), "already in the validation set")

# validation workflow
learner = lrn("regr.featureless", predict_sets = c("test", "validation"))
learner$train(task)
p = learner$predict(task, row_ids = task$row_roles$validation)
expect_data_table(as.data.table(p), nrows = 8)

rr = resample(task, learner, rsmp("cv", folds = 3))
measures = list(
msr("regr.mae", id = "mae_test"),
msr("regr.mae", id = "mae_validation", predict_sets = "validation")
)
aggr = rr$aggregate(measures)
expect_numeric(aggr, len = 2L, any.missing = FALSE)
expect_true(aggr[1] != aggr[2])
})