Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PRNG handling in nn.jit under nn.scan. #4359

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 144 additions & 2 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import warnings

from flax import traceback_util
from flax import traverse_util
from flax.typing import (
In,
InOutAxis,
Expand Down Expand Up @@ -1499,6 +1500,81 @@ def _hashable_filter(x):
return x


class CountsHolder:

def __init__(self, flat_d):
self.flat_d = flat_d

@classmethod
def make(cls, d):
flat_d = traverse_util.flatten_dict(d)
flat_d = {k: v for k, v in flat_d.items()}
return cls(flat_d)

def sub(self, other):
delta_flat_d = {}
new_flat_d = collections.defaultdict(int, self.flat_d)
old_flat_d = collections.defaultdict(int, other.flat_d)
for k in new_flat_d:
delta_flat_d[k] = new_flat_d[k] - old_flat_d[k]
return CountsHolder(delta_flat_d)

def add(self, other):
delta_flat_d = {}
new_flat_d = collections.defaultdict(int, self.flat_d)
old_flat_d = collections.defaultdict(int, other.flat_d)
for k in new_flat_d:
delta_flat_d[k] = new_flat_d[k] + old_flat_d[k]
return CountsHolder(delta_flat_d)

def unflat(self):
return traverse_util.unflatten_dict(self.flat_d)


def set_from_dict(original, updates):
for k in updates:
if k not in original:
original[k] = updates[k]
else:
if isinstance(updates[k], dict):
set_from_dict(original[k], updates[k])
else:
original[k] = updates[k]


class _SideEffectCache(threading.local):

def __init__(self):
self.cache = {}


_side_effect_cache = _SideEffectCache()


def _restore_rng_counters(scopes, fingerprint, capture_old_counts):
if fingerprint not in _side_effect_cache.cache:
capture_new_counts = jax.tree.map(
lambda s: CountsHolder.make(s.rng_counters), scopes
)
capture_delta_counts = jax.tree.map(
lambda old, new: new.sub(old),
capture_old_counts,
capture_new_counts,
)
_side_effect_cache.cache[fingerprint] = capture_delta_counts
else:
updated_counts = jax.tree.map(
lambda x, y: x.add(y).unflat(),
_side_effect_cache.cache[fingerprint],
capture_old_counts,
)
jax.tree.map(
lambda s, u: set_from_dict(s.rng_counters, u),
scopes,
updated_counts,
)


def jit(
fn: Callable[..., Any],
variables: CollectionFilter = True,
Expand Down Expand Up @@ -1599,13 +1675,18 @@ def inner(
mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)

rng_groups = jax.tree.map(
lambda x: x.fold() if isinstance(x, LazyRng) else x,
lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x,
rng_groups,
is_leaf=lambda x: isinstance(x, LazyRng),
)

fingerprint = (mutable, module_hash_key)
return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs)
capture_old_counts = jax.tree.map(
lambda s: CountsHolder.make(s.rng_counters), scopes
)
res = jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs)
_restore_rng_counters(scopes, fingerprint, capture_old_counts)
return res

return pack(
inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True
Expand Down Expand Up @@ -1692,3 +1773,64 @@ def inner_loop(scope, carry):
def _unzip2(xs):
ys = tuple(zip(*xs))
return ys if ys else ((), ())


def fold_rngs(
fn: Callable[..., Any],
variables: CollectionFilter = True,
rngs: PRNGSequenceFilter = True,
) -> Callable[..., Any]:
# Close over scope_fn & repack_fn to avoid recompilation
# this is impure but we use the fingerprint arg to differentiate between cases
# where scope_fn or repack_fn actually produce non-identical results.
fold_rngs_context = TransformContext[tuple[Callable, Callable]]()

@functools.wraps(fn)
def wrapped_fold_rngs(fingerprint, variable_groups, rng_groups, *args, **kwargs):
scope_fn, repack_fn = fold_rngs_context.get()
hash_key = fingerprint[1]
# fingerprint is only used to differentiate the cache signature
# del fingerprint
scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable
y = fn(scope, hash_key, *args, **kwargs)
return y, repack_fn(scope) # pylint: disable=not-callable

def inner_fold_rngs(
scope_fn,
repack_fn,
variable_groups,
rng_groups,
module_hash_key,
*args,
**kwargs,
):
with fold_rngs_context.push((scope_fn, repack_fn)):
scopes: list[Scope] = jax.tree_util.tree_leaves(
scope_fn(variable_groups, rng_groups)
)
mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)

rng_groups = jax.tree.map(
lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x,
rng_groups,
is_leaf=lambda x: isinstance(x, LazyRng),
)

fingerprint = (mutable, module_hash_key)
capture_old_counts = jax.tree.map(
lambda s: CountsHolder.make(s.rng_counters), scopes
)
res = wrapped_fold_rngs(
fingerprint, variable_groups, rng_groups, *args, **kwargs
)
_restore_rng_counters(scopes, fingerprint, capture_old_counts)
return res

return pack(
inner_fold_rngs,
(variables,),
(variables,),
(rngs,),
name='fold_rngs',
enable_kwargs=True,
)
11 changes: 2 additions & 9 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def create(
else:
return LazyRng(rng, suffix)

def fold(self):
key = self.as_jax_rng()
def clear_suffix(self):
key = self.rng
return LazyRng(key, ())


Expand Down Expand Up @@ -583,13 +583,6 @@ def default_name(self, prefix: str) -> str:
return name
i += 1

def fold_rngs(self):
"""Folds the rngs of this scope into the parent scope."""
self._check_valid()
for name, rng in self.rngs.items():
assert isinstance(rng, LazyRng)
self.rngs[name] = rng.fold()

def push(
self, name: str | None = None, prefix: str = '', reuse=False
) -> 'Scope':
Expand Down
15 changes: 8 additions & 7 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
from .batch_apply import BatchApply as BatchApply
from .combinators import Sequential as Sequential
from .fp8_ops import (
Fp8DotGeneralOp as Fp8DotGeneralOp,
Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp,
Fp8DotGeneralOp as Fp8DotGeneralOp,
NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp,
)
from .initializers import (
Expand All @@ -95,8 +95,8 @@
Module as Module,
Variable as Variable,
apply as apply,
compact as compact,
compact_name_scope as compact_name_scope,
compact as compact,
disable_named_call as disable_named_call,
enable_named_call as enable_named_call,
init_with_output as init_with_output,
Expand All @@ -114,19 +114,19 @@
LayerNorm as LayerNorm,
RMSNorm as RMSNorm,
SpectralNorm as SpectralNorm,
WeightNorm as WeightNorm
WeightNorm as WeightNorm,
)
from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool)
from .recurrent import (
Bidirectional as Bidirectional,
ConvLSTMCell as ConvLSTMCell,
SimpleCell as SimpleCell,
GRUCell as GRUCell,
MGUCell as MGUCell,
LSTMCell as LSTMCell,
MGUCell as MGUCell,
OptimizedLSTMCell as OptimizedLSTMCell,
RNNCellBase as RNNCellBase,
RNN as RNN,
SimpleCell as SimpleCell,
)
from .spmd import (
LogicallyPartitioned as LogicallyPartitioned,
Expand All @@ -146,6 +146,8 @@
checkpoint as checkpoint,
cond as cond,
custom_vjp as custom_vjp,
fold_rngs as fold_rngs,
grad as grad,
jit as jit,
jvp as jvp,
map_variables as map_variables,
Expand All @@ -154,9 +156,8 @@
remat as remat,
scan as scan,
switch as switch,
vjp as vjp,
grad as grad,
value_and_grad as value_and_grad,
vjp as vjp,
vmap as vmap,
while_loop as while_loop,
)
Expand Down
Loading
Loading