-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: improve batched jacobian #848
Conversation
[skip benchmarks]
Benchmark Results (ASV)
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 2b795ce | Previous: b6171a6 | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3741.875 ns |
3675.625 ns |
1.02 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
6728.285714285715 ns |
8093.5 ns |
0.83 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
21050 ns |
21210 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9856.4 ns |
9748.2 ns |
1.01 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
9185 ns |
9167.2 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4667.5 ns |
4470.875 ns |
1.04 |
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) |
4988.125 ns |
4956.875 ns |
1.01 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1046.3151515151515 ns |
2373.4 ns |
0.44 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1062.6564417177915 ns |
2270.3 ns |
0.47 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1793.1967213114754 ns |
1790.017543859649 ns |
1.00 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
180.0140056022409 ns |
179.70239774330042 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17312 ns |
17562.5 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
12673 ns |
24787 ns |
0.51 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
36298 ns |
38393 ns |
0.95 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
28804 ns |
29025 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
20047 ns |
21590 ns |
0.93 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
17183 ns |
17092 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) |
25788 ns |
25648 ns |
1.01 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
1415.6 ns |
20248 ns |
0.06991307783484788 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
1445.6 ns |
14448 ns |
0.10 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4889.142857142857 ns |
4846.285714285715 ns |
1.01 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1660 ns |
1659.2 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
117792312 ns |
77690170 ns |
1.52 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
48750528 ns |
76782338 ns |
0.63 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
190814626.5 ns |
155414925 ns |
1.23 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
205639818.5 ns |
167638289.5 ns |
1.23 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
210638995 ns |
142842293.5 ns |
1.47 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
12019986 ns |
11557321.5 ns |
1.04 |
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) |
182404684 ns |
199234044.5 ns |
0.92 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
5924764 ns |
15528408.5 ns |
0.38 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
5917701 ns |
15540189 ns |
0.38 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
47268995 ns |
30661456 ns |
1.54 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6384243 ns |
6376663 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
1023131234 ns |
1064055959.5 ns |
0.96 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
2847075600 ns |
2970205700 ns |
0.96 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) |
216104427 ns |
178121161 ns |
1.21 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
1374030570 ns |
1320655778 ns |
1.04 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
4091197249 ns |
3516351096 ns |
1.16 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) |
401783545 ns |
344809509 ns |
1.17 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
1515463169 ns |
1431616033 ns |
1.06 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
4404884213 ns |
4058579611 ns |
1.09 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) |
429373736.5 ns |
436008182 ns |
0.98 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
417449321 ns |
381866129 ns |
1.09 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
923421065.5 ns |
905256978 ns |
1.02 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) |
54364697.5 ns |
54567006.5 ns |
1.00 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
402764024.5 ns |
382293897 ns |
1.05 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
922893504 ns |
870357323.5 ns |
1.06 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) |
29382300 ns |
54472914.5 ns |
0.54 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
534852007 ns |
551222188 ns |
0.97 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
1447265873 ns |
1387168504 ns |
1.04 |
vgg16/cpu/forward/Flux/(32, 32, 3, 2) |
166049072 ns |
164122645 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1284285410.5 ns |
1180058919 ns |
1.09 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1573918068 ns |
1610297742 ns |
0.98 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2266896153 ns |
2289727615.5 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2524001196 ns |
2640437136 ns |
0.96 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
2264631726 ns |
2193753011.5 ns |
1.03 |
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) |
2425508620 ns |
2122924359 ns |
1.14 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
286158756 ns |
282003619 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
284746326 ns |
286261947 ns |
0.99 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
458549355.5 ns |
437257287 ns |
1.05 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
11691323 ns |
11806435 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
15153808 ns |
34527638 ns |
0.44 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
16445650 ns |
16364743 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
21065487 ns |
21004093 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
15201542 ns |
15284140 ns |
0.99 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1146511 ns |
1148921.5 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) |
19185166 ns |
35777843.5 ns |
0.54 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
1846077 ns |
4500694 ns |
0.41 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
1843918.5 ns |
4506207 ns |
0.41 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
1979827 ns |
2045686 ns |
0.97 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
193021 ns |
196300 ns |
0.98 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
375492 ns |
378068 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
208168.5 ns |
314462 ns |
0.66 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
374379 ns |
377972 ns |
0.99 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
524910 ns |
520691 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
290211.5 ns |
289716 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
404816 ns |
401777 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) |
425615 ns |
425321 ns |
1.00 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
56655 ns |
157406 ns |
0.36 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
60202 ns |
162456 ns |
0.37 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
92022 ns |
91953 ns |
1.00 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104375 ns |
104407 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
339878291.5 ns |
297649242 ns |
1.14 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
269324213 ns |
287837994 ns |
0.94 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
588762738.5 ns |
545531151.5 ns |
1.08 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
685038653 ns |
655809148 ns |
1.04 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
682011247 ns |
554893727 ns |
1.23 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
325025762.5 ns |
316084028.5 ns |
1.03 |
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) |
622892766 ns |
583442251.5 ns |
1.07 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
38430370.5 ns |
40159465 ns |
0.96 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
38442335 ns |
40173961.5 ns |
0.96 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
111238041.5 ns |
96663497 ns |
1.15 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
28030484 ns |
28321531 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
21176816 ns |
21078472 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19294631 ns |
17393481 ns |
1.11 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
22824568.5 ns |
22657728 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
28022340 ns |
28019412 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19353731 ns |
19298592.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) |
20894939 ns |
20720819 ns |
1.01 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6585712.5 ns |
6086608 ns |
1.08 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6540915 ns |
6101998 ns |
1.07 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6561562 ns |
6509879.5 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
e25b6ee
to
29d28ba
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #848 +/- ##
==========================================
+ Coverage 95.10% 95.34% +0.24%
==========================================
Files 58 58
Lines 2840 2859 +19
==========================================
+ Hits 2701 2726 +25
+ Misses 139 133 -6 ☔ View full report in Codecov by Sentry. |
batched_jacobian
not using lux layers are now differentiable