From ec18b5d40691a697eebc2929cf2e0e860f643554 Mon Sep 17 00:00:00 2001 From: lucianopaz Date: Wed, 12 Jan 2022 15:05:43 +0100 Subject: [PATCH] Make nested models share coords with parents --- RELEASE-NOTES.md | 1 + pymc/model.py | 11 +++++++---- pymc/tests/test_model.py | 14 ++++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 6d8b086002d..2f69e8a4822 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -120,6 +120,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - New named dimensions can be introduced to the model via `pm.Data(..., dims=...)`. For mutable data variables (see above) the lengths of these dimensions are symbolic, so they can be re-sized via `pm.set_data()`. - `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098). - Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169) +- Nested models now inherit the parent model's coordinates. [#5344](https://github.com/pymc-devs/pymc/pull/5344) - ... diff --git a/pymc/model.py b/pymc/model.py index fabf608f470..dc43ecc338a 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -628,10 +628,6 @@ def __init__( rng_seeder: Optional[Union[int, np.random.RandomState]] = None, ): self.name = name - self._coords = {} - self._RV_dims = {} - self._dim_lengths = {} - self.add_coords(coords) self.check_bounds = check_bounds if rng_seeder is None: @@ -654,6 +650,9 @@ def __init__( self.auto_deterministics = treelist(parent=self.parent.auto_deterministics) self.deterministics = treelist(parent=self.parent.deterministics) self.potentials = treelist(parent=self.parent.potentials) + self._coords = self.parent._coords + self._RV_dims = treedict(parent=self.parent._RV_dims) + self._dim_lengths = self.parent._dim_lengths else: self.named_vars = treedict() self.values_to_rvs = treedict() @@ -663,6 +662,10 @@ def __init__( self.auto_deterministics = treelist() self.deterministics = treelist() self.potentials = treelist() + self._coords = {} + self._RV_dims = treedict() + self._dim_lengths = {} + self.add_coords(coords) from pymc.printing import str_for_model diff --git a/pymc/tests/test_model.py b/pymc/tests/test_model.py index 9fa5dbc8278..adabb05abdb 100644 --- a/pymc/tests/test_model.py +++ b/pymc/tests/test_model.py @@ -651,3 +651,17 @@ def test_datalogpt_multiple_shapes(): # This would raise a TypeError, see #4803 and #4804 x_val = m.rvs_to_values[x] m.datalogpt.eval({x_val: 0}) + + +def test_nested_model_coords(): + COORDS = {"dim": range(10)} + with pm.Model(name="m1", coords=COORDS) as m1: + a = pm.Normal("a") + with pm.Model(name="m2") as m2: + b = pm.Normal("b") + c = pm.HalfNormal("c") + d = pm.Normal("d", b, c, dims="dim") + e = pm.Normal("e", a + d, dims="dim") + assert m1.coords is m2.coords + assert m1.dim_lengths is m2.dim_lengths + assert set(m2.RV_dims) < set(m1.RV_dims)