Skip to content

Commit

Permalink
align .config entries in tune_bayes() output (#718)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Sep 8, 2023
1 parent 9571f05 commit 74854a5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

* A method for rsample's `int_pctl()` function that will compute percentile confidence intervals on performance metrics for objects produced by `fit_resamples()`, `tune_*()`, and `last_fit()`.

* Fixes bug where `.config` entries in the `.extracts` column in `tune_bayes()` output didn't align with the entries they ought to in the `.metrics` and `.predictions` columns (#715).

# tune 1.1.2

* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701).
Expand Down
6 changes: 6 additions & 0 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ tune_bayes_workflow <-
~ dplyr::mutate(., .config = paste0("Iter", i))
)
}
if (".extracts" %in% names(tmp_res)) {
tmp_res[[".extracts"]] <- purrr::map(
tmp_res[[".extracts"]],
~ dplyr::mutate(., .config = paste0("Iter", i))
)
}
unsummarized <- dplyr::bind_rows(unsummarized, tmp_res %>% mutate(.iter = i))
rs_estimate <- estimate_tune_results(tmp_res)
mean_stats <- dplyr::bind_rows(mean_stats, rs_estimate %>% dplyr::mutate(.iter = i))
Expand Down
Binary file modified tests/testthat/data/test_objects.RData
Binary file not shown.
19 changes: 19 additions & 0 deletions tests/testthat/test-extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ test_that("check .config in extracts", {
expect_true(any(names(mt_spln_lm_bo$.extracts[[i]]) == ".config"))
}

recipe_only_configs <-
full_join(
mt_spln_lm_bo %>%
filter(id == first(id)) %>%
select(.iter, .metrics) %>%
unnest(cols = .metrics) %>%
filter(.metric == first(.metric)),
mt_spln_lm_bo %>%
filter(id == first(id)) %>%
select(.iter, .extracts) %>%
unnest(cols = .extracts),
by = c(".iter", "deg_free")
)

expect_equal(
recipe_only_configs$.config.x,
recipe_only_configs$.config.y
)

# recipe and model
for (i in 1:nrow(mt_spln_knn_grid)) {
expect_true(any(names(mt_spln_knn_grid$.extracts[[i]]) == ".config"))
Expand Down

0 comments on commit 74854a5

Please sign in to comment.