Skip to content

Commit

Permalink
Prevent Model from turning on test value computations
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 23, 2021
1 parent 20ddb4c commit 4b07810
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,10 +809,7 @@ def __new__(cls, *args, **kwargs):
instance._parent = kwargs.get("model")
else:
instance._parent = cls.get_context(error_if_none=False)
aesara_config = kwargs.get("aesara_config", None)
if aesara_config is None or "compute_test_value" not in aesara_config:
aesara_config = {"compute_test_value": "ignore"}
instance._aesara_config = aesara_config
instance._aesara_config = kwargs.get("aesara_config", {})
return instance

def __init__(self, name="", model=None, aesara_config=None, coords=None, check_bounds=True):
Expand Down Expand Up @@ -1007,7 +1004,20 @@ def independent_vars(self):
@property
def test_point(self):
"""Test point used to check that the model doesn't generate errors"""
return Point(((var, var.tag.test_value) for var in self.vars), model=self)
points = []
for var in self.free_RVs:
var_value = getattr(var.tag, "test_value", None)

if var_value is None:
try:
var_value = var.eval()
var.tag.test_value = var_value
except Exception:
raise Exception(f"Couldn't generate an initial value for {var}")

points.append((getattr(var.tag, "value_var", var), var_value))

return Point(points, model=self)

@property
def disc_vars(self):
Expand Down Expand Up @@ -1594,11 +1604,11 @@ def make_obs_var(rv_var: TensorVariable, data: Union[np.ndarray]) -> TensorVaria
else:
new_size = data.shape

test_value = getattr(rv_var.tag, "test_value", None)

rv_var = change_rv_size(rv_var, new_size)

if aesara.config.compute_test_value != "off":
test_value = getattr(rv_var.tag, "test_value", None)

if test_value is not None:
# We try to reuse the old test value
rv_var.tag.test_value = np.broadcast_to(test_value, rv_var.tag.test_value.shape)
Expand Down

0 comments on commit 4b07810

Please sign in to comment.