-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider changing the semantics of Module.setup()
to be called immediate during init rather than when a module is bound
#686
Comments
Thanks for this @avital! |
Actually, the workaround I just described doesn't work either, because in We should try, as part of this change, to set Here is the best workaround I could verify that works: import jax
from flax import linen as nn
from flax.linen import Dense
class Backbone(nn.Module):
dtype: Any = jnp.float32
@nn.compact
def __call__(self, inputs):
return Dense(10, dtype=self.dtype)(inputs)
class ResNet(nn.Module):
num_classes: int = 10
dtype: Any = jnp.float32
@staticmethod
def make_backbone(self):
return Backbone(self.dtype)
def setup(self):
self.backbone = ResNet.make_backbone(self)
self.classifier = nn.Dense(self.num_classes, dtype=self.dtype)
def __call__(self, inputs, train: bool = False):
x = self.backbone(inputs)
x = self.classifier(x)
return x |
See also the discussion on #665 |
@avital - How would this work w. |
@avital - I mean I guess, going back to the motivating example, what do we get by being able to grab the other issue is that class Foo(nn.Module):
def setup(self):
self.m = Bar()
self.v = self.variable('myvars', 'v', init_fn)
def __call__(self, x):
return self.m(x) + self.v.value with 'strict setup()' behavior, we'd be able to get |
I was thinking it would end up looking like this:
where Not sure though yet how deep this change would have to be. Another totally different use-case is:
Right now there's no way to access (These are just some more thoughts -- I don't know what the best path forward is, but I think @jheek is dwelling on these questions) |
Does this issue have anything to do with the initialization of the "last" submodule like here? (The line right above the header linked to) I'm not sure how to approach this. |
Actually, historically we went in the opposite direction of calling @shashank2000 - I actually have no idea what that sentence is even talking about - I'll have to figure out what the author meant by that and remove or redact that statement. It would be better to open a new issue if you have a specific question about initialization or writing a model. |
Right now, we tell users that
setup
acts like__init__
, but this isn't completely true assetup
is only called when a Module is bound -- that is -- when it has variables. Therefore, the following code doesn't work:A workaround is to expose the backbone as a method, e.g.
But this is quite confusing.
The proposal here, in short, is to make
setup
actually fire immediately during module construction, but defer variable from being actually present until the module is bound. This means that accessing the value of parameters defined duringsetup
would throw an error.There are some open design questions about this proposal, but I think it would simplify transfer learning use-cases, as well as simplifying the mental model around
setup
.@jheek and I have discussed this, and we may consider investigating this soon. I believe the change should mostly not impact user code, though in some cases modules that define variables during setup() and access their value may have to be thought about more carefully.
cc @rolandgvc @andsteing
The text was updated successfully, but these errors were encountered: