Skip to content

Commit

Permalink
Move all custom Exceptions to exceptions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 26, 2023
1 parent 4856e22 commit e7e8a54
Show file tree
Hide file tree
Showing 29 changed files with 101 additions and 98 deletions.
4 changes: 0 additions & 4 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
logger = logging.getLogger(__name__)


class BackendError(Exception):
pass


class IBaseTrace(ABC, Sized):
"""Minimal interface needed to record and access draws and stats for one MCMC chain."""

Expand Down
5 changes: 3 additions & 2 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import pymc as pm

from pymc.exceptions import ShapeError
from pymc.pytensorf import convert_observed_data

__all__ = [
Expand Down Expand Up @@ -237,7 +238,7 @@ def determine_coords(

if isinstance(value, np.ndarray) and dims is not None:
if len(dims) != value.ndim:
raise pm.exceptions.ShapeError(
raise ShapeError(

Check warning on line 241 in pymc/data.py

View check run for this annotation

Codecov / codecov/patch

pymc/data.py#L241

Added line #L241 was not covered by tests
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
actual=value.shape,
expected=value.ndim,
Expand Down Expand Up @@ -445,7 +446,7 @@ def Data(
if isinstance(dims, str):
dims = (dims,)
if not (dims is None or len(dims) == x.ndim):
raise pm.exceptions.ShapeError(
raise ShapeError(

Check warning on line 449 in pymc/data.py

View check run for this annotation

Codecov / codecov/patch

pymc/data.py#L449

Added line #L449 was not covered by tests
"Length of `dims` must match the dimensions of the dataset.",
actual=len(dims),
expected=x.ndim,
Expand Down
59 changes: 52 additions & 7 deletions pymc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = [
"SamplingError",
"ImputationWarning",
"ShapeWarning",
"ShapeError",
]


class SamplingError(RuntimeError):
pass
Expand Down Expand Up @@ -74,3 +67,55 @@ class NotConstantValueError(ValueError):

class BlockModelAccessError(RuntimeError):
pass


class ParallelSamplingError(Exception):
def __init__(self, message, chain):
super().__init__(message)
self._chain = chain

Check warning on line 75 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L74-L75

Added lines #L74 - L75 were not covered by tests


class RemoteTraceback(Exception):
def __init__(self, tb):
self.tb = tb

Check warning on line 80 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L80

Added line #L80 was not covered by tests

def __str__(self):
return self.tb

Check warning on line 83 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L83

Added line #L83 was not covered by tests


class VariationalInferenceError(Exception):
"""Exception for VI specific cases"""


class NotImplementedInference(VariationalInferenceError, NotImplementedError):
"""Marking non functional parts of code"""


class ExplicitInferenceError(VariationalInferenceError, TypeError):
"""Exception for bad explicit inference"""


class ParametrizationError(VariationalInferenceError, ValueError):
"""Error raised in case of bad parametrization"""


class GroupError(VariationalInferenceError, TypeError):
"""Error related to VI groups"""


class IntegrationError(RuntimeError):
pass


class PositiveDefiniteError(ValueError):
def __init__(self, msg, idx):
super().__init__(msg)
self.idx = idx
self.msg = msg

Check warning on line 114 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L112-L114

Added lines #L112 - L114 were not covered by tests

def __str__(self):
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."

Check warning on line 117 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L117

Added line #L117 was not covered by tests


class ParameterValueError(ValueError):
"""Exception for invalid parameters values in logprob graphs"""
5 changes: 1 addition & 4 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

from pymc.exceptions import ParameterValueError
from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.util import makeiter

Expand Down Expand Up @@ -231,10 +232,6 @@ def check_potential_measurability(
return False


class ParameterValueError(ValueError):
"""Exception for invalid parameters values in logprob graphs"""


class CheckParameterValue(CheckAndRaise):
"""Implements a parameter value check in a logprob graph.
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@
from pymc.exceptions import (
BlockModelAccessError,
ImputationWarning,
ParameterValueError,
SamplingError,
ShapeError,
ShapeWarning,
)
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.utils import ParameterValueError
from pymc.model_graph import model_to_graphviz
from pymc.pytensorf import (
PointFunc,
Expand Down
4 changes: 2 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.exceptions import ParallelSamplingError, SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import Draw, _cpu_count
Expand Down Expand Up @@ -1199,7 +1199,7 @@ def _mp_sample(
if callback is not None:
callback(trace=strace, draw=draw)

except ps.ParallelSamplingError as error:
except ParallelSamplingError as error:

Check warning on line 1202 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L1202

Added line #L1202 was not covered by tests
strace = traces[error._chain]
for strace in traces:
strace.close()
Expand Down
14 changes: 1 addition & 13 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,13 @@
from fastprogress.fastprogress import progress_bar

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.exceptions import ParallelSamplingError, RemoteTraceback, SamplingError
from pymc.util import RandomSeed

logger = logging.getLogger(__name__)


class ParallelSamplingError(Exception):
def __init__(self, message, chain):
super().__init__(message)
self._chain = chain


# Taken from https://hg.python.org/cpython/rev/c4f92b597074
class RemoteTraceback(Exception):
def __init__(self, tb):
self.tb = tb

def __str__(self):
return self.tb


class ExceptionWithTraceback:
Expand Down
4 changes: 2 additions & 2 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import numpy as np

from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
from pymc.exceptions import SamplingError
from pymc.exceptions import IntegrationError, SamplingError
from pymc.model import Point, modelcontext
from pymc.pytensorf import floatX
from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods import step_sizes
from pymc.step_methods.arraystep import GradientSharedStep
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.hmc.integration import State
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
from pymc.tuning import guess_scaling
from pymc.util import get_value_vars_from_user_vars
Expand Down
3 changes: 2 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

import numpy as np

from pymc.exceptions import IntegrationError
from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.hmc.integration import State
from pymc.vartypes import discrete_types

__all__ = ["HamiltonianMC"]
Expand Down
5 changes: 1 addition & 4 deletions pymc/step_methods/hmc/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from scipy import linalg

from pymc.blocking import RaveledVars
from pymc.exceptions import IntegrationError
from pymc.step_methods.hmc.quadpotential import QuadPotential


Expand All @@ -32,10 +33,6 @@ class State(NamedTuple):
index_in_trajectory: int


class IntegrationError(RuntimeError):
pass


class CpuLeapfrogIntegrator:
def __init__(self, potential: QuadPotential, logp_dlogp_func):
"""Leapfrog integrator using CPU."""
Expand Down
3 changes: 2 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

import numpy as np

from pymc.exceptions import IntegrationError
from pymc.math import logbern
from pymc.pytensorf import floatX
from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.hmc.integration import State
from pymc.vartypes import continuous_types

__all__ = ["NUTS"]
Expand Down
11 changes: 1 addition & 10 deletions pymc/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from numpy.random import normal
from scipy.sparse import issparse

from pymc.exceptions import PositiveDefiniteError
from pymc.pytensorf import floatX

__all__ = [
Expand Down Expand Up @@ -87,16 +88,6 @@ def partial_check_positive_definite(C):
raise PositiveDefiniteError("Simple check failed. Diagonal contains negatives", i)


class PositiveDefiniteError(ValueError):
def __init__(self, msg, idx):
super().__init__(msg)
self.idx = idx
self.msg = msg

def __str__(self):
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."


class QuadPotential:
dtype: np.dtype

Expand Down
3 changes: 2 additions & 1 deletion pymc/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@

from pymc.distributions.distribution import Distribution
from pymc.distributions.shape_utils import change_dist_size
from pymc.exceptions import ParameterValueError

Check warning on line 37 in pymc/testing.py

View check run for this annotation

Codecov / codecov/patch

pymc/testing.py#L37

Added line #L37 was not covered by tests
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import icdf, logcdf, logp, transformed_conditional_logp
from pymc.logprob.utils import ParameterValueError, find_rvs_in_graph
from pymc.logprob.utils import find_rvs_in_graph

Check warning on line 40 in pymc/testing.py

View check run for this annotation

Codecov / codecov/patch

pymc/testing.py#L40

Added line #L40 was not covered by tests
from pymc.pytensorf import (
compile_pymc,
floatX,
Expand Down
5 changes: 2 additions & 3 deletions pymc/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@

from pymc.blocking import DictToArrayBijection
from pymc.distributions.dist_math import rho2sigma
from pymc.exceptions import NotImplementedInference
from pymc.util import makeiter
from pymc.variational import opvi
from pymc.variational.opvi import (
Approximation,
Group,
NotImplementedInference,
_known_scan_ignored_inputs,
node_property,
)
Expand Down Expand Up @@ -212,7 +211,7 @@ def __init_group__(self, group):
def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
if trace is None:
if size is None:
raise opvi.ParametrizationError("Need `trace` or `size` to initialize")
raise pymc.exceptions.ParametrizationError("Need `trace` or `size` to initialize")

Check warning on line 214 in pymc/variational/approximations.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/approximations.py#L214

Added line #L214 was not covered by tests
else:
start = self._prepare_start(start)
# Initialize particles
Expand Down
4 changes: 2 additions & 2 deletions pymc/variational/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import pymc as pm

from pymc.exceptions import NotImplementedInference, ParametrizationError
from pymc.variational import opvi
from pymc.variational.opvi import (
NotImplementedInference,
ObjectiveFunction,
Operator,
_known_scan_ignored_inputs,
Expand Down Expand Up @@ -81,7 +81,7 @@ class KSDObjective(ObjectiveFunction):

def __init__(self, op: KSD, tf: opvi.TestFunction):
if not isinstance(op, KSD):
raise opvi.ParametrizationError("Op should be KSD")
raise ParametrizationError("Op should be KSD")

Check warning on line 84 in pymc/variational/operators.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/operators.py#L84

Added line #L84 was not covered by tests
super().__init__(op, tf)

@pytensor.config.change_flags(compute_test_value="off")
Expand Down
30 changes: 6 additions & 24 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
from pymc.backends.base import MultiTrace
from pymc.backends.ndarray import NDArray
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import (
ExplicitInferenceError,
GroupError,
ParametrizationError,
VariationalInferenceError,
)
from pymc.initial_point import make_initial_point_fn
from pymc.model import modelcontext
from pymc.pytensorf import (
Expand All @@ -91,30 +97,6 @@
__all__ = ["ObjectiveFunction", "Operator", "TestFunction", "Group", "Approximation"]


class VariationalInferenceError(Exception):
"""Exception for VI specific cases"""


class NotImplementedInference(VariationalInferenceError, NotImplementedError):
"""Marking non functional parts of code"""


class ExplicitInferenceError(VariationalInferenceError, TypeError):
"""Exception for bad explicit inference"""


class AEVBInferenceError(VariationalInferenceError, TypeError):
"""Exception for bad aevb inference"""


class ParametrizationError(VariationalInferenceError, ValueError):
"""Error raised in case of bad parametrization"""


class GroupError(VariationalInferenceError, TypeError):
"""Error related to VI groups"""


def _known_scan_ignored_inputs(terms):
# TODO: remove when scan issue with grads is fixed
from pymc.data import MinibatchIndexRV
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

from pymc.distributions.continuous import Normal, Uniform, get_tau_sigma, interpolated
from pymc.distributions.dist_math import clipped_beta_rvs
from pymc.exceptions import ParameterValueError
from pymc.logprob.basic import icdf, logcdf, logp
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import floatX
from pymc.testing import (
BaseTestDistributionRandom,
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import pymc as pm

from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit
from pymc.exceptions import ParameterValueError
from pymc.logprob.basic import icdf, logcdf, logp
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import floatX
from pymc.testing import (
BaseTestDistributionRandom,
Expand Down
Loading

0 comments on commit e7e8a54

Please sign in to comment.