From f7c72b97e675d87fbf141c3ae79a60cd637a60e9 Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Tue, 17 Oct 2023 16:56:03 +0100 Subject: [PATCH] Re-use macros in function_table --- src/sparse_blas/function_table.hpp | 301 +++++-------------------- src/sparse_blas/macros.hpp | 39 ++++ src/sparse_blas/sparse_blas_loader.cpp | 21 +- 3 files changed, 95 insertions(+), 266 deletions(-) create mode 100644 src/sparse_blas/macros.hpp diff --git a/src/sparse_blas/function_table.hpp b/src/sparse_blas/function_table.hpp index d4656190a..4d2296ff5 100644 --- a/src/sparse_blas/function_table.hpp +++ b/src/sparse_blas/function_table.hpp @@ -21,6 +21,52 @@ #define _ONEMKL_SPARSE_BLAS_FUNCTION_TABLE_HPP_ #include "oneapi/mkl/sparse_blas/types.hpp" +#include "sparse_blas/macros.hpp" + +#define DEFINE_SET_CSR_DATA(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ + void (*set_csr_data_buffer##FP_SUFFIX##INT_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t handle, INT_TYPE num_rows, \ + INT_TYPE num_cols, INT_TYPE nnz, oneapi::mkl::index_base index, \ + sycl::buffer & row_ptr, sycl::buffer & col_ind, \ + sycl::buffer & val); \ + sycl::event (*set_csr_data_usm##FP_SUFFIX##INT_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t handle, INT_TYPE num_rows, \ + INT_TYPE num_cols, INT_TYPE nnz, oneapi::mkl::index_base index, INT_TYPE * row_ptr, \ + INT_TYPE * col_ind, FP_TYPE * val, const std::vector &dependencies) + +#define DEFINE_GEMV(FP_TYPE, FP_SUFFIX) \ + void (*gemv_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::transpose transpose_val, const FP_TYPE alpha, \ + oneapi::mkl::sparse::matrix_handle_t A_handle, sycl::buffer &x, \ + const FP_TYPE beta, sycl::buffer &y); \ + sycl::event (*gemv_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::transpose transpose_val, const FP_TYPE alpha, \ + oneapi::mkl::sparse::matrix_handle_t A_handle, const FP_TYPE *x, const FP_TYPE beta, \ + FP_TYPE *y, const std::vector &dependencies) + +#define DEFINE_TRSV(FP_TYPE, FP_SUFFIX) \ + void (*trsv_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, \ + oneapi::mkl::diag diag_val, oneapi::mkl::sparse::matrix_handle_t A_handle, \ + sycl::buffer & x, sycl::buffer & y); \ + sycl::event (*trsv_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, \ + oneapi::mkl::diag diag_val, oneapi::mkl::sparse::matrix_handle_t A_handle, \ + const FP_TYPE *x, FP_TYPE *y, const std::vector &dependencies) + +#define DEFINE_GEMM(FP_TYPE, FP_SUFFIX) \ + void (*gemm_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::layout dense_matrix_layout, \ + oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, \ + const FP_TYPE alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, \ + sycl::buffer &B, const std::int64_t columns, const std::int64_t ldb, \ + const FP_TYPE beta, sycl::buffer &C, const std::int64_t ldc); \ + sycl::event (*gemm_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::layout dense_matrix_layout, \ + oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, \ + const FP_TYPE alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, const FP_TYPE *B, \ + const std::int64_t columns, const std::int64_t ldb, const FP_TYPE beta, FP_TYPE *C, \ + const std::int64_t ldc, const std::vector &dependencies) typedef struct { int version; @@ -30,117 +76,7 @@ typedef struct { oneapi::mkl::sparse::matrix_handle_t *p_handle, const std::vector &dependencies); - // set_csr_data - void (*set_csr_data_buffer_rf_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer &val); - void (*set_csr_data_buffer_rd_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer &val); - void (*set_csr_data_buffer_cf_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer, 1> &val); - void (*set_csr_data_buffer_cd_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer, 1> &val); - void (*set_csr_data_buffer_rf_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer &val); - void (*set_csr_data_buffer_rd_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer &val); - void (*set_csr_data_buffer_cf_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer, 1> &val); - void (*set_csr_data_buffer_cd_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - sycl::buffer &row_ptr, - sycl::buffer &col_ind, - sycl::buffer, 1> &val); - sycl::event (*set_csr_data_usm_rf_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - std::int32_t *row_ptr, std::int32_t *col_ind, float *val, - const std::vector &dependencies); - sycl::event (*set_csr_data_usm_rd_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - std::int32_t *row_ptr, std::int32_t *col_ind, - double *val, - const std::vector &dependencies); - sycl::event (*set_csr_data_usm_cf_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - std::int32_t *row_ptr, std::int32_t *col_ind, - std::complex *val, - const std::vector &dependencies); - sycl::event (*set_csr_data_usm_cd_i32)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int32_t num_rows, std::int32_t num_cols, - std::int32_t nnz, oneapi::mkl::index_base index, - std::int32_t *row_ptr, std::int32_t *col_ind, - std::complex *val, - const std::vector &dependencies); - sycl::event (*set_csr_data_usm_rf_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - std::int64_t *row_ptr, std::int64_t *col_ind, float *val, - const std::vector &dependencies); - sycl::event (*set_csr_data_usm_rd_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - std::int64_t *row_ptr, std::int64_t *col_ind, - double *val, - const std::vector &dependencies); - sycl::event (*set_csr_data_usm_cf_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - std::int64_t *row_ptr, std::int64_t *col_ind, - std::complex *val, - const std::vector &dependencies); - sycl::event (*set_csr_data_usm_cd_i64)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t handle, - std::int64_t num_rows, std::int64_t num_cols, - std::int64_t nnz, oneapi::mkl::index_base index, - std::int64_t *row_ptr, std::int64_t *col_ind, - std::complex *val, - const std::vector &dependencies); + FOR_EACH_FP_AND_INT_TYPE(DEFINE_SET_CSR_DATA); // optimize_* sycl::event (*optimize_gemv)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, @@ -151,141 +87,14 @@ typedef struct { oneapi::mkl::sparse::matrix_handle_t handle, const std::vector &dependencies); - // gemv - void (*gemv_buffer_rf)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const float alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer &x, const float beta, sycl::buffer &y); - void (*gemv_buffer_rd)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const double alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer &x, const double beta, - sycl::buffer &y); - void (*gemv_buffer_cf)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer, 1> &x, const std::complex beta, - sycl::buffer, 1> &y); - void (*gemv_buffer_cd)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer, 1> &x, - const std::complex beta, - sycl::buffer, 1> &y); - sycl::event (*gemv_usm_rf)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const float alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, - const float *x, const float beta, float *y, - const std::vector &dependencies); - sycl::event (*gemv_usm_rd)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const double alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, - const double *x, const double beta, double *y, - const std::vector &dependencies); - sycl::event (*gemv_usm_cf)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - const std::complex *x, const std::complex beta, - std::complex *y, - const std::vector &dependencies); - sycl::event (*gemv_usm_cd)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, - const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - const std::complex *x, const std::complex beta, - std::complex *y, - const std::vector &dependencies); - - // trsv - void (*trsv_buffer_rf)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, sycl::buffer &x, - sycl::buffer &y); - void (*trsv_buffer_rd)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer &x, sycl::buffer &y); - void (*trsv_buffer_cf)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer, 1> &x, - sycl::buffer, 1> &y); - void (*trsv_buffer_cd)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer, 1> &x, - sycl::buffer, 1> &y); - sycl::event (*trsv_usm_rf)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, const float *x, - float *y, const std::vector &dependencies); - sycl::event (*trsv_usm_rd)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, const double *x, - double *y, const std::vector &dependencies); - sycl::event (*trsv_usm_cf)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, - const std::complex *x, std::complex *y, - const std::vector &dependencies); - sycl::event (*trsv_usm_cd)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, - oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t A_handle, - const std::complex *x, std::complex *y, - const std::vector &dependencies); - - // gemm - void (*gemm_buffer_rf)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, - const float alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer &B, const std::int64_t columns, - const std::int64_t ldb, const float beta, sycl::buffer &C, - const std::int64_t ldc); - void (*gemm_buffer_rd)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, - const double alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer &B, const std::int64_t columns, - const std::int64_t ldb, const double beta, sycl::buffer &C, - const std::int64_t ldc); - void (*gemm_buffer_cf)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, - const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer, 1> &B, const std::int64_t columns, - const std::int64_t ldb, const std::complex beta, - sycl::buffer, 1> &C, const std::int64_t ldc); - void (*gemm_buffer_cd)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, - const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - sycl::buffer, 1> &B, const std::int64_t columns, - const std::int64_t ldb, const std::complex beta, - sycl::buffer, 1> &C, const std::int64_t ldc); - sycl::event (*gemm_usm_rf)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, - oneapi::mkl::transpose transpose_B, const float alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, const float *B, - const std::int64_t columns, const std::int64_t ldb, const float beta, - float *C, const std::int64_t ldc, - const std::vector &dependencies); - sycl::event (*gemm_usm_rd)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, - oneapi::mkl::transpose transpose_B, const double alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, const double *B, - const std::int64_t columns, const std::int64_t ldb, - const double beta, double *C, const std::int64_t ldc, - const std::vector &dependencies); - sycl::event (*gemm_usm_cf)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, - oneapi::mkl::transpose transpose_B, const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - const std::complex *B, const std::int64_t columns, - const std::int64_t ldb, const std::complex beta, - std::complex *C, const std::int64_t ldc, - const std::vector &dependencies); - sycl::event (*gemm_usm_cd)(sycl::queue &queue, oneapi::mkl::layout dense_matrix_layout, - oneapi::mkl::transpose transpose_A, - oneapi::mkl::transpose transpose_B, const std::complex alpha, - oneapi::mkl::sparse::matrix_handle_t A_handle, - const std::complex *B, const std::int64_t columns, - const std::int64_t ldb, const std::complex beta, - std::complex *C, const std::int64_t ldc, - const std::vector &dependencies); + FOR_EACH_FP_TYPE(DEFINE_GEMV); + FOR_EACH_FP_TYPE(DEFINE_TRSV); + FOR_EACH_FP_TYPE(DEFINE_GEMM); } sparse_blas_function_table_t; +#undef DEFINE_SET_CSR_DATA +#undef DEFINE_GEMV +#undef DEFINE_TRSV +#undef DEFINE_GEMM + #endif // _ONEMKL_SPARSE_BLAS_FUNCTION_TABLE_HPP_ diff --git a/src/sparse_blas/macros.hpp b/src/sparse_blas/macros.hpp new file mode 100644 index 000000000..a4ef88e35 --- /dev/null +++ b/src/sparse_blas/macros.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* (*Licensed under the Apache License, Version 2.0 )(the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_MACROS_HPP_ +#define _ONEMKL_SPARSE_BLAS_MACROS_HPP_ + +#define FOR_EACH_FP_TYPE(DEFINE_MACRO) \ + DEFINE_MACRO(float, _rf); \ + DEFINE_MACRO(double, _rd); \ + DEFINE_MACRO(std::complex, _cf); \ + DEFINE_MACRO(std::complex, _cd) + +#define FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, INT_TYPE, INT_SUFFIX) \ + DEFINE_MACRO(float, _rf, INT_TYPE, INT_SUFFIX); \ + DEFINE_MACRO(double, _rd, INT_TYPE, INT_SUFFIX); \ + DEFINE_MACRO(std::complex, _cf, INT_TYPE, INT_SUFFIX); \ + DEFINE_MACRO(std::complex, _cd, INT_TYPE, INT_SUFFIX) + +#define FOR_EACH_FP_AND_INT_TYPE(DEFINE_MACRO) \ + FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, std::int32_t, _i32); \ + FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, std::int64_t, _i64) + +#endif // _ONEMKL_SPARSE_BLAS_MACROS_HPP_ diff --git a/src/sparse_blas/sparse_blas_loader.cpp b/src/sparse_blas/sparse_blas_loader.cpp index 07d065561..3488d52ba 100644 --- a/src/sparse_blas/sparse_blas_loader.cpp +++ b/src/sparse_blas/sparse_blas_loader.cpp @@ -21,24 +21,9 @@ #include "function_table_initializer.hpp" #include "sparse_blas/function_table.hpp" +#include "sparse_blas/macros.hpp" #include "oneapi/mkl/detail/get_device_id.hpp" -#define FOR_EACH_FP_TYPE(DEFINE_MACRO) \ - DEFINE_MACRO(float, _rf); \ - DEFINE_MACRO(double, _rd); \ - DEFINE_MACRO(std::complex, _cf); \ - DEFINE_MACRO(std::complex, _cd) - -#define FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, INT_TYPE, INT_SUFFIX) \ - DEFINE_MACRO(float, _rf, INT_TYPE, INT_SUFFIX); \ - DEFINE_MACRO(double, _rd, INT_TYPE, INT_SUFFIX); \ - DEFINE_MACRO(std::complex, _cf, INT_TYPE, INT_SUFFIX); \ - DEFINE_MACRO(std::complex, _cd, INT_TYPE, INT_SUFFIX) - -#define FOR_EACH_FP_AND_INT_TYPE(DEFINE_MACRO) \ - FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, std::int32_t, _i32); \ - FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, std::int64_t, _i64) - namespace oneapi::mkl::sparse { static oneapi::mkl::detail::table_initializer