Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimization] Implicit gemm rewrite #2545

Merged
merged 69 commits into from
Nov 29, 2024

Conversation

wingertge
Copy link
Contributor

@wingertge wingertge commented Nov 26, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Requires tracel-ai/cubecl#309 to land first

Changes

Adds a brand new implicit GEMM implementation that uses the matmul primitives in cubecl. This is slower for small k sizes, but much faster for large ones, and more flexible. I'm keeping the current implementation because it's significantly faster for certain sizes, and uses a significantly different loader strategy (loading only within each warp, which skips cross warp syncs).

Adds a number of new convolution benchmarks to test performance with different sizes and characteristics.

Testing

All non-group tests pass, and CRAFT has the expected output with all layers using the new implicit GEMM. This tests many different and relatively large layers. Adds two new regression tests for bugs discovered during implementation.

@wingertge wingertge marked this pull request as ready for review November 27, 2024 18:14
Copy link

codecov bot commented Nov 27, 2024

Codecov Report

Attention: Patch coverage is 20.72264% with 746 lines in your changes missing coverage. Please review.

Project coverage is 81.85%. Comparing base (42e7c1f) to head (4eda62a).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...it/src/kernel/conv/conv2d/gemm/homogeneous/base.rs 1.81% 270 Missing ⚠️
...tes/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs 30.93% 125 Missing ⚠️
...n-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs 0.00% 103 Missing ⚠️
...urn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs 0.00% 83 Missing ⚠️
...n-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs 0.00% 64 Missing ⚠️
.../burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs 31.57% 26 Missing ⚠️
...s/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs 0.00% 23 Missing ⚠️
...urn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs 0.00% 17 Missing ⚠️
...rates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs 0.00% 13 Missing ⚠️
crates/burn-jit/src/kernel/matmul/base.rs 64.00% 9 Missing ⚠️
... and 7 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2545      +/-   ##
==========================================
- Coverage   82.38%   81.85%   -0.53%     
==========================================
  Files         826      834       +8     
  Lines      105711   106589     +878     
==========================================
+ Hits        87087    87247     +160     
- Misses      18624    19342     +718     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks awesome! Feels great to reuse a lot of components. There are still some improvements that we can make in our "design paradigm", especially in how we pass around the config. But this is beyond the scope of this PR.

I have a few comments, but it would also be great for @louisfd to review.

crates/burn-jit/src/kernel/conv/conv2d/base.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs Outdated Show resolved Hide resolved
Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow adding a sync_units after the for loop improved performance for the matmul. I think it makes sure all units in a plane are sync following the loop which improve the execution of following operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll benchmark it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the benchmark results are very odd. I tried 4 different ways of syncing, and the one I used initially was overall the fastest for CUDA, but adding a sync before the load and after the loop was significantly faster for SPIR-V. Only syncing where absolutely needed was the slowest by far. Very odd behaviour. I'll stick with lots of sync for now because it's only 10-15% slower than the current implementation on CUDA (and there's margin of error), but 30% faster on SPIR-V.

crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs Outdated Show resolved Hide resolved
crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/base.rs Outdated Show resolved Hide resolved
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM we can merge after the conflicts are resolved!

@wingertge
Copy link
Contributor Author

Done 👍

@nathanielsimard nathanielsimard merged commit a5624c1 into tracel-ai:main Nov 29, 2024
10 of 11 checks passed
@wingertge wingertge deleted the opt/implicit-gemm-rewrite branch November 29, 2024 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants