Construct a means of maintaining RNG consistency between graph rewrites #209
Labels
enhancement
New feature or request
graph rewriting
important
question
Further information is requested
random variables
Involves random variables and/or sampling
While attempting to create "lift" rewrites for
DimShuffle
s and*Subtensor*
s onRandomVariable
s, I came across a couple related issues involving graph consistency and RNGs.Pre-and-post rewrite numeric equality
The problem is neatly summarized by the following NumPy-only example:
The first case is the numeric result one would obtain from a
DimShuffle
d (i.e. reshaped)RandomVariable
graph. The second is the lifted version of the same graph. Both result are theoretically equivalent and—ideally—should produce the same numeric result for the same RNG and seed. As we can see, they do not.Here's an example of how it could be made to work:
Simply put, by implementing the affine transform that distinguishes
RandomState.normal
fromRandomState.standard_normal
, we can transpose the underlying block of standard normals and preserve consistency between the graphs.In other words, if we implement the underlying sampling processes, we can get what we want—in this case, at least.
Since I don't think we want to effectively reimplement all the samplers in NumPy's
RandomState
, we can either think of a good workaround to preserve consistency at a higher level, or we can accept the fact that the two graphs will produce different results although they're theoretically equivalent.The latter isn't entirely acceptable, so we need to consider some workarounds.
Rewrite limitations
The issue described in the previous section only involved numeric reproducibility between graphs before and after the
DimShuffle
and*Subtensor*
lift rewrites. This section is concerned with the limitations on when those rewrites can be used without producing inconsistent graphs.If we want to apply these rewrites more often—or simply use them as part of our canonicalizations—we will need to overcome these limitations.
Problem
Imagine that we're creating an Aesara graph for the NumPy operations that produce
z
in the following:Just as with NumPy, we would expect an Aesara graph for
z
to necessarily have a 0 for the element at index 1. This should also hold for any RNG state.The naive
local_subtensor_rv_lift
rewrite rule would effectively substitutex[1]
withnp.random.RandomState(seed).normal(np.arange(2)[1])
, which would only imply that the expectation ofz[1]
is 0. I.e.Unfortunately, that's not what the graph actually represents, so this rewrite is inconsistent.
As a simple way to avoid introducing this issue, we should not perform the rewrite if there's another reference to
x
in the graph; however, that would limit the applicability of the optimization. This restriction can be loosened a bit by allowing references to invariant properties (e.g. the shape ofx
) and not the values inx
themselves.Workarounds
RNG-based
We might be able to solve a larger number of cases using an RNG-based approach. Such an approach might also preserve numeric equality between graphs (i.e. equality of graphs pre-and-post rewrite, as described above), but it will require some additional Aesara functionality.
The idea is that we track the number of elements to skip, which might not be too difficult in most cases, especially since we're already computing all the requisite shape and index information for the rewrites themselves. In other words, the Aesara RNG objects would carry a set of state "jumps" that determine the evolution of the internal RNG state based on the indexing applied to it.
The most basic way of implementing this could use a seed-based approach (offsets from a seed, really). This would work with all RNGs and samplers, but I'm not sure if it could be efficiently extended to blocks/slices of indices. It seems like we would have to ensure that all values were drawn individually from a flattened version of the array. This isn't difficult to do, and it could be implemented in C/Cython/Numba to cut costs.
Alternatively, we could—and eventually should—add support for at least one of the two more flexible NumPy
BitGenerators
:PCG64
and/orPhilox
. These RNGs implement an.advance
method that would allow us to manipulate the state in a manner that preserves consistency between shuffles and subsets ofRandomVariable
arrays.Our simple example above can be fixed in this way:
Naturally, this
.advance
-based approach won't work for certain samplers (e.g. rejection-based ones), but it should work for more than a few of the samplers for basic random variables.Unfortunately, this approach would end up sampling the same value multiple times throughout a graph if it's implemented without some form of caching.
Otherwise, these RNG-based approaches have a direct correspondence with the actual source of change between rewrites (i.e. the RNG state), which adds to their appeal. In other words, indexing is equivalent to shifting an abstract
rng
state:normal(mean, stddev, rng)[index]
is converted tonormal(mean[index], stddev[index], new_rng)
.Graph-based
We could also attempt to synchronize slices of
x
throughout the graph by replacing the rewrittenRandomVariable
s with stand-ins that are updated in-place. In effect, we would replace indexed random arrays with some type of sparse, lazy random arrays that operate like a sparse array would, except that when elements are indexed a value is generated and permanently saved for those index locations.This is a nice solution because it would work for any RNG and sampling method. It would also avoid the RNG-based issue of producing duplicate samples, since it's effectively an extreme type of the caching needed to reduce duplicates in that approach.
Unfortunately, it would incur most of the same overhead that sparse arrays do, but some of that could be ameliorated by a simple, low-level C implementation—at least for certain key steps. It also doesn't address the simpler pre-and-post graph rewrite numerical consistency.
Originally posted by @brandonwillard in #137 (comment)
The text was updated successfully, but these errors were encountered: