Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

general code cleanup to allow faster broadcasting implementations #86

Merged
merged 47 commits into from
Jul 21, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Jul 14, 2024

Main Description

This doesn't really attempt to make things faster. We do see speed improvements from:

  • broadcast-ing rewritten as loops for CPUs for uniform broadcasting
  • more rrules for common operations
  • general type-stability improvements

Mostly to check why #85 is stalling CUDA -- fixed now.

We get rid of FastBroadcast.jl (for now) and replace it with loops for the CPU case. The idea here is to bring it back in a later PR, which solely focuses on performance in cases where we cannot use LoopVectorization.

My intuition was Enzyme would be great for loops but current Zygote with rrules absolutely destroys it for some of the operations -- at least dropout I checked. We need to add rules to help out enzyme here (especially now that it is the recommended autodiff backend).

TODOs

  • dropout
    • impl
    • test type stability
    • GPU tests fail due to evals for type stability
  • downstream testing
    • type stability regression
    • enzyme support regression
  • normalization impl update
  • affine normalize fuse operations -- layernorm has a very general structure making it very hard to write a loop for it.
    • groupnorm
    • batchnorm -- deferred to later
    • instancenorm -- deferred to latersome refactor
  • rework the fused impls using get_device_type to minimize code in extensions.
    • fused_conv
    • fused_dense
  • testing type stability
    • BN
    • GN
    • IN
    • LN
  • Add Enzyme Tests
    • dropout
    • alpha_dropout
    • conv
    • dense
    • batch_norm
    • group_norm
    • layer_norm
    • instance_norm
  • rework broadcasting function and replace more smaller components directly
  • doc strings for bias_activation

Before Merging

  • remove installing master for packages from runtests
  • remove parallel testing on gpus
  • skip enzyme float16 on windows
  • remove -g2 from downstream testing

@avik-pal avik-pal force-pushed the ap/patches branch 19 times, most recently from d765b03 to 5c287fd Compare July 15, 2024 04:49
@avik-pal avik-pal force-pushed the ap/patches branch 2 times, most recently from 10298c6 to 86b4475 Compare July 15, 2024 06:07
@avik-pal avik-pal force-pushed the ap/patches branch 3 times, most recently from 80363da to 27d1475 Compare July 20, 2024 22:38
@avik-pal avik-pal force-pushed the ap/patches branch 4 times, most recently from f3840f8 to 8b23dbe Compare July 21, 2024 01:21
@avik-pal avik-pal force-pushed the ap/patches branch 5 times, most recently from 5311500 to a7de3f9 Compare July 21, 2024 05:02
@avik-pal avik-pal force-pushed the ap/patches branch 2 times, most recently from b4d1ab8 to 0ec4ecd Compare July 21, 2024 19:24
@avik-pal avik-pal merged commit e52d6b1 into main Jul 21, 2024
20 of 26 checks passed
@avik-pal avik-pal deleted the ap/patches branch July 21, 2024 21:21
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant