Skip to content

Commit

Permalink
Remove Plots & Diagnostics (#4397), Closes #4371
Browse files Browse the repository at this point in the history
* 🎉 Start removing Diagnostics & Plots in PyMC3 development

🔥 Remove arviz plots
* Remove directly imported arviz plots used in pymc3 plots

🔥 Remove all plots from PyMC3 Plots module

🔥 Remove PyMC3 plots references in Docs

🎨 Mention Plotting & Diagnostics in API page and remove plots reference in __init__.py

⏪ Revert posterior_plot function, test, and docs

🎨 Add deprecation warning to posterior_plot function

🎨 Add context on plot import and import back into __init__.py

✏️ Add warning and details of posterior_plot added

Update docs/source/api/plots.rst

Co-authored-by: Alexandre ANDORRA <[email protected]>

Update docs/source/api/plots.rst

Co-authored-by: Alexandre ANDORRA <[email protected]>

Update pymc3/plots/__init__.py

Co-authored-by: Alexandre ANDORRA <[email protected]>

Update pymc3/plots/__init__.py

Co-authored-by: Alexandre ANDORRA <[email protected]>

✏️ Update docs to add stats.rst details

✏️ Minor docs notation for posterioplot function(s)

📝 Add breakline before docstring title

* ✏️ small typo and remove summary in plots.rst
  • Loading branch information
CloudChaoszero authored Jan 19, 2021
1 parent 32b5c94 commit 37ca5ea
Show file tree
Hide file tree
Showing 12 changed files with 50 additions and 221 deletions.
10 changes: 5 additions & 5 deletions benchmarks/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import timeit

import arviz as az
import numpy as np
import pandas as pd
import theano
Expand Down Expand Up @@ -192,7 +192,7 @@ def track_glm_hierarchical_ess(self, init):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = float(pm.ess(trace, var_names=["mu_a"])["mu_a"].values)
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
return ess / tot

def track_marginal_mixture_model_ess(self, init):
Expand All @@ -214,7 +214,7 @@ def track_marginal_mixture_model_ess(self, init):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = pm.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
ess = az.ess(trace, var_names=["mu"])["mu"].values.min() # worst case
return ess / tot


Expand Down Expand Up @@ -245,7 +245,7 @@ def track_glm_hierarchical_ess(self, step):
compute_convergence_checks=False,
)
tot = time.time() - t0
ess = float(pm.ess(trace, var_names=["mu_a"])["mu_a"].values)
ess = float(az.ess(trace, var_names=["mu_a"])["mu_a"].values)
return ess / tot


Expand Down Expand Up @@ -304,7 +304,7 @@ def freefall(y, t, p):
t0 = time.time()
trace = pm.sample(500, tune=1000, chains=2, cores=2, random_seed=0)
tot = time.time() - t0
ess = pm.ess(trace)
ess = az.ess(trace)
return np.mean([ess.sigma, ess.gamma]) / tot


Expand Down
1 change: 0 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ API Reference
api/shape_utils
api/ode


Indices and tables
===================

Expand Down
15 changes: 3 additions & 12 deletions docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,7 @@ Plots are delegated to the
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
library, a general purpose library for
"exploratory analysis of Bayesian models."
For plots, ``pymc3.<function>`` are now aliases
for ArviZ functions. Thus, the links below will redirect you to
ArviZ docs:
Refer to its documentation to use the plotting functions directly.

- :func:`pymc3.traceplot <arviz:arviz.plot_trace>`
- :func:`pymc3.plot_posterior <arviz:arviz.plot_posterior>`
- :func:`pymc3.forestplot <arviz:arviz.plot_forest>`
- :func:`pymc3.compareplot <arviz:arviz.plot_compare>`
- :func:`pymc3.autocorrplot <arviz:arviz.plot_autocorr>`
- :func:`pymc3.energyplot <arviz:arviz.plot_energy>`
- :func:`pymc3.kdeplot <arviz:arviz.plot_kde>`
- :func:`pymc3.densityplot <arviz:arviz.plot_density>`
- :func:`pymc3.pairplot <arviz:arviz.plot_pair>`
.. automodule:: pymc3.plots.posteriorplot
:members:
19 changes: 1 addition & 18 deletions docs/source/api/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,4 @@ Statistics and diagnostics are delegated to the
`ArviZ <https://arviz-devs.github.io/arviz/index.html>`_.
library, a general purpose library for
"exploratory analysis of Bayesian models."
For statistics and diagnostics, ``pymc3.<function>`` are now aliases
for ArviZ functions. Thus, the links below will redirect you to
ArviZ docs:

.. currentmodule:: pymc3.stats


- :func:`pymc3.bfmi <arviz:arviz.bfmi>`
- :func:`pymc3.compare <arviz:arviz.compare>`
- :func:`pymc3.ess <arviz:arviz.ess>`
- :data:`pymc3.geweke <arviz:arviz.geweke>`
- :func:`pymc3.hpd <arviz:arviz.hpd>`
- :func:`pymc3.loo <arviz:arviz.loo>`
- :func:`pymc3.mcse <arviz:arviz.mcse>`
- :func:`pymc3.r2_score <arviz:arviz.r2_score>`
- :func:`pymc3.rhat <arviz:arviz.rhat>`
- :func:`pymc3.summary <arviz:arviz.summary>`
- :func:`pymc3.waic <arviz:arviz.waic>`
Refer to its documentation to use the diagnostics functions directly.
1 change: 0 additions & 1 deletion pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __set_compiler_flags():
from pymc3.plots import *
from pymc3.sampling import *
from pymc3.smc import *
from pymc3.stats import *
from pymc3.step_methods import *
from pymc3.tests import test
from pymc3.theanof import *
Expand Down
101 changes: 5 additions & 96 deletions pymc3/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,108 +14,17 @@

"""PyMC3 Plotting.
Plots are delegated to the ArviZ library, a general purpose library for
"exploratory analysis of Bayesian models." See https://arviz-devs.github.io/arviz/
for details on plots.
Plots are delegated to the `ArviZ <https://arviz-devs.github.io/arviz/>`_ library, a general purpose library for
exploratory analysis of Bayesian models. For more details, see https://arviz-devs.github.io/arviz/.
Only `plot_posterior_predictive_glm` is kept in the PyMC code base for now, but it will move to ArviZ once the latter adds features for regression plots.
"""
import functools
import sys
import warnings

import arviz as az


def map_args(func):
swaps = [("varnames", "var_names")]

@functools.wraps(func)
def wrapped(*args, **kwargs):
for (old, new) in swaps:
if old in kwargs and new not in kwargs:
warnings.warn(
f"Keyword argument `{old}` renamed to `{new}`, and will be removed in pymc3 3.8"
)
kwargs[new] = kwargs.pop(old)
return func(*args, **kwargs)

return wrapped


# pymc3 custom plots: override these names for custom behavior
autocorrplot = map_args(az.plot_autocorr)
forestplot = map_args(az.plot_forest)
kdeplot = map_args(az.plot_kde)
plot_posterior = map_args(az.plot_posterior)
energyplot = map_args(az.plot_energy)
densityplot = map_args(az.plot_density)
pairplot = map_args(az.plot_pair)

# Use compact traceplot by default
@map_args
@functools.wraps(az.plot_trace)
def traceplot(*args, **kwargs):
try:
kwargs.setdefault("compact", True)
return az.plot_trace(*args, **kwargs)
except TypeError:
kwargs.pop("compact")
return az.plot_trace(*args, **kwargs)


# addition arg mapping for compare plot
@functools.wraps(az.plot_compare)
def compareplot(*args, **kwargs):
if "comp_df" in kwargs:
comp_df = kwargs["comp_df"].copy()
else:
args = list(args)
comp_df = args[0].copy()
if "WAIC" in comp_df.columns:
comp_df = comp_df.rename(
index=str,
columns={
"WAIC": "waic",
"pWAIC": "p_waic",
"dWAIC": "d_waic",
"SE": "se",
"dSE": "dse",
"var_warn": "warning",
},
)
elif "LOO" in comp_df.columns:
comp_df = comp_df.rename(
index=str,
columns={
"LOO": "loo",
"pLOO": "p_loo",
"dLOO": "d_loo",
"SE": "se",
"dSE": "dse",
"shape_warn": "warning",
},
)
if "comp_df" in kwargs:
kwargs["comp_df"] = comp_df
else:
args[0] = comp_df
return az.plot_compare(*args, **kwargs)


from pymc3.plots.posteriorplot import plot_posterior_predictive_glm

# Access to arviz plots: base plots provided by arviz
for plot in az.plots.__all__:
setattr(sys.modules[__name__], plot, map_args(getattr(az.plots, plot)))

__all__ = tuple(az.plots.__all__) + (
"autocorrplot",
"compareplot",
"forestplot",
"kdeplot",
"plot_posterior",
"traceplot",
"energyplot",
"densityplot",
"pairplot",
"plot_posterior_predictive_glm",
)
__all__ = ["plot_posterior_predictive_glm"]
41 changes: 28 additions & 13 deletions pymc3/plots/posteriorplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from __future__ import annotations

import warnings

from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import matplotlib.pyplot as plt
Expand All @@ -33,20 +35,33 @@ def plot_posterior_predictive_glm(
**kwargs: Any
) -> None:
"""Plot posterior predictive of a linear model.
:Arguments:
trace: InferenceData or MultiTrace
Output of pm.sample()
eval: <array>
Array over which to evaluate lm
lm: function <default: linear function>
Function mapping parameters at different points
to their respective outputs.
input: point, sample
output: estimated value
samples: int <default=30>
How many posterior samples to draw.
Additional keyword arguments are passed to pylab.plot().
Parameters
----------
trace: InferenceData or MultiTrace
Output of pm.sample()
eval: <array>
Array over which to evaluate lm
lm: function <default: linear function>
Function mapping parameters at different points
to their respective outputs.
input: point, sample
output: estimated value
samples: int <default=30>
How many posterior samples to draw.
kwargs : mapping, optional
Additional keyword arguments are passed to ``matplotlib.pyplot.plot()``.
Warnings
--------
The `plot_posterior_predictive_glm` function will be removed in a future PyMC3 release.
"""
warnings.warn(
"The `plot_posterior_predictive_glm` function will migrate to Arviz in a future release. "
"\nKeep up to date with `ArviZ <https://arviz-devs.github.io/arviz/>`_ for future updates.",
DeprecationWarning,
)

if lm is None:
lm = lambda x, sample: sample["Intercept"] + sample["x"] * x

Expand Down
2 changes: 1 addition & 1 deletion pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def sample(
...: y = pm.Binomial("y", n=n, p=p, observed=h)
...: trace = pm.sample()
In [3]: pm.summary(trace, kind="stats")
In [3]: az.summary(trace, kind="stats")
Out[3]:
mean sd hdi_3% hdi_97%
Expand Down
69 changes: 0 additions & 69 deletions pymc3/stats/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion pymc3/step_methods/mlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class MLDA(ArrayStepShared):
... tune=100, step=step_method,
... random_seed=123)
...
... pm.summary(trace, kind="stats")
... az.summary(trace, kind="stats")
mean sd hdi_3% hdi_97%
x 0.99 0.987 -0.734 2.992
Expand Down
5 changes: 3 additions & 2 deletions pymc3/tests/sampler_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import arviz as az
import numpy as np
import numpy.testing as npt
import theano.tensor as tt
Expand Down Expand Up @@ -146,12 +147,12 @@ def setup_class(cls):

def test_neff(self):
if hasattr(self, "min_n_eff"):
n_eff = pm.ess(self.trace[self.burn :])
n_eff = az.ess(self.trace[self.burn :])
for var in n_eff:
npt.assert_array_less(self.min_n_eff, n_eff[var])

def test_Rhat(self):
rhat = pm.rhat(self.trace[self.burn :])
rhat = az.rhat(self.trace[self.burn :])
for var in rhat:
npt.assert_allclose(rhat[var], 1, rtol=0.01)

Expand Down
Loading

0 comments on commit 37ca5ea

Please sign in to comment.