From 3eecc4c2300e0ba36f1df0386a32d42e4b26382c Mon Sep 17 00:00:00 2001 From: PoorvaGarg Date: Mon, 26 Aug 2024 14:27:23 -0400 Subject: [PATCH] small edits --- docs/source/explainable_sir.ipynb | 66 ++++++++-------- docs/source/inference.ipynb | 120 ++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 33 deletions(-) create mode 100644 docs/source/inference.ipynb diff --git a/docs/source/explainable_sir.ipynb b/docs/source/explainable_sir.ipynb index d7063349..e94c6102 100644 --- a/docs/source/explainable_sir.ipynb +++ b/docs/source/explainable_sir.ipynb @@ -11,7 +11,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The **Explainable Reasoning with Chirho** package aims to provide a unified, principled approach to computations of causal explanations. We showed in an earlier [tutorial](https://basisresearch.github.io/chirho/explainable_categorical.html) how Chirho provides a handler `SearchForExplanation` to carry out the program transformations needed to compute causal queries and explanations, focusing on on discrete variables (we assume the reader is familar with it). In this notebook we illustrate the usage of `SearchForExplanation` for causal models with continuous random variables in the context of a dynamical system.\n", + "The **Explainable Reasoning with Chirho** package aims to provide a unified, principled approach to computations of causal explanations. We showed in an earlier [tutorial](https://basisresearch.github.io/chirho/explainable_categorical.html) how Chirho provides a handler `SearchForExplanation` to carry out the program transformations needed to compute causal queries and explanations, focusing on discrete variables (we assume the reader is familar with it). In this notebook we illustrate the usage of `SearchForExplanation` for causal models with continuous random variables in the context of a dynamical system.\n", "\n", "We take an epidemiological dynamical system model (described in more detail in this [tutorial](https://basisresearch.github.io/chirho/dynamical_intro.html)) and show how the but-for analysis is not sufficiently fine-grained to allow us to derive the right conclusions about effects of different policies during a pandemic. Next, we illustrate how various causal explanation queries can be computed using `SearchForExplanation` and inference algorithms. We also demonstrate how more detailed causal queries can be answered by post-processing the samples obtained using the handler. " ] @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -118,7 +118,7 @@ "\n", "This quantity is of interest because epidemic mitigation policies often have multiple goals that need to be balanced. One goal is to increase `S_final`, i.e., to limit the total number of infected individuals. Another goal is to limit the number of infected individuals at the peak of the epidemic to avoid overwhelming the healthcare system. A further goal is to minimize the proportion of the population that becomes infected after the peak, that is, the overshoot, to reduce healthcare and economic burdens. Balancing these objectives involves making trade-offs.\n", "\n", - " Suppose we are working under constraint that the overshoot show be lower than 20% of the population, and we implement two policies, lockdown and masking, which together seem to lead to the overshoot being too high. In fact, only one of them is responsible, and we are interested in being able to identify which one. " + " Suppose we are working under constraint that the overshoot should be lower than 20% of the population, and we implement two policies, lockdown and masking, which together seem to lead to the overshoot being too high. In fact, only one of them is responsible, and we are interested in being able to identify which one. " ] }, { @@ -130,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -170,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -232,12 +232,12 @@ "source": [ "\n", "\n", - "Now suppose we are uncertain about $\\beta, \\gamma$, and want to construct a Bayesian SIR model that incorporates this uncertainty. Say we inducing $\\beta$ to be drawn from `Beta(18, 600)`, and $\\gamma$ to be drawn from distribution `Beta(1600, 1600)`. " + "Now suppose we are uncertain about $\\beta, \\gamma$, and want to construct a Bayesian SIR model that incorporates this uncertainty. Say we induce $\\beta$ to be drawn from the distribution `Beta(18, 600)`, and $\\gamma$ to be drawn from distribution `Beta(1600, 1600)`. " ] }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -273,13 +273,13 @@ "metadata": {}, "source": [ "\n", - "Now we incorporate the Bayesian SIR model into a larger model that includes the effect of two different policies, lockdown and masking, where each can be implemented with $50\\%$ probability (these probabilities won't really matter, as we will be intervening on these, the sampling is mainly used to register the parameters with Pyro). We encode their efficiencies which further affect the model. Crucially, these efficiencies interact in a fashion resembling the structure of the stone-throwing example we discussed in the tutorial on categorical variables. If lockdown is present, this limits the impact of masking as agents interact less and so masks have fewer opportunities to block anything. We assume the situation is assymetric: masking has no impact on the efficiency of lockdown. The model also computes `overshoot` and `os_too_high` for further analysis.\n", + "Now we incorporate the Bayesian SIR model into a larger model that includes the effect of two different policies, lockdown and masking, where each can be implemented with $50\\%$ probability (these probabilities won't really matter, as we will be intervening on these, the sampling is mainly used to register the parameters with Pyro). We encode their efficiencies which further affect the model. Crucially, these efficiencies interact in a fashion resembling the structure of the stone-throwing example we discussed in the tutorial on categorical variables. If lockdown is present, this limits the impact of masking as agents interact less and so masks have fewer opportunities to block anything. We assume the situation is asymmetric: masking has no impact on the efficiency of lockdown. The model also computes `overshoot` and `os_too_high` for further analysis.\n", "\n" ] }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -303,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -379,12 +379,12 @@ "3. Only masking was imposed\n", "4. Only lockdown was imposed\n", "\n", - "The hope is that by looking at these we will be able to indentify the culprit. We create these four models by conditioning on the policies being imposed as required (in fact, this has the same effect as intervening here, as the sites are upstream from the model). The models obtained are similar to the intervened models since the variables `lockdown` and `mask` do not have any variables upstream to them. In principle we could emulate 1-4 using `do` with the same estimates. For the sake of completeness, we also illustrate the consequences of deciding randomly about the policies." + "The hope is that by looking at these we will be able to indentify the culprit. We create these four models by conditioning on the policies being imposed as required (in fact, this has the same effect as intervening here, as the sites are upstream from the model). In principle we could emulate 1-4 using `do` with the same estimates. For the sake of completeness, we also illustrate the consequences of deciding randomly about the policies." ] }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -431,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -564,11 +564,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The plots above show what happens in the four different scenarios. We observe that in the model where none of the policies were imposed, ther probability of the overshoot being too high is relatively low, $0.24$. On the other hand, when both policies were imposed, the probability of the overshoot being to high was relatively higher $0.81$. \n", + "The plots above show what happens in the four different scenarios. We observe that in the model where none of the policies were imposed, the probability of the overshoot being too high is relatively low, $0.24$. On the other hand, when both policies were imposed, the probability of the overshoot being to high was relatively higher $0.81$. \n", "\n", - "To identify which of `lockdown` and `mask` is the cause, we analyze the models where only one of the policies were imposed. In both cases, the probability of too high overshoot seems to be even higher - $0.96$ and $0.9$. Interestingly, the effect of the interventions is somewhat nuanced. Implementing both increases the risk of overshoot as compared to the no intervention model. But individual intereventions would have even worse consequences, which means that the two interventions while jointly increasing the risk to some extent mitigate each other's contribution to that risk as well.\n", + "To identify which of `lockdown` and `mask` is the cause, we analyze the models where only one of the policies were imposed. In both cases, the probability of too high overshoot seems to be even higher - $0.96$ and $0.9$. Interestingly, the effect of the interventions is somewhat nuanced. Implementing both increases the risk of overshoot as compared to the no intervention model. But individual interventions would have even worse consequences, which means that the two interventions while jointly increasing the risk to some extent mitigate each other's contribution to that risk as well.\n", "\n", - "Crucially, the analysis does not allow us to distinghuish the intuitive role that the lockdown played, as opposed to masking (whose impact has been limited by the presence of lockdown). So, we need of a more fine-grained analysis where we not only control the variables being intervened on (that is, the policies), but also pay attention to what context we are in. We achieve that level of sensitivity by stochastically keeping part of the context (that is, other variables in the model) fixed (see the tutorial for categorical variables for a more extensive explanation of this method and simpler examples). The key idea is that starting with the scenario in which both interventions have been implemented, there is a context such that if we keep it fixed, removing lockdown would significantly lower the overshoot, but there is no context that we could keep fixed such that if in that context we remove the masking policy, the overshoot would decrease. In the next section, we show how this analysis can be carried out with the help of `SearchForExplanation`." + "Crucially, the analysis does not allow us to distinghuish the intuitive role that the lockdown played, as opposed to masking (whose impact has been limited by the presence of lockdown). So, we need a more fine-grained analysis where we not only control the variables being intervened on (that is, the policies), but also pay attention to what context we are in. We achieve that level of sensitivity by stochastically keeping part of the context (that is, other variables in the model) fixed (see the tutorial for categorical variables for a more extensive explanation of this method and simpler examples). The key idea is that starting with the scenario in which both interventions have been implemented, there is a context such that if we keep it fixed, removing lockdown would significantly lower the overshoot, but there is no context that we could keep fixed such that if in that context we remove the masking policy, the overshoot would decrease. In the next section, we show how this analysis can be carried out with the help of `SearchForExplanation`." ] }, { @@ -589,7 +589,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -651,7 +651,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -692,12 +692,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The above probability itself is not directly related to our query. It is the probability that the overshoot is too high in the antecedents-intervened workd and not too high in the alterantives-intervened world, where antecedent interventions are preempted with probabilities $0.5$ at each site, and witnesses are kept fixed at the observed values with probability $0.5+0.2$ at each site. But more fine-grained queries can be answered using the 10000 samples we have drawn in the process. We first compute the probabilities that different sets of antecedent candidates have causal effect over `os_too_high`." + "The above probability itself is not directly related to our query. It is the probability that the overshoot is too high in the antecedents-intervened workd and not too high in the alternatives-intervened world, where antecedent interventions are preempted with probabilities $0.5$ at each site, and witnesses are kept fixed at the observed values with probability $0.5+0.2$ at each site. But more fine-grained queries can be answered using the 10000 samples we have drawn in the process. We first compute the probabilities that different sets of antecedent candidates have causal effect over `os_too_high`." ] }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -723,7 +723,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 39, "metadata": {}, "outputs": [ { @@ -771,7 +771,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Note that one could also compute above queries by giving specific parameters to `SearchForExplanation` instead of subselecting the samples, as we did in the tutorial for explainable module for models with categorical variables. Here, however, we illustrate that running a sufficiently general query ones produces samples that can be used to answer multiple different questions.\n", + "Note that one could also compute above queries by giving specific parameters to `SearchForExplanation` instead of subselecting the samples, as we did in the tutorial for explainable module for models with categorical variables. Here, however, we illustrate that running a sufficiently general query once produces samples that can be used to answer multiple different questions.\n", "\n", "Also, we use the log probabilities above to identify whether a particular combination of intervening nodes and context nodes have causal power or not, which is made possible by the fact that our handler adds appropriate log probabilities to the trace (see the previous tutorial and documentation for more explanation). One can also obtain these results by explictly analyzing the sample trace as we do in the next section." ] @@ -785,7 +785,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 40, "metadata": {}, "outputs": [ { @@ -834,7 +834,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -876,7 +876,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -907,7 +907,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 43, "metadata": {}, "outputs": [ { @@ -992,7 +992,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -1023,7 +1023,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 45, "metadata": {}, "outputs": [ { @@ -1108,7 +1108,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -1201,7 +1201,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1233,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 48, "metadata": {}, "outputs": [ { @@ -1306,7 +1306,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -1336,7 +1336,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 50, "metadata": {}, "outputs": [ { diff --git a/docs/source/inference.ipynb b/docs/source/inference.ipynb new file mode 100644 index 00000000..d4b99e30 --- /dev/null +++ b/docs/source/inference.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional, Callable\n", + "import math\n", + "\n", + "import pyro.distributions as dist\n", + "import torch\n", + "\n", + "import pyro\n", + "from chirho.counterfactual.handlers.counterfactual import \\\n", + " MultiWorldCounterfactual\n", + "from chirho.explainable.handlers import SearchForExplanation\n", + "from chirho.explainable.handlers.components import ExtractSupports\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "def model():\n", + " a = pyro.sample(\"a\", dist.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)))\n", + " b = pyro.sample(\"b\", dist.Normal(loc=torch.tensor(0.0), scale=torch.tensor(1.0)))" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "with ExtractSupports() as s:\n", + " model()\n", + "\n", + "query = SearchForExplanation(\n", + " supports=s.supports,\n", + " alternatives={\"a\": torch.tensor(0.5)},\n", + " antecedents={\"a\": torch.tensor(-0.5)},\n", + " antecedent_bias=0.0,\n", + " witnesses={},\n", + " consequents={\"b\": torch.tensor(0.0)},\n", + " consequent_scale=1e-8,\n", + " )(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "How can I compute the probability that `a=-0.5` is a sufficienct and necessary cause of `b=0` using `SearchForExplanation`?" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def importance_infer(\n", + " model: Optional[Callable] = None, *, num_samples: int\n", + "):\n", + " \n", + " if model is None:\n", + " return lambda m: importance_infer(m, num_samples=num_samples)\n", + "\n", + " def _wrapped_model(\n", + " *args,\n", + " **kwargs\n", + " ):\n", + "\n", + " guide = pyro.poutine.block(hide_fn=lambda msg: msg[\"is_observed\"])(model)\n", + "\n", + " max_plate_nesting = 9 # TODO guess\n", + "\n", + " with pyro.poutine.block(), MultiWorldCounterfactual() as mwc_imp:\n", + " log_weights, importance_tr, _ = pyro.infer.importance.vectorized_importance_weights(\n", + " model,\n", + " guide,\n", + " *args,\n", + " num_samples=num_samples,\n", + " max_plate_nesting=max_plate_nesting,\n", + " normalized=False,\n", + " **kwargs\n", + " )\n", + "\n", + " return torch.logsumexp(log_weights, dim=0) - math.log(num_samples), importance_tr, mwc_imp, log_weights\n", + "\n", + " return _wrapped_model" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}