Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/wgpu/repeat #1068

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion burn-tensor/src/tests/ops/repeat.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(repeat)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
use burn_tensor::{Bool, Data, Int, Tensor};

#[test]
fn should_support_repeat_ops() {
Expand All @@ -18,4 +18,31 @@ mod tests {
]);
assert_eq!(data_expected, data_actual);
}

#[test]
fn should_support_bool_repeat_ops() {
let data = Data::from([[true, false, false]]);
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data);

let data_actual = tensor.repeat(0, 4).into_data();

let data_expected = Data::from([
[true, false, false],
[true, false, false],
[true, false, false],
[true, false, false],
]);
assert_eq!(data_expected, data_actual);
}

#[test]
fn should_support_int_repeat_ops() {
let data = Data::from([[0, 1, 2]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data);

let data_actual = tensor.repeat(0, 4).into_data();

let data_expected = Data::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
assert_eq!(data_expected, data_actual);
}
}
2 changes: 2 additions & 0 deletions burn-wgpu/src/kernel/index/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod gather;
mod repeat;
mod scatter;
mod select;
mod slice;

pub use gather::*;
pub use repeat::*;
pub use scatter::*;
pub use select::*;
pub use slice::*;
119 changes: 119 additions & 0 deletions burn-wgpu/src/kernel/index/repeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use crate::{
compute::StaticKernel,
element::WgpuElement,
kernel::{build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
tensor::WgpuTensor,
};

kernel_wgsl!(RepeatRaw, "../../template/index/repeat.wgsl");

pub(crate) fn repeat<E: WgpuElement, const D1: usize>(
input: WgpuTensor<E, D1>,
dim: usize,
times: usize,
) -> WgpuTensor<E, D1> {
let mut shape = input.shape.clone();
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}

// Create output handle
shape.dims[dim] = times;
let num_elems_output = shape.num_elements();
let handle = input
.client
.empty(num_elems_output * core::mem::size_of::<E>());
let output = WgpuTensor::new(
input.client.clone(),
input.device.clone(),
shape.clone(),
handle,
);

let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_handle = input.client.create(bytemuck::cast_slice(&info));

let kernel = StaticKernel::<
KernelSettings<RepeatRaw, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT));

input.client.execute(
Box::new(kernel),
&[&input.handle, &output.handle, &info_handle],
);

output
}

#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ReferenceBackend, TestBackend};
use burn_tensor::{Distribution, Tensor};

#[test]
fn repeat_dim_0_few_times() {
let tensor = Tensor::<TestBackend, 3>::random([1, 6, 6], Distribution::Default);
let dim = 0;
let times = 4;
let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data());

let actual = repeat(tensor.into_primitive(), dim, times);
let expected = tensor_ref.repeat(dim, times);

expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 3>::from_primitive(actual).into_data(),
3,
);
}

#[test]
fn repeat_dim_1_few_times() {
let tensor = Tensor::<TestBackend, 3>::random([6, 1, 6], Distribution::Default);
let dim = 1;
let times = 4;
let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data());

let actual = repeat(tensor.into_primitive(), dim, times);
let expected = tensor_ref.repeat(dim, times);

expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 3>::from_primitive(actual).into_data(),
3,
);
}

#[test]
fn repeat_dim_2_few_times() {
let tensor = Tensor::<TestBackend, 3>::random([6, 6, 1], Distribution::Default);
let dim = 2;
let times = 4;
let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data());

let actual = repeat(tensor.into_primitive(), dim, times);
let expected = tensor_ref.repeat(dim, times);

expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 3>::from_primitive(actual).into_data(),
3,
);
}

#[test]
fn repeat_dim_2_many_times() {
let tensor = Tensor::<TestBackend, 3>::random([10, 10, 1], Distribution::Default);
let dim = 2;
let times = 200;
let tensor_ref = Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data());

let actual = repeat(tensor.into_primitive(), dim, times);
let expected = tensor_ref.repeat(dim, times);

expected.into_data().assert_approx_eq(
&Tensor::<TestBackend, 3>::from_primitive(actual).into_data(),
3,
);
}
}
8 changes: 8 additions & 0 deletions burn-wgpu/src/ops/bool_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,12 @@ where

tensor
}

fn bool_repeat<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
times: usize,
) -> BoolTensor<Self, D> {
kernel::repeat(tensor, dim, times)
}
}
8 changes: 8 additions & 0 deletions burn-wgpu/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,4 +525,12 @@ where

unary_default::<Recip, F, D>(tensor)
}

fn repeat<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
) -> FloatTensor<Self, D> {
kernel::repeat(tensor, dim, times)
}
}
8 changes: 8 additions & 0 deletions burn-wgpu/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,12 @@ where

tensor
}

fn int_repeat<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
times: usize,
) -> IntTensor<Self, D> {
kernel::repeat(tensor, dim, times)
}
}
38 changes: 38 additions & 0 deletions burn-wgpu/src/template/index/repeat.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;

@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;

@group(0)
@binding(2)
var<storage, read> info: array<u32>;

const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;

@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let rank: u32 = info[0];
let repeat_dim = info[4u * rank + 1u];
var index_input: u32 = 0u;

for (var i: u32 = 1u; i <= rank; i++) {
let stride_input = info[i];
let stride_output = info[i + rank];
let shape_output = info[i + 3u * rank];

if repeat_dim != i - 1u {
let num_block = id / stride_output % shape_output; // 4
index_input += num_block * stride_input;
}
}

output[id] = input[index_input];
}