diff --git a/docs/source/automated_dr_learner.ipynb b/docs/source/automated_dr_learner.ipynb
new file mode 100644
index 00000000..4c68e400
--- /dev/null
+++ b/docs/source/automated_dr_learner.ipynb
@@ -0,0 +1,809 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Automated doubly robust estimation with ChiRho"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here, we install the necessary Pytorch, Pyro, and ChiRho dependencies for this example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from typing import Callable, Optional, Tuple\n",
+ "\n",
+ "import functools\n",
+ "import torch\n",
+ "import math\n",
+ "import seaborn as sns\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import pyro\n",
+ "import pyro.distributions as dist\n",
+ "\n",
+ "from chirho.counterfactual.handlers import MultiWorldCounterfactual\n",
+ "from chirho.indexed.ops import IndexSet, gather\n",
+ "from chirho.interventional.handlers import do\n",
+ "from chirho.observational.handlers.condition import condition\n",
+ "from chirho.observational.handlers.predictive import PredictiveModel\n",
+ "from chirho.robust.handlers.estimators import MonteCarloInfluenceEstimator, one_step_corrected_estimator \n",
+ "\n",
+ "if not pyro.settings.get(\"module_local_params\"):\n",
+ " pyro.settings.set(module_local_params=True)\n",
+ "\n",
+ "sns.set_style(\"white\")\n",
+ "\n",
+ "pyro.set_rng_seed(321) # for reproducibility"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Overview: automated robust estimation pipeline\n",
+ "\n",
+ "In this tutorial, we will use ChiRho to estimate the average treatment effect (ATE) from observational data. We will use a simple example to illustrate the basic concepts of doubly robust estimation and how ChiRho can be used to automate the process for more general summaries of interest. \n",
+ "\n",
+ "There are five main steps to our doubly robust estimation procedure but only the last step is different from a standard probabilistic programming workflow:\n",
+ "1. Write model of interest\n",
+ " - Define probabilistic model of interest using Pyro\n",
+ "2. Feed in data\n",
+ " - Observed data used to train the model\n",
+ "3. Run inference\n",
+ " - Use Pyro's rich inference library to fit the model to the data\n",
+ "4. Define target functional\n",
+ " - This is the model summary of interest (e.g. average treatment effect)\n",
+ "5. Compute robust estimate\n",
+ " - Use ChiRho to compute the doubly robust estimate of the target functional\n",
+ " - Importantly, this step is automated and does not require refitting the model for each new functional\n",
+ "\n",
+ "\n",
+ "Our proposed automated robust inference pipeline is summarized in the figure below.\n",
+ "\n",
+ "![fig1](figures/robust_pipeline.png)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Causal Probabilistic Program\n",
+ "\n",
+ "### Model Description\n",
+ "In this example, we will focus on a cannonical model `CausalGLM` consisting of three types of variables: binary treatment (`A`), confounders (`X`), and response (`Y`). For simplicitly, we assume that the response is generated from a generalized linear model with link function $g$. The model is described by the following generative process:\n",
+ "\n",
+ "$$\n",
+ "\\begin{align*}\n",
+ "X &\\sim \\text{Normal}(0, I_p) \\\\\n",
+ "A &\\sim \\text{Bernoulli}(\\pi(X)) \\\\\n",
+ "\\mu &= \\beta_0 + \\beta_1^T X + \\tau A \\\\\n",
+ "Y &\\sim \\text{ExponentialFamily}(\\text{mean} = g^{-1}(\\mu))\n",
+ "\\end{align*}\n",
+ "$$\n",
+ "\n",
+ "where $p$ denotes the number of confounders, $\\pi(X)$ is the probability of treatment conditional on confounders $X$, $\\beta_0$ is the intercept, $\\beta_1$ is the confounder effect, and $\\tau$ is the treatment effect."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class CausalGLM(pyro.nn.PyroModule):\n",
+ " def __init__(\n",
+ " self,\n",
+ " p: int,\n",
+ " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n",
+ " prior_scale: Optional[float] = None,\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.p = p\n",
+ " self.link_fn = link_fn\n",
+ " if prior_scale is None:\n",
+ " self.prior_scale = 1 / math.sqrt(self.p)\n",
+ " else:\n",
+ " self.prior_scale = prior_scale\n",
+ "\n",
+ " @pyro.nn.PyroSample\n",
+ " def outcome_weights(self):\n",
+ " return dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1)\n",
+ "\n",
+ " @pyro.nn.PyroSample\n",
+ " def intercept(self):\n",
+ " return dist.Normal(0.0, 1.0)\n",
+ "\n",
+ " @pyro.nn.PyroSample\n",
+ " def propensity_weights(self):\n",
+ " return dist.Normal(0.0, self.prior_scale).expand((self.p,)).to_event(1)\n",
+ "\n",
+ " @pyro.nn.PyroSample\n",
+ " def treatment_weight(self):\n",
+ " return dist.Normal(0.0, 1.0)\n",
+ "\n",
+ " @property\n",
+ " def covariate_loc(self):\n",
+ " return torch.zeros(self.p)\n",
+ " \n",
+ " @property\n",
+ " def covariate_scale(self):\n",
+ " return torch.ones(self.p)\n",
+ "\n",
+ " def forward(self):\n",
+ " X = pyro.sample(\"X\", dist.Normal(self.covariate_loc, self.covariate_scale).to_event(1))\n",
+ " A = pyro.sample(\n",
+ " \"A\",\n",
+ " dist.Bernoulli(\n",
+ " logits=torch.einsum(\"...i,...i->...\", X, self.propensity_weights)\n",
+ " ),\n",
+ " )\n",
+ "\n",
+ " return pyro.sample(\n",
+ " \"Y\",\n",
+ " self.link_fn(\n",
+ " torch.einsum(\"...i,...i->...\", X, self.outcome_weights) + A * self.treatment_weight + self.intercept\n",
+ " ),\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we will condition on both treatment and confounders to estimate the causal effect of treatment on the outcome. We will use the following causal probabilistic program to do so:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ConditionedModel(CausalGLM):\n",
+ "\n",
+ " def forward(self, *, X: torch.Tensor, A: torch.Tensor, Y: torch.Tensor):\n",
+ " with condition(data={\"X\": X, \"A\": A, \"Y\": Y}):\n",
+ " self.intercept, self.outcome_weights, self.propensity_weights, self.treatment_weight\n",
+ " with pyro.plate(\"__train__\", size=X.shape[0], dim=-1):\n",
+ " return super().forward()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Visualize the model\n",
+ "pyro.render_model(\n",
+ " lambda: ConditionedModel(p=1)(X=torch.zeros(1, 1), A=torch.zeros(1), Y=torch.zeros(1)),\n",
+ " render_params=True, \n",
+ " render_distributions=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Generating data\n",
+ "\n",
+ "For evaluation, we generate `N_datasets` datasets, each with `N` samples. We compare vanilla estimates of the target functional with the double robust estimates of the target functional across the `N_sims` datasets. We use a similar data generating process as in Kennedy (2022)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class GroundTruthModel(CausalGLM):\n",
+ " def __init__(\n",
+ " self,\n",
+ " p: int,\n",
+ " alpha: int,\n",
+ " beta: int,\n",
+ " link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),\n",
+ " ):\n",
+ " super().__init__(p, link_fn)\n",
+ " self.alpha = alpha # sparsity of propensity weights\n",
+ " self.beta = beta # sparsity of outcome weights\n",
+ "\n",
+ " @property\n",
+ " def outcome_weights(self):\n",
+ " outcome_weights = 1 / math.sqrt(self.beta) * torch.ones(self.p)\n",
+ " outcome_weights[self.beta :] = 0.0\n",
+ " return outcome_weights\n",
+ "\n",
+ " @property\n",
+ " def propensity_weights(self):\n",
+ " propensity_weights = 1 / math.sqrt(self.alpha) * torch.ones(self.p)\n",
+ " propensity_weights[self.alpha :] = 0.0\n",
+ " return propensity_weights\n",
+ "\n",
+ " @property\n",
+ " def treatment_weight(self):\n",
+ " return torch.tensor(0.)\n",
+ "\n",
+ " @property\n",
+ " def intercept(self):\n",
+ " return torch.tensor(0.0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "N_datasets = 100\n",
+ "simulated_datasets = []\n",
+ "\n",
+ "# Data configuration\n",
+ "p = 200\n",
+ "alpha = 50\n",
+ "beta = 50\n",
+ "N_train = 500\n",
+ "N_test = 500\n",
+ "\n",
+ "true_model = GroundTruthModel(p, alpha, beta)\n",
+ "\n",
+ "for _ in range(N_datasets):\n",
+ " # Generate data\n",
+ " D_train = pyro.infer.Predictive(\n",
+ " true_model, num_samples=N_train, return_sites=[\"X\", \"A\", \"Y\"], parallel=True\n",
+ " )()\n",
+ " D_test = pyro.infer.Predictive(\n",
+ " true_model, num_samples=N_test, return_sites=[\"X\", \"A\", \"Y\"], parallel=True\n",
+ " )()\n",
+ " simulated_datasets.append((D_train, D_test))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Fit parameters via maximum likelihood"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trained_guides = []\n",
+ "for i in range(N_datasets):\n",
+ " # Generate data\n",
+ " D_train = simulated_datasets[i][0]\n",
+ "\n",
+ " # Fit model using maximum likelihood\n",
+ " conditioned_model = ConditionedModel(p=D_train[\"X\"].shape[1])\n",
+ " \n",
+ " guide_train = pyro.infer.autoguide.AutoDelta(conditioned_model)\n",
+ " elbo = pyro.infer.Trace_ELBO()(conditioned_model, guide_train)\n",
+ "\n",
+ " # initialize parameters\n",
+ " elbo(X=D_train[\"X\"], A=D_train[\"A\"], Y=D_train[\"Y\"])\n",
+ " adam = torch.optim.Adam(elbo.parameters(), lr=0.03)\n",
+ "\n",
+ " # Do gradient steps\n",
+ " for _ in range(2000):\n",
+ " adam.zero_grad()\n",
+ " loss = elbo(X=D_train[\"X\"], A=D_train[\"A\"], Y=D_train[\"Y\"])\n",
+ " loss.backward()\n",
+ " adam.step()\n",
+ "\n",
+ " trained_guides.append(guide_train)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Causal Query: Average treatment effect (ATE)\n",
+ "\n",
+ "The average treatment effect summarizes, on average, how much the treatment changes the response, $ATE = \\mathbb{E}[Y|do(A=1)] - \\mathbb{E}[Y|do(A=0)]$. The `do` notation indicates that the expectations are taken according to *intervened* versions of the model, with $A$ set to a particular value. Note from our [tutorial](tutorial_i.ipynb) that this is different from conditioning on $A$ in the original `causal_model`, which assumes $X$ and $T$ are dependent.\n",
+ "\n",
+ "\n",
+ "To implement this query in ChiRho, we define the `ATEFunctional` class which take in a `model` and `guide` and returns the average treatment effect by simulating from the posterior predictive distribution of the model and guide."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Defining the target functional"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class ATEFunctional(torch.nn.Module):\n",
+ " def __init__(self, model: Callable, *, num_monte_carlo: int = 100):\n",
+ " super().__init__()\n",
+ " self.model = model\n",
+ " self.num_monte_carlo = num_monte_carlo\n",
+ " \n",
+ " def forward(self, *args, **kwargs):\n",
+ " with MultiWorldCounterfactual():\n",
+ " with pyro.plate(\"monte_carlo_functional\", size=self.num_monte_carlo, dim=-2):\n",
+ " with do(actions=dict(A=(torch.tensor(0.0), torch.tensor(1.0)))):\n",
+ " Ys = self.model(*args, **kwargs)\n",
+ " Y0 = gather(Ys, IndexSet(A={1}), event_dim=0)\n",
+ " Y1 = gather(Ys, IndexSet(A={2}), event_dim=0)\n",
+ " ate = (Y1 - Y0).mean(dim=-2, keepdim=True).mean(dim=-1, keepdim=True).squeeze()\n",
+ " return pyro.deterministic(\"ATE\", ate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Closed form doubly robust correction\n",
+ "\n",
+ "For the average treatment effect functional, there exists a closed-form analytical formula for the doubly robust correction. This formula is derived in Kennedy (2022) and is implemented below:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Closed form expression\n",
+ "def closed_form_doubly_robust_ate_correction(X_test, theta) -> Tuple[torch.Tensor, torch.Tensor]:\n",
+ " X = X_test[\"X\"]\n",
+ " A = X_test[\"A\"]\n",
+ " Y = X_test[\"Y\"]\n",
+ " pi_X = torch.sigmoid(X.mv(theta[\"propensity_weights\"]))\n",
+ " mu_X = (\n",
+ " X.mv(theta[\"outcome_weights\"])\n",
+ " + A * theta[\"treatment_weight\"]\n",
+ " + theta[\"intercept\"]\n",
+ " )\n",
+ " analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)\n",
+ " analytic_correction = analytic_eif_at_test_pts.mean()\n",
+ " return analytic_correction, analytic_eif_at_test_pts"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Computing automated doubly robust correction via Monte Carlo\n",
+ "\n",
+ "While the doubly robust correction term is known in closed-form for the average treatment effect functional, our `one_step_correction` function in `ChiRho` works for a wide class of other functionals. We focus on the average treatment effect functional here so that we have a ground truth to compare `one_step_correction` against."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/eli/development/chirho/chirho/robust/handlers/estimators.py:72: UserWarning: Calling influence_fn with torch.grad enabled can lead to memory leaks. Please use torch.no_grad() to avoid this issue. See example in the docstring.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Compute doubly robust ATE estimates using both the automated and closed form expressions\n",
+ "plug_in_ates = []\n",
+ "analytic_corrections = []\n",
+ "automated_monte_carlo_corrections = []\n",
+ "for i in range(N_datasets):\n",
+ " trained_guide = trained_guides[i]\n",
+ " D_test = simulated_datasets[i][1]\n",
+ " functional = functools.partial(ATEFunctional, num_monte_carlo=10000)\n",
+ " ate_plug_in = functional(\n",
+ " PredictiveModel(CausalGLM(p), trained_guide)\n",
+ " )()\n",
+ " analytic_correction, analytic_eif_at_test_pts = closed_form_doubly_robust_ate_correction(D_test, trained_guide(**D_test))\n",
+ " with MonteCarloInfluenceEstimator(num_samples_outer=max(10000, 100 * p), num_samples_inner=1):\n",
+ " automated_monte_carlo_correction = one_step_corrected_estimator(functional, D_test)(\n",
+ " PredictiveModel(CausalGLM(p), trained_guide)\n",
+ " )()\n",
+ "\n",
+ " plug_in_ates.append(ate_plug_in.detach().item())\n",
+ " analytic_corrections.append(ate_plug_in.detach().item() + analytic_correction.detach().item())\n",
+ " automated_monte_carlo_corrections.append(automated_monte_carlo_correction.detach().item())\n",
+ "\n",
+ "plug_in_ates = np.array(plug_in_ates)\n",
+ "analytic_corrections = np.array(analytic_corrections)\n",
+ "automated_monte_carlo_corrections = np.array(automated_monte_carlo_corrections)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Results"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "results = pd.DataFrame(\n",
+ " {\n",
+ " \"plug_in_ate\": plug_in_ates,\n",
+ " \"analytic_correction\": analytic_corrections,\n",
+ " \"automated_monte_carlo_correction\": automated_monte_carlo_corrections,\n",
+ " }\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " plug_in_ate | \n",
+ " analytic_correction | \n",
+ " automated_monte_carlo_correction | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " count | \n",
+ " 100.00 | \n",
+ " 100.00 | \n",
+ " 100.00 | \n",
+ "
\n",
+ " \n",
+ " mean | \n",
+ " 0.31 | \n",
+ " 0.20 | \n",
+ " 0.20 | \n",
+ "
\n",
+ " \n",
+ " std | \n",
+ " 0.11 | \n",
+ " 0.11 | \n",
+ " 0.11 | \n",
+ "
\n",
+ " \n",
+ " min | \n",
+ " -0.01 | \n",
+ " -0.07 | \n",
+ " -0.08 | \n",
+ "
\n",
+ " \n",
+ " 25% | \n",
+ " 0.24 | \n",
+ " 0.13 | \n",
+ " 0.14 | \n",
+ "
\n",
+ " \n",
+ " 50% | \n",
+ " 0.32 | \n",
+ " 0.20 | \n",
+ " 0.21 | \n",
+ "
\n",
+ " \n",
+ " 75% | \n",
+ " 0.37 | \n",
+ " 0.27 | \n",
+ " 0.28 | \n",
+ "
\n",
+ " \n",
+ " max | \n",
+ " 0.57 | \n",
+ " 0.44 | \n",
+ " 0.46 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " plug_in_ate analytic_correction automated_monte_carlo_correction\n",
+ "count 100.00 100.00 100.00\n",
+ "mean 0.31 0.20 0.20\n",
+ "std 0.11 0.11 0.11\n",
+ "min -0.01 -0.07 -0.08\n",
+ "25% 0.24 0.13 0.14\n",
+ "50% 0.32 0.20 0.21\n",
+ "75% 0.37 0.27 0.28\n",
+ "max 0.57 0.44 0.46"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# The true treatment effect is 0, so a mean estimate closer to zero is better\n",
+ "results.describe().round(2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Text(0.5, 0, 'ATE Estimate')"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "