Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResBlock combinator in stax #1569

Closed
wants to merge 3 commits into from
Closed

Conversation

aeftimia
Copy link
Contributor

I have seen (and used) the ResNet-style "split, apply, combine" design pattern enough that I thought it might be nice to have it built in. This may be trivial, but I thought it made code a little cleaner and could potentially clean up import statements for architectures with residual connections.

I confirmed the resnet50 example exhibited identical losses for each update with and without the new ResBlock call.

Please let me know if you like the idea, but would prefer some other function name and/or signature--I would be happy to modify it to suite maintainer preferences.

@shoyer
Copy link
Collaborator

shoyer commented Oct 26, 2019

Looks reasonable to me, though my preference would be to avoid the final “tail” layer defaulting to ReLU and keep all activations explicit. This would be more consistent with the rest of stax, and in my experience unexpected activations are a common source of bugs.

@@ -260,6 +260,22 @@ def apply_fun(params, inputs, **kwargs):
return init_fun, apply_fun


def ResBlock(*layers, fan_in=FanInSum, tail=Relu):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can’t use this syntax yet because JAX still supports Python 2 (for a little while longer)

Copy link
Contributor Author

@aeftimia aeftimia Oct 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me, though my preference would be to avoid the final “tail” layer defaulting to ReLU and keep all activations explicit. This would be more consistent with the rest of stax, and in my experience unexpected activations are a common source of bugs.

Oh interesting--in hindsight I could see that potentially being an issue but I don't think I would have thought of it before you mentioned it. I modified the default tail to be Identity.

You can’t use this syntax yet because JAX still supports Python 2 (for a little while longer)

I submitted a patch that works for Python 2 and 3 with a TODO to replace it with the more readable function head when Python 2 support is dropped.

Please feel free to let me know if you would like any further modifications.

@jekbradbury
Copy link
Contributor

Thanks for the PR! We're leaning towards declining this change even thought it's correct and increases the readability of the ResNet definition, because stax.py is meant more as an example (forking encouraged!) than as a growing or comprehensive library and more full-featured neural network creation APIs are available or coming soon (see https://github.com/google/trax and https://github.com/JuliusKunze/jaxnet for two of them).

@aeftimia
Copy link
Contributor Author

Thanks for the PR! We're leaning towards declining this change even thought it's correct and increases the readability of the ResNet definition, because stax.py is meant more as an example (forking encouraged!) than as a growing or comprehensive library and more full-featured neural network creation APIs are available or coming soon (see https://github.com/google/trax and https://github.com/JuliusKunze/jaxnet for two of them).

Thanks for the update. I was not aware of either of those two projects. It looks like Trax already has a similar residual layer.

@hawkinsp
Copy link
Collaborator

We're no longer evolving stax in favor of other libraries such as Flax, Haiku, and Trax. Closing!

@hawkinsp hawkinsp closed this May 12, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants