Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax parameter attribute check #706

Merged
merged 1 commit into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def __setattr__(self, name: str, val: Any):
Variable)) and self._state.in_setup:
var_name = f'{name}{suffix}'
# namecheck to ensure named variable matches self attribute name.
if self._state.last_varname and self._state.last_varname != var_name:
if (suffix == '' and # not when assigning lists or dicts
self._state.last_varname and self._state.last_varname != var_name):
raise ValueError(f'Variable name {self._state.last_varname} must '
f'equal attribute name {var_name}.')
self._state.last_varname = None
Expand Down
36 changes: 36 additions & 0 deletions tests/linen/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,42 @@ def __call__(self, x):
with self.assertRaisesRegex(ValueError, 'notbias.*must equal.*bias'):
y = Dummy(x.shape, parent=scope)(x)

def test_setattr_name_var_disagreement_allowed_in_lists(self):
rngkey = jax.random.PRNGKey(0)
class Dummy(nn.Module):
xshape: Tuple[int]
def setup(self):
self.biases = [
self.param(f'bias_{i}', initializers.ones, self.xshape)
for i in range(4)]
def __call__(self, x):
return x + self.biases[0]

x = jnp.array([1.])
scope = Scope({}, {'params': rngkey}, mutable=['params'])
y = Dummy(x.shape, parent=scope)(x)
self.assertEqual(y, jnp.array([2.]))

def test_setattr_name_var_disagreement_allowed_in_dicts(self):
rngkey = jax.random.PRNGKey(0)
class Dummy(nn.Module):
xshape: Tuple[int]
def setup(self):
self.biases = {
# NOTE that keys still must be strings. This is to make a possible
# future transition to automatically derived parameter names when assigned
# as a dict easier (like we currently have with submodules).
# See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853
str(i): self.param(f'bias_{i}', initializers.ones, self.xshape)
for i in range(4)}
def __call__(self, x):
return x + self.biases['0']

x = jnp.array([1.])
scope = Scope({}, {'params': rngkey}, mutable=['params'])
y = Dummy(x.shape, parent=scope)(x)
self.assertEqual(y, jnp.array([2.]))

def test_submodule_var_collision(self):
rngkey = jax.random.PRNGKey(0)
class Dummy(nn.Module):
Expand Down