Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradients for conv_bias_act, and a similar dense_bias_act #346

Closed
wants to merge 15 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 9, 2021

This aims to add gradient definitions for the existing conv_bias_act. That is, however, very much WIP, and I don't recommend anyone try to read it just yet.

It also adds an analogous dense_bias_act, which is closer to done. What this gains you over σ.(w*x .+ b) is memory savings. Zygote will by default un-fuse the broadcast, allocating 3 arrays on the forward pass, but in fact we can often over-write the result of w*x, saving 2 copies. This should happen both on CPU and GPU. There is one more copy you could save on the reverse pass, bringing you to 1/2 the memory usage of before, but only if you were sure that the pullback would only be called once. That isn't true for say Zygote.jacobian, and I don't think there's a way to know when it will be safe. So we save 1/3 not 1/2, when inside Zygote.

I say "often" because over-writing w*x only works when the gradient of σ can be written in terms of its output, without saving its input. That's true for tanh and relu and some others, which are explicitly whitelisted here as INPLACE_ACTS. Surely a more extensible method for that could be invented. now handled using JuliaDiff/ChainRulesCore.jl#453 .

This was written before seeing FluxML/NNlibCPU.jl#1 . But they may work well together -- for instance the function dense! there could (after we adjust signatures a little) simply overload a function here, providing a fast path when that package is loaded. Likewise it can overload conv_bias_act! to run a fused activation-and-convolution on the CPU, a bit like the existing NNlibCUDA routine. (From a first glance it looks like dense! has a trait for deciding which functions are in-place-safe, which is good.) Again, not fully baked, but opened now to start discussing.

@DhairyaLGandhi

This comment was marked as off-topic.

test/activations.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

darsnack commented Jan 12, 2022

Is the original message up top still accurate? It looks like the implementation is there. What help is necessary to get this through?

@mcabbott
Copy link
Member Author

mcabbott commented Jan 13, 2022

My memory is that this basically worked, but the performance was disappointing due to JuliaLang/julia#43153 . Writing back into the same x (when safe) saved memory but not time, unless you pirated Base things as suggested there. (Which it looks like I didn't do on this branch?)

Edit: ok I've updated things. I think the most honest benchmark looks like this, and shows a serious improvement from tanh_fast. And one copy saved by bias_act!, but now avoiding a serious slowdown, but still slower than ideal, why 71 allocations?

julia> w, b = rand(Float32, 100, 100), rand(Float32, 100); x = rand(Float32, size(w)...);

julia> @btime gradient((w,x,b) -> sum(abs2, dense_bias_act(tanh, w, x, b)), wr[], $x, $b)  setup=(wr=Ref(randn(Float32,100,100))) evals=1;
  min 44.792 μs, mean 79.901 μs (71 allocations, 198.37 KiB)

julia> @btime gradient((w,x,b) -> sum(abs2, tanh.((w * x) .+ b)), wr[], $x, $b)  setup=(wr=Ref(randn(Float32,100,100))) evals=1;
  min 114.583 μs, mean 158.989 μs (39 allocations, 275.25 KiB)

julia> @btime gradient((w,x,b) -> sum(abs2, tanh_fast.((w * x) .+ b)), wr[], $x, $b)  setup=(wr=Ref(randn(Float32,100,100))) evals=1;
  min 40.125 μs, mean 75.140 μs (39 allocations, 275.25 KiB)

Would be worthwhile to benchmark on other computers. (This is M1 + apple's blas.) And on GPUs. And conv... and ideally in bigger examples, whatever happened to https://github.com/FluxML/FluxBench.jl ?

@mcabbott
Copy link
Member Author

mcabbott commented Sep 2, 2022

Rebased at https://github.com/mcabbott/NNlib.jl/tree/bias_act_22 after squashing, but its own tests fail.

@mcabbott mcabbott closed this Nov 24, 2022
@mcabbott mcabbott mentioned this pull request Jan 5, 2023
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants