From 598dd9de2b818a58480071720a9f3da63177be89 Mon Sep 17 00:00:00 2001 From: Austin Rochford Date: Sun, 17 Oct 2021 10:11:23 -0400 Subject: [PATCH] Add dims to *Ordered probs variables when appropriate (#5084) * Add dims to *Ordered probs variables when appropriate * Make black happy * For real about black this time --- pymc/distributions/discrete.py | 4 ++-- pymc/distributions/multivariate.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d2c32c1a49e..6ed1a901da6 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1745,7 +1745,7 @@ class OrderedLogistic: def __new__(cls, name, *args, compute_p=True, **kwargs): out_rv = _OrderedLogistic(name, *args, **kwargs) if compute_p: - pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3]) + pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3], dims=kwargs.get("dims")) return out_rv @classmethod @@ -1856,7 +1856,7 @@ class OrderedProbit: def __new__(cls, name, *args, compute_p=True, **kwargs): out_rv = _OrderedProbit(name, *args, **kwargs) if compute_p: - pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3]) + pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[3], dims=kwargs.get("dims")) return out_rv @classmethod diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index ac25f7308cc..abeb4909994 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -752,7 +752,7 @@ class OrderedMultinomial: def __new__(cls, name, *args, compute_p=True, **kwargs): out_rv = _OrderedMultinomial(name, *args, **kwargs) if compute_p: - pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[4]) + pm.Deterministic(f"{name}_probs", out_rv.owner.inputs[4], dims=kwargs.get("dims")) return out_rv @classmethod