Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ValuedVariable #78

Merged

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Oct 23, 2021

This PR replaces #71 and addresses the issues and discussions therein.

The ValuedVariable Op adds the value variable to the graph so that PreserveRVMappings is no longer needed. It also clarifies the definition and actions of rewrites that truly apply to a MeasurableVariable and its value variable simultaneously by making ValuedVariables the targets and outputs of such rewrites.

For instance, the old naive_bcast_rv_lift can be refactored into a rewrite that takes and produces ValuedVariables, and the MeasurableVariable/RandomVariable terms being broadcasted within the ValuedVariable can be lifted and have their corresponding scalar value variables broadcasted all in the same place and way. This approach clarifies the nature of the rewrite and makes the entire operation consistent with respect to the relation that it represents.

Our current approach (i.e. naive_bcast_rv_lift) is inconsistent as a stand-alone rewrite, because it doesn't also handle the value variable broadcasting. Instead, that is handled much further down the line within _logprob implementations; however, that approach is fraught with potentially hard(er)-to-track issues.

  • Refactor user-specified transforms (i.e. TransformValuesRewrite)
    Our whole approach to this needs to be rewritten, mostly because the current one is too convoluted. We can simply have a helper function that performs a FunctionGraph.replace_all and adds Jacobian terms to the appropriate variables directly.
  • Finish refactoring *Subtensor* support
  • Finish refactoring Scan support
  • Change use of the term "valued" to "bound"
  • Determine how we want to handle test_persist_inputs, test_warn_random_not_found, test_multiple_rvs_to_same_value_raises, and test_fail_multiple_clip_single_base

@brandonwillard brandonwillard added enhancement New feature or request important This label is used to indicate priority over things not given this label graph rewriting Involves the implementation of rewrites to Aesara graphs request discussion rv-transforms Involves transforms applied to random variables refactoring A change that improves the codebase but doesn't necessarily introduce a new feature labels Oct 23, 2021
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch 2 times, most recently from a6a9cb4 to f7b12b9 Compare October 25, 2021 14:58
@brandonwillard brandonwillard deleted the add-value-vars-to-graph branch November 15, 2021 17:18
@brandonwillard brandonwillard restored the add-value-vars-to-graph branch November 15, 2021 17:19
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch 3 times, most recently from 9e8e262 to bf5212b Compare January 14, 2022 19:12
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch 5 times, most recently from 03dc97e to fd2a534 Compare December 4, 2022 07:48
@rlouf rlouf marked this pull request as ready for review December 4, 2022 21:10
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch from fd2a534 to fbe946f Compare December 5, 2022 05:20
# ],
# [],
# )
sorted_rv_vars = rv_vars
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still necessary?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anything we should perform operations on variables that are conditioned on but are not RandomVariables first. Their value variable should be used to set the expression of the value variable of the RandomVariable upstream so we can support the examples in #119

Copy link
Member

@rlouf rlouf Dec 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import aesara.tensor as at
from aeppl import joint_logprob

x = at.normal()
y = at.log(x)
z = at.normal(y)

y_vv = y.clone()
z_vv = z.clone()
joint_logprob({y: y_vv, z: z_vv})

For instance in this example we should also replace x with a ValueVariable whose value variable is bound to x_vv = at.exp(y_vv). The disintegrator can then see that x is bound and use this expression to compute the logdensity conditional on z. It currently assumes that x is an unbound random variable and thus takes random realizations for x as inputs to the logdensity.

Note: I naturally used the vocabulary you suggested in this message, so I think we can move forward with the name BoundVariable.

aeppl/tensor.py Outdated Show resolved Hide resolved
tests/test_joint_logprob.py Outdated Show resolved Hide resolved
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch 2 times, most recently from 49c2c2c to 0b28a51 Compare December 25, 2022 06:59
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch from 0b28a51 to 7cd24fe Compare December 25, 2022 07:27
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch from 7cd24fe to 436c754 Compare January 2, 2023 00:18
@rlouf
Copy link
Member

rlouf commented Jan 2, 2023

Will the following snippet output the correct logdensity?

import aeppl
import aesara.tensor as at

srng = at.random.RandomStream(0)

x = srng.normal(0, 1)
z = srng.normal(0, 1)

x_tr = at.log(x)
logprob, vvs  = aeppl.joint_logprob(x_tr, z)

AePPL currently returns a logdensity graph where the input to z's conditional logdensity is a stochastic variable, instead of the exponential of the value variable bound to x_tr.

If not we can tackle this in another PR.

@brandonwillard
Copy link
Member Author

brandonwillard commented Jan 2, 2023

Will the following snippet output the correct logdensity?

import aeppl
import aesara.tensor as at

srng = at.random.RandomStream(0)

x = srng.normal(0, 1)
z = srng.normal(0, 1)

x_tr = at.log(x)
logprob, vvs  = aeppl.joint_logprob(x_tr, z)

AePPL currently returns a logdensity graph where the input to z's conditional logdensity is a stochastic variable, instead of the exponential of the value variable bound to x_tr.

If not we can tackle this in another PR.

The following passes:

srng = at.random.RandomStream(0)

X_rv = srng.normal(0, 1, name="X")
Z_rv = srng.normal(0, 1, name="Z")

x_tr = at.exp(X_rv)
logprob, vvs = joint_logprob(x_tr, Z_rv)

logp_fn = aesara.function(vvs, logprob)

x_val, z_val = 0.1, 0.1
exp_res = sp.stats.lognorm(s=1).logpdf(x_val) + sp.stats.norm().logpdf(z_val)

np.testing.assert_allclose(logp_fn(x_val, z_val), exp_res)

So does the case of "shared" realized value variables. For example:

srng = at.random.RandomStream(0)

X_rv = srng.normal(0, 1, name="X")
Z_tr = at.exp(X_rv)

Y_rv = srng.normal(0, 1, name="Y")

z_vv = at.dscalar(name="z_vv")

logprob, vvs = joint_logprob(realized={Z_tr: z_vv, Y_rv: z_vv})

logp_fn = aesara.function([z_vv], logprob)

z_val = 0.1

exp_res = sp.stats.lognorm(s=1).logpdf(z_val) + sp.stats.norm().logpdf(z_val)

np.testing.assert_allclose(logp_fn(z_val), exp_res)

@brandonwillard
Copy link
Member Author

brandonwillard commented Jan 2, 2023

This case doesn't work, though:

srng = at.random.RandomStream(0)

X_rv = srng.normal(0, 1, name="X")
Z_tr = at.exp(X_rv)

z_vv = at.dscalar(name="z_vv")

logprob, vvs = joint_logprob(realized={Z_tr: z_vv, X_rv: z_vv})

The IR for this case is as follows:

ValuedVariable [id A] 1
 |Elemwise{exp,no_inplace} [id B] 0
 | |z_vv [id C]
 |z_vv [id C]
ValuedVariable [id D] 3
 |normal_rv{0, (0, 0), floatX, False}.1 [id E] 'X' 2
 | |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F2C85101660>) [id F]
 | |TensorConstant{[]} [id G]
 | |TensorConstant{11} [id H]
 | |TensorConstant{0} [id I]
 | |TensorConstant{1} [id J]
 |z_vv [id C]

A log-probability isn't produced because we replaced the original X_rv with its value variable in the graph of Z_tr. We could replace the z_vv in the Elemwise with the ValueVariable for X_rv (i.e. id D). That was actually the original intention of this approach, but I haven't made that addition/change yet. We can work on that in a follow-up PR, though.

@brandonwillard
Copy link
Member Author

A log-probability isn't produced because we replaced the original X_rv with its value variable in the graph of Z_tr. We could replace the z_vv in the Elemwise with the ValueVariable for X_rv (i.e. id D). That was actually the original intention of this approach, but I haven't made that addition/change yet. We can work on that in a follow-up PR, though.

I made some small changes, and it looks like this might not be particularly difficult to add. The value variable transform approach needs to be updated, though.

@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch from 9e822a1 to ae7c307 Compare January 3, 2023 04:07
@brandonwillard
Copy link
Member Author

brandonwillard commented Jan 3, 2023

OK, all the functionality is there: we can get the conditional densities for transformed variables and the variables that were transformed.

@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch from ae7c307 to c3a4162 Compare January 3, 2023 04:25
@rlouf
Copy link
Member

rlouf commented Jan 3, 2023

Great! How about this one?

import aeppl
import aesara.tensor as at

srng = at.random.RandomStream(0)

x = srng.normal(0, 1)
z = srng.normal(x, 1)

x_tr = at.exp(x)
logprob, vvs  = aeppl.joint_logprob(x_tr, z)

(Important if we want to refactor the approach to transforms).

We should merge this even if this example doesn't pass, and do the necessary changes in another PR. This PR is already a huge improvement.

@brandonwillard
Copy link
Member Author

Great! How about this one?

import aeppl
import aesara.tensor as at

srng = at.random.RandomStream(0)

x = srng.normal(0, 1)
z = srng.normal(x, 1)

x_tr = at.exp(x)
logprob, vvs  = aeppl.joint_logprob(x_tr, z)

(Important if we want to refactor the approach to transforms).

We should merge this even if this example doesn't pass, and do the necessary changes in another PR. This PR is already a huge improvement.

That example appears to work as expected. In other words, the following succeeds:

srng = at.random.RandomStream(203920)

X = srng.normal(0, 1, name="X")
Y = srng.normal(X, 1, name="Y")
Z = at.exp(X)

logprob, (z_vv, y_vv) = joint_logprob(Z, Y)

aesara.dprint(logprob)
# Sum{acc_dtype=float64} [id A]
#  |MakeVector{dtype='float64'} [id B]
#    |Sum{acc_dtype=float64} [id C]
#    | |Elemwise{add,no_inplace} [id D]
#    |   |Check{sigma > 0} [id E]
#    |   | |Elemwise{sub,no_inplace} [id F]
#    |   | | |Elemwise{sub,no_inplace} [id G]
#    |   | | | |Elemwise{mul,no_inplace} [id H]
#    |   | | | | |TensorConstant{-0.5} [id I]
#    |   | | | | |Elemwise{pow,no_inplace} [id J]
#    |   | | | |   |Elemwise{true_div,no_inplace} [id K]
#    |   | | | |   | |Elemwise{sub,no_inplace} [id L]
#    |   | | | |   | | |Elemwise{log,no_inplace} [id M]
#    |   | | | |   | | | |<TensorType(float64, ())> [id N]
#    |   | | | |   | | |TensorConstant{0} [id O]
#    |   | | | |   | |TensorConstant{1} [id P]
#    |   | | | |   |TensorConstant{2} [id Q]
#    |   | | | |Elemwise{log,no_inplace} [id R]
#    |   | | |   |Elemwise{sqrt,no_inplace} [id S]
#    |   | | |     |TensorConstant{6.283185307179586} [id T]
#    |   | | |Elemwise{log,no_inplace} [id U]
#    |   | |   |TensorConstant{1} [id P]
#    |   | |All [id V]
#    |   |   |Elemwise{gt,no_inplace} [id W]
#    |   |     |TensorConstant{1} [id P]
#    |   |     |TensorConstant{0.0} [id X]
#    |   |Elemwise{neg,no_inplace} [id Y]
#    |     |Elemwise{log,no_inplace} [id Z]
#    |       |<TensorType(float64, ())> [id N]
#    |Sum{acc_dtype=float64} [id BA]
#      |Check{sigma > 0} [id BB] 'Y_logprob'
#        |Elemwise{sub,no_inplace} [id BC]
#        | |Elemwise{sub,no_inplace} [id BD]
#        | | |Elemwise{mul,no_inplace} [id BE]
#        | | | |TensorConstant{-0.5} [id BF]
#        | | | |Elemwise{pow,no_inplace} [id BG]
#        | | |   |Elemwise{true_div,no_inplace} [id BH]
#        | | |   | |Elemwise{sub,no_inplace} [id BI]
#        | | |   | | |Y_vv [id BJ]
#        | | |   | | |normal_rv{0, (0, 0), floatX, False}.1 [id BK] 'X'
#        | | |   | |   |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCA2701FD60>) [id BL]
#        | | |   | |   |TensorConstant{[]} [id BM]
#        | | |   | |   |TensorConstant{11} [id BN]
#        | | |   | |   |TensorConstant{0} [id O]
#        | | |   | |   |TensorConstant{1} [id P]
#        | | |   | |TensorConstant{1} [id P]
#        | | |   |TensorConstant{2} [id BO]
#        | | |Elemwise{log,no_inplace} [id BP]
#        | |   |Elemwise{sqrt,no_inplace} [id BQ]
#        | |     |TensorConstant{6.283185307179586} [id BR]
#        | |Elemwise{log,no_inplace} [id BS]
#        |   |TensorConstant{1} [id P]
#        |All [id BT]
#          |Elemwise{gt,no_inplace} [id BU]
#            |TensorConstant{1} [id P]
#            |TensorConstant{0.0} [id BV]

logp_fn = aesara.function((z_vv, y_vv), logprob)

z_val, y_val = 0.1, 0.3

mean_val = at.random.RandomStream(203920).normal().eval()
exp_res = sp.stats.lognorm(s=1).logpdf(z_val) + sp.stats.norm(loc=mean_val).logpdf(y_val)

np.testing.assert_allclose(logp_fn(z_val, y_val), exp_res)

@rlouf
Copy link
Member

rlouf commented Jan 3, 2023

# Elemwise{sub,no_inplace} [id BI]
#        | | |   | | |Y_vv [id BJ]
#        | | |   | | |normal_rv{0, (0, 0), floatX, False}.1 [id BK] 'X'
#        | | |   | |   |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCA2701FD60>) [id BL]
#        | | |   | |   |TensorConstant{[]} [id BM]
#        | | |   | |   |TensorConstant{11} [id BN]
#        | | |   | |   |TensorConstant{0} [id O]
#        | | |   | |   |TensorConstant{1} [id P]
#        | | |   | |TensorConstant{1} [id P]
#        | | |   |TensorConstant{2} [id BO]

Shouldn't we have X_vv = at.exp(Z_vv) here instead of X?

@brandonwillard
Copy link
Member Author

brandonwillard commented Jan 3, 2023

# Elemwise{sub,no_inplace} [id BI]
#        | | |   | | |Y_vv [id BJ]
#        | | |   | | |normal_rv{0, (0, 0), floatX, False}.1 [id BK] 'X'
#        | | |   | |   |RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FCA2701FD60>) [id BL]
#        | | |   | |   |TensorConstant{[]} [id BM]
#        | | |   | |   |TensorConstant{11} [id BN]
#        | | |   | |   |TensorConstant{0} [id O]
#        | | |   | |   |TensorConstant{1} [id P]
#        | | |   | |TensorConstant{1} [id P]
#        | | |   |TensorConstant{2} [id BO]

Shouldn't we have X_vv = at.exp(Z_vv) here instead of X?

No, I don't think so. As I understand it, Z = at.exp(X) is the only transformed and valued/bound term for which we're requesting a conditional density in that example, and the other conditional density (i.e. for Y) shouldn't be affected.

You might be asking about "value/bound variable transforms" (i.e. the transforms constructed by TransformValuesRewrite that change all instances of their transformed values in the conditional densities produced). We really need a good, distinct name for those transforms, in contrast to the kinds of "transforms" that simply produce a density for a measurable function of measurable variables (which is what the Z = at.exp(X) example represents).

Perhaps we could call the latter "derived densities/measures" or something similar.

Basically, Z = at.exp(X) is a type of "explicit" transform that we want to be measurable, and use of TransformValuesRewrite({X: ExpTransform()}) is an "implicit" one that is applied to all occurrences of X in the model graph (and results of condtional_logprob). In other words, the TransformValuesRewrite approach is like performing the variable replacement $X \to Z = \exp(X)$ and the value replacement $x \to z$ all throughout the model graph.

@brandonwillard
Copy link
Member Author

brandonwillard commented Jan 3, 2023

Here's an example of the TransformValuesRewrite approach, which might be what you're looking for:

srng = at.random.RandomStream(203920)

X = srng.normal(0, 1, name="X")
Y = srng.normal(X, 1, name="Y")

transform_rewrite = TransformValuesRewrite({X: ExpTransform()})

logprob, (x_vv, y_vv) = joint_logprob(X, Y, extra_rewrites=transform_rewrite)

aesara.dprint(logprob)
# Sum{acc_dtype=float64} [id A]
#  |MakeVector{dtype='float64'} [id B]
#    |Sum{acc_dtype=float64} [id C]
#    | |Elemwise{add,no_inplace} [id D] 'X_logprob'
#    |   |Check{sigma > 0} [id E]
#    |   | |Elemwise{sub,no_inplace} [id F]
#    |   | | |Elemwise{sub,no_inplace} [id G]
#    |   | | | |Elemwise{mul,no_inplace} [id H]
#    |   | | | | |TensorConstant{-0.5} [id I]
#    |   | | | | |Elemwise{pow,no_inplace} [id J]
#    |   | | | |   |Elemwise{true_div,no_inplace} [id K]
#    |   | | | |   | |Elemwise{sub,no_inplace} [id L]
#    |   | | | |   | | |Elemwise{log,no_inplace} [id M]
#    |   | | | |   | | | |X_vv-trans [id N]
#    |   | | | |   | | |TensorConstant{0} [id O]
#    |   | | | |   | |TensorConstant{1} [id P]
#    |   | | | |   |TensorConstant{2} [id Q]
#    |   | | | |Elemwise{log,no_inplace} [id R]
#    |   | | |   |Elemwise{sqrt,no_inplace} [id S]
#    |   | | |     |TensorConstant{6.283185307179586} [id T]
#    |   | | |Elemwise{log,no_inplace} [id U]
#    |   | |   |TensorConstant{1} [id P]
#    |   | |All [id V]
#    |   |   |Elemwise{gt,no_inplace} [id W]
#    |   |     |TensorConstant{1} [id P]
#    |   |     |TensorConstant{0.0} [id X]
#    |   |Elemwise{neg,no_inplace} [id Y]
#    |     |Elemwise{log,no_inplace} [id Z]
#    |       |X_vv-trans [id N]
#    |Sum{acc_dtype=float64} [id BA]
#      |Check{sigma > 0} [id BB] 'Y_logprob'
#        |Elemwise{sub,no_inplace} [id BC]
#        | |Elemwise{sub,no_inplace} [id BD]
#        | | |Elemwise{mul,no_inplace} [id BE]
#        | | | |TensorConstant{-0.5} [id BF]
#        | | | |Elemwise{pow,no_inplace} [id BG]
#        | | |   |Elemwise{true_div,no_inplace} [id BH]
#        | | |   | |Elemwise{sub,no_inplace} [id BI]
#        | | |   | | |Y_vv [id BJ]
#        | | |   | | |TransformedVariable [id BK]
#        | | |   | |   |Elemwise{log,no_inplace} [id M]
#        | | |   | |   |X_vv-trans [id N]
#        | | |   | |TensorConstant{1} [id P]
#        | | |   |TensorConstant{2} [id BL]
#        | | |Elemwise{log,no_inplace} [id BM]
#        | |   |Elemwise{sqrt,no_inplace} [id BN]
#        | |     |TensorConstant{6.283185307179586} [id BO]
#        | |Elemwise{log,no_inplace} [id BP]
#        |   |TensorConstant{1} [id P]
#        |All [id BQ]
#          |Elemwise{gt,no_inplace} [id BR]
#            |TensorConstant{1} [id P]
#            |TensorConstant{0.0} [id BS]

logp_fn = aesara.function((x_vv, y_vv), logprob)

x_val, y_val = np.exp(0.1), 0.3

exp_res = sp.stats.lognorm(s=1).logpdf(x_val) + sp.stats.norm(loc=np.log(x_val)).logpdf(y_val)

np.testing.assert_allclose(logp_fn(x_val, y_val), exp_res)

@rlouf
Copy link
Member

rlouf commented Jan 3, 2023

My ideas aren't completely clear, although I still feel what I described should be the expected behavior. Anyway, this PR is not the best place to discuss this. I suggest we merge this PR and open a separate issue/discussion.

@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch from 99900d9 to 896fa24 Compare January 5, 2023 01:46
@brandonwillard brandonwillard force-pushed the add-value-vars-to-graph branch from 896fa24 to 7929f87 Compare January 6, 2023 01:25
@brandonwillard
Copy link
Member Author

OK, I'm going to merge this now and follow up with a few more improvements shortly.

@brandonwillard brandonwillard merged commit 7890228 into aesara-devs:main Jan 6, 2023
@brandonwillard brandonwillard deleted the add-value-vars-to-graph branch January 6, 2023 05:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting Involves the implementation of rewrites to Aesara graphs important This label is used to indicate priority over things not given this label refactoring A change that improves the codebase but doesn't necessarily introduce a new feature request discussion rv-transforms Involves transforms applied to random variables
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants