Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Construct a means of maintaining RNG consistency between graph rewrites #209

Open
brandonwillard opened this issue Dec 4, 2020 · 0 comments
Labels
enhancement New feature or request graph rewriting important question Further information is requested random variables Involves random variables and/or sampling

Comments

@brandonwillard
Copy link
Member

brandonwillard commented Dec 4, 2020

While attempting to create "lift" rewrites for DimShuffles and *Subtensor*s on RandomVariables, I came across a couple related issues involving graph consistency and RNGs.

Pre-and-post rewrite numeric equality

The problem is neatly summarized by the following NumPy-only example:

>>> np.random.RandomState(123).normal(mean, std).T
array([[0.99989144, 3.99984937],
       [2.00009973, 4.99994214],
       [3.0000283 , 6.00016514]])

>>> np.random.RandomState(123).normal(mean.T, std.T)
array([[0.99989144, 4.00009973],
       [2.0000283 , 4.99984937],
       [2.99994214, 6.00016514]])

The first case is the numeric result one would obtain from a DimShuffled (i.e. reshaped) RandomVariable graph. The second is the lifted version of the same graph. Both result are theoretically equivalent and—ideally—should produce the same numeric result for the same RNG and seed. As we can see, they do not.

Here's an example of how it could be made to work:

>>> (mean + std * np.random.RandomState(123).standard_normal((2, 3))).T
array([[0.99989144, 3.99984937],
       [2.00009973, 4.99994214],
       [3.0000283 , 6.00016514]])

>>> mean.T + std.T * np.random.RandomState(123).standard_normal((2, 3)).T
array([[0.99989144, 3.99984937],
       [2.00009973, 4.99994214],
       [3.0000283 , 6.00016514]])

Simply put, by implementing the affine transform that distinguishes RandomState.normal from RandomState.standard_normal, we can transpose the underlying block of standard normals and preserve consistency between the graphs.

In other words, if we implement the underlying sampling processes, we can get what we want—in this case, at least.

Since I don't think we want to effectively reimplement all the samplers in NumPy's RandomState, we can either think of a good workaround to preserve consistency at a higher level, or we can accept the fact that the two graphs will produce different results although they're theoretically equivalent.

The latter isn't entirely acceptable, so we need to consider some workarounds.

Rewrite limitations

The issue described in the previous section only involved numeric reproducibility between graphs before and after the DimShuffle and *Subtensor* lift rewrites. This section is concerned with the limitations on when those rewrites can be used without producing inconsistent graphs.

If we want to apply these rewrites more often—or simply use them as part of our canonicalizations—we will need to overcome these limitations.

Problem

Imagine that we're creating an Aesara graph for the NumPy operations that produce z in the following:

import numpy as np

seed = 34893
rng = np.random.RandomState(seed)

x = rng.normal(np.arange(2))

z = x - x[1]
>>>  z
array([-0.7960794,  0.       ])

Just as with NumPy, we would expect an Aesara graph for z to necessarily have a 0 for the element at index 1. This should also hold for any RNG state.

The naive local_subtensor_rv_lift rewrite rule would effectively substitute x[1] with np.random.RandomState(seed).normal(np.arange(2)[1]), which would only imply that the expectation of z[1] is 0. I.e.

rng = np.random.RandomState(seed)

x = rng.normal(np.arange(2))

rng_2 = np.random.RandomState(seed)
y = rng_2.normal(np.arange(2)[1])

z = x - y
>>>  z
array([-1.       , -0.2039206])

Unfortunately, that's not what the graph actually represents, so this rewrite is inconsistent.

As a simple way to avoid introducing this issue, we should not perform the rewrite if there's another reference to x in the graph; however, that would limit the applicability of the optimization. This restriction can be loosened a bit by allowing references to invariant properties (e.g. the shape of x) and not the values in x themselves.

Workarounds

RNG-based

We might be able to solve a larger number of cases using an RNG-based approach. Such an approach might also preserve numeric equality between graphs (i.e. equality of graphs pre-and-post rewrite, as described above), but it will require some additional Aesara functionality.

The idea is that we track the number of elements to skip, which might not be too difficult in most cases, especially since we're already computing all the requisite shape and index information for the rewrites themselves. In other words, the Aesara RNG objects would carry a set of state "jumps" that determine the evolution of the internal RNG state based on the indexing applied to it.

The most basic way of implementing this could use a seed-based approach (offsets from a seed, really). This would work with all RNGs and samplers, but I'm not sure if it could be efficiently extended to blocks/slices of indices. It seems like we would have to ensure that all values were drawn individually from a flattened version of the array. This isn't difficult to do, and it could be implemented in C/Cython/Numba to cut costs.

Alternatively, we could—and eventually should—add support for at least one of the two more flexible NumPy BitGenerators: PCG64 and/or Philox. These RNGs implement an .advance method that would allow us to manipulate the state in a manner that preserves consistency between shuffles and subsets of RandomVariable arrays.

Our simple example above can be fixed in this way:

x = drng.normal(np.arange(2))

drng = np.random.default_rng(seed)
# Move the state forward so that the next sample matches the second entry in
# `x`
drng.bit_generator.advance(1)
y = drng.normal(np.arange(2)[1])

z = x - y
>>>  z
array([-2.68521984,  0.        ])

Naturally, this .advance-based approach won't work for certain samplers (e.g. rejection-based ones), but it should work for more than a few of the samplers for basic random variables.

Unfortunately, this approach would end up sampling the same value multiple times throughout a graph if it's implemented without some form of caching.

Otherwise, these RNG-based approaches have a direct correspondence with the actual source of change between rewrites (i.e. the RNG state), which adds to their appeal. In other words, indexing is equivalent to shifting an abstract rng state: normal(mean, stddev, rng)[index] is converted to normal(mean[index], stddev[index], new_rng).

Graph-based

We could also attempt to synchronize slices of x throughout the graph by replacing the rewritten RandomVariables with stand-ins that are updated in-place. In effect, we would replace indexed random arrays with some type of sparse, lazy random arrays that operate like a sparse array would, except that when elements are indexed a value is generated and permanently saved for those index locations.

This is a nice solution because it would work for any RNG and sampling method. It would also avoid the RNG-based issue of producing duplicate samples, since it's effectively an extreme type of the caching needed to reduce duplicates in that approach.

Unfortunately, it would incur most of the same overhead that sparse arrays do, but some of that could be ameliorated by a simple, low-level C implementation—at least for certain key steps. It also doesn't address the simpler pre-and-post graph rewrite numerical consistency.

Originally posted by @brandonwillard in #137 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting important question Further information is requested random variables Involves random variables and/or sampling
Projects
None yet
Development

No branches or pull requests

1 participant