Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider changing the semantics of Module.setup() to be called immediate during init rather than when a module is bound #686

Closed
avital opened this issue Nov 27, 2020 · 8 comments
Assignees
Labels
Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required)

Comments

@avital
Copy link
Contributor

avital commented Nov 27, 2020

Right now, we tell users that setup acts like __init__, but this isn't completely true as setup is only called when a Module is bound -- that is -- when it has variables. Therefore, the following code doesn't work:

class ResNet(nn.Module):
  def setup(self):
    self.backbone = Backbone()
    self.classifier = Classifier()

ResNet().backbone  # not available

A workaround is to expose the backbone as a method, e.g.

class ResNet(nn.Module):
  def make_backbone():
    return Backbone()

  def setup(self):
    self.backbone = self.make_backbone()
    self.classifier = self.make_classifier()

ResNet().make_backbone()

But this is quite confusing.

The proposal here, in short, is to make setup actually fire immediately during module construction, but defer variable from being actually present until the module is bound. This means that accessing the value of parameters defined during setup would throw an error.

There are some open design questions about this proposal, but I think it would simplify transfer learning use-cases, as well as simplifying the mental model around setup.

@jheek and I have discussed this, and we may consider investigating this soon. I believe the change should mostly not impact user code, though in some cases modules that define variables during setup() and access their value may have to be thought about more carefully.

cc @rolandgvc @andsteing

@rolandgvc
Copy link
Contributor

Thanks for this @avital!

@avital
Copy link
Contributor Author

avital commented Dec 1, 2020

Actually, the workaround I just described doesn't work either, because in __post_init__, we set parent based on the module_stack (which is designed for nn.compact methods).

We should try, as part of this change, to set parent only when assigning the module to an attribute. I think that will make the workaround described in my previous comment work.

Here is the best workaround I could verify that works:

import jax
from flax import linen as nn
from flax.linen import Dense


class Backbone(nn.Module):
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, inputs):
    return Dense(10, dtype=self.dtype)(inputs)


class ResNet(nn.Module):
  num_classes: int = 10
  dtype: Any = jnp.float32

  @staticmethod
  def make_backbone(self):
    return Backbone(self.dtype)

  def setup(self):
    self.backbone = ResNet.make_backbone(self)
    self.classifier = nn.Dense(self.num_classes, dtype=self.dtype)

  def __call__(self, inputs, train: bool = False):
    x = self.backbone(inputs)
    x = self.classifier(x)

    return x

@avital
Copy link
Contributor Author

avital commented Dec 1, 2020

See also the discussion on #665

@avital avital added this to the Linen design iteration milestone Dec 1, 2020
@BertrandRdp BertrandRdp added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Jan 14, 2021
@jheek jheek added Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required) and removed Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) labels Jan 20, 2021
@levskaya
Copy link
Collaborator

@avital - How would this work w. self.param('foo', init_fn, (1,2,3)) in setup()? that returns a jax array, so we can't proxy accessing it.

@levskaya
Copy link
Collaborator

levskaya commented Jan 20, 2021

@avital - I mean I guess, going back to the motivating example, what do we get by being able to grab Resnet.Backbone() over just instantiating Backbone()? You're not going to get automatic parameter subsets from this suggested change alone, as there would be no parameters at init time. What are we trying to simplify?

the other issue is that

class Foo(nn.Module):
  def setup(self):
    self.m = Bar()
    self.v = self.variable('myvars', 'v', init_fn)
  def __call__(self, x):
    return self.m(x) + self.v.value

with 'strict setup()' behavior, we'd be able to get m but not v, which feels no less confusing.

@avital
Copy link
Contributor Author

avital commented Jan 22, 2021

I was thinking it would end up looking like this:

class Foo(nn.Module):
  def setup(self):
    self.m = Bar()
    self.v = nn.variable('myvars', init_fn)
  def __call__(self, x):
    return self.m(x) + self.v.value

where nn.variable() returns an "unbound variable" that derives its name via __setattr__, like submodules. This removes the wanky logic we have to try to match variable names with attributes when defined during setup.

Not sure though yet how deep this change would have to be.

Another totally different use-case is:

class Conv(nn.Module):
  # define kernel, dilation, etc
 
  def setup(self):
    self.receptive_field = ...do some math...

Right now there's no way to access conv.receptive_field if you're not inside a module. You can also workaround this by defining a function but again this means we can't really say that "setup is like __init__" and we don't make it particularly easy for people to hook into __post_init__ at the moment...

(These are just some more thoughts -- I don't know what the best path forward is, but I think @jheek is dwelling on these questions)

@shashank2000
Copy link

Does this issue have anything to do with the initialization of the "last" submodule like here? (The line right above the header linked to)

I'm not sure how to approach this.

@levskaya
Copy link
Collaborator

Actually, historically we went in the opposite direction of calling setup() lazily. We were forced to do this in order to avoid exponential-time blowup via double-recursion which occurs with any eager initialization scheme in the presence of nested transformations (which currently frequently occur when for instance using JAX's named_call machinery to label profiling traces). As such, I'm closing this old issue.

@shashank2000 - I actually have no idea what that sentence is even talking about - I'll have to figure out what the author meant by that and remove or redact that statement. It would be better to open a new issue if you have a specific question about initialization or writing a model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P0 (urgent) Response within 1 business day. Resolution as soon as possible. (Assignee required)
Projects
None yet
Development

No branches or pull requests

6 participants