Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor _repr_latex functionality #4065

Merged
merged 14 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 36 additions & 272 deletions pymc3/distributions/continuous.py

Large diffs are not rendered by default.

140 changes: 2 additions & 138 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from scipy import stats
import warnings

from pymc3.util import get_variable_name
from .dist_math import bound, factln, binomln, betaln, logpow, random_choice
from .distribution import Discrete, draw_values, generate_samples
from .shape_utils import broadcast_distribution_samples
Expand Down Expand Up @@ -123,15 +122,6 @@ def logp(self, value):
0 <= value, value <= n,
0 <= p, p <= 1)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
n = dist.n
p = dist.p
name = r'\text{%s}' % name
return r'${} \sim \text{{Binomial}}(\mathit{{n}}={},~\mathit{{p}}={})$'.format(name,
get_variable_name(n),
get_variable_name(p))

class BetaBinomial(Discrete):
R"""
Expand Down Expand Up @@ -259,16 +249,6 @@ def logp(self, value):
value >= 0, value <= self.n,
alpha > 0, beta > 0)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
alpha = dist.alpha
beta = dist.beta
name = r'\text{%s}' % name
return r'${} \sim \text{{BetaBinomial}}(\mathit{{alpha}}={},~\mathit{{beta}}={})$'.format(name,
get_variable_name(alpha),
get_variable_name(beta))


class Bernoulli(Discrete):
R"""Bernoulli log-likelihood
Expand Down Expand Up @@ -371,13 +351,8 @@ def logp(self, value):
value >= 0, value <= 1,
p >= 0, p <= 1)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
p = dist.p
name = r'\text{%s}' % name
return r'${} \sim \text{{Bernoulli}}(\mathit{{p}}={})$'.format(name,
get_variable_name(p))
def _distr_parameters_for_repr(self):
return ["p"]


class DiscreteWeibull(Discrete):
Expand Down Expand Up @@ -486,16 +461,6 @@ def random(self, point=None, size=None):
dist_shape=self.shape,
size=size)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
q = dist.q
beta = dist.beta
name = r'\text{%s}' % name
return r'${} \sim \text{{DiscreteWeibull}}(\mathit{{q}}={},~\mathit{{beta}}={})$'.format(name,
get_variable_name(q),
get_variable_name(beta))


class Poisson(Discrete):
R"""
Expand Down Expand Up @@ -590,14 +555,6 @@ def logp(self, value):
return tt.switch(tt.eq(mu, 0) * tt.eq(value, 0),
0, log_prob)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
mu = dist.mu
name = r'\text{%s}' % name
return r'${} \sim \text{{Poisson}}(\mathit{{mu}}={})$'.format(name,
get_variable_name(mu))


class NegativeBinomial(Discrete):
R"""
Expand Down Expand Up @@ -717,16 +674,6 @@ def logp(self, value):
Poisson.dist(self.mu).logp(value),
negbinom)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
mu = dist.mu
alpha = dist.alpha
name = r'\text{%s}' % name
return r'${} \sim \text{{NegativeBinomial}}(\mathit{{mu}}={},~\mathit{{alpha}}={})$'.format(name,
get_variable_name(mu),
get_variable_name(alpha))


class Geometric(Discrete):
R"""
Expand Down Expand Up @@ -810,14 +757,6 @@ def logp(self, value):
return bound(tt.log(p) + logpow(1 - p, value - 1),
0 <= p, p <= 1, value >= 1)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
p = dist.p
name = r'\text{%s}' % name
return r'${} \sim \text{{Geometric}}(\mathit{{p}}={})$'.format(name,
get_variable_name(p))


class DiscreteUniform(Discrete):
R"""
Expand Down Expand Up @@ -913,16 +852,6 @@ def logp(self, value):
return bound(-tt.log(upper - lower + 1),
lower <= value, value <= upper)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
lower = dist.lower
upper = dist.upper
name = r'\text{%s}' % name
return r'${} \sim \text{{DiscreteUniform}}(\mathit{{lower}}={},~\mathit{{upper}}={})$'.format(name,
get_variable_name(lower),
get_variable_name(upper))


class Categorical(Discrete):
R"""
Expand Down Expand Up @@ -1044,14 +973,6 @@ def logp(self, value):
return bound(a, value >= 0, value <= (k - 1),
tt.all(p_ >= 0, axis=-1), tt.all(p <= 1, axis=-1))

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
p = dist.p
name = r'\text{%s}' % name
return r'${} \sim \text{{Categorical}}(\mathit{{p}}={})$'.format(name,
get_variable_name(p))


class Constant(Discrete):
r"""
Expand Down Expand Up @@ -1112,12 +1033,6 @@ def logp(self, value):
c = self.c
return bound(0, tt.eq(value, c))

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
name = r'\text{%s}' % name
return r'${} \sim \text{{Constant}}()$'.format(name)


ConstantDist = Constant

Expand Down Expand Up @@ -1231,16 +1146,6 @@ def logp(self, value):
0 <= psi, psi <= 1,
0 <= theta)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
theta = dist.theta
psi = dist.psi
name = r'\text{%s}' % name
return r'${} \sim \text{{ZeroInflatedPoisson}}(\mathit{{theta}}={},~\mathit{{psi}}={})$'.format(name,
get_variable_name(theta),
get_variable_name(psi))


class ZeroInflatedBinomial(Discrete):
R"""
Expand Down Expand Up @@ -1354,22 +1259,6 @@ def logp(self, value):
0 <= psi, psi <= 1,
0 <= p, p <= 1)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
n = dist.n
p = dist.p
psi = dist.psi

name_n = get_variable_name(n)
name_p = get_variable_name(p)
name_psi = get_variable_name(psi)
name = r'\text{%s}' % name
return (r'${} \sim \text{{ZeroInflatedBinomial}}'
r'(\mathit{{n}}={},~\mathit{{p}}={},~'
r'\mathit{{psi}}={})$'
.format(name, name_n, name_p, name_psi))


class ZeroInflatedNegativeBinomial(Discrete):
R"""
Expand Down Expand Up @@ -1523,22 +1412,6 @@ def logp(self, value):
0 <= psi, psi <= 1,
mu > 0, alpha > 0)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
mu = dist.mu
alpha = dist.alpha
psi = dist.psi

name_mu = get_variable_name(mu)
name_alpha = get_variable_name(alpha)
name_psi = get_variable_name(psi)
name = r'\text{%s}' % name
return (r'${} \sim \text{{ZeroInflatedNegativeBinomial}}'
r'(\mathit{{mu}}={},~\mathit{{alpha}}={},~'
r'\mathit{{psi}}={})$'
.format(name, name_mu, name_alpha, name_psi))


class OrderedLogistic(Categorical):
R"""
Expand Down Expand Up @@ -1619,12 +1492,3 @@ def __init__(self, eta, cutpoints, *args, **kwargs):
p = p_cum[..., 1:] - p_cum[..., :-1]

super().__init__(p=p, *args, **kwargs)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
name_eta = get_variable_name(dist.eta)
name_cutpoints = get_variable_name(dist.cutpoints)
return (r'${} \sim \text{{OrderedLogistic}}'
r'(\mathit{{eta}}={}, \mathit{{cutpoints}}={}$'
.format(name, name_eta, name_cutpoints))
49 changes: 47 additions & 2 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
import numbers
import contextvars
import dill
import inspect
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Callable

import numpy as np
import theano.tensor as tt
from theano import function
from ..util import get_repr_for_variable
import theano
from ..memoize import memoize
from ..model import (
Expand Down Expand Up @@ -135,9 +137,46 @@ def getattr_value(self, val):

return val

def _repr_latex_(self, name=None, dist=None):
def _distr_parameters_for_repr(self):
"""Return the names of the parameters for this distribution (e.g. "mu"
and "sigma" for Normal). Used in generating string (and LaTeX etc.)
representations of Distribution objects. By default based on inspection
of __init__, but can be overwritten if necessary (e.g. to avoid including
"sd" and "tau").
"""
return inspect.getfullargspec(self.__init__).args[1:]

def _distr_name_for_repr(self):
return self.__class__.__name__

def _str_repr(self, name=None, dist=None, formatting='plain'):
"""Generate string representation for this distribution, optionally
including LaTeX markup (formatting='latex').
"""
if dist is None:
dist = self
if name is None:
name = '[unnamed]'

param_names = self._distr_parameters_for_repr()
param_values = [get_repr_for_variable(getattr(dist, x), formatting=formatting)
for x in param_names]

if formatting == "latex":
param_string = ",~".join([r"\mathit{{{name}}}={value}".format(name=name,
value=value) for name, value in zip(param_names, param_values)])
return r"$\text{{{var_name}}} \sim \text{{{distr_name}}}({params})$".format(var_name=name,
distr_name=dist._distr_name_for_repr(), params=param_string)
else:
# 'plain' is default option
param_string = ", ".join(["{name}={value}".format(name=name,
value=value) for name, value in zip(param_names, param_values)])
return "{var_name} ~ {distr_name}({params})".format(var_name=name,
distr_name=dist._distr_name_for_repr(), params=param_string)

def _repr_latex_(self, **kwargs):
"""Magic method name for IPython to use for LaTeX formatting."""
return None
return self._str_repr(formatting="latex", **kwargs)

def logp_nojac(self, *args, **kwargs):
"""Return the logp, but do not include a jacobian term for transforms.
Expand Down Expand Up @@ -200,6 +239,9 @@ def logp(self, x):
"""
return tt.zeros_like(x)

def _distr_parameters_for_repr(self):
return []


class Discrete(Distribution):
"""Base class for discrete distributions"""
Expand Down Expand Up @@ -501,6 +543,9 @@ def random(self, point=None, size=None, **kwargs):
"Define a custom random method and pass it as kwarg random"
)

def _distr_parameters_for_repr(self):
return []


class _DrawValuesContext(metaclass=ContextMeta, context_class='_DrawValuesContext'):
""" A context manager class used while drawing values with draw_values
Expand Down
16 changes: 4 additions & 12 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import theano.tensor as tt
import warnings

from pymc3.util import get_variable_name
from ..math import logsumexp
from .dist_math import bound, random_choice
from .distribution import (Discrete, Distribution, draw_values,
Expand Down Expand Up @@ -578,6 +577,8 @@ def random(self, point=None, size=None):
samples = np.reshape(samples, size + dist_shape)
return samples

def _distr_parameters_for_repr(self):
return []

class NormalMixture(Mixture):
R"""
Expand Down Expand Up @@ -627,14 +628,5 @@ def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, *
super().__init__(w, Normal.dist(mu, sigma=sigma, shape=comp_shape),
*args, **kwargs)

def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
mu = dist.mu
w = dist.w
sigma = dist.sigma
name = r'\text{%s}' % name
return r'${} \sim \text{{NormalMixture}}(\mathit{{w}}={},~\mathit{{mu}}={},~\mathit{{sigma}}={})$'.format(name,
get_variable_name(w),
get_variable_name(mu),
get_variable_name(sigma))
def _distr_parameters_for_repr(self):
return ["w", "mu", "sigma"]
Loading