diff --git a/.Rbuildignore b/.Rbuildignore index 9ad6ce1..8cad9b1 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -7,3 +7,5 @@ ^pkgdown$ ^codecov\.yml$ ^\.travis\.yml$ +^doc$ +^Meta$ diff --git a/.gitignore b/.gitignore index 343e9b9..4bca403 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ .Rproj.user inst/doc +doc +Meta diff --git a/DESCRIPTION b/DESCRIPTION index 0bceaa4..3b90b9e 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,12 +2,12 @@ Package: adjustr Encoding: UTF-8 Type: Package Title: Stan Model Adjustments and Sensitivity Analyses using Importance Sampling -Version: 0.0.0.9000 +Version: 0.1.0 Authors@R: person("Cory", "McCartan", email = "cmccartan@g.harvard.edu", role = c("aut", "cre")) Description: Functions to help assess the sensitivity of a Bayesian model (fitted using the rstan pakcage) to the specification of its likelihood and - priors.Users provide a series of alternate sampling specifications, and the + priors. Users provide a series of alternate sampling specifications, and the package uses Pareto-smoothed importance sampling to estimate posterior quantities of interest under each specification. License: BSD_3_clause + file LICENSE @@ -15,15 +15,14 @@ Depends: R (>= 3.6.0) Imports: tibble, tidyselect, - dplyr, + dplyr (>= 1.0.0), purrr, + stringr, methods, utils, stats, rlang, rstan, - stringr, - dparser, ggplot2, loo Suggests: @@ -35,5 +34,5 @@ Suggests: rmarkdown URL: https://corymccartan.github.io/adjustr/ LazyData: true -RoxygenNote: 7.1.0 +RoxygenNote: 7.1.1 VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index f251831..19ec0c2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,7 +3,6 @@ S3method(arrange,adjustr_spec) S3method(as.data.frame,adjustr_spec) S3method(length,adjustr_spec) -S3method(plot,adjustr_weighted) S3method(print,adjustr_spec) S3method(pull,adjustr_weighted) S3method(rename,adjustr_spec) @@ -15,6 +14,7 @@ export(adjust_weights) export(extract_samp_stmts) export(get_resampling_idxs) export(make_spec) +export(spec_plot) import(dplyr) import(ggplot2) import(rlang) diff --git a/NEWS.md b/NEWS.md new file mode 100644 index 0000000..07159a5 --- /dev/null +++ b/NEWS.md @@ -0,0 +1,5 @@ +# adjustr 0.1.0 + +* Initial release. + +* Basic workflow implemented: `make_spec()`, `adjust_weights()`, and `summarize()`/`spec_plot()`. \ No newline at end of file diff --git a/R/adjust_weights.R b/R/adjust_weights.R index 5c264de..5e252d8 100644 --- a/R/adjust_weights.R +++ b/R/adjust_weights.R @@ -18,15 +18,19 @@ #' posterior, and which as a result cannot be reliably estimated using #' importance sampling (i.e., if the Pareto shape parameter is larger than #' 0.7), have their weights discarded. +#' @param incl_orig When \code{TRUE}, include a row for the original +#' model specification, with all weights equal. Can facilitate comaprison +#' and plotting later. #' #' @return A tibble, produced by converting the provided \code{specs} to a #' tibble (see \code{\link{as.data.frame.adjustr_spec}}), and adding columns #' \code{.weights}, containing vectors of weights for each draw, and #' \code{.pareto_k}, containing the diagnostic Pareto shape parameters. Values #' greater than 0.7 indicate that importance sampling is not reliable. -#' Weights can be extracted with the \code{\link{pull.adjustr_weighted}} -#' method. The returned object also includes the model sample draws, in the -#' \code{draws} attribute. +#' If \code{incl_orig} is \code{TRUE}, a row is added for the original model +#' specification. Weights can be extracted with the +#' \code{\link{pull.adjustr_weighted}} method. The returned object also +#' includes the model sample draws, in the \code{draws} attribute. #' #' @examples \dontrun{ #' model_data = list( @@ -46,27 +50,25 @@ #' } #' #' @export -adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) { +adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE, incl_orig=TRUE) { # CHECK ARGUMENTS object = get_fit_obj(object) model_code = object@stanmodel@model_code stopifnot(is.adjustr_spec(spec)) - parsed_model = parse_model(model_code) - parsed_vars = get_variables(parsed_model) - parsed_samp = get_sampling_stmts(parsed_model) + parsed = parse_model(model_code) # if no model data provided, we can only resample distributions of parameters if (is.null(data)) { - samp_vars = map_chr(parsed_samp, ~ as.character(f_lhs(.))) - prior_vars = parsed_vars[samp_vars] != "data" - parsed_samp = parsed_samp[prior_vars] + samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.))) + prior_vars = parsed$vars[samp_vars] != "data" + parsed$samp = parsed$samp[prior_vars] data = list() } - matched_samp = match_sampling_stmts(spec$samp, parsed_samp) - original_lp = calc_original_lp(object, matched_samp, parsed_vars, data) - specs_lp = calc_specs_lp(object, spec$samp, parsed_vars, data, spec$params) + matched_samp = match_sampling_stmts(spec$samp, parsed$samp) + original_lp = calc_original_lp(object, matched_samp, parsed$vars, data) + specs_lp = calc_specs_lp(object, spec$samp, parsed$vars, data, spec$params) # compute weights wgts = map(specs_lp, function(spec_lp) { @@ -95,6 +97,14 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) { attr(adjust_obj, "draws") = rstan::extract(object) attr(adjust_obj, "data") = data attr(adjust_obj, "iter") = object@sim$chains * (object@sim$iter - object@sim$warmup) + if (incl_orig) { + adjust_obj = bind_rows(adjust_obj, tibble( + .weights=list(rep(1, attr(adjust_obj, "iter"))), + .pareto_k = -Inf)) + samp_cols = stringr::str_detect(names(adjust_obj), "\\.samp") + adjust_obj[nrow(adjust_obj), samp_cols] = "" + } + adjust_obj } @@ -141,19 +151,17 @@ pull.adjustr_weighted = function(.data, var=".weights") { extract_samp_stmts = function(object) { model_code = get_fit_obj(object)@stanmodel@model_code - parsed_model = parse_model(model_code) - parsed_vars = get_variables(parsed_model) - parsed_samp = get_sampling_stmts(parsed_model) + parsed = parse_model(model_code) - samp_vars = map_chr(parsed_samp, ~ as.character(f_lhs(.))) + samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.))) type = map_chr(samp_vars, function(var) { - if (stringr::str_ends(parsed_vars[var], "data")) "data" else "parameter" + if (stringr::str_ends(parsed$vars[var], "data")) "data" else "parameter" }) print_order = order(type, samp_vars, decreasing=c(T, F)) cat(paste0("Sampling statements for model ", object@model_name, ":\n")) - purrr::walk(print_order, ~ cat(sprintf(" %-9s %s\n", type[.], as.character(parsed_samp[.])))) - invisible(parsed_samp) + purrr::walk(print_order, ~ cat(sprintf(" %-9s %s\n", type[.], as.character(parsed$samp[.])))) + invisible(parsed$samp) } # Check that the model object is correct, and extract its Stan code diff --git a/R/adjustr-package.R b/R/adjustr-package.R index df5e404..53b3880 100644 --- a/R/adjustr-package.R +++ b/R/adjustr-package.R @@ -34,9 +34,9 @@ pkg_env = new_environment() .onLoad = function(libname, pkgname) { # nocov start # create the Stan parser - tryCatch(get_parser(), error = function(e) {}) + #tryCatch(get_parser(), error = function(e) {}) utils::globalVariables(c("name", "pos", "value", ".y", ".y_ol", ".y_ou", - ".y_il", ".y_iu", ".y_med")) + ".y_il", ".y_iu", ".y_med", "quantile", "median")) } # nocov end #> NULL \ No newline at end of file diff --git a/R/make_spec.R b/R/make_spec.R index 15eac13..f38f410 100644 --- a/R/make_spec.R +++ b/R/make_spec.R @@ -26,7 +26,7 @@ #' frame, each entry in each column will be substituted into the corresponding #' parameter in the sampling statements. #' -#' List arguments are coerced to data frame. They can either be lists of named +#' List arguments are coerced to data frames. They can either be lists of named #' vectors, or lists of lists of single-element named vector. #' #' The lengths of all parameter arguments must be consistent. Named vectors diff --git a/R/mockup.R b/R/mockup.R deleted file mode 100644 index 187eca3..0000000 --- a/R/mockup.R +++ /dev/null @@ -1,111 +0,0 @@ -if (F) { -library(rlang) -library(dplyr) -library(purrr) -library(stringr) -library(dparser) -library(rstan) -rstan_options(auto_write = TRUE) - -model_code = "data { - int J; // number of schools - real y[J]; // estimated treatment effects - real sigma[J]; // standard error of effect estimates -} -parameters { - real mu; // population treatment effect - real tau; // standard deviation in treatment effects - vector[J] eta; // unscaled deviation from mu by school -} -transformed parameters { - vector[J] theta = mu + tau * eta; // school treatment effects -} -model { - eta ~ std_normal(); - y ~ normal(theta, sigma); -}" - -model_d = list(J = 8, - y = c(28, 8, -3, 7, -1, 1, 18, 12), - sigma = c(15, 10, 16, 11, 9, 11, 10, 18)) -eightschools_m = stan(model_code=model_code, chains=2, data=model_d, warmup=500, - iter=510, save_dso=F, save_warmup=F) -eightschools_m@stanmodel@dso = new("cxxdso") -save(eightschools_m, file="tests/test_model.rda") -load("tests/test_model.rda") - -#slot(eightschools_m@stanmodel, "dso", F) = NULL -#draws = extract(eightschools_m) - - - -grammar = paste(readLines("R/stan.dpg"), collapse="\n") -parse_func = dparse(grammar, set_op_priority_from_rule=T, longest_match=T) - - -model_code = readr::read_file("~/Documents/Analyses/elections/president/stan/polls.stan") - - - -prog_sections = filter(d, name=="program", value != "") -sec_names = str_trim(str_extract(prog_sections$value, "^.+(?=\\{)")) -id_section = Vectorize(function(i) - sec_names[which.max(i <= c(prog_sections$i, Inf)) - 1]) -prog_vars_d = filter(d, name=="var_decl", pos==1) -prog_vars = id_section(prog_vars_d$i) -names(prog_vars) = prog_vars_d$value - -samp_stmts = d %>% - filter(name == "sampling_statement", pos == -2) %>% - pmap(function(value, ...) as.formula(value, env=global_env())) - -samp_lhs = map_chr(samp_stmts, ~ prog_vars[as.character(f_lhs(.))]) -get_rhs_deps = function(func) { - map_chr(call_args(f_rhs(func)), ~ prog_vars[as.character(.)]) -} -samp_rhs = map(samp_stmts, get_rhs_deps) - - - -make_dens = function(f) { - function(x) { - function(...) { - f(x, ..., log=T) - } - } -} -distrs = list( - normal = dnorm, - std_normal = dnorm, - student_t = dt, - exponential = dexp, - gamma = function(x, alpha, beta, ...) dgamma(x, shape=alpha, rate=beta, ...) -) -distr_env = new_environment(map(distrs, make_dens), parent=global_env()) - - -form = sigma_natl ~ gamma(alpha, 3/mean) -ref_form = samp_stmts[[5]] -xr = f_rhs(form) -xl = f_lhs(form) -draws = list(sigma_natl = rgamma(500, 2, 3/0.06)) - -combos = list(alpha = c(1, 2, 3), - mean = seq(0.01, 0.1, 0.01)) %>% - cross_df - -xr - -eval_tidy(xr, combos, distr_env) -eval_tidy(xl, combos, distr_env) -call_fn(xr, distr_env)(eval_tidy(xl, draws))(4, 3) - -call_fn(f_rhs(ref_form), distr_env)(eval_tidy(f_lhs(ref_form), draws)) - - - -specs = cross(list(df=1:4, mean=0)) -make_rs_weights(sm, eta ~ student_t(df, mean, 1), specs) - -} - diff --git a/R/parsing.R b/R/parsing.R index 13ecd36..850b074 100644 --- a/R/parsing.R +++ b/R/parsing.R @@ -1,53 +1,75 @@ -# build the parser and store it -get_parser = function() { # nocov start - if (env_has(pkg_env, "parse_func")) { - return(pkg_env$parse_func) - } else { - grammar_file = system.file("stan.g", package="adjustr") - grammar = paste(readLines(grammar_file), collapse="\n") - pkg_env$parse_func = dparser::dparse(grammar) - return(pkg_env$parse_func) - } -} # nocov end +# regexes +identifier = "[a-zA-Z][a-zA-Z0-9_]*" +re_stmt = paste0("(int|real|(?:unit_|row_)?vector|(?:positive_)?ordered|simplex", + "|(?:cov_|corr_)?matrix|cholesky_factor(?:_corr|_cov)?)(?:<.+>)?(?:\\[.+\\])?", + " (", identifier, ")(?:\\[.+\\])? ?=?") +#re_block = paste0(block_names, " ?\\{ ?(.+) ?\\} ?", block_names, "") +re_block = "((?:transformed )?data|(?:transformed )?parameters|model|generated quantities)" +re_samp = paste0("(", identifier, " ?~[^~{}]+)") +re_samp2 = paste0("target ?\\+= ?(", identifier, ")_lp[md]f\\((", + identifier, ")(?:| ?[|]? ?(.+))\\)") + +# Extract variable name from variable declaration, or return NA if no declaration +get_variables = function(statement) { + matches = stringr::str_match(statement, re_stmt)[,3] + matches[!is.na(matches)] +} -# Parse Stan `model_code` into a data frame which represents the parsing tree +get_sampling = function(statement) { + samps = stringr::str_match(statement, re_samp)[,2] + samps2 = stringr::str_match(statement, re_samp2)#[,,3] + samps2_rearr = paste0(samps2[,3], " ~ ", samps2[,2], "(", coalesce(samps2[,4], ""), ")") + stmts = c(samps[!is.na(samps)], samps2_rearr[!is.na(samps2[,1])]) + map(stmts, ~ stats::as.formula(., env=empty_env())) +} + +# Parse Stan `model_code` into a list with two elements: `vars` named +# vector, with the names matching the model's variable names and the values +# representing the program blocks they are defined in; `samp` is a list of +# sampling statements (as formulas) parse_model = function(model_code) { - parser_output = utils::capture.output( - get_parser()(model_code, function(name, value, pos, depth) { - cat('"', name, '","', value, '",', pos, ',', depth, '\n', sep="") + clean_code = stringr::str_replace_all(model_code, "//.*", "") %>% + stringr::str_replace_all("/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/", "") %>% + stringr::str_replace_all("\\n", " ") %>% + stringr::str_replace_all("\\s\\s+", " ") + + block_names = stringr::str_extract_all(clean_code, re_block)[[1]] + if (length(block_names)==0) return(list(vars=character(0), samps=list())) + + block_locs = rbind(stringr::str_locate_all(clean_code, re_block)[[1]], + c(nchar(clean_code), NA)) + blocks = map(1:length(block_names), function(i) { + block = stringr::str_sub(clean_code, block_locs[i,2]+1, block_locs[i+1,1]) + start = stringr::str_locate_all(block, stringr::fixed("{"))[[1]][1,1] + 1 + end = utils::tail(stringr::str_locate_all(block, stringr::fixed("}"))[[1]][,1], 1) - 1 + stringr::str_trim(stringr::str_sub(block, start+1, end-1)) }) - ) - parser_csv = paste0("name,value,pos,depth\n", - paste(parser_output, collapse="\n")) - parsed = utils::read.csv(text=parser_csv, as.is=T) - parsed$i = 1:nrow(parsed) - parsed -} + names(blocks) = block_names -# Take a parsing tree and return a named vector, with the names matching the -# model's variable names and the values representing the program blocks they -# are defined in -get_variables = function(parsed_model) { - prog_sections = filter(parsed_model, stringr::str_starts(name, "program__"), pos==-2) - sec_names = stringr::str_extract(prog_sections$value, "^.+(?=\\{)") %>% - stringr::str_trim() - id_section = Vectorize(function(i) - sec_names[which.max(i <= c(prog_sections$i, Inf)) - 1]) + statements = map(blocks, ~ stringr::str_split(., "; ?", simplify=T)[1,]) - prog_vars_d = filter(parsed_model, name=="var_decl", pos==1) - prog_vars = id_section(prog_vars_d$i) - names(prog_vars) = prog_vars_d$value - prog_vars -} + vars = map(statements, get_variables) + vars = purrr::flatten_chr(purrr::imap(vars, function(name, block) { + block = rep(block, length(name)) + names(block) = name + block + })) + + + samps = map(statements, get_sampling) + names(samps) = NULL + samps = flatten(samps) -# Take a parsing tree and return a list of formulas, one for each samping statement -# in the model -get_sampling_stmts = function(parsed_model) { - parsed_model %>% - filter(name == "sampling_statement", pos == -2) %>% - mutate(value = stringr::str_replace(value, " \\.\\*", "*")) %>% - purrr::pmap(function(value, ...) stats::as.formula(value, env=empty_env())) + parameters = names(vars)[vars == "parameters"] + sampled_pars = map_chr(samps, ~ as.character(f_lhs(.))) + uniform_pars = setdiff(parameters, sampled_pars) + uniform_samp = paste0(uniform_pars, " ~ uniform(-1e100, 1e100)") + uniform_samp = map(uniform_samp, ~ stats::as.formula(., env=empty_env())) + + list(vars=vars, samp=c(samps, uniform_samp)) } + + # Take a list of provided sampling formulas and return a matching list of # sampling statements from a reference list match_sampling_stmts = function(new_samp, ref_samp) { @@ -56,7 +78,7 @@ match_sampling_stmts = function(new_samp, ref_samp) { indices = match(new_vars, ref_vars) # check that every prior was matched if (any(is.na(indices))) { - stop("No matching sampling statement found for prior ", + stop("No matching sampling statement found for ", new_samp[which.max(is.na(indices))], "\n Check sampling statements and ensure that model data ", "has been provided.") diff --git a/R/use_weights.R b/R/use_weights.R index b56e510..6aca2ed 100644 --- a/R/use_weights.R +++ b/R/use_weights.R @@ -57,6 +57,12 @@ get_resampling_idxs = function(x, frac=1, replace=T) { #' a value of \code{mean(theta)} will compute the posterior mean of #' \code{theta} for each alternative specification. #' +#' Also supported is the custom function \code{wasserstein}, which computes +#' the Wasserstein-p distance between the posterior distribution of the +#' provided expression under the new model and under the original model, with +#' \code{p=1} the default. Lower the \code{spacing} parameter from the +#' default of 0.005 to compute a finer (but slower) approximation. +#' #' The arguments in \code{...} are automatically quoted and evaluated in the #' context of \code{.data}. They support unquoting and splicing. #' @param .resampling Whether to compute summary statistics by first resampling @@ -80,6 +86,7 @@ get_resampling_idxs = function(x, frac=1, replace=T) { #' adjusted = adjust_weights(spec, eightschools_m) #' #' summarize(adjusted, mean(mu), var(mu)) +#' summarize(adjusted, wasserstein(mu, p=2)) #' summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data) #' summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95))) #' } @@ -105,8 +112,13 @@ summarise.adjustr_weighted = function(.data, ..., .resampling=FALSE, .model_data if (name == "") name = expr_name(args[[i]]) call = args[[i]] + if (!is_call(call)) { + stop("Expressions must summarize posterior draws; `", + expr_text(call), "` has a different value for each draw.\n", + " Try `mean(", expr_text(call), ")` or `sd(", expr_text(call), ")`.") + } if (!.resampling && exists(call_name(call), funs_env)) { - fun = call_fn(call, funs_env) + fun = funs_env[[call_name(call)]] } else { fun = function(x, ...) apply(x, 2, call_fn(call), ...) .resampling = T @@ -120,9 +132,12 @@ summarise.adjustr_weighted = function(.data, ..., .resampling=FALSE, .model_data if (length(dim(computed)) == 1) dim(computed) = c(dim(computed), 1) if (!.resampling) { - new_col = map(.data$.weights, ~ fun(computed, .)) + new_col = map(.data$.weights, function(w) { + exec(fun, computed, w, !!!map(call_args(call)[-1], eval_tidy)) + }) } else { - idxs = map(.data$.weights, ~ sample.int(iter, iter, replace=T, prob=.)) + n_idx = max(min(5*iter, 20e3), iter) + idxs = map(.data$.weights, ~ sample.int(iter, n_idx, replace=T, prob=.)) new_col = map(idxs, function(idx) { comp = as.array(computed[idx,]) if (length(dim(comp)) == 1) dim(comp) = c(dim(comp), 1) @@ -140,22 +155,60 @@ summarise.adjustr_weighted = function(.data, ..., .resampling=FALSE, .model_data #' @export summarize.adjustr_weighted = summarise.adjustr_weighted +# Weighted ECDF +weighted.ecdf = function(samp, wgt=rep(1, length(samp))) { + or = order(samp) + y = cumsum(wgt[or])/sum(wgt) + f = stats::stepfun(samp[or], c(0, y)) + class(f) = c("weighted.ecdf", "ecdf", class(f)) + attr(f, "call") = sys.call() + f +} +quantile.weighted.ecdf = function(f, q) { + x = environment(f)$x + y = environment(f)$y + purrr::map_dbl(q, function(q) { + if (q == 0) return(x[1]) + if (q == 1) return(utils::tail(x, 1)) + idx = which(y > q)[1] + if (idx == 1) return(x[1]) + stats::approx(y[idx-0:1], x[idx-0:1], q)$y + }) +} + +weighted.wasserstein = function(samp, wgt, p=1, spacing=0.005) { + f = weighted.ecdf(samp, wgt) + q = seq(0, 1, spacing) + W = mean(abs(stats::quantile(samp, q, names=F, type=4) - quantile.weighted.ecdf(f, q))^p) + if (W < .Machine$double.eps) 0 else W^(1/p) +} + # Weighted summary functions that work on arrays wtd_array_mean = function(arr, wgt) colSums(as.array(arr)*wgt) / sum(wgt) wtd_array_var = function(arr, wgt) wtd_array_mean((arr - wtd_array_mean(arr, wgt))^2, wgt) wtd_array_sd = function(arr, wgt) sqrt(wtd_array_var(arr, wgt)) +wtd_array_quantile = function(arr, wgt, probs=c(0.05, 0.25, 0.5, 0.75, 0.95)) { + apply(arr, 2, function(x) quantile.weighted.ecdf(weighted.ecdf(x, wgt), probs)) +} +wtd_array_median = function(arr, wgt) wtd_array_quantile(arr, wgt, 0.5) +wtd_array_wasserstein = function(arr, wgt, ...) { + apply(arr, 2, function(x) weighted.wasserstein(x, wgt, ...)) +} funs_env = new_environment(list( mean = wtd_array_mean, var = wtd_array_var, - sd = wtd_array_sd + sd = wtd_array_sd, + quantile = wtd_array_quantile, + median = wtd_array_median, + wasserstein = wtd_array_wasserstein )) #' Plot Posterior Quantities of Interest Under Alternative Model Specifications #' #' Uses weights computed in \code{\link{adjust_weights}} to plot posterior -#' quantities of interest versus +#' quantities of interest versus specification parameters #' #' @param x An \code{adjustr_weighted} object. #' @param by The x-axis variable, which is usually one of the specification @@ -179,13 +232,13 @@ funs_env = new_environment(list( #' df=1:10, scale=seq(2, 1, -1/9)) #' adjusted = adjust_weights(spec, eightschools_m) #' -#' plot(adjusted, df, theta[1]) -#' plot(adjusted, df, mu, only_mean=TRUE) -#' plot(adjusted, scale, tau) +#' spec_plot(adjusted, df, theta[1]) +#' spec_plot(adjusted, df, mu, only_mean=TRUE) +#' spec_plot(adjusted, scale, tau) #' } #' #' @export -plot.adjustr_weighted = function(x, by, post, only_mean=FALSE, ci_level=0.8, +spec_plot = function(x, by, post, only_mean=FALSE, ci_level=0.8, outer_level=0.95, ...) { if (!requireNamespace("ggplot2", quietly=TRUE)) { # nocov start stop("Package `ggplot2` must be installed to plot posterior quantities of interest.") @@ -193,13 +246,15 @@ plot.adjustr_weighted = function(x, by, post, only_mean=FALSE, ci_level=0.8, if (ci_level > outer_level) stop("`ci_level` should be less than `outer_level`") post = enexpr(post) + orig_row = filter(x, across(starts_with(".samp"), ~ . == "")) if (!only_mean) { outer = (1 - outer_level) / 2 inner = (1 - ci_level) / 2 q_probs = c(outer, inner, 0.5, 1-inner, 1-outer) - sum_arg = quo(stats::quantile(!!post, probs = !!q_probs)) + sum_arg = quo(quantile(!!post, probs = !!q_probs)) - summarise.adjustr_weighted(x, .y = !!sum_arg) %>% + filter(x, across(starts_with(".samp"), ~ . != "")) %>% + summarise.adjustr_weighted(.y = !!sum_arg) %>% rowwise() %>% mutate(.y_ol = .y[1], .y_il = .y[2], @@ -209,16 +264,25 @@ plot.adjustr_weighted = function(x, by, post, only_mean=FALSE, ci_level=0.8, ggplot(aes({{ by }}, .y_med)) + geom_ribbon(aes(ymin=.y_ol, ymax=.y_ou), alpha=0.4) + geom_ribbon(aes(ymin=.y_il, ymax=.y_iu), alpha=0.5) + + { if (nrow(orig_row) == 1) + geom_hline(yintercept=summarise.adjustr_weighted(orig_row, .y = median(!!post))$`.y`, + lty="dashed") + } + geom_line() + geom_point(size=3) + theme_minimal() + - labs(y= expr_name(post)) + labs(y = expr_name(post)) } else { - summarise.adjustr_weighted(x, .y = mean(!!post)) %>% + filter(x, across(starts_with(".samp"), ~ . != "")) %>% + summarise.adjustr_weighted(.y = mean(!!post)) %>% ggplot(aes({{ by }}, .y)) + + { if (nrow(orig_row) == 1) + geom_hline(yintercept=summarise.adjustr_weighted(orig_row, .y = mean(!!post))$`.y`, + lty="dashed") + } + geom_line() + geom_point(size=3) + theme_minimal() + labs(y = expr_name(post)) } -} +} \ No newline at end of file diff --git a/README.md b/README.md index f634fba..a4f96a9 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,21 @@ [![Codecov test coverage](https://codecov.io/gh/CoryMcCartan/adjustr/branch/master/graph/badge.svg)](https://codecov.io/gh/CoryMcCartan/adjustr?branch=master) -**adjustr** is an R package which provides functions to help assess the -sensitivity of a Bayesian model (fitted with [Stan](https://mc-stan.org)) to the -specification of its likelihood and priors. Users provide a series of alternate -sampling specifications, and the package uses Pareto-smoothed importance -sampling to estimate the posterior under each specification. The package also -provides functions to summarize and plot how posterior quantities quantities -change across specifications. - -The package aims to provide simple interface that makes it as easy as possible +Sensitivity analysis is a critical component of a good modeling workflow. Yet +as the number and power of Bayesian computational tools has increased, the +options for sensitivity analysis have remained largely the same: compute +importance sampling weights manually, or fit a large number of similar models, +dramatically increasing computation time. Neither option is satisfactory for +most applied modeling. + +**adjustr** is an R package which aims to make sensitivity analysis faster +and easier, and works with Bayesian models fitted with [Stan](https://mc-stan.org). +Users provide a series of alternate sampling specifications, and the package +uses Pareto-smoothed importance sampling to estimate the posterior under each +specification. The package also provides functions to summarize and plot how +posterior quantities quantities change across specifications. + +The package provides simple interface that makes it as easy as possible for modellers to try out various adjustments to their Stan models, without needing to write any specific Stan code or even recompile or rerun their model. @@ -25,7 +31,21 @@ complex model templates, and cannot be used. ## Getting Started -The tutorial [vignettes](https://corymccartan.github.io/adjustr/articles/index.html) +The basic __adjustr__ workflow is as follows: + +1. Use [`make_spec`](https://corymccartan.github.io/adjustr/reference/make_spec.html) +to specify the set of alternative model specifications you'd like to fit. + +2. Use [`adjust_weights`](https://corymccartan.github.io/adjustr/reference/adjust_weights.html) +to calculate importance sampling weights which approximate the posterior of each +alternative specification. + +3. Use [`summarize`](https://corymccartan.github.io/adjustr/reference/summarize.adjustr_weighted.html) +and [`spec_plot`](https://corymccartan.github.io/adjustr/reference/spec_plot.html) +to examine posterior quantities of interest for each alternative specification, +in order to assess the sensitivity of the underlying model. + +The tutorial [vignette](https://corymccartan.github.io/adjustr/articles/eight-schools.html) walk through a full sensitivity analysis for the classic 8-schools example. Smaller examples are also included in the package [documentation](https://corymccartan.github.io/adjustr/reference/index.html). diff --git a/_pkgdown.yml b/_pkgdown.yml index 370d297..401d2f5 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -73,7 +73,7 @@ reference: - make_spec - adjust_weights - summarize.adjustr_weighted - - plot.adjustr_weighted + - spec_plot - title: "Helper Functions" desc: > Various helper functions for examining a model or building sampling diff --git a/docs/404.html b/docs/404.html index dcbf5b1..5daad1a 100644 --- a/docs/404.html +++ b/docs/404.html @@ -81,7 +81,7 @@ adjustr - 0.0.0.9000 + 0.1.0 diff --git a/docs/LICENSE-text.html b/docs/LICENSE-text.html index 8bd331e..e9221b2 100644 --- a/docs/LICENSE-text.html +++ b/docs/LICENSE-text.html @@ -81,7 +81,7 @@ adjustr - 0.0.0.9000 + 0.1.0 diff --git a/docs/articles/eight-schools.html b/docs/articles/eight-schools.html index 9983dcf..ad71d6d 100644 --- a/docs/articles/eight-schools.html +++ b/docs/articles/eight-schools.html @@ -37,7 +37,7 @@ adjustr - 0.0.0.9000 + 0.1.0 @@ -115,7 +115,7 @@ -
+
+ +
diff --git a/docs/articles/eight-schools_files/accessible-code-block-0.0.1/empty-anchor.js b/docs/articles/eight-schools_files/accessible-code-block-0.0.1/empty-anchor.js new file mode 100644 index 0000000..ca349fd --- /dev/null +++ b/docs/articles/eight-schools_files/accessible-code-block-0.0.1/empty-anchor.js @@ -0,0 +1,15 @@ +// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) --> +// v0.0.1 +// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020. + +document.addEventListener('DOMContentLoaded', function() { + const codeList = document.getElementsByClassName("sourceCode"); + for (var i = 0; i < codeList.length; i++) { + var linkList = codeList[i].getElementsByTagName('a'); + for (var j = 0; j < linkList.length; j++) { + if (linkList[j].innerHTML === "") { + linkList[j].setAttribute('aria-hidden', 'true'); + } + } + } +}); diff --git a/docs/articles/eight-schools_files/figure-html/unnamed-chunk-12-1.png b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-12-1.png new file mode 100644 index 0000000..1e1d3bb Binary files /dev/null and b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-12-1.png differ diff --git a/docs/articles/eight-schools_files/figure-html/unnamed-chunk-12-2.png b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-12-2.png new file mode 100644 index 0000000..0c4241a Binary files /dev/null and b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-12-2.png differ diff --git a/docs/articles/eight-schools_files/figure-html/unnamed-chunk-13-1.png b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-13-1.png new file mode 100644 index 0000000..f9626cf Binary files /dev/null and b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-13-1.png differ diff --git a/docs/articles/eight-schools_files/figure-html/unnamed-chunk-13-2.png b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-13-2.png new file mode 100644 index 0000000..3bd9a16 Binary files /dev/null and b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-13-2.png differ diff --git a/docs/articles/eight-schools_files/figure-html/unnamed-chunk-3-1.png b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-3-1.png new file mode 100644 index 0000000..3373195 Binary files /dev/null and b/docs/articles/eight-schools_files/figure-html/unnamed-chunk-3-1.png differ diff --git a/docs/articles/index.html b/docs/articles/index.html index 67c4192..5f69efc 100644 --- a/docs/articles/index.html +++ b/docs/articles/index.html @@ -81,7 +81,7 @@ adjustr - 0.0.0.9000 + 0.1.0
diff --git a/docs/authors.html b/docs/authors.html index a965615..f0a36a9 100644 --- a/docs/authors.html +++ b/docs/authors.html @@ -81,7 +81,7 @@ adjustr - 0.0.0.9000 + 0.1.0 diff --git a/docs/index.html b/docs/index.html index c1d1aec..fa76fa9 100644 --- a/docs/index.html +++ b/docs/index.html @@ -38,7 +38,7 @@ adjustr - 0.0.0.9000 + 0.1.0 @@ -123,13 +123,20 @@ adjustr -

adjustr is an R package which provides functions to help assess the sensitivity of a Bayesian model (fitted with Stan) to the specification of its likelihood and priors. Users provide a series of alternate sampling specifications, and the package uses Pareto-smoothed importance sampling to estimate the posterior under each specification. The package also provides functions to summarize and plot how posterior quantities quantities change across specifications.

-

The package aims to provide simple interface that makes it as easy as possible for modellers to try out various adjustments to their Stan models, without needing to write any specific Stan code or even recompile or rerun their model.

+

Sensitivity analysis is a critical component of a good modeling workflow. Yet as the number and power of Bayesian computational tools has increased, the options for sensitivity analysis have remained largely the same: compute importance sampling weights manually, or fit a large number of similar models, dramatically increasing computation time. Neither option is satisfactory for most applied modeling.

+

adjustr is an R package which aims to make sensitivity analysis faster and easier, and works with Bayesian models fitted with Stan. Users provide a series of alternate sampling specifications, and the package uses Pareto-smoothed importance sampling to estimate the posterior under each specification. The package also provides functions to summarize and plot how posterior quantities quantities change across specifications.

+

The package provides simple interface that makes it as easy as possible for modellers to try out various adjustments to their Stan models, without needing to write any specific Stan code or even recompile or rerun their model.

The package works by parsing Stan model code, so everything works best if the model was written by the user. Models made using brms may in principle be used as well. Models made using rstanarm are constructed using more complex model templates, and cannot be used.

Getting Started

-

The tutorial vignettes walk through a full sensitivity analysis for the classic 8-schools example. Smaller examples are also included in the package documentation.

+

The basic adjustr workflow is as follows:

+
    +
  1. Use make_spec to specify the set of alternative model specifications you’d like to fit.

  2. +
  3. Use adjust_weights to calculate importance sampling weights which approximate the posterior of each alternative specification.

  4. +
  5. Use summarize and spec_plot to examine posterior quantities of interest for each alternative specification, in order to assess the sensitivity of the underlying model.

  6. +
+

The tutorial vignette walk through a full sensitivity analysis for the classic 8-schools example. Smaller examples are also included in the package documentation.

diff --git a/docs/news/index.html b/docs/news/index.html new file mode 100644 index 0000000..c510fcd --- /dev/null +++ b/docs/news/index.html @@ -0,0 +1,206 @@ + + + + + + + + +Changelog • adjustr + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+ + + + +
+ +
+
+ + +
+

+adjustr 0.1.0

+ +
+
+ + + +
+ + + +
+ + + + + + + + diff --git a/docs/pkgdown.yml b/docs/pkgdown.yml index 56a868f..f0640b6 100644 --- a/docs/pkgdown.yml +++ b/docs/pkgdown.yml @@ -1,9 +1,9 @@ -pandoc: 2.3.1 +pandoc: 2.7.3 pkgdown: 1.5.1 pkgdown_sha: ~ articles: eight-schools: eight-schools.html -last_built: 2020-04-24T04:26Z +last_built: 2020-08-03T23:18Z urls: reference: https://corymccartan.github.io/adjustr//reference article: https://corymccartan.github.io/adjustr//articles diff --git a/docs/reference/adjust_weights.html b/docs/reference/adjust_weights.html index 0f1213e..8798aa0 100644 --- a/docs/reference/adjust_weights.html +++ b/docs/reference/adjust_weights.html @@ -86,7 +86,7 @@ adjustr - 0.0.0.9000 + 0.1.0

@@ -179,7 +179,7 @@

Compute Pareto-smoothed Importance Weights for Alternative Model them to the specification object, for further calculation and plotting.

-
adjust_weights(spec, object, data = NULL, keep_bad = FALSE)
+
adjust_weights(spec, object, data = NULL, keep_bad = FALSE, incl_orig = TRUE)

Arguments

@@ -193,7 +193,7 @@

Arg

- + @@ -208,6 +208,12 @@

Arg posterior, and which as a result cannot be reliably estimated using importance sampling (i.e., if the Pareto shape parameter is larger than 0.7), have their weights discarded.

+

+ + +
object

A stanfit model object.

A stanfit model object.

data
incl_orig

When TRUE, include a row for the original +model specification, with all weights equal. Can facilitate comaprison +and plotting later.

@@ -218,9 +224,10 @@

Value

.weights, containing vectors of weights for each draw, and .pareto_k, containing the diagnostic Pareto shape parameters. Values greater than 0.7 indicate that importance sampling is not reliable. - Weights can be extracted with the pull.adjustr_weighted - method. The returned object also includes the model sample draws, in the - draws attribute.

+ If incl_orig is TRUE, a row is added for the original model + specification. Weights can be extracted with the + pull.adjustr_weighted method. The returned object also + includes the model sample draws, in the draws attribute.

Examples

if (FALSE) { diff --git a/docs/reference/adjustr.html b/docs/reference/adjustr.html index 1213c2a..e20908e 100644 --- a/docs/reference/adjustr.html +++ b/docs/reference/adjustr.html @@ -88,7 +88,7 @@ adjustr - 0.0.0.9000 + 0.1.0
@@ -197,7 +197,7 @@

make_spec

  • adjust_weights

  • summarize

  • -
  • plot

  • +
  • plot

  • diff --git a/docs/reference/as.data.frame.adjustr_spec.html b/docs/reference/as.data.frame.adjustr_spec.html index 9c8df20..6d8c9d3 100644 --- a/docs/reference/as.data.frame.adjustr_spec.html +++ b/docs/reference/as.data.frame.adjustr_spec.html @@ -84,7 +84,7 @@ adjustr - 0.0.0.9000 + 0.1.0 diff --git a/docs/reference/dplyr.adjustr_spec.html b/docs/reference/dplyr.adjustr_spec.html index 9bc1b44..5970dfb 100644 --- a/docs/reference/dplyr.adjustr_spec.html +++ b/docs/reference/dplyr.adjustr_spec.html @@ -86,7 +86,7 @@ adjustr - 0.0.0.9000 + 0.1.0 @@ -176,7 +176,7 @@

    dplyr Methods for adjustr_spec Objects

    Core dplyr verbs which don't involve grouping (filter, arrange, mutate, select, - rename, and slice) are + rename, and slice) are implemented and operate on the underlying table of specification parameters.

    diff --git a/docs/reference/extract_samp_stmts.html b/docs/reference/extract_samp_stmts.html index b45300a..8e6a15e 100644 --- a/docs/reference/extract_samp_stmts.html +++ b/docs/reference/extract_samp_stmts.html @@ -84,7 +84,7 @@ adjustr - 0.0.0.9000 + 0.1.0 @@ -183,7 +183,7 @@

    Arg object -

    A stanfit model object.

    +

    A stanfit model object.

    diff --git a/docs/reference/get_resampling_idxs.html b/docs/reference/get_resampling_idxs.html index bc5ef02..bc7229e 100644 --- a/docs/reference/get_resampling_idxs.html +++ b/docs/reference/get_resampling_idxs.html @@ -83,7 +83,7 @@ adjustr - 0.0.0.9000 + 0.1.0 diff --git a/docs/reference/index.html b/docs/reference/index.html index 2a2a3fc..607ad7a 100644 --- a/docs/reference/index.html +++ b/docs/reference/index.html @@ -81,7 +81,7 @@ adjustr - 0.0.0.9000 + 0.1.0 @@ -207,7 +207,7 @@

    plot(<adjustr_weighted>)

    +

    spec_plot()

    Plot Posterior Quantities of Interest Under Alternative Model Specifications

    diff --git a/docs/reference/make_spec.html b/docs/reference/make_spec.html index ae0b96c..e980a24 100644 --- a/docs/reference/make_spec.html +++ b/docs/reference/make_spec.html @@ -84,7 +84,7 @@ adjustr - 0.0.0.9000 + 0.1.0 @@ -202,7 +202,7 @@

    Arg into the corresponding parameter in the sampling statements. For data frame, each entry in each column will be substituted into the corresponding parameter in the sampling statements.

    -

    List arguments are coerced to data frame. They can either be lists of named +

    List arguments are coerced to data frames. They can either be lists of named vectors, or lists of lists of single-element named vector.

    The lengths of all parameter arguments must be consistent. Named vectors can have length 1 or must have length equal to the number of rows in all @@ -218,16 +218,16 @@

    Value

    dplyr verbs which don't involve grouping (filter, arrange, mutate, select, -rename, and slice) are +rename, and slice) are supported and operate on the underlying table of specification parameters.

    Examples

    make_spec(eta ~ cauchy(0, 1))
    #> Sampling specifications: #> eta ~ cauchy(0, 1) -#> <environment: 0x7fc7dbbb1418>
    +#> <environment: 0x7f97fff72470>
    make_spec(eta ~ student_t(df, 0, 1), df=1:10)
    #> Sampling specifications: #> eta ~ student_t(df, 0, 1) -#> <environment: 0x7fc7dbbb1418> +#> <environment: 0x7f97fff72470> #> #> Specification parameters: #> df @@ -246,9 +246,9 @@

    Examp y ~ normal(theta, infl*sigma), params)

    #> Sampling specifications: #> eta ~ student_t(df, 0, 1) -#> <environment: 0x7fc7dbbb1418> +#> <environment: 0x7f97fff72470> #> y ~ normal(theta, infl * sigma) -#> <environment: 0x7fc7dbbb1418> +#> <environment: 0x7f97fff72470> #> #> Specification parameters: #> df infl diff --git a/docs/reference/plot.adjustr_weighted.html b/docs/reference/plot.adjustr_weighted.html index 4323657..01fa1de 100644 --- a/docs/reference/plot.adjustr_weighted.html +++ b/docs/reference/plot.adjustr_weighted.html @@ -175,7 +175,7 @@

    Plot Posterior Quantities of Interest Under Alternative Model Specifications

    # S3 method for adjustr_weighted
    -plot(x, by, post, only_mean = FALSE, ci_level = 0.8, outer_level = 0.95, ...)
    +plot(x, by, post, only_mean = FALSE, ci_level = 0.8, outer_level = 0.95, ...)

    Arguments

    @@ -227,9 +227,9 @@

    Examp df=1:10, scale=seq(2, 1, -1/9)) adjusted = adjust_weights(spec, eightschools_m) -plot(adjusted, df, theta[1]) -plot(adjusted, df, mu, only_mean=TRUE) -plot(adjusted, scale, tau) +plot(adjusted, df, theta[1]) +plot(adjusted, df, mu, only_mean=TRUE) +plot(adjusted, scale, tau) } diff --git a/docs/reference/spec_plot.html b/docs/reference/spec_plot.html new file mode 100644 index 0000000..43d0a50 --- /dev/null +++ b/docs/reference/spec_plot.html @@ -0,0 +1,268 @@ + + + + + + + + +Plot Posterior Quantities of Interest Under Alternative Model Specifications — spec_plot • adjustr + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + + +
    + +
    +
    + + +
    +

    Uses weights computed in adjust_weights to plot posterior +quantities of interest versus specification parameters

    +
    + +
    spec_plot(
    +  x,
    +  by,
    +  post,
    +  only_mean = FALSE,
    +  ci_level = 0.8,
    +  outer_level = 0.95,
    +  ...
    +)
    + +

    Arguments

    +

    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    x

    An adjustr_weighted object.

    by

    The x-axis variable, which is usually one of the specification +parameters. Can be set to 1 if there is only one specification. +Automatically quoted and evaluated in the context of x.

    post

    The posterior quantity of interest, to be computed for each +resampled draw of each specificaiton. Should evaluate to a single number +for each draw. Automatically quoted and evaluated in the context of x.

    only_mean

    Whether to only plot the posterior mean. May be more stable.

    ci_level

    The inner credible interval to plot. Central +100*ci_level +posterior draws.

    outer_level

    The outer credible interval to plot.

    ...

    Ignored.

    + +

    Value

    + +

    A ggplot object which can be further + customized with the ggplot2 package.

    + +

    Examples

    +
    if (FALSE) { +spec = make_spec(eta ~ student_t(df, 0, scale), + df=1:10, scale=seq(2, 1, -1/9)) +adjusted = adjust_weights(spec, eightschools_m) + +spec_plot(adjusted, df, theta[1]) +spec_plot(adjusted, df, mu, only_mean=TRUE) +spec_plot(adjusted, scale, tau) +}
    + + + + + + + + + + + + + + + diff --git a/docs/reference/summarize.adjustr_weighted.html b/docs/reference/summarize.adjustr_weighted.html index 57d3a21..8510309 100644 --- a/docs/reference/summarize.adjustr_weighted.html +++ b/docs/reference/summarize.adjustr_weighted.html @@ -85,7 +85,7 @@ adjustr - 0.0.0.9000 + 0.1.0 @@ -198,6 +198,11 @@

    Arg posterior distribution of eight alternative specification. For example, a value of mean(theta) will compute the posterior mean of theta for each alternative specification.

    +

    Also supported is the custom function wasserstein, which computes + the Wasserstein-p distance between the posterior distribution of the + provided expression under the new model and under the original model, with + p=1 the default. Lower the spacing parameter from the + default of 0.005 to compute a finer (but slower) approximation.

    The arguments in ... are automatically quoted and evaluated in the context of .data. They support unquoting and splicing.

    @@ -232,6 +237,7 @@

    Examp adjusted = adjust_weights(spec, eightschools_m) summarize(adjusted, mean(mu), var(mu)) +summarize(adjusted, wasserstein(mu, p=2)) summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data) summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95))) }

    diff --git a/docs/sitemap.xml b/docs/sitemap.xml index c788329..4bd8fb2 100644 --- a/docs/sitemap.xml +++ b/docs/sitemap.xml @@ -25,10 +25,10 @@ https://corymccartan.github.io/adjustr//reference/make_spec.html - https://corymccartan.github.io/adjustr//reference/plot.adjustr_weighted.html + https://corymccartan.github.io/adjustr//reference/pull.adjustr_weighted.html - https://corymccartan.github.io/adjustr//reference/pull.adjustr_weighted.html + https://corymccartan.github.io/adjustr//reference/spec_plot.html https://corymccartan.github.io/adjustr//reference/summarize.adjustr_weighted.html diff --git a/inst/stan.g b/inst/stan.g deleted file mode 100644 index 24ea6bc..0000000 --- a/inst/stan.g +++ /dev/null @@ -1,152 +0,0 @@ -program: functions? data? tdata? params? tparams? model? generated?; - -functions: 'functions' function_decls; -data: 'data' var_decls; -tdata: 'transformed data' var_decls_statements; -params: 'parameters' var_decls; -tparams: 'transformed parameters' var_decls_statements; -model: 'model' var_decls_statements; -generated: 'generated quantities' var_decls_statements; -function_decls: '{' function_decl* '}'; -var_decls: '{' var_decl* '}'; -var_decls_statements: '{' var_decl* statement* '}'; - - -function_decl: return_type identifier '(' (parameter_decl (',' parameter_decl)*)? ')' - statement; - -return_type: 'void' | unsized_type; -parameter_decl: 'data'? unsized_type identifier; -unsized_type: basic_type unsized_dims?; -basic_type: 'int' | 'real' | 'vector' | 'row_vector' | 'matrix'; -unsized_dims: '[' ','* ']'; - - - - -var_decl: var_type variable dims? ('=' expression)? ';'; - -var_type: 'int' range_constraint - | 'real' constraint - | 'vector' constraint '[' expression ']' - | 'ordered' '[' expression ']' - | 'positive_ordered' '[' expression ']' - | 'simplex' '[' expression ']' - | 'unit_vector' '[' expression ']' - | 'row_vector' constraint '[' expression ']' - | 'matrix' constraint '[' expression ',' expression ']' - | 'cholesky_factor_corr' '[' expression ']' - | 'cholesky_factor_cov' '[' expression (',' expression)? ']' - | 'corr_matrix' '[' expression ']' - | 'cov_matrix' '[' expression ']'; - -constraint: range_constraint | ('<' offset_multiplier '>'); - -range_constraint: ('<' range '>')?; - -range: 'lower' '=' constr_expression ',' 'upper' '=' constr_expression - | 'lower' '=' constr_expression - | 'upper' '=' constr_expression; - - -offset_multiplier: 'offset' '=' constr_expression ',' - 'multiplier' '=' constr_expression - | 'offset' '=' constr_expression - | 'multiplier' '=' constr_expression; - -dims: '[' expressions ']'; - -variable: identifier; - -identifier: "[a-zA-Z][a-zA-Z0-9_]*"; - - - -expressions: (expression (',' expression)*)?; - -expression: expression '?' expression ':' expression - | expression infixOp expression - | prefixOp expression - | expression postfixOp - | expression '[' indexes ']' - | common_expression; - -constr_expression: constr_expression arithmeticInfixOp constr_expression - | prefixOp constr_expression - | constr_expression postfixOp - | constr_expression '[' indexes ']' - | common_expression; - -common_expression : real_literal - | variable - | '{' expressions '}' - | '[' expressions ']' - | function_literal '(' expressions? ')' - | function_literal '(' expression ('|' (expression (',' expression)*)?)? ')' - | 'integrate_1d' '(' function_literal (',' expression)@5:6 ')' - | 'integrate_ode' '(' function_literal (',' expression)@6 ')' - | 'integrate_ode_rk45' '(' function_literal (',' expression)@6:9 ')' - | 'integrate_ode_bdf' '(' function_literal (',' expression)@6:9 ')' - | 'algebra_solver' '(' function_literal (',' expression)@4:7 ')' - | 'map_rect' '(' function_literal (',' expression)@4 ')' - | '(' expression ')'; - -prefixOp: ('!' | '-' | '+' | '^'); - -postfixOp: '\''; - -infixOp: arithmeticInfixOp | logicalInfixOp; - -arithmeticInfixOp: ('+' | '-' | '*' | '/' | '%' | '\\' | '.*' | './'); - -logicalInfixOp: ('||' | '&&' | '==' | '!=' | '<' | '<=' | '>' | '>='); - -index: (expression | expression ':' | ':' expression - | expression ':' expression)?; - -indexes: (index (',' index)*)?; - -integer_literal: "[0-9]+"; - -real_literal: integer_literal '.' "[0-9]*" exp_literal? - | '.' "[0-9]+" exp_literal? - | integer_literal exp_literal?; - -exp_literal: ('e' | 'E') ('+' | '-')? integer_literal; - -function_literal: identifier; - - - -statement: atomic_statement | nested_statement; - -atomic_statement: lhs assignment_op expression ';' - | sampling_statement - | function_literal '(' expressions ')' ';' - | 'increment_log_prob' '(' expression ')' ';' - | 'target' '+=' expression ';' - | 'break' ';' - | 'continue' ';' - | 'print' '(' ((expression | string_literal) (',' (expression | string_literal))*)? ')' ';' - | 'reject' '(' ((expression | string_literal) (',' (expression | string_literal))*)? ')' ';' - | 'return' expression ';' - | ';'; - -sampling_statement: expression '~' identifier '(' expressions ')' truncation? ';'; - -assignment_op: '<-' | '=' | '+=' | '-=' | '*=' | '/=' | '.*=' | './='; - -//string_literal: '"' char* '"'; -string_literal: "\"([^\"\\]|\\[^])*\""; - -truncation: 'T' '[' ?expression ',' ?expression ']'; - -lhs: identifier ('[' indexes ']')*; - -nested_statement: 'if' '(' expression ')' statement - ('else' 'if' '(' expression ')' statement)* - ('else' statement)? - | 'while' '(' expression ')' statement - | 'for' '(' identifier 'in' expression ':' expression ')' statement - | 'for' '(' identifier 'in' expression ')' statement - | '{' var_decl* statement+ '}'; diff --git a/man/adjust_weights.Rd b/man/adjust_weights.Rd index 1d0800f..5c7ccd7 100644 --- a/man/adjust_weights.Rd +++ b/man/adjust_weights.Rd @@ -5,7 +5,7 @@ \title{Compute Pareto-smoothed Importance Weights for Alternative Model Specifications} \usage{ -adjust_weights(spec, object, data = NULL, keep_bad = FALSE) +adjust_weights(spec, object, data = NULL, keep_bad = FALSE, incl_orig = TRUE) } \arguments{ \item{spec}{An object of class \code{adjustr_spec}, probably produced by @@ -24,6 +24,10 @@ alternate specifications which deviate too much from the original posterior, and which as a result cannot be reliably estimated using importance sampling (i.e., if the Pareto shape parameter is larger than 0.7), have their weights discarded.} + +\item{incl_orig}{When \code{TRUE}, include a row for the original +model specification, with all weights equal. Can facilitate comaprison +and plotting later.} } \value{ A tibble, produced by converting the provided \code{specs} to a @@ -31,9 +35,10 @@ A tibble, produced by converting the provided \code{specs} to a \code{.weights}, containing vectors of weights for each draw, and \code{.pareto_k}, containing the diagnostic Pareto shape parameters. Values greater than 0.7 indicate that importance sampling is not reliable. - Weights can be extracted with the \code{\link{pull.adjustr_weighted}} - method. The returned object also includes the model sample draws, in the - \code{draws} attribute. + If \code{incl_orig} is \code{TRUE}, a row is added for the original model + specification. Weights can be extracted with the + \code{\link{pull.adjustr_weighted}} method. The returned object also + includes the model sample draws, in the \code{draws} attribute. } \description{ Given a set of new sampling statements, which can be parametrized by a diff --git a/man/make_spec.Rd b/man/make_spec.Rd index fa662fd..da2fce0 100644 --- a/man/make_spec.Rd +++ b/man/make_spec.Rd @@ -29,7 +29,7 @@ make_spec(...) frame, each entry in each column will be substituted into the corresponding parameter in the sampling statements. - List arguments are coerced to data frame. They can either be lists of named + List arguments are coerced to data frames. They can either be lists of named vectors, or lists of lists of single-element named vector. The lengths of all parameter arguments must be consistent. Named vectors diff --git a/man/plot.adjustr_weighted.Rd b/man/spec_plot.Rd similarity index 82% rename from man/plot.adjustr_weighted.Rd rename to man/spec_plot.Rd index f740c3c..3dfdef3 100644 --- a/man/plot.adjustr_weighted.Rd +++ b/man/spec_plot.Rd @@ -1,10 +1,18 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/use_weights.R -\name{plot.adjustr_weighted} -\alias{plot.adjustr_weighted} +\name{spec_plot} +\alias{spec_plot} \title{Plot Posterior Quantities of Interest Under Alternative Model Specifications} \usage{ -\method{plot}{adjustr_weighted}(x, by, post, only_mean = FALSE, ci_level = 0.8, outer_level = 0.95, ...) +spec_plot( + x, + by, + post, + only_mean = FALSE, + ci_level = 0.8, + outer_level = 0.95, + ... +) } \arguments{ \item{x}{An \code{adjustr_weighted} object.} @@ -33,7 +41,7 @@ A \code{\link[ggplot2]{ggplot}} object which can be further } \description{ Uses weights computed in \code{\link{adjust_weights}} to plot posterior -quantities of interest versus +quantities of interest versus specification parameters } \examples{ \dontrun{ @@ -41,9 +49,9 @@ spec = make_spec(eta ~ student_t(df, 0, scale), df=1:10, scale=seq(2, 1, -1/9)) adjusted = adjust_weights(spec, eightschools_m) -plot(adjusted, df, theta[1]) -plot(adjusted, df, mu, only_mean=TRUE) -plot(adjusted, scale, tau) +spec_plot(adjusted, df, theta[1]) +spec_plot(adjusted, df, mu, only_mean=TRUE) +spec_plot(adjusted, scale, tau) } } diff --git a/man/summarize.adjustr_weighted.Rd b/man/summarize.adjustr_weighted.Rd index 17b762e..2edf524 100644 --- a/man/summarize.adjustr_weighted.Rd +++ b/man/summarize.adjustr_weighted.Rd @@ -18,6 +18,12 @@ a value of \code{mean(theta)} will compute the posterior mean of \code{theta} for each alternative specification. + Also supported is the custom function \code{wasserstein}, which computes + the Wasserstein-p distance between the posterior distribution of the + provided expression under the new model and under the original model, with + \code{p=1} the default. Lower the \code{spacing} parameter from the + default of 0.005 to compute a finer (but slower) approximation. + The arguments in \code{...} are automatically quoted and evaluated in the context of \code{.data}. They support unquoting and splicing.} @@ -51,6 +57,7 @@ spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) adjusted = adjust_weights(spec, eightschools_m) summarize(adjusted, mean(mu), var(mu)) +summarize(adjusted, wasserstein(mu, p=2)) summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data) summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95))) } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 5a27d52..954b66c 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -1,4 +1 @@ -load("../test_model.rda") - -# set up Stan parsing -get_parser() \ No newline at end of file +load("../test_model.rda") \ No newline at end of file diff --git a/tests/testthat/test_logprob.R b/tests/testthat/test_logprob.R index 3bb2f1f..1c31b15 100644 --- a/tests/testthat/test_logprob.R +++ b/tests/testthat/test_logprob.R @@ -24,10 +24,9 @@ test_that("Model parameter log probabilities are calculated correctly", { test_that("Data is assembled correctly", { code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - parsed_vars = get_variables(parsed_model) + parsed = parse_model(code) bd = get_base_data(eightschools_m, list(eta ~ student_t(df, 0, tau)), - parsed_vars, list(df=1:2), "df") + parsed$vars, list(df=1:2), "df") expect_length(bd, 1) expect_named(bd[[1]], c("eta", "tau"), ignore.order=T) @@ -37,10 +36,9 @@ test_that("Data is assembled correctly", { test_that("MCMC draws are preferred over provided data", { code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - parsed_vars = get_variables(parsed_model) + parsed = parse_model(code) bd = get_base_data(eightschools_m, list(eta ~ student_t(2, 0, tau)), - parsed_vars, list(tau=3)) + parsed$vars, list(tau=3)) expect_length(bd, 1) expect_named(bd[[1]], c("eta", "tau"), ignore.order=T) @@ -50,9 +48,8 @@ test_that("MCMC draws are preferred over provided data", { test_that("Parameter-less specification data is correctly assembled", { code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - parsed_vars = get_variables(parsed_model) - bd = get_base_data(eightschools_m, list(y ~ std_normal()), parsed_vars, + parsed = parse_model(code) + bd = get_base_data(eightschools_m, list(y ~ std_normal()), parsed$vars, list(y=c(28, 8, -3, 7, -1, 1, 18, 12), J=8)) expect_length(bd, 1) @@ -63,34 +60,31 @@ test_that("Parameter-less specification data is correctly assembled", { test_that("Error thrown for missing data", { code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - parsed_vars = get_variables(parsed_model) + parsed = parse_model(code) expect_error(get_base_data(eightschools_m, list(eta ~ normal(gamma, sigma)), - parsed_vars, list()), "sigma not found") + parsed$vars, list()), "sigma not found") expect_error(get_base_data(eightschools_m, list(eta ~ normal(gamma, 2)), - parsed_vars, list()), "gamma not found") + parsed$vars, list()), "gamma not found") }) test_that("Model log probability is correctly calculated", { code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - parsed_vars = get_variables(parsed_model) + parsed = parse_model(code) form = eta ~ normal(0, 1) draws = rstan::extract(eightschools_m, "eta", permuted=F) exp_lp = 2*apply(dnorm(draws, 0, 1, log=T), 1:2, sum) - lp = calc_original_lp(eightschools_m, list(form, form), parsed_vars, list()) + lp = calc_original_lp(eightschools_m, list(form, form), parsed$vars, list()) expect_equal(exp_lp, lp) }) test_that("Alternate specifications log probabilities are correctly calculated", { code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - parsed_vars = get_variables(parsed_model) + parsed = parse_model(code) form = eta ~ normal(0, s) draws = rstan::extract(eightschools_m, "eta", permuted=F) exp_lp = 2*apply(dnorm(draws, 0, 1, log=T), 1:2, sum) - lp = calc_specs_lp(eightschools_m, list(form, form), parsed_vars, list(), list(list(s=1))) + lp = calc_specs_lp(eightschools_m, list(form, form), parsed$vars, list(), list(list(s=1))) expect_equal(exp_lp, lp[[1]]) }) diff --git a/tests/testthat/test_parsing.R b/tests/testthat/test_parsing.R index 846f2be..05b4584 100644 --- a/tests/testthat/test_parsing.R +++ b/tests/testthat/test_parsing.R @@ -1,15 +1,9 @@ context("Stan model parsing") - -test_that("Model parses and returns parsing table", { - code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - expect_equal(names(parsed_model), c("name", "value", "pos", "depth", "i")) -}) - -test_that("Bad model throws syntax error", { - code = paste(eightschools_m@stanmodel@model_code, "\ndata{\n}\n") - expect_error(parse_model(code), "syntax error") +test_that("Empty model handled correctly", { + parsed = parse_model("") + expect_equal(length(parsed$vars), 0) + expect_equal(length(parsed$samps), 0) }) test_that("Correct parsed variables", { @@ -17,15 +11,16 @@ test_that("Correct parsed variables", { tau="parameters", eta="parameters", theta="transformed parameters") code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - expect_equal(get_variables(parsed_model), correct_vars) + parsed = parse_model(code) + expect_equal(parsed$vars, correct_vars) }) test_that("Correct parsed sampling statements", { - correct_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma)) + correct_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma), + mu ~ uniform(-1e+100, 1e+100), tau ~ uniform(-1e+100, 1e+100)) code = eightschools_m@stanmodel@model_code - parsed_model = parse_model(code) - expect_equal(get_sampling_stmts(parsed_model), correct_samp) + parsed = parse_model(code) + expect_equal(parsed$samp, correct_samp) }) test_that("Provided sampling statements can be matched to model", { @@ -41,7 +36,7 @@ test_that("Extra sampling statements not in model throw an error", { model_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma)) prov_samp = list(eta ~ exponential(5), x ~ normal(theta, sigma)) expect_error(match_sampling_stmts(prov_samp, model_samp), - "No matching sampling statement found for prior x ~ normal\\(theta, sigma\\)") + "No matching sampling statement found for x ~ normal\\(theta, sigma\\)") }) test_that("Variables are correctly extracted from sampling statements", { diff --git a/tests/testthat/test_use.R b/tests/testthat/test_use.R index 695d26e..c0045dc 100644 --- a/tests/testthat/test_use.R +++ b/tests/testthat/test_use.R @@ -34,6 +34,8 @@ test_that("Weighted array functions compute correctly", { expect_equal(wtd_array_mean(y, wgt), wtd_mean) expect_equal(wtd_array_var(y, wgt), weighted.mean((y - wtd_mean)^2, wgt)) expect_equal(wtd_array_sd(y, wgt), sqrt(weighted.mean((y - wtd_mean)^2, wgt))) + expect_equal(wtd_array_quantile(y, rep(1, 5), 0.2), 1) + expect_equal(wtd_array_median(y, rep(1, 5)), 2.5) }) test_that("Empty call to `summarize` should change nothing", { @@ -42,6 +44,15 @@ test_that("Empty call to `summarize` should change nothing", { expect_identical(summarize(obj), obj) }) +test_that("Non-summary call to `summarize` should throw error", { + obj = tibble(.weights=list(c(1,1,1), c(1,1,4))) + attr(obj, "draws") = list(theta=matrix(c(3,5,7,1,1,1), ncol=2)) + attr(obj, "iter") = 3 + class(obj) = c("adjustr_weighted", class(obj)) + + expect_error(summarize(obj, theta), "must summarize posterior draws") +}) + test_that("Basic summaries are computed correctly", { obj = tibble(.weights=list(c(1,1,1), c(1,1,4))) attr(obj, "draws") = list(theta=matrix(c(3,5,7,1,1,1), ncol=2)) @@ -56,6 +67,10 @@ test_that("Basic summaries are computed correctly", { expect_is(sum2, "adjustr_weighted") expect_equal(sum2$th, list(c(5, 1), c(6, 1))) + sum3 = summarize(obj, W = wasserstein(theta[1])) + expect_is(sum3, "adjustr_weighted") + expect_equal(sum3$W[1], 0) + expect_error(summarise.adjustr_weighted(as_tibble(obj)), "is not TRUE") }) @@ -82,19 +97,20 @@ test_that("Resampling-based summaries are computed correctly", { sum1 = summarize(obj, th=mean(theta), .resampling=T) expect_equal(sum1$th, c(3,7)) - sum2 = summarize(obj, th=quantile(theta, 0.05)) + sum2 = summarize(obj, th=quantile(theta, 0.05), .resampling=T) expect_equal(sum2$th, c(3,7)) }) test_that("Plotting function handles arguments correctly", { - obj = tibble(.weights=list(c(1,0,0), c(0,0,1))) + obj = tibble(.weights=list(c(1,0,0), c(0,0,1), c(1,1,1)), + .samp=c("y ~ normal(0, 1)", "y ~ normal(0, 2)", "")) attr(obj, "draws") = list(theta=matrix(c(3,5,7), nrow=3, ncol=1)) attr(obj, "iter") = 3 class(obj) = c("adjustr_weighted", class(obj)) - expect_is(plot(obj, 1, theta), "ggplot") - expect_is(plot(obj, 1, theta, only_mean=T), "ggplot") + expect_is(spec_plot(obj, 1, theta), "ggplot") + expect_is(spec_plot(obj, 1, theta, only_mean=T), "ggplot") - expect_error(plot(obj, 1, theta, outer_level=0.4), "should be less than") + expect_error(spec_plot(obj, 1, theta, outer_level=0.4), "should be less than") }) diff --git a/tests/testthat/test_weights.R b/tests/testthat/test_weights.R index c629746..0691b13 100644 --- a/tests/testthat/test_weights.R +++ b/tests/testthat/test_weights.R @@ -32,7 +32,7 @@ test_that("Weights calculated correctly (normal/inflated)", { weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=F)) spec = make_spec(y ~ normal(theta, 1.1*sigma)) - obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T) + obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F) expect_s3_class(obj, "adjustr_weighted") expect_s3_class(obj, "tbl_df") @@ -58,7 +58,7 @@ test_that("Weights calculated correctly (normal/student_t)", { weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=F)) spec = make_spec(y ~ student_t(df, theta, sigma), df=5:6) - obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T) + obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F) expect_equal(weights, obj$.weights[[2]]) expect_equal(pareto_k, obj$.pareto_k[2]) @@ -77,7 +77,7 @@ test_that("Weights calculated correctly (no data normal/student_t)", { weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=F)) spec = make_spec(eta ~ student_t(4, 0, 1)) - obj = adjust_weights(spec, eightschools_m, keep_bad=T) + obj = adjust_weights(spec, eightschools_m, keep_bad=T, incl_orig=F) expect_equal(weights, obj$.weights[[1]]) expect_equal(pareto_k, obj$.pareto_k) @@ -86,14 +86,14 @@ test_that("Weights calculated correctly (no data normal/student_t)", { test_that("Weights extracted correctly", { spec = make_spec(y ~ student_t(df, theta, sigma), df=5) - obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T) + obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F) pulled = pull(obj) expect_is(pulled, "numeric") expect_length(pulled, 20) spec2 = make_spec(y ~ student_t(df, theta, sigma), df=5:6) - obj = adjust_weights(spec2, eightschools_m, pkg_env$model_d, keep_bad=T) + obj = adjust_weights(spec2, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F) pulled = pull(obj) expect_is(pulled, "list") @@ -104,6 +104,8 @@ test_that("Weights extracted correctly", { test_that("Sampling statements printed correctly", { expect_output(extract_samp_stmts(eightschools_m), "Sampling statements for model 2c8d1d8a30137533422c438f23b83428: + parameter tau ~ uniform(-1e+100, 1e+100) + parameter mu ~ uniform(-1e+100, 1e+100) parameter eta ~ std_normal() data y ~ normal(theta, sigma)", fixed=T) }) diff --git a/vignettes/eight-schools.Rmd b/vignettes/eight-schools.Rmd index c8ee816..7f1c4f3 100644 --- a/vignettes/eight-schools.Rmd +++ b/vignettes/eight-schools.Rmd @@ -1,8 +1,9 @@ --- title: "Sensitivity Analysis of a Simple Hierarchical Model" output: rmarkdown::html_vignette +bibliography: eight-schools.bib vignette: > - %\VignetteIndexEntry{eight-schools} + %\VignetteIndexEntry{Sensitivity Analysis of a Simple Hierarchical Model} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- @@ -12,8 +13,175 @@ knitr::opts_chunk$set( collapse = TRUE, comment = "#>" ) + +library(dplyr) +library(rstan) +library(adjustr) +load("eightschools_model.rda") ``` -```{r setup} +## Introduction + +This vignette walks through the process of performing sensitivity +analysis using the `adjustr` package for the classic introductory +hierarchical model: the "eight schools" meta-analysis from Chapter 5 +of @bda3. + +We begin by specifying and fitting the model, which should be familiar +to most users of Stan. +```{r eval=F} +library(dplyr) +library(rstan) library(adjustr) + +model_code = " +data { + int J; // number of schools + real y[J]; // estimated treatment effects + real sigma[J]; // standard error of effect estimates +} +parameters { + real mu; // population treatment effect + real tau; // standard deviation in treatment effects + vector[J] eta; // unscaled deviation from mu by school +} +transformed parameters { + vector[J] theta = mu + tau * eta; // school treatment effects +} +model { + eta ~ std_normal(); + y ~ normal(theta, sigma); +}" + +model_d = list(J = 8, + y = c(28, 8, -3, 7, -1, 1, 18, 12), + sigma = c(15, 10, 16, 11, 9, 11, 10, 18)) +eightschools_m = stan(model_code=model_code, chains=2, data=model_d, + warmup=500, iter=1000) +``` + +We plot the original estimates for each of the eight schools. +```{r} +plot(eightschools_m, pars="theta") +``` + +The model partially pools information, pulling the school-level treatment effects +towards the overall mean. + +It is natural to wonder how much these estimates depend on certain aspects of +our model. The individual and school treatment effects are assumed to follow a +normal distribution, and we have used a uniform prior on the population +parameters `mu` and `tau`. + +The basic __adjustr__ workflow is as follows: + +1. Use `make_spec` to specify the set of alternative model specifications you'd +like to fit. + +2. Use `adjust_weights` to calculate importance sampling weights which +approximate the posterior of each alternative specification. + +3. Use `summarize` and `spec_plot` to examine posterior quantities of interest +for each alternative specification, in order to assess the sensitivity of the +underlying model. + +## Basic Workflow Example + +First suppose we want to examine the effect of our choice of uniform prior +on `mu` and `tau`. We begin by specifying an alternative model in which +these parameters have more informative priors. This just requires +passing the `make_spec` function the new sampling statements we'd like to +use. These replace any in the original model (`mu` and `tau` have implicit +improper uniform priors, since the original model does not have any sampling +statements for them). +```{r} +spec = make_spec(mu ~ normal(0, 20), tau ~ exponential(5)) +print(spec) +``` + +Then we compute importance sampling weights to approximate the posterior under +this alternative model. +```{r include=F} +adjusted = adjust_weights(spec, eightschools_m, keep_bad=T) ``` +```{r eval=F} +adjusted = adjust_weights(spec, eightschools_m) +``` + +The `adjust_weights` function returns a data frame +containing a summary of the alternative model and a list-column named `.weights` +containing the importance weights. The last row of the table by default +corresponds to the original model specification. The table also includes the +diagnostic Pareto *k*-value. When this value exceeds 0.7, importance sampling is +unreliable, and by default `adjust_weights` discards weights with a Pareto *k* +above 0.7. +```{r} +print(adjusted) +``` + +Finally, we can examine how these alternative priors have changed our posterior +inference. We use `summarize` to calculate these under the alternative model. +```{r} +summarize(adjusted, mean(mu), var(mu)) +``` +We see that the more informative priors have pulled the posterior distribution +of `mu` towards zero and made it less variable. + +## Multiple Alternative Specifications +What if instead we are concerned about our distributional assumption on the +school treatment effects. We could probe this assumption by fitting a series of +models where `eta` had a Student's *t* distribution, with varying degrees of +freedom. + +The `make_spec` function handles this easily. +```{r} +spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10) +print(spec) +``` +Notice how we have parameterized the alternative sampling statement with +a variable `df`, and then provided the values `df` takes in another argument +to `make_spec`. + +As before, we compute importance sampling weights to approximate the posterior +under these alternative models. Here, for the purposes of illustration, +we are using `keep_bad=T` to compute weights even when the Pareto _k_ diagnostic +value is above 0.7. In practice, the alternative models should be completely +re-fit in Stan. +```{r} +adjusted = adjust_weights(spec, eightschools_m, keep_bad=T) +``` +Now, `adjusted` has ten rows, one for each alternative model. +```{r} +print(adjusted) +``` + +To examine the impact of these model changes, we can plot the posterior for +a quantity of interest versus the degrees of freedom for the *t* distribution. +The package provides the `spec_plot` function which takes an x-axis specification +parameter and a y-axis posterior quantity (which must evaluate to a single +number per posterior draw). The dashed line shows the posterior median +under the original model. +```{r} +spec_plot(adjusted, df, mu) +spec_plot(adjusted, df, theta[3]) +``` + +It appears that changing the distribution of `eta`/`theta` from normal to +*t* has a small effect on posterior inferences (although, as noted above, +these inferences are unreliable as _k_ > 0.7. + +By default, the function plots an inner 80\% credible interval and an outer +95\% credible interval, but these can be changed by the user. + +We can also measure the distance between the new and original posterior +marginals by using the special `wasserstein()` function available in +`summarize()`: +```{r} +summarize(adjusted, wasserstein(mu)) +``` +As we would expect, the 1-Wasserstein distance decreases as the degrees of +freedom increase. In general, we can compute the _p_-Wasserstein distance +by passing an extra `p` parameter to `wasserstein()`. + + +### \ No newline at end of file diff --git a/vignettes/eight-schools.bib b/vignettes/eight-schools.bib new file mode 100644 index 0000000..a414cde --- /dev/null +++ b/vignettes/eight-schools.bib @@ -0,0 +1,7 @@ +@book{bda3, + address = {London}, + author = {Andrew Gelman and J.~B.~Carlin and Hal S.~Stern and David B.~Dunson and Aki Vehtari and Donald B.~Rubin}, + edition = {3rd}, + publisher = {CRC Press}, + title = {Bayesian Data Analysis}, + year = {2013}} diff --git a/vignettes/eightschools_model.rda b/vignettes/eightschools_model.rda new file mode 100644 index 0000000..cfc44de Binary files /dev/null and b/vignettes/eightschools_model.rda differ