From a453a440c8dd5d86161de24dc80db78a93f862a1 Mon Sep 17 00:00:00 2001 From: Aayush Gupta Date: Fri, 19 Jul 2024 19:07:27 +0000 Subject: [PATCH] Expanded Host BLAS support --- CMakeLists.txt | 33 ++- cmake/FindBLIS.cmake | 81 ++++++++ cmake/FindOpenBLAS.cmake | 47 +++++ include/matx.h | 6 +- include/matx/executors/support.h | 2 +- include/matx/transforms/matmul/matmul_cblas.h | 188 +++++++++++++----- 6 files changed, 300 insertions(+), 57 deletions(-) create mode 100644 cmake/FindBLIS.cmake create mode 100644 cmake/FindOpenBLAS.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index d618b123..c7a403da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,8 @@ option(MATX_EN_CUTENSOR OFF) option(MATX_EN_FILEIO OFF) option(MATX_EN_X86_FFTW OFF "Enable x86 FFTW support") option(MATX_EN_NVPL OFF, "Enable NVIDIA Performance Libraries for optimized ARM CPU support") +option(MATX_EN_BLIS OFF "Enable BLIS support") +option(MATX_EN_OPENBLAS OFF "Enable OpenBLAS support") option(MATX_DISABLE_CUB_CACHE "Disable caching for CUB allocations" ON) option(MATX_EN_COVERAGE OFF "Enable code coverage reporting") @@ -68,6 +70,7 @@ set(MATX_EN_PYBIND11 OFF CACHE BOOL "Enable pybind11 support") set(cutensor_DIR "" CACHE PATH "Directory where cuTENSOR is installed.") set(cutensornet_DIR "" CACHE PATH "Directory where cuTensorNet is installed.") set(eigen_DIR "" CACHE PATH "Directory where Eigen is installed") +set(blas_DIR "" CACHE PATH "Directory where a BLAS library (NVPL/OpenBLAS/BLIS) is installed (install prefix)") # Enable compile_commands.json set(CMAKE_EXPORT_COMPILE_COMMANDS ON) @@ -187,21 +190,34 @@ else() endif() # Host support -if (MATX_EN_NVPL OR MATX_EN_X86_FFTW) +if (MATX_EN_NVPL OR MATX_EN_X86_FFTW OR MATX_EN_BLIS OR MATX_EN_OPENBLAS) message(STATUS "Enabling OpenMP support") find_package(OpenMP REQUIRED) target_link_libraries(matx INTERFACE OpenMP::OpenMP_CXX) target_compile_options(matx INTERFACE ${OpenMP_CXX_FLAGS}) target_compile_definitions(matx INTERFACE MATX_EN_OMP=1) + + set(BLAS_FLAGS MATX_EN_NVPL MATX_EN_BLIS MATX_EN_OPENBLAS) + set(ENABLED_BLAS_COUNT 0) + foreach(BLAS_FLAG IN LISTS BLAS_FLAGS) + if(${BLAS_FLAG}) + math(EXPR ENABLED_BLAS_COUNT "${ENABLED_BLAS_COUNT} + 1") + endif() + endforeach() + if(ENABLED_BLAS_COUNT GREATER 1) + message(WARNING "Multiple Host BLAS libraries (${ENABLED_BLAS_COUNT}) are enabled. Only 1 will be used.") + endif() + if (MATX_EN_NVPL) message(STATUS "Enabling NVPL library support for ARM CPUs with ${INT_TYPE} interface") - find_package(nvpl REQUIRED COMPONENTS fft blas) + find_package(nvpl REQUIRED COMPONENTS fft blas HINTS ${blas_DIR}) if (NOT MATX_BUILD_32_BIT) target_compile_definitions(matx INTERFACE NVPL_ILP64) endif() target_link_libraries(matx INTERFACE nvpl::fftw nvpl::blas_${INT_TYPE}_omp) target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1) else() + # FFTW if (MATX_EN_X86_FFTW) message(STATUS "Enabling x86 FFTW") find_library(FFTW_LIB fftw3 REQUIRED) @@ -211,6 +227,19 @@ if (MATX_EN_NVPL OR MATX_EN_X86_FFTW) target_link_libraries(matx INTERFACE ${FFTW_LIB} ${FFTWF_LIB} ${FFTW_OMP_LIB} ${FFTWF_OMP_LIB}) target_compile_definitions(matx INTERFACE MATX_EN_X86_FFTW=1) endif() + + # BLAS + if (MATX_EN_BLIS) + message(STATUS "Enabling BLIS") + include(cmake/FindBLIS.cmake) + target_link_libraries(matx INTERFACE BLIS::BLIS) + target_compile_definitions(matx INTERFACE MATX_EN_BLIS=1) + elseif(MATX_EN_OPENBLAS) + message(STATUS "Enabling OpenBLAS") + include(cmake/FindOpenBLAS.cmake) + target_link_libraries(matx INTERFACE OpenBLAS::OpenBLAS) + target_compile_definitions(matx INTERFACE MATX_EN_OPENBLAS=1) + endif() endif() endif() diff --git a/cmake/FindBLIS.cmake b/cmake/FindBLIS.cmake new file mode 100644 index 00000000..f436131f --- /dev/null +++ b/cmake/FindBLIS.cmake @@ -0,0 +1,81 @@ +# //////////////////////////////////////////////////////////////////////////////// +# // BSD 3-Clause License +# // +# // Copyright (c) 2021, NVIDIA Corporation +# // All rights reserved. +# // +# // Redistribution and use in source and binary forms, with or without +# // modification, are permitted provided that the following conditions are met: +# // +# // 1. Redistributions of source code must retain the above copyright notice, this +# // list of conditions and the following disclaimer. +# // +# // 2. Redistributions in binary form must reproduce the above copyright notice, +# // this list of conditions and the following disclaimer in the documentation +# // and/or other materials provided with the distribution. +# // +# // 3. Neither the name of the copyright holder nor the names of its +# // contributors may be used to endorse or promote products derived from +# // this software without specific prior written permission. +# // +# // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ///////////////////////////////////////////////////////////////////////////////// + +# Try using the .pc file first in case it was installed from source +find_package(PkgConfig) + +if(PkgConfig_FOUND) + set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:${blas_DIR}/share/pkgconfig") + pkg_check_modules(BLIS QUIET blis) + if(BLIS_FOUND) + set(BLIS_LIBRARIES ${pkgcfg_lib_BLIS_blis}) + endif() +endif() + +# If not found, search for the BLIS library directly +if(NOT BLIS_FOUND) + find_library(BLIS_LIBRARIES NAMES blis64 blis HINTS ${blas_DIR}/lib) + + if(BLIS_LIBRARIES) + if(BLIS_LIBRARIES MATCHES ".*blis64.*") + # If the 64-bit index version is installed using a package manager like apt, + # the header files are blis64.h and cblas64.h. + set(BLIS_INCLUDE_NAME blis64.h) + set(BLIS_64_HEADER TRUE) + else() + set(BLIS_INCLUDE_NAME blis.h) + endif() + + find_path(BLIS_INCLUDE_DIRS + NAMES ${BLIS_INCLUDE_NAME} + HINTS ${blas_DIR}/include + REQUIRED + ) + + set(BLIS_FOUND TRUE) + endif() +endif() + +if(NOT BLIS_FOUND) + message(FATAL_ERROR "BLIS not found") +endif() + +if(NOT TARGET BLIS::BLIS) + add_library(BLIS::BLIS INTERFACE IMPORTED) + set_target_properties(BLIS::BLIS PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${BLIS_INCLUDE_DIRS}" + INTERFACE_LINK_LIBRARIES "${BLIS_LIBRARIES}" + ) + if(BLIS_64_HEADER) + target_compile_definitions(matx INTERFACE MATX_BLIS_64_HEADER=1) + endif() +endif() \ No newline at end of file diff --git a/cmake/FindOpenBLAS.cmake b/cmake/FindOpenBLAS.cmake new file mode 100644 index 00000000..22fbae3d --- /dev/null +++ b/cmake/FindOpenBLAS.cmake @@ -0,0 +1,47 @@ +# //////////////////////////////////////////////////////////////////////////////// +# // BSD 3-Clause License +# // +# // Copyright (c) 2021, NVIDIA Corporation +# // All rights reserved. +# // +# // Redistribution and use in source and binary forms, with or without +# // modification, are permitted provided that the following conditions are met: +# // +# // 1. Redistributions of source code must retain the above copyright notice, this +# // list of conditions and the following disclaimer. +# // +# // 2. Redistributions in binary form must reproduce the above copyright notice, +# // this list of conditions and the following disclaimer in the documentation +# // and/or other materials provided with the distribution. +# // +# // 3. Neither the name of the copyright holder nor the names of its +# // contributors may be used to endorse or promote products derived from +# // this software without specific prior written permission. +# // +# // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ///////////////////////////////////////////////////////////////////////////////// + +set(OPENBLAS_PATH_SUFFIXES "cmake/openblas") + +find_package(OpenBLAS CONFIG QUIET + HINTS ${blas_DIR} + PATH_SUFFIXES ${OPENBLAS_PATH_SUFFIXES} + REQUIRED +) + +if(NOT TARGET OpenBLAS::OpenBLAS) + add_library(OpenBLAS::OpenBLAS INTERFACE IMPORTED) + set_target_properties(OpenBLAS::OpenBLAS PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${OpenBLAS_INCLUDE_DIRS}" + INTERFACE_LINK_LIBRARIES "${OpenBLAS_LIBRARIES}" + ) +endif() \ No newline at end of file diff --git a/include/matx.h b/include/matx.h index 96cb477a..a5ecad55 100644 --- a/include/matx.h +++ b/include/matx.h @@ -53,8 +53,10 @@ #include "matx/operators/operators.h" #include "matx/transforms/transforms.h" -using fcomplex = cuda::std::complex; -using dcomplex = cuda::std::complex; +namespace matx { + using fcomplex = cuda::std::complex; + using dcomplex = cuda::std::complex; +} #define TEST_VECTOR_PATH "generated/" diff --git a/include/matx/executors/support.h b/include/matx/executors/support.h index f5d83837..04098b9c 100644 --- a/include/matx/executors/support.h +++ b/include/matx/executors/support.h @@ -47,7 +47,7 @@ namespace matx { #endif // MatMul -#if defined(MATX_EN_NVPL) +#if defined(MATX_EN_NVPL) || defined(MATX_EN_OPENBLAS) || defined(MATX_EN_BLIS) #define MATX_EN_CPU_MATMUL 1 #else #define MATX_EN_CPU_MATMUL 0 diff --git a/include/matx/transforms/matmul/matmul_cblas.h b/include/matx/transforms/matmul/matmul_cblas.h index ddb1ae33..a1bd2a22 100644 --- a/include/matx/transforms/matmul/matmul_cblas.h +++ b/include/matx/transforms/matmul/matmul_cblas.h @@ -46,6 +46,18 @@ #ifdef MATX_EN_NVPL #include using cblas_int_t = nvpl_int_t; +#elif defined(MATX_EN_OPENBLAS) + #include + using cblas_int_t = blasint; +#elif defined(MATX_EN_BLIS) + #ifdef MATX_BLIS_64_HEADER + #include + #include + #else + #include + #include + #endif + using cblas_int_t = f77_int; #endif namespace matx { @@ -113,20 +125,20 @@ static MatMulCBLASParams_t GetGemmParams(TensorTypeC &c, // If we have a 3D or above tensor, the upper dims are batch dimensions. if constexpr (RANK >= 3) { - params.batch = (c.Size(RANK - 3)); + params.batch = static_cast(c.Size(RANK - 3)); if constexpr (TensorTypeA::Rank() == RANK) { - params.astride = (a.Stride(TensorTypeA::Rank() - 3)); + params.astride = static_cast(a.Stride(TensorTypeA::Rank() - 3)); } else { params.astride = 0; } if constexpr (TensorTypeB::Rank() == RANK) { - params.bstride = (b.Stride(TensorTypeB::Rank() - 3)); + params.bstride = static_cast(b.Stride(TensorTypeB::Rank() - 3)); } else { params.bstride = 0; } - params.cstride = (c.Stride(RANK - 3)); + params.cstride = static_cast(c.Stride(RANK - 3)); } // At this point, the transpose mode on C case has already been handled @@ -147,24 +159,24 @@ static MatMulCBLASParams_t GetGemmParams(TensorTypeC &c, // doesn't like even though it's unused. Set it to something that it would be // if the matrix had more than 1 row. if (params.opB == CblasTrans) { - params.ldb = b.Stride(TensorTypeB::Rank() - 1); + params.ldb = static_cast(b.Stride(TensorTypeB::Rank() - 1)); } else { - params.ldb = b.Stride(TensorTypeB::Rank() - 2); - params.ldb = (params.ldb == 0) ? b.Size(TensorTypeB::Rank() - 1) : params.ldb; + params.ldb = static_cast(b.Stride(TensorTypeB::Rank() - 2)); + params.ldb = (params.ldb == 0) ? static_cast(b.Size(TensorTypeB::Rank() - 1)) : params.ldb; } if (params.opA == CblasTrans) { - params.lda = a.Stride(TensorTypeA::Rank() - 1); + params.lda = static_cast(a.Stride(TensorTypeA::Rank() - 1)); } else { - params.lda = a.Stride(TensorTypeA::Rank() - 2); - params.lda = (params.lda == 0) ? a.Size(TensorTypeA::Rank() - 1) : params.lda; + params.lda = static_cast(a.Stride(TensorTypeA::Rank() - 2)); + params.lda = (params.lda == 0) ? static_cast(a.Size(TensorTypeA::Rank() - 1)) : params.lda; } - params.ldc = c.Stride(RANK - 2); + params.ldc = static_cast(c.Stride(RANK - 2)); - params.m = a.Size(TensorTypeA::Rank() - 2); - params.n = b.Size(TensorTypeB::Rank() - 1); - params.k = a.Size(TensorTypeA::Rank() - 1); + params.m = static_cast(a.Size(TensorTypeA::Rank() - 2)); + params.n = static_cast(b.Size(TensorTypeB::Rank() - 1)); + params.k = static_cast(a.Size(TensorTypeA::Rank() - 1)); return params; } @@ -228,7 +240,6 @@ __MATX_INLINE__ void matmul_exec(TensorTypeC &c, MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) static constexpr int RANK = TensorTypeC::Rank(); - static constexpr int GROUP_COUNT = 1; using scalar_type = typename TensorTypeC::scalar_type; // Prep for batch looping @@ -238,11 +249,11 @@ __MATX_INLINE__ void matmul_exec(TensorTypeC &c, [[maybe_unused]] cuda::std::array c_idx{0}; [[maybe_unused]] size_t total_iter = 1; - if constexpr (RANK > 2) { + if constexpr (RANK > 3) { // Get total number of batches auto c_shape = c.Shape(); total_iter = std::accumulate(c_shape.begin(), - c_shape.begin() + TensorTypeC::Rank() - 2, 1, + c_shape.begin() + TensorTypeC::Rank() - 3, 1, std::multiplies()); } @@ -259,53 +270,126 @@ __MATX_INLINE__ void matmul_exec(TensorTypeC &c, sbeta = beta; } - std::vector a_array(total_iter); - std::vector b_array(total_iter); - std::vector c_array(total_iter); - +#ifdef MATX_EN_NVPL + if constexpr (RANK <= 3) { + auto a_ptr = a.Data(); + auto b_ptr = b.Data(); + auto c_ptr = c.Data(); + + if constexpr (std::is_same_v) { + cblas_sgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, salpha, + a_ptr, params.lda, params.astride, + b_ptr, params.ldb, params.bstride, sbeta, + c_ptr, params.ldc, params.cstride, params.batch); + } else if constexpr (std::is_same_v) { + cblas_dgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, salpha, + a_ptr, params.lda, params.astride, + b_ptr, params.ldb, params.bstride, sbeta, + c_ptr, params.ldc, params.cstride, params.batch); + } else if constexpr (std::is_same_v>) { + cblas_cgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, (void *)&salpha, + (void *)a_ptr, params.lda, params.astride, + (void *)b_ptr, params.ldb, params.bstride, (void *)&sbeta, + (void *)c_ptr, params.ldc, params.cstride, params.batch); + } else if constexpr (std::is_same_v>) { + cblas_zgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, (void *)&salpha, + (void *)a_ptr, params.lda, params.astride, + (void *)b_ptr, params.ldb, params.bstride, (void *)&sbeta, + (void *)c_ptr, params.ldc, params.cstride, params.batch); + } + } else { + for (size_t iter = 0; iter < total_iter; iter++) { + + // Get pointers into A/B/C for this round + auto ap = cuda::std::apply([&a](auto... param) { return a.GetPointer(param...); }, a_idx); + auto bp = cuda::std::apply([&b](auto... param) { return b.GetPointer(param...); }, b_idx); + auto cp = cuda::std::apply([&c](auto... param) { return c.GetPointer(param...); }, c_idx); + + if constexpr (std::is_same_v) { + cblas_sgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, salpha, + ap, params.lda, params.astride, + bp, params.ldb, params.bstride, sbeta, + cp, params.ldc, params.cstride, params.batch); + } else if constexpr (std::is_same_v) { + cblas_dgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, salpha, + ap, params.lda, params.astride, + bp, params.ldb, params.bstride, sbeta, + cp, params.ldc, params.cstride, params.batch); + } else if constexpr (std::is_same_v>) { + cblas_cgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, (void *)&salpha, + (void *)ap, params.lda, params.astride, + (void *)bp, params.ldb, params.bstride, (void *)&sbeta, + (void *)cp, params.ldc, params.cstride, params.batch); + } else if constexpr (std::is_same_v>) { + cblas_zgemm_batch_strided(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, (void *)&salpha, + (void *)ap, params.lda, params.astride, + (void *)bp, params.ldb, params.bstride, (void *)&sbeta, + (void *)cp, params.ldc, params.cstride, params.batch); + } + + // Update all but the last 3 indices + UpdateIndices(a, a_idx, 3); + UpdateIndices(b, b_idx, 3); + UpdateIndices(c, c_idx, 3); + } + } +#else + // The batch api is a new addition to BLIS and OpenBLAS, so it may not be present. + // Thus, we default to the standard gemm api and loop over anything above the 2nd dimension. + + #ifdef MATX_EN_OPENBLAS + openblas_set_num_threads(exec.GetNumThreads()); + #elif defined(MATX_EN_BLIS) + bli_thread_set_num_threads(exec.GetNumThreads()); + #endif + + total_iter *= params.batch; for (size_t iter = 0; iter < total_iter; iter++) { // Get pointers into A/B/C for this round auto ap = cuda::std::apply([&a](auto... param) { return a.GetPointer(param...); }, a_idx); auto bp = cuda::std::apply([&b](auto... param) { return b.GetPointer(param...); }, b_idx); auto cp = cuda::std::apply([&c](auto... param) { return c.GetPointer(param...); }, c_idx); - a_array[iter] = ap; - b_array[iter] = bp; - c_array[iter] = cp; + if constexpr (std::is_same_v) { + cblas_sgemm(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, salpha, + ap, params.lda, + bp, params.ldb, sbeta, + cp, params.ldc); + } else if constexpr (std::is_same_v) { + cblas_dgemm(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, salpha, + ap, params.lda, + bp, params.ldb, sbeta, + cp, params.ldc); + } else if constexpr (std::is_same_v>) { + cblas_cgemm(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, (void *)&salpha, + (const void **)ap, params.lda, + (const void **)bp, params.ldb, (void *)&sbeta, + (void **)cp, params.ldc); + } else if constexpr (std::is_same_v>) { + cblas_zgemm(CblasRowMajor, params.opA, params.opB, + params.m, params.n, params.k, (void *)&salpha, + (const void **)ap, params.lda, + (const void **)bp, params.ldb, (void *)&sbeta, + (void **)cp, params.ldc); + } // Update all but the last 2 indices UpdateIndices(a, a_idx, 2); UpdateIndices(b, b_idx, 2); UpdateIndices(c, c_idx, 2); } - - cblas_int_t group_size = static_cast(total_iter); - - if constexpr (std::is_same_v) { - cblas_sgemm_batch(CblasRowMajor, ¶ms.opA, ¶ms.opB, - ¶ms.m, ¶ms.n, ¶ms.k, &salpha, - a_array.data(), ¶ms.lda, - b_array.data(), ¶ms.ldb, &sbeta, - c_array.data(), ¶ms.ldc, GROUP_COUNT, &group_size); - } else if constexpr (std::is_same_v) { - cblas_dgemm_batch(CblasRowMajor, ¶ms.opA, ¶ms.opB, - ¶ms.m, ¶ms.n, ¶ms.k, &salpha, - a_array.data(), ¶ms.lda, - b_array.data(), ¶ms.ldb, &sbeta, - c_array.data(), ¶ms.ldc, GROUP_COUNT, &group_size); - } else if constexpr (std::is_same_v>) { - cblas_cgemm_batch(CblasRowMajor, ¶ms.opA, ¶ms.opB, - ¶ms.m, ¶ms.n, ¶ms.k, (void *)&salpha, - (const void **)a_array.data(), ¶ms.lda, - (const void **)b_array.data(), ¶ms.ldb, (void *)&sbeta, - (void **)c_array.data(), ¶ms.ldc, GROUP_COUNT, &group_size); - } else if constexpr (std::is_same_v>) { - cblas_zgemm_batch(CblasRowMajor, ¶ms.opA, ¶ms.opB, - ¶ms.m, ¶ms.n, ¶ms.k, (void *)&salpha, - (const void **)a_array.data(), ¶ms.lda, - (const void **)b_array.data(), ¶ms.ldb, (void *)&sbeta, - (void **)c_array.data(), ¶ms.ldc, GROUP_COUNT, &group_size); - } +#endif } template