Skip to content

Commit

Permalink
Merge pull request #27 from MatsuuraKentaro/extract-compute-state
Browse files Browse the repository at this point in the history
extract compute_state()
  • Loading branch information
hoxo-m authored Nov 9, 2024
2 parents 51d2c7f + 3475f3e commit 60e646c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
9 changes: 1 addition & 8 deletions R/allocation_rule.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ AllocationRule <- R6Class(
data_actions <- actions[as.character(data_doses)]

# Obtain the probabilities of next actions
count_per_action <- tapply(data_resps, data_actions, length)
# Check argument
stopifnot("the number of allocated subjects at each dose should be >= 2" = count_per_action >= 2L)
mean_resps <- tapply(data_resps, data_actions, mean)
shifted_mean_resps <- mean_resps[-1] - mean_resps[1]
sd_resps <- tapply(data_resps, data_actions, function(x) sd(x)*sqrt((length(x) - 1)/length(x)))
proportion_per_action <- count_per_action / N_total
state <- as.array(unname(c(shifted_mean_resps, sd_resps, proportion_per_action)))
state <- compute_state(data_actions, data_resps, N_total)
info <- policy$compute_single_action(state, full_fetch = TRUE)[[3L]]
action_probs <- info$action_dist_inputs # array
action_probs <- as.vector(action_probs) # cast to numeric vector
Expand Down
21 changes: 21 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@ compute_MAE <- function(estimated_response, true_response) {
return(mean(abs(errors)))
}

sample_variance <- function(x) {
var(x) * (length(x) - 1L) / length(x)
}

sample_standard_deviation <- function(x) {
sqrt(sample_variance(x))
}

compute_state <- function(actions, resps, N_total) {
# Check argument
count_per_action <- tapply(resps, actions, length)
stopifnot("the number of allocated subjects at each dose should be >= 2" = count_per_action >= 2L)

mean_resps <- tapply(resps, actions, mean)
shifted_mean_resps <- mean_resps[-1L] - mean_resps[1L]
sd_resps <- tapply(resps, actions, sample_standard_deviation)
proportion_per_action <- count_per_action / N_total
state <- as.array(unname(c(shifted_mean_resps, sd_resps, proportion_per_action)))
state
}

is_apple_silicon <- function() {
sys_info <- Sys.info()
sys_info["sysname"] == "Darwin" && sys_info["machine"] == "arm64"
Expand Down

0 comments on commit 60e646c

Please sign in to comment.