diff --git a/src/utils/matmul_helper.h b/src/utils/matmul_helper.h index a8cc291c..df06849d 100644 --- a/src/utils/matmul_helper.h +++ b/src/utils/matmul_helper.h @@ -793,14 +793,20 @@ class MMHelper { onednn_amx_gemm_compute_biasadd_relu( transA, M, N, K, alpha, A, lda, (const float16_t *)packedB, beta, C, ldc, bias)); #elif defined(AVX512_BF16_WEIGHT_ONLY_BF16) - if (M > AMXThresholdM) { - GEMMVERBOSE("onednn_amx_gemm_compute_biasadd_relu", - onednn_amx_gemm_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); + if constexpr (std::is_same_v) { + GEMMVERBOSE("onednn_amx_gemm_compute_biasadd_relu", + onednn_amx_gemm_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); } else { - GEMMVERBOSE("xdnn_bgemm_f32bf16f32_compute_biasadd_relu", - xdnn_bgemm_f32bf16f32_compute_biasadd_relu( - transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); + if (M > AMXThresholdM) { + GEMMVERBOSE("onednn_amx_gemm_compute_biasadd_relu", + onednn_amx_gemm_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, packedB, beta, C, ldc, bias)); + } else { + GEMMVERBOSE("xdnn_bgemm_f32bf16f32_compute_biasadd_relu", + xdnn_bgemm_f32bf16f32_compute_biasadd_relu( + transA, M, N, K, alpha, A, lda, (const XDNN_BF16 *)packedB, beta, C, ldc, bias)); + } } #else printf("%s:%d: Need to define WEIGHT_ONLY_BF16 kernel data type.\n", __FILE__, __LINE__); @@ -1571,8 +1577,7 @@ class MMHelper { Resext, }; - template - std::string create_key(bool transA, int M, int N, int K, int matmul_kind, const Twei *packedB) { + std::string create_key(bool transA, int M, int N, int K, int matmul_kind, const void *packedB = nullptr) { std::stringstream key; key << transA << "_" << M << "_" << N << "_" << K << "_" << matmul_kind << "_" << packedB; return key.str(); @@ -1722,7 +1727,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg, (const float *)nullptr); + std::string key = create_key(transA, M, N, K, postAlg); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); @@ -1797,7 +1802,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg, (const float *)nullptr); + std::string key = create_key(transA, M, N, K, postAlg); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1875,16 +1880,16 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, postAlg, packedB); + std::string key = create_key(transA, 0, N, K, postAlg); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B) and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; - memory::dims output_dims = {M, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; dt dt_x16 = std::is_same_v ? dt::bf16 : std::is_same_v ? dt::f16 : dt::undef; @@ -1937,7 +1942,7 @@ class MMHelper { } matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, postAlg, packedB); + std::string key = create_key(transA, 0, N, K, postAlg); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -1945,17 +1950,27 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::f32, get_onednn_input_layout(dt::f32)}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, get_onednn_input_layout(dt::bf16)}, *engine, const_cast(A)); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::f16, get_onednn_input_layout(dt::f16)}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } - auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, get_onednn_output_layout(dt::f32)}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, get_onednn_output_layout(dt::bf16)}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f16, get_onednn_output_layout(dt::f16)}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -1993,23 +2008,23 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; - memory::dims output_dims = {M, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; dt dt_x16 = std::is_same_v ? dt::bf16 : std::is_same_v ? dt::f16 : dt::undef; // Create memory descriptors and memory objects for src, weights, bias, and dst. - auto input_md = memory::desc(input_dims, dt_x16, tag::ab); + auto input_md = memory::desc(input_dims, dt_x16, get_onednn_input_layout(dt_x16)); auto weight_md = memory::desc(weight_dims, dt_x16, get_onednn_weight_layout(dt_x16)); auto bias_md = memory::desc(bias_dims, dt::f32, tag::ab); memory::desc output_md; @@ -2028,7 +2043,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -2036,18 +2051,28 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::f16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); auto bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -2086,17 +2111,17 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd_Relu); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; - memory::dims output_dims = {M, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; dt dt_x16 = std::is_same_v ? dt::bf16 : std::is_same_v ? dt::f16 : dt::undef; @@ -2129,7 +2154,7 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::BiasAdd_Relu, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::BiasAdd_Relu); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -2137,18 +2162,28 @@ class MMHelper { // Repack and convert input data. memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::f16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); auto bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -2187,17 +2222,17 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Resmul); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; - memory::dims scale_dims = {M, N}; - memory::dims output_dims = {M, N}; + memory::dims scale_dims = {DNNL_RUNTIME_DIM_VAL, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; dt dt_x16 = std::is_same_v ? dt::bf16 : std::is_same_v ? dt::f16 : dt::undef; @@ -2232,9 +2267,10 @@ class MMHelper { matmul_prim = new matmul(*matmul_pd); // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Resmul, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Resmul); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; + printf(">>> onednn"); } // Repack and convert input data. @@ -2256,17 +2292,27 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::f16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } - auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); + auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -2307,18 +2353,18 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Residential); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); matmul_prim = std::get<1>(it->second); } else { // Source (A), weights (B), and destination (C) matrix dimensions. - memory::dims input_dims = {M, K}; + memory::dims input_dims = {DNNL_RUNTIME_DIM_VAL, K}; memory::dims weight_dims = {K, N}; memory::dims bias_dims = {1, N}; - memory::dims shift_dims = {M, N}; - memory::dims output_dims = {M, N}; + memory::dims shift_dims = {DNNL_RUNTIME_DIM_VAL, N}; + memory::dims output_dims = {DNNL_RUNTIME_DIM_VAL, N}; dt dt_x16 = std::is_same_v ? dt::bf16 : std::is_same_v ? dt::f16 : dt::undef; @@ -2359,7 +2405,7 @@ class MMHelper { } // Cache primitive_desc and matmul - std::string key = create_key(transA, M, N, K, matmul_kinds::Residential, packedB); + std::string key = create_key(transA, 0, N, K, matmul_kinds::Residential); std::tuple value(matmul_pd, matmul_prim); matmul_hub[key] = value; } @@ -2373,20 +2419,30 @@ class MMHelper { memory input_mem; if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::bf16, tag::ab}, *engine, const_cast(A)); } else if constexpr (std::is_same_v) { - input_mem = memory(matmul_pd->src_desc(), *engine, const_cast(A)); + input_mem = memory({{M, K}, dt::f16, tag::ab}, *engine, const_cast(A)); } else { printf(">>> onednn amx input date type not supported."); } auto weight_mem = memory(matmul_pd->weights_desc(), *engine, const_cast(packedB)); memory bias_mem; - auto shift_mem = memory(shift_md, *engine, (void *)res); - auto output_mem = memory(matmul_pd->dst_desc(), *engine, C); if (bias != nullptr) { bias_mem = memory(matmul_pd->bias_desc(), *engine, const_cast(bias)); } + auto shift_mem = memory(shift_md, *engine, (void *)res); + memory output_mem; + if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f32, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::bf16, tag::ab}, *engine, C); + } else if constexpr (std::is_same_v) { + output_mem = memory({{M, N}, dt::f16, tag::ab}, *engine, C); + } else { + printf(">>> onednn amx output date type not supported."); + exit(-1); + } // Create the primitive args. std::unordered_map matmul_args; @@ -2425,7 +2481,7 @@ class MMHelper { matmul::primitive_desc *matmul_pd; matmul *matmul_prim; - std::string key = create_key(transA, M, N, K, matmul_kinds::Basic, B); + std::string key = create_key(transA, M, N, K, matmul_kinds::Basic); auto it = matmul_hub.find(key); if (it != matmul_hub.end()) { matmul_pd = std::get<0>(it->second); diff --git a/tests/ut/gemm_kernel_test.cpp b/tests/ut/gemm_kernel_test.cpp new file mode 100644 index 00000000..43aaf014 --- /dev/null +++ b/tests/ut/gemm_kernel_test.cpp @@ -0,0 +1,170 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================ +#include "matmul_helper.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +// Test function to compare reference and optimized implementations +template +void test_gemm(MMHelper *mm, int M, int N, int K) { + std::unique_ptr A = std::make_unique(M * K); + std::unique_ptr B = std::make_unique(K * N); + std::unique_ptr C = std::make_unique(M * N); + std::unique_ptr bias = std::make_unique(N); + + // Generate random matrices A and B + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution distr(0.0f, 1.0f); + + for (int i = 0; i < M * K; i++) { + A[i] = static_cast(distr(rng)); + } + for (int i = 0; i < K * N; i++) { + B[i] = static_cast(distr(rng)); + } + for (int i = 0; i < N; i++) { + bias[i] = static_cast(distr(rng)); + } + + std::chrono::system_clock::time_point start, end; + float during_time; + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_bias M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_bias( + false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, bias.get()); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_biasadd_relu M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_biasadd_relu( + false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, bias.get()); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_gelu M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_gelu(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + memset(C.get(), 0, M * N * sizeof(TC)); + printf("[ RUNTIME ] MMHelper::compute_silu M: %d, N: %d, K: %d\t", M, N, K); + start = std::chrono::high_resolution_clock::now(); + mm->compute_silu(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N); + end = std::chrono::high_resolution_clock::now(); + during_time = std::chrono::duration(end - start).count(); + printf("%.6f sec\n", during_time); + + // memset(C.get(), 0, M * N * sizeof(TC)); + // printf("[ RUNTIME ] MMHelper::compute_resmul M: %d, N: %d, K: %d\t", M, N, K); + // start = std::chrono::high_resolution_clock::now(); + // mm->compute_resmul( + // false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, A.get(), K); + // end = std::chrono::high_resolution_clock::now(); + // during_time = std::chrono::duration(end - start).count(); + // printf("%.6f sec\n", during_time); + + // memset(C.get(), 0, M * N * sizeof(TC)); + // printf("[ RUNTIME ] MMHelper::compute_resext M: %d, N: %d, K: %d\t", M, N, K); + // start = std::chrono::high_resolution_clock::now(); + // mm->compute_resext(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, + // bias.get(), 1.0, A.get(), K); + // end = std::chrono::high_resolution_clock::now(); + // during_time = std::chrono::duration(end - start).count(); + // printf("%.6f sec\n", during_time); + + // memset(C.get(), 0, M * N * sizeof(TC)); + // printf("[ RUNTIME ] MMHelper::compute_residential M: %d, N: %d, K: %d\t", M, N, K); + // start = std::chrono::high_resolution_clock::now(); + // mm->compute_residential(false, M, N, K, 1.0f, A.get(), K, B.get(), nullptr, nullptr, nullptr, 0.0f, C.get(), N, + // bias.get(), A.get(), K); + // end = std::chrono::high_resolution_clock::now(); + // during_time = std::chrono::duration(end - start).count(); + // printf("%.6f sec\n", during_time); +} + +// TEST(MMHelper, gemm_f32f16f32) { +// std::vector M(1024); +// std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); +// std::vector N = {4096, 5120, 7168, 8192}; +// std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; + +// for (int j = 0; j < N.size(); ++j) { +// std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); +// for (int t = j; t < K.size(); ++t) { +// for (int i = 0; i < M.size(); ++i) { +// test_gemm(mm.get(), M[i], N[j], K[t]); +// } +// } +// } +// } + +// TEST(MMHelper, gemm_f32bf16f32) { +// std::vector M(1024); +// std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); +// std::vector N = {4096, 5120, 7168, 8192}; +// std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; + +// for (int j = 0; j < N.size(); ++j) { +// std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); +// for (int t = j; t < K.size(); ++t) { +// for (int i = 0; i < M.size(); ++i) { +// test_gemm(mm.get(), M[i], N[j], K[t]); +// } +// } +// } +// } + +// TEST(MMHelper, gemm_bf16bf16bf16) { +// std::vector M(1024); +// std::generate(M.begin(), M.end(), [n = 1]() mutable { return n += 1; }); +// std::vector N = {4096, 5120, 7168, 8192}; +// std::vector K = {4096, 5120, 7168, 8192, 11008, 13696, 13824, 28672}; + +// for (int j = 0; j < N.size(); ++j) { +// std::unique_ptr mm = std::make_unique(xft::DeviceKind::iCPU, 0); +// for (int t = j; t < K.size(); ++t) { +// for (int i = 0; i < M.size(); ++i) { +// test_gemm(mm.get(), M[i], N[j], K[t]); +// } +// } +// } +// } + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file