diff --git a/code/centered_eight/centered_eight.ipynb b/code/centered_eight/centered_eight.ipynb index 4f26eb7..ffa676c 100644 --- a/code/centered_eight/centered_eight.ipynb +++ b/code/centered_eight/centered_eight.ipynb @@ -13,17 +13,15 @@ "metadata": {}, "outputs": [], "source": [ - "# Loading libraries\n", + "import pathlib\n", "import arviz as az\n", - "import pymc3 as pm\n", - "import numpy as np\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")" + "import pymc as pm\n", + "import numpy as np" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -31,7 +29,7 @@ "chains = 4\n", "\n", "J = 8\n", - "y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", + "scores = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", "sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])\n", "schools = np.array(\n", " [\n", @@ -49,48 +47,110 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ + "Sampling: [mu, obs, tau, theta]\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (4 chains in 4 jobs)\n", - "NUTS: [theta, tau, mu]\n", - "Sampling 4 chains, 56 divergences: 100%|███████████████████████████████████████| 4000/4000 [00:07<00:00, 512.82draws/s]\n", - "There were 7 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "The acceptance probability does not match the target. It is 0.7173789646779772, but should be close to 0.8. Try to increase the number of tuning steps.\n", - "There were 25 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "The acceptance probability does not match the target. It is 0.6336897568153758, but should be close to 0.8. Try to increase the number of tuning steps.\n", - "There were 4 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "There were 20 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "The acceptance probability does not match the target. It is 0.6939536410310974, but should be close to 0.8. Try to increase the number of tuning steps.\n", - "The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.\n", - "The estimated number of effective samples is smaller than 200 for some parameters.\n", - "100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 2343.60it/s]\n" + "Multiprocess sampling (4 chains in 2 jobs)\n", + "NUTS: [mu, tau, theta]\n" ] - } - ], - "source": [ - "with pm.Model() as centered_eight:\n", - " mu = pm.Normal(\"mu\", mu=0, sd=5)\n", - " tau = pm.HalfCauchy(\"tau\", beta=5)\n", - " theta = pm.Normal(\"theta\", mu=mu, sd=tau, shape=J)\n", - " obs = pm.Normal(\"obs\", mu=theta, sd=sigma, observed=y)\n", - "\n", - " prior = pm.sample_prior_predictive()\n", - " centered_eight_trace = pm.sample(draws, chains=chains)\n", - " posterior_predictive = pm.sample_posterior_predictive(centered_eight_trace)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [6000/6000 00:19<00:00 Sampling 4 chains, 15 divergences]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 51 seconds.\n", + "Sampling: [obs]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [2000/2000 00:00<00:00]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -99,11 +159,11 @@ "
\n", "
arviz.InferenceData
\n", "
\n", - "
    \n", + "
      \n", " \n", "
    • \n", - " \n", - " \n", + " \n", + " \n", "
      \n", "
      \n", "
        \n", @@ -150,7 +210,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -422,7 +482,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -464,76 +525,78 @@ " * draw (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", - " mu (chain, draw) float64 1.262 1.808 0.4999 ... -3.844 -1.01 -1.406\n", - " theta (chain, draw, school) float64 -0.3224 1.191 2.575 ... 6.112 -3.574\n", - " tau (chain, draw) float64 2.393 3.606 3.46 7.343 ... 5.343 6.199 8.825\n", + " mu (chain, draw) float64 6.438 3.713 4.46 8.17 ... -2.423 -0.5905 4.55\n", + " theta (chain, draw, school) float64 5.455 -0.8535 3.493 ... 6.798 0.4667\n", + " tau (chain, draw) float64 3.155 3.783 3.178 6.008 ... 3.654 3.234 6.417\n", "Attributes:\n", - " created_at: 2020-07-24T15:58:37.667622\n", - " arviz_version: 0.9.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.8
      • tau
        (chain, draw)
        float64
        3.155 3.783 3.178 ... 3.234 6.417
        array([[ 3.15475326,  3.78346627,  3.17791785, ...,  8.19147994,\n",
        +       "         8.20533847,  2.04024261],\n",
        +       "       [ 1.89369534,  2.66553171,  2.09548752, ..., 11.28039308,\n",
        +       "         7.41875925,  4.85740746],\n",
        +       "       [ 4.18072998,  7.71290966,  6.30393371, ..., 14.26822882,\n",
        +       "         8.02070217, 14.01415341],\n",
        +       "       [ 2.86524509,  1.70655648,  2.45879685, ...,  3.65393223,\n",
        +       "         3.23385673,  6.41657985]])
    • created_at :
      2022-10-12T09:58:50.551909
      arviz_version :
      0.13.0.dev0
      inference_library :
      pymc
      inference_library_version :
      4.2.2+7.g8239daa7
      sampling_time :
      51.28652000427246
      tuning_steps :
      1000

    \n", "
\n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -580,7 +643,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -852,7 +915,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -894,60 +958,60 @@ " * draw (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", - " obs (chain, draw, school) float64 -11.12 1.5 12.39 ... -2.56 -28.07\n", + " obs (chain, draw, school) float64 -6.003 -0.4622 ... 2.731 -3.391\n", "Attributes:\n", - " created_at: 2020-07-24T15:58:37.813592\n", - " arviz_version: 0.9.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.8
  • created_at :
    2022-10-12T09:58:55.044789
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -994,7 +1058,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -1266,7 +1330,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -1308,60 +1373,60 @@ " * draw (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", - " obs (chain, draw, school) float64 -5.41 -3.453 -3.752 ... -3.928 -4.184\n", + " obs (chain, draw, school) float64 -4.757 -3.613 ... -3.849 -4.015\n", "Attributes:\n", - " created_at: 2020-07-24T15:58:37.811611\n", - " arviz_version: 0.9.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.8
  • created_at :
    2022-10-12T09:58:50.807156
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -1408,7 +1473,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -1680,7 +1745,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -1716,91 +1782,156 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:           (chain: 4, draw: 500)\n",
      +       "Dimensions:                (chain: 4, draw: 500, warning_dim_0: 1)\n",
              "Coordinates:\n",
      -       "  * chain             (chain) int32 0 1 2 3\n",
      -       "  * draw              (draw) int32 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499\n",
      -       "Data variables:\n",
      -       "    energy            (chain, draw) float64 57.35 57.39 59.79 ... 69.22 66.6\n",
      -       "    diverging         (chain, draw) bool False False False ... False False False\n",
      -       "    energy_error      (chain, draw) float64 0.09743 0.003592 ... 0.04705\n",
      -       "    step_size         (chain, draw) float64 0.1929 0.1929 ... 0.3387 0.3387\n",
      -       "    max_energy_error  (chain, draw) float64 0.3468 0.9223 ... -0.04375 -0.06971\n",
      -       "    lp                (chain, draw) float64 -53.06 -53.59 ... -61.78 -62.33\n",
      -       "    step_size_bar     (chain, draw) float64 0.2378 0.2378 ... 0.2673 0.2673\n",
      -       "    depth             (chain, draw) int64 4 4 5 4 5 4 4 4 4 ... 4 5 5 4 4 4 4 4\n",
      -       "    mean_tree_accept  (chain, draw) float64 0.8429 0.804 ... 0.9954 0.9692\n",
      -       "    tree_size         (chain, draw) float64 15.0 15.0 31.0 ... 15.0 15.0 15.0\n",
      +       "  * chain                  (chain) int32 0 1 2 3\n",
      +       "  * draw                   (draw) int32 0 1 2 3 4 5 ... 494 495 496 497 498 499\n",
      +       "  * warning_dim_0          (warning_dim_0) int32 0\n",
      +       "Data variables: (12/18)\n",
      +       "    step_size_bar          (chain, draw) float64 0.2456 0.2456 ... 0.2822 0.2822\n",
      +       "    n_steps                (chain, draw) float64 7.0 7.0 7.0 ... 15.0 7.0 15.0\n",
      +       "    smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      +       "    lp                     (chain, draw) float64 -57.12 -53.18 ... -53.89 -58.03\n",
      +       "    largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      +       "    reached_max_treedepth  (chain, draw) bool False False False ... False False\n",
      +       "    ...                     ...\n",
      +       "    warning                (chain, draw, warning_dim_0) object None ... None\n",
      +       "    energy                 (chain, draw) float64 60.48 58.94 59.19 ... 61.4 60.7\n",
      +       "    diverging              (chain, draw) bool False False False ... False False\n",
      +       "    energy_error           (chain, draw) float64 0.1657 -0.244 ... 0.5359\n",
      +       "    acceptance_rate        (chain, draw) float64 0.8868 1.0 ... 0.9868 0.5437\n",
      +       "    process_time_diff      (chain, draw) float64 0.0 0.0 0.0 ... 0.0 0.01562 0.0\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T15:58:37.672590\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8
    • perf_counter_start
      (chain, draw)
      float64
      39.7 39.71 39.71 ... 26.29 26.29
      array([[39.7038528, 39.7072551, 39.710123 , ..., 42.2710844, 42.7336796,\n",
      +       "        42.7462516],\n",
      +       "       [31.6845042, 31.6881609, 31.6898322, ..., 34.7179709, 34.7279999,\n",
      +       "        34.7384898],\n",
      +       "       [32.8140127, 32.8193974, 32.8228105, ..., 34.8907805, 34.8953685,\n",
      +       "        34.8979765],\n",
      +       "       [24.6782131, 24.6818946, 24.6837041, ..., 26.2858427, 26.2910588,\n",
      +       "        26.2946144]])
    • step_size
      (chain, draw)
      float64
      0.2276 0.2276 ... 0.3893 0.3893
      array([[0.22761111, 0.22761111, 0.22761111, ..., 0.22761111, 0.22761111,\n",
      +       "        0.22761111],\n",
      +       "       [0.15808089, 0.15808089, 0.15808089, ..., 0.15808089, 0.15808089,\n",
      +       "        0.15808089],\n",
      +       "       [0.3686671 , 0.3686671 , 0.3686671 , ..., 0.3686671 , 0.3686671 ,\n",
      +       "        0.3686671 ],\n",
      +       "       [0.38927184, 0.38927184, 0.38927184, ..., 0.38927184, 0.38927184,\n",
      +       "        0.38927184]])
    • max_energy_error
      (chain, draw)
      float64
      0.2926 -0.244 ... -0.5895 29.55
      array([[ 2.92551164e-01, -2.43975166e-01,  2.88159485e-01, ...,\n",
      +       "         2.18616622e-01,  1.92673885e-01, -9.25260019e-01],\n",
      +       "       [ 1.72086021e+00,  1.34875443e+00, -8.12768358e-01, ...,\n",
      +       "         5.98687101e-01,  9.99166258e-02,  4.43614520e-01],\n",
      +       "       [-2.90753276e-01,  3.46758548e-01, -6.30542449e-01, ...,\n",
      +       "         2.35970929e+00,  7.80685142e-01, -4.98574407e-01],\n",
      +       "       [-1.11676154e+00,  3.27268370e+02,  1.07727496e+02, ...,\n",
      +       "         1.44885097e+00, -5.89538907e-01,  2.95480101e+01]])
    • perf_counter_diff
      (chain, draw)
      float64
      0.003042 0.002529 ... 0.002583
      array([[0.0030421, 0.0025291, 0.0025143, ..., 0.007314 , 0.0122639,\n",
      +       "        0.0090372],\n",
      +       "       [0.0033026, 0.0013466, 0.0052764, ..., 0.009622 , 0.0101186,\n",
      +       "        0.0052873],\n",
      +       "       [0.0051658, 0.0031416, 0.0041322, ..., 0.0043331, 0.0024177,\n",
      +       "        0.0023986],\n",
      +       "       [0.003448 , 0.0015753, 0.0031503, ..., 0.0048967, 0.0033298,\n",
      +       "        0.0025834]])
    • index_in_trajectory
      (chain, draw)
      int64
      -4 -3 2 6 -6 -4 ... 3 -1 -5 -3 4 7
      array([[ -4,  -3,   2, ...,   8,  -5,  11],\n",
      +       "       [  1,  -2,  -2, ...,   4,  -3,  12],\n",
      +       "       [  4,  -2, -12, ...,  -8,  -6,   3],\n",
      +       "       [ -8,   2,   2, ...,  -3,   4,   7]], dtype=int64)
    • tree_depth
      (chain, draw)
      int64
      3 3 3 5 4 5 3 4 ... 4 3 4 2 4 4 3 4
      array([[3, 3, 3, ..., 4, 5, 5],\n",
      +       "       [3, 2, 4, ..., 4, 4, 4],\n",
      +       "       [5, 4, 5, ..., 4, 4, 4],\n",
      +       "       [4, 3, 4, ..., 4, 3, 4]], dtype=int64)
    • warning
      (chain, draw, warning_dim_0)
      object
      None None None ... None None None
      array([[[None],\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        ...,\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        [None]],\n",
      +       "\n",
      +       "       [[None],\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        ...,\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        [None]],\n",
      +       "\n",
      +       "       [[None],\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        ...,\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        [None]],\n",
      +       "\n",
      +       "       [[None],\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        ...,\n",
      +       "        [None],\n",
      +       "        [None],\n",
      +       "        [None]]], dtype=object)
    • energy
      (chain, draw)
      float64
      60.48 58.94 59.19 ... 61.4 60.7
      array([[60.4847877 , 58.94032905, 59.19468076, ..., 62.78117276,\n",
      +       "        69.31635828, 61.54467182],\n",
      +       "       [56.93537973, 57.78532454, 59.29574367, ..., 66.25289619,\n",
      +       "        68.19093626, 75.28309756],\n",
      +       "       [64.5673875 , 66.01343496, 68.46851294, ..., 68.38460504,\n",
      +       "        71.38278386, 72.14683375],\n",
      +       "       [63.64125714, 54.45052517, 52.60680157, ..., 60.71941523,\n",
      +       "        61.40191872, 60.70298822]])
    • diverging
      (chain, draw)
      bool
      False False False ... False False
      array([[False, False, False, ..., False, False, False],\n",
      +       "       [False, False, False, ..., False, False, False],\n",
      +       "       [False, False, False, ..., False, False, False],\n",
      +       "       [False, False, False, ..., False, False, False]])
    • energy_error
      (chain, draw)
      float64
      0.1657 -0.244 ... -0.5895 0.5359
      array([[ 0.16565028, -0.24397517, -0.07494055, ...,  0.10677853,\n",
      +       "         0.03820266, -0.53659508],\n",
      +       "       [ 1.00439851,  1.12586386,  0.3150327 , ..., -0.07401262,\n",
      +       "         0.09991663,  0.23167204],\n",
      +       "       [ 0.0180871 ,  0.2596273 ,  0.09448181, ...,  0.26458177,\n",
      +       "         0.25796784, -0.13860717],\n",
      +       "       [-0.2704943 , -0.73321445, -0.28304404, ...,  0.49669728,\n",
      +       "        -0.58953891,  0.53591373]])
    • acceptance_rate
      (chain, draw)
      float64
      0.8868 1.0 0.9575 ... 0.9868 0.5437
      array([[0.88675481, 1.        , 0.95746331, ..., 0.88140606, 0.93964454,\n",
      +       "        0.99007237],\n",
      +       "       [0.40076539, 0.39275394, 0.97554501, ..., 0.95809153, 0.97829063,\n",
      +       "        0.83876602],\n",
      +       "       [0.99360741, 0.82589775, 0.96369187, ..., 0.5675964 , 0.91406341,\n",
      +       "        0.97567385],\n",
      +       "       [0.98033282, 0.28794532, 0.614792  , ..., 0.61305512, 0.98678056,\n",
      +       "        0.54370112]])
    • process_time_diff
      (chain, draw)
      float64
      0.0 0.0 0.0 0.0 ... 0.0 0.01562 0.0
      array([[0.      , 0.      , 0.      , ..., 0.      , 0.03125 , 0.      ],\n",
      +       "       [0.015625, 0.      , 0.      , ..., 0.      , 0.015625, 0.      ],\n",
      +       "       [0.      , 0.      , 0.      , ..., 0.015625, 0.      , 0.      ],\n",
      +       "       [0.015625, 0.      , 0.      , ..., 0.      , 0.015625, 0.      ]])
  • created_at :
    2022-10-12T09:58:50.563910
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7
    sampling_time :
    51.28652000427246
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -1847,7 +1978,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -2119,7 +2250,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -2155,161 +2287,120 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:    (chain: 1, draw: 500, school: 8)\n",
      +       "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
              "Coordinates:\n",
      -       "  * chain      (chain) int32 0\n",
      -       "  * draw       (draw) int32 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * school     (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
      +       "  * chain    (chain) int32 0\n",
      +       "  * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
      +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
              "Data variables:\n",
      -       "    tau_log__  (chain, draw) float64 0.3158 2.123 1.79 ... 1.178 1.796 1.372\n",
      -       "    tau        (chain, draw) float64 1.371 8.359 5.987 ... 3.249 6.027 3.945\n",
      -       "    mu         (chain, draw) float64 1.557 0.8172 2.985 ... 3.694 0.02784 -2.71\n",
      -       "    theta      (chain, draw, school) float64 1.638 1.044 2.263 ... 2.327 -7.168\n",
      +       "    theta    (chain, draw, school) float64 4.769 4.619 -8.079 ... -2.282 -6.272\n",
      +       "    tau      (chain, draw) float64 6.891 12.48 7.972 5.559 ... 6.112 50.19 1.671\n",
      +       "    mu       (chain, draw) float64 4.185 1.54 1.823 ... -1.644 1.584 -5.634\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T15:58:37.815636\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8
  • created_at :
    2022-10-12T09:57:52.918242
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -2356,7 +2447,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -2628,7 +2719,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -2670,32 +2762,32 @@ " * draw (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n", " * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n", "Data variables:\n", - " obs (chain, draw, school) float64 17.09 -19.86 -21.88 ... 7.722 -1.332\n", + " obs (chain, draw, school) float64 9.56 21.66 12.41 ... -7.506 30.73\n", "Attributes:\n", - " created_at: 2020-07-24T15:58:37.817621\n", - " arviz_version: 0.9.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.8
  • created_at :
    2022-10-12T09:57:52.920241
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -2742,7 +2834,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -3014,7 +3106,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -3056,18 +3149,37 @@ "Data variables:\n", " obs (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n", "Attributes:\n", - " created_at: 2020-07-24T15:58:37.818623\n", - " arviz_version: 0.9.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.8

    \n", + " created_at: 2022-10-12T09:57:52.921252\n", + " arviz_version: 0.13.0.dev0\n", + " inference_library: pymc\n", + " inference_library_version: 4.2.2+7.g8239daa7
    \n", " \n", " \n", "
  • \n", " \n", - " \n", - " \n", - "
    <xarray.Dataset>\n",
    +       "Dimensions:  (school: 8)\n",
    +       "Coordinates:\n",
    +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
    +       "Data variables:\n",
    +       "    scores   (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
    +       "Attributes:\n",
    +       "    created_at:                 2022-10-12T09:57:52.922243\n",
    +       "    arviz_version:              0.13.0.dev0\n",
    +       "    inference_library:          pymc\n",
    +       "    inference_library_version:  4.2.2+7.g8239daa7

    \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " " ], "text/plain": [ @@ -3412,46 +3882,50 @@ "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", - "\t> observed_data" + "\t> observed_data\n", + "\t> constant_data" ] }, - "execution_count": 8, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "data = az.from_pymc3(\n", - " trace=centered_eight_trace,\n", - " prior=prior,\n", - " posterior_predictive=posterior_predictive,\n", - " model=centered_eight,\n", - " coords={\"school\": schools},\n", - " dims={\"theta\": [\"school\"], \"obs\": [\"school\"]},\n", - ")\n", + "with pm.Model(coords={\n", + " \"school\": schools,\n", + "}) as centered_eight:\n", + " mu = pm.Normal(\"mu\", mu=0, sigma=5)\n", + " tau = pm.HalfCauchy(\"tau\", beta=5)\n", + " theta = pm.Normal(\"theta\", mu=mu, sigma=tau, shape=J, dims=\"school\")\n", + " y_obs = pm.ConstantData(\"scores\", scores, dims=\"school\")\n", + " obs = pm.Normal(\"obs\", mu=theta, sigma=sigma, observed=y_obs, dims=\"school\")\n", "\n", - "data" + " idata = pm.sample_prior_predictive()\n", + " idata.extend(pm.sample(draws, chains=chains))\n", + " idata.extend(pm.sample_posterior_predictive(idata))\n", + "idata" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'centered_eight.nc'" + "WindowsPath('../../data/centered_eight.nc')" ] }, - "execution_count": 9, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Storing the model to .nc format\n", - "data.to_netcdf('centered_eight.nc')" + "idata.to_netcdf(pathlib.Path(\"..\", \"..\", \"data\", \"centered_eight.nc\"))" ] }, { @@ -3464,7 +3938,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.9.13 ('pmdev')", "language": "python", "name": "python3" }, @@ -3478,7 +3952,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.9.13" + }, + "vscode": { + "interpreter": { + "hash": "a3357ae232635c60236cc6fd817e8b6b9d1594a7c105bd2f3bd11f1b30a66c24" + } } }, "nbformat": 4, diff --git a/code/non_centered_eight/non_centered_eight.ipynb b/code/non_centered_eight/non_centered_eight.ipynb index 6bc0af7..51001ee 100644 --- a/code/non_centered_eight/non_centered_eight.ipynb +++ b/code/non_centered_eight/non_centered_eight.ipynb @@ -13,12 +13,10 @@ "metadata": {}, "outputs": [], "source": [ - "# Loading libraries\n", + "import pathlib\n", "import arviz as az\n", - "import pymc3 as pm\n", - "import numpy as np\n", - "import warnings\n", - "warnings.filterwarnings(\"ignore\")" + "import pymc as pm\n", + "import numpy as np" ] }, { @@ -31,7 +29,7 @@ "chains = 4\n", "\n", "J = 8\n", - "y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", + "scores = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])\n", "sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])\n", "schools = np.array(\n", " [\n", @@ -49,48 +47,110 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ + "Sampling: [mu, obs, tau, theta, theta_t]\n", "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (4 chains in 4 jobs)\n", - "NUTS: [theta, theta_t, tau, mu]\n", - "Sampling 4 chains, 127 divergences: 100%|██████████████████████████████████████| 4000/4000 [00:09<00:00, 416.79draws/s]\n", - "There were 52 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "The acceptance probability does not match the target. It is 0.4104186119686267, but should be close to 0.8. Try to increase the number of tuning steps.\n", - "There were 41 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "The acceptance probability does not match the target. It is 0.6817166948956205, but should be close to 0.8. Try to increase the number of tuning steps.\n", - "There were 26 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "There were 8 divergences after tuning. Increase `target_accept` or reparameterize.\n", - "The rhat statistic is larger than 1.05 for some parameters. This indicates slight problems during sampling.\n", - "The estimated number of effective samples is smaller than 200 for some parameters.\n", - "100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:00<00:00, 2405.47it/s]\n" + "Multiprocess sampling (4 chains in 2 jobs)\n", + "NUTS: [mu, tau, theta_t, theta]\n" ] - } - ], - "source": [ - "with pm.Model() as non_centered_eight:\n", - " mu = pm.Normal(\"mu\", mu=0, sd=5)\n", - " tau = pm.HalfCauchy(\"tau\", beta=5)\n", - " theta_tilde = pm.Normal(\"theta_t\", mu=0, sd=1, shape=J)\n", - " theta = pm.Normal(\"theta\", mu=mu, sd=tau, shape=J)\n", - " obs = pm.Normal(\"obs\", mu=theta, sd=sigma, observed=y)\n", - "\n", - " prior = pm.sample_prior_predictive()\n", - " non_centered_eight_trace = pm.sample(draws, chains=chains)\n", - " posterior_predictive = pm.sample_posterior_predictive(non_centered_eight_trace)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
    \n", + " \n", + " 100.00% [6000/6000 00:37<00:00 Sampling 4 chains, 49 divergences]\n", + "
    \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 94 seconds.\n", + "Sampling: [obs]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
    \n", + " \n", + " 100.00% [2000/2000 00:00<00:00]\n", + "
    \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -99,11 +159,11 @@ "
    \n", "
    arviz.InferenceData
    \n", "
    \n", - "
      \n", + "
        \n", " \n", "
      • \n", - " \n", - " \n", + " \n", + " \n", "
        \n", "
        \n", "
          \n", @@ -150,7 +210,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -422,7 +482,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -458,123 +519,126 @@ " fill: currentColor;\n", "}\n", "
          <xarray.Dataset>\n",
          -       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
          +       "Dimensions:      (chain: 4, draw: 500, school: 8, theta_dim_0: 8)\n",
                  "Coordinates:\n",
          -       "  * chain    (chain) int32 0 1 2 3\n",
          -       "  * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
          -       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
          +       "  * chain        (chain) int32 0 1 2 3\n",
          +       "  * draw         (draw) int32 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
          +       "  * school       (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
          +       "  * theta_dim_0  (theta_dim_0) int32 0 1 2 3 4 5 6 7\n",
                  "Data variables:\n",
          -       "    mu       (chain, draw) float64 6.79 5.731 4.465 3.535 ... 5.957 -1.642 4.704\n",
          -       "    theta_t  (chain, draw, school) float64 -1.514 1.738 -1.498 ... 1.1 0.1812\n",
          -       "    theta    (chain, draw, school) float64 16.4 9.137 2.87 ... 15.01 2.525\n",
          -       "    tau      (chain, draw) float64 3.931 4.349 3.913 2.099 ... 4.856 6.098 5.005\n",
          +       "    mu           (chain, draw) float64 8.491 8.185 -1.778 ... 0.2131 -0.9438\n",
          +       "    theta_t      (chain, draw, school) float64 -0.01047 -0.4168 ... 0.04088\n",
          +       "    theta        (chain, draw, theta_dim_0) float64 4.086 11.74 ... -1.689\n",
          +       "    tau          (chain, draw) float64 5.868 6.223 7.5 ... 4.117 3.276 1.282\n",
                  "Attributes:\n",
          -       "    created_at:                 2020-07-24T16:02:06.283176\n",
          -       "    arviz_version:              0.9.0\n",
          -       "    inference_library:          pymc3\n",
          -       "    inference_library_version:  3.8
        • tau
          (chain, draw)
          float64
          5.868 6.223 7.5 ... 3.276 1.282
          array([[ 5.8679342 ,  6.22333068,  7.49971373, ...,  3.10320571,\n",
          +       "         3.06593982,  5.65934737],\n",
          +       "       [ 1.12644612,  2.56678196,  6.62530348, ...,  6.7548095 ,\n",
          +       "        10.14834745,  7.79967145],\n",
          +       "       [ 2.65297436,  3.18686697,  3.73830155, ...,  2.87252546,\n",
          +       "         4.25396636,  2.03125639],\n",
          +       "       [ 6.13414944,  3.4716475 ,  9.54572839, ...,  4.11732407,\n",
          +       "         3.27584461,  1.28233781]])
      • created_at :
        2022-10-12T10:04:05.357340
        arviz_version :
        0.13.0.dev0
        inference_library :
        pymc
        inference_library_version :
        4.2.2+7.g8239daa7
        sampling_time :
        93.89331722259521
        tuning_steps :
        1000

      \n", "
    \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -621,7 +685,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -893,7 +957,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -929,66 +994,65 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
      +       "Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)\n",
              "Coordinates:\n",
      -       "  * chain    (chain) int32 0 1 2 3\n",
      -       "  * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
      -       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
      +       "  * chain      (chain) int32 0 1 2 3\n",
      +       "  * draw       (draw) int32 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * obs_dim_0  (obs_dim_0) int32 0 1 2 3 4 5 6 7\n",
              "Data variables:\n",
      -       "    obs      (chain, draw, school) float64 -10.73 4.706 18.01 ... -1.436 1.874\n",
      +       "    obs        (chain, draw, obs_dim_0) float64 3.283 16.96 ... -0.339 28.01\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T16:02:06.411835\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8
  • created_at :
    2022-10-12T10:04:16.958157
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -1035,7 +1099,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -1307,7 +1371,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -1343,66 +1408,65 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:  (chain: 4, draw: 500, school: 8)\n",
      +       "Dimensions:    (chain: 4, draw: 500, obs_dim_0: 8)\n",
              "Coordinates:\n",
      -       "  * chain    (chain) int32 0 1 2 3\n",
      -       "  * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
      -       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
      +       "  * chain      (chain) int32 0 1 2 3\n",
      +       "  * draw       (draw) int32 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * obs_dim_0  (obs_dim_0) int32 0 1 2 3 4 5 6 7\n",
              "Data variables:\n",
      -       "    obs      (chain, draw, school) float64 -3.926 -3.228 ... -3.266 -3.948\n",
      +       "    obs        (chain, draw, obs_dim_0) float64 -4.898 -3.291 ... -4.843 -4.098\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T16:02:06.409825\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8
  • created_at :
    2022-10-12T10:04:05.773403
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -1449,7 +1513,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -1721,7 +1785,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -1757,91 +1822,156 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:           (chain: 4, draw: 500)\n",
      +       "Dimensions:                (chain: 4, draw: 500, warning_dim_0: 1)\n",
              "Coordinates:\n",
      -       "  * chain             (chain) int32 0 1 2 3\n",
      -       "  * draw              (draw) int32 0 1 2 3 4 5 6 ... 493 494 495 496 497 498 499\n",
      -       "Data variables:\n",
      -       "    energy            (chain, draw) float64 75.88 75.52 73.19 ... 83.26 82.12\n",
      -       "    lp                (chain, draw) float64 -69.98 -64.94 ... -76.14 -70.44\n",
      -       "    diverging         (chain, draw) bool False False False ... False False False\n",
      -       "    tree_size         (chain, draw) float64 7.0 15.0 15.0 ... 31.0 31.0 15.0\n",
      -       "    depth             (chain, draw) int64 3 4 4 4 4 4 4 4 4 ... 4 4 4 4 4 5 5 4\n",
      -       "    step_size_bar     (chain, draw) float64 0.2395 0.2395 ... 0.2295 0.2295\n",
      -       "    step_size         (chain, draw) float64 0.3458 0.3458 ... 0.2013 0.2013\n",
      -       "    energy_error      (chain, draw) float64 -0.3258 -0.07222 ... 0.1261 -0.138\n",
      -       "    max_energy_error  (chain, draw) float64 0.5255 1.304 1.13 ... 0.2224 -0.1699\n",
      -       "    mean_tree_accept  (chain, draw) float64 0.9066 0.9396 0.8767 ... 0.9212 1.0\n",
      +       "  * chain                  (chain) int32 0 1 2 3\n",
      +       "  * draw                   (draw) int32 0 1 2 3 4 5 ... 494 495 496 497 498 499\n",
      +       "  * warning_dim_0          (warning_dim_0) int32 0\n",
      +       "Data variables: (12/18)\n",
      +       "    reached_max_treedepth  (chain, draw) bool False False False ... False False\n",
      +       "    smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      +       "    lp                     (chain, draw) float64 -73.68 -72.15 ... -66.62 -61.77\n",
      +       "    warning                (chain, draw, warning_dim_0) object None ... None\n",
      +       "    energy_error           (chain, draw) float64 0.1046 -0.0208 ... -0.9957\n",
      +       "    diverging              (chain, draw) bool False False False ... False False\n",
      +       "    ...                     ...\n",
      +       "    n_steps                (chain, draw) float64 15.0 15.0 15.0 ... 15.0 31.0\n",
      +       "    index_in_trajectory    (chain, draw) int64 7 -1 7 -4 11 ... 24 -8 13 6 -19\n",
      +       "    tree_depth             (chain, draw) int64 4 4 4 4 4 4 4 5 ... 4 5 5 5 5 4 5\n",
      +       "    step_size_bar          (chain, draw) float64 0.2228 0.2228 ... 0.1901 0.1901\n",
      +       "    perf_counter_start     (chain, draw) float64 68.42 68.43 ... 50.34 50.34\n",
      +       "    max_energy_error       (chain, draw) float64 0.1417 1.044 ... -0.9957\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T16:02:06.291204\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8
    • acceptance_rate
      (chain, draw)
      float64
      0.9379 0.9508 ... 0.9733 0.9365
      array([[9.37925468e-01, 9.50766829e-01, 9.85269135e-01, ...,\n",
      +       "        8.58381777e-01, 9.80816014e-01, 8.33199292e-01],\n",
      +       "       [6.04336900e-65, 3.21143038e-01, 7.20882697e-01, ...,\n",
      +       "        9.77457430e-01, 9.22237423e-01, 9.79150951e-01],\n",
      +       "       [8.25465907e-01, 9.62276506e-01, 8.07298496e-01, ...,\n",
      +       "        8.91888034e-01, 8.49341800e-01, 9.10208208e-01],\n",
      +       "       [9.73775304e-01, 9.75197463e-01, 8.14491164e-01, ...,\n",
      +       "        9.82560778e-01, 9.73262555e-01, 9.36519359e-01]])
    • energy
      (chain, draw)
      float64
      78.15 83.13 86.76 ... 78.44 80.63
      array([[78.15489685, 83.12810597, 86.75582639, ..., 76.28973877,\n",
      +       "        72.03900284, 72.74829834],\n",
      +       "       [76.94764131, 68.06677293, 71.86066075, ..., 80.90386448,\n",
      +       "        86.40705324, 87.31683977],\n",
      +       "       [72.37764624, 74.84615038, 73.93984455, ..., 77.42517555,\n",
      +       "        80.33226351, 73.91492323],\n",
      +       "       [81.2466136 , 77.08417544, 86.09558459, ..., 76.17190404,\n",
      +       "        78.43987044, 80.62910484]])
    • process_time_diff
      (chain, draw)
      float64
      0.0 0.0 0.01562 ... 0.0 0.01562
      array([[0.      , 0.      , 0.015625, ..., 0.      , 0.015625, 0.      ],\n",
      +       "       [0.      , 0.      , 0.      , ..., 0.      , 0.015625, 0.      ],\n",
      +       "       [0.      , 0.      , 0.      , ..., 0.      , 0.015625, 0.      ],\n",
      +       "       [0.015625, 0.      , 0.015625, ..., 0.015625, 0.      , 0.015625]])
    • largest_eigval
      (chain, draw)
      float64
      nan nan nan nan ... nan nan nan nan
      array([[nan, nan, nan, ..., nan, nan, nan],\n",
      +       "       [nan, nan, nan, ..., nan, nan, nan],\n",
      +       "       [nan, nan, nan, ..., nan, nan, nan],\n",
      +       "       [nan, nan, nan, ..., nan, nan, nan]])
    • step_size
      (chain, draw)
      float64
      0.2607 0.2607 ... 0.2852 0.2852
      array([[0.26073707, 0.26073707, 0.26073707, ..., 0.26073707, 0.26073707,\n",
      +       "        0.26073707],\n",
      +       "       [0.15417116, 0.15417116, 0.15417116, ..., 0.15417116, 0.15417116,\n",
      +       "        0.15417116],\n",
      +       "       [0.17311246, 0.17311246, 0.17311246, ..., 0.17311246, 0.17311246,\n",
      +       "        0.17311246],\n",
      +       "       [0.28517361, 0.28517361, 0.28517361, ..., 0.28517361, 0.28517361,\n",
      +       "        0.28517361]])
    • perf_counter_diff
      (chain, draw)
      float64
      0.01012 0.01267 ... 0.01345
      array([[0.0101213, 0.0126721, 0.007963 , ..., 0.0052501, 0.0068283,\n",
      +       "        0.0052923],\n",
      +       "       [0.0008652, 0.007083 , 0.0034577, ..., 0.0039797, 0.005531 ,\n",
      +       "        0.005027 ],\n",
      +       "       [0.0035033, 0.004197 , 0.0033946, ..., 0.0028006, 0.0052307,\n",
      +       "        0.0045019],\n",
      +       "       [0.0083025, 0.0036725, 0.0075466, ..., 0.0139159, 0.0066084,\n",
      +       "        0.0134534]])
    • n_steps
      (chain, draw)
      float64
      15.0 15.0 15.0 ... 31.0 15.0 31.0
      array([[15., 15., 15., ..., 15., 31., 15.],\n",
      +       "       [ 1., 31., 15., ..., 15., 15., 15.],\n",
      +       "       [15., 15., 15., ...,  7., 15., 15.],\n",
      +       "       [31., 15., 31., ..., 31., 15., 31.]])
    • index_in_trajectory
      (chain, draw)
      int64
      7 -1 7 -4 11 -9 ... 24 -8 13 6 -19
      array([[  7,  -1,   7, ...,  -7,  27,  -4],\n",
      +       "       [  0,  -6,  -6, ...,   8,  12,  12],\n",
      +       "       [ -9,  13, -12, ...,  -7,  -5,  13],\n",
      +       "       [-18,  -7, -14, ...,  13,   6, -19]], dtype=int64)
    • tree_depth
      (chain, draw)
      int64
      4 4 4 4 4 4 4 5 ... 5 4 5 5 5 5 4 5
      array([[4, 4, 4, ..., 4, 5, 4],\n",
      +       "       [1, 5, 4, ..., 4, 4, 4],\n",
      +       "       [4, 4, 4, ..., 3, 4, 4],\n",
      +       "       [5, 4, 5, ..., 5, 4, 5]], dtype=int64)
    • step_size_bar
      (chain, draw)
      float64
      0.2228 0.2228 ... 0.1901 0.1901
      array([[0.22278064, 0.22278064, 0.22278064, ..., 0.22278064, 0.22278064,\n",
      +       "        0.22278064],\n",
      +       "       [0.21940501, 0.21940501, 0.21940501, ..., 0.21940501, 0.21940501,\n",
      +       "        0.21940501],\n",
      +       "       [0.2678179 , 0.2678179 , 0.2678179 , ..., 0.2678179 , 0.2678179 ,\n",
      +       "        0.2678179 ],\n",
      +       "       [0.190107  , 0.190107  , 0.190107  , ..., 0.190107  , 0.190107  ,\n",
      +       "        0.190107  ]])
    • perf_counter_start
      (chain, draw)
      float64
      68.42 68.43 68.44 ... 50.34 50.34
      array([[68.41524  , 68.4265393, 68.4395377, ..., 73.201897 , 73.2074558,\n",
      +       "        73.2145901],\n",
      +       "       [53.5475599, 53.5505434, 53.5578993, ..., 57.9020315, 57.9063038,\n",
      +       "        57.9122315],\n",
      +       "       [54.5895335, 54.593322 , 54.5978106, ..., 57.4393978, 57.4425156,\n",
      +       "        57.4483009],\n",
      +       "       [43.8126236, 43.8212469, 43.8252068, ..., 50.3229112, 50.3373333,\n",
      +       "        50.3444173]])
    • max_energy_error
      (chain, draw)
      float64
      0.1417 1.044 ... -0.4225 -0.9957
      array([[ 1.41747263e-01,  1.04362073e+00, -5.38911342e-02, ...,\n",
      +       "         4.41053707e-01, -5.41529925e-01,  1.49802290e+01],\n",
      +       "       [ 1.47869069e+02,  1.25415033e+01,  4.24994300e-01, ...,\n",
      +       "        -1.76456892e-01,  6.14174619e-01, -4.71332171e-01],\n",
      +       "       [ 1.30906492e+00, -1.15457186e+00,  5.66484734e-01, ...,\n",
      +       "         4.69185377e-01,  5.51302186e-01, -8.70687360e-01],\n",
      +       "       [-4.07278967e-01, -2.04084680e-01,  5.39089895e-01, ...,\n",
      +       "         1.40975003e-01, -4.22518564e-01, -9.95736363e-01]])
  • created_at :
    2022-10-12T10:04:05.376512
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7
    sampling_time :
    93.89331722259521
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -1888,7 +2018,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -2160,7 +2290,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -2196,174 +2327,134 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:    (chain: 1, draw: 500, school: 8)\n",
      +       "Dimensions:      (chain: 1, draw: 500, school: 8, theta_dim_0: 8)\n",
              "Coordinates:\n",
      -       "  * chain      (chain) int32 0\n",
      -       "  * draw       (draw) int32 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * school     (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
      +       "  * chain        (chain) int32 0\n",
      +       "  * draw         (draw) int32 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
      +       "  * school       (school) <U16 'Choate' 'Deerfield' ... 'Mt. Hermon'\n",
      +       "  * theta_dim_0  (theta_dim_0) int32 0 1 2 3 4 5 6 7\n",
              "Data variables:\n",
      -       "    mu         (chain, draw) float64 -0.7153 10.56 3.256 ... -5.293 -3.131 4.294\n",
      -       "    tau_log__  (chain, draw) float64 0.5034 -2.498 0.9147 ... 0.2578 0.8701 1.41\n",
      -       "    theta      (chain, draw, school) float64 2.238 -1.46 -3.312 ... 4.409 2.352\n",
      -       "    tau        (chain, draw) float64 1.654 0.08225 2.496 ... 1.294 2.387 4.097\n",
      -       "    theta_t    (chain, draw, school) float64 -1.308 -0.07714 ... 0.1931 0.1864\n",
      +       "    tau          (chain, draw) float64 2.494 6.88 3.25 3.46 ... 10.61 6.3 3.141\n",
      +       "    theta_t      (chain, draw, school) float64 1.248 -0.07138 ... 0.6023 0.6472\n",
      +       "    mu           (chain, draw) float64 2.021 6.311 2.41 ... 0.4196 -1.567 -1.855\n",
      +       "    theta        (chain, draw, theta_dim_0) float64 3.307 3.307 ... -3.564 2.77\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T16:02:06.413834\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8
  • created_at :
    2022-10-12T10:02:07.611352
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -2410,7 +2501,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -2682,7 +2773,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -2718,38 +2810,37 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:  (chain: 1, draw: 500, school: 8)\n",
      +       "Dimensions:    (chain: 1, draw: 500, obs_dim_0: 8)\n",
              "Coordinates:\n",
      -       "  * chain    (chain) int32 0\n",
      -       "  * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
      -       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
      +       "  * chain      (chain) int32 0\n",
      +       "  * draw       (draw) int32 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * obs_dim_0  (obs_dim_0) int32 0 1 2 3 4 5 6 7\n",
              "Data variables:\n",
      -       "    obs      (chain, draw, school) float64 14.62 -8.063 -32.39 ... 12.56 28.55\n",
      +       "    obs        (chain, draw, obs_dim_0) float64 12.22 9.156 ... 11.48 11.51\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T16:02:06.415834\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8
  • created_at :
    2022-10-12T10:02:07.618182
    arviz_version :
    0.13.0.dev0
    inference_library :
    pymc
    inference_library_version :
    4.2.2+7.g8239daa7

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
      \n", @@ -2796,7 +2887,7 @@ "}\n", "\n", ".xr-wrap {\n", - " display: block;\n", + " display: block !important;\n", " min-width: 300px;\n", " max-width: 700px;\n", "}\n", @@ -3068,7 +3159,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -3104,24 +3196,42 @@ " fill: currentColor;\n", "}\n", "
      <xarray.Dataset>\n",
      -       "Dimensions:  (school: 8)\n",
      +       "Dimensions:    (obs_dim_0: 8)\n",
              "Coordinates:\n",
      -       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
      +       "  * obs_dim_0  (obs_dim_0) int32 0 1 2 3 4 5 6 7\n",
              "Data variables:\n",
      -       "    obs      (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
      +       "    obs        (obs_dim_0) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
              "Attributes:\n",
      -       "    created_at:                 2020-07-24T16:02:06.416834\n",
      -       "    arviz_version:              0.9.0\n",
      -       "    inference_library:          pymc3\n",
      -       "    inference_library_version:  3.8

    \n", + " created_at: 2022-10-12T10:02:07.618182\n", + " arviz_version: 0.13.0.dev0\n", + " inference_library: pymc\n", + " inference_library_version: 4.2.2+7.g8239daa7
    \n", " \n", " \n", "
  • \n", " \n", - " \n", - " \n", - "
    <xarray.Dataset>\n",
    +       "Dimensions:  (school: 8)\n",
    +       "Coordinates:\n",
    +       "  * school   (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon'\n",
    +       "Data variables:\n",
    +       "    scores   (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0\n",
    +       "Attributes:\n",
    +       "    created_at:                 2022-10-12T10:02:07.618182\n",
    +       "    arviz_version:              0.13.0.dev0\n",
    +       "    inference_library:          pymc\n",
    +       "    inference_library_version:  4.2.2+7.g8239daa7

    \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " " ], "text/plain": [ @@ -3466,7 +3934,8 @@ "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", - "\t> observed_data" + "\t> observed_data\n", + "\t> constant_data" ] }, "execution_count": 4, @@ -3475,16 +3944,20 @@ } ], "source": [ - "data = az.from_pymc3(\n", - " trace=non_centered_eight_trace,\n", - " prior=prior,\n", - " posterior_predictive=posterior_predictive,\n", - " model=non_centered_eight,\n", - " coords={\"school\": schools},\n", - " dims={\"theta\": [\"school\"], \"obs\": [\"school\"], \"theta_t\": [\"school\"]},\n", - ")\n", + "with pm.Model(coords={\n", + " \"school\": schools,\n", + "}) as non_centered_eight:\n", + " mu = pm.Normal(\"mu\", mu=0, sigma=5)\n", + " tau = pm.HalfCauchy(\"tau\", beta=5)\n", + " theta_tilde = pm.Normal(\"theta_t\", mu=0, sigma=1, shape=J, dims=\"school\")\n", + " theta = pm.Normal(\"theta\", mu=mu, sigma=tau, shape=J)\n", + " y_obs = pm.ConstantData(\"scores\", scores, dims=\"school\")\n", + " obs = pm.Normal(\"obs\", mu=theta, sigma=sigma, observed=y_obs)\n", "\n", - "data" + " idata = pm.sample_prior_predictive()\n", + " idata.extend(pm.sample(draws, chains=chains))\n", + " idata.extend(pm.sample_posterior_predictive(idata))\n", + "idata" ] }, { @@ -3495,7 +3968,7 @@ { "data": { "text/plain": [ - "'non_centered_eight.nc'" + "WindowsPath('../../data/non_centered_eight.nc')" ] }, "execution_count": 5, @@ -3505,7 +3978,7 @@ ], "source": [ "# Storing the model to .nc format\n", - "data.to_netcdf('non_centered_eight.nc')" + "idata.to_netcdf(pathlib.Path(\"..\", \"..\", \"data\", \"non_centered_eight.nc\"))" ] }, { @@ -3518,7 +3991,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.9.13 ('pmdev')", "language": "python", "name": "python3" }, @@ -3532,7 +4005,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.9.13" + }, + "vscode": { + "interpreter": { + "hash": "a3357ae232635c60236cc6fd817e8b6b9d1594a7c105bd2f3bd11f1b30a66c24" + } } }, "nbformat": 4, diff --git a/data/centered_eight.nc b/data/centered_eight.nc index a22e183..3b556dc 100644 Binary files a/data/centered_eight.nc and b/data/centered_eight.nc differ diff --git a/data/non_centered_eight.nc b/data/non_centered_eight.nc index e48a348..cde9e3a 100644 Binary files a/data/non_centered_eight.nc and b/data/non_centered_eight.nc differ