-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvement to the @compact
API
#584
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Member
avik-pal
commented
Apr 12, 2024
•
edited
Loading
edited
- Move compact out of experimental
- Update user-facing tutorials/examples to recommend the compact API. Preserve the abstract explicit layer API for advanced users
- Allow passing in the parameters optionally, needed for sciml applications
avik-pal
force-pushed
the
ap/promote_compact
branch
from
April 12, 2024 15:31
1305ab7
to
b3e9a5f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Benchmark suite | Current: 72fc49f | Previous: fc591bd | Ratio |
---|---|---|---|
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) |
3661.8125 ns |
3653 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) |
7711 ns |
7729.5 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) |
13966 ns |
14106 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) |
9826.4 ns |
9916 ns |
0.99 |
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) |
8677 ns |
8698.75 ns |
1.00 |
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) |
4509.625 ns |
4506.5625 ns |
1.00 |
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) |
1982.6 ns |
1971.7 ns |
1.01 |
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) |
1648.5874125874127 ns |
1648.5314685314686 ns |
1.00 |
Dense(2 => 2)/cpu/forward/Flux/(2, 128) |
1820.7111111111112 ns |
1824.8510638297873 ns |
1.00 |
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) |
179.83053221288515 ns |
179.4728789986092 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) |
17443 ns |
17333 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) |
18614 ns |
18394 ns |
1.01 |
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) |
35205 ns |
35396 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) |
28363 ns |
28633 ns |
0.99 |
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) |
19687 ns |
19607 ns |
1.00 |
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) |
17061.5 ns |
17032 ns |
1.00 |
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) |
4753.142857142857 ns |
4768.857142857143 ns |
1.00 |
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) |
4836.142857142857 ns |
4800.428571428572 ns |
1.01 |
Dense(20 => 20)/cpu/forward/Flux/(20, 128) |
4829 ns |
4800.428571428572 ns |
1.01 |
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) |
1655 ns |
1659.1 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) |
41172020 ns |
48367699 ns |
0.85 |
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) |
78878498 ns |
90662926 ns |
0.87 |
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) |
97703863 ns |
97653785.5 ns |
1.00 |
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) |
92936983 ns |
107727588 ns |
0.86 |
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) |
78828640 ns |
108249388 ns |
0.73 |
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) |
12304830.5 ns |
12110710 ns |
1.02 |
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) |
12531653 ns |
18210910.5 ns |
0.69 |
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) |
11675075 ns |
18544073 ns |
0.63 |
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) |
12051370 ns |
18466654 ns |
0.65 |
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) |
6429716 ns |
6396982 ns |
1.01 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) |
106969935.5 ns |
106620467.5 ns |
1.00 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) |
759005268 ns |
763640160 ns |
0.99 |
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) |
3125408604 ns |
2762978316 ns |
1.13 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) |
172310631 ns |
163403619 ns |
1.05 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) |
1225694918.5 ns |
1198898689 ns |
1.02 |
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) |
3819391281 ns |
3765767577 ns |
1.01 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) |
84377856 ns |
85276372.5 ns |
0.99 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) |
816856115 ns |
840374369 ns |
0.97 |
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) |
2998787436 ns |
3347793443 ns |
0.90 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) |
26381644 ns |
25080614.5 ns |
1.05 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) |
226022755 ns |
232258093 ns |
0.97 |
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) |
932997268 ns |
1019038431 ns |
0.92 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) |
24924118 ns |
25059892 ns |
0.99 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) |
230912789 ns |
236184814.5 ns |
0.98 |
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) |
996753621 ns |
999233211 ns |
1.00 |
vgg16/cpu/forward/Flux/(32, 32, 3, 1) |
24262998 ns |
24562440.5 ns |
0.99 |
vgg16/cpu/forward/Flux/(32, 32, 3, 16) |
227728053.5 ns |
211748278 ns |
1.08 |
vgg16/cpu/forward/Flux/(32, 32, 3, 64) |
751884917 ns |
712431369.5 ns |
1.06 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) |
1148419496.5 ns |
1132641019 ns |
1.01 |
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) |
1990129747 ns |
1842889677.5 ns |
1.08 |
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) |
2279710311 ns |
2124383065.5 ns |
1.07 |
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) |
2559376520 ns |
2365462129 ns |
1.08 |
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) |
1954493244 ns |
1854224454.5 ns |
1.05 |
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) |
477209829 ns |
456010240 ns |
1.05 |
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) |
479257712 ns |
359691595 ns |
1.33 |
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) |
443087543 ns |
359652717.5 ns |
1.23 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) |
12072333 ns |
11966091 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) |
17998985 ns |
18076793 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) |
19301344.5 ns |
19252254 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) |
23860866 ns |
23893264 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) |
18020223 ns |
18061934 ns |
1.00 |
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) |
1177048.5 ns |
1158025 ns |
1.02 |
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) |
2096454 ns |
2075109 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) |
2103272 ns |
2081892 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) |
2093289 ns |
2071516.5 ns |
1.01 |
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) |
215139 ns |
200054 ns |
1.08 |
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) |
303112 ns |
298147 ns |
1.02 |
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) |
275331 ns |
273642 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) |
367351 ns |
365467.5 ns |
1.01 |
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) |
415050 ns |
414444.5 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) |
275491 ns |
275154 ns |
1.00 |
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) |
417454 ns |
410968 ns |
1.02 |
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) |
89971.5 ns |
89371.5 ns |
1.01 |
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) |
91751 ns |
89357.5 ns |
1.03 |
Dense(200 => 200)/cpu/forward/Flux/(200, 128) |
87803 ns |
87022 ns |
1.01 |
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) |
104775 ns |
104495 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) |
197328424.5 ns |
197534448 ns |
1.00 |
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) |
387363381 ns |
372121710 ns |
1.04 |
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) |
433622724 ns |
403011132 ns |
1.08 |
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) |
498380590 ns |
482377826 ns |
1.03 |
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) |
385257573 ns |
371969112 ns |
1.04 |
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) |
340926248 ns |
334264188.5 ns |
1.02 |
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) |
64683242 ns |
59961589 ns |
1.08 |
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) |
70828052 ns |
53644168 ns |
1.32 |
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) |
62517317 ns |
56527647 ns |
1.11 |
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) |
28838255 ns |
29291598.5 ns |
0.98 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) |
19628137 ns |
19730534 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) |
19708854.5 ns |
19802579.5 ns |
1.00 |
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) |
23786882 ns |
23663463 ns |
1.01 |
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) |
24221193 ns |
24349385 ns |
0.99 |
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) |
19759687 ns |
19922312 ns |
0.99 |
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) |
6615741.5 ns |
6620742 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) |
6613753 ns |
6614070 ns |
1.00 |
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) |
6581654 ns |
6529781 ns |
1.01 |
This comment was automatically generated by workflow using github-action-benchmark.
avik-pal
changed the title
[WIP] Improvement to the
Improvement to the Apr 12, 2024
@compact
API@compact
API
avik-pal
force-pushed
the
ap/promote_compact
branch
from
April 12, 2024 22:29
944bb0d
to
0046a85
Compare
avik-pal
force-pushed
the
ap/promote_compact
branch
from
April 12, 2024 22:32
0046a85
to
72fc49f
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.