Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to use LuxLib #156

Merged
merged 1 commit into from
Sep 25, 2022
Merged

Update to use LuxLib #156

merged 1 commit into from
Sep 25, 2022

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Sep 19, 2022

Fixes #10 #98

Performance Updates

Performance only changes for GroupNorm due to our new KA kernels

Code

using BenchmarkTools, Lux, Random, CUDA, Zygote

for sz in ((32, 32, 128, 16), (64, 64, 64, 16), (128, 128, 32, 16)), groups in (4, 8, 16)
    gn = GroupNorm(sz[3], groups)
    x = randn(Float32, sz...) |> gpu
    ps, st = Lux.setup(Random.default_rng(), gn) .|> gpu

    t1 = @belapsed CUDA.@sync Lux.apply($gn, $x, $ps, $st)
    Zygote.gradient((x, ps) -> sum(Lux.apply(gn, x, ps, st)[1]), x, ps)
    t2 = @belapsed CUDA.@sync Zygote.gradient((x, ps) -> sum(Lux.apply($gn, x, ps, $st)[1]),
                                              $x, $ps)

    println("$(sz), $(groups), $(t1), $(t2)")
end

Timings

Size N Groups Time Forward (New) Time Fwd + Bwd (New) Time Forward (Old) Time Fwd + Bwd (Old)
(32, 32, 128, 16) 4 0.001171762 0.003418343 0.001432062 0.006942468
(32, 32, 128, 16) 8 0.001179864 0.003470376 0.001454761 0.006201573
(32, 32, 128, 16) 16 0.001208639 0.003496287 0.001452037 0.006244667
(64, 64, 64, 16) 4 0.002307972 0.005459583 0.002810693 0.011882327
(64, 64, 64, 16) 8 0.002342824 0.00553138 0.002867196 0.011813671
(64, 64, 64, 16) 16 0.002331648 0.005505468 0.002855742 0.012082637
(128, 128, 32, 16) 4 0.00444357 0.010073916 0.005577281 0.024195243
(128, 128, 32, 16) 8 0.004467874 0.010130907 0.005562752 0.024235608
(128, 128, 32, 16) 16 0.004479398 0.010194882 0.005528737 0.024055475

@codecov
Copy link

codecov bot commented Sep 24, 2022

Codecov Report

Base: 83.69% // Head: 84.65% // Increases project coverage by +0.96% 🎉

Coverage data is based on head (0bcd7d1) compared to base (076d030).
Patch coverage: 78.57% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #156      +/-   ##
==========================================
+ Coverage   83.69%   84.65%   +0.96%     
==========================================
  Files          18       18              
  Lines        1208     1108     -100     
==========================================
- Hits         1011      938      -73     
+ Misses        197      170      -27     
Impacted Files Coverage Δ
src/Lux.jl 100.00% <ø> (ø)
src/autodiff.jl 62.00% <ø> (-0.86%) ⬇️
src/core.jl 100.00% <ø> (ø)
src/deprecated.jl 80.00% <0.00%> (-20.00%) ⬇️
src/layers/recurrent.jl 100.00% <ø> (ø)
src/nnlib.jl 77.77% <ø> (-4.17%) ⬇️
src/layers/containers.jl 90.26% <100.00%> (ø)
src/layers/dropout.jl 100.00% <100.00%> (ø)
src/layers/normalize.jl 98.43% <100.00%> (+10.35%) ⬆️
src/utils.jl 93.75% <100.00%> (+2.57%) ⬆️
... and 2 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@avik-pal
Copy link
Member Author

Blocked on #160

@avik-pal avik-pal force-pushed the ap/luxlib branch 4 times, most recently from 1749196 to 409ca26 Compare September 25, 2022 17:19
@avik-pal avik-pal changed the title [WIP] Update to use LuxLib Update to use LuxLib Sep 25, 2022
@avik-pal avik-pal merged commit 3c0aecf into main Sep 25, 2022
@avik-pal avik-pal deleted the ap/luxlib branch September 25, 2022 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Generalize normalization to work for unconstrained types Suboptimal GroupNorm Implementation on GPUs
1 participant