Skip to content

Commit

Permalink
Merge pull request #28 from MatsuuraKentaro/accelerate-compute-state
Browse files Browse the repository at this point in the history
accelerate compute_state()
  • Loading branch information
hoxo-m authored Nov 9, 2024
2 parents 60e646c + d373459 commit 7fda1cf
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@ sample_variance <- function(x) {
var(x) * (length(x) - 1L) / length(x)
}

sample_standard_deviation <- function(x) {
sample_std_dev <- function(x) {
sqrt(sample_variance(x))
}

# The state s is described in Section 2.3 of the original paper
compute_state <- function(actions, resps, N_total) {
# Check argument
count_per_action <- tapply(resps, actions, length)
stopifnot(length(actions) == length(resps))

resps_per_action <- split(resps, actions)
count_per_action <- vapply(resps_per_action, length, integer(1L), USE.NAMES = FALSE)
stopifnot("the number of allocated subjects at each dose should be >= 2" = count_per_action >= 2L)

mean_resps <- tapply(resps, actions, mean)
mean_resps <- vapply(resps_per_action, mean, double(1L), USE.NAMES = FALSE)
shifted_mean_resps <- mean_resps[-1L] - mean_resps[1L]
sd_resps <- tapply(resps, actions, sample_standard_deviation)
sd_resps <- vapply(resps_per_action, sample_std_dev, double(1L), USE.NAMES = FALSE)
proportion_per_action <- count_per_action / N_total
state <- as.array(unname(c(shifted_mean_resps, sd_resps, proportion_per_action)))

state <- as.array(c(shifted_mean_resps, sd_resps, proportion_per_action))
state
}

Expand Down

0 comments on commit 7fda1cf

Please sign in to comment.