Skip to content

Commit

Permalink
use R6 class to track unmber of shuffling steps
Browse files Browse the repository at this point in the history
  • Loading branch information
idavydov committed Mar 13, 2024
1 parent 84f7e8b commit eff8b71
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions tests/testthat/test_optimize_design_simple_shuffle.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,83 +6,89 @@ samples <- data.frame(
sampleId = seq_len(10)
)

n_elements_changed <- function(bc) {
df <- bc$get_samples(include_id = TRUE, as_tibble = FALSE)
cur_state <- df$.sample_id
cur_state <- tidyr::replace_na(cur_state, -1)
n_changed <<- c(n_changed, sum(start_state != cur_state))
# never accept the change
Inf
}
# This is a helper class which provides a scoring function and conunts
# internally the number of elements shuffled at each iteration.
NElementsChanged <- R6::R6Class("NElementsChanged",
list(
scoring_f = function(bc) {
df <- bc$get_samples(include_id = TRUE, as_tibble = FALSE)
cur_state <- df$.sample_id
cur_state <- tidyr::replace_na(cur_state, -1)
self$n <- c(self$n, sum(start_state != cur_state))
# never accept the change
Inf
},
n = numeric(0)
)
)

bc <- assign_in_order(bc, samples)

scoring_f <- n_elements_changed
bc <- assign_in_order(bc, samples)

set.seed(42)

start_state <- ifelse(is.na(bc$assignment), -1, bc$assignment)

n_changed <- numeric(0)
n_changed <- NElementsChanged$new()
test_that("correct number of shuffles = 1", {
optimize_design(
bc,
scoring = scoring_f,
scoring = n_changed$scoring_f,
max_iter = 10,
check_score_variance = F,
autoscale_scores = F
)
expect_equal(n_changed, c(0, rep(2, 10)))
expect_equal(n_changed$n, c(0, rep(2, 10)))
})

n_changed <- numeric(0)
n_changed <- NElementsChanged$new()
test_that("correct number of shuffles = 2", {
optimize_design(
bc,
scoring = scoring_f,
scoring = n_changed$scoring_f,
max_iter = 10,
n_shuffle = 2,
check_score_variance = F,
autoscale_scores = F
)
expect_equal(n_changed, c(0, rep(4, 10)))
expect_equal(n_changed$n, c(0, rep(4, 10)))
})

n_changed <- numeric(0)
n_changed <- NElementsChanged$new()
test_that("correct number of shuffles = 5", {
optimize_design(
bc,
scoring = scoring_f,
scoring = n_changed$scoring_f,
max_iter = 10,
n_shuffle = 5,
check_score_variance = FALSE,
autoscale_scores = FALSE
)
expect_equal(n_changed, c(0, rep(10, 10)))
expect_equal(n_changed$n, c(0, rep(10, 10)))
})

n_changed <- numeric(0)
n_changed <- NElementsChanged$new()
test_that("specify too many shuffles", {
optimize_design(
bc,
scoring = scoring_f,
scoring = n_changed$scoring_f,
max_iter = 10,
n_shuffle = 40,
check_score_variance = FALSE,
autoscale_scores = FALSE
)
expect_equal(n_changed, c(0, rep(20, 10)))
expect_equal(n_changed$n, c(0, rep(20, 10)))
})

n_changed <- numeric(0)
n_changed <- NElementsChanged$new()
test_that("complex shuffling schedule", {
optimize_design(
bc,
scoring = scoring_f,
scoring = n_changed$scoring_f,
max_iter = 10,
n_shuffle = c(2, 2, 5, 2, 2, 10, 20, 40, 40),
check_score_variance = F,
autoscale_scores = F
)
expect_equal(n_changed, c(0, c(4, 4, 10, 4, 4, 20, 20, 20, 20)))
expect_equal(n_changed$n, c(0, c(4, 4, 10, 4, 4, 20, 20, 20, 20)))
})

0 comments on commit eff8b71

Please sign in to comment.