-
-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gradients for conv_bias_act
, and a similar dense_bias_act
#346
Conversation
This comment was marked as off-topic.
This comment was marked as off-topic.
bcb3460
to
964dc16
Compare
Is the original message up top still accurate? It looks like the implementation is there. What help is necessary to get this through? |
My memory is that this basically worked, but the performance was disappointing due to JuliaLang/julia#43153 . Writing back into the same Edit: ok I've updated things. I think the most honest benchmark looks like this, and shows a serious improvement from julia> w, b = rand(Float32, 100, 100), rand(Float32, 100); x = rand(Float32, size(w)...);
julia> @btime gradient((w,x,b) -> sum(abs2, dense_bias_act(tanh, w, x, b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 44.792 μs, mean 79.901 μs (71 allocations, 198.37 KiB)
julia> @btime gradient((w,x,b) -> sum(abs2, tanh.((w * x) .+ b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 114.583 μs, mean 158.989 μs (39 allocations, 275.25 KiB)
julia> @btime gradient((w,x,b) -> sum(abs2, tanh_fast.((w * x) .+ b)), wr[], $x, $b) setup=(wr=Ref(randn(Float32,100,100))) evals=1;
min 40.125 μs, mean 75.140 μs (39 allocations, 275.25 KiB) Would be worthwhile to benchmark on other computers. (This is M1 + apple's blas.) And on GPUs. And |
Rebased at https://github.com/mcabbott/NNlib.jl/tree/bias_act_22 after squashing, but its own tests fail. |
This aims to add gradient definitions for the existing
conv_bias_act
. That is, however, very much WIP, and I don't recommend anyone try to read it just yet.It also adds an analogous
dense_bias_act
, which is closer to done. What this gains you overσ.(w*x .+ b)
is memory savings. Zygote will by default un-fuse the broadcast, allocating 3 arrays on the forward pass, but in fact we can often over-write the result ofw*x
, saving 2 copies. This should happen both on CPU and GPU. There is one more copy you could save on the reverse pass, bringing you to 1/2 the memory usage of before, but only if you were sure that the pullback would only be called once. That isn't true for sayZygote.jacobian
, and I don't think there's a way to know when it will be safe. So we save 1/3 not 1/2, when inside Zygote.I say "often" because over-writing
w*x
only works when the gradient ofσ
can be written in terms of its output, without saving its input. That's true fortanh
andrelu
and some others, which areexplicitly whitelisted here asnow handled using JuliaDiff/ChainRulesCore.jl#453 .INPLACE_ACTS
. Surely a more extensible method for that could be invented.This was written before seeing FluxML/NNlibCPU.jl#1 . But they may work well together -- for instance the function
dense!
there could (after we adjust signatures a little) simply overload a function here, providing a fast path when that package is loaded. Likewise it can overloadconv_bias_act!
to run a fused activation-and-convolution on the CPU, a bit like the existing NNlibCUDA routine. (From a first glance it looks likedense!
has a trait for deciding which functions are in-place-safe, which is good.) Again, not fully baked, but opened now to start discussing.