Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement to the @compact API #584

Merged
merged 5 commits into from
Apr 13, 2024
Merged

Improvement to the @compact API #584

merged 5 commits into from
Apr 13, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Apr 12, 2024

  • Move compact out of experimental
  • Update user-facing tutorials/examples to recommend the compact API. Preserve the abstract explicit layer API for advanced users
  • Allow passing in the parameters optionally, needed for sciml applications

@avik-pal avik-pal force-pushed the ap/promote_compact branch from 1305ab7 to b3e9a5f Compare April 12, 2024 15:31
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 72fc49f Previous: fc591bd Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3661.8125 ns 3653 ns 1.00
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7711 ns 7729.5 ns 1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 13966 ns 14106 ns 0.99
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9826.4 ns 9916 ns 0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8677 ns 8698.75 ns 1.00
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4509.625 ns 4506.5625 ns 1.00
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1982.6 ns 1971.7 ns 1.01
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1648.5874125874127 ns 1648.5314685314686 ns 1.00
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1820.7111111111112 ns 1824.8510638297873 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.83053221288515 ns 179.4728789986092 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17443 ns 17333 ns 1.01
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 18614 ns 18394 ns 1.01
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 35205 ns 35396 ns 0.99
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28363 ns 28633 ns 0.99
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19687 ns 19607 ns 1.00
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17061.5 ns 17032 ns 1.00
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 4753.142857142857 ns 4768.857142857143 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 4836.142857142857 ns 4800.428571428572 ns 1.01
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4829 ns 4800.428571428572 ns 1.01
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1655 ns 1659.1 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 41172020 ns 48367699 ns 0.85
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 78878498 ns 90662926 ns 0.87
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 97703863 ns 97653785.5 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 92936983 ns 107727588 ns 0.86
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 78828640 ns 108249388 ns 0.73
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 12304830.5 ns 12110710 ns 1.02
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 12531653 ns 18210910.5 ns 0.69
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 11675075 ns 18544073 ns 0.63
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 12051370 ns 18466654 ns 0.65
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6429716 ns 6396982 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 1) 106969935.5 ns 106620467.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 759005268 ns 763640160 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 3125408604 ns 2762978316 ns 1.13
vgg16/cpu/reverse/Tracker/(32, 32, 3, 1) 172310631 ns 163403619 ns 1.05
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 1225694918.5 ns 1198898689 ns 1.02
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3819391281 ns 3765767577 ns 1.01
vgg16/cpu/reverse/Flux/(32, 32, 3, 1) 84377856 ns 85276372.5 ns 0.99
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 816856115 ns 840374369 ns 0.97
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2998787436 ns 3347793443 ns 0.90
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 1) 26381644 ns 25080614.5 ns 1.05
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 226022755 ns 232258093 ns 0.97
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 932997268 ns 1019038431 ns 0.92
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1) 24924118 ns 25059892 ns 0.99
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 230912789 ns 236184814.5 ns 0.98
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 996753621 ns 999233211 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 1) 24262998 ns 24562440.5 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 227728053.5 ns 211748278 ns 1.08
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 751884917 ns 712431369.5 ns 1.06
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1148419496.5 ns 1132641019 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1990129747 ns 1842889677.5 ns 1.08
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2279710311 ns 2124383065.5 ns 1.07
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2559376520 ns 2365462129 ns 1.08
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1954493244 ns 1854224454.5 ns 1.05
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 477209829 ns 456010240 ns 1.05
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 479257712 ns 359691595 ns 1.33
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 443087543 ns 359652717.5 ns 1.23
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12072333 ns 11966091 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17998985 ns 18076793 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19301344.5 ns 19252254 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23860866 ns 23893264 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 18020223 ns 18061934 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1177048.5 ns 1158025 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2096454 ns 2075109 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2103272 ns 2081892 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2093289 ns 2071516.5 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 215139 ns 200054 ns 1.08
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 303112 ns 298147 ns 1.02
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 275331 ns 273642 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 367351 ns 365467.5 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 415050 ns 414444.5 ns 1.00
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275491 ns 275154 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 417454 ns 410968 ns 1.02
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 89971.5 ns 89371.5 ns 1.01
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 91751 ns 89357.5 ns 1.03
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87803 ns 87022 ns 1.01
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104775 ns 104495 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 197328424.5 ns 197534448 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 387363381 ns 372121710 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 433622724 ns 403011132 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 498380590 ns 482377826 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 385257573 ns 371969112 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 340926248 ns 334264188.5 ns 1.02
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 64683242 ns 59961589 ns 1.08
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 70828052 ns 53644168 ns 1.32
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 62517317 ns 56527647 ns 1.11
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 28838255 ns 29291598.5 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19628137 ns 19730534 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19708854.5 ns 19802579.5 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23786882 ns 23663463 ns 1.01
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24221193 ns 24349385 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19759687 ns 19922312 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6615741.5 ns 6620742 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6613753 ns 6614070 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6581654 ns 6529781 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal changed the title [WIP] Improvement to the @compact API Improvement to the @compact API Apr 12, 2024
@avik-pal avik-pal force-pushed the ap/promote_compact branch from 944bb0d to 0046a85 Compare April 12, 2024 22:29
@avik-pal avik-pal force-pushed the ap/promote_compact branch from 0046a85 to 72fc49f Compare April 12, 2024 22:32
@avik-pal avik-pal merged commit 9f1d902 into main Apr 13, 2024
18 of 21 checks passed
@avik-pal avik-pal deleted the ap/promote_compact branch April 13, 2024 00:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant