-
-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed-up normalization layers #2220
Conversation
Have you benchmarked gradient times? Otherwise LGTM. |
Before: julia> @benchmark Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 4 samples with 1 evaluation.
Range (min … max): 1.530 s … 1.593 s ┊ GC (min … max): 3.85% … 4.30%
Time (median): 1.570 s ┊ GC (median): 3.90%
Time (mean ± σ): 1.566 s ± 31.406 ms ┊ GC (mean ± σ): 3.83% ± 0.54%
█ █ █ █
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁█ ▁
1.53 s Histogram: frequency by time 1.59 s <
Memory estimate: 5.50 GiB, allocs estimate: 611. This PR: julia> @benchmark Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 5 samples with 1 evaluation.
Range (min … max): 1.026 s … 1.057 s ┊ GC (min … max): 4.64% … 5.15%
Time (median): 1.040 s ┊ GC (median): 4.58%
Time (mean ± σ): 1.040 s ± 14.209 ms ┊ GC (mean ± σ): 4.38% ± 0.96%
█ ▁ ▁ ▁
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█ ▁
1.03 s Histogram: frequency by time 1.06 s <
Memory estimate: 3.50 GiB, allocs estimate: 675. |
What about benchmark on gpu? |
And a potential regression issue. The result is now aligned to the pytorch norm, but that means all old model trained with old Flux norm will be affected. |
Before: julia> @benchmark CUDA.@sync gn(x)
BenchmarkTools.Trial: 109 samples with 1 evaluation.
Range (min … max): 41.979 ms … 52.871 ms ┊ GC (min … max): 0.00% … 6.65%
Time (median): 44.962 ms ┊ GC (median): 0.00%
Time (mean ± σ): 45.950 ms ± 3.069 ms ┊ GC (mean ± σ): 1.51% ± 2.50%
▂ ▁▂ █ ▁
▃▆█▄█▄▆▆▁▁▃▄███▇█▆█▇▁▃▁▁▁▁▁▁▃▁▃▁▁▁▁▁▁▃▄▁▄█▃▃▃▃▃▄▄▇▁▄▄▃▁▁▁▁▃ ▃
42 ms Histogram: frequency by time 52.8 ms <
Memory estimate: 28.30 KiB, allocs estimate: 321.
julia> @benchmark CUDA.@sync Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 25 samples with 1 evaluation.
Range (min … max): 200.516 ms … 210.522 ms ┊ GC (min … max): 4.68% … 1.57%
Time (median): 202.030 ms ┊ GC (median): 4.64%
Time (mean ± σ): 202.604 ms ± 2.110 ms ┊ GC (mean ± σ): 4.49% ± 0.63%
▁ █ ▁▄ ▁
▆▆▁▆█▆▆▆█▆██▁▁▆█▁▁▁▆▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
201 ms Histogram: frequency by time 211 ms <
Memory estimate: 93.84 KiB, allocs estimate: 1551. This PR: julia> @benchmark CUDA.@sync gn(x)
BenchmarkTools.Trial: 153 samples with 1 evaluation.
Range (min … max): 30.696 ms … 39.091 ms ┊ GC (min … max): 0.00% … 0.00%
Time (median): 32.063 ms ┊ GC (median): 0.00%
Time (mean ± σ): 32.806 ms ± 2.100 ms ┊ GC (mean ± σ): 0.88% ± 2.26%
▁█▄
▄▃▃▃▃▄▅▇███▅▆▅▄▁▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▄▂▂▁▂▁▂▂▂▃▂▃▂▂▂ ▂
30.7 ms Histogram: frequency by time 38.8 ms <
Memory estimate: 34.22 KiB, allocs estimate: 385.
julia> @benchmark CUDA.@sync Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 39 samples with 1 evaluation.
Range (min … max): 126.358 ms … 138.385 ms ┊ GC (min … max): 2.34% … 2.65%
Time (median): 130.110 ms ┊ GC (median): 4.00%
Time (mean ± σ): 130.436 ms ± 2.064 ms ┊ GC (mean ± σ): 3.75% ± 0.89%
█▅▂ ▅▂
▅▁▁▅▁▁▁▅▁▁▁▅▁█▁█████▅██▅▁█▅▅▅▁▁▁▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▁▁▁▁▅ ▁
126 ms Histogram: frequency by time 138 ms <
Memory estimate: 99.58 KiB, allocs estimate: 1647. |
I think we should mark this as breaking, and tag a breaking release |
I think it is worth separating the speedup of the PR from the breaking change, since it may take some time to have a breaking release |
It would be great if we could make the eps change not affect deserialized old models, I don't see how to do it though. I tried to think along the line of adding to the types and extra param acting as a flag, e.g. |
Reverted |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that I never posted my comment here...
scale = γ ./ sqrt.(σ² .+ eps) | ||
bias = -scale .* μ .+ β | ||
l.λ.(scale .* x .+ bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these split up this way?
If γ
is the full size, then I think this will call sqrt
N^2 times when it could do N. Avoiding such N^2 is one reason not to fuse things. (Zygote will un-fuse but Diffractor should be lazier at least about +-*/.)
scale = γ ./ sqrt.(σ² .+ eps) | |
bias = -scale .* μ .+ β | |
l.λ.(scale .* x .+ bias) | |
den = sqrt.(σ² .+ eps) | |
l.λ.(γ .* x ./ den .- γ .* μ ./ den .+ β) |
Before:
This PR:
PR Checklist