Skip to content

Commit

Permalink
Make nested models share coords with parents
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Jan 12, 2022
1 parent b29124b commit ec18b5d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- New named dimensions can be introduced to the model via `pm.Data(..., dims=...)`. For mutable data variables (see above) the lengths of these dimensions are symbolic, so they can be re-sized via `pm.set_data()`.
- `pm.Data` now passes additional kwargs to `aesara.shared`/`at.as_tensor`. [#5098](https://github.com/pymc-devs/pymc/pull/5098).
- Univariate censored distributions are now available via `pm.Censored`. [#5169](https://github.com/pymc-devs/pymc/pull/5169)
- Nested models now inherit the parent model's coordinates. [#5344](https://github.com/pymc-devs/pymc/pull/5344)
- ...


Expand Down
11 changes: 7 additions & 4 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,10 +628,6 @@ def __init__(
rng_seeder: Optional[Union[int, np.random.RandomState]] = None,
):
self.name = name
self._coords = {}
self._RV_dims = {}
self._dim_lengths = {}
self.add_coords(coords)
self.check_bounds = check_bounds

if rng_seeder is None:
Expand All @@ -654,6 +650,9 @@ def __init__(
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
self.deterministics = treelist(parent=self.parent.deterministics)
self.potentials = treelist(parent=self.parent.potentials)
self._coords = self.parent._coords
self._RV_dims = treedict(parent=self.parent._RV_dims)
self._dim_lengths = self.parent._dim_lengths
else:
self.named_vars = treedict()
self.values_to_rvs = treedict()
Expand All @@ -663,6 +662,10 @@ def __init__(
self.auto_deterministics = treelist()
self.deterministics = treelist()
self.potentials = treelist()
self._coords = {}
self._RV_dims = treedict()
self._dim_lengths = {}
self.add_coords(coords)

from pymc.printing import str_for_model

Expand Down
14 changes: 14 additions & 0 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,17 @@ def test_datalogpt_multiple_shapes():
# This would raise a TypeError, see #4803 and #4804
x_val = m.rvs_to_values[x]
m.datalogpt.eval({x_val: 0})


def test_nested_model_coords():
COORDS = {"dim": range(10)}
with pm.Model(name="m1", coords=COORDS) as m1:
a = pm.Normal("a")
with pm.Model(name="m2") as m2:
b = pm.Normal("b")
c = pm.HalfNormal("c")
d = pm.Normal("d", b, c, dims="dim")
e = pm.Normal("e", a + d, dims="dim")
assert m1.coords is m2.coords
assert m1.dim_lengths is m2.dim_lengths
assert set(m2.RV_dims) < set(m1.RV_dims)

0 comments on commit ec18b5d

Please sign in to comment.