Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Jan 24, 2024
1 parent 30f5f7e commit 4b56feb
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions sycl/test-e2e/Matrix/joint_matrix_fill_store_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@
#include "common.hpp"
#define SG_SZ 16

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

template <typename T1, typename T2, size_t TM, size_t TN, size_t TK>
void matrix_fill_store(big_matrix<T1, TM, TN> &C, big_matrix<T2, TM, TK> &A,
big_matrix<T2, TK / 2, TN * 2> &B) {
template <typename TC, typename Tab, size_t TM, size_t TN, size_t TK>
void matrix_fill_store(big_matrix<TC, TM, TN> &C, big_matrix<Tab, TM, TK> &A,
big_matrix<Tab, TK / 2, TN * 2> &B) {
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(TM, TK));
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(TK / 2, TN * 2));
buffer<float, 2> bufC((float *)C.get_data(), range<2>(TM, TN));
Expand All @@ -35,14 +32,14 @@ void matrix_fill_store(big_matrix<T1, TM, TN> &C, big_matrix<T2, TM, TK> &A,
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
joint_matrix<sub_group, Tab, use::a, TM, TK, layout::row_major>
sub_a;

// For B, we assume B has been already VNNIed.
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
joint_matrix<sub_group, Tab, use::b, TK, TN,
layout::ext_intel_packed>
sub_b;
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
joint_matrix<sub_group, TC, use::accumulator, TM, TN> sub_c;

// TODO: uncomment these calls to add testing for other types of
// matrices
Expand Down Expand Up @@ -100,6 +97,7 @@ int main() {
// TODO: add all supported size and types combinations
bool res = run_test<8, 16, 16>();
res &= run_test<32, 64, 16>();
res &= run_test<16, 16, 16>();
std::cout << (res ? "passed" : "failed") << std::endl;
return !res;
}

0 comments on commit 4b56feb

Please sign in to comment.