Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make transforms stateless #4551

Merged
merged 10 commits into from
Mar 23, 2021

Conversation

brandonwillard
Copy link
Contributor

This PR addresses a few more transform changes/issues.

The primary change is that transforms are now stateless (i.e. they no longer carry their own parameters). Stateful transforms make it very easy to accidentally introduce old and/or irrelevant parameters into a graph, and are a source for some extremely confusing and difficult bugs. That's why this change was made.

Now, transforms only take a "parameter extraction function" that, when applied to a random variable, will extract the required transform parameters.

In other words, transform objects are no longer random variable instance-specific, but random variable class-specfic.

@brandonwillard brandonwillard self-assigned this Mar 17, 2021
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the entire diff I'm now quite sure I got the purposes of the rv_var and rv_value args wrong.

pymc3/distributions/__init__.py Outdated Show resolved Hide resolved
pymc3/distributions/__init__.py Show resolved Hide resolved

if transform is not None and rv_var is None:
warnings.warn(
f"A transform was found for {measure_var}" " but no corresponding random variable"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String is a bit messed up.
More importantly: The sentence is a bit incomplete - no variable corresponding to what?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That measure_var doesn't have a random variable associated with it, so there's really nothing else to print or say. If anything, this should probably be an error condition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might actually make more sense to associate the transform object with the rv_var (i.e. the random variable). I'll have to think about that.

rv_var
The random variable being transformed
rv_value
The parameters required for the transform.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rv_value doesn't sound very intuitive for something that holds the transform parameters. (I was confused by this above already.)

How about rv and transform_params?
Or rv and params?

Copy link
Contributor Author

@brandonwillard brandonwillard Mar 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be clear about the rv_var and rv_value[_var] distinctions.

rv_vars are the "sample-space" variables that are produced by RandomVariable Ops.
rv_value[_var]s are the "measure-space" (or log-likelihood) variables that correspond to a specific value of an rv_var.

These are the same two types of variables described here, where the sloppy P(X = x) or x ~ X notation denotes the rv_var with X (i.e. the random variable), and the value variable with rv_value[_var].

These transform methods are getting those two variables, so any new name that involves "params" would be inaccurate, because the rv_value variable does not provide parameters. The first argument, rv_var, does provide access to a random variable's parameters via rv_var.owner.inputs, and—again—rv_value is a value that's compatible with the random variable rv_var (i.e. a value that could've been a sample from it).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So rv_var is the tensor of the user-provided, observed values? (A TensorConstant?)

We might still want to copy parts of your explanation into the docstring.

pymc3/distributions/transforms.py Show resolved Hide resolved
@brandonwillard brandonwillard force-pushed the more-transform-updates branch 3 times, most recently from b97dc37 to 0762608 Compare March 17, 2021 23:56
@brandonwillard brandonwillard force-pushed the more-transform-updates branch from 0762608 to 43bd711 Compare March 17, 2021 23:57
michaelosthege
michaelosthege previously approved these changes Mar 18, 2021
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few threads still open.
Nevertheless I'll say LGTM, but don't count too much on my judgement. Most of my trust comes from the facts that Brandon did this and that the CI Tests are now ✔.

pymc3/tests/test_distributions.py Show resolved Hide resolved
with pytest.warns(
DeprecationWarning, match="The argument `eps` is deprecated and will not be used."
):
tr.StickBreaking(eps=1e-9)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Where) do we keep a list of these changes? We should mention them in the release notes. The alternative is to raise the DeprecationWarning which saves users from complicated digging.

@brandonwillard
Copy link
Contributor Author

Sorry, been pretty busy, but I have another commit to push, and it's a big refactor that should address most/all of the open logp-related issues.

@brandonwillard brandonwillard force-pushed the more-transform-updates branch 2 times, most recently from 6d8c136 to 05d4e19 Compare March 22, 2021 05:36
@@ -161,80 +157,119 @@ def rv_log_likelihood_args(
variable).

"""
if not var.owner:
return None, None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't match with the return type hints and docstring.

Can you explain (maybe in the docstring) why and under what circumstances None, None is returned?

rv_value = rv_var.type.filter_variable(rv_value.astype(rv_var.dtype))

if rv_value_var is None:
rv_value_var = rv_value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the case when rv_value has no observations, right?

mean = alpha / (alpha + beta)
variance = (alpha * beta) / ((alpha + beta) ** 2 * (alpha + beta + 1))
# mean = alpha / (alpha + beta)
# variance = (alpha * beta) / ((alpha + beta) ** 2 * (alpha + beta + 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed?

#
# @logp_transform.register(rv_type)
# def transform(op, *args, **kwargs):
# return class_transform(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

super().__init__(shape, dtype, defaults=defaults, *args, **kwargs)
if kwargs.get("transform", None):
raise ValueError("Transformations for discrete distributions")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep the dtype checks? (Based on intX.)

Copy link
Contributor Author

@brandonwillard brandonwillard Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are done at the Aesara Op-level now (i.e. within RandomVariable.make_node); although I'm not sure if float-to-int conversion is part of that. It might only raise an exception for the wrong dtype. If it's not, then we might need to add that at this level.

@michaelosthege
Copy link
Member

The failing test looks like the non-deterministic logpt that @ricardoV94 noticed a few days ago?

@brandonwillard brandonwillard force-pushed the more-transform-updates branch from 05d4e19 to 00dcfad Compare March 23, 2021 00:32
@brandonwillard brandonwillard force-pushed the more-transform-updates branch from 00dcfad to 049c5f8 Compare March 23, 2021 00:58
@brandonwillard brandonwillard merged commit 4b07810 into pymc-devs:v4 Mar 23, 2021
@brandonwillard brandonwillard deleted the more-transform-updates branch March 23, 2021 03:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants