Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimization] Implicit gemm rewrite #2545

Merged
merged 69 commits into from
Nov 29, 2024
Merged
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
116a582
Add SPIR-V backend
wingertge Oct 19, 2024
5eb4dfe
Update READMEs
wingertge Oct 19, 2024
6f8f589
More doc updates and testing
wingertge Oct 19, 2024
636b280
Ensure SPIR-V tests actually run
wingertge Oct 19, 2024
325f659
Disable SPIR-V tests if WGPU is disabled in general
wingertge Oct 20, 2024
00a3036
Disable SPIR-V tests on MacOS
wingertge Oct 20, 2024
08f212f
Update `cubecl`
wingertge Oct 20, 2024
1cf435a
Merge branch 'main' into feat/wgpu-spirv-backend
wingertge Oct 20, 2024
44f116e
Disable SPIR-V CI tests until I can figure out what causes the segfau…
wingertge Oct 20, 2024
351cf07
Merge branch 'main' into feat/wgpu-spirv-backend
wingertge Oct 20, 2024
bf81855
Reenable SPIR-V tests to see if fixes work
wingertge Oct 20, 2024
2c124ae
Temporarily point to fork to check fixes
wingertge Oct 20, 2024
9eecef1
Revert to main
wingertge Oct 20, 2024
19a9f3f
Disable SPIR-V tests on CI
wingertge Oct 20, 2024
82d4a58
More conv2d benches
wingertge Oct 20, 2024
27b93a5
Optimize implicit GEMM
wingertge Nov 11, 2024
3bfdfd6
Ensure weight isn't vectorized if it's loaded directly
wingertge Nov 11, 2024
7535046
Merge branch 'main' into opt/implicit_gemm
wingertge Nov 11, 2024
95822d0
Implement tf32 for implicit GEMM
wingertge Nov 11, 2024
d472a17
Fixes
wingertge Nov 14, 2024
efe36e1
Merge branch 'main' into opt/implicit_gemm
wingertge Nov 14, 2024
7543f7b
Make reduce checked since we're still getting segfaults
wingertge Nov 16, 2024
b17713e
Merge branch 'main' into opt/implicit_gemm
wingertge Nov 16, 2024
07bd12f
Make bicubic interp checked
wingertge Nov 16, 2024
ea9e872
Undo direct weight loader because it was backfiring
wingertge Nov 16, 2024
449e464
Undo version change
wingertge Nov 16, 2024
a95b692
Use git version of cubecl
wingertge Nov 16, 2024
6577812
Update cubecl
wingertge Nov 17, 2024
ca64d18
Use select to ensure correctness in bilinear interpolate
wingertge Nov 17, 2024
83adb2f
Disable reduce_dim_subcube if warp size isn't known, to prevent poten…
wingertge Nov 18, 2024
8dd1a39
Initial commit
wingertge Nov 18, 2024
5f9ce2f
Broken version
wingertge Nov 20, 2024
8971d23
a
wingertge Nov 20, 2024
3157e93
Merge branch 'main' into feat/conv2d-benches
wingertge Nov 20, 2024
34e460a
Use more descriptive naming
wingertge Nov 20, 2024
b1f8af1
Add custom NCHW to NHWC kernel to speed up implicit GEMM
wingertge Nov 24, 2024
83601aa
Tune block size
wingertge Nov 24, 2024
3b65d28
Cleanup
wingertge Nov 24, 2024
8b3b0df
Simplify swizzle
wingertge Nov 24, 2024
47e9a83
Migrate
wingertge Nov 25, 2024
42c73b1
Merge branch 'opt/conv-custom-transpose' into opt/implicit-gemm-rewrite
wingertge Nov 25, 2024
3257cbd
Merge branch 'main' into opt/implicit-gemm-rewrite
wingertge Nov 25, 2024
6853af4
Merge branch 'feat/conv2d-benches' into opt/implicit-gemm-rewrite
wingertge Nov 25, 2024
55e207b
Revert default conv to 16x16
wingertge Nov 25, 2024
ad307c8
Check k bounds
wingertge Nov 25, 2024
d488193
Attempt fixes
wingertge Nov 25, 2024
41f2069
Update matmul
nathanielsimard Nov 25, 2024
d8d8ccf
Fix bias loading, refactor
wingertge Nov 26, 2024
25a8adb
Refactor and documentation
wingertge Nov 26, 2024
774e49f
Refactor
wingertge Nov 26, 2024
7b50d93
Merge branch 'main' into opt/implicit-gemm-rewrite
wingertge Nov 26, 2024
378a77c
Revert accidental changes
wingertge Nov 26, 2024
70b4532
Add newline
wingertge Nov 26, 2024
55cfbb5
Update cubecl
wingertge Nov 27, 2024
100780e
Merge remote-tracking branch 'upstream/matmulupdate' into opt/implici…
wingertge Nov 27, 2024
3f17791
Vectorize SMEM for `implicit_gemm`
wingertge Nov 27, 2024
ac8ea65
Merge branch 'main' into opt/implicit-gemm-rewrite
wingertge Nov 27, 2024
7d8d119
Update cubecl
wingertge Nov 27, 2024
135d256
Temp fix for cubecl strategy
wingertge Nov 27, 2024
5565f8e
Fix deform_conv_transpose2d
wingertge Nov 27, 2024
86dd009
Cleanup and generic typing
wingertge Nov 28, 2024
253f85b
Update cubecl
wingertge Nov 28, 2024
9ce83c8
Merge branch 'main' into opt/implicit-gemm-rewrite
wingertge Nov 28, 2024
266c07b
Cleanup
wingertge Nov 28, 2024
cec10e1
Fix clippy
wingertge Nov 28, 2024
927e026
Merge branch 'main' into opt/implicit-gemm-rewrite
wingertge Nov 28, 2024
fdf73c9
Make conv2d bench more consistent
wingertge Nov 28, 2024
41a7f06
Merge branch 'main' into opt/implicit-gemm-rewrite
nathanielsimard Nov 29, 2024
4eda62a
Merge branch 'main' into opt/implicit-gemm-rewrite
wingertge Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Temp fix for cubecl strategy
wingertge committed Nov 27, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 135d256c3ffcee752c72d359598a9fe5a859efd5
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ use burn_tensor::{
use cubecl::{calculate_cube_count_elemwise, linalg::matmul, prelude::*};

use crate::{
kernel::into_contiguous,
kernel::{into_contiguous, matmul::cube_strategy},
ops::{numeric::empty_device, reshape, swap_dims},
tensor::JitTensor,
FloatElement, IntElement, JitBackend, JitRuntime,
@@ -298,7 +298,7 @@ fn execute<R: JitRuntime, E: FloatElement>(
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));

matmul::launch_ref::<R, E>(
&Default::default(),
&cube_strategy::<R>(&client),
&client,
weight.as_handle_ref(),
columns.as_handle_ref(),
30 changes: 28 additions & 2 deletions crates/burn-jit/src/kernel/matmul/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
use super::{init_matmul_output, matmul_simple};
use crate::{tensor::JitTensor, FloatElement, JitRuntime};
use burn_tensor::Shape;
use cubecl::prelude::*;
use cubecl::{
ir::{Elem, FloatKind},
linalg::matmul::Strategy,
prelude::*,
Feature,
};

#[cfg(feature = "autotune")]
use super::matmul_autotune;
@@ -49,8 +54,9 @@ pub fn matmul<R: JitRuntime, E: FloatElement>(
let out = init_matmul_output::<R, E>(&lhs, &rhs);

let client = &lhs.client;

cubecl::linalg::matmul::launch_ref::<R, E>(
&Default::default(),
&cube_strategy::<R>(client),
client,
lhs.as_handle_ref(),
rhs.as_handle_ref(),
@@ -63,6 +69,26 @@ pub fn matmul<R: JitRuntime, E: FloatElement>(
}
}

pub(crate) fn cube_strategy<R: JitRuntime>(
client: &ComputeClient<R::Server, R::Channel>,
) -> Strategy {
// TODO: Replace with auto option once cubecl has one
let cmma_available = client.properties().feature_enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
n: 16,
});
let plane_available = client.properties().feature_enabled(Feature::Plane);
match (cmma_available, plane_available) {
(true, _) => Strategy::Accelerated,
(false, true) => Strategy::PlaneMma,
_ => Strategy::Tiling2D(Default::default()),
}
}

pub(crate) fn simple_cube_count(
lhs_shape: &Shape,
rhs_shape: &Shape,
8 changes: 6 additions & 2 deletions crates/burn-jit/src/kernel/matmul/tune/base.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,10 @@ use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTu

use crate::{
element::FloatElement,
kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform},
kernel::{
matmul::{cube_strategy, utils::init_matmul_output},
prng::random_like_uniform,
},
ops::numeric::empty_device,
tensor::JitTensor,
tune_key::JitAutotuneKey,
@@ -149,8 +152,9 @@ matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| {
matmul_tune_ops!(
MatmulCube,
|lhs: JitTensor<R>, rhs: JitTensor<R>, out: JitTensor<R>| {
let strategy = cube_strategy::<R>(&lhs.client);
cubecl::linalg::matmul::launch_ref::<R, E>(
&Default::default(),
&strategy,
&lhs.client,
lhs.as_handle_ref(),
rhs.as_handle_ref(),