Skip to content

Commit

Permalink
Fix str representations for KroneckerNormal and MatrixNormal (#4243)
Browse files Browse the repository at this point in the history
* fallback __str__ to default Theano on error
* fix str repr for KroneckerNormal and MatrixNormal
* black formatting
* update release notes
  • Loading branch information
Spaak authored Nov 25, 2020
1 parent 68d5201 commit a3a63da
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 3 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- Add alternative parametrization to NegativeBinomial distribution in terms of n and p (see [#4126](https://github.com/pymc-devs/pymc3/issues/4126))
- Add Bayesian Additive Regression Trees (BARTs) [#4183](https://github.com/pymc-devs/pymc3/pull/4183))
- Added a new `MixtureSameFamily` distribution to handle mixtures of arbitrary dimensions in vectorized form (see [#4185](https://github.com/pymc-devs/pymc3/issues/4185)).
- Added semantically meaningful `str` representations to PyMC3 objects for console, notebook, and GraphViz use (see [#4076](https://github.com/pymc-devs/pymc3/pull/4076), [#4065](https://github.com/pymc-devs/pymc3/pull/4065), [#4159](https://github.com/pymc-devs/pymc3/pull/4159), and [#4217](https://github.com/pymc-devs/pymc3/pull/4217))
- Added semantically meaningful `str` representations to PyMC3 objects for console, notebook, and GraphViz use (see [#4076](https://github.com/pymc-devs/pymc3/pull/4076), [#4065](https://github.com/pymc-devs/pymc3/pull/4065), [#4159](https://github.com/pymc-devs/pymc3/pull/4159), [#4217](https://github.com/pymc-devs/pymc3/pull/4217), and [#4243](https://github.com/pymc-devs/pymc3/pull/4243)).



Expand Down
5 changes: 4 additions & 1 deletion pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
)

def __str__(self, **kwargs):
return self._str_repr(formatting="plain", **kwargs)
try:
return self._str_repr(formatting="plain", **kwargs)
except:
return super().__str__()

def _repr_latex_(self, **kwargs):
"""Magic method name for IPython to use for LaTeX formatting."""
Expand Down
10 changes: 10 additions & 0 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,9 @@ def logp(self, x):
broadcast_conditions=False,
)

def _distr_parameters_for_repr(self):
return ["eta", "n"]


class MatrixNormal(Continuous):
R"""
Expand Down Expand Up @@ -1712,6 +1715,10 @@ def logp(self, value):
norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi))
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet

def _distr_parameters_for_repr(self):
mapping = {"tau": "tau", "cov": "cov", "chol": "chol_cov"}
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]


class KroneckerNormal(Continuous):
R"""
Expand Down Expand Up @@ -1954,3 +1961,6 @@ def logp(self, value):
"""
quad, logdet = self._quaddist(value)
return -(quad + logdet + self.N * tt.log(2 * np.pi)) / 2.0

def _distr_parameters_for_repr(self):
return ["mu"]
5 changes: 4 additions & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def _repr_latex_(self, **kwargs):
return self._str_repr(formatting="latex", **kwargs)

def __str__(self, **kwargs):
return self._str_repr(formatting="plain", **kwargs)
try:
return self._str_repr(formatting="plain", **kwargs)
except:
return super().__str__()

__latex__ = _repr_latex_

Expand Down
19 changes: 19 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,8 +1782,23 @@ def setup_class(self):
# add a bounded variable as well
bound_var = Bound(Normal, lower=1.0)("bound_var", mu=0, sigma=10)

# KroneckerNormal
n, m = 3, 4
covs = [np.eye(n), np.eye(m)]
kron_normal = KroneckerNormal("kron_normal", mu=np.zeros(n * m), covs=covs, shape=n * m)

# MatrixNormal
matrix_normal = MatrixNormal(
"mat_normal",
mu=np.random.normal(size=n),
rowcov=np.eye(n),
colchol=np.linalg.cholesky(np.eye(n)),
shape=(n, n),
)

# Likelihood (sampling distribution) of observations
Y_obs = Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)

self.distributions = [alpha, sigma, mu, b, Z, Y_obs, bound_var]
self.expected_latex = (
r"$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
Expand All @@ -1793,6 +1808,8 @@ def setup_class(self):
r"$\text{Z} \sim \text{MvNormal}(\mathit{mu}=array,~\mathit{chol_cov}=array)$",
r"$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$",
r"$\text{bound_var} \sim \text{Bound}(\mathit{lower}=1.0,~\mathit{upper}=\text{None})$ -- \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$",
r"$\text{kron_normal} \sim \text{KroneckerNormal}(\mathit{mu}=array)$",
r"$\text{mat_normal} \sim \text{MatrixNormal}(\mathit{mu}=array,~\mathit{rowcov}=array,~\mathit{colchol_cov}=array)$",
)
self.expected_str = (
r"alpha ~ Normal(mu=0.0, sigma=10.0)",
Expand All @@ -1802,6 +1819,8 @@ def setup_class(self):
r"Z ~ MvNormal(mu=array, chol_cov=array)",
r"Y_obs ~ Normal(mu=mu, sigma=f(sigma))",
r"bound_var ~ Bound(lower=1.0, upper=None)-Normal(mu=0.0, sigma=10.0)",
r"kron_normal ~ KroneckerNormal(mu=array)",
r"mat_normal ~ MatrixNormal(mu=array, rowcov=array, colchol_cov=array)",
)

def test__repr_latex_(self):
Expand Down

0 comments on commit a3a63da

Please sign in to comment.