From 9fd14d64ab25df33ffb5dfdbabbda5397e2358e4 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 26 Oct 2023 13:10:45 +0200 Subject: [PATCH] Move all custom Exceptions to exceptions.py --- pymc/backends/base.py | 4 -- pymc/data.py | 5 +- pymc/exceptions.py | 59 +++++++++++++++++--- pymc/logprob/utils.py | 5 +- pymc/model/core.py | 2 +- pymc/sampling/mcmc.py | 4 +- pymc/sampling/parallel.py | 14 +---- pymc/step_methods/hmc/base_hmc.py | 4 +- pymc/step_methods/hmc/hmc.py | 3 +- pymc/step_methods/hmc/integration.py | 5 +- pymc/step_methods/hmc/nuts.py | 3 +- pymc/step_methods/hmc/quadpotential.py | 11 +--- pymc/testing.py | 3 +- pymc/variational/approximations.py | 5 +- pymc/variational/operators.py | 4 +- pymc/variational/opvi.py | 30 ++-------- tests/distributions/test_continuous.py | 2 +- tests/distributions/test_discrete.py | 2 +- tests/distributions/test_dist_math.py | 2 +- tests/distributions/test_multivariate.py | 2 +- tests/distributions/test_truncated.py | 3 +- tests/logprob/test_transforms.py | 2 +- tests/logprob/test_utils.py | 2 +- tests/sampling/test_parallel.py | 5 +- tests/step_methods/hmc/test_quadpotential.py | 3 +- tests/test_pytensorf.py | 6 +- tests/variational/test_inference.py | 7 ++- tests/variational/test_opvi.py | 3 +- 28 files changed, 102 insertions(+), 98 deletions(-) diff --git a/pymc/backends/base.py b/pymc/backends/base.py index fc470a9a7ce..89398e3e74f 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -45,10 +45,6 @@ logger = logging.getLogger(__name__) -class BackendError(Exception): - pass - - class IBaseTrace(ABC, Sized): """Minimal interface needed to record and access draws and stats for one MCMC chain.""" diff --git a/pymc/data.py b/pymc/data.py index 02194640bc9..01f76b9cd95 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -36,6 +36,7 @@ import pymc as pm +from pymc.exceptions import ShapeError from pymc.pytensorf import convert_observed_data __all__ = [ @@ -237,7 +238,7 @@ def determine_coords( if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: - raise pm.exceptions.ShapeError( + raise ShapeError( "Invalid data shape. The rank of the dataset must match the " "length of `dims`.", actual=value.shape, expected=value.ndim, @@ -445,7 +446,7 @@ def Data( if isinstance(dims, str): dims = (dims,) if not (dims is None or len(dims) == x.ndim): - raise pm.exceptions.ShapeError( + raise ShapeError( "Length of `dims` must match the dimensions of the dataset.", actual=len(dims), expected=x.ndim, diff --git a/pymc/exceptions.py b/pymc/exceptions.py index c10d2e76624..157e61abeb9 100644 --- a/pymc/exceptions.py +++ b/pymc/exceptions.py @@ -12,13 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - "SamplingError", - "ImputationWarning", - "ShapeWarning", - "ShapeError", -] - class SamplingError(RuntimeError): pass @@ -74,3 +67,55 @@ class NotConstantValueError(ValueError): class BlockModelAccessError(RuntimeError): pass + + +class ParallelSamplingError(Exception): + def __init__(self, message, chain): + super().__init__(message) + self._chain = chain + + +class RemoteTraceback(Exception): + def __init__(self, tb): + self.tb = tb + + def __str__(self): + return self.tb + + +class VariationalInferenceError(Exception): + """Exception for VI specific cases""" + + +class NotImplementedInference(VariationalInferenceError, NotImplementedError): + """Marking non functional parts of code""" + + +class ExplicitInferenceError(VariationalInferenceError, TypeError): + """Exception for bad explicit inference""" + + +class ParametrizationError(VariationalInferenceError, ValueError): + """Error raised in case of bad parametrization""" + + +class GroupError(VariationalInferenceError, TypeError): + """Error related to VI groups""" + + +class IntegrationError(RuntimeError): + pass + + +class PositiveDefiniteError(ValueError): + def __init__(self, msg, idx): + super().__init__(msg) + self.idx = idx + self.msg = msg + + def __str__(self): + return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}." + + +class ParameterValueError(ValueError): + """Exception for invalid parameters values in logprob graphs""" diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 783b9ad95de..a9f2693de0d 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -63,6 +63,7 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable +from pymc.exceptions import ParameterValueError from pymc.logprob.abstract import MeasurableVariable, _logprob from pymc.util import makeiter @@ -231,10 +232,6 @@ def check_potential_measurability( return False -class ParameterValueError(ValueError): - """Exception for invalid parameters values in logprob graphs""" - - class CheckParameterValue(CheckAndRaise): """Implements a parameter value check in a logprob graph. diff --git a/pymc/model/core.py b/pymc/model/core.py index 9fe8626d3fe..5b1ce825710 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -59,13 +59,13 @@ from pymc.exceptions import ( BlockModelAccessError, ImputationWarning, + ParameterValueError, SamplingError, ShapeError, ShapeWarning, ) from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import transformed_conditional_logp -from pymc.logprob.utils import ParameterValueError from pymc.model_graph import model_to_graphviz from pymc.pytensorf import ( PointFunc, diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f5468372900..19c372750f6 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -55,7 +55,7 @@ ) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.blocking import DictToArrayBijection -from pymc.exceptions import SamplingError +from pymc.exceptions import ParallelSamplingError, SamplingError from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain from pymc.model import Model, modelcontext from pymc.sampling.parallel import Draw, _cpu_count @@ -1198,7 +1198,7 @@ def _mp_sample( if callback is not None: callback(trace=strace, draw=draw) - except ps.ParallelSamplingError as error: + except ParallelSamplingError as error: strace = traces[error._chain] for strace in traces: strace.close() diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 91af1a58dbd..e0072dbc431 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -29,25 +29,13 @@ from fastprogress.fastprogress import progress_bar from pymc.blocking import DictToArrayBijection -from pymc.exceptions import SamplingError +from pymc.exceptions import ParallelSamplingError, RemoteTraceback, SamplingError from pymc.util import RandomSeed logger = logging.getLogger(__name__) -class ParallelSamplingError(Exception): - def __init__(self, message, chain): - super().__init__(message) - self._chain = chain - - # Taken from https://hg.python.org/cpython/rev/c4f92b597074 -class RemoteTraceback(Exception): - def __init__(self, tb): - self.tb = tb - - def __str__(self): - return self.tb class ExceptionWithTraceback: diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 7b4719f81d0..b41fec96e14 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -23,14 +23,14 @@ import numpy as np from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType -from pymc.exceptions import SamplingError +from pymc.exceptions import IntegrationError, SamplingError from pymc.model import Point, modelcontext from pymc.pytensorf import floatX from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods import step_sizes from pymc.step_methods.arraystep import GradientSharedStep from pymc.step_methods.hmc import integration -from pymc.step_methods.hmc.integration import IntegrationError, State +from pymc.step_methods.hmc.integration import State from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential from pymc.tuning import guess_scaling from pymc.util import get_value_vars_from_user_vars diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 2a92a0c3322..a84ea848155 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -18,10 +18,11 @@ import numpy as np +from pymc.exceptions import IntegrationError from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData -from pymc.step_methods.hmc.integration import IntegrationError, State +from pymc.step_methods.hmc.integration import State from pymc.vartypes import discrete_types __all__ = ["HamiltonianMC"] diff --git a/pymc/step_methods/hmc/integration.py b/pymc/step_methods/hmc/integration.py index 5e0bdb8ee09..f501c1b1230 100644 --- a/pymc/step_methods/hmc/integration.py +++ b/pymc/step_methods/hmc/integration.py @@ -19,6 +19,7 @@ from scipy import linalg from pymc.blocking import RaveledVars +from pymc.exceptions import IntegrationError from pymc.step_methods.hmc.quadpotential import QuadPotential @@ -32,10 +33,6 @@ class State(NamedTuple): index_in_trajectory: int -class IntegrationError(RuntimeError): - pass - - class CpuLeapfrogIntegrator: def __init__(self, potential: QuadPotential, logp_dlogp_func): """Leapfrog integrator using CPU.""" diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 61dc56a8a67..5431421c8a2 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -18,13 +18,14 @@ import numpy as np +from pymc.exceptions import IntegrationError from pymc.math import logbern from pymc.pytensorf import floatX from pymc.stats.convergence import SamplerWarning from pymc.step_methods.compound import Competence from pymc.step_methods.hmc import integration from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData -from pymc.step_methods.hmc.integration import IntegrationError, State +from pymc.step_methods.hmc.integration import State from pymc.vartypes import continuous_types __all__ = ["NUTS"] diff --git a/pymc/step_methods/hmc/quadpotential.py b/pymc/step_methods/hmc/quadpotential.py index 9dd43b1748f..f04e4c72310 100644 --- a/pymc/step_methods/hmc/quadpotential.py +++ b/pymc/step_methods/hmc/quadpotential.py @@ -25,6 +25,7 @@ from numpy.random import normal from scipy.sparse import issparse +from pymc.exceptions import PositiveDefiniteError from pymc.pytensorf import floatX __all__ = [ @@ -87,16 +88,6 @@ def partial_check_positive_definite(C): raise PositiveDefiniteError("Simple check failed. Diagonal contains negatives", i) -class PositiveDefiniteError(ValueError): - def __init__(self, msg, idx): - super().__init__(msg) - self.idx = idx - self.msg = msg - - def __str__(self): - return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}." - - class QuadPotential: dtype: np.dtype diff --git a/pymc/testing.py b/pymc/testing.py index 3eb1b7ba819..517293c9e1d 100644 --- a/pymc/testing.py +++ b/pymc/testing.py @@ -34,9 +34,10 @@ from pymc.distributions.distribution import Distribution from pymc.distributions.shape_utils import change_dist_size +from pymc.exceptions import ParameterValueError from pymc.initial_point import make_initial_point_fn from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp -from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph +from pymc.logprob.utils import find_rvs_in_graph from pymc.pytensorf import ( compile_pymc, floatX, diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index bac4a9926b1..485603859b5 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -26,12 +26,11 @@ from pymc.blocking import DictToArrayBijection from pymc.distributions.dist_math import rho2sigma +from pymc.exceptions import NotImplementedInference from pymc.util import makeiter -from pymc.variational import opvi from pymc.variational.opvi import ( Approximation, Group, - NotImplementedInference, _known_scan_ignored_inputs, node_property, ) @@ -212,7 +211,7 @@ def __init_group__(self, group): def create_shared_params(self, trace=None, size=None, jitter=1, start=None): if trace is None: if size is None: - raise opvi.ParametrizationError("Need `trace` or `size` to initialize") + raise pymc.exceptions.ParametrizationError("Need `trace` or `size` to initialize") else: start = self._prepare_start(start) # Initialize particles diff --git a/pymc/variational/operators.py b/pymc/variational/operators.py index a57f1272ac0..b4c721a60d7 100644 --- a/pymc/variational/operators.py +++ b/pymc/variational/operators.py @@ -19,9 +19,9 @@ import pymc as pm +from pymc.exceptions import NotImplementedInference, ParametrizationError from pymc.variational import opvi from pymc.variational.opvi import ( - NotImplementedInference, ObjectiveFunction, Operator, _known_scan_ignored_inputs, @@ -81,7 +81,7 @@ class KSDObjective(ObjectiveFunction): def __init__(self, op: KSD, tf: opvi.TestFunction): if not isinstance(op, KSD): - raise opvi.ParametrizationError("Op should be KSD") + raise ParametrizationError("Op should be KSD") super().__init__(op, tf) @pytensor.config.change_flags(compute_test_value="off") diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 99261f026e0..dfae01c9857 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -67,6 +67,12 @@ from pymc.backends.base import MultiTrace from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection +from pymc.exceptions import ( + ExplicitInferenceError, + GroupError, + ParametrizationError, + VariationalInferenceError, +) from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext from pymc.pytensorf import ( @@ -90,30 +96,6 @@ __all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"] -class VariationalInferenceError(Exception): - """Exception for VI specific cases""" - - -class NotImplementedInference(VariationalInferenceError, NotImplementedError): - """Marking non functional parts of code""" - - -class ExplicitInferenceError(VariationalInferenceError, TypeError): - """Exception for bad explicit inference""" - - -class AEVBInferenceError(VariationalInferenceError, TypeError): - """Exception for bad aevb inference""" - - -class ParametrizationError(VariationalInferenceError, ValueError): - """Error raised in case of bad parametrization""" - - -class GroupError(VariationalInferenceError, TypeError): - """Error related to VI groups""" - - def _known_scan_ignored_inputs(terms): # TODO: remove when scan issue with grads is fixed from pymc.data import MinibatchIndexRV diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 3d74b00bff4..567c53aeffd 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -29,8 +29,8 @@ from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated from pymc.distributions.dist_math import clipped_beta_rvs +from pymc.exceptions import ParameterValueError from pymc.logprob.basic import icdf, logcdf, logp -from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import floatX from pymc.testing import ( BaseTestDistributionRandom, diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index e9543c07465..7c95e5f7e4a 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -29,8 +29,8 @@ import pymc as pm from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit +from pymc.exceptions import ParameterValueError from pymc.logprob.basic import icdf, logcdf, logp -from pymc.logprob.utils import ParameterValueError from pymc.pytensorf import floatX from pymc.testing import ( BaseTestDistributionRandom, diff --git a/tests/distributions/test_dist_math.py b/tests/distributions/test_dist_math.py index 99f22af31e5..c264155979d 100644 --- a/tests/distributions/test_dist_math.py +++ b/tests/distributions/test_dist_math.py @@ -33,7 +33,7 @@ incomplete_beta, multigammaln, ) -from pymc.logprob.utils import ParameterValueError +from pymc.exceptions import ParameterValueError from pymc.pytensorf import floatX from tests.helpers import verify_grad diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index f9f0546fa50..a8a66bf121a 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -39,8 +39,8 @@ quaddist_matrix, ) from pymc.distributions.shape_utils import change_dist_size, to_tuple +from pymc.exceptions import ParameterValueError from pymc.logprob.basic import logp -from pymc.logprob.utils import ParameterValueError from pymc.math import kronecker from pymc.pytensorf import compile_pymc, floatX, intX from pymc.sampling.forward import draw diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index d9d007c51f1..b2f7d6490c6 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -24,11 +24,10 @@ from pymc.distributions.shape_utils import change_dist_size from pymc.distributions.transforms import _default_transform from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated -from pymc.exceptions import TruncationError +from pymc.exceptions import ParameterValueError, TruncationError from pymc.logprob.abstract import _icdf from pymc.logprob.basic import logcdf, logp from pymc.logprob.transforms import IntervalTransform -from pymc.logprob.utils import ParameterValueError from pymc.testing import assert_moment_is_expected diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 32924a37d20..c27f446b4ad 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -49,6 +49,7 @@ from pymc.distributions.continuous import Cauchy from pymc.distributions.transforms import _default_transform, log, logodds +from pymc.exceptions import ParameterValueError from pymc.logprob.abstract import MeasurableVariable, _logprob from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( @@ -72,7 +73,6 @@ TransformValuesMapping, TransformValuesRewrite, ) -from pymc.logprob.utils import ParameterValueError from pymc.testing import Rplusbig, Vector, assert_no_rvs from tests.distributions.test_transform import check_jacobian_det diff --git a/tests/logprob/test_utils.py b/tests/logprob/test_utils.py index 320de6a36a0..4ee492c01af 100644 --- a/tests/logprob/test_utils.py +++ b/tests/logprob/test_utils.py @@ -47,10 +47,10 @@ import pymc as pm +from pymc.exceptions import ParameterValueError from pymc.logprob.abstract import MeasurableVariable from pymc.logprob.basic import logp, transformed_conditional_logp from pymc.logprob.utils import ( - ParameterValueError, check_potential_measurability, dirac_delta, rvs_to_value_vars, diff --git a/tests/sampling/test_parallel.py b/tests/sampling/test_parallel.py index 2b56882f3ca..024f61efc4a 100644 --- a/tests/sampling/test_parallel.py +++ b/tests/sampling/test_parallel.py @@ -27,6 +27,7 @@ from pytensor.tensor.type import TensorType import pymc as pm +import pymc.exceptions import pymc.sampling.parallel as ps from pymc.pytensorf import floatX @@ -87,7 +88,9 @@ def test_remote_pipe_closed(): pm.Normal("y", mu=_crash_remote_process(x, at_pid), shape=2) step = pm.Metropolis() - with pytest.raises(ps.ParallelSamplingError, match="Chain [0-9] failed with") as ex: + with pytest.raises( + pymc.exceptions.ParallelSamplingError, match="Chain [0-9] failed with" + ) as ex: pm.sample(step=step, mp_ctx="spawn", tune=2, draws=2, cores=2, chains=2) diff --git a/tests/step_methods/hmc/test_quadpotential.py b/tests/step_methods/hmc/test_quadpotential.py index ce9719215a0..a12cfdda2b5 100644 --- a/tests/step_methods/hmc/test_quadpotential.py +++ b/tests/step_methods/hmc/test_quadpotential.py @@ -20,13 +20,14 @@ import pymc +from pymc.exceptions import PositiveDefiniteError from pymc.pytensorf import floatX from pymc.step_methods.hmc import quadpotential def test_elemwise_posdef(): scaling = np.array([0, 2, 3]) - with pytest.raises(quadpotential.PositiveDefiniteError): + with pytest.raises(PositiveDefiniteError): quadpotential.quad_potential(scaling, True) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index fac7b704625..e204aca49be 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -25,6 +25,7 @@ from pytensor import scan, shared from pytensor.compile.builders import OpFromGraph from pytensor.graph.basic import Variable, equal_computations +from pytensor.tensor.exceptions import N, NotScalarConstantError from pytensor.tensor.random.basic import normal, uniform from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.var import RandomStateSharedVariable @@ -36,8 +37,7 @@ from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import SymbolicRandomVariable from pymc.distributions.transforms import Interval -from pymc.exceptions import NotConstantValueError -from pymc.logprob.utils import ParameterValueError +from pymc.exceptions import ParameterValueError from pymc.pytensorf import ( _replace_vars_in_graphs, collect_default_updates, @@ -661,7 +661,7 @@ def test_constant_fold_raises(): x = pt.random.normal(size=(size,)) y = pt.arange(x.size) - with pytest.raises(NotConstantValueError): + with pytest.raises(NotScalarConstantError): constant_fold((y, y.shape)) res = constant_fold((y, y.shape), raise_not_constant=False) diff --git a/tests/variational/test_inference.py b/tests/variational/test_inference.py index b38c94740a0..a7324acd40f 100644 --- a/tests/variational/test_inference.py +++ b/tests/variational/test_inference.py @@ -24,11 +24,12 @@ import pytest import pymc as pm +import pymc.exceptions import pymc.variational.opvi as opvi +from pymc.exceptions import NotImplementedInference from pymc.pytensorf import intX from pymc.variational.inference import ADVI, ASVGD, SVGD, FullRankADVI -from pymc.variational.opvi import NotImplementedInference from tests import models pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test", "fail_on_warning") @@ -288,14 +289,14 @@ def test_replacements(binomial_model_inference): try: p_z = approx.sample_node(p_t, deterministic=True, size=10) assert p_z.shape.eval() == (10,) - except opvi.NotImplementedInference: + except pymc.exceptions.NotImplementedInference: pass try: p_d = approx.sample_node(p_t, deterministic=True) sampled = [pm.draw(p_d) for _ in range(100)] assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic - except opvi.NotImplementedInference: + except pymc.exceptions.NotImplementedInference: pass p_r = approx.sample_node(p_t, deterministic=d) diff --git a/tests/variational/test_opvi.py b/tests/variational/test_opvi.py index 84214197a45..a27ef812f8b 100644 --- a/tests/variational/test_opvi.py +++ b/tests/variational/test_opvi.py @@ -19,6 +19,7 @@ import pytest import pymc as pm +import pymc.exceptions from pymc.variational import opvi from pymc.variational.approximations import ( @@ -41,7 +42,7 @@ def test_discrete_not_allowed(): mu = pm.Normal("mu", mu=0, sigma=10, size=3) z = pm.Categorical("z", p=pt.ones(3) / 3, size=len(y)) pm.Normal("y_obs", mu=mu[z], sigma=1.0, observed=y) - with pytest.raises(opvi.ParametrizationError, match="Discrete variables"): + with pytest.raises(pymc.exceptions.ParametrizationError, match="Discrete variables"): pm.fit(n=1) # fails