Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove several functions and objects from PyMC root namespace #6973

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,36 +46,19 @@ def __set_compiler_flags():

__set_compiler_flags()

from pymc import _version, gp, ode, sampling
from pymc.backends import *
from pymc.blocking import *
from pymc import _version, gp, ode, plots, sampling, stats
from pymc.data import *
from pymc.distributions import *
from pymc.exceptions import *
from pymc.func_utils import find_constrained_prior
from pymc.logprob import *
from pymc.math import (
expand_packed_triangular,
invlogit,
invprobit,
logaddexp,
logit,
logsumexp,
probit,
)
from pymc.model.core import *
from pymc.model.transform.conditioning import do, observe
from pymc.model_graph import model_to_graphviz, model_to_networkx
from pymc.plots import *
from pymc.printing import *
from pymc.pytensorf import *
from pymc.sampling import *
from pymc.smc import *
from pymc.stats import *
from pymc.step_methods import *
from pymc.tuning import *
from pymc.util import drop_warning_stat
from pymc.variational import *
from pymc.vartypes import *

__version__ = _version.get_versions()["version"]
4 changes: 0 additions & 4 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
logger = logging.getLogger(__name__)


class BackendError(Exception):
pass


class IBaseTrace(ABC, Sized):
"""Minimal interface needed to record and access draws and stats for one MCMC chain."""

Expand Down
3 changes: 0 additions & 3 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@

from typing_extensions import TypeAlias

__all__ = ["DictToArrayBijection"]


T = TypeVar("T")
PointType: TypeAlias = Dict[str, np.ndarray]
StatsDict: TypeAlias = Dict[str, Any]
Expand Down
6 changes: 3 additions & 3 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@

import pymc as pm

from pymc.exceptions import ShapeError
from pymc.pytensorf import convert_observed_data

__all__ = [
"get_data",
"GeneratorAdapter",
"Minibatch",
"Data",
"ConstantData",
Expand Down Expand Up @@ -238,7 +238,7 @@

if isinstance(value, np.ndarray) and dims is not None:
if len(dims) != value.ndim:
raise pm.exceptions.ShapeError(
raise ShapeError(

Check warning on line 241 in pymc/data.py

View check run for this annotation

Codecov / codecov/patch

pymc/data.py#L241

Added line #L241 was not covered by tests
"Invalid data shape. The rank of the dataset must match the " "length of `dims`.",
actual=value.shape,
expected=value.ndim,
Expand Down Expand Up @@ -446,7 +446,7 @@
if isinstance(dims, str):
dims = (dims,)
if not (dims is None or len(dims) == x.ndim):
raise pm.exceptions.ShapeError(
raise ShapeError(

Check warning on line 449 in pymc/data.py

View check run for this annotation

Codecov / codecov/patch

pymc/data.py#L449

Added line #L449 was not covered by tests
"Length of `dims` must match the dimensions of the dataset.",
actual=len(dims),
expected=x.ndim,
Expand Down
71 changes: 52 additions & 19 deletions pymc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = [
"SamplingError",
"IncorrectArgumentsError",
"TraceDirectoryError",
"ImputationWarning",
"ShapeWarning",
"ShapeError",
]


class SamplingError(RuntimeError):
pass


class IncorrectArgumentsError(ValueError):
pass


class TraceDirectoryError(ValueError):
"""Error from trying to load a trace from an incorrectly-structured directory,"""

pass


class ImputationWarning(UserWarning):
"""Warning that there are missing values that will be imputed."""

Expand Down Expand Up @@ -86,3 +67,55 @@

class BlockModelAccessError(RuntimeError):
pass


class ParallelSamplingError(Exception):
def __init__(self, message, chain):
super().__init__(message)
self._chain = chain

Check warning on line 75 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L74-L75

Added lines #L74 - L75 were not covered by tests


class RemoteTraceback(Exception):
def __init__(self, tb):
self.tb = tb

Check warning on line 80 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L80

Added line #L80 was not covered by tests

def __str__(self):
return self.tb

Check warning on line 83 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L83

Added line #L83 was not covered by tests


class VariationalInferenceError(Exception):
"""Exception for VI specific cases"""


class NotImplementedInference(VariationalInferenceError, NotImplementedError):
"""Marking non functional parts of code"""


class ExplicitInferenceError(VariationalInferenceError, TypeError):
"""Exception for bad explicit inference"""


class ParametrizationError(VariationalInferenceError, ValueError):
"""Error raised in case of bad parametrization"""


class GroupError(VariationalInferenceError, TypeError):
"""Error related to VI groups"""


class IntegrationError(RuntimeError):
pass


class PositiveDefiniteError(ValueError):
def __init__(self, msg, idx):
super().__init__(msg)
self.idx = idx
self.msg = msg

Check warning on line 114 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L112-L114

Added lines #L112 - L114 were not covered by tests

def __str__(self):
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."

Check warning on line 117 in pymc/exceptions.py

View check run for this annotation

Codecov / codecov/patch

pymc/exceptions.py#L117

Added line #L117 was not covered by tests


class ParameterValueError(ValueError):
"""Exception for invalid parameters values in logprob graphs"""
5 changes: 1 addition & 4 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

from pymc.exceptions import ParameterValueError
from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.util import makeiter

Expand Down Expand Up @@ -231,10 +232,6 @@ def check_potential_measurability(
return False


class ParameterValueError(ValueError):
"""Exception for invalid parameters values in logprob graphs"""


class CheckParameterValue(CheckAndRaise):
"""Implements a parameter value check in a logprob graph.

Expand Down
2 changes: 1 addition & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@
from pymc.exceptions import (
BlockModelAccessError,
ImputationWarning,
ParameterValueError,
SamplingError,
ShapeError,
ShapeWarning,
)
from pymc.initial_point import make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.utils import ParameterValueError
from pymc.model_graph import model_to_graphviz
from pymc.pytensorf import (
PointFunc,
Expand Down
38 changes: 0 additions & 38 deletions pymc/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"""
import functools
import sys
import warnings

import arviz as az

Expand All @@ -29,40 +28,3 @@
obj = getattr(az.plots, attr)
if not attr.startswith("__"):
setattr(sys.modules[__name__], attr, obj)


def alias_deprecation(func, alias: str):
original = func.__name__

@functools.wraps(func)
def wrapped(*args, **kwargs):
raise FutureWarning(
f"The function `{alias}` from PyMC was an alias for `{original}` from ArviZ. "
"It was removed in PyMC 4.0. "
f"Switch to `pymc.{original}` or `arviz.{original}`."
)

return wrapped


# Aliases of ArviZ functions
autocorrplot = alias_deprecation(az.plot_autocorr, alias="autocorrplot")
forestplot = alias_deprecation(az.plot_forest, alias="forestplot")
kdeplot = alias_deprecation(az.plot_kde, alias="kdeplot")
energyplot = alias_deprecation(az.plot_energy, alias="energyplot")
densityplot = alias_deprecation(az.plot_density, alias="densityplot")
pairplot = alias_deprecation(az.plot_pair, alias="pairplot")
traceplot = alias_deprecation(az.plot_trace, alias="traceplot")
compareplot = alias_deprecation(az.plot_compare, alias="compareplot")
Comment on lines -34 to -56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me like this has already been issuing deprecation warnings. Was this working and warning against using everything slated to be removed since v4?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was just for utilities whose names have changed. Those could be safely removed by now yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, then I agree with @ColCarroll that we should definitely deprecate all the other stuff well in advance of removing it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why remove the deprecation warning? In case you want to clear up the namespace then you could refactor it into something like https://peps.python.org/pep-0562/

Copy link
Member Author

@ricardoV94 ricardoV94 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this deprecation warning because I removed the objects that were being deprecated as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These one specifically were deprecated since v4, seems safe no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't there still lots of people who use PyMC3 because of the name recognition and all the SEO? Seems pretty low-effort to provide explicit instructions for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds silly. This will not be the thing that people switching from v3 to v5 will find challenging

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, I agree it's silly in the case of these ArviZ warnings, but you could do something similar to provide a transition period for the rest of the stuff. Proof-of-concept: ricardoV94#4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I am going to try and do that!



__all__ = tuple(az.plots.__all__) + (
"autocorrplot",
"compareplot",
"forestplot",
"kdeplot",
"traceplot",
"energyplot",
"densityplot",
"pairplot",
)
7 changes: 0 additions & 7 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,9 @@
"hessian",
"hessian_diag",
"inputvars",
"cont_inputs",
"floatX",
"intX",
"smartfloatX",
"jacobian",
"CallableTensor",
"join_nonshared_inputs",
"make_shared_replacements",
"generator",
"convert_observed_data",
"compile_pymc",
]

Expand Down
14 changes: 8 additions & 6 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@
from pytensor.tensor.sharedvar import SharedVariable
from typing_extensions import TypeAlias

import pymc as pm

from pymc.backends.arviz import _DefaultTrace
from pymc.backends.arviz import (
_DefaultTrace,
predictions_to_inference_data,
to_inference_data,
)
from pymc.backends.base import MultiTrace
from pymc.blocking import PointType
from pymc.model import Model, modelcontext
Expand Down Expand Up @@ -438,7 +440,7 @@
ikwargs: Dict[str, Any] = dict(model=model)
if idata_kwargs:
ikwargs.update(idata_kwargs)
return pm.to_inference_data(prior=prior, **ikwargs)
return to_inference_data(prior=prior, **ikwargs)

Check warning on line 443 in pymc/sampling/forward.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/forward.py#L443

Added line #L443 was not covered by tests


def sample_posterior_predictive(
Expand Down Expand Up @@ -669,8 +671,8 @@
if extend_inferencedata:
ikwargs.setdefault("idata_orig", idata)
ikwargs.setdefault("inplace", True)
return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
return predictions_to_inference_data(ppc_trace, **ikwargs)
idata_pp = to_inference_data(posterior_predictive=ppc_trace, **ikwargs)

Check warning on line 675 in pymc/sampling/forward.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/forward.py#L674-L675

Added lines #L674 - L675 were not covered by tests

if extend_inferencedata and idata is not None:
idata.extend(idata_pp)
Expand Down
7 changes: 4 additions & 3 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
to_inference_data,
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.exceptions import ParallelSamplingError, SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import Draw, _cpu_count
Expand Down Expand Up @@ -892,7 +893,7 @@
if compute_convergence_checks or return_inferencedata:
ikwargs: Dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples)
ikwargs.update(idata_kwargs)
idata = pm.to_inference_data(mtrace, **ikwargs)
idata = to_inference_data(mtrace, **ikwargs)

if compute_convergence_checks:
warns = run_convergence_checks(idata, model)
Expand Down Expand Up @@ -1198,7 +1199,7 @@
if callback is not None:
callback(trace=strace, draw=draw)

except ps.ParallelSamplingError as error:
except ParallelSamplingError as error:

Check warning on line 1202 in pymc/sampling/mcmc.py

View check run for this annotation

Codecov / codecov/patch

pymc/sampling/mcmc.py#L1202

Added line #L1202 was not covered by tests
strace = traces[error._chain]
for strace in traces:
strace.close()
Expand Down
14 changes: 1 addition & 13 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,13 @@
from fastprogress.fastprogress import progress_bar

from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.exceptions import ParallelSamplingError, RemoteTraceback, SamplingError
from pymc.util import RandomSeed

logger = logging.getLogger(__name__)


class ParallelSamplingError(Exception):
def __init__(self, message, chain):
super().__init__(message)
self._chain = chain


# Taken from https://hg.python.org/cpython/rev/c4f92b597074
class RemoteTraceback(Exception):
def __init__(self, tb):
self.tb = tb

def __str__(self):
return self.tb


class ExceptionWithTraceback:
Expand Down
2 changes: 0 additions & 2 deletions pymc/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,3 @@
setattr(sys.modules[__name__], attr, obj)

from pymc.stats.log_likelihood import compute_log_likelihood

__all__ = ("compute_log_likelihood",) + tuple(az.stats.__all__)
4 changes: 2 additions & 2 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import numpy as np

from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
from pymc.exceptions import SamplingError
from pymc.exceptions import IntegrationError, SamplingError
from pymc.model import Point, modelcontext
from pymc.pytensorf import floatX
from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods import step_sizes
from pymc.step_methods.arraystep import GradientSharedStep
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.step_methods.hmc.integration import State
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
from pymc.tuning import guess_scaling
from pymc.util import get_value_vars_from_user_vars
Expand Down
Loading
Loading