Skip to content

Commit

Permalink
Generalize attribute assignment to support passing module instances
Browse files Browse the repository at this point in the history
Co-authored-by: Anselm Levskaya <[email protected]>
  • Loading branch information
jheek and levskaya committed Feb 12, 2021
1 parent feb9119 commit 9bf7161
Show file tree
Hide file tree
Showing 6 changed files with 454 additions and 116 deletions.
2 changes: 1 addition & 1 deletion flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _dedup_scopes(scopes):
max_parent_path = tuple(reversed(path))
path.append(scope.name)
scope = scope.parent
if max_parent is not leaf:
if max_parent is not leaf and leaf in minimal_set:
del minimal_set[leaf]
paths.append((max_parent, max_parent_path))
return tuple(minimal_set), tuple(paths)
Expand Down
45 changes: 22 additions & 23 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def _normalize_axes(axes, ndim):
return tuple([ax if ax >= 0 else ndim + ax for ax in axes])


def _canonicalize_tuple(x):
if isinstance(x, Iterable):
return tuple(x)
else:
return (x,)


class DenseGeneral(Module):
"""A linear transformation with flexible axes.
Expand All @@ -64,23 +71,6 @@ class DenseGeneral(Module):
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
precision: Any = None

def setup(self):
"""Normalize hyperparameters."""
if not isinstance(self.features, Iterable):
self.features = (self.features,)
if not isinstance(self.axis, Iterable):
self.axis = (self.axis,)
if not isinstance(self.batch_dims, Iterable):
self.batch_dims = (self.batch_dims,)
self.features = tuple(self.features)
self.axis = tuple(self.axis)
self.batch_dims = tuple(self.batch_dims)
if self.batch_dims:
max_dim = np.max(self.batch_dims)
if set(self.batch_dims) != set(range(max_dim + 1)):
raise ValueError('batch_dims %s must be consecutive leading '
'dimensions starting from 0.' % str(self.batch_dims))

@compact
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Expand All @@ -91,13 +81,22 @@ def __call__(self, inputs: Array) -> Array:
Returns:
The transformed input.
"""
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
batch_dims = _canonicalize_tuple(self.batch_dims)
if batch_dims:
max_dim = np.max(batch_dims)
if set(batch_dims) != set(range(max_dim + 1)):
raise ValueError('batch_dims %s must be consecutive leading '
'dimensions starting from 0.' % str(batch_dims))

inputs = jnp.asarray(inputs, self.dtype)

ndim = inputs.ndim
n_batch_dims = len(self.batch_dims)
axis = _normalize_axes(self.axis, ndim)
batch_dims = _normalize_axes(self.batch_dims, ndim)
n_axis, n_features = len(axis), len(self.features)
n_batch_dims = len(batch_dims)
axis = _normalize_axes(axis, ndim)
batch_dims = _normalize_axes(batch_dims, ndim)
n_axis, n_features = len(axis), len(features)

def kernel_init_wrap(rng, shape, dtype=jnp.float32):
size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
Expand All @@ -108,7 +107,7 @@ def kernel_init_wrap(rng, shape, dtype=jnp.float32):
return jnp.reshape(kernel, shape)

batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + self.features
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape)
kernel = jnp.asarray(kernel, self.dtype)

Expand All @@ -126,7 +125,7 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32):
for _ in range(size_batch_dims)], axis=0)
return jnp.reshape(bias, shape)

bias = self.param('bias', bias_init_wrap, batch_shape + self.features)
bias = self.param('bias', bias_init_wrap, batch_shape + features)

# Reshape bias for broadcast.
expand_dims = sorted(
Expand Down
141 changes: 81 additions & 60 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import inspect
import os
import threading
import weakref

from typing import (Any, Callable, Sequence, Iterable, List, Optional, Tuple,
Set, Type, Union, TypeVar, Generic, Dict)

Expand Down Expand Up @@ -130,16 +132,33 @@ def disable_named_call():

# Utilities for pytrees of Modules defined inside setup()
# -----------------------------------------------------------------------------

def _sorted_items(x):
"""Returns items of a dict ordered by keys."""
return sorted(x.items(), key=lambda x: x[0])


def _get_suffix_value_pairs(
tree_or_leaf: Any) -> List[Tuple[str, Type["Module"]]]:
"""Helper for naming pytrees of submodules."""
dict_or_leaf = serialization.to_state_dict(tree_or_leaf)
if dict_or_leaf == {} or not isinstance(dict_or_leaf, dict):
if not isinstance(dict_or_leaf, dict) or dict_or_leaf == {}:
return [('', tree_or_leaf)]
else:
flat_dict = traverse_util.flatten_dict(dict_or_leaf)
return [('_' + '_'.join(k), v) for k, v in flat_dict.items()]
return [('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict)]

def _map_over_modules_in_tree(fn, tree_or_leaf):
"""Helper for mapping function over submodules."""
dict_or_leaf = serialization.to_state_dict(tree_or_leaf)
if not isinstance(dict_or_leaf, dict) or dict_or_leaf == {}:
return fn('', tree_or_leaf)
else:
flat_dict = traverse_util.flatten_dict(dict_or_leaf)
mapped_flat_dict = {k: fn('_' + '_'.join(k), v)
for k, v in _sorted_items(flat_dict)}
return serialization.from_state_dict(
tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict))

def _all_names_on_object(obj: Any) -> Set[str]:
"""Gets all names of attributes on `obj` and its classes throughout MRO.
Expand Down Expand Up @@ -233,8 +252,6 @@ def wrapped_module_method(self, *args, **kwargs):
if self.scope is None:
raise ValueError("Can't call compact methods on unbound modules")
self._state.in_compact_method = True
elif is_setup_method:
self._state.in_setup = True
_context.module_stack.append(self)
try:
return fun(self, *args, **kwargs)
Expand Down Expand Up @@ -287,15 +304,13 @@ class _ModuleInternalState:
in_compact_method: bool = False
in_setup: bool = False
setup_called: bool = False
last_varname: Optional[str] = None
autoname_cursor: Optional[dict] = dataclasses.field(default_factory=dict)
children: Dict[str, Union[str, 'Module']] = dataclasses.field(default_factory=dict)

def reset(self):
"""Resets transient state."""
self.in_compact_method = False
self.in_setup = False
self.last_varname = None
self.autoname_cursor = dict()

def export(self):
Expand All @@ -304,15 +319,13 @@ def export(self):
in_compact_method=self.in_compact_method,
in_setup=self.in_setup,
setup_called=False, # setup_called is object local, not shared.
last_varname=self.last_varname,
autoname_cursor=dict(self.autoname_cursor))
return cloned

def reimport(self, other):
"""Re-imports transform-preserved state from across transform boundary."""
self.in_compact_method = other.in_compact_method
self.in_setup = other.in_setup
self.last_varname = other.last_varname
self.autoname_cursor = dict(other.autoname_cursor)

_uninitialized_module_internal_state = _ModuleInternalState()
Expand All @@ -322,6 +335,9 @@ def reimport(self, other):
'__getstate__', '__setstate__', '__getnewargs_ex__',
'__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__')


_caches = weakref.WeakKeyDictionary()

This comment has been minimized.

Copy link
@avital

avital Feb 16, 2021

Contributor

I'm late to this party but can we add a comment explaining the caching logic here? What's the motivation and general approach?


# Base Module definition.
# -----------------------------------------------------------------------------
class Module:
Expand Down Expand Up @@ -452,47 +468,18 @@ def __setattr__(self, name: str, val: Any):
name: Attribute to set.
val: Value of the attribute.
"""
if name != '_state' and self._state.setup_called:
is_dataclass_attr = name in self.__dataclass_fields__ and self.__dataclass_fields__[name].init # pytype: disable=attribute-error

if not self._state.in_setup and not is_dataclass_attr:
# Raises a TypeError just like frozen python dataclasses.
raise TypeError("Module instance is frozen outside of setup method.")

# We don't mess with the parent module.
if name == 'parent':
pass
# Modules have been passed in as dataclass args and set in __init__.
elif name in self.__dataclass_fields__ and self.__dataclass_fields__[name].init: # pytype: disable=attribute-error
pass
if is_dataclass_attr:
if self._state.in_setup:
raise TypeError("Module construction attributes are frozen.")
object.__setattr__(self, name, val)
# Submodules are being defined and attached in setup()
else:
val = _freeze_attr(val)
for suffix, subvalue in _get_suffix_value_pairs(val):
if isinstance(subvalue, Module):
if not self._state.in_setup:
raise ValueError(
"You can only assign submodules to self in setup().")
if subvalue.parent is _unspecified_parent:
subvalue.parent = self
elif subvalue.parent is not self:
raise ValueError("Can't attach to remote parent in setup, pass in "
"bound Modules from outside as an argument.")
if subvalue.name is not None:
raise ValueError(
"In setup, assign names of Modules via self.<name> and not "
"using keyword argument name=\"<name>\"")
subvalue.name = f'{name}{suffix}'
subvalue.__post_init__()
# val is a parameter array or a Variable reference class.
elif isinstance(subvalue, (np.ndarray, jax.interpreters.xla.DeviceArray,
Variable)) and self._state.in_setup:
var_name = f'{name}{suffix}'
# namecheck to ensure named variable matches self attribute name.
if (suffix == '' and # not when assigning lists or dicts
self._state.last_varname and self._state.last_varname != var_name):
raise ValueError(f'Variable name {self._state.last_varname} must '
f'equal attribute name {var_name}.')
self._state.last_varname = None
# Finally, always run default __setattr__ to attach to self.__dict__.
object.__setattr__(self, name, val)
self._register_submodules(name, val)

def __getattr__(self, name: str) -> Any:
"""Call setup() before getting any setup-defined attributes."""
Expand All @@ -519,11 +506,11 @@ def __post_init__(self):
# initialization, attach this Module as a submodule of a parent, or bind
# this Module at the top-level to variables and rngs.

self._state = _ModuleInternalState()
object.__setattr__(self, '_state', _ModuleInternalState())

# Typically we set the parent based on the dynamic module context.
if self.parent is _unspecified_parent: # pytype: disable=attribute-error
self.parent = _context.module_stack[-1]
object.__setattr__(self, 'parent', _context.module_stack[-1])

# Initialization is deferred for top level Modules or any other "orphan"
# Modules until attachment by __setattr__ i.e. MyModule(..., parent=None)
Expand All @@ -547,19 +534,18 @@ def __post_init__(self):
cursor = self.parent._state.autoname_cursor.get(prefix, 0)
self.name = f"{prefix}_{cursor}"
self.parent._state.autoname_cursor[prefix] = cursor + 1
if self.parent._name_taken(self.name):
if self.parent._name_taken(self.name, self):
raise ValueError(
f"A variable of name {self.name} exists already, or "
f"trying to share submodule {self.__class__.__name__} by name "
f"{self.name}. To share submodules, store module instances as a"
f" Python object or as an attribute on self and reuse.")
self.parent._state.children[self.name] = self
self.scope = self.parent.scope.push(self.name)
object.__setattr__(self, 'scope', self.parent.scope.push(self.name))

# Top-level invocation with a functional Scope.
elif isinstance(self.parent, Scope):
self.scope = self.parent

object.__setattr__(self, 'scope', self.parent)
else:
raise ValueError("parent must be None, Module or Scope")

Expand Down Expand Up @@ -601,17 +587,56 @@ def setup(self):
"""
pass

def _try_setup(self):
def _register_submodules(self, name, val):
assert self.scope, 'Trying to register submodules on unbound scope.'
root = self.scope.root
cache = _caches.get(root, {})
_caches[root] = cache
queue = []
def adopt_attr_modules(cache, queue, suffix, subvalue):
if isinstance(subvalue, Module):
if subvalue.parent is None:
# module was passed from outside. It needs to be cloned
key = id(subvalue)
if key not in cache:
cache[key] = subvalue.clone()
subvalue = cache[key]
if subvalue.name is None:
object.__setattr__(subvalue, 'parent', self)
object.__setattr__(subvalue, 'name', f'{name}{suffix}')
queue.append(subvalue)
return subvalue
val = _freeze_attr(_map_over_modules_in_tree(
functools.partial(adopt_attr_modules, cache, queue), val))
object.__setattr__(self, name, val)
for x in queue:
x.__post_init__()

def _try_setup(self, shallow=False):
"""Tries to setup module if scope is available and setup has not been called yet."""
if self.scope and not self._state.setup_called and not self._state.in_setup:
try:
self.setup()
self._state.in_setup = True
# a shallow setup will only register attribute submodules but it does not call the user's setup
# this avoids running before a transformation.
for field in dataclasses.fields(self):
if field.name != 'parent' and field.init:
self._register_submodules(field.name, getattr(self, field.name))
if not shallow:
self.setup()
finally:
self._state.in_setup = False
self._state.setup_called = True

def _name_taken(self, name: str) -> bool:
return (name in self.scope.reservations or
name in _all_names_on_object(self))
def _name_taken(self, name: str, module: 'Module' = None) -> bool:
if name in _all_names_on_object(self):
val = getattr(self, name, None)
if module is not None and val is module:
# name is taken by the value itself because
# field assignment happened before naming
return False
return True
return name in self.scope.reservations

@property
def _initialization_allowed(self):
Expand Down Expand Up @@ -659,8 +684,6 @@ def variable(self, col: str, name: str, init_fn, *init_args) -> Variable:
if self._name_taken(name):
raise ValueError(
f'Name {name} already in use in {self.__class__.__name__}.')
# ephemeral state for setattr name-equality-check
self._state.last_varname = name
v = self.scope.variable(col, name, init_fn, *init_args)
self._state.children[name] = col
return v
Expand Down Expand Up @@ -688,8 +711,6 @@ def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
if self._name_taken(name):
raise ValueError(
f'Name {name} already in use in {self.__class__.__name__}.')
# ephemeral state for setattr name-equality-check
self._state.last_varname = name
v = self.scope.param(name, init_fn, *init_args)
self._state.children[name] = 'params'
return v
Expand Down
7 changes: 4 additions & 3 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_module_scopes(module):
A list of all functional-core Scopes bound on self and inside dataclass
fields.
"""
module._try_setup(shallow=True)
outer_scopes = []
def get_scope(x):
nonlocal outer_scopes
Expand Down Expand Up @@ -138,7 +139,7 @@ def core_fn(scopes, *args, **kwargs):
# we reference module_class, not self.__class__ to avoid infinite loop
cloned = module_class(parent=None, **attrs)
cloned = set_module_scopes(cloned, scopes)
cloned._state = self._state.export() # pylint: disable=protected-access
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
res = fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
return res
Expand Down Expand Up @@ -167,7 +168,7 @@ def wrapped_fn(self, *args, **kwargs):
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
cloned = set_module_scopes(self, scopes)
cloned._state = self._state.export() # pylint: disable=protected-access
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
return res
Expand Down Expand Up @@ -216,7 +217,7 @@ def wrapped_fn(self, *args, **kwargs):
# make a scope-function to transform
def core_fn(scopes, *args, **kwargs):
cloned = set_module_scopes(self, scopes)
cloned._state = self._state.export() # pylint: disable=protected-access
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
res = prewrapped_fn(cloned, *args, **kwargs)
self._state.reimport(cloned._state) # pylint: disable=protected-access
return res
Expand Down
Loading

0 comments on commit 9bf7161

Please sign in to comment.