From 585009755c8bc464dbccada1bae34d32e30333c2 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 14 Dec 2023 09:48:21 -0500 Subject: [PATCH 1/2] repeat wgpu --- burn-wgpu/src/kernel/index/mod.rs | 2 + burn-wgpu/src/kernel/index/repeat.rs | 119 +++++++++++++++++++++++ burn-wgpu/src/ops/bool_ops.rs | 8 ++ burn-wgpu/src/ops/float_ops.rs | 8 ++ burn-wgpu/src/ops/int_ops.rs | 8 ++ burn-wgpu/src/template/index/repeat.wgsl | 38 ++++++++ 6 files changed, 183 insertions(+) create mode 100644 burn-wgpu/src/kernel/index/repeat.rs create mode 100644 burn-wgpu/src/template/index/repeat.wgsl diff --git a/burn-wgpu/src/kernel/index/mod.rs b/burn-wgpu/src/kernel/index/mod.rs index c1d9594f51..dd4aa25b8c 100644 --- a/burn-wgpu/src/kernel/index/mod.rs +++ b/burn-wgpu/src/kernel/index/mod.rs @@ -1,9 +1,11 @@ mod gather; +mod repeat; mod scatter; mod select; mod slice; pub use gather::*; +pub use repeat::*; pub use scatter::*; pub use select::*; pub use slice::*; diff --git a/burn-wgpu/src/kernel/index/repeat.rs b/burn-wgpu/src/kernel/index/repeat.rs new file mode 100644 index 0000000000..2b05892329 --- /dev/null +++ b/burn-wgpu/src/kernel/index/repeat.rs @@ -0,0 +1,119 @@ +use crate::{ + compute::StaticKernel, + element::WgpuElement, + kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, +}; + +kernel_wgsl!(RepeatRaw, "../../template/index/repeat.wgsl"); + +pub(crate) fn repeat( + input: WgpuTensor, + dim: usize, + times: usize, +) -> WgpuTensor { + let mut shape = input.shape.clone(); + if shape.dims[dim] != 1 { + panic!("Can only repeat dimension with dim=1"); + } + + // Create output handle + shape.dims[dim] = times; + let num_elems_output = shape.num_elements(); + let handle = input + .client + .empty(num_elems_output * core::mem::size_of::()); + let output = WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape.clone(), + handle, + ); + + let mut info = build_info(&[&input, &output]); + info.push(dim as u32); + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + let kernel = StaticKernel::< + KernelSettings, + >::new(elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); + + output +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::{ReferenceBackend, TestBackend}; + use burn_tensor::{Bool, Distribution, Tensor}; + + #[test] + fn repeat_dim_0_few_times() { + let tensor = Tensor::::random([1, 6, 6], Distribution::Default); + let dim = 0; + let times = 4; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let actual = repeat(tensor.into_primitive(), dim, times); + let expected = tensor_ref.repeat(dim, times); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn repeat_dim_1_few_times() { + let tensor = Tensor::::random([6, 1, 6], Distribution::Default); + let dim = 1; + let times = 4; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let actual = repeat(tensor.into_primitive(), dim, times); + let expected = tensor_ref.repeat(dim, times); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn repeat_dim_2_few_times() { + let tensor = Tensor::::random([6, 6, 1], Distribution::Default); + let dim = 2; + let times = 4; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let actual = repeat(tensor.into_primitive(), dim, times); + let expected = tensor_ref.repeat(dim, times); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } + + #[test] + fn repeat_dim_2_many_times() { + let tensor = Tensor::::random([10, 10, 1], Distribution::Default); + let dim = 2; + let times = 200; + let tensor_ref = Tensor::::from_data(tensor.to_data()); + + let actual = repeat(tensor.into_primitive(), dim, times); + let expected = tensor_ref.repeat(dim, times); + + expected.into_data().assert_approx_eq( + &Tensor::::from_primitive(actual).into_data(), + 3, + ); + } +} diff --git a/burn-wgpu/src/ops/bool_ops.rs b/burn-wgpu/src/ops/bool_ops.rs index 82470694f2..411ac93459 100644 --- a/burn-wgpu/src/ops/bool_ops.rs +++ b/burn-wgpu/src/ops/bool_ops.rs @@ -123,4 +123,12 @@ where tensor } + + fn bool_repeat( + tensor: BoolTensor, + dim: usize, + times: usize, + ) -> BoolTensor { + kernel::repeat(tensor, dim, times) + } } diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index 3f0b008b88..7e91b86ef0 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -525,4 +525,12 @@ where unary_default::(tensor) } + + fn repeat( + tensor: FloatTensor, + dim: usize, + times: usize, + ) -> FloatTensor { + kernel::repeat(tensor, dim, times) + } } diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs index bbef2dd6a6..9aa7480364 100644 --- a/burn-wgpu/src/ops/int_ops.rs +++ b/burn-wgpu/src/ops/int_ops.rs @@ -327,4 +327,12 @@ where tensor } + + fn int_repeat( + tensor: IntTensor, + dim: usize, + times: usize, + ) -> IntTensor { + kernel::repeat(tensor, dim, times) + } } diff --git a/burn-wgpu/src/template/index/repeat.wgsl b/burn-wgpu/src/template/index/repeat.wgsl new file mode 100644 index 0000000000..54d6d59839 --- /dev/null +++ b/burn-wgpu/src/template/index/repeat.wgsl @@ -0,0 +1,38 @@ +@group(0) +@binding(0) +var input: array<{{ elem }}>; + +@group(0) +@binding(1) +var output: array<{{ elem }}>; + +@group(0) +@binding(2) +var info: array; + +const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; + +@compute +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(num_workgroups) num_workgroups: vec3, +) { + let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; + let rank: u32 = info[0]; + let repeat_dim = info[4u * rank + 1u]; + var index_input: u32 = 0u; + + for (var i: u32 = 1u; i <= rank; i++) { + let stride_input = info[i]; + let stride_output = info[i + rank]; + let shape_output = info[i + 3u * rank]; + + if repeat_dim != i - 1u { + let num_block = id / stride_output % shape_output; // 4 + index_input += num_block * stride_input; + } + } + + output[id] = input[index_input]; +} From 99399f5dfc069f3a7b1f8fc73c6098acc6c49f7e Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 14 Dec 2023 10:04:21 -0500 Subject: [PATCH 2/2] test bool and int --- burn-tensor/src/tests/ops/repeat.rs | 29 +++++++++++++++++++++++++++- burn-wgpu/src/kernel/index/repeat.rs | 2 +- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/burn-tensor/src/tests/ops/repeat.rs b/burn-tensor/src/tests/ops/repeat.rs index 8725decb26..b24e6f5295 100644 --- a/burn-tensor/src/tests/ops/repeat.rs +++ b/burn-tensor/src/tests/ops/repeat.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(repeat)] mod tests { use super::*; - use burn_tensor::{Data, Tensor}; + use burn_tensor::{Bool, Data, Int, Tensor}; #[test] fn should_support_repeat_ops() { @@ -18,4 +18,31 @@ mod tests { ]); assert_eq!(data_expected, data_actual); } + + #[test] + fn should_support_bool_repeat_ops() { + let data = Data::from([[true, false, false]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.repeat(0, 4).into_data(); + + let data_expected = Data::from([ + [true, false, false], + [true, false, false], + [true, false, false], + [true, false, false], + ]); + assert_eq!(data_expected, data_actual); + } + + #[test] + fn should_support_int_repeat_ops() { + let data = Data::from([[0, 1, 2]]); + let tensor = Tensor::::from_data(data); + + let data_actual = tensor.repeat(0, 4).into_data(); + + let data_expected = Data::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]); + assert_eq!(data_expected, data_actual); + } } diff --git a/burn-wgpu/src/kernel/index/repeat.rs b/burn-wgpu/src/kernel/index/repeat.rs index 2b05892329..16d2418622 100644 --- a/burn-wgpu/src/kernel/index/repeat.rs +++ b/burn-wgpu/src/kernel/index/repeat.rs @@ -51,7 +51,7 @@ pub(crate) fn repeat( mod tests { use super::*; use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Bool, Distribution, Tensor}; + use burn_tensor::{Distribution, Tensor}; #[test] fn repeat_dim_0_few_times() {