Skip to content

Commit

Permalink
modify compute_reward_TD to precompute the lower and upper bounds for…
Browse files Browse the repository at this point in the history
… all models
  • Loading branch information
hoxo-m committed Dec 7, 2024
1 parent 0a25397 commit a809b88
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions R/generate_setup_code.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ generate_setup_code <- function(
true_responses <- unname(unlist(true_response_list))

# Utility functions
lower_all <- DoseFinding::TD(rl_models, Delta = Delta_lower, direction = direction)
upper_all <- DoseFinding::TD(rl_models, Delta = Delta_upper, direction = direction)

compute_reward_TD <- function(estimated_target_dose, true_model_name) {
# estimated_target_dose is possibly NA
if (is.na(estimated_target_dose)) return(0)

# Note: Calculating TD for all DR models is costly, but extracting
# a single model from 'rl_models' is troublesome.
# The overhead is about 100ms per execution.
lower <- DoseFinding::TD(rl_models, Delta = Delta_lower, direction = direction)
lower <- unname(lower[true_model_name])
upper <- DoseFinding::TD(rl_models, Delta = Delta_upper, direction = direction)
upper <- unname(upper[true_model_name])
lower <- lower_all[[true_model_name]]
upper <- upper_all[[true_model_name]]

# TD returns NA for large Delta, so if NA, return Inf.
lower <- ifelse(is.na(lower), Inf, lower)
upper <- ifelse(is.na(upper), Inf, upper)
Expand Down

0 comments on commit a809b88

Please sign in to comment.