-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ResBlock combinator in stax #1569
Conversation
Looks reasonable to me, though my preference would be to avoid the final “tail” layer defaulting to ReLU and keep all activations explicit. This would be more consistent with the rest of stax, and in my experience unexpected activations are a common source of bugs. |
jax/experimental/stax.py
Outdated
@@ -260,6 +260,22 @@ def apply_fun(params, inputs, **kwargs): | |||
return init_fun, apply_fun | |||
|
|||
|
|||
def ResBlock(*layers, fan_in=FanInSum, tail=Relu): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can’t use this syntax yet because JAX still supports Python 2 (for a little while longer)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable to me, though my preference would be to avoid the final “tail” layer defaulting to ReLU and keep all activations explicit. This would be more consistent with the rest of stax, and in my experience unexpected activations are a common source of bugs.
Oh interesting--in hindsight I could see that potentially being an issue but I don't think I would have thought of it before you mentioned it. I modified the default tail
to be Identity
.
You can’t use this syntax yet because JAX still supports Python 2 (for a little while longer)
I submitted a patch that works for Python 2 and 3 with a TODO
to replace it with the more readable function head when Python 2 support is dropped.
Please feel free to let me know if you would like any further modifications.
Thanks for the PR! We're leaning towards declining this change even thought it's correct and increases the readability of the ResNet definition, because stax.py is meant more as an example (forking encouraged!) than as a growing or comprehensive library and more full-featured neural network creation APIs are available or coming soon (see https://github.com/google/trax and https://github.com/JuliusKunze/jaxnet for two of them). |
Thanks for the update. I was not aware of either of those two projects. It looks like Trax already has a similar residual layer. |
We're no longer evolving stax in favor of other libraries such as Flax, Haiku, and Trax. Closing! |
I have seen (and used) the ResNet-style "split, apply, combine" design pattern enough that I thought it might be nice to have it built in. This may be trivial, but I thought it made code a little cleaner and could potentially clean up
import
statements for architectures with residual connections.I confirmed the resnet50 example exhibited identical losses for each update with and without the new
ResBlock
call.Please let me know if you like the idea, but would prefer some other function name and/or signature--I would be happy to modify it to suite maintainer preferences.