From 788ed9ff0072d28048977779c679793e07e9c5d6 Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Thu, 10 Oct 2024 15:30:06 +0200 Subject: [PATCH 1/3] Pad implicit GEMM to allow using it with any input shape --- .../src/kernel/conv/conv2d/implicit_gemm.rs | 204 ++++++++++-------- 1 file changed, 115 insertions(+), 89 deletions(-) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 906b03fa9a..7b51f176b9 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -7,16 +7,13 @@ use cubecl::{cube, prelude::*, Compiler, CubeCount, CubeDim, Feature}; use half::f16; use crate::{ - kernel::{into_contiguous, slice_assign}, - ops::{ - numeric::{empty_device, zeros_device}, - permute, reshape, - }, + kernel::{into_contiguous, slice}, + ops::{numeric::empty_device, permute, reshape}, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime, }; -/// Perform a 2D convolution using the implicit GEMM algorithm. Requires `cmma` to be available. +/// Perform a 2D convolution using the implicit GEMM algorithm. Requries `cmma` to be available. /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel @@ -29,9 +26,10 @@ pub fn conv2d_implicit_gemm( bias: Option>, options: ConvOptions<2>, ) -> JitTensor { - let [batch_size, mut in_channels, height, width] = input.shape.dims(); + let [batch_size, in_channels, height, width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); - let padded_channels = padded_in_channels(in_channels, kernel_h, kernel_w); + let padded_in_channels = padded_in_channels(in_channels, kernel_h, kernel_w); + let padded_out_channels = out_channels.div_ceil(16) * 16; let out_h = calculate_conv_output_size( kernel_h, @@ -48,49 +46,30 @@ pub fn conv2d_implicit_gemm( width, ); + let padded_batch_size = padded_batch_size(batch_size, out_h, out_w); + println!("Padded batches: {padded_batch_size}"); + if !can_do_implicit_gemm(&input, &weight, &options, out_h, out_w) { panic!( "Requirements for implicit GEMM not met: - CMMA must be available -- `batch_size * out_h * out_w` must be divisible by 16 -- `out_channels` must be divisible by 16 -- `in_channels * kernel_h * kernel_w` must be divisible by 16 - `groups` must be 1 " ); } - let (input, weight) = if padded_channels != in_channels { - let input = permute(input, &[0, 2, 3, 1]); - let weight = permute(weight, &[0, 2, 3, 1]); - - let in_shape = Shape::new([batch_size, height, width, padded_channels]); - let in_slice = &[0..batch_size, 0..height, 0..width, 0..in_channels]; - let new_input = zeros_device(input.client.clone(), input.device.clone(), in_shape); - let new_input = slice_assign(new_input, in_slice, input); - - let weight_shape = Shape::new([out_channels, kernel_h, kernel_w, padded_channels]); - let weight_slice = &[0..out_channels, 0..kernel_h, 0..kernel_w, 0..in_channels]; - let new_weight = zeros_device(weight.client.clone(), weight.device.clone(), weight_shape); - let new_weight = slice_assign(new_weight, weight_slice, weight); - - in_channels = padded_channels; - (new_input, new_weight) - } else { - // channel last is more efficient even with the extra into_contiguous kernel - let input = into_contiguous(permute(input, &[0, 2, 3, 1])); - let weight = into_contiguous(permute(weight, &[0, 2, 3, 1])); - (input, weight) - }; + let input = into_contiguous(permute(input, &[0, 2, 3, 1])); + let weight = into_contiguous(permute(weight, &[0, 2, 3, 1])); - let out_shape = Shape::new([batch_size, out_h, out_w, out_channels]); - let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape); + let out_shape = Shape::new([padded_batch_size, out_h, out_w, padded_out_channels]); + let out = empty_device(input.client.clone(), input.device.clone(), out_shape); // Implicit GEMM matrix size - let gemm_m = (batch_size * out_h * out_w) as u32; - let gemm_n = out_channels as u32; - let gemm_k = (in_channels * kernel_h * kernel_w) as u32; - let slice_size = kernel_h * kernel_w * in_channels; + let gemm_m = (padded_batch_size * out_h * out_w) as u32; + let gemm_n = padded_out_channels as u32; + let gemm_k = (padded_in_channels * kernel_h * kernel_w) as u32; + + let slice_size = kernel_h * kernel_w * padded_in_channels; let (cmma_m, cmma_n, cmma_k) = find_cmma_size::(&input.device, gemm_m, gemm_k, gemm_n).unwrap(); @@ -104,21 +83,22 @@ pub fn conv2d_implicit_gemm( let warp_size = 32; let warps_per_cube = (cube_dim_y * cube_dim_x) / warp_size; - let max_vectorization = u8::MAX; // TODO: Fetch this based on backend + let supported_vecs = R::supported_line_sizes(); let input_elems_per_thread = input_tile_size / warp_size; - let input_vectorization = u8::min( - find_common_vec(in_channels, input_elems_per_thread), - max_vectorization, - ); + let input_vectorization = find_common_vec(in_channels, input_elems_per_thread, supported_vecs); let weight_elems_per_thread = weight_tile_size / warp_size; - let weight_vectorization = u8::min(weight_elems_per_thread as u8, max_vectorization); + let weight_vectorization = + find_common_vec(in_channels, weight_elems_per_thread, supported_vecs); let settings = GemmSettings { cmma_m, cmma_n, cmma_k, + check_m: batch_size != padded_batch_size, + check_n: out_channels != padded_out_channels, + check_k: in_channels != padded_in_channels, warp_size, warps_per_cube, cube_dim_x, @@ -153,6 +133,7 @@ pub fn conv2d_implicit_gemm( ScalarArg::new(gemm_n), ScalarArg::new(gemm_k), ScalarArg::new(slice_size as u32), + ScalarArg::new(padded_in_channels as u32), ScalarArg::new(out_h as u32), ScalarArg::new(out_w as u32), ), @@ -165,7 +146,7 @@ pub fn conv2d_implicit_gemm( ScalarArg::new(options.dilation[1] as u32), ), settings, - KernelSettings { + ConvSettings { kernel_h: kernel_h as u32, kernel_w: kernel_w as u32, padding_h: options.padding[0] as i32, @@ -174,22 +155,26 @@ pub fn conv2d_implicit_gemm( }, ); + let mut out = slice(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); + if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, 1, 1, out_channels])); out = JitBackend::::float_add(out, bias); } + // Reset to NCHW permute(out, &[0, 3, 1, 2]) } -fn find_common_vec(channels: usize, elems_per_thread: u32) -> u8 { +fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 { let channels = channels as u8; let elems_per_thread = elems_per_thread as u8; let smaller = u8::min(channels, elems_per_thread); (1..=smaller) .rev() + .filter(|it| supported_vecs.contains(it)) .find(|vec| channels % *vec == 0 && elems_per_thread % *vec == 0) - .unwrap() + .unwrap_or(1) } #[derive(CubeLaunch)] @@ -209,6 +194,8 @@ struct Dimensions { gemm_k: u32, slice_size: u32, + padded_channels: u32, + out_h: u32, out_w: u32, } @@ -219,6 +206,10 @@ struct GemmSettings { cmma_n: u32, cmma_k: u32, + check_m: bool, + check_n: bool, + check_k: bool, + warp_size: u32, warps_per_cube: u32, @@ -226,7 +217,7 @@ struct GemmSettings { } #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -struct KernelSettings { +struct ConvSettings { kernel_h: u32, kernel_w: u32, padding_h: i32, @@ -259,7 +250,7 @@ fn implicit_gemm_kernel( dims: &Dimensions, args: &ConvArgs, #[comptime] gemm_settings: GemmSettings, - #[comptime] kernel_settings: KernelSettings, + #[comptime] kernel_settings: ConvSettings, ) { let GemmSettings { cmma_m, @@ -346,35 +337,34 @@ fn make_matrices( .. } = gemm_settings; - let matrices = unsafe { - Matrices:: { - a: Matrix::::uninitialized( + Matrices:: { + a: unsafe { + Matrix::::uninitialized( MatrixIdent::A, cmma_m, cmma_n, cmma_k, MatrixLayout::RowMajor, - ), - b: Matrix::::uninitialized( + ) + }, + b: unsafe { + Matrix::::uninitialized( MatrixIdent::B, cmma_m, cmma_n, cmma_k, MatrixLayout::ColMajor, - ), - acc: Matrix::::uninitialized( - MatrixIdent::Accumulator, - cmma_m, - cmma_n, - cmma_k, - MatrixLayout::Undefined, - ), - } - }; - - cmma::fill(&matrices.acc, FAcc::new(0.0)); - - matrices + ) + }, + acc: Matrix::::from_value( + MatrixIdent::Accumulator, + cmma_m, + cmma_n, + cmma_k, + MatrixLayout::Undefined, + FAcc::new(0.0), + ), + } } #[cube] @@ -388,7 +378,7 @@ fn execute_gemm( pos: &Positions, args: &ConvArgs, #[comptime] g_settings: GemmSettings, - #[comptime] k_settings: KernelSettings, + #[comptime] k_settings: ConvSettings, ) { let GemmSettings { cmma_n, cmma_k, .. } = g_settings; @@ -404,7 +394,7 @@ fn execute_gemm( input, args, input_tile, dims, pos, k, g_settings, k_settings, ); - load_weight_tile(weight, weight_tile, pos, k, g_settings); + load_weight_tile(weight, weight_tile, dims, pos, k, g_settings, k_settings); // Run CMMA cmma::load(&matrices.a, input_tile.as_slice(), cmma_k); @@ -425,16 +415,18 @@ fn load_input_tile( pos: &Positions, k: u32, #[comptime] gemm_settings: GemmSettings, - #[comptime] kernel_settings: KernelSettings, + #[comptime] kernel_settings: ConvSettings, ) { let GemmSettings { cmma_m, cmma_k, warp_size, + check_m, + check_k, .. } = gemm_settings; - let KernelSettings { + let ConvSettings { kernel_w, padding_h, padding_w, @@ -447,7 +439,7 @@ fn load_input_tile( let height = input.shape(1) as i32; let width = input.shape(2) as i32; - let channels = input.shape(3); + let channels = dims.padded_channels; // Row strides in the implicit GEMM matrix let batch_stride = dims.out_h * dims.out_w; @@ -458,6 +450,18 @@ fn load_input_tile( let slice_start_idx = k % dims.slice_size; let start = pos.intra_warp_unit_idx * elems_per_thread; + let rel_slice_row = start / cmma_k; // Relative row (0 - 15) + let abs_slice_row = pos.global_m + rel_slice_row; // Row of the matrix the slice is on + + // Given the row of the matrix that the slice is in, and the index of the thread + // within a slice, want to compute what input element to load... + // first compute coordinates in output space (center of the kernel in MxK matrix A) + let batch = abs_slice_row / batch_stride; + + let m_in_bounds = !check_m || batch < input.shape(0); + let out_y = (abs_slice_row % batch_stride) / y_stride; + let out_x = ((abs_slice_row % batch_stride) % y_stride) / x_stride; + #[unroll] for m in range_stepped(0, elems_per_thread, vec) { let m = m + start; @@ -466,22 +470,13 @@ fn load_input_tile( // Slices are always `kernel_size * channels` elements wide so we can compute where inside a slice // we are and also which row the slice is in relative to the start of the CMMA matrix - let rel_slice_row = m / cmma_k; // Relative row (0 - 15) - let abs_slice_row = pos.global_m + rel_slice_row; // Row of the matrix the slice is on - - // Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is - // responsible for + // Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is repsonsible for let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size; - // Given the row of the matrix that the slice is in, and the index of the thread - // within a slice, want to compute what input element to load... - // first compute coordinates in output space (center of the kernel in MxK matrix A) - let batch = abs_slice_row / batch_stride; - let out_y = (abs_slice_row % batch_stride) / y_stride; - let out_x = ((abs_slice_row % batch_stride) % y_stride) / x_stride; - let channel = my_slice_idx % channels; + let k_in_bounds = !check_k || channel < input.shape(3); + let kernel_x = (my_slice_idx / channels) % kernel_w; let kernel_y = my_slice_idx / (channels * kernel_w); @@ -494,7 +489,7 @@ fn load_input_tile( + x as u32 * input.stride(2) + channel; let value = select( - in_bounds, + in_bounds && m_in_bounds && k_in_bounds, FMat::cast_from(input[idx / vec]), FMat::vectorized(0.0, vec), ); @@ -510,21 +505,31 @@ fn load_input_tile( fn load_weight_tile( weight: &Tensor>, tile: &mut SliceMut, + dims: &Dimensions, pos: &Positions, k: u32, #[comptime] gemm_settings: GemmSettings, + #[comptime] kernel_settings: ConvSettings, ) { let GemmSettings { cmma_n, cmma_k, warp_size, + check_n, + check_k, .. } = gemm_settings; + let ConvSettings { kernel_w, .. } = kernel_settings; + let vec = vectorization_of(weight); let cmma_filter_tile_size = cmma_k * cmma_n; let elems_per_thread = cmma_filter_tile_size / warp_size; let start = pos.intra_warp_unit_idx * elems_per_thread; + let abs_slice_col = pos.global_n + (start / cmma_k); // Row of the matrix the slice is on + + let n_in_bounds = !check_n || abs_slice_col < weight.shape(0); + let col_idx = abs_slice_col * weight.stride(0); #[unroll] for n in range_stepped(0, elems_per_thread, vec) { @@ -532,9 +537,18 @@ fn load_weight_tile( // Compute where in the slice we are starting let rel_slice_row = n % cmma_k; // Relative row (0 - 15) let abs_slice_row = k + rel_slice_row; // Row of the matrix the slice is on - let abs_slice_col = pos.global_n + (n / cmma_k); // Row of the matrix the slice is on - let idx = abs_slice_col * weight.stride(0) + abs_slice_row; + let channel = abs_slice_row % dims.padded_channels; + let k_in_bounds = !check_k || channel < weight.shape(3); + + let idx = if check_k { + let kernel_x = abs_slice_row / dims.padded_channels % kernel_w; + let kernel_y = abs_slice_row / (dims.padded_channels * kernel_w); + col_idx + kernel_y * weight.stride(1) + kernel_x * weight.stride(2) + channel + } else { + col_idx + abs_slice_row + }; let value = FMat::cast_from(weight[idx / vec]); + let value = select(k_in_bounds && n_in_bounds, value, FMat::new(0.0)); #[unroll] for i in 0..vec { @@ -553,6 +567,8 @@ pub(crate) fn can_do_implicit_gemm( let [batch_size, in_channels, _, _] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let in_channels = padded_in_channels(in_channels, kernel_h, kernel_w); + let batch_size = padded_batch_size(batch_size, out_h, out_w); + let out_channels = out_channels.div_ceil(16) * 16; let gemm_m = batch_size * out_h * out_w; let gemm_n = out_channels; @@ -587,6 +603,16 @@ fn padded_in_channels(in_channels: usize, kernel_h: usize, kernel_w: usize) -> u } } +fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize { + let out_size = out_h * out_w; + let target = if out_size % 2 == 0 { + (16usize).div_ceil(out_size) + } else { + 16 + }; + batch_size.div_ceil(target) * target +} + fn find_cmma_size( device: &R::JitDevice, gemm_m: u32, From aada8d1ae2bfe23bf76119c43cc1a84e58724c4c Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Thu, 10 Oct 2024 16:16:49 +0200 Subject: [PATCH 2/3] Improve k padding --- .../src/kernel/conv/conv2d/implicit_gemm.rs | 75 +++++++++++-------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 7b51f176b9..74265355f3 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -28,7 +28,7 @@ pub fn conv2d_implicit_gemm( ) -> JitTensor { let [batch_size, in_channels, height, width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); - let padded_in_channels = padded_in_channels(in_channels, kernel_h, kernel_w); + let (pad_in_channels, pad_kh, pad_kw) = padded_k(in_channels, kernel_h, kernel_w); let padded_out_channels = out_channels.div_ceil(16) * 16; let out_h = calculate_conv_output_size( @@ -47,7 +47,6 @@ pub fn conv2d_implicit_gemm( ); let padded_batch_size = padded_batch_size(batch_size, out_h, out_w); - println!("Padded batches: {padded_batch_size}"); if !can_do_implicit_gemm(&input, &weight, &options, out_h, out_w) { panic!( @@ -67,9 +66,9 @@ pub fn conv2d_implicit_gemm( // Implicit GEMM matrix size let gemm_m = (padded_batch_size * out_h * out_w) as u32; let gemm_n = padded_out_channels as u32; - let gemm_k = (padded_in_channels * kernel_h * kernel_w) as u32; + let gemm_k = (pad_in_channels * pad_kh * pad_kw) as u32; - let slice_size = kernel_h * kernel_w * padded_in_channels; + let slice_size = pad_kh * pad_kw * pad_in_channels; let (cmma_m, cmma_n, cmma_k) = find_cmma_size::(&input.device, gemm_m, gemm_k, gemm_n).unwrap(); @@ -98,7 +97,7 @@ pub fn conv2d_implicit_gemm( cmma_k, check_m: batch_size != padded_batch_size, check_n: out_channels != padded_out_channels, - check_k: in_channels != padded_in_channels, + check_k: (kernel_h * kernel_w * in_channels) as u32 != gemm_k, warp_size, warps_per_cube, cube_dim_x, @@ -133,7 +132,8 @@ pub fn conv2d_implicit_gemm( ScalarArg::new(gemm_n), ScalarArg::new(gemm_k), ScalarArg::new(slice_size as u32), - ScalarArg::new(padded_in_channels as u32), + ScalarArg::new(pad_kw as u32), + ScalarArg::new(pad_in_channels as u32), ScalarArg::new(out_h as u32), ScalarArg::new(out_w as u32), ), @@ -194,7 +194,8 @@ struct Dimensions { gemm_k: u32, slice_size: u32, - padded_channels: u32, + pad_kw: u32, + pad_channels: u32, out_h: u32, out_w: u32, @@ -428,6 +429,7 @@ fn load_input_tile( let ConvSettings { kernel_w, + kernel_h, padding_h, padding_w, .. @@ -439,7 +441,7 @@ fn load_input_tile( let height = input.shape(1) as i32; let width = input.shape(2) as i32; - let channels = dims.padded_channels; + let channels = dims.pad_channels; // Row strides in the implicit GEMM matrix let batch_stride = dims.out_h * dims.out_w; @@ -475,10 +477,11 @@ fn load_input_tile( let channel = my_slice_idx % channels; - let k_in_bounds = !check_k || channel < input.shape(3); + let kernel_x = (my_slice_idx / channels) % dims.pad_kw; + let kernel_y = my_slice_idx / (channels * dims.pad_kw); - let kernel_x = (my_slice_idx / channels) % kernel_w; - let kernel_y = my_slice_idx / (channels * kernel_w); + let k_in_bounds = + !check_k || (channel < input.shape(3) && kernel_x < kernel_w && kernel_y < kernel_h); let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32 - padding_h; let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32 - padding_w; @@ -520,7 +523,9 @@ fn load_weight_tile( .. } = gemm_settings; - let ConvSettings { kernel_w, .. } = kernel_settings; + let ConvSettings { + kernel_w, kernel_h, .. + } = kernel_settings; let vec = vectorization_of(weight); let cmma_filter_tile_size = cmma_k * cmma_n; @@ -537,15 +542,17 @@ fn load_weight_tile( // Compute where in the slice we are starting let rel_slice_row = n % cmma_k; // Relative row (0 - 15) let abs_slice_row = k + rel_slice_row; // Row of the matrix the slice is on - let channel = abs_slice_row % dims.padded_channels; - let k_in_bounds = !check_k || channel < weight.shape(3); - let idx = if check_k { - let kernel_x = abs_slice_row / dims.padded_channels % kernel_w; - let kernel_y = abs_slice_row / (dims.padded_channels * kernel_w); - col_idx + kernel_y * weight.stride(1) + kernel_x * weight.stride(2) + channel + let (idx, k_in_bounds) = if check_k { + let channel = abs_slice_row % dims.pad_channels; + let kernel_x = abs_slice_row / dims.pad_channels % dims.pad_kw; + let kernel_y = abs_slice_row / (dims.pad_channels * dims.pad_kw); + let k_in_bounds = !check_k + || (channel < weight.shape(3) && kernel_x < kernel_w && kernel_y < kernel_h); + let idx = col_idx + kernel_y * weight.stride(1) + kernel_x * weight.stride(2) + channel; + (idx, k_in_bounds) } else { - col_idx + abs_slice_row + (col_idx + abs_slice_row, true) }; let value = FMat::cast_from(weight[idx / vec]); let value = select(k_in_bounds && n_in_bounds, value, FMat::new(0.0)); @@ -566,7 +573,7 @@ pub(crate) fn can_do_implicit_gemm( ) -> bool { let [batch_size, in_channels, _, _] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); - let in_channels = padded_in_channels(in_channels, kernel_h, kernel_w); + let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_h, kernel_w); let batch_size = padded_batch_size(batch_size, out_h, out_w); let out_channels = out_channels.div_ceil(16) * 16; @@ -588,19 +595,23 @@ pub(crate) fn can_do_implicit_gemm( } } -fn padded_in_channels(in_channels: usize, kernel_h: usize, kernel_w: usize) -> usize { - let kernel_size = kernel_h * kernel_w; - let target = if kernel_size % 2 == 0 { - (16usize).div_ceil(kernel_size) - } else { - 16 - }; - if in_channels % target != 0 { - let tiles = in_channels.div_ceil(target); - tiles * target - } else { - in_channels +fn padded_k(in_channels: usize, kernel_h: usize, kernel_w: usize) -> (usize, usize, usize) { + let target = 16; + if in_channels * kernel_h * kernel_w % target == 0 { + return (in_channels, kernel_h, kernel_w); + } + let kernel_h = kernel_h.next_power_of_two(); + let target = target.div_ceil(kernel_h); + if in_channels * kernel_w % target == 0 { + return (in_channels, kernel_h, kernel_w); + } + let kernel_w = kernel_w.next_power_of_two(); + let target = target.div_ceil(kernel_w); + if in_channels % target == 0 { + return (in_channels, kernel_h, kernel_w); } + let in_channels = in_channels.div_ceil(target) * target; + (in_channels, kernel_h, kernel_w) } fn padded_batch_size(batch_size: usize, out_h: usize, out_w: usize) -> usize { From ec8a356bae1b2b0cb428a7d5ab079104da0c95ab Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Thu, 10 Oct 2024 16:23:30 +0200 Subject: [PATCH 3/3] Fix typo --- crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 74265355f3..f6057326ef 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -13,7 +13,7 @@ use crate::{ FloatElement, IntElement, JitBackend, JitRuntime, }; -/// Perform a 2D convolution using the implicit GEMM algorithm. Requries `cmma` to be available. +/// Perform a 2D convolution using the implicit GEMM algorithm. Requires `cmma` to be available. /// /// * `input` - The input feature map /// * `weight` - The weights (filter) applied to each kernel