Skip to content

Commit

Permalink
refactor (#580)
Browse files Browse the repository at this point in the history
* refactor

* remove logging

* ref
  • Loading branch information
aloctavodia authored Nov 4, 2024
1 parent 551dfcc commit 28a0488
Show file tree
Hide file tree
Showing 20 changed files with 382 additions and 420 deletions.
10 changes: 1 addition & 9 deletions preliz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Tools to help you pick a prior
"""
import logging
from os import path as os_path

from matplotlib import rcParams
Expand All @@ -18,13 +17,6 @@

__version__ = "0.11.0"

_log = logging.getLogger("preliz")

if not logging.root.handlers:
_log.setLevel(logging.INFO)
if len(_log.handlers) == 0:
handler = logging.StreamHandler()
_log.addHandler(handler)

# Allow legend outside plot in maxent to be included when saving a figure
# We may want to make this more explicit by having preliz.rcParams
Expand All @@ -37,4 +29,4 @@
style.core.reload_library()

# clean namespace
del logging, os_path, rcParams, _preliz_style_path, _log
del os_path, rcParams, _preliz_style_path
29 changes: 3 additions & 26 deletions preliz/internal/distribution_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def from_precision(precision):


def to_precision(sigma):
precision = 1 / sigma**2
precision = 1 / (eps + sigma**2)
return precision


Expand Down Expand Up @@ -148,38 +148,15 @@ def num_kurtosis(dist):
}


def get_distributions(dist_names=None, exclude=None):
def get_distributions(dist_names=None):

if dist_names is None:
all_distributions = modules["preliz.distributions"].__all__
else:
all_distributions = dist_names

if exclude is None:
exclude = []
if exclude == "auto":
exclude = [
"Beta",
"BetaScaled",
"Triangular",
"TruncatedNormal",
"Uniform",
"VonMises",
"Categorical",
"DiscreteUniform",
"HyperGeometric",
"zeroInflatedBinomial",
"ZeroInflatedNegativeBinomial",
"ZeroInflatedPoisson",
"MvNormal",
"Mixture",
]

distributions = []
for a_dist in all_distributions:
dist = getattr(modules["preliz.distributions"], a_dist)()
if dist.__class__.__name__ not in exclude:
distributions.append(dist)
if exclude:
return exclude, distributions
distributions.append(dist)
return distributions
13 changes: 0 additions & 13 deletions preliz/internal/logging.py

This file was deleted.

208 changes: 0 additions & 208 deletions preliz/internal/parser.py

This file was deleted.

61 changes: 1 addition & 60 deletions preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
try:
from IPython import get_ipython
from ipywidgets import FloatSlider, IntSlider, FloatText, IntText, Checkbox, ToggleButton
from pymc import sample_prior_predictive
except ImportError:
pass

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import _pylab_helpers, get_backend
from matplotlib.ticker import MaxNLocator
from .logging import disable_pymc_sampling_logs
from .narviz import hdi, kde
from preliz.internal.narviz import hdi, kde


def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False, ax=None):
Expand Down Expand Up @@ -425,63 +423,6 @@ def looper(*args, **kwargs):
return looper


def bambi_plot_decorator(func, iterations, kind_plot, references, plot_func):
def looper(*args, **kwargs):
kwargs.pop("__resample__")
x_min = kwargs.pop("__x_min__")
x_max = kwargs.pop("__x_max__")
if not kwargs.pop("__set_xlim__"):
x_min = None
x_max = None
auto = True
else:
auto = False

model = func(*args, **kwargs)
model.build()
with disable_pymc_sampling_logs():
idata = model.prior_predictive(iterations)
results = (
idata["prior_predictive"].stack(sample=("chain", "draw"))[model.response_name].values.T
)

_, ax = plt.subplots()
ax.set_xlim(x_min, x_max, auto=auto)
if plot_func is None:
plot_repr(results, kind_plot, references, iterations, ax)
else:
plot_func(results, ax)

return looper


def pymc_plot_decorator(func, iterations, kind_plot, references, plot_func):
def looper(*args, **kwargs):
kwargs.pop("__resample__")
x_min = kwargs.pop("__x_min__")
x_max = kwargs.pop("__x_max__")
if not kwargs.pop("__set_xlim__"):
x_min = None
x_max = None
auto = True
else:
auto = False
with func(*args, **kwargs) as model:
obs_name = model.observed_RVs[0].name
with disable_pymc_sampling_logs():
idata = sample_prior_predictive(samples=iterations)
results = idata["prior_predictive"].stack(sample=("chain", "draw"))[obs_name].values.T

_, ax = plt.subplots()
ax.set_xlim(x_min, x_max, auto=auto)
if plot_func is None:
plot_repr(results, kind_plot, references, iterations, ax)
else:
plot_func(results, ax)

return looper


def plot_repr(results, kind_plot, references, iterations, ax):
alpha = max(0.01, 1 - iterations * 0.009)

Expand Down
Loading

0 comments on commit 28a0488

Please sign in to comment.