-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bring back stats and plotting aliases #4536
Changes from 1 commit
04f2612
b98a895
1982170
2abb029
83e4a2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,21 @@ Statistics and diagnostics are delegated to the | |
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_. | ||
library, a general purpose library for | ||
"exploratory analysis of Bayesian models." | ||
Refer to its documentation to use the diagnostics functions directly. | ||
For statistics and diagnostics, ``pymc3.<function>`` are now aliases | ||
for ArviZ functions. Thus, the links below will redirect you to | ||
ArviZ docs: | ||
|
||
.. currentmodule:: pymc3.stats | ||
|
||
|
||
- :func:`pymc3.bfmi <arviz:arviz.bfmi>` | ||
- :func:`pymc3.compare <arviz:arviz.compare>` | ||
- :func:`pymc3.ess <arviz:arviz.ess>` | ||
- :data:`pymc3.geweke <arviz:arviz.geweke>` | ||
- :func:`pymc3.hpd <arviz:arviz.hpd>` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand wanting to keep the original Note that currently the code does not do any hpd-hdi aliasing on pymc3 side though. |
||
- :func:`pymc3.loo <arviz:arviz.loo>` | ||
- :func:`pymc3.mcse <arviz:arviz.mcse>` | ||
- :func:`pymc3.r2_score <arviz:arviz.r2_score>` | ||
- :func:`pymc3.rhat <arviz:arviz.rhat>` | ||
- :func:`pymc3.summary <arviz:arviz.summary>` | ||
- :func:`pymc3.waic <arviz:arviz.waic>` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,17 +14,108 @@ | |
|
||
"""PyMC3 Plotting. | ||
|
||
Plots are delegated to the `ArviZ <https://arviz-devs.github.io/arviz/>`_ library, a general purpose library for | ||
exploratory analysis of Bayesian models. For more details, see https://arviz-devs.github.io/arviz/. | ||
|
||
Only `plot_posterior_predictive_glm` is kept in the PyMC code base for now, but it will move to ArviZ once the latter adds features for regression plots. | ||
Plots are delegated to the ArviZ library, a general purpose library for | ||
"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/ | ||
for details on plots. | ||
""" | ||
import functools | ||
import sys | ||
import warnings | ||
|
||
import arviz as az | ||
|
||
|
||
def map_args(func): | ||
swaps = [("varnames", "var_names")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These should have been around for long enough - I replaced the soft with a hard warning. |
||
|
||
@functools.wraps(func) | ||
def wrapped(*args, **kwargs): | ||
for (old, new) in swaps: | ||
if old in kwargs and new not in kwargs: | ||
warnings.warn( | ||
f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8" | ||
) | ||
kwargs[new] = kwargs.pop(old) | ||
return func(*args, **kwargs) | ||
|
||
return wrapped | ||
|
||
|
||
# pymc3 custom plots: override these names for custom behavior | ||
autocorrplot = map_args(az.plot_autocorr) | ||
forestplot = map_args(az.plot_forest) | ||
kdeplot = map_args(az.plot_kde) | ||
plot_posterior = map_args(az.plot_posterior) | ||
energyplot = map_args(az.plot_energy) | ||
densityplot = map_args(az.plot_density) | ||
pairplot = map_args(az.plot_pair) | ||
|
||
# Use compact traceplot by default | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ArviZ switched their default in |
||
@map_args | ||
@functools.wraps(az.plot_trace) | ||
def traceplot(*args, **kwargs): | ||
try: | ||
kwargs.setdefault("compact", True) | ||
return az.plot_trace(*args, **kwargs) | ||
except TypeError: | ||
kwargs.pop("compact") | ||
return az.plot_trace(*args, **kwargs) | ||
|
||
|
||
# addition arg mapping for compare plot | ||
@functools.wraps(az.plot_compare) | ||
def compareplot(*args, **kwargs): | ||
if "comp_df" in kwargs: | ||
comp_df = kwargs["comp_df"].copy() | ||
else: | ||
args = list(args) | ||
comp_df = args[0].copy() | ||
if "WAIC" in comp_df.columns: | ||
comp_df = comp_df.rename( | ||
index=str, | ||
columns={ | ||
"WAIC": "waic", | ||
"pWAIC": "p_waic", | ||
"dWAIC": "d_waic", | ||
"SE": "se", | ||
"dSE": "dse", | ||
"var_warn": "warning", | ||
}, | ||
) | ||
elif "LOO" in comp_df.columns: | ||
comp_df = comp_df.rename( | ||
index=str, | ||
columns={ | ||
"LOO": "loo", | ||
"pLOO": "p_loo", | ||
"dLOO": "d_loo", | ||
"SE": "se", | ||
"dSE": "dse", | ||
"shape_warn": "warning", | ||
}, | ||
) | ||
if "comp_df" in kwargs: | ||
kwargs["comp_df"] = comp_df | ||
else: | ||
args[0] = comp_df | ||
return az.plot_compare(*args, **kwargs) | ||
|
||
|
||
from pymc3.plots.posteriorplot import plot_posterior_predictive_glm | ||
|
||
__all__ = ["plot_posterior_predictive_glm"] | ||
# Access to arviz plots: base plots provided by arviz | ||
for plot in az.plots.__all__: | ||
setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot))) | ||
|
||
__all__ = tuple(az.plots.__all__) + ( | ||
"autocorrplot", | ||
"compareplot", | ||
"forestplot", | ||
"kdeplot", | ||
"plot_posterior", | ||
"traceplot", | ||
"energyplot", | ||
"densityplot", | ||
"pairplot", | ||
"plot_posterior_predictive_glm", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2020 The PyMC Developers | ||
michaelosthege marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Statistical utility functions for PyMC3 | ||
|
||
Diagnostics and auxiliary statistical functions are delegated to the ArviZ library, a general | ||
purpose library for "exploratory analysis of Bayesian models." See | ||
https://arviz-devs.github.io/arviz/ for details. | ||
""" | ||
import functools | ||
import warnings | ||
|
||
import arviz as az | ||
|
||
|
||
def map_args(func): | ||
swaps = [("varnames", "var_names")] | ||
|
||
@functools.wraps(func) | ||
def wrapped(*args, **kwargs): | ||
for (old, new) in swaps: | ||
if old in kwargs and new not in kwargs: | ||
warnings.warn( | ||
"Keyword argument `{old}` renamed to `{new}`, and will be removed in " | ||
"pymc3 3.9".format(old=old, new=new) | ||
) | ||
kwargs[new] = kwargs.pop(old) | ||
return func(*args, **kwargs) | ||
|
||
return wrapped | ||
|
||
|
||
bfmi = map_args(az.bfmi) | ||
compare = map_args(az.compare) | ||
ess = map_args(az.ess) | ||
geweke = map_args(az.geweke) | ||
hpd = map_args(az.hpd) | ||
loo = map_args(az.loo) | ||
mcse = map_args(az.mcse) | ||
r2_score = map_args(az.r2_score) | ||
rhat = map_args(az.rhat) | ||
summary = map_args(az.summary) | ||
waic = map_args(az.waic) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All the names are identical - therefore they can be imported without wrapping. |
||
|
||
|
||
__all__ = [ | ||
"bfmi", | ||
"compare", | ||
"ess", | ||
"geweke", | ||
"hpd", | ||
"loo", | ||
"mcse", | ||
"r2_score", | ||
"rhat", | ||
"summary", | ||
"waic", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we deleted geweke in ArviZ, it had basically been broken for a while (worked only on 1d arrays) and nobody had complained. This is why there are questions on discourse about it being missing, but so far nobody has complained about it being missing, only about the import error it causes with old pymc3 versions.