Skip to content

Commit

Permalink
Add should_run for convs instead of panicking (#2403)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee authored Oct 23, 2024
1 parent 3b51c26 commit bac4405
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 44 deletions.
3 changes: 2 additions & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/col2im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ pub fn conv_transpose2d_col2im<R: JitRuntime, E: FloatElement, I: IntElement>(
);
let im_channels = im_ch_per_group * groups;

let batches_per_run = batches_per_run(batch_size, input_h, input_w);
let batches_per_run = batches_per_run(batch_size, input_h, input_w)
.expect("Image too large to run even one batch at once");
let col_shape_0 = im_ch_per_group * kernel_h * kernel_w;

let weight = reshape(
Expand Down
21 changes: 12 additions & 9 deletions crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,25 @@ fn im2col_kernel<F: Float>(
}

#[cfg(not(test))]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize {
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option<usize> {
let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::SUBCUBE_DIM_APPROX);
let max_cube_count = u16::MAX as usize;
let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size);
if max_simultaneous == 0 {
panic!("Image too large to run even one batch at once");
return None;
}
(0..=max_simultaneous)
.rev()
.find(|per_run| batch_size % per_run == 0)
.unwrap()
Some(
(0..=max_simultaneous)
.rev()
.find(|per_run| batch_size % per_run == 0)
.expect("Logically not possible"),
)
}

#[cfg(test)]
#[allow(unused)]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize {
1
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option<usize> {
Some(1)
}

fn im2col<R: JitRuntime, E: FloatElement>(
Expand Down Expand Up @@ -207,7 +209,8 @@ pub fn conv2d_im2col<R: JitRuntime, E: FloatElement, I: IntElement>(
return execute_1x1_kernel::<R, E, I>(input, weight, bias, options);
}

let batches_per_run = batches_per_run(batch_size, out_h, out_w);
let batches_per_run = batches_per_run(batch_size, out_h, out_w)
.expect("Image too large to run even one batch at once");
let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]);

let mut out = if batches_per_run != batch_size {
Expand Down
30 changes: 20 additions & 10 deletions crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,16 @@ pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(

let padded_batch_size = padded_batch_size(batch_size, out_h, out_w);

if !can_do_implicit_gemm(&input, &weight, &options, out_h, out_w) {
if !can_do_implicit_gemm::<R, F>(
batch_size,
in_channels,
out_channels,
[kernel_h, kernel_w],
options.groups,
out_h,
out_w,
&input.device,
) {
panic!(
"Requirements for implicit GEMM not met:
- CMMA must be available
Expand Down Expand Up @@ -645,32 +654,33 @@ fn load_bias_tile<F: Float>(
}
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
input: &JitTensor<R, E>,
weight: &JitTensor<R, E>,
options: &ConvOptions<2>,
batch_size: usize,
in_channels: usize,
out_channels: usize,
kernel_size: [usize; 2],
groups: usize,
out_h: usize,
out_w: usize,
device: &R::Device,
) -> bool {
let [batch_size, in_channels, _, _] = input.shape.dims();
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims();
let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_h, kernel_w);
let (in_channels, kernel_h, kernel_w) = padded_k(in_channels, kernel_size[0], kernel_size[1]);
let batch_size = padded_batch_size(batch_size, out_h, out_w);
let out_channels = out_channels.div_ceil(16) * 16;

let gemm_m = batch_size * out_h * out_w;
let gemm_n = out_channels;
let gemm_k = in_channels * kernel_h * kernel_w;

let size =
find_cmma_size::<R, f16, E>(&input.device, gemm_m as u32, gemm_k as u32, gemm_n as u32);
let size = find_cmma_size::<R, f16, E>(device, gemm_m as u32, gemm_k as u32, gemm_n as u32);

if let Some((cmma_m, cmma_k, cmma_n)) = size {
let warps_per_cube = 8;

let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();

<R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && options.groups == 1
<R::Compiler as Compiler>::max_shared_memory_size() >= smem_size && groups == 1
} else {
false
}
Expand Down
60 changes: 38 additions & 22 deletions crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use cubecl::{

use crate::{
kernel::{
conv::{can_do_implicit_gemm, conv2d_direct, conv2d_im2col, conv2d_implicit_gemm},
conv::{
batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_im2col,
conv2d_implicit_gemm,
},
prng::random_uniform,
},
tensor::JitTensor,
Expand Down Expand Up @@ -73,30 +76,43 @@ pub fn conv2d_operations<R: JitRuntime, E: FloatElement, I: IntElement>(

fn should_run<R: JitRuntime, F: FloatElement, I: IntElement>(
op: &Conv2dOperations<R, F, I>,
_key: &JitAutotuneKey,
key: &JitAutotuneKey,
index: usize,
) -> bool {
let key = match key {
JitAutotuneKey::Conv2d(key) => key,
_ => unreachable!(),
};

let out_h = calculate_conv_output_size(
key.kernel_size[0],
key.stride[0],
key.padding[0],
key.dilation[0],
key.height,
);
let out_w = calculate_conv_output_size(
key.kernel_size[1],
key.stride[1],
key.padding[1],
key.dilation[1],
key.width,
);

match index {
2 => {
let [_, _, height, width] = op.input.shape.dims();
let [_, _, kernel_h, kernel_w] = op.weights.shape.dims();
let o = &op.options;
let out_h = calculate_conv_output_size(
kernel_h,
o.stride[0],
o.padding[0],
o.dilation[0],
height,
);
let out_w = calculate_conv_output_size(
kernel_w,
o.stride[1],
o.padding[1],
o.dilation[1],
width,
);
can_do_implicit_gemm(&op.input, &op.weights, &op.options, out_h, out_w)
}
// im2col
1 => batches_per_run(key.batch_size, out_h, out_w).is_some(),
// Implicit gemm.
2 => can_do_implicit_gemm::<R, F>(
key.batch_size,
key.in_channels,
key.out_channels,
key.kernel_size,
op.options.groups,
out_h,
out_w,
&op.input.device,
),
_ => true,
}
}
Expand Down
21 changes: 19 additions & 2 deletions crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use cubecl::{

use crate::{
kernel::{
conv::{conv_transpose2d_col2im, conv_transpose2d_direct},
conv::{batches_per_run, conv_transpose2d_col2im, conv_transpose2d_direct},
prng::random_uniform,
},
tensor::JitTensor,
Expand Down Expand Up @@ -35,7 +35,7 @@ pub fn conv_transpose2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
)
}

#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key)]
#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key, should_run = should_run)]
pub fn conv_transpose2d_operations<R: JitRuntime, E: FloatElement, I: IntElement>(
key: JitAutotuneKey,
input: JitTensor<R, E>,
Expand Down Expand Up @@ -93,3 +93,20 @@ fn create_key<R: JitRuntime, E: FloatElement>(
bias.is_some(),
))
}

fn should_run<R: JitRuntime, F: FloatElement, I: IntElement>(
_op: &ConvTranspose2dOperations<R, F, I>,
key: &JitAutotuneKey,
index: usize,
) -> bool {
let key = match key {
JitAutotuneKey::ConvTranspose2d(key) => key,
_ => unreachable!(),
};

match index {
// im2col
1 => batches_per_run(key.batch_size, key.height, key.width).is_some(),
_ => true,
}
}

0 comments on commit bac4405

Please sign in to comment.