Skip to content

Commit

Permalink
Merge 4d186fc into bcf297c
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored Jan 22, 2024
2 parents bcf297c + 4d186fc commit 99c3559
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 13 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
## Model changes

* Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in # and reviewed by @sbfnk.
* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk.

# EpiNow2 1.4.0

Expand Down
3 changes: 2 additions & 1 deletion R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ create_obs_model <- function(obs = obs_opts(), dates) {
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.numeric(length(obs$scale) != 0),
accumulate = obs$accumulate,
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)
Expand Down Expand Up @@ -481,7 +482,7 @@ create_stan_data <- function(reported_cases, seeding_time,
is.na(data$prior_infections) || is.null(data$prior_infections),
0, data$prior_infections
)
if (data$seeding_time > 1) {
if (data$seeding_time > 1 && nrow(first_week) > 1) {
safe_lm <- purrr::safely(stats::lm)
data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
data$prior_growth <- ifelse(is.null(data$prior_growth), 0,
Expand Down
1 change: 1 addition & 0 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ estimate_secondary <- function(reports,
data <- list(
t = nrow(reports),
obs = reports$secondary,
obs_time = seq_along(reports$secondary),
primary = reports$primary,
burn_in = burn_in,
seeding_time = 0
Expand Down
13 changes: 13 additions & 0 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,15 @@ gp_opts <- function(basis_prop = 0.2,
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#'
#' @param na Character. Options are "missing" (the default) and "accumulate".
#' This determines how NA values in the data are interpreted. If set to
#' "missing", any NA values in the observation data set will be interpreted as
#' missing and skipped in the likelihood. If set to "accumulate", modelled
#' observations will be accumulated and added to the next non-NA data point.
#' This can be used to model incidence data that is reported at less than
#' daily intervals. If set to "accumulate", the first data point is not
#' included in the data point but used only to reset modelled observations to
#' zero.
#' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be
#' included in the model.
#'
Expand All @@ -471,18 +480,22 @@ obs_opts <- function(family = "negbin",
week_effect = TRUE,
week_length = 7,
scale = list(),
na = "missing",
likelihood = TRUE,
return_likelihood = FALSE) {
if (length(phi) != 2 || !is.numeric(phi)) {
stop("phi be numeric and of length two")
}
na <- arg_match(na, values = c("missing", "accumulate"))

obs <- list(
family = arg_match(family, values = c("poisson", "negbin")),
phi = phi,
weight = weight,
week_effect = week_effect,
week_length = week_length,
scale = scale,
accumulate = as.integer(na == "accumulate"),
likelihood = likelihood,
return_likelihood = return_likelihood
)
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
real obs_weight; // weight given to observation in log density
int likelihood; // Should the likelihood be included in the model
int return_likelihood; // Should the likehood be returned by the model
int accumulate; // Should missing values be accumulated
int<lower = 0> trunc_id; // id of truncation
int<lower = 0> delay_id; // id of delay
4 changes: 2 additions & 2 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ model {
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(
cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type,
obs_weight
cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type,
obs_weight, accumulate
);
}
}
Expand Down
5 changes: 3 additions & 2 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ functions {
data {
int t; // time of observations
array[t] int<lower = 0> obs; // observed secondary data
array[t] int obs_time; // observed secondary data
vector[t] primary; // observed primary data
int burn_in; // time period to not use for fitting
#include data/secondary.stan
Expand Down Expand Up @@ -83,8 +84,8 @@ model {
}
// observed secondary reports from mean of secondary reports (update likelihood)
if (likelihood) {
report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1);
report_lp(obs[(burn_in + 1):t], obs_time, secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1, accumulate);
}
}

Expand Down
34 changes: 28 additions & 6 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,44 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}
// update log density for reported cases
void report_lp(array[] int cases, vector reports,
void report_lp(array[] int cases, array[] int cases_time, vector reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight) {
int model_type, real weight, int accumulate) {
int n = num_elements(cases) - accumulate; // number of observations
vector[n] obs_reports; // reports at observation time
array[n] int obs_cases; // observed cases at observation time
if (accumulate) {
int t = num_elements(reports);
int current_obs = 0;
obs_reports = rep_vector(0, n);
for (i in 1:t) {
if (current_obs > 0) { // first observation gets ignored when acucmulating
obs_reports[current_obs] += reports[i];
}
if (i == cases_time[current_obs]) {
current_obs += 1;
}
}
obs_cases = cases[2:(n - 1)];
} else {
obs_reports = reports[cases_time];
obs_cases = cases;
}
if (model_type) {
real dispersion = 1 / pow(rep_phi[model_type], 2);
rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,];
if (weight == 1) {
cases ~ neg_binomial_2(reports, dispersion);
obs_cases ~ neg_binomial_2(obs_reports, dispersion);
} else {
target += neg_binomial_2_lpmf(cases | reports, dispersion) * weight;
target += neg_binomial_2_lpmf(
obs_cases | obs_reports, dispersion
) * weight;
}
} else {
if (weight == 1) {
cases ~ poisson(reports);
obs_cases ~ poisson(obs_reports);
} else {
target += poisson_lpmf(cases | reports) * weight;
target += poisson_lpmf(obs_cases | obs_reports) * weight;
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions man/obs_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-create_obs_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ dates <- seq(as.Date("2020-03-15"), by = "days", length.out = 15)

test_that("create_obs_model works with default settings", {
obs <- create_obs_model(dates = dates)
expect_equal(length(obs), 11)
expect_equal(length(obs), 12)
expect_equal(names(obs), c(
"model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight",
"obs_scale", "likelihood", "return_likelihood",
"obs_scale", "accumulate", "likelihood", "return_likelihood",
"day_of_week", "obs_scale_mean",
"obs_scale_sd"
))
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ test_that("estimate_infections successfully returns estimates when passed NA val
test_estimate_infections(reported_cases_na)
})

test_that("estimate_infections successfully returns estimates when accumulating to weekly", {
skip_on_cran()
reported_cases_weekly <- data.table::copy(reported_cases)
reported_cases_weekly[, confirm := frollsum(confirm, 7)]
reported_cases_weekly <-
reported_cases_weekly[seq(7, nrow(reported_cases_weekly), 7)]
test_estimate_infections(reported_cases_weekly, obs = obs_opts(na = "accumulate"))
})

test_that("estimate_infections successfully returns estimates using no delays", {
skip_on_cran()
Expand Down

0 comments on commit 99c3559

Please sign in to comment.