This vignette walks through the process of performing sensitivity analysis using the adjustr package for the classic introductory hierarchical model: the “eight schools” meta-analysis from Chapter 5 of Gelman et al. (2013).
+
We begin by specifying and fitting the model, which should be familiar to most users of Stan.
+
library(dplyr)
+library(rstan)
+library(adjustr)
+
+model_code="
+data {
+ int<lower=0> J; // number of schools
+ real y[J]; // estimated treatment effects
+ real<lower=0> sigma[J]; // standard error of effect estimates
+}
+parameters {
+ real mu; // population treatment effect
+ real<lower=0> tau; // standard deviation in treatment effects
+ vector[J] eta; // unscaled deviation from mu by school
+}
+transformed parameters {
+ vector[J] theta = mu + tau * eta; // school treatment effects
+}
+model {
+ eta ~ std_normal();
+ y ~ normal(theta, sigma);
+}"
+
+model_d=list(J=8,
+ y=c(28, 8, -3, 7, -1, 1, 18, 12),
+ sigma=c(15, 10, 16, 11, 9, 11, 10, 18))
+eightschools_m=stan(model_code=model_code, chains=2, data=model_d,
+ warmup=500, iter=1000)
+
We plot the original estimates for each of the eight schools.
The model partially pools information, pulling the school-level treatment effects towards the overall mean.
+
It is natural to wonder how much these estimates depend on certain aspects of our model. The individual and school treatment effects are assumed to follow a normal distribution, and we have used a uniform prior on the population parameters mu and tau.
+
The basic adjustr workflow is as follows:
+
+
Use make_spec to specify the set of alternative model specifications you’d like to fit.
+
Use adjust_weights to calculate importance sampling weights which approximate the posterior of each alternative specification.
+
Use summarize and spec_plot to examine posterior quantities of interest for each alternative specification, in order to assess the sensitivity of the underlying model.
+
+
+
+
+Basic Workflow Example
+
First suppose we want to examine the effect of our choice of uniform prior on mu and tau. We begin by specifying an alternative model in which these parameters have more informative priors. This just requires passing the make_spec function the new sampling statements we’d like to use. These replace any in the original model (mu and tau have implicit improper uniform priors, since the original model does not have any sampling statements for them).
+
spec=make_spec(mu ~ normal(0, 20), tau ~ exponential(5))
+print(spec)
+#> Sampling specifications:
+#> mu ~ normal(0, 20)
+#> tau ~ exponential(5)
+
Then we compute importance sampling weights to approximate the posterior under this alternative model.
The adjust_weights function returns a data frame containing a summary of the alternative model and a list-column named .weights containing the importance weights. The last row of the table by default corresponds to the original model specification. The table also includes the diagnostic Pareto k-value. When this value exceeds 0.7, importance sampling is unreliable, and by default adjust_weights discards weights with a Pareto k above 0.7.
+
print(adjusted)
+#> # A tibble: 2 x 4
+#> .samp_1 .samp_2 .weights .pareto_k
+#> <chr> <chr> <list> <dbl>
+#> 1 mu ~ normal(0, 20) tau ~ exponential(5) <dbl [1,000]> 0.569
+#> 2 <original model> <original model> <dbl [1,000]> -Inf
+
Finally, we can examine how these alternative priors have changed our posterior inference. We use summarize to calculate these under the alternative model.
We see that the more informative priors have pulled the posterior distribution of mu towards zero and made it less variable.
+
+
+
+Multiple Alternative Specifications
+
What if instead we are concerned about our distributional assumption on the school treatment effects. We could probe this assumption by fitting a series of models where eta had a Student’s t distribution, with varying degrees of freedom.
Notice how we have parameterized the alternative sampling statement with a variable df, and then provided the values df takes in another argument to make_spec.
+
As before, we compute importance sampling weights to approximate the posterior under these alternative models. Here, for the purposes of illustration, we are using keep_bad=T to compute weights even when the Pareto k diagnostic value is above 0.7. In practice, the alternative models should be completely re-fit in Stan.
To examine the impact of these model changes, we can plot the posterior for a quantity of interest versus the degrees of freedom for the t distribution. The package provides the spec_plot function which takes an x-axis specification parameter and a y-axis posterior quantity (which must evaluate to a single number per posterior draw). The dashed line shows the posterior median under the original model.
It appears that changing the distribution of eta/theta from normal to t has a small effect on posterior inferences (although, as noted above, these inferences are unreliable as k > 0.7.
+
By default, the function plots an inner 80% credible interval and an outer 95% credible interval, but these can be changed by the user.
+
We can also measure the distance between the new and original posterior marginals by using the special wasserstein() function available in summarize():
As we would expect, the 1-Wasserstein distance decreases as the degrees of freedom increase. In general, we can compute the p-Wasserstein distance by passing an extra p parameter to wasserstein().
+
+
+
+
+
Gelman, Andrew, J. B. Carlin, Hal S. Stern, David B. Dunson, Aki Vehtari, and Donald B. Rubin. 2013. Bayesian Data Analysis. 3rd ed. London: CRC Press.
adjustr is an R package which provides functions to help assess the sensitivity of a Bayesian model (fitted with Stan) to the specification of its likelihood and priors. Users provide a series of alternate sampling specifications, and the package uses Pareto-smoothed importance sampling to estimate the posterior under each specification. The package also provides functions to summarize and plot how posterior quantities quantities change across specifications.
-
The package aims to provide simple interface that makes it as easy as possible for modellers to try out various adjustments to their Stan models, without needing to write any specific Stan code or even recompile or rerun their model.
+
Sensitivity analysis is a critical component of a good modeling workflow. Yet as the number and power of Bayesian computational tools has increased, the options for sensitivity analysis have remained largely the same: compute importance sampling weights manually, or fit a large number of similar models, dramatically increasing computation time. Neither option is satisfactory for most applied modeling.
+
adjustr is an R package which aims to make sensitivity analysis faster and easier, and works with Bayesian models fitted with Stan. Users provide a series of alternate sampling specifications, and the package uses Pareto-smoothed importance sampling to estimate the posterior under each specification. The package also provides functions to summarize and plot how posterior quantities quantities change across specifications.
+
The package provides simple interface that makes it as easy as possible for modellers to try out various adjustments to their Stan models, without needing to write any specific Stan code or even recompile or rerun their model.
The package works by parsing Stan model code, so everything works best if the model was written by the user. Models made using brms may in principle be used as well. Models made using rstanarm are constructed using more complex model templates, and cannot be used.
Getting Started
-
The tutorial vignettes walk through a full sensitivity analysis for the classic 8-schools example. Smaller examples are also included in the package documentation.
+
The basic adjustr workflow is as follows:
+
+
Use make_spec to specify the set of alternative model specifications you’d like to fit.
+
Use adjust_weights to calculate importance sampling weights which approximate the posterior of each alternative specification.
+
Use summarize and spec_plot to examine posterior quantities of interest for each alternative specification, in order to assess the sensitivity of the underlying model.
+
+
The tutorial vignette walk through a full sensitivity analysis for the classic 8-schools example. Smaller examples are also included in the package documentation.
Arg
posterior, and which as a result cannot be reliably estimated using
importance sampling (i.e., if the Pareto shape parameter is larger than
0.7), have their weights discarded.
+
+
+
incl_orig
+
When TRUE, include a row for the original
+model specification, with all weights equal. Can facilitate comaprison
+and plotting later.
@@ -218,9 +224,10 @@
Value
.weights, containing vectors of weights for each draw, and
.pareto_k, containing the diagnostic Pareto shape parameters. Values
greater than 0.7 indicate that importance sampling is not reliable.
- Weights can be extracted with the pull.adjustr_weighted
- method. The returned object also includes the model sample draws, in the
- draws attribute.
+ If incl_orig is TRUE, a row is added for the original model
+ specification. Weights can be extracted with the
+ pull.adjustr_weighted method. The returned object also
+ includes the model sample draws, in the draws attribute.
Arg
into the corresponding parameter in the sampling statements. For data
frame, each entry in each column will be substituted into the corresponding
parameter in the sampling statements.
-
List arguments are coerced to data frame. They can either be lists of named
+
List arguments are coerced to data frames. They can either be lists of named
vectors, or lists of lists of single-element named vector.
The lengths of all parameter arguments must be consistent. Named vectors
can have length 1 or must have length equal to the number of rows in all
@@ -218,16 +218,16 @@
The x-axis variable, which is usually one of the specification
+parameters. Can be set to 1 if there is only one specification.
+Automatically quoted and evaluated in the context of x.
+
+
+
post
+
The posterior quantity of interest, to be computed for each
+resampled draw of each specificaiton. Should evaluate to a single number
+for each draw. Automatically quoted and evaluated in the context of x.
+
+
+
only_mean
+
Whether to only plot the posterior mean. May be more stable.
+
+
+
ci_level
+
The inner credible interval to plot. Central
+100*ci_level
+posterior draws.
+
+
+
outer_level
+
The outer credible interval to plot.
+
+
+
...
+
Ignored.
+
+
+
+
Value
+
+
A ggplot object which can be further
+ customized with the ggplot2 package.
Arg
posterior distribution of eight alternative specification. For example,
a value of mean(theta) will compute the posterior mean of
theta for each alternative specification.
+
Also supported is the custom function wasserstein, which computes
+ the Wasserstein-p distance between the posterior distribution of the
+ provided expression under the new model and under the original model, with
+ p=1 the default. Lower the spacing parameter from the
+ default of 0.005 to compute a finer (but slower) approximation.
The arguments in ... are automatically quoted and evaluated in the
context of .data. They support unquoting and splicing.
@@ -232,6 +237,7 @@
Examp
adjusted=adjust_weights(spec, eightschools_m)
summarize(adjusted, mean(mu), var(mu))
+summarize(adjusted, wasserstein(mu, p=2))
summarize(adjusted, diff_1=mean(y[1] - theta[1]), .model_data=model_data)
summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95)))
}
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
index c788329..4bd8fb2 100644
--- a/docs/sitemap.xml
+++ b/docs/sitemap.xml
@@ -25,10 +25,10 @@
https://corymccartan.github.io/adjustr//reference/make_spec.html
- https://corymccartan.github.io/adjustr//reference/plot.adjustr_weighted.html
+ https://corymccartan.github.io/adjustr//reference/pull.adjustr_weighted.html
- https://corymccartan.github.io/adjustr//reference/pull.adjustr_weighted.html
+ https://corymccartan.github.io/adjustr//reference/spec_plot.htmlhttps://corymccartan.github.io/adjustr//reference/summarize.adjustr_weighted.html
diff --git a/inst/stan.g b/inst/stan.g
deleted file mode 100644
index 24ea6bc..0000000
--- a/inst/stan.g
+++ /dev/null
@@ -1,152 +0,0 @@
-program: functions? data? tdata? params? tparams? model? generated?;
-
-functions: 'functions' function_decls;
-data: 'data' var_decls;
-tdata: 'transformed data' var_decls_statements;
-params: 'parameters' var_decls;
-tparams: 'transformed parameters' var_decls_statements;
-model: 'model' var_decls_statements;
-generated: 'generated quantities' var_decls_statements;
-function_decls: '{' function_decl* '}';
-var_decls: '{' var_decl* '}';
-var_decls_statements: '{' var_decl* statement* '}';
-
-
-function_decl: return_type identifier '(' (parameter_decl (',' parameter_decl)*)? ')'
- statement;
-
-return_type: 'void' | unsized_type;
-parameter_decl: 'data'? unsized_type identifier;
-unsized_type: basic_type unsized_dims?;
-basic_type: 'int' | 'real' | 'vector' | 'row_vector' | 'matrix';
-unsized_dims: '[' ','* ']';
-
-
-
-
-var_decl: var_type variable dims? ('=' expression)? ';';
-
-var_type: 'int' range_constraint
- | 'real' constraint
- | 'vector' constraint '[' expression ']'
- | 'ordered' '[' expression ']'
- | 'positive_ordered' '[' expression ']'
- | 'simplex' '[' expression ']'
- | 'unit_vector' '[' expression ']'
- | 'row_vector' constraint '[' expression ']'
- | 'matrix' constraint '[' expression ',' expression ']'
- | 'cholesky_factor_corr' '[' expression ']'
- | 'cholesky_factor_cov' '[' expression (',' expression)? ']'
- | 'corr_matrix' '[' expression ']'
- | 'cov_matrix' '[' expression ']';
-
-constraint: range_constraint | ('<' offset_multiplier '>');
-
-range_constraint: ('<' range '>')?;
-
-range: 'lower' '=' constr_expression ',' 'upper' '=' constr_expression
- | 'lower' '=' constr_expression
- | 'upper' '=' constr_expression;
-
-
-offset_multiplier: 'offset' '=' constr_expression ','
- 'multiplier' '=' constr_expression
- | 'offset' '=' constr_expression
- | 'multiplier' '=' constr_expression;
-
-dims: '[' expressions ']';
-
-variable: identifier;
-
-identifier: "[a-zA-Z][a-zA-Z0-9_]*";
-
-
-
-expressions: (expression (',' expression)*)?;
-
-expression: expression '?' expression ':' expression
- | expression infixOp expression
- | prefixOp expression
- | expression postfixOp
- | expression '[' indexes ']'
- | common_expression;
-
-constr_expression: constr_expression arithmeticInfixOp constr_expression
- | prefixOp constr_expression
- | constr_expression postfixOp
- | constr_expression '[' indexes ']'
- | common_expression;
-
-common_expression : real_literal
- | variable
- | '{' expressions '}'
- | '[' expressions ']'
- | function_literal '(' expressions? ')'
- | function_literal '(' expression ('|' (expression (',' expression)*)?)? ')'
- | 'integrate_1d' '(' function_literal (',' expression)@5:6 ')'
- | 'integrate_ode' '(' function_literal (',' expression)@6 ')'
- | 'integrate_ode_rk45' '(' function_literal (',' expression)@6:9 ')'
- | 'integrate_ode_bdf' '(' function_literal (',' expression)@6:9 ')'
- | 'algebra_solver' '(' function_literal (',' expression)@4:7 ')'
- | 'map_rect' '(' function_literal (',' expression)@4 ')'
- | '(' expression ')';
-
-prefixOp: ('!' | '-' | '+' | '^');
-
-postfixOp: '\'';
-
-infixOp: arithmeticInfixOp | logicalInfixOp;
-
-arithmeticInfixOp: ('+' | '-' | '*' | '/' | '%' | '\\' | '.*' | './');
-
-logicalInfixOp: ('||' | '&&' | '==' | '!=' | '<' | '<=' | '>' | '>=');
-
-index: (expression | expression ':' | ':' expression
- | expression ':' expression)?;
-
-indexes: (index (',' index)*)?;
-
-integer_literal: "[0-9]+";
-
-real_literal: integer_literal '.' "[0-9]*" exp_literal?
- | '.' "[0-9]+" exp_literal?
- | integer_literal exp_literal?;
-
-exp_literal: ('e' | 'E') ('+' | '-')? integer_literal;
-
-function_literal: identifier;
-
-
-
-statement: atomic_statement | nested_statement;
-
-atomic_statement: lhs assignment_op expression ';'
- | sampling_statement
- | function_literal '(' expressions ')' ';'
- | 'increment_log_prob' '(' expression ')' ';'
- | 'target' '+=' expression ';'
- | 'break' ';'
- | 'continue' ';'
- | 'print' '(' ((expression | string_literal) (',' (expression | string_literal))*)? ')' ';'
- | 'reject' '(' ((expression | string_literal) (',' (expression | string_literal))*)? ')' ';'
- | 'return' expression ';'
- | ';';
-
-sampling_statement: expression '~' identifier '(' expressions ')' truncation? ';';
-
-assignment_op: '<-' | '=' | '+=' | '-=' | '*=' | '/=' | '.*=' | './=';
-
-//string_literal: '"' char* '"';
-string_literal: "\"([^\"\\]|\\[^])*\"";
-
-truncation: 'T' '[' ?expression ',' ?expression ']';
-
-lhs: identifier ('[' indexes ']')*;
-
-nested_statement: 'if' '(' expression ')' statement
- ('else' 'if' '(' expression ')' statement)*
- ('else' statement)?
- | 'while' '(' expression ')' statement
- | 'for' '(' identifier 'in' expression ':' expression ')' statement
- | 'for' '(' identifier 'in' expression ')' statement
- | '{' var_decl* statement+ '}';
diff --git a/man/adjust_weights.Rd b/man/adjust_weights.Rd
index 1d0800f..5c7ccd7 100644
--- a/man/adjust_weights.Rd
+++ b/man/adjust_weights.Rd
@@ -5,7 +5,7 @@
\title{Compute Pareto-smoothed Importance Weights for Alternative Model
Specifications}
\usage{
-adjust_weights(spec, object, data = NULL, keep_bad = FALSE)
+adjust_weights(spec, object, data = NULL, keep_bad = FALSE, incl_orig = TRUE)
}
\arguments{
\item{spec}{An object of class \code{adjustr_spec}, probably produced by
@@ -24,6 +24,10 @@ alternate specifications which deviate too much from the original
posterior, and which as a result cannot be reliably estimated using
importance sampling (i.e., if the Pareto shape parameter is larger than
0.7), have their weights discarded.}
+
+\item{incl_orig}{When \code{TRUE}, include a row for the original
+model specification, with all weights equal. Can facilitate comaprison
+and plotting later.}
}
\value{
A tibble, produced by converting the provided \code{specs} to a
@@ -31,9 +35,10 @@ A tibble, produced by converting the provided \code{specs} to a
\code{.weights}, containing vectors of weights for each draw, and
\code{.pareto_k}, containing the diagnostic Pareto shape parameters. Values
greater than 0.7 indicate that importance sampling is not reliable.
- Weights can be extracted with the \code{\link{pull.adjustr_weighted}}
- method. The returned object also includes the model sample draws, in the
- \code{draws} attribute.
+ If \code{incl_orig} is \code{TRUE}, a row is added for the original model
+ specification. Weights can be extracted with the
+ \code{\link{pull.adjustr_weighted}} method. The returned object also
+ includes the model sample draws, in the \code{draws} attribute.
}
\description{
Given a set of new sampling statements, which can be parametrized by a
diff --git a/man/make_spec.Rd b/man/make_spec.Rd
index fa662fd..da2fce0 100644
--- a/man/make_spec.Rd
+++ b/man/make_spec.Rd
@@ -29,7 +29,7 @@ make_spec(...)
frame, each entry in each column will be substituted into the corresponding
parameter in the sampling statements.
- List arguments are coerced to data frame. They can either be lists of named
+ List arguments are coerced to data frames. They can either be lists of named
vectors, or lists of lists of single-element named vector.
The lengths of all parameter arguments must be consistent. Named vectors
diff --git a/man/plot.adjustr_weighted.Rd b/man/spec_plot.Rd
similarity index 82%
rename from man/plot.adjustr_weighted.Rd
rename to man/spec_plot.Rd
index f740c3c..3dfdef3 100644
--- a/man/plot.adjustr_weighted.Rd
+++ b/man/spec_plot.Rd
@@ -1,10 +1,18 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/use_weights.R
-\name{plot.adjustr_weighted}
-\alias{plot.adjustr_weighted}
+\name{spec_plot}
+\alias{spec_plot}
\title{Plot Posterior Quantities of Interest Under Alternative Model Specifications}
\usage{
-\method{plot}{adjustr_weighted}(x, by, post, only_mean = FALSE, ci_level = 0.8, outer_level = 0.95, ...)
+spec_plot(
+ x,
+ by,
+ post,
+ only_mean = FALSE,
+ ci_level = 0.8,
+ outer_level = 0.95,
+ ...
+)
}
\arguments{
\item{x}{An \code{adjustr_weighted} object.}
@@ -33,7 +41,7 @@ A \code{\link[ggplot2]{ggplot}} object which can be further
}
\description{
Uses weights computed in \code{\link{adjust_weights}} to plot posterior
-quantities of interest versus
+quantities of interest versus specification parameters
}
\examples{
\dontrun{
@@ -41,9 +49,9 @@ spec = make_spec(eta ~ student_t(df, 0, scale),
df=1:10, scale=seq(2, 1, -1/9))
adjusted = adjust_weights(spec, eightschools_m)
-plot(adjusted, df, theta[1])
-plot(adjusted, df, mu, only_mean=TRUE)
-plot(adjusted, scale, tau)
+spec_plot(adjusted, df, theta[1])
+spec_plot(adjusted, df, mu, only_mean=TRUE)
+spec_plot(adjusted, scale, tau)
}
}
diff --git a/man/summarize.adjustr_weighted.Rd b/man/summarize.adjustr_weighted.Rd
index 17b762e..2edf524 100644
--- a/man/summarize.adjustr_weighted.Rd
+++ b/man/summarize.adjustr_weighted.Rd
@@ -18,6 +18,12 @@
a value of \code{mean(theta)} will compute the posterior mean of
\code{theta} for each alternative specification.
+ Also supported is the custom function \code{wasserstein}, which computes
+ the Wasserstein-p distance between the posterior distribution of the
+ provided expression under the new model and under the original model, with
+ \code{p=1} the default. Lower the \code{spacing} parameter from the
+ default of 0.005 to compute a finer (but slower) approximation.
+
The arguments in \code{...} are automatically quoted and evaluated in the
context of \code{.data}. They support unquoting and splicing.}
@@ -51,6 +57,7 @@ spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
adjusted = adjust_weights(spec, eightschools_m)
summarize(adjusted, mean(mu), var(mu))
+summarize(adjusted, wasserstein(mu, p=2))
summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data)
summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95)))
}
diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R
index 5a27d52..954b66c 100644
--- a/tests/testthat/setup.R
+++ b/tests/testthat/setup.R
@@ -1,4 +1 @@
-load("../test_model.rda")
-
-# set up Stan parsing
-get_parser()
\ No newline at end of file
+load("../test_model.rda")
\ No newline at end of file
diff --git a/tests/testthat/test_logprob.R b/tests/testthat/test_logprob.R
index 3bb2f1f..1c31b15 100644
--- a/tests/testthat/test_logprob.R
+++ b/tests/testthat/test_logprob.R
@@ -24,10 +24,9 @@ test_that("Model parameter log probabilities are calculated correctly", {
test_that("Data is assembled correctly", {
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- parsed_vars = get_variables(parsed_model)
+ parsed = parse_model(code)
bd = get_base_data(eightschools_m, list(eta ~ student_t(df, 0, tau)),
- parsed_vars, list(df=1:2), "df")
+ parsed$vars, list(df=1:2), "df")
expect_length(bd, 1)
expect_named(bd[[1]], c("eta", "tau"), ignore.order=T)
@@ -37,10 +36,9 @@ test_that("Data is assembled correctly", {
test_that("MCMC draws are preferred over provided data", {
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- parsed_vars = get_variables(parsed_model)
+ parsed = parse_model(code)
bd = get_base_data(eightschools_m, list(eta ~ student_t(2, 0, tau)),
- parsed_vars, list(tau=3))
+ parsed$vars, list(tau=3))
expect_length(bd, 1)
expect_named(bd[[1]], c("eta", "tau"), ignore.order=T)
@@ -50,9 +48,8 @@ test_that("MCMC draws are preferred over provided data", {
test_that("Parameter-less specification data is correctly assembled", {
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- parsed_vars = get_variables(parsed_model)
- bd = get_base_data(eightschools_m, list(y ~ std_normal()), parsed_vars,
+ parsed = parse_model(code)
+ bd = get_base_data(eightschools_m, list(y ~ std_normal()), parsed$vars,
list(y=c(28, 8, -3, 7, -1, 1, 18, 12), J=8))
expect_length(bd, 1)
@@ -63,34 +60,31 @@ test_that("Parameter-less specification data is correctly assembled", {
test_that("Error thrown for missing data", {
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- parsed_vars = get_variables(parsed_model)
+ parsed = parse_model(code)
expect_error(get_base_data(eightschools_m, list(eta ~ normal(gamma, sigma)),
- parsed_vars, list()), "sigma not found")
+ parsed$vars, list()), "sigma not found")
expect_error(get_base_data(eightschools_m, list(eta ~ normal(gamma, 2)),
- parsed_vars, list()), "gamma not found")
+ parsed$vars, list()), "gamma not found")
})
test_that("Model log probability is correctly calculated", {
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- parsed_vars = get_variables(parsed_model)
+ parsed = parse_model(code)
form = eta ~ normal(0, 1)
draws = rstan::extract(eightschools_m, "eta", permuted=F)
exp_lp = 2*apply(dnorm(draws, 0, 1, log=T), 1:2, sum)
- lp = calc_original_lp(eightschools_m, list(form, form), parsed_vars, list())
+ lp = calc_original_lp(eightschools_m, list(form, form), parsed$vars, list())
expect_equal(exp_lp, lp)
})
test_that("Alternate specifications log probabilities are correctly calculated", {
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- parsed_vars = get_variables(parsed_model)
+ parsed = parse_model(code)
form = eta ~ normal(0, s)
draws = rstan::extract(eightschools_m, "eta", permuted=F)
exp_lp = 2*apply(dnorm(draws, 0, 1, log=T), 1:2, sum)
- lp = calc_specs_lp(eightschools_m, list(form, form), parsed_vars, list(), list(list(s=1)))
+ lp = calc_specs_lp(eightschools_m, list(form, form), parsed$vars, list(), list(list(s=1)))
expect_equal(exp_lp, lp[[1]])
})
diff --git a/tests/testthat/test_parsing.R b/tests/testthat/test_parsing.R
index 846f2be..05b4584 100644
--- a/tests/testthat/test_parsing.R
+++ b/tests/testthat/test_parsing.R
@@ -1,15 +1,9 @@
context("Stan model parsing")
-
-test_that("Model parses and returns parsing table", {
- code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- expect_equal(names(parsed_model), c("name", "value", "pos", "depth", "i"))
-})
-
-test_that("Bad model throws syntax error", {
- code = paste(eightschools_m@stanmodel@model_code, "\ndata{\n}\n")
- expect_error(parse_model(code), "syntax error")
+test_that("Empty model handled correctly", {
+ parsed = parse_model("")
+ expect_equal(length(parsed$vars), 0)
+ expect_equal(length(parsed$samps), 0)
})
test_that("Correct parsed variables", {
@@ -17,15 +11,16 @@ test_that("Correct parsed variables", {
tau="parameters", eta="parameters",
theta="transformed parameters")
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- expect_equal(get_variables(parsed_model), correct_vars)
+ parsed = parse_model(code)
+ expect_equal(parsed$vars, correct_vars)
})
test_that("Correct parsed sampling statements", {
- correct_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma))
+ correct_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma),
+ mu ~ uniform(-1e+100, 1e+100), tau ~ uniform(-1e+100, 1e+100))
code = eightschools_m@stanmodel@model_code
- parsed_model = parse_model(code)
- expect_equal(get_sampling_stmts(parsed_model), correct_samp)
+ parsed = parse_model(code)
+ expect_equal(parsed$samp, correct_samp)
})
test_that("Provided sampling statements can be matched to model", {
@@ -41,7 +36,7 @@ test_that("Extra sampling statements not in model throw an error", {
model_samp = list(eta ~ std_normal(), y ~ normal(theta, sigma))
prov_samp = list(eta ~ exponential(5), x ~ normal(theta, sigma))
expect_error(match_sampling_stmts(prov_samp, model_samp),
- "No matching sampling statement found for prior x ~ normal\\(theta, sigma\\)")
+ "No matching sampling statement found for x ~ normal\\(theta, sigma\\)")
})
test_that("Variables are correctly extracted from sampling statements", {
diff --git a/tests/testthat/test_use.R b/tests/testthat/test_use.R
index 695d26e..c0045dc 100644
--- a/tests/testthat/test_use.R
+++ b/tests/testthat/test_use.R
@@ -34,6 +34,8 @@ test_that("Weighted array functions compute correctly", {
expect_equal(wtd_array_mean(y, wgt), wtd_mean)
expect_equal(wtd_array_var(y, wgt), weighted.mean((y - wtd_mean)^2, wgt))
expect_equal(wtd_array_sd(y, wgt), sqrt(weighted.mean((y - wtd_mean)^2, wgt)))
+ expect_equal(wtd_array_quantile(y, rep(1, 5), 0.2), 1)
+ expect_equal(wtd_array_median(y, rep(1, 5)), 2.5)
})
test_that("Empty call to `summarize` should change nothing", {
@@ -42,6 +44,15 @@ test_that("Empty call to `summarize` should change nothing", {
expect_identical(summarize(obj), obj)
})
+test_that("Non-summary call to `summarize` should throw error", {
+ obj = tibble(.weights=list(c(1,1,1), c(1,1,4)))
+ attr(obj, "draws") = list(theta=matrix(c(3,5,7,1,1,1), ncol=2))
+ attr(obj, "iter") = 3
+ class(obj) = c("adjustr_weighted", class(obj))
+
+ expect_error(summarize(obj, theta), "must summarize posterior draws")
+})
+
test_that("Basic summaries are computed correctly", {
obj = tibble(.weights=list(c(1,1,1), c(1,1,4)))
attr(obj, "draws") = list(theta=matrix(c(3,5,7,1,1,1), ncol=2))
@@ -56,6 +67,10 @@ test_that("Basic summaries are computed correctly", {
expect_is(sum2, "adjustr_weighted")
expect_equal(sum2$th, list(c(5, 1), c(6, 1)))
+ sum3 = summarize(obj, W = wasserstein(theta[1]))
+ expect_is(sum3, "adjustr_weighted")
+ expect_equal(sum3$W[1], 0)
+
expect_error(summarise.adjustr_weighted(as_tibble(obj)), "is not TRUE")
})
@@ -82,19 +97,20 @@ test_that("Resampling-based summaries are computed correctly", {
sum1 = summarize(obj, th=mean(theta), .resampling=T)
expect_equal(sum1$th, c(3,7))
- sum2 = summarize(obj, th=quantile(theta, 0.05))
+ sum2 = summarize(obj, th=quantile(theta, 0.05), .resampling=T)
expect_equal(sum2$th, c(3,7))
})
test_that("Plotting function handles arguments correctly", {
- obj = tibble(.weights=list(c(1,0,0), c(0,0,1)))
+ obj = tibble(.weights=list(c(1,0,0), c(0,0,1), c(1,1,1)),
+ .samp=c("y ~ normal(0, 1)", "y ~ normal(0, 2)", ""))
attr(obj, "draws") = list(theta=matrix(c(3,5,7), nrow=3, ncol=1))
attr(obj, "iter") = 3
class(obj) = c("adjustr_weighted", class(obj))
- expect_is(plot(obj, 1, theta), "ggplot")
- expect_is(plot(obj, 1, theta, only_mean=T), "ggplot")
+ expect_is(spec_plot(obj, 1, theta), "ggplot")
+ expect_is(spec_plot(obj, 1, theta, only_mean=T), "ggplot")
- expect_error(plot(obj, 1, theta, outer_level=0.4), "should be less than")
+ expect_error(spec_plot(obj, 1, theta, outer_level=0.4), "should be less than")
})
diff --git a/tests/testthat/test_weights.R b/tests/testthat/test_weights.R
index c629746..0691b13 100644
--- a/tests/testthat/test_weights.R
+++ b/tests/testthat/test_weights.R
@@ -32,7 +32,7 @@ test_that("Weights calculated correctly (normal/inflated)", {
weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=F))
spec = make_spec(y ~ normal(theta, 1.1*sigma))
- obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T)
+ obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F)
expect_s3_class(obj, "adjustr_weighted")
expect_s3_class(obj, "tbl_df")
@@ -58,7 +58,7 @@ test_that("Weights calculated correctly (normal/student_t)", {
weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=F))
spec = make_spec(y ~ student_t(df, theta, sigma), df=5:6)
- obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T)
+ obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F)
expect_equal(weights, obj$.weights[[2]])
expect_equal(pareto_k, obj$.pareto_k[2])
@@ -77,7 +77,7 @@ test_that("Weights calculated correctly (no data normal/student_t)", {
weights = as.numeric(loo::weights.importance_sampling(psis_wgt, log=F))
spec = make_spec(eta ~ student_t(4, 0, 1))
- obj = adjust_weights(spec, eightschools_m, keep_bad=T)
+ obj = adjust_weights(spec, eightschools_m, keep_bad=T, incl_orig=F)
expect_equal(weights, obj$.weights[[1]])
expect_equal(pareto_k, obj$.pareto_k)
@@ -86,14 +86,14 @@ test_that("Weights calculated correctly (no data normal/student_t)", {
test_that("Weights extracted correctly", {
spec = make_spec(y ~ student_t(df, theta, sigma), df=5)
- obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T)
+ obj = adjust_weights(spec, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F)
pulled = pull(obj)
expect_is(pulled, "numeric")
expect_length(pulled, 20)
spec2 = make_spec(y ~ student_t(df, theta, sigma), df=5:6)
- obj = adjust_weights(spec2, eightschools_m, pkg_env$model_d, keep_bad=T)
+ obj = adjust_weights(spec2, eightschools_m, pkg_env$model_d, keep_bad=T, incl_orig=F)
pulled = pull(obj)
expect_is(pulled, "list")
@@ -104,6 +104,8 @@ test_that("Weights extracted correctly", {
test_that("Sampling statements printed correctly", {
expect_output(extract_samp_stmts(eightschools_m),
"Sampling statements for model 2c8d1d8a30137533422c438f23b83428:
+ parameter tau ~ uniform(-1e+100, 1e+100)
+ parameter mu ~ uniform(-1e+100, 1e+100)
parameter eta ~ std_normal()
data y ~ normal(theta, sigma)", fixed=T)
})
diff --git a/vignettes/eight-schools.Rmd b/vignettes/eight-schools.Rmd
index c8ee816..7f1c4f3 100644
--- a/vignettes/eight-schools.Rmd
+++ b/vignettes/eight-schools.Rmd
@@ -1,8 +1,9 @@
---
title: "Sensitivity Analysis of a Simple Hierarchical Model"
output: rmarkdown::html_vignette
+bibliography: eight-schools.bib
vignette: >
- %\VignetteIndexEntry{eight-schools}
+ %\VignetteIndexEntry{Sensitivity Analysis of a Simple Hierarchical Model}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
@@ -12,8 +13,175 @@ knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
+
+library(dplyr)
+library(rstan)
+library(adjustr)
+load("eightschools_model.rda")
```
-```{r setup}
+## Introduction
+
+This vignette walks through the process of performing sensitivity
+analysis using the `adjustr` package for the classic introductory
+hierarchical model: the "eight schools" meta-analysis from Chapter 5
+of @bda3.
+
+We begin by specifying and fitting the model, which should be familiar
+to most users of Stan.
+```{r eval=F}
+library(dplyr)
+library(rstan)
library(adjustr)
+
+model_code = "
+data {
+ int J; // number of schools
+ real y[J]; // estimated treatment effects
+ real sigma[J]; // standard error of effect estimates
+}
+parameters {
+ real mu; // population treatment effect
+ real tau; // standard deviation in treatment effects
+ vector[J] eta; // unscaled deviation from mu by school
+}
+transformed parameters {
+ vector[J] theta = mu + tau * eta; // school treatment effects
+}
+model {
+ eta ~ std_normal();
+ y ~ normal(theta, sigma);
+}"
+
+model_d = list(J = 8,
+ y = c(28, 8, -3, 7, -1, 1, 18, 12),
+ sigma = c(15, 10, 16, 11, 9, 11, 10, 18))
+eightschools_m = stan(model_code=model_code, chains=2, data=model_d,
+ warmup=500, iter=1000)
+```
+
+We plot the original estimates for each of the eight schools.
+```{r}
+plot(eightschools_m, pars="theta")
+```
+
+The model partially pools information, pulling the school-level treatment effects
+towards the overall mean.
+
+It is natural to wonder how much these estimates depend on certain aspects of
+our model. The individual and school treatment effects are assumed to follow a
+normal distribution, and we have used a uniform prior on the population
+parameters `mu` and `tau`.
+
+The basic __adjustr__ workflow is as follows:
+
+1. Use `make_spec` to specify the set of alternative model specifications you'd
+like to fit.
+
+2. Use `adjust_weights` to calculate importance sampling weights which
+approximate the posterior of each alternative specification.
+
+3. Use `summarize` and `spec_plot` to examine posterior quantities of interest
+for each alternative specification, in order to assess the sensitivity of the
+underlying model.
+
+## Basic Workflow Example
+
+First suppose we want to examine the effect of our choice of uniform prior
+on `mu` and `tau`. We begin by specifying an alternative model in which
+these parameters have more informative priors. This just requires
+passing the `make_spec` function the new sampling statements we'd like to
+use. These replace any in the original model (`mu` and `tau` have implicit
+improper uniform priors, since the original model does not have any sampling
+statements for them).
+```{r}
+spec = make_spec(mu ~ normal(0, 20), tau ~ exponential(5))
+print(spec)
+```
+
+Then we compute importance sampling weights to approximate the posterior under
+this alternative model.
+```{r include=F}
+adjusted = adjust_weights(spec, eightschools_m, keep_bad=T)
```
+```{r eval=F}
+adjusted = adjust_weights(spec, eightschools_m)
+```
+
+The `adjust_weights` function returns a data frame
+containing a summary of the alternative model and a list-column named `.weights`
+containing the importance weights. The last row of the table by default
+corresponds to the original model specification. The table also includes the
+diagnostic Pareto *k*-value. When this value exceeds 0.7, importance sampling is
+unreliable, and by default `adjust_weights` discards weights with a Pareto *k*
+above 0.7.
+```{r}
+print(adjusted)
+```
+
+Finally, we can examine how these alternative priors have changed our posterior
+inference. We use `summarize` to calculate these under the alternative model.
+```{r}
+summarize(adjusted, mean(mu), var(mu))
+```
+We see that the more informative priors have pulled the posterior distribution
+of `mu` towards zero and made it less variable.
+
+## Multiple Alternative Specifications
+What if instead we are concerned about our distributional assumption on the
+school treatment effects. We could probe this assumption by fitting a series of
+models where `eta` had a Student's *t* distribution, with varying degrees of
+freedom.
+
+The `make_spec` function handles this easily.
+```{r}
+spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
+print(spec)
+```
+Notice how we have parameterized the alternative sampling statement with
+a variable `df`, and then provided the values `df` takes in another argument
+to `make_spec`.
+
+As before, we compute importance sampling weights to approximate the posterior
+under these alternative models. Here, for the purposes of illustration,
+we are using `keep_bad=T` to compute weights even when the Pareto _k_ diagnostic
+value is above 0.7. In practice, the alternative models should be completely
+re-fit in Stan.
+```{r}
+adjusted = adjust_weights(spec, eightschools_m, keep_bad=T)
+```
+Now, `adjusted` has ten rows, one for each alternative model.
+```{r}
+print(adjusted)
+```
+
+To examine the impact of these model changes, we can plot the posterior for
+a quantity of interest versus the degrees of freedom for the *t* distribution.
+The package provides the `spec_plot` function which takes an x-axis specification
+parameter and a y-axis posterior quantity (which must evaluate to a single
+number per posterior draw). The dashed line shows the posterior median
+under the original model.
+```{r}
+spec_plot(adjusted, df, mu)
+spec_plot(adjusted, df, theta[3])
+```
+
+It appears that changing the distribution of `eta`/`theta` from normal to
+*t* has a small effect on posterior inferences (although, as noted above,
+these inferences are unreliable as _k_ > 0.7.
+
+By default, the function plots an inner 80\% credible interval and an outer
+95\% credible interval, but these can be changed by the user.
+
+We can also measure the distance between the new and original posterior
+marginals by using the special `wasserstein()` function available in
+`summarize()`:
+```{r}
+summarize(adjusted, wasserstein(mu))
+```
+As we would expect, the 1-Wasserstein distance decreases as the degrees of
+freedom increase. In general, we can compute the _p_-Wasserstein distance
+by passing an extra `p` parameter to `wasserstein()`.
+
+
+###
\ No newline at end of file
diff --git a/vignettes/eight-schools.bib b/vignettes/eight-schools.bib
new file mode 100644
index 0000000..a414cde
--- /dev/null
+++ b/vignettes/eight-schools.bib
@@ -0,0 +1,7 @@
+@book{bda3,
+ address = {London},
+ author = {Andrew Gelman and J.~B.~Carlin and Hal S.~Stern and David B.~Dunson and Aki Vehtari and Donald B.~Rubin},
+ edition = {3rd},
+ publisher = {CRC Press},
+ title = {Bayesian Data Analysis},
+ year = {2013}}
diff --git a/vignettes/eightschools_model.rda b/vignettes/eightschools_model.rda
new file mode 100644
index 0000000..cfc44de
Binary files /dev/null and b/vignettes/eightschools_model.rda differ