Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed-up normalization layers #2220

Merged
merged 2 commits into from
May 5, 2023
Merged

Speed-up normalization layers #2220

merged 2 commits into from
May 5, 2023

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Mar 28, 2023

  • Speed-up other normalization layers, by re-arranging expression a bit.
$$y = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta = \frac{\gamma}{\sqrt{Var[x] + \epsilon}} * (x - E[x]) + \beta$$ $$scale = \frac{\gamma}{\sqrt{Var[x] + \epsilon}}$$ $$bias = -scale * E[x] + \beta$$ $$y = scale * x + bias$$
julia> using BenchmarkTools, Flux

julia> gn = GroupNorm(128, 32);

julia> x = rand(Float32, 256, 256, 128, 16);

julia> @benchmark gn(x)

Before:

julia> @benchmark gn(x)
BenchmarkTools.Trial: 9 samples with 1 evaluation.
 Range (min  max):  545.960 ms  576.558 ms  ┊ GC (min  max): 0.27%  6.13%
 Time  (median):     575.436 ms               ┊ GC (median):    6.13%
 Time  (mean ± σ):   570.931 ms ±  10.306 ms  ┊ GC (mean ± σ):  5.27% ± 2.01%

                                                              █  
  ▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▇▇▇█ ▁
  546 ms           Histogram: frequency by time          577 ms <

 Memory estimate: 1.00 GiB, allocs estimate: 32.

This PR:

julia> @benchmark gn(x)
BenchmarkTools.Trial: 23 samples with 1 evaluation.
 Range (min  max):  202.554 ms  221.089 ms  ┊ GC (min  max): 0.34%  7.94%
 Time  (median):     219.391 ms               ┊ GC (median):    7.99%
 Time  (mean ± σ):   218.166 ms ±   4.402 ms  ┊ GC (mean ± σ):  7.34% ± 2.20%

                                                        ▆█▂      
  ▄▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▄███▄▁▁▄ ▁
  203 ms           Histogram: frequency by time          221 ms <

 Memory estimate: 512.03 MiB, allocs estimate: 35.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@pxl-th pxl-th changed the title Speedup normalization layers Speed-up normalization layers Mar 28, 2023
@ToucheSir
Copy link
Member

Have you benchmarked gradient times? Otherwise LGTM.

@pxl-th
Copy link
Member Author

pxl-th commented Mar 28, 2023

Before:

julia> @benchmark Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range (min  max):  1.530 s    1.593 s  ┊ GC (min  max): 3.85%  4.30%
 Time  (median):     1.570 s              ┊ GC (median):    3.90%
 Time  (mean ± σ):   1.566 s ± 31.406 ms  ┊ GC (mean ± σ):  3.83% ± 0.54%

  █                █                                    █ █  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁█ ▁
  1.53 s         Histogram: frequency by time        1.59 s <

 Memory estimate: 5.50 GiB, allocs estimate: 611.

This PR:

julia> @benchmark Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range (min  max):  1.026 s    1.057 s  ┊ GC (min  max): 4.64%  5.15%
 Time  (median):     1.040 s              ┊ GC (median):    4.58%
 Time  (mean ± σ):   1.040 s ± 14.209 ms  ┊ GC (mean ± σ):  4.38% ± 0.96%

  █                        ▁                      ▁       ▁  
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁█ ▁
  1.03 s         Histogram: frequency by time        1.06 s <

 Memory estimate: 3.50 GiB, allocs estimate: 675.

@chengchingwen
Copy link
Member

What about benchmark on gpu?

@chengchingwen
Copy link
Member

And a potential regression issue. The result is now aligned to the pytorch norm, but that means all old model trained with old Flux norm will be affected.

@pxl-th
Copy link
Member Author

pxl-th commented Mar 28, 2023

What about benchmark on gpu?

Before:

julia> @benchmark CUDA.@sync gn(x)
BenchmarkTools.Trial: 109 samples with 1 evaluation.
 Range (min  max):  41.979 ms  52.871 ms  ┊ GC (min  max): 0.00%  6.65%
 Time  (median):     44.962 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   45.950 ms ±  3.069 ms  ┊ GC (mean ± σ):  1.51% ± 2.50%

    ▂         ▁▂  █ ▁                                          
  ▃▆█▄█▄▆▆▁▁▃▄███▇█▆█▇▁▃▁▁▁▁▁▁▃▁▃▁▁▁▁▁▁▃▄▁▄█▃▃▃▃▃▄▄▇▁▄▄▃▁▁▁▁▃ ▃
  42 ms           Histogram: frequency by time        52.8 ms <

 Memory estimate: 28.30 KiB, allocs estimate: 321.


julia> @benchmark CUDA.@sync Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 25 samples with 1 evaluation.
 Range (min  max):  200.516 ms  210.522 ms  ┊ GC (min  max): 4.68%  1.57%
 Time  (median):     202.030 ms               ┊ GC (median):    4.64%
 Time  (mean ± σ):   202.604 ms ±   2.110 ms  ┊ GC (mean ± σ):  4.49% ± 0.63%

      ▁   █ ▁▄   ▁                                               
  ▆▆▁▆█▆▆▆█▆██▁▁▆█▁▁▁▆▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  201 ms           Histogram: frequency by time          211 ms <

 Memory estimate: 93.84 KiB, allocs estimate: 1551.

This PR:

julia> @benchmark CUDA.@sync gn(x)
BenchmarkTools.Trial: 153 samples with 1 evaluation.
 Range (min  max):  30.696 ms  39.091 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     32.063 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   32.806 ms ±  2.100 ms  ┊ GC (mean ± σ):  0.88% ± 2.26%

          ▁█▄                                                  
  ▄▃▃▃▃▄▅▇███▅▆▅▄▁▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▄▂▂▁▂▁▂▂▂▃▂▃▂▂▂ ▂
  30.7 ms         Histogram: frequency by time        38.8 ms <

 Memory estimate: 34.22 KiB, allocs estimate: 385.

julia> @benchmark CUDA.@sync Zygote.gradient(gn -> sum(gn(x)), gn)
BenchmarkTools.Trial: 39 samples with 1 evaluation.
 Range (min  max):  126.358 ms  138.385 ms  ┊ GC (min  max): 2.34%  2.65%
 Time  (median):     130.110 ms               ┊ GC (median):    4.00%
 Time  (mean ± σ):   130.436 ms ±   2.064 ms  ┊ GC (mean ± σ):  3.75% ± 0.89%

                   █▅▂ ▅▂                                        
  ▅▁▁▅▁▁▁▅▁▁▁▅▁█▁█████▅██▅▁█▅▅▅▁▁▁▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▁▁▁▁▅ ▁
  126 ms           Histogram: frequency by time          138 ms <

 Memory estimate: 99.58 KiB, allocs estimate: 1647.

@CarloLucibello
Copy link
Member

And a potential regression issue. The result is now aligned to the pytorch norm, but that means all old model trained with old Flux norm will be affected.

I think we should mark this as breaking, and tag a breaking release

@CarloLucibello
Copy link
Member

I think it is worth separating the speedup of the PR from the breaking change, since it may take some time to have a breaking release

@CarloLucibello
Copy link
Member

It would be great if we could make the eps change not affect deserialized old models, I don't see how to do it though. I tried to think along the line of adding to the types and extra param acting as a flag, e.g. struct LayerNorm{VERSION} .. but this would break deserialization.

@pxl-th
Copy link
Member Author

pxl-th commented May 2, 2023

Reverted normalise change.

@CarloLucibello CarloLucibello merged commit 1e1da28 into FluxML:master May 5, 2023
@pxl-th pxl-th deleted the norm branch May 5, 2023 11:24
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that I never posted my comment here...

Comment on lines +233 to +235
scale = γ ./ sqrt.(σ² .+ eps)
bias = -scale .* μ .+ β
l.λ.(scale .* x .+ bias)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these split up this way?

If γ is the full size, then I think this will call sqrt N^2 times when it could do N. Avoiding such N^2 is one reason not to fuse things. (Zygote will un-fuse but Diffractor should be lazier at least about +-*/.)

Suggested change
scale = γ ./ sqrt.(σ² .+ eps)
bias = -scale .* μ .+ β
l.λ.(scale .* x .+ bias)
den = sqrt.(σ² .+ eps)
l.λ.(γ .* x ./ den .- γ .* μ ./ den .+ β)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants