Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate divide method for task #1090

Merged
merged 5 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* refactor: Optimize runtime of setting row roles.
* refactor: Optimize runtime of marshalling.
* refactor: Optimize runtime of `Task$col_info`
* feat: `$internal_valid_task` can now be set to an `integer` vector.
* deprecated the `$divide()` method
* fix: `Task$cbind()` now works with non-standard primary keys
for `data.frames` (#961).
* fix: Triggering of fallback learner now has log-level "info"
Expand Down
26 changes: 17 additions & 9 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,7 @@ Task = R6Class("Task",
},

#' @description
#' Creates an internal validation task (field `$internal_valid_task`) from the primary task.
#' This modifies the task in-place.
#' Subsequent operations on the (primary) task are **not** relayed to the internal validation task.
#' One must either provide the parameter `ratio` or `ids.
#' Deprecated.
#'
#' @param ratio (`numeric(1)`)\cr
#' The proportion of datapoints to use as validation data.
Expand All @@ -172,6 +169,7 @@ Task = R6Class("Task",
#'
#' @return Modified `Self`.
divide = function(ratio = NULL, ids = NULL, remove = TRUE) {
.Deprecated("field $internal_valid_task")
assert_flag(remove)
private$.hash = NULL

Expand Down Expand Up @@ -786,9 +784,11 @@ Task = R6Class("Task",
private$.id = assert_string(rhs, min.chars = 1L)
},

#' @field internal_valid_task (`Task` or `NULL`)\cr
#' @field internal_valid_task (`Task` or `integer()` or `NULL`)\cr
#' Optional validation task that can, e.g., be used for early stopping with learners such as XGBoost.
#' See also the `$validate` field of [`Learner`].
#' If integers are assigned they are removed from the primary task and an internal validation task
#' with those ids is created from the primary task using only those ids.
#' When assigning a new task, it is always cloned.
internal_valid_task = function(rhs) {
if (missing(rhs)) {
Expand All @@ -799,11 +799,19 @@ Task = R6Class("Task",
private$.internal_valid_task = NULL
return(invisible(private$.internal_valid_task))
}
private$.hash = NULL

assert_task(rhs, task_type = self$task_type)
rhs = rhs$clone(deep = TRUE)
if (!is.null(rhs$internal_valid_task)) { # avoid recursive structures
stopf("Trying to assign task '%s' as a validation task, remove its validation task first.", rhs$id)
if (test_integerish(rhs)) {
train_ids = setdiff(self$row_ids, rhs)
rhs = self$clone(deep = TRUE)$filter(rhs)
rhs$internal_valid_task = NULL
self$row_roles$use = train_ids
} else {
if (!is.null(rhs$internal_valid_task)) { # avoid recursive structures
stopf("Trying to assign task '%s' as a validation task, remove its validation task first.", rhs$id)
}
assert_task(rhs, task_type = self$task_type)
rhs = rhs$clone(deep = TRUE)
}

ci1 = self$col_info
Expand Down
5 changes: 4 additions & 1 deletion R/helper_hashes.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ resampling_task_hashes = function(task, resampling, learner = NULL) {
task_hash = function(task, use_ids, test_ids = NULL, ignore_internal_valid_task = FALSE) {
# order matters: we first check for test_ids and then for the internal_valid_task
internal_valid_task_hash = if (!is.null(test_ids)) {
# this does the same as task$divide(ids = test_ids)$internal_valid_task$hash but avoids the deep clone
# this does the same as
# task$internal_valid_task = test_ids
# $internal_valid_task$hash
# but avoids the deep clone
task_hash(task, use_ids = test_ids, test_ids = NULL, ignore_internal_valid_task = TRUE)
} else if (!ignore_internal_valid_task) {
task$internal_valid_task$hash
Expand Down
6 changes: 4 additions & 2 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -463,14 +463,16 @@ create_internal_valid_task = function(validate, task, test_row_ids, prev_valid,
}
# at this point, the train rows are already set to the train set, i.e. we don't have to remove the test ids
# from the primary task (this would cause bugs for resamplings with overlapping train and test set)
task$divide(ids = test_row_ids, remove = FALSE)
valid_task = task$clone(deep = TRUE)
valid_task$row_roles$use = test_row_ids
task$internal_valid_task = valid_task
return(task)
}

return(task)
}

# validate is numeric
task$divide(ratio = validate, remove = TRUE)
task$internal_valid_task = partition(task, ratio = 1 - validate)$test
return(task)
}
9 changes: 4 additions & 5 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 10 additions & 7 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ test_that("Models can be replaced", {

test_that("validation task's backend is removed", {
learner = lrn("regr.rpart")
task = tsk("mtcars")$divide(ids = 1:10)
task = tsk("mtcars")
task$internal_valid_task = 1:10
learner$train(task)
expect_true(is.null(learner$state$train_task$internal_valid_task$backend))
})
Expand All @@ -340,7 +341,7 @@ test_that("manual $train() stores validation hash and validation ids", {

l = lrn("classif.debug", validate = "predefined")
task = tsk("iris")
task$divide(ids = 1:10)
task$internal_valid_task = 1:10
l$train(task)
expect_equal(l$state$internal_valid_task_hash, task$internal_valid_task$hash)

Expand All @@ -354,7 +355,7 @@ test_that("error when training a learner that sets valiadte to 'predefined' on a
task = tsk("iris")
learner = lrn("classif.debug", validate = "predefined")
expect_error(learner$train(task), "is set to 'predefined'")
task$divide(ids = 1:10)
task$internal_valid_task = 1:10
expect_class(learner, "Learner")
})

Expand All @@ -364,7 +365,7 @@ test_that("properties are also checked on validation task", {
row[[1]][1] = NA
row$..row_id = 151
task$rbind(row)
task$divide(ids = 151)
task$internal_valid_task = 151
learner = lrn("classif.debug", validate = "predefined")
learner$properties = setdiff(learner$properties, "missings")
expect_error(learner$train(task), "missing values")
Expand Down Expand Up @@ -416,7 +417,8 @@ test_that("internal_valid_task is created correctly", {
)
# validate = NULL (but task has one)
learner = LearnerClassifTest$new()
task = tsk("iris")$divide(0.3)
task = tsk("iris")
task$internal_valid_task = partition(task)$test
learner$train(task)
learner$validate = NULL
expect_true(is.null(learner$internal_valid_scores))
Expand Down Expand Up @@ -480,7 +482,8 @@ test_that("internal_valid_task is created correctly", {

test_that("compatability check on validation task", {
learner = lrn("classif.debug", validate = "predefined")
task = tsk("german_credit")$divide(ids = 1:10)
task = tsk("german_credit")
task$internal_valid_task = 1:10
task$col_roles$feature = "age"
expect_error(learner$train(task), "has different features")
task$internal_valid_task$col_roles$feature = "age"
Expand Down Expand Up @@ -544,6 +547,6 @@ test_that("learner state contains internal valid task information", {
test_that("validation task with 0 observations", {
learner = lrn("classif.debug", validate = "predefined")
task = tsk("iris")
task$divide(ids = integer(0))
task$internal_valid_task = integer(0)
expect_error({learner$train(task)}, "has 0 observations")
})
53 changes: 10 additions & 43 deletions tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ test_that("Roles get printed (#877)", {

test_that("validation task is cloned", {
task = tsk("iris")
task$divide(ids = c(1:10, 51:60, 101:110))
task$internal_valid_task = c(1:10, 51:60, 101:110)
task2 = task$clone(deep = TRUE)
expect_false(identical(task$internal_valid_task, task2$internal_valid_task))
expect_equal(task$internal_valid_task, task2$internal_valid_task)
Expand All @@ -550,23 +550,10 @@ test_that("task is cloned when assining internal validation task", {
expect_false(identical(task, task$internal_valid_task))
})

test_that("validation task cannot have a validation task", {
task = tsk("iris")
expect_error({task$internal_valid_task = task$clone(deep = TRUE)$divide(ids = 1) }, "remove its validation")
})

test_that("divide works with ratio", {
task = tsk("iris")$filter(1:10)
task$divide(ratio = 0.1)
expect_equal(task$nrow, 9)
expect_equal(task$internal_valid_task$nrow, 1)
expect_permutation(1:10, c(task$row_ids, task$internal_valid_task$row_ids))
})

test_that("validation task changes a task's hash", {
task = tsk("iris")
h1 = task$hash
task$divide(ids = 1:10, remove = FALSE)
task$internal_valid_task = task$clone(deep = TRUE)$filter(1:10)
h2 = task$hash
expect_false(h1 == h2)
})
Expand All @@ -585,20 +572,14 @@ test_that("compatibility checks on internal_valid_task", {

test_that("can NULL validation task", {
task = tsk("iris")
task$divide(ids = 1)
task$internal_valid_task = 1
task$internal_valid_task = NULL
expect_equal(length(task$row_ids), 149)
})

test_that("can call $divide twice", {
task = tsk("iris")
task$divide(ids = 1:10)
expect_task(task$divide(ids = 1:10))
})

test_that("internal_valid_task is printed", {
task = tsk("iris")
task$divide(ids = c(1:10, 51:60, 101:110))
task$internal_valid_task = c(1:10, 51:60, 101:110)
out = capture_output(print(task))
expect_true(grepl(pattern = "* Validation Task: (30x5)", fixed = TRUE, x = out))
})
Expand All @@ -608,31 +589,17 @@ test_that("task hashes during resample", {
task = orig$clone(deep = TRUE)
resampling = rsmp("holdout")
resampling$instantiate(task)
task$divide(ids = resampling$test_set(1))
task$internal_valid_task = resampling$test_set(1)
task$hash
learner = lrn("classif.debug", validate = "test")
expect_equal(resampling_task_hashes(task, resampling, learner), task$hash)
})

test_that("divide remove parameter works", {
task = tsk("iris")
task$divide(ids = 1L, remove = FALSE)
expect_true(1L %in% task$row_ids)
task = tsk("iris")
task$divide(ids = 1L, remove = TRUE)
expect_false(1L %in% task$row_ids)
})

test_that("divide does not take ratio and ids", {
expect_error(tsk("iris")$divide(0.2, 1), "to create a validation task")
})

test_that("divide requires ratio in (0, 1)", {
expect_error(tsk("iris")$divide(1.2))
})

test_that("divide requires ids to be row_ids", {
expect_error(tsk("iris")$divide(ids = 0.5))
test_that("integer vector can be passed to internal_valid_task", {
task = tsk("iris")$filter(1:5)
task$internal_valid_task = 5
expect_permutation(task$row_ids, 1:4)
expect_equal(task$internal_valid_task$row_ids, 5)
})

test_that("cbind supports non-standard primary key (#961)", {
Expand Down
8 changes: 5 additions & 3 deletions tests/testthat/test_benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -491,14 +491,16 @@ test_that("param_values in benchmark", {
test_that("learner's validate cannot be 'test' if internal_valid_set is present", {
# otherwise, predict_set = "internal_valid" would be ambiguous
learner = lrn("classif.debug", validate = "test", predict_sets = c("train", "internal_valid"))
task = tsk("iris")$divide(ids = 1)
task = tsk("iris")
task$internal_valid_task = 1
expect_error(benchmark(benchmark_grid(task, learner, rsmp("holdout"))), "cannot be set to ")
})

test_that("learner's validate cannot be a ratio if internal_valid_set is present", {
# otherwise, predict_set = "internal_valid" would be ambiguous
learner = lrn("classif.debug", validate = 0.5, predict_sets = c("train", "internal_valid"))
task = tsk("iris")$divide(ids = 1)
task = tsk("iris")
task$internal_valid_task = 1
expect_error(benchmark(benchmark_grid(task, learner, rsmp("holdout"))), "cannot be set to ")
})

Expand All @@ -508,7 +510,7 @@ test_that("properties are also checked on validation task", {
row[[1]][1] = NA
row$..row_id = 151
task$rbind(row)
task$divide(ids = 151)
task$internal_valid_task = 151
learner = lrn("classif.debug", validate = "predefined")
learner$properties = setdiff(learner$properties, "missings")

Expand Down
14 changes: 9 additions & 5 deletions tests/testthat/test_resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,19 +282,21 @@ test_that("can make predictions for internal_valid_task", {
test_that("learner's validate cannot be 'test' if internal_valid_set is present", {
# otherwise, predict_set = "internal_valid" would be ambiguous
learner = lrn("classif.debug", validate = "test", predict_sets = c("train", "internal_valid"))
task = tsk("iris")$divide(ids = 1)
task = tsk("iris")
task$internal_valid_task = 1
expect_error(resample(task, learner, rsmp("holdout")), "cannot be set to ")
})

test_that("learner's validate cannot be a ratio if internal_valid_set is present", {
# otherwise, predict_set = "internal_valid" would be ambiguous
learner = lrn("classif.debug", validate = 0.5, predict_sets = c("train", "internal_valid"))
task = tsk("iris")$divide(ids = 1)
task$internal_valid_task = 1
expect_error(resample(task, learner, rsmp("holdout")), "cannot be set to")
})

test_that("internal_valid and train predictions", {
task = tsk("iris")$divide(ids = 1:2)
task = tsk("iris")
task$internal_valid_task = 1:2
learner = lrn("classif.debug", validate = "predefined", predict_sets = c("train", "internal_valid", "test"))
rr = resample(task, learner, rsmp("insample"))
measure_valid = msr("classif.acc")
Expand Down Expand Up @@ -340,7 +342,7 @@ test_that("properties are also checked on validation task", {
row[[1]][1] = NA
row$..row_id = 151
task$rbind(row)
task$divide(ids = 151)
task$internal_valid_task = 151
learner = lrn("classif.debug", validate = "predefined")
learner$properties = setdiff(learner$properties, "missings")

Expand All @@ -366,7 +368,9 @@ test_that("predict_set internal_valid throws error when none is available", {
})

test_that("can even use internal_valid predict set on learners that don't support validation", {
rr = resample(tsk("mtcars")$divide(ids = 1:10), lrn("regr.debug", predict_sets = "internal_valid"), rsmp("holdout"))
task = tsk("mtcars")
task$internal_valid_task = 1:10
rr = resample(task, lrn("regr.debug", predict_sets = "internal_valid"), rsmp("holdout"))
})

test_that("callr during prediction triggers marshaling", {
Expand Down