Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define activation functions taking arrays as input #423

Merged
merged 2 commits into from
Jun 19, 2022
Merged

Define activation functions taking arrays as input #423

merged 2 commits into from
Jun 19, 2022

Conversation

theabhirath
Copy link
Member

An attempt to fix #422...hopefully this doesn't break anything

@darsnack
Copy link
Member

Convolution fuzzing tests are already failing on master

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jun 19, 2022 via email

@darsnack
Copy link
Member

What do you mean? The canonical definition of what?

If it's the layer forward definitions, then this doesn't change those. All this does is allow relu directly in a Chain instead of x -> relu.(x). And since the implementation is just broadcasting the function, it will hit the same paths.

@DhairyaLGandhi
Copy link
Member

The forward passes in many cases looks like

act.(f(x))

Ie with the activation broadcasted over an object. In this case, the activation function provided has to be relu. When something like x -> relu.(x) is provided with the above form of the forwards pass, the anonymous function is already broadcasted and generates

(x -> relu.(x)).(f(x))

Thus what actually happens is that the anonymous function actually receives a scalar and works anyway since numbers are iterable. I think in most cases it doesn't make much difference but that is what is happening unless the compiler can optimise the extra broadcast away. In AD, we would actually have to see both the outer broadcast and the inner broadcast and generate pullback for both (this hopefully isn't the case with a lens based system that can interleave optimisation and compilation, but is the case elsewhere).

To be clear, I'm in favour of picking points of optimisation and simplifications, i just wanted to clarify that it's mostly useful for cases when the activation function sees an iterable of arrays (like a tuple or vector of array), and if there are specific advantages to automatically calling broadcasted operations (say for fusion) then perhaps overloading broadcasted would give the compiler more hints.

@darsnack darsnack merged commit aad08b5 into FluxML:master Jun 19, 2022
@theabhirath theabhirath deleted the broadcast-act branch June 19, 2022 16:27
@theabhirath
Copy link
Member Author

Could we get a patch release with this? Would be helpful 😅

@mcabbott
Copy link
Member

mcabbott commented Jun 23, 2022

I missed this, but before we release, is it a good idea?

It means Chain(Dense(2 => 2), relu)(rand(2, 2)) will work, instead of giving an explanatory error. But Chain(Dense(2 => 2), tanh)(rand(2, 2)) will do something quite different, and using your own function f will probably not work.

Another level at which this "do what I mean" could be implemented is to make the Chain constructor replace any lonely activation functions with Base.Fix1(broadcast, f). Perhaps with a warning. Then when you enter such a chain at the REPL, at least you know that it's been fixed up somehow.

@darsnack
Copy link
Member

Unfortunately, we already released this change.

It means Chain(Dense(2 => 2), relu)(rand(2, 2)) will work, instead of giving an explanatory error. But Chain(Dense(2 => 2), tanh)(rand(2, 2)) will do something quite different, and using your own function f will probably not work.

My immediate thought is: when is activation(x::AbstractArray) not an element-wise operation? In almost any description of the model at the start of a paper, this is the assumed understanding. And this is clearly true for all the activations changed in this PR. So, my gut says there is almost no case where broadcasting isn't the correct and only thing to do, making the interpretation in this PR fairly safe. I'm not entirely clear on why the examples you gave should error.

Random thoughts:

  • Doing a replacement in a Chain constructor excludes Parallel, etc. We could do this in Parallel's constructor too, but then we exclude custom layers.
  • You could make it work for any model by doing a walk and replacing, but then the user has to do that to get more informative errors.
  • Doing things this way means that Dense(2 => 2, tanh) and Chain(Dense(2 => 2), tanh) do different things under the hood. Maybe Chain(..., tanh, ...) should be able to replace tanh with tanh_fast. Doesn't Chain(..., x -> tanh.(x), ...) also suffer from this issue?

@mcabbott
Copy link
Member

this is the assumed understanding

I agree it's what's normally meant, it just seems a bit contrary to Julia's normal behaviour.

My tanh example is meant to highlight that some functions already have matrix definitions, which are quite different. Any function you write yourself out of exp, log etc. will also tend to have such a definition. Although in practice getting a square matrix input is going to be unlikely.

My Chain constructor idea may indeed not be a great one. If it's noisy, then it could be an earlier place to remind people "you need to broadcast that!" than waiting until a runtime error. There would be less expectation that such a "training wheels" feature also apply to Parallel etc.

@darsnack
Copy link
Member

darsnack commented Jun 23, 2022

My tanh example is meant to highlight that some functions already have matrix definitions, which are quite different.

Good point, and actually these cases should probably reverted as piracy? Still a good point but I misread the source. No explicit piracy here.

@ToucheSir
Copy link
Member

The nuclear option is to do what PyTorch does and make vectorized activation layer types/constructors, then not export any of the activation functions themselves so that users are incentivized to use the former. That gets around some of the confusion with act.(x) vs act(x) not always being equivalent, but historically we've not wanted to go down this route.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Activation functions have to be broadcasted by the user to act on arrays
6 participants