From 0522d5ebff32671a4744af080430a944b6bcb914 Mon Sep 17 00:00:00 2001 From: PoorvaGarg Date: Tue, 27 Aug 2024 13:26:13 -0400 Subject: [PATCH] tweaks --- docs/source/explainable_sir.ipynb | 50 +++++++++++++++---------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/docs/source/explainable_sir.ipynb b/docs/source/explainable_sir.ipynb index c3a17e23..b0f6f5eb 100644 --- a/docs/source/explainable_sir.ipynb +++ b/docs/source/explainable_sir.ipynb @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -130,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -170,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -237,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -279,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -303,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -384,7 +384,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -431,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -589,7 +589,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -651,7 +651,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -697,7 +697,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -709,8 +709,8 @@ " mask_intervened &= trace.nodes[i][\"value\"] == v\n", "\n", " with mwc_imp:\n", - " mask_os_too_high = (gather(trace.nodes[\"mask\"][\"value\"], IndexSet(**{\"mask\": {0}, \"lockdown\": {0}})) == 1) & (gather(trace.nodes[\"lockdown\"][\"value\"], IndexSet(**{\"mask\": {0}, \"lockdown\": {0}})) == 1)\n", - " mask_intervened &= mask_os_too_high\n", + " mask_tensor = (gather(trace.nodes[\"mask\"][\"value\"], IndexSet(**{\"mask\": {0}, \"lockdown\": {0}})) == 1) & (gather(trace.nodes[\"lockdown\"][\"value\"], IndexSet(**{\"mask\": {0}, \"lockdown\": {0}})) == 1)\n", + " mask_intervened &= mask_tensor\n", "\n", " print(\n", " mask,\n", @@ -723,7 +723,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -785,7 +785,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -834,7 +834,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -876,7 +876,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -907,7 +907,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -992,7 +992,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -1023,7 +1023,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1108,7 +1108,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -1201,7 +1201,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -1233,7 +1233,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -1306,7 +1306,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -1336,7 +1336,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 23, "metadata": {}, "outputs": [ {