Skip to content

Commit

Permalink
Fix failing VI test due to pytest change (#7144)
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov authored Feb 8, 2024
1 parent 3693198 commit 8745974
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions tests/variational/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import io
import operator
import warnings

from contextlib import nullcontext

Expand Down Expand Up @@ -196,18 +197,26 @@ def test_fit_start(inference_spec, simple_model):

# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
[observed_value] = [simple_model.rvs_to_values[obs] for obs in simple_model.observed_RVs]
if observed_value.name.startswith("minibatch"):
warn_ctxt = pytest.warns(
UserWarning, match="Could not extract data from symbolic observation"
)
else:
warn_ctxt = nullcontext()

try:
with warn_ctxt:
# We can`t use pytest.warns here because after version 8.0 it`s still check for warning when
# exception raised and test failed instead being skipped
warning_raised = False
expected_warning = observed_value.name.startswith("minibatch")
with warnings.catch_warnings(record=True) as record:
warnings.simplefilter("always")
try:
trace = inference.fit(n=0).sample(10000)
except NotImplementedInference as e:
pytest.skip(str(e))
except NotImplementedInference as e:
pytest.skip(str(e))

if expected_warning:
assert len(record) > 0
for item in record:
assert issubclass(item.category, UserWarning)
assert "Could not extract data from symbolic observation" in str(item.message)
if not expected_warning:
assert not record

np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
if has_start_sigma:
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
Expand Down

0 comments on commit 8745974

Please sign in to comment.