Skip to content

Commit

Permalink
Circular intervals for VonMises (#592)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Nov 20, 2024
1 parent eea8dd1 commit 16a47fb
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 16 deletions.
26 changes: 26 additions & 0 deletions preliz/distributions/vonmises.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def _fit_mle(self, sample):
mu = np.mod(mu + np.pi, 2 * np.pi) - np.pi
self._update(mu, kappa)

def eti(self, mass=0.94, fmt=".2f"):
mean = self.mu
self.mu = 0
hdi_min, hdi_max = super().eti(mass=mass, fmt=fmt)
self.mu = mean
return _warp_interval(hdi_min, hdi_max, self.mu, fmt)

def hdi(self, mass=0.94, fmt=".2f"):
mean = self.mu
self.mu = 0
hdi_min, hdi_max = super().hdi(mass=mass, fmt=fmt)
self.mu = mean
return _warp_interval(hdi_min, hdi_max, self.mu, fmt)


def nb_cdf(x, pdf):
if isinstance(x, (int, float)):
Expand Down Expand Up @@ -170,3 +184,15 @@ def nb_logpdf(x, mu, kappa):

def nb_neg_logpdf(x, mu, kappa):
return -(nb_logpdf(x, mu, kappa)).sum()


def _warp_interval(hdi_min, hdi_max, mu, fmt):
hdi_min = hdi_min + mu
hdi_max = hdi_max + mu

lower_tail = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
upper_tail = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
if fmt != "none":
lower_tail = float(f"{lower_tail:{fmt}}")
upper_tail = float(f"{upper_tail:{fmt}}")
return (lower_tail, upper_tail)
49 changes: 33 additions & 16 deletions preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,19 @@ def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False,
Whether to do the plot along the x-axis (default) or on the y-axis
ax : matplotlib axis
"""

if isinstance(distribution, (np.ndarray, list, tuple)):
dist_type = "sample"
else:
dist_type = "preliz"

if interval == "quantiles":
if levels is None:
levels = [0.05, 0.25, 0.5, 0.75, 0.95]
elif len(levels) not in (5, 3, 1, 0):
raise ValueError("levels should have 5, 3, 1 or 0 elements")

if isinstance(distribution, (np.ndarray, list, tuple)):
if dist_type == "sample":
q_s = np.quantile(distribution, levels).tolist()
else:
q_s = distribution.ppf(levels).tolist()
Expand All @@ -52,7 +58,7 @@ def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False,
elif len(levels) not in (2, 1):
raise ValueError("levels should have 2 or 1 elements")

if isinstance(distribution, (np.ndarray, list, tuple)):
if dist_type == "sample":
if interval == "hdi":
func = hdi
if interval == "eti":
Expand All @@ -77,21 +83,32 @@ def plot_pointinterval(distribution, interval="hdi", levels=None, rotated=False,

q_s_size = len(q_s)

if rotated:
if q_s_size == 5:
ax.plot([0, 0], (q_s.pop(0), q_s.pop(-1)), "k", solid_capstyle="butt", lw=1.5)
if q_s_size > 2:
ax.plot([0, 0], (q_s.pop(0), q_s.pop(-1)), "k", solid_capstyle="butt", lw=4)
if q_s_size > 0:
ax.plot(0, q_s[0], "wo", mec="k")
if q_s_size == 5:
_plot_sub_iterval(q_s, lw=1.5, rotated=rotated, ax=ax)
if q_s_size > 2:
_plot_sub_iterval(q_s, lw=4, rotated=rotated, ax=ax)
if q_s_size > 0:
x, y = q_s[0], 0
if rotated:
x, y = y, x
ax.plot(x, y, "wo", mec="k")


def _plot_sub_iterval(q_s, lw, rotated, ax):
lower, upper = q_s.pop(0), q_s.pop(-1)
if lower < upper:
x, y = (lower, upper), [0, 0]
if rotated:
x, y = y, x
ax.plot(x, y, "k", solid_capstyle="butt", lw=lw)
else:
if q_s_size == 5:
ax.plot((q_s.pop(0), q_s.pop(-1)), [0, 0], "k", solid_capstyle="butt", lw=1.5)
if q_s_size > 2:
ax.plot((q_s.pop(0), q_s.pop(-1)), [0, 0], "k", solid_capstyle="butt", lw=4)

if q_s_size > 0:
ax.plot(q_s[0], 0, "wo", mec="k")
x0, y0 = (lower, np.pi), [0, 0]
x1, y1 = (-np.pi, upper), [0, 0]
if rotated:
x0, y0 = y0, x0
x1, y1 = y1, x1
ax.plot(x0, y0, "k", solid_capstyle="butt", lw=lw)
ax.plot(x1, y1, "k", solid_capstyle="butt", lw=lw)


def eti(distribution, mass):
Expand Down

0 comments on commit 16a47fb

Please sign in to comment.