From dd042e788a76d345153455d9e14e53176de5375a Mon Sep 17 00:00:00 2001 From: changqi1 Date: Tue, 28 May 2024 12:33:19 +0800 Subject: [PATCH 1/5] [Kernel] Add dynamic onednn matmul. --- ci_build | 1 + src/utils/matmul_helper.h | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/ci_build b/ci_build index 073f9fbb..00efbc21 100755 --- a/ci_build +++ b/ci_build @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================ +source ~/.bashrc pushd 3rdparty/ sh prepare_oneccl.sh source ./oneccl/build/_install/env/setvars.sh diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 39839c35..5b70cd24 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -1369,7 +1369,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B) and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims output_dims = {M, N}; @@ -1427,9 +1427,9 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1476,7 +1476,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; memory::dims output_dims = {M, N}; @@ -1507,9 +1507,9 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1558,7 +1558,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; memory::dims output_dims = {M, N}; @@ -1597,9 +1597,9 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1648,7 +1648,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims scale_dims = {M, N}; memory::dims output_dims = {M, N}; @@ -1703,9 +1703,9 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1755,7 +1755,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; memory::dims shift_dims = {M, N}; @@ -1807,9 +1807,9 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1860,7 +1860,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B) and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims output_dims = {M, N}; @@ -1880,7 +1880,7 @@ class MMHelper { matmul_hub[key] = value; } - auto input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + auto input_mem = memory({{M, K}, dt::s8, tag::ab}, *engine, const_cast(A)); auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(B)); auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); From 8f938f6020d882d745f021c2b51f0eedeb3f92f7 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Fri, 31 May 2024 10:53:41 +0800 Subject: [PATCH 2/5] temp --- src/utils/matmul_helper.h | 114 ++++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index 5b70cd24..ee00b060 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -20,6 +20,7 @@ #include "dtype.h" #include "environment.h" #include "float16.h" +#include "intrinsics_util.h" #include "my_types.h" #include "normal_float4x2.h" #include "oneapi/dnnl/dnnl.hpp" @@ -645,14 +646,20 @@ class MMHelper { xdnn_sgemm_f32bf16f32_compute_biasadd_relu( transA, M, N, K, alpha, A, lda, (const XDNN_UINT4x2 *)packedB, beta, C, ldc, bias)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) - if (M > AMXThresholdM) { - GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu", - onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); + if constexpr (std::is_same_v) { + GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu", + onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); } else { - GEMMVERBOSE("xdnn_bgemm_f32bf16f32_compute_biasadd_relu", - xdnn_bgemm_f32bf16f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); + if (M > AMXThresholdM) { + GEMMVERBOSE("onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu", + onednn_amx_sgemm_f32bf16f32_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); + } else { + GEMMVERBOSE("xdnn_bgemm_f32bf16f32_compute_biasadd_relu", + xdnn_bgemm_f32bf16f32_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); + } } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -1369,7 +1376,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B) and destination (C) matrix dimensions. - memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; + memory::dims input_dims = {M, K}; memory::dims weight_dims = {K, N}; memory::dims output_dims = {M, N}; @@ -1387,35 +1394,34 @@ class MMHelper { } // Create primitive descriptor and primitive. - switch (postAlg) - { - case matmul_kinds::Basic: - matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md); - break; - case matmul_kinds::Silu:{ - const float post_alpha = 1.0f; - const float post_beta = 0.0f; - post_ops matmul_ops; - matmul_ops.append_eltwise(algorithm::eltwise_swish, post_alpha, post_beta); - primitive_attr matmul_attr; - matmul_attr.set_post_ops(matmul_ops); - matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); - break; - } - case matmul_kinds::Gelu:{ - const float post_alpha = 1.0f; - const float post_beta = 0.0f; - post_ops matmul_ops; - matmul_ops.append_eltwise(algorithm::eltwise_gelu_tanh, post_alpha, post_beta); - primitive_attr matmul_attr; - matmul_attr.set_post_ops(matmul_ops); - matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); - break; - } - default: - printf(">>> onednn amx postAlg type %s not supported.", std::to_string(postAlg).c_str()); - exit(-1); - break; + switch (postAlg) { + case matmul_kinds::Basic: + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md); + break; + case matmul_kinds::Silu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_swish, post_alpha, post_beta); + primitive_attr matmul_attr; + matmul_attr.set_post_ops(matmul_ops); + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); + break; + } + case matmul_kinds::Gelu: { + const float post_alpha = 1.0f; + const float post_beta = 0.0f; + post_ops matmul_ops; + matmul_ops.append_eltwise(algorithm::eltwise_gelu_tanh, post_alpha, post_beta); + primitive_attr matmul_attr; + matmul_attr.set_post_ops(matmul_ops); + matmul_pd = new matmul::primitive_desc(*engine, input_md, weight_md, output_md, matmul_attr); + break; + } + default: + printf(">>> onednn amx postAlg type %s not supported.", std::to_string(postAlg).c_str()); + exit(-1); + break; } matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul @@ -1427,9 +1433,11 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); + // input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); + input_mem = memory(matmul_pd->src_desc(), *engine); } else if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); + // input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); + input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1476,7 +1484,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; + memory::dims input_dims = {M, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; memory::dims output_dims = {M, N}; @@ -1507,9 +1515,9 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); + input_mem = memory(matmul_pd->src_desc(), *engine); } else if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); + input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1558,7 +1566,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; + memory::dims input_dims = {M, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; memory::dims output_dims = {M, N}; @@ -1597,9 +1605,9 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); + input_mem = memory(matmul_pd->src_desc(), *engine); } else if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); + input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1648,7 +1656,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; + memory::dims input_dims = {M, K}; memory::dims weight_dims = {K, N}; memory::dims scale_dims = {M, N}; memory::dims output_dims = {M, N}; @@ -1703,9 +1711,9 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); + input_mem = memory(matmul_pd->src_desc(), *engine); } else if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); + input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1755,7 +1763,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; + memory::dims input_dims = {M, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; memory::dims shift_dims = {M, N}; @@ -1807,9 +1815,9 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); + input_mem = memory(matmul_pd->src_desc(), *engine); } else if constexpr (std::is_same_v) { - input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); + input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } @@ -1860,7 +1868,7 @@ class MMHelper { matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B) and destination (C) matrix dimensions. - memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; + memory::dims input_dims = {M, K}; memory::dims weight_dims = {K, N}; memory::dims output_dims = {M, N}; @@ -1880,7 +1888,7 @@ class MMHelper { matmul_hub[key] = value; } - auto input_mem = memory({{M, K}, dt::s8, tag::ab}, *engine, const_cast(A)); + auto input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(B)); auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); From a42b367c65c0f594ca432b1d102fdcba87a8b5f1 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Mon, 3 Jun 2024 09:19:46 +0800 Subject: [PATCH 3/5] Fix oneDNN MatMul issue. --- ci_build | 1 - src/utils/matmul_helper.h | 119 +++++++++++++++++++++++++------------- 2 files changed, 79 insertions(+), 41 deletions(-) diff --git a/ci_build b/ci_build index 00efbc21..073f9fbb 100755 --- a/ci_build +++ b/ci_build @@ -14,7 +14,6 @@ # limitations under the License. # ============================================================================ -source ~/.bashrc pushd 3rdparty/ sh prepare_oneccl.sh source ./oneccl/build/_install/env/setvars.sh diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index ee00b060..2a925b15 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -1369,16 +1369,16 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, 0, N, K, postAlg); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B) and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; - memory::dims output_dims = {M, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; // Create memory descriptors and memory objects for src, weights, bias, and dst. auto input_md = memory::desc(input_dims, dt::bf16, tag::ab); @@ -1425,7 +1425,7 @@ class MMHelper { } matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg); + std::string key = create_key(transA, 0, N, K, postAlg); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1433,17 +1433,23 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - // input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - // input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -1477,17 +1483,17 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; - memory::dims output_dims = {M, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; // Create memory descriptors and memory objects for src, weights, bias, and dst. auto input_md = memory::desc(input_dims, dt::bf16, tag::ab); @@ -1507,7 +1513,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1515,16 +1521,24 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); auto bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -1559,17 +1573,17 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd_Relu); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; - memory::dims output_dims = {M, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; // Create primitive descriptor. auto input_md = memory::desc(input_dims, dt::bf16, tag::ab); @@ -1597,7 +1611,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd_Relu); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1605,16 +1619,24 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); auto bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -1649,17 +1671,17 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Resmul); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; - memory::dims scale_dims = {M, N}; - memory::dims output_dims = {M, N}; + memory::dims scale_dims = {DNNL_RUNTIME_DIM_VAL, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; // Create primitive descriptor. auto input_md = memory::desc(input_dims, dt::bf16, tag::ab); @@ -1688,9 +1710,10 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Resmul); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; + printf(">>> onednn"); } // Repack and convert input data. @@ -1711,15 +1734,23 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -1756,18 +1787,18 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Residential); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; - memory::dims shift_dims = {M, N}; - memory::dims output_dims = {M, N}; + memory::dims shift_dims = {DNNL_RUNTIME_DIM_VAL, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; // Create primitive descriptor. auto input_md = memory::desc(input_dims, dt::bf16, tag::ab); @@ -1802,7 +1833,7 @@ class MMHelper { } // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Residential); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1815,18 +1846,26 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); memory bias_mem; - auto shift_mem = memory(shift_md, *engine, (void *)res); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); if (bias != nullptr) { bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); } + auto shift_mem = memory(shift_md, *engine, (void *)res); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; From 32e518eda241ba524e3410024f618230bc199544 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Wed, 5 Jun 2024 20:09:08 +0800 Subject: [PATCH 4/5] Add Benchmark --- tests/ut/gemm_kernel_test.cpp | 179 ++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 tests/ut/gemm_kernel_test.cpp diff --git a/tests/ut/gemm_kernel_test.cpp b/tests/ut/gemm_kernel_test.cpp new file mode 100644 index 00000000..702b5655 --- /dev/null +++ b/tests/ut/gemm_kernel_test.cpp @@ -0,0 +1,179 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include "matmul_helper.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +// Test function to compare reference and optimized implementations +template +void test_gemm(MMHelper *mm, int M, int N, int K) { + std::unique_ptr A = std::make_unique(M * K); + std::unique_ptr B = std::make_unique(K * N); + std::unique_ptr C = std::make_unique(M * N); + std::unique_ptr bias = std::make_unique(N); + + // Generate random matrices A and B + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution distr(0.0f, 1.0f); + + for (int i = 0; i < M * K; i++) { + A[i] = static_cast(distr(rng)); + } + for (int i = 0; i < K * N; i++) { + B[i] = static_cast(distr(rng)); + } + for (int i = 0; i < N; i++) { + bias[i] = static_cast(distr(rng)); + } + + std::chrono::system_clock::time_point start, end; + float during_time; + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_bias M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_bias( + false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, bias.get()); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_biasadd_relu M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_biasadd_relu( + false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, bias.get()); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_gelu M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_gelu(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_silu M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_silu(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + // memset(C.get(), 0, M * N * sizeof(TC)); + // printf("[ RUNTIME ] MMHelper::compute_resmul M: %d, N: %d, K: %d\t", M, N, K); + // start = std::chrono::high_resolution_clock::now(); + // mm->compute_resmul( + // false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, A.get(), K); + // end = std::chrono::high_resolution_clock::now(); + // during_time = std::chrono::duration(end - start).count(); + // printf("%.6f sec\n", during_time); + + // memset(C.get(), 0, M * N * sizeof(TC)); + // printf("[ RUNTIME ] MMHelper::compute_resext M: %d, N: %d, K: %d\t", M, N, K); + // start = std::chrono::high_resolution_clock::now(); + // mm->compute_resext(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, + // bias.get(), 1.0, A.get(), K); + // end = std::chrono::high_resolution_clock::now(); + // during_time = std::chrono::duration(end - start).count(); + // printf("%.6f sec\n", during_time); + + // memset(C.get(), 0, M * N * sizeof(TC)); + // printf("[ RUNTIME ] MMHelper::compute_residential M: %d, N: %d, K: %d\t", M, N, K); + // start = std::chrono::high_resolution_clock::now(); + // mm->compute_residential(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, + // bias.get(), A.get(), K); + // end = std::chrono::high_resolution_clock::now(); + // during_time = std::chrono::duration(end - start).count(); + // printf("%.6f sec\n", during_time); +} + +// TEST(MMHelper, gemm_f32f16f32) { +// std::vector M(8192); +// std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); +// std::vector N = {4096, 5120, 7168, 8192}; +// std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; + +// for (int j = 0; j < N.size(); ++j) { +// std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); +// for (int t = j; t < K.size(); ++t) { +// for (int i = 0; i < M.size(); ++i) { +// test_gemm(mm.get(), M[i], N[j], K[t]); +// } +// } +// std::string name; +// std::cout << "Enter your name: "; +// std::getline(std::cin, name); +// } +// } + +TEST(MMHelper, gemm_f32bf16f32) { + std::vector M(8192); + std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); + std::vector N = {4096, 5120, 7168, 8192}; + std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; + + for (int j = 0; j < N.size(); ++j) { + std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); + for (int t = j; t < K.size(); ++t) { + for (int i = 0; i < M.size(); ++i) { + test_gemm(mm.get(), M[i], N[j], K[t]); + } + } + std::string name; + std::cout << "Enter your name: "; + std::getline(std::cin, name); + } +} + +// TEST(MMHelper, gemm_bf16bf16bf16) { +// std::vector M(8192); +// std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); +// std::vector N = {4096, 5120, 7168, 8192}; +// std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; + +// for (int j = 0; j < N.size(); ++j) { +// std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); +// for (int t = j; t < K.size(); ++t) { +// for (int i = 0; i < M.size(); ++i) { +// test_gemm(mm.get(), M[i], N[j], K[t]); +// } +// } +// std::string name; +// std::cout << "Enter your name: "; +// std::getline(std::cin, name); +// } +// } + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file From 88a4485aac81e08ccdfcf4bde492b95a3f405a99 Mon Sep 17 00:00:00 2001 From: changqi1 Date: Wed, 5 Jun 2024 21:41:28 +0800 Subject: [PATCH 5/5] update --- tests/ut/gemm_kernel_test.cpp | 43 ++++++++++++++--------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/tests/ut/gemm_kernel_test.cpp b/tests/ut/gemm_kernel_test.cpp index 702b5655..43aaf014 100644 --- a/tests/ut/gemm_kernel_test.cpp +++ b/tests/ut/gemm_kernel_test.cpp @@ -117,7 +117,7 @@ void test_gemm(MMHelper *mm, int M, int N, int K) { } // TEST(MMHelper, gemm_f32f16f32) { -// std::vector M(8192); +// std::vector M(1024); // std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); // std::vector N = {4096, 5120, 7168, 8192}; // std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; @@ -129,33 +129,27 @@ void test_gemm(MMHelper *mm, int M, int N, int K) { // test_gemm(mm.get(), M[i], N[j], K[t]); // } // } -// std::string name; -// std::cout << "Enter your name: "; -// std::getline(std::cin, name); // } // } -TEST(MMHelper, gemm_f32bf16f32) { - std::vector M(8192); - std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); - std::vector N = {4096, 5120, 7168, 8192}; - std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; - - for (int j = 0; j < N.size(); ++j) { - std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); - for (int t = j; t < K.size(); ++t) { - for (int i = 0; i < M.size(); ++i) { - test_gemm(mm.get(), M[i], N[j], K[t]); - } - } - std::string name; - std::cout << "Enter your name: "; - std::getline(std::cin, name); - } -} +// TEST(MMHelper, gemm_f32bf16f32) { +// std::vector M(1024); +// std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); +// std::vector N = {4096, 5120, 7168, 8192}; +// std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; + +// for (int j = 0; j < N.size(); ++j) { +// std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); +// for (int t = j; t < K.size(); ++t) { +// for (int i = 0; i < M.size(); ++i) { +// test_gemm(mm.get(), M[i], N[j], K[t]); +// } +// } +// } +// } // TEST(MMHelper, gemm_bf16bf16bf16) { -// std::vector M(8192); +// std::vector M(1024); // std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); // std::vector N = {4096, 5120, 7168, 8192}; // std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; @@ -167,9 +161,6 @@ TEST(MMHelper, gemm_f32bf16f32) { // test_gemm(mm.get(), M[i], N[j], K[t]); // } // } -// std::string name; -// std::cout << "Enter your name: "; -// std::getline(std::cin, name); // } // }