From 82566e5026444717991adffbd1c809989ae79aea Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Tue, 9 Jul 2024 15:12:22 +0200 Subject: [PATCH] Test symmetric with complex types and hermitian and conjtrans with real types --- .../sparse_blas/include/test_spmm.hpp | 7 +-- .../sparse_blas/include/test_spmv.hpp | 61 ++++++++++++++++--- .../sparse_blas/source/sparse_spmv_buffer.cpp | 20 ++---- .../sparse_blas/source/sparse_spmv_usm.cpp | 20 ++---- 4 files changed, 61 insertions(+), 47 deletions(-) diff --git a/tests/unit_tests/sparse_blas/include/test_spmm.hpp b/tests/unit_tests/sparse_blas/include/test_spmm.hpp index 049d58b88..6188d4268 100644 --- a/tests/unit_tests/sparse_blas/include/test_spmm.hpp +++ b/tests/unit_tests/sparse_blas/include/test_spmm.hpp @@ -205,7 +205,6 @@ void test_helper_with_format_with_transpose( /** * Helper function to test combination of transpose vals. - * Only test \p conjtrans if \p fpType is complex. * * @tparam fpType Complex or scalar, single or double precision type * @tparam testFunctorI32 Test functor for fpType and int32 @@ -223,10 +222,8 @@ void test_helper_with_format( const std::vector &non_default_algorithms, int &num_passed, int &num_skipped) { std::vector transpose_vals{ oneapi::mkl::transpose::nontrans, - oneapi::mkl::transpose::trans }; - if (complex_info::is_complex) { - transpose_vals.push_back(oneapi::mkl::transpose::conjtrans); - } + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::conjtrans }; for (auto transpose_A : transpose_vals) { for (auto transpose_B : transpose_vals) { test_helper_with_format_with_transpose( diff --git a/tests/unit_tests/sparse_blas/include/test_spmv.hpp b/tests/unit_tests/sparse_blas/include/test_spmv.hpp index 43599e9d3..6ee256adb 100644 --- a/tests/unit_tests/sparse_blas/include/test_spmv.hpp +++ b/tests/unit_tests/sparse_blas/include/test_spmv.hpp @@ -51,7 +51,7 @@ * The test functions will use different sizes if the configuration implies a symmetric matrix. */ template -void test_helper_with_format( +void test_helper_with_format_with_transpose( testFunctorI32 test_functor_i32, testFunctorI64 test_functor_i64, sycl::device *dev, sparse_matrix_format_t format, const std::vector &non_default_algorithms, @@ -153,22 +153,37 @@ void test_helper_with_format( no_reset_data, no_scalars_on_device), num_passed, num_skipped); if (transpose_val != oneapi::mkl::transpose::conjtrans) { - // Lower symmetric or hermitian + // Do not test conjtrans with symmetric or hermitian views as no backend supports it. + // Lower symmetric oneapi::mkl::sparse::matrix_view symmetric_view( - complex_info::is_complex ? oneapi::mkl::sparse::matrix_descr::hermitian - : oneapi::mkl::sparse::matrix_descr::symmetric); + oneapi::mkl::sparse::matrix_descr::symmetric); EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, default_alg, symmetric_view, no_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); - // Upper symmetric or hermitian + // Upper symmetric symmetric_view.uplo_view = oneapi::mkl::uplo::upper; EXPECT_TRUE_OR_FUTURE_SKIP( test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, default_alg, symmetric_view, no_properties, no_reset_data, no_scalars_on_device), num_passed, num_skipped); + // Lower hermitian + oneapi::mkl::sparse::matrix_view hermitian_view( + oneapi::mkl::sparse::matrix_descr::hermitian); + EXPECT_TRUE_OR_FUTURE_SKIP( + test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, + transpose_val, fp_one, fp_zero, default_alg, hermitian_view, + no_properties, no_reset_data, no_scalars_on_device), + num_passed, num_skipped); + // Upper hermitian + hermitian_view.uplo_view = oneapi::mkl::uplo::upper; + EXPECT_TRUE_OR_FUTURE_SKIP( + test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero, + transpose_val, fp_one, fp_zero, default_alg, hermitian_view, + no_properties, no_reset_data, no_scalars_on_device), + num_passed, num_skipped); } // Test other algorithms for (auto alg : non_default_algorithms) { @@ -188,6 +203,34 @@ void test_helper_with_format( } } +/** + * Helper function to test combination of transpose vals. + * + * @tparam fpType Complex or scalar, single or double precision type + * @tparam testFunctorI32 Test functor for fpType and int32 + * @tparam testFunctorI64 Test functor for fpType and int64 + * @param dev Device to test + * @param format Sparse matrix format to use + * @param non_default_algorithms Algorithms compatible with the given format, other than default_alg + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +void test_helper_with_format( + testFunctorI32 test_functor_i32, testFunctorI64 test_functor_i64, sycl::device *dev, + sparse_matrix_format_t format, + const std::vector &non_default_algorithms, int &num_passed, + int &num_skipped) { + std::vector transpose_vals{ oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::conjtrans }; + for (auto transpose_A : transpose_vals) { + test_helper_with_format_with_transpose(test_functor_i32, test_functor_i64, dev, + format, non_default_algorithms, transpose_A, + num_passed, num_skipped); + } +} + /** * Helper function to test multiple sparse matrix format and choose valid algorithms. * @@ -195,24 +238,22 @@ void test_helper_with_format( * @tparam testFunctorI32 Test functor for fpType and int32 * @tparam testFunctorI64 Test functor for fpType and int64 * @param dev Device to test - * @param transpose_val Transpose value for the input matrix * @param num_passed Increase the number of configurations passed * @param num_skipped Increase the number of configurations skipped */ template void test_helper(testFunctorI32 test_functor_i32, testFunctorI64 test_functor_i64, - sycl::device *dev, oneapi::mkl::transpose transpose_val, int &num_passed, - int &num_skipped) { + sycl::device *dev, int &num_passed, int &num_skipped) { test_helper_with_format( test_functor_i32, test_functor_i64, dev, sparse_matrix_format_t::CSR, { oneapi::mkl::sparse::spmv_alg::no_optimize_alg, oneapi::mkl::sparse::spmv_alg::csr_alg1, oneapi::mkl::sparse::spmv_alg::csr_alg2, oneapi::mkl::sparse::spmv_alg::csr_alg3 }, - transpose_val, num_passed, num_skipped); + num_passed, num_skipped); test_helper_with_format( test_functor_i32, test_functor_i64, dev, sparse_matrix_format_t::COO, { oneapi::mkl::sparse::spmv_alg::no_optimize_alg, oneapi::mkl::sparse::spmv_alg::coo_alg1, oneapi::mkl::sparse::spmv_alg::coo_alg2 }, - transpose_val, num_passed, num_skipped); + num_passed, num_skipped); } /// Compute spmv reference as a dense operation diff --git a/tests/unit_tests/sparse_blas/source/sparse_spmv_buffer.cpp b/tests/unit_tests/sparse_blas/source/sparse_spmv_buffer.cpp index 12b449e61..0ba5afb9c 100644 --- a/tests/unit_tests/sparse_blas/source/sparse_spmv_buffer.cpp +++ b/tests/unit_tests/sparse_blas/source/sparse_spmv_buffer.cpp @@ -184,9 +184,7 @@ TEST_P(SparseSpmvBufferTests, RealSinglePrecision) { using fpType = float; int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped @@ -199,9 +197,7 @@ TEST_P(SparseSpmvBufferTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(GetParam()); int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped @@ -213,11 +209,7 @@ TEST_P(SparseSpmvBufferTests, ComplexSinglePrecision) { using fpType = std::complex; int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped @@ -230,11 +222,7 @@ TEST_P(SparseSpmvBufferTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(GetParam()); int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped diff --git a/tests/unit_tests/sparse_blas/source/sparse_spmv_usm.cpp b/tests/unit_tests/sparse_blas/source/sparse_spmv_usm.cpp index 85feacbda..fdeb57913 100644 --- a/tests/unit_tests/sparse_blas/source/sparse_spmv_usm.cpp +++ b/tests/unit_tests/sparse_blas/source/sparse_spmv_usm.cpp @@ -233,9 +233,7 @@ TEST_P(SparseSpmvUsmTests, RealSinglePrecision) { using fpType = float; int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped @@ -248,9 +246,7 @@ TEST_P(SparseSpmvUsmTests, RealDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(GetParam()); int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped @@ -262,11 +258,7 @@ TEST_P(SparseSpmvUsmTests, ComplexSinglePrecision) { using fpType = std::complex; int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped @@ -279,11 +271,7 @@ TEST_P(SparseSpmvUsmTests, ComplexDoublePrecision) { CHECK_DOUBLE_ON_DEVICE(GetParam()); int num_passed = 0, num_skipped = 0; test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::nontrans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::trans, num_passed, num_skipped); - test_helper(test_spmv, test_spmv, GetParam(), - oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + num_passed, num_skipped); if (num_skipped > 0) { // Mark that some tests were skipped GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped