From 7d913ce69ca2714e08837f8d6e7d6fd7a1a62b0f Mon Sep 17 00:00:00 2001
From: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Date: Mon, 10 Jun 2024 18:14:02 +0200
Subject: [PATCH] Allow opting out of model nesting

---
 pymc/model/core.py                   | 20 +++++++++-----
 pymc/model/fgraph.py                 |  4 +--
 pymc/model/transform/basic.py        |  2 +-
 pymc/model/transform/conditioning.py |  2 +-
 pymc/sampling/deterministic.py       |  2 +-
 pymc/stats/log_density.py            | 39 +++++++++++-----------------
 tests/model/test_core.py             | 17 ++++++++++--
 tests/model/test_fgraph.py           | 11 +++++---
 8 files changed, 55 insertions(+), 42 deletions(-)

diff --git a/pymc/model/core.py b/pymc/model/core.py
index 7aec544d79f..475ff42b5f7 100644
--- a/pymc/model/core.py
+++ b/pymc/model/core.py
@@ -25,6 +25,7 @@
     Literal,
     Optional,
     TypeVar,
+    Union,
     cast,
     overload,
 )
@@ -441,7 +442,7 @@ class Model(WithMemoization, metaclass=ContextMeta):
 
         coords = {
             "feature", ["A", "B", "C"],
-             "trial", [1, 2, 3, 4, 5],
+            "trial", [1, 2, 3, 4, 5],
         }
 
         with pm.Model(coords=coords) as model:
@@ -476,6 +477,11 @@ class Model(WithMemoization, metaclass=ContextMeta):
                 # Variable will belong to root and second
                 z = pm.Normal("z", mu=y)  # Variable wil be named "root::second::z"
 
+            # Set None for standalone model
+            with pm.Model(name="third", model=None) as third:
+                # Variable will belong to third only
+                w = pm.Normal("w")  # Variable wil be named "third::w"
+
 
     Set `check_bounds` to False for models with only continuous variables and default transformers
     PyMC will remove the bounds check from the model logp which can speed up sampling
@@ -497,13 +503,13 @@ def __enter__(self: Self) -> Self: ...
 
         def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: ...
 
-    def __new__(cls, *args, **kwargs):
+    def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs):
         # resolves the parent instance
         instance = super().__new__(cls)
-        if kwargs.get("model") is not None:
-            instance._parent = kwargs.get("model")
-        else:
+        if model is UNSET:
             instance._parent = cls.get_context(error_if_none=False)
+        else:
+            instance._parent = model
         return instance
 
     @staticmethod
@@ -519,9 +525,9 @@ def __init__(
         check_bounds=True,
         *,
         coords_mutable=None,
-        model=None,
+        model: Union[Literal[UNSET], None, "Model"] = UNSET,
     ):
-        del model  # used in __new__
+        del model  # used in __new__ to define the parent of this model
         self.name = self._validate_name(name)
         self.check_bounds = check_bounds
 
diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py
index ce15c40760f..b1d67fd07b0 100644
--- a/pymc/model/fgraph.py
+++ b/pymc/model/fgraph.py
@@ -299,9 +299,7 @@ def first_non_model_var(var):
         else:
             return var
 
-    model = Model()
-    if model.parent is not None:
-        raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
+    model = Model(model=None)  # Do not inherit from any model in the context manager
 
     _coords = getattr(fgraph, "_coords", {})
     _dim_lengths = getattr(fgraph, "_dim_lengths", {})
diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py
index aff6042de5a..76556ae08ab 100644
--- a/pymc/model/transform/basic.py
+++ b/pymc/model/transform/basic.py
@@ -16,7 +16,7 @@
 from pytensor import Variable
 from pytensor.graph import ancestors
 
-from pymc import Model
+from pymc.model.core import Model
 from pymc.model.fgraph import (
     ModelObservedRV,
     ModelVar,
diff --git a/pymc/model/transform/conditioning.py b/pymc/model/transform/conditioning.py
index 0979964eecf..531ec2bd6b9 100644
--- a/pymc/model/transform/conditioning.py
+++ b/pymc/model/transform/conditioning.py
@@ -19,9 +19,9 @@
 from pytensor.graph import ancestors
 from pytensor.tensor import TensorVariable
 
-from pymc import Model
 from pymc.logprob.transforms import Transform
 from pymc.logprob.utils import rvs_in_graph
+from pymc.model.core import Model
 from pymc.model.fgraph import (
     ModelDeterministic,
     ModelFreeRV,
diff --git a/pymc/sampling/deterministic.py b/pymc/sampling/deterministic.py
index b0b04f38ec7..3d8398c3a7e 100644
--- a/pymc/sampling/deterministic.py
+++ b/pymc/sampling/deterministic.py
@@ -83,7 +83,7 @@ def compute_deterministics(
     model = modelcontext(model)
 
     if var_names is None:
-        deterministics = model.deterministics
+        deterministics = list(model.deterministics)
         var_names = [det.name for det in deterministics]
     else:
         deterministics = [model[var_name] for var_name in var_names]
diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py
index 5b6406d02b1..a26f8aa60df 100644
--- a/pymc/stats/log_density.py
+++ b/pymc/stats/log_density.py
@@ -25,6 +25,8 @@
 
 __all__ = ("compute_log_likelihood", "compute_log_prior")
 
+from pymc.model.transform.conditioning import remove_value_transforms
+
 
 def compute_log_likelihood(
     idata: InferenceData,
@@ -126,46 +128,35 @@ def compute_log_density(
     if kind not in ("likelihood", "prior"):
         raise ValueError("kind must be either 'likelihood' or 'prior'")
 
+    # We need to disable transforms, because the InferenceData only keeps the untransformed values
+    umodel = remove_value_transforms(model)
+
     if kind == "likelihood":
-        target_rvs = model.observed_RVs
+        target_rvs = list(umodel.observed_RVs)
         target_str = "observed_RVs"
     else:
-        target_rvs = model.free_RVs
+        target_rvs = list(umodel.free_RVs)
         target_str = "free_RVs"
 
     if var_names is None:
         vars = target_rvs
         var_names = tuple(rv.name for rv in vars)
     else:
-        vars = [model.named_vars[name] for name in var_names]
+        vars = [umodel.named_vars[name] for name in var_names]
         if not set(vars).issubset(target_rvs):
             raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")
 
-    # We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
-    try:
-        original_rvs_to_values = model.rvs_to_values
-        original_rvs_to_transforms = model.rvs_to_transforms
-
-        model.rvs_to_values = {
-            rv: rv.clone() if rv not in model.observed_RVs else value
-            for rv, value in model.rvs_to_values.items()
-        }
-        model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}
-
-        elemwise_logdens_fn = model.compile_fn(
-            inputs=model.value_vars,
-            outs=model.logp(vars=vars, sum=False),
-            on_unused_input="ignore",
-        )
-    finally:
-        model.rvs_to_values = original_rvs_to_values
-        model.rvs_to_transforms = original_rvs_to_transforms
+    elemwise_logdens_fn = umodel.compile_fn(
+        inputs=umodel.value_vars,
+        outs=umodel.logp(vars=vars, sum=False),
+        on_unused_input="ignore",
+    )
 
-    coords, dims = coords_and_dims_for_inferencedata(model)
+    coords, dims = coords_and_dims_for_inferencedata(umodel)
 
     logdens_dataset = apply_function_over_dataset(
         elemwise_logdens_fn,
-        posterior[[rv.name for rv in model.free_RVs]],
+        posterior[[rv.name for rv in umodel.free_RVs]],
         output_var_names=var_names,
         sample_dims=sample_dims,
         dims=dims,
diff --git a/tests/model/test_core.py b/tests/model/test_core.py
index edee73cfdd2..d4a429bc074 100644
--- a/tests/model/test_core.py
+++ b/tests/model/test_core.py
@@ -143,13 +143,20 @@ def test_docstring_example(self):
                 # Variable will belong to root and second
                 z = pm.Normal("z", mu=y)  # Variable wil be named "root::second::z"
 
+            # Set None for standalone model
+            with pm.Model(name="third", model=None) as third:
+                # Variable will belong to third only
+                w = pm.Normal("w")  # Variable wil be named "third::w"
+
         assert x.name == "root::x"
         assert y.name == "root::first::y"
         assert z.name == "root::second::z"
+        assert w.name == "third::w"
 
         assert set(root.basic_RVs) == {x, y, z}
         assert set(first.basic_RVs) == {y}
         assert set(second.basic_RVs) == {z}
+        assert set(third.basic_RVs) == {w}
 
 
 class TestNested:
@@ -1106,11 +1113,17 @@ def test_model_parent_set_programmatically():
         y = pm.Normal("y")
 
     with model:
+        # Default inherits from model
+        with pm.Model():
+            z_in = pm.Normal("z_in")
+
+        # Explict None opts out of model context
         with pm.Model(model=None):
-            z = pm.Normal("z")
+            z_out = pm.Normal("z_out")
 
     assert "y" in model.named_vars
-    assert "z" in model.named_vars
+    assert "z_in" in model.named_vars
+    assert "z_out" not in model.named_vars
 
 
 class TestModelContext:
diff --git a/tests/model/test_fgraph.py b/tests/model/test_fgraph.py
index 7a57bfc16a4..a964f1faf6d 100644
--- a/tests/model/test_fgraph.py
+++ b/tests/model/test_fgraph.py
@@ -267,10 +267,15 @@ def test_context_error():
     with pm.Model() as m:
         x = pm.Normal("x")
 
-        fg = fgraph_from_model(m)
+        fg, _ = fgraph_from_model(m)
 
-        with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"):
-            model_from_fgraph(fg)
+        new_m = model_from_fgraph(fg)
+        new_x = new_m["x"]
+
+    assert new_m.parent is None
+    assert x != new_x
+    assert m.named_vars == {"x": x}
+    assert new_m.named_vars == {"x": new_x}
 
 
 def test_sub_model_error():