Skip to content

Commit

Permalink
Make transform objects stateless
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 17, 2021
1 parent 6a4cdd6 commit 0762608
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 319 deletions.
2 changes: 1 addition & 1 deletion pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, name, model=None, vars=None, test_point=None):
if transform:
# We need to create and add an un-transformed version of
# each transformed variable
untrans_var = transform.backward(var)
untrans_var = transform.backward(v, var)
untrans_var.name = v.name
vars.append(untrans_var)
vars.append(var)
Expand Down
92 changes: 55 additions & 37 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from functools import singledispatch
from itertools import chain
from typing import Generator, List, Optional, Tuple, Union
Expand All @@ -20,7 +22,7 @@

from aesara import config
from aesara.graph.basic import Variable, ancestors, clone_replace
from aesara.graph.op import compute_test_value
from aesara.graph.op import Op, compute_test_value
from aesara.tensor.random.op import Observed, RandomVariable
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.var import TensorVariable
Expand All @@ -33,7 +35,7 @@


@singledispatch
def logp_transform(op, inputs):
def logp_transform(op: Op):
return None


Expand Down Expand Up @@ -141,7 +143,8 @@ def change_rv_size(

def rv_log_likelihood_args(
rv_var: TensorVariable,
transformed: Optional[bool] = True,
*,
return_observations: bool = True,
) -> Tuple[TensorVariable, TensorVariable]:
"""Get a `RandomVariable` and its corresponding log-likelihood `TensorVariable` value.
Expand All @@ -151,8 +154,9 @@ def rv_log_likelihood_args(
A variable corresponding to a `RandomVariable`, whether directly or
indirectly (e.g. an observed variable that's the output of an
`Observed` `Op`).
transformed
When ``True``, return the transformed value var.
return_observations
When ``True``, return the observed values in place of the log-likelihood
value variable.
Returns
=======
Expand All @@ -163,12 +167,14 @@ def rv_log_likelihood_args(
"""

if rv_var.owner and isinstance(rv_var.owner.op, Observed):
return tuple(rv_var.owner.inputs)
elif hasattr(rv_var.tag, "value_var"):
rv_value = rv_var.tag.value_var
return rv_var, rv_value
else:
return rv_var, None
rv_var, obs_var = rv_var.owner.inputs
if return_observations:
return rv_var, obs_var
else:
return rv_var, rv_log_likelihood_args(rv_var)[1]

rv_value = getattr(rv_var.tag, "value_var", None)
return rv_var, rv_value


def rv_ancestors(graphs: List[TensorVariable]) -> Generator[TensorVariable, None, None]:
Expand Down Expand Up @@ -217,7 +223,7 @@ def sample_to_measure_vars(
if not (anc.owner and isinstance(anc.owner.op, RandomVariable)):
continue

_, value_var = rv_log_likelihood_args(anc)
_, value_var = rv_log_likelihood_args(anc, return_observations=False)

if value_var is not None:
replace[anc] = value_var
Expand All @@ -233,8 +239,9 @@ def sample_to_measure_vars(
def logpt(
rv_var: TensorVariable,
rv_value: Optional[TensorVariable] = None,
jacobian: Optional[bool] = True,
scaling: Optional[bool] = True,
jacobian: bool = True,
scaling: bool = True,
transformed: bool = True,
**kwargs,
) -> TensorVariable:
"""Create a measure-space (i.e. log-likelihood) graph for a random variable at a given point.
Expand All @@ -257,6 +264,8 @@ def logpt(
Whether or not to include the Jacobian term.
scaling
A scaling term to apply to the generated log-likelihood graph.
transformed
Apply transforms.
"""

Expand All @@ -282,22 +291,22 @@ def logpt(

raise NotImplementedError("Missing value support is incomplete")

# "Flatten" and sum an array of indexed RVs' log-likelihoods
rv_var, missing_values = rv_node.inputs

missing_values = missing_values.data
logp_var = aet.sum(
[
logpt(
rv_var,
)
for idx, missing in zip(
np.ndindex(missing_values.shape), missing_values.flatten()
)
if missing
]
)
return logp_var
# # "Flatten" and sum an array of indexed RVs' log-likelihoods
# rv_var, missing_values = rv_node.inputs
#
# missing_values = missing_values.data
# logp_var = aet.sum(
# [
# logpt(
# rv_var,
# )
# for idx, missing in zip(
# np.ndindex(missing_values.shape), missing_values.flatten()
# )
# if missing
# ]
# )
# return logp_var

return aet.zeros_like(rv_var)

Expand All @@ -312,15 +321,16 @@ def logpt(
# If any of the measure vars are transformed measure-space variables
# (signified by having a `transform` value in their tags), then we apply
# the their transforms and add their Jacobians (when enabled)
if transform:
logp_var = _logp(rv_node.op, transform.backward(rv_value), *dist_params, **kwargs)
if transform and transformed:
logp_var = _logp(rv_node.op, transform.backward(rv_var, rv_value), *dist_params, **kwargs)

logp_var = transform_logp(
logp_var,
tuple(replacements.values()),
)

if jacobian:
transformed_jacobian = transform.jacobian_det(rv_value)
transformed_jacobian = transform.jacobian_det(rv_var, rv_value)
if transformed_jacobian:
if logp_var.ndim > transformed_jacobian.ndim:
logp_var = logp_var.sum(axis=-1)
Expand All @@ -345,11 +355,17 @@ def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> Te
for measure_var in inputs:

transform = getattr(measure_var.tag, "transform", None)
rv_var = getattr(measure_var.tag, "rv_var", None)

if transform is not None and rv_var is None:
warnings.warn(
f"A transform was found for {measure_var} but not a corresponding random variable"
)

if transform is None:
if transform is None or rv_var is None:
continue

trans_rv_value = transform.backward(measure_var)
trans_rv_value = transform.backward(rv_var, measure_var)
trans_replacements[measure_var] = trans_rv_value

if trans_replacements:
Expand All @@ -359,7 +375,7 @@ def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> Te


@singledispatch
def _logp(op, value, *dist_params, **kwargs):
def _logp(op: Op, value: TensorVariable, *dist_params, **kwargs):
"""Create a log-likelihood graph.
This function dispatches on the type of `op`, which should be a subclass
Expand All @@ -370,7 +386,9 @@ def _logp(op, value, *dist_params, **kwargs):
return aet.zeros_like(value)


def logcdf(rv_var, rv_value, jacobian=True, **kwargs):
def logcdf(
rv_var: TensorVariable, rv_value: Optional[TensorVariable], jacobian: bool = True, **kwargs
):
"""Create a log-CDF graph."""

rv_var, _ = rv_log_likelihood_args(rv_var)
Expand Down
27 changes: 10 additions & 17 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,31 +104,24 @@ class BoundedContinuous(Continuous):


@logp_transform.register(PositiveContinuous)
def pos_cont_transform(op, rv_var):
def pos_cont_transform(op):
return transforms.log


@logp_transform.register(UnitContinuous)
def unit_cont_transform(op, rv_var):
def unit_cont_transform(op):
return transforms.logodds


@logp_transform.register(BoundedContinuous)
def bounded_cont_transform(op, rv_var):
_, _, _, lower, upper = rv_var.owner.inputs
lower = aet.as_tensor_variable(lower) if lower is not None else None
upper = aet.as_tensor_variable(upper) if upper is not None else None

if lower is None and upper is None:
transform = None
elif lower is not None and upper is None:
transform = transforms.lowerbound(lower)
elif lower is None and upper is not None:
transform = transforms.upperbound(upper)
else:
transform = transforms.interval(lower, upper)

return transform
def bounded_cont_transform(op):
def transform_params(rv_var):
_, _, _, lower, upper = rv_var.owner.inputs
lower = aet.as_tensor_variable(lower) if lower is not None else None
upper = aet.as_tensor_variable(upper) if upper is not None else None
return lower, upper

return transforms.interval(transform_params)


def assert_negative_support(var, label, distname, value=-1e-6):
Expand Down
10 changes: 7 additions & 3 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):


def quaddist_chol(delta, chol_mat):
diag = aet.nlinalg.diag(chol_mat)
diag = aet.diag(chol_mat)
# Check if the covariance matrix is positive definite.
ok = aet.all(diag > 0)
# If not, replace the diagonal. We return -inf later, but
Expand Down Expand Up @@ -222,7 +222,7 @@ class MvNormal(Continuous):
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
mu = aet.as_tensor_variable(mu)
cov = quaddist_matrix(cov, tau, chol, lower)
return super().__init__([mu, cov], **kwargs)
return super().dist([mu, cov], **kwargs)

def logp(value, mu, cov):
"""
Expand Down Expand Up @@ -968,7 +968,11 @@ def __init__(self, eta, n, sd_dist, *args, **kwargs):
if sd_dist.shape.ndim not in [0, 1]:
raise ValueError("Invalid shape for sd_dist.")

transform = transforms.CholeskyCovPacked(n)
def transform_params(rv_var):
_, _, _, n, eta = rv_var.owner.inputs
return np.arange(1, n + 1).cumsum() - 1

transform = transforms.CholeskyCovPacked(transform_params)

kwargs["shape"] = shape
kwargs["transform"] = transform
Expand Down
Loading

0 comments on commit 0762608

Please sign in to comment.