Skip to content

Commit

Permalink
feat: pass parameter values in design (#31)
Browse files Browse the repository at this point in the history
* feat: pass parameter values in design

* chore: update news
  • Loading branch information
be-marc authored May 28, 2024
1 parent a8820b2 commit 10a635d
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 14 deletions.
10 changes: 6 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# mlr3batchmark (development version)

* feat: `reduceResultsBatchmark` gains argument `fun` which is passed on to `batchtools::reduceResultsList`, useful for deleting model data to avoid running out of memory, https://github.com/mlr-org/mlr3batchmark/issues/18 Thanks to Toby Dylan Hocking @tdhock for the PR.
* docs: A warning is now given when the loaded mlr3 version differs from the
mlr3 version stored in the trained learners
* Support marshaling
* feat: The design of `batchmark()` can now include parameter settings.
* feat: `reduceResultsBatchmark` gains argument `fun` which is passed on to `batchtools::reduceResultsList`.
Useful for deleting model data to avoid running out of memory.
Thanks to Toby Dylan Hocking @tdhock for the PR (https://github.com/mlr-org/mlr3batchmark/issues/18).
* docs: A warning is now given when the loaded mlr3 version differs from the mlr3 version stored in the trained learners
* feat: support marshaling

# mlr3batchmark 0.1.1

Expand Down
12 changes: 12 additions & 0 deletions R/assertions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
assert_param_values = function(x, n_learners = NULL, .var.name = vname(x)) {
assert_list(x, len = n_learners, .var.name = .var.name)

ok = every(x, function(x) {
test_list(x) && every(x, test_list, names = "unique", null.ok = TRUE)
})

if (!ok) {
stopf("'%s' must be a three-time nested list and the most inner list must be named", .var.name)
}
invisible(x)
}
32 changes: 27 additions & 5 deletions R/batchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#' reduceResultsBatchmark(reg = reg)
batchmark = function(design, store_models = FALSE, reg = batchtools::getDefaultRegistry()) {
design = as.data.table(assert_data_frame(design, min.rows = 1L))
assert_names(names(design), permutation.of = c("task", "learner", "resampling"))
assert_names(names(design), must.include = c("task", "learner", "resampling"))
assert_flag(store_models)
batchtools::assertRegistry(reg, class = "ExperimentRegistry", writeable = TRUE, sync = TRUE,
running.ok = FALSE)
Expand All @@ -53,10 +53,22 @@ batchmark = function(design, store_models = FALSE, reg = batchtools::getDefaultR
batchtools::addAlgorithm("run_learner", fun = run_learner, reg = reg)
}

# group per problem to speed up addExperiments()
# set hashes
set(design, j = "task_hash", value = map_chr(design$task, "hash"))
set(design, j = "learner_hash", value = map_chr(design$learner, "hash"))
set(design, j = "resampling_hash", value = map_chr(design$resampling, "hash"))

# expand with param values
if (is.null(design$param_values)) {
design$param_values = list()
} else {
design$param_values = list(assert_param_values(design$param_values, n_learners = length(design$learner)))
task = learner = resampling = NULL
design = design[, list(task, learner, resampling, param_values = unlist(get("param_values"), recursive = FALSE)), by = c("learner_hash", "task_hash", "resampling_hash")]
}
design[, "param_values_hash" := map(get("param_values"), calculate_hash)]

# group per problem to speed up addExperiments()
design[, "group" := .GRP, by = c("task_hash", "resampling_hash")]

groups = unique(design$group)
Expand Down Expand Up @@ -85,13 +97,23 @@ batchmark = function(design, store_models = FALSE, reg = batchtools::getDefaultR
exports = c(exports, learner_hashes[i])
}

param_values_hashes = tab$param_values_hash
for (i in which(param_values_hashes %nin% exports)) {
batchtools::batchExport(export = set_names(list(tab$param_values[[i]]), param_values_hashes[i]), reg = reg)
exports = c(exports, param_values_hashes[i])
}

prob_design = data.table(
task_hash = task_hash, task_id = task$id,
resampling_hash = resampling_hash, resampling_id = resampling$id
task_hash = task_hash,
task_id = task$id,
resampling_hash = resampling_hash,
resampling_id = resampling$id
)

algo_design = data.table(
learner_hash = learner_hashes, learner_id = map_chr(tab$learner, "id"),
learner_hash = learner_hashes,
learner_id = map_chr(tab$learner, "id"),
param_values_hash = param_values_hashes,
store_models = store_models
)

Expand Down
3 changes: 1 addition & 2 deletions R/reduceResultsBatchmark.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#' @title Collect Results from batchmark
#'
#' @description
#' Collect the results from jobs defined via [batchmark()] and combine them into
#' a [mlr3::BenchmarkResult].
#' Collect the results from jobs defined via [batchmark()] and combine them into a [mlr3::BenchmarkResult].
#'
#' Note that `ids` defaults to finished jobs (as reported by [batchtools::findDone()]).
#' If a job threw an error, is expired or is still running, it will be ignored with this default.
Expand Down
5 changes: 4 additions & 1 deletion R/worker.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
run_learner = function(job, data, learner_hash, store_models, ...) {
run_learner = function(job, data, learner_hash, param_values_hash, store_models, ...) {
workhorse = utils::getFromNamespace("workhorse", ns = asNamespace("mlr3"))
resampling = get(job$prob.pars$resampling_hash, envir = .GlobalEnv)
learner = get(learner_hash, envir = .GlobalEnv)
param_values = get(param_values_hash, envir = .GlobalEnv)

if (!is.null(param_values)) learner$param_set$set_values(.values = param_values)

workhorse(
iteration = job$repl,
Expand Down
3 changes: 1 addition & 2 deletions man/reduceResultsBatchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions tests/testthat/test_batchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,32 @@ test_that("marshaling", {
expect_true(bmr_marshaled$resample_result(1)$learners[[1]]$marshaled)
expect_false(bmr_unmarshaled$resample_result(1)$learners[[1]]$marshaled)
})

test_that("adding parameter values works", {
tasks = tsks(c("iris", "spam"))
resamplings = list(rsmp("cv", folds = 3)$instantiate(tasks[[1]]))
learners = lrns("classif.debug")

design = data.table(
task = tasks,
learner = learners,
resampling = resamplings,
param_values = list(list(list(x = 1), list(x = 0.5))))

reg = batchtools::makeExperimentRegistry(NA, make.default = FALSE)

ids = batchmark(design, reg = reg)
expect_data_table(ids, ncol = 1L, nrows = 12L)
ids = batchtools::submitJobs(reg = reg)
batchtools::waitForJobs(reg = reg)
expect_data_table(ids, nrows = 12)

logs = batchtools::getErrorMessages(reg = reg)
expect_data_table(logs, nrows = 0L)
results = reduceResultsBatchmark(reg = reg)
expect_is(results, "BenchmarkResult")
expect_benchmark_result(results)
expect_data_table(as.data.table(results), nrow = 12L)
expect_equal(results$learners$learner[[1]]$param_set$values$x, 1)
expect_equal(results$learners$learner[[2]]$param_set$values$x, 0.5)
})

0 comments on commit 10a635d

Please sign in to comment.