Skip to content

Commit

Permalink
Test/cmma/strided (#483)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Feb 17, 2025
1 parent 7f07d39 commit c305e93
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,130 @@ pub fn test_simple_tf32<R: Runtime>(
assert_eq!(expected, actual);
}

#[cube(launch)]
pub fn kernel_strided(
lhs: &Array<f16>,
rhs: &Array<f16>,
out: &mut Array<f32>,
#[comptime] stride_lhs: u32,
#[comptime] stride_rhs: u32,
) {
let a = cmma::Matrix::<f16>::from_slice(
cmma::MatrixIdent::A,
16,
16,
16,
cmma::MatrixLayout::RowMajor,
&lhs.to_slice(),
stride_lhs,
);
let b = cmma::Matrix::<f16>::from_slice(
cmma::MatrixIdent::B,
16,
16,
16,
cmma::MatrixLayout::ColMajor,
&rhs.to_slice(),
stride_rhs,
);
let c = cmma::Matrix::<f32>::from_value(
cmma::MatrixIdent::Accumulator,
16,
16,
16,
cmma::MatrixLayout::Undefined,
0.0,
);

cmma::execute::<f16, f16, f32, f32>(&a, &b, &c, &c);

cmma::store(
&mut out.to_slice_mut(),
&c,
16,
cmma::MatrixLayout::RowMajor,
);
}

pub fn test_cmma_strided<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
) {
// Lhs (row major) will have strided tiles
let (m, n, k) = (16, 16, 32);
let (t_m, t_n, t_k) = (16, 16, 16);
if !client.properties().feature_enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
c: Elem::Float(FloatKind::F32),
m: t_m as u8,
k: t_k as u8,
n: t_n as u8,
}) {
// We can't execute the test, skip.
return;
}

// Fills left tile while right tile is zero
let lhs: Vec<f16> = (0..m * k)
.map(|i| {
if (i % k) < t_k {
f16::from_f32((i - (i / k) * t_k) as f32)
} else {
f16::from_f32(0.)
}
})
.collect();
let rhs: Vec<f16> = (0..n * k).map(|i| f16::from_f32((i % 8) as f32)).collect();

let lhs = client.create(f16::as_bytes(&lhs));
let rhs = client.create(f16::as_bytes(&rhs));
let out = client.empty(core::mem::size_of::<f32>() * m * n);

unsafe {
kernel_strided::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
cube_dimensions,
ArrayArg::from_raw_parts::<f16>(&lhs, m * k, 1),
ArrayArg::from_raw_parts::<f16>(&rhs, k * n, 1),
ArrayArg::from_raw_parts::<f32>(&out, m * n, 1),
k as u32,
n as u32,
)
};

let actual = client.read_one(out.binding());
let actual = f32::from_bytes(&actual);

let expected = [
504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504.,
504., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400.,
1400., 1400., 1400., 1400., 2296., 2296., 2296., 2296., 2296., 2296., 2296., 2296., 2296.,
2296., 2296., 2296., 2296., 2296., 2296., 2296., 3192., 3192., 3192., 3192., 3192., 3192.,
3192., 3192., 3192., 3192., 3192., 3192., 3192., 3192., 3192., 3192., 4088., 4088., 4088.,
4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088.,
4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984.,
4984., 4984., 4984., 5880., 5880., 5880., 5880., 5880., 5880., 5880., 5880., 5880., 5880.,
5880., 5880., 5880., 5880., 5880., 5880., 6776., 6776., 6776., 6776., 6776., 6776., 6776.,
6776., 6776., 6776., 6776., 6776., 6776., 6776., 6776., 6776., 7672., 7672., 7672., 7672.,
7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 8568.,
8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568.,
8568., 8568., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464.,
9464., 9464., 9464., 9464., 9464., 10360., 10360., 10360., 10360., 10360., 10360., 10360.,
10360., 10360., 10360., 10360., 10360., 10360., 10360., 10360., 10360., 11256., 11256.,
11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256.,
11256., 11256., 11256., 12152., 12152., 12152., 12152., 12152., 12152., 12152., 12152.,
12152., 12152., 12152., 12152., 12152., 12152., 12152., 12152., 13048., 13048., 13048.,
13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048.,
13048., 13048., 13944., 13944., 13944., 13944., 13944., 13944., 13944., 13944., 13944.,
13944., 13944., 13944., 13944., 13944., 13944., 13944.,
];

assert_eq!(expected, actual);
// assert!(false);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_cmma {
Expand Down Expand Up @@ -383,5 +507,15 @@ macro_rules! testgen_cmma {
cube_dimensions,
);
}

#[test]
fn test_cmma_strided() {
let client = TestRuntime::client(&Default::default());
let cube_dimensions = CubeDim::new(16, 16, 1);
cubecl_core::runtime_tests::cmma::test_cmma_strided::<TestRuntime>(
client,
cube_dimensions,
);
}
};
}

0 comments on commit c305e93

Please sign in to comment.