From ed95ddde633a021306b14e257e1a6049b4f94293 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Mon, 7 Aug 2023 22:18:08 +0200 Subject: [PATCH 01/13] refactored the actual causality class --- docs/source/actual_causality.ipynb | 1256 ++++++++++++++++++++++++++++ 1 file changed, 1256 insertions(+) create mode 100644 docs/source/actual_causality.ipynb diff --git a/docs/source/actual_causality.ipynb b/docs/source/actual_causality.ipynb new file mode 100644 index 00000000..80601354 --- /dev/null +++ b/docs/source/actual_causality.ipynb @@ -0,0 +1,1256 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Actual Causality: the modified Halpern-Pearl definition" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Summary**\n", + "\n", + "Here we show how the tools made available within Causal Pyro TODO: CHANGE NAME(?) can be used to implement the notion of actual causality developed by Halpern and Pearl (see J. Halpern, *Actual Causality*, 2016), and illustrate its workings by replicating a few key examples from the book." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Outline**\n", + "\n", + "[Intuitions](##intuitions)\n", + " \n", + "[Formalization](#formalization)\n", + "\n", + "- [Structural causal models](#structural-causal-models)\n", + "\n", + "- [Halpern-Pearl modified definition of actual causality](#halpern-pearl-modified-definition-of-actual-causality)\n", + "\n", + "[Implementation](#implementation)\n", + "\n", + "[Examples](#examples)\n", + "\n", + "- [Comments on example selection](#comments-on-example-selection)\n", + " \n", + "- [Stone-throwing](#stone-throwing)\n", + "\n", + "- [Forest fire](#forest-fire)\n", + "\n", + "- [Doctors](#doctors)\n", + "\n", + "- [Friendly fire](#friendly-fire)\n", + "\n", + "\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Intuitions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Actual causality (sometimes called **token causality** or **specific causality**) is usually contrasted with type causality (sometimes called **general causality**). While the latter is concerned with general statements (such as \"smoking causes cancer\"), actual causality focuses on particular events. For illustration, consider the following causality-related questions:\n", + "\n", + "- **Friendly Fire**: On March 24, 2002, A B-52 bomber fired a Joint Direct Attack Munition at a US battalion command post, killing three and injuring twenty special forces soldiers. Out of multiple potential contributing factors, which were **actually** responsible for the incident?\n", + " \n", + "- **Schizophrenia** : The disease arises from the interaction between multiple genetic and environmental factors. Given a particular patient and what we know about them, which of these factors **actually** caused her state?\n", + " \n", + "- **Explainable AI**: Your loan application has been refused. The bank representative informs you the decision was made using predictive modeling to estimate the probability of default. They give you a list of various factors considered in the prediction. But which of these factors **actually** resulted in the rejection, and what were their contributions?\n", + " \n", + "These are questions about **actual causality**. While having answers to such questions is not directly useful for prediction tasks, they are useful for understanding how we can prevent undesirable outcomes similar to ones that we have observed or promote the occurrence of desirable outcomes in contexts similar to the ones in which they had been observed. These context-sensitive causality questions are also an essential element of blame and responsibility assignments, and of at least one prominent account of the notion of explanation (all of which will be explored in other notebooks). TODO add links\n", + "\n", + "The general intuition behind the notion of actual causality that we will focus on is that a certain state of antecedent nodes is the cause of a given state of the consequent nodes if there is a part of the actual reality such that if it is kept fixed at what it actually is, and we intervened on the antecedent nodes to be in a different state, the consequent nodes would no longer be in the observed states. A proper explication of this notion requires the context of structural causal models - we first explain what these are, and then move on to the definition." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Formalization " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Structural causal models" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "While statistical information might help address questions of actual causality, is not sufficient. One requires causal theories that explain how the relevant aspects of the world function, as well as information about the actual facts pertaining to the specific case. For this reason, the notion on which we focus in this notebook is formulated within the framework of structural causal models, which can represent such information.\n", + "\n", + "The notion is defined in the context of a deterministic structural causal model (SCMs). One major component thereof is a selection of **variables**. For instance, in a very simple model for a forest-fire problem, we might consider a model with three endogenous binary variables: $FF$ (forest fire), $L$ (lightning), and $MD$ (match dropped) whose values are determined by the values of other variables, and two exogenous noise variables $U_{MD}$ and $U_L$ that determine the values of $MD$ and $L$. Moreover, some of those variables/nodes are connected by means of directed **edges**. For instance, in the example at hand, the model contains two edges that go from $U_MD$ to $MD$ and from $U_L$ to $L$ respectively, and two edges that go from $L$ to $FF$ and from $MD$ to $FF$. Each influence is associated with a **structural equation** - for instance, $FF = max(L, MD)$ indicates that a forest fire occurs if either of the two factors occurs. SCMs come also with a **context**, which is the values of **exogenous variables** whose values are not determined by the structural equations, but rather by factors outside the model. In our example, one context might be that both a match has been dropped and a lightning occurred.\n", + "\n", + "More formally, a causal model $M$ is a tuple $\\langle S, F\\rangle$, where:\n", + "\n", + "- $S$ is a **signature**, that is a tuple $\\langle U, V, R\\rangle$, where $U$ is a set of exogenous variables, $V$ is a set of endogenous variables and $R: U \\cup V \\mapsto R(Y)$, where $R(Y)\\neq \\emptyset$, that is $R$ assigns non-empty ranges to exogenous and endogenous variables.\n", + "\n", + "- To each endogenous $X\\in V$, $F$ assigns a function $F_X$, which maps the cross-product of ranges of all variables other than $X$ to $R(X)$. In other words, $F_X$ determines the value of $X$ given the values of other variables in the model (some of them might be redundant in a given equation). The intuition is that these functions correspond to structural equations of the form $X = F_X(U, V)$ which are to be read from right to left: if the values of $U\\cup V$ are fixed to be such-and-such, say $\\vec{u}$ and $\\vec{v}$, this causes $X$ to take the value $F_X(\\vec{u}, \\vec{v})$.\n", + "\n", + "A **deterministic causal model** (also called **causal setting**), $\\langle M, \\vec{u}\\rangle$ is a causal model $M$ together with fixed settings $\\vec{u}$ of its exogenous variables $U$. To intervene, say, to make $Y$ have value $y$, is to replace the structural equation for $Y$ of the form $Y = F_Y(U, V)$ with $Y = y$. $\\langle M, \\vec{u}\\rangle \\models [Y \\leftarrow y](X = x)$ means: in the deterministic model obtained from $\\langle M, \\vec{u}\\rangle$ by intervening on $Y$ to have value $y$ $X$ has value $x$. Sometimes, instead of $X = x$, one might be interested in a more general claim $\\varphi$ involving potentially multiple variables, in which case the notation is $\\langle M, \\vec{u}\\rangle \\models [Y \\leftarrow y](\\varphi)$. " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Halpern-Pearl modified definition of actual causality" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It is important to recognize that the straightforward counterfactual strategy, which asks whether the event would have occurred if the antecedent had not taken place, is inadequate as a definition of actual causality. A simple example can help illustrate this point. Suppose I throw a stone, which hits and shatters a bottle. However, just a second later, Bill also throws a stone at the bottle but misses, solely because the bottle was already shattered by my stone. In this scenario, the intuition is that my throw is the cause of the bottle shattering, even though the bottle would still have shattered if I hadn't thrown the stone. \n", + "This highlights the need for a more elaborate account that considers the actual state, taking into consideration the fact that Bill's stone did not, in fact, hit the bottle. One such account involves the following definition of actual causality:\n", + "\n", + "Given an SCM $M$ and a vector of its exogenous variable settings $\\vec{u}$ we'll write $(M, \\vec{u})\\models [ \\vec{Y} \\leftarrow \\vec{y}]\\psi$ just in case $\\psi$ holds in $(M',\\vec{u})$, where $M'$ is the intervened model obtained by replacing the structural equation(s) for $\\vec{Y}$ in $M$ with $\\vec{Y_i} = \\vec{y_i}$. \n", + "\n", + "We say that $\\vec{X}=\\vec{x}$ is an actual cause of $\\varphi$ in $(M,\\vec{u})$ just in case:\n", + "\n", + "AC1. Factivity: $(M, \\vec{u}) \\models [\\vec{X} = \\vec{x} \\wedge \\varphi]$\n", + "\n", + "AC2. Necessity:\n", + "\n", + "$\\exists \\vec{W}, \\vec{x}'(M, \\vec{u})\\models [\\vec{X} \\leftarrow \\vec{x}', \\vec{W} = \\vec{w}^{\\star}] \\neg \\varphi$,\n", + "where $\\vec{w}^\\star$ are the actual values of $\\vec{W}$, i.e. $(M, \\vec{u}) \\models \\vec{W} = \\vec{w}^\\star$\n", + "\n", + "AC3. Minimality: $\\vec{X}$ is a subset-minimal set of potential causes satisfying AC2\n", + "\n", + "AC1 requires that both the antecedent and the consequent hold. The intuition behind AC2 is that for $\\vec{X}=\\vec{x}$ to be the actual cause of $\\varphi$, there needs to be a vector of witness nodes $\\vec{W}$ and a vector $\\vec{x'}$ of *alternative* settings of $\\vec{X}$ such that if $\\vec{W}$ are intervened to have their actual values $\\vec{w^\\star}$, and $\\vec{X}$ are intervened to have values $\\vec{x'}$, $\\varphi$ no longer holds in the resulting model. AC3 requires that the antecedent should be a minimal one satisfying AC2." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implementation\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "import numpy as np\n", + "from itertools import combinations\n", + "\n", + "\n", + "import torch\n", + "from typing import Dict, List, Optional, Union, Callable, Any\n", + "\n", + "import pandas as pd\n", + "\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "\n", + "from chirho.indexed.ops import IndexSet, gather, indices_of, scatter\n", + "from chirho.interventional.handlers import do\n", + "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here and in later notebooks, instead of full enumeration, we will be approximating the answers with sampling. In particular, answering an actual causality query requires investigating the consequences of intervening on all possible witness candidate nodes in all possible combinations to have the values they actually have in a given model. While complete enumeration would work for smaller models, we implement a more general approximate method, which draws random sets of witness nodes multiple times. For smaller models (as the one used in our examples), complete coverage of all possible combinations is easily obtained. For larger models complete enumeration becomes less feasible.\n", + "\n", + "An SCM in this context is represented by a Pyro model, where the exogenous variables are stochastic and introduced using `pyro.sample`, and all the endogenous variables are determined by these, and introduced by `pyro.deterministic` (read on for examples). For simplicity we also assume the antecedent nodes are binary (this assumption can be weakened to them being discrete), and that the consequent nodes are discrete. \n", + "\n", + "The key role in this implementation is played by the preemption handler, which is used to randomly select some of the witness candidate nodes and preempt them - in our case, with their observed values.\n", + "\n", + "The key moves are in our `def __call__`:\n", + "\n", + "- `pyro.plate` is a messenger used to construct conditionally independent sequences of variables, we use it to obtain multiple independent samples.\n", + "\n", + "- `pyro.condition` is used to constrain the values of some variables. We use it to fix the values of the exogenous variables. As they are causally upstream from all the other variables, this also fixes the values of the endogenous variables.\n", + " \n", + "- `Preemptions` is a messenger that in post-processing fixes a random selection of (witness) nodes to the results of certain actions - in our case, these actions correspond to the actual values of the variables.\n", + " \n", + "- `do` is a handler that intervenes on selected variables to have the specified values. We use it to fix the *alternative* combination of values for the antecedent nodes.\n", + " \n", + "- `MutliWorldCounterfactual` allows us to keep track of all the scenarios, keeping split records thereof as values at various sites.\n", + "\n", + "- `pyro.poutine.trace` is a messenger that keeps track of sites visited for downstream use.\n", + "\n", + "- within the model, we compare the observed and the intervened consequent and record whether these differ.\n", + " \n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "class BiasedPreemptions(Preemptions):\n", + " \"\"\"\n", + " Counterfactual handler that preempts the model with a biased coin flip.\n", + " \"\"\"\n", + " def __init__(self, actions, weights: torch.Tensor) -> None:\n", + " self.weights = weights\n", + " super().__init__(actions)\n", + "\n", + " def _pyro_preempt(self,msg: Dict[str, Any]) -> None:\n", + " if msg[\"name\"] not in self.actions:\n", + " return\n", + " obs, acts, case = msg[\"args\"]\n", + " msg[\"kwargs\"][\"name\"] = f\"__split_{msg['name']}\"\n", + " case_dist = pyro.distributions.Categorical(self.weights)\n", + " case = pyro.sample(msg[\"kwargs\"][\"name\"], case_dist, obs=case)\n", + " msg[\"args\"] = (obs, acts, case)\n", + " msg[\"stop\"] = True" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "class HalpernPearlModifiedApproximate:\n", + "\n", + " def __init__(\n", + " self, \n", + " model: Callable,\n", + " antecedents: Dict[str, torch.Tensor],\n", + " outcome: str,\n", + " witness_candidates: List[str],\n", + " observations: Optional[Dict[str, torch.Tensor]],\n", + " sample_size: int = 100,\n", + " event_dim: int = 0\n", + " ):\n", + " \n", + " self.model = model\n", + " # rename to counterfactual antecedents\n", + " self.antecedents = antecedents\n", + " self.outcome = outcome\n", + " self.witness_candidates = witness_candidates\n", + " self.observations = observations\n", + " self.sample_size = sample_size\n", + "\n", + " self.antecedent_preemptions = {antecedent: functools.partial(self.preempt_with_factual,\n", + " antecedents = [antecedent]) for\n", + " antecedent in self.antecedents.keys()}\n", + " \n", + " self.witness_preemptions = {candidate: functools.partial(self.preempt_with_factual,\n", + " antecedents = self.antecedents) for \n", + " candidate in self.witness_candidates}\n", + " \n", + " @staticmethod \n", + " def preempt_with_factual(value: torch.Tensor, *,\n", + " antecedents: List[str] = None, event_dim: int = 0):\n", + " \n", + " if antecedents is None:\n", + " antecedents = []\n", + "\n", + " antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]\n", + "\n", + " factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),\n", + " event_dim=event_dim)\n", + " \n", + " return scatter({\n", + " IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,\n", + " IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,\n", + " }, event_dim=event_dim)\n", + " \n", + " \n", + " def __call__(self, *args, **kwargs):\n", + " with MultiWorldCounterfactual():\n", + " with do(actions=self.antecedents):\n", + " # the last element of the tensor is the factual case (preempted)\n", + " with BiasedPreemptions(actions = self.antecedent_preemptions, weights = torch.tensor([.1, .9])):\n", + " with Preemptions(actions = self.witness_preemptions):\n", + " with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", + " self.consequent = self.model()[self.outcome]\n", + " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.antecedents}))\n", + " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.antecedents}))\n", + " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", + " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + " \n", + " # self.trace = trace.trace\n", + "\n", + " # # slightly hacky solution for odd witness candidate sets\n", + " # if isinstance(self.consequent_differs.squeeze().tolist(), bool):\n", + " # self.existential_but_for = self.consequent_differs.squeeze()\n", + " # else:\n", + " # #if (len(self.consequent_differs.squeeze().tolist() )>1):\n", + " # self.existential_but_for = any(self.consequent_differs.squeeze().tolist() ) \n", + "\n", + " \n", + "\n", + " # witness_dict = dict()\n", + " # if self.witness_candidates:\n", + " # witness_keys = [\"__split_\" + candidate for candidate in self.witness_candidates]\n", + " # witness_dict = {key: self.trace.nodes[key]['value'] for key in witness_keys}\n", + " \n", + " # witness_dict['observed'] = self.observed_consequent.squeeze()\n", + " # witness_dict['intervened'] = self.intervened_consequent.squeeze()\n", + " # witness_dict['consequent_differs'] = self.consequent_differs.squeeze()\n", + "\n", + " # # slightly hacky as above\n", + " # self.witness_df = pd.DataFrame(witness_dict) if self.witness_candidates else witness_dict\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As the definition of actual causality requires a minimality check, we implement it on top of the existential but-for test:" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examples" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comments on example selection\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the sake of illustration, we reconstruct a few examples, which-with one exception (friendly fire incident)-come from Halpern's book. The selection is as follows:\n", + "\n", + "- **Stone throwing:** this is a classic, simple structure in which the but-for clause fails due to over-determination, but an actual causality claim holds.\n", + "\n", + "- **Forest fire:** one of the simplest structures illustrating conjunctions being actual causes, and how an event can be part of an actual cause without being an actual cause itself.\n", + "\n", + "- **Doctors:** a simple example illustrating the intransitivity of actual causality.\n", + "\n", + "- **Friendly fire incident:** a real-life example, to illustrate how the tools can be applied outside of a narrow selection of thought experiments.\n", + "\n", + "- **Voting:** this illustrates how on this approach a voter is only an actual cause if they can make a difference, but only part of an actual cause otherwise, which motivates reflection on responsibility and blame, to which we will come back in other notebooks #dTODO add links" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Stone-throwing" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sally and Billy pick up stones and throw them at a bottle. Sally's stone gets there first, shattering the bottle. Both throws are perfectly accurate, so Billy's stone would have shattered the bottle had it not been preempted by Sally’s throw. (see *Actual Causality*, p. 3 and multiple further points at which the example is discussed in the book)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def stones_model(): \n", + " prob_sally_throws = pyro.sample(\"prob_sally_throws\", dist.Beta(1, 1))\n", + " prob_bill_throws = pyro.sample(\"prob_bill_throws\", dist.Beta(1, 1))\n", + " prob_sally_hits = pyro.sample(\"prob_sally_hits\", dist.Beta(1, 1))\n", + " prob_bill_hits = pyro.sample(\"prob_bill_hits\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_sally = pyro.sample(\"prob_bottle_shatters_if_sally\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_bill = pyro.sample(\"prob_bottle_shatters_if_bill\", dist.Beta(1, 1))\n", + "\n", + "\n", + " sally_throws = pyro.sample(\"sally_throws\", dist.Bernoulli(prob_sally_throws))\n", + " bill_throws = pyro.sample(\"bill_throws\", dist.Bernoulli(prob_bill_throws))\n", + "\n", + " new_shp = torch.where(sally_throws == 1,prob_sally_hits , 0.0)\n", + "\n", + " sally_hits = pyro.sample(\"sally_hits\",dist.Bernoulli(new_shp))\n", + "\n", + " new_bhp = torch.where(\n", + " (\n", + " bill_throws.bool()\n", + " & (~sally_hits.bool())\n", + " )\n", + " == 1,\n", + " prob_bill_hits,\n", + " torch.tensor(0.0),\n", + " )\n", + "\n", + "\n", + " bill_hits = pyro.sample(\"bill_hits\", dist.Bernoulli(new_bhp))\n", + "\n", + " new_bsp = torch.where(\n", + " bill_hits.bool() == 1,\n", + " prob_bottle_shatters_if_bill,\n", + " torch.where(\n", + " sally_hits.bool() == 1,\n", + " prob_bottle_shatters_if_sally,\n", + " torch.tensor(0.0),\n", + " ),\n", + " )\n", + "\n", + " bottle_shatters = pyro.sample(\n", + " \"bottle_shatters\", dist.Bernoulli(new_bsp)\n", + " )\n", + "\n", + " return {\n", + " \"sally_throws\": sally_throws,\n", + " \"bill_throws\": bill_throws,\n", + " \"sally_hits\": sally_hits,\n", + " \"bill_hits\": bill_hits,\n", + " \"bottle_shatters\": bottle_shatters,\n", + " }\n", + "\n", + "stones_model.nodes = [\n", + " \"sally_throws\",\n", + " \"bill_throws\",\n", + " \"sally_hits\",\n", + " \"bill_hits\",\n", + " \"bottle_shatters\",\n", + " ]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now instantiate the class, specifying the observations. When we run the resulting model, we then randomly generate 100 witness sets - `witness_df` contains information on whether the intervention changes the consequent for each of these preemptions. " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(101)\n", + "stonesHPM = HalpernPearlModifiedApproximate(\n", + " model = stones_model,\n", + " antecedents = {\"sally_throws\": 0},\n", + " outcome = \"bottle_shatters\",\n", + " witness_candidates = [\"bill_throws\", \"bill_hits\"],\n", + " observations = {\"prob_sally_throws\": 1, \n", + " \"prob_bill_throws\": 1,\n", + " \"prob_sally_hits\": 1,\n", + " \"prob_bill_hits\": 1,\n", + " \"prob_bottle_shatters_if_sally\": 1,\n", + " \"prob_bottle_shatters_if_bill\": 1,\n", + " \"sally_throws\": 1, \"bill_throws\": 1},\n", + " sample_size = 100,\n", + " event_dim = 0\n", + ")\n", + "\n", + "with pyro.poutine.trace() as trace:\n", + " stonesHPM()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'type': 'sample', 'name': 'consequent_differs', 'fn': Unit(log_factor: tensor([[[[[-100000000.]]]]])), 'is_observed': True, 'args': (), 'kwargs': {}, 'value': tensor([], size=(1, 1, 1, 1, 1, 0)), 'infer': {'is_auxiliary': True}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None}\n" + ] + } + ], + "source": [ + "#print(trace.trace.nodes.keys())\n", + "\n", + "\n", + "print(trace.trace.nodes[\"consequent_differs\"])\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The existential causality claim (*is there a witness set such that if it is intervened to be fixed at the actual values, an intervention on the antecedent to have a different value would cause the consequent to have a different value?*) holds just in case consequent differs at list ones. This information is contained in `existential_but_for`." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stonesHPM.existential_but_for" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now can use `ac_minimality_check` to check further conditions." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[]]\n", + "[False]\n", + "True\n", + "True\n" + ] + } + ], + "source": [ + "stones_min = ac_minimality_check(stonesHPM)\n", + "print(stones_min.ante_subsets) #there is only one, empty subset\n", + "print(stones_min.ante_existential_but_for) #absolute but-for clause fails\n", + "print(stones_min.minimal) #so our antecedent is minimal\n", + "print(stones_min.ac) #and the actual causality claim holds" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[], ['sally_throws'], ['bill_throws']]\n", + "[False, True, False]\n", + "False\n" + ] + } + ], + "source": [ + "# let's compare it to a case in which \n", + "# the antecedent is not minimal\n", + "pyro.set_rng_seed(101)\n", + "stones_redundant_HPM = HalpernPearlModifiedApproximate(\n", + " model = stones_model,\n", + " antecedents = [\"sally_throws\", \"bill_throws\"],\n", + " outcome = \"bottle_shatters\",\n", + " witness_candidates = [\"bill_throws\", \"bill_hits\"],\n", + " observations = {\"prob_sally_throws\": 1, \n", + " \"prob_bill_throws\": 1,\n", + " \"prob_sally_hits\": 1,\n", + " \"prob_bill_hits\": 1,\n", + " \"prob_bottle_shatters_if_sally\": 1,\n", + " \"prob_bottle_shatters_if_bill\": 1,\n", + " \"sally_throws\": 1, \"bill_throws\": 1},\n", + " sample_size = 100,\n", + " event_dim = 0\n", + ")\n", + "\n", + "stones_redundant_HPM()\n", + "\n", + "stones_min_checked = ac_minimality_check(stones_redundant_HPM)\n", + "print(stones_min_checked.ante_subsets)\n", + "print(stones_min_checked.ante_existential_but_for)\n", + "print(stones_min_checked.ac)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Forest fire" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this simplified model, a forest fire was caused by lightning or an arsonist, so we use three endogenous variables, and two exogenous variables corresponding to the two factors. In the conjunctive model,\n", + "both of the factors have to be present for the fire to start. In the disjunctive model, each of them alone is sufficient." + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "def ff_conjunctive():\n", + " u_match_dropped = pyro.sample(\"u_match_dropped\", dist.Bernoulli(0.5))\n", + " u_lightning = pyro.sample(\"u_lightning\", dist.Bernoulli(0.5))\n", + "\n", + " match_dropped = pyro.deterministic(\"match_dropped\",\n", + " u_match_dropped, event_dim=0)\n", + " lightning = pyro.deterministic(\"lightning\", u_lightning, event_dim=0)\n", + " forest_fire = pyro.deterministic(\"forest_fire\", torch.logical_and(match_dropped, lightning), event_dim=0).float()\n", + "\n", + " return {\"match_dropped\": match_dropped, \"lightning\": lightning,\n", + " \"forest_fire\": forest_fire}" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "def ff_disjunctive():\n", + " u_match_dropped = pyro.sample(\"u_match_dropped\", dist.Bernoulli(0.5))\n", + " u_lightning = pyro.sample(\"u_lightning\", dist.Bernoulli(0.5))\n", + "\n", + " match_dropped = pyro.deterministic(\"match_dropped\",\n", + " u_match_dropped, event_dim=0)\n", + " lightning = pyro.deterministic(\"lightning\", u_lightning, event_dim=0)\n", + " forest_fire = pyro.deterministic(\"forest_fire\", torch.logical_or(match_dropped, lightning), event_dim=0).float()\n", + "\n", + " return {\"match_dropped\": match_dropped, \"lightning\": lightning,\n", + " \"forest_fire\": forest_fire}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "# In the conjunctive model \n", + "# Each of the two factors is a but-for cause\n", + " \n", + "pyro.set_rng_seed(101)\n", + "ff_conjunctiveHPM = HalpernPearlModifiedApproximate(\n", + " model = ff_conjunctive,\n", + " antecedents = [\"match_dropped\"],\n", + " outcome = \"forest_fire\",\n", + " witness_candidates = [\"lightning\"],\n", + " observations = {\"match_dropped\": 1, \"lightning\": 1},\n", + " sample_size = 4,\n", + " event_dim = 0\n", + ")\n", + "\n", + "ff_conjunctiveHPM()\n", + "ff_conjunctiveHPM_min = ac_minimality_check(ff_conjunctiveHPM)\n", + "print(ff_conjunctiveHPM_min.ac) \n" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# In the disjunctive model \n", + "# there still would be fire if there was no lightning\n", + "\n", + "pyro.set_rng_seed(101)\n", + "ff_disjunctiveHPM = HalpernPearlModifiedApproximate(\n", + " model = ff_disjunctive,\n", + " antecedents = [\"match_dropped\"],\n", + " outcome = \"forest_fire\",\n", + " witness_candidates = [\"lightning\"],\n", + " observations = {\"match_dropped\": 1, \"lightning\": 1},\n", + " sample_size = 4,\n", + " event_dim = 0\n", + ")\n", + "\n", + "ff_disjunctiveHPM()\n", + "ff_disjunctiveHPM.existential_but_for\n", + "# no need for further checks" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(True)\n", + "[[], ['match_dropped'], ['lightning']]\n", + "[tensor(False), tensor(False), tensor(False)]\n", + "tensor(True)\n" + ] + } + ], + "source": [ + "# in the disjunctive model\n", + "# the actual cause is the conjunction of the two factors\n", + "\n", + "pyro.set_rng_seed(101)\n", + "ff_disjunctive_jointHPM = HalpernPearlModifiedApproximate(\n", + " model = ff_disjunctive,\n", + " antecedents = [\"match_dropped\", \"lightning\"],\n", + " outcome = \"forest_fire\",\n", + " witness_candidates = [],\n", + " observations = {\"match_dropped\": 1, \"lightning\": 1},\n", + " sample_size = 4,\n", + " event_dim = 0\n", + ")\n", + "\n", + "ff_disjunctive_jointHPM()\n", + "print(ff_disjunctive_jointHPM.existential_but_for)\n", + "\n", + "ff_disjunctive_jointHPM_min = ac_minimality_check(ff_disjunctive_jointHPM)\n", + "print(ff_disjunctive_jointHPM_min.ante_subsets)\n", + "print(ff_disjunctive_jointHPM_min.ante_existential_but_for)\n", + "print(ff_disjunctive_jointHPM_min.ac) " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Doctors" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This example illustrates that actual causality is not, in general, transitive. One doctor is responsible for administering the medicine on Monday, and if she does, Bill recovers on Tuesday.\n", + "Another doctor is reliable and treats Bill on Tuesday if the first doctor failed to do so on Monday. If both doctors treat Bill, he is in `condition1`, dead on Wednesday. Otherwise, he is either healthy on Tuesday (`condition2`) or healthy on Wednesday (`condition3`), or did not receive any treatment and feels worse but is alive on Wednesday (`condition4`).\n", + "\n", + "Now suppose Bill did receive treatment on Monday. This is an actual cause of his not receiving treatment on Tuesday, and the latter is an actual cause of his being alive on Wednesday. However, there is nothing that the first doctor could do to cause Bill to be dead on Wednesday." + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [], + "source": [ + "def bc_function(mt, tt):\n", + " condition1 = (mt == 1) & (tt == 1)\n", + " condition2 = (mt == 1) & (tt == 0)\n", + " condition3 = (mt == 0) & (tt == 1)\n", + " condition4 = ~(condition1 | condition2 | condition3)\n", + "\n", + " output = torch.where(condition1, torch.tensor(3.0), torch.tensor(0.0))\n", + " output = torch.where(condition2, torch.tensor(0.0), output)\n", + " output = torch.where(condition3, torch.tensor(1.0), output)\n", + " output = torch.where(condition4, torch.tensor(2.0), output)\n", + "\n", + " return output\n", + "\n", + "\n", + "def model_doctors():\n", + " u_monday_treatment = pyro.sample(\"u_monday_treatment\", dist.Bernoulli(0.5))\n", + "\n", + " monday_treatment = pyro.deterministic(\n", + " \"monday_treatment\", u_monday_treatment, event_dim=0\n", + " )\n", + "\n", + " tuesday_treatment = pyro.deterministic(\n", + " \"tuesday_treatment\",\n", + " torch.logical_not(monday_treatment).float(),\n", + " event_dim=0,\n", + " )\n", + "\n", + " bills_condition = pyro.deterministic(\n", + " \"bills_condition\",\n", + " bc_function(monday_treatment, tuesday_treatment),\n", + " event_dim=0,\n", + " )\n", + "\n", + " bill_alive = pyro.deterministic(\n", + " \"bill_alive\", bills_condition.not_equal(3.0).float(), event_dim=0\n", + " )\n", + "\n", + " return {\n", + " \"monday_treatment\": monday_treatment,\n", + " \"tuesday_treatment\": tuesday_treatment,\n", + " \"bills_condition\": bills_condition,\n", + " \"bill_alive\": bill_alive,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "step 1: tensor(True) step 2: tensor(True) step 3: tensor(False)\n", + "tensor(False)\n" + ] + } + ], + "source": [ + "doctors1_HPM = HalpernPearlModifiedApproximate(\n", + " model = model_doctors,\n", + " antecedents = [\"monday_treatment\"],\n", + " outcome = \"tuesday_treatment\",\n", + " witness_candidates = [],\n", + " observations = {\"u_monday_treatment\": 1},\n", + " sample_size = 4,\n", + " event_dim = 0\n", + ")\n", + "\n", + "doctors1_HPM()\n", + "\n", + "doctors2_HPM = HalpernPearlModifiedApproximate(\n", + " model = model_doctors,\n", + " antecedents = [\"tuesday_treatment\"],\n", + " outcome = \"bill_alive\",\n", + " witness_candidates = [],\n", + " observations = {\"u_monday_treatment\": 1},\n", + " sample_size = 4,\n", + " event_dim = 0\n", + ")\n", + "\n", + "doctors2_HPM()\n", + "\n", + "doctors3_HPM = HalpernPearlModifiedApproximate(\n", + " model = model_doctors,\n", + " antecedents = [\"monday_treatment\"],\n", + " outcome = \"bill_alive\",\n", + " witness_candidates = [],\n", + " observations = {\"u_monday_treatment\": 1},\n", + " sample_size = 4,\n", + " event_dim = 0\n", + ")\n", + "\n", + "doctors3_HPM()\n", + "\n", + "\n", + "doctors1_HPM_min = ac_minimality_check(doctors1_HPM)\n", + "doctors2_HPM_min = ac_minimality_check(doctors2_HPM)\n", + "doctors3_HPM_min = ac_minimality_check(doctors3_HPM)\n", + "\n", + "\n", + "print(\n", + "\"step 1:\", doctors1_HPM.ac,\n", + "\"step 2:\", doctors2_HPM.ac,\n", + "\"step 3:\", doctors3_HPM.ac\n", + ")\n", + "\n", + "print(doctors3_HPM_min.existential_but_for)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Friendly fire\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "This comes from a causal model developed in a real-life incident investigation, as discussed in the [Incident Reporting using SERAS® Reporter and SERAS® Analyst](http://www.causalis.com/90-publications/IncidentReportingUsingSERAS.pdf) paper.\n", + "\n", + "a U.S. Special Forces air controller changing the battery on a Global Positioning System device he was using to target a Taliban outpost north of Kandahar. Three special forces soldiers were killed and 20 were injured when a 2,000-pound, satellite-guided bomb landed, not on the Taliban outpost, but on a battalion command post occupied by American forces and a group of Afghan allies, including Hamid Karzai, now the interim prime minister. The Air Force combat controller was using a Precision Lightweight GPS Receiver to calculate the Taliban's coordinates for the attack. The controller did not realize that after he changed the device's battery, the machine was programmed to automatically come back on displaying coordinates for its own location, the official said.\n", + "\n", + "Minutes before the B-52 strike, the controller had used the GPS receiver to\n", + "calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18. Then, with the B-52 approaching the target, the air controller did a second calculation in “degree decimals” required by the bomber crew. The controller had performed the calculation and recorded the position, when the receiver battery died. Without realizing the machine was programmed to come back on showing the coordinates of its\n", + "own location, the controller mistakenly called in the American position to the B-52.\n", + "\n", + "Factors included in the model:\n", + "\n", + "1. The air controller changed the battery on the PLGR\n", + "2. Three special forces soldiers were killed and 20 were injured\n", + "3. B-52 fired a JDAM bomb at the Allied position\n", + "4. The air controller was using the PLGR to calculate the Taliban's coordinates\n", + "5. The controller did not realize that the PLGR was programmed to automatically come back on displaying coordinates for its own location\n", + "6. The controller had used the PLGR to calculate the latitude and longitude of the Taliban position in minutes and seconds for an airstrike by a Navy F/A-18\n", + "7. The air controller did a second calculation in “degree decimals” required by the bomber crew\n", + "8. The controller had performed the calculation and recorded the position\n", + "9. The controller mistakenly called in the American position to the B-52\n", + "10. The B-52 fired a JDAM bomb at the Allied position\n", + "11. The U.S. Air Force and Army had a training problem\n", + "12. The PLRG resumed displaying the coordinates of its own location after the battery was changed\n", + "13. The battery died at the crucial time\n", + "14. The controller thought he was calling in the Taliban position\n", + "\n", + "The DAG used in the model is as follows:\n", + "![Friendly Fire DAG](figures/friendly_fire_dag.png)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def model_friendly_fire():\n", + " u_f4_PLGR_now = pyro.sample(\"u_f4_PLGR_now\", dist.Bernoulli(0.5))\n", + " u_f11_training = pyro.sample(\"u_f11_training\", dist.Bernoulli(0.5))\n", + "\n", + " f4_PLGR_now = pyro.deterministic(\"f4_PLGR_now\", u_f4_PLGR_now, event_dim=0)\n", + " f11_training = pyro.deterministic(\n", + " \"f11_training\", u_f11_training, event_dim=0\n", + " )\n", + "\n", + " f6_PLGR_before = pyro.deterministic(\n", + " \"f6_PLGR_before\", f4_PLGR_now, event_dim=0\n", + " )\n", + " f7_second_calculation = pyro.deterministic(\n", + " \"f7_second_calculation\", f4_PLGR_now, event_dim=0\n", + " )\n", + " f13_battery_died = pyro.deterministic(\n", + " \"f13_battery_died\",\n", + " f6_PLGR_before.bool() & f7_second_calculation.bool(),\n", + " event_dim=0,\n", + " )\n", + "\n", + " f1_battery_change = pyro.deterministic(\n", + " \"f1_battery_change\", f13_battery_died, event_dim=0\n", + " )\n", + "\n", + " f12_PLGR_after = pyro.deterministic(\n", + " \"f12_PLGR_after\", f1_battery_change, event_dim=0\n", + " )\n", + "\n", + " f5_unaware = pyro.deterministic(\"f5_unaware\", f11_training, event_dim=0)\n", + "\n", + " f14_wrong_position = pyro.deterministic(\n", + " \"f14_wrong_position\", f5_unaware, event_dim=0\n", + " )\n", + "\n", + " f9_mistake_call = pyro.deterministic(\n", + " \"f9_mistake_call\",\n", + " f12_PLGR_after.bool() & \n", + " f14_wrong_position.bool(),\n", + " event_dim=0,\n", + " )\n", + "\n", + " f3_fired = pyro.deterministic(\"f3_fired\", f9_mistake_call, event_dim=0)\n", + "\n", + " f10_landed = pyro.deterministic(\n", + " \"f10_landed\", f3_fired.bool() & f9_mistake_call.bool(), event_dim=0\n", + " )\n", + "\n", + " f2_killed = pyro.deterministic(\"f2_killed\", f10_landed, event_dim=0)\n", + "\n", + " return {\n", + " \"f1_battery_change\": f1_battery_change,\n", + " \"f2_killed\": f2_killed,\n", + " \"f3_fired\": f3_fired,\n", + " \"f4_PLGR_now\": f4_PLGR_now,\n", + " \"f5_unaware\": f5_unaware,\n", + " \"f6_PLGR_before\": f6_PLGR_before,\n", + " \"f7_second_calculation\": f7_second_calculation,\n", + " \"f9_mistake_call\": f9_mistake_call,\n", + " \"f10_landed\": f10_landed,\n", + " \"f11_training\": f11_training,\n", + " \"f12_PLGR_after\": f12_PLGR_after,\n", + " \"f13_battery_died\": f13_battery_died,\n", + " \"f14_wrong_position\": f14_wrong_position,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tuple: True PLGR_before: True second calculation: True\n", + "[[], ['f6_PLGR_before'], ['f7_second_calculation']]\n", + "[False, True, True]\n", + "False\n", + "False\n" + ] + } + ], + "source": [ + "# while a conjunction of these two nodes satisfies the existential but-for...\n", + "\n", + "friendly_fire_HPM = HalpernPearlModifiedApproximate(\n", + " model = model_friendly_fire,\n", + " antecedents = [\"f6_PLGR_before\", \"f7_second_calculation\"],\n", + " outcome = \"f2_killed\",\n", + " witness_candidates = [\"f4_PLGR_now\",\"f5_unaware\",\n", + " \"f11_training\",\n", + " \"f14_wrong_position\"],\n", + " observations = {\"u_f4_PLGR_now\": 1.0, \"u_f11_training\": 1.0},\n", + " sample_size = 20,\n", + " event_dim = 0\n", + ")\n", + "\n", + "friendly_fire_HPM()\n", + "print(friendly_fire_HPM.existential_but_for)\n", + "\n", + "# ... it is not minimal as so does any of the two factors alone\n", + "friendly_fire_HPM_min = ac_minimality_check(friendly_fire_HPM)\n", + "\n", + "print(friendly_fire_HPM_min.ante_subsets)\n", + "print(friendly_fire_HPM_min.ante_existential_but_for)\n", + "print(friendly_fire_HPM_min.minimal)\n", + "print(friendly_fire_HPM_min.ac)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Voting\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The main reason why the voting models are interesting in this context is that we are interested in the role of particular voters in the coming about of the result. The intuition-and we will pursue it in the responsibility notebook-is that a voter might play are role or be blamed for not voting even if her vote is not decisive. For now, we just notice that the notion of actual causality at play is not enough to capture these intuitions. Say you give one vote in a binary majority vote, `vote0`, you vote \"for\", and there are six other voters. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def voting_model():\n", + " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", + " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", + " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", + " u_vote3 = pyro.sample(\"u_vote3\", dist.Bernoulli(0.6))\n", + " u_vote4 = pyro.sample(\"u_vote4\", dist.Bernoulli(0.6))\n", + " u_vote5 = pyro.sample(\"u_vote5\", dist.Bernoulli(0.6))\n", + "\n", + " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", + " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", + " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", + " vote3 = pyro.deterministic(\"vote3\", u_vote3, event_dim=0)\n", + " vote4 = pyro.deterministic(\"vote4\", u_vote4, event_dim=0)\n", + " vote5 = pyro.deterministic(\"vote5\", u_vote5, event_dim=0)\n", + " return {\"outcome\": vote0 + vote1 + vote2 + vote3 + vote4 + vote5 > 3}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if you're one of four voters who voted for, you are an actual cause\n", + "# of the outcome\n", + "\n", + "voting4HPM = HalpernPearlModifiedApproximate(\n", + " model = voting_model,\n", + " antecedents = [\"vote0\"],\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,6)],\n", + " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=0., u_vote5=0),\n", + " sample_size = 1000)\n", + "\n", + "voting4HPM()\n", + "\n", + "voting4HPM.existential_but_for" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if you're one of five voters who voted for, you are not an actual cause\n", + "# of the outcome\n", + "\n", + "voting5HPM = HalpernPearlModifiedApproximate(\n", + " model = voting_model,\n", + " antecedents = [\"vote0\"],\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,6)],\n", + " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=0),\n", + " sample_size = 1000)\n", + "\n", + "voting5HPM()\n", + "\n", + "voting5HPM.existential_but_for" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# still, you are part of an actual cause \n", + "\n", + "voting_groupHPM = HalpernPearlModifiedApproximate(\n", + " model = voting_model,\n", + " antecedents = [\"vote0\", \"vote1\"],\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(2,6)],\n", + " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=0),\n", + " sample_size = 1000)\n", + "\n", + "voting_groupHPM()\n", + "\n", + "voting_groupHPM.existential_but_for\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## References" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(Halpern 2016) Halpern, Josepy Y., \"Actual Causality\", MIT Press, Cambridge, 2016" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 577eb5352bdbc0ed0ec842a19f6e4b9a7811962c Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Mon, 14 Aug 2023 20:52:13 +0200 Subject: [PATCH 02/13] actual cause revisions in progress --- docs/source/actual_causality.ipynb | 711 +++++++++++++++++++++++------ 1 file changed, 559 insertions(+), 152 deletions(-) diff --git a/docs/source/actual_causality.ipynb b/docs/source/actual_causality.ipynb index 80601354..1ff99ccf 100644 --- a/docs/source/actual_causality.ipynb +++ b/docs/source/actual_causality.ipynb @@ -153,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -173,7 +173,8 @@ "\n", "from chirho.indexed.ops import IndexSet, gather, indices_of, scatter\n", "from chirho.interventional.handlers import do\n", - "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions" + "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions\n", + "from chirho.observational.handlers import condition" ] }, { @@ -185,29 +186,27 @@ "\n", "An SCM in this context is represented by a Pyro model, where the exogenous variables are stochastic and introduced using `pyro.sample`, and all the endogenous variables are determined by these, and introduced by `pyro.deterministic` (read on for examples). For simplicity we also assume the antecedent nodes are binary (this assumption can be weakened to them being discrete), and that the consequent nodes are discrete. \n", "\n", - "The key role in this implementation is played by the preemption handler, which is used to randomly select some of the witness candidate nodes and preempt them - in our case, with their observed values.\n", + "The key role in this implementation is played by (1) the biased preemption handler, , and (2) the preemption handler, which is used to randomly select some of the witness candidate nodes and preempt them - in our case, with their observed values.\n", "\n", "The key moves are in our `def __call__`:\n", "\n", - "- `pyro.plate` is a messenger used to construct conditionally independent sequences of variables, we use it to obtain multiple independent samples.\n", + "- `MutliWorldCounterfactual` allows us to keep track of all the scenarios, keeping split records thereof as values at various sites.\n", "\n", - "- `pyro.condition` is used to constrain the values of some variables. We use it to fix the values of the exogenous variables. As they are causally upstream from all the other variables, this also fixes the values of the endogenous variables.\n", + "- `do` is a handler that intervenes on selected variables to have the specified values. We use it to fix the *counterfactual* combination of values for the antecedent nodes.\n", " \n", - "- `Preemptions` is a messenger that in post-processing fixes a random selection of (witness) nodes to the results of certain actions - in our case, these actions correspond to the actual values of the variables.\n", + "- `BiasedPreemptions` is an effect handler that randomly preempts counterfactual interventions on antecedent nodes with an assymetric categorical distribution, so that the resulting log probabilities prefer smaller antecedent sets\n", " \n", - "- `do` is a handler that intervenes on selected variables to have the specified values. We use it to fix the *alternative* combination of values for the antecedent nodes.\n", + "- `Preemptions` is a messenger that fixes a random selection of (witness) nodes to the results of certain actions - in our case, these actions will correspond to the factual values of the variables.\n", " \n", - "- `MutliWorldCounterfactual` allows us to keep track of all the scenarios, keeping split records thereof as values at various sites.\n", - "\n", + "- `pyro.condition` is used to constrain the values of some variables. We use it to fix the values of the exogenous variables. As in the models that we use as examples they are causally upstream from all the other variables, this also fixes the values of the endogenous variables.\n", + " \n", "- `pyro.poutine.trace` is a messenger that keeps track of sites visited for downstream use.\n", - "\n", - "- within the model, we compare the observed and the intervened consequent and record whether these differ.\n", - " \n" + "\n" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -232,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -241,28 +240,27 @@ " def __init__(\n", " self, \n", " model: Callable,\n", - " antecedents: Dict[str, torch.Tensor],\n", + " counterfactual_antecedents: Dict[str, torch.Tensor],\n", " outcome: str,\n", " witness_candidates: List[str],\n", - " observations: Optional[Dict[str, torch.Tensor]],\n", - " sample_size: int = 100,\n", - " event_dim: int = 0\n", + " observations: Optional[Dict[str, torch.Tensor]] = None\n", " ):\n", " \n", + " if observations is None:\n", + " observations = {}\n", + "\n", " self.model = model\n", - " # rename to counterfactual antecedents\n", - " self.antecedents = antecedents\n", + " self.counterfactual_antecedents = counterfactual_antecedents\n", " self.outcome = outcome\n", " self.witness_candidates = witness_candidates\n", " self.observations = observations\n", - " self.sample_size = sample_size\n", "\n", " self.antecedent_preemptions = {antecedent: functools.partial(self.preempt_with_factual,\n", " antecedents = [antecedent]) for\n", - " antecedent in self.antecedents.keys()}\n", + " antecedent in self.counterfactual_antecedents.keys()}\n", " \n", " self.witness_preemptions = {candidate: functools.partial(self.preempt_with_factual,\n", - " antecedents = self.antecedents) for \n", + " antecedents = self.counterfactual_antecedents) for \n", " candidate in self.witness_candidates}\n", " \n", " @staticmethod \n", @@ -285,55 +283,103 @@ " \n", " def __call__(self, *args, **kwargs):\n", " with MultiWorldCounterfactual():\n", - " with do(actions=self.antecedents):\n", + " with do(actions=self.counterfactual_antecedents):\n", " # the last element of the tensor is the factual case (preempted)\n", - " with BiasedPreemptions(actions = self.antecedent_preemptions, weights = torch.tensor([.1, .9])):\n", + " with BiasedPreemptions(actions = self.antecedent_preemptions, weights = torch.tensor([.4, .6])): \n", " with Preemptions(actions = self.witness_preemptions):\n", - " with pyro.condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", - " self.consequent = self.model()[self.outcome]\n", - " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.antecedents}))\n", - " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.antecedents}))\n", - " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", - " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))\n", - "\n" + " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", + " with pyro.poutine.trace() as self.trace:\n", + " self.consequent = self.model(*args, **kwargs)[self.outcome]\n", + " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.counterfactual_antecedents}))\n", + " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.counterfactual_antecedents}))\n", + " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", + " pyro.deterministic(\"consequent_differs_binary\", self.consequent_differs, event_dim = 0) #feels inelegant\n", + " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))\n" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - " \n", - " # self.trace = trace.trace\n", + "# this will explore the trace once we run inference on the model\n", "\n", - " # # slightly hacky solution for odd witness candidate sets\n", - " # if isinstance(self.consequent_differs.squeeze().tolist(), bool):\n", - " # self.existential_but_for = self.consequent_differs.squeeze()\n", - " # else:\n", - " # #if (len(self.consequent_differs.squeeze().tolist() )>1):\n", - " # self.existential_but_for = any(self.consequent_differs.squeeze().tolist() ) \n", + "def get_table(nodes, antecedents, witness_candidates, consequent):\n", + " \n", + " values_table = {}\n", "\n", - " \n", + " for antecedent in antecedents:\n", + " values_table[antecedent] = nodes[antecedent][\"value\"].squeeze().tolist()\n", + " values_table['preempted_' + antecedent] = nodes['__split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['preempted_' + antecedent + '_log_prob'] = nodes['__split_' + antecedent][\"fn\"].log_prob(nodes['__split_' + antecedent][\"value\"]).squeeze().tolist()\n", "\n", - " # witness_dict = dict()\n", - " # if self.witness_candidates:\n", - " # witness_keys = [\"__split_\" + candidate for candidate in self.witness_candidates]\n", - " # witness_dict = {key: self.trace.nodes[key]['value'] for key in witness_keys}\n", - " \n", - " # witness_dict['observed'] = self.observed_consequent.squeeze()\n", - " # witness_dict['intervened'] = self.intervened_consequent.squeeze()\n", - " # witness_dict['consequent_differs'] = self.consequent_differs.squeeze()\n", "\n", - " # # slightly hacky as above\n", - " # self.witness_df = pd.DataFrame(witness_dict) if self.witness_candidates else witness_dict\n" + " for candidate in witness_candidates:\n", + " _values = nodes[candidate][\"value\"].squeeze().tolist()\n", + " # TODO: uncomment in the final version (?) \n", + " #values_table[candidate + '0'] = _values[0]\n", + " #values_table[candidate + '1'] = _values[1]\n", + " values_table['fixed_factual_' + candidate] = nodes['__split_' + candidate][\"value\"].squeeze().tolist()\n", + " \n", + " # TODO uncomment in the final version (?)\n", + " #values_table[consequent + '0'] = nodes[consequent][\"value\"].squeeze().tolist()[0]\n", + " #values_table[consequent + '1'] = nodes[consequent][\"value\"].squeeze().tolist()[1]\n", + " values_table['consequent_differs_binary'] = nodes['consequent_differs_binary'][\"value\"].squeeze().tolist()\n", + " values_table['consequent_log_prob'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " if isinstance(values_table['consequent_log_prob'], float):\n", + " values_df = pd.DataFrame([values_table])\n", + " else:\n", + " values_df = pd.DataFrame(values_table)\n", + " \n", + "\n", + " summands = ['preempted_' + antecedent + '_log_prob' for antecedent in antecedents]\n", + " summands.append('consequent_log_prob')\n", + " values_df[\"sum_log_prob\"] = values_df[summands].sum(axis = 1) \n", + " values_df.drop_duplicates(inplace = True)\n", + " values_df.sort_values(by = \"sum_log_prob\", inplace = True, ascending = False)\n", + "\n", + " return values_df.reset_index(drop = True)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": 5, "metadata": {}, + "outputs": [], "source": [ - "As the definition of actual causality requires a minimality check, we implement it on top of the existential but-for test:" + "# this reduces the actual causality check to checking a property of the resulting sums of log probabilities for the antecedent preemption and the consequent differs nodes\n", + "\n", + "def ac_check(hpm,nodes):\n", + "\n", + " antecedents = list(hpm.counterfactual_antecedents.keys())\n", + " witness_candidates= hpm.witness_candidates\n", + " consequent= hpm.outcome\n", + "\n", + " table = get_table(nodes,\n", + " antecedents,\n", + " witness_candidates,\n", + " consequent)\n", + "\n", + " if table['sum_log_prob'][0] == -1e8:\n", + " print(\"No resulting difference to the consequent in the sample.\")\n", + " return\n", + " \n", + " winner = table.iloc[0]\n", + " active_antecedents = []\n", + " for antecedent in antecedents:\n", + " if winner['preempted_'+antecedent] == 0:\n", + " active_antecedents.append(antecedent)\n", + "\n", + " ac_flag = set(active_antecedents) == set(antecedents)\n", + " \n", + " if not ac_flag:\n", + " print(\"The antecedent set is not minimal.\")\n", + " else:\n", + " print(\"The antecedent set is an actual cause.\")\n", + "\n", + " return ac_flag\n" ] }, { @@ -389,10 +435,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ + "@pyro.infer.config_enumerate\n", "def stones_model(): \n", " prob_sally_throws = pyro.sample(\"prob_sally_throws\", dist.Beta(1, 1))\n", " prob_bill_throws = pyro.sample(\"prob_bill_throws\", dist.Beta(1, 1))\n", @@ -458,19 +505,19 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We now instantiate the class, specifying the observations. When we run the resulting model, we then randomly generate 100 witness sets - `witness_df` contains information on whether the intervention changes the consequent for each of these preemptions. " + "We now instantiate the class, specifying the observations." ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "pyro.set_rng_seed(101)\n", "stonesHPM = HalpernPearlModifiedApproximate(\n", " model = stones_model,\n", - " antecedents = {\"sally_throws\": 0},\n", + " counterfactual_antecedents = {\"sally_throws\": 0},\n", " outcome = \"bottle_shatters\",\n", " witness_candidates = [\"bill_throws\", \"bill_hits\"],\n", " observations = {\"prob_sally_throws\": 1, \n", @@ -479,109 +526,423 @@ " \"prob_bill_hits\": 1,\n", " \"prob_bottle_shatters_if_sally\": 1,\n", " \"prob_bottle_shatters_if_bill\": 1,\n", - " \"sally_throws\": 1, \"bill_throws\": 1},\n", - " sample_size = 100,\n", - " event_dim = 0\n", + " \"sally_throws\": 1, \"bill_throws\": 1}\n", ")\n", "\n", "with pyro.poutine.trace() as trace:\n", - " stonesHPM()\n", - "\n" + " stonesHPM()\n" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 11, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'type': 'sample', 'name': 'consequent_differs', 'fn': Unit(log_factor: tensor([[[[[-100000000.]]]]])), 'is_observed': True, 'args': (), 'kwargs': {}, 'value': tensor([], size=(1, 1, 1, 1, 1, 0)), 'infer': {'is_auxiliary': True}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None}\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sally_throwspreempted_sally_throwspreempted_sally_throws_log_probfixed_factual_bill_hitsfixed_factual_bill_throwsconsequent_differs_binaryconsequent_log_probsum_log_prob
010-0.91629110True0.0-9.162907e-01
110-0.91629111True0.0-9.162907e-01
211-0.51082600False-100000000.0-1.000000e+08
311-0.51082601False-100000000.0-1.000000e+08
411-0.51082611False-100000000.0-1.000000e+08
511-0.51082610False-100000000.0-1.000000e+08
610-0.91629101False-100000000.0-1.000000e+08
710-0.91629100False-100000000.0-1.000000e+08
\n", + "
" + ], + "text/plain": [ + " sally_throws preempted_sally_throws preempted_sally_throws_log_prob \\\n", + "0 1 0 -0.916291 \n", + "1 1 0 -0.916291 \n", + "2 1 1 -0.510826 \n", + "3 1 1 -0.510826 \n", + "4 1 1 -0.510826 \n", + "5 1 1 -0.510826 \n", + "6 1 0 -0.916291 \n", + "7 1 0 -0.916291 \n", + "\n", + " fixed_factual_bill_hits fixed_factual_bill_throws \\\n", + "0 1 0 \n", + "1 1 1 \n", + "2 0 0 \n", + "3 0 1 \n", + "4 1 1 \n", + "5 1 0 \n", + "6 0 1 \n", + "7 0 0 \n", + "\n", + " consequent_differs_binary consequent_log_prob sum_log_prob \n", + "0 True 0.0 -9.162907e-01 \n", + "1 True 0.0 -9.162907e-01 \n", + "2 False -100000000.0 -1.000000e+08 \n", + "3 False -100000000.0 -1.000000e+08 \n", + "4 False -100000000.0 -1.000000e+08 \n", + "5 False -100000000.0 -1.000000e+08 \n", + "6 False -100000000.0 -1.000000e+08 \n", + "7 False -100000000.0 -1.000000e+08 " + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "#print(trace.trace.nodes.keys())\n", + "with pyro.poutine.trace() as basic_trace_stones:\n", + " with pyro.plate(\"runs\", 1000):\n", + " stonesHPM()\n", "\n", + "btr_stones= basic_trace_stones.trace.nodes\n", "\n", - "print(trace.trace.nodes[\"consequent_differs\"])\n", - "\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The existential causality claim (*is there a witness set such that if it is intervened to be fixed at the actual values, an intervention on the antecedent to have a different value would cause the consequent to have a different value?*) holds just in case consequent differs at list ones. This information is contained in `existential_but_for`." + "get_table(btr_stones, antecedents = ['sally_throws'], witness_candidates= ['bill_hits','bill_throws'], consequent= 'bottle_shatters')" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 13, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, { "data": { "text/plain": [ "True" ] }, - "execution_count": 52, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "stonesHPM.existential_but_for" + "ac_check(stonesHPM, btr_stones)" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "We now can use `ac_minimality_check` to check further conditions." + "The existential causality claim (*is there a witness set such that if it is intervened to be fixed at the actual values, an intervention on the antecedent to have a different value would cause the consequent to have a different value?*) holds just in case consequent differs at list ones. This information is contained in `existential_but_for`." ] }, { - "cell_type": "code", - "execution_count": 53, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[]]\n", - "[False]\n", - "True\n", - "True\n" - ] - } - ], "source": [ - "stones_min = ac_minimality_check(stonesHPM)\n", - "print(stones_min.ante_subsets) #there is only one, empty subset\n", - "print(stones_min.ante_existential_but_for) #absolute but-for clause fails\n", - "print(stones_min.minimal) #so our antecedent is minimal\n", - "print(stones_min.ac) #and the actual causality claim holds" + "We now can use `ac_minimality_check` to check further conditions." ] }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 22, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "[[], ['sally_throws'], ['bill_throws']]\n", - "[False, True, False]\n", - "False\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sally_throwspreempted_sally_throwspreempted_sally_throws_log_probbill_throwspreempted_bill_throwspreempted_bill_throws_log_probfixed_factual_bill_hitsconsequent_differs_binaryconsequent_log_probsum_log_prob
010-0.91629111-0.5108261True0.0-1.427116e+00
110-0.91629110-0.9162910True0.0-1.832581e+00
210-0.91629110-0.9162911True0.0-1.832581e+00
311-0.51082611-0.5108260False-100000000.0-1.000000e+08
411-0.51082611-0.5108261False-100000000.0-1.000000e+08
511-0.51082610-0.9162910False-100000000.0-1.000000e+08
611-0.51082610-0.9162911False-100000000.0-1.000000e+08
710-0.91629111-0.5108260False-100000000.0-1.000000e+08
\n", + "
" + ], + "text/plain": [ + " sally_throws preempted_sally_throws preempted_sally_throws_log_prob \\\n", + "0 1 0 -0.916291 \n", + "1 1 0 -0.916291 \n", + "2 1 0 -0.916291 \n", + "3 1 1 -0.510826 \n", + "4 1 1 -0.510826 \n", + "5 1 1 -0.510826 \n", + "6 1 1 -0.510826 \n", + "7 1 0 -0.916291 \n", + "\n", + " bill_throws preempted_bill_throws preempted_bill_throws_log_prob \\\n", + "0 1 1 -0.510826 \n", + "1 1 0 -0.916291 \n", + "2 1 0 -0.916291 \n", + "3 1 1 -0.510826 \n", + "4 1 1 -0.510826 \n", + "5 1 0 -0.916291 \n", + "6 1 0 -0.916291 \n", + "7 1 1 -0.510826 \n", + "\n", + " fixed_factual_bill_hits consequent_differs_binary consequent_log_prob \\\n", + "0 1 True 0.0 \n", + "1 0 True 0.0 \n", + "2 1 True 0.0 \n", + "3 0 False -100000000.0 \n", + "4 1 False -100000000.0 \n", + "5 0 False -100000000.0 \n", + "6 1 False -100000000.0 \n", + "7 0 False -100000000.0 \n", + "\n", + " sum_log_prob \n", + "0 -1.427116e+00 \n", + "1 -1.832581e+00 \n", + "2 -1.832581e+00 \n", + "3 -1.000000e+08 \n", + "4 -1.000000e+08 \n", + "5 -1.000000e+08 \n", + "6 -1.000000e+08 \n", + "7 -1.000000e+08 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -590,26 +951,55 @@ "pyro.set_rng_seed(101)\n", "stones_redundant_HPM = HalpernPearlModifiedApproximate(\n", " model = stones_model,\n", - " antecedents = [\"sally_throws\", \"bill_throws\"],\n", + " counterfactual_antecedents = {\"sally_throws\": 0, \"bill_throws\": 0},\n", " outcome = \"bottle_shatters\",\n", - " witness_candidates = [\"bill_throws\", \"bill_hits\"],\n", + " witness_candidates = [\"bill_hits\"], #note \"bill_throws\" is no longer a witness candidate\n", " observations = {\"prob_sally_throws\": 1, \n", " \"prob_bill_throws\": 1,\n", " \"prob_sally_hits\": 1,\n", " \"prob_bill_hits\": 1,\n", " \"prob_bottle_shatters_if_sally\": 1,\n", " \"prob_bottle_shatters_if_bill\": 1,\n", - " \"sally_throws\": 1, \"bill_throws\": 1},\n", - " sample_size = 100,\n", - " event_dim = 0\n", + " \"sally_throws\": 1, \"bill_throws\": 1}\n", ")\n", "\n", - "stones_redundant_HPM()\n", "\n", - "stones_min_checked = ac_minimality_check(stones_redundant_HPM)\n", - "print(stones_min_checked.ante_subsets)\n", - "print(stones_min_checked.ante_existential_but_for)\n", - "print(stones_min_checked.ac)" + "with pyro.poutine.trace() as basic_trace__redundant_stones:\n", + " with pyro.plate(\"runs\", 1000):\n", + " stones_redundant_HPM()\n", + "\n", + "btr_redundant_stones= basic_trace__redundant_stones.trace.nodes\n", + "\n", + "btr_redundant_stones.keys()\n", + "\n", + "get_table(btr_redundant_stones, antecedents = ['sally_throws', 'bill_throws'], witness_candidates= ['bill_hits'], consequent= 'bottle_shatters')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is not minimal.\n" + ] + }, + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ac_check(stones_redundant_HPM, btr_redundant_stones)" ] }, { @@ -631,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -650,7 +1040,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -669,49 +1059,67 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "True\n" + "The antecedent set is an actual cause.\n" ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "# In the conjunctive model \n", "# Each of the two factors is a but-for cause\n", " \n", - "pyro.set_rng_seed(101)\n", "ff_conjunctiveHPM = HalpernPearlModifiedApproximate(\n", " model = ff_conjunctive,\n", - " antecedents = [\"match_dropped\"],\n", + " counterfactual_antecedents = {\"match_dropped\": 0},\n", " outcome = \"forest_fire\",\n", " witness_candidates = [\"lightning\"],\n", - " observations = {\"match_dropped\": 1, \"lightning\": 1},\n", - " sample_size = 4,\n", - " event_dim = 0\n", + " observations = {\"match_dropped\": 1, \"lightning\": 1}\n", ")\n", "\n", - "ff_conjunctiveHPM()\n", - "ff_conjunctiveHPM_min = ac_minimality_check(ff_conjunctiveHPM)\n", - "print(ff_conjunctiveHPM_min.ac) \n" + "with pyro.poutine.trace() as trace_ff_conjunctive:\n", + " with pyro.plate(\"runs\", 500):\n", + " ff_conjunctiveHPM()\n", + "\n", + "tfc = trace_ff_conjunctive.trace.nodes \n", + "\n", + "ac_check(ff_conjunctiveHPM, tfc)" ] }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 33, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is not minimal.\n" + ] + }, { "data": { "text/plain": [ "False" ] }, - "execution_count": 58, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -723,17 +1131,19 @@ "pyro.set_rng_seed(101)\n", "ff_disjunctiveHPM = HalpernPearlModifiedApproximate(\n", " model = ff_disjunctive,\n", - " antecedents = [\"match_dropped\"],\n", + " counterfactual_antecedents = {\"match_dropped\":0},\n", " outcome = \"forest_fire\",\n", " witness_candidates = [\"lightning\"],\n", - " observations = {\"match_dropped\": 1, \"lightning\": 1},\n", - " sample_size = 4,\n", - " event_dim = 0\n", + " observations = {\"match_dropped\": 1, \"lightning\": 1}\n", ")\n", "\n", - "ff_disjunctiveHPM()\n", - "ff_disjunctiveHPM.existential_but_for\n", - "# no need for further checks" + "with pyro.poutine.trace() as trace_ff_disjunctive:\n", + " with pyro.plate(\"runs\", 100):\n", + " ff_disjunctiveHPM()\n", + "\n", + "tfd = trace_ff_disjunctive.trace.nodes \n", + "\n", + "ac_check(ff_disjunctiveHPM, tfd)\n" ] }, { @@ -756,24 +1166,21 @@ "# in the disjunctive model\n", "# the actual cause is the conjunction of the two factors\n", "\n", - "pyro.set_rng_seed(101)\n", "ff_disjunctive_jointHPM = HalpernPearlModifiedApproximate(\n", " model = ff_disjunctive,\n", - " antecedents = [\"match_dropped\", \"lightning\"],\n", + " counterfactual_antecedents = {\"match_dropped\": 0, \"lightning\":0},\n", " outcome = \"forest_fire\",\n", " witness_candidates = [],\n", " observations = {\"match_dropped\": 1, \"lightning\": 1},\n", - " sample_size = 4,\n", - " event_dim = 0\n", ")\n", "\n", - "ff_disjunctive_jointHPM()\n", - "print(ff_disjunctive_jointHPM.existential_but_for)\n", + "with pyro.poutine.trace() as trace_ff_disjunctive_joint:\n", + " with pyro.plate(\"runs\", 100):\n", + " ff_disjunctive_jointHPM()\n", + "\n", + "tfj = trace_ff_disjunctive_joint.trace.nodes \n", "\n", - "ff_disjunctive_jointHPM_min = ac_minimality_check(ff_disjunctive_jointHPM)\n", - "print(ff_disjunctive_jointHPM_min.ante_subsets)\n", - "print(ff_disjunctive_jointHPM_min.ante_existential_but_for)\n", - "print(ff_disjunctive_jointHPM_min.ac) " + "ac_check(ff_disjunctive_jointHPM, tfj)" ] }, { From 074ab4e4ffc4e0a6814f30c4f423fcf7cdb8632d Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Mon, 14 Aug 2023 23:44:22 +0200 Subject: [PATCH 03/13] actual causality notebook with preemption --- docs/source/actual_causality.ipynb | 356 +++++++++++++++++++---------- 1 file changed, 235 insertions(+), 121 deletions(-) diff --git a/docs/source/actual_causality.ipynb b/docs/source/actual_causality.ipynb index 1ff99ccf..7525c06d 100644 --- a/docs/source/actual_causality.ipynb +++ b/docs/source/actual_causality.ipynb @@ -162,7 +162,6 @@ "import numpy as np\n", "from itertools import combinations\n", "\n", - "\n", "import torch\n", "from typing import Dict, List, Optional, Union, Callable, Any\n", "\n", @@ -305,7 +304,7 @@ "source": [ "# this will explore the trace once we run inference on the model\n", "\n", - "def get_table(nodes, antecedents, witness_candidates, consequent):\n", + "def get_table(nodes, antecedents, witness_candidates):\n", " \n", " values_table = {}\n", "\n", @@ -349,7 +348,8 @@ "metadata": {}, "outputs": [], "source": [ - "# this reduces the actual causality check to checking a property of the resulting sums of log probabilities for the antecedent preemption and the consequent differs nodes\n", + "# this reduces the actual causality check to checking a property of the resulting sums of log probabilities\n", + "# for the antecedent preemption and the consequent differs nodes\n", "\n", "def ac_check(hpm,nodes):\n", "\n", @@ -359,10 +359,9 @@ "\n", " table = get_table(nodes,\n", " antecedents,\n", - " witness_candidates,\n", - " consequent)\n", + " witness_candidates)\n", "\n", - " if table['sum_log_prob'][0] == -1e8:\n", + " if table['sum_log_prob'][0] <= -1e8:\n", " print(\"No resulting difference to the consequent in the sample.\")\n", " return\n", " \n", @@ -510,7 +509,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -535,7 +534,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -694,7 +693,7 @@ "7 False -100000000.0 -1.000000e+08 " ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -706,12 +705,12 @@ "\n", "btr_stones= basic_trace_stones.trace.nodes\n", "\n", - "get_table(btr_stones, antecedents = ['sally_throws'], witness_candidates= ['bill_hits','bill_throws'], consequent= 'bottle_shatters')" + "get_table(btr_stones, antecedents = ['sally_throws'], witness_candidates= ['bill_hits','bill_throws'])" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -727,7 +726,7 @@ "True" ] }, - "execution_count": 13, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -753,7 +752,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -940,7 +939,7 @@ "7 -1.000000e+08 " ] }, - "execution_count": 22, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -972,12 +971,12 @@ "\n", "btr_redundant_stones.keys()\n", "\n", - "get_table(btr_redundant_stones, antecedents = ['sally_throws', 'bill_throws'], witness_candidates= ['bill_hits'], consequent= 'bottle_shatters')\n" + "get_table(btr_redundant_stones, antecedents = ['sally_throws', 'bill_throws'], witness_candidates= ['bill_hits'])\n" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -993,7 +992,7 @@ "False" ] }, - "execution_count": 23, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1021,7 +1020,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -1040,7 +1039,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -1059,7 +1058,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -1075,7 +1074,7 @@ "True" ] }, - "execution_count": 29, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -1103,25 +1102,15 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The antecedent set is not minimal.\n" + "No resulting difference to the consequent in the sample.\n" ] - }, - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -1138,7 +1127,7 @@ ")\n", "\n", "with pyro.poutine.trace() as trace_ff_disjunctive:\n", - " with pyro.plate(\"runs\", 100):\n", + " with pyro.plate(\"runs\", 500):\n", " ff_disjunctiveHPM()\n", "\n", "tfd = trace_ff_disjunctive.trace.nodes \n", @@ -1148,18 +1137,25 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor(True)\n", - "[[], ['match_dropped'], ['lightning']]\n", - "[tensor(False), tensor(False), tensor(False)]\n", - "tensor(True)\n" + "The antecedent set is an actual cause.\n" ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -1175,7 +1171,7 @@ ")\n", "\n", "with pyro.poutine.trace() as trace_ff_disjunctive_joint:\n", - " with pyro.plate(\"runs\", 100):\n", + " with pyro.plate(\"runs\", 500):\n", " ff_disjunctive_jointHPM()\n", "\n", "tfj = trace_ff_disjunctive_joint.trace.nodes \n", @@ -1204,7 +1200,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -1255,68 +1251,118 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "step 1: tensor(True) step 2: tensor(True) step 3: tensor(False)\n", - "tensor(False)\n" + "The antecedent set is an actual cause.\n" ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ + "# The first actual causal link holds\n", + "\n", "doctors1_HPM = HalpernPearlModifiedApproximate(\n", " model = model_doctors,\n", - " antecedents = [\"monday_treatment\"],\n", + " counterfactual_antecedents = {\"monday_treatment\": 0.0},\n", " outcome = \"tuesday_treatment\",\n", " witness_candidates = [],\n", - " observations = {\"u_monday_treatment\": 1},\n", - " sample_size = 4,\n", - " event_dim = 0\n", - ")\n", + " observations = {\"u_monday_treatment\": 1.0}\n", + " )\n", "\n", - "doctors1_HPM()\n", + "with pyro.poutine.trace() as trace_doctors1:\n", + " with pyro.plate(\"runs\", 500):\n", + " doctors1_HPM()\n", + "\n", + "doc1 = trace_doctors1.trace.nodes \n", + "ac_check(doctors1_HPM, doc1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ "\n", "doctors2_HPM = HalpernPearlModifiedApproximate(\n", " model = model_doctors,\n", - " antecedents = [\"tuesday_treatment\"],\n", + " counterfactual_antecedents = {\"tuesday_treatment\": 1.0},\n", " outcome = \"bill_alive\",\n", " witness_candidates = [],\n", - " observations = {\"u_monday_treatment\": 1},\n", - " sample_size = 4,\n", - " event_dim = 0\n", + " observations = {\"u_monday_treatment\": 1.0}\n", ")\n", "\n", - "doctors2_HPM()\n", + "with pyro.poutine.trace() as trace_doctors2:\n", + " with pyro.plate(\"runs\", 5):\n", + " doctors2_HPM()\n", + "\n", + "doc2 = trace_doctors2.trace.nodes \n", + "\n", + "ac_check(doctors2_HPM, doc2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No resulting difference to the consequent in the sample.\n" + ] + } + ], + "source": [ "\n", "doctors3_HPM = HalpernPearlModifiedApproximate(\n", " model = model_doctors,\n", - " antecedents = [\"monday_treatment\"],\n", + " counterfactual_antecedents = {\"monday_treatment\": 0.0},\n", " outcome = \"bill_alive\",\n", " witness_candidates = [],\n", - " observations = {\"u_monday_treatment\": 1},\n", - " sample_size = 4,\n", - " event_dim = 0\n", + " observations = {\"u_monday_treatment\": 1.0}\n", ")\n", "\n", - "doctors3_HPM()\n", - "\n", - "\n", - "doctors1_HPM_min = ac_minimality_check(doctors1_HPM)\n", - "doctors2_HPM_min = ac_minimality_check(doctors2_HPM)\n", - "doctors3_HPM_min = ac_minimality_check(doctors3_HPM)\n", - "\n", + "with pyro.poutine.trace() as trace_doctors3:\n", + " with pyro.plate(\"runs\", 500):\n", + " doctors3_HPM()\n", "\n", - "print(\n", - "\"step 1:\", doctors1_HPM.ac,\n", - "\"step 2:\", doctors2_HPM.ac,\n", - "\"step 3:\", doctors3_HPM.ac\n", - ")\n", + "doc3 = trace_doctors3.trace.nodes \n", "\n", - "print(doctors3_HPM_min.existential_but_for)" + "ac_check(doctors3_HPM, doc3)" ] }, { @@ -1364,7 +1410,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -1439,46 +1485,91 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tuple: True PLGR_before: True second calculation: True\n", - "[[], ['f6_PLGR_before'], ['f7_second_calculation']]\n", - "[False, True, True]\n", - "False\n", - "False\n" + "The antecedent set is not minimal.\n" ] + }, + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "# while a conjunction of these two nodes satisfies the existential but-for...\n", + "# a conjunction of these two nodes is not minimal:\n", "\n", "friendly_fire_HPM = HalpernPearlModifiedApproximate(\n", " model = model_friendly_fire,\n", - " antecedents = [\"f6_PLGR_before\", \"f7_second_calculation\"],\n", + " counterfactual_antecedents = {\"f6_PLGR_before\": 0.0, \"f7_second_calculation\": 0.0},\n", " outcome = \"f2_killed\",\n", " witness_candidates = [\"f4_PLGR_now\",\"f5_unaware\",\n", " \"f11_training\",\n", " \"f14_wrong_position\"],\n", - " observations = {\"u_f4_PLGR_now\": 1.0, \"u_f11_training\": 1.0},\n", - " sample_size = 20,\n", - " event_dim = 0\n", + " observations = {\"u_f4_PLGR_now\": 1.0, \"u_f11_training\": 1.0}\n", ")\n", "\n", - "friendly_fire_HPM()\n", - "print(friendly_fire_HPM.existential_but_for)\n", + "with pyro.poutine.trace() as friendly_tuple_trace:\n", + " with pyro.plate(\"runs\", 500):\n", + " friendly_fire_HPM()\n", "\n", - "# ... it is not minimal as so does any of the two factors alone\n", - "friendly_fire_HPM_min = ac_minimality_check(friendly_fire_HPM)\n", + "ft = friendly_tuple_trace.trace.nodes \n", "\n", - "print(friendly_fire_HPM_min.ante_subsets)\n", - "print(friendly_fire_HPM_min.ante_existential_but_for)\n", - "print(friendly_fire_HPM_min.minimal)\n", - "print(friendly_fire_HPM_min.ac)\n" + "ac_check(friendly_fire_HPM, ft)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "friendly_fire_singleHPM = HalpernPearlModifiedApproximate(\n", + " model = model_friendly_fire,\n", + " counterfactual_antecedents = {\"f6_PLGR_before\": 0.0},\n", + " outcome = \"f2_killed\",\n", + " witness_candidates = [\"f4_PLGR_now\",\"f5_unaware\",\n", + " \"f11_training\",\n", + " \"f14_wrong_position\"],\n", + " observations = {\"u_f4_PLGR_now\": 1.0, \"u_f11_training\": 1.0}\n", + ")\n", + "\n", + "\n", + "with pyro.poutine.trace() as friendly_single_trace:\n", + " with pyro.plate(\"runs\", 500):\n", + " friendly_fire_singleHPM()\n", + "\n", + "fs = friendly_single_trace.trace.nodes \n", + "\n", + "ac_check(friendly_fire_singleHPM, fs)" ] }, { @@ -1499,7 +1590,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -1522,16 +1613,23 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 26, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, { "data": { "text/plain": [ "True" ] }, - "execution_count": 16, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -1542,32 +1640,34 @@ "\n", "voting4HPM = HalpernPearlModifiedApproximate(\n", " model = voting_model,\n", - " antecedents = [\"vote0\"],\n", + " counterfactual_antecedents = {\"vote0\":0.0},\n", " outcome = \"outcome\",\n", " witness_candidates = [f\"vote{i}\" for i in range(1,6)],\n", " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=0., u_vote5=0),\n", - " sample_size = 1000)\n", + " u_vote3=1., u_vote4=0., u_vote5=0.))\n", + "\n", "\n", - "voting4HPM()\n", "\n", - "voting4HPM.existential_but_for" + "with pyro.poutine.trace() as voting4_trace:\n", + " with pyro.plate(\"runs\", 500):\n", + " voting4HPM()\n", + "\n", + "v4 = voting4_trace.trace.nodes \n", + "\n", + "ac_check(voting4HPM, v4)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 27, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "No resulting difference to the consequent in the sample.\n" + ] } ], "source": [ @@ -1576,30 +1676,40 @@ "\n", "voting5HPM = HalpernPearlModifiedApproximate(\n", " model = voting_model,\n", - " antecedents = [\"vote0\"],\n", + " counterfactual_antecedents = {\"vote0\": 0.0},\n", " outcome = \"outcome\",\n", " witness_candidates = [f\"vote{i}\" for i in range(1,6)],\n", " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=1., u_vote5=0),\n", - " sample_size = 1000)\n", + " u_vote3=1., u_vote4=1., u_vote5=0.))\n", + "\n", + "with pyro.poutine.trace() as voting5_trace:\n", + " with pyro.plate(\"runs\", 500):\n", + " voting5HPM()\n", "\n", - "voting5HPM()\n", + "v5 = voting5_trace.trace.nodes \n", "\n", - "voting5HPM.existential_but_for" + "ac_check(voting5HPM, v5)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 28, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The antecedent set is an actual cause.\n" + ] + }, { "data": { "text/plain": [ "True" ] }, - "execution_count": 18, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1609,16 +1719,20 @@ "\n", "voting_groupHPM = HalpernPearlModifiedApproximate(\n", " model = voting_model,\n", - " antecedents = [\"vote0\", \"vote1\"],\n", + " counterfactual_antecedents = {\"vote0\": 0.0, \"vote1\": 0.0},\n", " outcome = \"outcome\",\n", " witness_candidates = [f\"vote{i}\" for i in range(2,6)],\n", " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=1., u_vote5=0),\n", - " sample_size = 1000)\n", + " u_vote3=1., u_vote4=1., u_vote5=0))\n", + "\n", + "\n", + "with pyro.poutine.trace() as voting_group_trace:\n", + " with pyro.plate(\"runs\", 500):\n", + " voting_groupHPM()\n", "\n", - "voting_groupHPM()\n", + "vg = voting_group_trace.trace.nodes \n", "\n", - "voting_groupHPM.existential_but_for\n" + "ac_check(voting_groupHPM, vg)" ] }, { From b4d19acb21e36fd7bd085f1f73c60d1fc016c018 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Wed, 16 Aug 2023 00:27:31 +0200 Subject: [PATCH 04/13] responsibility debugging --- docs/source/responsibility.ipynb | 1306 ++++++++++++++++++++++++++++++ 1 file changed, 1306 insertions(+) create mode 100644 docs/source/responsibility.ipynb diff --git a/docs/source/responsibility.ipynb b/docs/source/responsibility.ipynb new file mode 100644 index 00000000..fd868073 --- /dev/null +++ b/docs/source/responsibility.ipynb @@ -0,0 +1,1306 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Responsibility and actual causality" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Preceding notebook**\n", + "\n", + "- [Actual Causality: the modified Halpern-Pearl definition]() TODO add link\n", + "\n", + "\n", + "**Summary**\n", + "\n", + "In a previous notebook, we introduced and implemented the Halpern-Pearl modified definition of actual causality. Here we implement the way Halpern used this notion to introduce his so-called *naive definition of responsibility*. We also briefly illustrate some reasons to think a somewhat more sophisticated notion is needed." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Outline**\n", + "\n", + "[Intuitions](##intuitions)\n", + " \n", + "[Formalization](#formalization)\n", + "\n", + "[Implementation](#implementation)\n", + "\n", + "[Examples](#examples)\n", + "\n", + "- [Comments on example selection](#comments-on-example-selection)\n", + " \n", + "- [Voting](#voting)\n", + "\n", + "- [Stone-throwing](#stone-throwing)\n", + "\n", + "- [Firing squad](#firing-squad)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Intuitions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The key idea here is that your responsibility for an outcome is to be measured in terms of how drastic a change would have to be made to the world for the outcome to depend counterfactually on your actions. However, the definition uses a fairly crude measure thereof, the minimal *number* of changes needed, where those numbers are individuated in terms of nodes. On one hand, if you are part of a cause, we count how many elements the cause has. On the other, we count the number of nodes that a witness set has. We add these two numbers for any combination of an actual cause and a witness set and we take the minimum, say $k$. Your responsibility is then $1/k$. " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Formalization" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The degree of responsibility of $X = x$ for $\\varphi$ in $\\langle M, \\vec{u}\\rangle$ is 0 if $X = x$ is not part of an actual cause of $\\varphi$ in $\\langle M, \\vec{u}\\rangle$ according\n", + "to the modified HP definition. It is $1/k$ if there exists an actual cause $\\vec{X} = \\vec{x}$ of $\\varphi$ and a witness $\\vec{W}$ to $\\vec{X}=\\vec{x}$ being a cause of $\\varphi$ in $\\langle M, \\vec{u}\\rangle$ such that \n", + "(a) $X=x$ is a conjunct in $\\vec{X}= \\vec{x}$, (b) $\\vert \\vec{W}\\vert + \\vert\\vec{X}\\vert = k$, and (c) $k$ is minimal such a number.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "import numpy as np\n", + "from itertools import combinations\n", + "\n", + "import torch\n", + "from typing import Dict, List, Optional, Union, Callable, Any\n", + "\n", + "import pandas as pd\n", + "\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "\n", + "from chirho.indexed.ops import IndexSet, gather, indices_of, scatter\n", + "from chirho.interventional.handlers import do\n", + "from chirho.counterfactual.ops import preempt\n", + "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions\n", + "from chirho.observational.handlers import condition" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class BiasedPreemptions(Preemptions):\n", + " \"\"\"\n", + " Counterfactual handler that preempts the model with a biased coin flip.\n", + " \"\"\"\n", + " def __init__(self, actions, weights: torch.Tensor, event_dim: int = 0) -> None:\n", + " self.weights = weights\n", + " self.event_dim = event_dim\n", + " super().__init__(actions)\n", + "\n", + " def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:\n", + " try:\n", + " name = msg[\"name\"]\n", + " action = self.actions[name]\n", + " except KeyError:\n", + " return\n", + " value = msg[\"value\"]\n", + " factual_value = gather(value, IndexSet(**{name: {0}}),\n", + " event_dim=self.event_dim),\n", + " counterfactual_value = gather(value, IndexSet(**{name: {1}}),\n", + " event_dim=self.event_dim),\n", + " factual_value = preempt(\n", + " factual_value,\n", + " (action,),\n", + " None,\n", + " event_dim=len(msg[\"fn\"].event_shape),\n", + " name=msg[\"name\"],\n", + " )\n", + "\n", + " msg[\"value\"] = scatter({\n", + " IndexSet(**{name: {0}}): factual_value,\n", + " IndexSet(**{name: {1}}): counterfactual_value,\n", + " }, event_dim=self.event_dim)\n", + "\n", + " def _pyro_preempt(self,msg: Dict[str, Any]) -> None:\n", + " if msg[\"name\"] not in self.actions:\n", + " return\n", + " obs, acts, case = msg[\"args\"]\n", + " msg[\"kwargs\"][\"name\"] = f\"__split_{msg['name']}\"\n", + " case_dist = pyro.distributions.Categorical(self.weights)\n", + " case = pyro.sample(msg[\"kwargs\"][\"name\"], case_dist, obs=case)\n", + " msg[\"args\"] = (obs, acts, case)\n", + " msg[\"stop\"] = True" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# slight modification wrt. to the orginal class:\n", + "# we use BiasedPreemption for witnesses\n", + "# with the same intervention bias as for antecedents\n", + "# to minimize for the number of active (antecedents + witnesses)\n", + "\n", + "class HalpernPearlModifiedApproximate:\n", + "\n", + " def __init__(\n", + " self, \n", + " model: Callable,\n", + " counterfactual_antecedents: Dict[str, torch.Tensor],\n", + " outcome: str,\n", + " witness_candidates: List[str],\n", + " observations: Optional[Dict[str, torch.Tensor]] = None\n", + " ):\n", + " \n", + " if observations is None:\n", + " observations = {}\n", + "\n", + " self.model = model\n", + " self.counterfactual_antecedents = counterfactual_antecedents\n", + " self.outcome = outcome\n", + " self.witness_candidates = witness_candidates\n", + " self.observations = observations\n", + "\n", + " self.antecedent_preemptions = {antecedent: functools.partial(self.preempt_with_factual,\n", + " antecedents = [antecedent]) for\n", + " antecedent in self.counterfactual_antecedents.keys()}\n", + " \n", + " self.witness_preemptions = {candidate: functools.partial(self.preempt_with_factual,\n", + " antecedents = self.counterfactual_antecedents) for \n", + " candidate in self.witness_candidates}\n", + " \n", + " @staticmethod \n", + " def preempt_with_factual(value: torch.Tensor, *,\n", + " antecedents: List[str] = None, event_dim: int = 0):\n", + " \n", + " if antecedents is None:\n", + " antecedents = []\n", + "\n", + " antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]\n", + "\n", + " factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),\n", + " event_dim=event_dim)\n", + " \n", + " return scatter({\n", + " IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,\n", + " IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,\n", + " }, event_dim=event_dim)\n", + " \n", + " \n", + " def __call__(self, *args, **kwargs):\n", + " with MultiWorldCounterfactual():\n", + " with do(actions=self.counterfactual_antecedents):\n", + " # the last element of the tensor is the factual case (preempted)\n", + " with BiasedPreemptions(actions = self.antecedent_preemptions, weights = torch.tensor([.4, .6])):\n", + " #the last element is the fixed at the observed value (preempted) \n", + " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.6, .4])):\n", + " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", + " with pyro.poutine.trace() as self.trace:\n", + " self.consequent = self.model(*args, **kwargs)[self.outcome]\n", + " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.counterfactual_antecedents}))\n", + " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.counterfactual_antecedents}))\n", + " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", + " pyro.deterministic(\"consequent_differs_binary\", self.consequent_differs, event_dim = 0) #feels inelegant\n", + " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# this will explore the trace once we run inference on the model\n", + "\n", + "def get_table(nodes, antecedents, witness_candidates):\n", + " \n", + " values_table = {}\n", + "\n", + " for antecedent in antecedents:\n", + " values_table[antecedent] = nodes[antecedent][\"value\"].squeeze().tolist()\n", + " values_table['preempted_' + antecedent] = nodes['__split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['preempted_' + antecedent + '_log_prob'] = nodes['__split_' + antecedent][\"fn\"].log_prob(nodes['__split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + "\n", + " for candidate in witness_candidates:\n", + " _values = nodes[candidate][\"value\"].squeeze().tolist()\n", + " # TODO: uncomment in the final version (?) \n", + " #values_table[candidate + '0'] = _values[0]\n", + " #values_table[candidate + '1'] = _values[1]\n", + " values_table['fixed_factual_' + candidate] = nodes['__split_' + candidate][\"value\"].squeeze().tolist()\n", + " \n", + " # TODO uncomment in the final version (?)\n", + " #values_table[consequent + '0'] = nodes[consequent][\"value\"].squeeze().tolist()[0]\n", + " #values_table[consequent + '1'] = nodes[consequent][\"value\"].squeeze().tolist()[1]\n", + " values_table['consequent_differs_binary'] = nodes['consequent_differs_binary'][\"value\"].squeeze().tolist()\n", + " values_table['consequent_log_prob'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " if isinstance(values_table['consequent_log_prob'], float):\n", + " values_df = pd.DataFrame([values_table])\n", + " else:\n", + " values_df = pd.DataFrame(values_table)\n", + " \n", + "\n", + " summands = ['preempted_' + antecedent + '_log_prob' for antecedent in antecedents]\n", + " summands.append('consequent_log_prob')\n", + " values_df[\"sum_log_prob\"] = values_df[summands].sum(axis = 1) \n", + " values_df.drop_duplicates(inplace = True)\n", + " values_df.sort_values(by = \"sum_log_prob\", inplace = True, ascending = False)\n", + "\n", + " return values_df.reset_index(drop = True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def voting_model():\n", + " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", + " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", + " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", + " u_vote3 = pyro.sample(\"u_vote3\", dist.Bernoulli(0.6))\n", + " u_vote4 = pyro.sample(\"u_vote4\", dist.Bernoulli(0.6))\n", + " u_vote5 = pyro.sample(\"u_vote5\", dist.Bernoulli(0.6))\n", + " u_vote6 = pyro.sample(\"u_vote6\", dist.Bernoulli(0.6))\n", + " u_vote7 = pyro.sample(\"u_vote7\", dist.Bernoulli(0.6))\n", + "\n", + " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", + " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", + " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", + " vote3 = pyro.deterministic(\"vote3\", u_vote3, event_dim=0)\n", + " vote4 = pyro.deterministic(\"vote4\", u_vote4, event_dim=0)\n", + " vote5 = pyro.deterministic(\"vote5\", u_vote5, event_dim=0)\n", + " vote6 = pyro.deterministic(\"vote6\", u_vote6, event_dim=0)\n", + " vote7 = pyro.deterministic(\"vote7\", u_vote7, event_dim=0)\n", + "\n", + "\n", + " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 + vote3 + \n", + " vote4 + vote5 + vote6 + vote7 > 4)\n", + " return {\"outcome\": outcome.float()}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# if you're one of five voters who voted for, you are an actual cause\n", + "\n", + "# and your responsibility is 1 \n", + "\n", + "observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=0.,\n", + " u_vote6=0., u_vote7=0.)\n", + "\n", + "\n", + "\n", + "counterfactual_antecedents = {key[2:]: 1-v for key, v in observations.items()}\n", + "\n", + "\n", + "voting5HPM = HalpernPearlModifiedApproximate(\n", + " model = voting_model,\n", + " counterfactual_antecedents = counterfactual_antecedents,\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,8)],\n", + " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=0.,\n", + " u_vote6=0., u_vote7=0.))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "NotImplementedError", + "evalue": "intervene not implemented for type ", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mwith\u001b[39;00m pyro\u001b[39m.\u001b[39mplate(\u001b[39m\"\u001b[39m\u001b[39mruns\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m10\u001b[39m):\n\u001b[0;32m----> 2\u001b[0m voting5HPM()\n", + "Cell \u001b[0;32mIn[3], line 61\u001b[0m, in \u001b[0;36mHalpernPearlModifiedApproximate.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[39mwith\u001b[39;00m condition(data\u001b[39m=\u001b[39m{k: torch\u001b[39m.\u001b[39mas_tensor(v) \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobservations\u001b[39m.\u001b[39mitems()}):\n\u001b[1;32m 60\u001b[0m \u001b[39mwith\u001b[39;00m pyro\u001b[39m.\u001b[39mpoutine\u001b[39m.\u001b[39mtrace() \u001b[39mas\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrace:\n\u001b[0;32m---> 61\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconsequent \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutcome]\n\u001b[1;32m 62\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mintervened_consequent \u001b[39m=\u001b[39m gather(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconsequent, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{ant: {\u001b[39m1\u001b[39m} \u001b[39mfor\u001b[39;00m ant \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcounterfactual_antecedents}))\n\u001b[1;32m 63\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobserved_consequent \u001b[39m=\u001b[39m gather(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconsequent, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{ant: {\u001b[39m0\u001b[39m} \u001b[39mfor\u001b[39;00m ant \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcounterfactual_antecedents}))\n", + "Cell \u001b[0;32mIn[5], line 11\u001b[0m, in \u001b[0;36mvoting_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m u_vote6 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39msample(\u001b[39m\"\u001b[39m\u001b[39mu_vote6\u001b[39m\u001b[39m\"\u001b[39m, dist\u001b[39m.\u001b[39mBernoulli(\u001b[39m0.6\u001b[39m))\n\u001b[1;32m 9\u001b[0m u_vote7 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39msample(\u001b[39m\"\u001b[39m\u001b[39mu_vote7\u001b[39m\u001b[39m\"\u001b[39m, dist\u001b[39m.\u001b[39mBernoulli(\u001b[39m0.6\u001b[39m))\n\u001b[0;32m---> 11\u001b[0m vote0 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39;49mdeterministic(\u001b[39m\"\u001b[39;49m\u001b[39mvote0\u001b[39;49m\u001b[39m\"\u001b[39;49m, u_vote0, event_dim\u001b[39m=\u001b[39;49m\u001b[39m0\u001b[39;49m)\n\u001b[1;32m 12\u001b[0m vote1 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39mdeterministic(\u001b[39m\"\u001b[39m\u001b[39mvote1\u001b[39m\u001b[39m\"\u001b[39m, u_vote1, event_dim\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 13\u001b[0m vote2 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39mdeterministic(\u001b[39m\"\u001b[39m\u001b[39mvote2\u001b[39m\u001b[39m\"\u001b[39m, u_vote2, event_dim\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/primitives.py:209\u001b[0m, in \u001b[0;36mdeterministic\u001b[0;34m(name, value, event_dim)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[39mDeterministic statement to add a :class:`~pyro.distributions.Delta` site\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[39mwith name `name` and value `value` to the trace. This is useful when we\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[39m:param int event_dim: Optional event dimension, defaults to `value.ndim`.\u001b[39;00m\n\u001b[1;32m 207\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 208\u001b[0m event_dim \u001b[39m=\u001b[39m value\u001b[39m.\u001b[39mndim \u001b[39mif\u001b[39;00m event_dim \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m event_dim\n\u001b[0;32m--> 209\u001b[0m \u001b[39mreturn\u001b[39;00m sample(\n\u001b[1;32m 210\u001b[0m name,\n\u001b[1;32m 211\u001b[0m dist\u001b[39m.\u001b[39;49mDelta(value, event_dim\u001b[39m=\u001b[39;49mevent_dim)\u001b[39m.\u001b[39;49mmask(\u001b[39mFalse\u001b[39;49;00m),\n\u001b[1;32m 212\u001b[0m obs\u001b[39m=\u001b[39;49mvalue,\n\u001b[1;32m 213\u001b[0m infer\u001b[39m=\u001b[39;49m{\u001b[39m\"\u001b[39;49m\u001b[39m_deterministic\u001b[39;49m\u001b[39m\"\u001b[39;49m: \u001b[39mTrue\u001b[39;49;00m},\n\u001b[1;32m 214\u001b[0m )\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/primitives.py:163\u001b[0m, in \u001b[0;36msample\u001b[0;34m(name, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 146\u001b[0m msg \u001b[39m=\u001b[39m {\n\u001b[1;32m 147\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39msample\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 148\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mname\u001b[39m\u001b[39m\"\u001b[39m: name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcontinuation\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 161\u001b[0m }\n\u001b[1;32m 162\u001b[0m \u001b[39m# apply the stack and return its return value\u001b[39;00m\n\u001b[0;32m--> 163\u001b[0m apply_stack(msg)\n\u001b[1;32m 164\u001b[0m \u001b[39mreturn\u001b[39;00m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:220\u001b[0m, in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 217\u001b[0m default_process_message(msg)\n\u001b[1;32m 219\u001b[0m \u001b[39mfor\u001b[39;00m frame \u001b[39min\u001b[39;00m stack[\u001b[39m-\u001b[39mpointer:]:\n\u001b[0;32m--> 220\u001b[0m frame\u001b[39m.\u001b[39;49m_postprocess_message(msg)\n\u001b[1;32m 222\u001b[0m cont \u001b[39m=\u001b[39m msg[\u001b[39m\"\u001b[39m\u001b[39mcontinuation\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[1;32m 223\u001b[0m \u001b[39mif\u001b[39;00m cont \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:168\u001b[0m, in \u001b[0;36mMessenger._postprocess_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 166\u001b[0m method \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m_pyro_post_\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(msg[\u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m]), \u001b[39mNone\u001b[39;00m)\n\u001b[1;32m 167\u001b[0m \u001b[39mif\u001b[39;00m method \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 168\u001b[0m \u001b[39mreturn\u001b[39;00m method(msg)\n\u001b[1;32m 169\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mNone\u001b[39;00m\n", + "Cell \u001b[0;32mIn[2], line 21\u001b[0m, in \u001b[0;36mBiasedPreemptions._pyro_post_sample\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 17\u001b[0m factual_value \u001b[39m=\u001b[39m gather(value, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m0\u001b[39m}}),\n\u001b[1;32m 18\u001b[0m event_dim\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mevent_dim),\n\u001b[1;32m 19\u001b[0m counterfactual_value \u001b[39m=\u001b[39m gather(value, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m1\u001b[39m}}),\n\u001b[1;32m 20\u001b[0m event_dim\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mevent_dim),\n\u001b[0;32m---> 21\u001b[0m factual_value \u001b[39m=\u001b[39m preempt(\n\u001b[1;32m 22\u001b[0m factual_value,\n\u001b[1;32m 23\u001b[0m (action,),\n\u001b[1;32m 24\u001b[0m \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 25\u001b[0m event_dim\u001b[39m=\u001b[39;49m\u001b[39mlen\u001b[39;49m(msg[\u001b[39m\"\u001b[39;49m\u001b[39mfn\u001b[39;49m\u001b[39m\"\u001b[39;49m]\u001b[39m.\u001b[39;49mevent_shape),\n\u001b[1;32m 26\u001b[0m name\u001b[39m=\u001b[39;49mmsg[\u001b[39m\"\u001b[39;49m\u001b[39mname\u001b[39;49m\u001b[39m\"\u001b[39;49m],\n\u001b[1;32m 27\u001b[0m )\n\u001b[1;32m 29\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m scatter({\n\u001b[1;32m 30\u001b[0m IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m0\u001b[39m}}): factual_value,\n\u001b[1;32m 31\u001b[0m IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m1\u001b[39m}}): counterfactual_value,\n\u001b[1;32m 32\u001b[0m }, event_dim\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mevent_dim)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:281\u001b[0m, in \u001b[0;36meffectful.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 264\u001b[0m msg \u001b[39m=\u001b[39m {\n\u001b[1;32m 265\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mtype\u001b[39m,\n\u001b[1;32m 266\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mname\u001b[39m\u001b[39m\"\u001b[39m: name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[39m\"\u001b[39m\u001b[39minfer\u001b[39m\u001b[39m\"\u001b[39m: infer,\n\u001b[1;32m 279\u001b[0m }\n\u001b[1;32m 280\u001b[0m \u001b[39m# apply the stack and return its return value\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m apply_stack(msg)\n\u001b[1;32m 282\u001b[0m \u001b[39mreturn\u001b[39;00m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:217\u001b[0m, in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[39mif\u001b[39;00m msg[\u001b[39m\"\u001b[39m\u001b[39mstop\u001b[39m\u001b[39m\"\u001b[39m]:\n\u001b[1;32m 215\u001b[0m \u001b[39mbreak\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m default_process_message(msg)\n\u001b[1;32m 219\u001b[0m \u001b[39mfor\u001b[39;00m frame \u001b[39min\u001b[39;00m stack[\u001b[39m-\u001b[39mpointer:]:\n\u001b[1;32m 220\u001b[0m frame\u001b[39m.\u001b[39m_postprocess_message(msg)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:179\u001b[0m, in \u001b[0;36mdefault_process_message\u001b[0;34m(msg)\u001b[0m\n\u001b[1;32m 176\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mdone\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[1;32m 177\u001b[0m \u001b[39mreturn\u001b[39;00m msg\n\u001b[0;32m--> 179\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m msg[\u001b[39m\"\u001b[39;49m\u001b[39mfn\u001b[39;49m\u001b[39m\"\u001b[39;49m](\u001b[39m*\u001b[39;49mmsg[\u001b[39m\"\u001b[39;49m\u001b[39margs\u001b[39;49m\u001b[39m\"\u001b[39;49m], \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmsg[\u001b[39m\"\u001b[39;49m\u001b[39mkwargs\u001b[39;49m\u001b[39m\"\u001b[39;49m])\n\u001b[1;32m 181\u001b[0m \u001b[39m# after fn has been called, update msg to prevent it from being called again.\u001b[39;00m\n\u001b[1;32m 182\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mdone\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:12\u001b[0m, in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_context_wrap\u001b[39m(context, fn, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 11\u001b[0m \u001b[39mwith\u001b[39;00m context:\n\u001b[0;32m---> 12\u001b[0m \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", + "File \u001b[0;32m~/L2projects/chirho/chirho/counterfactual/ops.py:53\u001b[0m, in \u001b[0;36mpreempt\u001b[0;34m(obs, acts, case, **kwargs)\u001b[0m\n\u001b[1;32m 51\u001b[0m act_values \u001b[39m=\u001b[39m {IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m0\u001b[39m}}): obs}\n\u001b[1;32m 52\u001b[0m \u001b[39mfor\u001b[39;00m i, act \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(acts):\n\u001b[0;32m---> 53\u001b[0m act_values[IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {i \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m}})] \u001b[39m=\u001b[39m intervene(obs, act, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 55\u001b[0m \u001b[39mreturn\u001b[39;00m cond(act_values, case, event_dim\u001b[39m=\u001b[39mkwargs\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mevent_dim\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m0\u001b[39m))\n", + "File \u001b[0;32m~/anaconda3/envs/causal_pyro/lib/python3.10/functools.py:889\u001b[0m, in \u001b[0;36msingledispatch..wrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 885\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m args:\n\u001b[1;32m 886\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mfuncname\u001b[39m}\u001b[39;00m\u001b[39m requires at least \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 887\u001b[0m \u001b[39m'\u001b[39m\u001b[39m1 positional argument\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m--> 889\u001b[0m \u001b[39mreturn\u001b[39;00m dispatch(args[\u001b[39m0\u001b[39;49m]\u001b[39m.\u001b[39;49m\u001b[39m__class__\u001b[39;49m)(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkw)\n", + "File \u001b[0;32m~/L2projects/chirho/chirho/interventional/ops.py:25\u001b[0m, in \u001b[0;36mintervene\u001b[0;34m(obs, act, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39msingledispatch\n\u001b[1;32m 12\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mintervene\u001b[39m(obs, act: Optional[Intervention[T]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 13\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[39m Intervene on a value in a probabilistic program.\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m :param act: an optional intervention.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 25\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mintervene not implemented for type \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(obs)\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", + "\u001b[0;31mNotImplementedError\u001b[0m: intervene not implemented for type " + ] + } + ], + "source": [ + "with pyro.plate(\"runs\", 10):\n", + " voting5HPM()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This implementation is now used within another class definition, where, again, the main moves are in `def __call__`. We sample antecedent sets, leave other nodes (aside from the outcome) as witness candidates, and pass the result to an actual causality evaluation, keeping track of minimal antecedent sets and the corresponding witness sizes. Then we find a minimum of the sum and use it in the denominator.`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "class HalpernPearlResponsibilityApproximate:\n", + "\n", + " def __init__(\n", + " self, \n", + " model: Callable,\n", + " nodes: List,\n", + " antecedent: str,\n", + " outcome: str,\n", + " observations: Dict[str, torch.Tensor], \n", + " runs_n: int \n", + " ):\n", + " self.model = model\n", + " self.nodes = nodes\n", + " self.antecedent = antecedent\n", + " self.outcome = outcome\n", + " self.observations = observations\n", + " self.runs_n = runs_n\n", + " \n", + " self.minimal_antecedents_cache = []\n", + " self.antecedent_sizes = []\n", + " self.existential_but_fors = []\n", + " self.acs = []\n", + " self.minimal_witness_sizes = []\n", + " self.responsibilities = []\n", + " self.HPMs = []\n", + "\n", + " def __call__(self):\n", + " \n", + " for step in range(1,self.runs_n):\n", + "\n", + " nodes = self.nodes\n", + " if self.outcome in nodes:\n", + " nodes.remove(self.outcome) \n", + " \n", + " companion_size = random.randint(0,len(nodes))\n", + " companion_candidates = random.sample(self.nodes, companion_size)\n", + " witness_candidates = [node for node in self.nodes if \n", + " node != self.antecedent and \n", + " node != self.outcome and \n", + " node not in companion_candidates]\n", + "\n", + " HPM = HalpernPearlModifiedApproximate(\n", + " model = self.model,\n", + " antecedents = companion_candidates,\n", + " outcome = self.outcome,\n", + " witness_candidates = witness_candidates,\n", + " observations = self.observations,\n", + " sample_size = 1000)\n", + " \n", + " HPM()\n", + " self.HPMs.append(HPM)\n", + "\n", + " if HPM.existential_but_for:\n", + " \n", + " HPM_min = ac_minimality_check(HPM)\n", + "\n", + " if HPM_min.ac:\n", + "\n", + " subset_in_cache = any([s.issubset(set(HPM.antecedents)) for s in self.minimal_antecedents_cache])\n", + " if not subset_in_cache:\n", + " for s in self.minimal_antecedents_cache:\n", + " if set(HPM.antecedents).issubset(s):\n", + " self.minimal_antecedents_cache.remove(s)\n", + " self.minimal_antecedents_cache.append(set(HPM.antecedents))\n", + "\n", + " if self.antecedent in HPM.antecedents:\n", + " self.antecedent_sizes.append(len(HPM.antecedents))\n", + " self.existential_but_fors.append(HPM.existential_but_for)\n", + " self.acs.append(HPM.ac)\n", + " self.minimal_witness_sizes.append(HPM.minimal_witness_size)\n", + " self.responsibilities.append(HPM.responsibility_internal)\n", + "\n", + "\n", + " self.denumerators = [x + y for x, y in zip(self.antecedent_sizes, self.minimal_witness_sizes)]\n", + "\n", + " self.responsibilityDF = pd.DataFrame(\n", + " {#\"existential_but_for\": [bool(value) for value in self.existential_but_fors],\n", + " \"acs\": [bool(value) for value in self.acs],\n", + " \"antecedent_size\": self.antecedent_sizes, \n", + " \"minimal_witness_size\": self.minimal_witness_sizes,\n", + " \"denumerator\": self.denumerators,\n", + " \"responsibility\": self.responsibilities\n", + " }\n", + " )\n", + " if len(self.responsibilityDF['acs']) == 0:\n", + " self.responsibility = 0\n", + " else:\n", + " min_denumerator = min(self.responsibilityDF['denumerator'])\n", + " self.responsibility = 1/min_denumerator\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examples" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comments on example selection\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- **Voting:** the example illustrates that parts of actual causes can share various degrees of responsibility for the outcome, without being actual causes.\n", + "\n", + "- **Stone-throwing:** responsibility calculations in one of the main running examples in the *Actual Causality* book by Halpern (2016).\n", + "\n", + "- **Firing squad:** an example in which responsibility and actual causality agree, where-as disussed in the notebook on the notion of blame-the notion of responsibility and blame will diverge." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Voting" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We discussed a similar model in a previous notebook. This time we have eight voters involved in a binary majority voting procedure and we investigate the responsibility assigned to voter 0. The situation is analogous to the one discussed in the actual causality notebook: if your vote is decisive, you are an actual cause, and you're not an actual cause otherwise. What's your responsibility, though? " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'__split_vote0'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m get_table(voting5HPM\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mnodes, antecedents \u001b[39m=\u001b[39;49m counterfactual_antecedents, witness_candidates \u001b[39m=\u001b[39;49m voting5HPM\u001b[39m.\u001b[39;49mwitness_candidates)\n", + "Cell \u001b[0;32mIn[4], line 9\u001b[0m, in \u001b[0;36mget_table\u001b[0;34m(nodes, antecedents, witness_candidates)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39mfor\u001b[39;00m antecedent \u001b[39min\u001b[39;00m antecedents:\n\u001b[1;32m 8\u001b[0m values_table[antecedent] \u001b[39m=\u001b[39m nodes[antecedent][\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39msqueeze()\u001b[39m.\u001b[39mtolist()\n\u001b[0;32m----> 9\u001b[0m values_table[\u001b[39m'\u001b[39m\u001b[39mpreempted_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent] \u001b[39m=\u001b[39m nodes[\u001b[39m'\u001b[39;49m\u001b[39m__split_\u001b[39;49m\u001b[39m'\u001b[39;49m \u001b[39m+\u001b[39;49m antecedent][\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39msqueeze()\u001b[39m.\u001b[39mtolist()\n\u001b[1;32m 10\u001b[0m values_table[\u001b[39m'\u001b[39m\u001b[39mpreempted_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39m_log_prob\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m nodes[\u001b[39m'\u001b[39m\u001b[39m__split_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent][\u001b[39m\"\u001b[39m\u001b[39mfn\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mlog_prob(nodes[\u001b[39m'\u001b[39m\u001b[39m__split_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent][\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m])\u001b[39m.\u001b[39msqueeze()\u001b[39m.\u001b[39mtolist()\n\u001b[1;32m 13\u001b[0m \u001b[39mfor\u001b[39;00m candidate \u001b[39min\u001b[39;00m witness_candidates:\n", + "\u001b[0;31mKeyError\u001b[0m: '__split_vote0'" + ] + } + ], + "source": [ + "get_table(voting5HPM.trace.trace.nodes, antecedents = counterfactual_antecedents, witness_candidates = voting5HPM.witness_candidates)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "# if everyone voted for, you are not an actual cause\n", + "\n", + "everyone_voted_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = voting_model,\n", + " nodes = [f\"vote{i}\" for i in range(0,8,)],\n", + " antecedent = \"vote0\", outcome = \"outcome\",\n", + " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 1.), \n", + " runs_n=500\n", + " )\n", + "\n", + "pyro.set_rng_seed(42)\n", + "everyone_voted_HPR()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
acsantecedent_sizeminimal_witness_sizedenumeratorresponsibility
0True4040.25
1True4040.25
2True4040.25
3True4040.25
4True4040.25
5True4040.25
6True4040.25
7True4040.25
8True4040.25
9True4040.25
10True4040.25
11True4040.25
12True4040.25
13True4040.25
14True4040.25
15True4040.25
16True4040.25
17True4040.25
18True4040.25
19True4040.25
20True4040.25
21True4040.25
\n", + "
" + ], + "text/plain": [ + " acs antecedent_size minimal_witness_size denumerator responsibility\n", + "0 True 4 0 4 0.25\n", + "1 True 4 0 4 0.25\n", + "2 True 4 0 4 0.25\n", + "3 True 4 0 4 0.25\n", + "4 True 4 0 4 0.25\n", + "5 True 4 0 4 0.25\n", + "6 True 4 0 4 0.25\n", + "7 True 4 0 4 0.25\n", + "8 True 4 0 4 0.25\n", + "9 True 4 0 4 0.25\n", + "10 True 4 0 4 0.25\n", + "11 True 4 0 4 0.25\n", + "12 True 4 0 4 0.25\n", + "13 True 4 0 4 0.25\n", + "14 True 4 0 4 0.25\n", + "15 True 4 0 4 0.25\n", + "16 True 4 0 4 0.25\n", + "17 True 4 0 4 0.25\n", + "18 True 4 0 4 0.25\n", + "19 True 4 0 4 0.25\n", + "20 True 4 0 4 0.25\n", + "21 True 4 0 4 0.25" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# but the size-minimal actual causes are all of size 4\n", + "# so your responsibility is 1/4\n", + "\n", + "everyone_voted_HPR.responsibilityDF\n" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.25" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# four people would need to change their votes\n", + "# to change the outcome\n", + "# so your responsibility is 1/4\n", + "\n", + "everyone_voted_HPR.responsibility" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "# if only seven people voted for, \n", + "# your responsibility changes to 1/3\n", + "\n", + "seven_voted_for_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = voting_model,\n", + " nodes = [f\"vote{i}\" for i in range(0,8,)],\n", + " antecedent = \"vote0\", outcome = \"outcome\",\n", + " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 0.), \n", + " runs_n=500\n", + " )\n", + "\n", + "pyro.set_rng_seed(42)\n", + "seven_voted_for_HPR()" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.3333333333333333" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# your responsibility is 1/3 as in this case\n", + "# it would be enough for three people to vote against\n", + "# to change the outcome\n", + "\n", + "seven_voted_for_HPR.responsibility" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Stone-throwing\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've already discussed the model in the actual causality notebook. Sally and Bill throw stones at a bottle, Sally throws first. Bill is perfectly accurate, so his stone would have shattered the bottle had not Sally's stone done it. The model is worth looking at, as the causal structure is less trivial. Again, we will see that responsibility judgment might to some extent disagree with actual causality." + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "def stones_model(): \n", + " prob_sally_throws = pyro.sample(\"prob_sally_throws\", dist.Beta(1, 1))\n", + " prob_bill_throws = pyro.sample(\"prob_bill_throws\", dist.Beta(1, 1))\n", + " prob_sally_hits = pyro.sample(\"prob_sally_hits\", dist.Beta(1, 1))\n", + " prob_bill_hits = pyro.sample(\"prob_bill_hits\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_sally = pyro.sample(\"prob_bottle_shatters_if_sally\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_bill = pyro.sample(\"prob_bottle_shatters_if_bill\", dist.Beta(1, 1))\n", + "\n", + "\n", + " sally_throws = pyro.sample(\"sally_throws\", dist.Bernoulli(prob_sally_throws))\n", + " bill_throws = pyro.sample(\"bill_throws\", dist.Bernoulli(prob_bill_throws))\n", + "\n", + " new_shp = torch.where(sally_throws == 1,prob_sally_hits , 0.0)\n", + "\n", + " sally_hits = pyro.sample(\"sally_hits\",dist.Bernoulli(new_shp))\n", + "\n", + " new_bhp = torch.where(\n", + " (\n", + " bill_throws.bool()\n", + " & (~sally_hits.bool())\n", + " )\n", + " == 1,\n", + " prob_bill_hits,\n", + " torch.tensor(0.0),\n", + " )\n", + "\n", + "\n", + " bill_hits = pyro.sample(\"bill_hits\", dist.Bernoulli(new_bhp))\n", + "\n", + " new_bsp = torch.where(\n", + " bill_hits.bool() == 1,\n", + " prob_bottle_shatters_if_bill,\n", + " torch.where(\n", + " sally_hits.bool() == 1,\n", + " prob_bottle_shatters_if_sally,\n", + " torch.tensor(0.0),\n", + " ),\n", + " )\n", + "\n", + " bottle_shatters = pyro.sample(\n", + " \"bottle_shatters\", dist.Bernoulli(new_bsp)\n", + " )\n", + "\n", + " return {\n", + " \"sally_throws\": sally_throws,\n", + " \"bill_throws\": bill_throws,\n", + " \"sally_hits\": sally_hits,\n", + " \"bill_hits\": bill_hits,\n", + " \"bottle_shatters\": bottle_shatters,\n", + " }\n", + "\n", + "stones_model.nodes = [\n", + " \"sally_throws\",\n", + " \"bill_throws\",\n", + " \"sally_hits\",\n", + " \"bill_hits\",\n", + " \"bottle_shatters\",\n", + " ]" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(101)\n", + "responsibility_stones_sally_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = stones_model,\n", + " nodes = stones_model.nodes,\n", + " antecedent = \"sally_throws\", outcome = \"bottle_shatters\",\n", + " observations = {\"prob_sally_throws\": 1, \n", + " \"prob_bill_throws\": 1,\n", + " \"prob_sally_hits\": 1,\n", + " \"prob_bill_hits\": 1,\n", + " \"prob_bottle_shatters_if_sally\": 1,\n", + " \"prob_bottle_shatters_if_bill\": 1,\n", + " \"sally_throws\": 1, \"bill_throws\": 1},\n", + " runs_n=100)\n", + "\n", + "responsibility_stones_sally_HPR()" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
acsantecedent_sizeminimal_witness_sizedenumeratorresponsibility
0True1120.5
\n", + "
" + ], + "text/plain": [ + " acs antecedent_size minimal_witness_size denumerator responsibility\n", + "0 True 1 1 2 0.5" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# minimal witness size becomes non-trivial here\n", + "# we only record different minimal difference-making scenarios\n", + "\n", + "responsibility_stones_sally_HPR.responsibilityDF" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# following Halpern\n", + "# Sally's responsibility is 1/2\n", + "\n", + "responsibility_stones_sally_HPR.responsibility" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Billy has degree of responsibility 0\n", + "# for the bottle shattering,\n", + "# as his throw is not a part of an actual cause\n", + "\n", + "pyro.set_rng_seed(102)\n", + "\n", + "responsibility_stones_bill_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = stones_model,\n", + " nodes = stones_model.nodes,\n", + " antecedent = \"bill_throws\", outcome = \"bottle_shatters\",\n", + " observations = {\"prob_sally_throws\": 1, \n", + " \"prob_bill_throws\": 1,\n", + " \"prob_sally_hits\": 1,\n", + " \"prob_bill_hits\": 1,\n", + " \"prob_bottle_shatters_if_sally\": 1,\n", + " \"prob_bottle_shatters_if_bill\": 1,\n", + " \"sally_throws\": 1, \"bill_throws\": 1},\n", + " runs_n=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(101)\n", + "responsibility_stones_bill_HPR()" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "responsibility_stones_bill_HPR.responsibility" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Firing squad" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There is a firing squad consisting of five excellent marksmen. Only one of them has a live bullet in his rifle and the rest have blanks. They shoot and the prisoner dies. The marksmen shoot at the prisoner and he dies. The only cause of the prisoner’s death is the marksman with the live bullet. That marksman has degree of responsibility 1 for the death and all the others have degree of responsibility 0. In the notebook on blame, TODO add link we will see that if the marksmen completely do not know which of them has the live bullet, blame is nevertheless equally distributed between them." + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def firing_squad_model():\n", + " probs = pyro.sample(\"probs\", dist.Dirichlet(torch.ones(5)))\n", + "\n", + " who_has_bullet = pyro.sample(\"who_has_bullet\", dist.OneHotCategorical(probs))\n", + "\n", + " mark0 = pyro.deterministic(\"mark0\", torch.tensor([who[0] for who in who_has_bullet]), event_dim=0)\n", + " mark1 = pyro.deterministic(\"mark1\", torch.tensor([who[1] for who in who_has_bullet]), event_dim=0)\n", + " mark2 = pyro.deterministic(\"mark2\", torch.tensor([who[2] for who in who_has_bullet]), event_dim=0)\n", + " mark3 = pyro.deterministic(\"mark3\", torch.tensor([who[3] for who in who_has_bullet]), event_dim=0)\n", + " mark4 = pyro.deterministic(\"mark4\", torch.tensor([who[4] for who in who_has_bullet]), event_dim=0)\n", + "\n", + " dead = pyro.deterministic(\"dead\", mark0 + mark1 + mark2 + mark3 + \n", + " mark4 > 0)\n", + " \n", + " return {\"probs\": probs,\n", + " \"mark0\": mark0,\n", + " \"mark1\": mark1,\n", + " \"mark2\": mark2,\n", + " \"mark3\": mark3,\n", + " \"mark4\": mark4, \n", + " \"dead\": dead}\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(102)\n", + "\n", + "responsibility_loaded_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = firing_squad_model,\n", + " nodes = [\"mark\" + str(i) for i in range(0,5)],\n", + " antecedent = \"mark0\", outcome = \"dead\",\n", + " observations = {\"probs\": torch.tensor([1., 0., 0., 0., 0.]),},\n", + " runs_n=50)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(102)\n", + "\n", + "responsibility_empty_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = firing_squad_model,\n", + " nodes = [\"mark\" + str(i) for i in range(0,5)],\n", + " antecedent = \"mark1\", outcome = \"dead\",\n", + " observations = {\"probs\": torch.tensor([1., 0., 0., 0., 0.]),},\n", + " runs_n=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# If you have the live bullet\n", + "\n", + "responsibility_loaded_HPR()\n", + "responsibility_loaded_HPR.responsibility" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if you have a blank,\n", + "# as we keep bullet's location constant in the model\n", + "# nothing can make a difference to mark1's contribution\n", + "# so his responsibility is zero\n", + "\n", + "responsibility_empty_HPR()\n", + "responsibility_empty_HPR.responsibility" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "causal_pyro", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 55670684409c6afea1e56dda1f7ec61b54ea645b Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Tue, 22 Aug 2023 11:27:20 +0200 Subject: [PATCH 05/13] voting and responsibility done --- docs/source/.~lock.smalltab.csv# | 1 + docs/source/responsibility.ipynb | 1176 ++++++++++++++++++++++-------- 2 files changed, 887 insertions(+), 290 deletions(-) create mode 100644 docs/source/.~lock.smalltab.csv# diff --git a/docs/source/.~lock.smalltab.csv# b/docs/source/.~lock.smalltab.csv# new file mode 100644 index 00000000..855da661 --- /dev/null +++ b/docs/source/.~lock.smalltab.csv# @@ -0,0 +1 @@ +,rafal,pop-os,22.08.2023 11:23,file:///home/rafal/.config/libreoffice/4; \ No newline at end of file diff --git a/docs/source/responsibility.ipynb b/docs/source/responsibility.ipynb index fd868073..7fac683c 100644 --- a/docs/source/responsibility.ipynb +++ b/docs/source/responsibility.ipynb @@ -99,6 +99,7 @@ "\n", "import numpy as np\n", "from itertools import combinations\n", + "import math\n", "\n", "import torch\n", "from typing import Dict, List, Optional, Union, Callable, Any\n", @@ -110,9 +111,9 @@ "\n", "from chirho.indexed.ops import IndexSet, gather, indices_of, scatter\n", "from chirho.interventional.handlers import do\n", - "from chirho.counterfactual.ops import preempt\n", + "from chirho.counterfactual.ops import preempt, intervene\n", "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions\n", - "from chirho.observational.handlers import condition" + "from chirho.observational.handlers import condition\n" ] }, { @@ -121,48 +122,35 @@ "metadata": {}, "outputs": [], "source": [ + "\n", "class BiasedPreemptions(Preemptions):\n", " \"\"\"\n", " Counterfactual handler that preempts the model with a biased coin flip.\n", " \"\"\"\n", - " def __init__(self, actions, weights: torch.Tensor, event_dim: int = 0) -> None:\n", + " def __init__(self, actions, weights: torch.Tensor, event_dim: int = 0, prefix: str = \"__split_\") -> None:\n", " self.weights = weights\n", " self.event_dim = event_dim\n", + " self.prefix = prefix\n", " super().__init__(actions)\n", "\n", - " def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:\n", - " try:\n", - " name = msg[\"name\"]\n", - " action = self.actions[name]\n", - " except KeyError:\n", - " return\n", - " value = msg[\"value\"]\n", - " factual_value = gather(value, IndexSet(**{name: {0}}),\n", - " event_dim=self.event_dim),\n", - " counterfactual_value = gather(value, IndexSet(**{name: {1}}),\n", - " event_dim=self.event_dim),\n", - " factual_value = preempt(\n", - " factual_value,\n", - " (action,),\n", - " None,\n", - " event_dim=len(msg[\"fn\"].event_shape),\n", - " name=msg[\"name\"],\n", - " )\n", - "\n", - " msg[\"value\"] = scatter({\n", - " IndexSet(**{name: {0}}): factual_value,\n", - " IndexSet(**{name: {1}}): counterfactual_value,\n", - " }, event_dim=self.event_dim)\n", "\n", " def _pyro_preempt(self,msg: Dict[str, Any]) -> None:\n", " if msg[\"name\"] not in self.actions:\n", - " return\n", + " return \n", + "\n", " obs, acts, case = msg[\"args\"]\n", - " msg[\"kwargs\"][\"name\"] = f\"__split_{msg['name']}\"\n", + " msg[\"kwargs\"][\"name\"] = f\"{self.prefix}{msg['name']}\"\n", " case_dist = pyro.distributions.Categorical(self.weights)\n", + " #print(msg[\"kwargs\"][\"name\"] , self.prefix, msg['name'], self.weights)\n", " case = pyro.sample(msg[\"kwargs\"][\"name\"], case_dist, obs=case)\n", " msg[\"args\"] = (obs, acts, case)\n", - " msg[\"stop\"] = True" + " msg[\"stop\"] = True\n", + "\n", + " def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:\n", + " with pyro.poutine.messenger.block_messengers(\n", + " lambda m : (isinstance(m, Preemptions) and (m is not self))\n", + " ):\n", + " super()._pyro_post_sample(msg) " ] }, { @@ -171,39 +159,71 @@ "metadata": {}, "outputs": [], "source": [ - "# slight modification wrt. to the orginal class:\n", - "# we use BiasedPreemption for witnesses\n", - "# with the same intervention bias as for antecedents\n", - "# to minimize for the number of active (antecedents + witnesses)\n", - "\n", - "class HalpernPearlModifiedApproximate:\n", + "class HalpernPearlResponsibilityApproximate:\n", "\n", " def __init__(\n", " self, \n", " model: Callable,\n", - " counterfactual_antecedents: Dict[str, torch.Tensor],\n", - " outcome: str,\n", + " evaluated_node_counterfactual: Dict[str, torch.Tensor],\n", + " treatment_candidates: Dict[str, torch.Tensor],\n", " witness_candidates: List[str],\n", - " observations: Optional[Dict[str, torch.Tensor]] = None\n", + " outcome: str,\n", + " observations: Optional[Dict[str, torch.Tensor]] = None,\n", + " bias_t: float = .2\n", " ):\n", " \n", " if observations is None:\n", " observations = {}\n", "\n", + " if not set(witness_candidates) <= set(treatment_candidates.keys()):\n", + " raise ValueError(\"witness_candidates must be a subset of treatment_candidates.keys().\")\n", + " \n", " self.model = model\n", - " self.counterfactual_antecedents = counterfactual_antecedents\n", - " self.outcome = outcome\n", + " self.evaluated_node_counterfactual = evaluated_node_counterfactual\n", + " self.treatment_candidates = treatment_candidates\n", " self.witness_candidates = witness_candidates\n", + " self.outcome = outcome\n", " self.observations = observations\n", + " self.bias_t = bias_t\n", + " self.bias_n = self.find_max_bias_within(self.bias_t, 1)\n", + " self.bias_w = self.find_max_bias_within(self.bias_n, len(self.witness_candidates))\n", + "\n", + " self.evaluated_node_preemptions = {node: functools.partial(self.preempt_with_factual,\n", + " antecedents = [node]) for\n", + " node in self.evaluated_node_counterfactual.keys()}\n", "\n", - " self.antecedent_preemptions = {antecedent: functools.partial(self.preempt_with_factual,\n", + " self.treatment_preemptions = {antecedent: functools.partial(self.preempt_with_factual,\n", " antecedents = [antecedent]) for\n", - " antecedent in self.counterfactual_antecedents.keys()}\n", + " antecedent in self.treatment_candidates.keys()}\n", " \n", " self.witness_preemptions = {candidate: functools.partial(self.preempt_with_factual,\n", - " antecedents = self.counterfactual_antecedents) for \n", + " antecedents = self.treatment_candidates) for \n", " candidate in self.witness_candidates}\n", " \n", + " @staticmethod\n", + " def find_max_bias_within(e: float, n: int,\n", + " max_iterations: int = 1000, learning_rate: float = 0.002):\n", + " \n", + " ediff = math.log(0.5 + e) - math.log(0.5 - e)\n", + " #print(\"up\", math.log(0.5 + e), \"down\", math.log(0.5 - e), \"ediff\", ediff)\n", + "\n", + " w = e\n", + " wdiff = math.log(0.5 + w) - math.log(0.5 - w)\n", + "\n", + " iteration = 0 \n", + " while iteration < max_iterations and ediff <= n * wdiff:\n", + " \n", + " distance = n * wdiff / ediff\n", + " assert w - learning_rate * distance >0 , \"The learning rate is too high.\"\n", + " w -= learning_rate * distance\n", + " \n", + " wdiff = math.log(0.5 + w) - math.log(0.5 - w)\n", + " #print(\"up\", math.log(0.5 + w), \"down\", math.log(0.5 - w), \"wdiff\", wdiff, \"nwdiff\", n * wdiff)\n", + "\n", + " iteration += 1\n", + " \n", + " return w\n", + "\n", " @staticmethod \n", " def preempt_with_factual(value: torch.Tensor, *,\n", " antecedents: List[str] = None, event_dim: int = 0):\n", @@ -223,20 +243,27 @@ " \n", " \n", " def __call__(self, *args, **kwargs):\n", + " print(\"Preemption biases used (upper) - t:\",.5+ self.bias_t, \", n:\", .5 + self.bias_n, \", w:\", .5 + self.bias_w, \".\")\n", " with MultiWorldCounterfactual():\n", - " with do(actions=self.counterfactual_antecedents):\n", - " # the last element of the tensor is the factual case (preempted)\n", - " with BiasedPreemptions(actions = self.antecedent_preemptions, weights = torch.tensor([.4, .6])):\n", - " #the last element is the fixed at the observed value (preempted) \n", - " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.6, .4])):\n", - " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", - " with pyro.poutine.trace() as self.trace:\n", - " self.consequent = self.model(*args, **kwargs)[self.outcome]\n", - " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.counterfactual_antecedents}))\n", - " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.counterfactual_antecedents}))\n", - " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", - " pyro.deterministic(\"consequent_differs_binary\", self.consequent_differs, event_dim = 0) #feels inelegant\n", - " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))\n" + " # the last element of the tensor is the factual case (preempted)\n", + " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),\n", + " prefix = \"__witness_split_\"):\n", + " with do(actions=self.evaluated_node_counterfactual):\n", + " with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", + " prefix = \"__evaluated_split_\"):\n", + " with do(actions=self.treatment_candidates):\n", + " with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),\n", + " prefix = \"__treatment_split_\"):\n", + " # the last element is the fixed at the observed value (preempted) \n", + " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", + " with pyro.poutine.trace() as self.trace:\n", + " self.consequent = self.model(*args, **kwargs)[self.outcome]\n", + " self.counterfactual_interventions = list(self.evaluated_node_counterfactual.keys()) + list(self.treatment_candidates.keys())\n", + " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.counterfactual_interventions}))\n", + " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.counterfactual_interventions}))\n", + " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", + " pyro.deterministic(\"consequent_differs_binary\", self.consequent_differs, event_dim = 0) #feels inelegant\n", + " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))" ] }, { @@ -244,257 +271,157 @@ "execution_count": 4, "metadata": {}, "outputs": [], + "source": [ + "# only needed for ease of exposition,\n", + "# not for the inference itself\n", + "\n", + "def remove_redundant_rows(tab):\n", + " existing_pairs = []\n", + "\n", + " for col in tab.columns:\n", + " if col[0:4] == \"apr_\":\n", + " ending = col.split(\"apr_\")[1]\n", + " wpr_col = f\"wpr_{ending}\"\n", + " if wpr_col in tab.columns:\n", + " existing_pairs.append((col,wpr_col))\n", + "\n", + " keep = []\n", + " for index, row in tab.iterrows():\n", + " \n", + " flag = True\n", + " for pair in existing_pairs:\n", + " apr_col = pair[0]\n", + " wpr_col = pair[1]\n", + " apr_value = row[apr_col]\n", + " wpr_value = row[wpr_col]\n", + " \n", + " if apr_value == 0 and wpr_value == 1:\n", + " flag = False\n", + " break\n", + " keep.append(flag)\n", + " \n", + " return(tab[keep])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], "source": [ "# this will explore the trace once we run inference on the model\n", "\n", - "def get_table(nodes, antecedents, witness_candidates):\n", + "def get_table(nodes, evaluated_node, antecedents, witness_candidates, round = True):\n", " \n", " values_table = {}\n", "\n", + " values_table[f\"obs_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][1].squeeze().tolist()\n", + " values_table[f\"epr_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"].squeeze().tolist()\n", + " values_table[f\"elp_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"fn\"].log_prob(nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"]).squeeze().tolist()\n", + "\n", " for antecedent in antecedents:\n", - " values_table[antecedent] = nodes[antecedent][\"value\"].squeeze().tolist()\n", - " values_table['preempted_' + antecedent] = nodes['__split_' + antecedent][\"value\"].squeeze().tolist()\n", - " values_table['preempted_' + antecedent + '_log_prob'] = nodes['__split_' + antecedent][\"fn\"].log_prob(nodes['__split_' + antecedent][\"value\"]).squeeze().tolist()\n", - "\n", - "\n", - " for candidate in witness_candidates:\n", - " _values = nodes[candidate][\"value\"].squeeze().tolist()\n", - " # TODO: uncomment in the final version (?) \n", - " #values_table[candidate + '0'] = _values[0]\n", - " #values_table[candidate + '1'] = _values[1]\n", - " values_table['fixed_factual_' + candidate] = nodes['__split_' + candidate][\"value\"].squeeze().tolist()\n", - " \n", - " # TODO uncomment in the final version (?)\n", - " #values_table[consequent + '0'] = nodes[consequent][\"value\"].squeeze().tolist()[0]\n", - " #values_table[consequent + '1'] = nodes[consequent][\"value\"].squeeze().tolist()[1]\n", - " values_table['consequent_differs_binary'] = nodes['consequent_differs_binary'][\"value\"].squeeze().tolist()\n", - " values_table['consequent_log_prob'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", - "\n", - " if isinstance(values_table['consequent_log_prob'], float):\n", + " values_table[f\"obs_{antecedent}\"] = nodes[antecedent][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{antecedent}\"] = nodes[antecedent][\"value\"][1].squeeze().tolist()\n", + " values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent][\"fn\"].log_prob(nodes['__treatment_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + "\n", + "\n", + " if f\"__witness_split_{antecedent}\" in nodes.keys():\n", + " values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent][\"fn\"].log_prob(nodes['__witness_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " \n", + " values_table['cdif'] = nodes['consequent_differs_binary'][\"value\"].squeeze().tolist()\n", + " values_table['clp'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " if isinstance(values_table['clp'], float):\n", " values_df = pd.DataFrame([values_table])\n", " else:\n", " values_df = pd.DataFrame(values_table)\n", " \n", + " values_df = pd.DataFrame(values_table)\n", "\n", - " summands = ['preempted_' + antecedent + '_log_prob' for antecedent in antecedents]\n", - " summands.append('consequent_log_prob')\n", - " values_df[\"sum_log_prob\"] = values_df[summands].sum(axis = 1) \n", + " summands_ant = ['alp_' + antecedent for antecedent in antecedents]\n", + " summands_wit = ['wlp_' + witness for witness in witness_candidates]\n", + " summands = [f\"elp_{evaluated_node}\"] + summands_ant + summands_wit + ['clp']\n", + " \n", + " \n", + " values_df[\"int\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"apr_\")] == 0), axis=1)\n", + " values_df['int'] = 1 - values_df[f\"epr_{evaluated_node}\"] + values_df[\"int\"]\n", + " values_df[\"wpr\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"wpr_\")] == 1), axis=1)\n", + " values_df[\"changes\"] = values_df[\"int\"] + values_df[\"wpr\"]\n", + "\n", + "\n", + " values_df[\"sum_lp\"] = values_df[summands].sum(axis = 1) \n", " values_df.drop_duplicates(inplace = True)\n", - " values_df.sort_values(by = \"sum_log_prob\", inplace = True, ascending = False)\n", + " values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", "\n", - " return values_df.reset_index(drop = True)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "def voting_model():\n", - " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", - " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", - " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", - " u_vote3 = pyro.sample(\"u_vote3\", dist.Bernoulli(0.6))\n", - " u_vote4 = pyro.sample(\"u_vote4\", dist.Bernoulli(0.6))\n", - " u_vote5 = pyro.sample(\"u_vote5\", dist.Bernoulli(0.6))\n", - " u_vote6 = pyro.sample(\"u_vote6\", dist.Bernoulli(0.6))\n", - " u_vote7 = pyro.sample(\"u_vote7\", dist.Bernoulli(0.6))\n", + " tab = values_df.reset_index(drop = True)\n", "\n", - " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", - " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", - " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", - " vote3 = pyro.deterministic(\"vote3\", u_vote3, event_dim=0)\n", - " vote4 = pyro.deterministic(\"vote4\", u_vote4, event_dim=0)\n", - " vote5 = pyro.deterministic(\"vote5\", u_vote5, event_dim=0)\n", - " vote6 = pyro.deterministic(\"vote6\", u_vote6, event_dim=0)\n", - " vote7 = pyro.deterministic(\"vote7\", u_vote7, event_dim=0)\n", + " tab = remove_redundant_rows(tab)\n", "\n", + " #tab = values_table\n", "\n", - " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 + vote3 + \n", - " vote4 + vote5 + vote6 + vote7 > 4)\n", - " return {\"outcome\": outcome.float()}" + " if round:\n", + " tab = tab.round(3)\n", + "\n", + " return tab\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "# if you're one of five voters who voted for, you are an actual cause\n", + "def responsibility_check(hpr):\n", "\n", - "# and your responsibility is 1 \n", + " evaluated_node = list(hpr.evaluated_node_counterfactual.keys())[0]\n", + " tab = get_table(hpr.trace.trace.nodes,\n", + " evaluated_node ,\n", + " list(hpr.treatment_candidates.keys()), \n", + " hpr.witness_candidates)\n", + " \n", + " max_sum_lp = tab['sum_lp'].max()\n", + " max_sum_lp_rows = tab[tab['sum_lp'] == max_sum_lp]\n", "\n", - "observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=1., u_vote5=0.,\n", - " u_vote6=0., u_vote7=0.)\n", + " map_estimate = 1/ tab['changes'][0]\n", "\n", + " print (f\"MAP estimate: {map_estimate}\")\n", "\n", + " # sanity check; consider removing later\n", + " min_changes = max_sum_lp_rows['changes'].min()\n", + " min_changes_row = max_sum_lp_rows[max_sum_lp_rows['changes'] == min_changes]\n", "\n", - "counterfactual_antecedents = {key[2:]: 1-v for key, v in observations.items()}\n", + " print(\"Minimal scenarios:\")\n", + " print(min_changes_row)\n", "\n", + " if not (min_changes_row[f'int_{evaluated_node}'] == 0).any():\n", + " print (f\"No MAP estimate includes intervention on int_{evaluated_node} == 0\")\n", + " return 0\n", + " \n", + " min_changes_row = min_changes_row[min_changes_row[f'int_{evaluated_node}'] == 0]\n", "\n", - "voting5HPM = HalpernPearlModifiedApproximate(\n", - " model = voting_model,\n", - " counterfactual_antecedents = counterfactual_antecedents,\n", - " outcome = \"outcome\",\n", - " witness_candidates = [f\"vote{i}\" for i in range(1,8)],\n", - " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=1., u_vote5=0.,\n", - " u_vote6=0., u_vote7=0.))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "ename": "NotImplementedError", - "evalue": "intervene not implemented for type ", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mwith\u001b[39;00m pyro\u001b[39m.\u001b[39mplate(\u001b[39m\"\u001b[39m\u001b[39mruns\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m10\u001b[39m):\n\u001b[0;32m----> 2\u001b[0m voting5HPM()\n", - "Cell \u001b[0;32mIn[3], line 61\u001b[0m, in \u001b[0;36mHalpernPearlModifiedApproximate.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[39mwith\u001b[39;00m condition(data\u001b[39m=\u001b[39m{k: torch\u001b[39m.\u001b[39mas_tensor(v) \u001b[39mfor\u001b[39;00m k, v \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobservations\u001b[39m.\u001b[39mitems()}):\n\u001b[1;32m 60\u001b[0m \u001b[39mwith\u001b[39;00m pyro\u001b[39m.\u001b[39mpoutine\u001b[39m.\u001b[39mtrace() \u001b[39mas\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtrace:\n\u001b[0;32m---> 61\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconsequent \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)[\u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutcome]\n\u001b[1;32m 62\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mintervened_consequent \u001b[39m=\u001b[39m gather(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconsequent, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{ant: {\u001b[39m1\u001b[39m} \u001b[39mfor\u001b[39;00m ant \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcounterfactual_antecedents}))\n\u001b[1;32m 63\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mobserved_consequent \u001b[39m=\u001b[39m gather(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconsequent, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{ant: {\u001b[39m0\u001b[39m} \u001b[39mfor\u001b[39;00m ant \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcounterfactual_antecedents}))\n", - "Cell \u001b[0;32mIn[5], line 11\u001b[0m, in \u001b[0;36mvoting_model\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m u_vote6 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39msample(\u001b[39m\"\u001b[39m\u001b[39mu_vote6\u001b[39m\u001b[39m\"\u001b[39m, dist\u001b[39m.\u001b[39mBernoulli(\u001b[39m0.6\u001b[39m))\n\u001b[1;32m 9\u001b[0m u_vote7 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39msample(\u001b[39m\"\u001b[39m\u001b[39mu_vote7\u001b[39m\u001b[39m\"\u001b[39m, dist\u001b[39m.\u001b[39mBernoulli(\u001b[39m0.6\u001b[39m))\n\u001b[0;32m---> 11\u001b[0m vote0 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39;49mdeterministic(\u001b[39m\"\u001b[39;49m\u001b[39mvote0\u001b[39;49m\u001b[39m\"\u001b[39;49m, u_vote0, event_dim\u001b[39m=\u001b[39;49m\u001b[39m0\u001b[39;49m)\n\u001b[1;32m 12\u001b[0m vote1 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39mdeterministic(\u001b[39m\"\u001b[39m\u001b[39mvote1\u001b[39m\u001b[39m\"\u001b[39m, u_vote1, event_dim\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 13\u001b[0m vote2 \u001b[39m=\u001b[39m pyro\u001b[39m.\u001b[39mdeterministic(\u001b[39m\"\u001b[39m\u001b[39mvote2\u001b[39m\u001b[39m\"\u001b[39m, u_vote2, event_dim\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/primitives.py:209\u001b[0m, in \u001b[0;36mdeterministic\u001b[0;34m(name, value, event_dim)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[39mDeterministic statement to add a :class:`~pyro.distributions.Delta` site\u001b[39;00m\n\u001b[1;32m 194\u001b[0m \u001b[39mwith name `name` and value `value` to the trace. This is useful when we\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[39m:param int event_dim: Optional event dimension, defaults to `value.ndim`.\u001b[39;00m\n\u001b[1;32m 207\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 208\u001b[0m event_dim \u001b[39m=\u001b[39m value\u001b[39m.\u001b[39mndim \u001b[39mif\u001b[39;00m event_dim \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m event_dim\n\u001b[0;32m--> 209\u001b[0m \u001b[39mreturn\u001b[39;00m sample(\n\u001b[1;32m 210\u001b[0m name,\n\u001b[1;32m 211\u001b[0m dist\u001b[39m.\u001b[39;49mDelta(value, event_dim\u001b[39m=\u001b[39;49mevent_dim)\u001b[39m.\u001b[39;49mmask(\u001b[39mFalse\u001b[39;49;00m),\n\u001b[1;32m 212\u001b[0m obs\u001b[39m=\u001b[39;49mvalue,\n\u001b[1;32m 213\u001b[0m infer\u001b[39m=\u001b[39;49m{\u001b[39m\"\u001b[39;49m\u001b[39m_deterministic\u001b[39;49m\u001b[39m\"\u001b[39;49m: \u001b[39mTrue\u001b[39;49;00m},\n\u001b[1;32m 214\u001b[0m )\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/primitives.py:163\u001b[0m, in \u001b[0;36msample\u001b[0;34m(name, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 146\u001b[0m msg \u001b[39m=\u001b[39m {\n\u001b[1;32m 147\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39msample\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 148\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mname\u001b[39m\u001b[39m\"\u001b[39m: name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcontinuation\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 161\u001b[0m }\n\u001b[1;32m 162\u001b[0m \u001b[39m# apply the stack and return its return value\u001b[39;00m\n\u001b[0;32m--> 163\u001b[0m apply_stack(msg)\n\u001b[1;32m 164\u001b[0m \u001b[39mreturn\u001b[39;00m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:220\u001b[0m, in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 217\u001b[0m default_process_message(msg)\n\u001b[1;32m 219\u001b[0m \u001b[39mfor\u001b[39;00m frame \u001b[39min\u001b[39;00m stack[\u001b[39m-\u001b[39mpointer:]:\n\u001b[0;32m--> 220\u001b[0m frame\u001b[39m.\u001b[39;49m_postprocess_message(msg)\n\u001b[1;32m 222\u001b[0m cont \u001b[39m=\u001b[39m msg[\u001b[39m\"\u001b[39m\u001b[39mcontinuation\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[1;32m 223\u001b[0m \u001b[39mif\u001b[39;00m cont \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:168\u001b[0m, in \u001b[0;36mMessenger._postprocess_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 166\u001b[0m method \u001b[39m=\u001b[39m \u001b[39mgetattr\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39m_pyro_post_\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(msg[\u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m]), \u001b[39mNone\u001b[39;00m)\n\u001b[1;32m 167\u001b[0m \u001b[39mif\u001b[39;00m method \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 168\u001b[0m \u001b[39mreturn\u001b[39;00m method(msg)\n\u001b[1;32m 169\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mNone\u001b[39;00m\n", - "Cell \u001b[0;32mIn[2], line 21\u001b[0m, in \u001b[0;36mBiasedPreemptions._pyro_post_sample\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 17\u001b[0m factual_value \u001b[39m=\u001b[39m gather(value, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m0\u001b[39m}}),\n\u001b[1;32m 18\u001b[0m event_dim\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mevent_dim),\n\u001b[1;32m 19\u001b[0m counterfactual_value \u001b[39m=\u001b[39m gather(value, IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m1\u001b[39m}}),\n\u001b[1;32m 20\u001b[0m event_dim\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mevent_dim),\n\u001b[0;32m---> 21\u001b[0m factual_value \u001b[39m=\u001b[39m preempt(\n\u001b[1;32m 22\u001b[0m factual_value,\n\u001b[1;32m 23\u001b[0m (action,),\n\u001b[1;32m 24\u001b[0m \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 25\u001b[0m event_dim\u001b[39m=\u001b[39;49m\u001b[39mlen\u001b[39;49m(msg[\u001b[39m\"\u001b[39;49m\u001b[39mfn\u001b[39;49m\u001b[39m\"\u001b[39;49m]\u001b[39m.\u001b[39;49mevent_shape),\n\u001b[1;32m 26\u001b[0m name\u001b[39m=\u001b[39;49mmsg[\u001b[39m\"\u001b[39;49m\u001b[39mname\u001b[39;49m\u001b[39m\"\u001b[39;49m],\n\u001b[1;32m 27\u001b[0m )\n\u001b[1;32m 29\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m scatter({\n\u001b[1;32m 30\u001b[0m IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m0\u001b[39m}}): factual_value,\n\u001b[1;32m 31\u001b[0m IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m1\u001b[39m}}): counterfactual_value,\n\u001b[1;32m 32\u001b[0m }, event_dim\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mevent_dim)\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:281\u001b[0m, in \u001b[0;36meffectful.._fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 264\u001b[0m msg \u001b[39m=\u001b[39m {\n\u001b[1;32m 265\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mtype\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mtype\u001b[39m,\n\u001b[1;32m 266\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mname\u001b[39m\u001b[39m\"\u001b[39m: name,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[39m\"\u001b[39m\u001b[39minfer\u001b[39m\u001b[39m\"\u001b[39m: infer,\n\u001b[1;32m 279\u001b[0m }\n\u001b[1;32m 280\u001b[0m \u001b[39m# apply the stack and return its return value\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m apply_stack(msg)\n\u001b[1;32m 282\u001b[0m \u001b[39mreturn\u001b[39;00m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:217\u001b[0m, in \u001b[0;36mapply_stack\u001b[0;34m(initial_msg)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[39mif\u001b[39;00m msg[\u001b[39m\"\u001b[39m\u001b[39mstop\u001b[39m\u001b[39m\"\u001b[39m]:\n\u001b[1;32m 215\u001b[0m \u001b[39mbreak\u001b[39;00m\n\u001b[0;32m--> 217\u001b[0m default_process_message(msg)\n\u001b[1;32m 219\u001b[0m \u001b[39mfor\u001b[39;00m frame \u001b[39min\u001b[39;00m stack[\u001b[39m-\u001b[39mpointer:]:\n\u001b[1;32m 220\u001b[0m frame\u001b[39m.\u001b[39m_postprocess_message(msg)\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/runtime.py:179\u001b[0m, in \u001b[0;36mdefault_process_message\u001b[0;34m(msg)\u001b[0m\n\u001b[1;32m 176\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mdone\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n\u001b[1;32m 177\u001b[0m \u001b[39mreturn\u001b[39;00m msg\n\u001b[0;32m--> 179\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m msg[\u001b[39m\"\u001b[39;49m\u001b[39mfn\u001b[39;49m\u001b[39m\"\u001b[39;49m](\u001b[39m*\u001b[39;49mmsg[\u001b[39m\"\u001b[39;49m\u001b[39margs\u001b[39;49m\u001b[39m\"\u001b[39;49m], \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mmsg[\u001b[39m\"\u001b[39;49m\u001b[39mkwargs\u001b[39;49m\u001b[39m\"\u001b[39;49m])\n\u001b[1;32m 181\u001b[0m \u001b[39m# after fn has been called, update msg to prevent it from being called again.\u001b[39;00m\n\u001b[1;32m 182\u001b[0m msg[\u001b[39m\"\u001b[39m\u001b[39mdone\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pyro/poutine/messenger.py:12\u001b[0m, in \u001b[0;36m_context_wrap\u001b[0;34m(context, fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_context_wrap\u001b[39m(context, fn, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 11\u001b[0m \u001b[39mwith\u001b[39;00m context:\n\u001b[0;32m---> 12\u001b[0m \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", - "File \u001b[0;32m~/L2projects/chirho/chirho/counterfactual/ops.py:53\u001b[0m, in \u001b[0;36mpreempt\u001b[0;34m(obs, acts, case, **kwargs)\u001b[0m\n\u001b[1;32m 51\u001b[0m act_values \u001b[39m=\u001b[39m {IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {\u001b[39m0\u001b[39m}}): obs}\n\u001b[1;32m 52\u001b[0m \u001b[39mfor\u001b[39;00m i, act \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(acts):\n\u001b[0;32m---> 53\u001b[0m act_values[IndexSet(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39m{name: {i \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m}})] \u001b[39m=\u001b[39m intervene(obs, act, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 55\u001b[0m \u001b[39mreturn\u001b[39;00m cond(act_values, case, event_dim\u001b[39m=\u001b[39mkwargs\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39mevent_dim\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m0\u001b[39m))\n", - "File \u001b[0;32m~/anaconda3/envs/causal_pyro/lib/python3.10/functools.py:889\u001b[0m, in \u001b[0;36msingledispatch..wrapper\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m 885\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m args:\n\u001b[1;32m 886\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mfuncname\u001b[39m}\u001b[39;00m\u001b[39m requires at least \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 887\u001b[0m \u001b[39m'\u001b[39m\u001b[39m1 positional argument\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m--> 889\u001b[0m \u001b[39mreturn\u001b[39;00m dispatch(args[\u001b[39m0\u001b[39;49m]\u001b[39m.\u001b[39;49m\u001b[39m__class__\u001b[39;49m)(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkw)\n", - "File \u001b[0;32m~/L2projects/chirho/chirho/interventional/ops.py:25\u001b[0m, in \u001b[0;36mintervene\u001b[0;34m(obs, act, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39msingledispatch\n\u001b[1;32m 12\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mintervene\u001b[39m(obs, act: Optional[Intervention[T]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 13\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[39m Intervene on a value in a probabilistic program.\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[39m :param act: an optional intervention.\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 25\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mintervene not implemented for type \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(obs)\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", - "\u001b[0;31mNotImplementedError\u001b[0m: intervene not implemented for type " - ] - } - ], - "source": [ - "with pyro.plate(\"runs\", 10):\n", - " voting5HPM()\n" + " secondary_check = 1/min_changes_row['changes'].min()\n", + "\n", + " print (f\"Secondary check: {secondary_check}\")\n", + "\n", + " assert map_estimate == secondary_check, \"MAP estimate does not match secondary check.\" \n", + "\n", + " return map_estimate" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ + "#TODO THIS NEEDS TO BE UPDATED \n", "This implementation is now used within another class definition, where, again, the main moves are in `def __call__`. We sample antecedent sets, leave other nodes (aside from the outcome) as witness candidates, and pass the result to an actual causality evaluation, keeping track of minimal antecedent sets and the corresponding witness sizes. Then we find a minimum of the sum and use it in the denominator.`" ] }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "class HalpernPearlResponsibilityApproximate:\n", - "\n", - " def __init__(\n", - " self, \n", - " model: Callable,\n", - " nodes: List,\n", - " antecedent: str,\n", - " outcome: str,\n", - " observations: Dict[str, torch.Tensor], \n", - " runs_n: int \n", - " ):\n", - " self.model = model\n", - " self.nodes = nodes\n", - " self.antecedent = antecedent\n", - " self.outcome = outcome\n", - " self.observations = observations\n", - " self.runs_n = runs_n\n", - " \n", - " self.minimal_antecedents_cache = []\n", - " self.antecedent_sizes = []\n", - " self.existential_but_fors = []\n", - " self.acs = []\n", - " self.minimal_witness_sizes = []\n", - " self.responsibilities = []\n", - " self.HPMs = []\n", - "\n", - " def __call__(self):\n", - " \n", - " for step in range(1,self.runs_n):\n", - "\n", - " nodes = self.nodes\n", - " if self.outcome in nodes:\n", - " nodes.remove(self.outcome) \n", - " \n", - " companion_size = random.randint(0,len(nodes))\n", - " companion_candidates = random.sample(self.nodes, companion_size)\n", - " witness_candidates = [node for node in self.nodes if \n", - " node != self.antecedent and \n", - " node != self.outcome and \n", - " node not in companion_candidates]\n", - "\n", - " HPM = HalpernPearlModifiedApproximate(\n", - " model = self.model,\n", - " antecedents = companion_candidates,\n", - " outcome = self.outcome,\n", - " witness_candidates = witness_candidates,\n", - " observations = self.observations,\n", - " sample_size = 1000)\n", - " \n", - " HPM()\n", - " self.HPMs.append(HPM)\n", - "\n", - " if HPM.existential_but_for:\n", - " \n", - " HPM_min = ac_minimality_check(HPM)\n", - "\n", - " if HPM_min.ac:\n", - "\n", - " subset_in_cache = any([s.issubset(set(HPM.antecedents)) for s in self.minimal_antecedents_cache])\n", - " if not subset_in_cache:\n", - " for s in self.minimal_antecedents_cache:\n", - " if set(HPM.antecedents).issubset(s):\n", - " self.minimal_antecedents_cache.remove(s)\n", - " self.minimal_antecedents_cache.append(set(HPM.antecedents))\n", - "\n", - " if self.antecedent in HPM.antecedents:\n", - " self.antecedent_sizes.append(len(HPM.antecedents))\n", - " self.existential_but_fors.append(HPM.existential_but_for)\n", - " self.acs.append(HPM.ac)\n", - " self.minimal_witness_sizes.append(HPM.minimal_witness_size)\n", - " self.responsibilities.append(HPM.responsibility_internal)\n", - "\n", - "\n", - " self.denumerators = [x + y for x, y in zip(self.antecedent_sizes, self.minimal_witness_sizes)]\n", - "\n", - " self.responsibilityDF = pd.DataFrame(\n", - " {#\"existential_but_for\": [bool(value) for value in self.existential_but_fors],\n", - " \"acs\": [bool(value) for value in self.acs],\n", - " \"antecedent_size\": self.antecedent_sizes, \n", - " \"minimal_witness_size\": self.minimal_witness_sizes,\n", - " \"denumerator\": self.denumerators,\n", - " \"responsibility\": self.responsibilities\n", - " }\n", - " )\n", - " if len(self.responsibilityDF['acs']) == 0:\n", - " self.responsibility = 0\n", - " else:\n", - " min_denumerator = min(self.responsibilityDF['denumerator'])\n", - " self.responsibility = 1/min_denumerator\n", - "\n", - " " - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -543,22 +470,661 @@ "cell_type": "code", "execution_count": 10, "metadata": {}, + "outputs": [], + "source": [ + "# let's start with a minimal interesting example\n", + "# you are one of three voters\n", + "\n", + "def voting_model():\n", + " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", + " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", + " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", + "\n", + " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", + " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", + " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", + "\n", + " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 >1\n", + " )\n", + " return {\"outcome\": outcome.float()}\n", + "\n", + "observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.\n", + " )\n", + "\n", + "treatment_candidates = {key[2:]: 1-v for key, v in observations.items() if key != \"u_vote0\"}\n", + "evaluated_node_counterfactual = {\"vote0\": 1 - observations[\"u_vote0\"]}\n", + "\n", + "votingHPR = HalpernPearlResponsibilityApproximate(\n", + " model = voting_model,\n", + " evaluated_node_counterfactual = evaluated_node_counterfactual,\n", + " treatment_candidates = treatment_candidates,\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,3)],\n", + " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.\n", + " ))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.6026881514681908 .\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
obs_vote0int_vote0epr_vote0elp_vote0obs_vote1int_vote1apr_vote1alp_vote1wpr_vote1wlp_vote1...apr_vote2alp_vote2wpr_vote2wlp_vote2cdifclpintwprchangessum_lp
01.00.00-1.1971.00.00-1.2040-0.506...1-0.3570-0.506True0.0202-3.771000e+00
11.00.00-1.1971.01.01-0.3570-0.506...0-1.2040-0.506True0.0202-3.771000e+00
21.01.01-0.3601.00.00-1.2040-0.506...0-1.2040-0.506True0.0202-3.780000e+00
31.00.00-1.1971.01.01-0.3571-0.923...0-1.2040-0.506True0.0213-4.187000e+00
51.00.00-1.1971.00.00-1.2040-0.506...1-0.3571-0.923True0.0213-4.187000e+00
121.00.00-1.1971.00.00-1.2040-0.506...0-1.2040-0.506True0.0303-4.618000e+00
161.01.01-0.3601.01.01-0.3570-0.506...1-0.3570-0.506False-100000000.0000-1.000000e+08
171.01.01-0.3601.01.01-0.3571-0.923...1-0.3570-0.506False-100000000.0011-1.000000e+08
181.01.01-0.3601.01.01-0.3570-0.506...1-0.3571-0.923False-100000000.0011-1.000000e+08
191.01.01-0.3601.01.01-0.3571-0.923...1-0.3571-0.923False-100000000.0022-1.000000e+08
201.00.00-1.1971.01.01-0.3570-0.506...1-0.3570-0.506False-100000000.0101-1.000000e+08
211.01.01-0.3601.00.00-1.2040-0.506...1-0.3570-0.506False-100000000.0101-1.000000e+08
221.01.01-0.3601.01.01-0.3570-0.506...0-1.2040-0.506False-100000000.0101-1.000000e+08
231.00.00-1.1971.01.01-0.3570-0.506...1-0.3571-0.923False-100000000.0112-1.000000e+08
241.00.00-1.1971.01.01-0.3571-0.923...1-0.3570-0.506False-100000000.0112-1.000000e+08
261.01.01-0.3601.00.00-1.2040-0.506...1-0.3571-0.923False-100000000.0112-1.000000e+08
281.01.01-0.3601.01.01-0.3571-0.923...0-1.2040-0.506False-100000000.0112-1.000000e+08
291.00.00-1.1971.01.01-0.3571-0.923...1-0.3571-0.923False-100000000.0123-1.000000e+08
\n", + "

18 rows × 22 columns

\n", + "
" + ], + "text/plain": [ + " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", + "0 1.0 0.0 0 -1.197 1.0 0.0 \n", + "1 1.0 0.0 0 -1.197 1.0 1.0 \n", + "2 1.0 1.0 1 -0.360 1.0 0.0 \n", + "3 1.0 0.0 0 -1.197 1.0 1.0 \n", + "5 1.0 0.0 0 -1.197 1.0 0.0 \n", + "12 1.0 0.0 0 -1.197 1.0 0.0 \n", + "16 1.0 1.0 1 -0.360 1.0 1.0 \n", + "17 1.0 1.0 1 -0.360 1.0 1.0 \n", + "18 1.0 1.0 1 -0.360 1.0 1.0 \n", + "19 1.0 1.0 1 -0.360 1.0 1.0 \n", + "20 1.0 0.0 0 -1.197 1.0 1.0 \n", + "21 1.0 1.0 1 -0.360 1.0 0.0 \n", + "22 1.0 1.0 1 -0.360 1.0 1.0 \n", + "23 1.0 0.0 0 -1.197 1.0 1.0 \n", + "24 1.0 0.0 0 -1.197 1.0 1.0 \n", + "26 1.0 1.0 1 -0.360 1.0 0.0 \n", + "28 1.0 1.0 1 -0.360 1.0 1.0 \n", + "29 1.0 0.0 0 -1.197 1.0 1.0 \n", + "\n", + " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", + "0 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "1 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "2 0 -1.204 0 -0.506 ... 0 -1.204 \n", + "3 1 -0.357 1 -0.923 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "12 0 -1.204 0 -0.506 ... 0 -1.204 \n", + "16 1 -0.357 0 -0.506 ... 1 -0.357 \n", + "17 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "18 1 -0.357 0 -0.506 ... 1 -0.357 \n", + "19 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "20 1 -0.357 0 -0.506 ... 1 -0.357 \n", + "21 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "22 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "23 1 -0.357 0 -0.506 ... 1 -0.357 \n", + "24 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "26 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "28 1 -0.357 1 -0.923 ... 0 -1.204 \n", + "29 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "\n", + " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", + "0 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", + "1 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", + "2 0 -0.506 True 0.0 2 0 2 -3.780000e+00 \n", + "3 0 -0.506 True 0.0 2 1 3 -4.187000e+00 \n", + "5 1 -0.923 True 0.0 2 1 3 -4.187000e+00 \n", + "12 0 -0.506 True 0.0 3 0 3 -4.618000e+00 \n", + "16 0 -0.506 False -100000000.0 0 0 0 -1.000000e+08 \n", + "17 0 -0.506 False -100000000.0 0 1 1 -1.000000e+08 \n", + "18 1 -0.923 False -100000000.0 0 1 1 -1.000000e+08 \n", + "19 1 -0.923 False -100000000.0 0 2 2 -1.000000e+08 \n", + "20 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", + "21 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", + "22 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", + "23 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", + "24 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", + "26 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", + "28 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", + "29 1 -0.923 False -100000000.0 1 2 3 -1.000000e+08 \n", + "\n", + "[18 rows x 22 columns]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now inference, and let's inspect the table while it's small\n", + "\n", + "with pyro.plate(\"runs\", 1000):\n", + " votingHPR()\n", + "\n", + "vtr = votingHPR.trace.trace.nodes\n", + "\n", + "get_table(vtr, \"vote0\", treatment_candidates, [f\"vote{i}\" for i in range(1,3)])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TODO need a brief explanation of what's going on here" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, "outputs": [ { - "ename": "KeyError", - "evalue": "'__split_vote0'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m get_table(voting5HPM\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mnodes, antecedents \u001b[39m=\u001b[39;49m counterfactual_antecedents, witness_candidates \u001b[39m=\u001b[39;49m voting5HPM\u001b[39m.\u001b[39;49mwitness_candidates)\n", - "Cell \u001b[0;32mIn[4], line 9\u001b[0m, in \u001b[0;36mget_table\u001b[0;34m(nodes, antecedents, witness_candidates)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39mfor\u001b[39;00m antecedent \u001b[39min\u001b[39;00m antecedents:\n\u001b[1;32m 8\u001b[0m values_table[antecedent] \u001b[39m=\u001b[39m nodes[antecedent][\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39msqueeze()\u001b[39m.\u001b[39mtolist()\n\u001b[0;32m----> 9\u001b[0m values_table[\u001b[39m'\u001b[39m\u001b[39mpreempted_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent] \u001b[39m=\u001b[39m nodes[\u001b[39m'\u001b[39;49m\u001b[39m__split_\u001b[39;49m\u001b[39m'\u001b[39;49m \u001b[39m+\u001b[39;49m antecedent][\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39msqueeze()\u001b[39m.\u001b[39mtolist()\n\u001b[1;32m 10\u001b[0m values_table[\u001b[39m'\u001b[39m\u001b[39mpreempted_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39m_log_prob\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m nodes[\u001b[39m'\u001b[39m\u001b[39m__split_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent][\u001b[39m\"\u001b[39m\u001b[39mfn\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mlog_prob(nodes[\u001b[39m'\u001b[39m\u001b[39m__split_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent][\u001b[39m\"\u001b[39m\u001b[39mvalue\u001b[39m\u001b[39m\"\u001b[39m])\u001b[39m.\u001b[39msqueeze()\u001b[39m.\u001b[39mtolist()\n\u001b[1;32m 13\u001b[0m \u001b[39mfor\u001b[39;00m candidate \u001b[39min\u001b[39;00m witness_candidates:\n", - "\u001b[0;31mKeyError\u001b[0m: '__split_vote0'" + "name": "stdout", + "output_type": "stream", + "text": [ + "MAP estimate: 0.5\n", + "Minimal scenarios:\n", + " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", + "0 1.0 0.0 0 -1.197 1.0 0.0 \n", + "1 1.0 0.0 0 -1.197 1.0 1.0 \n", + "\n", + " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", + "0 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "1 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "\n", + " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", + "0 0 -0.506 True 0.0 2 0 2 -3.771 \n", + "1 0 -0.506 True 0.0 2 0 2 -3.771 \n", + "\n", + "[2 rows x 22 columns]\n", + "Secondary check: 0.5\n" ] + }, + { + "data": { + "text/plain": [ + "0.5" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "get_table(voting5HPM.trace.trace.nodes, antecedents = counterfactual_antecedents, witness_candidates = voting5HPM.witness_candidates)" + "responsibility_check(votingHPR)" ] }, { @@ -567,19 +1133,32 @@ "metadata": {}, "outputs": [], "source": [ - "# if everyone voted for, you are not an actual cause\n", + "# now consider a more complex example,\n", + "# with 8 voters, where you are not an actual cause\n", "\n", - "everyone_voted_HPR = HalpernPearlResponsibilityApproximate(\n", - " model = voting_model,\n", - " nodes = [f\"vote{i}\" for i in range(0,8,)],\n", - " antecedent = \"vote0\", outcome = \"outcome\",\n", - " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 1.), \n", - " runs_n=500\n", - " )\n", + "def voting_model8():\n", + " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", + " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", + " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", + " u_vote3 = pyro.sample(\"u_vote3\", dist.Bernoulli(0.6))\n", + " u_vote4 = pyro.sample(\"u_vote4\", dist.Bernoulli(0.6))\n", + " u_vote5 = pyro.sample(\"u_vote5\", dist.Bernoulli(0.6))\n", + " u_vote6 = pyro.sample(\"u_vote6\", dist.Bernoulli(0.6))\n", + " u_vote7 = pyro.sample(\"u_vote7\", dist.Bernoulli(0.6))\n", "\n", - "pyro.set_rng_seed(42)\n", - "everyone_voted_HPR()\n" + " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", + " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", + " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", + " vote3 = pyro.deterministic(\"vote3\", u_vote3, event_dim=0)\n", + " vote4 = pyro.deterministic(\"vote4\", u_vote4, event_dim=0)\n", + " vote5 = pyro.deterministic(\"vote5\", u_vote5, event_dim=0)\n", + " vote6 = pyro.deterministic(\"vote6\", u_vote6, event_dim=0)\n", + " vote7 = pyro.deterministic(\"vote7\", u_vote7, event_dim=0)\n", + "\n", + "\n", + " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 + vote3 + \n", + " vote4 + vote5 + vote6 + vote7 > 4)\n", + " return {\"outcome\": outcome.float()}\n" ] }, { @@ -828,10 +1407,27 @@ } ], "source": [ - "# but the size-minimal actual causes are all of size 4\n", + "# everyone voter for \n", + "# the minimal number of interventions to\n", + "# change the outcome is 4\n", "# so your responsibility is 1/4\n", "\n", - "everyone_voted_HPR.responsibilityDF\n" + "observations8 = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1., u_vote5=0.,\n", + " u_vote6=1., u_vote7=1.)\n", + "\n", + "\n", + "treatment_candidates8 = {key[2:]: 1-v for key, v in observations.items() if key != \"u_vote0\"}\n", + "evaluated_node_counterfactual8 = {\"vote0\": 1 - observations[\"u_vote0\"]}\n", + "\n", + "voting8HPR = HalpernPearlResponsibilityApproximate(\n", + " model = voting_model8,\n", + " evaluated_node_counterfactual = evaluated_node_counterfactual8,\n", + " treatment_candidates = treatment_candidates8,\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,9)],\n", + " observations = observations8)\n", + "\n" ] }, { @@ -1297,7 +1893,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.6" }, "orig_nbformat": 4 }, From dad1096b8c7e80c3fad631b5d48a4c751bb382c1 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Tue, 22 Aug 2023 11:27:39 +0200 Subject: [PATCH 06/13] voting in responsibility done --- docs/source/responsibility.ipynb | 570 +++++++++---------------------- 1 file changed, 166 insertions(+), 404 deletions(-) diff --git a/docs/source/responsibility.ipynb b/docs/source/responsibility.ipynb index 7fac683c..55d79db1 100644 --- a/docs/source/responsibility.ipynb +++ b/docs/source/responsibility.ipynb @@ -305,7 +305,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -372,7 +372,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -408,7 +408,7 @@ "\n", " print (f\"Secondary check: {secondary_check}\")\n", "\n", - " assert map_estimate == secondary_check, \"MAP estimate does not match secondary check.\" \n", + " assert map_estimate == secondary_check, \"MAP estimate does not match secondary check, increase sample size.\" \n", "\n", " return map_estimate" ] @@ -468,7 +468,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -500,13 +500,12 @@ " treatment_candidates = treatment_candidates,\n", " outcome = \"outcome\",\n", " witness_candidates = [f\"vote{i}\" for i in range(1,3)],\n", - " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.\n", - " ))\n" + " observations = observations)\n" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -568,14 +567,14 @@ " 0\n", " -1.197\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", + " 1.0\n", + " 1\n", + " -0.357\n", " 0\n", " -0.506\n", " ...\n", - " 1\n", - " -0.357\n", + " 0\n", + " -1.204\n", " 0\n", " -0.506\n", " True\n", @@ -592,14 +591,14 @@ " 0\n", " -1.197\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", + " 0.0\n", + " 0\n", + " -1.204\n", " 0\n", " -0.506\n", " ...\n", - " 0\n", - " -1.204\n", + " 1\n", + " -0.357\n", " 0\n", " -0.506\n", " True\n", @@ -640,16 +639,16 @@ " 0\n", " -1.197\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.923\n", - " ...\n", + " 0.0\n", " 0\n", " -1.204\n", " 0\n", " -0.506\n", + " ...\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.923\n", " True\n", " 0.0\n", " 2\n", @@ -658,22 +657,22 @@ " -4.187000e+00\n", " \n", " \n", - " 5\n", + " 6\n", " 1.0\n", " 0.0\n", " 0\n", " -1.197\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.506\n", - " ...\n", + " 1.0\n", " 1\n", " -0.357\n", " 1\n", " -0.923\n", + " ...\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.506\n", " True\n", " 0.0\n", " 2\n", @@ -739,13 +738,13 @@ " 1.0\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.506\n", " ...\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.923\n", " False\n", " -100000000.0\n", " 0\n", @@ -763,13 +762,13 @@ " 1.0\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.923\n", " ...\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.506\n", " False\n", " -100000000.0\n", " 0\n", @@ -832,14 +831,14 @@ " 1\n", " -0.360\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", + " 1.0\n", + " 1\n", + " -0.357\n", " 0\n", " -0.506\n", " ...\n", - " 1\n", - " -0.357\n", + " 0\n", + " -1.204\n", " 0\n", " -0.506\n", " False\n", @@ -856,14 +855,14 @@ " 1\n", " -0.360\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", + " 0.0\n", + " 0\n", + " -1.204\n", " 0\n", " -0.506\n", " ...\n", - " 0\n", - " -1.204\n", + " 1\n", + " -0.357\n", " 0\n", " -0.506\n", " False\n", @@ -922,22 +921,22 @@ " -1.000000e+08\n", " \n", " \n", - " 26\n", + " 25\n", " 1.0\n", " 1.0\n", " 1\n", " -0.360\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.506\n", - " ...\n", + " 1.0\n", " 1\n", " -0.357\n", " 1\n", " -0.923\n", + " ...\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.506\n", " False\n", " -100000000.0\n", " 1\n", @@ -952,16 +951,16 @@ " 1\n", " -0.360\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.923\n", - " ...\n", + " 0.0\n", " 0\n", " -1.204\n", " 0\n", " -0.506\n", + " ...\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.923\n", " False\n", " -100000000.0\n", " 1\n", @@ -1000,69 +999,69 @@ ], "text/plain": [ " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -1.197 1.0 0.0 \n", - "1 1.0 0.0 0 -1.197 1.0 1.0 \n", + "0 1.0 0.0 0 -1.197 1.0 1.0 \n", + "1 1.0 0.0 0 -1.197 1.0 0.0 \n", "2 1.0 1.0 1 -0.360 1.0 0.0 \n", - "3 1.0 0.0 0 -1.197 1.0 1.0 \n", - "5 1.0 0.0 0 -1.197 1.0 0.0 \n", + "3 1.0 0.0 0 -1.197 1.0 0.0 \n", + "6 1.0 0.0 0 -1.197 1.0 1.0 \n", "12 1.0 0.0 0 -1.197 1.0 0.0 \n", "16 1.0 1.0 1 -0.360 1.0 1.0 \n", "17 1.0 1.0 1 -0.360 1.0 1.0 \n", "18 1.0 1.0 1 -0.360 1.0 1.0 \n", "19 1.0 1.0 1 -0.360 1.0 1.0 \n", "20 1.0 0.0 0 -1.197 1.0 1.0 \n", - "21 1.0 1.0 1 -0.360 1.0 0.0 \n", - "22 1.0 1.0 1 -0.360 1.0 1.0 \n", + "21 1.0 1.0 1 -0.360 1.0 1.0 \n", + "22 1.0 1.0 1 -0.360 1.0 0.0 \n", "23 1.0 0.0 0 -1.197 1.0 1.0 \n", "24 1.0 0.0 0 -1.197 1.0 1.0 \n", - "26 1.0 1.0 1 -0.360 1.0 0.0 \n", - "28 1.0 1.0 1 -0.360 1.0 1.0 \n", + "25 1.0 1.0 1 -0.360 1.0 1.0 \n", + "28 1.0 1.0 1 -0.360 1.0 0.0 \n", "29 1.0 0.0 0 -1.197 1.0 1.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", - "0 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "1 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "0 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "1 0 -1.204 0 -0.506 ... 1 -0.357 \n", "2 0 -1.204 0 -0.506 ... 0 -1.204 \n", - "3 1 -0.357 1 -0.923 ... 0 -1.204 \n", - "5 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "3 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "6 1 -0.357 1 -0.923 ... 0 -1.204 \n", "12 0 -1.204 0 -0.506 ... 0 -1.204 \n", "16 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "17 1 -0.357 1 -0.923 ... 1 -0.357 \n", - "18 1 -0.357 0 -0.506 ... 1 -0.357 \n", + "17 1 -0.357 0 -0.506 ... 1 -0.357 \n", + "18 1 -0.357 1 -0.923 ... 1 -0.357 \n", "19 1 -0.357 1 -0.923 ... 1 -0.357 \n", "20 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "21 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "22 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "21 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "22 0 -1.204 0 -0.506 ... 1 -0.357 \n", "23 1 -0.357 0 -0.506 ... 1 -0.357 \n", "24 1 -0.357 1 -0.923 ... 1 -0.357 \n", - "26 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "28 1 -0.357 1 -0.923 ... 0 -1.204 \n", + "25 1 -0.357 1 -0.923 ... 0 -1.204 \n", + "28 0 -1.204 0 -0.506 ... 1 -0.357 \n", "29 1 -0.357 1 -0.923 ... 1 -0.357 \n", "\n", " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", "0 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", "1 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", "2 0 -0.506 True 0.0 2 0 2 -3.780000e+00 \n", - "3 0 -0.506 True 0.0 2 1 3 -4.187000e+00 \n", - "5 1 -0.923 True 0.0 2 1 3 -4.187000e+00 \n", + "3 1 -0.923 True 0.0 2 1 3 -4.187000e+00 \n", + "6 0 -0.506 True 0.0 2 1 3 -4.187000e+00 \n", "12 0 -0.506 True 0.0 3 0 3 -4.618000e+00 \n", "16 0 -0.506 False -100000000.0 0 0 0 -1.000000e+08 \n", - "17 0 -0.506 False -100000000.0 0 1 1 -1.000000e+08 \n", - "18 1 -0.923 False -100000000.0 0 1 1 -1.000000e+08 \n", + "17 1 -0.923 False -100000000.0 0 1 1 -1.000000e+08 \n", + "18 0 -0.506 False -100000000.0 0 1 1 -1.000000e+08 \n", "19 1 -0.923 False -100000000.0 0 2 2 -1.000000e+08 \n", "20 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", "21 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", "22 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", "23 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", "24 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", - "26 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", - "28 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", + "25 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", + "28 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", "29 1 -0.923 False -100000000.0 1 2 3 -1.000000e+08 \n", "\n", "[18 rows x 22 columns]" ] }, - "execution_count": 11, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -1087,7 +1086,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -1097,12 +1096,12 @@ "MAP estimate: 0.5\n", "Minimal scenarios:\n", " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -1.197 1.0 0.0 \n", - "1 1.0 0.0 0 -1.197 1.0 1.0 \n", + "0 1.0 0.0 0 -1.197 1.0 1.0 \n", + "1 1.0 0.0 0 -1.197 1.0 0.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", - "0 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "1 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "0 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "1 0 -1.204 0 -0.506 ... 1 -0.357 \n", "\n", " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", "0 0 -0.506 True 0.0 2 0 2 -3.771 \n", @@ -1118,7 +1117,7 @@ "0.5" ] }, - "execution_count": 12, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -1129,14 +1128,14 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# now consider a more complex example,\n", - "# with 8 voters, where you are not an actual cause\n", + "# with 7 voters, where you are not an actual cause\n", "\n", - "def voting_model8():\n", + "def voting_model7():\n", " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", @@ -1144,7 +1143,7 @@ " u_vote4 = pyro.sample(\"u_vote4\", dist.Bernoulli(0.6))\n", " u_vote5 = pyro.sample(\"u_vote5\", dist.Bernoulli(0.6))\n", " u_vote6 = pyro.sample(\"u_vote6\", dist.Bernoulli(0.6))\n", - " u_vote7 = pyro.sample(\"u_vote7\", dist.Bernoulli(0.6))\n", + " \n", "\n", " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", @@ -1153,351 +1152,114 @@ " vote4 = pyro.deterministic(\"vote4\", u_vote4, event_dim=0)\n", " vote5 = pyro.deterministic(\"vote5\", u_vote5, event_dim=0)\n", " vote6 = pyro.deterministic(\"vote6\", u_vote6, event_dim=0)\n", - " vote7 = pyro.deterministic(\"vote7\", u_vote7, event_dim=0)\n", - "\n", "\n", " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 + vote3 + \n", - " vote4 + vote5 + vote6 + vote7 > 4)\n", + " vote4 + vote5 + vote6 > 3 )\n", " return {\"outcome\": outcome.float()}\n" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 23, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
acsantecedent_sizeminimal_witness_sizedenumeratorresponsibility
0True4040.25
1True4040.25
2True4040.25
3True4040.25
4True4040.25
5True4040.25
6True4040.25
7True4040.25
8True4040.25
9True4040.25
10True4040.25
11True4040.25
12True4040.25
13True4040.25
14True4040.25
15True4040.25
16True4040.25
17True4040.25
18True4040.25
19True4040.25
20True4040.25
21True4040.25
\n", - "
" - ], - "text/plain": [ - " acs antecedent_size minimal_witness_size denumerator responsibility\n", - "0 True 4 0 4 0.25\n", - "1 True 4 0 4 0.25\n", - "2 True 4 0 4 0.25\n", - "3 True 4 0 4 0.25\n", - "4 True 4 0 4 0.25\n", - "5 True 4 0 4 0.25\n", - "6 True 4 0 4 0.25\n", - "7 True 4 0 4 0.25\n", - "8 True 4 0 4 0.25\n", - "9 True 4 0 4 0.25\n", - "10 True 4 0 4 0.25\n", - "11 True 4 0 4 0.25\n", - "12 True 4 0 4 0.25\n", - "13 True 4 0 4 0.25\n", - "14 True 4 0 4 0.25\n", - "15 True 4 0 4 0.25\n", - "16 True 4 0 4 0.25\n", - "17 True 4 0 4 0.25\n", - "18 True 4 0 4 0.25\n", - "19 True 4 0 4 0.25\n", - "20 True 4 0 4 0.25\n", - "21 True 4 0 4 0.25" - ] - }, - "execution_count": 52, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.5347741090651572 .\n" + ] } ], "source": [ - "# everyone voter for \n", - "# the minimal number of interventions to\n", - "# change the outcome is 4\n", + "# everyone voted for,\n", + "# you are not an actual cause \n", + "# the minimal number of interventions \n", + "# including your change of vote\n", + "# needed to change the outcome is 4\n", "# so your responsibility is 1/4\n", "\n", - "observations8 = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=1., u_vote5=0.,\n", - " u_vote6=1., u_vote7=1.)\n", + "observations7 = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1.,\n", + " u_vote5=1.,\n", + " u_vote6=1.,\n", + " )\n", + "\n", "\n", + "treatment_candidates7 = {key[2:]: 1-v for key, v in observations7.items() if key != \"u_vote0\"}\n", "\n", - "treatment_candidates8 = {key[2:]: 1-v for key, v in observations.items() if key != \"u_vote0\"}\n", - "evaluated_node_counterfactual8 = {\"vote0\": 1 - observations[\"u_vote0\"]}\n", + "evaluated_node_counterfactual7 = {\"vote0\": 1 - observations7[\"u_vote0\"]}\n", "\n", "voting8HPR = HalpernPearlResponsibilityApproximate(\n", - " model = voting_model8,\n", - " evaluated_node_counterfactual = evaluated_node_counterfactual8,\n", - " treatment_candidates = treatment_candidates8,\n", + " model = voting_model7,\n", + " evaluated_node_counterfactual = evaluated_node_counterfactual7,\n", + " treatment_candidates = treatment_candidates7,\n", " outcome = \"outcome\",\n", - " witness_candidates = [f\"vote{i}\" for i in range(1,9)],\n", - " observations = observations8)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.25" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# four people would need to change their votes\n", - "# to change the outcome\n", - "# so your responsibility is 1/4\n", - "\n", - "everyone_voted_HPR.responsibility" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [], - "source": [ - "# if only seven people voted for, \n", - "# your responsibility changes to 1/3\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,7)],\n", + " observations = observations7)\n", "\n", - "seven_voted_for_HPR = HalpernPearlResponsibilityApproximate(\n", - " model = voting_model,\n", - " nodes = [f\"vote{i}\" for i in range(0,8,)],\n", - " antecedent = \"vote0\", outcome = \"outcome\",\n", - " observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", - " u_vote3=1., u_vote4=1., u_vote5=1., u_vote6 = 1., u_vote7 = 0.), \n", - " runs_n=500\n", - " )\n", - "\n", - "pyro.set_rng_seed(42)\n", - "seven_voted_for_HPR()" + "with pyro.plate(\"runs\", 10000):\n", + " voting8HPR()\n" ] }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 24, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MAP estimate: 0.25\n", + "Minimal scenarios:\n", + " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", + "0 1.0 0.0 0 -1.197 1.0 1.0 \n", + "1 1.0 0.0 0 -1.197 1.0 1.0 \n", + "2 1.0 0.0 0 -1.197 1.0 1.0 \n", + "3 1.0 0.0 0 -1.197 1.0 0.0 \n", + "4 1.0 0.0 0 -1.197 1.0 0.0 \n", + "5 1.0 0.0 0 -1.197 1.0 0.0 \n", + "6 1.0 0.0 0 -1.197 1.0 0.0 \n", + "7 1.0 0.0 0 -1.197 1.0 1.0 \n", + "\n", + " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote6 alp_vote6 \\\n", + "0 1 -0.357 0 -0.626 ... 0 -1.204 \n", + "1 1 -0.357 0 -0.626 ... 1 -0.357 \n", + "2 1 -0.357 0 -0.626 ... 1 -0.357 \n", + "3 0 -1.204 0 -0.626 ... 1 -0.357 \n", + "4 0 -1.204 0 -0.626 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.626 ... 1 -0.357 \n", + "6 0 -1.204 0 -0.626 ... 0 -1.204 \n", + "7 1 -0.357 0 -0.626 ... 0 -1.204 \n", + "\n", + " wpr_vote6 wlp_vote6 cdif clp int wpr changes sum_lp \n", + "0 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "1 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "2 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "3 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "4 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "5 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "6 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "7 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "\n", + "[8 rows x 46 columns]\n", + "Secondary check: 0.25\n" + ] + }, { "data": { "text/plain": [ - "0.3333333333333333" + "0.25" ] }, - "execution_count": 56, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# your responsibility is 1/3 as in this case\n", - "# it would be enough for three people to vote against\n", - "# to change the outcome\n", - "\n", - "seven_voted_for_HPR.responsibility" + "responsibility_check(voting8HPR)" ] }, { From fef86dabfa43e0096b3591c4d6bd7108049cdef2 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Wed, 23 Aug 2023 12:49:51 +0200 Subject: [PATCH 07/13] blocked with triple preemptions going awry --- docs/source/responsibility.ipynb | 1384 ++++++++++++++++++++++++++---- 1 file changed, 1217 insertions(+), 167 deletions(-) diff --git a/docs/source/responsibility.ipynb b/docs/source/responsibility.ipynb index 55d79db1..fc2c88c3 100644 --- a/docs/source/responsibility.ipynb +++ b/docs/source/responsibility.ipynb @@ -175,8 +175,8 @@ " if observations is None:\n", " observations = {}\n", "\n", - " if not set(witness_candidates) <= set(treatment_candidates.keys()):\n", - " raise ValueError(\"witness_candidates must be a subset of treatment_candidates.keys().\")\n", + " #if not set(witness_candidates) <= set(treatment_candidates.keys()):\n", + " # raise ValueError(\"witness_candidates must be a subset of treatment_candidates.keys().\")\n", " \n", " self.model = model\n", " self.evaluated_node_counterfactual = evaluated_node_counterfactual\n", @@ -257,10 +257,15 @@ " # the last element is the fixed at the observed value (preempted) \n", " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", " with pyro.poutine.trace() as self.trace:\n", - " self.consequent = self.model(*args, **kwargs)[self.outcome]\n", - " self.counterfactual_interventions = list(self.evaluated_node_counterfactual.keys()) + list(self.treatment_candidates.keys())\n", - " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.counterfactual_interventions}))\n", - " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.counterfactual_interventions}))\n", + " self.run = self.model(*args, **kwargs)\n", + " self.consequent = self.run[self.outcome]\n", + " self.active_interventions = { intervention: {1} for intervention \n", + " in list(self.evaluated_node_counterfactual.keys()) + \n", + " list(self.treatment_candidates.keys())}\n", + " \n", + " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.all_interventions}))\n", + " print(\"ic:\", self.intervened_consequent)\n", + " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.all_interventions}))\n", " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", " pyro.deterministic(\"consequent_differs_binary\", self.consequent_differs, event_dim = 0) #feels inelegant\n", " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))" @@ -372,7 +377,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -468,7 +473,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -505,14 +510,77 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.6026881514681908 .\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.6026881514681908 .\n", + "ic: tensor([[[[[[[1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.,\n", + " 1., 0., 0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.,\n", + " 0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1.,\n", + " 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1.,\n", + " 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1.,\n", + " 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0.,\n", + " 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0.,\n", + " 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", + " 1., 0., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1.,\n", + " 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1., 1.,\n", + " 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1.,\n", + " 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,\n", + " 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1.,\n", + " 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1.,\n", + " 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,\n", + " 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0.,\n", + " 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1., 1.,\n", + " 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1.,\n", + " 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,\n", + " 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0.,\n", + " 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,\n", + " 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1., 0.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0.,\n", + " 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0.,\n", + " 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1.,\n", + " 1., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0.,\n", + " 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1.,\n", + " 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0., 1.,\n", + " 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1.,\n", + " 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0.,\n", + " 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1.,\n", + " 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1.,\n", + " 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,\n", + " 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1.,\n", + " 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 0.,\n", + " 1., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0.,\n", + " 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,\n", + " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,\n", + " 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,\n", + " 0., 1., 1., 0., 1., 1., 1., 0.]]]]]]])\n" ] }, { @@ -633,7 +701,7 @@ " -3.780000e+00\n", " \n", " \n", - " 3\n", + " 4\n", " 1.0\n", " 0.0\n", " 0\n", @@ -738,13 +806,13 @@ " 1.0\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.923\n", " ...\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.506\n", " False\n", " -100000000.0\n", " 0\n", @@ -762,13 +830,13 @@ " 1.0\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.506\n", " ...\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.923\n", " False\n", " -100000000.0\n", " 0\n", @@ -831,14 +899,14 @@ " 1\n", " -0.360\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", + " 0.0\n", + " 0\n", + " -1.204\n", " 0\n", " -0.506\n", " ...\n", - " 0\n", - " -1.204\n", + " 1\n", + " -0.357\n", " 0\n", " -0.506\n", " False\n", @@ -855,14 +923,14 @@ " 1\n", " -0.360\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", + " 1.0\n", + " 1\n", + " -0.357\n", " 0\n", " -0.506\n", " ...\n", - " 1\n", - " -0.357\n", + " 0\n", + " -1.204\n", " 0\n", " -0.506\n", " False\n", @@ -882,13 +950,13 @@ " 1.0\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.923\n", " ...\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.506\n", " False\n", " -100000000.0\n", " 1\n", @@ -906,13 +974,13 @@ " 1.0\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.506\n", " ...\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.923\n", " False\n", " -100000000.0\n", " 1\n", @@ -921,22 +989,22 @@ " -1.000000e+08\n", " \n", " \n", - " 25\n", + " 26\n", " 1.0\n", " 1.0\n", " 1\n", " -0.360\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.923\n", - " ...\n", + " 0.0\n", " 0\n", " -1.204\n", " 0\n", " -0.506\n", + " ...\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.923\n", " False\n", " -100000000.0\n", " 1\n", @@ -945,22 +1013,22 @@ " -1.000000e+08\n", " \n", " \n", - " 28\n", + " 27\n", " 1.0\n", " 1.0\n", " 1\n", " -0.360\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.506\n", - " ...\n", + " 1.0\n", " 1\n", " -0.357\n", " 1\n", " -0.923\n", + " ...\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.506\n", " False\n", " -100000000.0\n", " 1\n", @@ -1002,7 +1070,7 @@ "0 1.0 0.0 0 -1.197 1.0 1.0 \n", "1 1.0 0.0 0 -1.197 1.0 0.0 \n", "2 1.0 1.0 1 -0.360 1.0 0.0 \n", - "3 1.0 0.0 0 -1.197 1.0 0.0 \n", + "4 1.0 0.0 0 -1.197 1.0 0.0 \n", "6 1.0 0.0 0 -1.197 1.0 1.0 \n", "12 1.0 0.0 0 -1.197 1.0 0.0 \n", "16 1.0 1.0 1 -0.360 1.0 1.0 \n", @@ -1010,58 +1078,58 @@ "18 1.0 1.0 1 -0.360 1.0 1.0 \n", "19 1.0 1.0 1 -0.360 1.0 1.0 \n", "20 1.0 0.0 0 -1.197 1.0 1.0 \n", - "21 1.0 1.0 1 -0.360 1.0 1.0 \n", - "22 1.0 1.0 1 -0.360 1.0 0.0 \n", + "21 1.0 1.0 1 -0.360 1.0 0.0 \n", + "22 1.0 1.0 1 -0.360 1.0 1.0 \n", "23 1.0 0.0 0 -1.197 1.0 1.0 \n", "24 1.0 0.0 0 -1.197 1.0 1.0 \n", - "25 1.0 1.0 1 -0.360 1.0 1.0 \n", - "28 1.0 1.0 1 -0.360 1.0 0.0 \n", + "26 1.0 1.0 1 -0.360 1.0 0.0 \n", + "27 1.0 1.0 1 -0.360 1.0 1.0 \n", "29 1.0 0.0 0 -1.197 1.0 1.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", "0 1 -0.357 0 -0.506 ... 0 -1.204 \n", "1 0 -1.204 0 -0.506 ... 1 -0.357 \n", "2 0 -1.204 0 -0.506 ... 0 -1.204 \n", - "3 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "4 0 -1.204 0 -0.506 ... 1 -0.357 \n", "6 1 -0.357 1 -0.923 ... 0 -1.204 \n", "12 0 -1.204 0 -0.506 ... 0 -1.204 \n", "16 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "17 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "18 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "17 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "18 1 -0.357 0 -0.506 ... 1 -0.357 \n", "19 1 -0.357 1 -0.923 ... 1 -0.357 \n", "20 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "21 1 -0.357 0 -0.506 ... 0 -1.204 \n", - "22 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "23 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "24 1 -0.357 1 -0.923 ... 1 -0.357 \n", - "25 1 -0.357 1 -0.923 ... 0 -1.204 \n", - "28 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "21 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "22 1 -0.357 0 -0.506 ... 0 -1.204 \n", + "23 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "24 1 -0.357 0 -0.506 ... 1 -0.357 \n", + "26 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "27 1 -0.357 1 -0.923 ... 0 -1.204 \n", "29 1 -0.357 1 -0.923 ... 1 -0.357 \n", "\n", " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", "0 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", "1 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", "2 0 -0.506 True 0.0 2 0 2 -3.780000e+00 \n", - "3 1 -0.923 True 0.0 2 1 3 -4.187000e+00 \n", + "4 1 -0.923 True 0.0 2 1 3 -4.187000e+00 \n", "6 0 -0.506 True 0.0 2 1 3 -4.187000e+00 \n", "12 0 -0.506 True 0.0 3 0 3 -4.618000e+00 \n", "16 0 -0.506 False -100000000.0 0 0 0 -1.000000e+08 \n", - "17 1 -0.923 False -100000000.0 0 1 1 -1.000000e+08 \n", - "18 0 -0.506 False -100000000.0 0 1 1 -1.000000e+08 \n", + "17 0 -0.506 False -100000000.0 0 1 1 -1.000000e+08 \n", + "18 1 -0.923 False -100000000.0 0 1 1 -1.000000e+08 \n", "19 1 -0.923 False -100000000.0 0 2 2 -1.000000e+08 \n", "20 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", "21 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", "22 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", - "23 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", - "24 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", - "25 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", - "28 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", + "23 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", + "24 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", + "26 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", + "27 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", "29 1 -0.923 False -100000000.0 1 2 3 -1.000000e+08 \n", "\n", "[18 rows x 22 columns]" ] }, - "execution_count": 16, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -1086,7 +1154,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -1117,7 +1185,7 @@ "0.5" ] }, - "execution_count": 17, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -1128,7 +1196,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -1160,14 +1228,15 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.5347741090651572 .\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.5347741090651572 .\n", + "ic: tensor([[[[[[[[[[[1., 1., 0., ..., 1., 1., 1.]]]]]]]]]]])\n" ] } ], @@ -1204,7 +1273,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -1215,23 +1284,27 @@ "Minimal scenarios:\n", " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", "0 1.0 0.0 0 -1.197 1.0 1.0 \n", - "1 1.0 0.0 0 -1.197 1.0 1.0 \n", - "2 1.0 0.0 0 -1.197 1.0 1.0 \n", + "1 1.0 0.0 0 -1.197 1.0 0.0 \n", + "2 1.0 0.0 0 -1.197 1.0 0.0 \n", "3 1.0 0.0 0 -1.197 1.0 0.0 \n", - "4 1.0 0.0 0 -1.197 1.0 0.0 \n", + "4 1.0 0.0 0 -1.197 1.0 1.0 \n", "5 1.0 0.0 0 -1.197 1.0 0.0 \n", - "6 1.0 0.0 0 -1.197 1.0 0.0 \n", + "6 1.0 0.0 0 -1.197 1.0 1.0 \n", "7 1.0 0.0 0 -1.197 1.0 1.0 \n", + "8 1.0 0.0 0 -1.197 1.0 1.0 \n", + "9 1.0 0.0 0 -1.197 1.0 0.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote6 alp_vote6 \\\n", - "0 1 -0.357 0 -0.626 ... 0 -1.204 \n", - "1 1 -0.357 0 -0.626 ... 1 -0.357 \n", - "2 1 -0.357 0 -0.626 ... 1 -0.357 \n", + "0 1 -0.357 0 -0.626 ... 1 -0.357 \n", + "1 0 -1.204 0 -0.626 ... 0 -1.204 \n", + "2 0 -1.204 0 -0.626 ... 1 -0.357 \n", "3 0 -1.204 0 -0.626 ... 1 -0.357 \n", - "4 0 -1.204 0 -0.626 ... 0 -1.204 \n", - "5 0 -1.204 0 -0.626 ... 1 -0.357 \n", - "6 0 -1.204 0 -0.626 ... 0 -1.204 \n", - "7 1 -0.357 0 -0.626 ... 0 -1.204 \n", + "4 1 -0.357 0 -0.626 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.626 ... 0 -1.204 \n", + "6 1 -0.357 0 -0.626 ... 1 -0.357 \n", + "7 1 -0.357 0 -0.626 ... 1 -0.357 \n", + "8 1 -0.357 0 -0.626 ... 0 -1.204 \n", + "9 0 -1.204 0 -0.626 ... 1 -0.357 \n", "\n", " wpr_vote6 wlp_vote6 cdif clp int wpr changes sum_lp \n", "0 0 -0.626 True 0.0 4 0 4 -9.635 \n", @@ -1242,8 +1315,10 @@ "5 0 -0.626 True 0.0 4 0 4 -9.635 \n", "6 0 -0.626 True 0.0 4 0 4 -9.635 \n", "7 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "8 0 -0.626 True 0.0 4 0 4 -9.635 \n", + "9 0 -0.626 True 0.0 4 0 4 -9.635 \n", "\n", - "[8 rows x 46 columns]\n", + "[10 rows x 46 columns]\n", "Secondary check: 0.25\n" ] }, @@ -1253,7 +1328,7 @@ "0.25" ] }, - "execution_count": 24, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1280,7 +1355,62 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'sally_throws': tensor(1.), 'bill_throws': tensor(1.), 'sally_hits': tensor(0.), 'bill_hits': tensor(1.), 'bottle_shatters': tensor(1)}\n" + ] + } + ], + "source": [ + "import pyro\n", + "import pyro.distributions as dist\n", + "import torch\n", + "\n", + "def stones_model2(): \n", + " prob_sally_throws = pyro.sample(\"prob_sally_throws\", dist.Beta(1, 1))\n", + " prob_bill_throws = pyro.sample(\"prob_bill_throws\", dist.Beta(1, 1))\n", + " prob_sally_hits = pyro.sample(\"prob_sally_hits\", dist.Beta(1, 1))\n", + " prob_bill_hits = pyro.sample(\"prob_bill_hits\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_sally = pyro.sample(\"prob_bottle_shatters_if_sally\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_bill = pyro.sample(\"prob_bottle_shatters_if_bill\", dist.Beta(1, 1))\n", + "\n", + " sally_throws = pyro.sample(\"sally_throws\", dist.Bernoulli(prob_sally_throws))\n", + " bill_throws = pyro.sample(\"bill_throws\", dist.Bernoulli(prob_bill_throws))\n", + " sally_hits = pyro.sample(\"sally_hits\", dist.Bernoulli(prob_sally_hits * sally_throws))\n", + " bill_hits = pyro.sample(\"bill_hits\", dist.Bernoulli(prob_bill_hits * bill_throws * (1 - sally_hits)))\n", + "\n", + " bottle_shatters = pyro.sample(\"bottle_shatters\", dist.Bernoulli(\n", + " prob_bottle_shatters_if_bill * bill_hits + prob_bottle_shatters_if_sally * sally_hits\n", + " )).long()\n", + "\n", + " return {\n", + " \"sally_throws\": sally_throws,\n", + " \"bill_throws\": bill_throws,\n", + " \"sally_hits\": sally_hits,\n", + " \"bill_hits\": bill_hits,\n", + " \"bottle_shatters\": bottle_shatters,\n", + " }\n", + "\n", + "stones_model2.nodes = [\n", + " \"sally_throws\",\n", + " \"bill_throws\",\n", + " \"sally_hits\",\n", + " \"bill_hits\",\n", + " \"bottle_shatters\",\n", + "]\n", + "\n", + "result = stones_model2()\n", + "print(result)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -1334,100 +1464,1020 @@ " \"bill_hits\": bill_hits,\n", " \"bottle_shatters\": bottle_shatters,\n", " }\n", - "\n", - "stones_model.nodes = [\n", - " \"sally_throws\",\n", - " \"bill_throws\",\n", - " \"sally_hits\",\n", - " \"bill_hits\",\n", - " \"bottle_shatters\",\n", - " ]" + "\n" ] }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 55, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.567842770667782 .\n", + "ic: tensor([[[[[[[[1., 1., 1., 1., 1., 0., 1., 1., 1., 1.]]]]]]]])\n" + ] + } + ], "source": [ - "pyro.set_rng_seed(101)\n", - "responsibility_stones_sally_HPR = HalpernPearlResponsibilityApproximate(\n", + "\n", + "pyro.set_rng_seed(4)\n", + "stones_sallyHPR = HalpernPearlResponsibilityApproximate(\n", " model = stones_model,\n", - " nodes = stones_model.nodes,\n", - " antecedent = \"sally_throws\", outcome = \"bottle_shatters\",\n", - " observations = {\"prob_sally_throws\": 1, \n", - " \"prob_bill_throws\": 1,\n", - " \"prob_sally_hits\": 1,\n", - " \"prob_bill_hits\": 1,\n", - " \"prob_bottle_shatters_if_sally\": 1,\n", - " \"prob_bottle_shatters_if_bill\": 1,\n", - " \"sally_throws\": 1, \"bill_throws\": 1},\n", - " runs_n=100)\n", + " evaluated_node_counterfactual= {\"sally_throws\": 0.0},\n", + " treatment_candidates = {\"sally_hits\": 0.0, \"bill_hits\": 1.0, \"bill_throws\": 0.0},\n", + " outcome = \"bottle_shatters\",\n", + " witness_candidates = [\"bill_hits\", \"bill_throws\", \"sally_hits\"],\n", + " observations = {\"prob_sally_throws\": 1.0, \n", + " \"prob_bill_throws\": 1.0,\n", + " \"prob_sally_hits\": 1.0,\n", + " \"prob_bill_hits\": 1.0,\n", + " \"prob_bottle_shatters_if_sally\": 1.0,\n", + " \"prob_bottle_shatters_if_bill\": 1.0})\n", + "\n", + "with pyro.plate(\"runs\",10):\n", + " stones_sallyHPR()" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "odict_keys(['prob_sally_throws', 'prob_bill_throws', 'prob_sally_hits', 'prob_bill_hits', 'prob_bottle_shatters_if_sally', 'prob_bottle_shatters_if_bill', '__evaluated_split_sally_throws', 'sally_throws', '__witness_split_bill_throws', '__treatment_split_bill_throws', 'bill_throws', '__witness_split_sally_hits', '__treatment_split_sally_hits', 'sally_hits', '__witness_split_bill_hits', '__treatment_split_bill_hits', 'bill_hits', 'bottle_shatters', 'consequent_differs_binary', 'consequent_differs'])\n" + ] + } + ], + "source": [ + "print(stones_sallyHPR.trace.trace.nodes.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bh_split tensor([0, 1, 1, 0, 0, 0, 0, 1, 0, 0])\n" + ] + } + ], + "source": [ "\n", - "responsibility_stones_sally_HPR()" + "print(\"bh_split\", stones_sallyHPR.trace.trace.nodes['__witness_split_bill_hits']['value'])\n" ] }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 57, "metadata": {}, "outputs": [ { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
acsantecedent_sizeminimal_witness_sizedenumeratorresponsibility
0True1120.5
\n", - "
" - ], - "text/plain": [ - " acs antecedent_size minimal_witness_size denumerator responsibility\n", - "0 True 1 1 2 0.5" - ] - }, - "execution_count": 59, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "bh tensor([[[[[[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 1., 0., 1., 0., 1., 0., 0.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]]],\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 1., 1., 1., 0., 0., 0., 1.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 0., 1., 0., 0., 1., 0., 1.]]]]]]],\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 0.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 1., 1., 1., 0., 1., 1., 0.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 0.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 0.]]]]]],\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 1., 1., 1., 0., 0., 1., 1.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 1., 1., 0., 0., 1., 1., 1.]]]]]]]])\n" + ] + } + ], + "source": [ + "print(\"bh\",stones_sallyHPR.trace.trace.nodes['bill_hits']['value'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sh tensor([[[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 0., 1., 0., 1., 0., 1., 1.]]]]]],\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[[1., 1., 1., 1., 0., 1., 1., 1., 1., 0.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 0., 0., 0., 1., 0., 1., 0.]]]]]]])\n" + ] + } + ], + "source": [ + "\n", + "print(\"sh\", stones_sallyHPR.trace.trace.nodes['sally_hits']['value'])\n", + "\n", + "#stones_sallyHPR.trace.trace.nodes['bottle_shatters']\n" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "con tensor([[[[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 0., 1., 0., 1., 0., 1., 1.]]]]]],\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1., 1., 1., 0., 1., 1.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 0., 1., 0., 1., 1., 1., 1.]]]]]]],\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1., 0., 1., 0., 1., 1.]]]]]],\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1., 1., 1., 0., 1., 1.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1., 0., 1., 1., 1., 1.]]]]]]]])\n" + ] + } + ], + "source": [ + "print(\"con\", stones_sallyHPR.trace.trace.nodes['bottle_shatters']['value'])" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epr_sally_throwselp_sally_throwsapr_sally_hitsalp_sally_hitswpr_sally_hitswlp_sally_hitsapr_bill_hitsalp_bill_hitswpr_bill_hitswlp_bill_hitsapr_bill_throwsalp_bill_throwswpr_bill_throwswlp_bill_throwsclpintwprchangessum_lp
01-0.3601-0.3571-0.8391-0.3570-0.5660-1.2041-0.839-100000000.0123-1.000000e+08
11-0.3601-0.3570-0.5661-0.3571-0.8390-1.2040-0.566-100000000.0112-1.000000e+08
21-0.3601-0.3571-0.8391-0.3571-0.8390-1.2041-0.839-100000000.0134-1.000000e+08
30-1.1971-0.3571-0.8390-1.2040-0.5660-1.2041-0.839-100000000.0325-1.000000e+08
41-0.3600-1.2040-0.5660-1.2040-0.5661-0.3570-0.566-100000000.0202-1.000000e+08
50-1.1971-0.3570-0.5661-0.3570-0.5660-1.2040-0.5660.0202-4.812000e+00
61-0.3601-0.3570-0.5661-0.3570-0.5661-0.3570-0.566-100000000.0000-1.000000e+08
70-1.1971-0.3570-0.5661-0.3571-0.8391-0.3570-0.566-100000000.0112-1.000000e+08
81-0.3601-0.3570-0.5660-1.2040-0.5661-0.3571-0.839-100000000.0112-1.000000e+08
91-0.3600-1.2041-0.8391-0.3570-0.5661-0.3570-0.566-100000000.0112-1.000000e+08
\n", + "
" + ], + "text/plain": [ + " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", + "0 1 -0.360 1 -0.357 \n", + "1 1 -0.360 1 -0.357 \n", + "2 1 -0.360 1 -0.357 \n", + "3 0 -1.197 1 -0.357 \n", + "4 1 -0.360 0 -1.204 \n", + "5 0 -1.197 1 -0.357 \n", + "6 1 -0.360 1 -0.357 \n", + "7 0 -1.197 1 -0.357 \n", + "8 1 -0.360 1 -0.357 \n", + "9 1 -0.360 0 -1.204 \n", + "\n", + " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", + "0 1 -0.839 1 -0.357 \n", + "1 0 -0.566 1 -0.357 \n", + "2 1 -0.839 1 -0.357 \n", + "3 1 -0.839 0 -1.204 \n", + "4 0 -0.566 0 -1.204 \n", + "5 0 -0.566 1 -0.357 \n", + "6 0 -0.566 1 -0.357 \n", + "7 0 -0.566 1 -0.357 \n", + "8 0 -0.566 0 -1.204 \n", + "9 1 -0.839 1 -0.357 \n", + "\n", + " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", + "0 0 -0.566 0 -1.204 \n", + "1 1 -0.839 0 -1.204 \n", + "2 1 -0.839 0 -1.204 \n", + "3 0 -0.566 0 -1.204 \n", + "4 0 -0.566 1 -0.357 \n", + "5 0 -0.566 0 -1.204 \n", + "6 0 -0.566 1 -0.357 \n", + "7 1 -0.839 1 -0.357 \n", + "8 0 -0.566 1 -0.357 \n", + "9 0 -0.566 1 -0.357 \n", + "\n", + " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", + "0 1 -0.839 -100000000.0 1 2 3 \n", + "1 0 -0.566 -100000000.0 1 1 2 \n", + "2 1 -0.839 -100000000.0 1 3 4 \n", + "3 1 -0.839 -100000000.0 3 2 5 \n", + "4 0 -0.566 -100000000.0 2 0 2 \n", + "5 0 -0.566 0.0 2 0 2 \n", + "6 0 -0.566 -100000000.0 0 0 0 \n", + "7 0 -0.566 -100000000.0 1 1 2 \n", + "8 1 -0.839 -100000000.0 1 1 2 \n", + "9 0 -0.566 -100000000.0 1 1 2 \n", + "\n", + " sum_lp \n", + "0 -1.000000e+08 \n", + "1 -1.000000e+08 \n", + "2 -1.000000e+08 \n", + "3 -1.000000e+08 \n", + "4 -1.000000e+08 \n", + "5 -4.812000e+00 \n", + "6 -1.000000e+08 \n", + "7 -1.000000e+08 \n", + "8 -1.000000e+08 \n", + "9 -1.000000e+08 " + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def gett(nodes, evaluated_node, antecedents, witness_candidates, round = True):\n", + " \n", + " values_table = {}\n", + "\n", + "\n", + "# values_table[f\"obs_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][0].squeeze().tolist()\n", + "# values_table[f\"int_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][1].squeeze().tolist()\n", + " values_table[f\"epr_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"].squeeze().tolist()\n", + " values_table[f\"elp_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"fn\"].log_prob(nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"]).squeeze().tolist()\n", + "\n", + " for antecedent in antecedents:\n", + "# andecedent_m = HPR.run[antecedent]\n", + "# print(gather(andecedent_m, IndexSet(**{antecedent: {0} for antecedent in antecedents})))\n", + "# values_table[f\"obs_{antecedent}\"] = nodes[antecedent][\"value\"][0].squeeze().tolist()\n", + "# values_table[f\"int_{antecedent}\"] = nodes[antecedent][\"value\"][1].squeeze().tolist()\n", + " values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent][\"fn\"].log_prob(nodes['__treatment_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " if f\"__witness_split_{antecedent}\" in nodes.keys():\n", + " values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent][\"fn\"].log_prob(nodes['__witness_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " for witness in witness_candidates:\n", + " if witness not in antecedents:\n", + " values_table['wpr_' + witness] = nodes['__witness_split_' + witness][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + witness] = nodes['__witness_split_' + witness][\"fn\"].log_prob(nodes['__witness_split_' + witness][\"value\"]).squeeze().tolist()\n", + "\n", + " values_table['clp'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " if isinstance(values_table['clp'], float):\n", + " values_df = pd.DataFrame([values_table])\n", + " else:\n", + " values_df = pd.DataFrame(values_table)\n", + " \n", + " values_df = pd.DataFrame(values_table)\n", + "\n", + " summands_ant = ['alp_' + antecedent for antecedent in antecedents]\n", + " summands_wit = ['wlp_' + witness for witness in witness_candidates]\n", + " summands = [f\"elp_{evaluated_node}\"] + summands_ant + summands_wit + ['clp']\n", + "\n", + " values_df[\"int\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"apr_\")] == 0), axis=1)\n", + " values_df['int'] = 1 - values_df[f\"epr_{evaluated_node}\"] + values_df[\"int\"]\n", + " values_df[\"wpr\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"wpr_\")] == 1), axis=1)\n", + " values_df[\"changes\"] = values_df[\"int\"] + values_df[\"wpr\"]\n", + "\n", + " values_df[\"sum_lp\"] = values_df[summands].sum(axis = 1) \n", + " #values_df.drop_duplicates(inplace = True)\n", + " #values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", + "\n", + " tab = values_df.reset_index(drop = True)\n", + "\n", + " #tab = remove_redundant_rows(tab)\n", + " \n", + " if round:\n", + " tab = tab.round(3)\n", + "\n", + " return tab\n", + "\n", + "\n", + "tab = gett(stones_sallyHPR.trace.trace.nodes, \"sally_throws\", stones_sallyHPR.treatment_candidates, \n", + " stones_sallyHPR.witness_candidates)\n", + "\n", + "tab" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epr_sally_throwselp_sally_throwsapr_sally_hitsalp_sally_hitswpr_sally_hitswlp_sally_hitsapr_bill_hitsalp_bill_hitswpr_bill_hitswlp_bill_hitsapr_bill_throwsalp_bill_throwswpr_bill_throwswlp_bill_throwsclpintwprchangessum_lp
170-1.1971-0.3570-0.5661-0.3571-0.8391-0.3570-0.566-100000000.0112-1.000000e+08
\n", + "
" + ], + "text/plain": [ + " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", + "17 0 -1.197 1 -0.357 \n", + "\n", + " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", + "17 0 -0.566 1 -0.357 \n", + "\n", + " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", + "17 1 -0.839 1 -0.357 \n", + "\n", + " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", + "17 0 -0.566 -100000000.0 1 1 2 \n", + "\n", + " sum_lp \n", + "17 -1.000000e+08 " + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#this is worrying\n", + "\n", + "tab.query(\"epr_sally_throws == 0 & apr_sally_hits == 1 & wpr_sally_hits == 0 & apr_bill_hits == 1 & wpr_bill_hits == 1 & apr_bill_throws == 1 & wpr_bill_throws == 0\")" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'type': 'sample',\n", + " 'name': 'consequent_differs_binary',\n", + " 'fn': MaskedDistribution(),\n", + " 'is_observed': True,\n", + " 'args': (),\n", + " 'kwargs': {},\n", + " 'value': tensor([[[[[False, False, False, False, False, False, False, False, False,\n", + " False]]]]]),\n", + " 'infer': {'_deterministic': True},\n", + " 'scale': 1.0,\n", + " 'mask': None,\n", + " 'cond_indep_stack': (CondIndepStackFrame(name='runs', dim=-1, size=10, counter=0),),\n", + " 'done': True,\n", + " 'stop': False,\n", + " 'continuation': None}" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stones_sallyHPR.trace.trace.nodes[\"consequent_differs_binary\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'obs_sally_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", + " 'int_sally_throws': [1.0, 1.0, 1.0, 1.0, 0.0],\n", + " 'epr_sally_throws': [1, 1, 1, 1, 0],\n", + " 'elp_sally_throws': [-0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -1.1973283290863037],\n", + " 'obs_bill_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", + " 'int_bill_throws': [1.0, 0.0, 1.0, 1.0, 1.0],\n", + " 'apr_bill_throws': [1, 0, 1, 1, 1],\n", + " 'alp_bill_throws': [-0.3566749691963196,\n", + " -1.2039728164672852,\n", + " -0.3566749691963196,\n", + " -0.3566749691963196,\n", + " -0.3566749691963196],\n", + " 'obs_bill_hits': [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]],\n", + " 'wpr_bill_hits': [0, 1, 0, 0, 0],\n", + " 'wlp_bill_hits': [-0.3624056577682495,\n", + " -1.1907275915145874,\n", + " -0.3624056577682495,\n", + " -0.3624056577682495,\n", + " -0.3624056577682495]}" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def get_table(nodes, evaluated_node, antecedents, witness_candidates, round = True):\n", + " \n", + " values_table = {}\n", + "\n", + " values_table[f\"obs_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][1].squeeze().tolist()\n", + " values_table[f\"epr_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"].squeeze().tolist()\n", + " values_table[f\"elp_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"fn\"].log_prob(nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"]).squeeze().tolist()\n", + "\n", + " for antecedent in antecedents:\n", + " values_table[f\"obs_{antecedent}\"] = nodes[antecedent][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{antecedent}\"] = nodes[antecedent][\"value\"][1].squeeze().tolist()\n", + " values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent][\"fn\"].log_prob(nodes['__treatment_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + "\n", + "\n", + " if f\"__witness_split_{antecedent}\" in nodes.keys():\n", + " values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent][\"fn\"].log_prob(nodes['__witness_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " for witness in witness_candidates:\n", + " if witness not in antecedents:\n", + " values_table[f\"obs_{witness}\"] = nodes[witness][\"value\"][0].squeeze().tolist()\n", + " #values_table[f\"int_{witness}\"] = nodes[witness][\"value\"][1].squeeze().tolist()\n", + " values_table['wpr_' + witness] = nodes['__witness_split_' + witness][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + witness] = nodes['__witness_split_' + witness][\"fn\"].log_prob(nodes['__witness_split_' + witness][\"value\"]).squeeze().tolist()\n", + "\n", + " \n", + " #values_table['cdif'] = nodes['consequent_differs_binary'][\"value\"].squeeze().tolist()\n", + " #values_table['clp'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " #if isinstance(values_table['clp'], float):\n", + " # values_df = pd.DataFrame([values_table])\n", + " # else:\n", + " # values_df = pd.DataFrame(values_table)\n", + " \n", + " # values_df = pd.DataFrame(values_table)\n", + "\n", + " #summands_ant = ['alp_' + antecedent for antecedent in antecedents]\n", + " #summands_wit = ['wlp_' + witness for witness in witness_candidates]\n", + " #summands = [f\"elp_{evaluated_node}\"] + summands_ant + summands_wit + ['clp']\n", + " \n", + " \n", + " # values_df[\"int\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"apr_\")] == 0), axis=1)\n", + " # values_df['int'] = 1 - values_df[f\"epr_{evaluated_node}\"] + values_df[\"int\"]\n", + " # values_df[\"wpr\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"wpr_\")] == 1), axis=1)\n", + " # values_df[\"changes\"] = values_df[\"int\"] + values_df[\"wpr\"]\n", + "\n", + "\n", + " #values_df[\"sum_lp\"] = values_df[summands].sum(axis = 1) \n", + " # values_df.drop_duplicates(inplace = True)\n", + " # values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", + "\n", + " # tab = values_df.reset_index(drop = True)\n", + "\n", + " # tab = remove_redundant_rows(tab)\n", + "\n", + " tab = values_table\n", + "\n", + " #if round:\n", + " # tab = tab.round(3)\n", + "\n", + " return tab\n", + "\n", + "\n", + "get_table(stones_sallyHPR.trace.trace.nodes, \"sally_throws\", stones_sallyHPR.treatment_candidates, \n", + " stones_sallyHPR.witness_candidates)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'type': 'sample', 'name': 'prob_bill_hits', 'fn': Beta(), 'is_observed': True, 'args': (), 'kwargs': {}, 'value': tensor(1.), 'infer': {'_do_not_observe': True}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (CondIndepStackFrame(name='runs', dim=-1, size=5, counter=0),), 'done': True, 'stop': False, 'continuation': None}\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'obs_sally_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", + " 'int_sally_throws': [1.0, 1.0, 1.0, 1.0, 0.0],\n", + " 'epr_sally_throws': [1, 1, 1, 1, 0],\n", + " 'elp_sally_throws': [-0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -1.1973283290863037],\n", + " 'obs_bill_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", + " 'int_bill_throws': [1.0, 1.0, 1.0, 0.0, 1.0],\n", + " 'apr_bill_throws': [1, 1, 1, 0, 1],\n", + " 'alp_bill_throws': [-0.3566749691963196,\n", + " -0.3566749691963196,\n", + " -0.3566749691963196,\n", + " -1.2039728164672852,\n", + " -0.3566749691963196],\n", + " 'wpr_bill_throws': [0, 0, 1, 1, 1],\n", + " 'wlp_bill_throws': [-0.5659106969833374,\n", + " -0.5659106969833374,\n", + " -0.8389658331871033,\n", + " -0.8389658331871033,\n", + " -0.8389658331871033]}" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "All arrays must be of the same length", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[63], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# minimal witness size becomes non-trivial here\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[39m# we only record different minimal difference-making scenarios\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m responsibility_check(stones_sallyHPR)\n", + "Cell \u001b[0;32mIn[6], line 4\u001b[0m, in \u001b[0;36mresponsibility_check\u001b[0;34m(hpr)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mresponsibility_check\u001b[39m(hpr):\n\u001b[1;32m 3\u001b[0m evaluated_node \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(hpr\u001b[39m.\u001b[39mevaluated_node_counterfactual\u001b[39m.\u001b[39mkeys())[\u001b[39m0\u001b[39m]\n\u001b[0;32m----> 4\u001b[0m tab \u001b[39m=\u001b[39m get_table(hpr\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mnodes,\n\u001b[1;32m 5\u001b[0m evaluated_node ,\n\u001b[1;32m 6\u001b[0m \u001b[39mlist\u001b[39;49m(hpr\u001b[39m.\u001b[39;49mtreatment_candidates\u001b[39m.\u001b[39;49mkeys()), \n\u001b[1;32m 7\u001b[0m hpr\u001b[39m.\u001b[39;49mwitness_candidates)\n\u001b[1;32m 9\u001b[0m max_sum_lp \u001b[39m=\u001b[39m tab[\u001b[39m'\u001b[39m\u001b[39msum_lp\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m.\u001b[39mmax()\n\u001b[1;32m 10\u001b[0m max_sum_lp_rows \u001b[39m=\u001b[39m tab[tab[\u001b[39m'\u001b[39m\u001b[39msum_lp\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m==\u001b[39m max_sum_lp]\n", + "Cell \u001b[0;32mIn[5], line 31\u001b[0m, in \u001b[0;36mget_table\u001b[0;34m(nodes, evaluated_node, antecedents, witness_candidates, round)\u001b[0m\n\u001b[1;32m 29\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mDataFrame([values_table])\n\u001b[1;32m 30\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 31\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39;49mDataFrame(values_table)\n\u001b[1;32m 33\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mDataFrame(values_table)\n\u001b[1;32m 35\u001b[0m summands_ant \u001b[39m=\u001b[39m [\u001b[39m'\u001b[39m\u001b[39malp_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent \u001b[39mfor\u001b[39;00m antecedent \u001b[39min\u001b[39;00m antecedents]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/frame.py:663\u001b[0m, in \u001b[0;36mDataFrame.__init__\u001b[0;34m(self, data, index, columns, dtype, copy)\u001b[0m\n\u001b[1;32m 657\u001b[0m mgr \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_init_mgr(\n\u001b[1;32m 658\u001b[0m data, axes\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39mindex\u001b[39m\u001b[39m\"\u001b[39m: index, \u001b[39m\"\u001b[39m\u001b[39mcolumns\u001b[39m\u001b[39m\"\u001b[39m: columns}, dtype\u001b[39m=\u001b[39mdtype, copy\u001b[39m=\u001b[39mcopy\n\u001b[1;32m 659\u001b[0m )\n\u001b[1;32m 661\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, \u001b[39mdict\u001b[39m):\n\u001b[1;32m 662\u001b[0m \u001b[39m# GH#38939 de facto copy defaults to False only in non-dict cases\u001b[39;00m\n\u001b[0;32m--> 663\u001b[0m mgr \u001b[39m=\u001b[39m dict_to_mgr(data, index, columns, dtype\u001b[39m=\u001b[39;49mdtype, copy\u001b[39m=\u001b[39;49mcopy, typ\u001b[39m=\u001b[39;49mmanager)\n\u001b[1;32m 664\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, ma\u001b[39m.\u001b[39mMaskedArray):\n\u001b[1;32m 665\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mma\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmrecords\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mmrecords\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:493\u001b[0m, in \u001b[0;36mdict_to_mgr\u001b[0;34m(data, index, columns, dtype, typ, copy)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 490\u001b[0m \u001b[39m# dtype check to exclude e.g. range objects, scalars\u001b[39;00m\n\u001b[1;32m 491\u001b[0m arrays \u001b[39m=\u001b[39m [x\u001b[39m.\u001b[39mcopy() \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(x, \u001b[39m\"\u001b[39m\u001b[39mdtype\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39melse\u001b[39;00m x \u001b[39mfor\u001b[39;00m x \u001b[39min\u001b[39;00m arrays]\n\u001b[0;32m--> 493\u001b[0m \u001b[39mreturn\u001b[39;00m arrays_to_mgr(arrays, columns, index, dtype\u001b[39m=\u001b[39;49mdtype, typ\u001b[39m=\u001b[39;49mtyp, consolidate\u001b[39m=\u001b[39;49mcopy)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:118\u001b[0m, in \u001b[0;36marrays_to_mgr\u001b[0;34m(arrays, columns, index, dtype, verify_integrity, typ, consolidate)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[39mif\u001b[39;00m verify_integrity:\n\u001b[1;32m 116\u001b[0m \u001b[39m# figure out the index, if necessary\u001b[39;00m\n\u001b[1;32m 117\u001b[0m \u001b[39mif\u001b[39;00m index \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 118\u001b[0m index \u001b[39m=\u001b[39m _extract_index(arrays)\n\u001b[1;32m 119\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 120\u001b[0m index \u001b[39m=\u001b[39m ensure_index(index)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:666\u001b[0m, in \u001b[0;36m_extract_index\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 664\u001b[0m lengths \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(\u001b[39mset\u001b[39m(raw_lengths))\n\u001b[1;32m 665\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(lengths) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m--> 666\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mAll arrays must be of the same length\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 668\u001b[0m \u001b[39mif\u001b[39;00m have_dicts:\n\u001b[1;32m 669\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 670\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mMixing dicts with non-Series may lead to ambiguous ordering.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 671\u001b[0m )\n", + "\u001b[0;31mValueError\u001b[0m: All arrays must be of the same length" + ] } ], "source": [ "# minimal witness size becomes non-trivial here\n", "# we only record different minimal difference-making scenarios\n", "\n", - "responsibility_stones_sally_HPR.responsibilityDF" + "responsibility_check(stones_sallyHPR)" ] }, { @@ -1655,7 +2705,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.9" }, "orig_nbformat": 4 }, From 3814b58b5c1b5884aab0ea52422ab9e85619eb07 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Wed, 23 Aug 2023 13:32:30 +0200 Subject: [PATCH 08/13] blocked with responsibility --- docs/source/responsibility.ipynb | 2367 ++++++++++++++++++++---------- 1 file changed, 1593 insertions(+), 774 deletions(-) diff --git a/docs/source/responsibility.ipynb b/docs/source/responsibility.ipynb index fc2c88c3..370183db 100644 --- a/docs/source/responsibility.ipynb +++ b/docs/source/responsibility.ipynb @@ -155,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ " witness_candidates: List[str],\n", " outcome: str,\n", " observations: Optional[Dict[str, torch.Tensor]] = None,\n", - " bias_t: float = .2\n", + " bias_n: float = .2\n", " ):\n", " \n", " if observations is None:\n", @@ -184,8 +184,8 @@ " self.witness_candidates = witness_candidates\n", " self.outcome = outcome\n", " self.observations = observations\n", - " self.bias_t = bias_t\n", - " self.bias_n = self.find_max_bias_within(self.bias_t, 1)\n", + " self.bias_t = .2\n", + " self.bias_n = self.find_max_bias_within(self.bias_t, len(self.treatment_candidates))\n", " self.bias_w = self.find_max_bias_within(self.bias_n, len(self.witness_candidates))\n", "\n", " self.evaluated_node_preemptions = {node: functools.partial(self.preempt_with_factual,\n", @@ -245,27 +245,32 @@ " def __call__(self, *args, **kwargs):\n", " print(\"Preemption biases used (upper) - t:\",.5+ self.bias_t, \", n:\", .5 + self.bias_n, \", w:\", .5 + self.bias_w, \".\")\n", " with MultiWorldCounterfactual():\n", - " # the last element of the tensor is the factual case (preempted)\n", - " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),\n", - " prefix = \"__witness_split_\"):\n", - " with do(actions=self.evaluated_node_counterfactual):\n", - " with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", + " with do(actions=self.evaluated_node_counterfactual):\n", + " with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", " prefix = \"__evaluated_split_\"):\n", - " with do(actions=self.treatment_candidates):\n", - " with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),\n", + " with do(actions=self.treatment_candidates):\n", + " with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),\n", " prefix = \"__treatment_split_\"):\n", - " # the last element is the fixed at the observed value (preempted) \n", + " # the last element is the fixed at the observed value (preempted) \n", + " # the last element of the tensor is the factual case (preempted)\n", + " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),\n", + " prefix = \"__witness_split_\"):\n", + "\n", " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", " with pyro.poutine.trace() as self.trace:\n", " self.run = self.model(*args, **kwargs)\n", " self.consequent = self.run[self.outcome]\n", - " self.active_interventions = { intervention: {1} for intervention \n", + " self.interventionIndex = { intervention: {1} for intervention \n", " in list(self.evaluated_node_counterfactual.keys()) + \n", - " list(self.treatment_candidates.keys())}\n", + " list(self.treatment_candidates.keys()) + self.witness_candidates}\n", + " \n", + " self.observedIndex = {node: {0} for node in list(self.evaluated_node_counterfactual.keys()) + \n", + " list(self.treatment_candidates.keys()) + self.witness_candidates}\n", + "\n", + " \n", + " self.intervened_consequent = gather(self.consequent, IndexSet(**self.interventionIndex))\n", " \n", - " self.intervened_consequent = gather(self.consequent, IndexSet(**{ant: {1} for ant in self.all_interventions}))\n", - " print(\"ic:\", self.intervened_consequent)\n", - " self.observed_consequent = gather(self.consequent, IndexSet(**{ant: {0} for ant in self.all_interventions}))\n", + " self.observed_consequent = gather(self.consequent, IndexSet(**self.observedIndex))\n", " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", " pyro.deterministic(\"consequent_differs_binary\", self.consequent_differs, event_dim = 0) #feels inelegant\n", " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))" @@ -273,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -310,7 +315,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -377,7 +382,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -473,7 +478,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -510,77 +515,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.6026881514681908 .\n", - "ic: tensor([[[[[[[1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0.,\n", - " 1., 0., 0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.,\n", - " 0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1.,\n", - " 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1.,\n", - " 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1.,\n", - " 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0.,\n", - " 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 0.,\n", - " 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", - " 1., 0., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1.,\n", - " 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1., 1.,\n", - " 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1.,\n", - " 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,\n", - " 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1.,\n", - " 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1.,\n", - " 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,\n", - " 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0.,\n", - " 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1., 1.,\n", - " 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1.,\n", - " 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,\n", - " 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 0.,\n", - " 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,\n", - " 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1., 0.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0.,\n", - " 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0.,\n", - " 0., 1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1.,\n", - " 1., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0.,\n", - " 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1.,\n", - " 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0., 1.,\n", - " 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1.,\n", - " 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0.,\n", - " 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1.,\n", - " 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1.,\n", - " 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,\n", - " 1., 1., 1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1.,\n", - " 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 0.,\n", - " 1., 0., 1., 0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0.,\n", - " 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,\n", - " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,\n", - " 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,\n", - " 0., 1., 1., 0., 1., 1., 1., 0.]]]]]]])\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.6024412643276109 , w: 0.5502509213795265 .\n" ] }, { @@ -633,162 +575,162 @@ " 1.0\n", " 0.0\n", " 0\n", - " -1.197\n", + " -0.922\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " True\n", " 0.0\n", " 2\n", " 0\n", " 2\n", - " -3.771000e+00\n", + " -3.678000e+00\n", " \n", " \n", " 1\n", " 1.0\n", " 0.0\n", " 0\n", - " -1.197\n", + " -0.922\n", " 1.0\n", " 0.0\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " True\n", " 0.0\n", " 2\n", " 0\n", " 2\n", - " -3.771000e+00\n", + " -3.678000e+00\n", " \n", " \n", " 2\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.360\n", - " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.506\n", - " ...\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.506\n", - " True\n", - " 0.0\n", - " 2\n", - " 0\n", - " 2\n", - " -3.780000e+00\n", - " \n", - " \n", - " 4\n", - " 1.0\n", " 0.0\n", " 0\n", - " -1.197\n", + " -0.922\n", " 1.0\n", " 0.0\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", " 1\n", " -0.357\n", " 1\n", - " -0.923\n", + " -0.799\n", " True\n", " 0.0\n", " 2\n", " 1\n", " 3\n", - " -4.187000e+00\n", + " -3.880000e+00\n", " \n", " \n", - " 6\n", + " 3\n", " 1.0\n", " 0.0\n", " 0\n", - " -1.197\n", + " -0.922\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 1\n", - " -0.923\n", + " -0.799\n", " ...\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " True\n", " 0.0\n", " 2\n", " 1\n", " 3\n", - " -4.187000e+00\n", + " -3.880000e+00\n", " \n", " \n", - " 12\n", + " 4\n", + " 1.0\n", + " 1.0\n", + " 1\n", + " -0.507\n", + " 1.0\n", + " 0.0\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.597\n", + " ...\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.597\n", + " True\n", + " 0.0\n", + " 2\n", + " 0\n", + " 2\n", + " -4.109000e+00\n", + " \n", + " \n", + " 5\n", " 1.0\n", " 0.0\n", " 0\n", - " -1.197\n", + " -0.922\n", " 1.0\n", " 0.0\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " True\n", " 0.0\n", " 3\n", " 0\n", " 3\n", - " -4.618000e+00\n", + " -4.525000e+00\n", " \n", " \n", - " 16\n", + " 8\n", " 1.0\n", " 1.0\n", " 1\n", - " -0.360\n", + " -0.507\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " False\n", " -100000000.0\n", " 0\n", @@ -797,22 +739,22 @@ " -1.000000e+08\n", " \n", " \n", - " 17\n", + " 9\n", " 1.0\n", " 1.0\n", " 1\n", - " -0.360\n", + " -0.507\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.597\n", " ...\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.799\n", " False\n", " -100000000.0\n", " 0\n", @@ -821,22 +763,22 @@ " -1.000000e+08\n", " \n", " \n", - " 18\n", + " 10\n", " 1.0\n", " 1.0\n", " 1\n", - " -0.360\n", + " -0.507\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.799\n", " ...\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.597\n", " False\n", " -100000000.0\n", " 0\n", @@ -845,22 +787,22 @@ " -1.000000e+08\n", " \n", " \n", - " 19\n", + " 11\n", " 1.0\n", " 1.0\n", " 1\n", - " -0.360\n", + " -0.507\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 1\n", - " -0.923\n", + " -0.799\n", " ...\n", " 1\n", " -0.357\n", " 1\n", - " -0.923\n", + " -0.799\n", " False\n", " -100000000.0\n", " 0\n", @@ -869,22 +811,22 @@ " -1.000000e+08\n", " \n", " \n", - " 20\n", + " 12\n", " 1.0\n", " 0.0\n", " 0\n", - " -1.197\n", + " -0.922\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " False\n", " -100000000.0\n", " 1\n", @@ -893,142 +835,142 @@ " -1.000000e+08\n", " \n", " \n", - " 21\n", - " 1.0\n", - " 1.0\n", - " 1\n", - " -0.360\n", + " 13\n", " 1.0\n", " 0.0\n", " 0\n", - " -1.204\n", - " 0\n", - " -0.506\n", + " -0.922\n", + " 1.0\n", + " 1.0\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.799\n", " ...\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " False\n", " -100000000.0\n", " 1\n", - " 0\n", " 1\n", + " 2\n", " -1.000000e+08\n", " \n", " \n", - " 22\n", - " 1.0\n", + " 14\n", " 1.0\n", - " 1\n", - " -0.360\n", + " 0.0\n", + " 0\n", + " -0.922\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.799\n", " False\n", " -100000000.0\n", " 1\n", - " 0\n", " 1\n", + " 2\n", " -1.000000e+08\n", " \n", " \n", - " 23\n", + " 15\n", " 1.0\n", " 0.0\n", " 0\n", - " -1.197\n", + " -0.922\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 1\n", - " -0.923\n", + " -0.799\n", " ...\n", " 1\n", " -0.357\n", - " 0\n", - " -0.506\n", + " 1\n", + " -0.799\n", " False\n", " -100000000.0\n", " 1\n", - " 1\n", " 2\n", + " 3\n", " -1.000000e+08\n", " \n", " \n", - " 24\n", + " 16\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.197\n", + " 1.0\n", + " 1\n", + " -0.507\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.597\n", " False\n", " -100000000.0\n", " 1\n", + " 0\n", " 1\n", - " 2\n", " -1.000000e+08\n", " \n", " \n", - " 26\n", + " 17\n", " 1.0\n", " 1.0\n", " 1\n", - " -0.360\n", + " -0.507\n", " 1.0\n", " 0.0\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " ...\n", " 1\n", " -0.357\n", - " 1\n", - " -0.923\n", + " 0\n", + " -0.597\n", " False\n", " -100000000.0\n", " 1\n", + " 0\n", " 1\n", - " 2\n", " -1.000000e+08\n", " \n", " \n", - " 27\n", + " 18\n", " 1.0\n", " 1.0\n", " 1\n", - " -0.360\n", + " -0.507\n", " 1.0\n", " 1.0\n", " 1\n", " -0.357\n", " 1\n", - " -0.923\n", + " -0.799\n", " ...\n", " 0\n", " -1.204\n", " 0\n", - " -0.506\n", + " -0.597\n", " False\n", " -100000000.0\n", " 1\n", @@ -1037,27 +979,27 @@ " -1.000000e+08\n", " \n", " \n", - " 29\n", - " 1.0\n", - " 0.0\n", - " 0\n", - " -1.197\n", + " 19\n", " 1.0\n", " 1.0\n", " 1\n", - " -0.357\n", - " 1\n", - " -0.923\n", + " -0.507\n", + " 1.0\n", + " 0.0\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.597\n", " ...\n", " 1\n", " -0.357\n", " 1\n", - " -0.923\n", + " -0.799\n", " False\n", " -100000000.0\n", " 1\n", + " 1\n", " 2\n", - " 3\n", " -1.000000e+08\n", " \n", " \n", @@ -1067,69 +1009,69 @@ ], "text/plain": [ " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -1.197 1.0 1.0 \n", - "1 1.0 0.0 0 -1.197 1.0 0.0 \n", - "2 1.0 1.0 1 -0.360 1.0 0.0 \n", - "4 1.0 0.0 0 -1.197 1.0 0.0 \n", - "6 1.0 0.0 0 -1.197 1.0 1.0 \n", - "12 1.0 0.0 0 -1.197 1.0 0.0 \n", - "16 1.0 1.0 1 -0.360 1.0 1.0 \n", - "17 1.0 1.0 1 -0.360 1.0 1.0 \n", - "18 1.0 1.0 1 -0.360 1.0 1.0 \n", - "19 1.0 1.0 1 -0.360 1.0 1.0 \n", - "20 1.0 0.0 0 -1.197 1.0 1.0 \n", - "21 1.0 1.0 1 -0.360 1.0 0.0 \n", - "22 1.0 1.0 1 -0.360 1.0 1.0 \n", - "23 1.0 0.0 0 -1.197 1.0 1.0 \n", - "24 1.0 0.0 0 -1.197 1.0 1.0 \n", - "26 1.0 1.0 1 -0.360 1.0 0.0 \n", - "27 1.0 1.0 1 -0.360 1.0 1.0 \n", - "29 1.0 0.0 0 -1.197 1.0 1.0 \n", + "0 1.0 0.0 0 -0.922 1.0 1.0 \n", + "1 1.0 0.0 0 -0.922 1.0 0.0 \n", + "2 1.0 0.0 0 -0.922 1.0 0.0 \n", + "3 1.0 0.0 0 -0.922 1.0 1.0 \n", + "4 1.0 1.0 1 -0.507 1.0 0.0 \n", + "5 1.0 0.0 0 -0.922 1.0 0.0 \n", + "8 1.0 1.0 1 -0.507 1.0 1.0 \n", + "9 1.0 1.0 1 -0.507 1.0 1.0 \n", + "10 1.0 1.0 1 -0.507 1.0 1.0 \n", + "11 1.0 1.0 1 -0.507 1.0 1.0 \n", + "12 1.0 0.0 0 -0.922 1.0 1.0 \n", + "13 1.0 0.0 0 -0.922 1.0 1.0 \n", + "14 1.0 0.0 0 -0.922 1.0 1.0 \n", + "15 1.0 0.0 0 -0.922 1.0 1.0 \n", + "16 1.0 1.0 1 -0.507 1.0 1.0 \n", + "17 1.0 1.0 1 -0.507 1.0 0.0 \n", + "18 1.0 1.0 1 -0.507 1.0 1.0 \n", + "19 1.0 1.0 1 -0.507 1.0 0.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", - "0 1 -0.357 0 -0.506 ... 0 -1.204 \n", - "1 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "2 0 -1.204 0 -0.506 ... 0 -1.204 \n", - "4 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "6 1 -0.357 1 -0.923 ... 0 -1.204 \n", - "12 0 -1.204 0 -0.506 ... 0 -1.204 \n", - "16 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "17 1 -0.357 1 -0.923 ... 1 -0.357 \n", - "18 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "19 1 -0.357 1 -0.923 ... 1 -0.357 \n", - "20 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "21 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "22 1 -0.357 0 -0.506 ... 0 -1.204 \n", - "23 1 -0.357 1 -0.923 ... 1 -0.357 \n", - "24 1 -0.357 0 -0.506 ... 1 -0.357 \n", - "26 0 -1.204 0 -0.506 ... 1 -0.357 \n", - "27 1 -0.357 1 -0.923 ... 0 -1.204 \n", - "29 1 -0.357 1 -0.923 ... 1 -0.357 \n", + "0 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "1 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "2 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "3 1 -0.357 1 -0.799 ... 0 -1.204 \n", + "4 0 -1.204 0 -0.597 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.597 ... 0 -1.204 \n", + "8 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "9 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "10 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "11 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "12 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "13 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "14 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "15 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "16 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "17 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "18 1 -0.357 1 -0.799 ... 0 -1.204 \n", + "19 0 -1.204 0 -0.597 ... 1 -0.357 \n", "\n", " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", - "0 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", - "1 0 -0.506 True 0.0 2 0 2 -3.771000e+00 \n", - "2 0 -0.506 True 0.0 2 0 2 -3.780000e+00 \n", - "4 1 -0.923 True 0.0 2 1 3 -4.187000e+00 \n", - "6 0 -0.506 True 0.0 2 1 3 -4.187000e+00 \n", - "12 0 -0.506 True 0.0 3 0 3 -4.618000e+00 \n", - "16 0 -0.506 False -100000000.0 0 0 0 -1.000000e+08 \n", - "17 0 -0.506 False -100000000.0 0 1 1 -1.000000e+08 \n", - "18 1 -0.923 False -100000000.0 0 1 1 -1.000000e+08 \n", - "19 1 -0.923 False -100000000.0 0 2 2 -1.000000e+08 \n", - "20 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", - "21 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", - "22 0 -0.506 False -100000000.0 1 0 1 -1.000000e+08 \n", - "23 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", - "24 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", - "26 1 -0.923 False -100000000.0 1 1 2 -1.000000e+08 \n", - "27 0 -0.506 False -100000000.0 1 1 2 -1.000000e+08 \n", - "29 1 -0.923 False -100000000.0 1 2 3 -1.000000e+08 \n", + "0 0 -0.597 True 0.0 2 0 2 -3.678000e+00 \n", + "1 0 -0.597 True 0.0 2 0 2 -3.678000e+00 \n", + "2 1 -0.799 True 0.0 2 1 3 -3.880000e+00 \n", + "3 0 -0.597 True 0.0 2 1 3 -3.880000e+00 \n", + "4 0 -0.597 True 0.0 2 0 2 -4.109000e+00 \n", + "5 0 -0.597 True 0.0 3 0 3 -4.525000e+00 \n", + "8 0 -0.597 False -100000000.0 0 0 0 -1.000000e+08 \n", + "9 1 -0.799 False -100000000.0 0 1 1 -1.000000e+08 \n", + "10 0 -0.597 False -100000000.0 0 1 1 -1.000000e+08 \n", + "11 1 -0.799 False -100000000.0 0 2 2 -1.000000e+08 \n", + "12 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "13 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", + "14 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", + "15 1 -0.799 False -100000000.0 1 2 3 -1.000000e+08 \n", + "16 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "17 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "18 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", + "19 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", "\n", "[18 rows x 22 columns]" ] }, - "execution_count": 8, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -1154,7 +1096,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -1164,16 +1106,16 @@ "MAP estimate: 0.5\n", "Minimal scenarios:\n", " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -1.197 1.0 1.0 \n", - "1 1.0 0.0 0 -1.197 1.0 0.0 \n", + "0 1.0 0.0 0 -0.922 1.0 1.0 \n", + "1 1.0 0.0 0 -0.922 1.0 0.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", - "0 1 -0.357 0 -0.506 ... 0 -1.204 \n", - "1 0 -1.204 0 -0.506 ... 1 -0.357 \n", + "0 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "1 0 -1.204 0 -0.597 ... 1 -0.357 \n", "\n", " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", - "0 0 -0.506 True 0.0 2 0 2 -3.771 \n", - "1 0 -0.506 True 0.0 2 0 2 -3.771 \n", + "0 0 -0.597 True 0.0 2 0 2 -3.678 \n", + "1 0 -0.597 True 0.0 2 0 2 -3.678 \n", "\n", "[2 rows x 22 columns]\n", "Secondary check: 0.5\n" @@ -1185,7 +1127,7 @@ "0.5" ] }, - "execution_count": 9, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -1196,7 +1138,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -1228,15 +1170,14 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.5347741090651572 .\n", - "ic: tensor([[[[[[[[[[[1., 1., 0., ..., 1., 1., 1.]]]]]]]]]]])\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.5337756999955469 , w: 0.5037726755473835 .\n" ] } ], @@ -1273,7 +1214,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -1283,42 +1224,33 @@ "MAP estimate: 0.25\n", "Minimal scenarios:\n", " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -1.197 1.0 1.0 \n", - "1 1.0 0.0 0 -1.197 1.0 0.0 \n", - "2 1.0 0.0 0 -1.197 1.0 0.0 \n", - "3 1.0 0.0 0 -1.197 1.0 0.0 \n", - "4 1.0 0.0 0 -1.197 1.0 1.0 \n", - "5 1.0 0.0 0 -1.197 1.0 0.0 \n", - "6 1.0 0.0 0 -1.197 1.0 1.0 \n", - "7 1.0 0.0 0 -1.197 1.0 1.0 \n", - "8 1.0 0.0 0 -1.197 1.0 1.0 \n", - "9 1.0 0.0 0 -1.197 1.0 0.0 \n", + "0 1.0 0.0 0 -0.763 1.0 0.0 \n", + "1 1.0 0.0 0 -0.763 1.0 1.0 \n", + "2 1.0 0.0 0 -0.763 1.0 1.0 \n", + "3 1.0 0.0 0 -0.763 1.0 0.0 \n", + "4 1.0 0.0 0 -0.763 1.0 1.0 \n", + "5 1.0 0.0 0 -0.763 1.0 0.0 \n", + "6 1.0 0.0 0 -0.763 1.0 0.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote6 alp_vote6 \\\n", - "0 1 -0.357 0 -0.626 ... 1 -0.357 \n", - "1 0 -1.204 0 -0.626 ... 0 -1.204 \n", - "2 0 -1.204 0 -0.626 ... 1 -0.357 \n", - "3 0 -1.204 0 -0.626 ... 1 -0.357 \n", - "4 1 -0.357 0 -0.626 ... 0 -1.204 \n", - "5 0 -1.204 0 -0.626 ... 0 -1.204 \n", - "6 1 -0.357 0 -0.626 ... 1 -0.357 \n", - "7 1 -0.357 0 -0.626 ... 1 -0.357 \n", - "8 1 -0.357 0 -0.626 ... 0 -1.204 \n", - "9 0 -1.204 0 -0.626 ... 1 -0.357 \n", + "0 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "1 1 -0.357 0 -0.686 ... 0 -1.204 \n", + "2 1 -0.357 0 -0.686 ... 1 -0.357 \n", + "3 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "4 1 -0.357 0 -0.686 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "6 0 -1.204 0 -0.686 ... 1 -0.357 \n", "\n", " wpr_vote6 wlp_vote6 cdif clp int wpr changes sum_lp \n", - "0 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "1 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "2 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "3 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "4 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "5 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "6 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "7 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "8 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "9 0 -0.626 True 0.0 4 0 4 -9.635 \n", - "\n", - "[10 rows x 46 columns]\n", + "0 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "1 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "2 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "3 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "4 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "5 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "6 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "\n", + "[7 rows x 46 columns]\n", "Secondary check: 0.25\n" ] }, @@ -1328,7 +1260,7 @@ "0.25" ] }, - "execution_count": 12, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -1355,62 +1287,7 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'sally_throws': tensor(1.), 'bill_throws': tensor(1.), 'sally_hits': tensor(0.), 'bill_hits': tensor(1.), 'bottle_shatters': tensor(1)}\n" - ] - } - ], - "source": [ - "import pyro\n", - "import pyro.distributions as dist\n", - "import torch\n", - "\n", - "def stones_model2(): \n", - " prob_sally_throws = pyro.sample(\"prob_sally_throws\", dist.Beta(1, 1))\n", - " prob_bill_throws = pyro.sample(\"prob_bill_throws\", dist.Beta(1, 1))\n", - " prob_sally_hits = pyro.sample(\"prob_sally_hits\", dist.Beta(1, 1))\n", - " prob_bill_hits = pyro.sample(\"prob_bill_hits\", dist.Beta(1, 1))\n", - " prob_bottle_shatters_if_sally = pyro.sample(\"prob_bottle_shatters_if_sally\", dist.Beta(1, 1))\n", - " prob_bottle_shatters_if_bill = pyro.sample(\"prob_bottle_shatters_if_bill\", dist.Beta(1, 1))\n", - "\n", - " sally_throws = pyro.sample(\"sally_throws\", dist.Bernoulli(prob_sally_throws))\n", - " bill_throws = pyro.sample(\"bill_throws\", dist.Bernoulli(prob_bill_throws))\n", - " sally_hits = pyro.sample(\"sally_hits\", dist.Bernoulli(prob_sally_hits * sally_throws))\n", - " bill_hits = pyro.sample(\"bill_hits\", dist.Bernoulli(prob_bill_hits * bill_throws * (1 - sally_hits)))\n", - "\n", - " bottle_shatters = pyro.sample(\"bottle_shatters\", dist.Bernoulli(\n", - " prob_bottle_shatters_if_bill * bill_hits + prob_bottle_shatters_if_sally * sally_hits\n", - " )).long()\n", - "\n", - " return {\n", - " \"sally_throws\": sally_throws,\n", - " \"bill_throws\": bill_throws,\n", - " \"sally_hits\": sally_hits,\n", - " \"bill_hits\": bill_hits,\n", - " \"bottle_shatters\": bottle_shatters,\n", - " }\n", - "\n", - "stones_model2.nodes = [\n", - " \"sally_throws\",\n", - " \"bill_throws\",\n", - " \"sally_hits\",\n", - " \"bill_hits\",\n", - " \"bottle_shatters\",\n", - "]\n", - "\n", - "result = stones_model2()\n", - "print(result)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -1469,15 +1346,14 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.567842770667782 .\n", - "ic: tensor([[[[[[[[1., 1., 1., 1., 1., 0., 1., 1., 1., 1.]]]]]]]])\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.5693194340229132 , w: 0.5214645870036448 .\n" ] } ], @@ -1497,258 +1373,13 @@ " \"prob_bottle_shatters_if_sally\": 1.0,\n", " \"prob_bottle_shatters_if_bill\": 1.0})\n", "\n", - "with pyro.plate(\"runs\",10):\n", + "with pyro.plate(\"runs\",10000):\n", " stones_sallyHPR()" ] }, { "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "odict_keys(['prob_sally_throws', 'prob_bill_throws', 'prob_sally_hits', 'prob_bill_hits', 'prob_bottle_shatters_if_sally', 'prob_bottle_shatters_if_bill', '__evaluated_split_sally_throws', 'sally_throws', '__witness_split_bill_throws', '__treatment_split_bill_throws', 'bill_throws', '__witness_split_sally_hits', '__treatment_split_sally_hits', 'sally_hits', '__witness_split_bill_hits', '__treatment_split_bill_hits', 'bill_hits', 'bottle_shatters', 'consequent_differs_binary', 'consequent_differs'])\n" - ] - } - ], - "source": [ - "print(stones_sallyHPR.trace.trace.nodes.keys())" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bh_split tensor([0, 1, 1, 0, 0, 0, 0, 1, 0, 0])\n" - ] - } - ], - "source": [ - "\n", - "print(\"bh_split\", stones_sallyHPR.trace.trace.nodes['__witness_split_bill_hits']['value'])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bh tensor([[[[[[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 1., 0., 1., 0., 1., 0., 0.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]]],\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 1., 1., 1., 0., 0., 0., 1.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 0., 1., 0., 0., 1., 0., 1.]]]]]]],\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 0.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 1., 1., 1., 0., 1., 1., 0.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 0.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 0.]]]]]],\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 1., 1., 1., 0., 0., 1., 1.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[0., 0., 0., 1., 1., 0., 0., 0., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[0., 0., 0., 1., 1., 0., 0., 1., 1., 1.]]]]]]]])\n" - ] - } - ], - "source": [ - "print(\"bh\",stones_sallyHPR.trace.trace.nodes['bill_hits']['value'])\n" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "sh tensor([[[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 0., 1., 0., 1., 0., 1., 1.]]]]]],\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[[1., 1., 1., 1., 0., 1., 1., 1., 1., 0.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 0., 0., 0., 1., 0., 1., 0.]]]]]]])\n" - ] - } - ], - "source": [ - "\n", - "print(\"sh\", stones_sallyHPR.trace.trace.nodes['sally_hits']['value'])\n", - "\n", - "#stones_sallyHPR.trace.trace.nodes['bottle_shatters']\n" - ] - }, - { - "cell_type": "code", - "execution_count": 59, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "con tensor([[[[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 0., 1., 0., 1., 0., 1., 1.]]]]]],\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 1., 1., 1., 1., 0., 1., 1.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 0., 1., 0., 1., 1., 1., 1.]]]]]]],\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 1., 1., 0., 1., 0., 1., 1.]]]]]],\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 1., 1., 1., 1., 0., 1., 1.]]]]],\n", - "\n", - "\n", - "\n", - "\n", - " [[[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]],\n", - "\n", - "\n", - "\n", - " [[[[1., 1., 1., 1., 1., 0., 1., 1., 1., 1.]]]]]]]])\n" - ] - } - ], - "source": [ - "print(\"con\", stones_sallyHPR.trace.trace.nodes['bottle_shatters']['value'])" - ] - }, - { - "cell_type": "code", - "execution_count": 60, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1792,112 +1423,992 @@ " changes\n", " sum_lp\n", " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " 0\n", + " 1\n", + " -0.563\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0.0\n", + " 1\n", + " 1\n", + " 2\n", + " -4.520000e+00\n", + " \n", + " \n", + " 1\n", + " 1\n", + " -0.563\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 0.0\n", + " 1\n", + " 2\n", + " 3\n", + " -4.606000e+00\n", + " \n", + " \n", + " 2\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0.0\n", + " 2\n", + " 0\n", + " 2\n", + " -4.713000e+00\n", + " \n", + " \n", + " 3\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0.0\n", + " 2\n", + " 1\n", + " 3\n", + " -4.799000e+00\n", + " \n", + " \n", + " 4\n", + " 1\n", + " -0.563\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0.0\n", + " 2\n", + " 0\n", + " 2\n", + " -5.281000e+00\n", + " \n", + " \n", + " 5\n", + " 1\n", + " -0.563\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0.0\n", + " 2\n", + " 1\n", + " 3\n", + " -5.367000e+00\n", + " \n", + " \n", + " 9\n", + " 0\n", + " -0.842\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0.0\n", + " 3\n", + " 0\n", + " 3\n", + " -5.560000e+00\n", + " \n", + " \n", + " 13\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 0\n", + " 0\n", + " 0\n", + " -1.000000e+08\n", + " \n", + " \n", + " 14\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 0\n", + " 1\n", + " 1\n", + " -1.000000e+08\n", + " \n", + " \n", + " 15\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 0\n", + " 1\n", + " 1\n", + " -1.000000e+08\n", + " \n", + " \n", + " 16\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 0\n", + " 1\n", + " 1\n", + " -1.000000e+08\n", + " \n", + " \n", + " 17\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 0\n", + " 2\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 18\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 0\n", + " 2\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 19\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 0\n", + " 2\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 20\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 0\n", + " 3\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 21\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 0\n", + " 1\n", + " -1.000000e+08\n", + " \n", + " \n", + " 22\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 23\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 24\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 25\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 1\n", + " 2\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 26\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 2\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 27\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 1\n", + " 2\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 28\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 1\n", + " 3\n", + " 4\n", + " -1.000000e+08\n", + " \n", + " \n", + " 29\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 0\n", + " 1\n", + " -1.000000e+08\n", + " \n", + " \n", + " 30\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 0\n", + " 1\n", + " -1.000000e+08\n", + " \n", + " \n", + " 31\n", + " 1\n", + " -0.563\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 0\n", + " 1\n", + " -1.000000e+08\n", + " \n", + " \n", + " 32\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 33\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 36\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 37\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 38\n", + " 1\n", + " -0.563\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 1\n", + " 1\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 40\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 1\n", + " 2\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 47\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 1\n", + " 2\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 51\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 2\n", + " 0\n", + " 2\n", + " -1.000000e+08\n", + " \n", " \n", - " 0\n", - " 1\n", - " -0.360\n", + " 52\n", + " 0\n", + " -0.842\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", " 1\n", " -0.357\n", + " 0\n", + " -0.651\n", " 1\n", - " -0.839\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 2\n", + " 0\n", + " 2\n", + " -1.000000e+08\n", + " \n", + " \n", + " 53\n", + " 0\n", + " -0.842\n", " 1\n", " -0.357\n", " 0\n", - " -0.566\n", + " -0.651\n", " 0\n", " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", " 1\n", - " -0.839\n", + " -0.737\n", " -100000000.0\n", + " 2\n", + " 1\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 54\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", " 2\n", + " 1\n", " 3\n", " -1.000000e+08\n", " \n", " \n", - " 1\n", + " 57\n", + " 0\n", + " -0.842\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", + " 2\n", " 1\n", - " -0.360\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 59\n", + " 0\n", + " -0.842\n", " 1\n", " -0.357\n", " 0\n", - " -0.566\n", + " -0.651\n", " 1\n", " -0.357\n", " 1\n", - " -0.839\n", + " -0.737\n", " 0\n", " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", " -100000000.0\n", + " 2\n", + " 1\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 60\n", + " 0\n", + " -0.842\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", " 1\n", + " -0.737\n", " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", " 2\n", + " 1\n", + " 3\n", " -1.000000e+08\n", " \n", " \n", - " 2\n", + " 61\n", + " 0\n", + " -0.842\n", + " 1\n", + " -0.357\n", " 1\n", - " -0.360\n", + " -0.737\n", " 1\n", " -0.357\n", " 1\n", - " -0.839\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 2\n", + " 2\n", + " 4\n", + " -1.000000e+08\n", + " \n", + " \n", + " 67\n", + " 0\n", + " -0.842\n", " 1\n", " -0.357\n", " 1\n", - " -0.839\n", + " -0.737\n", " 0\n", " -1.204\n", + " 0\n", + " -0.651\n", " 1\n", - " -0.839\n", - " -100000000.0\n", + " -0.357\n", " 1\n", - " 3\n", + " -0.737\n", + " -100000000.0\n", + " 2\n", + " 2\n", " 4\n", " -1.000000e+08\n", " \n", " \n", - " 3\n", + " 68\n", + " 0\n", + " -0.842\n", " 0\n", - " -1.197\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", " 1\n", " -0.357\n", " 1\n", - " -0.839\n", + " -0.737\n", + " -100000000.0\n", + " 2\n", + " 2\n", + " 4\n", + " -1.000000e+08\n", + " \n", + " \n", + " 73\n", + " 1\n", + " -0.563\n", + " 1\n", + " -0.357\n", + " 0\n", + " -0.651\n", " 0\n", " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", " 0\n", " -1.204\n", - " 1\n", - " -0.839\n", + " 0\n", + " -0.651\n", " -100000000.0\n", - " 3\n", " 2\n", - " 5\n", + " 0\n", + " 2\n", " -1.000000e+08\n", " \n", " \n", - " 4\n", + " 74\n", " 1\n", - " -0.360\n", + " -0.563\n", " 0\n", " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", " 0\n", " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", " 1\n", " -0.357\n", " 0\n", - " -0.566\n", + " -0.651\n", " -100000000.0\n", " 2\n", " 0\n", @@ -1905,113 +2416,201 @@ " -1.000000e+08\n", " \n", " \n", - " 5\n", - " 0\n", - " -1.197\n", + " 75\n", " 1\n", - " -0.357\n", + " -0.563\n", " 0\n", - " -0.566\n", - " 1\n", - " -0.357\n", + " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", " 0\n", " -1.204\n", " 0\n", - " -0.566\n", - " 0.0\n", - " 2\n", - " 0\n", + " -0.651\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.737\n", + " -100000000.0\n", " 2\n", - " -4.812000e+00\n", + " 1\n", + " 3\n", + " -1.000000e+08\n", " \n", " \n", - " 6\n", + " 77\n", " 1\n", - " -0.360\n", + " -0.563\n", " 1\n", " -0.357\n", + " 1\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", + " -100000000.0\n", + " 2\n", " 1\n", - " -0.357\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 92\n", " 0\n", - " -0.566\n", + " -0.842\n", " 1\n", " -0.357\n", " 0\n", - " -0.566\n", - " -100000000.0\n", + " -0.651\n", " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", " 0\n", + " -0.651\n", + " -100000000.0\n", + " 3\n", " 0\n", + " 3\n", " -1.000000e+08\n", " \n", " \n", - " 7\n", + " 93\n", + " 0\n", + " -0.842\n", " 0\n", - " -1.197\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", " 1\n", " -0.357\n", " 0\n", - " -0.566\n", + " -0.651\n", + " -100000000.0\n", + " 3\n", + " 0\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 95\n", + " 0\n", + " -0.842\n", " 1\n", " -0.357\n", " 1\n", - " -0.839\n", - " 1\n", - " -0.357\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", " -100000000.0\n", + " 3\n", " 1\n", - " 1\n", - " 2\n", + " 4\n", " -1.000000e+08\n", " \n", " \n", - " 8\n", - " 1\n", - " -0.360\n", - " 1\n", - " -0.357\n", + " 96\n", " 0\n", - " -0.566\n", + " -0.842\n", " 0\n", " -1.204\n", " 0\n", - " -0.566\n", + " -0.651\n", " 1\n", " -0.357\n", " 1\n", - " -0.839\n", + " -0.737\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", " -100000000.0\n", + " 3\n", " 1\n", - " 1\n", - " 2\n", + " 4\n", " -1.000000e+08\n", " \n", " \n", - " 9\n", - " 1\n", - " -0.360\n", + " 99\n", + " 0\n", + " -0.842\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", " 0\n", " -1.204\n", - " 1\n", - " -0.839\n", - " 1\n", - " -0.357\n", " 0\n", - " -0.566\n", + " -0.651\n", " 1\n", " -0.357\n", - " 0\n", - " -0.566\n", + " 1\n", + " -0.737\n", " -100000000.0\n", + " 3\n", " 1\n", + " 4\n", + " -1.000000e+08\n", + " \n", + " \n", + " 114\n", " 1\n", - " 2\n", + " -0.563\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 3\n", + " 0\n", + " 3\n", + " -1.000000e+08\n", + " \n", + " \n", + " 120\n", + " 0\n", + " -0.842\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.651\n", + " -100000000.0\n", + " 4\n", + " 0\n", + " 4\n", " -1.000000e+08\n", " \n", " \n", @@ -2019,68 +2618,288 @@ "" ], "text/plain": [ - " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", - "0 1 -0.360 1 -0.357 \n", - "1 1 -0.360 1 -0.357 \n", - "2 1 -0.360 1 -0.357 \n", - "3 0 -1.197 1 -0.357 \n", - "4 1 -0.360 0 -1.204 \n", - "5 0 -1.197 1 -0.357 \n", - "6 1 -0.360 1 -0.357 \n", - "7 0 -1.197 1 -0.357 \n", - "8 1 -0.360 1 -0.357 \n", - "9 1 -0.360 0 -1.204 \n", + " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", + "0 1 -0.563 0 -1.204 \n", + "1 1 -0.563 0 -1.204 \n", + "2 0 -0.842 1 -0.357 \n", + "3 0 -0.842 1 -0.357 \n", + "4 1 -0.563 0 -1.204 \n", + "5 1 -0.563 0 -1.204 \n", + "9 0 -0.842 0 -1.204 \n", + "13 1 -0.563 1 -0.357 \n", + "14 1 -0.563 1 -0.357 \n", + "15 1 -0.563 1 -0.357 \n", + "16 1 -0.563 1 -0.357 \n", + "17 1 -0.563 1 -0.357 \n", + "18 1 -0.563 1 -0.357 \n", + "19 1 -0.563 1 -0.357 \n", + "20 1 -0.563 1 -0.357 \n", + "21 0 -0.842 1 -0.357 \n", + "22 0 -0.842 1 -0.357 \n", + "23 0 -0.842 1 -0.357 \n", + "24 0 -0.842 1 -0.357 \n", + "25 0 -0.842 1 -0.357 \n", + "26 0 -0.842 1 -0.357 \n", + "27 0 -0.842 1 -0.357 \n", + "28 0 -0.842 1 -0.357 \n", + "29 1 -0.563 1 -0.357 \n", + "30 1 -0.563 1 -0.357 \n", + "31 1 -0.563 0 -1.204 \n", + "32 1 -0.563 1 -0.357 \n", + "33 1 -0.563 1 -0.357 \n", + "36 1 -0.563 1 -0.357 \n", + "37 1 -0.563 1 -0.357 \n", + "38 1 -0.563 0 -1.204 \n", + "40 1 -0.563 1 -0.357 \n", + "47 1 -0.563 1 -0.357 \n", + "51 0 -0.842 1 -0.357 \n", + "52 0 -0.842 0 -1.204 \n", + "53 0 -0.842 1 -0.357 \n", + "54 0 -0.842 1 -0.357 \n", + "57 0 -0.842 0 -1.204 \n", + "59 0 -0.842 1 -0.357 \n", + "60 0 -0.842 0 -1.204 \n", + "61 0 -0.842 1 -0.357 \n", + "67 0 -0.842 1 -0.357 \n", + "68 0 -0.842 0 -1.204 \n", + "73 1 -0.563 1 -0.357 \n", + "74 1 -0.563 0 -1.204 \n", + "75 1 -0.563 0 -1.204 \n", + "77 1 -0.563 1 -0.357 \n", + "92 0 -0.842 1 -0.357 \n", + "93 0 -0.842 0 -1.204 \n", + "95 0 -0.842 1 -0.357 \n", + "96 0 -0.842 0 -1.204 \n", + "99 0 -0.842 0 -1.204 \n", + "114 1 -0.563 0 -1.204 \n", + "120 0 -0.842 0 -1.204 \n", "\n", - " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", - "0 1 -0.839 1 -0.357 \n", - "1 0 -0.566 1 -0.357 \n", - "2 1 -0.839 1 -0.357 \n", - "3 1 -0.839 0 -1.204 \n", - "4 0 -0.566 0 -1.204 \n", - "5 0 -0.566 1 -0.357 \n", - "6 0 -0.566 1 -0.357 \n", - "7 0 -0.566 1 -0.357 \n", - "8 0 -0.566 0 -1.204 \n", - "9 1 -0.839 1 -0.357 \n", + " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", + "0 0 -0.651 1 -0.357 \n", + "1 0 -0.651 1 -0.357 \n", + "2 0 -0.651 1 -0.357 \n", + "3 1 -0.737 1 -0.357 \n", + "4 0 -0.651 1 -0.357 \n", + "5 0 -0.651 1 -0.357 \n", + "9 0 -0.651 1 -0.357 \n", + "13 0 -0.651 1 -0.357 \n", + "14 0 -0.651 1 -0.357 \n", + "15 0 -0.651 1 -0.357 \n", + "16 1 -0.737 1 -0.357 \n", + "17 1 -0.737 1 -0.357 \n", + "18 0 -0.651 1 -0.357 \n", + "19 1 -0.737 1 -0.357 \n", + "20 1 -0.737 1 -0.357 \n", + "21 0 -0.651 1 -0.357 \n", + "22 1 -0.737 1 -0.357 \n", + "23 0 -0.651 1 -0.357 \n", + "24 0 -0.651 1 -0.357 \n", + "25 0 -0.651 1 -0.357 \n", + "26 1 -0.737 1 -0.357 \n", + "27 1 -0.737 1 -0.357 \n", + "28 1 -0.737 1 -0.357 \n", + "29 0 -0.651 0 -1.204 \n", + "30 0 -0.651 1 -0.357 \n", + "31 0 -0.651 1 -0.357 \n", + "32 0 -0.651 0 -1.204 \n", + "33 0 -0.651 1 -0.357 \n", + "36 1 -0.737 1 -0.357 \n", + "37 1 -0.737 0 -1.204 \n", + "38 0 -0.651 1 -0.357 \n", + "40 1 -0.737 0 -1.204 \n", + "47 1 -0.737 1 -0.357 \n", + "51 0 -0.651 0 -1.204 \n", + "52 0 -0.651 1 -0.357 \n", + "53 0 -0.651 0 -1.204 \n", + "54 1 -0.737 0 -1.204 \n", + "57 0 -0.651 1 -0.357 \n", + "59 0 -0.651 1 -0.357 \n", + "60 0 -0.651 1 -0.357 \n", + "61 1 -0.737 1 -0.357 \n", + "67 1 -0.737 0 -1.204 \n", + "68 0 -0.651 1 -0.357 \n", + "73 0 -0.651 0 -1.204 \n", + "74 0 -0.651 0 -1.204 \n", + "75 0 -0.651 0 -1.204 \n", + "77 1 -0.737 0 -1.204 \n", + "92 0 -0.651 0 -1.204 \n", + "93 0 -0.651 0 -1.204 \n", + "95 1 -0.737 0 -1.204 \n", + "96 0 -0.651 1 -0.357 \n", + "99 0 -0.651 0 -1.204 \n", + "114 0 -0.651 0 -1.204 \n", + "120 0 -0.651 0 -1.204 \n", "\n", - " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", - "0 0 -0.566 0 -1.204 \n", - "1 1 -0.839 0 -1.204 \n", - "2 1 -0.839 0 -1.204 \n", - "3 0 -0.566 0 -1.204 \n", - "4 0 -0.566 1 -0.357 \n", - "5 0 -0.566 0 -1.204 \n", - "6 0 -0.566 1 -0.357 \n", - "7 1 -0.839 1 -0.357 \n", - "8 0 -0.566 1 -0.357 \n", - "9 0 -0.566 1 -0.357 \n", + " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", + "0 1 -0.737 1 -0.357 \n", + "1 1 -0.737 1 -0.357 \n", + "2 0 -0.651 0 -1.204 \n", + "3 0 -0.651 0 -1.204 \n", + "4 0 -0.651 0 -1.204 \n", + "5 1 -0.737 0 -1.204 \n", + "9 0 -0.651 0 -1.204 \n", + "13 0 -0.651 1 -0.357 \n", + "14 1 -0.737 1 -0.357 \n", + "15 0 -0.651 1 -0.357 \n", + "16 0 -0.651 1 -0.357 \n", + "17 1 -0.737 1 -0.357 \n", + "18 1 -0.737 1 -0.357 \n", + "19 0 -0.651 1 -0.357 \n", + "20 1 -0.737 1 -0.357 \n", + "21 0 -0.651 1 -0.357 \n", + "22 0 -0.651 1 -0.357 \n", + "23 1 -0.737 1 -0.357 \n", + "24 0 -0.651 1 -0.357 \n", + "25 1 -0.737 1 -0.357 \n", + "26 1 -0.737 1 -0.357 \n", + "27 0 -0.651 1 -0.357 \n", + "28 1 -0.737 1 -0.357 \n", + "29 0 -0.651 1 -0.357 \n", + "30 0 -0.651 0 -1.204 \n", + "31 0 -0.651 1 -0.357 \n", + "32 0 -0.651 1 -0.357 \n", + "33 1 -0.737 0 -1.204 \n", + "36 0 -0.651 0 -1.204 \n", + "37 0 -0.651 1 -0.357 \n", + "38 0 -0.651 1 -0.357 \n", + "40 0 -0.651 1 -0.357 \n", + "47 1 -0.737 0 -1.204 \n", + "51 0 -0.651 1 -0.357 \n", + "52 0 -0.651 1 -0.357 \n", + "53 0 -0.651 1 -0.357 \n", + "54 0 -0.651 1 -0.357 \n", + "57 0 -0.651 1 -0.357 \n", + "59 1 -0.737 0 -1.204 \n", + "60 1 -0.737 1 -0.357 \n", + "61 1 -0.737 0 -1.204 \n", + "67 0 -0.651 1 -0.357 \n", + "68 1 -0.737 1 -0.357 \n", + "73 0 -0.651 0 -1.204 \n", + "74 0 -0.651 1 -0.357 \n", + "75 0 -0.651 1 -0.357 \n", + "77 0 -0.651 0 -1.204 \n", + "92 0 -0.651 0 -1.204 \n", + "93 0 -0.651 1 -0.357 \n", + "95 0 -0.651 0 -1.204 \n", + "96 1 -0.737 0 -1.204 \n", + "99 0 -0.651 1 -0.357 \n", + "114 0 -0.651 0 -1.204 \n", + "120 0 -0.651 0 -1.204 \n", "\n", - " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", - "0 1 -0.839 -100000000.0 1 2 3 \n", - "1 0 -0.566 -100000000.0 1 1 2 \n", - "2 1 -0.839 -100000000.0 1 3 4 \n", - "3 1 -0.839 -100000000.0 3 2 5 \n", - "4 0 -0.566 -100000000.0 2 0 2 \n", - "5 0 -0.566 0.0 2 0 2 \n", - "6 0 -0.566 -100000000.0 0 0 0 \n", - "7 0 -0.566 -100000000.0 1 1 2 \n", - "8 1 -0.839 -100000000.0 1 1 2 \n", - "9 0 -0.566 -100000000.0 1 1 2 \n", + " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", + "0 0 -0.651 0.0 1 1 2 \n", + "1 1 -0.737 0.0 1 2 3 \n", + "2 0 -0.651 0.0 2 0 2 \n", + "3 0 -0.651 0.0 2 1 3 \n", + "4 0 -0.651 0.0 2 0 2 \n", + "5 0 -0.651 0.0 2 1 3 \n", + "9 0 -0.651 0.0 3 0 3 \n", + "13 0 -0.651 -100000000.0 0 0 0 \n", + "14 0 -0.651 -100000000.0 0 1 1 \n", + "15 1 -0.737 -100000000.0 0 1 1 \n", + "16 0 -0.651 -100000000.0 0 1 1 \n", + "17 0 -0.651 -100000000.0 0 2 2 \n", + "18 1 -0.737 -100000000.0 0 2 2 \n", + "19 1 -0.737 -100000000.0 0 2 2 \n", + "20 1 -0.737 -100000000.0 0 3 3 \n", + "21 0 -0.651 -100000000.0 1 0 1 \n", + "22 0 -0.651 -100000000.0 1 1 2 \n", + "23 0 -0.651 -100000000.0 1 1 2 \n", + "24 1 -0.737 -100000000.0 1 1 2 \n", + "25 1 -0.737 -100000000.0 1 2 3 \n", + "26 0 -0.651 -100000000.0 1 2 3 \n", + "27 1 -0.737 -100000000.0 1 2 3 \n", + "28 1 -0.737 -100000000.0 1 3 4 \n", + "29 0 -0.651 -100000000.0 1 0 1 \n", + "30 0 -0.651 -100000000.0 1 0 1 \n", + "31 0 -0.651 -100000000.0 1 0 1 \n", + "32 1 -0.737 -100000000.0 1 1 2 \n", + "33 0 -0.651 -100000000.0 1 1 2 \n", + "36 0 -0.651 -100000000.0 1 1 2 \n", + "37 0 -0.651 -100000000.0 1 1 2 \n", + "38 1 -0.737 -100000000.0 1 1 2 \n", + "40 1 -0.737 -100000000.0 1 2 3 \n", + "47 0 -0.651 -100000000.0 1 2 3 \n", + "51 0 -0.651 -100000000.0 2 0 2 \n", + "52 0 -0.651 -100000000.0 2 0 2 \n", + "53 1 -0.737 -100000000.0 2 1 3 \n", + "54 0 -0.651 -100000000.0 2 1 3 \n", + "57 1 -0.737 -100000000.0 2 1 3 \n", + "59 0 -0.651 -100000000.0 2 1 3 \n", + "60 0 -0.651 -100000000.0 2 1 3 \n", + "61 0 -0.651 -100000000.0 2 2 4 \n", + "67 1 -0.737 -100000000.0 2 2 4 \n", + "68 1 -0.737 -100000000.0 2 2 4 \n", + "73 0 -0.651 -100000000.0 2 0 2 \n", + "74 0 -0.651 -100000000.0 2 0 2 \n", + "75 1 -0.737 -100000000.0 2 1 3 \n", + "77 0 -0.651 -100000000.0 2 1 3 \n", + "92 0 -0.651 -100000000.0 3 0 3 \n", + "93 0 -0.651 -100000000.0 3 0 3 \n", + "95 0 -0.651 -100000000.0 3 1 4 \n", + "96 0 -0.651 -100000000.0 3 1 4 \n", + "99 1 -0.737 -100000000.0 3 1 4 \n", + "114 0 -0.651 -100000000.0 3 0 3 \n", + "120 0 -0.651 -100000000.0 4 0 4 \n", "\n", - " sum_lp \n", - "0 -1.000000e+08 \n", - "1 -1.000000e+08 \n", - "2 -1.000000e+08 \n", - "3 -1.000000e+08 \n", - "4 -1.000000e+08 \n", - "5 -4.812000e+00 \n", - "6 -1.000000e+08 \n", - "7 -1.000000e+08 \n", - "8 -1.000000e+08 \n", - "9 -1.000000e+08 " + " sum_lp \n", + "0 -4.520000e+00 \n", + "1 -4.606000e+00 \n", + "2 -4.713000e+00 \n", + "3 -4.799000e+00 \n", + "4 -5.281000e+00 \n", + "5 -5.367000e+00 \n", + "9 -5.560000e+00 \n", + "13 -1.000000e+08 \n", + "14 -1.000000e+08 \n", + "15 -1.000000e+08 \n", + "16 -1.000000e+08 \n", + "17 -1.000000e+08 \n", + "18 -1.000000e+08 \n", + "19 -1.000000e+08 \n", + "20 -1.000000e+08 \n", + "21 -1.000000e+08 \n", + "22 -1.000000e+08 \n", + "23 -1.000000e+08 \n", + "24 -1.000000e+08 \n", + "25 -1.000000e+08 \n", + "26 -1.000000e+08 \n", + "27 -1.000000e+08 \n", + "28 -1.000000e+08 \n", + "29 -1.000000e+08 \n", + "30 -1.000000e+08 \n", + "31 -1.000000e+08 \n", + "32 -1.000000e+08 \n", + "33 -1.000000e+08 \n", + "36 -1.000000e+08 \n", + "37 -1.000000e+08 \n", + "38 -1.000000e+08 \n", + "40 -1.000000e+08 \n", + "47 -1.000000e+08 \n", + "51 -1.000000e+08 \n", + "52 -1.000000e+08 \n", + "53 -1.000000e+08 \n", + "54 -1.000000e+08 \n", + "57 -1.000000e+08 \n", + "59 -1.000000e+08 \n", + "60 -1.000000e+08 \n", + "61 -1.000000e+08 \n", + "67 -1.000000e+08 \n", + "68 -1.000000e+08 \n", + "73 -1.000000e+08 \n", + "74 -1.000000e+08 \n", + "75 -1.000000e+08 \n", + "77 -1.000000e+08 \n", + "92 -1.000000e+08 \n", + "93 -1.000000e+08 \n", + "95 -1.000000e+08 \n", + "96 -1.000000e+08 \n", + "99 -1.000000e+08 \n", + "114 -1.000000e+08 \n", + "120 -1.000000e+08 " ] }, - "execution_count": 60, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -2132,12 +2951,12 @@ " values_df[\"changes\"] = values_df[\"int\"] + values_df[\"wpr\"]\n", "\n", " values_df[\"sum_lp\"] = values_df[summands].sum(axis = 1) \n", - " #values_df.drop_duplicates(inplace = True)\n", - " #values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", + " values_df.drop_duplicates(inplace = True)\n", + " values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", "\n", " tab = values_df.reset_index(drop = True)\n", "\n", - " #tab = remove_redundant_rows(tab)\n", + " tab = remove_redundant_rows(tab)\n", " \n", " if round:\n", " tab = tab.round(3)\n", @@ -2153,7 +2972,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -2200,21 +3019,21 @@ " \n", " \n", " \n", - " 17\n", + " 23\n", " 0\n", - " -1.197\n", + " -0.842\n", " 1\n", " -0.357\n", " 0\n", - " -0.566\n", + " -0.651\n", " 1\n", " -0.357\n", " 1\n", - " -0.839\n", + " -0.737\n", " 1\n", " -0.357\n", " 0\n", - " -0.566\n", + " -0.651\n", " -100000000.0\n", " 1\n", " 1\n", @@ -2227,28 +3046,28 @@ ], "text/plain": [ " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", - "17 0 -1.197 1 -0.357 \n", + "23 0 -0.842 1 -0.357 \n", "\n", " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", - "17 0 -0.566 1 -0.357 \n", + "23 0 -0.651 1 -0.357 \n", "\n", " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", - "17 1 -0.839 1 -0.357 \n", + "23 1 -0.737 1 -0.357 \n", "\n", " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", - "17 0 -0.566 -100000000.0 1 1 2 \n", + "23 0 -0.651 -100000000.0 1 1 2 \n", "\n", " sum_lp \n", - "17 -1.000000e+08 " + "23 -1.000000e+08 " ] }, - "execution_count": 29, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "#this is worrying\n", + "#this is worrying, should be on top with clp == 0\n", "\n", "tab.query(\"epr_sally_throws == 0 & apr_sally_hits == 1 & wpr_sally_hits == 0 & apr_bill_hits == 1 & wpr_bill_hits == 1 & apr_bill_throws == 1 & wpr_bill_throws == 0\")" ] From d1412cbdddb8b0bace383b5a07405f2f6c93831a Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Thu, 24 Aug 2023 09:50:08 +0200 Subject: [PATCH 09/13] expanded responsibility exploration --- docs/source/responsibility.ipynb | 54 -------------------------------- 1 file changed, 54 deletions(-) diff --git a/docs/source/responsibility.ipynb b/docs/source/responsibility.ipynb index 370183db..5ddb636c 100644 --- a/docs/source/responsibility.ipynb +++ b/docs/source/responsibility.ipynb @@ -3215,60 +3215,6 @@ " stones_sallyHPR.witness_candidates)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'type': 'sample', 'name': 'prob_bill_hits', 'fn': Beta(), 'is_observed': True, 'args': (), 'kwargs': {}, 'value': tensor(1.), 'infer': {'_do_not_observe': True}, 'scale': 1.0, 'mask': None, 'cond_indep_stack': (CondIndepStackFrame(name='runs', dim=-1, size=5, counter=0),), 'done': True, 'stop': False, 'continuation': None}\n" - ] - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'obs_sally_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", - " 'int_sally_throws': [1.0, 1.0, 1.0, 1.0, 0.0],\n", - " 'epr_sally_throws': [1, 1, 1, 1, 0],\n", - " 'elp_sally_throws': [-0.35953617095947266,\n", - " -0.35953617095947266,\n", - " -0.35953617095947266,\n", - " -0.35953617095947266,\n", - " -1.1973283290863037],\n", - " 'obs_bill_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", - " 'int_bill_throws': [1.0, 1.0, 1.0, 0.0, 1.0],\n", - " 'apr_bill_throws': [1, 1, 1, 0, 1],\n", - " 'alp_bill_throws': [-0.3566749691963196,\n", - " -0.3566749691963196,\n", - " -0.3566749691963196,\n", - " -1.2039728164672852,\n", - " -0.3566749691963196],\n", - " 'wpr_bill_throws': [0, 0, 1, 1, 1],\n", - " 'wlp_bill_throws': [-0.5659106969833374,\n", - " -0.5659106969833374,\n", - " -0.8389658331871033,\n", - " -0.8389658331871033,\n", - " -0.8389658331871033]}" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, { "cell_type": "code", "execution_count": 63, From 726c9533d880be36f2f9ff6dc93fa5d9746d2075 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Thu, 24 Aug 2023 09:51:36 +0200 Subject: [PATCH 10/13] responsibility wip --- docs/source/.~lock.smalltab.csv# | 1 - 1 file changed, 1 deletion(-) delete mode 100644 docs/source/.~lock.smalltab.csv# diff --git a/docs/source/.~lock.smalltab.csv# b/docs/source/.~lock.smalltab.csv# deleted file mode 100644 index 855da661..00000000 --- a/docs/source/.~lock.smalltab.csv# +++ /dev/null @@ -1 +0,0 @@ -,rafal,pop-os,22.08.2023 11:23,file:///home/rafal/.config/libreoffice/4; \ No newline at end of file From 1d25a0e578509379cbdbdc3b0c0f5f8b3477ed75 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Thu, 24 Aug 2023 09:52:39 +0200 Subject: [PATCH 11/13] responsibility wip --- docs/source/responsibility_exploration.ipynb | 3479 ++++++++++++++++++ 1 file changed, 3479 insertions(+) create mode 100644 docs/source/responsibility_exploration.ipynb diff --git a/docs/source/responsibility_exploration.ipynb b/docs/source/responsibility_exploration.ipynb new file mode 100644 index 00000000..5ddb636c --- /dev/null +++ b/docs/source/responsibility_exploration.ipynb @@ -0,0 +1,3479 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Responsibility and actual causality" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Preceding notebook**\n", + "\n", + "- [Actual Causality: the modified Halpern-Pearl definition]() TODO add link\n", + "\n", + "\n", + "**Summary**\n", + "\n", + "In a previous notebook, we introduced and implemented the Halpern-Pearl modified definition of actual causality. Here we implement the way Halpern used this notion to introduce his so-called *naive definition of responsibility*. We also briefly illustrate some reasons to think a somewhat more sophisticated notion is needed." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Outline**\n", + "\n", + "[Intuitions](##intuitions)\n", + " \n", + "[Formalization](#formalization)\n", + "\n", + "[Implementation](#implementation)\n", + "\n", + "[Examples](#examples)\n", + "\n", + "- [Comments on example selection](#comments-on-example-selection)\n", + " \n", + "- [Voting](#voting)\n", + "\n", + "- [Stone-throwing](#stone-throwing)\n", + "\n", + "- [Firing squad](#firing-squad)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Intuitions" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The key idea here is that your responsibility for an outcome is to be measured in terms of how drastic a change would have to be made to the world for the outcome to depend counterfactually on your actions. However, the definition uses a fairly crude measure thereof, the minimal *number* of changes needed, where those numbers are individuated in terms of nodes. On one hand, if you are part of a cause, we count how many elements the cause has. On the other, we count the number of nodes that a witness set has. We add these two numbers for any combination of an actual cause and a witness set and we take the minimum, say $k$. Your responsibility is then $1/k$. " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Formalization" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The degree of responsibility of $X = x$ for $\\varphi$ in $\\langle M, \\vec{u}\\rangle$ is 0 if $X = x$ is not part of an actual cause of $\\varphi$ in $\\langle M, \\vec{u}\\rangle$ according\n", + "to the modified HP definition. It is $1/k$ if there exists an actual cause $\\vec{X} = \\vec{x}$ of $\\varphi$ and a witness $\\vec{W}$ to $\\vec{X}=\\vec{x}$ being a cause of $\\varphi$ in $\\langle M, \\vec{u}\\rangle$ such that \n", + "(a) $X=x$ is a conjunct in $\\vec{X}= \\vec{x}$, (b) $\\vert \\vec{W}\\vert + \\vert\\vec{X}\\vert = k$, and (c) $k$ is minimal such a number.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "import numpy as np\n", + "from itertools import combinations\n", + "import math\n", + "\n", + "import torch\n", + "from typing import Dict, List, Optional, Union, Callable, Any\n", + "\n", + "import pandas as pd\n", + "\n", + "import pyro\n", + "import pyro.distributions as dist\n", + "\n", + "from chirho.indexed.ops import IndexSet, gather, indices_of, scatter\n", + "from chirho.interventional.handlers import do\n", + "from chirho.counterfactual.ops import preempt, intervene\n", + "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions\n", + "from chirho.observational.handlers import condition\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "class BiasedPreemptions(Preemptions):\n", + " \"\"\"\n", + " Counterfactual handler that preempts the model with a biased coin flip.\n", + " \"\"\"\n", + " def __init__(self, actions, weights: torch.Tensor, event_dim: int = 0, prefix: str = \"__split_\") -> None:\n", + " self.weights = weights\n", + " self.event_dim = event_dim\n", + " self.prefix = prefix\n", + " super().__init__(actions)\n", + "\n", + "\n", + " def _pyro_preempt(self,msg: Dict[str, Any]) -> None:\n", + " if msg[\"name\"] not in self.actions:\n", + " return \n", + "\n", + " obs, acts, case = msg[\"args\"]\n", + " msg[\"kwargs\"][\"name\"] = f\"{self.prefix}{msg['name']}\"\n", + " case_dist = pyro.distributions.Categorical(self.weights)\n", + " #print(msg[\"kwargs\"][\"name\"] , self.prefix, msg['name'], self.weights)\n", + " case = pyro.sample(msg[\"kwargs\"][\"name\"], case_dist, obs=case)\n", + " msg[\"args\"] = (obs, acts, case)\n", + " msg[\"stop\"] = True\n", + "\n", + " def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:\n", + " with pyro.poutine.messenger.block_messengers(\n", + " lambda m : (isinstance(m, Preemptions) and (m is not self))\n", + " ):\n", + " super()._pyro_post_sample(msg) " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class HalpernPearlResponsibilityApproximate:\n", + "\n", + " def __init__(\n", + " self, \n", + " model: Callable,\n", + " evaluated_node_counterfactual: Dict[str, torch.Tensor],\n", + " treatment_candidates: Dict[str, torch.Tensor],\n", + " witness_candidates: List[str],\n", + " outcome: str,\n", + " observations: Optional[Dict[str, torch.Tensor]] = None,\n", + " bias_n: float = .2\n", + " ):\n", + " \n", + " if observations is None:\n", + " observations = {}\n", + "\n", + " #if not set(witness_candidates) <= set(treatment_candidates.keys()):\n", + " # raise ValueError(\"witness_candidates must be a subset of treatment_candidates.keys().\")\n", + " \n", + " self.model = model\n", + " self.evaluated_node_counterfactual = evaluated_node_counterfactual\n", + " self.treatment_candidates = treatment_candidates\n", + " self.witness_candidates = witness_candidates\n", + " self.outcome = outcome\n", + " self.observations = observations\n", + " self.bias_t = .2\n", + " self.bias_n = self.find_max_bias_within(self.bias_t, len(self.treatment_candidates))\n", + " self.bias_w = self.find_max_bias_within(self.bias_n, len(self.witness_candidates))\n", + "\n", + " self.evaluated_node_preemptions = {node: functools.partial(self.preempt_with_factual,\n", + " antecedents = [node]) for\n", + " node in self.evaluated_node_counterfactual.keys()}\n", + "\n", + " self.treatment_preemptions = {antecedent: functools.partial(self.preempt_with_factual,\n", + " antecedents = [antecedent]) for\n", + " antecedent in self.treatment_candidates.keys()}\n", + " \n", + " self.witness_preemptions = {candidate: functools.partial(self.preempt_with_factual,\n", + " antecedents = self.treatment_candidates) for \n", + " candidate in self.witness_candidates}\n", + " \n", + " @staticmethod\n", + " def find_max_bias_within(e: float, n: int,\n", + " max_iterations: int = 1000, learning_rate: float = 0.002):\n", + " \n", + " ediff = math.log(0.5 + e) - math.log(0.5 - e)\n", + " #print(\"up\", math.log(0.5 + e), \"down\", math.log(0.5 - e), \"ediff\", ediff)\n", + "\n", + " w = e\n", + " wdiff = math.log(0.5 + w) - math.log(0.5 - w)\n", + "\n", + " iteration = 0 \n", + " while iteration < max_iterations and ediff <= n * wdiff:\n", + " \n", + " distance = n * wdiff / ediff\n", + " assert w - learning_rate * distance >0 , \"The learning rate is too high.\"\n", + " w -= learning_rate * distance\n", + " \n", + " wdiff = math.log(0.5 + w) - math.log(0.5 - w)\n", + " #print(\"up\", math.log(0.5 + w), \"down\", math.log(0.5 - w), \"wdiff\", wdiff, \"nwdiff\", n * wdiff)\n", + "\n", + " iteration += 1\n", + " \n", + " return w\n", + "\n", + " @staticmethod \n", + " def preempt_with_factual(value: torch.Tensor, *,\n", + " antecedents: List[str] = None, event_dim: int = 0):\n", + " \n", + " if antecedents is None:\n", + " antecedents = []\n", + "\n", + " antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]\n", + "\n", + " factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),\n", + " event_dim=event_dim)\n", + " \n", + " return scatter({\n", + " IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,\n", + " IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,\n", + " }, event_dim=event_dim)\n", + " \n", + " \n", + " def __call__(self, *args, **kwargs):\n", + " print(\"Preemption biases used (upper) - t:\",.5+ self.bias_t, \", n:\", .5 + self.bias_n, \", w:\", .5 + self.bias_w, \".\")\n", + " with MultiWorldCounterfactual():\n", + " with do(actions=self.evaluated_node_counterfactual):\n", + " with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", + " prefix = \"__evaluated_split_\"):\n", + " with do(actions=self.treatment_candidates):\n", + " with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),\n", + " prefix = \"__treatment_split_\"):\n", + " # the last element is the fixed at the observed value (preempted) \n", + " # the last element of the tensor is the factual case (preempted)\n", + " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),\n", + " prefix = \"__witness_split_\"):\n", + "\n", + " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", + " with pyro.poutine.trace() as self.trace:\n", + " self.run = self.model(*args, **kwargs)\n", + " self.consequent = self.run[self.outcome]\n", + " self.interventionIndex = { intervention: {1} for intervention \n", + " in list(self.evaluated_node_counterfactual.keys()) + \n", + " list(self.treatment_candidates.keys()) + self.witness_candidates}\n", + " \n", + " self.observedIndex = {node: {0} for node in list(self.evaluated_node_counterfactual.keys()) + \n", + " list(self.treatment_candidates.keys()) + self.witness_candidates}\n", + "\n", + " \n", + " self.intervened_consequent = gather(self.consequent, IndexSet(**self.interventionIndex))\n", + " \n", + " self.observed_consequent = gather(self.consequent, IndexSet(**self.observedIndex))\n", + " self.consequent_differs = self.intervened_consequent != self.observed_consequent \n", + " pyro.deterministic(\"consequent_differs_binary\", self.consequent_differs, event_dim = 0) #feels inelegant\n", + " pyro.factor(\"consequent_differs\", torch.where(self.consequent_differs, torch.tensor(0.0), torch.tensor(-1e8)))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# only needed for ease of exposition,\n", + "# not for the inference itself\n", + "\n", + "def remove_redundant_rows(tab):\n", + " existing_pairs = []\n", + "\n", + " for col in tab.columns:\n", + " if col[0:4] == \"apr_\":\n", + " ending = col.split(\"apr_\")[1]\n", + " wpr_col = f\"wpr_{ending}\"\n", + " if wpr_col in tab.columns:\n", + " existing_pairs.append((col,wpr_col))\n", + "\n", + " keep = []\n", + " for index, row in tab.iterrows():\n", + " \n", + " flag = True\n", + " for pair in existing_pairs:\n", + " apr_col = pair[0]\n", + " wpr_col = pair[1]\n", + " apr_value = row[apr_col]\n", + " wpr_value = row[wpr_col]\n", + " \n", + " if apr_value == 0 and wpr_value == 1:\n", + " flag = False\n", + " break\n", + " keep.append(flag)\n", + " \n", + " return(tab[keep])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# this will explore the trace once we run inference on the model\n", + "\n", + "def get_table(nodes, evaluated_node, antecedents, witness_candidates, round = True):\n", + " \n", + " values_table = {}\n", + "\n", + " values_table[f\"obs_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][1].squeeze().tolist()\n", + " values_table[f\"epr_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"].squeeze().tolist()\n", + " values_table[f\"elp_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"fn\"].log_prob(nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"]).squeeze().tolist()\n", + "\n", + " for antecedent in antecedents:\n", + " values_table[f\"obs_{antecedent}\"] = nodes[antecedent][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{antecedent}\"] = nodes[antecedent][\"value\"][1].squeeze().tolist()\n", + " values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent][\"fn\"].log_prob(nodes['__treatment_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + "\n", + "\n", + " if f\"__witness_split_{antecedent}\" in nodes.keys():\n", + " values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent][\"fn\"].log_prob(nodes['__witness_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " \n", + " values_table['cdif'] = nodes['consequent_differs_binary'][\"value\"].squeeze().tolist()\n", + " values_table['clp'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " if isinstance(values_table['clp'], float):\n", + " values_df = pd.DataFrame([values_table])\n", + " else:\n", + " values_df = pd.DataFrame(values_table)\n", + " \n", + " values_df = pd.DataFrame(values_table)\n", + "\n", + " summands_ant = ['alp_' + antecedent for antecedent in antecedents]\n", + " summands_wit = ['wlp_' + witness for witness in witness_candidates]\n", + " summands = [f\"elp_{evaluated_node}\"] + summands_ant + summands_wit + ['clp']\n", + " \n", + " \n", + " values_df[\"int\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"apr_\")] == 0), axis=1)\n", + " values_df['int'] = 1 - values_df[f\"epr_{evaluated_node}\"] + values_df[\"int\"]\n", + " values_df[\"wpr\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"wpr_\")] == 1), axis=1)\n", + " values_df[\"changes\"] = values_df[\"int\"] + values_df[\"wpr\"]\n", + "\n", + "\n", + " values_df[\"sum_lp\"] = values_df[summands].sum(axis = 1) \n", + " values_df.drop_duplicates(inplace = True)\n", + " values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", + "\n", + " tab = values_df.reset_index(drop = True)\n", + "\n", + " tab = remove_redundant_rows(tab)\n", + "\n", + " #tab = values_table\n", + "\n", + " if round:\n", + " tab = tab.round(3)\n", + "\n", + " return tab\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def responsibility_check(hpr):\n", + "\n", + " evaluated_node = list(hpr.evaluated_node_counterfactual.keys())[0]\n", + " tab = get_table(hpr.trace.trace.nodes,\n", + " evaluated_node ,\n", + " list(hpr.treatment_candidates.keys()), \n", + " hpr.witness_candidates)\n", + " \n", + " max_sum_lp = tab['sum_lp'].max()\n", + " max_sum_lp_rows = tab[tab['sum_lp'] == max_sum_lp]\n", + "\n", + " map_estimate = 1/ tab['changes'][0]\n", + "\n", + " print (f\"MAP estimate: {map_estimate}\")\n", + "\n", + " # sanity check; consider removing later\n", + " min_changes = max_sum_lp_rows['changes'].min()\n", + " min_changes_row = max_sum_lp_rows[max_sum_lp_rows['changes'] == min_changes]\n", + "\n", + " print(\"Minimal scenarios:\")\n", + " print(min_changes_row)\n", + "\n", + " if not (min_changes_row[f'int_{evaluated_node}'] == 0).any():\n", + " print (f\"No MAP estimate includes intervention on int_{evaluated_node} == 0\")\n", + " return 0\n", + " \n", + " min_changes_row = min_changes_row[min_changes_row[f'int_{evaluated_node}'] == 0]\n", + "\n", + " secondary_check = 1/min_changes_row['changes'].min()\n", + "\n", + " print (f\"Secondary check: {secondary_check}\")\n", + "\n", + " assert map_estimate == secondary_check, \"MAP estimate does not match secondary check, increase sample size.\" \n", + "\n", + " return map_estimate" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#TODO THIS NEEDS TO BE UPDATED \n", + "This implementation is now used within another class definition, where, again, the main moves are in `def __call__`. We sample antecedent sets, leave other nodes (aside from the outcome) as witness candidates, and pass the result to an actual causality evaluation, keeping track of minimal antecedent sets and the corresponding witness sizes. Then we find a minimum of the sum and use it in the denominator.`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Examples" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comments on example selection\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- **Voting:** the example illustrates that parts of actual causes can share various degrees of responsibility for the outcome, without being actual causes.\n", + "\n", + "- **Stone-throwing:** responsibility calculations in one of the main running examples in the *Actual Causality* book by Halpern (2016).\n", + "\n", + "- **Firing squad:** an example in which responsibility and actual causality agree, where-as disussed in the notebook on the notion of blame-the notion of responsibility and blame will diverge." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Voting" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We discussed a similar model in a previous notebook. This time we have eight voters involved in a binary majority voting procedure and we investigate the responsibility assigned to voter 0. The situation is analogous to the one discussed in the actual causality notebook: if your vote is decisive, you are an actual cause, and you're not an actual cause otherwise. What's your responsibility, though? " + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# let's start with a minimal interesting example\n", + "# you are one of three voters\n", + "\n", + "def voting_model():\n", + " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", + " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", + " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", + "\n", + " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", + " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", + " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", + "\n", + " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 >1\n", + " )\n", + " return {\"outcome\": outcome.float()}\n", + "\n", + "observations = dict(u_vote0=1., u_vote1=1., u_vote2=1.\n", + " )\n", + "\n", + "treatment_candidates = {key[2:]: 1-v for key, v in observations.items() if key != \"u_vote0\"}\n", + "evaluated_node_counterfactual = {\"vote0\": 1 - observations[\"u_vote0\"]}\n", + "\n", + "votingHPR = HalpernPearlResponsibilityApproximate(\n", + " model = voting_model,\n", + " evaluated_node_counterfactual = evaluated_node_counterfactual,\n", + " treatment_candidates = treatment_candidates,\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,3)],\n", + " observations = observations)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preemption biases used (upper) - t: 0.7 , n: 0.6024412643276109 , w: 0.5502509213795265 .\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
obs_vote0int_vote0epr_vote0elp_vote0obs_vote1int_vote1apr_vote1alp_vote1wpr_vote1wlp_vote1...apr_vote2alp_vote2wpr_vote2wlp_vote2cdifclpintwprchangessum_lp
01.00.00-0.9221.01.01-0.3570-0.597...0-1.2040-0.597True0.0202-3.678000e+00
11.00.00-0.9221.00.00-1.2040-0.597...1-0.3570-0.597True0.0202-3.678000e+00
21.00.00-0.9221.00.00-1.2040-0.597...1-0.3571-0.799True0.0213-3.880000e+00
31.00.00-0.9221.01.01-0.3571-0.799...0-1.2040-0.597True0.0213-3.880000e+00
41.01.01-0.5071.00.00-1.2040-0.597...0-1.2040-0.597True0.0202-4.109000e+00
51.00.00-0.9221.00.00-1.2040-0.597...0-1.2040-0.597True0.0303-4.525000e+00
81.01.01-0.5071.01.01-0.3570-0.597...1-0.3570-0.597False-100000000.0000-1.000000e+08
91.01.01-0.5071.01.01-0.3570-0.597...1-0.3571-0.799False-100000000.0011-1.000000e+08
101.01.01-0.5071.01.01-0.3571-0.799...1-0.3570-0.597False-100000000.0011-1.000000e+08
111.01.01-0.5071.01.01-0.3571-0.799...1-0.3571-0.799False-100000000.0022-1.000000e+08
121.00.00-0.9221.01.01-0.3570-0.597...1-0.3570-0.597False-100000000.0101-1.000000e+08
131.00.00-0.9221.01.01-0.3571-0.799...1-0.3570-0.597False-100000000.0112-1.000000e+08
141.00.00-0.9221.01.01-0.3570-0.597...1-0.3571-0.799False-100000000.0112-1.000000e+08
151.00.00-0.9221.01.01-0.3571-0.799...1-0.3571-0.799False-100000000.0123-1.000000e+08
161.01.01-0.5071.01.01-0.3570-0.597...0-1.2040-0.597False-100000000.0101-1.000000e+08
171.01.01-0.5071.00.00-1.2040-0.597...1-0.3570-0.597False-100000000.0101-1.000000e+08
181.01.01-0.5071.01.01-0.3571-0.799...0-1.2040-0.597False-100000000.0112-1.000000e+08
191.01.01-0.5071.00.00-1.2040-0.597...1-0.3571-0.799False-100000000.0112-1.000000e+08
\n", + "

18 rows × 22 columns

\n", + "
" + ], + "text/plain": [ + " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", + "0 1.0 0.0 0 -0.922 1.0 1.0 \n", + "1 1.0 0.0 0 -0.922 1.0 0.0 \n", + "2 1.0 0.0 0 -0.922 1.0 0.0 \n", + "3 1.0 0.0 0 -0.922 1.0 1.0 \n", + "4 1.0 1.0 1 -0.507 1.0 0.0 \n", + "5 1.0 0.0 0 -0.922 1.0 0.0 \n", + "8 1.0 1.0 1 -0.507 1.0 1.0 \n", + "9 1.0 1.0 1 -0.507 1.0 1.0 \n", + "10 1.0 1.0 1 -0.507 1.0 1.0 \n", + "11 1.0 1.0 1 -0.507 1.0 1.0 \n", + "12 1.0 0.0 0 -0.922 1.0 1.0 \n", + "13 1.0 0.0 0 -0.922 1.0 1.0 \n", + "14 1.0 0.0 0 -0.922 1.0 1.0 \n", + "15 1.0 0.0 0 -0.922 1.0 1.0 \n", + "16 1.0 1.0 1 -0.507 1.0 1.0 \n", + "17 1.0 1.0 1 -0.507 1.0 0.0 \n", + "18 1.0 1.0 1 -0.507 1.0 1.0 \n", + "19 1.0 1.0 1 -0.507 1.0 0.0 \n", + "\n", + " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", + "0 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "1 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "2 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "3 1 -0.357 1 -0.799 ... 0 -1.204 \n", + "4 0 -1.204 0 -0.597 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.597 ... 0 -1.204 \n", + "8 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "9 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "10 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "11 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "12 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "13 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "14 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "15 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "16 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "17 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "18 1 -0.357 1 -0.799 ... 0 -1.204 \n", + "19 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "\n", + " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", + "0 0 -0.597 True 0.0 2 0 2 -3.678000e+00 \n", + "1 0 -0.597 True 0.0 2 0 2 -3.678000e+00 \n", + "2 1 -0.799 True 0.0 2 1 3 -3.880000e+00 \n", + "3 0 -0.597 True 0.0 2 1 3 -3.880000e+00 \n", + "4 0 -0.597 True 0.0 2 0 2 -4.109000e+00 \n", + "5 0 -0.597 True 0.0 3 0 3 -4.525000e+00 \n", + "8 0 -0.597 False -100000000.0 0 0 0 -1.000000e+08 \n", + "9 1 -0.799 False -100000000.0 0 1 1 -1.000000e+08 \n", + "10 0 -0.597 False -100000000.0 0 1 1 -1.000000e+08 \n", + "11 1 -0.799 False -100000000.0 0 2 2 -1.000000e+08 \n", + "12 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "13 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", + "14 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", + "15 1 -0.799 False -100000000.0 1 2 3 -1.000000e+08 \n", + "16 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "17 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "18 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", + "19 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", + "\n", + "[18 rows x 22 columns]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now inference, and let's inspect the table while it's small\n", + "\n", + "with pyro.plate(\"runs\", 1000):\n", + " votingHPR()\n", + "\n", + "vtr = votingHPR.trace.trace.nodes\n", + "\n", + "get_table(vtr, \"vote0\", treatment_candidates, [f\"vote{i}\" for i in range(1,3)])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TODO need a brief explanation of what's going on here" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MAP estimate: 0.5\n", + "Minimal scenarios:\n", + " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", + "0 1.0 0.0 0 -0.922 1.0 1.0 \n", + "1 1.0 0.0 0 -0.922 1.0 0.0 \n", + "\n", + " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", + "0 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "1 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "\n", + " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", + "0 0 -0.597 True 0.0 2 0 2 -3.678 \n", + "1 0 -0.597 True 0.0 2 0 2 -3.678 \n", + "\n", + "[2 rows x 22 columns]\n", + "Secondary check: 0.5\n" + ] + }, + { + "data": { + "text/plain": [ + "0.5" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "responsibility_check(votingHPR)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# now consider a more complex example,\n", + "# with 7 voters, where you are not an actual cause\n", + "\n", + "def voting_model7():\n", + " u_vote0 = pyro.sample(\"u_vote0\", dist.Bernoulli(0.6))\n", + " u_vote1 = pyro.sample(\"u_vote1\", dist.Bernoulli(0.6))\n", + " u_vote2 = pyro.sample(\"u_vote2\", dist.Bernoulli(0.6))\n", + " u_vote3 = pyro.sample(\"u_vote3\", dist.Bernoulli(0.6))\n", + " u_vote4 = pyro.sample(\"u_vote4\", dist.Bernoulli(0.6))\n", + " u_vote5 = pyro.sample(\"u_vote5\", dist.Bernoulli(0.6))\n", + " u_vote6 = pyro.sample(\"u_vote6\", dist.Bernoulli(0.6))\n", + " \n", + "\n", + " vote0 = pyro.deterministic(\"vote0\", u_vote0, event_dim=0)\n", + " vote1 = pyro.deterministic(\"vote1\", u_vote1, event_dim=0)\n", + " vote2 = pyro.deterministic(\"vote2\", u_vote2, event_dim=0)\n", + " vote3 = pyro.deterministic(\"vote3\", u_vote3, event_dim=0)\n", + " vote4 = pyro.deterministic(\"vote4\", u_vote4, event_dim=0)\n", + " vote5 = pyro.deterministic(\"vote5\", u_vote5, event_dim=0)\n", + " vote6 = pyro.deterministic(\"vote6\", u_vote6, event_dim=0)\n", + "\n", + " outcome = pyro.deterministic(\"outcome\", vote0 + vote1 + vote2 + vote3 + \n", + " vote4 + vote5 + vote6 > 3 )\n", + " return {\"outcome\": outcome.float()}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preemption biases used (upper) - t: 0.7 , n: 0.5337756999955469 , w: 0.5037726755473835 .\n" + ] + } + ], + "source": [ + "# everyone voted for,\n", + "# you are not an actual cause \n", + "# the minimal number of interventions \n", + "# including your change of vote\n", + "# needed to change the outcome is 4\n", + "# so your responsibility is 1/4\n", + "\n", + "observations7 = dict(u_vote0=1., u_vote1=1., u_vote2=1.,\n", + " u_vote3=1., u_vote4=1.,\n", + " u_vote5=1.,\n", + " u_vote6=1.,\n", + " )\n", + "\n", + "\n", + "treatment_candidates7 = {key[2:]: 1-v for key, v in observations7.items() if key != \"u_vote0\"}\n", + "\n", + "evaluated_node_counterfactual7 = {\"vote0\": 1 - observations7[\"u_vote0\"]}\n", + "\n", + "voting8HPR = HalpernPearlResponsibilityApproximate(\n", + " model = voting_model7,\n", + " evaluated_node_counterfactual = evaluated_node_counterfactual7,\n", + " treatment_candidates = treatment_candidates7,\n", + " outcome = \"outcome\",\n", + " witness_candidates = [f\"vote{i}\" for i in range(1,7)],\n", + " observations = observations7)\n", + "\n", + "with pyro.plate(\"runs\", 10000):\n", + " voting8HPR()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MAP estimate: 0.25\n", + "Minimal scenarios:\n", + " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", + "0 1.0 0.0 0 -0.763 1.0 0.0 \n", + "1 1.0 0.0 0 -0.763 1.0 1.0 \n", + "2 1.0 0.0 0 -0.763 1.0 1.0 \n", + "3 1.0 0.0 0 -0.763 1.0 0.0 \n", + "4 1.0 0.0 0 -0.763 1.0 1.0 \n", + "5 1.0 0.0 0 -0.763 1.0 0.0 \n", + "6 1.0 0.0 0 -0.763 1.0 0.0 \n", + "\n", + " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote6 alp_vote6 \\\n", + "0 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "1 1 -0.357 0 -0.686 ... 0 -1.204 \n", + "2 1 -0.357 0 -0.686 ... 1 -0.357 \n", + "3 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "4 1 -0.357 0 -0.686 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "6 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "\n", + " wpr_vote6 wlp_vote6 cdif clp int wpr changes sum_lp \n", + "0 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "1 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "2 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "3 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "4 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "5 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "6 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "\n", + "[7 rows x 46 columns]\n", + "Secondary check: 0.25\n" + ] + }, + { + "data": { + "text/plain": [ + "0.25" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "responsibility_check(voting8HPR)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Stone-throwing\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We've already discussed the model in the actual causality notebook. Sally and Bill throw stones at a bottle, Sally throws first. Bill is perfectly accurate, so his stone would have shattered the bottle had not Sally's stone done it. The model is worth looking at, as the causal structure is less trivial. Again, we will see that responsibility judgment might to some extent disagree with actual causality." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def stones_model(): \n", + " prob_sally_throws = pyro.sample(\"prob_sally_throws\", dist.Beta(1, 1))\n", + " prob_bill_throws = pyro.sample(\"prob_bill_throws\", dist.Beta(1, 1))\n", + " prob_sally_hits = pyro.sample(\"prob_sally_hits\", dist.Beta(1, 1))\n", + " prob_bill_hits = pyro.sample(\"prob_bill_hits\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_sally = pyro.sample(\"prob_bottle_shatters_if_sally\", dist.Beta(1, 1))\n", + " prob_bottle_shatters_if_bill = pyro.sample(\"prob_bottle_shatters_if_bill\", dist.Beta(1, 1))\n", + "\n", + "\n", + " sally_throws = pyro.sample(\"sally_throws\", dist.Bernoulli(prob_sally_throws))\n", + " bill_throws = pyro.sample(\"bill_throws\", dist.Bernoulli(prob_bill_throws))\n", + "\n", + " new_shp = torch.where(sally_throws == 1,prob_sally_hits , 0.0)\n", + "\n", + " sally_hits = pyro.sample(\"sally_hits\",dist.Bernoulli(new_shp))\n", + "\n", + " new_bhp = torch.where(\n", + " (\n", + " bill_throws.bool()\n", + " & (~sally_hits.bool())\n", + " )\n", + " == 1,\n", + " prob_bill_hits,\n", + " torch.tensor(0.0),\n", + " )\n", + "\n", + "\n", + " bill_hits = pyro.sample(\"bill_hits\", dist.Bernoulli(new_bhp))\n", + "\n", + " new_bsp = torch.where(\n", + " bill_hits.bool() == 1,\n", + " prob_bottle_shatters_if_bill,\n", + " torch.where(\n", + " sally_hits.bool() == 1,\n", + " prob_bottle_shatters_if_sally,\n", + " torch.tensor(0.0),\n", + " ),\n", + " )\n", + "\n", + " bottle_shatters = pyro.sample(\n", + " \"bottle_shatters\", dist.Bernoulli(new_bsp)\n", + " )\n", + "\n", + " return {\n", + " \"sally_throws\": sally_throws,\n", + " \"bill_throws\": bill_throws,\n", + " \"sally_hits\": sally_hits,\n", + " \"bill_hits\": bill_hits,\n", + " \"bottle_shatters\": bottle_shatters,\n", + " }\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preemption biases used (upper) - t: 0.7 , n: 0.5693194340229132 , w: 0.5214645870036448 .\n" + ] + } + ], + "source": [ + "\n", + "pyro.set_rng_seed(4)\n", + "stones_sallyHPR = HalpernPearlResponsibilityApproximate(\n", + " model = stones_model,\n", + " evaluated_node_counterfactual= {\"sally_throws\": 0.0},\n", + " treatment_candidates = {\"sally_hits\": 0.0, \"bill_hits\": 1.0, \"bill_throws\": 0.0},\n", + " outcome = \"bottle_shatters\",\n", + " witness_candidates = [\"bill_hits\", \"bill_throws\", \"sally_hits\"],\n", + " observations = {\"prob_sally_throws\": 1.0, \n", + " \"prob_bill_throws\": 1.0,\n", + " \"prob_sally_hits\": 1.0,\n", + " \"prob_bill_hits\": 1.0,\n", + " \"prob_bottle_shatters_if_sally\": 1.0,\n", + " \"prob_bottle_shatters_if_bill\": 1.0})\n", + "\n", + "with pyro.plate(\"runs\",10000):\n", + " stones_sallyHPR()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epr_sally_throwselp_sally_throwsapr_sally_hitsalp_sally_hitswpr_sally_hitswlp_sally_hitsapr_bill_hitsalp_bill_hitswpr_bill_hitswlp_bill_hitsapr_bill_throwsalp_bill_throwswpr_bill_throwswlp_bill_throwsclpintwprchangessum_lp
01-0.5630-1.2040-0.6511-0.3571-0.7371-0.3570-0.6510.0112-4.520000e+00
11-0.5630-1.2040-0.6511-0.3571-0.7371-0.3571-0.7370.0123-4.606000e+00
20-0.8421-0.3570-0.6511-0.3570-0.6510-1.2040-0.6510.0202-4.713000e+00
30-0.8421-0.3571-0.7371-0.3570-0.6510-1.2040-0.6510.0213-4.799000e+00
41-0.5630-1.2040-0.6511-0.3570-0.6510-1.2040-0.6510.0202-5.281000e+00
51-0.5630-1.2040-0.6511-0.3571-0.7370-1.2040-0.6510.0213-5.367000e+00
90-0.8420-1.2040-0.6511-0.3570-0.6510-1.2040-0.6510.0303-5.560000e+00
131-0.5631-0.3570-0.6511-0.3570-0.6511-0.3570-0.651-100000000.0000-1.000000e+08
141-0.5631-0.3570-0.6511-0.3571-0.7371-0.3570-0.651-100000000.0011-1.000000e+08
151-0.5631-0.3570-0.6511-0.3570-0.6511-0.3571-0.737-100000000.0011-1.000000e+08
161-0.5631-0.3571-0.7371-0.3570-0.6511-0.3570-0.651-100000000.0011-1.000000e+08
171-0.5631-0.3571-0.7371-0.3571-0.7371-0.3570-0.651-100000000.0022-1.000000e+08
181-0.5631-0.3570-0.6511-0.3571-0.7371-0.3571-0.737-100000000.0022-1.000000e+08
191-0.5631-0.3571-0.7371-0.3570-0.6511-0.3571-0.737-100000000.0022-1.000000e+08
201-0.5631-0.3571-0.7371-0.3571-0.7371-0.3571-0.737-100000000.0033-1.000000e+08
210-0.8421-0.3570-0.6511-0.3570-0.6511-0.3570-0.651-100000000.0101-1.000000e+08
220-0.8421-0.3571-0.7371-0.3570-0.6511-0.3570-0.651-100000000.0112-1.000000e+08
230-0.8421-0.3570-0.6511-0.3571-0.7371-0.3570-0.651-100000000.0112-1.000000e+08
240-0.8421-0.3570-0.6511-0.3570-0.6511-0.3571-0.737-100000000.0112-1.000000e+08
250-0.8421-0.3570-0.6511-0.3571-0.7371-0.3571-0.737-100000000.0123-1.000000e+08
260-0.8421-0.3571-0.7371-0.3571-0.7371-0.3570-0.651-100000000.0123-1.000000e+08
270-0.8421-0.3571-0.7371-0.3570-0.6511-0.3571-0.737-100000000.0123-1.000000e+08
280-0.8421-0.3571-0.7371-0.3571-0.7371-0.3571-0.737-100000000.0134-1.000000e+08
291-0.5631-0.3570-0.6510-1.2040-0.6511-0.3570-0.651-100000000.0101-1.000000e+08
301-0.5631-0.3570-0.6511-0.3570-0.6510-1.2040-0.651-100000000.0101-1.000000e+08
311-0.5630-1.2040-0.6511-0.3570-0.6511-0.3570-0.651-100000000.0101-1.000000e+08
321-0.5631-0.3570-0.6510-1.2040-0.6511-0.3571-0.737-100000000.0112-1.000000e+08
331-0.5631-0.3570-0.6511-0.3571-0.7370-1.2040-0.651-100000000.0112-1.000000e+08
361-0.5631-0.3571-0.7371-0.3570-0.6510-1.2040-0.651-100000000.0112-1.000000e+08
371-0.5631-0.3571-0.7370-1.2040-0.6511-0.3570-0.651-100000000.0112-1.000000e+08
381-0.5630-1.2040-0.6511-0.3570-0.6511-0.3571-0.737-100000000.0112-1.000000e+08
401-0.5631-0.3571-0.7370-1.2040-0.6511-0.3571-0.737-100000000.0123-1.000000e+08
471-0.5631-0.3571-0.7371-0.3571-0.7370-1.2040-0.651-100000000.0123-1.000000e+08
510-0.8421-0.3570-0.6510-1.2040-0.6511-0.3570-0.651-100000000.0202-1.000000e+08
520-0.8420-1.2040-0.6511-0.3570-0.6511-0.3570-0.651-100000000.0202-1.000000e+08
530-0.8421-0.3570-0.6510-1.2040-0.6511-0.3571-0.737-100000000.0213-1.000000e+08
540-0.8421-0.3571-0.7370-1.2040-0.6511-0.3570-0.651-100000000.0213-1.000000e+08
570-0.8420-1.2040-0.6511-0.3570-0.6511-0.3571-0.737-100000000.0213-1.000000e+08
590-0.8421-0.3570-0.6511-0.3571-0.7370-1.2040-0.651-100000000.0213-1.000000e+08
600-0.8420-1.2040-0.6511-0.3571-0.7371-0.3570-0.651-100000000.0213-1.000000e+08
610-0.8421-0.3571-0.7371-0.3571-0.7370-1.2040-0.651-100000000.0224-1.000000e+08
670-0.8421-0.3571-0.7370-1.2040-0.6511-0.3571-0.737-100000000.0224-1.000000e+08
680-0.8420-1.2040-0.6511-0.3571-0.7371-0.3571-0.737-100000000.0224-1.000000e+08
731-0.5631-0.3570-0.6510-1.2040-0.6510-1.2040-0.651-100000000.0202-1.000000e+08
741-0.5630-1.2040-0.6510-1.2040-0.6511-0.3570-0.651-100000000.0202-1.000000e+08
751-0.5630-1.2040-0.6510-1.2040-0.6511-0.3571-0.737-100000000.0213-1.000000e+08
771-0.5631-0.3571-0.7370-1.2040-0.6510-1.2040-0.651-100000000.0213-1.000000e+08
920-0.8421-0.3570-0.6510-1.2040-0.6510-1.2040-0.651-100000000.0303-1.000000e+08
930-0.8420-1.2040-0.6510-1.2040-0.6511-0.3570-0.651-100000000.0303-1.000000e+08
950-0.8421-0.3571-0.7370-1.2040-0.6510-1.2040-0.651-100000000.0314-1.000000e+08
960-0.8420-1.2040-0.6511-0.3571-0.7370-1.2040-0.651-100000000.0314-1.000000e+08
990-0.8420-1.2040-0.6510-1.2040-0.6511-0.3571-0.737-100000000.0314-1.000000e+08
1141-0.5630-1.2040-0.6510-1.2040-0.6510-1.2040-0.651-100000000.0303-1.000000e+08
1200-0.8420-1.2040-0.6510-1.2040-0.6510-1.2040-0.651-100000000.0404-1.000000e+08
\n", + "
" + ], + "text/plain": [ + " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", + "0 1 -0.563 0 -1.204 \n", + "1 1 -0.563 0 -1.204 \n", + "2 0 -0.842 1 -0.357 \n", + "3 0 -0.842 1 -0.357 \n", + "4 1 -0.563 0 -1.204 \n", + "5 1 -0.563 0 -1.204 \n", + "9 0 -0.842 0 -1.204 \n", + "13 1 -0.563 1 -0.357 \n", + "14 1 -0.563 1 -0.357 \n", + "15 1 -0.563 1 -0.357 \n", + "16 1 -0.563 1 -0.357 \n", + "17 1 -0.563 1 -0.357 \n", + "18 1 -0.563 1 -0.357 \n", + "19 1 -0.563 1 -0.357 \n", + "20 1 -0.563 1 -0.357 \n", + "21 0 -0.842 1 -0.357 \n", + "22 0 -0.842 1 -0.357 \n", + "23 0 -0.842 1 -0.357 \n", + "24 0 -0.842 1 -0.357 \n", + "25 0 -0.842 1 -0.357 \n", + "26 0 -0.842 1 -0.357 \n", + "27 0 -0.842 1 -0.357 \n", + "28 0 -0.842 1 -0.357 \n", + "29 1 -0.563 1 -0.357 \n", + "30 1 -0.563 1 -0.357 \n", + "31 1 -0.563 0 -1.204 \n", + "32 1 -0.563 1 -0.357 \n", + "33 1 -0.563 1 -0.357 \n", + "36 1 -0.563 1 -0.357 \n", + "37 1 -0.563 1 -0.357 \n", + "38 1 -0.563 0 -1.204 \n", + "40 1 -0.563 1 -0.357 \n", + "47 1 -0.563 1 -0.357 \n", + "51 0 -0.842 1 -0.357 \n", + "52 0 -0.842 0 -1.204 \n", + "53 0 -0.842 1 -0.357 \n", + "54 0 -0.842 1 -0.357 \n", + "57 0 -0.842 0 -1.204 \n", + "59 0 -0.842 1 -0.357 \n", + "60 0 -0.842 0 -1.204 \n", + "61 0 -0.842 1 -0.357 \n", + "67 0 -0.842 1 -0.357 \n", + "68 0 -0.842 0 -1.204 \n", + "73 1 -0.563 1 -0.357 \n", + "74 1 -0.563 0 -1.204 \n", + "75 1 -0.563 0 -1.204 \n", + "77 1 -0.563 1 -0.357 \n", + "92 0 -0.842 1 -0.357 \n", + "93 0 -0.842 0 -1.204 \n", + "95 0 -0.842 1 -0.357 \n", + "96 0 -0.842 0 -1.204 \n", + "99 0 -0.842 0 -1.204 \n", + "114 1 -0.563 0 -1.204 \n", + "120 0 -0.842 0 -1.204 \n", + "\n", + " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", + "0 0 -0.651 1 -0.357 \n", + "1 0 -0.651 1 -0.357 \n", + "2 0 -0.651 1 -0.357 \n", + "3 1 -0.737 1 -0.357 \n", + "4 0 -0.651 1 -0.357 \n", + "5 0 -0.651 1 -0.357 \n", + "9 0 -0.651 1 -0.357 \n", + "13 0 -0.651 1 -0.357 \n", + "14 0 -0.651 1 -0.357 \n", + "15 0 -0.651 1 -0.357 \n", + "16 1 -0.737 1 -0.357 \n", + "17 1 -0.737 1 -0.357 \n", + "18 0 -0.651 1 -0.357 \n", + "19 1 -0.737 1 -0.357 \n", + "20 1 -0.737 1 -0.357 \n", + "21 0 -0.651 1 -0.357 \n", + "22 1 -0.737 1 -0.357 \n", + "23 0 -0.651 1 -0.357 \n", + "24 0 -0.651 1 -0.357 \n", + "25 0 -0.651 1 -0.357 \n", + "26 1 -0.737 1 -0.357 \n", + "27 1 -0.737 1 -0.357 \n", + "28 1 -0.737 1 -0.357 \n", + "29 0 -0.651 0 -1.204 \n", + "30 0 -0.651 1 -0.357 \n", + "31 0 -0.651 1 -0.357 \n", + "32 0 -0.651 0 -1.204 \n", + "33 0 -0.651 1 -0.357 \n", + "36 1 -0.737 1 -0.357 \n", + "37 1 -0.737 0 -1.204 \n", + "38 0 -0.651 1 -0.357 \n", + "40 1 -0.737 0 -1.204 \n", + "47 1 -0.737 1 -0.357 \n", + "51 0 -0.651 0 -1.204 \n", + "52 0 -0.651 1 -0.357 \n", + "53 0 -0.651 0 -1.204 \n", + "54 1 -0.737 0 -1.204 \n", + "57 0 -0.651 1 -0.357 \n", + "59 0 -0.651 1 -0.357 \n", + "60 0 -0.651 1 -0.357 \n", + "61 1 -0.737 1 -0.357 \n", + "67 1 -0.737 0 -1.204 \n", + "68 0 -0.651 1 -0.357 \n", + "73 0 -0.651 0 -1.204 \n", + "74 0 -0.651 0 -1.204 \n", + "75 0 -0.651 0 -1.204 \n", + "77 1 -0.737 0 -1.204 \n", + "92 0 -0.651 0 -1.204 \n", + "93 0 -0.651 0 -1.204 \n", + "95 1 -0.737 0 -1.204 \n", + "96 0 -0.651 1 -0.357 \n", + "99 0 -0.651 0 -1.204 \n", + "114 0 -0.651 0 -1.204 \n", + "120 0 -0.651 0 -1.204 \n", + "\n", + " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", + "0 1 -0.737 1 -0.357 \n", + "1 1 -0.737 1 -0.357 \n", + "2 0 -0.651 0 -1.204 \n", + "3 0 -0.651 0 -1.204 \n", + "4 0 -0.651 0 -1.204 \n", + "5 1 -0.737 0 -1.204 \n", + "9 0 -0.651 0 -1.204 \n", + "13 0 -0.651 1 -0.357 \n", + "14 1 -0.737 1 -0.357 \n", + "15 0 -0.651 1 -0.357 \n", + "16 0 -0.651 1 -0.357 \n", + "17 1 -0.737 1 -0.357 \n", + "18 1 -0.737 1 -0.357 \n", + "19 0 -0.651 1 -0.357 \n", + "20 1 -0.737 1 -0.357 \n", + "21 0 -0.651 1 -0.357 \n", + "22 0 -0.651 1 -0.357 \n", + "23 1 -0.737 1 -0.357 \n", + "24 0 -0.651 1 -0.357 \n", + "25 1 -0.737 1 -0.357 \n", + "26 1 -0.737 1 -0.357 \n", + "27 0 -0.651 1 -0.357 \n", + "28 1 -0.737 1 -0.357 \n", + "29 0 -0.651 1 -0.357 \n", + "30 0 -0.651 0 -1.204 \n", + "31 0 -0.651 1 -0.357 \n", + "32 0 -0.651 1 -0.357 \n", + "33 1 -0.737 0 -1.204 \n", + "36 0 -0.651 0 -1.204 \n", + "37 0 -0.651 1 -0.357 \n", + "38 0 -0.651 1 -0.357 \n", + "40 0 -0.651 1 -0.357 \n", + "47 1 -0.737 0 -1.204 \n", + "51 0 -0.651 1 -0.357 \n", + "52 0 -0.651 1 -0.357 \n", + "53 0 -0.651 1 -0.357 \n", + "54 0 -0.651 1 -0.357 \n", + "57 0 -0.651 1 -0.357 \n", + "59 1 -0.737 0 -1.204 \n", + "60 1 -0.737 1 -0.357 \n", + "61 1 -0.737 0 -1.204 \n", + "67 0 -0.651 1 -0.357 \n", + "68 1 -0.737 1 -0.357 \n", + "73 0 -0.651 0 -1.204 \n", + "74 0 -0.651 1 -0.357 \n", + "75 0 -0.651 1 -0.357 \n", + "77 0 -0.651 0 -1.204 \n", + "92 0 -0.651 0 -1.204 \n", + "93 0 -0.651 1 -0.357 \n", + "95 0 -0.651 0 -1.204 \n", + "96 1 -0.737 0 -1.204 \n", + "99 0 -0.651 1 -0.357 \n", + "114 0 -0.651 0 -1.204 \n", + "120 0 -0.651 0 -1.204 \n", + "\n", + " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", + "0 0 -0.651 0.0 1 1 2 \n", + "1 1 -0.737 0.0 1 2 3 \n", + "2 0 -0.651 0.0 2 0 2 \n", + "3 0 -0.651 0.0 2 1 3 \n", + "4 0 -0.651 0.0 2 0 2 \n", + "5 0 -0.651 0.0 2 1 3 \n", + "9 0 -0.651 0.0 3 0 3 \n", + "13 0 -0.651 -100000000.0 0 0 0 \n", + "14 0 -0.651 -100000000.0 0 1 1 \n", + "15 1 -0.737 -100000000.0 0 1 1 \n", + "16 0 -0.651 -100000000.0 0 1 1 \n", + "17 0 -0.651 -100000000.0 0 2 2 \n", + "18 1 -0.737 -100000000.0 0 2 2 \n", + "19 1 -0.737 -100000000.0 0 2 2 \n", + "20 1 -0.737 -100000000.0 0 3 3 \n", + "21 0 -0.651 -100000000.0 1 0 1 \n", + "22 0 -0.651 -100000000.0 1 1 2 \n", + "23 0 -0.651 -100000000.0 1 1 2 \n", + "24 1 -0.737 -100000000.0 1 1 2 \n", + "25 1 -0.737 -100000000.0 1 2 3 \n", + "26 0 -0.651 -100000000.0 1 2 3 \n", + "27 1 -0.737 -100000000.0 1 2 3 \n", + "28 1 -0.737 -100000000.0 1 3 4 \n", + "29 0 -0.651 -100000000.0 1 0 1 \n", + "30 0 -0.651 -100000000.0 1 0 1 \n", + "31 0 -0.651 -100000000.0 1 0 1 \n", + "32 1 -0.737 -100000000.0 1 1 2 \n", + "33 0 -0.651 -100000000.0 1 1 2 \n", + "36 0 -0.651 -100000000.0 1 1 2 \n", + "37 0 -0.651 -100000000.0 1 1 2 \n", + "38 1 -0.737 -100000000.0 1 1 2 \n", + "40 1 -0.737 -100000000.0 1 2 3 \n", + "47 0 -0.651 -100000000.0 1 2 3 \n", + "51 0 -0.651 -100000000.0 2 0 2 \n", + "52 0 -0.651 -100000000.0 2 0 2 \n", + "53 1 -0.737 -100000000.0 2 1 3 \n", + "54 0 -0.651 -100000000.0 2 1 3 \n", + "57 1 -0.737 -100000000.0 2 1 3 \n", + "59 0 -0.651 -100000000.0 2 1 3 \n", + "60 0 -0.651 -100000000.0 2 1 3 \n", + "61 0 -0.651 -100000000.0 2 2 4 \n", + "67 1 -0.737 -100000000.0 2 2 4 \n", + "68 1 -0.737 -100000000.0 2 2 4 \n", + "73 0 -0.651 -100000000.0 2 0 2 \n", + "74 0 -0.651 -100000000.0 2 0 2 \n", + "75 1 -0.737 -100000000.0 2 1 3 \n", + "77 0 -0.651 -100000000.0 2 1 3 \n", + "92 0 -0.651 -100000000.0 3 0 3 \n", + "93 0 -0.651 -100000000.0 3 0 3 \n", + "95 0 -0.651 -100000000.0 3 1 4 \n", + "96 0 -0.651 -100000000.0 3 1 4 \n", + "99 1 -0.737 -100000000.0 3 1 4 \n", + "114 0 -0.651 -100000000.0 3 0 3 \n", + "120 0 -0.651 -100000000.0 4 0 4 \n", + "\n", + " sum_lp \n", + "0 -4.520000e+00 \n", + "1 -4.606000e+00 \n", + "2 -4.713000e+00 \n", + "3 -4.799000e+00 \n", + "4 -5.281000e+00 \n", + "5 -5.367000e+00 \n", + "9 -5.560000e+00 \n", + "13 -1.000000e+08 \n", + "14 -1.000000e+08 \n", + "15 -1.000000e+08 \n", + "16 -1.000000e+08 \n", + "17 -1.000000e+08 \n", + "18 -1.000000e+08 \n", + "19 -1.000000e+08 \n", + "20 -1.000000e+08 \n", + "21 -1.000000e+08 \n", + "22 -1.000000e+08 \n", + "23 -1.000000e+08 \n", + "24 -1.000000e+08 \n", + "25 -1.000000e+08 \n", + "26 -1.000000e+08 \n", + "27 -1.000000e+08 \n", + "28 -1.000000e+08 \n", + "29 -1.000000e+08 \n", + "30 -1.000000e+08 \n", + "31 -1.000000e+08 \n", + "32 -1.000000e+08 \n", + "33 -1.000000e+08 \n", + "36 -1.000000e+08 \n", + "37 -1.000000e+08 \n", + "38 -1.000000e+08 \n", + "40 -1.000000e+08 \n", + "47 -1.000000e+08 \n", + "51 -1.000000e+08 \n", + "52 -1.000000e+08 \n", + "53 -1.000000e+08 \n", + "54 -1.000000e+08 \n", + "57 -1.000000e+08 \n", + "59 -1.000000e+08 \n", + "60 -1.000000e+08 \n", + "61 -1.000000e+08 \n", + "67 -1.000000e+08 \n", + "68 -1.000000e+08 \n", + "73 -1.000000e+08 \n", + "74 -1.000000e+08 \n", + "75 -1.000000e+08 \n", + "77 -1.000000e+08 \n", + "92 -1.000000e+08 \n", + "93 -1.000000e+08 \n", + "95 -1.000000e+08 \n", + "96 -1.000000e+08 \n", + "99 -1.000000e+08 \n", + "114 -1.000000e+08 \n", + "120 -1.000000e+08 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def gett(nodes, evaluated_node, antecedents, witness_candidates, round = True):\n", + " \n", + " values_table = {}\n", + "\n", + "\n", + "# values_table[f\"obs_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][0].squeeze().tolist()\n", + "# values_table[f\"int_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][1].squeeze().tolist()\n", + " values_table[f\"epr_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"].squeeze().tolist()\n", + " values_table[f\"elp_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"fn\"].log_prob(nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"]).squeeze().tolist()\n", + "\n", + " for antecedent in antecedents:\n", + "# andecedent_m = HPR.run[antecedent]\n", + "# print(gather(andecedent_m, IndexSet(**{antecedent: {0} for antecedent in antecedents})))\n", + "# values_table[f\"obs_{antecedent}\"] = nodes[antecedent][\"value\"][0].squeeze().tolist()\n", + "# values_table[f\"int_{antecedent}\"] = nodes[antecedent][\"value\"][1].squeeze().tolist()\n", + " values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent][\"fn\"].log_prob(nodes['__treatment_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " if f\"__witness_split_{antecedent}\" in nodes.keys():\n", + " values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent][\"fn\"].log_prob(nodes['__witness_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " for witness in witness_candidates:\n", + " if witness not in antecedents:\n", + " values_table['wpr_' + witness] = nodes['__witness_split_' + witness][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + witness] = nodes['__witness_split_' + witness][\"fn\"].log_prob(nodes['__witness_split_' + witness][\"value\"]).squeeze().tolist()\n", + "\n", + " values_table['clp'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " if isinstance(values_table['clp'], float):\n", + " values_df = pd.DataFrame([values_table])\n", + " else:\n", + " values_df = pd.DataFrame(values_table)\n", + " \n", + " values_df = pd.DataFrame(values_table)\n", + "\n", + " summands_ant = ['alp_' + antecedent for antecedent in antecedents]\n", + " summands_wit = ['wlp_' + witness for witness in witness_candidates]\n", + " summands = [f\"elp_{evaluated_node}\"] + summands_ant + summands_wit + ['clp']\n", + "\n", + " values_df[\"int\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"apr_\")] == 0), axis=1)\n", + " values_df['int'] = 1 - values_df[f\"epr_{evaluated_node}\"] + values_df[\"int\"]\n", + " values_df[\"wpr\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"wpr_\")] == 1), axis=1)\n", + " values_df[\"changes\"] = values_df[\"int\"] + values_df[\"wpr\"]\n", + "\n", + " values_df[\"sum_lp\"] = values_df[summands].sum(axis = 1) \n", + " values_df.drop_duplicates(inplace = True)\n", + " values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", + "\n", + " tab = values_df.reset_index(drop = True)\n", + "\n", + " tab = remove_redundant_rows(tab)\n", + " \n", + " if round:\n", + " tab = tab.round(3)\n", + "\n", + " return tab\n", + "\n", + "\n", + "tab = gett(stones_sallyHPR.trace.trace.nodes, \"sally_throws\", stones_sallyHPR.treatment_candidates, \n", + " stones_sallyHPR.witness_candidates)\n", + "\n", + "tab" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epr_sally_throwselp_sally_throwsapr_sally_hitsalp_sally_hitswpr_sally_hitswlp_sally_hitsapr_bill_hitsalp_bill_hitswpr_bill_hitswlp_bill_hitsapr_bill_throwsalp_bill_throwswpr_bill_throwswlp_bill_throwsclpintwprchangessum_lp
230-0.8421-0.3570-0.6511-0.3571-0.7371-0.3570-0.651-100000000.0112-1.000000e+08
\n", + "
" + ], + "text/plain": [ + " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", + "23 0 -0.842 1 -0.357 \n", + "\n", + " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", + "23 0 -0.651 1 -0.357 \n", + "\n", + " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", + "23 1 -0.737 1 -0.357 \n", + "\n", + " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", + "23 0 -0.651 -100000000.0 1 1 2 \n", + "\n", + " sum_lp \n", + "23 -1.000000e+08 " + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#this is worrying, should be on top with clp == 0\n", + "\n", + "tab.query(\"epr_sally_throws == 0 & apr_sally_hits == 1 & wpr_sally_hits == 0 & apr_bill_hits == 1 & wpr_bill_hits == 1 & apr_bill_throws == 1 & wpr_bill_throws == 0\")" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'type': 'sample',\n", + " 'name': 'consequent_differs_binary',\n", + " 'fn': MaskedDistribution(),\n", + " 'is_observed': True,\n", + " 'args': (),\n", + " 'kwargs': {},\n", + " 'value': tensor([[[[[False, False, False, False, False, False, False, False, False,\n", + " False]]]]]),\n", + " 'infer': {'_deterministic': True},\n", + " 'scale': 1.0,\n", + " 'mask': None,\n", + " 'cond_indep_stack': (CondIndepStackFrame(name='runs', dim=-1, size=10, counter=0),),\n", + " 'done': True,\n", + " 'stop': False,\n", + " 'continuation': None}" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stones_sallyHPR.trace.trace.nodes[\"consequent_differs_binary\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'obs_sally_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", + " 'int_sally_throws': [1.0, 1.0, 1.0, 1.0, 0.0],\n", + " 'epr_sally_throws': [1, 1, 1, 1, 0],\n", + " 'elp_sally_throws': [-0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -0.35953617095947266,\n", + " -1.1973283290863037],\n", + " 'obs_bill_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", + " 'int_bill_throws': [1.0, 0.0, 1.0, 1.0, 1.0],\n", + " 'apr_bill_throws': [1, 0, 1, 1, 1],\n", + " 'alp_bill_throws': [-0.3566749691963196,\n", + " -1.2039728164672852,\n", + " -0.3566749691963196,\n", + " -0.3566749691963196,\n", + " -0.3566749691963196],\n", + " 'obs_bill_hits': [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]],\n", + " 'wpr_bill_hits': [0, 1, 0, 0, 0],\n", + " 'wlp_bill_hits': [-0.3624056577682495,\n", + " -1.1907275915145874,\n", + " -0.3624056577682495,\n", + " -0.3624056577682495,\n", + " -0.3624056577682495]}" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def get_table(nodes, evaluated_node, antecedents, witness_candidates, round = True):\n", + " \n", + " values_table = {}\n", + "\n", + " values_table[f\"obs_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{evaluated_node}\"] = nodes[evaluated_node][\"value\"][1].squeeze().tolist()\n", + " values_table[f\"epr_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"].squeeze().tolist()\n", + " values_table[f\"elp_{evaluated_node}\"] = nodes[f\"__evaluated_split_{evaluated_node}\"][\"fn\"].log_prob(nodes[f\"__evaluated_split_{evaluated_node}\"][\"value\"]).squeeze().tolist()\n", + "\n", + " for antecedent in antecedents:\n", + " values_table[f\"obs_{antecedent}\"] = nodes[antecedent][\"value\"][0].squeeze().tolist()\n", + " values_table[f\"int_{antecedent}\"] = nodes[antecedent][\"value\"][1].squeeze().tolist()\n", + " values_table['apr_' + antecedent] = nodes['__treatment_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['alp_' + antecedent] = nodes['__treatment_split_' + antecedent][\"fn\"].log_prob(nodes['__treatment_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + "\n", + "\n", + " if f\"__witness_split_{antecedent}\" in nodes.keys():\n", + " values_table['wpr_' + antecedent] = nodes['__witness_split_' + antecedent][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + antecedent] = nodes['__witness_split_' + antecedent][\"fn\"].log_prob(nodes['__witness_split_' + antecedent][\"value\"]).squeeze().tolist()\n", + "\n", + " for witness in witness_candidates:\n", + " if witness not in antecedents:\n", + " values_table[f\"obs_{witness}\"] = nodes[witness][\"value\"][0].squeeze().tolist()\n", + " #values_table[f\"int_{witness}\"] = nodes[witness][\"value\"][1].squeeze().tolist()\n", + " values_table['wpr_' + witness] = nodes['__witness_split_' + witness][\"value\"].squeeze().tolist()\n", + " values_table['wlp_' + witness] = nodes['__witness_split_' + witness][\"fn\"].log_prob(nodes['__witness_split_' + witness][\"value\"]).squeeze().tolist()\n", + "\n", + " \n", + " #values_table['cdif'] = nodes['consequent_differs_binary'][\"value\"].squeeze().tolist()\n", + " #values_table['clp'] = nodes['consequent_differs'][\"fn\"].log_prob(nodes['consequent_differs'][\"value\"]).squeeze().tolist()\n", + "\n", + " #if isinstance(values_table['clp'], float):\n", + " # values_df = pd.DataFrame([values_table])\n", + " # else:\n", + " # values_df = pd.DataFrame(values_table)\n", + " \n", + " # values_df = pd.DataFrame(values_table)\n", + "\n", + " #summands_ant = ['alp_' + antecedent for antecedent in antecedents]\n", + " #summands_wit = ['wlp_' + witness for witness in witness_candidates]\n", + " #summands = [f\"elp_{evaluated_node}\"] + summands_ant + summands_wit + ['clp']\n", + " \n", + " \n", + " # values_df[\"int\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"apr_\")] == 0), axis=1)\n", + " # values_df['int'] = 1 - values_df[f\"epr_{evaluated_node}\"] + values_df[\"int\"]\n", + " # values_df[\"wpr\"] = values_df.apply(lambda row: sum(row[row.index.str.startswith(\"wpr_\")] == 1), axis=1)\n", + " # values_df[\"changes\"] = values_df[\"int\"] + values_df[\"wpr\"]\n", + "\n", + "\n", + " #values_df[\"sum_lp\"] = values_df[summands].sum(axis = 1) \n", + " # values_df.drop_duplicates(inplace = True)\n", + " # values_df.sort_values(by = \"sum_lp\", inplace = True, ascending = False)\n", + "\n", + " # tab = values_df.reset_index(drop = True)\n", + "\n", + " # tab = remove_redundant_rows(tab)\n", + "\n", + " tab = values_table\n", + "\n", + " #if round:\n", + " # tab = tab.round(3)\n", + "\n", + " return tab\n", + "\n", + "\n", + "get_table(stones_sallyHPR.trace.trace.nodes, \"sally_throws\", stones_sallyHPR.treatment_candidates, \n", + " stones_sallyHPR.witness_candidates)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "All arrays must be of the same length", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[63], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# minimal witness size becomes non-trivial here\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[39m# we only record different minimal difference-making scenarios\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m responsibility_check(stones_sallyHPR)\n", + "Cell \u001b[0;32mIn[6], line 4\u001b[0m, in \u001b[0;36mresponsibility_check\u001b[0;34m(hpr)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mresponsibility_check\u001b[39m(hpr):\n\u001b[1;32m 3\u001b[0m evaluated_node \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(hpr\u001b[39m.\u001b[39mevaluated_node_counterfactual\u001b[39m.\u001b[39mkeys())[\u001b[39m0\u001b[39m]\n\u001b[0;32m----> 4\u001b[0m tab \u001b[39m=\u001b[39m get_table(hpr\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mnodes,\n\u001b[1;32m 5\u001b[0m evaluated_node ,\n\u001b[1;32m 6\u001b[0m \u001b[39mlist\u001b[39;49m(hpr\u001b[39m.\u001b[39;49mtreatment_candidates\u001b[39m.\u001b[39;49mkeys()), \n\u001b[1;32m 7\u001b[0m hpr\u001b[39m.\u001b[39;49mwitness_candidates)\n\u001b[1;32m 9\u001b[0m max_sum_lp \u001b[39m=\u001b[39m tab[\u001b[39m'\u001b[39m\u001b[39msum_lp\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m.\u001b[39mmax()\n\u001b[1;32m 10\u001b[0m max_sum_lp_rows \u001b[39m=\u001b[39m tab[tab[\u001b[39m'\u001b[39m\u001b[39msum_lp\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m==\u001b[39m max_sum_lp]\n", + "Cell \u001b[0;32mIn[5], line 31\u001b[0m, in \u001b[0;36mget_table\u001b[0;34m(nodes, evaluated_node, antecedents, witness_candidates, round)\u001b[0m\n\u001b[1;32m 29\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mDataFrame([values_table])\n\u001b[1;32m 30\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 31\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39;49mDataFrame(values_table)\n\u001b[1;32m 33\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mDataFrame(values_table)\n\u001b[1;32m 35\u001b[0m summands_ant \u001b[39m=\u001b[39m [\u001b[39m'\u001b[39m\u001b[39malp_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent \u001b[39mfor\u001b[39;00m antecedent \u001b[39min\u001b[39;00m antecedents]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/frame.py:663\u001b[0m, in \u001b[0;36mDataFrame.__init__\u001b[0;34m(self, data, index, columns, dtype, copy)\u001b[0m\n\u001b[1;32m 657\u001b[0m mgr \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_init_mgr(\n\u001b[1;32m 658\u001b[0m data, axes\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39mindex\u001b[39m\u001b[39m\"\u001b[39m: index, \u001b[39m\"\u001b[39m\u001b[39mcolumns\u001b[39m\u001b[39m\"\u001b[39m: columns}, dtype\u001b[39m=\u001b[39mdtype, copy\u001b[39m=\u001b[39mcopy\n\u001b[1;32m 659\u001b[0m )\n\u001b[1;32m 661\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, \u001b[39mdict\u001b[39m):\n\u001b[1;32m 662\u001b[0m \u001b[39m# GH#38939 de facto copy defaults to False only in non-dict cases\u001b[39;00m\n\u001b[0;32m--> 663\u001b[0m mgr \u001b[39m=\u001b[39m dict_to_mgr(data, index, columns, dtype\u001b[39m=\u001b[39;49mdtype, copy\u001b[39m=\u001b[39;49mcopy, typ\u001b[39m=\u001b[39;49mmanager)\n\u001b[1;32m 664\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, ma\u001b[39m.\u001b[39mMaskedArray):\n\u001b[1;32m 665\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mma\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmrecords\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mmrecords\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:493\u001b[0m, in \u001b[0;36mdict_to_mgr\u001b[0;34m(data, index, columns, dtype, typ, copy)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 490\u001b[0m \u001b[39m# dtype check to exclude e.g. range objects, scalars\u001b[39;00m\n\u001b[1;32m 491\u001b[0m arrays \u001b[39m=\u001b[39m [x\u001b[39m.\u001b[39mcopy() \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(x, \u001b[39m\"\u001b[39m\u001b[39mdtype\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39melse\u001b[39;00m x \u001b[39mfor\u001b[39;00m x \u001b[39min\u001b[39;00m arrays]\n\u001b[0;32m--> 493\u001b[0m \u001b[39mreturn\u001b[39;00m arrays_to_mgr(arrays, columns, index, dtype\u001b[39m=\u001b[39;49mdtype, typ\u001b[39m=\u001b[39;49mtyp, consolidate\u001b[39m=\u001b[39;49mcopy)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:118\u001b[0m, in \u001b[0;36marrays_to_mgr\u001b[0;34m(arrays, columns, index, dtype, verify_integrity, typ, consolidate)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[39mif\u001b[39;00m verify_integrity:\n\u001b[1;32m 116\u001b[0m \u001b[39m# figure out the index, if necessary\u001b[39;00m\n\u001b[1;32m 117\u001b[0m \u001b[39mif\u001b[39;00m index \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 118\u001b[0m index \u001b[39m=\u001b[39m _extract_index(arrays)\n\u001b[1;32m 119\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 120\u001b[0m index \u001b[39m=\u001b[39m ensure_index(index)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:666\u001b[0m, in \u001b[0;36m_extract_index\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 664\u001b[0m lengths \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(\u001b[39mset\u001b[39m(raw_lengths))\n\u001b[1;32m 665\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(lengths) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m--> 666\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mAll arrays must be of the same length\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 668\u001b[0m \u001b[39mif\u001b[39;00m have_dicts:\n\u001b[1;32m 669\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 670\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mMixing dicts with non-Series may lead to ambiguous ordering.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 671\u001b[0m )\n", + "\u001b[0;31mValueError\u001b[0m: All arrays must be of the same length" + ] + } + ], + "source": [ + "# minimal witness size becomes non-trivial here\n", + "# we only record different minimal difference-making scenarios\n", + "\n", + "responsibility_check(stones_sallyHPR)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# following Halpern\n", + "# Sally's responsibility is 1/2\n", + "\n", + "responsibility_stones_sally_HPR.responsibility" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Billy has degree of responsibility 0\n", + "# for the bottle shattering,\n", + "# as his throw is not a part of an actual cause\n", + "\n", + "pyro.set_rng_seed(102)\n", + "\n", + "responsibility_stones_bill_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = stones_model,\n", + " nodes = stones_model.nodes,\n", + " antecedent = \"bill_throws\", outcome = \"bottle_shatters\",\n", + " observations = {\"prob_sally_throws\": 1, \n", + " \"prob_bill_throws\": 1,\n", + " \"prob_sally_hits\": 1,\n", + " \"prob_bill_hits\": 1,\n", + " \"prob_bottle_shatters_if_sally\": 1,\n", + " \"prob_bottle_shatters_if_bill\": 1,\n", + " \"sally_throws\": 1, \"bill_throws\": 1},\n", + " runs_n=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(101)\n", + "responsibility_stones_bill_HPR()" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "responsibility_stones_bill_HPR.responsibility" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Firing squad" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There is a firing squad consisting of five excellent marksmen. Only one of them has a live bullet in his rifle and the rest have blanks. They shoot and the prisoner dies. The marksmen shoot at the prisoner and he dies. The only cause of the prisoner’s death is the marksman with the live bullet. That marksman has degree of responsibility 1 for the death and all the others have degree of responsibility 0. In the notebook on blame, TODO add link we will see that if the marksmen completely do not know which of them has the live bullet, blame is nevertheless equally distributed between them." + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def firing_squad_model():\n", + " probs = pyro.sample(\"probs\", dist.Dirichlet(torch.ones(5)))\n", + "\n", + " who_has_bullet = pyro.sample(\"who_has_bullet\", dist.OneHotCategorical(probs))\n", + "\n", + " mark0 = pyro.deterministic(\"mark0\", torch.tensor([who[0] for who in who_has_bullet]), event_dim=0)\n", + " mark1 = pyro.deterministic(\"mark1\", torch.tensor([who[1] for who in who_has_bullet]), event_dim=0)\n", + " mark2 = pyro.deterministic(\"mark2\", torch.tensor([who[2] for who in who_has_bullet]), event_dim=0)\n", + " mark3 = pyro.deterministic(\"mark3\", torch.tensor([who[3] for who in who_has_bullet]), event_dim=0)\n", + " mark4 = pyro.deterministic(\"mark4\", torch.tensor([who[4] for who in who_has_bullet]), event_dim=0)\n", + "\n", + " dead = pyro.deterministic(\"dead\", mark0 + mark1 + mark2 + mark3 + \n", + " mark4 > 0)\n", + " \n", + " return {\"probs\": probs,\n", + " \"mark0\": mark0,\n", + " \"mark1\": mark1,\n", + " \"mark2\": mark2,\n", + " \"mark3\": mark3,\n", + " \"mark4\": mark4, \n", + " \"dead\": dead}\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(102)\n", + "\n", + "responsibility_loaded_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = firing_squad_model,\n", + " nodes = [\"mark\" + str(i) for i in range(0,5)],\n", + " antecedent = \"mark0\", outcome = \"dead\",\n", + " observations = {\"probs\": torch.tensor([1., 0., 0., 0., 0.]),},\n", + " runs_n=50)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": {}, + "outputs": [], + "source": [ + "pyro.set_rng_seed(102)\n", + "\n", + "responsibility_empty_HPR = HalpernPearlResponsibilityApproximate(\n", + " model = firing_squad_model,\n", + " nodes = [\"mark\" + str(i) for i in range(0,5)],\n", + " antecedent = \"mark1\", outcome = \"dead\",\n", + " observations = {\"probs\": torch.tensor([1., 0., 0., 0., 0.]),},\n", + " runs_n=50)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1.0" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# If you have the live bullet\n", + "\n", + "responsibility_loaded_HPR()\n", + "responsibility_loaded_HPR.responsibility" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# if you have a blank,\n", + "# as we keep bullet's location constant in the model\n", + "# nothing can make a difference to mark1's contribution\n", + "# so his responsibility is zero\n", + "\n", + "responsibility_empty_HPR()\n", + "responsibility_empty_HPR.responsibility" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "causal_pyro", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 9abf8600c519512c6a95e455d8c06667cf3058e6 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Thu, 24 Aug 2023 10:21:57 +0200 Subject: [PATCH 12/13] dealing with multiple preemptions --- docs/source/responsibility_exploration.ipynb | 2032 ++++-------------- 1 file changed, 385 insertions(+), 1647 deletions(-) diff --git a/docs/source/responsibility_exploration.ipynb b/docs/source/responsibility_exploration.ipynb index 5ddb636c..89cd252a 100644 --- a/docs/source/responsibility_exploration.ipynb +++ b/docs/source/responsibility_exploration.ipynb @@ -155,7 +155,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -245,17 +245,14 @@ " def __call__(self, *args, **kwargs):\n", " print(\"Preemption biases used (upper) - t:\",.5+ self.bias_t, \", n:\", .5 + self.bias_n, \", w:\", .5 + self.bias_w, \".\")\n", " with MultiWorldCounterfactual():\n", - " with do(actions=self.evaluated_node_counterfactual):\n", - " with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", + " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),\n", + " prefix = \"__witness_split_\"):\n", + " with do(actions=self.evaluated_node_counterfactual):\n", + " with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", " prefix = \"__evaluated_split_\"):\n", - " with do(actions=self.treatment_candidates):\n", - " with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),\n", + " with do(actions=self.treatment_candidates):\n", + " with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),\n", " prefix = \"__treatment_split_\"):\n", - " # the last element is the fixed at the observed value (preempted) \n", - " # the last element of the tensor is the factual case (preempted)\n", - " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),\n", - " prefix = \"__witness_split_\"):\n", - "\n", " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", " with pyro.poutine.trace() as self.trace:\n", " self.run = self.model(*args, **kwargs)\n", @@ -267,7 +264,7 @@ " self.observedIndex = {node: {0} for node in list(self.evaluated_node_counterfactual.keys()) + \n", " list(self.treatment_candidates.keys()) + self.witness_candidates}\n", "\n", - " \n", + " print(indices_of(self.consequent, event_dim = 0))\n", " self.intervened_consequent = gather(self.consequent, IndexSet(**self.interventionIndex))\n", " \n", " self.observed_consequent = gather(self.consequent, IndexSet(**self.observedIndex))\n", @@ -278,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -315,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -382,7 +379,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -478,7 +475,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -515,14 +512,15 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.6024412643276109 , w: 0.5502509213795265 .\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.6024412643276109 , w: 0.5502509213795265 .\n", + "IndexSet({'vote0': {0, 1}, 'vote1': {0, 1}, 'vote2': {0, 1}})\n" ] }, { @@ -577,14 +575,14 @@ " 0\n", " -0.922\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", + " 0.0\n", + " 0\n", + " -1.204\n", " 0\n", " -0.597\n", " ...\n", - " 0\n", - " -1.204\n", + " 1\n", + " -0.357\n", " 0\n", " -0.597\n", " True\n", @@ -601,14 +599,14 @@ " 0\n", " -0.922\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", + " 1.0\n", + " 1\n", + " -0.357\n", " 0\n", " -0.597\n", " ...\n", - " 1\n", - " -0.357\n", + " 0\n", + " -1.204\n", " 0\n", " -0.597\n", " True\n", @@ -625,16 +623,16 @@ " 0\n", " -0.922\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.597\n", - " ...\n", + " 1.0\n", " 1\n", " -0.357\n", " 1\n", " -0.799\n", + " ...\n", + " 0\n", + " -1.204\n", + " 0\n", + " -0.597\n", " True\n", " 0.0\n", " 2\n", @@ -643,22 +641,22 @@ " -3.880000e+00\n", " \n", " \n", - " 3\n", + " 5\n", " 1.0\n", " 0.0\n", " 0\n", " -0.922\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.799\n", - " ...\n", + " 0.0\n", " 0\n", " -1.204\n", " 0\n", " -0.597\n", + " ...\n", + " 1\n", + " -0.357\n", + " 1\n", + " -0.799\n", " True\n", " 0.0\n", " 2\n", @@ -667,7 +665,7 @@ " -3.880000e+00\n", " \n", " \n", - " 4\n", + " 8\n", " 1.0\n", " 1.0\n", " 1\n", @@ -691,7 +689,7 @@ " -4.109000e+00\n", " \n", " \n", - " 5\n", + " 12\n", " 1.0\n", " 0.0\n", " 0\n", @@ -715,7 +713,7 @@ " -4.525000e+00\n", " \n", " \n", - " 8\n", + " 16\n", " 1.0\n", " 1.0\n", " 1\n", @@ -739,7 +737,7 @@ " -1.000000e+08\n", " \n", " \n", - " 9\n", + " 17\n", " 1.0\n", " 1.0\n", " 1\n", @@ -763,7 +761,7 @@ " -1.000000e+08\n", " \n", " \n", - " 10\n", + " 18\n", " 1.0\n", " 1.0\n", " 1\n", @@ -787,7 +785,7 @@ " -1.000000e+08\n", " \n", " \n", - " 11\n", + " 19\n", " 1.0\n", " 1.0\n", " 1\n", @@ -811,7 +809,7 @@ " -1.000000e+08\n", " \n", " \n", - " 12\n", + " 20\n", " 1.0\n", " 0.0\n", " 0\n", @@ -835,7 +833,7 @@ " -1.000000e+08\n", " \n", " \n", - " 13\n", + " 21\n", " 1.0\n", " 0.0\n", " 0\n", @@ -859,7 +857,7 @@ " -1.000000e+08\n", " \n", " \n", - " 14\n", + " 22\n", " 1.0\n", " 0.0\n", " 0\n", @@ -883,7 +881,7 @@ " -1.000000e+08\n", " \n", " \n", - " 15\n", + " 23\n", " 1.0\n", " 0.0\n", " 0\n", @@ -907,20 +905,20 @@ " -1.000000e+08\n", " \n", " \n", - " 16\n", + " 24\n", " 1.0\n", " 1.0\n", " 1\n", " -0.507\n", " 1.0\n", - " 1.0\n", - " 1\n", - " -0.357\n", + " 0.0\n", + " 0\n", + " -1.204\n", " 0\n", " -0.597\n", " ...\n", - " 0\n", - " -1.204\n", + " 1\n", + " -0.357\n", " 0\n", " -0.597\n", " False\n", @@ -931,20 +929,20 @@ " -1.000000e+08\n", " \n", " \n", - " 17\n", + " 25\n", " 1.0\n", " 1.0\n", " 1\n", " -0.507\n", " 1.0\n", - " 0.0\n", - " 0\n", - " -1.204\n", + " 1.0\n", + " 1\n", + " -0.357\n", " 0\n", " -0.597\n", " ...\n", - " 1\n", - " -0.357\n", + " 0\n", + " -1.204\n", " 0\n", " -0.597\n", " False\n", @@ -955,7 +953,7 @@ " -1.000000e+08\n", " \n", " \n", - " 18\n", + " 27\n", " 1.0\n", " 1.0\n", " 1\n", @@ -979,7 +977,7 @@ " -1.000000e+08\n", " \n", " \n", - " 19\n", + " 29\n", " 1.0\n", " 1.0\n", " 1\n", @@ -1009,69 +1007,69 @@ ], "text/plain": [ " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -0.922 1.0 1.0 \n", - "1 1.0 0.0 0 -0.922 1.0 0.0 \n", - "2 1.0 0.0 0 -0.922 1.0 0.0 \n", - "3 1.0 0.0 0 -0.922 1.0 1.0 \n", - "4 1.0 1.0 1 -0.507 1.0 0.0 \n", + "0 1.0 0.0 0 -0.922 1.0 0.0 \n", + "1 1.0 0.0 0 -0.922 1.0 1.0 \n", + "2 1.0 0.0 0 -0.922 1.0 1.0 \n", "5 1.0 0.0 0 -0.922 1.0 0.0 \n", - "8 1.0 1.0 1 -0.507 1.0 1.0 \n", - "9 1.0 1.0 1 -0.507 1.0 1.0 \n", - "10 1.0 1.0 1 -0.507 1.0 1.0 \n", - "11 1.0 1.0 1 -0.507 1.0 1.0 \n", - "12 1.0 0.0 0 -0.922 1.0 1.0 \n", - "13 1.0 0.0 0 -0.922 1.0 1.0 \n", - "14 1.0 0.0 0 -0.922 1.0 1.0 \n", - "15 1.0 0.0 0 -0.922 1.0 1.0 \n", + "8 1.0 1.0 1 -0.507 1.0 0.0 \n", + "12 1.0 0.0 0 -0.922 1.0 0.0 \n", "16 1.0 1.0 1 -0.507 1.0 1.0 \n", - "17 1.0 1.0 1 -0.507 1.0 0.0 \n", + "17 1.0 1.0 1 -0.507 1.0 1.0 \n", "18 1.0 1.0 1 -0.507 1.0 1.0 \n", - "19 1.0 1.0 1 -0.507 1.0 0.0 \n", + "19 1.0 1.0 1 -0.507 1.0 1.0 \n", + "20 1.0 0.0 0 -0.922 1.0 1.0 \n", + "21 1.0 0.0 0 -0.922 1.0 1.0 \n", + "22 1.0 0.0 0 -0.922 1.0 1.0 \n", + "23 1.0 0.0 0 -0.922 1.0 1.0 \n", + "24 1.0 1.0 1 -0.507 1.0 0.0 \n", + "25 1.0 1.0 1 -0.507 1.0 1.0 \n", + "27 1.0 1.0 1 -0.507 1.0 1.0 \n", + "29 1.0 1.0 1 -0.507 1.0 0.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", - "0 1 -0.357 0 -0.597 ... 0 -1.204 \n", - "1 0 -1.204 0 -0.597 ... 1 -0.357 \n", - "2 0 -1.204 0 -0.597 ... 1 -0.357 \n", - "3 1 -0.357 1 -0.799 ... 0 -1.204 \n", - "4 0 -1.204 0 -0.597 ... 0 -1.204 \n", - "5 0 -1.204 0 -0.597 ... 0 -1.204 \n", - "8 1 -0.357 0 -0.597 ... 1 -0.357 \n", - "9 1 -0.357 0 -0.597 ... 1 -0.357 \n", - "10 1 -0.357 1 -0.799 ... 1 -0.357 \n", - "11 1 -0.357 1 -0.799 ... 1 -0.357 \n", - "12 1 -0.357 0 -0.597 ... 1 -0.357 \n", - "13 1 -0.357 1 -0.799 ... 1 -0.357 \n", - "14 1 -0.357 0 -0.597 ... 1 -0.357 \n", - "15 1 -0.357 1 -0.799 ... 1 -0.357 \n", - "16 1 -0.357 0 -0.597 ... 0 -1.204 \n", - "17 0 -1.204 0 -0.597 ... 1 -0.357 \n", - "18 1 -0.357 1 -0.799 ... 0 -1.204 \n", - "19 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "0 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "1 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "2 1 -0.357 1 -0.799 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "8 0 -1.204 0 -0.597 ... 0 -1.204 \n", + "12 0 -1.204 0 -0.597 ... 0 -1.204 \n", + "16 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "17 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "18 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "19 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "20 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "21 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "22 1 -0.357 0 -0.597 ... 1 -0.357 \n", + "23 1 -0.357 1 -0.799 ... 1 -0.357 \n", + "24 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "25 1 -0.357 0 -0.597 ... 0 -1.204 \n", + "27 1 -0.357 1 -0.799 ... 0 -1.204 \n", + "29 0 -1.204 0 -0.597 ... 1 -0.357 \n", "\n", " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", "0 0 -0.597 True 0.0 2 0 2 -3.678000e+00 \n", "1 0 -0.597 True 0.0 2 0 2 -3.678000e+00 \n", - "2 1 -0.799 True 0.0 2 1 3 -3.880000e+00 \n", - "3 0 -0.597 True 0.0 2 1 3 -3.880000e+00 \n", - "4 0 -0.597 True 0.0 2 0 2 -4.109000e+00 \n", - "5 0 -0.597 True 0.0 3 0 3 -4.525000e+00 \n", - "8 0 -0.597 False -100000000.0 0 0 0 -1.000000e+08 \n", - "9 1 -0.799 False -100000000.0 0 1 1 -1.000000e+08 \n", - "10 0 -0.597 False -100000000.0 0 1 1 -1.000000e+08 \n", - "11 1 -0.799 False -100000000.0 0 2 2 -1.000000e+08 \n", - "12 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", - "13 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", - "14 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", - "15 1 -0.799 False -100000000.0 1 2 3 -1.000000e+08 \n", - "16 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", - "17 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", - "18 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", - "19 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", + "2 0 -0.597 True 0.0 2 1 3 -3.880000e+00 \n", + "5 1 -0.799 True 0.0 2 1 3 -3.880000e+00 \n", + "8 0 -0.597 True 0.0 2 0 2 -4.109000e+00 \n", + "12 0 -0.597 True 0.0 3 0 3 -4.525000e+00 \n", + "16 0 -0.597 False -100000000.0 0 0 0 -1.000000e+08 \n", + "17 1 -0.799 False -100000000.0 0 1 1 -1.000000e+08 \n", + "18 0 -0.597 False -100000000.0 0 1 1 -1.000000e+08 \n", + "19 1 -0.799 False -100000000.0 0 2 2 -1.000000e+08 \n", + "20 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "21 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", + "22 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", + "23 1 -0.799 False -100000000.0 1 2 3 -1.000000e+08 \n", + "24 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "25 0 -0.597 False -100000000.0 1 0 1 -1.000000e+08 \n", + "27 0 -0.597 False -100000000.0 1 1 2 -1.000000e+08 \n", + "29 1 -0.799 False -100000000.0 1 1 2 -1.000000e+08 \n", "\n", "[18 rows x 22 columns]" ] }, - "execution_count": 15, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -1096,7 +1094,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -1106,12 +1104,12 @@ "MAP estimate: 0.5\n", "Minimal scenarios:\n", " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -0.922 1.0 1.0 \n", - "1 1.0 0.0 0 -0.922 1.0 0.0 \n", + "0 1.0 0.0 0 -0.922 1.0 0.0 \n", + "1 1.0 0.0 0 -0.922 1.0 1.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote2 alp_vote2 \\\n", - "0 1 -0.357 0 -0.597 ... 0 -1.204 \n", - "1 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "0 0 -1.204 0 -0.597 ... 1 -0.357 \n", + "1 1 -0.357 0 -0.597 ... 0 -1.204 \n", "\n", " wpr_vote2 wlp_vote2 cdif clp int wpr changes sum_lp \n", "0 0 -0.597 True 0.0 2 0 2 -3.678 \n", @@ -1127,7 +1125,7 @@ "0.5" ] }, - "execution_count": 16, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -1138,7 +1136,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -1170,14 +1168,15 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.5337756999955469 , w: 0.5037726755473835 .\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.5337756999955469 , w: 0.5037726755473835 .\n", + "IndexSet({'vote0': {0, 1}, 'vote1': {0, 1}, 'vote2': {0, 1}, 'vote3': {0, 1}, 'vote4': {0, 1}, 'vote5': {0, 1}, 'vote6': {0, 1}})\n" ] } ], @@ -1214,7 +1213,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -1224,22 +1223,26 @@ "MAP estimate: 0.25\n", "Minimal scenarios:\n", " obs_vote0 int_vote0 epr_vote0 elp_vote0 obs_vote1 int_vote1 \\\n", - "0 1.0 0.0 0 -0.763 1.0 0.0 \n", - "1 1.0 0.0 0 -0.763 1.0 1.0 \n", - "2 1.0 0.0 0 -0.763 1.0 1.0 \n", - "3 1.0 0.0 0 -0.763 1.0 0.0 \n", - "4 1.0 0.0 0 -0.763 1.0 1.0 \n", + "0 1.0 0.0 0 -0.763 1.0 1.0 \n", + "1 1.0 0.0 0 -0.763 1.0 0.0 \n", + "2 1.0 0.0 0 -0.763 1.0 0.0 \n", + "3 1.0 0.0 0 -0.763 1.0 1.0 \n", + "4 1.0 0.0 0 -0.763 1.0 0.0 \n", "5 1.0 0.0 0 -0.763 1.0 0.0 \n", - "6 1.0 0.0 0 -0.763 1.0 0.0 \n", + "6 1.0 0.0 0 -0.763 1.0 1.0 \n", + "7 1.0 0.0 0 -0.763 1.0 1.0 \n", + "8 1.0 0.0 0 -0.763 1.0 1.0 \n", "\n", " apr_vote1 alp_vote1 wpr_vote1 wlp_vote1 ... apr_vote6 alp_vote6 \\\n", - "0 0 -1.204 0 -0.686 ... 1 -0.357 \n", - "1 1 -0.357 0 -0.686 ... 0 -1.204 \n", - "2 1 -0.357 0 -0.686 ... 1 -0.357 \n", - "3 0 -1.204 0 -0.686 ... 1 -0.357 \n", - "4 1 -0.357 0 -0.686 ... 0 -1.204 \n", - "5 0 -1.204 0 -0.686 ... 1 -0.357 \n", - "6 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "0 1 -0.357 0 -0.686 ... 1 -0.357 \n", + "1 0 -1.204 0 -0.686 ... 0 -1.204 \n", + "2 0 -1.204 0 -0.686 ... 1 -0.357 \n", + "3 1 -0.357 0 -0.686 ... 1 -0.357 \n", + "4 0 -1.204 0 -0.686 ... 0 -1.204 \n", + "5 0 -1.204 0 -0.686 ... 0 -1.204 \n", + "6 1 -0.357 0 -0.686 ... 0 -1.204 \n", + "7 1 -0.357 0 -0.686 ... 0 -1.204 \n", + "8 1 -0.357 0 -0.686 ... 0 -1.204 \n", "\n", " wpr_vote6 wlp_vote6 cdif clp int wpr changes sum_lp \n", "0 0 -0.686 True 0.0 4 0 4 -9.559 \n", @@ -1249,8 +1252,10 @@ "4 0 -0.686 True 0.0 4 0 4 -9.559 \n", "5 0 -0.686 True 0.0 4 0 4 -9.559 \n", "6 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "7 0 -0.686 True 0.0 4 0 4 -9.559 \n", + "8 0 -0.686 True 0.0 4 0 4 -9.559 \n", "\n", - "[7 rows x 46 columns]\n", + "[9 rows x 46 columns]\n", "Secondary check: 0.25\n" ] }, @@ -1260,7 +1265,7 @@ "0.25" ] }, - "execution_count": 19, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1287,7 +1292,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -1346,14 +1351,15 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Preemption biases used (upper) - t: 0.7 , n: 0.5693194340229132 , w: 0.5214645870036448 .\n" + "Preemption biases used (upper) - t: 0.7 , n: 0.698 , w: 0.696 .\n", + "IndexSet({'sally_throws': {0, 1}, 'bill_hits': {0, 1}})\n" ] } ], @@ -1363,9 +1369,9 @@ "stones_sallyHPR = HalpernPearlResponsibilityApproximate(\n", " model = stones_model,\n", " evaluated_node_counterfactual= {\"sally_throws\": 0.0},\n", - " treatment_candidates = {\"sally_hits\": 0.0, \"bill_hits\": 1.0, \"bill_throws\": 0.0},\n", + " treatment_candidates = {\"bill_hits\": 1.0},\n", " outcome = \"bottle_shatters\",\n", - " witness_candidates = [\"bill_hits\", \"bill_throws\", \"sally_hits\"],\n", + " witness_candidates = [\"bill_hits\"],\n", " observations = {\"prob_sally_throws\": 1.0, \n", " \"prob_bill_throws\": 1.0,\n", " \"prob_sally_hits\": 1.0,\n", @@ -1373,14 +1379,182 @@ " \"prob_bottle_shatters_if_sally\": 1.0,\n", " \"prob_bottle_shatters_if_bill\": 1.0})\n", "\n", - "with pyro.plate(\"runs\",10000):\n", + "with pyro.plate(\"runs\",1000):\n", " stones_sallyHPR()" ] }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "odict_keys(['prob_sally_throws', 'prob_bill_throws', 'prob_sally_hits', 'prob_bill_hits', 'prob_bottle_shatters_if_sally', 'prob_bottle_shatters_if_bill', '__evaluated_split_sally_throws', 'sally_throws', 'bill_throws', 'sally_hits', '__witness_split_bill_hits', '__treatment_split_bill_hits', 'bill_hits', 'bottle_shatters', 'consequent_differs_binary', 'consequent_differs'])\n", + "tensor([[[[[1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 0., 1.]]]]])\n", + "tensor([1, 1, 1, 0, 1])\n", + "tensor([1., 1., 1., 1., 1.])\n" + ] + } + ], + "source": [ + "str = stones_sallyHPR.trace.trace.nodes\n", + "\n", + "print(str.keys())\n", + "\n", + "print(str[\"sally_throws\"][\"value\"])\n", + "print(str[\"__evaluated_split_sally_throws\"][\"value\"])\n", + "\n", + "print(str[\"bill_throws\"][\"value\"])\n" + ] + }, { "cell_type": "code", "execution_count": 22, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[[[[0., 0., 0., 0., 0.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 0., 0., 1., 0.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[0., 1., 1., 1., 0.]]]],\n", + "\n", + "\n", + "\n", + " [[[[0., 1., 1., 1., 0.]]]]]])\n", + "treatment: tensor([1, 0, 0, 0, 1])\n", + "witness: tensor([0, 1, 1, 0, 0])\n" + ] + } + ], + "source": [ + "print(str[\"bill_hits\"][\"value\"])\n", + "print(\"treatment:\", str[\"__treatment_split_bill_hits\"][\"value\"])\n", + "print(\"witness:\", str[\"__witness_split_bill_hits\"][\"value\"])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "IndexSet({})" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "indices_of(stones_sallyHPR.run[\"bill_hits\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "IndexSet({})" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "indices_of(stones_sallyHPR.run[\"bottle_shatters\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[[[1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1.]]]]],\n", + "\n", + "\n", + "\n", + "\n", + " [[[[[1., 1., 1., 1., 1.]]]],\n", + "\n", + "\n", + "\n", + " [[[[1., 1., 1., 1., 1.]]]]]])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "str[\"bottle_shatters\"][\"value\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "All arrays must be of the same length", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/rafal/UGPOP/projectsUGPOP/chirho/docs/source/responsibility_exploration.ipynb Cell 32\u001b[0m in \u001b[0;36m1\n\u001b[0;32m----> 1\u001b[0m tab \u001b[39m=\u001b[39m get_table(stones_sallyHPR\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mtrace\u001b[39m.\u001b[39;49mnodes, \u001b[39m\"\u001b[39;49m\u001b[39msally_throws\u001b[39;49m\u001b[39m\"\u001b[39;49m, stones_sallyHPR\u001b[39m.\u001b[39;49mtreatment_candidates, \n\u001b[1;32m 2\u001b[0m stones_sallyHPR\u001b[39m.\u001b[39;49mwitness_candidates)\n\u001b[1;32m 4\u001b[0m tab\n", + "\u001b[1;32m/home/rafal/UGPOP/projectsUGPOP/chirho/docs/source/responsibility_exploration.ipynb Cell 32\u001b[0m in \u001b[0;36m3\n\u001b[1;32m 29\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mDataFrame([values_table])\n\u001b[1;32m 30\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 31\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39;49mDataFrame(values_table)\n\u001b[1;32m 33\u001b[0m values_df \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mDataFrame(values_table)\n\u001b[1;32m 35\u001b[0m summands_ant \u001b[39m=\u001b[39m [\u001b[39m'\u001b[39m\u001b[39malp_\u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m antecedent \u001b[39mfor\u001b[39;00m antecedent \u001b[39min\u001b[39;00m antecedents]\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/frame.py:662\u001b[0m, in \u001b[0;36mDataFrame.__init__\u001b[0;34m(self, data, index, columns, dtype, copy)\u001b[0m\n\u001b[1;32m 656\u001b[0m mgr \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_init_mgr(\n\u001b[1;32m 657\u001b[0m data, axes\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39mindex\u001b[39m\u001b[39m\"\u001b[39m: index, \u001b[39m\"\u001b[39m\u001b[39mcolumns\u001b[39m\u001b[39m\"\u001b[39m: columns}, dtype\u001b[39m=\u001b[39mdtype, copy\u001b[39m=\u001b[39mcopy\n\u001b[1;32m 658\u001b[0m )\n\u001b[1;32m 660\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, \u001b[39mdict\u001b[39m):\n\u001b[1;32m 661\u001b[0m \u001b[39m# GH#38939 de facto copy defaults to False only in non-dict cases\u001b[39;00m\n\u001b[0;32m--> 662\u001b[0m mgr \u001b[39m=\u001b[39m dict_to_mgr(data, index, columns, dtype\u001b[39m=\u001b[39;49mdtype, copy\u001b[39m=\u001b[39;49mcopy, typ\u001b[39m=\u001b[39;49mmanager)\n\u001b[1;32m 663\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(data, ma\u001b[39m.\u001b[39mMaskedArray):\n\u001b[1;32m 664\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mma\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmrecords\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mmrecords\u001b[39;00m\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:493\u001b[0m, in \u001b[0;36mdict_to_mgr\u001b[0;34m(data, index, columns, dtype, typ, copy)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 490\u001b[0m \u001b[39m# dtype check to exclude e.g. range objects, scalars\u001b[39;00m\n\u001b[1;32m 491\u001b[0m arrays \u001b[39m=\u001b[39m [x\u001b[39m.\u001b[39mcopy() \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(x, \u001b[39m\"\u001b[39m\u001b[39mdtype\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39melse\u001b[39;00m x \u001b[39mfor\u001b[39;00m x \u001b[39min\u001b[39;00m arrays]\n\u001b[0;32m--> 493\u001b[0m \u001b[39mreturn\u001b[39;00m arrays_to_mgr(arrays, columns, index, dtype\u001b[39m=\u001b[39;49mdtype, typ\u001b[39m=\u001b[39;49mtyp, consolidate\u001b[39m=\u001b[39;49mcopy)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:118\u001b[0m, in \u001b[0;36marrays_to_mgr\u001b[0;34m(arrays, columns, index, dtype, verify_integrity, typ, consolidate)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[39mif\u001b[39;00m verify_integrity:\n\u001b[1;32m 116\u001b[0m \u001b[39m# figure out the index, if necessary\u001b[39;00m\n\u001b[1;32m 117\u001b[0m \u001b[39mif\u001b[39;00m index \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 118\u001b[0m index \u001b[39m=\u001b[39m _extract_index(arrays)\n\u001b[1;32m 119\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 120\u001b[0m index \u001b[39m=\u001b[39m ensure_index(index)\n", + "File \u001b[0;32m~/.local/lib/python3.10/site-packages/pandas/core/internals/construction.py:666\u001b[0m, in \u001b[0;36m_extract_index\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 664\u001b[0m lengths \u001b[39m=\u001b[39m \u001b[39mlist\u001b[39m(\u001b[39mset\u001b[39m(raw_lengths))\n\u001b[1;32m 665\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(lengths) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[0;32m--> 666\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mAll arrays must be of the same length\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 668\u001b[0m \u001b[39mif\u001b[39;00m have_dicts:\n\u001b[1;32m 669\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 670\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mMixing dicts with non-Series may lead to ambiguous ordering.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 671\u001b[0m )\n", + "\u001b[0;31mValueError\u001b[0m: All arrays must be of the same length" + ] + } + ], + "source": [ + "\n", + "# tab = get_table(stones_sallyHPR.trace.trace.nodes, \"sally_throws\", stones_sallyHPR.treatment_candidates, \n", + "# stones_sallyHPR.witness_candidates)\n", + "\n", + "# tab" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, "outputs": [ { "data": { @@ -1405,18 +1579,10 @@ " \n", " epr_sally_throws\n", " elp_sally_throws\n", - " apr_sally_hits\n", - " alp_sally_hits\n", - " wpr_sally_hits\n", - " wlp_sally_hits\n", " apr_bill_hits\n", " alp_bill_hits\n", " wpr_bill_hits\n", " wlp_bill_hits\n", - " apr_bill_throws\n", - " alp_bill_throws\n", - " wpr_bill_throws\n", - " wlp_bill_throws\n", " clp\n", " int\n", " wpr\n", @@ -1428,1189 +1594,85 @@ " \n", " 0\n", " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", + " -0.360\n", " 1\n", " -0.357\n", " 0\n", - " -0.651\n", - " 0.0\n", - " 1\n", - " 1\n", - " 2\n", - " -4.520000e+00\n", + " -0.362\n", + " -100000000.0\n", + " 0\n", + " 0\n", + " 0\n", + " -1.000000e+08\n", " \n", " \n", " 1\n", " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", + " -0.360\n", " 1\n", " -0.357\n", " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0.0\n", - " 1\n", - " 2\n", - " 3\n", - " -4.606000e+00\n", - " \n", - " \n", - " 2\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0.0\n", - " 2\n", - " 0\n", - " 2\n", - " -4.713000e+00\n", - " \n", - " \n", - " 3\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0.0\n", - " 2\n", - " 1\n", - " 3\n", - " -4.799000e+00\n", - " \n", - " \n", - " 4\n", - " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0.0\n", - " 2\n", - " 0\n", - " 2\n", - " -5.281000e+00\n", - " \n", - " \n", - " 5\n", - " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0.0\n", - " 2\n", - " 1\n", - " 3\n", - " -5.367000e+00\n", - " \n", - " \n", - " 9\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0.0\n", - " 3\n", - " 0\n", - " 3\n", - " -5.560000e+00\n", - " \n", - " \n", - " 13\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 0\n", - " 0\n", - " 0\n", - " -1.000000e+08\n", - " \n", - " \n", - " 14\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 0\n", - " 1\n", - " 1\n", - " -1.000000e+08\n", - " \n", - " \n", - " 15\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 0\n", - " 1\n", - " 1\n", - " -1.000000e+08\n", - " \n", - " \n", - " 16\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 0\n", - " 1\n", - " 1\n", - " -1.000000e+08\n", - " \n", - " \n", - " 17\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 0\n", - " 2\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 18\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 0\n", - " 2\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 19\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 0\n", - " 2\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 20\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 0\n", - " 3\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 21\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 0\n", - " 1\n", - " -1.000000e+08\n", - " \n", - " \n", - " 22\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 23\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 24\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 25\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 1\n", - " 2\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 26\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 2\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 27\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 1\n", - " 2\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 28\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 1\n", - " 3\n", - " 4\n", - " -1.000000e+08\n", - " \n", - " \n", - " 29\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 0\n", - " 1\n", - " -1.000000e+08\n", - " \n", - " \n", - " 30\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 0\n", - " 1\n", - " -1.000000e+08\n", - " \n", - " \n", - " 31\n", - " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 0\n", - " 1\n", - " -1.000000e+08\n", - " \n", - " \n", - " 32\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 33\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 36\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 37\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 38\n", - " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 1\n", - " 1\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 40\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 1\n", - " 2\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 47\n", - " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 1\n", - " 2\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 51\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 2\n", - " 0\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 52\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 2\n", - " 0\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 53\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 2\n", - " 1\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 54\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 2\n", - " 1\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 57\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 2\n", - " 1\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 59\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", + " -1.191\n", " -100000000.0\n", - " 2\n", - " 1\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 60\n", " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 2\n", " 1\n", - " 3\n", " -1.000000e+08\n", " \n", " \n", - " 61\n", + " 2\n", " 0\n", - " -0.842\n", + " -1.197\n", " 1\n", " -0.357\n", - " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", + " -0.362\n", " -100000000.0\n", - " 2\n", - " 2\n", - " 4\n", - " -1.000000e+08\n", - " \n", - " \n", - " 67\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", " 1\n", - " -0.737\n", - " -100000000.0\n", - " 2\n", - " 2\n", - " 4\n", - " -1.000000e+08\n", - " \n", - " \n", - " 68\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 2\n", - " 2\n", - " 4\n", " -1.000000e+08\n", " \n", " \n", - " 73\n", - " 1\n", - " -0.563\n", + " 3\n", " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", + " -0.360\n", " 0\n", " -1.204\n", " 0\n", - " -0.651\n", + " -0.362\n", " -100000000.0\n", - " 2\n", - " 0\n", - " 2\n", - " -1.000000e+08\n", - " \n", - " \n", - " 74\n", " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 2\n", - " 0\n", - " 2\n", " -1.000000e+08\n", " \n", " \n", - " 75\n", - " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", + " 4\n", " 0\n", - " -0.651\n", + " -1.197\n", " 1\n", " -0.357\n", " 1\n", - " -0.737\n", + " -1.191\n", " -100000000.0\n", - " 2\n", - " 1\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 77\n", " 1\n", - " -0.563\n", - " 1\n", - " -0.357\n", " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", " 2\n", - " 1\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 92\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 3\n", - " 0\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 93\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 3\n", - " 0\n", - " 3\n", - " -1.000000e+08\n", - " \n", - " \n", - " 95\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 3\n", - " 1\n", - " 4\n", - " -1.000000e+08\n", - " \n", - " \n", - " 96\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 3\n", - " 1\n", - " 4\n", - " -1.000000e+08\n", - " \n", - " \n", - " 99\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 1\n", - " -0.357\n", - " 1\n", - " -0.737\n", - " -100000000.0\n", - " 3\n", - " 1\n", - " 4\n", - " -1.000000e+08\n", - " \n", - " \n", - " 114\n", - " 1\n", - " -0.563\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", - " -100000000.0\n", - " 3\n", - " 0\n", - " 3\n", " -1.000000e+08\n", " \n", " \n", - " 120\n", - " 0\n", - " -0.842\n", - " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", + " 6\n", " 0\n", - " -1.204\n", - " 0\n", - " -0.651\n", + " -1.197\n", " 0\n", " -1.204\n", " 0\n", - " -0.651\n", + " -0.362\n", " -100000000.0\n", - " 4\n", + " 2\n", " 0\n", - " 4\n", + " 2\n", " -1.000000e+08\n", " \n", " \n", @@ -2618,288 +1680,24 @@ "" ], "text/plain": [ - " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", - "0 1 -0.563 0 -1.204 \n", - "1 1 -0.563 0 -1.204 \n", - "2 0 -0.842 1 -0.357 \n", - "3 0 -0.842 1 -0.357 \n", - "4 1 -0.563 0 -1.204 \n", - "5 1 -0.563 0 -1.204 \n", - "9 0 -0.842 0 -1.204 \n", - "13 1 -0.563 1 -0.357 \n", - "14 1 -0.563 1 -0.357 \n", - "15 1 -0.563 1 -0.357 \n", - "16 1 -0.563 1 -0.357 \n", - "17 1 -0.563 1 -0.357 \n", - "18 1 -0.563 1 -0.357 \n", - "19 1 -0.563 1 -0.357 \n", - "20 1 -0.563 1 -0.357 \n", - "21 0 -0.842 1 -0.357 \n", - "22 0 -0.842 1 -0.357 \n", - "23 0 -0.842 1 -0.357 \n", - "24 0 -0.842 1 -0.357 \n", - "25 0 -0.842 1 -0.357 \n", - "26 0 -0.842 1 -0.357 \n", - "27 0 -0.842 1 -0.357 \n", - "28 0 -0.842 1 -0.357 \n", - "29 1 -0.563 1 -0.357 \n", - "30 1 -0.563 1 -0.357 \n", - "31 1 -0.563 0 -1.204 \n", - "32 1 -0.563 1 -0.357 \n", - "33 1 -0.563 1 -0.357 \n", - "36 1 -0.563 1 -0.357 \n", - "37 1 -0.563 1 -0.357 \n", - "38 1 -0.563 0 -1.204 \n", - "40 1 -0.563 1 -0.357 \n", - "47 1 -0.563 1 -0.357 \n", - "51 0 -0.842 1 -0.357 \n", - "52 0 -0.842 0 -1.204 \n", - "53 0 -0.842 1 -0.357 \n", - "54 0 -0.842 1 -0.357 \n", - "57 0 -0.842 0 -1.204 \n", - "59 0 -0.842 1 -0.357 \n", - "60 0 -0.842 0 -1.204 \n", - "61 0 -0.842 1 -0.357 \n", - "67 0 -0.842 1 -0.357 \n", - "68 0 -0.842 0 -1.204 \n", - "73 1 -0.563 1 -0.357 \n", - "74 1 -0.563 0 -1.204 \n", - "75 1 -0.563 0 -1.204 \n", - "77 1 -0.563 1 -0.357 \n", - "92 0 -0.842 1 -0.357 \n", - "93 0 -0.842 0 -1.204 \n", - "95 0 -0.842 1 -0.357 \n", - "96 0 -0.842 0 -1.204 \n", - "99 0 -0.842 0 -1.204 \n", - "114 1 -0.563 0 -1.204 \n", - "120 0 -0.842 0 -1.204 \n", + " epr_sally_throws elp_sally_throws apr_bill_hits alp_bill_hits \\\n", + "0 1 -0.360 1 -0.357 \n", + "1 1 -0.360 1 -0.357 \n", + "2 0 -1.197 1 -0.357 \n", + "3 1 -0.360 0 -1.204 \n", + "4 0 -1.197 1 -0.357 \n", + "6 0 -1.197 0 -1.204 \n", "\n", - " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", - "0 0 -0.651 1 -0.357 \n", - "1 0 -0.651 1 -0.357 \n", - "2 0 -0.651 1 -0.357 \n", - "3 1 -0.737 1 -0.357 \n", - "4 0 -0.651 1 -0.357 \n", - "5 0 -0.651 1 -0.357 \n", - "9 0 -0.651 1 -0.357 \n", - "13 0 -0.651 1 -0.357 \n", - "14 0 -0.651 1 -0.357 \n", - "15 0 -0.651 1 -0.357 \n", - "16 1 -0.737 1 -0.357 \n", - "17 1 -0.737 1 -0.357 \n", - "18 0 -0.651 1 -0.357 \n", - "19 1 -0.737 1 -0.357 \n", - "20 1 -0.737 1 -0.357 \n", - "21 0 -0.651 1 -0.357 \n", - "22 1 -0.737 1 -0.357 \n", - "23 0 -0.651 1 -0.357 \n", - "24 0 -0.651 1 -0.357 \n", - "25 0 -0.651 1 -0.357 \n", - "26 1 -0.737 1 -0.357 \n", - "27 1 -0.737 1 -0.357 \n", - "28 1 -0.737 1 -0.357 \n", - "29 0 -0.651 0 -1.204 \n", - "30 0 -0.651 1 -0.357 \n", - "31 0 -0.651 1 -0.357 \n", - "32 0 -0.651 0 -1.204 \n", - "33 0 -0.651 1 -0.357 \n", - "36 1 -0.737 1 -0.357 \n", - "37 1 -0.737 0 -1.204 \n", - "38 0 -0.651 1 -0.357 \n", - "40 1 -0.737 0 -1.204 \n", - "47 1 -0.737 1 -0.357 \n", - "51 0 -0.651 0 -1.204 \n", - "52 0 -0.651 1 -0.357 \n", - "53 0 -0.651 0 -1.204 \n", - "54 1 -0.737 0 -1.204 \n", - "57 0 -0.651 1 -0.357 \n", - "59 0 -0.651 1 -0.357 \n", - "60 0 -0.651 1 -0.357 \n", - "61 1 -0.737 1 -0.357 \n", - "67 1 -0.737 0 -1.204 \n", - "68 0 -0.651 1 -0.357 \n", - "73 0 -0.651 0 -1.204 \n", - "74 0 -0.651 0 -1.204 \n", - "75 0 -0.651 0 -1.204 \n", - "77 1 -0.737 0 -1.204 \n", - "92 0 -0.651 0 -1.204 \n", - "93 0 -0.651 0 -1.204 \n", - "95 1 -0.737 0 -1.204 \n", - "96 0 -0.651 1 -0.357 \n", - "99 0 -0.651 0 -1.204 \n", - "114 0 -0.651 0 -1.204 \n", - "120 0 -0.651 0 -1.204 \n", - "\n", - " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", - "0 1 -0.737 1 -0.357 \n", - "1 1 -0.737 1 -0.357 \n", - "2 0 -0.651 0 -1.204 \n", - "3 0 -0.651 0 -1.204 \n", - "4 0 -0.651 0 -1.204 \n", - "5 1 -0.737 0 -1.204 \n", - "9 0 -0.651 0 -1.204 \n", - "13 0 -0.651 1 -0.357 \n", - "14 1 -0.737 1 -0.357 \n", - "15 0 -0.651 1 -0.357 \n", - "16 0 -0.651 1 -0.357 \n", - "17 1 -0.737 1 -0.357 \n", - "18 1 -0.737 1 -0.357 \n", - "19 0 -0.651 1 -0.357 \n", - "20 1 -0.737 1 -0.357 \n", - "21 0 -0.651 1 -0.357 \n", - "22 0 -0.651 1 -0.357 \n", - "23 1 -0.737 1 -0.357 \n", - "24 0 -0.651 1 -0.357 \n", - "25 1 -0.737 1 -0.357 \n", - "26 1 -0.737 1 -0.357 \n", - "27 0 -0.651 1 -0.357 \n", - "28 1 -0.737 1 -0.357 \n", - "29 0 -0.651 1 -0.357 \n", - "30 0 -0.651 0 -1.204 \n", - "31 0 -0.651 1 -0.357 \n", - "32 0 -0.651 1 -0.357 \n", - "33 1 -0.737 0 -1.204 \n", - "36 0 -0.651 0 -1.204 \n", - "37 0 -0.651 1 -0.357 \n", - "38 0 -0.651 1 -0.357 \n", - "40 0 -0.651 1 -0.357 \n", - "47 1 -0.737 0 -1.204 \n", - "51 0 -0.651 1 -0.357 \n", - "52 0 -0.651 1 -0.357 \n", - "53 0 -0.651 1 -0.357 \n", - "54 0 -0.651 1 -0.357 \n", - "57 0 -0.651 1 -0.357 \n", - "59 1 -0.737 0 -1.204 \n", - "60 1 -0.737 1 -0.357 \n", - "61 1 -0.737 0 -1.204 \n", - "67 0 -0.651 1 -0.357 \n", - "68 1 -0.737 1 -0.357 \n", - "73 0 -0.651 0 -1.204 \n", - "74 0 -0.651 1 -0.357 \n", - "75 0 -0.651 1 -0.357 \n", - "77 0 -0.651 0 -1.204 \n", - "92 0 -0.651 0 -1.204 \n", - "93 0 -0.651 1 -0.357 \n", - "95 0 -0.651 0 -1.204 \n", - "96 1 -0.737 0 -1.204 \n", - "99 0 -0.651 1 -0.357 \n", - "114 0 -0.651 0 -1.204 \n", - "120 0 -0.651 0 -1.204 \n", - "\n", - " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", - "0 0 -0.651 0.0 1 1 2 \n", - "1 1 -0.737 0.0 1 2 3 \n", - "2 0 -0.651 0.0 2 0 2 \n", - "3 0 -0.651 0.0 2 1 3 \n", - "4 0 -0.651 0.0 2 0 2 \n", - "5 0 -0.651 0.0 2 1 3 \n", - "9 0 -0.651 0.0 3 0 3 \n", - "13 0 -0.651 -100000000.0 0 0 0 \n", - "14 0 -0.651 -100000000.0 0 1 1 \n", - "15 1 -0.737 -100000000.0 0 1 1 \n", - "16 0 -0.651 -100000000.0 0 1 1 \n", - "17 0 -0.651 -100000000.0 0 2 2 \n", - "18 1 -0.737 -100000000.0 0 2 2 \n", - "19 1 -0.737 -100000000.0 0 2 2 \n", - "20 1 -0.737 -100000000.0 0 3 3 \n", - "21 0 -0.651 -100000000.0 1 0 1 \n", - "22 0 -0.651 -100000000.0 1 1 2 \n", - "23 0 -0.651 -100000000.0 1 1 2 \n", - "24 1 -0.737 -100000000.0 1 1 2 \n", - "25 1 -0.737 -100000000.0 1 2 3 \n", - "26 0 -0.651 -100000000.0 1 2 3 \n", - "27 1 -0.737 -100000000.0 1 2 3 \n", - "28 1 -0.737 -100000000.0 1 3 4 \n", - "29 0 -0.651 -100000000.0 1 0 1 \n", - "30 0 -0.651 -100000000.0 1 0 1 \n", - "31 0 -0.651 -100000000.0 1 0 1 \n", - "32 1 -0.737 -100000000.0 1 1 2 \n", - "33 0 -0.651 -100000000.0 1 1 2 \n", - "36 0 -0.651 -100000000.0 1 1 2 \n", - "37 0 -0.651 -100000000.0 1 1 2 \n", - "38 1 -0.737 -100000000.0 1 1 2 \n", - "40 1 -0.737 -100000000.0 1 2 3 \n", - "47 0 -0.651 -100000000.0 1 2 3 \n", - "51 0 -0.651 -100000000.0 2 0 2 \n", - "52 0 -0.651 -100000000.0 2 0 2 \n", - "53 1 -0.737 -100000000.0 2 1 3 \n", - "54 0 -0.651 -100000000.0 2 1 3 \n", - "57 1 -0.737 -100000000.0 2 1 3 \n", - "59 0 -0.651 -100000000.0 2 1 3 \n", - "60 0 -0.651 -100000000.0 2 1 3 \n", - "61 0 -0.651 -100000000.0 2 2 4 \n", - "67 1 -0.737 -100000000.0 2 2 4 \n", - "68 1 -0.737 -100000000.0 2 2 4 \n", - "73 0 -0.651 -100000000.0 2 0 2 \n", - "74 0 -0.651 -100000000.0 2 0 2 \n", - "75 1 -0.737 -100000000.0 2 1 3 \n", - "77 0 -0.651 -100000000.0 2 1 3 \n", - "92 0 -0.651 -100000000.0 3 0 3 \n", - "93 0 -0.651 -100000000.0 3 0 3 \n", - "95 0 -0.651 -100000000.0 3 1 4 \n", - "96 0 -0.651 -100000000.0 3 1 4 \n", - "99 1 -0.737 -100000000.0 3 1 4 \n", - "114 0 -0.651 -100000000.0 3 0 3 \n", - "120 0 -0.651 -100000000.0 4 0 4 \n", - "\n", - " sum_lp \n", - "0 -4.520000e+00 \n", - "1 -4.606000e+00 \n", - "2 -4.713000e+00 \n", - "3 -4.799000e+00 \n", - "4 -5.281000e+00 \n", - "5 -5.367000e+00 \n", - "9 -5.560000e+00 \n", - "13 -1.000000e+08 \n", - "14 -1.000000e+08 \n", - "15 -1.000000e+08 \n", - "16 -1.000000e+08 \n", - "17 -1.000000e+08 \n", - "18 -1.000000e+08 \n", - "19 -1.000000e+08 \n", - "20 -1.000000e+08 \n", - "21 -1.000000e+08 \n", - "22 -1.000000e+08 \n", - "23 -1.000000e+08 \n", - "24 -1.000000e+08 \n", - "25 -1.000000e+08 \n", - "26 -1.000000e+08 \n", - "27 -1.000000e+08 \n", - "28 -1.000000e+08 \n", - "29 -1.000000e+08 \n", - "30 -1.000000e+08 \n", - "31 -1.000000e+08 \n", - "32 -1.000000e+08 \n", - "33 -1.000000e+08 \n", - "36 -1.000000e+08 \n", - "37 -1.000000e+08 \n", - "38 -1.000000e+08 \n", - "40 -1.000000e+08 \n", - "47 -1.000000e+08 \n", - "51 -1.000000e+08 \n", - "52 -1.000000e+08 \n", - "53 -1.000000e+08 \n", - "54 -1.000000e+08 \n", - "57 -1.000000e+08 \n", - "59 -1.000000e+08 \n", - "60 -1.000000e+08 \n", - "61 -1.000000e+08 \n", - "67 -1.000000e+08 \n", - "68 -1.000000e+08 \n", - "73 -1.000000e+08 \n", - "74 -1.000000e+08 \n", - "75 -1.000000e+08 \n", - "77 -1.000000e+08 \n", - "92 -1.000000e+08 \n", - "93 -1.000000e+08 \n", - "95 -1.000000e+08 \n", - "96 -1.000000e+08 \n", - "99 -1.000000e+08 \n", - "114 -1.000000e+08 \n", - "120 -1.000000e+08 " + " wpr_bill_hits wlp_bill_hits clp int wpr changes sum_lp \n", + "0 0 -0.362 -100000000.0 0 0 0 -1.000000e+08 \n", + "1 1 -1.191 -100000000.0 0 1 1 -1.000000e+08 \n", + "2 0 -0.362 -100000000.0 1 0 1 -1.000000e+08 \n", + "3 0 -0.362 -100000000.0 1 0 1 -1.000000e+08 \n", + "4 1 -1.191 -100000000.0 1 1 2 -1.000000e+08 \n", + "6 0 -0.362 -100000000.0 2 0 2 -1.000000e+08 " ] }, - "execution_count": 22, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -2972,7 +1770,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -2998,18 +1796,10 @@ " \n", " epr_sally_throws\n", " elp_sally_throws\n", - " apr_sally_hits\n", - " alp_sally_hits\n", - " wpr_sally_hits\n", - " wlp_sally_hits\n", " apr_bill_hits\n", " alp_bill_hits\n", " wpr_bill_hits\n", " wlp_bill_hits\n", - " apr_bill_throws\n", - " alp_bill_throws\n", - " wpr_bill_throws\n", - " wlp_bill_throws\n", " clp\n", " int\n", " wpr\n", @@ -3019,21 +1809,13 @@ " \n", " \n", " \n", - " 23\n", - " 0\n", - " -0.842\n", - " 1\n", - " -0.357\n", + " 4\n", " 0\n", - " -0.651\n", + " -1.197\n", " 1\n", " -0.357\n", " 1\n", - " -0.737\n", - " 1\n", - " -0.357\n", - " 0\n", - " -0.651\n", + " -1.191\n", " -100000000.0\n", " 1\n", " 1\n", @@ -3045,23 +1827,14 @@ "" ], "text/plain": [ - " epr_sally_throws elp_sally_throws apr_sally_hits alp_sally_hits \\\n", - "23 0 -0.842 1 -0.357 \n", + " epr_sally_throws elp_sally_throws apr_bill_hits alp_bill_hits \\\n", + "4 0 -1.197 1 -0.357 \n", "\n", - " wpr_sally_hits wlp_sally_hits apr_bill_hits alp_bill_hits \\\n", - "23 0 -0.651 1 -0.357 \n", - "\n", - " wpr_bill_hits wlp_bill_hits apr_bill_throws alp_bill_throws \\\n", - "23 1 -0.737 1 -0.357 \n", - "\n", - " wpr_bill_throws wlp_bill_throws clp int wpr changes \\\n", - "23 0 -0.651 -100000000.0 1 1 2 \n", - "\n", - " sum_lp \n", - "23 -1.000000e+08 " + " wpr_bill_hits wlp_bill_hits clp int wpr changes sum_lp \n", + "4 1 -1.191 -100000000.0 1 1 2 -1.000000e+08 " ] }, - "execution_count": 23, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -3069,77 +1842,42 @@ "source": [ "#this is worrying, should be on top with clp == 0\n", "\n", - "tab.query(\"epr_sally_throws == 0 & apr_sally_hits == 1 & wpr_sally_hits == 0 & apr_bill_hits == 1 & wpr_bill_hits == 1 & apr_bill_throws == 1 & wpr_bill_throws == 0\")" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'type': 'sample',\n", - " 'name': 'consequent_differs_binary',\n", - " 'fn': MaskedDistribution(),\n", - " 'is_observed': True,\n", - " 'args': (),\n", - " 'kwargs': {},\n", - " 'value': tensor([[[[[False, False, False, False, False, False, False, False, False,\n", - " False]]]]]),\n", - " 'infer': {'_deterministic': True},\n", - " 'scale': 1.0,\n", - " 'mask': None,\n", - " 'cond_indep_stack': (CondIndepStackFrame(name='runs', dim=-1, size=10, counter=0),),\n", - " 'done': True,\n", - " 'stop': False,\n", - " 'continuation': None}" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "stones_sallyHPR.trace.trace.nodes[\"consequent_differs_binary\"]" + "tab.query(\"epr_sally_throws == 0 & apr_bill_hits == 1 & wpr_bill_hits == 1\")" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'obs_sally_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", - " 'int_sally_throws': [1.0, 1.0, 1.0, 1.0, 0.0],\n", - " 'epr_sally_throws': [1, 1, 1, 1, 0],\n", - " 'elp_sally_throws': [-0.35953617095947266,\n", - " -0.35953617095947266,\n", - " -0.35953617095947266,\n", - " -0.35953617095947266,\n", - " -1.1973283290863037],\n", - " 'obs_bill_throws': [1.0, 1.0, 1.0, 1.0, 1.0],\n", - " 'int_bill_throws': [1.0, 0.0, 1.0, 1.0, 1.0],\n", - " 'apr_bill_throws': [1, 0, 1, 1, 1],\n", - " 'alp_bill_throws': [-0.3566749691963196,\n", + " 'int_sally_throws': [1.0, 1.0, 1.0, 0.0, 1.0],\n", + " 'epr_sally_throws': [1, 1, 1, 0, 1],\n", + " 'elp_sally_throws': [-0.35953614115715027,\n", + " -0.35953614115715027,\n", + " -0.35953614115715027,\n", + " -1.1973283290863037,\n", + " -0.35953614115715027],\n", + " 'obs_bill_hits': [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0]],\n", + " 'int_bill_hits': [[0.0, 1.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 1.0, 0.0]],\n", + " 'apr_bill_hits': [1, 0, 0, 0, 1],\n", + " 'alp_bill_hits': [-0.3566749691963196,\n", + " -1.2039728164672852,\n", + " -1.2039728164672852,\n", " -1.2039728164672852,\n", - " -0.3566749691963196,\n", - " -0.3566749691963196,\n", " -0.3566749691963196],\n", - " 'obs_bill_hits': [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]],\n", - " 'wpr_bill_hits': [0, 1, 0, 0, 0],\n", + " 'wpr_bill_hits': [0, 1, 1, 0, 0],\n", " 'wlp_bill_hits': [-0.3624056577682495,\n", " -1.1907275915145874,\n", - " -0.3624056577682495,\n", + " -1.1907275915145874,\n", " -0.3624056577682495,\n", " -0.3624056577682495]}" ] }, - "execution_count": 39, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -3470,7 +2208,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.6" }, "orig_nbformat": 4 }, From 609ed2c5338f8acd2a4b0e9aeff78385f440216f Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Fri, 25 Aug 2023 17:06:57 +0200 Subject: [PATCH 13/13] PartOfCause --- docs/source/responsibility_exploration.ipynb | 104 ++++++++++++++++++- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/docs/source/responsibility_exploration.ipynb b/docs/source/responsibility_exploration.ipynb index 89cd252a..e5bcbe02 100644 --- a/docs/source/responsibility_exploration.ipynb +++ b/docs/source/responsibility_exploration.ipynb @@ -91,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -111,7 +111,8 @@ "\n", "from chirho.indexed.ops import IndexSet, gather, indices_of, scatter\n", "from chirho.interventional.handlers import do\n", - "from chirho.counterfactual.ops import preempt, intervene\n", + "from chirho.counterfactual.ops import preempt, split\n", + "from chirho.interventional.ops import intervene, Intervention\n", "from chirho.counterfactual.handlers.counterfactual import MultiWorldCounterfactual, Preemptions\n", "from chirho.observational.handlers import condition\n" ] @@ -146,6 +147,19 @@ " msg[\"args\"] = (obs, acts, case)\n", " msg[\"stop\"] = True\n", "\n", + " def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:\n", + " try:\n", + " action = self.actions[msg[\"name\"]]\n", + " except KeyError:\n", + " return\n", + " msg[\"value\"] = preempt(\n", + " msg[\"value\"],\n", + " (action,),\n", + " None,\n", + " event_dim=len(msg[\"fn\"].event_shape),\n", + " name=msg[\"name\"],\n", + " )\n", + "\n", " def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:\n", " with pyro.poutine.messenger.block_messengers(\n", " lambda m : (isinstance(m, Preemptions) and (m is not self))\n", @@ -153,6 +167,84 @@ " super()._pyro_post_sample(msg) " ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# self.evaluated_node_preemptions = {node: functools.partial(self.preempt_with_factual,\n", + "# antecedents = [node]) for\n", + "# node in self.evaluated_node_counterfactual.keys()}\n", + "\n", + "# with do(actions=self.evaluated_node_counterfactual):\n", + "# with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", + "# prefix = \"__evaluated_split_\"):\n", + "\n", + "class PartOfCause(pyro.poutine.messenger.Messenger):\n", + "\n", + " def __init__(self, evaluated_node_counterfactual: Dict[str, Intervention[torch.Tensor]],\n", + " bias: float = 0.0, prefix: str = \"__split_\") -> None:\n", + " \n", + " self.bias = bias\n", + " self.prefix = prefix\n", + " self.evaluated_node_counterfactual = evaluated_node_counterfactual\n", + " self.evaluated_node_preemptions = {node: functools.partial(self.preempt_with_factual,\n", + " antecedents = [node]) for\n", + " node in self.evaluated_node_counterfactual.keys()}\n", + " super().__init__()\n", + "\n", + " @staticmethod \n", + " def preempt_with_factual(value: torch.Tensor, *,\n", + " antecedents: List[str] = None, event_dim: int = 0):\n", + " \n", + " if antecedents is None:\n", + " antecedents = []\n", + "\n", + " antecedents = [a for a in antecedents if a in indices_of(value, event_dim=event_dim)]\n", + "\n", + " factual_value = gather(value, IndexSet(**{antecedent: {0} for antecedent in antecedents}),\n", + " event_dim=event_dim)\n", + " \n", + " return scatter({\n", + " IndexSet(**{antecedent: {0} for antecedent in antecedents}): factual_value,\n", + " IndexSet(**{antecedent: {1} for antecedent in antecedents}): factual_value,\n", + " }, event_dim=event_dim)\n", + " \n", + " \n", + " \n", + "\n", + " def _pyro_post_sample(self, msg: Dict[str, Any]) -> None:\n", + " if msg[\"name\"] not in self.evaluated_node_counterfactual:\n", + " return \n", + " \n", + " with pyro.poutine.messenger.block_messengers(\n", + " lambda m : (isinstance(m, Preemptions) and (m is not self))\n", + " ):\n", + " \n", + " msg['value'] = split(msg['value'], (self.evaluated_node_counterfactual[msg[\"name\"]],),\n", + " event_dim=len(msg[\"fn\"].event_shape), name=msg[\"name\"])\n", + " \n", + " weights = torch.tensor([.5-self.bias, .5+self.bias], device = msg['value'].device)\n", + " case_dist = pyro.distributions.Categorical(weights)\n", + " case = pyro.sample(f\"{self.prefix}{msg['name']}\", case_dist)\n", + " \n", + "\n", + " msg[\"value\"] = preempt(\n", + " msg[\"value\"],\n", + " (self.evaluated_node_preemptions[msg[\"name\"]],),\n", + " case,\n", + " event_dim=len(msg[\"fn\"].event_shape),\n", + " name=f\"{self.prefix}{msg['name']}\",\n", + " )\n", + "\n", + " \n", + "\n", + " \n", + "\n", + " \n" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -224,6 +316,8 @@ " \n", " return w\n", "\n", + "\n", + "\n", " @staticmethod \n", " def preempt_with_factual(value: torch.Tensor, *,\n", " antecedents: List[str] = None, event_dim: int = 0):\n", @@ -245,14 +339,20 @@ " def __call__(self, *args, **kwargs):\n", " print(\"Preemption biases used (upper) - t:\",.5+ self.bias_t, \", n:\", .5 + self.bias_n, \", w:\", .5 + self.bias_w, \".\")\n", " with MultiWorldCounterfactual():\n", + "\n", " with BiasedPreemptions(actions = self.witness_preemptions, weights = torch.tensor([.5 + self.bias_w, .5-self.bias_w]),\n", " prefix = \"__witness_split_\"):\n", + " \n", + "\n", " with do(actions=self.evaluated_node_counterfactual):\n", " with BiasedPreemptions(actions = self.evaluated_node_preemptions, weights = torch.tensor([.5-self.bias_n, .5+self.bias_n]),\n", " prefix = \"__evaluated_split_\"):\n", + " \n", + "\n", " with do(actions=self.treatment_candidates):\n", " with BiasedPreemptions(actions = self.treatment_preemptions, weights = torch.tensor([.5-self.bias_t, .5+self.bias_t]),\n", " prefix = \"__treatment_split_\"):\n", + " \n", " with condition(data={k: torch.as_tensor(v) for k, v in self.observations.items()}):\n", " with pyro.poutine.trace() as self.trace:\n", " self.run = self.model(*args, **kwargs)\n",