Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Requesting support for arbitrary pytrees of Module params #705

Closed
j-towns opened this issue Dec 4, 2020 · 4 comments
Closed

Requesting support for arbitrary pytrees of Module params #705

j-towns opened this issue Dec 4, 2020 · 4 comments

Comments

@j-towns
Copy link
Contributor

j-towns commented Dec 4, 2020

As mentioned in the docs, Module.__setattr__ allows for assigning arbitrary pytrees of submodules in setup. I would like to assign a list of params during setup. The following fails:

from jax import random
import flax.linen as nn


class Foo(nn.Module):
    def setup(self):
        self.param_list = [
            self.param('param_list', nn.initializers.zeros, (1, 2, 3))
            for _ in range(4)
        ]

foo = Foo()
foo.init({'params': random.PRNGKey(0)})

with

~/dev/flax/flax/linen/module.py in param(self, name, init_fn, *init_args)
    614           'wrapped in `@compact`')
    615     if self._name_taken(name):
--> 616       raise ValueError(
    617           f'Name {name} already in use in {self.__class__.__name__}.')
    618     # ephemeral state for setattr name-equality-check

ValueError: Name param_list already in use in Foo.

Setting the names of each list element manually also fails:

class Bar(nn.Module):
    def setup(self):
        self.param_list = [
            self.param(f'param_list_{i}', nn.initializers.zeros, (1, 2, 3))
            for i in range(4)
        ]

bar = Bar()
bar.init({'params': random.PRNGKey(0)})

this gives a different error though:

~/dev/flax/flax/linen/module.py in __setattr__(self, name, val)
    439           # namecheck to ensure named variable matches self attribute name.
    440           if self._state.last_varname and self._state.last_varname != var_name:
--> 441             raise ValueError(f'Variable name {self._state.last_varname} must '
    442                              f'equal attribute name {var_name}.')
    443           self._state.last_varname = None

ValueError: Variable name param_list_3 must equal attribute name param_list_0.

Would it be possible for an appropriate name string to be automatically inferred rather than being supplied by the user? I'm sure there's a good reason that it is the way it is, apologies if this has already been brought up elsewhere and I've missed it, just curious.

@j-towns j-towns changed the title Request support for arbitrary pytrees of Module params Requesting support for arbitrary pytrees of Module params Dec 4, 2020
@j-towns
Copy link
Contributor Author

j-towns commented Dec 4, 2020

A workaround for lists is to define a list of name strings, and then iterate over it, assigning params directly to the Module. This seems to work but obviously is quite cumbersome.

class Bar(nn.Module):
    def setup(self):
        param_names = [f'param_list_{i}' for i in range(4)]
        for n in param_names:
            self.__setattr__(
                n, self.param(n, nn.initializers.zeros, (1, 2, 3)))
        self.param_names = param_names

    def __call__(self):
        return self.param_list_0

bar = Bar().init({'params': random.PRNGKey(0)})

@avital
Copy link
Contributor

avital commented Dec 4, 2020

Yes, this exposes a problem with our current design for setup (that's related to #686, from a different angle).

Here's the gist of the problem: At what point should parameters (or, in general, variables) be bound (meaning -- have a parent and a name -- giving them a location in the variable tree)? Our current solution says "give it a parent and a name during the call to self.param and then when its assigned to an attribute make sure you haven't given it the wrong name". But we use a broken heuristic there based on the last generated parameter name.

Instead, we can entirely remove the bad heuristic, and say: parameters defined in setup should not use self.param and instead use a new nn.param which doesn't allow you to specify a parent or a name. Then when you assign to an attribute the parent and name become known (and we can use the same naming scheme for nested submodules as you had pointed out). We would somehow have to wrap the return value from nn.param because it's an array whose value can't be read until assigned to an attribute (an "unbound parameter", a concept we don't have right now in Flax)

I can't commit to a timeline for this change as there are many details to get right...

For now, maybe we can find a reasonable intermediate solution: Perhaps we simply shouldn't apply our "test that the attribute name matches the defined parameter name" for nested parameters... Seems like our heuristic can only be wrong in this case, so we might as well let people define their names like you've been doing.

avital added a commit to avital/flax that referenced this issue Dec 4, 2020
We currently try to enforce that a user can't do:

```py
class MyModule(nn.Module):
  def setup(self):
    self.name1 = self.param('name2', ...)
```

But the check is brittle and entirely disallows assigning dicts
or lists of parameters. This check relaxes our heuristic such
that it doesn't run at all for lists or dicts of parameters
assigned to attributes during setup.

See google#705 (comment)
for a bit more context.
@avital
Copy link
Contributor

avital commented Dec 4, 2020

Hi @j-towns -- could you check if this PR resolves your problem? #706

@j-towns
Copy link
Contributor Author

j-towns commented Dec 4, 2020

Yes it does, thanks!

@avital avital added this to the Improve Linen milestone Dec 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants