diff --git a/preliz/internal/plot_helper_multivariate.py b/preliz/internal/plot_helper_multivariate.py index 3adc8af..29d2a8b 100644 --- a/preliz/internal/plot_helper_multivariate.py +++ b/preliz/internal/plot_helper_multivariate.py @@ -1,11 +1,12 @@ from functools import reduce from operator import mul +import warnings import numpy as np import matplotlib.pyplot as plt from matplotlib import tri -from scipy.special import gamma -from .plot_helper import repr_to_matplotlib +from preliz.internal.plot_helper import repr_to_matplotlib +from preliz.internal.special import gammaln def get_cols_rows(n_plots): @@ -26,7 +27,9 @@ def __init__(self, alpha): """ self._alpha = np.array(alpha) - self._coef = gamma(np.sum(self._alpha)) / reduce(mul, [gamma(a) for a in self._alpha]) + self._coef = np.exp( + gammaln(np.sum(self._alpha)) - np.sum([gammaln(a) for a in self._alpha]) + ) self._corners = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, 0.75**0.5]]) self._triangle = tri.Triangulation(self._corners[:, 0], self._corners[:, 1]) @@ -34,6 +37,19 @@ def __init__(self, alpha): (self._corners[(i + 1) % 3] + self._corners[(i + 2) % 3]) / 2.0 for i in range(3) ] + refiner = tri.UniformTriRefiner(self._triangle) + self.trimesh = refiner.refine_triangulation(subdiv=8) + self.pvals = np.nan_to_num( + [self.pdf(self.xy2bc(xy)) for xy in zip(self.trimesh.x, self.trimesh.y)] + ) + self.ok = True + if not np.any(self.pvals): + self.ok = False + warnings.warn( + "The joint pdf is to concentrated to plot, use `marginals=True` instead", + stacklevel=2, + ) + def xy2bc(self, x_y, tol=1.0e-3): """ Converts 2D Cartesian coordinates to barycentric coordinates. @@ -64,18 +80,14 @@ def plot(self, ax=None): subdiv: int Number of recursive mesh subdivisions to create. """ - refiner = tri.UniformTriRefiner(self._triangle) - trimesh = refiner.refine_triangulation(subdiv=8) - pvals = [self.pdf(self.xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)] - hdi_probs = [0.1, 0.5, 0.94] - contour_levels = find_hdi_contours(pvals, hdi_probs) + contour_levels = find_hdi_contours(self.pvals, hdi_probs) if all(contour_levels == contour_levels[0]): - ax.tricontourf(trimesh, pvals) + ax.tricontourf(self.trimesh, self.pvals) else: ax.tricontour( - trimesh, - pvals, + self.trimesh, + self.pvals, levels=contour_levels, ) ax.triplot(self._triangle, color="0.8", linestyle="--", linewidth=2) @@ -168,11 +180,13 @@ def plot_dirichlet( else: if dim == 3: - if axes is None: - _, axes = plt.subplots(1, 1) - DirichletOnSimplex(alpha).plot(ax=axes) - if legend == "title": - axes.set_title(repr_to_matplotlib(dist)) + dirichlet_ = DirichletOnSimplex(alpha) + if dirichlet_.ok: + if axes is None: + _, axes = plt.subplots(1, 1) + dirichlet_.plot(ax=axes) + if legend == "title": + axes.set_title(repr_to_matplotlib(dist)) else: raise ValueError("joint only works for Dirichlet of dim=3")