diff --git a/cmake/RepositoryDependenciesSetup.cmake b/cmake/RepositoryDependenciesSetup.cmake index 80147ad78fde..b8526dbbe0d1 100644 --- a/cmake/RepositoryDependenciesSetup.cmake +++ b/cmake/RepositoryDependenciesSetup.cmake @@ -1,16 +1,16 @@ -if (TPL_ENABLE_CUDA AND NOT Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE) - if ("${${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho}" STREQUAL "") - message( - "-- " "NOTE: Setting ${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho=OFF by default since TPL_ENABLE_CUDA='${TPL_ENABLE_CUDA}' AND Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE='${Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE}'!\n" - "-- NOTE: To allow the enable of ShyLU_NodeTacho, please set Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE=ON.") - set(${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho OFF) - # NOTE: Above we set the non-cache var - # ${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho so that each reconfigure will - # show this same note. - elseif (${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho) - message(FATAL_ERROR "ERROR: ${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho=ON but TPL_ENABLE_CUDA='${TPL_ENABLE_CUDA}' AND Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE='${Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE}' which is not allowed!") - endif() -endif() +# if (TPL_ENABLE_CUDA AND NOT Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE) +# if ("${${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho}" STREQUAL "") +# message( +# "-- " "NOTE: Setting ${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho=OFF by default since TPL_ENABLE_CUDA='${TPL_ENABLE_CUDA}' AND Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE='${Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE}'!\n" +# "-- NOTE: To allow the enable of ShyLU_NodeTacho, please set Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE=ON.") +# set(${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho OFF) +# # NOTE: Above we set the non-cache var +# # ${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho so that each reconfigure will +# # show this same note. +# elseif (${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho) +# message(FATAL_ERROR "ERROR: ${PROJECT_NAME}_ENABLE_ShyLU_NodeTacho=ON but TPL_ENABLE_CUDA='${TPL_ENABLE_CUDA}' AND Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE='${Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE}' which is not allowed!") +# endif() +# endif() ######################################################################### # STKBalance does not work with GO=INT or GO=UNSIGNED diff --git a/packages/shylu/shylu_node/tacho/CMakeLists.txt b/packages/shylu/shylu_node/tacho/CMakeLists.txt index 14f9d72338c6..2efce7b95211 100644 --- a/packages/shylu/shylu_node/tacho/CMakeLists.txt +++ b/packages/shylu/shylu_node/tacho/CMakeLists.txt @@ -5,39 +5,29 @@ IF (Kokkos_ENABLE_CUDA) IF (DEFINED CUDA_VERSION AND (CUDA_VERSION VERSION_LESS "8.0")) MESSAGE(FATAL_ERROR "Tacho requires CUDA 8 if CUDA is enabled") ENDIF() - # If RDC is off, emits a warning message - IF (NOT Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE) - MESSAGE(WARNING "Tacho requires CUDA relocatable device code to be enabled if CUDA is enabled. Set: Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE=ON ") - ENDIF() ENDIF() + IF (Kokkos_ENABLE_THREADS) IF (NOT Kokkos_ENABLE_OPENMP) MESSAGE(FATAL_ERROR "Tacho can not be build with Pthreads as the Kokkos Host Backend.") ENDIF() ENDIF() -# Set cmake variable to control examples and tests -IF (Kokkos_ENABLE_CUDA) - IF (DEFINED CUDA_VERSION AND (CUDA_VERSION VERSION_LESS "8.0")) - SET(TACHO_HAVE_KOKKOS_TASK OFF) - ELSE() - IF (Kokkos_ENABLE_CUDA_RELOCATABLE_DEVICE_CODE) - SET(TACHO_HAVE_KOKKOS_TASK ON) - ELSE() - SET(TACHO_HAVE_KOKKOS_TASK OFF) - ENDIF() - ENDIF() -ELSE() - SET(TACHO_HAVE_KOKKOS_TASK ON) +ADD_SUBDIRECTORY(src) + +IF (NOT DEFINED Tacho_ENABLE_EXAMPLES) + SET(Tacho_ENABLE_EXAMPLES ${Trilinos_ENABLE_EXAMPLES}) +ENDIF() +IF (NOT DEFINED Tacho_ENABLE_TESTS) + SET(Tacho_ENABLE_TESTS ${Trilinos_ENABLE_TESTS}) ENDIF() -ADD_SUBDIRECTORY(src) -IF (TACHO_HAVE_KOKKOS_TASK) - IF (Trilinos_ENABLE_Gtest) - TRIBITS_ADD_EXAMPLE_DIRECTORIES(example) - TRIBITS_ADD_TEST_DIRECTORIES(unit-test) - ELSE() - MESSAGE(STATUS "Tacho disables examples and tests as Trilinos disables Gtest") - ENDIF() +IF (Tacho_ENABLE_EXAMPLES) + TRIBITS_ADD_EXAMPLE_DIRECTORIES(example) ENDIF() + +IF (Tacho_ENABLE_TESTS) + TRIBITS_ADD_TEST_DIRECTORIES(unit-test) +ENDIF() + TRIBITS_SUBPACKAGE_POSTPROCESS() diff --git a/packages/shylu/shylu_node/tacho/cmake/Dependencies.cmake b/packages/shylu/shylu_node/tacho/cmake/Dependencies.cmake index c184b938a4c3..c1752bf3fde0 100644 --- a/packages/shylu/shylu_node/tacho/cmake/Dependencies.cmake +++ b/packages/shylu/shylu_node/tacho/cmake/Dependencies.cmake @@ -1,8 +1,8 @@ SET(LIB_REQUIRED_DEP_PACKAGES Kokkos) SET(LIB_OPTIONAL_DEP_PACKAGES) -SET(TEST_REQUIRED_DEP_PACKAGES Kokkos KokkosAlgorithms Gtest) +SET(TEST_REQUIRED_DEP_PACKAGES Kokkos KokkosAlgorithms) SET(TEST_OPTIONAL_DEP_PACKAGES) SET(LIB_REQUIRED_DEP_TPLS) -SET(LIB_OPTIONAL_DEP_TPLS METIS Scotch Cholmod HWLOC HYPRE MKL LAPACK BLAS Pthread QTHREAD VTune CUSOLVER CUSPARSE CUBLAS CUDA) +SET(LIB_OPTIONAL_DEP_TPLS METIS HWLOC HYPRE MKL LAPACK BLAS Pthread QTHREAD VTune CUSOLVER CUSPARSE CUBLAS CUDA) SET(TEST_REQUIRED_DEP_TPLS BLAS LAPACK) SET(TEST_OPTIONAL_DEP_TPLS METIS HWLOC Cholmod MKL LAPACK BLAS Pthread QTHREAD CUSOLVER CUSPARSE CUBLAS CUDA) diff --git a/packages/shylu/shylu_node/tacho/cmake/Tacho_config.h.in b/packages/shylu/shylu_node/tacho/cmake/Tacho_config.h.in index 816081e6750f..ce02a75e210c 100644 --- a/packages/shylu/shylu_node/tacho/cmake/Tacho_config.h.in +++ b/packages/shylu/shylu_node/tacho/cmake/Tacho_config.h.in @@ -1,9 +1,6 @@ #ifndef __TACHO_CONFIG_H__ #define __TACHO_CONFIG_H__ -/* Define if kokkos tasking is enabled */ -#cmakedefine TACHO_HAVE_KOKKOS_TASK - /* Define if want to build with size_type (int) enabled */ #cmakedefine TACHO_USE_INT_INT @@ -22,21 +19,12 @@ /* Define if want to build with METIS enabled */ #cmakedefine TACHO_HAVE_METIS -/* Define if want to build with METIS enabled */ -//#cmakedefine TACHO_HAVE_METIS_MT - -/* Define if want to build with Scotch enabled */ -//#cmakedefine TACHO_HAVE_SCOTCH - /* Define if want to build with CHOLMOD enabled */ #cmakedefine TACHO_HAVE_SUITESPARSE /* Define if want to build with VTune enabled */ #cmakedefine TACHO_HAVE_VTUNE -///* Define if want to build with Teuchos enabled */ -#cmakedefine TACHO_HAVE_TRILINOS_SS - #ifndef F77_BLAS_MANGLE # define F77_BLAS_MANGLE@F77_BLAS_MANGLE@ #endif diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholLevelSet.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCholLevelSet.cpp similarity index 52% rename from packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholLevelSet.cpp rename to packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCholLevelSet.cpp index 33db60f960c7..26014be5621b 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholLevelSet.cpp +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCholLevelSet.cpp @@ -2,12 +2,12 @@ #include #include -#include "Tacho_Internal.hpp" #include "Tacho_CommandLineParser.hpp" +#include "Tacho_Internal.hpp" //#define TACHO_ENABLE_PROFILE #if defined(KOKKOS_ENABLE_CUDA) && defined(TACHO_ENABLE_PROFILE) -#include "cuda_profiler_api.h" +#include "cuda_profiler_api.h" #endif //#define TACHO_ENABLE_MPI_TEST @@ -15,10 +15,10 @@ #include "mpi.h" #endif -template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; -static const char * scheduler_name = "TaskSchedulerMultiple"; +template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; +static const char *scheduler_name = "TaskSchedulerMultiple"; -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { Tacho::CommandLineParser opts("This example program measure the performance of Tacho level set tools"); bool serial = false; @@ -40,8 +40,10 @@ int main (int argc, char *argv[]) { opts.set_option("nrhs", "Number of RHS vectors", &nrhs); opts.set_option("nstreams", "# of streams used in level set factorization and solve", &nstreams); opts.set_option("device-level-cut", "tree cut level to force device function", &device_level_cut); - opts.set_option("device-factor-thres", "device function is used above this threshold in factorization", &device_factor_thres); - opts.set_option("device-solve-thres", "device function is used above this threshold in solve", &device_solve_thres); + opts.set_option("device-factor-thres", "device function is used above this threshold in factorization", + &device_factor_thres); + opts.set_option("device-solve-thres", "device function is used above this threshold in solve", + &device_solve_thres); #if defined(TACHO_ENABLE_MPI_TEST) int nrank, irank; @@ -51,12 +53,12 @@ int main (int argc, char *argv[]) { MPI_Comm_size(comm, &nrank); std::vector t_all(nrank, double(0)); - auto is_root = [irank]()->bool { return irank == 0; }; + auto is_root = [irank]() -> bool { return irank == 0; }; #else - auto is_root = []()->bool { return true; }; + auto is_root = []() -> bool { return true; }; #endif -#if !defined (KOKKOS_ENABLE_CUDA) +#if !defined(KOKKOS_ENABLE_CUDA) // override serial flag if (serial) { std::cout << "CUDA is enabled and serial code cannot be instanciated\n"; @@ -65,7 +67,8 @@ int main (int argc, char *argv[]) { #endif const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); @@ -78,22 +81,23 @@ int main (int argc, char *argv[]) { typedef TaskSchedulerType scheduler_type; - Tacho::printExecSpaceConfiguration ("DeviceSpace", false); - Tacho::printExecSpaceConfiguration ("HostSpace", false); + Tacho::printExecSpaceConfiguration("DeviceSpace", false); + Tacho::printExecSpaceConfiguration("HostSpace", false); printf("Scheduler Type = %s\n", scheduler_name); - + int r_val = 0; - + { typedef double value_type; - typedef Tacho::CrsMatrixBase CrsMatrixBaseType; - typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; - typedef Kokkos::View DenseMatrixBaseType; - + typedef Tacho::CrsMatrixBase CrsMatrixBaseType; + typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; + typedef Kokkos::View DenseMatrixBaseType; + Kokkos::Timer timer; double t = 0.0; - - if (is_root()) std::cout << "CholLevelSet:: import input file = " << file << std::endl; + + if (is_root()) + std::cout << "CholLevelSet:: import input file = " << file << std::endl; CrsMatrixBaseTypeHost A; timer.reset(); { @@ -109,11 +113,13 @@ int main (int argc, char *argv[]) { } Tacho::Graph G(A); t = timer.seconds(); - if (is_root()) std::cout << "CholLevelSet:: import input file::time = " << t << std::endl; + if (is_root()) + std::cout << "CholLevelSet:: import input file::time = " << t << std::endl; - if (is_root()) std::cout << "CholLevelSet:: analyze matrix" << std::endl; + if (is_root()) + std::cout << "CholLevelSet:: analyze matrix" << std::endl; timer.reset(); -#if defined(TACHO_HAVE_METIS) +#if defined(TACHO_HAVE_METIS) Tacho::GraphTools_Metis T(G); #elif defined(TACHO_HAVE_SCOTCH) Tacho::GraphTools_Scotch T(G); @@ -121,55 +127,51 @@ int main (int argc, char *argv[]) { Tacho::GraphTools T(G); #endif T.reorder(verbose); - + Tacho::SymbolicTools S(A, T); S.symbolicFactorize(verbose); t = timer.seconds(); - if (is_root()) std::cout << "CholLevelSet:: analyze matrix::time = " << t << std::endl; - - typedef typename device_type::memory_space device_memory_space; - - auto a_row_ptr = Kokkos::create_mirror_view(device_memory_space(), A.RowPtr()); - auto a_cols = Kokkos::create_mirror_view(device_memory_space(), A.Cols()); - auto a_values = Kokkos::create_mirror_view(device_memory_space(), A.Values()); - - auto t_perm = Kokkos::create_mirror_view(device_memory_space(), T.PermVector()); - auto t_peri = Kokkos::create_mirror_view(device_memory_space(), T.InvPermVector()); - auto s_supernodes = Kokkos::create_mirror_view(device_memory_space(), S.Supernodes()); - auto s_gid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelPtr()); - auto s_gid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelColIdx()); - auto s_sid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelPtr()); - auto s_sid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelColIdx()); - auto s_blk_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.blkSuperPanelColIdx()); - auto s_snodes_tree_parent = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeParent()); - auto s_snodes_tree_ptr = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreePtr()); + if (is_root()) + std::cout << "CholLevelSet:: analyze matrix::time = " << t << std::endl; + + typedef typename device_type::memory_space device_memory_space; + + auto a_row_ptr = Kokkos::create_mirror_view(device_memory_space(), A.RowPtr()); + auto a_cols = Kokkos::create_mirror_view(device_memory_space(), A.Cols()); + auto a_values = Kokkos::create_mirror_view(device_memory_space(), A.Values()); + + auto t_perm = Kokkos::create_mirror_view(device_memory_space(), T.PermVector()); + auto t_peri = Kokkos::create_mirror_view(device_memory_space(), T.InvPermVector()); + auto s_supernodes = Kokkos::create_mirror_view(device_memory_space(), S.Supernodes()); + auto s_gid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelPtr()); + auto s_gid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelColIdx()); + auto s_sid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelPtr()); + auto s_sid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelColIdx()); + auto s_blk_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.blkSuperPanelColIdx()); + auto s_snodes_tree_parent = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeParent()); + auto s_snodes_tree_ptr = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreePtr()); auto s_snodes_tree_children = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeChildren()); - - Kokkos::deep_copy(a_row_ptr , A.RowPtr()); - Kokkos::deep_copy(a_cols , A.Cols()); - Kokkos::deep_copy(a_values , A.Values()); - - Kokkos::deep_copy(t_perm , T.PermVector()); - Kokkos::deep_copy(t_peri , T.InvPermVector()); - Kokkos::deep_copy(s_supernodes , S.Supernodes()); - Kokkos::deep_copy(s_gid_spanel_ptr , S.gidSuperPanelPtr()); - Kokkos::deep_copy(s_gid_spanel_colidx , S.gidSuperPanelColIdx()); - Kokkos::deep_copy(s_sid_spanel_ptr , S.sidSuperPanelPtr()); - Kokkos::deep_copy(s_sid_spanel_colidx , S.sidSuperPanelColIdx()); - Kokkos::deep_copy(s_blk_spanel_colidx , S.blkSuperPanelColIdx()); - Kokkos::deep_copy(s_snodes_tree_parent , S.SupernodesTreeParent()); - Kokkos::deep_copy(s_snodes_tree_ptr , S.SupernodesTreePtr()); - Kokkos::deep_copy(s_snodes_tree_children , S.SupernodesTreeChildren()); - - Tacho::NumericTools - N(A.NumRows(), a_row_ptr, a_cols, - t_perm, t_peri, - S.NumSupernodes(), s_supernodes, - s_gid_spanel_ptr, s_gid_spanel_colidx, - s_sid_spanel_ptr, s_sid_spanel_colidx, s_blk_spanel_colidx, - s_snodes_tree_parent, s_snodes_tree_ptr, s_snodes_tree_children, - S.SupernodesTreeLevel(), - S.SupernodesTreeRoots()); + + Kokkos::deep_copy(a_row_ptr, A.RowPtr()); + Kokkos::deep_copy(a_cols, A.Cols()); + Kokkos::deep_copy(a_values, A.Values()); + + Kokkos::deep_copy(t_perm, T.PermVector()); + Kokkos::deep_copy(t_peri, T.InvPermVector()); + Kokkos::deep_copy(s_supernodes, S.Supernodes()); + Kokkos::deep_copy(s_gid_spanel_ptr, S.gidSuperPanelPtr()); + Kokkos::deep_copy(s_gid_spanel_colidx, S.gidSuperPanelColIdx()); + Kokkos::deep_copy(s_sid_spanel_ptr, S.sidSuperPanelPtr()); + Kokkos::deep_copy(s_sid_spanel_colidx, S.sidSuperPanelColIdx()); + Kokkos::deep_copy(s_blk_spanel_colidx, S.blkSuperPanelColIdx()); + Kokkos::deep_copy(s_snodes_tree_parent, S.SupernodesTreeParent()); + Kokkos::deep_copy(s_snodes_tree_ptr, S.SupernodesTreePtr()); + Kokkos::deep_copy(s_snodes_tree_children, S.SupernodesTreeChildren()); + + Tacho::NumericTools N( + A.NumRows(), a_row_ptr, a_cols, t_perm, t_peri, S.NumSupernodes(), s_supernodes, s_gid_spanel_ptr, + s_gid_spanel_colidx, s_sid_spanel_ptr, s_sid_spanel_colidx, s_blk_spanel_colidx, s_snodes_tree_parent, + s_snodes_tree_ptr, s_snodes_tree_children, S.SupernodesTreeLevel(), S.SupernodesTreeRoots()); N.printMemoryStat(verbose); #if defined(TACHO_USE_LEVELSET_VARIANT) @@ -177,84 +179,86 @@ int main (int argc, char *argv[]) { #else constexpr int variant = 0; #endif - Tacho::LevelSetTools L(N); + Tacho::LevelSetTools L(N); L.initialize(device_level_cut, device_factor_thres, device_solve_thres, verbose); L.createStream(nstreams); - - if (is_root()) std::cout << "CholLevelSet:: factorize matrix" << std::endl; + if (is_root()) + std::cout << "CholLevelSet:: factorize matrix" << std::endl; #if defined(KOKKOS_ENABLE_CUDA) && defined(TACHO_ENABLE_PROFILE) cudaProfilerStart(); #endif - timer.reset(); + timer.reset(); L.factorizeCholesky(a_values, verbose); - t = timer.seconds(); + t = timer.seconds(); #if defined(KOKKOS_ENABLE_CUDA) && defined(TACHO_ENABLE_PROFILE) cudaProfilerStop(); #endif - if (is_root()) std::cout << "CholLevelSet:: factorize matrix::time = " << t << std::endl; + if (is_root()) + std::cout << "CholLevelSet:: factorize matrix::time = " << t << std::endl; #if defined(TACHO_ENABLE_MPI_TEST) { - MPI_Gather(&t, 1, MPI_DOUBLE, t_all.data(), 1, MPI_DOUBLE, 0, comm); - + MPI_Gather(&t, 1, MPI_DOUBLE, t_all.data(), 1, MPI_DOUBLE, 0, comm); + if (is_root()) { double t_min(1000000), t_max(0), t_avg(0); - for (int i=0;i random(13718); + Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(B, random, value_type(1)); } /// /// solve via level set /// - if (is_root()) std::cout << "CholLevelSet:: solve matrix via LevelSetTools" << std::endl; + if (is_root()) + std::cout << "CholLevelSet:: solve matrix via LevelSetTools" << std::endl; - timer.reset(); + timer.reset(); #if defined(KOKKOS_ENABLE_CUDA) && defined(TACHO_ENABLE_PROFILE) cudaProfilerStart(); #endif constexpr int niter = 10; - for (int i=0;i using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; -static const char * scheduler_name = "TaskSchedulerMultiple"; +template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; +static const char *scheduler_name = "TaskSchedulerMultiple"; - -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { CommandLineParser opts("This example program measure the performance of Tacho on Kokkos::OpenMP"); bool serial = false; @@ -38,7 +37,7 @@ int main (int argc, char *argv[]) { opts.set_option("mb", "Blocksize", &mb); opts.set_option("nb", "Panelsize", &nb); -#if !defined (KOKKOS_ENABLE_CUDA) +#if !defined(KOKKOS_ENABLE_CUDA) // override serial flag if (serial) { std::cout << "CUDA is enabled and serial code cannot be instanciated\n"; @@ -47,27 +46,28 @@ int main (int argc, char *argv[]) { #endif const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); typedef Kokkos::DefaultExecutionSpace exec_space; - //typedef Kokkos::DefaultHostExecutionSpace exec_space; + // typedef Kokkos::DefaultHostExecutionSpace exec_space; typedef Kokkos::DefaultHostExecutionSpace host_space; typedef TaskSchedulerType scheduler_type; - printExecSpaceConfiguration ("DeviceSpace", false); - printExecSpaceConfiguration ("HostSpace", false); + printExecSpaceConfiguration("DeviceSpace", false); + printExecSpaceConfiguration("HostSpace", false); printf("Scheduler Type = %s\n", scheduler_name); - + int r_val = 0; - + { typedef double value_type; - typedef CrsMatrixBase CrsMatrixBaseTypeHost; - typedef Kokkos::View DenseMatrixBaseType; - + typedef CrsMatrixBase CrsMatrixBaseTypeHost; + typedef Kokkos::View DenseMatrixBaseType; + Kokkos::Timer timer; double t = 0.0; @@ -91,7 +91,7 @@ int main (int argc, char *argv[]) { std::cout << "CholSupernodes:: analyze matrix" << std::endl; timer.reset(); -#if defined(TACHO_HAVE_METIS) +#if defined(TACHO_HAVE_METIS) GraphTools_Metis T(G); #elif defined(TACHO_HAVE_SCOTCH) GraphTools_Scotch T(G); @@ -99,62 +99,57 @@ int main (int argc, char *argv[]) { GraphTools T(G); #endif T.reorder(verbose); - + SymbolicTools S(A, T); S.symbolicFactorize(verbose); t = timer.seconds(); std::cout << "CholSupernodes:: analyze matrix::time = " << t << std::endl; - typedef typename exec_space::memory_space device_memory_space; - - auto a_row_ptr = Kokkos::create_mirror_view(device_memory_space(), A.RowPtr()); - auto a_cols = Kokkos::create_mirror_view(device_memory_space(), A.Cols()); - auto a_values = Kokkos::create_mirror_view(device_memory_space(), A.Values()); - - auto t_perm = Kokkos::create_mirror_view(device_memory_space(), T.PermVector()); - auto t_peri = Kokkos::create_mirror_view(device_memory_space(), T.InvPermVector()); - auto s_supernodes = Kokkos::create_mirror_view(device_memory_space(), S.Supernodes()); - auto s_gid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelPtr()); - auto s_gid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelColIdx()); - auto s_sid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelPtr()); - auto s_sid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelColIdx()); - auto s_blk_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.blkSuperPanelColIdx()); - auto s_snodes_tree_parent = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeParent()); - auto s_snodes_tree_ptr = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreePtr()); + typedef typename exec_space::memory_space device_memory_space; + + auto a_row_ptr = Kokkos::create_mirror_view(device_memory_space(), A.RowPtr()); + auto a_cols = Kokkos::create_mirror_view(device_memory_space(), A.Cols()); + auto a_values = Kokkos::create_mirror_view(device_memory_space(), A.Values()); + + auto t_perm = Kokkos::create_mirror_view(device_memory_space(), T.PermVector()); + auto t_peri = Kokkos::create_mirror_view(device_memory_space(), T.InvPermVector()); + auto s_supernodes = Kokkos::create_mirror_view(device_memory_space(), S.Supernodes()); + auto s_gid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelPtr()); + auto s_gid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelColIdx()); + auto s_sid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelPtr()); + auto s_sid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelColIdx()); + auto s_blk_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.blkSuperPanelColIdx()); + auto s_snodes_tree_parent = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeParent()); + auto s_snodes_tree_ptr = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreePtr()); auto s_snodes_tree_children = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeChildren()); - - Kokkos::deep_copy(a_row_ptr , A.RowPtr()); - Kokkos::deep_copy(a_cols , A.Cols()); - Kokkos::deep_copy(a_values , A.Values()); - - Kokkos::deep_copy(t_perm , T.PermVector()); - Kokkos::deep_copy(t_peri , T.InvPermVector()); - Kokkos::deep_copy(s_supernodes , S.Supernodes()); - Kokkos::deep_copy(s_gid_spanel_ptr , S.gidSuperPanelPtr()); - Kokkos::deep_copy(s_gid_spanel_colidx , S.gidSuperPanelColIdx()); - Kokkos::deep_copy(s_sid_spanel_ptr , S.sidSuperPanelPtr()); - Kokkos::deep_copy(s_sid_spanel_colidx , S.sidSuperPanelColIdx()); - Kokkos::deep_copy(s_blk_spanel_colidx , S.blkSuperPanelColIdx()); - Kokkos::deep_copy(s_snodes_tree_parent , S.SupernodesTreeParent()); - Kokkos::deep_copy(s_snodes_tree_ptr , S.SupernodesTreePtr()); - Kokkos::deep_copy(s_snodes_tree_children , S.SupernodesTreeChildren()); - - NumericTools - N(A.NumRows(), a_row_ptr, a_cols, - t_perm, t_peri, - S.NumSupernodes(), s_supernodes, - s_gid_spanel_ptr, s_gid_spanel_colidx, - s_sid_spanel_ptr, s_sid_spanel_colidx, s_blk_spanel_colidx, - s_snodes_tree_parent, s_snodes_tree_ptr, s_snodes_tree_children, - S.SupernodesTreeLevel(), - S.SupernodesTreeRoots()); + + Kokkos::deep_copy(a_row_ptr, A.RowPtr()); + Kokkos::deep_copy(a_cols, A.Cols()); + Kokkos::deep_copy(a_values, A.Values()); + + Kokkos::deep_copy(t_perm, T.PermVector()); + Kokkos::deep_copy(t_peri, T.InvPermVector()); + Kokkos::deep_copy(s_supernodes, S.Supernodes()); + Kokkos::deep_copy(s_gid_spanel_ptr, S.gidSuperPanelPtr()); + Kokkos::deep_copy(s_gid_spanel_colidx, S.gidSuperPanelColIdx()); + Kokkos::deep_copy(s_sid_spanel_ptr, S.sidSuperPanelPtr()); + Kokkos::deep_copy(s_sid_spanel_colidx, S.sidSuperPanelColIdx()); + Kokkos::deep_copy(s_blk_spanel_colidx, S.blkSuperPanelColIdx()); + Kokkos::deep_copy(s_snodes_tree_parent, S.SupernodesTreeParent()); + Kokkos::deep_copy(s_snodes_tree_ptr, S.SupernodesTreePtr()); + Kokkos::deep_copy(s_snodes_tree_children, S.SupernodesTreeChildren()); + + NumericTools N( + A.NumRows(), a_row_ptr, a_cols, t_perm, t_peri, S.NumSupernodes(), s_supernodes, s_gid_spanel_ptr, + s_gid_spanel_colidx, s_sid_spanel_ptr, s_sid_spanel_colidx, s_blk_spanel_colidx, s_snodes_tree_parent, + s_snodes_tree_ptr, s_snodes_tree_children, S.SupernodesTreeLevel(), S.SupernodesTreeRoots()); N.setSerialThresholdSize(serial_thres_size); std::cout << "CholSupernodes:: factorize matrix" << std::endl; - timer.reset(); + timer.reset(); if (serial) { -#if !defined (KOKKOS_ENABLE_CUDA) +#if !defined(KOKKOS_ENABLE_CUDA) N.factorizeCholesky_Serial(a_values, verbose); #endif } else { @@ -170,33 +165,30 @@ int main (int argc, char *argv[]) { N.factorizeCholesky_ParallelPanel(a_values, nb, verbose); } } - t = timer.seconds(); + t = timer.seconds(); std::cout << "CholSupernodes:: factorize matrix::time = " << t << std::endl; - - DenseMatrixBaseType - B("B", A.NumRows(), nrhs), - X("X", A.NumRows(), nrhs), - Y("Y", A.NumRows(), nrhs); + + DenseMatrixBaseType B("B", A.NumRows(), nrhs), X("X", A.NumRows(), nrhs), Y("Y", A.NumRows(), nrhs); { - Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(B, random, value_type(1)); } std::cout << "CholSupernodes:: solve matrix" << std::endl; - timer.reset(); + timer.reset(); if (serial) { -#if !defined (KOKKOS_ENABLE_CUDA) +#if !defined(KOKKOS_ENABLE_CUDA) N.solveCholesky_Serial(X, B, Y, verbose); #endif } else { N.solveCholesky_Parallel(X, B, Y, verbose); } - t = timer.seconds(); + t = timer.seconds(); std::cout << "CholSupernodes:: solve matrix::time = " << t << std::endl; const double res = N.computeRelativeResidual(X, B); - //const double eps = std::numeric_limits::epsilon()*100; + // const double eps = std::numeric_limits::epsilon()*100; std::cout << "CholSupernodes:: residual = " << res << std::endl; } Kokkos::finalize(); diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholTriSolve.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCholTriSolve.cpp similarity index 60% rename from packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholTriSolve.cpp rename to packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCholTriSolve.cpp index 304b25f36f40..b3b77e2f96f6 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholTriSolve.cpp +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCholTriSolve.cpp @@ -2,18 +2,18 @@ #include #include -#include "Tacho_Internal.hpp" #include "Tacho_CommandLineParser.hpp" +#include "Tacho_Internal.hpp" #define TACHO_ENABLE_PROFILE #if defined(KOKKOS_ENABLE_CUDA) && defined(TACHO_ENABLE_PROFILE) -#include "cuda_profiler_api.h" +#include "cuda_profiler_api.h" #endif -template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; -static const char * scheduler_name = "TaskSchedulerMultiple"; +template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; +static const char *scheduler_name = "TaskSchedulerMultiple"; -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { Tacho::CommandLineParser opts("This example program measure the performance of Tacho on Kokkos::OpenMP"); bool serial = false; @@ -46,7 +46,7 @@ int main (int argc, char *argv[]) { opts.set_option("device-level-cut", "tree cut level to force device function", &device_level_cut); opts.set_option("device-function-thres", "device function is used above this threshold", &device_function_thres); -#if !defined (KOKKOS_ENABLE_CUDA) +#if !defined(KOKKOS_ENABLE_CUDA) // override serial flag if (serial) { std::cout << "CUDA is enabled and serial code cannot be instanciated\n"; @@ -55,7 +55,8 @@ int main (int argc, char *argv[]) { #endif const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); @@ -68,17 +69,17 @@ int main (int argc, char *argv[]) { typedef TaskSchedulerType scheduler_type; - Tacho::printExecSpaceConfiguration ("DeviceSpace", false); - Tacho::printExecSpaceConfiguration ("HostSpace", false); + Tacho::printExecSpaceConfiguration("DeviceSpace", false); + Tacho::printExecSpaceConfiguration("HostSpace", false); printf("Scheduler Type = %s\n", scheduler_name); - + int r_val = 0; - + { typedef double value_type; - typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; - typedef Kokkos::View DenseMatrixBaseType; - + typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; + typedef Kokkos::View DenseMatrixBaseType; + Kokkos::Timer timer; double t = 0.0; @@ -102,7 +103,7 @@ int main (int argc, char *argv[]) { std::cout << "CholTriSolve:: analyze matrix" << std::endl; timer.reset(); -#if defined(TACHO_HAVE_METIS) +#if defined(TACHO_HAVE_METIS) Tacho::GraphTools_Metis T(G); #elif defined(TACHO_HAVE_SCOTCH) Tacho::GraphTools_Scotch T(G); @@ -110,63 +111,58 @@ int main (int argc, char *argv[]) { Tacho::GraphTools T(G); #endif T.reorder(verbose); - + Tacho::SymbolicTools S(A, T); S.symbolicFactorize(verbose); t = timer.seconds(); std::cout << "CholTriSolve:: analyze matrix::time = " << t << std::endl; - typedef typename device_type::memory_space device_memory_space; - - auto a_row_ptr = Kokkos::create_mirror_view(device_memory_space(), A.RowPtr()); - auto a_cols = Kokkos::create_mirror_view(device_memory_space(), A.Cols()); - auto a_values = Kokkos::create_mirror_view(device_memory_space(), A.Values()); - - auto t_perm = Kokkos::create_mirror_view(device_memory_space(), T.PermVector()); - auto t_peri = Kokkos::create_mirror_view(device_memory_space(), T.InvPermVector()); - auto s_supernodes = Kokkos::create_mirror_view(device_memory_space(), S.Supernodes()); - auto s_gid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelPtr()); - auto s_gid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelColIdx()); - auto s_sid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelPtr()); - auto s_sid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelColIdx()); - auto s_blk_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.blkSuperPanelColIdx()); - auto s_snodes_tree_parent = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeParent()); - auto s_snodes_tree_ptr = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreePtr()); + typedef typename device_type::memory_space device_memory_space; + + auto a_row_ptr = Kokkos::create_mirror_view(device_memory_space(), A.RowPtr()); + auto a_cols = Kokkos::create_mirror_view(device_memory_space(), A.Cols()); + auto a_values = Kokkos::create_mirror_view(device_memory_space(), A.Values()); + + auto t_perm = Kokkos::create_mirror_view(device_memory_space(), T.PermVector()); + auto t_peri = Kokkos::create_mirror_view(device_memory_space(), T.InvPermVector()); + auto s_supernodes = Kokkos::create_mirror_view(device_memory_space(), S.Supernodes()); + auto s_gid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelPtr()); + auto s_gid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.gidSuperPanelColIdx()); + auto s_sid_spanel_ptr = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelPtr()); + auto s_sid_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.sidSuperPanelColIdx()); + auto s_blk_spanel_colidx = Kokkos::create_mirror_view(device_memory_space(), S.blkSuperPanelColIdx()); + auto s_snodes_tree_parent = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeParent()); + auto s_snodes_tree_ptr = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreePtr()); auto s_snodes_tree_children = Kokkos::create_mirror_view(device_memory_space(), S.SupernodesTreeChildren()); - - Kokkos::deep_copy(a_row_ptr , A.RowPtr()); - Kokkos::deep_copy(a_cols , A.Cols()); - Kokkos::deep_copy(a_values , A.Values()); - - Kokkos::deep_copy(t_perm , T.PermVector()); - Kokkos::deep_copy(t_peri , T.InvPermVector()); - Kokkos::deep_copy(s_supernodes , S.Supernodes()); - Kokkos::deep_copy(s_gid_spanel_ptr , S.gidSuperPanelPtr()); - Kokkos::deep_copy(s_gid_spanel_colidx , S.gidSuperPanelColIdx()); - Kokkos::deep_copy(s_sid_spanel_ptr , S.sidSuperPanelPtr()); - Kokkos::deep_copy(s_sid_spanel_colidx , S.sidSuperPanelColIdx()); - Kokkos::deep_copy(s_blk_spanel_colidx , S.blkSuperPanelColIdx()); - Kokkos::deep_copy(s_snodes_tree_parent , S.SupernodesTreeParent()); - Kokkos::deep_copy(s_snodes_tree_ptr , S.SupernodesTreePtr()); - Kokkos::deep_copy(s_snodes_tree_children , S.SupernodesTreeChildren()); - - Tacho::NumericTools - N(A.NumRows(), a_row_ptr, a_cols, - t_perm, t_peri, - S.NumSupernodes(), s_supernodes, - s_gid_spanel_ptr, s_gid_spanel_colidx, - s_sid_spanel_ptr, s_sid_spanel_colidx, s_blk_spanel_colidx, - s_snodes_tree_parent, s_snodes_tree_ptr, s_snodes_tree_children, - S.SupernodesTreeLevel(), - S.SupernodesTreeRoots()); + + Kokkos::deep_copy(a_row_ptr, A.RowPtr()); + Kokkos::deep_copy(a_cols, A.Cols()); + Kokkos::deep_copy(a_values, A.Values()); + + Kokkos::deep_copy(t_perm, T.PermVector()); + Kokkos::deep_copy(t_peri, T.InvPermVector()); + Kokkos::deep_copy(s_supernodes, S.Supernodes()); + Kokkos::deep_copy(s_gid_spanel_ptr, S.gidSuperPanelPtr()); + Kokkos::deep_copy(s_gid_spanel_colidx, S.gidSuperPanelColIdx()); + Kokkos::deep_copy(s_sid_spanel_ptr, S.sidSuperPanelPtr()); + Kokkos::deep_copy(s_sid_spanel_colidx, S.sidSuperPanelColIdx()); + Kokkos::deep_copy(s_blk_spanel_colidx, S.blkSuperPanelColIdx()); + Kokkos::deep_copy(s_snodes_tree_parent, S.SupernodesTreeParent()); + Kokkos::deep_copy(s_snodes_tree_ptr, S.SupernodesTreePtr()); + Kokkos::deep_copy(s_snodes_tree_children, S.SupernodesTreeChildren()); + + Tacho::NumericTools N( + A.NumRows(), a_row_ptr, a_cols, t_perm, t_peri, S.NumSupernodes(), s_supernodes, s_gid_spanel_ptr, + s_gid_spanel_colidx, s_sid_spanel_ptr, s_sid_spanel_colidx, s_blk_spanel_colidx, s_snodes_tree_parent, + s_snodes_tree_ptr, s_snodes_tree_children, S.SupernodesTreeLevel(), S.SupernodesTreeRoots()); N.setSerialThresholdSize(serial_thres_size); N.setMaxNumberOfSuperblocks(max_num_superblocks); std::cout << "CholTriSolve:: factorize matrix" << std::endl; - timer.reset(); + timer.reset(); if (serial) { -#if !defined (KOKKOS_ENABLE_CUDA) +#if !defined(KOKKOS_ENABLE_CUDA) N.factorizeCholesky_Serial(a_values, verbose); #endif } else { @@ -182,16 +178,13 @@ int main (int argc, char *argv[]) { N.factorizeCholesky_ParallelPanel(a_values, nb, verbose); } } - t = timer.seconds(); + t = timer.seconds(); std::cout << "CholTriSolve:: factorize matrix::time = " << t << std::endl; - - DenseMatrixBaseType - B("B", A.NumRows(), nrhs), - X("X", A.NumRows(), nrhs), - W("W", A.NumRows(), nrhs); + + DenseMatrixBaseType B("B", A.NumRows(), nrhs), X("X", A.NumRows(), nrhs), W("W", A.NumRows(), nrhs); { - Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(B, random, value_type(1)); } @@ -201,29 +194,28 @@ int main (int argc, char *argv[]) { /// std::cout << "CholTriSolve:: solve matrix via TriSolveTools" << std::endl; - timer.reset(); + timer.reset(); #if defined(TACHO_USE_TRISOLVE_VARIANT) constexpr int variant = TACHO_USE_TRISOLVE_VARIANT; #else constexpr int variant = 0; #endif - Tacho::TriSolveTools - TS(N, nrhs); + Tacho::TriSolveTools TS(N, nrhs); TS.initialize(device_level_cut, device_function_thres, verbose); TS.createStream(nstreams_solve); - TS.prepareSolve(nstreams_prepare, verbose); - t = timer.seconds(); + TS.prepareSolve(nstreams_prepare, verbose); + t = timer.seconds(); std::cout << "CholTriSolve:: TriSolve prepare::time = " << t << std::endl; - timer.reset(); + timer.reset(); #if defined(KOKKOS_ENABLE_CUDA) && defined(TACHO_ENABLE_PROFILE) cudaProfilerStart(); #endif - TS.solveCholesky(X, B, W, verbose); + TS.solveCholesky(X, B, W, verbose); #if defined(KOKKOS_ENABLE_CUDA) && defined(TACHO_ENABLE_PROFILE) cudaProfilerStop(); #endif - t = timer.seconds(); + t = timer.seconds(); std::cout << "CholTriSolve:: solve matrix::time = " << t << std::endl; TS.release(verbose); #else @@ -231,15 +223,15 @@ int main (int argc, char *argv[]) { /// solve via tasking /// std::cout << "CholTriSolve:: solve matrix" << std::endl; - timer.reset(); + timer.reset(); if (serial) { -#if !defined (KOKKOS_ENABLE_CUDA) +#if !defined(KOKKOS_ENABLE_CUDA) N.solveCholesky_Serial(X, B, W0, verbose); #endif } else { N.solveCholesky_Parallel(X, B, W0, verbose); } - t = timer.seconds(); + t = timer.seconds(); std::cout << "CholTriSolve:: solve matrix::time = " << t << std::endl; #endif diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCuSparseTriSolve.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCuSparseTriSolve.cpp similarity index 79% rename from packages/shylu/shylu_node/tacho/example/Tacho_ExampleCuSparseTriSolve.cpp rename to packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCuSparseTriSolve.cpp index 801b5f6f093c..961b138dcf47 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCuSparseTriSolve.cpp +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleCuSparseTriSolve.cpp @@ -1,17 +1,17 @@ #include #include +#include "Tacho_CommandLineParser.hpp" #include "Tacho_Internal.hpp" #include "Tacho_Solver.hpp" -#include "Tacho_CommandLineParser.hpp" -#if defined (KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) #include "Tacho_CuSparseTriSolve.hpp" #endif using namespace Tacho; -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { CommandLineParser opts("This example program measure the performance of cuSparse TriSolve on Kokkos::Cuda"); @@ -41,7 +41,8 @@ int main (int argc, char *argv[]) { opts.set_option("nb", "Internal panel size", &nb); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); @@ -51,11 +52,9 @@ int main (int argc, char *argv[]) { typedef UseThisDevice::type device_type; typedef UseThisDevice::type host_device_type; - - Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); - typedef Kokkos::TaskSchedulerMultiple scheduler_type; + Tacho::printExecSpaceConfiguration("DeviceSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); Kokkos::Timer timer; int r_val = 0; @@ -64,9 +63,9 @@ int main (int argc, char *argv[]) { /// /// read from crs matrix /// - typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; - typedef Tacho::CrsMatrixBase CrsMatrixBaseType; - typedef Kokkos::View DenseMultiVectorType; + typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; + typedef Tacho::CrsMatrixBase CrsMatrixBaseType; + typedef Kokkos::View DenseMultiVectorType; /// read a spd matrix of matrix market format CrsMatrixBaseTypeHost h_A; @@ -79,18 +78,18 @@ int main (int argc, char *argv[]) { } Tacho::MatrixMarket::read(file, h_A, sanitize, verbose); } - + /// /// A on device /// CrsMatrixBaseType A; A.createConfTo(h_A); A.copy(h_A); - + /// /// Tacho Solver (factorization) /// - Tacho::Solver solver; + Tacho::Solver solver; { solver.setMatrixType(sym, posdef); solver.setVerbose(verbose); @@ -101,10 +100,8 @@ int main (int argc, char *argv[]) { } /// inputs are used for graph reordering and analysis - solver.analyze(A.NumRows(), - A.RowPtr(), - A.Cols()); - + solver.analyze(A.NumRows(), A.RowPtr(), A.Cols()); + /// symbolic structure can be reused solver.factorize(A.Values()); @@ -118,11 +115,11 @@ int main (int argc, char *argv[]) { printf("ExampleCuSparseTriSolve: Construction of CrsMatrix of factors\n"); printf("=============================================================\n"); printf(" Time\n"); - printf(" time for construction of F in CRS: %10.6f s\n", t_factor_export); + printf(" time for construction of F in CRS: %10.6f s\n", t_factor_export); printf("\n"); F.showMe(std::cout, false); - } + } } /// @@ -130,27 +127,21 @@ int main (int argc, char *argv[]) { /// CuSparseTriSolve trisolve; trisolve.setVerbose(verbose); - + /// /// CuSparse analyze /// - { - trisolve.analyze(F.NumRows(), - F.RowPtr(), - F.Cols(), - F.Values()); - } + { trisolve.analyze(F.NumRows(), F.RowPtr(), F.Cols(), F.Values()); } /// /// random right hand side /// - DenseMultiVectorType - b("b", F.NumRows(), nrhs), // rhs multivector - x("x", F.NumRows(), nrhs), // solution multivector - t("t", F.NumRows(), nrhs), // temporary workvector - bb("bb", F.NumRows(), nrhs), // temp workspace (store permuted rhs) - xx("xx", F.NumRows(), nrhs); // temp workspace (store permuted rhs) - + DenseMultiVectorType b("b", F.NumRows(), nrhs), // rhs multivector + x("x", F.NumRows(), nrhs), // solution multivector + t("t", F.NumRows(), nrhs), // temporary workvector + bb("bb", F.NumRows(), nrhs), // temp workspace (store permuted rhs) + xx("xx", F.NumRows(), nrhs); // temp workspace (store permuted rhs) + { Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(b, random, value_type(1)); @@ -196,12 +187,12 @@ int main (int argc, char *argv[]) { printf(" time for permute and solve: %10.6f s\n", t_solve); printf("\n"); } -#else +#else timer.reset(); - for (int iter=0;iter #include +#include "Tacho_CommandLineParser.hpp" #include "Tacho_Internal.hpp" -#include "Tacho_CommandLineParser.hpp" #ifdef TACHO_HAVE_MKL #include "mkl_service.h" @@ -11,18 +11,18 @@ using namespace Tacho; -#define PRINT_TIMER \ - printf(" Time \n"); \ - printf(" byblocks/reference (speedup): %10.6f\n", t_reference/t_byblocks); \ - printf(" Task Scheduler (%s) \n", scheduler_name); \ - printf(" allocation count %10d\n", sched.queue().allocation_count());\ - printf("\n"); +#define PRINT_TIMER \ + printf(" Time \n"); \ + printf(" byblocks/reference (speedup): %10.6f\n", t_reference / t_byblocks); \ + printf(" Task Scheduler (%s) \n", scheduler_name); \ + printf(" allocation count %10d\n", sched.queue().allocation_count()); \ + printf("\n"); -template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; -static const char * scheduler_name = "TaskSchedulerMultiple"; +template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; +static const char *scheduler_name = "TaskSchedulerMultiple"; -int main (int argc, char *argv[]) { - CommandLineParser opts("This example program measure the performance of dense-by-blocks on Kokkos::OpenMP"); +int main(int argc, char *argv[]) { + CommandLineParser opts("This example program measure the performance of dense-by-blocks on Kokkos::OpenMP"); bool serial = false; int nthreads = 1; @@ -35,103 +35,100 @@ int main (int argc, char *argv[]) { opts.set_option("serial", "Flag for invoking serial algorithm", &serial); opts.set_option("kokkos-threads", "Number of threads", &nthreads); opts.set_option("verbose", "Flag for verbose printing", &verbose); - opts.set_option("begin", "Test problem begin size", &mbeg); - opts.set_option("end", "Test problem end size", &mend); - opts.set_option("step", "Test problem step size", &step); - opts.set_option("mb", "Blocksize", &mb); + opts.set_option("begin", "Test problem begin size", &mbeg); + opts.set_option("end", "Test problem end size", &mend); + opts.set_option("step", "Test problem step size", &step); + opts.set_option("mb", "Blocksize", &mb); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); typedef double value_type; - typedef Kokkos::pair range_type; + typedef Kokkos::pair range_type; typedef Kokkos::DefaultExecutionSpace exec_space; typedef Kokkos::DefaultHostExecutionSpace host_exec_space; - typedef TaskSchedulerType< exec_space> scheduler_type; - typedef TaskSchedulerType host_scheduler_type; + typedef TaskSchedulerType scheduler_type; + typedef TaskSchedulerType host_scheduler_type; printExecSpaceConfiguration("Default HostSpace"); - printExecSpaceConfiguration< exec_space>("Default DeviceSpace"); + printExecSpaceConfiguration("Default DeviceSpace"); printf("Scheduler Type = %s\n", scheduler_name); int r_val = 0; - const double eps = std::numeric_limits::epsilon()*10000; + const double eps = std::numeric_limits::epsilon() * 10000; { - typedef DenseMatrixView DenseMatrixViewType; - typedef DenseMatrixView DenseMatrixOfBlocksType; + typedef DenseMatrixView DenseMatrixViewType; + typedef DenseMatrixView DenseMatrixOfBlocksType; - typedef DenseMatrixView DenseMatrixViewHostType; - typedef DenseMatrixView DenseMatrixOfBlocksHostType; + typedef DenseMatrixView DenseMatrixViewHostType; + typedef DenseMatrixView DenseMatrixOfBlocksHostType; Kokkos::Timer timer; scheduler_type sched; - typedef TaskFunctor_Chol task_functor_chol; - typedef TaskFunctor_Trsm task_functor_trsm; - typedef TaskFunctor_Gemm task_functor_gemm; - typedef TaskFunctor_Herk task_functor_herk; - - const ordinal_type max_functor_size = 4*sizeof(task_functor_gemm); - - Kokkos::DualView - a("a", mend*mend), a1("a1", mend*mend), a2("a2", mend*mend), - b("b", mend*mend); - - const ordinal_type bmend = (mend/mb) + 1; - Kokkos::DualView - ha("ha", bmend*bmend), hb("hb", bmend*bmend), hc("hc", bmend*bmend); - - { - const ordinal_type - task_queue_capacity_tmp = 2*bmend*bmend*bmend*max_functor_size, - min_block_size = 16, - max_block_size = 4*max_functor_size, - num_superblock = 4, - superblock_size = std::max(task_queue_capacity_tmp/num_superblock,max_block_size), - task_queue_capacity = std::max(task_queue_capacity_tmp,superblock_size*num_superblock); - + typedef TaskFunctor_Chol task_functor_chol; + typedef TaskFunctor_Trsm + task_functor_trsm; + typedef TaskFunctor_Gemm + task_functor_gemm; + typedef TaskFunctor_Herk + task_functor_herk; + + const ordinal_type max_functor_size = 4 * sizeof(task_functor_gemm); + + Kokkos::DualView a("a", mend * mend), a1("a1", mend * mend), a2("a2", mend * mend), + b("b", mend * mend); + + const ordinal_type bmend = (mend / mb) + 1; + Kokkos::DualView ha("ha", bmend * bmend), hb("hb", bmend * bmend), + hc("hc", bmend * bmend); + + { + const ordinal_type task_queue_capacity_tmp = 2 * bmend * bmend * bmend * max_functor_size, min_block_size = 16, + max_block_size = 4 * max_functor_size, num_superblock = 4, + superblock_size = std::max(task_queue_capacity_tmp / num_superblock, max_block_size), + task_queue_capacity = std::max(task_queue_capacity_tmp, superblock_size * num_superblock); + std::cout << "capacity = " << task_queue_capacity << "\n"; std::cout << "min_block_size = " << min_block_size << "\n"; std::cout << "max_block_size = " << max_block_size << "\n"; std::cout << "superblock_size = " << superblock_size << "\n"; - sched = scheduler_type(typename scheduler_type::memory_space(), - (size_t)task_queue_capacity, - (unsigned)min_block_size, - (unsigned)max_block_size, - (unsigned)superblock_size); + sched = scheduler_type(typename scheduler_type::memory_space(), (size_t)task_queue_capacity, + (unsigned)min_block_size, (unsigned)max_block_size, (unsigned)superblock_size); } const ordinal_type dry = -2, niter = 3; - //const ordinal_type dry = 0, niter = 1; + // const ordinal_type dry = 0, niter = 1; double t_reference = 0, t_byblocks = 0; Random random; auto randomize = [&](const DenseMatrixViewHostType &mat) { const ordinal_type m = mat.extent(0), n = mat.extent(1); - for (ordinal_type j=0;j::invoke(A); - t_reference += (iter >= 0)*timer.seconds(); + Chol::invoke(A); + t_reference += (iter >= 0) * timer.seconds(); } t_reference /= niter; } - + // dense by blocks { - sub_a. sync_device(); + sub_a.sync_device(); sub_a2.modify_device(); DenseMatrixViewType A; A.set_view(m, m); A.attach_buffer(1, m, sub_a2.d_view.data()); - const ordinal_type bm = (m/mb) + (m%mb>0); + const ordinal_type bm = (m / mb) + (m % mb > 0); DenseMatrixOfBlocksHostType HA; HA.set_view(bm, bm); HA.attach_buffer(1, bm, ha.h_view.data()); - + DenseMatrixOfBlocksType DA; DA.set_view(bm, bm); DA.attach_buffer(1, bm, ha.d_view.data()); - - { + + { ha.modify_host(); - + setMatrixOfBlocks(HA, m, m, mb); attachBaseBuffer(HA, A.data(), A.stride_0(), A.stride_1()); ha.sync_device(); - - for (ordinal_type iter=dry;iter=0)*timer.seconds(); + t_byblocks += (iter >= 0) * timer.seconds(); } t_byblocks /= niter; clearFutureOfBlocks(HA); @@ -214,24 +210,23 @@ int main (int argc, char *argv[]) { a2.sync_host(); double diff = 0.0, norm = 0.0; - for (ordinal_type p=0;p<(m*m);++p) { - norm += a1.h_view(p)*a1.h_view(p); - diff += (a1.h_view(p) - a2.h_view(p))*(a1.h_view(p) - a2.h_view(p)); + for (ordinal_type p = 0; p < (m * m); ++p) { + norm += a1.h_view(p) * a1.h_view(p); + diff += (a1.h_view(p) - a2.h_view(p)) * (a1.h_view(p) - a2.h_view(p)); } - const double relerr = sqrt(diff/norm); + const double relerr = sqrt(diff / norm); if (relerr > eps) { - printf("******* chol problem %d fails, reltaive error against reference is %10.4f\n", - m, relerr); + printf("******* chol problem %d fails, reltaive error against reference is %10.4f\n", m, relerr); r_val = -1; break; } } - + { - const double kilo = 1024, gflop = DenseFlopCount::Chol(m)/kilo/kilo/kilo; - printf("chol problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", - m, gflop, gflop/t_reference, gflop/t_byblocks); + const double kilo = 1024, gflop = DenseFlopCount::Chol(m) / kilo / kilo / kilo; + printf("chol problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", m, gflop, + gflop / t_reference, gflop / t_byblocks); PRINT_TIMER; } } @@ -240,16 +235,17 @@ int main (int argc, char *argv[]) { /// /// Trsm /// - + #if 1 - for (ordinal_type m=mbeg;m<=mend;m+=step) { - t_reference = 0; t_byblocks = 0; - auto sub_a = Kokkos::subview(a, range_type(0,m*m)); - auto sub_a1 = Kokkos::subview(a1, range_type(0,m*m)); - auto sub_a2 = Kokkos::subview(a2, range_type(0,m*m)); + for (ordinal_type m = mbeg; m <= mend; m += step) { + t_reference = 0; + t_byblocks = 0; + auto sub_a = Kokkos::subview(a, range_type(0, m * m)); + auto sub_a1 = Kokkos::subview(a1, range_type(0, m * m)); + auto sub_a2 = Kokkos::subview(a2, range_type(0, m * m)); { - sub_a. modify_host(); + sub_a.modify_host(); sub_a1.modify_host(); DenseMatrixViewHostType A, B; @@ -266,9 +262,9 @@ int main (int argc, char *argv[]) { Kokkos::deep_copy(sub_a2.d_view, sub_a1.h_view); } - // reference + // reference { - sub_a. sync_host(); + sub_a.sync_host(); sub_a1.sync_host(); sub_a1.modify_host(); @@ -278,20 +274,19 @@ int main (int argc, char *argv[]) { B.set_view(m, m); B.attach_buffer(1, m, sub_a1.h_view.data()); - + const double alpha = -1.0; - for (ordinal_type iter=dry;iter - ::invoke(Diag::NonUnit(), alpha, A, B); - t_reference += (iter >= 0)*timer.seconds(); + Trsm::invoke(Diag::NonUnit(), alpha, A, B); + t_reference += (iter >= 0) * timer.seconds(); } t_reference /= niter; } - + // dense by blocks { - sub_a. sync_device(); + sub_a.sync_device(); sub_a2.sync_device(); sub_a2.modify_device(); @@ -302,7 +297,7 @@ int main (int argc, char *argv[]) { B.set_view(m, m); B.attach_buffer(1, m, sub_a2.d_view.data()); - const ordinal_type bm = (m/mb) + (m%mb>0); + const ordinal_type bm = (m / mb) + (m % mb > 0); ha.modify_host(); hb.modify_host(); @@ -317,55 +312,53 @@ int main (int argc, char *argv[]) { setMatrixOfBlocks(HA, m, m, mb); attachBaseBuffer(HA, A.data(), A.stride_0(), A.stride_1()); - + setMatrixOfBlocks(HB, m, m, mb); attachBaseBuffer(HB, B.data(), B.stride_0(), B.stride_1()); - + ha.sync_device(); hb.sync_device(); DenseMatrixOfBlocksType DA, DB; - + DA.set_view(bm, bm); DA.attach_buffer(1, bm, ha.d_view.data()); - + DB.set_view(bm, bm); DB.attach_buffer(1, bm, hb.d_view.data()); { const double alpha = -1.0; - for (ordinal_type iter=dry;iter=0)*timer.seconds(); + t_byblocks += (iter >= 0) * timer.seconds(); } t_byblocks /= niter; clearFutureOfBlocks(HB); } } - + { a1.sync_host(); a2.sync_host(); double diff = 0.0, norm = 0.0; - for (ordinal_type p=0;p<(m*m);++p) { - norm += a1.h_view(p)*a1.h_view(p); - diff += (a1.h_view(p) - a2.h_view(p))*(a1.h_view(p) - a2.h_view(p)); + for (ordinal_type p = 0; p < (m * m); ++p) { + norm += a1.h_view(p) * a1.h_view(p); + diff += (a1.h_view(p) - a2.h_view(p)) * (a1.h_view(p) - a2.h_view(p)); } - const double relerr = sqrt(diff/norm); + const double relerr = sqrt(diff / norm); - if (relerr > eps) - printf("******* trsm problem %d fails, reltaive error against reference is %10.4f\n", - m, relerr); + if (relerr > eps) + printf("******* trsm problem %d fails, reltaive error against reference is %10.4f\n", m, relerr); } - + { - const double kilo = 1024, gflop = DenseFlopCount::Trsm(true, m, m)/kilo/kilo/kilo; - printf("trsm problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", - m, gflop, gflop/t_reference, gflop/t_byblocks); + const double kilo = 1024, gflop = DenseFlopCount::Trsm(true, m, m) / kilo / kilo / kilo; + printf("trsm problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", m, gflop, + gflop / t_reference, gflop / t_byblocks); PRINT_TIMER; } } @@ -376,17 +369,18 @@ int main (int argc, char *argv[]) { /// Gemm /// #if 1 - for (ordinal_type m=mbeg;m<=mend;m+=step) { - t_reference = 0; t_byblocks = 0; - auto sub_a = Kokkos::subview(a, range_type(0,m*m)); - auto sub_b = Kokkos::subview(b, range_type(0,m*m)); - auto sub_a1 = Kokkos::subview(a1, range_type(0,m*m)); - auto sub_a2 = Kokkos::subview(a2, range_type(0,m*m)); + for (ordinal_type m = mbeg; m <= mend; m += step) { + t_reference = 0; + t_byblocks = 0; + auto sub_a = Kokkos::subview(a, range_type(0, m * m)); + auto sub_b = Kokkos::subview(b, range_type(0, m * m)); + auto sub_a1 = Kokkos::subview(a1, range_type(0, m * m)); + auto sub_a2 = Kokkos::subview(a2, range_type(0, m * m)); { - sub_a. modify_host(); - sub_b. modify_host(); + sub_a.modify_host(); + sub_b.modify_host(); sub_a1.modify_host(); - + DenseMatrixViewHostType A, B, C; A.set_view(m, m); A.attach_buffer(1, m, sub_a.h_view.data()); @@ -405,38 +399,37 @@ int main (int argc, char *argv[]) { Kokkos::deep_copy(sub_a2.d_view, sub_a1.h_view); } - // reference + // reference { - sub_a. sync_host(); - sub_b. sync_host(); + sub_a.sync_host(); + sub_b.sync_host(); sub_a1.sync_host(); sub_a1.modify_host(); DenseMatrixViewType A, B, C; A.set_view(m, m); A.attach_buffer(1, m, sub_a.h_view.data()); - + B.set_view(m, m); B.attach_buffer(1, m, sub_b.h_view.data()); C.set_view(m, m); C.attach_buffer(1, m, sub_a1.h_view.data()); - + const double alpha = -1.0, beta = 1.0; - for (ordinal_type iter=dry;iter - ::invoke(alpha, A, B, beta, C); - t_reference += (iter >= 0)*timer.seconds(); + Gemm::invoke(alpha, A, B, beta, C); + t_reference += (iter >= 0) * timer.seconds(); } t_reference /= niter; } - + // dense by blocks { - sub_a. sync_device(); - sub_b. sync_device(); + sub_a.sync_device(); + sub_b.sync_device(); sub_a2.sync_device(); sub_a2.modify_device(); @@ -450,7 +443,7 @@ int main (int argc, char *argv[]) { C.set_view(m, m); C.attach_buffer(1, m, sub_a2.d_view.data()); - const ordinal_type bm = (m/mb) + (m%mb>0); + const ordinal_type bm = (m / mb) + (m % mb > 0); ha.modify_host(); hb.modify_host(); @@ -469,19 +462,19 @@ int main (int argc, char *argv[]) { setMatrixOfBlocks(HA, m, m, mb); attachBaseBuffer(HA, A.data(), A.stride_0(), A.stride_1()); - + setMatrixOfBlocks(HB, m, m, mb); attachBaseBuffer(HB, B.data(), B.stride_0(), B.stride_1()); - + setMatrixOfBlocks(HC, m, m, mb); attachBaseBuffer(HC, C.data(), C.stride_0(), C.stride_1()); ha.sync_device(); hb.sync_device(); hc.sync_device(); - + DenseMatrixOfBlocksType DA, DB, DC; - + DA.set_view(bm, bm); DA.attach_buffer(1, bm, ha.d_view.data()); @@ -493,37 +486,36 @@ int main (int argc, char *argv[]) { { const double alpha = -1.0, beta = 1.0; - for (ordinal_type iter=dry;iter=0)*timer.seconds(); + t_byblocks += (iter >= 0) * timer.seconds(); } t_byblocks /= niter; clearFutureOfBlocks(HC); } } - + { a1.sync_host(); a2.sync_host(); double diff = 0.0, norm = 0.0; - for (ordinal_type p=0;p<(m*m);++p) { - norm += a1.h_view(p)*a1.h_view(p); - diff += (a1.h_view(p) - a2.h_view(p))*(a1.h_view(p) - a2.h_view(p)); + for (ordinal_type p = 0; p < (m * m); ++p) { + norm += a1.h_view(p) * a1.h_view(p); + diff += (a1.h_view(p) - a2.h_view(p)) * (a1.h_view(p) - a2.h_view(p)); } - const double relerr = sqrt(diff/norm); - if (relerr > eps) - printf("******* gemm problem %d fails, reltaive error against reference is %10.8e\n", - m, relerr); + const double relerr = sqrt(diff / norm); + if (relerr > eps) + printf("******* gemm problem %d fails, reltaive error against reference is %10.8e\n", m, relerr); } - + { - const double kilo = 1024, gflop = DenseFlopCount::Gemm(m, m, m)/kilo/kilo/kilo; - printf("gemm problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", - m, gflop, gflop/t_reference, gflop/t_byblocks); + const double kilo = 1024, gflop = DenseFlopCount::Gemm(m, m, m) / kilo / kilo / kilo; + printf("gemm problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", m, gflop, + gflop / t_reference, gflop / t_byblocks); PRINT_TIMER; } } @@ -533,16 +525,17 @@ int main (int argc, char *argv[]) { /// Herk /// #if 1 - for (ordinal_type m=mbeg;m<=mend;m+=step) { - t_reference = 0; t_byblocks = 0; - auto sub_a = Kokkos::subview(a, range_type(0,m*m)); - auto sub_a1 = Kokkos::subview(a1, range_type(0,m*m)); - auto sub_a2 = Kokkos::subview(a2, range_type(0,m*m)); + for (ordinal_type m = mbeg; m <= mend; m += step) { + t_reference = 0; + t_byblocks = 0; + auto sub_a = Kokkos::subview(a, range_type(0, m * m)); + auto sub_a1 = Kokkos::subview(a1, range_type(0, m * m)); + auto sub_a2 = Kokkos::subview(a2, range_type(0, m * m)); { - sub_a. modify_host(); + sub_a.modify_host(); sub_a1.modify_host(); - + DenseMatrixViewHostType A, C; A.set_view(m, m); A.attach_buffer(1, m, sub_a.h_view.data()); @@ -557,33 +550,32 @@ int main (int argc, char *argv[]) { Kokkos::deep_copy(sub_a2.d_view, sub_a1.h_view); } - // reference + // reference { - sub_a. sync_host(); + sub_a.sync_host(); sub_a1.sync_host(); sub_a1.modify_host(); DenseMatrixViewType A, C; A.set_view(m, m); A.attach_buffer(1, m, sub_a.h_view.data()); - + C.set_view(m, m); C.attach_buffer(1, m, sub_a1.h_view.data()); - + const double alpha = -1.0, beta = 1.0; - for (ordinal_type iter=dry;iter - ::invoke(alpha, A, beta, C); - t_reference += (iter >= 0)*timer.seconds(); + Herk::invoke(alpha, A, beta, C); + t_reference += (iter >= 0) * timer.seconds(); } t_reference /= niter; } - + // dense by blocks { - sub_a. sync_device(); + sub_a.sync_device(); sub_a2.sync_device(); sub_a2.modify_device(); @@ -594,7 +586,7 @@ int main (int argc, char *argv[]) { C.set_view(m, m); C.attach_buffer(1, m, sub_a2.d_view.data()); - const ordinal_type bm = (m/mb) + (m%mb>0); + const ordinal_type bm = (m / mb) + (m % mb > 0); ha.modify_host(); hc.modify_host(); @@ -609,15 +601,15 @@ int main (int argc, char *argv[]) { setMatrixOfBlocks(HA, m, m, mb); attachBaseBuffer(HA, A.data(), A.stride_0(), A.stride_1()); - + setMatrixOfBlocks(HC, m, m, mb); attachBaseBuffer(HC, C.data(), C.stride_0(), C.stride_1()); ha.sync_device(); hc.sync_device(); - + DenseMatrixOfBlocksType DA, DC; - + DA.set_view(bm, bm); DA.attach_buffer(1, bm, ha.d_view.data()); @@ -626,44 +618,41 @@ int main (int argc, char *argv[]) { { const double alpha = -1.0, beta = 1.0; - for (ordinal_type iter=dry;iter=0)*timer.seconds(); + t_byblocks += (iter >= 0) * timer.seconds(); } t_byblocks /= niter; clearFutureOfBlocks(HC); } } - + { a1.sync_host(); a2.sync_host(); double diff = 0.0, norm = 0.0; - for (ordinal_type p=0;p<(m*m);++p) { - norm += a1.h_view(p)*a1.h_view(p); - diff += (a1.h_view(p) - a2.h_view(p))*(a1.h_view(p) - a2.h_view(p)); + for (ordinal_type p = 0; p < (m * m); ++p) { + norm += a1.h_view(p) * a1.h_view(p); + diff += (a1.h_view(p) - a2.h_view(p)) * (a1.h_view(p) - a2.h_view(p)); } - const double relerr = sqrt(diff/norm); - if (relerr > eps) - printf("******* herk problem %d fails, reltaive error against reference is %10.8e\n", - m, relerr); + const double relerr = sqrt(diff / norm); + if (relerr > eps) + printf("******* herk problem %d fails, reltaive error against reference is %10.8e\n", m, relerr); } - + { - const double kilo = 1024, gflop = DenseFlopCount::Gemm(m, m, m)/kilo/kilo/kilo; - printf("herk problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", - m, gflop, gflop/t_reference, gflop/t_byblocks); + const double kilo = 1024, gflop = DenseFlopCount::Gemm(m, m, m) / kilo / kilo / kilo; + printf("herk problem %10d, gflop %10.2f, gflop/s :: reference %10.2f, byblocks %10.2f\n", m, gflop, + gflop / t_reference, gflop / t_byblocks); PRINT_TIMER; } } printf("\n\n"); #endif - - } Kokkos::finalize(); diff --git a/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseCholesky.hpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseCholesky.hpp new file mode 100644 index 000000000000..b50a5188662a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseCholesky.hpp @@ -0,0 +1,183 @@ +#ifndef __TACHO_EXAMPLE_DEVICE_DENSE_CHOLESKY_HPP__ +#define __TACHO_EXAMPLE_DEVICE_DENSE_CHOLESKY_HPP__ + +#include "Kokkos_Random.hpp" + +#include "Tacho_Blas_External.hpp" +#include "Tacho_Lapack_External.hpp" +#include "Tacho_Util.hpp" + +#include "Tacho_Chol.hpp" +#include "Tacho_Chol_OnDevice.hpp" + +#include "Tacho_Trsv.hpp" +#include "Tacho_Trsv_OnDevice.hpp" + +#include "Tacho_Gemv.hpp" +#include "Tacho_Gemv_OnDevice.hpp" + +template int driver_chol(const int m, const bool verbose) { + int max_iter = 1; + + Kokkos::Timer timer; + const bool detail = false; + + typedef typename Tacho::UseThisDevice::type device_type; + typedef typename Tacho::UseThisDevice::type host_device_type; + + Tacho::printExecSpaceConfiguration("DeviceSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); + printf("\n\n"); + +#if defined(KOKKOS_ENABLE_CUDA) + printf("CUDA testing\n"); +#else + printf("Host testing\n"); +#endif + int r_val = 0; + { + const value_type one(1), zero(0); + +#if defined(KOKKOS_ENABLE_CUDA) + cublasHandle_t handle_blas; + cusolverDnHandle_t handle_lapack; + { + { + const int status = cublasCreate(&handle_blas); + if (status) + printf("Nonzero error from cublasCreate %d\n", status); + } + { + const int status = cusolverDnCreate(&handle_lapack); + if (status) + printf("Nonzero error from cusolverDnCreate %d\n", status); + } + } +#else + int handle_blas, handle_lapack; // dummy +#endif + + Kokkos::View Arand("Arand", m, m); + { + Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::fill_random(Arand, random, one); + } + + Kokkos::View A("A", m, m); + { + const value_type *Arand_ptr = Arand.data(); + value_type *A_ptr = A.data(); + + timer.reset(); +#if defined(KOKKOS_ENABLE_CUDA) + Tacho::Blas::gemm(handle_blas, CUBLAS_OP_N, CUBLAS_OP_T, m, m, m, one, Arand_ptr, m, Arand_ptr, m, + zero, A_ptr, m); +#else + Tacho::Blas::gemm('N', 'T', m, m, m, one, Arand_ptr, m, Arand_ptr, m, zero, A_ptr, m); +#endif + Kokkos::fence(); + const double t = timer.seconds(); + printf("hermitianize time %e\n", t); + } + + Kokkos::View Aback("Aback", m, m); + { + Kokkos::deep_copy(Aback, A); + auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), Aback); + if (m < 20) { + printf("AA = \n"); + for (int i = 0; i < m; ++i) { + for (int j = 0; j < m; ++j) + std::cout << AA(i, j) << " "; + printf("\n"); + } + } + } + + Kokkos::View x("x", m, 1); + { + Kokkos::parallel_for( + Kokkos::RangePolicy(0, m), + KOKKOS_LAMBDA(const int i) { x(i, 0) = i + 1; }); + Kokkos::fence(); + } + + Kokkos::View b("b", m, 1); + { + Tacho::Gemv::invoke(handle_blas, one, A, x, zero, b); + Kokkos::fence(); + } + + /// factorizeCholesky + for (int iter = 0; iter < max_iter; ++iter) { + Kokkos::deep_copy(A, Aback); + Kokkos::fence(); + + Kokkos::View dev("dev"); + + timer.reset(); + int lwork(0); +#if defined(KOKKOS_ENABLE_CUDA) + Tacho::Lapack::potrf_buffersize(handle_lapack, CUBLAS_FILL_MODE_UPPER, m, A.data(), m, &lwork); + printf("Cholesky lwork %d\n", lwork); +#endif + Kokkos::View W("W", lwork); + Tacho::Chol::invoke(handle_lapack, A, W); + Kokkos::fence(); + { + const double t = timer.seconds(); + printf("Cholesky factorization time %e\n", t); + } + { + const auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev); + if (dev_host()) + printf("Cholesky returns non-zero dev info %d\n", dev_host()); + } + + Kokkos::deep_copy(x, b); + Kokkos::fence(); + + timer.reset(); + { + Tacho::Trsv::invoke( + handle_blas, Tacho::Diag::NonUnit(), A, x); + Kokkos::fence(); + Tacho::Trsv::invoke( + handle_blas, Tacho::Diag::NonUnit(), A, x); + Kokkos::fence(); + } + { + const double t = timer.seconds(); + printf("Cholesky solve time %e\n", t); + } + + { + const auto x_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); + if (m < 20) { + printf("x = \n"); + for (int i = 0; i < m; ++i) + std::cout << x(i, 0) << std::endl; + } + } + } + +#if defined(KOKKOS_ENABLE_CUDA) + { + { + const int status = cublasDestroy(handle_blas); + if (status) + printf("Nonzero error from cublasDestroy %d\n", status); + } + { + const int status = cusolverDnDestroy(handle_lapack); + if (status) + printf("Nonzero error from cusolverDnDestroy %d\n", status); + } + } +#endif + } + + return r_val; +} + +#endif diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseLDL.hpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseLDL.hpp similarity index 50% rename from packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseLDL.hpp rename to packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseLDL.hpp index 0b384c6c8784..3b37804d9522 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseLDL.hpp +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseLDL.hpp @@ -3,9 +3,9 @@ #include "Kokkos_Random.hpp" -#include "Tacho_Util.hpp" #include "Tacho_Blas_External.hpp" #include "Tacho_Lapack_External.hpp" +#include "Tacho_Util.hpp" #include "Tacho_LDL.hpp" #include "Tacho_LDL_External.hpp" @@ -24,8 +24,7 @@ #include "Tacho_ApplyPermutation.hpp" #include "Tacho_ApplyPermutation_OnDevice.hpp" -template -int driver_ldl (const int m, const bool verbose) { +template int driver_ldl(const int m, const bool verbose) { int max_iter = 1; Kokkos::Timer timer; @@ -35,107 +34,99 @@ int driver_ldl (const int m, const bool verbose) { typedef typename Tacho::UseThisDevice::type host_device_type; Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); printf("\n\n"); -#if defined (KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) printf("CUDA testing\n"); #else printf("Host testing\n"); #endif - int r_val = 0; + int r_val = 0; { const value_type one(1), zero(0); Kokkos::DefaultExecutionSpace exec_instance; -#if defined (KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) cublasHandle_t handle_blas; cusolverDnHandle_t handle_lapack; { { - const int status = cublasCreate(&handle_blas); - if (status) printf("Nonzero error from cublasCreate %d\n", status); + const int status = cublasCreate(&handle_blas); + if (status) + printf("Nonzero error from cublasCreate %d\n", status); } { - const int status = cusolverDnCreate(&handle_lapack); - if (status) printf("Nonzero error from cusolverDnCreate %d\n", status); + const int status = cusolverDnCreate(&handle_lapack); + if (status) + printf("Nonzero error from cusolverDnCreate %d\n", status); } } #else int handle_blas(0), handle_lapack(0); // dummy #endif - Kokkos::View A("A", m, m); + Kokkos::View A("A", m, m); { Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(A, random, one); Kokkos::fence(); - - Kokkos::parallel_for - (Kokkos::RangePolicy(0,m*m), - KOKKOS_LAMBDA(const int ij) { - const int i = ij%m, j = ij/m; - A(i,j) = (i < j ? A(j,i) : A(i,j)); - }); - } - - Kokkos::View Aback("Aback", m, m); + + Kokkos::parallel_for( + Kokkos::RangePolicy(0, m * m), KOKKOS_LAMBDA(const int ij) { + const int i = ij % m, j = ij / m; + A(i, j) = (i < j ? A(j, i) : A(i, j)); + }); + } + + Kokkos::View Aback("Aback", m, m); { Kokkos::deep_copy(Aback, A); if (m < 20) { auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), Aback); printf("AA = \n"); - for (int i=0;i x("x", m, 1); + Kokkos::View x("x", m, 1); { - Kokkos::parallel_for - (Kokkos::RangePolicy(0,m), - KOKKOS_LAMBDA(const int i) { - x(i, 0) = i+1; - }); + Kokkos::parallel_for( + Kokkos::RangePolicy(0, m), + KOKKOS_LAMBDA(const int i) { x(i, 0) = i + 1; }); } - Kokkos::View b("b", m, 1); + Kokkos::View b("b", m, 1); { - Tacho::Gemv - ::invoke(handle_blas, - one, - A, x, - zero, - b); + Tacho::Gemv::invoke(handle_blas, one, A, x, zero, b); Kokkos::fence(); } /// factorize LDLt - for (int iter=0;iter D("D", m, 2); - Kokkos::View p("pivots", 4*m); + Kokkos::View D("D", m, 2); + Kokkos::View p("pivots", 4 * m); + + // value_type * A_ptr = A.data(); + // value_type * x_ptr = x.data(); + // int * p_ptr = p.data(); - //value_type * A_ptr = A.data(); - //value_type * x_ptr = x.data(); - //int * p_ptr = p.data(); - timer.reset(); - Kokkos::View W; - const int lwork = Tacho::LDL - ::invoke(handle_lapack, A, p, W); - W = Kokkos::View("W", lwork); + Kokkos::View W; + const int lwork = Tacho::LDL::invoke(handle_lapack, A, p, W); + W = Kokkos::View("W", lwork); printf("LDLt lwork %d\n", lwork); - Tacho::LDL - ::invoke(handle_lapack, A, p, W); + Tacho::LDL::invoke(handle_lapack, A, p, W); // { // using policy_type = Kokkos::TeamPolicy; // policy_type policy(1, 1, 1); @@ -145,113 +136,113 @@ int driver_ldl (const int m, const bool verbose) { // }); // } Kokkos::fence(); - - Tacho::LDL - ::modify(exec_instance, A, p, D); + + Tacho::LDL::modify(exec_instance, A, p, D); Kokkos::fence(); const double t = timer.seconds(); - + auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A); auto pp = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), p); auto DD = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), D); { printf("L = \n"); - for (int i=0;i(2*m, 3*m)); - auto peri = Kokkos::subview(p, Kokkos::pair(3*m, 4*m)); + auto perm = Kokkos::subview(p, Kokkos::pair(2 * m, 3 * m)); + auto peri = Kokkos::subview(p, Kokkos::pair(3 * m, 4 * m)); /// copy and transpose - Tacho::ApplyPermutation - ::invoke(exec_instance, b, perm, x); + Tacho::ApplyPermutation::invoke( + exec_instance, b, perm, x); { auto x_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); printf("x (permuted b) = \n"); - for (int i=0;i - ::invoke(handle_blas, Tacho::Diag::Unit(), A, x); + Tacho::Trsv::invoke( + handle_blas, Tacho::Diag::Unit(), A, x); Kokkos::fence(); { auto x_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); printf("x (solve first) = \n"); - for (int i=0;i - ::invoke(exec_instance, p, D, x); + Tacho::Scale2x2_BlockInverseDiagonals::invoke(exec_instance, p, D, x); Kokkos::fence(); { auto x_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); printf("x (inverse scale diagoanl) = \n"); - for (int i=0;i - ::invoke(handle_blas, Tacho::Diag::Unit(), A, x); + Tacho::Trsv::invoke( + handle_blas, Tacho::Diag::Unit(), A, x); Kokkos::fence(); { auto x_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); printf("x (solve b) = \n"); - for (int i=0;i z("z",m,1); - Tacho::ApplyPermutation - ::invoke(exec_instance, x, peri, z); + Kokkos::View z("z", m, 1); + Tacho::ApplyPermutation::invoke( + exec_instance, x, peri, z); { auto x_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), z); printf("x (permute) = \n"); - for (int i=0;i int driver(int argc, char *argv[]) { + int nthreads = 1; + bool verbose = true; + + int max_iter = 1; + int m = 10; + + Tacho::CommandLineParser opts("This example program measure the Tacho on Kokkos::OpenMP"); + + opts.set_option("kokkos-threads", "Number of threads", &nthreads); + opts.set_option("verbose", "Flag for verbose printing", &verbose); + opts.set_option("m", "Dense problem size", &m); + + const bool r_parse = opts.parse(argc, argv); + if (r_parse) + return 0; // print help return + + Kokkos::initialize(argc, argv); + Kokkos::Timer timer; + const bool detail = false; + + typedef typename Tacho::UseThisDevice::type device_type; + typedef typename Tacho::UseThisDevice::type host_device_type; + + Tacho::printExecSpaceConfiguration("DeviceSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); + printf("\n\n"); + +#if defined(KOKKOS_ENABLE_CUDA) + printf("CUDA testing\n"); +#else + printf("Host testing\n"); +#endif + int r_val = 0; + { + const value_type one(1), zero(0); + +#if defined(KOKKOS_ENABLE_CUDA) + cublasHandle_t handle_blas; + cusolverDnHandle_t handle_lapack; + { + if (verbose) + printf("cublas/cusolver handle create begin\n"); + { + const int status = cublasCreate(&handle_blas); + if (status) + printf("Nonzero error from cublasCreate %d\n", status); + } + { + const int status = cusolverDnCreate(&handle_lapack); + if (status) + printf("Nonzero error from cusolverDnCreate %d\n", status); + } + if (verbose) + printf("cublas/cusolver handle create end\n"); + } +#endif + + Kokkos::View Arand("Arand", m, m); + { + if (verbose) + printf("test problem randomization\n"); + Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::fill_random(Arand, random, one); + } + + Kokkos::View A("A", m, m); + { + if (verbose) + printf("test problem symmetrization\n"); + const value_type *Arand_ptr = Arand.data(); + value_type *A_ptr = A.data(); + + timer.reset(); +#if defined(KOKKOS_ENABLE_CUDA) + Tacho::Blas::gemm(handle_blas, CUBLAS_OP_N, CUBLAS_OP_T, m, m, m, one, Arand_ptr, m, Arand_ptr, m, + zero, A_ptr, m); +#else + Tacho::Blas::gemm('N', 'T', m, m, m, one, Arand_ptr, m, Arand_ptr, m, zero, A_ptr, m); +#endif + Kokkos::fence(); + const double t = timer.seconds(); + printf("symmetrization time %e\n", t); + } + + Kokkos::View Aback("Aback", m, m); + { + if (verbose) + printf("test problem backup\n"); + Kokkos::deep_copy(Aback, A); + auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), Aback); + if (0) { + printf("AA = \n"); + for (int i = 0; i < m; ++i) { + for (int j = 0; j < m; ++j) + std::cout << AA(i, j) << " "; + printf("\n"); + } + } + } + + /// run Cholesky + for (int iter = 0; iter < max_iter; ++iter) { + Kokkos::deep_copy(A, Aback); + { + value_type *A_ptr = A.data(); + int dev(0); + + timer.reset(); +#if defined(KOKKOS_ENABLE_CUDA) + int lwork(0); + Tacho::Lapack::potrf_buffersize(handle_lapack, CUBLAS_FILL_MODE_UPPER, m, A_ptr, m, &lwork); + printf("Cholesky lwork %d\n", lwork); + Kokkos::View W("W", lwork); + value_type *W_ptr = W.data(); + + Kokkos::View dev_view("dev"); + Tacho::Lapack::potrf(handle_lapack, CUBLAS_FILL_MODE_UPPER, m, A_ptr, m, W_ptr, lwork, + dev_view.data()); + auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev_view); + dev = dev_host(); +#else + Tacho::Lapack::potrf('U', m, A_ptr, m, &dev); +#endif + Kokkos::fence(); + const double t = timer.seconds(); + printf("Cholesky time %e\n", t); + if (dev) + printf("Cholesky returns non-zero dev info %d\n", dev); + } + } + + /// run LDLt + for (int iter = 0; iter < max_iter; ++iter) { + Kokkos::deep_copy(A, Aback); + Kokkos::View ipiv("pivot", m); + { + value_type *A_ptr = A.data(); + int *ipiv_ptr = ipiv.data(); + int dev(0); + + timer.reset(); +#if defined(KOKKOS_ENABLE_CUDA) + int lwork(0); + Tacho::Lapack::sytrf_buffersize(handle_lapack, m, A_ptr, m, &lwork); + printf("LDLt lwork %d\n", lwork); + Kokkos::View W("W", lwork); + value_type *W_ptr = W.data(); + + Kokkos::View dev_view("dev"); + Tacho::Lapack::sytrf(handle_lapack, CUBLAS_FILL_MODE_LOWER, m, A_ptr, m, ipiv_ptr, W_ptr, lwork, + dev_view.data()); + auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev_view); + dev = dev_host(); +#else + int lwork(m * 32); + printf("LDLt lwork %d\n", lwork); + Kokkos::View W("W", lwork); + value_type *W_ptr = W.data(); + Tacho::Lapack::sytrf('L', m, A_ptr, m, ipiv_ptr, W_ptr, lwork, &dev); +#endif + Kokkos::fence(); + const double t = timer.seconds(); + auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A); + auto pp = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv); + if (0) { + printf("LDL = \n"); + for (int i = 0; i < m; ++i) { + for (int j = 0; j < m; ++j) + std::cout << AA(i, j) << " "; + printf("\n"); + } + printf("piv = \n"); + for (int i = 0; i < m; ++i) { + printf("%d\n", pp(i)); + } + } + + printf("LDLt time %e\n", t); + if (dev) + printf("LDLt returns non-zero dev info %d\n", dev); + } + } + + /// run LU + for (int iter = 0; iter < 4; ++iter) { + Kokkos::deep_copy(A, Aback); + Kokkos::View ipiv("pivot", m); + { + value_type *A_ptr = A.data(); + int *ipiv_ptr = ipiv.data(); + int dev(0); + + timer.reset(); +#if defined(KOKKOS_ENABLE_CUDA) + int lwork(0); + Tacho::Lapack::getrf_buffersize(handle_lapack, m, m, A_ptr, m, &lwork); + printf("LU lwork %d\n", lwork); + Kokkos::View W("W", lwork); + value_type *W_ptr = W.data(); + + Kokkos::View dev_view("dev"); + Tacho::Lapack::getrf(handle_lapack, m, m, A_ptr, m, W_ptr, ipiv_ptr, dev_view.data()); + auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev_view); + dev = dev_host(); +#else + Tacho::Lapack::getrf(m, m, A_ptr, m, ipiv_ptr, &dev); +#endif + Kokkos::fence(); + const double t = timer.seconds(); + printf("LU time %e\n", t); + if (dev) + printf("LU returns non-zero dev info %d\n", dev); + } + } + +#if defined(KOKKOS_ENABLE_CUDA) + { + if (verbose) + printf("cublas/cusolver handle destroy begin\n"); + { + const int status = cublasDestroy(handle_blas); + if (status) + printf("Nonzero error from cublasDestroy %d\n", status); + } + { + const int status = cusolverDnDestroy(handle_lapack); + if (status) + printf("Nonzero error from cusolverDnDestroy %d\n", status); + } + if (verbose) + printf("cublas/cusolver handle destroy end\n"); + } +#endif + } + Kokkos::finalize(); + + return r_val; +} diff --git a/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseSolver_dcomplex.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseSolver_dcomplex.cpp new file mode 100644 index 000000000000..65352cbe7a58 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseSolver_dcomplex.cpp @@ -0,0 +1,3 @@ +#include "Tacho_ExampleDeviceDenseSolver.hpp" + +int main(int argc, char *argv[]) { return driver>(argc, argv); } diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseSolver_double.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseSolver_double.cpp similarity index 71% rename from packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseSolver_double.cpp rename to packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseSolver_double.cpp index 2415e0a8cf5e..d75da02ec2bd 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseSolver_double.cpp +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleDeviceDenseSolver_double.cpp @@ -1,15 +1,15 @@ -#include "Tacho_CommandLineParser.hpp" +#include "Tacho_CommandLineParser.hpp" #include "Tacho_ExampleDeviceDenseCholesky.hpp" #include "Tacho_ExampleDeviceDenseLDL.hpp" -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { Tacho::CommandLineParser opts("This example program measure the Tacho on Kokkos::OpenMP"); int m = 10; bool verbose = false; bool test_chol = false; bool test_ldl = false; - //bool test_lu = false; + // bool test_lu = false; opts.set_option("m", "Dense problem size", &m); opts.set_option("verbose", "Flag for verbose printing", &verbose); @@ -17,21 +17,23 @@ int main (int argc, char *argv[]) { opts.set_option("test-ldl", "Flag for testing LDL", &test_ldl); // opts.set_option("test-lu", "Flag for testing LU", &test_lu); - const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return int r_val(0); Kokkos::initialize(argc, argv); { if (test_chol) { - const int r_val_chol = driver_chol(m, verbose); r_val += r_val_chol; + const int r_val_chol = driver_chol(m, verbose); + r_val += r_val_chol; } if (test_ldl) { - const int r_val_ldl = driver_ldl (m, verbose); r_val += r_val_ldl; + const int r_val_ldl = driver_ldl(m, verbose); + r_val += r_val_ldl; } } Kokkos::finalize(); - + return r_val; } diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExamplePerfTest.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExamplePerfTest.cpp similarity index 67% rename from packages/shylu/shylu_node/tacho/example/Tacho_ExamplePerfTest.cpp rename to packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExamplePerfTest.cpp index 0de486a71d6b..6449765540da 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExamplePerfTest.cpp +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExamplePerfTest.cpp @@ -8,24 +8,24 @@ #define TACHO_ITT_PAUSE __itt_pause() #define TACHO_ITT_RESUME __itt_resume() #else -#define TACHO_ITT_PAUSE +#define TACHO_ITT_PAUSE #define TACHO_ITT_RESUME #endif -#if defined( __INTEL_MKL__ ) -#include "mkl_service.h" +#if defined(__INTEL_MKL__) #include "Tacho_Pardiso.hpp" +#include "mkl_service.h" #endif -#if defined( TACHO_HAVE_SUITESPARSE ) +#if defined(TACHO_HAVE_SUITESPARSE) #include "cholmod.h" #endif -template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; -static const char * scheduler_name = "TaskSchedulerMultiple"; +template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; +static const char *scheduler_name = "TaskSchedulerMultiple"; -int main (int argc, char *argv[]) { - int nthreads = 1; +int main(int argc, char *argv[]) { + int nthreads = 1; bool verbose = true; bool sanitize = false; @@ -44,7 +44,8 @@ int main (int argc, char *argv[]) { bool use_same_ordering = true; - Tacho::CommandLineParser opts("This is Tacho performance test comparing with Pardiso and Cholmod on OpenMP and Cuda spaces"); + Tacho::CommandLineParser opts( + "This is Tacho performance test comparing with Pardiso and Cholmod on OpenMP and Cuda spaces"); // threading environment opts.set_option("kokkos-threads", "Number of threads", &nthreads); @@ -64,19 +65,20 @@ int main (int argc, char *argv[]) { // testing flags opts.set_option("test-tacho", "Flag for testing Tacho", &test_tacho); -#if defined( __INTEL_MKL__ ) +#if defined(__INTEL_MKL__) opts.set_option("test-pardiso", "Flag for testing Pardiso", &test_pardiso); #endif -#if defined( TACHO_HAVE_SUITESPARSE ) +#if defined(TACHO_HAVE_SUITESPARSE) opts.set_option("test-cholmod", "Flag for testing Cholmod", &test_cholmod); #endif opts.set_option("use-same-ordering", "Same Metis ordering is used for all tests", &use_same_ordering); TACHO_ITT_PAUSE; - + const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; + if (r_parse) + return 0; int r_val = 0; #if !defined(KOKKOS_ENABLE_CUDA) @@ -92,9 +94,9 @@ int main (int argc, char *argv[]) { typedef typename Tacho::UseThisDevice::type host_device_type; /// crs matrix format and dense multi vector - typedef Tacho::CrsMatrixBase CrsMatrixBaseType; - typedef Kokkos::View DenseMultiVectorType; - //typedef Kokkos::View OrdinalTypeArray; + typedef Tacho::CrsMatrixBase CrsMatrixBaseType; + typedef Kokkos::View DenseMultiVectorType; + // typedef Kokkos::View OrdinalTypeArray; /// /// problem setting @@ -112,13 +114,12 @@ int main (int argc, char *argv[]) { Tacho::MatrixMarket::read(file, A, sanitize, verbose); } - DenseMultiVectorType - b("b", A.NumRows(), nrhs), // rhs multivector - x("x", A.NumRows(), nrhs), // solution multivector - t("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) + DenseMultiVectorType b("b", A.NumRows(), nrhs), // rhs multivector + x("x", A.NumRows(), nrhs), // solution multivector + t("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) Tacho::Graph graph(A.NumRows(), A.NumNonZeros(), A.RowPtr(), A.Cols()); -#if defined(TACHO_HAVE_METIS) +#if defined(TACHO_HAVE_METIS) Tacho::GraphTools_Metis G(graph); #elif defined(TACHO_HAVE_SCOTCH) Tacho::GraphTools_Scotch G(graph); @@ -130,20 +131,20 @@ int main (int argc, char *argv[]) { { Tacho::Random random; const ordinal_type m = A.NumRows(); - for (ordinal_type rhs=0;rhs flush; // ----------------------------------------------------------------- if (test_pardiso) { -#if defined( __INTEL_MKL__ ) +#if defined(__INTEL_MKL__) flush.run(); Kokkos::Timer timer; @@ -158,12 +159,12 @@ int main (int argc, char *argv[]) { Pardiso pardiso; constexpr int AlgoChol = 2; - r_val = pardiso.init(); + r_val = pardiso.init(); if (r_val) { std::cout << "PardisoChol:: Pardiso init error = " << r_val << std::endl; pardiso.showErrorCode(std::cout) << std::endl; } - pardiso.setParameter(2, 2); // metis ordering is used + pardiso.setParameter(2, 2); // metis ordering is used if (use_same_ordering) pardiso.setParameter(5, 1); // user permutation is used for mkl permutation @@ -172,9 +173,9 @@ int main (int argc, char *argv[]) { Asym.createConfTo(A); { size_t nnz = 0; - for (ordinal_type i=0;i rowptr("rowptr", Asym.NumRows()+1); - for (ordinal_type i=0;i<=Asym.NumRows();++i) + Kokkos::View rowptr("rowptr", Asym.NumRows() + 1); + for (ordinal_type i = 0; i <= Asym.NumRows(); ++i) rowptr(i) = Asym.RowPtrBegin(i); - pardiso.setProblem(Asym.NumRows(), - (double*)Asym.Values().data(), - (int*)rowptr.data(),// (int*)Asym.RowPtr().data(), - (int*)Asym.Cols().data(), - (int*)G.PermVector().data(), - nrhs, - (double*)b.data(), - (double*)x.data()); - + pardiso.setProblem(Asym.NumRows(), (double *)Asym.Values().data(), + (int *)rowptr.data(), // (int*)Asym.RowPtr().data(), + (int *)Asym.Cols().data(), (int *)G.PermVector().data(), nrhs, (double *)b.data(), + (double *)x.data()); + r_val = pardiso.run(Pardiso::Analyze, 1); - if (r_val) { + if (r_val) { std::cout << "PardisoChol:: Pardiso analyze error = " << r_val << std::endl; - pardiso.showErrorCode(std::cout) << std::endl; - } else { - pardiso.showStat(std::cout, Pardiso::Analyze) << std::endl; + pardiso.showErrorCode(std::cout) << std::endl; + } else { + pardiso.showStat(std::cout, Pardiso::Analyze) << std::endl; } // compute inverse permutation { const auto peri = G.InvPermVector(); const auto perm = G.PermVector(); - for (ordinal_type i=0;inrow); printf(" max number of nonzeros: %10zu\n", AA->nzmax); - printf("\n"); + printf("\n"); printf(" Factors L\n"); printf(" number of nonzeros: %10.0f\n", c.lnz); - //printf(" current method: %10d\n", c.current); - printf("\n"); + // printf(" current method: %10d\n", c.current); + printf("\n"); } TACHO_ITT_RESUME; timer.reset(); - cholmod_factorize (AA, LL, &c) ; /* factorize */ + cholmod_factorize(AA, LL, &c); /* factorize */ t_factor = timer.seconds(); TACHO_ITT_PAUSE; @@ -348,7 +344,7 @@ int main (int argc, char *argv[]) { printf("==================\n"); printf(" Time\n"); printf(" total time spent: %10.6f s\n", t_factor); - printf("\n"); + printf("\n"); printf(" Property\n"); printf(" is_super: %10d\n", LL->is_super); printf(" ordering: %10d\n", LL->ordering); @@ -358,20 +354,24 @@ int main (int argc, char *argv[]) { printf(" 3 - METIS\n"); printf(" 4 - NESDIS (CHOLMOD nd)\n"); printf(" 5 - AMD for A, COLAMD for A*A'\n"); - printf("\n"); + printf("\n"); printf(" Memory\n"); - printf(" memory used: %10.6f MB\n", double(c.memory_inuse)/1024/1024); - printf(" peak memory used in factorization: %10.6f MB\n", double(c.memory_usage)/1024/1024); - printf("\n"); + printf(" memory used: %10.6f MB\n", + double(c.memory_inuse) / 1024 / 1024); + printf(" peak memory used in factorization: %10.6f MB\n", + double(c.memory_usage) / 1024 / 1024); + printf("\n"); printf(" FLOPs\n"); - printf(" gflop for numeric factorization: %10.6f GFLOP\n", c.fl/1024/1024/1024); - printf(" gflop/s for numeric factorization: %10.6f GFLOP/s\n", c.fl/1024/1024/1024/t_factor); - printf("\n"); + printf(" gflop for numeric factorization: %10.6f GFLOP\n", + c.fl / 1024 / 1024 / 1024); + printf(" gflop/s for numeric factorization: %10.6f GFLOP/s\n", + c.fl / 1024 / 1024 / 1024 / t_factor); + printf("\n"); } timer.reset(); - for (int iter=0;iternrow, n = nrhs, lda = AA->nrow; - double *xxx = (double*)xx->x; - for (int j=0;jx; + for (int j = 0; j < n; ++j) + for (int i = 0; i < m; ++i) { + norm += x(i, j) * x(i, j); + const double tmp = xxx[i + j * lda] - x(i, j); + diff += tmp * tmp; } - printf ("CHOLMOD: diff to tacho %10.6e\n", diff/norm); + printf("CHOLMOD: diff to tacho %10.6e\n", diff / norm); } - rr = cholmod_copy_dense (bb, &c) ; /* r = b */ - cholmod_sdmult (AA, 0, m1, one, xx, rr, &c) ; /* r = r-Ax */ - printf ("CHOLMOD: residual %10.6e\n", cholmod_norm_dense (rr, 0, &c)); + rr = cholmod_copy_dense(bb, &c); /* r = b */ + cholmod_sdmult(AA, 0, m1, one, xx, rr, &c); /* r = r-Ax */ + printf("CHOLMOD: residual %10.6e\n", cholmod_norm_dense(rr, 0, &c)); - cholmod_free_factor (&LL, &c) ; /* free matrices */ - cholmod_free_sparse (&AA, &c) ; - cholmod_free_dense (&rr, &c) ; - cholmod_free_dense (&xx, &c) ; - cholmod_free_dense (&bb, &c) ; - cholmod_finish (&c) ; /* finish CHOLMOD */ + cholmod_free_factor(&LL, &c); /* free matrices */ + cholmod_free_sparse(&AA, &c); + cholmod_free_dense(&rr, &c); + cholmod_free_dense(&xx, &c); + cholmod_free_dense(&bb, &c); + cholmod_finish(&c); /* finish CHOLMOD */ printf("CHOLMOD: Finished\n"); printf("=================\n"); #else @@ -425,12 +425,12 @@ int main (int argc, char *argv[]) { /// /// tacho /// - typedef TaskSchedulerType scheduler_type; + typedef TaskSchedulerType scheduler_type; printf("Scheduler Type = %s\n", scheduler_name); - Tacho::Solver solver; + Tacho::Solver solver; - //solver.setMatrixType(sym, posdef); + // solver.setMatrixType(sym, posdef); solver.setVerbose(verbose); solver.setMaxNumberOfSuperblocks(max_num_superblocks); solver.setSmallProblemThresholdsize(small_problem_thres); @@ -439,11 +439,7 @@ int main (int argc, char *argv[]) { solver.setFrontUpdateMode(front_update_mode); /// inputs are used for graph reordering and analysis - solver.analyze(A.NumRows(), - A.RowPtr(), - A.Cols(), - G.PermVector(), - G.InvPermVector()); + solver.analyze(A.NumRows(), A.RowPtr(), A.Cols(), G.PermVector(), G.InvPermVector()); /// symbolic structure can be reused TACHO_ITT_RESUME; @@ -452,7 +448,7 @@ int main (int argc, char *argv[]) { solver.setVerbose(0); // disable verbose out for the iteration timer.reset(); - for (int iter=0;iter using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; -static const char * scheduler_name = "TaskSchedulerMultiple"; - -template -int driver (int argc, char *argv[]) { +template int driver(int argc, char *argv[]) { int nthreads = 1; int max_num_superblocks = 4; bool verbose = true; @@ -53,13 +49,15 @@ int driver (int argc, char *argv[]) { opts.set_option("nb", "Internal panel size", &nb); opts.set_option("levelset", "Enable levelset scheduling", &levelset); opts.set_option("device-level-cut", "Device function is used above this level", &device_level_cut); - opts.set_option("device-factor-thres", "Device function is used above this subproblem size", &device_factor_thres); + opts.set_option("device-factor-thres", "Device function is used above this subproblem size", + &device_factor_thres); opts.set_option("device-solve-thres", "Device function is used above this subproblem size", &device_solve_thres); opts.set_option("variant", "algorithm variant in levelset scheduling; 0 or 1", &variant); opts.set_option("nstreams", "# of streams used in CUDA; on host, it is ignored", &nstreams); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); @@ -68,21 +66,18 @@ int driver (int argc, char *argv[]) { typedef typename Tacho::UseThisDevice::type device_type; typedef typename Tacho::UseThisDevice::type host_device_type; - typedef TaskSchedulerType scheduler_type; - Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); - printf("Scheduler Type = %s\n", scheduler_name); + Tacho::printExecSpaceConfiguration("HostSpace", detail); int r_val = 0; - + { /// crs matrix format and dense multi vector - //typedef Tacho::CrsMatrixBase CrsMatrixBaseType; - typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; - - typedef Kokkos::View DenseMultiVectorType; - //typedef Kokkos::View DenseMultiVectorTypeHost; + // typedef Tacho::CrsMatrixBase CrsMatrixBaseType; + typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; + + typedef Kokkos::View DenseMultiVectorType; + // typedef Kokkos::View DenseMultiVectorTypeHost; /// read a spd matrix of matrix market format CrsMatrixBaseTypeHost A; @@ -101,7 +96,7 @@ int driver (int argc, char *argv[]) { /// read graph file if available using size_type_array_host = typename CrsMatrixBaseTypeHost::size_type_array; using ordinal_type_array_host = typename CrsMatrixBaseTypeHost::ordinal_type_array; - + ordinal_type m_graph(0); size_type_array_host ap_graph; ordinal_type_array_host aw_graph, aj_graph; @@ -115,19 +110,19 @@ int driver (int argc, char *argv[]) { return -1; } in >> m_graph; - - ap_graph = size_type_array_host("ap", m_graph+1); - for (ordinal_type i=0,iend=m_graph+1;i> ap_graph(i); - + aj_graph = ordinal_type_array_host("aj", ap_graph(m_graph)); - for (ordinal_type i=0;i> aj_graph(j); } } - + { std::ifstream in; in.open(weight_file); @@ -139,12 +134,12 @@ int driver (int argc, char *argv[]) { in >> m; in >> m_graph; aw_graph = ordinal_type_array_host("aw", m_graph); - for (ordinal_type i=0;i> aw_graph(i); } } } - + /// /// * to wrap triple pointers, declare following view types /// typedef Kokkos::View ordinal_type_array; @@ -156,16 +151,16 @@ int driver (int argc, char *argv[]) { /// typedef typename CrsMatrixBaseType::size_type_array size_type_array; /// typedef typename CrsMatrixBaseType::value_type_array value_type_array; /// - /// * wrap triple pointers (row_ptr, colidx_ptr, value_ptr) with views + /// * wrap triple pointers (row_ptr, colidx_ptr, value_ptr) with views /// size_type_array ap(row_ptr, nrows + 1); /// ordinal_type_array aj(colidx_ptr, nnz); /// value_type_array ax(value_ptr, nnz); - /// - /// * attach views into csr matrix + /// + /// * attach views into csr matrix /// CrsMatrixBaseType A; /// A.setExternalMatrix(nrows, ncols, nnzm ap, aj, ax); - /// - Tacho::Solver solver; + /// + Tacho::Solver solver; /// common options solver.setMatrixType(sym, posdef); @@ -192,37 +187,28 @@ int driver (int argc, char *argv[]) { /// inputs are used for graph reordering and analysis if (m_graph > 0 && m_graph < A.NumRows()) - solver.analyze(A.NumRows(), - A.RowPtr(), - A.Cols(), - m_graph, - ap_graph, - aj_graph, - aw_graph); - else - solver.analyze(A.NumRows(), - A.RowPtr(), - A.Cols()); + solver.analyze(A.NumRows(), A.RowPtr(), A.Cols(), m_graph, ap_graph, aj_graph, aw_graph); + else + solver.analyze(A.NumRows(), A.RowPtr(), A.Cols()); /// create numeric tools and levelset tools solver.initialize(); /// symbolic structure can be reused solver.factorize(values_on_device); - - DenseMultiVectorType - b("b", A.NumRows(), nrhs), // rhs multivector - x("x", A.NumRows(), nrhs), // solution multivector - t("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) - + + DenseMultiVectorType b("b", A.NumRows(), nrhs), // rhs multivector + x("x", A.NumRows(), nrhs), // solution multivector + t("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) + { Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(b, random, value_type(1)); } - for (int i=0;i<3;++i) + for (int i = 0; i < 3; ++i) solver.solve(x, b, t); - + const double res = solver.computeRelativeResidual(values_on_device, x, b); std::cout << "TachoSolver: residual = " << res << "\n\n"; diff --git a/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleSolver_dcomplex.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleSolver_dcomplex.cpp new file mode 100644 index 000000000000..284c77a88e21 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleSolver_dcomplex.cpp @@ -0,0 +1,3 @@ +#include "Tacho_ExampleSolver.hpp" + +int main(int argc, char *argv[]) { return driver>(argc, argv); } diff --git a/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleSolver_double.cpp b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleSolver_double.cpp new file mode 100644 index 000000000000..3fbf38162276 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/example/.do-not-use/Tacho_ExampleSolver_double.cpp @@ -0,0 +1,3 @@ +#include "Tacho_ExampleSolver.hpp" + +int main(int argc, char *argv[]) { return driver(argc, argv); } diff --git a/packages/shylu/shylu_node/tacho/example/CMakeLists.txt b/packages/shylu/shylu_node/tacho/example/CMakeLists.txt index acecf8e099b4..db060b446210 100644 --- a/packages/shylu/shylu_node/tacho/example/CMakeLists.txt +++ b/packages/shylu/shylu_node/tacho/example/CMakeLists.txt @@ -9,157 +9,74 @@ FILE(GLOB SOURCES *.cpp) SET(LIBRARIES tacho) TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCombineDataFileToMatrixMarketFile + Tacho_ExampleCombineDataFileToMatrixMarketFile.x + NOEXESUFFIX NOEXEPREFIX SOURCES Tacho_ExampleCombineDataFileToMatrixMarketFile.cpp COMM serial mpi ) +# +# Driver +# TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleDeviceDenseSolverDouble + Tacho_ExampleDriverDouble.x + NOEXESUFFIX NOEXEPREFIX - SOURCES Tacho_ExampleDeviceDenseSolver_double.cpp + SOURCES Tacho_ExampleDriver_double.cpp COMM serial mpi ) TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleDeviceDenseSolverDoubleComplex + Tacho_ExampleDriverDoubleComplex.x + NOEXESUFFIX NOEXEPREFIX - SOURCES Tacho_ExampleDeviceDenseSolver_dcomplex.cpp + SOURCES Tacho_ExampleDriver_dcomplex.cpp COMM serial mpi ) -IF (TACHO_HAVE_KOKKOS_TASK) - # - # Supernodes - # - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCholSupernodes - NOEXEPREFIX - SOURCES Tacho_ExampleCholSupernodes.cpp - COMM serial mpi - ) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCholTriSolveVar0 - TARGET_DEFINES -DTACHO_USE_TRISOLVE_VARIANT=0 - NOEXEPREFIX - SOURCES Tacho_ExampleCholTriSolve.cpp - COMM serial mpi - ) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCholTriSolveVar1 - TARGET_DEFINES -DTACHO_USE_TRISOLVE_VARIANT=1 - NOEXEPREFIX - SOURCES Tacho_ExampleCholTriSolve.cpp - COMM serial mpi - ) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCholLevelSetVar0 - TARGET_DEFINES -DTACHO_USE_TRISOLVE_VARIANT=0 - NOEXEPREFIX - SOURCES Tacho_ExampleCholLevelSet.cpp - COMM serial mpi - ) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCholLevelSetVar1 - TARGET_DEFINES -DTACHO_USE_TRISOLVE_VARIANT=1 - NOEXEPREFIX - SOURCES Tacho_ExampleCholLevelSet.cpp - COMM serial mpi - ) - - # - # Solver - # - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleSolverDouble - NOEXEPREFIX - SOURCES Tacho_ExampleSolver_double.cpp - COMM serial mpi - ) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleSolverDoubleComplex - NOEXEPREFIX - SOURCES Tacho_ExampleSolver_dcomplex.cpp - COMM serial mpi - ) - - # - # Driver - # - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleDriverDouble - NOEXEPREFIX - SOURCES Tacho_ExampleDriver_double.cpp - COMM serial mpi - ) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleDriverDoubleComplex - NOEXEPREFIX - SOURCES Tacho_ExampleDriver_dcomplex.cpp - COMM serial mpi - ) - - # - # External Interface - # - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleExternalInterface - NOEXEPREFIX - SOURCES Tacho_ExampleExternalInterface.cpp - COMM serial mpi - ) +# +# External Interface +# +TRIBITS_ADD_EXECUTABLE( + Tacho_ExampleExternalInterface.x + NOEXESUFFIX + NOEXEPREFIX + SOURCES Tacho_ExampleExternalInterface.cpp + COMM serial mpi +) - # - # DenseByBlocks - # +# +# NVIDIA cuSolver +# +IF(Kokkos_ENABLE_CUDA) TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleDenseByBlocks + Tacho_ExampleCholCuSolver.x + NOEXESUFFIX NOEXEPREFIX - SOURCES Tacho_ExampleDenseByBlocks.cpp + SOURCES Tacho_ExampleCholCuSolver.cpp COMM serial mpi ) +ENDIF() - # - # NVIDIA cuSolver - # - IF(Kokkos_ENABLE_CUDA) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCholCuSolver - NOEXEPREFIX - SOURCES Tacho_ExampleCholCuSolver.cpp - COMM serial mpi - ) +# +# Intel MKL Pardiso +# +IF(Kokkos_ENABLE_OPENMP OR Kokkos_ENABLE_SERIAL) + IF(TPL_ENABLE_MKL) TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCuSparseTriSolve + Tacho_ExampleCholPardiso.x + NOEXESUFFIX NOEXEPREFIX - SOURCES Tacho_ExampleCuSparseTriSolve.cpp + SOURCES Tacho_ExampleCholPardiso.cpp COMM serial mpi ) ENDIF() +ENDIF() - # - # Intel MKL Pardiso and PerfTest - # - IF(Kokkos_ENABLE_OPENMP OR Kokkos_ENABLE_SERIAL) - IF(TPL_ENABLE_MKL) - TRIBITS_ADD_EXECUTABLE( - Tacho_ExampleCholPardiso +TRIBITS_COPY_FILES_TO_BINARY_DIR(Tacho_SimpleSparseTest_File + SOURCE_FILES test2.mtx test.mtx graph.dat weight.dat graph_test.mtx + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR} + DEST_DIR ${CMAKE_CURRENT_BINARY_DIR} NOEXEPREFIX - SOURCES Tacho_ExampleCholPardiso.cpp - COMM serial mpi - ) - ENDIF() - TRIBITS_ADD_EXECUTABLE( - Tacho_ExamplePerfTest - NOEXEPREFIX - SOURCES Tacho_ExamplePerfTest.cpp - COMM serial mpi - ) - ENDIF() +) - TRIBITS_COPY_FILES_TO_BINARY_DIR(Tacho_SimpleSparseTest_File - SOURCE_FILES test.mtx graph.dat weight.dat graph_test.mtx - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR} - DEST_DIR ${CMAKE_CURRENT_BINARY_DIR} - ) -ENDIF() diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholCuSolver.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholCuSolver.cpp index 1cdb45c17357..84395173e36f 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholCuSolver.cpp +++ b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholCuSolver.cpp @@ -1,16 +1,16 @@ #include #include -#include "Tacho_Internal.hpp" #include "Tacho_CommandLineParser.hpp" +#include "Tacho_Internal.hpp" -#if defined (KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) #include "Tacho_CuSolver.hpp" #endif using namespace Tacho; -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { CommandLineParser opts("This example program measure the performance of cuSolver on Kokkos::Cuda"); bool verbose = true; @@ -24,7 +24,8 @@ int main (int argc, char *argv[]) { opts.set_option("nrhs", "Number of RHS vectors", &nrhs); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); @@ -34,9 +35,9 @@ int main (int argc, char *argv[]) { typedef UseThisDevice::type device_type; typedef UseThisDevice::type host_device_type; - + Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); Kokkos::Timer timer; int r_val = 0; @@ -45,9 +46,9 @@ int main (int argc, char *argv[]) { /// /// read from crs matrix /// - typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; - typedef Tacho::CrsMatrixBase CrsMatrixBaseType; - typedef Kokkos::View DenseMultiVectorType; + typedef Tacho::CrsMatrixBase CrsMatrixBaseTypeHost; + typedef Tacho::CrsMatrixBase CrsMatrixBaseType; + typedef Kokkos::View DenseMultiVectorType; /// read a spd matrix of matrix market format CrsMatrixBaseTypeHost h_A; @@ -60,7 +61,7 @@ int main (int argc, char *argv[]) { } Tacho::MatrixMarket::read(file, h_A, sanitize, verbose); } - + /// /// cuSolver /// @@ -73,22 +74,24 @@ int main (int argc, char *argv[]) { #if defined(TACHO_HAVE_METIS) typedef GraphTools_Metis graph_tools_type; #else - typedef GraphTools graph_tools_type; + typedef GraphTools graph_tools_type; #endif Graph graph(h_A.NumRows(), h_A.NumNonZeros(), h_A.RowPtr(), h_A.Cols()); graph_tools_type G(graph); G.reorder(verbose); - + const auto h_perm = G.PermVector(); const auto h_peri = G.InvPermVector(); - const auto perm = Kokkos::create_mirror_view(typename device_type::memory_space(), h_perm); Kokkos::deep_copy(perm, h_perm); - const auto peri = Kokkos::create_mirror_view(typename device_type::memory_space(), h_peri); Kokkos::deep_copy(peri, h_peri); + const auto perm = Kokkos::create_mirror_view(typename device_type::memory_space(), h_perm); + Kokkos::deep_copy(perm, h_perm); + const auto peri = Kokkos::create_mirror_view(typename device_type::memory_space(), h_peri); + Kokkos::deep_copy(peri, h_peri); CrsMatrixBaseType A; A.createConfTo(h_A); A.copy(h_A); - + /// permute ondevice CrsMatrixBaseType Ap; { @@ -101,34 +104,29 @@ int main (int argc, char *argv[]) { if (verbose) { printf("ExampleCuSolver: Construction of permuted matrix A\n"); printf(" Time\n"); - printf(" time for permutation of A: %10.6f s\n", t_permute_A); + printf(" time for permutation of A: %10.6f s\n", t_permute_A); printf("\n"); - } + } } /// /// analyze /// - { - cusolver.analyze(Ap.NumRows(), Ap.RowPtr(), Ap.Cols()); - } + { cusolver.analyze(Ap.NumRows(), Ap.RowPtr(), Ap.Cols()); } /// /// factorize /// - { - cusolver.factorize(Ap.Values()); - } + { cusolver.factorize(Ap.Values()); } /// /// random right hand side /// - DenseMultiVectorType - b("b", A.NumRows(), nrhs), // rhs multivector - x("x", A.NumRows(), nrhs), // solution multivector - bb("bb", A.NumRows(), nrhs), // temp workspace (store permuted rhs) - xx("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) - + DenseMultiVectorType b("b", A.NumRows(), nrhs), // rhs multivector + x("x", A.NumRows(), nrhs), // solution multivector + bb("bb", A.NumRows(), nrhs), // temp workspace (store permuted rhs) + xx("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) + { Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(b, random, value_type(1)); @@ -139,15 +137,16 @@ int main (int argc, char *argv[]) { /// { timer.reset(); - applyRowPermutationToDenseMatrix(bb, b, perm); + const auto exec_instance = typename device_type::execution_space(); + ApplyPermutation::invoke(exec_instance, b, perm, bb); cusolver.solve(xx, bb); - applyRowPermutationToDenseMatrix(x, xx, peri); + ApplyPermutation::invoke(exec_instance, xx, peri, x); Kokkos::fence(); const double t_solve = timer.seconds(); if (verbose) { printf("ExampleCuSolver: P b, solve, and P^{-1} x\n"); printf(" Time\n"); - printf(" time for permute and solve: %10.6f s\n", t_solve); + printf(" time for permute and solve: %10.6f s\n", t_solve); printf("\n"); } } @@ -155,16 +154,15 @@ int main (int argc, char *argv[]) { /// /// compute residual to check solutions /// - const double res = computeRelativeResidual(A, x, b); - - std::cout << "cuSolver: residual = " << res << "\n\n"; + const double res = computeRelativeResidual(A, x, b); + std::cout << "cuSolver: residual = " << res << "\n\n"; } #else r_val = -1; std::cout << "CUDA is NOT configured in Trilinos" << std::endl; #endif - + Kokkos::finalize(); return r_val; diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholPardiso.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholPardiso.cpp index 4b6804d3ed01..8cd63fcbb9a3 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholPardiso.cpp +++ b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleCholPardiso.cpp @@ -1,17 +1,17 @@ #include #include -#include "Tacho_Internal.hpp" #include "Tacho_CommandLineParser.hpp" +#include "Tacho_Internal.hpp" -#if defined (TACHO_HAVE_MKL) -#include "mkl.h" +#if defined(TACHO_HAVE_MKL) #include "Tacho_Pardiso.hpp" +#include "mkl.h" #endif using namespace Tacho; -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { CommandLineParser opts("This example program measure the performance of Pardiso on Kokkos::OpenMP"); int nthreads = 1; @@ -27,26 +27,27 @@ int main (int argc, char *argv[]) { opts.set_option("nrhs", "Number of RHS vectors", &nrhs); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return const bool skip_factorize = false, skip_solve = false; - + Kokkos::initialize(argc, argv); typedef typename UseThisDevice::type host_device_type; printExecSpaceConfiguration("HostDevice", false); int r_val = 0; -#if defined (__INTEL_MKL__) +#if defined(__INTEL_MKL__) { typedef double value_type; - typedef CrsMatrixBase CrsMatrixBaseType; - typedef Kokkos::View DenseMatrixBaseType; - - // mkl nthreads setting + typedef CrsMatrixBase CrsMatrixBaseType; + typedef Kokkos::View DenseMatrixBaseType; + + // mkl nthreads setting mkl_set_dynamic(0); mkl_set_num_threads(nthreads); - + Kokkos::Timer timer; double t = 0.0; Pardiso pardiso; @@ -55,16 +56,16 @@ int main (int argc, char *argv[]) { std::cout << "PardisoChol:: init" << std::endl; { timer.reset(); - r_val = pardiso.init(); + r_val = pardiso.init(); t = timer.seconds(); - + if (r_val) { std::cout << "PardisoChol:: Pardiso init error = " << r_val << std::endl; pardiso.showErrorCode(std::cout) << std::endl; } } std::cout << "PardisoChol:: init ::time = " << t << std::endl; - + std::cout << "PardisoChol:: import input file = " << file_input << std::endl; CrsMatrixBaseType A, Asym; timer.reset(); @@ -78,14 +79,14 @@ int main (int argc, char *argv[]) { } } MatrixMarket::read(file_input, A, sanitize, verbose); - + // somehow pardiso does not like symmetric full matrix (store only half) Asym.createConfTo(A); { size_type nnz = 0; - for (ordinal_type i=0;i rowptr("rowptr", Asym.NumRows()+1); - { - for (ordinal_type i=0;i<=Asym.NumRows();++i) + Kokkos::View rowptr("rowptr", Asym.NumRows() + 1); + { + for (ordinal_type i = 0; i <= Asym.NumRows(); ++i) rowptr(i) = Asym.RowPtrBegin(i); - } + } std::cout << "PardisoChol:: import input file::time = " << t << std::endl; - - DenseMatrixBaseType - B("B", Asym.NumRows(), nrhs), - X("X", Asym.NumRows(), nrhs), - P("P", Asym.NumRows(), 1); - + + DenseMatrixBaseType B("B", Asym.NumRows(), nrhs), X("X", Asym.NumRows(), nrhs), P("P", Asym.NumRows(), 1); + { const auto m = Asym.NumRows(); Random random; - for (ordinal_type rhs=0;rhs #include -#include "Tacho_Internal.hpp" #include "Tacho_CommandLineParser.hpp" +#include "Tacho_Internal.hpp" using namespace Tacho; -int main (int argc, char *argv[]) { +int main(int argc, char *argv[]) { CommandLineParser opts("This example program combines data file into a single matrix market file"); std::string graph_data_file = "graph.dat"; @@ -19,14 +19,15 @@ int main (int argc, char *argv[]) { opts.set_option("matrix-market-file", "Output matrixmarket file", &matrix_market_file); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); typedef Kokkos::DefaultHostExecutionSpace host_space; { typedef double value_type; - typedef CrsMatrixBase CrsMatrixBaseTypeHost; + typedef CrsMatrixBase CrsMatrixBaseTypeHost; CrsMatrixBaseTypeHost A; using ordinal_type_array = typename CrsMatrixBaseTypeHost::ordinal_type_array; @@ -46,13 +47,13 @@ int main (int argc, char *argv[]) { } in >> m; - ap = size_type_array("ap", m+1); - for (ordinal_type i=0;i<(m+1);++i) + ap = size_type_array("ap", m + 1); + for (ordinal_type i = 0; i < (m + 1); ++i) in >> ap(i); nnz = ap(m); aj = ordinal_type_array("aj", nnz); - for (ordinal_type k=0;k> aj(k); } { @@ -62,9 +63,9 @@ int main (int argc, char *argv[]) { std::cout << "Failed in open the file: " << value_data_file << std::endl; return -1; } - + ax = value_type_array("ax", nnz); - for (ordinal_type k=0;k> ax(k); } diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseCholesky.hpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseCholesky.hpp deleted file mode 100644 index 4bb3d92e22ba..000000000000 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseCholesky.hpp +++ /dev/null @@ -1,203 +0,0 @@ -#ifndef __TACHO_EXAMPLE_DEVICE_DENSE_CHOLESKY_HPP__ -#define __TACHO_EXAMPLE_DEVICE_DENSE_CHOLESKY_HPP__ - -#include "Kokkos_Random.hpp" - -#include "Tacho_Util.hpp" -#include "Tacho_Blas_External.hpp" -#include "Tacho_Lapack_External.hpp" - -#include "Tacho_Chol.hpp" -#include "Tacho_Chol_OnDevice.hpp" - -#include "Tacho_Trsv.hpp" -#include "Tacho_Trsv_OnDevice.hpp" - -#include "Tacho_Gemv.hpp" -#include "Tacho_Gemv_OnDevice.hpp" - -template -int driver_chol (const int m, const bool verbose) { - int max_iter = 1; - - Kokkos::Timer timer; - const bool detail = false; - - typedef typename Tacho::UseThisDevice::type device_type; - typedef typename Tacho::UseThisDevice::type host_device_type; - - Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); - printf("\n\n"); - -#if defined (KOKKOS_ENABLE_CUDA) - printf("CUDA testing\n"); -#else - printf("Host testing\n"); -#endif - int r_val = 0; - { - const value_type one(1), zero(0); - -#if defined (KOKKOS_ENABLE_CUDA) - cublasHandle_t handle_blas; - cusolverDnHandle_t handle_lapack; - { - { - const int status = cublasCreate(&handle_blas); - if (status) printf("Nonzero error from cublasCreate %d\n", status); - } - { - const int status = cusolverDnCreate(&handle_lapack); - if (status) printf("Nonzero error from cusolverDnCreate %d\n", status); - } - } -#else - int handle_blas, handle_lapack; // dummy -#endif - - Kokkos::View Arand("Arand", m, m); - { - Kokkos::Random_XorShift64_Pool random(13718); - Kokkos::fill_random(Arand, random, one); - } - - Kokkos::View A("A", m, m); - { - const value_type * Arand_ptr = Arand.data(); - value_type * A_ptr = A.data(); - - timer.reset(); -#if defined (KOKKOS_ENABLE_CUDA) - Tacho::Blas::gemm(handle_blas, - CUBLAS_OP_N, CUBLAS_OP_T, - m, m, m, - one, - Arand_ptr, m, - Arand_ptr, m, - zero, - A_ptr, m); -#else - Tacho::Blas::gemm('N', 'T', - m, m, m, - one, - Arand_ptr, m, - Arand_ptr, m, - zero, - A_ptr, m); -#endif - Kokkos::fence(); - const double t = timer.seconds(); - printf("hermitianize time %e\n", t); - } - - Kokkos::View Aback("Aback", m, m); - { - Kokkos::deep_copy(Aback, A); - auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), Aback); - if (m < 20) { - printf("AA = \n"); - for (int i=0;i x("x", m, 1); - { - Kokkos::parallel_for - (Kokkos::RangePolicy(0,m), - KOKKOS_LAMBDA(const int i) { - x(i,0) = i+1; - }); - Kokkos::fence(); - } - - Kokkos::View b("b", m, 1); - { - Tacho::Gemv - ::invoke(handle_blas, - one, - A, x, - zero, - b); - Kokkos::fence(); - } - - /// factorizeCholesky - for (int iter=0;iter dev("dev"); - - timer.reset(); - int lwork(0); -#if defined (KOKKOS_ENABLE_CUDA) - Tacho::Lapack::potrf_buffersize(handle_lapack, - CUBLAS_FILL_MODE_UPPER, - m, - A.data(), m, - &lwork); - printf("Cholesky lwork %d\n", lwork); -#endif - Kokkos::View W("W", lwork); - Tacho::Chol::invoke(handle_lapack, A, W); - Kokkos::fence(); - { - const double t = timer.seconds(); - printf("Cholesky factorization time %e\n", t); - } - { - const auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev); - if (dev_host()) printf("Cholesky returns non-zero dev info %d\n", dev_host()); - } - - Kokkos::deep_copy(x, b); - Kokkos::fence(); - - timer.reset(); - { - Tacho::Trsv - ::invoke(handle_blas, Tacho::Diag::NonUnit(), A, x); - Kokkos::fence(); - Tacho::Trsv - ::invoke(handle_blas, Tacho::Diag::NonUnit(), A, x); - Kokkos::fence(); - } - { - const double t = timer.seconds(); - printf("Cholesky solve time %e\n", t); - } - - { - const auto x_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); - if (m < 20) { - printf("x = \n"); - for (int i=0;i -int driver (int argc, char *argv[]) { - int nthreads = 1; - bool verbose = true; - - int max_iter = 1; - int m = 10; - - Tacho::CommandLineParser opts("This example program measure the Tacho on Kokkos::OpenMP"); - - opts.set_option("kokkos-threads", "Number of threads", &nthreads); - opts.set_option("verbose", "Flag for verbose printing", &verbose); - opts.set_option("m", "Dense problem size", &m); - - const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return - - Kokkos::initialize(argc, argv); - Kokkos::Timer timer; - const bool detail = false; - - typedef typename Tacho::UseThisDevice::type device_type; - typedef typename Tacho::UseThisDevice::type host_device_type; - - Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); - printf("\n\n"); - -#if defined (KOKKOS_ENABLE_CUDA) - printf("CUDA testing\n"); -#else - printf("Host testing\n"); -#endif - int r_val = 0; - { - const value_type one(1), zero(0); - -#if defined (KOKKOS_ENABLE_CUDA) - cublasHandle_t handle_blas; - cusolverDnHandle_t handle_lapack; - { - if (verbose) printf("cublas/cusolver handle create begin\n"); - { - const int status = cublasCreate(&handle_blas); - if (status) printf("Nonzero error from cublasCreate %d\n", status); - } - { - const int status = cusolverDnCreate(&handle_lapack); - if (status) printf("Nonzero error from cusolverDnCreate %d\n", status); - } - if (verbose) printf("cublas/cusolver handle create end\n"); - } -#endif - - Kokkos::View Arand("Arand", m, m); - { - if (verbose) printf("test problem randomization\n"); - Kokkos::Random_XorShift64_Pool random(13718); - Kokkos::fill_random(Arand, random, one); - } - - Kokkos::View A("A", m, m); - { - if (verbose) printf("test problem symmetrization\n"); - const value_type * Arand_ptr = Arand.data(); - value_type * A_ptr = A.data(); - - timer.reset(); -#if defined (KOKKOS_ENABLE_CUDA) - Tacho::Blas::gemm(handle_blas, - CUBLAS_OP_N, CUBLAS_OP_T, - m, m, m, - one, - Arand_ptr, m, - Arand_ptr, m, - zero, - A_ptr, m); -#else - Tacho::Blas::gemm('N', 'T', - m, m, m, - one, - Arand_ptr, m, - Arand_ptr, m, - zero, - A_ptr, m); -#endif - Kokkos::fence(); - const double t = timer.seconds(); - printf("symmetrization time %e\n", t); - } - - Kokkos::View Aback("Aback", m, m); - { - if (verbose) printf("test problem backup\n"); - Kokkos::deep_copy(Aback, A); - auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), Aback); - if (0) { - printf("AA = \n"); - for (int i=0;i::potrf_buffersize(handle_lapack, - CUBLAS_FILL_MODE_UPPER, - m, - A_ptr, m, - &lwork); - printf("Cholesky lwork %d\n", lwork); - Kokkos::View W("W", lwork); - value_type * W_ptr = W.data(); - - Kokkos::View dev_view("dev"); - Tacho::Lapack::potrf(handle_lapack, - CUBLAS_FILL_MODE_UPPER, - m, - A_ptr, m, - W_ptr, lwork, - dev_view.data()); - auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev_view); - dev = dev_host(); -#else - Tacho::Lapack::potrf('U', - m, - A_ptr, m, - &dev); -#endif - Kokkos::fence(); - const double t = timer.seconds(); - printf("Cholesky time %e\n", t); - if (dev) printf("Cholesky returns non-zero dev info %d\n", dev); - } - } - - /// run LDLt - for (int iter=0;iter ipiv("pivot", m); - { - value_type * A_ptr = A.data(); - int * ipiv_ptr = ipiv.data(); - int dev(0); - - timer.reset(); -#if defined (KOKKOS_ENABLE_CUDA) - int lwork(0); - Tacho::Lapack::sytrf_buffersize(handle_lapack, - m, - A_ptr, m, - &lwork); - printf("LDLt lwork %d\n", lwork); - Kokkos::View W("W", lwork); - value_type * W_ptr = W.data(); - - Kokkos::View dev_view("dev"); - Tacho::Lapack::sytrf(handle_lapack, - CUBLAS_FILL_MODE_LOWER, - m, - A_ptr, m, - ipiv_ptr, - W_ptr, lwork, - dev_view.data()); - auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev_view); - dev = dev_host(); -#else - int lwork(m*32); - printf("LDLt lwork %d\n", lwork); - Kokkos::View W("W", lwork); - value_type * W_ptr = W.data(); - Tacho::Lapack::sytrf('L', - m, - A_ptr, m, - ipiv_ptr, - W_ptr, lwork, - &dev); -#endif - Kokkos::fence(); - const double t = timer.seconds(); - auto AA = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A); - auto pp = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv); - if (0) { - printf("LDL = \n"); - for (int i=0;i ipiv("pivot", m); - { - value_type * A_ptr = A.data(); - int * ipiv_ptr = ipiv.data(); - int dev(0); - - timer.reset(); -#if defined (KOKKOS_ENABLE_CUDA) - int lwork(0); - Tacho::Lapack::getrf_buffersize(handle_lapack, - m, m, - A_ptr, m, - &lwork); - printf("LU lwork %d\n", lwork); - Kokkos::View W("W", lwork); - value_type * W_ptr = W.data(); - - Kokkos::View dev_view("dev"); - Tacho::Lapack::getrf(handle_lapack, - m, m, - A_ptr, m, - W_ptr, - ipiv_ptr, - dev_view.data()); - auto dev_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), dev_view); - dev = dev_host(); -#else - Tacho::Lapack::getrf(m, m, - A_ptr, m, - ipiv_ptr, - &dev); -#endif - Kokkos::fence(); - const double t = timer.seconds(); - printf("LU time %e\n", t); - if (dev) printf("LU returns non-zero dev info %d\n", dev); - } - } - -#if defined (KOKKOS_ENABLE_CUDA) - { - if (verbose) printf("cublas/cusolver handle destroy begin\n"); - { - const int status = cublasDestroy(handle_blas); - if (status) printf("Nonzero error from cublasDestroy %d\n", status); - } - { - const int status = cusolverDnDestroy(handle_lapack); - if (status) printf("Nonzero error from cusolverDnDestroy %d\n", status); - } - if (verbose) printf("cublas/cusolver handle destroy end\n"); - } -#endif - } - Kokkos::finalize(); - - return r_val; -} diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseSolver_dcomplex.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseSolver_dcomplex.cpp deleted file mode 100644 index 58bc547e1e23..000000000000 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDeviceDenseSolver_dcomplex.cpp +++ /dev/null @@ -1,5 +0,0 @@ -#include "Tacho_ExampleDeviceDenseSolver.hpp" - -int main (int argc, char *argv[]) { - return driver >(argc, argv); -} diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver.hpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver.hpp index 7100158bc876..749916381622 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver.hpp +++ b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver.hpp @@ -1,14 +1,13 @@ #include "Kokkos_Random.hpp" -#include "Tacho_Driver.hpp" +#include "Tacho_CommandLineParser.hpp" #include "Tacho_CrsMatrixBase.hpp" +#include "Tacho_Driver.hpp" #include "Tacho_MatrixMarket.hpp" -#include "Tacho_CommandLineParser.hpp" using ordinal_type = Tacho::ordinal_type; -template -int driver (int argc, char *argv[]) { +template int driver(int argc, char *argv[]) { int nthreads = 1; bool verbose = true; bool sanitize = false; @@ -17,8 +16,8 @@ int driver (int argc, char *argv[]) { std::string graph_file = ""; std::string weight_file = ""; int nrhs = 1; - int sym = 2; - int posdef = 1; + std::string method_name = "chol"; + int method = 1; // 1 - Chol, 2 - LDL, 3 - SymLU int small_problem_thres = 1024; int device_factor_thres = 64; int device_solve_thres = 128; @@ -35,16 +34,28 @@ int driver (int argc, char *argv[]) { opts.set_option("graph", "Input condensed graph", &graph_file); opts.set_option("weight", "Input condensed graph weight", &weight_file); opts.set_option("nrhs", "Number of RHS vectors", &nrhs); - opts.set_option("symmetric", "Symmetric type: 0 - unsym, 1 - structure sym, 2 - symmetric", &sym); - opts.set_option("posdef", "Positive definite: 0 - indef, 1 - positive definite", &posdef); + opts.set_option("method", "Solution method: chol, ldl, lu", &method_name); opts.set_option("small-problem-thres", "LAPACK is used smaller than this thres", &small_problem_thres); - opts.set_option("device-factor-thres", "Device function is used above this subproblem size", &device_factor_thres); + opts.set_option("device-factor-thres", "Device function is used above this subproblem size", + &device_factor_thres); opts.set_option("device-solve-thres", "Device function is used above this subproblem size", &device_solve_thres); opts.set_option("variant", "algorithm variant in levelset scheduling; 0, 1 and 2", &variant); opts.set_option("nstreams", "# of streams used in CUDA; on host, it is ignored", &nstreams); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return + + if (method_name == "chol") + method = 1; + else if (method_name == "ldl") + method = 2; + else if (method_name == "lu") + method = 3; + else { + std::cout << "Error: not supported solution method\n"; + return -1; + } Kokkos::initialize(argc, argv); @@ -54,13 +65,13 @@ int driver (int argc, char *argv[]) { using host_device_type = typename Tacho::UseThisDevice::type; Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); - int r_val = 0; - { + int r_val = 0; + try { /// crs matrix format and dense multi vector - using CrsMatrixBaseTypeHost = Tacho::CrsMatrixBase; - using DenseMultiVectorType = Kokkos::View; + using CrsMatrixBaseTypeHost = Tacho::CrsMatrixBase; + using DenseMultiVectorType = Kokkos::View; /// read a spd matrix of matrix market format CrsMatrixBaseTypeHost A; @@ -79,7 +90,7 @@ int driver (int argc, char *argv[]) { /// read graph file if available using size_type_array_host = typename CrsMatrixBaseTypeHost::size_type_array; using ordinal_type_array_host = typename CrsMatrixBaseTypeHost::ordinal_type_array; - + ordinal_type m_graph(0); size_type_array_host ap_graph; ordinal_type_array_host aw_graph, aj_graph; @@ -93,19 +104,19 @@ int driver (int argc, char *argv[]) { return -1; } in >> m_graph; - - ap_graph = size_type_array_host("ap", m_graph+1); - for (ordinal_type i=0,iend=m_graph+1;i> ap_graph(i); - + aj_graph = ordinal_type_array_host("aj", ap_graph(m_graph)); - for (ordinal_type i=0;i> aj_graph(j); } } - + { std::ifstream in; in.open(weight_file); @@ -117,16 +128,16 @@ int driver (int argc, char *argv[]) { in >> m; in >> m_graph; aw_graph = ordinal_type_array_host("aw", m_graph); - for (ordinal_type i=0;i> aw_graph(i); } } } - - Tacho::Driver solver; + + Tacho::Driver solver; /// common options - solver.setMatrixType(sym, posdef); + solver.setSolutionMethod(method); solver.setSmallProblemThresholdsize(small_problem_thres); solver.setVerbose(verbose); @@ -143,43 +154,36 @@ int driver (int argc, char *argv[]) { /// inputs are used for graph reordering and analysis if (m_graph > 0 && m_graph < A.NumRows()) - solver.analyze(A.NumRows(), - A.RowPtr(), - A.Cols(), - m_graph, - ap_graph, - aj_graph, - aw_graph); - else - solver.analyze(A.NumRows(), - A.RowPtr(), - A.Cols()); + solver.analyze(A.NumRows(), A.RowPtr(), A.Cols(), m_graph, ap_graph, aj_graph, aw_graph); + else + solver.analyze(A.NumRows(), A.RowPtr(), A.Cols()); /// create numeric tools and levelset tools solver.initialize(); /// symbolic structure can be reused - for (int i=0;i<2;++i) + for (int i = 0; i < 2; ++i) solver.factorize(values_on_device); - - DenseMultiVectorType - b("b", A.NumRows(), nrhs), // rhs multivector - x("x", A.NumRows(), nrhs), // solution multivector - t("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) - + + DenseMultiVectorType b("b", A.NumRows(), nrhs), // rhs multivector + x("x", A.NumRows(), nrhs), // solution multivector + t("t", A.NumRows(), nrhs); // temp workspace (store permuted rhs) + { Kokkos::Random_XorShift64_Pool random(13718); Kokkos::fill_random(b, random, value_type(1)); } - for (int i=0;i<3;++i) + for (int i = 0; i < 3; ++i) solver.solve(x, b, t); - + const double res = solver.computeRelativeResidual(values_on_device, x, b); std::cout << "TachoSolver: residual = " << res << "\n\n"; solver.release(); + } catch (const std::exception &e) { + std::cerr << "Error: exception is caught: \n" << e.what() << "\n"; } Kokkos::finalize(); diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_dcomplex.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_dcomplex.cpp index 0d47232e5a2d..de350ec20bbc 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_dcomplex.cpp +++ b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_dcomplex.cpp @@ -1,5 +1,3 @@ #include "Tacho_ExampleDriver.hpp" -int main (int argc, char *argv[]) { - return driver >(argc, argv); -} +int main(int argc, char *argv[]) { return driver>(argc, argv); } diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_double.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_double.cpp index 8280e0e05693..ebf7e4274d7d 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_double.cpp +++ b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleDriver_double.cpp @@ -1,5 +1,3 @@ #include "Tacho_ExampleDriver.hpp" -int main (int argc, char *argv[]) { - return driver(argc, argv); -} +int main(int argc, char *argv[]) { return driver(argc, argv); } diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleExternalInterface.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleExternalInterface.cpp index 6fcb2b9969a5..959196af4498 100644 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleExternalInterface.cpp +++ b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleExternalInterface.cpp @@ -1,43 +1,33 @@ #include "Tacho_ExampleExternalInterface.hpp" -#include "Tacho_CommandLineParser.hpp" +#include "Tacho_CommandLineParser.hpp" -#if defined(TACHO_USE_INT_INT) +#if defined(TACHO_USE_INT_INT) -double solutionError(const int numRows, - const double* rhs, - const double* sol, - const int* rowBegin, - const int* columns, - const double* values) -{ +double solutionError(const int numRows, const double *rhs, const double *sol, const int *rowBegin, const int *columns, + const double *values) { double normRhsSquared(0), normErrorSquared(0); - for (int i=0; i 0 uses the array interface /// - if (nSolvers > 0) + if (nSolvers > 0) solver.Initialize(nSolvers, numRows, rowBegin, columns, values); else solver.Initialize(numRows, rowBegin, columns, values); Kokkos::fence(); - /// /// Export supernodes /// @@ -108,50 +96,43 @@ void testTachoSolver(int nSolvers, if (nSolvers > 0) { const int iSolver = 0; - solver.exportUpperTriangularFactorsToCrsMatrix(iSolver, - rowBeginU, - columnsU, - valuesU, - perm); + solver.exportUpperTriangularFactorsToCrsMatrix(iSolver, rowBeginU, columnsU, valuesU, perm); } else { - solver.exportUpperTriangularFactorsToCrsMatrix(rowBeginU, - columnsU, - valuesU, - perm); + solver.exportUpperTriangularFactorsToCrsMatrix(rowBeginU, columnsU, valuesU, perm); } /// /// std vector right hand side /// if an application uses std vector for interfacing rhs, - /// it requires additional copy. it is better to directly + /// it requires additional copy. it is better to directly /// use a kokkos device view. /// const int nProb = nSolvers > 0 ? nSolvers : 1; - std::vector rhs(numRows*nProb), sol(numRows*nProb); + std::vector rhs(numRows * nProb), sol(numRows * nProb); { /// randomize rhs const unsigned int seed = 0; srand(seed); - for (int i=0;i ViewVectorType; - ViewVectorType x("x", numRows*nProb, NRHS); + typedef Kokkos::View ViewVectorType; + ViewVectorType x("x", numRows * nProb, NRHS); -#if defined (KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) /// transfer b into device - ViewVectorType b(Kokkos::ViewAllocateWithoutInitializing("b"), numRows*nProb, NRHS); - Kokkos::deep_copy(Kokkos::subview(b, Kokkos::ALL(), 0), - Kokkos::View(rhs.data(), numRows*nProb)); + ViewVectorType b(Kokkos::ViewAllocateWithoutInitializing("b"), numRows * nProb, NRHS); + Kokkos::deep_copy(Kokkos::subview(b, Kokkos::ALL(), 0), + Kokkos::View(rhs.data(), numRows * nProb)); #else /// wrap rhs data with view - ViewVectorType b(rhs.data(), numRows*nProb, NRHS); + ViewVectorType b(rhs.data(), numRows * nProb, NRHS); #endif timer.reset(); - - for (int run=0; run("file", "Input file (MatrixMarket SPD matrix)", &file); opts.set_option("nsolver", "# of solvers for testing array solver interface", &nsolver); opts.set_option("niter", "# of solver iterations", &niter); opts.set_option("verbose", "Flag for verbose printing", &verbose); - opts.set_option("posdef", "Flag to indicate that the matrix positive definite", &posdef); + opts.set_option("method", "Solution method 1 - chol, 2 - LDL, 3 - SymLU", &solution_method); const bool r_parse = opts.parse(argc, argv); - if (r_parse) return 0; // print help return + if (r_parse) + return 0; // print help return Kokkos::initialize(argc, argv); const bool detail = false; Tacho::printExecSpaceConfiguration("DeviceSpace", detail); - Tacho::printExecSpaceConfiguration("HostSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); { - Tacho::CrsMatrixBase A; + Tacho::CrsMatrixBase A; { std::ifstream in; in.open(file); @@ -211,18 +192,17 @@ int main(int argc, char *argv[]) { { int numRows = A.NumRows(), *rowBegin = A.RowPtr().data(), *columns = A.Cols().data(); - double * values = nullptr; - Kokkos::View - As("As", nsolver, A.Values().extent(0)); + double *values = nullptr; + Kokkos::View As("As", nsolver, A.Values().extent(0)); if (nsolver) { /// duplicate A - for (int i=0;i::type; - using exec_space = typename device_type::execution_space; - - using host_device_type = typename Tacho::UseThisDevice::type; - using host_space = typename host_device_type::execution_space; - - using ViewVectorType = Kokkos::View; - using ViewVectorTypeInt = Kokkos::View; - - template class tachoSolver - { - public: - - typedef Tacho::Solver solver_type; - - typedef Tacho::ordinal_type ordinal_type; - typedef Tacho::size_type size_type; - - typedef typename solver_type::value_type value_type; - typedef typename solver_type::ordinal_type_array ordinal_type_array; - typedef typename solver_type::size_type_array size_type_array; - typedef typename solver_type::value_type_array value_type_array; - typedef typename solver_type::value_type_matrix value_type_matrix; - - // host typedefs - typedef typename solver_type::ordinal_type_array_host ordinal_type_array_host; - typedef typename solver_type::size_type_array_host size_type_array_host; - typedef typename solver_type::value_type_array_host value_type_array_host; - typedef typename solver_type::value_type_matrix_host value_type_matrix_host; - - tachoSolver(const int* solverParams) : - m_numRows(0), - m_Solver() - { - setSolverParameters(solverParams); - } - - ~tachoSolver() - { - for (auto & solver : m_SolverArray) { - solver.release(); - } - m_Solver.release(); +enum TACHO_PARAM_INDICES { + USEDEFAULTSOLVERPARAMETERS, + VERBOSITY, + SMALLPROBLEMTHRESHOLDSIZE, + SOLUTION_METHOD, + TASKING_OPTION_BLOCKSIZE, + TASKING_OPTION_PANELSIZE, + TASKING_OPTION_MAXNUMSUPERBLOCKS, + LEVELSET_OPTION_SCHEDULING, + LEVELSET_OPTION_DEVICE_LEVEL_CUT, + LEVELSET_OPTION_DEVICE_FACTOR_THRES, + LEVELSET_OPTION_DEVICE_SOLVE_THRES, + LEVELSET_OPTION_NSTREAMS, + LEVELSET_OPTION_VARIANT, + INDEX_LENGTH +}; + +using device_type = typename Tacho::UseThisDevice::type; +using exec_space = typename device_type::execution_space; + +using host_device_type = typename Tacho::UseThisDevice::type; +using host_space = typename host_device_type::execution_space; + +using ViewVectorType = Kokkos::View; +using ViewVectorTypeInt = Kokkos::View; + +template class tachoSolver { +public: + typedef Tacho::Solver solver_type; + + typedef Tacho::ordinal_type ordinal_type; + typedef Tacho::size_type size_type; + + typedef typename solver_type::value_type value_type; + typedef typename solver_type::ordinal_type_array ordinal_type_array; + typedef typename solver_type::size_type_array size_type_array; + typedef typename solver_type::value_type_array value_type_array; + typedef typename solver_type::value_type_matrix value_type_matrix; + + // host typedefs + typedef typename solver_type::ordinal_type_array_host ordinal_type_array_host; + typedef typename solver_type::size_type_array_host size_type_array_host; + typedef typename solver_type::value_type_array_host value_type_array_host; + typedef typename solver_type::value_type_matrix_host value_type_matrix_host; + + tachoSolver(const int *solverParams) : m_numRows(0), m_Solver() { setSolverParameters(solverParams); } + + ~tachoSolver() { + for (auto &solver : m_SolverArray) { + solver.release(); } + m_Solver.release(); + } + + int Initialize(int numRows, + /// with TACHO_ENABLE_INT_INT, size_type is "int" + int *rowBegin, int *columns, SX *values, int numGraphRows = 0, int *graphRowBegin = nullptr, + int *graphColumns = nullptr, int *graphWeights = nullptr) { + m_numRows = numRows; + if (m_numRows == 0) + return 0; - int Initialize(int numRows, - /// with TACHO_ENABLE_INT_INT, size_type is "int" - int* rowBegin, - int* columns, - SX* values, - int numGraphRows = 0, - int* graphRowBegin = nullptr, - int* graphColumns = nullptr, - int* graphWeights = nullptr) - { - m_numRows = numRows; - if (m_numRows == 0) return 0; - - const int numTerms = rowBegin[numRows]; - size_type_array_host ap_host((size_type*) rowBegin, numRows+1); - ordinal_type_array_host aj_host((ordinal_type*)columns, numTerms); + const int numTerms = rowBegin[numRows]; + size_type_array_host ap_host((size_type *)rowBegin, numRows + 1); + ordinal_type_array_host aj_host((ordinal_type *)columns, numTerms); - size_type_array_host graph_ap_host; - ordinal_type_array_host graph_aj_host; - ordinal_type_array_host graph_aw_host; + size_type_array_host graph_ap_host; + ordinal_type_array_host graph_aj_host; + ordinal_type_array_host graph_aw_host; - if (numGraphRows > 0) { - if (graphRowBegin == nullptr || - graphColumns == nullptr || - graphWeights == nullptr) { - std::cout << "ExternalInterface::Error, with non-zero numGraphRows, graph pointers should not be nullptr\n"; - std::logic_error("Error: one of graph pointers is nullptr"); - } else { - const size_type nnz_graph = graph_ap_host(numGraphRows); - graph_ap_host = size_type_array_host((size_type*)graphRowBegin, numGraphRows+1); - graph_aj_host = ordinal_type_array_host((ordinal_type*)graphColumns, nnz_graph); - graph_aw_host = ordinal_type_array_host((ordinal_type*)graphWeights, nnz_graph); - } + if (numGraphRows > 0) { + if (graphRowBegin == nullptr || graphColumns == nullptr || graphWeights == nullptr) { + std::cout << "ExternalInterface::Error, with non-zero numGraphRows, graph pointers should not be nullptr\n"; + std::logic_error("Error: one of graph pointers is nullptr"); + } else { + const size_type nnz_graph = graph_ap_host(numGraphRows); + graph_ap_host = size_type_array_host((size_type *)graphRowBegin, numGraphRows + 1); + graph_aj_host = ordinal_type_array_host((ordinal_type *)graphColumns, nnz_graph); + graph_aw_host = ordinal_type_array_host((ordinal_type *)graphWeights, nnz_graph); } + } -#if defined (KOKKOS_ENABLE_CUDA) - /// transfer A into device - value_type_array ax(Kokkos::ViewAllocateWithoutInitializing("ax"), - numTerms); - value_type_array_host ax_host(values, numTerms); - Kokkos::deep_copy(ax, ax_host); +#if defined(KOKKOS_ENABLE_CUDA) + /// transfer A into device + value_type_array ax(Kokkos::ViewAllocateWithoutInitializing("ax"), numTerms); + value_type_array_host ax_host(values, numTerms); + Kokkos::deep_copy(ax, ax_host); #else - /// wrap pointer on host - value_type_array ax(values, numTerms); + /// wrap pointer on host + value_type_array ax(values, numTerms); #endif - Kokkos::Timer timer; - { - timer.reset(); - if (numGraphRows > 0) { - m_Solver.analyze(numRows, ap_host, aj_host, - numGraphRows, graph_ap_host, graph_aj_host, graph_aw_host); - } else { - m_Solver.analyze(numRows, ap_host, aj_host); - } - const double t = timer.seconds(); - std::cout << "ExternalInterface:: analyze time " << t << std::endl; + Kokkos::Timer timer; + { + timer.reset(); + if (numGraphRows > 0) { + m_Solver.analyze(numRows, ap_host, aj_host, numGraphRows, graph_ap_host, graph_aj_host, graph_aw_host); + } else { + m_Solver.analyze(numRows, ap_host, aj_host); } + const double t = timer.seconds(); + std::cout << "ExternalInterface:: analyze time " << t << std::endl; + } - { - timer.reset(); - m_Solver.initialize(); - const double t = timer.seconds(); - std::cout << "ExternalInterface:: initialize time " << t << std::endl; - } - - /// I recommend to separate factorization from the solver initialization - { - timer.reset(); - m_Solver.factorize(ax); - const double t = timer.seconds(); - std::cout << "ExternalInterface:: factorize time " << t << std::endl; - } + { + timer.reset(); + m_Solver.initialize(); + const double t = timer.seconds(); + std::cout << "ExternalInterface:: initialize time " << t << std::endl; + } - return 0; + /// I recommend to separate factorization from the solver initialization + { + timer.reset(); + m_Solver.factorize(ax); + const double t = timer.seconds(); + std::cout << "ExternalInterface:: factorize time " << t << std::endl; } + return 0; + } - int Initialize(int numSolvers, - int numRows, - /// with TACHO_ENABLE_INT_INT, size_type is "int" - int* rowBegin, - int* columns, - SX* values, - int numGraphRows = 0, - int* graphRowBegin = nullptr, - int* graphColumns = nullptr, - int* graphWeights = nullptr) - { - m_numRows = numRows; - if (m_numRows == 0) return 0; - if (numSolvers > 0) { - /// this is okay - } else { - std::cout << "ExternalInterface::Error, nSolver is not a positive number\n"; - std::logic_error("Error: nSolver must be a positive number"); - } + int Initialize(int numSolvers, int numRows, + /// with TACHO_ENABLE_INT_INT, size_type is "int" + int *rowBegin, int *columns, SX *values, int numGraphRows = 0, int *graphRowBegin = nullptr, + int *graphColumns = nullptr, int *graphWeights = nullptr) { + m_numRows = numRows; + if (m_numRows == 0) + return 0; + if (numSolvers > 0) { + /// this is okay + } else { + std::cout << "ExternalInterface::Error, nSolver is not a positive number\n"; + std::logic_error("Error: nSolver must be a positive number"); + } - const int numTerms = rowBegin[numRows]; - size_type_array_host ap_host((size_type*) rowBegin, numRows+1); - ordinal_type_array_host aj_host((ordinal_type*)columns, numTerms); + const int numTerms = rowBegin[numRows]; + size_type_array_host ap_host((size_type *)rowBegin, numRows + 1); + ordinal_type_array_host aj_host((ordinal_type *)columns, numTerms); - size_type_array_host graph_ap_host; - ordinal_type_array_host graph_aj_host; - ordinal_type_array_host graph_aw_host; + size_type_array_host graph_ap_host; + ordinal_type_array_host graph_aj_host; + ordinal_type_array_host graph_aw_host; - if (numGraphRows > 0) { - if (graphRowBegin == nullptr || - graphColumns == nullptr || - graphWeights == nullptr) { - std::cout << "ExternalInterface::Error, with non-zero numGraphRows, graph pointers should not be nullptr\n"; - std::logic_error("Error: one of graph pointers is nullptr"); - } else { - const size_type nnz_graph = graph_ap_host(numGraphRows); - graph_ap_host = size_type_array_host((size_type*)graphRowBegin, numGraphRows+1); - graph_aj_host = ordinal_type_array_host((ordinal_type*)graphColumns, nnz_graph); - graph_aw_host = ordinal_type_array_host((ordinal_type*)graphWeights, nnz_graph); - } + if (numGraphRows > 0) { + if (graphRowBegin == nullptr || graphColumns == nullptr || graphWeights == nullptr) { + std::cout << "ExternalInterface::Error, with non-zero numGraphRows, graph pointers should not be nullptr\n"; + std::logic_error("Error: one of graph pointers is nullptr"); + } else { + const size_type nnz_graph = graph_ap_host(numGraphRows); + graph_ap_host = size_type_array_host((size_type *)graphRowBegin, numGraphRows + 1); + graph_aj_host = ordinal_type_array_host((ordinal_type *)graphColumns, nnz_graph); + graph_aw_host = ordinal_type_array_host((ordinal_type *)graphWeights, nnz_graph); } + } -#if defined (KOKKOS_ENABLE_CUDA) - /// transfer A into device - value_type_array ax(Kokkos::ViewAllocateWithoutInitializing("ax"), numSolvers*numTerms); - value_type_array_host ax_host(values, numSolvers*numTerms); - Kokkos::deep_copy(ax, ax_host); +#if defined(KOKKOS_ENABLE_CUDA) + /// transfer A into device + value_type_array ax(Kokkos::ViewAllocateWithoutInitializing("ax"), numSolvers * numTerms); + value_type_array_host ax_host(values, numSolvers * numTerms); + Kokkos::deep_copy(ax, ax_host); #else - /// wrap pointer on host - value_type_array ax(values, numSolvers*numTerms); + /// wrap pointer on host + value_type_array ax(values, numSolvers * numTerms); #endif - Kokkos::Timer timer; - { - timer.reset(); - /// m_Solver holds symbolic factorization to be shared with array solver - if (numGraphRows > 0) { - m_Solver.analyze(numRows, ap_host, aj_host, - numGraphRows, graph_ap_host, graph_aj_host, graph_aw_host); - } else { - m_Solver.analyze(numRows, ap_host, aj_host); - } - /// array solver soft copy - m_SolverArray.resize(numSolvers); - for (auto & solver : m_SolverArray) { - /// duplicate perform soft copy of symbolic factors and - /// nullify numeric tools object - solver = m_Solver.duplicate(); - } - const double t = timer.seconds(); - std::cout << "ExternalInterface:: analyze time " << t << std::endl; + Kokkos::Timer timer; + { + timer.reset(); + /// m_Solver holds symbolic factorization to be shared with array solver + if (numGraphRows > 0) { + m_Solver.analyze(numRows, ap_host, aj_host, numGraphRows, graph_ap_host, graph_aj_host, graph_aw_host); + } else { + m_Solver.analyze(numRows, ap_host, aj_host); } - - /// the solver objects in the array still need to be initalized - /// to allocate required work space - { - timer.reset(); - for (auto & solver : m_SolverArray) { - solver.initialize(); - } - const double t = timer.seconds(); - std::cout << "ExternalInterface:: initialize time (" << numSolvers << ") " << t << std::endl; + /// array solver soft copy + m_SolverArray.resize(numSolvers); + for (auto &solver : m_SolverArray) { + /// duplicate perform soft copy of symbolic factors and + /// nullify numeric tools object + solver = m_Solver.duplicate(); } + const double t = timer.seconds(); + std::cout << "ExternalInterface:: analyze time " << t << std::endl; + } - /// I recommend to separate factorization from the solver initialization - { - timer.reset(); - for (ordinal_type i=0,iend=m_SolverArray.size();i &supernodes) { - const auto supernodes_device = m_Solver.getSupernodes(); - auto supernodes_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), supernodes_device); - { - const int n = supernodes_host.extent(0); - supernodes.resize(n); - std::copy(supernodes_host.data(), supernodes_host.data()+n, supernodes.data()); + /// I recommend to separate factorization from the solver initialization + { + timer.reset(); + for (ordinal_type i = 0, iend = m_SolverArray.size(); i < iend; ++i) { + m_SolverArray[i].factorize(value_type_array(ax.data() + numTerms * i, numTerms)); } + const double t = timer.seconds(); + std::cout << "ExternalInterface:: factorize time (" << numSolvers << ") " << t << std::endl; } - void exportUpperTriangularFactorsToCrsMatrix(const int iSolver, - std::vector &rowBeginU, - std::vector &columnsU, - std::vector &valuesU, - std::vector &perm) { - solver_type * solver = nullptr; - if (iSolver < 0) { - solver = &m_Solver; + return 0; + } + + void exportSupernodes(std::vector &supernodes) { + const auto supernodes_device = m_Solver.getSupernodes(); + auto supernodes_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), supernodes_device); + { + const int n = supernodes_host.extent(0); + supernodes.resize(n); + std::copy(supernodes_host.data(), supernodes_host.data() + n, supernodes.data()); + } + } + + void exportUpperTriangularFactorsToCrsMatrix(const int iSolver, std::vector &rowBeginU, + std::vector &columnsU, std::vector &valuesU, + std::vector &perm) { + solver_type *solver = nullptr; + if (iSolver < 0) { + solver = &m_Solver; + } else { + const int solver_array_size(m_SolverArray.size()); + if (iSolver < solver_array_size) { + solver = &m_SolverArray[iSolver]; } else { - const int solver_array_size(m_SolverArray.size()); - if (iSolver < solver_array_size) { - solver = &m_SolverArray[iSolver]; - } else { - std::cout << "ExternalInterface::Error, non-zero iSolver (" - << iSolver - << ") is selected where m_SolverArray is sized by (" - << solver_array_size - << ")\n"; - std::logic_error("Error: iSolver is out of range in m_SolverArray"); - } + std::cout << "ExternalInterface::Error, non-zero iSolver (" << iSolver + << ") is selected where m_SolverArray is sized by (" << solver_array_size << ")\n"; + std::logic_error("Error: iSolver is out of range in m_SolverArray"); } + } + { + const auto perm_device = solver->getPermutationVector(); + auto perm_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), perm_device); { - const auto perm_device = solver->getPermutationVector(); - auto perm_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), perm_device); - { - const int n = perm_host.extent(0); - perm.resize(n); - std::copy(perm_host.data(), perm_host.data()+n, perm.data()); - } - } - - { - typename solver_type::crs_matrix_type A; - solver->exportFactorsToCrsMatrix(A); - - typename solver_type::crs_matrix_type_host A_host; - A_host.createMirror(A); - A_host.copy(A); - { - const auto ap = A_host.RowPtr(); - const int n = ap.extent(0); - rowBeginU.resize(n); - std::copy(ap.data(), ap.data()+n, rowBeginU.data()); - } - { - const auto aj = A_host.Cols(); - const int n = aj.extent(0); - columnsU.resize(n); - std::copy(aj.data(), aj.data()+n, columnsU.data()); - } - { - const auto ax = A_host.Values(); - const int n = ax.extent(0); - valuesU.resize(n); - std::copy(ax.data(), ax.data()+n, valuesU.data()); - } + const int n = perm_host.extent(0); + perm.resize(n); + std::copy(perm_host.data(), perm_host.data() + n, perm.data()); } } - - void exportUpperTriangularFactorsToCrsMatrix(std::vector &rowBeginU, - std::vector &columnsU, - std::vector &valuesU, - std::vector &perm) { - exportUpperTriangularFactorsToCrsMatrix(-1, rowBeginU, columnsU, valuesU, perm); - } - - void MySolve(int NRHS, - value_type_matrix &b, - value_type_matrix &x) { - if (m_numRows == 0) return; - - const int m = m_numRows; - if (static_cast(m_TempRhs.extent(0)) < m || - static_cast(m_TempRhs.extent(1)) < NRHS) - m_TempRhs = value_type_matrix("temp rhs", m, NRHS); - - if (m_SolverArray.empty()) { - /// solve for a single instance m_Solver - m_Solver.solve(x, b, m_TempRhs); - } else { - /// solve multiple with m_SolverArray - using range_type = Kokkos::pair; - const auto range_b = range_type(0, NRHS); - for (int i=0,iend=m_SolverArray.size();iexportFactorsToCrsMatrix(A); + + typename solver_type::crs_matrix_type_host A_host; + A_host.createMirror(A); + A_host.copy(A); + { + const auto ap = A_host.RowPtr(); + const int n = ap.extent(0); + rowBeginU.resize(n); + std::copy(ap.data(), ap.data() + n, rowBeginU.data()); + } + { + const auto aj = A_host.Cols(); + const int n = aj.extent(0); + columnsU.resize(n); + std::copy(aj.data(), aj.data() + n, columnsU.data()); + } + { + const auto ax = A_host.Values(); + const int n = ax.extent(0); + valuesU.resize(n); + std::copy(ax.data(), ax.data() + n, valuesU.data()); } } - - private: - int m_numRows; - solver_type m_Solver; - std::vector m_SolverArray; - value_type_matrix m_TempRhs; - - void setSolverParameters(const int* solverParams) - { - if (solverParams[USEDEFAULTSOLVERPARAMETERS]) return; - // common options - m_Solver.setVerbose (solverParams[VERBOSITY]); - m_Solver.setSmallProblemThresholdsize (solverParams[SMALLPROBLEMTHRESHOLDSIZE]); - - // matrix type - m_Solver.setMatrixType(solverParams[MATRIX_SYMMETRIC], solverParams[MATRIX_POSITIVE_DEFINITE]); - - // tasking options - m_Solver.setBlocksize (solverParams[TASKING_OPTION_BLOCKSIZE]); - m_Solver.setPanelsize (solverParams[TASKING_OPTION_PANELSIZE]); - m_Solver.setMaxNumberOfSuperblocks (solverParams[TASKING_OPTION_MAXNUMSUPERBLOCKS]); - - // levelset options - m_Solver.setLevelSetScheduling (solverParams[LEVELSET_OPTION_SCHEDULING]); - m_Solver.setLevelSetOptionDeviceLevelCut (solverParams[LEVELSET_OPTION_DEVICE_LEVEL_CUT]); - m_Solver.setLevelSetOptionDeviceFunctionThreshold - (solverParams[LEVELSET_OPTION_DEVICE_FACTOR_THRES], - solverParams[LEVELSET_OPTION_DEVICE_SOLVE_THRES]); - m_Solver.setLevelSetOptionNumStreams (solverParams[LEVELSET_OPTION_NSTREAMS]); - m_Solver.setLevelSetOptionAlgorithmVariant(solverParams[LEVELSET_OPTION_VARIANT]); + } + + void exportUpperTriangularFactorsToCrsMatrix(std::vector &rowBeginU, std::vector &columnsU, + std::vector &valuesU, std::vector &perm) { + exportUpperTriangularFactorsToCrsMatrix(-1, rowBeginU, columnsU, valuesU, perm); + } + + void MySolve(int NRHS, value_type_matrix &b, value_type_matrix &x) { + if (m_numRows == 0) + return; + + const int m = m_numRows; + if (static_cast(m_TempRhs.extent(0)) < m || static_cast(m_TempRhs.extent(1)) < NRHS) + m_TempRhs = value_type_matrix("temp rhs", m, NRHS); + + if (m_SolverArray.empty()) { + /// solve for a single instance m_Solver + m_Solver.solve(x, b, m_TempRhs); + } else { + /// solve multiple with m_SolverArray + using range_type = Kokkos::pair; + const auto range_b = range_type(0, NRHS); + for (int i = 0, iend = m_SolverArray.size(); i < iend; ++i) { + const auto range_a = range_type(m * i, m * i + m); + m_SolverArray[i].solve(Kokkos::subview(x, range_a, range_b), Kokkos::subview(b, range_a, range_b), m_TempRhs); + } } - - }; + } + +private: + int m_numRows; + solver_type m_Solver; + std::vector m_SolverArray; + value_type_matrix m_TempRhs; + + void setSolverParameters(const int *solverParams) { + if (solverParams[USEDEFAULTSOLVERPARAMETERS]) + return; + // common options + m_Solver.setVerbose(solverParams[VERBOSITY]); + m_Solver.setSmallProblemThresholdsize(solverParams[SMALLPROBLEMTHRESHOLDSIZE]); + + // solution method + m_Solver.setSolutionMethod(solverParams[SOLUTION_METHOD]); + + // tasking options + m_Solver.setBlocksize(solverParams[TASKING_OPTION_BLOCKSIZE]); + m_Solver.setPanelsize(solverParams[TASKING_OPTION_PANELSIZE]); + m_Solver.setMaxNumberOfSuperblocks(solverParams[TASKING_OPTION_MAXNUMSUPERBLOCKS]); + + // levelset options + m_Solver.setLevelSetScheduling(solverParams[LEVELSET_OPTION_SCHEDULING]); + m_Solver.setLevelSetOptionDeviceLevelCut(solverParams[LEVELSET_OPTION_DEVICE_LEVEL_CUT]); + m_Solver.setLevelSetOptionDeviceFunctionThreshold(solverParams[LEVELSET_OPTION_DEVICE_FACTOR_THRES], + solverParams[LEVELSET_OPTION_DEVICE_SOLVE_THRES]); + m_Solver.setLevelSetOptionNumStreams(solverParams[LEVELSET_OPTION_NSTREAMS]); + m_Solver.setLevelSetOptionAlgorithmVariant(solverParams[LEVELSET_OPTION_VARIANT]); + } +}; } // namespace tacho diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleSolver_dcomplex.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleSolver_dcomplex.cpp deleted file mode 100644 index d1a653410343..000000000000 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleSolver_dcomplex.cpp +++ /dev/null @@ -1,5 +0,0 @@ -#include "Tacho_ExampleSolver.hpp" - -int main (int argc, char *argv[]) { - return driver >(argc, argv); -} diff --git a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleSolver_double.cpp b/packages/shylu/shylu_node/tacho/example/Tacho_ExampleSolver_double.cpp deleted file mode 100644 index d55db8b36a6b..000000000000 --- a/packages/shylu/shylu_node/tacho/example/Tacho_ExampleSolver_double.cpp +++ /dev/null @@ -1,5 +0,0 @@ -#include "Tacho_ExampleSolver.hpp" - -int main (int argc, char *argv[]) { - return driver(argc, argv); -} diff --git a/packages/shylu/shylu_node/tacho/example/test2.mtx b/packages/shylu/shylu_node/tacho/example/test2.mtx new file mode 100644 index 000000000000..70c7e3da11a2 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/example/test2.mtx @@ -0,0 +1,93 @@ +%%MatrixMarket matrix coordinate real general +%------------------------------------------------------------------------------- +12 12 90 + 1 1 102 + 2 1 20 + 3 1 20 + 4 1 1 + 5 1 3 + 6 1 1 + 1 2 20 + 2 2 106 + 3 2 3 + 4 2 20 + 5 2 40 + 7 2 1 + 8 2 5 + 9 2 4 + 1 3 20 + 2 3 3 + 3 3 103 + 5 3 20 + 6 3 20 + 8 3 2 + 9 3 3 + 1 4 1 + 2 4 20 + 4 4 103 + 5 4 4 + 7 4 20 + 8 4 20 + 10 4 4 + 11 4 1 + 1 5 3 + 2 5 40 + 3 5 20 + 4 5 4 + 5 5 113 + 6 5 3 + 8 5 40 + 9 5 40 + 10 5 4 + 11 5 4 + 1 6 1 + 3 6 20 + 5 6 3 + 6 6 102 + 9 6 20 + 11 6 1 + 2 7 1 + 4 7 20 + 7 7 105 + 8 7 5 + 10 7 40 + 12 7 4 + 2 8 5 + 3 8 2 + 4 8 20 + 5 8 40 + 7 8 5 + 8 8 110 + 9 8 5 + 10 8 40 + 11 8 20 + 12 8 5 + 2 9 4 + 3 9 3 + 5 9 40 + 6 9 20 + 8 9 5 + 9 9 106 + 11 9 20 + 12 9 1 + 4 10 4 + 5 10 4 + 7 10 40 + 8 10 40 + 10 10 112 + 11 10 4 + 12 10 40 + 4 11 1 + 5 11 4 + 6 11 1 + 8 11 20 + 9 11 20 + 10 11 4 + 11 11 103 + 12 11 20 + 7 12 4 + 8 12 5 + 9 12 1 + 10 12 40 + 11 12 20 + 12 12 105 diff --git a/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_DenseMatrixView.hpp b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_DenseMatrixView.hpp new file mode 100644 index 000000000000..5910ad4e4347 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_DenseMatrixView.hpp @@ -0,0 +1,324 @@ +#ifndef __TACHO_DENSE_MATRIX_VIEW_HPP__ +#define __TACHO_DENSE_MATRIX_VIEW_HPP__ + +#include "Tacho_Util.hpp" + +/// \file Tacho_DenseMatrixView.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +namespace Tacho { + +template struct DenseMatrixView { +public: + enum : ordinal_type { rank = 2 }; + + using value_type = ValueType; + using non_const_value_type = typename std::remove_const::type; + using device_type = DeviceType; + + using future_type = typename UseThisFuture::type; + +private: + ordinal_type _offm, _offn, _m, _n, _rs, _cs; + value_type *_buf; + future_type _future; + +public: + KOKKOS_INLINE_FUNCTION + DenseMatrixView() : _offm(0), _offn(0), _m(0), _n(0), _rs(0), _cs(0), _buf(NULL), _future() {} + + KOKKOS_INLINE_FUNCTION + DenseMatrixView(value_type *buf, const ordinal_type m, const ordinal_type n) + : _offm(0), _offn(0), _m(m), _n(n), _rs(1), _cs(m), _buf(buf), _future() {} + + KOKKOS_INLINE_FUNCTION + DenseMatrixView(const DenseMatrixView &b) + : _offm(b._offm), _offn(b._offn), _m(b._m), _n(b._n), _rs(b._rs), _cs(b._cs), _buf(b._buf), _future() {} + + KOKKOS_INLINE_FUNCTION + value_type &operator[](const ordinal_type k) const { return _buf[k]; } + + KOKKOS_INLINE_FUNCTION + value_type &operator()(const ordinal_type i, const ordinal_type j) const { + return _buf[(i + _offm) * _rs + (j + _offn) * _cs]; + } + + KOKKOS_INLINE_FUNCTION + void set_view(const DenseMatrixView &base, const ordinal_type offm, const ordinal_type m, const ordinal_type offn, + const ordinal_type n) { + _rs = base._rs; + _cs = base._cs; + _buf = base._buf; + + _offm = offm; + _m = m; + _offn = offn; + _n = n; + } + + KOKKOS_INLINE_FUNCTION + void set_view(const ordinal_type offm, const ordinal_type m, const ordinal_type offn, const ordinal_type n) { + _offm = offm; + _m = m; + _offn = offn; + _n = n; + } + + KOKKOS_INLINE_FUNCTION + void set_view(const ordinal_type m, const ordinal_type n) { + _offm = 0; + _m = m; + _offn = 0; + _n = n; + } + + KOKKOS_INLINE_FUNCTION + void attach_buffer(const ordinal_type rs, const ordinal_type cs, const value_type *buf) { + _rs = rs; + _cs = cs; + _buf = const_cast(buf); + } + + KOKKOS_INLINE_FUNCTION + void set_future(const future_type &f) { _future = f; } + + KOKKOS_INLINE_FUNCTION + void set_future() { _future.clear(); } + + /// get methods + + KOKKOS_INLINE_FUNCTION + ordinal_type offset_0() const { return _offm; } + + KOKKOS_INLINE_FUNCTION + ordinal_type offset_1() const { return _offn; } + + KOKKOS_INLINE_FUNCTION + ordinal_type extent(const ordinal_type r) const { return (r == 0) ? _m : _n; } + + KOKKOS_INLINE_FUNCTION + ordinal_type stride_0() const { return _rs; } + + KOKKOS_INLINE_FUNCTION + ordinal_type stride_1() const { return _cs; } + + KOKKOS_INLINE_FUNCTION + value_type *data() const { return _buf + _offm * _rs + _offn * _cs; } + + KOKKOS_INLINE_FUNCTION + future_type future() const { return _future; } +}; + +template +KOKKOS_INLINE_FUNCTION void clearFutureOfBlocks(const MatrixOfBlocksViewType &H) { + const ordinal_type m = H.extent(0); + const ordinal_type n = H.extent(1); + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < m; ++i) + H(i, j).set_future(); +} + +template +KOKKOS_INLINE_FUNCTION void clearFutureOfBlocks(/* */ MemberType &member, const MatrixOfBlocksViewType &H) { + const ordinal_type m = H.extent(0); + const ordinal_type n = H.extent(1); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), [&](const int &i) { H(i, j).set_future(); }); + }); +} + +template +KOKKOS_INLINE_FUNCTION void setMatrixOfBlocks(const MatrixOfBlocksViewType &H, const ordinal_type m, + const ordinal_type n, const ordinal_type mb, const ordinal_type nb) { + const ordinal_type bm = H.extent(0); + const ordinal_type bn = H.extent(1); + + for (ordinal_type j = 0; j < bn; ++j) { + const ordinal_type jbeg = j * nb, jtmp = jbeg + nb, jend = jtmp > n ? n : jtmp, + jdiff = (jend > jbeg) * (jend - jbeg); + + for (ordinal_type i = 0; i < bm; ++i) { + const ordinal_type ibeg = i * mb, itmp = ibeg + mb, iend = itmp > m ? m : itmp, + idiff = (iend > ibeg) * (iend - ibeg); + + H(i, j).set_view(ibeg, idiff, jbeg, jdiff); + } + } +} + +template +KOKKOS_INLINE_FUNCTION void setMatrixOfBlocks(const MatrixOfBlocksViewType &H, const ordinal_type m, + const ordinal_type n, const ordinal_type mb) { + setMatrixOfBlocks(H, m, n, mb, mb); +} + +template +KOKKOS_INLINE_FUNCTION void setMatrixOfBlocks(/* */ MemberType &member, const MatrixOfBlocksViewType &H, + const ordinal_type m, const ordinal_type n, const ordinal_type mb, + const ordinal_type nb) { + const ordinal_type bm = H.extent(0); + const ordinal_type bn = H.extent(1); + + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, bn), [&](const int &j) { + const ordinal_type jbeg = j * nb, jtmp = jbeg + nb, jend = jtmp > n ? n : jtmp, + jdiff = (jend > jbeg) * (jend - jbeg); + + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, bm), [&](const int &i) { + const ordinal_type ibeg = i * mb, itmp = ibeg + mb, iend = itmp > m ? m : itmp, + idiff = (iend > ibeg) * (iend - ibeg); + + H(i, j).set_view(ibeg, idiff, jbeg, jdiff); + }); + }); +} + +template +KOKKOS_INLINE_FUNCTION void setMatrixOfBlocks(/* */ MemberType &member, const MatrixOfBlocksViewType &H, + const ordinal_type m, const ordinal_type n, const ordinal_type mb) { + setMatrixOfBlocks(member, H, m, n, mb, mb); +} + +template +KOKKOS_INLINE_FUNCTION void attachBaseBuffer(const MatrixOfBlocksViewType &H, const BaseBufferPtrType ptr, + const ordinal_type rs, const ordinal_type cs) { + const ordinal_type m = H.extent(0), n = H.extent(1); + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < m; ++i) + H(i, j).attach_buffer(rs, cs, ptr); +} + +template +KOKKOS_INLINE_FUNCTION void attachBaseBuffer(/* */ MemberType &member, const MatrixOfBlocksViewType &H, + const BaseBufferPtrType ptr, const ordinal_type rs, + const ordinal_type cs) { + const ordinal_type m = H.extent(0); + const ordinal_type n = H.extent(1); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), + [&](const int &i) { H(i, j).attach_buffer(rs, cs, ptr); }); + }); +} + +template +KOKKOS_INLINE_FUNCTION void allocateStorageByBlocks(const MatrixOfBlocksViewType &H, const MemoryPoolType &pool) { + typedef typename MatrixOfBlocksViewType::value_type dense_block_type; + typedef typename dense_block_type::value_type value_type; + + const ordinal_type m = H.extent(0); + const ordinal_type n = H.extent(1); + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type mm = H(i, j).extent(0), nn = H(i, j).extent(1); + if (mm > 0 && nn > 0) { + auto ptr = (value_type *)pool.allocate(mm * nn * sizeof(value_type)); + TACHO_TEST_FOR_ABORT(ptr == NULL, "memory pool allocation fails"); + + H(i, j).set_view(mm, nn); // whatever offsets are defined here, they are gone. + H(i, j).attach_buffer(1, mm, ptr); + } + } +} + +template +KOKKOS_INLINE_FUNCTION void deallocateStorageByBlocks(const MatrixOfBlocksViewType &H, const MemoryPoolType &pool) { + typedef typename MatrixOfBlocksViewType::value_type dense_block_type; + typedef typename dense_block_type::value_type value_type; + + const ordinal_type m = H.extent(0), n = H.extent(1); + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < m; ++i) { + auto &blk = H(i, j); + const ordinal_type mm = blk.extent(0), nn = blk.extent(1); + if (mm > 0 && nn > 0) + pool.deallocate(blk.data(), mm * nn * sizeof(value_type)); + } +} + +template +KOKKOS_INLINE_FUNCTION void +copyElementwise(const DenseMatrixView &F, + const DenseMatrixView, ExecSpace> &H) { + const ordinal_type hm = H.extent(0), hn = H.extent(0), fm = F.extent(0), fn = F.extent(0); + + if (hm > 0 && hn > 0) { + ordinal_type offj = 0; + for (ordinal_type j = 0; j < hn; ++j) { + ordinal_type offi = 0; + for (ordinal_type i = 0; i < hm; ++i) { + const auto &blk = H(i, j); + const ordinal_type mm = blk.extent(0), nn = blk.extent(1); + for (ordinal_type jj = 0; jj < nn; ++jj) { + const ordinal_type jjj = offj + jj; + for (ordinal_type ii = 0; ii < mm; ++ii) { + const ordinal_type iii = offi + ii; + if (iii < fm && jjj < fn) + F(iii, jjj) = blk(ii, jj); + } + } + offi += mm; + } + offj += H(0, j).extent(1); + } + } +} + +template +KOKKOS_INLINE_FUNCTION void copyElementwise(const DenseMatrixView, ExecSpace> &H, + const DenseMatrixView &F) { + const ordinal_type hm = H.extent(0), hn = H.extent(0), fm = F.extent(0), fn = F.extent(0); + + if (hm > 0 && hn > 0) { + ordinal_type offj = 0; + for (ordinal_type j = 0; j < hn; ++j) { + ordinal_type offi = 0; + for (ordinal_type i = 0; i < hm; ++i) { + const auto &blk = H(i, j); + const ordinal_type mm = blk.extent(0), nn = blk.extent(1); + for (ordinal_type jj = 0; jj < nn; ++jj) { + const ordinal_type jjj = offj + jj; + for (ordinal_type ii = 0; ii < mm; ++ii) { + const ordinal_type iii = offi + ii; + if (iii < fm && jjj < fn) + blk(ii, jj) = F(iii, jjj); + } + } + offi += mm; + } + offj += H(0, j).extent(1); + } + } +} + +/// A = P B +template +inline void applyRowPermutationToDenseMatrix(const DenseMatrixViewType &A, const DenseMatrixViewType &B, + const OrdinalTypeArray &p) { + const ordinal_type m = A.extent(0), n = A.extent(1); + typedef typename DenseMatrixViewType::device_type::execution_space exec_space; + + if (true) { // std::is_same::value) { + // serial copy on host + Kokkos::RangePolicy> policy(0, m); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &i) { + for (ordinal_type j = 0; j < n; ++j) + A(i, j) = B(p(i), j); + }); + } else { + // gcc has compiler errors + // Kokkos::TeamPolicy > policy(m, 1); + // Kokkos::parallel_for + // (policy, KOKKOS_LAMBDA (const typename Kokkos::TeamPolicy::member_type &member) { + // const ordinal_type i = member.league_rank(); + // Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n),[&](const int &j) { + // Kokkos::single(Kokkos::PerThread(member), [&]() { + // A(i, j) = B(p(i), j); + // }); + // }); + // }); + } +} + +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_CAMD.hpp b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_CAMD.hpp new file mode 100644 index 000000000000..dccc3e1a7e11 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_CAMD.hpp @@ -0,0 +1,213 @@ +#ifndef __TACHO_GRAPH_TOOLS_CAMD_HPP__ +#define __TACHO_GRAPH_TOOLS_CAMD_HPP__ + +/// \file Tacho_GraphTools_CAMD.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Graph.hpp" +#include "Tacho_Util.hpp" + +// #if defined (TACHO_HAVE_SUITESPARSE) +// TPL SuiteSparse +// #include "camd.h" +// #define TACHO_SUITESPARSE(run) run +// #define TRILINOS_CAMD_CONTROL CAMD_CONTROL +// #define TRILINOS_CAMD_INFO CAMD_INFO +// #define TRILINOS_CAMD_STATUS CAMD_STATUS +// #define TRILINOS_CAMD_OK CAMD_OK + +#if defined(TACHO_HAVE_TRILINOS_SS) +// Trilinos SuiteSparse +#include "trilinos_camd.h" +#define TACHO_SUITESPARSE(run) trilinos_##run +typedef UF_long SuiteSparse_long; +#endif + +#if defined(TACHO_HAVE_TRILINOS_SS) //|| defined(TACHO_HAVE_SUITESPARSE) + +namespace Tacho { + +template class CAMD; + +template <> class CAMD { +public: + static void run(int n, int Pe[], int Iw[], int Len[], int iwlen, int pfree, int Nv[], int Next[], int Last[], + int Head[], int Elen[], int Degree[], int W[], double Control[], double Info[], const int C[], + int BucketSet[]) { + TACHO_SUITESPARSE(camd_2) + (n, Pe, Iw, Len, iwlen, pfree, Nv, Next, Last, Head, Elen, Degree, W, Control, Info, C, BucketSet); + } +}; + +template <> class CAMD { +public: + static void run(SuiteSparse_long n, SuiteSparse_long Pe[], SuiteSparse_long Iw[], SuiteSparse_long Len[], + SuiteSparse_long iwlen, SuiteSparse_long pfree, SuiteSparse_long Nv[], SuiteSparse_long Next[], + SuiteSparse_long Last[], SuiteSparse_long Head[], SuiteSparse_long Elen[], SuiteSparse_long Degree[], + SuiteSparse_long W[], double Control[], double Info[], const SuiteSparse_long C[], + SuiteSparse_long BucketSet[]) { + TACHO_SUITESPARSE(camd_l2) + (n, Pe, Iw, Len, iwlen, pfree, Nv, Next, Last, Head, Elen, Degree, W, Control, Info, C, BucketSet); + } +}; + +class GraphTools_CAMD { +public: + typedef Kokkos::DefaultHostExecutionSpace host_exec_space; + typedef Kokkos::View ordinal_type_array; + +private: + // graph input + ordinal_type _m; + size_type _nnz; + ordinal_type_array _rptr, _cidx, _cnst; + + // CAMD output + ordinal_type_array _pe, _nv, _el, _next, _perm, _peri; // perm = last, peri = next + + double _control[TRILINOS_CAMD_CONTROL], _info[TRILINOS_CAMD_INFO]; + + bool _is_ordered; + +public: + // static assert is necessary to enforce to use host space only + GraphTools_CAMD() = default; + GraphTools_CAMD(const GraphTools_CAMD &b) = default; + + GraphTools_CAMD(const Graph &g) { + _m = g.NumRows(); + _nnz = g.NumNonZeros(); + + _rptr = g.RowPtr(); + _cidx = g.ColIdx(); + _cnst = ordinal_type_array("CAMD::ConstraintArray", _m + 1); + + // permutation vector + _pe = ordinal_type_array("CAMD::EliminationArray", _m); + _nv = ordinal_type_array("CAMD::SupernodesArray", _m); + _el = ordinal_type_array("CAMD::DegreeArray", _m); + _next = ordinal_type_array("CAMD::InvPermSupernodesArray", _m); + _perm = ordinal_type_array("CAMD::PermutationArray", _m); + _peri = ordinal_type_array("CAMD::InvPermutationArray", _m); + } + virtual ~GraphTools_CAMD() = default; + + void setConstraint(const ordinal_type nblk, const ordinal_type_array range, const ordinal_type_array peri) { + for (ordinal_type i = 0; i < nblk; ++i) + for (ordinal_type j = range(i); j < range(i + 1); ++j) + _cnst(peri(j)) = i; + } + + void reorder(const ordinal_type verbose = 0) { + Kokkos::Timer timer; + double t_camd = 0; + + TACHO_SUITESPARSE(camd_defaults)(_control); + TACHO_SUITESPARSE(camd_control)(_control); + + ordinal_type *rptr = reinterpret_cast(_rptr.data()); + ordinal_type *cidx = reinterpret_cast(_cidx.data()); + ordinal_type *cnst = reinterpret_cast(_cnst.data()); + + ordinal_type *next = reinterpret_cast(_next.data()); + ordinal_type *perm = reinterpret_cast(_perm.data()); + + // length array + ordinal_type_array lwork("CAMD::LWorkArray", _m); + ordinal_type *lwork_ptr = reinterpret_cast(lwork.data()); + for (ordinal_type i = 0; i < _m; ++i) + lwork_ptr[i] = rptr[i + 1] - rptr[i]; + + // workspace + const size_type swlen = _nnz + _nnz / 5 + 5 * (_m + 1); + ; + ordinal_type_array swork("CAMD::SWorkArray", swlen); + ordinal_type *swork_ptr = reinterpret_cast(swork.data()); + + ordinal_type *pe_ptr = reinterpret_cast(_pe.data()); // 1) Pe + size_type pfree = 0; + for (ordinal_type i = 0; i < _m; ++i) { + pe_ptr[i] = pfree; + pfree += lwork_ptr[i]; + } + TACHO_TEST_FOR_EXCEPTION(_nnz != pfree, std::logic_error, + ">> nnz in the graph does not match to nnz count (pfree)"); + + ordinal_type *nv_ptr = reinterpret_cast(_nv.data()); // 2) Nv + ordinal_type *hd_ptr = swork_ptr; + swork_ptr += (_m + 1); // 3) Head + ordinal_type *el_ptr = reinterpret_cast(_el.data()); // 4) Elen + ordinal_type *dg_ptr = swork_ptr; + swork_ptr += _m; // 5) Degree + ordinal_type *wk_ptr = swork_ptr; + swork_ptr += (_m + 1); // 6) W + ordinal_type *bk_ptr = swork_ptr; + swork_ptr += _m; // 7) BucketSet + + const size_type iwlen = swlen - (4 * _m + 2); + ordinal_type *iw_ptr = swork_ptr; + swork_ptr += iwlen; // Iw + for (size_type i = 0; i < pfree; ++i) + iw_ptr[i] = cidx[i]; + + timer.reset(); + CAMD::run(_m, pe_ptr, iw_ptr, lwork_ptr, iwlen, pfree, + // output + nv_ptr, next, perm, hd_ptr, el_ptr, dg_ptr, wk_ptr, _control, _info, cnst, bk_ptr); + t_camd = timer.seconds(); + + TACHO_TEST_FOR_EXCEPTION(_info[TRILINOS_CAMD_STATUS] != TRILINOS_CAMD_OK, std::runtime_error, "CAMD fails"); + + for (ordinal_type i = 0; i < _m; ++i) + _peri[_perm[i]] = i; + + _is_ordered = true; + + if (verbose) { + printf("Summary: GraphTools (CAMD)\n"); + printf("===========================\n"); + + switch (verbose) { + case 1: { + printf(" Time\n"); + printf(" time for reordering: %10.6f s\n", t_camd); + printf("\n"); + } + } + } + } + + ordinal_type_array PermVector() const { return _perm; } + ordinal_type_array InvPermVector() const { return _peri; } + ordinal_type_array ConstraintVector() const { return _cnst; } + + std::ostream &showMe(std::ostream &os, const bool detail) const { + std::streamsize prec = os.precision(); + os.precision(8); + os << std::scientific; + + os << " -- CAMD input -- " << std::endl + << " # of Rows = " << _m << std::endl + << " # of NonZeros = " << _nnz << std::endl; + + if (_is_ordered) + os << " -- Ordering -- " << std::endl + << " CNST PERM PERI PE NV NEXT ELEN" << std::endl; + + const int w = 6; + for (ordinal_type i = 0; i < _m; ++i) + os << std::setw(w) << _cnst[i] << " " << std::setw(w) << _perm[i] << " " << std::setw(w) << _peri[i] << " " + << std::setw(w) << _pe[i] << " " << std::setw(w) << _nv[i] << " " << std::setw(w) << _next[i] << " " + << std::setw(w) << _el[i] << " " << std::endl; + + os.unsetf(std::ios::scientific); + os.precision(prec); + + return os; + } +}; + +} // namespace Tacho + +#endif +#endif diff --git a/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_MetisMT.hpp b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_MetisMT.hpp new file mode 100644 index 000000000000..83c1f2334f50 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_MetisMT.hpp @@ -0,0 +1,176 @@ +#ifndef __TACHO_GRAPH_TOOLS_METIS_MT_HPP__ +#define __TACHO_GRAPH_TOOLS_METIS_MT_HPP__ + +/// \file Tacho_GraphTools_Metis_MT.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" +#if defined(TACHO_HAVE_METIS_MT) + +#include "Tacho_Graph.hpp" + +#include "mtmetis.h" + +namespace Tacho { + +class GraphTools_MetisMT { +public: + typedef typename UseThisDevice host_device_type; + typedef typename host_device_type::execution_space host_space; + + typedef Kokkos::View mtmetis_vtx_type_array; + typedef Kokkos::View mtmetis_adj_type_array; + typedef Kokkos::View mtmetis_wgt_type_array; + typedef Kokkos::View mtmetis_pid_type_array; + + typedef Kokkos::View ordinal_type_array; + +private: + // metis main data structure + mtmetis_vtx_type _nvts; + mtmetis_vtx_type_array _xadj; + mtmetis_adj_type_array _adjncy; + mtmetis_wgt_type_array _vwgt; + + double _options[MTMETIS_NOPTIONS]; + + // metis output + mtmetis_pid_type_array _perm_t, _peri_t; + ordinal_type_array _perm, _peri; + + // status flag + bool _is_ordered, _verbose; + +public: + GraphTools_MetisMT() = default; + GraphTools_MetisMT(const GraphTools_MetisMT &b) = default; + + /// + /// construction of scotch graph + /// + GraphTools_MetisMT(const Graph &g) { + _is_ordered = false; + _verbose = false; + + // input + _nvts = g.NumRows(); + + _xadj = mtmetis_vtx_type_array("vtx_type_xadj", g.RowPtr().extent(0)); + _adjncy = mtmetis_adj_type_array("adj_type_adjncy", g.ColIdx().extent(0)); + _vwgt = mtmetis_wgt_type_array(); + + const auto &g_row_ptr = g.RowPtr(); + const auto &g_col_idx = g.ColIdx(); + + for (ordinal_type i = 0; i < static_cast(_xadj.extent(0)); ++i) + _xadj(i) = g_row_ptr(i); + for (ordinal_type i = 0; i < static_cast(_adjncy.extent(0)); ++i) + _adjncy(i) = g_col_idx(i); + + // default + for (ordinal_type i = 0; i < static_cast(MTMETIS_NOPTIONS); ++i) + _options[i] = MTMETIS_VAL_OFF; + + // by default, metis use + // # of threads : omp_get_max_threads + // seed : (unsigned int)time(NULL) + // internal verbose options are : + // MTMETIS_VERBOSITY_NONE, + // MTMETIS_VERBOSITY_LOW, + // MTMETIS_VERBOSITY_MEDIUM, + // MTMETIS_VERBOSITY_HIGH, + // MTMETIS_VERBOSITY_MAXIMUM + + _options[MTMETIS_OPTION_NTHREADS] = host_space::thread_pool_size(0); // from kokkos + //_options[MTMETIS_OPTION_SEED] = 0; // for testing, use the same seed now + //_options[MTMETIS_OPTION_PTYPE] = MTMETIS_PTYPE_ND; // when explicit interface is used + //_options[MTMETIS_OPTION_VERBOSITY] = MTMETIS_VERBOSITY_NONE; + //_options[MTMETIS_OPTION_METIS] = 1; // flag to use serial metis + + _perm_t = mtmetis_pid_type_array("pid_type_perm", _nvts); + _peri_t = mtmetis_pid_type_array("pid_type_peri", _nvts); + + // output + _perm = ordinal_type_array("MetisMT::PermutationArray", _nvts); + _peri = ordinal_type_array("MetisMT::InvPermutationArray", _nvts); + } + virtual ~GraphTools_MetisMT() {} + + /// + /// setup metis parameters + /// + + void setVerbose(const bool verbose) { _verbose = verbose; } + void setOption(const int id, const double value) { _options[id] = value; } + + /// + /// reorder by metis + /// + + void reorder(const ordinal_type verbose = 0) { + Kokkos::Timer timer; + double t_metis = 0; + + int ierr = 0; + + mtmetis_vtx_type *xadj = (mtmetis_vtx_type *)_xadj.data(); + mtmetis_adj_type *adjncy = (mtmetis_adj_type *)_adjncy.data(); + mtmetis_wgt_type *vwgt = (mtmetis_wgt_type *)_vwgt.data(); + + mtmetis_pid_type *perm = (mtmetis_pid_type *)_perm_t.data(); + mtmetis_pid_type *peri = (mtmetis_pid_type *)_peri_t.data(); + + timer.reset(); + ierr = MTMETIS_NodeND(&_nvts, xadj, adjncy, vwgt, _options, perm, peri); + t_metis = timer.seconds(); + + for (mtmetis_vtx_type i = 0; i < _nvts; ++i) { + _perm(i) = _perm_t(i); + _peri(i) = _peri_t(i); + } + + TACHO_TEST_FOR_EXCEPTION(ierr != MTMETIS_SUCCESS, std::runtime_error, "Failed in METIS_NodeND"); + _is_ordered = true; + + if (verbose) { + printf("Summary: GraphTools (MetisMT)\n"); + printf("=============================\n"); + + switch (verbose) { + case 1: { + printf(" Time\n"); + printf(" time for reordering: %10.6f s\n", t_metis); + printf("\n"); + } + } + } + } + + ordinal_type_array PermVector() const { return _perm; } + ordinal_type_array InvPermVector() const { return _peri; } + + std::ostream &showMe(std::ostream &os, const bool detail = false) const { + std::streamsize prec = os.precision(); + os.precision(4); + os << std::scientific; + + if (_is_ordered) + os << " -- MetisMT Ordering -- " << std::endl << " PERM PERI " << std::endl; + else + os << " -- Not Ordered -- " << std::endl; + + if (detail) { + const ordinal_type w = 6, m = _perm.extent(0); + for (ordinal_type i = 0; i < m; ++i) + os << std::setw(w) << _perm[i] << " " << std::setw(w) << _peri[i] << " " << std::endl; + } + os.unsetf(std::ios::scientific); + os.precision(prec); + + return os; + } +}; + +} // namespace Tacho +#endif +#endif diff --git a/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_Scotch.hpp b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_Scotch.hpp new file mode 100644 index 000000000000..16dbff700618 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_GraphTools_Scotch.hpp @@ -0,0 +1,220 @@ +#ifndef __TACHO_GRAPH_TOOLS_SCOTCH_HPP__ +#define __TACHO_GRAPH_TOOLS_SCOTCH_HPP__ + +/// \file Tacho_GraphTools_Scotch.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" + +#if defined(TACHO_HAVE_SCOTCH) +#include "Tacho_Graph.hpp" + +#include "scotch.h" + +namespace Tacho { + +class GraphTools_Scotch { +public: + typedef typename UseThisDevice host_device_type; + typedef Kokkos::View ordinal_type_array; + + enum : int { DefaultRandomSeed = -1 }; + +private: + // scotch main data structure + SCOTCH_Graph _graph; + SCOTCH_Num _strat; + int _level; + + // scotch output + ordinal_type _cblk; + ordinal_type_array _perm, _peri, _range, _tree; + + // status flag + bool _is_ordered, _verbose; + +public: + GraphTools_Scotch() = default; + GraphTools_Scotch(const GraphTools_Scotch &b) = default; + + /// + /// construction of scotch graph + /// + GraphTools_Scotch(const Graph &g) { + _is_ordered = false; + _verbose = false; + + // input + const ordinal_type base = 0; + const ordinal_type m = g.NumRows(); + const size_type nnz = g.NumNonZeros(); + + // scotch control parameter + _strat = 0; + _level = 0; + + // output + _cblk = 0; + _perm = ordinal_type_array("Scotch::PermutationArray", m); + _peri = ordinal_type_array("Scotch::InvPermutationArray", m); + _range = ordinal_type_array("Scotch::RangeArray", m); + _tree = ordinal_type_array("Scotch::TreeArray", m); + + // construct scotch graph + int ierr = 0; + const SCOTCH_Num *rptr_ptr = reinterpret_cast(g.RowPtr().data()); + const SCOTCH_Num *cidx_ptr = reinterpret_cast(g.ColIdx().data()); + + ierr = SCOTCH_graphInit(&_graph); + TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_graphInit"); + + ierr = SCOTCH_graphBuild(&_graph, // scotch graph + base, // base value + m, // # of vertices + rptr_ptr, // column index array pointer begin + rptr_ptr + 1, // column index array pointer end + NULL, // weights on vertices (optional) + NULL, // label array on vertices (optional) + nnz, // # of nonzeros + cidx_ptr, // column index array + NULL); // edge load array (optional) + TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_graphBuild"); + + ierr = SCOTCH_graphCheck(&_graph); + TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_graphCheck"); + } + virtual ~GraphTools_Scotch() { SCOTCH_graphFree(&_graph); } + + /// + /// setup scotch parameters + /// + + void setVerbose(const bool verbose) { _verbose = verbose; } + void setSeed(const int seed = DefaultRandomSeed) { + if (seed != DefaultRandomSeed) { + SCOTCH_randomSeed(seed); + SCOTCH_randomReset(); + } + } + + void setStrategy(const SCOTCH_Num strat = 0) { + // a typical choice + //(SCOTCH_STRATLEVELMAX));// | + // SCOTCH_STRATLEVELMIN | + // SCOTCH_STRATLEAFSIMPLE | + // SCOTCH_STRATSEPASIMPLE); + _strat = strat; + } + + void setTreeLevel(const unsigned int level = 0) { _level = level; } + + /// + /// setup scotch parameters + /// + + void reorder(const ordinal_type verbose = 0) { + Kokkos::Timer timer; + double t_scotch = 0; + + _verbose = verbose; + + const int treecut = 0; + int ierr = 0; + + // pointers for global graph ordering + ordinal_type *perm = _perm.data(); + ordinal_type *peri = _peri.data(); + ordinal_type *range = _range.data(); + ordinal_type *tree = _tree.data(); + + timer.reset(); + { + // set desired tree level + if (_strat & SCOTCH_STRATLEVELMAX || _strat & SCOTCH_STRATLEVELMIN) { + TACHO_TEST_FOR_EXCEPTION(_level == 0, std::logic_error, + "SCOTCH_STRATLEVEL(MIN/MAX) is used but level is not specified"); + } + const int level = max(1, _level - treecut); + + SCOTCH_Strat stradat; + SCOTCH_Num straval = _strat; + + ierr = SCOTCH_stratInit(&stradat); + TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_stratInit"); + + // if both are zero, do not build strategy + if (_strat || _level) { + ierr = SCOTCH_stratGraphOrderBuild(&stradat, straval, level, 0.2); + TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_stratGraphOrderBuild"); + } + ierr = SCOTCH_graphOrder(&_graph, &stradat, perm, peri, &_cblk, range, tree); + TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_graphOrder"); + SCOTCH_stratExit(&stradat); + } + t_scotch = timer.seconds(); + _is_ordered = true; + + if (_verbose) { + printf("Summary: GraphTools (Scotch)\n"); + printf("===========================\n"); + printf(" Time\n"); + printf(" time for reordering: %10.6f s\n", t_scotch); + printf("\n"); + if (_strat || _level) { + printf(" User provided strategy ( %d ) and/or level ( %d )\n", _strat, _level); + printf(" strategy & SCOTCH_STRATLEVELMAX: %3d\n", (_strat & SCOTCH_STRATLEVELMAX)); + printf(" strategy & SCOTCH_STRATLEVELMIN: %3d\n", (_strat & SCOTCH_STRATLEVELMIN)); + printf(" strategy & SCOTCH_STRATLEAFSIMPLE: %3d\n", (_strat & SCOTCH_STRATLEAFSIMPLE)); + printf(" strategy & SCOTCH_STRATSEPASIMPLE: %3d\n", (_strat & SCOTCH_STRATSEPASIMPLE)); + printf("\n"); + } + printf(" Partitions\n"); + printf(" number of block partitions: %3d\n", _cblk); + printf("\n"); + } + } + + ordinal_type_array PermVector() const { return _perm; } + ordinal_type_array InvPermVector() const { return _peri; } + + ordinal_type_array RangeVector() const { return _range; } + ordinal_type_array TreeVector() const { return _tree; } + + ordinal_type NumBlocks() const { return _cblk; } + ordinal_type TreeLevel() const { + ordinal_type r_val; + if (_strat & SCOTCH_STRATLEVELMAX || _strat & SCOTCH_STRATLEVELMIN) + r_val = _level; + else + r_val = 0; + return r_val; + } + + std::ostream &showMe(std::ostream &os, const bool detail = false) const { + std::streamsize prec = os.precision(); + os.precision(4); + os << std::scientific; + + if (_is_ordered) + os << " -- Scotch Ordering -- " << std::endl + << " CBLK = " << _cblk << std::endl + << " PERM PERI RANG TREE" << std::endl; + else + os << " -- Not Ordered -- " << std::endl; + + if (detail) { + const ordinal_type w = 6, m = _perm.extent(0); + for (ordinal_type i = 0; i < m; ++i) + os << std::setw(w) << _perm[i] << " " << std::setw(w) << _peri[i] << " " << std::setw(w) << _range[i] + << " " << std::setw(w) << _tree[i] << std::endl; + } + os.unsetf(std::ios::scientific); + os.precision(prec); + + return os; + } +}; + +} // namespace Tacho +#endif +#endif diff --git a/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_NumericTools_Factory.hpp b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_NumericTools_Factory.hpp new file mode 100644 index 000000000000..f6c7e3dc5969 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_NumericTools_Factory.hpp @@ -0,0 +1,304 @@ +#ifndef __TACHO_NUMERIC_TOOLS_FACTORY_HPP__ +#define __TACHO_NUMERIC_TOOLS_FACTORY_HPP__ + +/// \file Tacho_NumericTools_Serial.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_NumericTools_Base.hpp" +#include "Tacho_NumericTools_LevelSet.hpp" +#include "Tacho_NumericTools_Serial.hpp" + +namespace Tacho { + +/// +/// +/// +template class NumericToolsFactory; + +/// partial specialization +#define TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING \ + using ordinal_type_array = typename numeric_tools_base_type::ordinal_type_array; \ + using size_type_array = typename numeric_tools_base_type::size_type_array; \ + using ordinal_type_array_host = typename numeric_tools_base_type::ordinal_type_array_host; \ + using size_type_array_host = typename numeric_tools_base_type::size_type_array_host + +#define TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER \ + ordinal_type _method; \ + ordinal_type _m; \ + size_type_array _ap; \ + ordinal_type_array _aj; \ + ordinal_type_array _perm; \ + ordinal_type_array _peri; \ + ordinal_type _nsupernodes; \ + ordinal_type_array _supernodes; \ + size_type_array _gid_ptr; \ + ordinal_type_array _gid_colidx; \ + size_type_array _sid_ptr; \ + ordinal_type_array _sid_colidx; \ + ordinal_type_array _blk_colidx; \ + ordinal_type_array _stree_parent; \ + size_type_array _stree_ptr; \ + ordinal_type_array _stree_children; \ + ordinal_type_array_host _stree_level; \ + ordinal_type_array_host _stree_roots; \ + ordinal_type _verbose + +#define TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER \ + do { \ + _method = method; \ + _m = m; \ + _ap = ap; \ + _aj = aj; \ + _perm = perm; \ + _peri = peri; \ + _nsupernodes = nsupernodes; \ + _supernodes = supernodes; \ + _gid_ptr = gid_ptr; \ + _gid_colidx = gid_colidx; \ + _sid_ptr = sid_ptr; \ + _sid_colidx = sid_colidx; \ + _blk_colidx = blk_colidx; \ + _stree_parent = stree_parent; \ + _stree_ptr = stree_ptr; \ + _stree_children = stree_children; \ + _stree_level = stree_level; \ + _stree_roots = stree_roots; \ + _verbose = verbose; \ + } while (false) + +#define TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER \ + ordinal_type _variant; \ + ordinal_type _device_level_cut; \ + ordinal_type _device_factor_thres; \ + ordinal_type _device_solve_thres; \ + ordinal_type _nstreams + +#define TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER \ + do { \ + _variant = variant; \ + _device_level_cut = device_level_cut; \ + _device_factor_thres = device_factor_thres; \ + _device_solve_thres = device_solve_thres; \ + _nstreams = nstreams; \ + } while (false) + +#define TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY \ + do { \ + if (object == nullptr) \ + object = (numeric_tools_base_type *)::operator new(sizeof(numeric_tools_serial_type)); \ + \ + new (object) numeric_tools_serial_type(_method, _m, _ap, _aj, _perm, _peri, _nsupernodes, _supernodes, _gid_ptr, \ + _gid_colidx, _sid_ptr, _sid_colidx, _blk_colidx, _stree_parent, _stree_ptr, \ + _stree_children, _stree_level, _stree_roots); \ + } while (false) + +#define TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_name) \ + do { \ + if (object == nullptr) \ + object = (numeric_tools_base_type *)::operator new(sizeof(numeric_tools_levelset_name)); \ + new (object) numeric_tools_levelset_name(_method, _m, _ap, _aj, _perm, _peri, _nsupernodes, _supernodes, _gid_ptr, \ + _gid_colidx, _sid_ptr, _sid_colidx, _blk_colidx, _stree_parent, \ + _stree_ptr, _stree_children, _stree_level, _stree_roots); \ + numeric_tools_levelset_name *N = dynamic_cast(object); \ + N->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); \ + N->createStream(_nstreams, _verbose); \ + } while (false) + +/// +/// Serial construction +/// +#if defined(KOKKOS_ENABLE_SERIAL) +template class NumericToolsFactory::type> { +public: + using value_type = ValueType; + using device_type = typename UseThisDevice::type; + using numeric_tools_base_type = NumericToolsBase; + using numeric_tools_serial_type = NumericToolsSerial; + using numeric_tools_levelset_var0_type = NumericToolsLevelSet; + using numeric_tools_levelset_var1_type = NumericToolsLevelSet; + using numeric_tools_levelset_var2_type = NumericToolsLevelSet; + + TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING; + TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER; + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER; + + void setBaseMember(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, + const size_type_array &gid_ptr, const ordinal_type_array &gid_colidx, + const size_type_array &sid_ptr, const ordinal_type_array &sid_colidx, + const ordinal_type_array &blk_colidx, const ordinal_type_array &stree_parent, + const size_type_array &stree_ptr, const ordinal_type_array &stree_children, + const ordinal_type_array_host &stree_level, const ordinal_type_array_host &stree_roots, + const ordinal_type verbose) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER; + } + + void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut, + const ordinal_type device_factor_thres, const ordinal_type device_solve_thres, + const ordinal_type nstreams) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER; + } + + void createObject(numeric_tools_base_type *&object) { +#if !defined(__CUDA_ARCH__) + // TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY; + switch (_variant) { + case 0: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var0_type); + break; + } + case 1: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var1_type); + break; + } + case 2: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var2_type); + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Invalid variant input"); + break; + } + } +#endif + } +}; +#endif + +#if defined(KOKKOS_ENABLE_OPENMP) +template class NumericToolsFactory::type> { +public: + using value_type = ValueType; + using device_type = typename UseThisDevice::type; + using numeric_tools_base_type = NumericToolsBase; + using numeric_tools_serial_type = NumericToolsSerial; + using numeric_tools_levelset_var0_type = NumericToolsLevelSet; + using numeric_tools_levelset_var1_type = NumericToolsLevelSet; + using numeric_tools_levelset_var2_type = NumericToolsLevelSet; + + TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING; + TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER; + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER; + + void setBaseMember(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, + const size_type_array &gid_ptr, const ordinal_type_array &gid_colidx, + const size_type_array &sid_ptr, const ordinal_type_array &sid_colidx, + const ordinal_type_array &blk_colidx, const ordinal_type_array &stree_parent, + const size_type_array &stree_ptr, const ordinal_type_array &stree_children, + const ordinal_type_array_host &stree_level, const ordinal_type_array_host &stree_roots, + const ordinal_type verbose) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER; + } + + void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut, + const ordinal_type device_factor_thres, const ordinal_type device_solve_thres, + const ordinal_type nstreams) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER; + } + + void createObject(numeric_tools_base_type *&object) { +#if !defined(__CUDA_ARCH__) + // TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY; + switch (_variant) { + case 0: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var0_type); + break; + } + case 1: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var1_type); + break; + } + case 2: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var2_type); + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Invalid variant input"); + break; + } + } +#endif + } +}; +#endif + +#if defined(KOKKOS_ENABLE_CUDA) +template class NumericToolsFactory::type> { +public: + using value_type = ValueType; + using device_type = typename UseThisDevice::type; + using numeric_tools_base_type = NumericToolsBase; + using numeric_tools_serial_type = NumericToolsSerial; + using numeric_tools_levelset_var0_type = NumericToolsLevelSet; + using numeric_tools_levelset_var1_type = NumericToolsLevelSet; + using numeric_tools_levelset_var2_type = NumericToolsLevelSet; + + TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING; + TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER; + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER; + + void setBaseMember(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, + const size_type_array &gid_ptr, const ordinal_type_array &gid_colidx, + const size_type_array &sid_ptr, const ordinal_type_array &sid_colidx, + const ordinal_type_array &blk_colidx, const ordinal_type_array &stree_parent, + const size_type_array &stree_ptr, const ordinal_type_array &stree_children, + const ordinal_type_array_host &stree_level, const ordinal_type_array_host &stree_roots, + const ordinal_type verbose) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER; + } + + void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut, + const ordinal_type device_factor_thres, const ordinal_type device_solve_thres, + const ordinal_type nstreams) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER; + } + + void createObject(numeric_tools_base_type *&object) { + switch (_variant) { + case 0: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var0_type); + break; + } + case 1: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var1_type); + break; + } + case 2: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var2_type); + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Invalid variant input"); + break; + } + } + } +}; +#endif + +#undef TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING +#undef TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER +#undef TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER +#undef TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER +#undef TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER +#undef TACHO_NUMERIC_TOOLS_SERIAL_BODY + +} // namespace Tacho +#endif diff --git a/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_Solver_Impl.hpp b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_Solver_Impl.hpp new file mode 100644 index 000000000000..a0be7ef91131 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/.do-not-use/Tacho_Solver_Impl.hpp @@ -0,0 +1,828 @@ +#ifndef __TACHO_SOLVER_IMPL_HPP__ +#define __TACHO_SOLVER_IMPL_HPP__ + +/// \file Tacho_Solver_Impl.hpp +/// \brief solver interface +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Internal.hpp" +#include "Tacho_Solver.hpp" + +namespace Tacho { + +template +Solver::Solver() + : _transpose(0), _mode(0), _order_connected_graph_separately(0), _m(0), _nnz(0), _ap(), _h_ap(), _aj(), _h_aj(), + _perm(), _h_perm(), _peri(), _h_peri(), _m_graph(0), _nnz_graph(0), _h_ap_graph(), _h_aj_graph(), _h_perm_graph(), + _h_peri_graph(), _N(nullptr), _L0(nullptr), _L1(nullptr), _L2(nullptr), _verbose(0), _small_problem_thres(1024), + _serial_thres_size(-1), _mb(-1), _nb(-1), _front_update_mode(-1), _levelset(0), _device_level_cut(0), + _device_factor_thres(64), _device_solve_thres(128), _variant(2), _nstreams(16), _max_num_superblocks(-1) {} + +/// deleted +// template +// Solver +// ::Solver(const Solver &b) = default; + +/// +/// common options +/// +template void Solver::setVerbose(const ordinal_type verbose) { _verbose = verbose; } + +template +void Solver::setSmallProblemThresholdsize(const ordinal_type small_problem_thres) { + _small_problem_thres = small_problem_thres; +} + +// template +// void +// Solver +// ::setTransposeSolve(const bool transpose) { +// _transpose = transpose; // this option is not used yet +// } + +template +void Solver::setMatrixType(const int symmetric, // 0 - unsymmetric, 1 - structure sym, 2 - symmetric + const bool is_positive_definite) { + switch (symmetric) { + case 0: { + _mode = LU; + break; + } + case 1: { + _mode = SymLU; + break; + } + case 2: { + if (is_positive_definite) { + if (std::is_same::value || std::is_same::value || + std::is_same>::value || + std::is_same>::value) { + // real symmetric posdef + _mode = Cholesky; + } + } else { // real or complex symmetric indef + _mode = LDL; + } + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "symmetric argument is wrong"); + } + } +} + +template +void Solver::setOrderConnectedGraphSeparately(const ordinal_type order_connected_graph_separately) { + _order_connected_graph_separately = order_connected_graph_separately; +} + +/// +/// tasking options +/// +template void Solver::setSerialThresholdsize(const ordinal_type serial_thres_size) { + _serial_thres_size = serial_thres_size; +} + +template void Solver::setBlocksize(const ordinal_type mb) { _mb = mb; } + +template void Solver::setPanelsize(const ordinal_type nb) { _nb = nb; } + +template void Solver::setFrontUpdateMode(const ordinal_type front_update_mode) { + _front_update_mode = front_update_mode; +} + +template +void Solver::setMaxNumberOfSuperblocks(const ordinal_type max_num_superblocks) { + _max_num_superblocks = max_num_superblocks; +} + +/// +/// Level set tools options +/// +template void Solver::setLevelSetScheduling(const bool levelset) { + _levelset = levelset; +} + +template +void Solver::setLevelSetOptionDeviceLevelCut(const ordinal_type device_level_cut) { + _device_level_cut = device_level_cut; +} + +template +void Solver::setLevelSetOptionDeviceFunctionThreshold(const ordinal_type device_factor_thres, + const ordinal_type device_solve_thres) { + _device_factor_thres = device_factor_thres; + _device_solve_thres = device_solve_thres; +} + +template void Solver::setLevelSetOptionAlgorithmVariant(const ordinal_type variant) { + if (variant > 2 || variant < 0) { + std::logic_error("levelset algorithm variants range from 0 to 2"); + } + _variant = variant; +} + +template void Solver::setLevelSetOptionNumStreams(const ordinal_type nstreams) { + _nstreams = nstreams; +} + +/// +/// get interface +/// +template ordinal_type Solver::getNumSupernodes() const { return _nsupernodes; } + +template typename Solver::ordinal_type_array Solver::getSupernodes() const { + return _supernodes; +} + +template +typename Solver::ordinal_type_array Solver::getPermutationVector() const { + return _perm; +} + +template +typename Solver::ordinal_type_array Solver::getInversePermutationVector() const { + return _peri; +} + +// internal only +template int Solver::analyze() { + int r_val(0); + if (_m < _small_problem_thres) { + /// do nothing + if (_verbose) { + printf("TachoSolver: Analyze\n"); + printf("====================\n"); + printf(" Linear system A\n"); + printf(" number of equations: %10d\n", _m); + printf("\n"); + printf(" A is a small problem ( < %d ) and LAPACK is used\n", _small_problem_thres); + printf("\n"); + } + } else { + const bool use_condensed_graph = (_m_graph > 0 && _m_graph < _m); + if (use_condensed_graph) { + Graph graph(_m_graph, _nnz_graph, _h_ap_graph, _h_aj_graph); + graph_tools_type G(graph); +#if defined(TACHO_HAVE_METIS) + if (_order_connected_graph_separately) { + G.setOption(METIS_OPTION_CCORDER, 1); + } +#endif + G.reorder(_verbose); + + _h_perm_graph = G.PermVector(); + _h_peri_graph = G.InvPermVector(); + + r_val = analyze_condensed_graph(); + } else { + const bool use_graph_partitioner = (_h_perm.extent(0) == 0 && _h_peri.extent(0) == 0); + if (use_graph_partitioner) { + Graph graph(_m, _nnz, _h_ap, _h_aj); + graph_tools_type G(graph); +#if defined(TACHO_HAVE_METIS) + if (_order_connected_graph_separately) { + G.setOption(METIS_OPTION_CCORDER, 1); + } +#endif + G.reorder(_verbose); + + _h_perm = G.PermVector(); + _h_peri = G.InvPermVector(); + + r_val = analyze_linear_system(); + } else { + r_val = analyze_linear_system(); + } + } + } + return r_val; +} + +template int Solver::analyze_linear_system() { + if (_verbose) { + printf("TachoSolver: Analyze Linear System\n"); + printf("==================================\n"); + } + + { + symbolic_tools_type S(_m, _h_ap, _h_aj, _h_perm, _h_peri); + S.symbolicFactorize(_verbose); + + _nsupernodes = S.NumSupernodes(); + _stree_level = S.SupernodesTreeLevel(); + _stree_roots = S.SupernodesTreeRoots(); + + _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); + _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); + _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); + _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); + _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); + _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); + _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); + _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); + _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); + + Kokkos::deep_copy(_supernodes, S.Supernodes()); + Kokkos::deep_copy(_gid_super_panel_ptr, S.gidSuperPanelPtr()); + Kokkos::deep_copy(_gid_super_panel_colidx, S.gidSuperPanelColIdx()); + Kokkos::deep_copy(_sid_super_panel_ptr, S.sidSuperPanelPtr()); + Kokkos::deep_copy(_sid_super_panel_colidx, S.sidSuperPanelColIdx()); + Kokkos::deep_copy(_blk_super_panel_colidx, S.blkSuperPanelColIdx()); + Kokkos::deep_copy(_stree_parent, S.SupernodesTreeParent()); + Kokkos::deep_copy(_stree_ptr, S.SupernodesTreePtr()); + Kokkos::deep_copy(_stree_children, S.SupernodesTreeChildren()); + + // perm and peri is updated during symbolic factorization + _perm = Kokkos::create_mirror_view(exec_memory_space(), _h_perm); + _peri = Kokkos::create_mirror_view(exec_memory_space(), _h_peri); + + Kokkos::deep_copy(_perm, _h_perm); + Kokkos::deep_copy(_peri, _h_peri); + } + return 0; +} + +template int Solver::analyze_condensed_graph() { + if (_verbose) { + printf("TachoSolver: Analyze Condensed Graph and Evaporate the Graph\n"); + printf("============================================================\n"); + } + + { + symbolic_tools_type S(_m_graph, _h_ap_graph, _h_aj_graph, _h_perm_graph, _h_peri_graph); + S.symbolicFactorize(_verbose); + S.evaporateSymbolicFactors(_h_aw_graph, _verbose); + + _nsupernodes = S.NumSupernodes(); + _stree_level = S.SupernodesTreeLevel(); + _stree_roots = S.SupernodesTreeRoots(); + + _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); + _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); + _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); + _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); + _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); + _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); + _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); + _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); + _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); + _perm = Kokkos::create_mirror_view(exec_memory_space(), S.PermVector()); + _peri = Kokkos::create_mirror_view(exec_memory_space(), S.InvPermVector()); + + Kokkos::deep_copy(_supernodes, S.Supernodes()); + Kokkos::deep_copy(_gid_super_panel_ptr, S.gidSuperPanelPtr()); + Kokkos::deep_copy(_gid_super_panel_colidx, S.gidSuperPanelColIdx()); + Kokkos::deep_copy(_sid_super_panel_ptr, S.sidSuperPanelPtr()); + Kokkos::deep_copy(_sid_super_panel_colidx, S.sidSuperPanelColIdx()); + Kokkos::deep_copy(_blk_super_panel_colidx, S.blkSuperPanelColIdx()); + Kokkos::deep_copy(_stree_parent, S.SupernodesTreeParent()); + Kokkos::deep_copy(_stree_ptr, S.SupernodesTreePtr()); + Kokkos::deep_copy(_stree_children, S.SupernodesTreeChildren()); + Kokkos::deep_copy(_perm, S.PermVector()); + Kokkos::deep_copy(_peri, S.InvPermVector()); + + _h_perm = S.PermVector(); + _h_peri = S.InvPermVector(); + } + return 0; +} + +template int Solver::initialize() { + if (_verbose) { + printf("TachoSolver: Initialize\n"); + printf("=======================\n"); + } + + /// + /// initialize numeric tools + /// + if (_m < _small_problem_thres) { + //_A = value_type_matrix_host("A", _m, _m); + } else { + if (_N == nullptr) + _N = (numeric_tools_type *)::operator new(sizeof(numeric_tools_type)); + + new (_N) numeric_tools_type(_m, _ap, _aj, _perm, _peri, _nsupernodes, _supernodes, _gid_super_panel_ptr, + _gid_super_panel_colidx, _sid_super_panel_ptr, _sid_super_panel_colidx, + _blk_super_panel_colidx, _stree_parent, _stree_ptr, _stree_children, _stree_level, + _stree_roots); + + if (_serial_thres_size < 0) { // set default values + _serial_thres_size = 64; + } + _N->setSerialThresholdSize(_serial_thres_size); + + if (_max_num_superblocks < 0) { // set default values + _max_num_superblocks = 16; + } + _N->setMaxNumberOfSuperblocks(_max_num_superblocks); + + if (_front_update_mode < 0) { // set default values + _front_update_mode = 1; // atomic is default + } + _N->setFrontUpdateMode(_front_update_mode); + _N->printMemoryStat(_verbose); + + /// + /// initialize levelset tools + /// + if (_levelset) { + if (_variant == 0) { + if (_L0 == nullptr) + _L0 = (levelset_tools_var0_type *)::operator new(sizeof(levelset_tools_var0_type)); + new (_L0) levelset_tools_var0_type(*_N); + _L0->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); + _L0->createStream(_nstreams); + } else if (_variant == 1) { + if (_L1 == nullptr) + _L1 = (levelset_tools_var1_type *)::operator new(sizeof(levelset_tools_var1_type)); + new (_L1) levelset_tools_var1_type(*_N); + _L1->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); + _L1->createStream(_nstreams); + } else if (_variant == 2) { + if (_L2 == nullptr) + _L2 = (levelset_tools_var2_type *)::operator new(sizeof(levelset_tools_var2_type)); + new (_L2) levelset_tools_var1_type(*_N); + _L2->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); + _L2->createStream(_nstreams); + } + } + } + return 0; +} + +template int Solver::factorize(const value_type_array &ax) { + switch (_mode) { + case Cholesky: + factorize_chol(ax); + break; + case LDL: + factorize_ldl(ax); + break; + default: + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); + } + return 0; +} + +template int Solver::factorize_chol(const value_type_array &ax) { + if (_verbose) { + printf("TachoSolver: Factorize\n"); + printf("======================\n"); + } + if (_m < _small_problem_thres) { + Kokkos::Timer timer; + + timer.reset(); + _A = value_type_matrix_host("A", _m, _m); + auto h_ax = Kokkos::create_mirror_view(host_memory_space(), ax); + Kokkos::deep_copy(h_ax, ax); + for (ordinal_type i = 0; i < _m; ++i) { + const size_type jbeg = _h_ap(i), jend = _h_ap(i + 1); + for (size_type j = jbeg; j < jend; ++j) { + const ordinal_type col = _h_aj(j); + if (i <= col) + _A(i, col) = h_ax(j); + } + } + const double t_copy = timer.seconds(); + + timer.reset(); + Tacho::Chol::invoke(_A); + const double t_factor = timer.seconds(); + + if (_verbose) { + printf("Summary: NumericTools (SmallDenseFactorization)\n"); + printf("===============================================\n"); + printf(" Time\n"); + printf(" time for copying A into U: %10.6f s\n", t_copy); + printf(" time for numeric factorization: %10.6f s\n", t_factor); + printf(" total time spent: %10.6f s\n", (t_copy + t_factor)); + printf("\n"); + } + } else { + +#if !defined(KOKKOS_ENABLE_CUDA) + const ordinal_type nthreads = host_space::impl_thread_pool_size(0); +#endif + if (_levelset) { + if (_variant == 0) + _L0->factorizeCholesky(ax, _verbose); + else if (_variant == 1) + _L1->factorizeCholesky(ax, _verbose); + else if (_variant == 2) + _L2->factorizeCholesky(ax, _verbose); + } +#if !defined(KOKKOS_ENABLE_CUDA) + else if (nthreads == 1) { + if (_nb < 0) + _N->factorizeCholesky_Serial(ax, _verbose); + else + _N->factorizeCholesky_SerialPanel(ax, _nb, _verbose); + } +#endif + else { + const ordinal_type max_dense_size = max(_N->getMaxSupernodeSize(), _N->getMaxSchurSize()); + if (std::is_same::value) { + if (_nb < 0) { + _nb = 64; + // if (max_dense_size < 256) _nb = -1; + // else if (max_dense_size < 512) _nb = 64; + // else if (max_dense_size < 1024) _nb = 128; + // else if (max_dense_size < 8192) _nb = 256; + // else _nb = 256; + } + if (_mb < 0) { + if (max_dense_size < 256) + _mb = -1; + else if (max_dense_size < 512) + _mb = 64; + else if (max_dense_size < 1024) + _mb = 128; + else if (max_dense_size < 8192) + _mb = 256; + else + _mb = 256; + } + } else { + if (_nb < 0) { + _nb = 40; + // if (max_dense_size < 256) _nb = -1; + // else if (max_dense_size < 512) _nb = 64; + // else if (max_dense_size < 1024) _nb = 128; + // else if (max_dense_size < 8192) _nb = 256; + // else _nb = 256; + } + if (_mb < 0) { + if (max_dense_size < 256) + _mb = -1; + else if (max_dense_size < 512) + _mb = 80; + else if (max_dense_size < 1024) + _mb = 120; + else if (max_dense_size < 8192) + _mb = 160; + else + _mb = 160; + } + } + + if (_nb <= 0) + if (_mb > 0) + _N->factorizeCholesky_ParallelByBlocks(ax, _mb, _verbose); + else + _N->factorizeCholesky_Parallel(ax, _verbose); + else if (_mb > 0) + _N->factorizeCholesky_ParallelByBlocksPanel(ax, _mb, _nb, _verbose); + else + _N->factorizeCholesky_ParallelPanel(ax, _nb, _verbose); + } + } + return 0; +} + +template int Solver::factorize_ldl(const value_type_array &ax) { + if (_verbose) { + printf("TachoSolver: Factorize\n"); + printf("======================\n"); + } + if (_m < _small_problem_thres) { + Kokkos::Timer timer; + + timer.reset(); + _A = value_type_matrix_host("A", _m, _m); + auto h_ax = Kokkos::create_mirror_view(host_memory_space(), ax); + Kokkos::deep_copy(h_ax, ax); + for (ordinal_type i = 0; i < _m; ++i) { + const size_type jbeg = _h_ap(i), jend = _h_ap(i + 1); + for (size_type j = jbeg; j < jend; ++j) { + const ordinal_type col = _h_aj(j); + if (i >= col) + _A(i, col) = h_ax(j); + } + } + const double t_copy = timer.seconds(); + + timer.reset(); + _P = ordinal_type_array_host("P", 4 * _m); + _D = value_type_matrix_host("D", _m, 2); + auto W = value_type_array_host("W", 32 * _m); + Tacho::LDL::invoke(_A, _P, W); + Tacho::LDL::modify(_A, _P, _D); + const double t_factor = timer.seconds(); + + if (_verbose) { + printf("Summary: NumericTools (SmallDenseFactorization)\n"); + printf("===============================================\n"); + printf(" Time\n"); + printf(" time for copying A into L: %10.6f s\n", t_copy); + printf(" time for numeric factorization: %10.6f s\n", t_factor); + printf(" total time spent: %10.6f s\n", (t_copy + t_factor)); + printf("\n"); + } + } else { +#if !defined(KOKKOS_ENABLE_CUDA) + const ordinal_type nthreads = host_space::impl_thread_pool_size(0); +#endif + if (_levelset) { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); + // if (_variant == 0) _L0->factorizeLDL(ax, _verbose); + // else if (_variant == 1) _L1->factorizeLDL(ax, _verbose); + // else if (_variant == 2) _L2->factorizeLDL(ax, _verbose); + } +#if !defined(KOKKOS_ENABLE_CUDA) + else if (nthreads == 1) { + _N->factorizeLDL_Serial(ax, _verbose); + // if (_nb < 0) + // _N->factorizeLDL_Serial(ax, _verbose); + // else + // _N->factorizeLDL_SerialPanel(ax, _nb, _verbose); + } +#endif + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); + } + } + return 0; +} + +template +int Solver::solve(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t) { + switch (_mode) { + case Cholesky: + solve_chol(x, b, t); + break; + case LDL: + solve_ldl(x, b, t); + break; + default: + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); + } + return 0; +} + +template +int Solver::solve_chol(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t) { + if (_verbose) { + printf("TachoSolver: Solve\n"); + printf("==================\n"); + } + + if (_m < _small_problem_thres) { + Kokkos::Timer timer; + + timer.reset(); + Kokkos::deep_copy(x, b); + const double t_copy = timer.seconds(); + + timer.reset(); + auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); + Kokkos::deep_copy(h_x, x); + Trsm::invoke(Diag::NonUnit(), 1.0, _A, h_x); + Trsm::invoke(Diag::NonUnit(), 1.0, _A, h_x); + Kokkos::deep_copy(x, h_x); + const double t_solve = timer.seconds(); + + if (_verbose) { + printf("Summary: NumericTools (SmallDenseSolve)\n"); + printf("=======================================\n"); + printf(" Time\n"); + printf(" time for extra work e.g.,copy rhs: %10.6f s\n", t_copy); + printf(" time for numeric solve: %10.6f s\n", t_solve); + printf(" total time spent: %10.6f s\n", (t_solve + t_copy)); + printf("\n"); + } + } else { +#if !defined(KOKKOS_ENABLE_CUDA) + const ordinal_type nthreads = host_space::impl_thread_pool_size(0); +#endif + TACHO_TEST_FOR_EXCEPTION(t.extent(0) < x.extent(0) || t.extent(1) < x.extent(1), std::logic_error, + "Temporary rhs vector t is smaller than x"); + auto tt = Kokkos::subview(t, Kokkos::pair(0, x.extent(0)), + Kokkos::pair(0, x.extent(1))); + if (_levelset) { + if (_variant == 0) + _L0->solveCholesky(x, b, tt, _verbose); + else if (_variant == 1) + _L1->solveCholesky(x, b, tt, _verbose); + else if (_variant == 2) + _L2->solveCholesky(x, b, tt, _verbose); + } +#if !defined(KOKKOS_ENABLE_CUDA) + else if (nthreads == 1) { + _N->solveCholesky_Serial(x, b, tt, _verbose); + } +#endif + else { + _N->solveCholesky_Parallel(x, b, tt, _verbose); + } + } + return 0; +} + +template +int Solver::solve_ldl(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t) { + if (_verbose) { + printf("TachoSolver: Solve\n"); + printf("==================\n"); + } + + if (_m < _small_problem_thres) { + Kokkos::Timer timer; + + timer.reset(); + Kokkos::deep_copy(x, b); + const double t_copy = timer.seconds(); + + timer.reset(); + auto perm = ordinal_type_array_host(_P.data() + 2 * _m, _m); + auto peri = ordinal_type_array_host(_P.data() + 3 * _m, _m); + auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); + auto h_t = Kokkos::create_mirror_view(host_memory_space(), t); + + ApplyPermutation::invoke(h_x, perm, h_t); + Trsm::invoke(Diag::Unit(), 1.0, _A, h_t); + Scale2x2_BlockInverseDiagonals::invoke(_P, _D, h_t); + Trsm::invoke(Diag::Unit(), 1.0, _A, h_t); + ApplyPermutation::invoke(h_t, peri, h_x); + Kokkos::deep_copy(x, h_x); + const double t_solve = timer.seconds(); + + if (_verbose) { + printf("Summary: NumericTools (SmallDenseSolve)\n"); + printf("=======================================\n"); + printf(" Time\n"); + printf(" time for extra work e.g.,copy rhs: %10.6f s\n", t_copy); + printf(" time for numeric solve: %10.6f s\n", t_solve); + printf(" total time spent: %10.6f s\n", (t_solve + t_copy)); + printf("\n"); + } + } else { +#if !defined(KOKKOS_ENABLE_CUDA) + const ordinal_type nthreads = host_space::impl_thread_pool_size(0); +#endif + TACHO_TEST_FOR_EXCEPTION(t.extent(0) < x.extent(0) || t.extent(1) < x.extent(1), std::logic_error, + "Temporary rhs vector t is smaller than x"); + auto tt = Kokkos::subview(t, Kokkos::pair(0, x.extent(0)), + Kokkos::pair(0, x.extent(1))); + if (_levelset) { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); + // if (_variant == 0) _L0->solveCholesky(x, b, tt, _verbose); + // else if (_variant == 1) _L1->solveCholesky(x, b, tt, _verbose); + // else if (_variant == 2) _L2->solveCholesky(x, b, tt, _verbose); + } +#if !defined(KOKKOS_ENABLE_CUDA) + else if (nthreads == 1) { + _N->solveLDL_Serial(x, b, tt, _verbose); + } +#endif + else { + //_N->solveCholesky_Parallel(x, b, tt, _verbose); + } + } + return 0; +} + +template +double Solver::computeRelativeResidual(const value_type_array &ax, const value_type_matrix &x, + const value_type_matrix &b) { + CrsMatrixBase A; + A.setExternalMatrix(_m, _m, _nnz, _ap, _aj, ax); + + return Tacho::computeRelativeResidual(A, x, b); +} + +template int Solver::exportFactorsToCrsMatrix(crs_matrix_type &A) { + if (_m < _small_problem_thres) { + typedef ArithTraits ats; + const typename ats::mag_type zero(0); + + /// count nonzero elements in dense U + const ordinal_type m = _m; + size_type_array_host h_ap("h_ap", m + 1); + for (ordinal_type i = 0; i < m; ++i) + for (ordinal_type j = 0; j < m; ++j) + h_ap(i + 1) += (ats::abs(_A(i, j)) > zero); + + /// serial scan; this is a small problem + h_ap(0) = 0; + for (ordinal_type i = 0; i < m; ++i) + h_ap(i + 1) += h_ap(i); + + /// create a host crs matrix + const ordinal_type nnz = h_ap(m); + ordinal_type_array_host h_aj(do_not_initialize_tag("h_aj"), nnz); + value_type_array_host h_ax(do_not_initialize_tag("h_ax"), nnz); + + for (ordinal_type i = 0, k = 0; i < m; ++i) + for (ordinal_type j = i; j < m; ++j) + if (ats::abs(_A(i, j)) > zero) { + h_aj(k) = j; + h_ax(k) = _A(i, j); + ++k; + } + + crs_matrix_type_host h_A; + h_A.setExternalMatrix(m, m, nnz, h_ap, h_aj, h_ax); + /// h_A.showMe(std::cout, true); + A.clear(); + A.createConfTo(h_A); + A.copy(h_A); + } else { + _N->exportFactorsToCrsMatrix(A, false); + } + return 0; +} + +template int Solver::release() { + if (_verbose) { + printf("TachoSolver: Release\n"); + printf("====================\n"); + } + + if (_levelset) { + if (_variant == 0) { + if (_L0 != nullptr) + _L0->release(_verbose); + delete _L0; + _L0 = nullptr; + } else if (_variant == 1) { + if (_L1 != nullptr) + _L1->release(_verbose); + delete _L1; + _L1 = nullptr; + } else if (_variant == 2) { + if (_L2 != nullptr) + _L2->release(_verbose); + delete _L2; + _L2 = nullptr; + } + } + + { + if (_N != nullptr) + _N->release(_verbose); + delete _N; + _N = nullptr; + } + { + _transpose = false; + _mode = 0; + + _m = 0; + _nnz = 0; + + _ap = size_type_array(); + _h_ap = size_type_array_host(); + _aj = ordinal_type_array(); + _h_aj = ordinal_type_array_host(); + + _perm = ordinal_type_array(); + _h_perm = ordinal_type_array_host(); + _peri = ordinal_type_array(); + _h_peri = ordinal_type_array_host(); + + _m_graph = 0; + _nnz_graph = 0; + + _h_ap_graph = size_type_array_host(); + _h_aj_graph = ordinal_type_array_host(); + + _h_perm_graph = ordinal_type_array_host(); + _h_peri_graph = ordinal_type_array_host(); + + _nsupernodes = 0; + _supernodes = ordinal_type_array(); + + _gid_super_panel_ptr = size_type_array(); + _gid_super_panel_colidx = ordinal_type_array(); + + _sid_super_panel_ptr = size_type_array(); + + _sid_super_panel_colidx = ordinal_type_array(); + _blk_super_panel_colidx = ordinal_type_array(); + + _stree_ptr = size_type_array(); + _stree_children = ordinal_type_array(); + + _stree_parent = ordinal_type_array(); + _stree_roots = ordinal_type_array_host(); + + _A = value_type_matrix_host(); + + _verbose = 0; + _small_problem_thres = 1024; + _serial_thres_size = -1; + _mb = -1; + _nb = -1; + _front_update_mode = -1; + + _max_num_superblocks = -1; + } + return 0; +} + +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/CMakeLists.txt b/packages/shylu/shylu_node/tacho/src/CMakeLists.txt index 35ec6c9e41f2..ac72b762f7cd 100644 --- a/packages/shylu/shylu_node/tacho/src/CMakeLists.txt +++ b/packages/shylu/shylu_node/tacho/src/CMakeLists.txt @@ -38,64 +38,82 @@ SET(SOURCES "") FILE(GLOB HEADERS_PUBLIC *.hpp impl/*.hpp) LIST( APPEND HEADERS_PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/Tacho_config.h ) - APPEND_SET(SOURCES impl/Tacho_Util.cpp impl/Tacho_Blas_External.cpp impl/Tacho_Lapack_External.cpp impl/Tacho_GraphTools_Metis.cpp impl/Tacho_SymbolicTools.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Cuda.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_OpenMP.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Serial.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Cuda.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_OpenMP.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Serial.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Cuda.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_OpenMP.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Serial.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Cuda.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_OpenMP.cpp - # eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Serial.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_complex_double_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Serial.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_complex_float_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Serial.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_double_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_double_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_double_Serial.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_float_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_float_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskScheduler_float_Serial.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Serial.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Serial.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Serial.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Cuda.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_OpenMP.cpp - # eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Serial.cpp - eti/Tacho_Driver_ETI_complex_double_Cuda.cpp - eti/Tacho_Driver_ETI_complex_double_OpenMP.cpp - eti/Tacho_Driver_ETI_complex_double_Serial.cpp - eti/Tacho_Driver_ETI_complex_float_Cuda.cpp - eti/Tacho_Driver_ETI_complex_float_OpenMP.cpp - eti/Tacho_Driver_ETI_complex_float_Serial.cpp - eti/Tacho_Driver_ETI_double_Cuda.cpp - eti/Tacho_Driver_ETI_double_OpenMP.cpp - eti/Tacho_Driver_ETI_double_Serial.cpp - eti/Tacho_Driver_ETI_float_Cuda.cpp - eti/Tacho_Driver_ETI_float_OpenMP.cpp - eti/Tacho_Driver_ETI_float_Serial.cpp - ) +SET(TACHO_ETI_FILE "Tacho_Driver") +SET(TACHO_ETI_DEVICE_NAME "") +SET(TACHO_ETI_DEVICE_TYPE "") +SET(TACHO_ETI_WITH_TASK "") + +IF (Kokkos_ENABLE_SERIAL) + LIST(APPEND TACHO_ETI_DEVICE_NAME "Serial") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "0") +ENDIF() +IF (Kokkos_ENABLE_OPENMP) + LIST(APPEND TACHO_ETI_DEVICE_NAME "OpenMP") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "1") +ENDIF() +IF (Kokkos_ENABLE_CUDA) + LIST(APPEND TACHO_ETI_DEVICE_NAME "CUDA") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "0") +ENDIF() +IF (Kokkos_ENABLE_HIP) + LIST(APPEND TACHO_ETI_DEVICE_NAME "HIP") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "0") +ENDIF() + +LIST(LENGTH TACHO_ETI_DEVICE_NAME ETI_DEVICE_COUNT) +MATH(EXPR ETI_DEVICE_COUNT "${ETI_DEVICE_COUNT}-1") + +SET(TACHO_ETI_VALUE_NAME "") +SET(TACHO_ETI_VALUE_TYPE "") + +LIST(APPEND TACHO_ETI_VALUE_NAME "double") +LIST(APPEND TACHO_ETI_VALUE_TYPE "double") + +LIST(APPEND TACHO_ETI_VALUE_NAME "float") +LIST(APPEND TACHO_ETI_VALUE_TYPE "float") + +LIST(APPEND TACHO_ETI_VALUE_NAME "complex_double") +LIST(APPEND TACHO_ETI_VALUE_TYPE "Kokkos::complex") + +LIST(APPEND TACHO_ETI_VALUE_NAME "complex_float") +LIST(APPEND TACHO_ETI_VALUE_TYPE "Kokkos::complex") + +LIST(LENGTH TACHO_ETI_VALUE_NAME ETI_VALUE_COUNT) +MATH(EXPR ETI_VALUE_COUNT "${ETI_VALUE_COUNT}-1") + +FOREACH(I RANGE ${ETI_DEVICE_COUNT}) + LIST(GET TACHO_ETI_DEVICE_NAME ${I} ETI_DEVICE_NAME) + LIST(GET TACHO_ETI_DEVICE_TYPE ${I} ETI_DEVICE_TYPE) + LIST(GET TACHO_ETI_WITH_TASK ${I} ETI_WITH_TASK) + + FOREACH(J RANGE ${ETI_VALUE_COUNT}) + LIST(GET TACHO_ETI_VALUE_NAME ${J} ETI_VALUE_NAME) + LIST(GET TACHO_ETI_VALUE_TYPE ${J} ETI_VALUE_TYPE) + + FOREACH(ETI_FILE IN LISTS TACHO_ETI_FILE) + SET(ETI_NAME "${ETI_FILE}_ETI_${ETI_VALUE_NAME}_${ETI_DEVICE_NAME}") + MESSAGE(STATUS "Generating ETI: ${ETI_NAME}.cpp") + CONFIGURE_FILE(eti/${ETI_FILE}_ETI.in eti/${ETI_NAME}.cpp) + + APPEND_SET(SOURCES + eti/${ETI_NAME}.cpp) + + ENDFOREACH() + ENDFOREACH() +ENDFOREACH() + #----------------------------------------------------------------------------- TRIBITS_ADD_LIBRARY( diff --git a/packages/shylu/shylu_node/tacho/src/Tacho.hpp b/packages/shylu/shylu_node/tacho/src/Tacho.hpp index 85a1be13afff..dcc1d2d1cb35 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho.hpp @@ -1,19 +1,19 @@ #ifndef __TACHO_HPP__ #define __TACHO_HPP__ -#include "Tacho_config.h" +#include "Tacho_config.h" #include "Kokkos_Core.hpp" #include "Kokkos_Timer.hpp" #include #include -#include #include +#include #include #include -#include #include +#include #include /// \file Tacho.hpp @@ -22,151 +22,241 @@ namespace Tacho { - /// - /// default ordinal and size type - /// +/// +/// default ordinal and size type +/// -#if defined( TACHO_USE_INT_INT ) - typedef int ordinal_type; - typedef int size_type; -#elif defined( TACHO_USE_INT_SIZE_T ) - typedef int ordinal_type; - typedef size_t size_type; +#if defined(TACHO_USE_INT_INT) +typedef int ordinal_type; +typedef int size_type; +#elif defined(TACHO_USE_INT_SIZE_T) +typedef int ordinal_type; +typedef size_t size_type; #else - typedef int ordinal_type; - typedef size_t size_type; +typedef int ordinal_type; +typedef size_t size_type; #endif - /// - /// default device type used in tacho - /// - template - struct UseThisDevice { - using exec_space = ExecSpace; - using memory_space = typename exec_space::memory_space; - using type = Kokkos::Device; - using device_type = type; - }; - - template - struct UseThisScheduler { - using type = Kokkos::TaskSchedulerMultiple; - using scheduler_type = type; - }; - - /// until kokkos dual view issue is resolved, we follow the default space in Trilinos (uvm) +/// +/// default Kokkos types (non-specialized code path is error) +/// +template struct UseThisDevice; + +template struct UseThisScheduler; + +template struct UseThisFuture; + +/// +/// dummy objects when kokkos tasking is not used +/// +template struct DummyTaskScheduler { + static_assert(Kokkos::is_execution_space::value, "Error: ExecSpace is not an execution space"); + using execution_space = ExecSpace; +}; + +template struct DummyFuture { + DummyFuture() = default; + DummyFuture(const DummyFuture &b) = default; + + void clear() {} +}; + +/// until kokkos dual view issue is resolved, we follow the default space in Trilinos (uvm) #if defined(KOKKOS_ENABLE_CUDA) - template<> - struct UseThisDevice { - using type = Kokkos::Device; - using device_type = type; - }; +template <> struct UseThisDevice { + using type = Kokkos::Device; + using device_type = type; +}; +template <> struct UseThisScheduler { + using type = DummyTaskScheduler; + using scheduler_type = type; +}; +template struct UseThisFuture { + using type = DummyFuture; + using future_type = type; +}; +#endif +#if defined(KOKKOS_ENABLE_HIP) +template <> struct UseThisDevice { + using type = Kokkos::Device; + using device_type = type; +}; +template <> struct UseThisScheduler { + using type = DummyTaskScheduler; + using scheduler_type = type; +}; +template struct UseThisFuture { + using type = DummyFuture; + using future_type = type; +}; #endif #if defined(KOKKOS_ENABLE_OPENMP) - template<> - struct UseThisDevice { - using type = Kokkos::Device; - using device_type = type; - }; +template <> struct UseThisDevice { + using type = Kokkos::Device; + using device_type = type; +}; +template <> struct UseThisScheduler { +#if defined(KOKKOS_ENABLE_TASKDAG) && false + using type = Kokkos::TaskSchedulerMultiple; +#else + using type = DummyTaskScheduler; +#endif + using scheduler_type = type; +}; +template struct UseThisFuture { +#if defined(KOKKOS_ENABLE_TASKDAG) && false + using type = Kokkos::BasicFuture; +#else + using type = DummyFuture; +#endif + using future_type = type; +}; #endif #if defined(KOKKOS_ENABLE_SERIAL) - template<> - struct UseThisDevice { - using type = Kokkos::Device; - using device_type = type; - }; +template <> struct UseThisDevice { + using type = Kokkos::Device; + using device_type = type; +}; +template <> struct UseThisScheduler { + using type = DummyTaskScheduler; + using scheduler_type = type; +}; +template struct UseThisFuture { + using type = DummyFuture; + using future_type = type; +}; #endif - /// - /// print execution spaces - /// - template - void printExecSpaceConfiguration(std::string name, const bool detail = false) { - if (!Kokkos::is_space::value) { - std::string msg("SpT is not Kokkos execution space"); - fprintf(stderr, ">> Error in file %s, line %d\n",__FILE__,__LINE__); - fprintf(stderr, " %s\n", msg.c_str()); - throw std::logic_error(msg.c_str()); - } - std::cout << std::setw(16) << name << ":: "; - SpT::print_configuration(std::cout, detail); +/// +/// print execution spaces +/// +template void printExecSpaceConfiguration(std::string name, const bool detail = false) { + if (!Kokkos::is_space::value) { + std::string msg("SpT is not Kokkos execution space"); + fprintf(stderr, ">> Error in file %s, line %d\n", __FILE__, __LINE__); + fprintf(stderr, " %s\n", msg.c_str()); + throw std::logic_error(msg.c_str()); + } + bool is_printed(false); +#if defined(KOKKOS_ENABLE_SERIAL) + if (std::is_same::value) { + is_printed = true; + std::cout << std::setw(16) << name << ":: Serial \n"; } +#endif +#if defined(KOKKOS_ENABLE_OPENMP) + if (std::is_same::value) { + is_printed = true; + std::cout << std::setw(16) << name << ":: OpenMP \n"; + } +#endif +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value) { + is_printed = true; + std::cout << std::setw(16) << name << ":: Cuda \n"; + } +#endif +#if defined(KOKKOS_ENABLE_HIP) + if (std::is_same::value) { + is_printed = true; + std::cout << std::setw(16) << name << ":: HIP \n"; + } +#endif + if (!is_printed) { + std::cout << std::setw(16) << name << ":: not supported Kokkos execution space\n"; + SpT().print_configuration(std::cout, detail); + throw std::logic_error("Error: not supported Kokkos execution space"); + } + if (detail) + SpT().print_configuration(std::cout, true); +} - template - struct ArithTraits; - - template<> - struct ArithTraits { - typedef float val_type; - typedef float mag_type; - - enum : bool { is_complex = false }; - static KOKKOS_FORCEINLINE_FUNCTION mag_type abs (const val_type& x) { return x > 0 ? x : -x; } - static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type& x) { return x; } - static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type& x) { return x; } - static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type& x) { return x; } - }; - - template<> - struct ArithTraits { - typedef double val_type; - typedef double mag_type; - - enum : bool { is_complex = false }; - static KOKKOS_FORCEINLINE_FUNCTION mag_type abs (const val_type& x) { return x > 0 ? x : -x; } - static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type& x) { return x; } - static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type& x) { return x; } - static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type& x) { return x; } - }; - - template<> - struct ArithTraits > { - typedef std::complex val_type; - typedef float mag_type; - - enum : bool { is_complex = true }; - static inline mag_type abs (const val_type& x) { return std::abs(x); } - static inline mag_type real(const val_type& x) { return x.real(); } - static inline mag_type imag(const val_type& x) { return x.imag(); } - static inline val_type conj(const val_type& x) { return std::conj(x); } - }; - - template<> - struct ArithTraits > { - typedef std::complex val_type; - typedef double mag_type; - - enum : bool { is_complex = true }; - static inline mag_type abs (const val_type& x) { return std::abs(x); } - static inline mag_type real(const val_type& x) { return x.real(); } - static inline mag_type imag(const val_type& x) { return x.imag(); } - static inline val_type conj(const val_type& x) { return std::conj(x); } - }; - - template<> - struct ArithTraits > { - typedef Kokkos::complex val_type; - typedef float mag_type; - - enum : bool { is_complex = true }; - static KOKKOS_FORCEINLINE_FUNCTION mag_type abs (const val_type& x) { return Kokkos::abs(x); } - static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type& x) { return x.real(); } - static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type& x) { return x.imag(); } - static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type& x) { return Kokkos::conj(x); } - }; - - template<> - struct ArithTraits > { - typedef Kokkos::complex val_type; - typedef double mag_type; - - enum : bool { is_complex = true }; - static KOKKOS_FORCEINLINE_FUNCTION mag_type abs (const val_type& x) { return Kokkos::abs(x); } - static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type& x) { return x.real(); } - static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type& x) { return x.imag(); } - static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type& x) { return Kokkos::conj(x); } - }; +template struct ArithTraits; -} +template <> struct ArithTraits { + typedef float val_type; + typedef float mag_type; + + enum : bool { is_complex = false }; + static KOKKOS_FORCEINLINE_FUNCTION mag_type abs(const val_type &x) { return x > 0 ? x : -x; } + static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type &x) { return x; } + static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type &x) { return 0; } + static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type &x) { return x; } + static KOKKOS_FORCEINLINE_FUNCTION mag_type epsilon() { return FLT_EPSILON; } + static KOKKOS_FORCEINLINE_FUNCTION void set_real(val_type &x, const mag_type &val) { x = val; } + static KOKKOS_FORCEINLINE_FUNCTION void set_imag(val_type &x, const mag_type &val) {} +}; + +template <> struct ArithTraits { + typedef double val_type; + typedef double mag_type; + + enum : bool { is_complex = false }; + static KOKKOS_FORCEINLINE_FUNCTION mag_type abs(const val_type &x) { return x > 0 ? x : -x; } + static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type &x) { return x; } + static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type &x) { return 0; } + static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type &x) { return x; } + static KOKKOS_FORCEINLINE_FUNCTION mag_type epsilon() { return DBL_EPSILON; } + static KOKKOS_FORCEINLINE_FUNCTION void set_real(val_type &x, const mag_type &val) { x = val; } + static KOKKOS_FORCEINLINE_FUNCTION void set_imag(val_type &x, const mag_type &val) {} +}; + +template <> struct ArithTraits> { + typedef std::complex val_type; + typedef float mag_type; + + enum : bool { is_complex = true }; + static inline mag_type abs(const val_type &x) { return std::abs(x); } + static inline mag_type real(const val_type &x) { return x.real(); } + static inline mag_type imag(const val_type &x) { return x.imag(); } + static inline val_type conj(const val_type &x) { return std::conj(x); } + static inline mag_type epsilon() { return FLT_EPSILON; } + static inline void set_real(val_type &x, const mag_type &val) { x.real(val); } + static inline void set_imag(val_type &x, const mag_type &val) { x.imag(val); } +}; + +template <> struct ArithTraits> { + typedef std::complex val_type; + typedef double mag_type; + + enum : bool { is_complex = true }; + static inline mag_type abs(const val_type &x) { return std::abs(x); } + static inline mag_type real(const val_type &x) { return x.real(); } + static inline mag_type imag(const val_type &x) { return x.imag(); } + static inline val_type conj(const val_type &x) { return std::conj(x); } + static inline mag_type epsilon() { return DBL_EPSILON; } + static inline void set_real(val_type &x, const mag_type &val) { x.real(val); } + static inline void set_imag(val_type &x, const mag_type &val) { x.imag(val); } +}; + +template <> struct ArithTraits> { + typedef Kokkos::complex val_type; + typedef float mag_type; + + enum : bool { is_complex = true }; + static KOKKOS_FORCEINLINE_FUNCTION mag_type abs(const val_type &x) { return Kokkos::abs(x); } + static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type &x) { return x.real(); } + static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type &x) { return x.imag(); } + static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type &x) { return Kokkos::conj(x); } + static KOKKOS_FORCEINLINE_FUNCTION mag_type epsilon() { return FLT_EPSILON; } + static KOKKOS_FORCEINLINE_FUNCTION void set_real(val_type &x, const mag_type &val) { x.real(val); } + static KOKKOS_FORCEINLINE_FUNCTION void set_imag(val_type &x, const mag_type &val) { x.imag(val); } +}; + +template <> struct ArithTraits> { + typedef Kokkos::complex val_type; + typedef double mag_type; + + enum : bool { is_complex = true }; + static KOKKOS_FORCEINLINE_FUNCTION mag_type abs(const val_type &x) { return Kokkos::abs(x); } + static KOKKOS_FORCEINLINE_FUNCTION mag_type real(const val_type &x) { return x.real(); } + static KOKKOS_FORCEINLINE_FUNCTION mag_type imag(const val_type &x) { return x.imag(); } + static KOKKOS_FORCEINLINE_FUNCTION val_type conj(const val_type &x) { return Kokkos::conj(x); } + static KOKKOS_FORCEINLINE_FUNCTION mag_type epsilon() { return DBL_EPSILON; } + static KOKKOS_FORCEINLINE_FUNCTION void set_real(val_type &x, const mag_type &val) { x.real(val); } + static KOKKOS_FORCEINLINE_FUNCTION void set_imag(val_type &x, const mag_type &val) { x.imag(val); } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/Tacho_CommandLineParser.hpp b/packages/shylu/shylu_node/tacho/src/Tacho_CommandLineParser.hpp index ed620bec9d3a..40f5b4ff985e 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho_CommandLineParser.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho_CommandLineParser.hpp @@ -6,188 +6,175 @@ /// \author Kyungjoo Kim (kyukim@sandia.gov) // "std" includes -#include -#include #include +#include #include +#include #include namespace Tacho { - template - struct Option { - std::string _desc; - std::string _type; - std::string _val; - Option(std::string desc, T val); - Option(std::string val) : _val(val) {}; - T value() const; - }; - template<> Option::Option(std::string desc, bool val) - : _desc(desc), _type("bool"), _val(std::to_string(val)) {} - template<> Option::Option(std::string desc, int val) - : _desc(desc), _type("int"), _val(std::to_string(val)) {} - template<> Option::Option(std::string desc, std::string val) - : _desc(desc), _type("string"), _val(val) {} - - struct CommandLineParser { - private: - typedef std::tuple option_type; - - std::string _desc; - std::map _map; - - - public: - CommandLineParser() {}; - CommandLineParser(const std::string desc) : _desc(desc) {} - - template - void set_option(const std::string opt, - const std::string desc, - T *val) { - const auto in = Option(desc, *val); - _map[opt] = std::make_tuple(in._desc, in._type, in._val, val); - } - void print_option(const std::string argv0) { - std::cout << "Usage: " << argv0 << " [options]\n"; - std::cout << " options:\n"; - - const std::string prefix = "--"; - for (auto it=_map.begin();it!=_map.end();++it) { - auto key = prefix+it->first; - auto val = it->second; - std::cout << std::left - << " " - << std::setw(30) << key - << std::setw(10) << std::get<1>(val) - << std::get<0>(val) - << "\n"; - if (std::get<1>(val) == "bool") { - if (it->first != "help" && it->first != "echo-command-line") - std::cout << std::setw(42) << " " - << "(default: " << key << "=" << (std::get<2>(val) == "0" ? "false" : "true") - << ")\n"; - } else { - std::cout << std::setw(42) << " " - << "(default: " << key << "=" << std::get<2>(val) - << ")\n"; - } - } - std::cout << "Description:\n"; - std::cout << " " << _desc << "\n\n"; +template struct Option { + std::string _desc; + std::string _type; + std::string _val; + Option(std::string desc, T val); + Option(std::string val) : _val(val){}; + T value() const; +}; +template <> Option::Option(std::string desc, bool val) : _desc(desc), _type("bool"), _val(std::to_string(val)) {} +template <> Option::Option(std::string desc, int val) : _desc(desc), _type("int"), _val(std::to_string(val)) {} +template <> Option::Option(std::string desc, std::string val) : _desc(desc), _type("string"), _val(val) {} + +struct CommandLineParser { +private: + typedef std::tuple option_type; + + std::string _desc; + std::map _map; + +public: + CommandLineParser(){}; + CommandLineParser(const std::string desc) : _desc(desc) {} + + template void set_option(const std::string opt, const std::string desc, T *val) { + const auto in = Option(desc, *val); + _map[opt] = std::make_tuple(in._desc, in._type, in._val, val); + } + void print_option(const std::string argv0) { + std::cout << "Usage: " << argv0 << " [options]\n"; + std::cout << " options:\n"; + + const std::string prefix = "--"; + for (auto it = _map.begin(); it != _map.end(); ++it) { + auto key = prefix + it->first; + auto val = it->second; + std::cout << std::left << " " << std::setw(30) << key << std::setw(10) << std::get<1>(val) << std::get<0>(val) + << "\n"; + if (std::get<1>(val) == "bool") { + if (it->first != "help" && it->first != "echo-command-line") + std::cout << std::setw(42) << " " + << "(default: " << key << "=" << (std::get<2>(val) == "0" ? "false" : "true") << ")\n"; + } else { + std::cout << std::setw(42) << " " + << "(default: " << key << "=" << std::get<2>(val) << ")\n"; } - bool parse(int argc, char **argv) { - bool help = false, echo = false; - this->set_option("help", "Print this help message", &help); - this->set_option("echo-command-line", "Echo the command-line but continue as normal", &echo); - - // check help - for (int i=1;isecond; - // find option starting with -- - if (s.find("--") != std::string::npos && s.length() != 2) { - std::string desc = std::get<0>(t); - std::string type = std::get<1>(t); - - size_t pos = s.find("="); - if (pos != std::string::npos) { - // --opt=val - std::string key = s.substr(2, pos-2); - if (key == it->first) { - std::string sval = s.substr(pos+1, s.length()); - void *tval = std::get<3>(t); - if (tval != NULL) { - if (type == "int") { - *((int*)tval) = atoi(sval.c_str()); - } else if (type == "string") { - *((std::string*)tval) = sval; - } else if (type == "bool") { - *((bool*)tval) = (sval == "true"); - } else { - std::cout << " int somethng wrong\n"; - } - _map[it->first] = std::make_tuple(desc, type, sval, (void*)NULL); - } + } + std::cout << "Description:\n"; + std::cout << " " << _desc << "\n\n"; + } + bool parse(int argc, char **argv) { + bool help = false, echo = false; + this->set_option("help", "Print this help message", &help); + this->set_option("echo-command-line", "Echo the command-line but continue as normal", &echo); + + // check help + for (int i = 1; i < argc; ++i) { + std::string s = argv[i]; + if (s == "--help") + help = true; + if (s == "--echo-command-line") + echo = true; + } + + if (help) { + print_option(argv[0]); + } else { + // parse + for (int i = 1; i < argc; ++i) { + std::string s = argv[i]; + for (auto it = _map.begin(); it != _map.end(); ++it) { + const auto t = it->second; + // find option starting with -- + if (s.find("--") != std::string::npos && s.length() != 2) { + std::string desc = std::get<0>(t); + std::string type = std::get<1>(t); + + size_t pos = s.find("="); + if (pos != std::string::npos) { + // --opt=val + std::string key = s.substr(2, pos - 2); + if (key == it->first) { + std::string sval = s.substr(pos + 1, s.length()); + void *tval = std::get<3>(t); + if (tval != NULL) { + if (type == "int") { + *((int *)tval) = atoi(sval.c_str()); + } else if (type == "string") { + *((std::string *)tval) = sval; + } else if (type == "bool") { + *((bool *)tval) = (sval == "true"); + } else { + std::cout << " int somethng wrong\n"; } - } - else { - // --opt (bool) - std::string key = s.substr(2, s.length()); - if (key == it->first) { - void *tval = std::get<3>(t); - if (tval != NULL) { - if (type == "bool") { - *((bool*)tval) = true; - } else { - std::cout << " bool somethng wrong\n"; - } - _map[it->first] = std::make_tuple(desc, type, "1", (void*)NULL); - } + _map[it->first] = std::make_tuple(desc, type, sval, (void *)NULL); + } + } + } else { + // --opt (bool) + std::string key = s.substr(2, s.length()); + if (key == it->first) { + void *tval = std::get<3>(t); + if (tval != NULL) { + if (type == "bool") { + *((bool *)tval) = true; + } else { + std::cout << " bool somethng wrong\n"; } + _map[it->first] = std::make_tuple(desc, type, "1", (void *)NULL); } } } } + } + } - // print out unused options - if (echo) { - // echo command - std::cout << "Echoing the command-line:\n\n"; - for (int i=0;ifirst; - std::string type = std::get<1>(it->second); - std::string val = std::get<2>(it->second); - bool used = std::get<3>(it->second) == NULL; - if (used) { - std::cout << " "; - if (type == "bool") - std::cout << "--" << key << "\n"; - else - std::cout << "--" << key << "=" << val << "\n"; - } - } - std::cout << "\n"; - - // not used options - std::cout << "Not used options:\n\n"; - for (auto it=_map.begin();it!=_map.end();++it) { - std::string key = it->first; - std::string type = std::get<1>(it->second); - std::string val = std::get<2>(it->second); - bool used = std::get<3>(it->second) == NULL; - if (!used) { - std::cout << " "; - if (type == "bool") - std::cout << "--" << key << "\n"; - else - std::cout << "--" << key << "=" << val << " (default)\n"; - } - } - std::cout << "\n"; + // print out unused options + if (echo) { + // echo command + std::cout << "Echoing the command-line:\n\n"; + for (int i = 0; i < argc; ++i) + std::cout << argv[i] << " "; + std::cout << "\n\n"; + + // used options + std::cout << "Used options:\n\n"; + for (auto it = _map.begin(); it != _map.end(); ++it) { + std::string key = it->first; + std::string type = std::get<1>(it->second); + std::string val = std::get<2>(it->second); + bool used = std::get<3>(it->second) == NULL; + if (used) { + std::cout << " "; + if (type == "bool") + std::cout << "--" << key << "\n"; + else + std::cout << "--" << key << "=" << val << "\n"; + } + } + std::cout << "\n"; + + // not used options + std::cout << "Not used options:\n\n"; + for (auto it = _map.begin(); it != _map.end(); ++it) { + std::string key = it->first; + std::string type = std::get<1>(it->second); + std::string val = std::get<2>(it->second); + bool used = std::get<3>(it->second) == NULL; + if (!used) { + std::cout << " "; + if (type == "bool") + std::cout << "--" << key << "\n"; + else + std::cout << "--" << key << "=" << val << " (default)\n"; } } - return help; + std::cout << "\n"; } - }; -} + } + return help; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/Tacho_CrsMatrixBase.hpp b/packages/shylu/shylu_node/tacho/src/Tacho_CrsMatrixBase.hpp index 74b6f188e186..dfcc39ed15d3 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho_CrsMatrixBase.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho_CrsMatrixBase.hpp @@ -7,327 +7,263 @@ #include "Tacho.hpp" -namespace Tacho { - - /// \class CrsMatrixBase - /// \breif CRS matrix base object using Kokkos view and subview - template - class CrsMatrixBase { - public: - typedef ValueType value_type; - typedef DeviceType device_type; - - typedef typename device_type::execution_space exec_space; - typedef typename device_type::memory_space exec_memory_space; - - typedef Kokkos::Device host_device_type; - - typedef typename host_device_type::execution_space host_space; - typedef typename host_device_type::memory_space host_memory_space; - - typedef Kokkos::View value_type_array; - typedef Kokkos::View ordinal_type_array; - typedef Kokkos::View size_type_array; - - template - friend class CrsMatrixBase; - - private: - ordinal_type _m; //!< # of rows - ordinal_type _n; //!< # of cols - size_type _nnz; //!< # of nonzeros - - size_type_array _ap; //!< pointers to column index and values - ordinal_type_array _aj; //!< column index compressed format - value_type_array _ax; //!< values - - protected: - - void createInternal(const ordinal_type m, - const ordinal_type n, - const size_type nnz) { - _m = m; - _n = n; - _nnz = nnz; - - if (static_cast(_ap.extent(0)) < (m+1)) - _ap = size_type_array("CrsMatrixBase::RowPtrArray", m+1); - else - Kokkos::deep_copy(_ap, size_type()); - - if (static_cast(_aj.extent(0)) < nnz) - _aj = ordinal_type_array("CrsMatrixBase::ColsArray", nnz); - else - Kokkos::deep_copy(_aj, ordinal_type()); - - if (static_cast(_ax.extent(0)) < nnz) - _ax = value_type_array("CrsMatrixBase::ValuesArray", nnz); - else - Kokkos::deep_copy(_ax, value_type()); - } - - public: - - /// Interface functions - /// ------------------------------------------------------------------ - void setExternalMatrix(const ordinal_type m, - const ordinal_type n, - const size_type nnz, - const size_type_array &ap, - const ordinal_type_array &aj, - const value_type_array &ax) { - // TODO:: check each array's space - _m = m; - _n = n; - _nnz = nnz; - _ap = ap; - _aj = aj; - _ax = ax; - } +namespace Tacho { + +/// \class CrsMatrixBase +/// \breif CRS matrix base object using Kokkos view and subview +template class CrsMatrixBase { +public: + typedef ValueType value_type; + typedef DeviceType device_type; + + typedef typename device_type::execution_space exec_space; + typedef typename device_type::memory_space exec_memory_space; + + typedef Kokkos::Device host_device_type; + + typedef typename host_device_type::execution_space host_space; + typedef typename host_device_type::memory_space host_memory_space; + + typedef Kokkos::View value_type_array; + typedef Kokkos::View ordinal_type_array; + typedef Kokkos::View size_type_array; + + template friend class CrsMatrixBase; + +private: + ordinal_type _m; //!< # of rows + ordinal_type _n; //!< # of cols + size_type _nnz; //!< # of nonzeros + + size_type_array _ap; //!< pointers to column index and values + ordinal_type_array _aj; //!< column index compressed format + value_type_array _ax; //!< values + +protected: + void createInternal(const ordinal_type m, const ordinal_type n, const size_type nnz) { + _m = m; + _n = n; + _nnz = nnz; + + if (static_cast(_ap.extent(0)) < (m + 1)) + _ap = size_type_array("CrsMatrixBase::RowPtrArray", m + 1); + else + Kokkos::deep_copy(_ap, size_type()); + + if (static_cast(_aj.extent(0)) < nnz) + _aj = ordinal_type_array("CrsMatrixBase::ColsArray", nnz); + else + Kokkos::deep_copy(_aj, ordinal_type()); + + if (static_cast(_ax.extent(0)) < nnz) + _ax = value_type_array("CrsMatrixBase::ValuesArray", nnz); + else + Kokkos::deep_copy(_ax, value_type()); + } - void setExternalMatrix(const ordinal_type m, - const ordinal_type n, - const ordinal_type nnz, - const size_type* ap, - const ordinal_type* aj, - const value_type* ax) { - _m = m; - _n = n; - _nnz = nnz; - _ap = size_type_array(ap, _m+1); - _aj = ordinal_type_array(aj, _nnz); - _ax = value_type_array(ax, _nnz); - } +public: + /// Interface functions + /// ------------------------------------------------------------------ + void setExternalMatrix(const ordinal_type m, const ordinal_type n, const size_type nnz, const size_type_array &ap, + const ordinal_type_array &aj, const value_type_array &ax) { + // TODO:: check each array's space + _m = m; + _n = n; + _nnz = nnz; + _ap = ap; + _aj = aj; + _ax = ax; + } - void setNumNonZeros() { - if (_m) { - auto last = Kokkos::subview(_ap, _m); - auto h_last = Kokkos::create_mirror_view(host_memory_space(), last); Kokkos::deep_copy(h_last, last); - _nnz = h_last(); - } + void setExternalMatrix(const ordinal_type m, const ordinal_type n, const ordinal_type nnz, const size_type *ap, + const ordinal_type *aj, const value_type *ax) { + _m = m; + _n = n; + _nnz = nnz; + _ap = size_type_array(ap, _m + 1); + _aj = ordinal_type_array(aj, _nnz); + _ax = value_type_array(ax, _nnz); + } + + void setNumNonZeros() { + if (_m) { + auto last = Kokkos::subview(_ap, _m); + auto h_last = Kokkos::create_mirror_view(host_memory_space(), last); + Kokkos::deep_copy(h_last, last); + _nnz = h_last(); } + } - KOKKOS_INLINE_FUNCTION - size_type_array& RowPtr() { return _ap; } + KOKKOS_INLINE_FUNCTION + size_type_array &RowPtr() { return _ap; } - KOKKOS_INLINE_FUNCTION - ordinal_type_array& Cols() { return _aj; } + KOKKOS_INLINE_FUNCTION + ordinal_type_array &Cols() { return _aj; } - KOKKOS_INLINE_FUNCTION - value_type_array& Values() { return _ax; } - - KOKKOS_INLINE_FUNCTION - ordinal_type NumRows() const { return _m; } + KOKKOS_INLINE_FUNCTION + value_type_array &Values() { return _ax; } - KOKKOS_INLINE_FUNCTION - ordinal_type NumCols() const { return _n; } + KOKKOS_INLINE_FUNCTION + ordinal_type NumRows() const { return _m; } - KOKKOS_INLINE_FUNCTION - size_type NumNonZeros() const { return _nnz; } + KOKKOS_INLINE_FUNCTION + ordinal_type NumCols() const { return _n; } - KOKKOS_INLINE_FUNCTION - size_type& RowPtrBegin(const ordinal_type i) { return _ap[i]; } + KOKKOS_INLINE_FUNCTION + size_type NumNonZeros() const { return _nnz; } - KOKKOS_INLINE_FUNCTION - size_type RowPtrBegin(const ordinal_type i) const { return _ap[i]; } + KOKKOS_INLINE_FUNCTION + size_type &RowPtrBegin(const ordinal_type i) { return _ap[i]; } - KOKKOS_INLINE_FUNCTION - size_type& RowPtrEnd(const ordinal_type i) { return _ap[i+1]; } - - KOKKOS_INLINE_FUNCTION - size_type RowPtrEnd(const ordinal_type i) const { return _ap[i+1]; } + KOKKOS_INLINE_FUNCTION + size_type RowPtrBegin(const ordinal_type i) const { return _ap[i]; } - KOKKOS_INLINE_FUNCTION - ordinal_type& Col(const ordinal_type k) { return _aj[k]; } + KOKKOS_INLINE_FUNCTION + size_type &RowPtrEnd(const ordinal_type i) { return _ap[i + 1]; } - KOKKOS_INLINE_FUNCTION - ordinal_type Col(const ordinal_type k) const { return _aj[k]; } + KOKKOS_INLINE_FUNCTION + size_type RowPtrEnd(const ordinal_type i) const { return _ap[i + 1]; } - KOKKOS_INLINE_FUNCTION - value_type& Value(const ordinal_type k) { return _ax[k]; } + KOKKOS_INLINE_FUNCTION + ordinal_type &Col(const ordinal_type k) { return _aj[k]; } - KOKKOS_INLINE_FUNCTION - value_type Value(const ordinal_type k) const { return _ax[k]; } + KOKKOS_INLINE_FUNCTION + ordinal_type Col(const ordinal_type k) const { return _aj[k]; } + KOKKOS_INLINE_FUNCTION + value_type &Value(const ordinal_type k) { return _ax[k]; } - /// Constructors - /// ------------------------------------------------------------------ + KOKKOS_INLINE_FUNCTION + value_type Value(const ordinal_type k) const { return _ax[k]; } - /// \brief Default constructor. - CrsMatrixBase() - : _m(0), - _n(0), - _nnz(0), - _ap(), - _aj(), - _ax() - { - } + /// Constructors + /// ------------------------------------------------------------------ - /// \brief Constructor with label - CrsMatrixBase(const ordinal_type m, - const ordinal_type n, - const size_type nnz) - : _m(0), - _n(0), - _nnz(0), - _ap(), - _aj(), - _ax() - { - createInternal(m, n, nnz); - } + /// \brief Default constructor. + CrsMatrixBase() : _m(0), _n(0), _nnz(0), _ap(), _aj(), _ax() {} - /// \brief Copy constructor (shallow copy), for deep-copy use a method copy - KOKKOS_INLINE_FUNCTION - CrsMatrixBase(const CrsMatrixBase &b) - : _m(b._m), - _n(b._n), - _nnz(b._nnz), - _ap(b._ap), - _aj(b._aj), - _ax(b._ax) - { - } - - KOKKOS_INLINE_FUNCTION - ~CrsMatrixBase() {} //= default; + /// \brief Constructor with label + CrsMatrixBase(const ordinal_type m, const ordinal_type n, const size_type nnz) + : _m(0), _n(0), _nnz(0), _ap(), _aj(), _ax() { + createInternal(m, n, nnz); + } + /// \brief Copy constructor (shallow copy), for deep-copy use a method copy + KOKKOS_INLINE_FUNCTION + CrsMatrixBase(const CrsMatrixBase &b) : _m(b._m), _n(b._n), _nnz(b._nnz), _ap(b._ap), _aj(b._aj), _ax(b._ax) {} - /// Create - /// ------------------------------------------------------------------ + KOKKOS_INLINE_FUNCTION + ~CrsMatrixBase() {} //= default; - void - clear() { - _m = 0; _n = 0; _nnz = 0; + /// Create + /// ------------------------------------------------------------------ - _ap = size_type_array(); - _aj = ordinal_type_array(); - _ax = value_type_array(); - } + void clear() { + _m = 0; + _n = 0; + _nnz = 0; - void - create(const ordinal_type m, - const ordinal_type n, - const size_type nnz) { - createInternal(m, n, nnz); - } + _ap = size_type_array(); + _aj = ordinal_type_array(); + _ax = value_type_array(); + } - template - void - createConfTo(const CrsMatrixBase &b) { - createInternal(b._m, b._n, b._nnz); - } + void create(const ordinal_type m, const ordinal_type n, const size_type nnz) { createInternal(m, n, nnz); } - - /// Create - /// ------------------------------------------------------------------ - - /// \brief deep copy of matrix b - template - inline - void - createMirror(const CrsMatrixBase &b) { - _m = b._m; - _n = b._n; - _nnz = b._nnz; - - _ap = Kokkos::create_mirror_view(exec_memory_space(), b._ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), b._aj); - _ax = Kokkos::create_mirror_view(exec_memory_space(), b._ax); - } + template void createConfTo(const CrsMatrixBase &b) { + createInternal(b._m, b._n, b._nnz); + } - /// \brief deep copy of matrix b - template - inline - void - copy(const CrsMatrixBase &b) { - Kokkos::deep_copy(_ap, b._ap); - Kokkos::deep_copy(_aj, b._aj); - Kokkos::deep_copy(_ax, b._ax); - } + /// Create + /// ------------------------------------------------------------------ - /// \brief print out to stream - inline - std::ostream& showMe(std::ostream &os, const bool detail = false) const { - std::streamsize prec = os.precision(); - os.precision(16); - os << std::scientific; - - os << " -- CrsMatrixBase -- " << std::endl - << " # of Rows = " << _m << std::endl - << " # of Cols = " << _n << std::endl - << " # of NonZeros = " << _nnz << std::endl - << std::endl - << " RowPtrArray length = " << _ap.extent(0) << std::endl - << " ColArray length = " << _aj.extent(0) << std::endl - << " ValueArray length = " << _ax.extent(0) << std::endl - << std::endl - << " Memory = " - << double( _ap.span()*sizeof(size_type) + - _aj.span()*sizeof(ordinal_type) + - _ax.span()*sizeof(value_type) )/1e6 << " MB" - << std::endl << std::endl; - - if (detail) { - const int w = 10; - if ( (ordinal_type(_ap.size()) > _m ) && - (size_type(_aj.size()) >= _nnz) && - (size_type(_ax.size()) >= _nnz) ) { - os << std::setw(w) << "Row" << " " - << std::setw(w) << "Col" << " " - << std::setw(w) << "Val" << std::endl; - auto h_ap = Kokkos::create_mirror_view_and_copy(host_memory_space(), _ap); - auto h_aj = Kokkos::create_mirror_view_and_copy(host_memory_space(), _aj); - auto h_ax = Kokkos::create_mirror_view_and_copy(host_memory_space(), _ax); - for (ordinal_type i=0;i<_m;++i) { - const size_type jbegin = h_ap[i], jend = h_ap[i+1]; - for (size_type j=jbegin;j inline void createMirror(const CrsMatrixBase &b) { + _m = b._m; + _n = b._n; + _nnz = b._nnz; + + _ap = Kokkos::create_mirror_view(exec_memory_space(), b._ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), b._aj); + _ax = Kokkos::create_mirror_view(exec_memory_space(), b._ax); + } + + /// \brief deep copy of matrix b + template inline void copy(const CrsMatrixBase &b) { + Kokkos::deep_copy(_ap, b._ap); + Kokkos::deep_copy(_aj, b._aj); + Kokkos::deep_copy(_ax, b._ax); + } + + /// \brief print out to stream + inline std::ostream &showMe(std::ostream &os, const bool detail = false) const { + std::streamsize prec = os.precision(); + os.precision(16); + os << std::scientific; + + os << " -- CrsMatrixBase -- " << std::endl + << " # of Rows = " << _m << std::endl + << " # of Cols = " << _n << std::endl + << " # of NonZeros = " << _nnz << std::endl + << std::endl + << " RowPtrArray length = " << _ap.extent(0) << std::endl + << " ColArray length = " << _aj.extent(0) << std::endl + << " ValueArray length = " << _ax.extent(0) << std::endl + << std::endl + << " Memory = " + << double(_ap.span() * sizeof(size_type) + _aj.span() * sizeof(ordinal_type) + _ax.span() * sizeof(value_type)) / + 1e6 + << " MB" << std::endl + << std::endl; + + if (detail) { + const int w = 10; + if ((ordinal_type(_ap.size()) > _m) && (size_type(_aj.size()) >= _nnz) && (size_type(_ax.size()) >= _nnz)) { + os << std::setw(w) << "Row" + << " " << std::setw(w) << "Col" + << " " << std::setw(w) << "Val" << std::endl; + auto h_ap = Kokkos::create_mirror_view_and_copy(host_memory_space(), _ap); + auto h_aj = Kokkos::create_mirror_view_and_copy(host_memory_space(), _aj); + auto h_ax = Kokkos::create_mirror_view_and_copy(host_memory_space(), _ax); + for (ordinal_type i = 0; i < _m; ++i) { + const size_type jbegin = h_ap[i], jend = h_ap[i + 1]; + for (size_type j = jbegin; j < jend; ++j) { + value_type val = h_ax[j]; + os << std::setw(w) << i << " " << std::setw(w) << h_aj[j] << " " << std::setw(w) << std::showpos << val + << std::noshowpos << std::endl; } } } - - os.unsetf(std::ios::scientific); - os.precision(prec); - - return os; } - }; - - // A = P B P^{-1} - template - inline - static void - applyPermutationToCrsMatrix(/* */ CrsMatrixType &A, - const CrsMatrixType &B, - const OrdinalTypeArray &p, - const OrdinalTypeArray &ip) { - const ordinal_type m = A.NumRows();//, n = A.NumCols(); - typedef typename CrsMatrixType::exec_space exec_space; - typedef typename CrsMatrixType::exec_memory_space exec_memory_space; - - /// temporary matrix with the same structure - auto ap = A.RowPtr(); - auto aj = A.Cols(); - auto ax = A.Values(); - - auto perm = Kokkos::create_mirror_view(exec_memory_space(), p); Kokkos::deep_copy(perm, p); - auto peri = Kokkos::create_mirror_view(exec_memory_space(), ip); Kokkos::deep_copy(peri, ip); - { /// permute row indices (exclusive scan) - Kokkos::RangePolicy > policy(0, m+1); - Kokkos::parallel_scan - (policy, KOKKOS_LAMBDA(const ordinal_type &i, size_type &update, - const bool &final) { + os.unsetf(std::ios::scientific); + os.precision(prec); + + return os; + } +}; + +// A = P B P^{-1} +template +inline static void applyPermutationToCrsMatrix(/* */ CrsMatrixType &A, const CrsMatrixType &B, + const OrdinalTypeArray &p, const OrdinalTypeArray &ip) { + const ordinal_type m = A.NumRows(); //, n = A.NumCols(); + typedef typename CrsMatrixType::exec_space exec_space; + typedef typename CrsMatrixType::exec_memory_space exec_memory_space; + + /// temporary matrix with the same structure + auto ap = A.RowPtr(); + auto aj = A.Cols(); + auto ax = A.Values(); + + auto perm = Kokkos::create_mirror_view(exec_memory_space(), p); + Kokkos::deep_copy(perm, p); + auto peri = Kokkos::create_mirror_view(exec_memory_space(), ip); + Kokkos::deep_copy(peri, ip); + { /// permute row indices (exclusive scan) + Kokkos::RangePolicy> policy(0, m + 1); + Kokkos::parallel_scan( + policy, KOKKOS_LAMBDA(const ordinal_type &i, size_type &update, const bool &final) { if (final) ap(i) = update; @@ -336,88 +272,78 @@ namespace Tacho { update += (B.RowPtrEnd(row) - B.RowPtrBegin(row)); } }); - Kokkos::fence(); - } - { /// permute col indices (do not sort) - typedef Kokkos::TeamPolicy team_policy_type; - team_policy_type policy(m, 1, 1); //Kokkos::AUTO()); ///, Kokkos::AUTO()); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const typename team_policy_type::member_type &member) { - const ordinal_type - i = member.league_rank(), /// row in A - kbeg = ap(i), - kend = ap(i+1), - nk = kend - kbeg, - row = perm(i), /// row in B - colbeg = B.RowPtrBegin(row); - //colend = B.RowPtrEnd(row); - - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, nk), - [&,kbeg,colbeg](const ordinal_type &k) { /// compiler bug with c++14 lambda capturing and workaround - const ordinal_type tk = kbeg+k, sk = colbeg+k; - aj(tk) = peri(B.Col(sk)); - ax(tk) = B.Value(sk); - }); + Kokkos::fence(); + } + { /// permute col indices (do not sort) + typedef Kokkos::TeamPolicy team_policy_type; + team_policy_type policy(m, 1, 1); // Kokkos::AUTO()); ///, Kokkos::AUTO()); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const typename team_policy_type::member_type &member) { + const ordinal_type i = member.league_rank(), /// row in A + kbeg = ap(i), kend = ap(i + 1), nk = kend - kbeg, + row = perm(i), /// row in B + colbeg = B.RowPtrBegin(row); + // colend = B.RowPtrEnd(row); + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, nk), + [&, kbeg, colbeg](const ordinal_type &k) { /// compiler bug with c++14 lambda capturing and workaround + const ordinal_type tk = kbeg + k, sk = colbeg + k; + aj(tk) = peri(B.Col(sk)); + ax(tk) = B.Value(sk); + }); }); - Kokkos::fence(); - } + Kokkos::fence(); } - - // A = P B P^{-1} - template - inline - static void - applyPermutationToCrsMatrixLower(/* */ CrsMatrixType &A, - const CrsMatrixType &B, - const OrdinalTypeArray &p, - const OrdinalTypeArray &ip) { - const ordinal_type m = A.NumRows(); //, n = A.NumCols(); - typedef typename CrsMatrixType::exec_space exec_space; - typedef typename CrsMatrixType::exec_memory_space exec_memory_space; - - /// temporary matrix with the same structure - auto ap = A.RowPtr(); - auto aj = A.Cols(); - auto ax = A.Values(); - - auto perm = Kokkos::create_mirror_view(exec_memory_space(), p); Kokkos::deep_copy(perm, p); - auto peri = Kokkos::create_mirror_view(exec_memory_space(), ip); Kokkos::deep_copy(peri, ip); - - { /// permute row indices (exclusive scan) - Kokkos::RangePolicy > policy(0, m+1); - Kokkos::parallel_scan - (policy, KOKKOS_LAMBDA(const ordinal_type &i, size_type &update, - const bool &final) { +} + +// A = P B P^{-1} +template +inline static void applyPermutationToCrsMatrixLower(/* */ CrsMatrixType &A, const CrsMatrixType &B, + const OrdinalTypeArray &p, const OrdinalTypeArray &ip) { + const ordinal_type m = A.NumRows(); //, n = A.NumCols(); + typedef typename CrsMatrixType::exec_space exec_space; + typedef typename CrsMatrixType::exec_memory_space exec_memory_space; + + /// temporary matrix with the same structure + auto ap = A.RowPtr(); + auto aj = A.Cols(); + auto ax = A.Values(); + + auto perm = Kokkos::create_mirror_view(exec_memory_space(), p); + Kokkos::deep_copy(perm, p); + auto peri = Kokkos::create_mirror_view(exec_memory_space(), ip); + Kokkos::deep_copy(peri, ip); + + { /// permute row indices (exclusive scan) + Kokkos::RangePolicy> policy(0, m + 1); + Kokkos::parallel_scan( + policy, KOKKOS_LAMBDA(const ordinal_type &i, size_type &update, const bool &final) { if (final) ap(i) = update; if (i < m) { ordinal_type count(0); const ordinal_type row = perm(i); /// row in B - for (ordinal_type k=B.RowPtrBegin(row),kend=B.RowPtrEnd(row);k=j); /// lower triangular + count += (i >= j); /// lower triangular } update += count; } }); - Kokkos::fence(); - } - { /// permute col indices (do not sort) - Kokkos::RangePolicy > policy(0, m); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const ordinal_type &i) { /// row in A - const ordinal_type - kbeg = ap(i), - row = perm(i), /// row in B - colbeg = B.RowPtrBegin(row), - colend = B.RowPtrEnd(row), - nk = colend - colbeg; - - for (ordinal_type k=0,t=0;k> policy(0, m); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &i) { /// row in A + const ordinal_type kbeg = ap(i), + row = perm(i), /// row in B + colbeg = B.RowPtrBegin(row), colend = B.RowPtrEnd(row), nk = colend - colbeg; + + for (ordinal_type k = 0, t = 0; k < nk; ++k) { + const ordinal_type tk = kbeg + t, sk = colbeg + k; const ordinal_type j = peri(B.Col(sk)); /// col in A if (i >= j) { aj(tk) = j; @@ -426,54 +352,52 @@ namespace Tacho { } } }); - Kokkos::fence(); - } - - /// update A with new nnz (for now let's not resize) - A.setNumNonZeros(); + Kokkos::fence(); } - template - inline - double - computeRelativeResidual(const CrsMatrixBase &A, - const Kokkos::View &x, - const Kokkos::View &b) { - const bool test = (size_t(A.NumRows()) != size_t(A.NumCols()) || - size_t(A.NumRows()) != size_t(b.extent(0)) || - size_t(x.extent(0)) != size_t(b.extent(0)) || - size_t(x.extent(1)) != size_t(b.extent(1))); - if (test) - throw std::logic_error("A,x and b dimensions are not compatible"); - - typedef ValueType value_type; - typedef typename UseThisDevice::type host_device_type; - typedef typename host_device_type::memory_space host_memory_space; - - CrsMatrixBase h_A; - h_A.createMirror(A); h_A.copy(A); - - auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); Kokkos::deep_copy(h_x, x); - auto h_b = Kokkos::create_mirror_view(host_memory_space(), b); Kokkos::deep_copy(h_b, b); - - typedef ArithTraits arith_traits; - const ordinal_type m = h_A.NumRows(), k = h_b.extent(1); - double diff = 0, norm = 0; - for (ordinal_type i=0;i +inline double computeRelativeResidual(const CrsMatrixBase &A, + const Kokkos::View &x, + const Kokkos::View &b) { + const bool test = (size_t(A.NumRows()) != size_t(A.NumCols()) || size_t(A.NumRows()) != size_t(b.extent(0)) || + size_t(x.extent(0)) != size_t(b.extent(0)) || size_t(x.extent(1)) != size_t(b.extent(1))); + if (test) + throw std::logic_error("A,x and b dimensions are not compatible"); + + typedef ValueType value_type; + typedef typename UseThisDevice::type host_device_type; + typedef typename host_device_type::memory_space host_memory_space; + + CrsMatrixBase h_A; + h_A.createMirror(A); + h_A.copy(A); + + auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); + Kokkos::deep_copy(h_x, x); + auto h_b = Kokkos::create_mirror_view(host_memory_space(), b); + Kokkos::deep_copy(h_b, b); + + typedef ArithTraits arith_traits; + const ordinal_type m = h_A.NumRows(), k = h_b.extent(1); + double diff = 0, norm = 0; + for (ordinal_type i = 0; i < m; ++i) { + for (ordinal_type p = 0; p < k; ++p) { + value_type s = 0; + const ordinal_type jbeg = h_A.RowPtrBegin(i), jend = h_A.RowPtrEnd(i); + for (ordinal_type j = jbeg; j < jend; ++j) { + const ordinal_type col = h_A.Col(j); + s += h_A.Value(j) * h_x(col, p); } + norm += arith_traits::real(h_b(i, p) * arith_traits::conj(h_b(i, p))); + diff += arith_traits::real((h_b(i, p) - s) * arith_traits::conj(h_b(i, p) - s)); } - return sqrt(diff/norm); } - + return sqrt(diff / norm); +} -} // namespace tacho +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/Tacho_CuSolver.hpp b/packages/shylu/shylu_node/tacho/src/Tacho_CuSolver.hpp index 06296e3f21dd..d6da1a0f3461 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho_CuSolver.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho_CuSolver.hpp @@ -1,215 +1,206 @@ #ifndef __TACHO_CUSOLVER_HPP__ #define __TACHO_CUSOLVER_HPP__ -#if defined (KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) #include "Tacho_Util.hpp" //#include "cusparse.h" -#include "cusparse_v2.h" #include "cusolverSp.h" #include "cusolverSp_LOWLEVEL_PREVIEW.h" +#include "cusparse_v2.h" #include #include namespace Tacho { - class CuSolver { - public: - typedef double value_type; - - typedef typename UseThisDevice::type device_type; - - typedef typename device_type::execution_space exec_space; - typedef typename device_type::memory_space exec_memory_space; - - typedef typename UseThisDevice::type host_device_type; - typedef typename host_device_type::execution_space host_space; - typedef typename host_device_type::memory_space host_memory_space; - - typedef Kokkos::View size_type_array; - typedef Kokkos::View ordinal_type_array; - typedef Kokkos::View value_type_array; - typedef Kokkos::View value_type_matrix; - - typedef Kokkos::View size_type_array_host; - typedef Kokkos::View ordinal_type_array_host; - typedef Kokkos::View value_type_array_host; - typedef Kokkos::View value_type_matrix_host; - - private: - cusolverSpHandle_t _handle; - csrcholInfo_t _chol_info; - cusparseMatDescr_t _desc; - int _status; - - ordinal_type _m; - size_type _nnz; - - size_type_array _ap; /// ap_ordinal is used to interface to cuSolver - ordinal_type_array _ap_ordinal, _aj; - - value_type_array _buf; - - ordinal_type _verbose; - - void checkStatus(const char *s) { - if (_status != 0) - printf("Error: %s, status %d\n", s, _status); +class CuSolver { +public: + typedef double value_type; + + typedef typename UseThisDevice::type device_type; + + typedef typename device_type::execution_space exec_space; + typedef typename device_type::memory_space exec_memory_space; + + typedef typename UseThisDevice::type host_device_type; + typedef typename host_device_type::execution_space host_space; + typedef typename host_device_type::memory_space host_memory_space; + + typedef Kokkos::View size_type_array; + typedef Kokkos::View ordinal_type_array; + typedef Kokkos::View value_type_array; + typedef Kokkos::View value_type_matrix; + + typedef Kokkos::View size_type_array_host; + typedef Kokkos::View ordinal_type_array_host; + typedef Kokkos::View value_type_array_host; + typedef Kokkos::View value_type_matrix_host; + +private: + cusolverSpHandle_t _handle; + csrcholInfo_t _chol_info; + cusparseMatDescr_t _desc; + int _status; + + ordinal_type _m; + size_type _nnz; + + size_type_array _ap; /// ap_ordinal is used to interface to cuSolver + ordinal_type_array _ap_ordinal, _aj; + + value_type_array _buf; + + ordinal_type _verbose; + + void checkStatus(const char *s) { + if (_status != 0) + printf("Error: %s, status %d\n", s, _status); + } + +public: + CuSolver() { + _status = cusolverSpCreate(&_handle); + checkStatus("cusolverSpCreate"); + _status = cusolverSpCreateCsrcholInfo(&_chol_info); + checkStatus("cusolverSpCreateCsrcholInfo"); + _status = cusparseCreateMatDescr(&_desc); + checkStatus("cusparseCreateMatDescr"); + } + virtual ~CuSolver() { + _status = cusparseDestroyMatDescr(_desc); + checkStatus("cusparseDestroyMatDescr"); + _status = cusolverSpDestroyCsrcholInfo(_chol_info); + checkStatus("cusolverSpDestroyCsrcholInfo"); + _status = cusolverSpDestroy(_handle); + checkStatus("cusolverSpDestroy"); + } + + void setVerbose(const ordinal_type verbose = 1) { _verbose = verbose; } + + template + int analyze(const ordinal_type m, const arg_size_type_array &ap, const arg_ordinal_type_array &aj) { + if (_verbose) { + printf("cuSolver: Analyze\n"); + printf("=================\n"); } - - public: - CuSolver() { - _status = cusolverSpCreate(&_handle); checkStatus("cusolverSpCreate"); - _status = cusolverSpCreateCsrcholInfo(&_chol_info); checkStatus("cusolverSpCreateCsrcholInfo"); - _status = cusparseCreateMatDescr(&_desc); checkStatus("cusparseCreateMatDescr"); + Kokkos::Timer timer; + + _m = m; + + timer.reset(); + + _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); + Kokkos::deep_copy(_ap, ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); + Kokkos::deep_copy(_aj, aj); + Kokkos::fence(); +#if defined(TACHO_USE_INT_INT) + _ap_ordinal = ap; +#else + /// LAMBDA cannot capture this pointer; make all variables local + { + ordinal_type_array l_ap_ordinal(do_not_initialize_tag("CuSolver::ap_ordinal"), _ap.extent(0)); + auto l_ap = _ap; + Kokkos::RangePolicy> policy(0, l_ap.extent(0)); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const int i) { l_ap_ordinal(i) = static_cast(l_ap(i)); }); + _ap_ordinal = l_ap_ordinal; } - virtual~CuSolver() { - _status = cusparseDestroyMatDescr(_desc); checkStatus("cusparseDestroyMatDescr"); - _status = cusolverSpDestroyCsrcholInfo(_chol_info); checkStatus("cusolverSpDestroyCsrcholInfo"); - _status = cusolverSpDestroy(_handle); checkStatus("cusolverSpDestroy"); +#endif + auto last = Kokkos::subview(_ap, _m); + auto h_last = Kokkos::create_mirror_view(host_memory_space(), last); + Kokkos::deep_copy(h_last, last); + _nnz = h_last(); + const double t_copy = timer.seconds(); + + timer.reset(); + _status = cusolverSpXcsrcholAnalysis(_handle, _m, _nnz, _desc, _ap_ordinal.data(), _aj.data(), _chol_info); + checkStatus("cusolverSpXcsrcholAnalysis"); + Kokkos::fence(); + const double t_analyze = timer.seconds(); + + if (_verbose) { + printf(" Linear system A\n"); + printf(" number of equations: %10d\n", _m); + printf(" number of nonzeros: %10d\n", _nnz); + printf("\n"); + printf(" Time\n"); + printf(" time for copying A into U: %10.6f s\n", t_copy); + printf(" time for analysis: %10.6f s\n", t_analyze); + printf(" total time spent: %10.6f s\n", (t_copy + t_analyze)); + printf("\n"); } - void setVerbose(const ordinal_type verbose = 1) { - _verbose = verbose; + return 0; + } + + int factorize(const value_type_array &ax) { + if (_verbose) { + printf("cuSolver: Factorize\n"); + printf("===================\n"); } - - template - int analyze(const ordinal_type m, - const arg_size_type_array &ap, - const arg_ordinal_type_array &aj) { - if (_verbose) { - printf("cuSolver: Analyze\n"); - printf("=================\n"); - } - Kokkos::Timer timer; - - _m = m; - - timer.reset(); - - _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); Kokkos::deep_copy(_ap, ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); Kokkos::deep_copy(_aj, aj); - Kokkos::fence(); -#if defined( TACHO_USE_INT_INT ) - _ap_ordinal = ap; -#else - /// LAMBDA cannot capture this pointer; make all variables local - { - ordinal_type_array l_ap_ordinal(do_not_initialize_tag("CuSolver::ap_ordinal"), _ap.extent(0)); - auto l_ap = _ap; - Kokkos::RangePolicy > policy(0,l_ap.extent(0)); - Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const int i) { - l_ap_ordinal(i) = static_cast(l_ap(i)); - }); - _ap_ordinal = l_ap_ordinal; - } -#endif - auto last = Kokkos::subview(_ap, _m); - auto h_last = Kokkos::create_mirror_view(host_memory_space(), last); - Kokkos::deep_copy(h_last, last); - _nnz = h_last(); - const double t_copy = timer.seconds(); - - timer.reset(); - _status = cusolverSpXcsrcholAnalysis(_handle, - _m, _nnz, - _desc, - _ap_ordinal.data(), _aj.data(), - _chol_info); checkStatus("cusolverSpXcsrcholAnalysis"); - Kokkos::fence(); - const double t_analyze = timer.seconds(); - - if (_verbose) { - printf(" Linear system A\n"); - printf(" number of equations: %10d\n", _m); - printf(" number of nonzeros: %10d\n", _nnz); - printf("\n"); - printf(" Time\n"); - printf(" time for copying A into U: %10.6f s\n", t_copy); - printf(" time for analysis: %10.6f s\n", t_analyze); - printf(" total time spent: %10.6f s\n", (t_copy+t_analyze)); - printf("\n"); - } - - return 0; + Kokkos::Timer timer; + + timer.reset(); + size_t internalDataInBytes, workspaceInBytes; + _status = cusolverSpDcsrcholBufferInfo(_handle, _m, _nnz, _desc, ax.data(), _ap_ordinal.data(), _aj.data(), + _chol_info, &internalDataInBytes, &workspaceInBytes); + checkStatus("cusolverSpDcsrcholBufferInfo"); + + const size_t bufsize = workspaceInBytes / sizeof(value_type); + if (bufsize > _buf.extent(0)) + _buf = value_type_array(do_not_initialize_tag("cusolver buf"), bufsize); + Kokkos::fence(); + const double t_alloc = timer.seconds(); + + timer.reset(); + _status = cusolverSpDcsrcholFactor(_handle, _m, _nnz, _desc, ax.data(), _ap_ordinal.data(), _aj.data(), _chol_info, + _buf.data()); + checkStatus("cusolverSpDcsrcholFactor"); + Kokkos::fence(); + const double t_factor = timer.seconds(); + if (_verbose) { + printf(" Time\n"); + printf(" time for workspace allocation: %10.6f s\n", t_alloc); + printf(" time for numeric factorization: %10.6f s\n", t_factor); + printf(" total time spent: %10.6f s\n", (t_alloc + t_factor)); + printf("\n"); + printf(" Workspace\n"); + printf(" internal data in MB: %10.3f MB\n", + double(internalDataInBytes) / 1.e6); + printf(" workspace in MB: %10.3f MB\n", double(workspaceInBytes) / 1.e6); + printf("\n"); + } + + return 0; + } + + int solve(const value_type_matrix &x, const value_type_matrix &b) { + if (_verbose) { + printf("cuSolver: Solve\n"); + printf("===============\n"); } - - int factorize(const value_type_array &ax) { - if (_verbose) { - printf("cuSolver: Factorize\n"); - printf("===================\n"); - } - Kokkos::Timer timer; - - timer.reset(); - size_t internalDataInBytes, workspaceInBytes; - _status = cusolverSpDcsrcholBufferInfo(_handle, - _m, _nnz, _desc, - ax.data(), _ap_ordinal.data(), _aj.data(), - _chol_info, - &internalDataInBytes, - &workspaceInBytes); checkStatus("cusolverSpDcsrcholBufferInfo"); - - const size_t bufsize = workspaceInBytes/sizeof(value_type); - if (bufsize > _buf.extent(0)) - _buf = value_type_array(do_not_initialize_tag("cusolver buf"), bufsize); - Kokkos::fence(); - const double t_alloc = timer.seconds(); - - timer.reset(); - _status = cusolverSpDcsrcholFactor(_handle, - _m, _nnz, _desc, - ax.data(), _ap_ordinal.data(), _aj.data(), - _chol_info, - _buf.data()); checkStatus("cusolverSpDcsrcholFactor"); - Kokkos::fence(); - const double t_factor = timer.seconds(); - if (_verbose) { - printf(" Time\n"); - printf(" time for workspace allocation: %10.6f s\n", t_alloc); - printf(" time for numeric factorization: %10.6f s\n", t_factor); - printf(" total time spent: %10.6f s\n", (t_alloc+t_factor)); - printf("\n"); - printf(" Workspace\n"); - printf(" internal data in MB: %10.3f MB\n", double(internalDataInBytes)/1.e6); - printf(" workspace in MB: %10.3f MB\n", double(workspaceInBytes)/1.e6); - printf("\n"); - } - - return 0; + Kokkos::Timer timer; + + timer.reset(); + /// solve A x = t + const ordinal_type len = x.extent(0), nrhs = x.extent(1); + for (ordinal_type i = 0; i < nrhs; ++i) { + _status = cusolverSpDcsrcholSolve(_handle, _m, b.data() + i * len, x.data() + i * len, _chol_info, _buf.data()); + checkStatus("cusolverSpDcsrcholSolve"); } - - int solve(const value_type_matrix &x, - const value_type_matrix &b) { - if (_verbose) { - printf("cuSolver: Solve\n"); - printf("===============\n"); - } - Kokkos::Timer timer; - - timer.reset(); - /// solve A x = t - const ordinal_type len = x.extent(0), nrhs = x.extent(1); - for (ordinal_type i=0;i::type device_type; + typedef typename UseThisDevice::type device_type; - typedef typename device_type::execution_space exec_space; - typedef typename device_type::memory_space exec_memory_space; + typedef typename device_type::execution_space exec_space; + typedef typename device_type::memory_space exec_memory_space; - typedef typename UseThisDevice::type host_device_type; - typedef typename host_device_type::execution_space host_space; - typedef typename host_device_type::memory_space host_memory_space; + typedef typename UseThisDevice::type host_device_type; + typedef typename host_device_type::execution_space host_space; + typedef typename host_device_type::memory_space host_memory_space; - typedef Kokkos::View size_type_array; - typedef Kokkos::View ordinal_type_array; - typedef Kokkos::View value_type_array; - typedef Kokkos::View value_type_matrix; + typedef Kokkos::View size_type_array; + typedef Kokkos::View ordinal_type_array; + typedef Kokkos::View value_type_array; + typedef Kokkos::View value_type_matrix; - typedef Kokkos::View size_type_array_host; - typedef Kokkos::View ordinal_type_array_host; - typedef Kokkos::View value_type_array_host; - typedef Kokkos::View value_type_matrix_host; + typedef Kokkos::View size_type_array_host; + typedef Kokkos::View ordinal_type_array_host; + typedef Kokkos::View value_type_array_host; + typedef Kokkos::View value_type_matrix_host; - private: - /// pair of cuSparse objects for Lower and Upper - std::pair _handle; - std::pair _desc; - std::pair _info; - std::pair _policy; - std::pair _trans; +private: + /// pair of cuSparse objects for Lower and Upper + std::pair _handle; + std::pair _desc; + std::pair _info; + std::pair _policy; + std::pair _trans; - cudaStream_t _stream; - cudaGraph_t _graph; - cudaGraphExec_t _instance; - int _graph_status; - int _status; + cudaStream_t _stream; + cudaGraph_t _graph; + cudaGraphExec_t _instance; + int _graph_status; + int _status; - ordinal_type _m; - size_type _nnz; + ordinal_type _m; + size_type _nnz; - size_type_array _ap; /// ap_ordinal is used to interface to cuSparse - ordinal_type_array _ap_ordinal, _aj; + size_type_array _ap; /// ap_ordinal is used to interface to cuSparse + ordinal_type_array _ap_ordinal, _aj; - value_type_array _ax, _buf; + value_type_array _ax, _buf; - ordinal_type _verbose; + ordinal_type _verbose; - void checkStatus(const char *s) { - if (_status != 0) { - printf("Error: %s, status %d\n", s, _status); - throw std::runtime_error("Error: checkStatus returns a non-zero status value"); - } + void checkStatus(const char *s) { + if (_status != 0) { + printf("Error: %s, status %d\n", s, _status); + throw std::runtime_error("Error: checkStatus returns a non-zero status value"); } - - public: - CuSparseTriSolve() { - /// let's use a default stream to interoperate with kokkos without execution space - /// later we can create a separate stream but for now it is not worth - _stream = NULL; - ///_status = cudaStreamCreate(&_stream); checkStatus("cudaStreamCreate"); - - _status = cusparseCreate(&_handle.first); checkStatus("cusparseCreate::Lower"); - _status = cusparseCreate(&_handle.second); checkStatus("cusparseCreate::Upper"); - _status = cusparseCreateMatDescr(&_desc.first); checkStatus("cusparseCreateMatDescr::Lower"); - _status = cusparseCreateMatDescr(&_desc.second); checkStatus("cusparseCreateMatDescr::Upper"); - _status = cusparseCreateCsrsv2Info(&_info.first); checkStatus("cusparseCreateCsrsv2Info::Lower"); - _status = cusparseCreateCsrsv2Info(&_info.second); checkStatus("cusparseCreateCsrsv2Info::Upper"); - - _status = cusparseSetStream(_handle.first, _stream); - _status = cusparseSetStream(_handle.second, _stream); - - _graph_status = 0; - } - virtual~CuSparseTriSolve() { - _status = cusparseDestroyCsrsv2Info(_info.second); checkStatus("cusparseCreateCsrsv2Info::Upper"); - _status = cusparseDestroyCsrsv2Info(_info.first); checkStatus("cusparseCreateCsrsv2Info::Lower"); - _status = cusparseDestroyMatDescr(_desc.second); checkStatus("cusparseDestroyMatDescr::Upper"); - _status = cusparseDestroyMatDescr(_desc.first); checkStatus("cusparseDestroyMatDescr::Lower"); - _status = cusparseDestroy(_handle.second); checkStatus("cusparseDestroy::Upper"); - _status = cusparseDestroy(_handle.first); checkStatus("cusparseDestroy::Lower"); - ///_status = cudaStreamDestroy(_stream); + } + +public: + CuSparseTriSolve() { + /// let's use a default stream to interoperate with kokkos without execution space + /// later we can create a separate stream but for now it is not worth + _stream = NULL; + ///_status = cudaStreamCreate(&_stream); checkStatus("cudaStreamCreate"); + + _status = cusparseCreate(&_handle.first); + checkStatus("cusparseCreate::Lower"); + _status = cusparseCreate(&_handle.second); + checkStatus("cusparseCreate::Upper"); + _status = cusparseCreateMatDescr(&_desc.first); + checkStatus("cusparseCreateMatDescr::Lower"); + _status = cusparseCreateMatDescr(&_desc.second); + checkStatus("cusparseCreateMatDescr::Upper"); + _status = cusparseCreateCsrsv2Info(&_info.first); + checkStatus("cusparseCreateCsrsv2Info::Lower"); + _status = cusparseCreateCsrsv2Info(&_info.second); + checkStatus("cusparseCreateCsrsv2Info::Upper"); + + _status = cusparseSetStream(_handle.first, _stream); + _status = cusparseSetStream(_handle.second, _stream); + + _graph_status = 0; + } + virtual ~CuSparseTriSolve() { + _status = cusparseDestroyCsrsv2Info(_info.second); + checkStatus("cusparseCreateCsrsv2Info::Upper"); + _status = cusparseDestroyCsrsv2Info(_info.first); + checkStatus("cusparseCreateCsrsv2Info::Lower"); + _status = cusparseDestroyMatDescr(_desc.second); + checkStatus("cusparseDestroyMatDescr::Upper"); + _status = cusparseDestroyMatDescr(_desc.first); + checkStatus("cusparseDestroyMatDescr::Lower"); + _status = cusparseDestroy(_handle.second); + checkStatus("cusparseDestroy::Upper"); + _status = cusparseDestroy(_handle.first); + checkStatus("cusparseDestroy::Lower"); + ///_status = cudaStreamDestroy(_stream); + } + + void setVerbose(const ordinal_type verbose = 1) { _verbose = verbose; } + + void getStream(cudaStream_t *stream) { *stream = _stream; } + + template + int analyze(const ordinal_type m, const arg_size_type_array &ap, const arg_ordinal_type_array &aj, + const arg_value_type_array &ax) { + if (_verbose) { + printf("cuSparse: Analyze\n"); + printf("=================\n"); } + Kokkos::Timer timer; - void setVerbose(const ordinal_type verbose = 1) { - _verbose = verbose; - } + _m = m; - void getStream(cudaStream_t* stream) { - *stream = _stream; - } + timer.reset(); - template - int analyze(const ordinal_type m, - const arg_size_type_array &ap, - const arg_ordinal_type_array &aj, - const arg_value_type_array &ax) { - if (_verbose) { - printf("cuSparse: Analyze\n"); - printf("=================\n"); - } - Kokkos::Timer timer; - - _m = m; - - timer.reset(); - - _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); Kokkos::deep_copy(_ap, ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); Kokkos::deep_copy(_aj, aj); - _ax = Kokkos::create_mirror_view(exec_memory_space(), ax); Kokkos::deep_copy(_ax, ax); - Kokkos::fence(); -#if defined( TACHO_USE_INT_INT ) - _ap_ordinal = ap; + _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); + Kokkos::deep_copy(_ap, ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); + Kokkos::deep_copy(_aj, aj); + _ax = Kokkos::create_mirror_view(exec_memory_space(), ax); + Kokkos::deep_copy(_ax, ax); + Kokkos::fence(); +#if defined(TACHO_USE_INT_INT) + _ap_ordinal = ap; #else - /// LAMBDA cannot capture this pointer; make all variables local - { - ordinal_type_array l_ap_ordinal(do_not_initialize_tag("CuSolver::ap_ordinal"), _ap.extent(0)); - auto l_ap = _ap; - Kokkos::RangePolicy > policy(0,l_ap.extent(0)); - Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const int i) { - l_ap_ordinal(i) = static_cast(l_ap(i)); - }); - _ap_ordinal = l_ap_ordinal; - } -#endif - auto last = Kokkos::subview(_ap, _m); - auto h_last = Kokkos::create_mirror_view(host_memory_space(), last); - Kokkos::deep_copy(h_last, last); - _nnz = h_last(); - - Kokkos::fence(); - const double t_copy = timer.seconds(); - - timer.reset(); - std::pair bufSizeInBytes; - - _policy.first = CUSPARSE_SOLVE_POLICY_USE_LEVEL; - _policy.second = CUSPARSE_SOLVE_POLICY_USE_LEVEL; - - _trans.first = CUSPARSE_OPERATION_TRANSPOSE; - _trans.second = CUSPARSE_OPERATION_NON_TRANSPOSE; - - /// - /// query buffersize and analyze the matrix; two separate handles are required for book-keeping - /// - _status = cusparseSetMatIndexBase(_desc.first, CUSPARSE_INDEX_BASE_ZERO); checkStatus("cusparseSetMatIndexBase::Lower"); - _status = cusparseSetMatFillMode(_desc.first, CUSPARSE_FILL_MODE_UPPER); checkStatus("cusparseSetMatFillMode::Lower"); - _status = cusparseSetMatDiagType(_desc.first, CUSPARSE_DIAG_TYPE_NON_UNIT); checkStatus("cusparseSetMatDiagType::Lower"); - _status = cusparseDcsrsv2_bufferSize(_handle.first, - _trans.first, - _m, _nnz, - _desc.first, - _ax.data(), - _ap_ordinal.data(), - _aj.data(), - _info.first, - &bufSizeInBytes.first); checkStatus("cusparseDcsrsv2_bufferSize::Lower"); - - _status = cusparseSetMatIndexBase(_desc.second, CUSPARSE_INDEX_BASE_ZERO); checkStatus("cusparseSetMatIndexBase::Upper"); - _status = cusparseSetMatFillMode(_desc.second, CUSPARSE_FILL_MODE_UPPER); checkStatus("cusparseSetMatFillMode::Upper"); - _status = cusparseSetMatDiagType(_desc.second, CUSPARSE_DIAG_TYPE_NON_UNIT); checkStatus("cusparseSetMatDiagType::Upper"); - _status = cusparseDcsrsv2_bufferSize(_handle.second, - _trans.second, - _m, _nnz, - _desc.second, - _ax.data(), - _ap_ordinal.data(), - _aj.data(), - _info.second, - &bufSizeInBytes.second); checkStatus("cusparseDcsrsv2_bufferSize::Upper"); - - const int maxBufSizeInBytes = std::max(bufSizeInBytes.first, bufSizeInBytes.second); - const int bufsize = maxBufSizeInBytes > 0 ? maxBufSizeInBytes/sizeof(double) : 32; - - if (bufsize > int(_buf.extent(0))) - _buf = value_type_array(do_not_initialize_tag("buf"), bufsize); - - _status = cusparseDcsrsv2_analysis(_handle.first, - _trans.first, - _m, _nnz, - _desc.first, - _ax.data(), - _ap_ordinal.data(), - _aj.data(), - _info.first, - _policy.first, - _buf.data()); checkStatus("cusparseDcsrsv2_analysis::Lower"); - - _status = cusparseDcsrsv2_analysis(_handle.second, - _trans.second, - _m, _nnz, - _desc.second, - _ax.data(), - _ap_ordinal.data(), - _aj.data(), - _info.second, - _policy.second, - _buf.data()); checkStatus("cusparseDcsrsv2_analysis::Upper"); - Kokkos::fence(); - const double t_analyze = timer.seconds(); - - if (_verbose) { - printf(" Linear system A\n"); - printf(" number of equations: %10d\n", _m); - printf(" number of nonzeros: %10d\n", _nnz); - printf("\n"); - printf(" Time\n"); - printf(" time for copying A into U: %10.6f s\n", t_copy); - printf(" time for analysis: %10.6f s\n", t_analyze); - printf(" total time spent: %10.6f s\n", (t_copy+t_analyze)); - printf(" Workspace\n"); - printf(" upper solve workspace in MB: %10.3f MB\n", double(bufSizeInBytes.second)/1.e6); - printf(" lower solve workspace in MB: %10.3f MB\n", double(bufSizeInBytes.first)/1.e6); - printf(" max workspace in MB: %10.3f MB\n", double(maxBufSizeInBytes)/1.e6); - printf("\n"); - } - - return 0; + /// LAMBDA cannot capture this pointer; make all variables local + { + ordinal_type_array l_ap_ordinal(do_not_initialize_tag("CuSolver::ap_ordinal"), _ap.extent(0)); + auto l_ap = _ap; + Kokkos::RangePolicy> policy(0, l_ap.extent(0)); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const int i) { l_ap_ordinal(i) = static_cast(l_ap(i)); }); + _ap_ordinal = l_ap_ordinal; } - - int solve(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t, - const bool apply_fence = true) { - if (_verbose) { - printf("cuSolver: Solve\n"); - printf("===============\n"); - } - Kokkos::Timer timer; - - timer.reset(); - - /// solve A x = t - const value_type one(1); - const ordinal_type len = x.extent(0), nrhs = x.extent(1); - for (ordinal_type i=0;i bufSizeInBytes; + + _policy.first = CUSPARSE_SOLVE_POLICY_USE_LEVEL; + _policy.second = CUSPARSE_SOLVE_POLICY_USE_LEVEL; + + _trans.first = CUSPARSE_OPERATION_TRANSPOSE; + _trans.second = CUSPARSE_OPERATION_NON_TRANSPOSE; + + /// + /// query buffersize and analyze the matrix; two separate handles are required for book-keeping + /// + _status = cusparseSetMatIndexBase(_desc.first, CUSPARSE_INDEX_BASE_ZERO); + checkStatus("cusparseSetMatIndexBase::Lower"); + _status = cusparseSetMatFillMode(_desc.first, CUSPARSE_FILL_MODE_UPPER); + checkStatus("cusparseSetMatFillMode::Lower"); + _status = cusparseSetMatDiagType(_desc.first, CUSPARSE_DIAG_TYPE_NON_UNIT); + checkStatus("cusparseSetMatDiagType::Lower"); + _status = cusparseDcsrsv2_bufferSize(_handle.first, _trans.first, _m, _nnz, _desc.first, _ax.data(), + _ap_ordinal.data(), _aj.data(), _info.first, &bufSizeInBytes.first); + checkStatus("cusparseDcsrsv2_bufferSize::Lower"); + + _status = cusparseSetMatIndexBase(_desc.second, CUSPARSE_INDEX_BASE_ZERO); + checkStatus("cusparseSetMatIndexBase::Upper"); + _status = cusparseSetMatFillMode(_desc.second, CUSPARSE_FILL_MODE_UPPER); + checkStatus("cusparseSetMatFillMode::Upper"); + _status = cusparseSetMatDiagType(_desc.second, CUSPARSE_DIAG_TYPE_NON_UNIT); + checkStatus("cusparseSetMatDiagType::Upper"); + _status = cusparseDcsrsv2_bufferSize(_handle.second, _trans.second, _m, _nnz, _desc.second, _ax.data(), + _ap_ordinal.data(), _aj.data(), _info.second, &bufSizeInBytes.second); + checkStatus("cusparseDcsrsv2_bufferSize::Upper"); + + const int maxBufSizeInBytes = std::max(bufSizeInBytes.first, bufSizeInBytes.second); + const int bufsize = maxBufSizeInBytes > 0 ? maxBufSizeInBytes / sizeof(double) : 32; + + if (bufsize > int(_buf.extent(0))) + _buf = value_type_array(do_not_initialize_tag("buf"), bufsize); + + _status = cusparseDcsrsv2_analysis(_handle.first, _trans.first, _m, _nnz, _desc.first, _ax.data(), + _ap_ordinal.data(), _aj.data(), _info.first, _policy.first, _buf.data()); + checkStatus("cusparseDcsrsv2_analysis::Lower"); + + _status = cusparseDcsrsv2_analysis(_handle.second, _trans.second, _m, _nnz, _desc.second, _ax.data(), + _ap_ordinal.data(), _aj.data(), _info.second, _policy.second, _buf.data()); + checkStatus("cusparseDcsrsv2_analysis::Upper"); + Kokkos::fence(); + const double t_analyze = timer.seconds(); + + if (_verbose) { + printf(" Linear system A\n"); + printf(" number of equations: %10d\n", _m); + printf(" number of nonzeros: %10d\n", _nnz); + printf("\n"); + printf(" Time\n"); + printf(" time for copying A into U: %10.6f s\n", t_copy); + printf(" time for analysis: %10.6f s\n", t_analyze); + printf(" total time spent: %10.6f s\n", (t_copy + t_analyze)); + printf(" Workspace\n"); + printf(" upper solve workspace in MB: %10.3f MB\n", + double(bufSizeInBytes.second) / 1.e6); + printf(" lower solve workspace in MB: %10.3f MB\n", + double(bufSizeInBytes.first) / 1.e6); + printf(" max workspace in MB: %10.3f MB\n", + double(maxBufSizeInBytes) / 1.e6); + printf("\n"); } - int solve_capture(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t) { - _graph_status = 1; /// capture begin - cudaStreamBeginCapture(_stream, cudaStreamCaptureModeGlobal); - solve(x, b, t, false); /// do not apply fence when it capture - cudaStreamEndCapture(_stream, &_graph); - cudaGraphInstantiate(&_instance, _graph, NULL, NULL, 0); + return 0; + } - return 0; + int solve(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t, + const bool apply_fence = true) { + if (_verbose) { + printf("cuSolver: Solve\n"); + printf("===============\n"); } + Kokkos::Timer timer; + + timer.reset(); + + /// solve A x = t + const value_type one(1); + const ordinal_type len = x.extent(0), nrhs = x.extent(1); + for (ordinal_type i = 0; i < nrhs; ++i) { + _status = cusparseDcsrsv2_solve(_handle.first, _trans.first, _m, _nnz, (const double *)&one, _desc.first, + (const double *)_ax.data(), (const int *)_ap_ordinal.data(), + (const int *)_aj.data(), _info.first, (const double *)b.data() + i * len, + (double *)t.data() + i * len, _policy.first, (void *)_buf.data()); + checkStatus("cusparseDcsrsv2_solve::Lower"); + + _status = cusparseDcsrsv2_solve(_handle.second, _trans.second, _m, _nnz, (const double *)&one, _desc.second, + (const double *)_ax.data(), (const int *)_ap_ordinal.data(), + (const int *)_aj.data(), _info.second, (const double *)t.data() + i * len, + (double *)x.data() + i * len, _policy.second, (void *)_buf.data()); + checkStatus("cusparseDcsrsv2_solve::Upper"); + } + if (apply_fence) + Kokkos::fence(); - int solve_launch() { - cudaGraphLaunch(_instance, _stream); - cudaStreamSynchronize(_stream); - return 0; + const double t_solve = timer.seconds(); + if (_verbose) { + printf(" Time\n"); + printf(" time for solve: %10.6f s\n", t_solve); + printf("\n"); } - - }; -} + + return 0; + } + + int solve_capture(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t) { + _graph_status = 1; /// capture begin + cudaStreamBeginCapture(_stream, cudaStreamCaptureModeGlobal); + solve(x, b, t, false); /// do not apply fence when it capture + cudaStreamEndCapture(_stream, &_graph); + cudaGraphInstantiate(&_instance, _graph, NULL, NULL, 0); + + return 0; + } + + int solve_launch() { + cudaGraphLaunch(_instance, _stream); + cudaStreamSynchronize(_stream); + return 0; + } +}; +} // namespace Tacho #endif #endif diff --git a/packages/shylu/shylu_node/tacho/src/Tacho_Driver.hpp b/packages/shylu/shylu_node/tacho/src/Tacho_Driver.hpp index c26f9e188846..a00298b16898 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho_Driver.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho_Driver.hpp @@ -12,384 +12,374 @@ namespace Tacho { - /// forward decl - class Graph; +/// forward decl +class Graph; #if defined(TACHO_HAVE_METIS) - class GraphTools_Metis; -#else - class GraphTools; +class GraphTools_Metis; +#else +class GraphTools; #endif - - class SymbolicTools; - template class CrsMatrixBase; - template class NumericToolsBase; - template class NumericToolsSerial; - template class NumericToolsLevelSet; - /// - /// Tacho Solver interface - /// - template - struct Driver { - public: - using value_type = ValueType; - using device_type = DeviceType; - using exec_space = typename device_type::execution_space; - using exec_memory_space = typename device_type::memory_space; - using scheduler_type = typename UseThisScheduler::type; - - using host_device_type = typename UseThisDevice::type; - using host_space = typename host_device_type::execution_space; - using host_memory_space = typename host_device_type::memory_space; - - using size_type_array = Kokkos::View; - using ordinal_type_array = Kokkos::View; - using value_type_array = Kokkos::View; - using value_type_matrix = Kokkos::View; - - using size_type_array_host = Kokkos::View; - using ordinal_type_array_host = Kokkos::View; - using value_type_array_host = Kokkos::View; - using value_type_matrix_host = Kokkos::View; - - using crs_matrix_type = CrsMatrixBase; - using crs_matrix_type_host = CrsMatrixBase; +class SymbolicTools; +template class CrsMatrixBase; +template class NumericToolsBase; +template class NumericToolsSerial; +template class NumericToolsLevelSet; + +/// +/// Tacho Solver interface +/// +template struct Driver { +public: + using value_type = ValueType; + using device_type = DeviceType; + using exec_space = typename device_type::execution_space; + using exec_memory_space = typename device_type::memory_space; + + using host_device_type = typename UseThisDevice::type; + using host_space = typename host_device_type::execution_space; + using host_memory_space = typename host_device_type::memory_space; + + using size_type_array = Kokkos::View; + using ordinal_type_array = Kokkos::View; + using value_type_array = Kokkos::View; + using value_type_matrix = Kokkos::View; + + using size_type_array_host = Kokkos::View; + using ordinal_type_array_host = Kokkos::View; + using value_type_array_host = Kokkos::View; + using value_type_matrix_host = Kokkos::View; + + using crs_matrix_type = CrsMatrixBase; + using crs_matrix_type_host = CrsMatrixBase; #if defined(TACHO_HAVE_METIS) - using graph_tools_type = GraphTools_Metis; + using graph_tools_type = GraphTools_Metis; #else - using graph_tools_type = GraphTools; + using graph_tools_type = GraphTools; #endif - using symbolic_tools_type = SymbolicTools; - using numeric_tools_base_type = NumericToolsBase; - using numeric_tools_serial_type = NumericToolsSerial; - using numeric_tools_levelset_var0_type = NumericToolsLevelSet; - using numeric_tools_levelset_var1_type = NumericToolsLevelSet; - using numeric_tools_levelset_var2_type = NumericToolsLevelSet; - - private: - enum : int { Cholesky = 1, - LDL = 2, - SymLU = 3, - LU = 4 }; - - // ** solver mode - ordinal_type _mode; - - // ** ordering options - ordinal_type _order_connected_graph_separately; - - // ** problem - ordinal_type _m; - size_type _nnz; - - size_type_array _ap; size_type_array_host _h_ap; - ordinal_type_array _aj; ordinal_type_array_host _h_aj; - - ordinal_type_array _perm; ordinal_type_array_host _h_perm; - ordinal_type_array _peri; ordinal_type_array_host _h_peri; - - // ** condensed graph - ordinal_type _m_graph; - size_type _nnz_graph; - - size_type_array_host _h_ap_graph; - ordinal_type_array_host _h_aj_graph; - ordinal_type_array_host _h_aw_graph; - - ordinal_type_array_host _h_perm_graph; - ordinal_type_array_host _h_peri_graph; - - // ** symbolic factorization output - // supernodes output - ordinal_type _nsupernodes; - ordinal_type_array _supernodes; - - // dof mapping to sparse matrix - size_type_array _gid_super_panel_ptr; - ordinal_type_array _gid_super_panel_colidx; - - // supernode map and panel size configuration - size_type_array _sid_super_panel_ptr; - ordinal_type_array _sid_super_panel_colidx, _blk_super_panel_colidx; - - // supernode elimination tree (parent - children) - size_type_array _stree_ptr; - ordinal_type_array _stree_children; - - // supernode elimination tree (child - parent) - ordinal_type_array _stree_parent; - - // roots of supernodes - ordinal_type_array_host _stree_level, _stree_roots; - - // ** numeric factorization output - numeric_tools_base_type *_N; - - // small dense matrix - // - chol A is used - // - ldl A D P are used - value_type_matrix_host _A, _D; - ordinal_type_array_host _P; - - // ** options - ordinal_type _verbose; // print - ordinal_type _small_problem_thres; // smaller than this, use lapack - - // // ** tasking options - ordinal_type _serial_thres_size; // serialization threshold size - ordinal_type _mb; // block size for byblocks algorithms - ordinal_type _nb; // panel size for panel algorithms - ordinal_type _front_update_mode; // front update mode 0 - lock, 1 - atomic - - // ** levelset options - bool _levelset; // use level set code instead of tasking - ordinal_type _device_level_cut; // above this level, matrices are computed on device - ordinal_type _device_factor_thres; // bigger than this threshold, device function is used - ordinal_type _device_solve_thres; // bigger than this threshold, device function is used - ordinal_type _variant; // algorithmic variant in levelset 0: naive, 1: invert diagonals - ordinal_type _nstreams; // on cuda, multi streams are used - - // parallelism and memory constraint is made via this parameter - ordinal_type _max_num_superblocks; // # of superblocks in the memoyrpool - - public: - Driver(); - /// delete copy constructor and assignment operator - /// sharing numeric tools for different inputs does not make sense - Driver(const Driver &) = default; - Driver& operator=(const Driver &) = default; - - /// duplicate the solver with sharing symbolic factorization - Driver duplicate(); - - /// - /// common options - /// - void setVerbose(const ordinal_type verbose = 1); - void setSmallProblemThresholdsize(const ordinal_type small_problem_thres = 1024); - void setMatrixType(const int symmetric, // 0 - unsymmetric, 1 - structure sym, 2 - symmetric - const bool is_positive_definite); - - /// - /// Graph options - /// - void setOrderConnectedGraphSeparately(const ordinal_type order_connected_graph_separately = 1); - - /// - /// tasking options - /// - void setSerialThresholdsize(const ordinal_type serial_thres_size = -1); - void setBlocksize(const ordinal_type mb = -1); - void setPanelsize(const ordinal_type nb = -1); - void setFrontUpdateMode(const ordinal_type front_update_mode = 1); - void setMaxNumberOfSuperblocks(const ordinal_type max_num_superblocks = -1); - - /// - /// Level set tools options - /// - void setLevelSetScheduling(const bool levelset); - void setLevelSetOptionDeviceLevelCut(const ordinal_type device_level_cut); - void setLevelSetOptionDeviceFunctionThreshold(const ordinal_type device_factor_thres, - const ordinal_type device_solve_thres); - void setLevelSetOptionNumStreams(const ordinal_type nstreams); - void setLevelSetOptionAlgorithmVariant(const ordinal_type variant); - - /// - /// get interface - /// - ordinal_type getNumSupernodes() const; - ordinal_type_array getSupernodes() const; - ordinal_type_array getPermutationVector() const; - ordinal_type_array getInversePermutationVector() const; - - // internal only - int analyze(); - int analyze_linear_system(); - int analyze_condensed_graph(); - - template - int analyze(const ordinal_type m, - const arg_size_type_array &ap, - const arg_ordinal_type_array &aj, - const bool duplicate = false) { - _m = m; - - if (duplicate) { - /// for most cases, ap and aj are from host; so construct ap and aj and mirror to device - _h_ap = size_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_ap"), ap.extent(0)); Kokkos::deep_copy(_h_ap, ap); - _h_aj = ordinal_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_aj"), aj.extent(0)); Kokkos::deep_copy(_h_aj, aj); - - _ap = Kokkos::create_mirror_view(exec_memory_space(), _h_ap); Kokkos::deep_copy(_ap, _h_ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), _h_aj); Kokkos::deep_copy(_aj, _h_aj); - } else { - /// this does not make any extra deep copy; users should hold the graph data - _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); Kokkos::deep_copy(_ap, ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); Kokkos::deep_copy(_aj, aj); - - _h_ap = Kokkos::create_mirror_view(host_memory_space(), ap); Kokkos::deep_copy(_h_ap, ap); - _h_aj = Kokkos::create_mirror_view(host_memory_space(), aj); Kokkos::deep_copy(_h_aj, aj); - } - - _h_perm = ordinal_type_array_host(); - _h_peri = ordinal_type_array_host(); - - _nnz = _h_ap(m); - - _m_graph = 0; - _nnz_graph = 0; - - _h_ap_graph = size_type_array_host(); - _h_aj_graph = ordinal_type_array_host(); - - _h_perm_graph = ordinal_type_array_host(); - _h_peri_graph = ordinal_type_array_host(); - - return analyze(); + using symbolic_tools_type = SymbolicTools; + using numeric_tools_base_type = NumericToolsBase; + using numeric_tools_serial_type = NumericToolsSerial; + using numeric_tools_levelset_var0_type = NumericToolsLevelSet; + using numeric_tools_levelset_var1_type = NumericToolsLevelSet; + using numeric_tools_levelset_var2_type = NumericToolsLevelSet; + +private: + enum : int { Cholesky = 1, LDL = 2, SymLU = 3, LU = 4 }; + + // ** solver mode + ordinal_type _method; + + // ** ordering options + ordinal_type _order_connected_graph_separately; + + // ** problem + ordinal_type _m; + size_type _nnz; + + size_type_array _ap; + size_type_array_host _h_ap; + ordinal_type_array _aj; + ordinal_type_array_host _h_aj; + + ordinal_type_array _perm; + ordinal_type_array_host _h_perm; + ordinal_type_array _peri; + ordinal_type_array_host _h_peri; + + // ** condensed graph + ordinal_type _m_graph; + size_type _nnz_graph; + + size_type_array_host _h_ap_graph; + ordinal_type_array_host _h_aj_graph; + ordinal_type_array_host _h_aw_graph; + + ordinal_type_array_host _h_perm_graph; + ordinal_type_array_host _h_peri_graph; + + // ** symbolic factorization output + // supernodes output + ordinal_type _nsupernodes; + ordinal_type_array _supernodes; + + // dof mapping to sparse matrix + size_type_array _gid_super_panel_ptr; + ordinal_type_array _gid_super_panel_colidx; + + // supernode map and panel size configuration + size_type_array _sid_super_panel_ptr; + ordinal_type_array _sid_super_panel_colidx, _blk_super_panel_colidx; + + // supernode elimination tree (parent - children) + size_type_array _stree_ptr; + ordinal_type_array _stree_children; + + // supernode elimination tree (child - parent) + ordinal_type_array _stree_parent; + + // roots of supernodes + ordinal_type_array_host _stree_level, _stree_roots; + + // ** numeric factorization output + numeric_tools_base_type *_N; + + // small dense matrix + // - chol A is used + // - ldl A D P are used + value_type_matrix_host _A, _D; + ordinal_type_array_host _P; + + // ** options + ordinal_type _verbose; // print + ordinal_type _small_problem_thres; // smaller than this, use lapack + + // // ** tasking options + ordinal_type _serial_thres_size; // serialization threshold size + ordinal_type _mb; // block size for byblocks algorithms + ordinal_type _nb; // panel size for panel algorithms + ordinal_type _front_update_mode; // front update mode 0 - lock, 1 - atomic + + // ** levelset options + bool _levelset; // use level set code instead of tasking + ordinal_type _device_level_cut; // above this level, matrices are computed on device + ordinal_type _device_factor_thres; // bigger than this threshold, device function is used + ordinal_type _device_solve_thres; // bigger than this threshold, device function is used + ordinal_type _variant; // algorithmic variant in levelset 0: naive, 1: invert diagonals + ordinal_type _nstreams; // on cuda, multi streams are used + + // parallelism and memory constraint is made via this parameter + ordinal_type _max_num_superblocks; // # of superblocks in the memoyrpool + +public: + Driver(); + /// delete copy constructor and assignment operator + /// sharing numeric tools for different inputs does not make sense + Driver(const Driver &) = default; + Driver &operator=(const Driver &) = default; + + /// duplicate the solver with sharing symbolic factorization + Driver duplicate(); + + /// + /// common options + /// + void setVerbose(const ordinal_type verbose = 1); + void setSmallProblemThresholdsize(const ordinal_type small_problem_thres = 1024); + void setMatrixType(const int symmetric, // 0 - unsymmetric, 1 - structure sym, 2 - symmetric + const bool is_positive_definite); + void setSolutionMethod(const int method); /// 1 - cholesky, 2 - LDL, 3 - LU + + /// + /// Graph options + /// + void setOrderConnectedGraphSeparately(const ordinal_type order_connected_graph_separately = 1); + + /// + /// tasking options + /// + void setSerialThresholdsize(const ordinal_type serial_thres_size = -1); + void setBlocksize(const ordinal_type mb = -1); + void setPanelsize(const ordinal_type nb = -1); + void setFrontUpdateMode(const ordinal_type front_update_mode = 1); + void setMaxNumberOfSuperblocks(const ordinal_type max_num_superblocks = -1); + + /// + /// Level set tools options + /// + void setLevelSetScheduling(const bool levelset); + void setLevelSetOptionDeviceLevelCut(const ordinal_type device_level_cut); + void setLevelSetOptionDeviceFunctionThreshold(const ordinal_type device_factor_thres, + const ordinal_type device_solve_thres); + void setLevelSetOptionNumStreams(const ordinal_type nstreams); + void setLevelSetOptionAlgorithmVariant(const ordinal_type variant); + + /// + /// get interface + /// + ordinal_type getNumSupernodes() const; + ordinal_type_array getSupernodes() const; + ordinal_type_array getPermutationVector() const; + ordinal_type_array getInversePermutationVector() const; + + // internal only + int analyze(); + int analyze_linear_system(); + int analyze_condensed_graph(); + + template + int analyze(const ordinal_type m, const arg_size_type_array &ap, const arg_ordinal_type_array &aj, + const bool duplicate = false) { + _m = m; + + if (duplicate) { + /// for most cases, ap and aj are from host; so construct ap and aj and mirror to device + _h_ap = size_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_ap"), ap.extent(0)); + Kokkos::deep_copy(_h_ap, ap); + _h_aj = ordinal_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_aj"), aj.extent(0)); + Kokkos::deep_copy(_h_aj, aj); + + _ap = Kokkos::create_mirror_view(exec_memory_space(), _h_ap); + Kokkos::deep_copy(_ap, _h_ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), _h_aj); + Kokkos::deep_copy(_aj, _h_aj); + } else { + /// this does not make any extra deep copy; users should hold the graph data + _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); + Kokkos::deep_copy(_ap, ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); + Kokkos::deep_copy(_aj, aj); + + _h_ap = Kokkos::create_mirror_view(host_memory_space(), ap); + Kokkos::deep_copy(_h_ap, ap); + _h_aj = Kokkos::create_mirror_view(host_memory_space(), aj); + Kokkos::deep_copy(_h_aj, aj); } - template - int analyze(const ordinal_type m, - const arg_size_type_array &ap, - const arg_ordinal_type_array &aj, - const arg_perm_type_array &perm, - const arg_perm_type_array &peri, - const bool duplicate = false) { - _m = m; - - if (duplicate) { - /// for most cases, ap and aj are from host; so construct ap and aj and mirror to device - _h_ap = size_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_ap"), ap.extent(0)); Kokkos::deep_copy(_h_ap, ap); - _h_aj = ordinal_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_aj"), aj.extent(0)); Kokkos::deep_copy(_h_aj, aj); - - _ap = Kokkos::create_mirror_view(exec_memory_space(), _h_ap); Kokkos::deep_copy(_ap, _h_ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), _h_aj); Kokkos::deep_copy(_aj, _h_aj); - - _h_perm = ordinal_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_perm"), perm.extent(0)); - _h_peri = ordinal_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_peri"), peri.extent(0)); - } else { - /// this does not make any extra deep copy; users should hold the graph data - _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); Kokkos::deep_copy(_ap, ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); Kokkos::deep_copy(_aj, aj); - - _h_ap = Kokkos::create_mirror_view(host_memory_space(), ap); Kokkos::deep_copy(_h_ap, ap); - _h_aj = Kokkos::create_mirror_view(host_memory_space(), aj); Kokkos::deep_copy(_h_aj, aj); - - _h_perm = Kokkos::create_mirror_view(host_memory_space(), perm); - _h_peri = Kokkos::create_mirror_view(host_memory_space(), peri); - } - - Kokkos::deep_copy(_h_perm, perm); - Kokkos::deep_copy(_h_peri, peri); - - _nnz = _h_ap(m); - - _m_graph = 0; - _nnz_graph = 0; - - _h_ap_graph = size_type_array_host(); - _h_aj_graph = ordinal_type_array_host(); - - _h_perm_graph = ordinal_type_array_host(); - _h_peri_graph = ordinal_type_array_host(); - - return analyze(); + _h_perm = ordinal_type_array_host(); + _h_peri = ordinal_type_array_host(); + + _nnz = _h_ap(m); + + _m_graph = 0; + _nnz_graph = 0; + + _h_ap_graph = size_type_array_host(); + _h_aj_graph = ordinal_type_array_host(); + + _h_perm_graph = ordinal_type_array_host(); + _h_peri_graph = ordinal_type_array_host(); + + return analyze(); + } + + template + int analyze(const ordinal_type m, const arg_size_type_array &ap, const arg_ordinal_type_array &aj, + const arg_perm_type_array &perm, const arg_perm_type_array &peri, const bool duplicate = false) { + _m = m; + + if (duplicate) { + /// for most cases, ap and aj are from host; so construct ap and aj and mirror to device + _h_ap = size_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_ap"), ap.extent(0)); + Kokkos::deep_copy(_h_ap, ap); + _h_aj = ordinal_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_aj"), aj.extent(0)); + Kokkos::deep_copy(_h_aj, aj); + + _ap = Kokkos::create_mirror_view(exec_memory_space(), _h_ap); + Kokkos::deep_copy(_ap, _h_ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), _h_aj); + Kokkos::deep_copy(_aj, _h_aj); + + _h_perm = ordinal_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_perm"), perm.extent(0)); + _h_peri = ordinal_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_peri"), peri.extent(0)); + } else { + /// this does not make any extra deep copy; users should hold the graph data + _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); + Kokkos::deep_copy(_ap, ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); + Kokkos::deep_copy(_aj, aj); + + _h_ap = Kokkos::create_mirror_view(host_memory_space(), ap); + Kokkos::deep_copy(_h_ap, ap); + _h_aj = Kokkos::create_mirror_view(host_memory_space(), aj); + Kokkos::deep_copy(_h_aj, aj); + + _h_perm = Kokkos::create_mirror_view(host_memory_space(), perm); + _h_peri = Kokkos::create_mirror_view(host_memory_space(), peri); + } + + Kokkos::deep_copy(_h_perm, perm); + Kokkos::deep_copy(_h_peri, peri); + + _nnz = _h_ap(m); + + _m_graph = 0; + _nnz_graph = 0; + + _h_ap_graph = size_type_array_host(); + _h_aj_graph = ordinal_type_array_host(); + + _h_perm_graph = ordinal_type_array_host(); + _h_peri_graph = ordinal_type_array_host(); + + return analyze(); + } + + template + int analyze(const ordinal_type m, const arg_size_type_array &ap, const arg_ordinal_type_array &aj, + const ordinal_type m_graph, const arg_size_type_array &ap_graph, const arg_ordinal_type_array &aj_graph, + const arg_ordinal_type_array &aw_graph, const bool duplicate = false) { + _m = m; + + if (duplicate) { + /// for most cases, ap and aj are from host; so construct ap and aj and mirror to device + _h_ap = size_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_ap"), ap.extent(0)); + Kokkos::deep_copy(_h_ap, ap); + _h_aj = ordinal_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_aj"), aj.extent(0)); + Kokkos::deep_copy(_h_aj, aj); + + _ap = Kokkos::create_mirror_view(exec_memory_space(), _h_ap); + Kokkos::deep_copy(_ap, _h_ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), _h_aj); + Kokkos::deep_copy(_aj, _h_aj); + } else { + /// this does not make any extra deep copy; users should hold the graph data + _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); + Kokkos::deep_copy(_ap, ap); + _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); + Kokkos::deep_copy(_aj, aj); + + _h_ap = Kokkos::create_mirror_view(host_memory_space(), ap); + Kokkos::deep_copy(_h_ap, ap); + _h_aj = Kokkos::create_mirror_view(host_memory_space(), aj); + Kokkos::deep_copy(_h_aj, aj); } - - template - int analyze(const ordinal_type m, - const arg_size_type_array &ap, - const arg_ordinal_type_array &aj, - const ordinal_type m_graph, - const arg_size_type_array &ap_graph, - const arg_ordinal_type_array &aj_graph, - const arg_ordinal_type_array &aw_graph, - const bool duplicate = false) { - _m = m; - - if (duplicate) { - /// for most cases, ap and aj are from host; so construct ap and aj and mirror to device - _h_ap = size_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_ap"), ap.extent(0)); Kokkos::deep_copy(_h_ap, ap); - _h_aj = ordinal_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_aj"), aj.extent(0)); Kokkos::deep_copy(_h_aj, aj); - - _ap = Kokkos::create_mirror_view(exec_memory_space(), _h_ap); Kokkos::deep_copy(_ap, _h_ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), _h_aj); Kokkos::deep_copy(_aj, _h_aj); - } else { - /// this does not make any extra deep copy; users should hold the graph data - _ap = Kokkos::create_mirror_view(exec_memory_space(), ap); Kokkos::deep_copy(_ap, ap); - _aj = Kokkos::create_mirror_view(exec_memory_space(), aj); Kokkos::deep_copy(_aj, aj); - - _h_ap = Kokkos::create_mirror_view(host_memory_space(), ap); Kokkos::deep_copy(_h_ap, ap); - _h_aj = Kokkos::create_mirror_view(host_memory_space(), aj); Kokkos::deep_copy(_h_aj, aj); - } - - _h_perm = ordinal_type_array_host(); - _h_peri = ordinal_type_array_host(); - - _nnz = _h_ap(m); - - _m_graph = m_graph; - if (duplicate) { - _h_ap_graph = size_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_ap_graph"), ap_graph.extent(0)); - _h_aj_graph = ordinal_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_aj_graph"), aj_graph.extent(0)); - _h_aw_graph = ordinal_type_array_host - (Kokkos::ViewAllocateWithoutInitializing("h_aw_graph"), aw_graph.extent(0)); - } else { - _h_ap_graph = Kokkos::create_mirror_view(host_memory_space(), ap_graph); - _h_aj_graph = Kokkos::create_mirror_view(host_memory_space(), aj_graph); - _h_aw_graph = Kokkos::create_mirror_view(host_memory_space(), aw_graph); - } - - Kokkos::deep_copy(_h_ap_graph, ap_graph); - Kokkos::deep_copy(_h_aj_graph, aj_graph); - Kokkos::deep_copy(_h_aw_graph, aw_graph); - - _h_perm_graph = ordinal_type_array_host(); - _h_peri_graph = ordinal_type_array_host(); - - _nnz_graph = _h_ap_graph(m_graph); - - return analyze(); + _h_perm = ordinal_type_array_host(); + _h_peri = ordinal_type_array_host(); + + _nnz = _h_ap(m); + + _m_graph = m_graph; + if (duplicate) { + _h_ap_graph = size_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_ap_graph"), ap_graph.extent(0)); + _h_aj_graph = ordinal_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_aj_graph"), aj_graph.extent(0)); + _h_aw_graph = ordinal_type_array_host(Kokkos::ViewAllocateWithoutInitializing("h_aw_graph"), aw_graph.extent(0)); + } else { + _h_ap_graph = Kokkos::create_mirror_view(host_memory_space(), ap_graph); + _h_aj_graph = Kokkos::create_mirror_view(host_memory_space(), aj_graph); + _h_aw_graph = Kokkos::create_mirror_view(host_memory_space(), aw_graph); } - int initialize(); + Kokkos::deep_copy(_h_ap_graph, ap_graph); + Kokkos::deep_copy(_h_aj_graph, aj_graph); + Kokkos::deep_copy(_h_aw_graph, aw_graph); + + _h_perm_graph = ordinal_type_array_host(); + _h_peri_graph = ordinal_type_array_host(); + + _nnz_graph = _h_ap_graph(m_graph); + + return analyze(); + } + + int initialize(); - int factorize(const value_type_array &ax); - int factorize_small_host(const value_type_array &ax); + int factorize(const value_type_array &ax); + int factorize_small_host(const value_type_array &ax); - int solve(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t); - int solve_small_host(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t); + int solve(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t); + int solve_small_host(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t); - double computeRelativeResidual(const value_type_array &ax, - const value_type_matrix &x, - const value_type_matrix &b); - - int exportFactorsToCrsMatrix(crs_matrix_type &A); - int release(); + double computeRelativeResidual(const value_type_array &ax, const value_type_matrix &x, const value_type_matrix &b); - }; + int exportFactorsToCrsMatrix(crs_matrix_type &A); + int release(); +}; -} +} // namespace Tacho //#include "Tacho_Driver_Impl.hpp" diff --git a/packages/shylu/shylu_node/tacho/src/Tacho_MatrixMarket.hpp b/packages/shylu/shylu_node/tacho/src/Tacho_MatrixMarket.hpp index ccaf7f329268..3310200c8b8b 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho_MatrixMarket.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho_MatrixMarket.hpp @@ -10,328 +10,294 @@ namespace Tacho { - /// - /// Coo : Sparse coordinate format; (i, j, val). - /// - template - struct Coo { - typedef ValueType value_type; +/// +/// Coo : Sparse coordinate format; (i, j, val). +/// +template struct Coo { + typedef ValueType value_type; + + ordinal_type i, j; + value_type val; + + Coo() = default; + Coo(const ordinal_type ii, const ordinal_type jj, const value_type vval) : i(ii), j(jj), val(vval) {} + Coo(const Coo &b) = default; + + /// \brief Compare "less" index i and j only. + bool operator<(const Coo &y) const { + const auto r_val = (this->i - y.i); + return (r_val == 0 ? this->j < y.j : r_val < 0); + } - ordinal_type i,j; - value_type val; + /// \brief Compare "equality" only index i and j. + bool operator==(const Coo &y) const { return (this->i == y.i) && (this->j == y.j); } - Coo() = default; - Coo(const ordinal_type ii, - const ordinal_type jj, - const value_type vval) - : i(ii), j(jj), val(vval) {} - Coo(const Coo& b) = default; + /// \brief Compare "in-equality" only index i and j. + bool operator!=(const Coo &y) const { return !(*this == y); } +}; - /// \brief Compare "less" index i and j only. - bool operator<(const Coo &y) const { - const auto r_val = (this->i - y.i); - return (r_val == 0 ? this->j < y.j : r_val < 0); - } +template +inline typename std::enable_if>::value || + std::is_same>::value>::type +impl_conj_val(T &val, T &conj_val) { + conj_val = Kokkos::conj(val); +} - /// \brief Compare "equality" only index i and j. - bool operator==(const Coo &y) const { - return (this->i == y.i) && (this->j == y.j); - } +template +inline typename std::enable_if::value || std::is_same::value>::type +impl_conj_val(T &val, T &conj_val) { + conj_val = val; +} - /// \brief Compare "in-equality" only index i and j. - bool operator!=(const Coo &y) const { - return !(*this == y); - } - }; - - - template - inline - typename std::enable_if >::value || std::is_same >::value>::type - impl_conj_val(T &val, T &conj_val) { - conj_val = Kokkos::conj(val); - } +template +inline typename std::enable_if>::value || + std::is_same>::value>::type +impl_is_zero(T &val, bool &is_zero) { + is_zero = Kokkos::abs(val) < std::numeric_limits::epsilon(); +} - template - inline - typename std::enable_if::value || std::is_same::value>::type - impl_conj_val(T &val, T &conj_val) { - conj_val = val; - } +template +inline typename std::enable_if::value || std::is_same::value>::type +impl_is_zero(T &val, bool &is_zero) { + is_zero = std::abs(val) < std::numeric_limits::epsilon(); +} - template - inline - typename std::enable_if >::value || std::is_same >::value>::type - impl_is_zero(T &val, bool &is_zero) { - is_zero = Kokkos::abs(val) < std::numeric_limits::epsilon(); - } +template +inline typename std::enable_if::value || std::is_same::value>::type +impl_read_value_from_file(std::ifstream &file, ordinal_type &row, ordinal_type &col, T &val) { + file >> row >> col >> val; +} - template - inline - typename std::enable_if::value || std::is_same::value>::type - impl_is_zero(T &val, bool &is_zero) { - is_zero = std::abs(val) < std::numeric_limits::epsilon(); - } +template +inline typename std::enable_if>::value || + std::is_same>::value>::type +impl_read_value_from_file(std::ifstream &file, ordinal_type &row, ordinal_type &col, T &val) { + typename T::value_type r, i; + file >> row >> col >> r >> i; + val = T(r, i); +} - template - inline - typename std::enable_if::value || std::is_same::value>::type - impl_read_value_from_file(std::ifstream &file, - ordinal_type &row, - ordinal_type &col, - T &val) { - file >> row >> col >> val; - } +template +inline typename std::enable_if::value || std::is_same::value>::type +impl_write_value_to_file(std::ofstream &file, const ordinal_type row, const ordinal_type col, const T val, + const ordinal_type w = 10) { + file << std::setw(w) << row << " " << std::setw(w) << col << " " << std::setw(w) << val << std::endl; +} - template - inline - typename std::enable_if >::value || std::is_same >::value>::type - impl_read_value_from_file(std::ifstream &file, - ordinal_type &row, - ordinal_type &col, - T &val) { - typename T::value_type r, i; - file >> row >> col >> r >> i; - val = T(r, i); - } +template +inline typename std::enable_if>::value || + std::is_same>::value>::type +impl_write_value_to_file(std::ofstream &file, const ordinal_type row, const ordinal_type col, const T val, + const ordinal_type w = 10) { + file << std::setw(w) << row << " " << std::setw(w) << col << " " << std::setw(w) << val.real() << " " + << std::setw(w) << val.imag() << std::endl; +} - template - inline - typename std::enable_if::value || std::is_same::value>::type - impl_write_value_to_file(std::ofstream &file, - const ordinal_type row, - const ordinal_type col, - const T val, - const ordinal_type w = 10) { - file << std::setw(w) << row << " " - << std::setw(w) << col << " " - << std::setw(w) << val << std::endl; - } - - template - inline - typename std::enable_if >::value || std::is_same >::value>::type - impl_write_value_to_file(std::ofstream &file, - const ordinal_type row, - const ordinal_type col, - const T val, - const ordinal_type w = 10) { - file << std::setw(w) << row << " " - << std::setw(w) << col << " " - << std::setw(w) << val.real() << " " - << std::setw(w) << val.imag() << std::endl; - } - - template - struct MatrixMarket { - - /// \brief matrix market reader - template - static void - read(const std::string &filename, - CrsMatrixBase &A, - const ordinal_type sanitize = 0, - const ordinal_type verbose = 0) { - static_assert(Kokkos::Impl::MemorySpaceAccess< - Kokkos::HostSpace, - typename DeviceType::memory_space >::assignable, - "DeviceType is not assignable from HostSpace" ); - - Kokkos::Timer timer; - - timer.reset(); - - std::ifstream file; - file.open(filename); - - // reading mm header - ordinal_type m, n; - size_type nnz, nnz_input; - bool symmetry = false, hermitian = false; //, cmplx = false; - { - std::string header; - std::getline(file, header); - while (file.good()) { - char c = file.peek(); - if (c == '%' || c == '\n') { - file.ignore(256, '\n'); - continue; - } - break; +template struct MatrixMarket { + + /// \brief matrix market reader + template + static void read(const std::string &filename, CrsMatrixBase &A, + const ordinal_type sanitize = 0, const ordinal_type verbose = 0) { + static_assert(Kokkos::Impl::MemorySpaceAccess::assignable, + "DeviceType is not assignable from HostSpace"); + + Kokkos::Timer timer; + + timer.reset(); + + std::ifstream file; + file.open(filename); + + // reading mm header + ordinal_type m, n; + size_type nnz, nnz_input; + bool symmetry = false, hermitian = false; //, cmplx = false; + { + std::string header; + std::getline(file, header); + while (file.good()) { + char c = file.peek(); + if (c == '%' || c == '\n') { + file.ignore(256, '\n'); + continue; } - std::transform(header.begin(), header.end(), header.begin(), ::tolower); - symmetry = (header.find("symmetric") != std::string::npos || - header.find("hermitian") != std::string::npos); - - hermitian = (header.find("hermitian") != std::string::npos); - - file >> m >> n >> nnz; + break; } + std::transform(header.begin(), header.end(), header.begin(), ::tolower); + symmetry = (header.find("symmetric") != std::string::npos || header.find("hermitian") != std::string::npos); - // read data into coo format - const ordinal_type mm_base = 1; - - typedef ValueType value_type; - typedef Coo ijv_type; - std::vector mm; - { - std::vector mm_org; - mm_org.reserve(nnz*(symmetry ? 2 : 1)); - for (size_type i=0;i> m >> n >> nnz; + } - row -= mm_base; - col -= mm_base; + // read data into coo format + const ordinal_type mm_base = 1; - mm_org.push_back(ijv_type(row, col, val)); - if (symmetry && row != col) { - value_type conj_val; impl_conj_val(val, conj_val); - mm_org.push_back(ijv_type(col, row, hermitian ? conj_val : val)); - } + typedef ValueType value_type; + typedef Coo ijv_type; + std::vector mm; + { + std::vector mm_org; + mm_org.reserve(nnz * (symmetry ? 2 : 1)); + for (size_type i = 0; i < nnz; ++i) { + ordinal_type row, col; + value_type val; + + impl_read_value_from_file(file, row, col, val); + + row -= mm_base; + col -= mm_base; + + mm_org.push_back(ijv_type(row, col, val)); + if (symmetry && row != col) { + value_type conj_val; + impl_conj_val(val, conj_val); + mm_org.push_back(ijv_type(col, row, hermitian ? conj_val : val)); } - std::sort(mm_org.begin(), mm_org.end(), std::less()); - - // update nnz (this is the nnz from input matrix) - nnz = mm_org.size(); - nnz_input = nnz; - - // copy to mm - mm.reserve(nnz); - if (sanitize) { - for (size_type i=0;i()); + + // update nnz (this is the nnz from input matrix) + nnz = mm_org.size(); + nnz_input = nnz; + + // copy to mm + mm.reserve(nnz); + if (sanitize) { + for (size_type i = 0; i < nnz; ++i) { + bool is_zero; + impl_is_zero(mm_org[i].val, is_zero); + if (!is_zero) mm.push_back(mm_org[i]); - nnz = mm.size(); } + nnz = mm.size(); + } else { + for (size_type i = 0; i < nnz; ++i) + mm.push_back(mm_org[i]); + nnz = mm.size(); } + } - // change mm to crs - Kokkos::View ap("ap", m+1); - Kokkos::View aj("aj", nnz); - Kokkos::View ax("ax", nnz); - { - ordinal_type icnt = 0; - size_type jcnt = 0; - ijv_type prev = mm[0]; - - ap[icnt++] = 0; - aj[jcnt] = prev.j; - ax[jcnt++] = prev.val; - - for (auto it=(mm.begin()+1);it<(mm.end());++it) { - const ijv_type aij = (*it); - - if (aij.i != prev.i) - ap[icnt++] = jcnt; - - if (aij == prev) { - aj[jcnt] = aij.j; - ax[jcnt] += aij.val; - } else { - aj[jcnt] = aij.j; - ax[jcnt++] = aij.val; - } - prev = aij; + // change mm to crs + Kokkos::View ap("ap", m + 1); + Kokkos::View aj("aj", nnz); + Kokkos::View ax("ax", nnz); + { + ordinal_type icnt = 0; + size_type jcnt = 0; + ijv_type prev = mm[0]; + + ap[icnt++] = 0; + aj[jcnt] = prev.j; + ax[jcnt++] = prev.val; + + for (auto it = (mm.begin() + 1); it < (mm.end()); ++it) { + const ijv_type aij = (*it); + + if (aij.i != prev.i) + ap[icnt++] = jcnt; + + if (aij == prev) { + aj[jcnt] = aij.j; + ax[jcnt] += aij.val; + } else { + aj[jcnt] = aij.j; + ax[jcnt++] = aij.val; } - ap[icnt++] = jcnt; - nnz = jcnt; + prev = aij; } + ap[icnt++] = jcnt; + nnz = jcnt; + } - // create crs matrix view - A.clear(); - A.setExternalMatrix(m, n, nnz, ap, aj, ax); - - const double t = timer.seconds(); - if (verbose) { - - printf("Summary: MatrixMarket\n"); - printf("=====================\n"); - printf(" File: %s\n", filename.c_str()); - printf(" Time\n"); - printf(" time for reading A: %10.6f s\n", t); - printf("\n"); - printf(" Sparse Matrix (%s) \n", (symmetry ? "symmetric" : "non-symmetric")); - printf(" number of rows: %10d\n", m); - printf(" number of cols: %10d\n", n); - printf(" number of nonzeros from input: %10d\n", ordinal_type(nnz_input)); - printf(" number of nonzeros after sanitized: %10d\n", ordinal_type(nnz)); - printf("\n"); - } + // create crs matrix view + A.clear(); + A.setExternalMatrix(m, n, nnz, ap, aj, ax); + + const double t = timer.seconds(); + if (verbose) { + + printf("Summary: MatrixMarket\n"); + printf("=====================\n"); + printf(" File: %s\n", filename.c_str()); + printf(" Time\n"); + printf(" time for reading A: %10.6f s\n", t); + printf("\n"); + printf(" Sparse Matrix (%s) \n", (symmetry ? "symmetric" : "non-symmetric")); + printf(" number of rows: %10d\n", m); + printf(" number of cols: %10d\n", n); + printf(" number of nonzeros from input: %10d\n", ordinal_type(nnz_input)); + printf(" number of nonzeros after sanitized: %10d\n", ordinal_type(nnz)); + printf("\n"); } + } - /// \brief matrix marker writer - template - static void - write(std::ofstream &file, - const CrsMatrixBase &A, - const int uplo = 0, // 0 - all, 1 - upper, 2 - lower - const std::string comment = "%% Tacho::MatrixMarket::Export") { - static_assert(Kokkos::Impl::MemorySpaceAccess< - Kokkos::HostSpace, - typename DeviceType::memory_space - >::assignable, - "DeviceType is not assignable from HostSpace" ); - - typedef ValueType value_type; - constexpr bool is_complex = (std::is_same >::value || - std::is_same >::value); - - std::streamsize prec = file.precision(); - file.precision(16); - file << std::scientific; - - { - file << "%%MatrixMarket matrix coordinate " - << (is_complex ? "complex " : "real ") - << std::endl; - file << comment << std::endl; - } - // cnt nnz - size_type nnz = 0; - { - for (ordinal_type i=0;i= aj) ++nnz; - if (uplo == 0) ++nnz; - } + /// \brief matrix marker writer + template + static void write(std::ofstream &file, const CrsMatrixBase &A, + const int uplo = 0, // 0 - all, 1 - upper, 2 - lower + const std::string comment = "%% Tacho::MatrixMarket::Export") { + static_assert(Kokkos::Impl::MemorySpaceAccess::assignable, + "DeviceType is not assignable from HostSpace"); + + typedef ValueType value_type; + constexpr bool is_complex = (std::is_same>::value || + std::is_same>::value); + + std::streamsize prec = file.precision(); + file.precision(16); + file << std::scientific; + + { + file << "%%MatrixMarket matrix coordinate " << (is_complex ? "complex " : "real ") << std::endl; + file << comment << std::endl; + } + // cnt nnz + size_type nnz = 0; + { + for (ordinal_type i = 0; i < A.NumRows(); ++i) { + const size_type jbegin = A.RowPtrBegin(i), jend = A.RowPtrEnd(i); + for (size_type j = jbegin; j < jend; ++j) { + const auto aj = A.Col(j); + if (uplo == 1 && i <= aj) + ++nnz; + if (uplo == 2 && i >= aj) + ++nnz; + if (uplo == 0) + ++nnz; } - file << A.NumRows() << " " << A.NumCols() << " " << nnz << std::endl; } + file << A.NumRows() << " " << A.NumCols() << " " << nnz << std::endl; + } - const int w = 10; - { - for (ordinal_type i=0;i= aj) flag = true; - if (uplo == 0) flag = true; - if (flag) { - value_type val = A.Value(j); - impl_write_value_to_file(file, i+1, aj+1, val, w); - } + const int w = 10; + { + for (ordinal_type i = 0; i < A.NumRows(); ++i) { + const size_type jbegin = A.RowPtrBegin(i), jend = A.RowPtrEnd(i); + for (size_type j = jbegin; j < jend; ++j) { + const auto aj = A.Col(j); + bool flag = false; + if (uplo == 1 && i <= aj) + flag = true; + if (uplo == 2 && i >= aj) + flag = true; + if (uplo == 0) + flag = true; + if (flag) { + value_type val = A.Value(j); + impl_write_value_to_file(file, i + 1, aj + 1, val, w); } } } - - file.unsetf(std::ios::scientific); - file.precision(prec); } - }; -} + file.unsetf(std::ios::scientific); + file.precision(prec); + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/Tacho_Pardiso.hpp b/packages/shylu/shylu_node/tacho/src/Tacho_Pardiso.hpp index 3508785f5d07..09e2372386ac 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho_Pardiso.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho_Pardiso.hpp @@ -1,7 +1,7 @@ #ifndef __TACHO_EXAMPLE_PARDISO_HPP__ #define __TACHO_EXAMPLE_PARDISO_HPP__ -#if defined (__INTEL_MKL__) +#if defined(__INTEL_MKL__) using namespace std; #include "Tacho_Util.hpp" @@ -9,217 +9,204 @@ using namespace std; namespace Tacho { - class Pardiso { - public: - enum Phase { - Analyze = 11, - AnalyzeFactorize = 12, - AnalyzeFactorizeSolve = 13, - Factorize = 22, - FactorizeSolve = 23, - Solve = 33, - ReleaseInternal = 0, // release internal memory for LU matrix number MNUM - ReleaseAll = -1 // release all internal memory for all matrices - }; - - private: - int _mtype, // matrix type 2 - spd, 4 - hpd - _phase; - - int _maxfct, // maxfct - maximum number of factors (=1) - _mnum, // mnum - actual matrix for the solution phase (=1) - _msglvl; // msglvl - print out level 0: nothing, 1: statistics - - // parameters - int _iparm[64]; - double _dparm[64]; - - // internal data address pointers: NEVER TOUCH - void *_pt[64]; - - // Translate Fortran index to C-index - int Fort(const int i) const { return i-1; } - - // Matrix type - // 1 - real structrue sym - // 2 - real sym pos def - // -2 - real sym indef - // 3 - complex structrue sym - // 4 - complex hermitian pos def - // -4 - complex and hermitian indef - // 6 - complex and sym - // 11 - real and nonsym - // 13 - complex and nonsym - template - void setMatrixType(); - - public: - Pardiso() - : _mtype(0), - _phase(0) { - } - - void setDefaultParameters() { - // initialize arrays - for (int i=0;i<64;++i) { - _iparm[i] = 0; - _dparm[i] = 0.0; - _pt[i] = 0; - } - - _maxfct = 1; - _mnum = 1; - _msglvl = 1; - } - - void setParameter(const int id, const int value) { _iparm[Fort(id)] = value; } - void setParameter(const int id, const double value) { _dparm[Fort(id)] = value; } - - template - int init() { - int ierr = 0; - - setMatrixType(); - setDefaultParameters(); - - // load default param - pardisoinit(_pt, &_mtype, _iparm); - - // overload default parameters - - // setParameter( 1, 1); // default param: 0 - default, 1 - user provided - // setParameter( 2, 0); // reordering: 0 - mindegreem, 2 - nd, 3 - parallel nd - setParameter(27, 1); // mat check: 0 - no, 1 - check - setParameter(35, 1); // row and col index: 0 - fortran, 1 - CXX - - return ierr; - } - - ostream& showErrorCode(ostream &os) const { - os << " 0 No error" << std::endl - << "- 1 Input inconsistent" << std::endl - << "- 2 Not enough memory" << std::endl - << "- 3 Reordering problem" << std::endl - << "- 4 Zero pivot" << std::endl - << "- 5 Unclassified (internal) error" << std::endl - << "- 6 Preordering fail (matrix types 11,13)" << std::endl - << "- 7 Diagonal matrix problem" << std::endl - << "- 8 32-bit integer overflow" << std::endl - << "- 10 No license file pardiso.lic found" << std::endl - << "- 11 License expired " << std::endl - << "- 12 Wrong user name or host" << std::endl - << "-100 over, Krylov fail" << std::endl ; - return os; - } +class Pardiso { +public: + enum Phase { + Analyze = 11, + AnalyzeFactorize = 12, + AnalyzeFactorizeSolve = 13, + Factorize = 22, + FactorizeSolve = 23, + Solve = 33, + ReleaseInternal = 0, // release internal memory for LU matrix number MNUM + ReleaseAll = -1 // release all internal memory for all matrices + }; - ostream& showStat(ostream &os, const Phase phase) const { - switch(phase) { - case Analyze: - os << "- Phase: Analyze -" << std::endl - << "Number of perturbed pivots = " << _iparm[Fort(14)] << std::endl - << "Number of peak memory symbolic = " << _iparm[Fort(15)] << std::endl - << "Number of permenant memory symbolic = " << _iparm[Fort(16)] << " KB " << std::endl; - break; - case Factorize: - os << "- Phase: Factorize -" << std::endl - << "Peak memory used in factorization = " << max(_iparm[Fort(15)], _iparm[Fort(16)]+_iparm[Fort(17)]) << " KB "<< std::endl - << "Memory numerical factorization = " << _iparm[Fort(17)] << " KB " << std::endl - << "Number of nonzeros in factors = " << _iparm[Fort(18)] << std::endl - << "Number of factorization MFLOP = " << _iparm[Fort(19)] << std::endl - << "MFLOPs = " << _iparm[Fort(19)] << std::endl; - break; - case Solve: - os << "- Phase: Solve -" << std::endl - << "Number of iterative refinements = " << _iparm[Fort(7)] << std::endl; - break; - case AnalyzeFactorize: - showStat(os, Analyze); - showStat(os, Factorize); - break; - case AnalyzeFactorizeSolve: - showStat(os, Analyze); - showStat(os, Factorize); - showStat(os, Solve); - break; - case FactorizeSolve: - showStat(os, Factorize); - showStat(os, Solve); - break; - default: - os << "- Phase: " << phase << " -" << std::endl - << "Nothing serious in this phase" << std::endl; - break; - } - return os; - } +private: + int _mtype, // matrix type 2 - spd, 4 - hpd + _phase; - private: - int _ndof, *_ia, *_ja, *_pivot, _nrhs; - double *_a, *_b, *_x; - - public: - void setProblem(const int ndof, - double *a, - int *ia, - int *ja, - int *pivot, - const int nrhs, - double *b, - double *x) { - _ndof = ndof; - _a = a; - _ia = ia; - _ja = ja; - _pivot = pivot; - _nrhs = nrhs; - _b = b; - _x = x; + int _maxfct, // maxfct - maximum number of factors (=1) + _mnum, // mnum - actual matrix for the solution phase (=1) + _msglvl; // msglvl - print out level 0: nothing, 1: statistics + + // parameters + int _iparm[64]; + double _dparm[64]; + + // internal data address pointers: NEVER TOUCH + void *_pt[64]; + + // Translate Fortran index to C-index + int Fort(const int i) const { return i - 1; } + + // Matrix type + // 1 - real structrue sym + // 2 - real sym pos def + // -2 - real sym indef + // 3 - complex structrue sym + // 4 - complex hermitian pos def + // -4 - complex and hermitian indef + // 6 - complex and sym + // 11 - real and nonsym + // 13 - complex and nonsym + template void setMatrixType(); + +public: + Pardiso() : _mtype(0), _phase(0) {} + + void setDefaultParameters() { + // initialize arrays + for (int i = 0; i < 64; ++i) { + _iparm[i] = 0; + _dparm[i] = 0.0; + _pt[i] = 0; } - int run(int phase, int msglvl = 1) { - int ierr = 0; - - _msglvl = msglvl; - pardiso(_pt, - &_maxfct, &_mnum, &_mtype, - &phase, - &_ndof, - _a, _ia, _ja, - _pivot, &_nrhs, - _iparm, &_msglvl, - _b, _x, - &ierr); - - return ierr; + _maxfct = 1; + _mnum = 1; + _msglvl = 1; + } + + void setParameter(const int id, const int value) { _iparm[Fort(id)] = value; } + void setParameter(const int id, const double value) { _dparm[Fort(id)] = value; } + + template int init() { + int ierr = 0; + + setMatrixType(); + setDefaultParameters(); + + // load default param + pardisoinit(_pt, &_mtype, _iparm); + + // overload default parameters + + // setParameter( 1, 1); // default param: 0 - default, 1 - user provided + // setParameter( 2, 0); // reordering: 0 - mindegreem, 2 - nd, 3 - parallel nd + setParameter(27, 1); // mat check: 0 - no, 1 - check + setParameter(35, 1); // row and col index: 0 - fortran, 1 - CXX + + return ierr; + } + + ostream &showErrorCode(ostream &os) const { + os << " 0 No error" << std::endl + << "- 1 Input inconsistent" << std::endl + << "- 2 Not enough memory" << std::endl + << "- 3 Reordering problem" << std::endl + << "- 4 Zero pivot" << std::endl + << "- 5 Unclassified (internal) error" << std::endl + << "- 6 Preordering fail (matrix types 11,13)" << std::endl + << "- 7 Diagonal matrix problem" << std::endl + << "- 8 32-bit integer overflow" << std::endl + << "- 10 No license file pardiso.lic found" << std::endl + << "- 11 License expired " << std::endl + << "- 12 Wrong user name or host" << std::endl + << "-100 over, Krylov fail" << std::endl; + return os; + } + + ostream &showStat(ostream &os, const Phase phase) const { + switch (phase) { + case Analyze: + os << "- Phase: Analyze -" << std::endl + << "Number of perturbed pivots = " << _iparm[Fort(14)] << std::endl + << "Number of peak memory symbolic = " << _iparm[Fort(15)] << std::endl + << "Number of permenant memory symbolic = " << _iparm[Fort(16)] << " KB " << std::endl; + break; + case Factorize: + os << "- Phase: Factorize -" << std::endl + << "Peak memory used in factorization = " << max(_iparm[Fort(15)], _iparm[Fort(16)] + _iparm[Fort(17)]) + << " KB " << std::endl + << "Memory numerical factorization = " << _iparm[Fort(17)] << " KB " << std::endl + << "Number of nonzeros in factors = " << _iparm[Fort(18)] << std::endl + << "Number of factorization MFLOP = " << _iparm[Fort(19)] << std::endl + << "MFLOPs = " << _iparm[Fort(19)] << std::endl; + break; + case Solve: + os << "- Phase: Solve -" << std::endl << "Number of iterative refinements = " << _iparm[Fort(7)] << std::endl; + break; + case AnalyzeFactorize: + showStat(os, Analyze); + showStat(os, Factorize); + break; + case AnalyzeFactorizeSolve: + showStat(os, Analyze); + showStat(os, Factorize); + showStat(os, Solve); + break; + case FactorizeSolve: + showStat(os, Factorize); + showStat(os, Solve); + break; + default: + os << "- Phase: " << phase << " -" << std::endl << "Nothing serious in this phase" << std::endl; + break; } - - }; - - // Pardiso mtype - // 1 - real structure sym - // 2 - real sym posdef - // -2 - real sym indef - // 3 - complex structure sym - // 4 - complex her posdef - // -4 - complex her indef - // 6 - complex structure sym - // 11 - real non sym - // 13 - complex non sym - // - // pardiso does not like 2 4; for not we use 1 and 3 or 11 and 13. - - template<> - void Pardiso::setMatrixType() { _mtype = 2; setParameter(28, 1); } // SPD - - template<> - void Pardiso::setMatrixType() { _mtype = 2; setParameter(28, 0); } // SPD - - template<> - void Pardiso::setMatrixType, 2>() { _mtype = 3; setParameter(28, 1); } // HPD - - template<> - void Pardiso::setMatrixType, 2>() { _mtype = 3; setParameter(28, 0); } // HPD -} + return os; + } + +private: + int _ndof, *_ia, *_ja, *_pivot, _nrhs; + double *_a, *_b, *_x; + +public: + void setProblem(const int ndof, double *a, int *ia, int *ja, int *pivot, const int nrhs, double *b, double *x) { + _ndof = ndof; + _a = a; + _ia = ia; + _ja = ja; + _pivot = pivot; + _nrhs = nrhs; + _b = b; + _x = x; + } + + int run(int phase, int msglvl = 1) { + int ierr = 0; + + _msglvl = msglvl; + pardiso(_pt, &_maxfct, &_mnum, &_mtype, &phase, &_ndof, _a, _ia, _ja, _pivot, &_nrhs, _iparm, &_msglvl, _b, _x, + &ierr); + + return ierr; + } +}; + +// Pardiso mtype +// 1 - real structure sym +// 2 - real sym posdef +// -2 - real sym indef +// 3 - complex structure sym +// 4 - complex her posdef +// -4 - complex her indef +// 6 - complex structure sym +// 11 - real non sym +// 13 - complex non sym +// +// pardiso does not like 2 4; for not we use 1 and 3 or 11 and 13. + +template <> void Pardiso::setMatrixType() { + _mtype = 2; + setParameter(28, 1); +} // SPD + +template <> void Pardiso::setMatrixType() { + _mtype = 2; + setParameter(28, 0); +} // SPD + +template <> void Pardiso::setMatrixType, 2>() { + _mtype = 3; + setParameter(28, 1); +} // HPD + +template <> void Pardiso::setMatrixType, 2>() { + _mtype = 3; + setParameter(28, 0); +} // HPD +} // namespace Tacho #endif #endif diff --git a/packages/shylu/shylu_node/tacho/src/Tacho_Solver.hpp b/packages/shylu/shylu_node/tacho/src/Tacho_Solver.hpp index 7e7d2184c9d0..65de1b59aa05 100644 --- a/packages/shylu/shylu_node/tacho/src/Tacho_Solver.hpp +++ b/packages/shylu/shylu_node/tacho/src/Tacho_Solver.hpp @@ -10,8 +10,8 @@ namespace Tacho { - template - using Solver = Driver::type >; +template +using Solver = Driver::type>; } diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI.in similarity index 57% rename from packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_Serial.cpp rename to packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI.in index af16d62f419c..c8d980955aed 100644 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_Serial.cpp +++ b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI.in @@ -1,13 +1,15 @@ #include "Kokkos_Core.hpp" +#if @ETI_WITH_TASK@ +#define TACHO_ENABLE_KOKKOS_TASK +#endif + #include "Tacho.hpp" #include "Tacho_Driver.hpp" #include "Tacho_Driver_Impl.hpp" namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_value_type = double; - using eti_device_type = typename UseThisDevice::type; + using eti_value_type = @ETI_VALUE_TYPE@; + using eti_device_type = @ETI_DEVICE_TYPE@; template struct Driver; -#endif } diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_Cuda.cpp deleted file mode 100644 index 8bd40f7a4aa2..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_Cuda.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_value_type = Kokkos::complex; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_OpenMP.cpp deleted file mode 100644 index c246a467dd36..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_OpenMP.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_value_type = Kokkos::complex; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_Serial.cpp deleted file mode 100644 index 765f2a4a57b1..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_double_Serial.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_value_type = Kokkos::complex; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_Cuda.cpp deleted file mode 100644 index 71dcc5d0be68..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_Cuda.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_value_type = Kokkos::complex; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_OpenMP.cpp deleted file mode 100644 index 057b6fbb029c..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_OpenMP.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_value_type = Kokkos::complex; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_Serial.cpp deleted file mode 100644 index 6560b10fb7f5..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_complex_float_Serial.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_value_type = Kokkos::complex; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_Cuda.cpp deleted file mode 100644 index d504af78be59..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_Cuda.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_value_type = double; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_OpenMP.cpp deleted file mode 100644 index 1c51420dc0a1..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_double_OpenMP.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_value_type = double; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_Cuda.cpp deleted file mode 100644 index 951c99703d97..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_Cuda.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_value_type = float; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_OpenMP.cpp deleted file mode 100644 index 8bd1a1305d37..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_OpenMP.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_value_type = float; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_Serial.cpp deleted file mode 100644 index 8d1af8fbc159..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Driver_ETI_float_Serial.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Driver.hpp" -#include "Tacho_Driver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_value_type = float; - using eti_device_type = typename UseThisDevice::type; - template struct Driver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Cuda.cpp deleted file mode 100644 index dba7d524603b..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_OpenMP.cpp deleted file mode 100644 index e45832f5ff0b..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Serial.cpp deleted file mode 100644 index 4c702cd79a5d..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_double_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Cuda.cpp deleted file mode 100644 index 83b29f6c9c2f..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_OpenMP.cpp deleted file mode 100644 index 6727b2d3454e..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Serial.cpp deleted file mode 100644 index 20495e2f322a..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_complex_float_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Cuda.cpp deleted file mode 100644 index 0182872153e7..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_OpenMP.cpp deleted file mode 100644 index 91a1d92950bc..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Serial.cpp deleted file mode 100644 index 95e666beec44..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_double_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Cuda.cpp deleted file mode 100644 index 4a761f27f23f..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_OpenMP.cpp deleted file mode 100644 index 3301b45548f1..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Serial.cpp deleted file mode 100644 index daf708a521b4..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_ChseLevTaskScheduler_float_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::ChaseLevTaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Cuda.cpp deleted file mode 100644 index a34c33a4022d..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_OpenMP.cpp deleted file mode 100644 index 1955b3f2570f..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Serial.cpp deleted file mode 100644 index 8e8ccdb4cc08..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_double_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Cuda.cpp deleted file mode 100644 index 7dfbac5d913b..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_OpenMP.cpp deleted file mode 100644 index 13d66fa30d53..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Serial.cpp deleted file mode 100644 index 2c4b2c77f67f..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_complex_float_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Cuda.cpp deleted file mode 100644 index cf87e221a7ed..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_OpenMP.cpp deleted file mode 100644 index 80fe7c71c306..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Serial.cpp deleted file mode 100644 index 6a61443e96c2..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_double_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Cuda.cpp deleted file mode 100644 index 6f223d3c7af3..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Cuda.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_OpenMP.cpp deleted file mode 100644 index a5548a35a438..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_OpenMP.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Serial.cpp deleted file mode 100644 index 1d0d6be85066..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskSchedulerMultiple_float_Serial.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskSchedulerMultiple; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Cuda.cpp deleted file mode 100644 index 15015c662cd4..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Cuda.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_OpenMP.cpp deleted file mode 100644 index bdeb3c6f5622..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_OpenMP.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Serial.cpp deleted file mode 100644 index 6d1dc116b4be..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_double_Serial.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Cuda.cpp deleted file mode 100644 index 46996bdb3343..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Cuda.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_OpenMP.cpp deleted file mode 100644 index 07b493c0b84e..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_OpenMP.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Serial.cpp deleted file mode 100644 index 0417a700ea96..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_complex_float_Serial.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver,eti_scheduler_type>; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_Cuda.cpp deleted file mode 100644 index 81682b95868f..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_Cuda.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_OpenMP.cpp deleted file mode 100644 index 0241a855f190..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_OpenMP.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_Serial.cpp deleted file mode 100644 index 8225354e7184..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_double_Serial.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_Cuda.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_Cuda.cpp deleted file mode 100644 index efc996983eff..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_Cuda.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_CUDA) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_OpenMP.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_OpenMP.cpp deleted file mode 100644 index 343588251a43..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_OpenMP.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_OPENMP) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_Serial.cpp b/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_Serial.cpp deleted file mode 100644 index 7b1d8bc0a029..000000000000 --- a/packages/shylu/shylu_node/tacho/src/eti/Tacho_Solver_ETI_TaskScheduler_float_Serial.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "Kokkos_Core.hpp" - -#include "Tacho.hpp" -#include "Tacho_Solver.hpp" -#include "Tacho_Solver_Impl.hpp" - -namespace Tacho { -#if defined(KOKKOS_ENABLE_SERIAL) - using eti_scheduler_type = Kokkos::TaskScheduler; - template struct Solver; -#endif -} diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation.hpp index 3419cfacf360..e71317a158c7 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation.hpp @@ -9,13 +9,12 @@ namespace Tacho { - /// - /// Apply Permutation - /// +/// +/// Apply Permutation +/// - /// various implementation for different uplo and algo parameters - template - struct ApplyPermutation; -} +/// various implementation for different uplo and algo parameters +template struct ApplyPermutation; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_Internal.hpp index 4ddfd9fcb98e..c689e1820966 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_Internal.hpp @@ -1,93 +1,68 @@ #ifndef __TACHO_APPLY_PERMUTATION_INTERNAL_HPP__ #define __TACHO_APPLY_PERMUTATION_INTERNAL_HPP__ - /// \file Tacho_ApplyPermutation_Internal.hpp /// \brief Apply pivots /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - /// row exchange - template<> - struct ApplyPermutation { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeB &B) { - if (A.extent(0) == P.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - if (n == 1) { /// vector - for (ordinal_type i=0;i struct ApplyPermutation { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ViewTypeA &A, const ViewTypeP &P, const ViewTypeB &B) { + if (A.extent(0) == P.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + if (n == 1) { /// vector + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type idx = P(i); + B(i, 0) = A(idx, 0); + } + } else { /// matrix + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type idx = P(i); + for (ordinal_type j = 0; j < n; ++j) { + B(i, j) = A(idx, j); } } } - } else { - printf("Error: ApplyPermutation A extent(0) does not match to P extent(0)\n"); } - return 0; + } else { + printf("Error: ApplyPermutation A extent(0) does not match to P extent(0)\n"); } + return 0; + } - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeB &B) { + template + KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const ViewTypeA &A, const ViewTypeP &P, + const ViewTypeB &B) { #if defined(__CUDA_ARCH__) - if (A.extent(0) == P.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - if (n == 1) { /// vector - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m), - [&](const ordinal_type &i) { - const ordinal_type idx = P(i); - B(i,0) = A(idx,0); - }); - } else { /// matrix - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, m), - [&](const ordinal_type &i) { - const ordinal_type idx = P(i); - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, n), - [&](const ordinal_type &j) { - B(i,j) = A(idx,j); - }); - }); - } + if (A.extent(0) == P.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + if (n == 1) { /// vector + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const ordinal_type &i) { + const ordinal_type idx = P(i); + B(i, 0) = A(idx, 0); + }); + } else { /// matrix + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m * n), [&](const ordinal_type &ij) { + const ordinal_type i = ij % m, j = ij / m; + const ordinal_type idx = P(i); + B(i, j) = A(idx, j); + }); } - } else { - printf("Error: ApplyPermutation A extent(0) does not match to P extent(0)\n"); } + } else { + printf("Error: ApplyPermutation A extent(0) does not match to P extent(0)\n"); + } #else - invoke(A, B, P); + invoke(A, P, B); #endif - return 0; - } - - - }; - + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_OnDevice.hpp index 086868ce2513..ae146902d15b 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPermutation_OnDevice.hpp @@ -1,67 +1,48 @@ #ifndef __TACHO_APPLY_PERMUTATION_ON_DEVICE_HPP__ #define __TACHO_APPLY_PERMUTATION_ON_DEVICE_HPP__ - /// \file Tacho_ApplyPivots_OnDevice.hpp /// \brief Apply pivots on device /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - template <> - struct ApplyPermutation { - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeB &B) { - //typedef typename ViewTypeA::non_const_value_type value_type; +template <> struct ApplyPermutation { + template + inline static int invoke(const MemberType &member, const ViewTypeA &A, const ViewTypeP &P, const ViewTypeB &B) { + // typedef typename ViewTypeA::non_const_value_type value_type; + + const ordinal_type m = A.extent(0), n = A.extent(1), plen = P.extent(0); + + if (m == plen) { + if (A.span() > 0) { + using exec_space = MemberType; + const auto &exec_instance = member; - const ordinal_type - m = A.extent(0), - n = A.extent(1), - plen = P.extent(0); - - if (m == plen) { + const Kokkos::RangePolicy policy(exec_instance, 0, m * n); if (A.span() > 0) { - using exec_space = MemberType; - const auto exec_instance = member; - if (n == 1) { - Kokkos::RangePolicy policy(exec_instance, 0, m); - Kokkos::parallel_for - (policy, - KOKKOS_LAMBDA(const ordinal_type &i) { - const ordinal_type idx = P(i); - B(i,0) = A(idx,0); - }); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &i) { + const ordinal_type idx = P(i); + B(i, 0) = A(idx, 0); + }); } else { - using policy_type = Kokkos::TeamPolicy; - policy_type policy(exec_instance, m, Kokkos::AUTO); - Kokkos::parallel_for - (policy, - KOKKOS_LAMBDA(const typename policy_type::member_type &member) { - const ordinal_type i = member.league_rank(); - const ordinal_type idx = P(i); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n), - [&](const ordinal_type &j) { - B(i,j) = A(idx,j); - }); - }); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &ij) { + const ordinal_type i = ij % m, j = ij / m; + const ordinal_type idx = P(i); + B(i, j) = A(idx, j); + }); } } - } else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); } - return 0; + } else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots.hpp index ad2f4fe73b4b..8fcd521d7b01 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots.hpp @@ -2,20 +2,20 @@ #define __TACHO_APPLY_PIVOTS_HPP__ /// \file Tacho_ApplyPivots.hpp -/// \brief Front interface to apply pivots +/// \brief Front interface to apply pivots /// \author Kyungjoo Kim (kyukim@sandia.gov) #include "Tacho_Util.hpp" namespace Tacho { - /// - /// Apply Pivots - /// +/// +/// Apply Pivots +/// - /// various implementation for different uplo and algo parameters - template - struct ApplyPivots; -} +/// various implementation for different uplo and algo parameters +template struct ApplyPivots; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_Internal.hpp index 3e8f51f887ed..1e87e8985d48 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_Internal.hpp @@ -1,168 +1,139 @@ #ifndef __TACHO_APPLY_PIVOTS_INTERNAL_HPP__ #define __TACHO_APPLY_PIVOTS_INTERNAL_HPP__ - /// \file Tacho_ApplyPivots_Internal.hpp /// \brief Apply pivots /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - /// row exchange - template<> - struct ApplyPivots { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(const ViewTypeP &P, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; - - if (A.extent(0) == P.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - for (ordinal_type i=0;i struct ApplyPivots { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ViewTypeP &P, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; + + if (A.extent(0) == P.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type piv = P(i); + if (piv == 0) { + /// no pivot + } else { + /// 1x1 pivot + /// 2x2 pivots are already converted to 1x1 type + const ordinal_type p = i + piv; + for (ordinal_type j = 0; j < n; ++j) { + const value_type tmp = A(i, j); + A(i, j) = A(p, j); + A(p, j) = tmp; + } + } + } + } + } else { + printf("Error: ApplyPivots A is not square\n"); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeP &P, const ViewTypeA &A) { +#if defined(__CUDA_ARCH__) + typedef typename ViewTypeA::non_const_value_type value_type; + + if (A.extent(0) == P.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, n), [&](const ordinal_type &j) { + for (ordinal_type i = 0; i < m; ++i) { const ordinal_type piv = P(i); if (piv == 0) { /// no pivot } else { /// 1x1 pivot - /// 2x2 pivots are already converted to 1x1 type const ordinal_type p = i + piv; - for (ordinal_type j=0;j A is not square\n"); + }); } - return 0; + } else { + printf("Error: ApplyPivots A is not square\n"); } - - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeP &P, - const ViewTypeA &A) { -#if defined(__CUDA_ARCH__) - typedef typename ViewTypeA::non_const_value_type value_type; - - if (A.extent(0) == P.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n), - [&](const ordinal_type &j) { - for (ordinal_type i=0;i A is not square\n"); - } #else - invoke(P, A); + invoke(P, A); #endif - return 0; + return 0; + } +}; + +/// row exchange +template <> struct ApplyPivots { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ViewTypeP &P, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; + + if (A.extent(0) == P.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + for (ordinal_type i = (m - 1); i >= 0; --i) { + const ordinal_type piv = P(i); + if (piv == 0) { + /// no pivot + } else { + /// 1x1 pivot + /// 2x2 pivots are already converted to 1x1 type + const ordinal_type p = i + piv; + for (ordinal_type j = 0; j < n; ++j) { + const value_type tmp = A(i, j); + A(i, j) = A(p, j); + A(p, j) = tmp; + } + } + } + } + } else { + printf("Error: ApplyPivots A is not square\n"); } - }; - - /// row exchange - template<> - struct ApplyPivots { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(const ViewTypeP &P, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; - - if (A.extent(0) == P.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - for (ordinal_type i=(m-1);i>=0;--i) { + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeP &P, const ViewTypeA &A) { +#if defined(__CUDA_ARCH__) + typedef typename ViewTypeA::non_const_value_type value_type; + + if (A.extent(0) == P.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, n), [&](const ordinal_type &j) { + for (ordinal_type i = (m - 1); i >= 0; --i) { const ordinal_type piv = P(i); if (piv == 0) { /// no pivot } else { /// 1x1 pivot - /// 2x2 pivots are already converted to 1x1 type const ordinal_type p = i + piv; - for (ordinal_type j=0;j A is not square\n"); + }); } - return 0; + } else { + printf("Error: ApplyPivots A is not square\n"); } - - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeP &P, - const ViewTypeA &A) { -#if defined(__CUDA_ARCH__) - typedef typename ViewTypeA::non_const_value_type value_type; - - if (A.extent(0) == P.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n), - [&](const ordinal_type &j) { - for (ordinal_type i=(m-1);i>=0;--i) { - const ordinal_type piv = P(i); - if (piv == 0) { - /// no pivot - } else { - /// 1x1 pivot - const ordinal_type p = i + piv; - const value_type tmp = A(i,j); - A(i,j) = A(p,j); - A(p,j) = tmp; - } - } - }); - } - } else { - printf("Error: ApplyPivots A is not square\n"); - } #else - invoke(P, A); + invoke(P, A); #endif - return 0; - } - }; - - + return 0; + } +}; // template<> // struct ApplyPivots { @@ -172,8 +143,8 @@ namespace Tacho { // static int // invoke(const ViewTypeP &P, // const ViewTypeA &A) { -// typedef typename ViewTypeA::non_const_value_type value_type; - +// typedef typename ViewTypeA::non_const_value_type value_type; + // if (A.extent(0) == P.extent(0)) { // if (A.span() > 0) { // const ordinal_type m = A.extent(0), n = A.extent(1); @@ -198,7 +169,7 @@ namespace Tacho { // A(i+1,j) = A(p+1,j); // A(p, j) = tmp_a; // A(p+1,j) = tmp_b; -// } +// } // } // } // } @@ -208,7 +179,6 @@ namespace Tacho { // return 0; // } - // template @@ -217,9 +187,9 @@ namespace Tacho { // invoke(MemberType &member, // const ViewTypeP &P, // const ViewTypeA &A) { -// #if defined(__CUDA_ARCH__) -// typedef typename ViewTypeA::non_const_value_type value_type; - +// #if defined(__CUDA_ARCH__) +// typedef typename ViewTypeA::non_const_value_type value_type; + // if (A.extent(0) == P.extent(0)) { // if (A.span() > 0) { // const ordinal_type m = A.extent(0), n = A.extent(1); @@ -258,8 +228,5 @@ namespace Tacho { // } // }; - - - -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_OnDevice.hpp index 4dba8d398851..6f7c04f3a0aa 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ApplyPivots_OnDevice.hpp @@ -1,163 +1,130 @@ #ifndef __TACHO_APPLY_PIVOTS_ON_DEVICE_HPP__ #define __TACHO_APPLY_PIVOTS_ON_DEVICE_HPP__ - /// \file Tacho_ApplyPivots_OnDevice.hpp /// \brief Apply pivots on device /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - template<> - struct ApplyPivots { - template - inline - static int - invoke(MemberType &member, - const ViewTypeP &P, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; +template <> struct ApplyPivots { + template + inline static int invoke(MemberType &member, const ViewTypeP &P, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; + + const ordinal_type m = A.extent(0), n = A.extent(1); + + if (m == P.extent(0)) { + if (A.span() > 0) { + using exec_space = MemberType; + using policy_type = Kokkos::RangePolicy; - const ordinal_type - m = A.extent(0), - n = A.extent(1); - - if (m == P.extent(0)) { - if (A.span() > 0) { - using exec_space = MemberType; - using policy_type = Kokkos::RangePolicy; - - const auto exec_instance = member; - const auto policy = policy_type(exec_instance, 0, n); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const ordinal_type &j) { - for (ordinal_type i=0;i 0) { /// 1x1 pivot const ordinal_type p = piv - 1; - const value_type tmp = A(i,j); - A(i,j) = A(p,j); - A(p,j) = tmp; + const value_type tmp = A(i, j); + A(i, j) = A(p, j); + A(p, j) = tmp; } else { /// 2x2 pivot const int p = -piv - 1; - const value_type tmp_a = A(i,j), tmp_b = A(i+1,j); - A(i ,j) = A(p ,j); - A(i+1,j) = A(p+1,j); - A(p, j) = tmp_a; - A(p+1,j) = tmp_b; + const value_type tmp_a = A(i, j), tmp_b = A(i + 1, j); + A(i, j) = A(p, j); + A(i + 1, j) = A(p + 1, j); + A(p, j) = tmp_a; + A(p + 1, j) = tmp_b; } } - }); - } - } else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); } - return 0; + } else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); } - }; + return 0; + } +}; + +template <> struct ApplyPivots { + template + inline static int invoke(MemberType &member, const ViewTypeP &P, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; + + const ordinal_type m = A.extent(0), n = A.extent(1), plen = P.extent(0); - template<> - struct ApplyPivots { - template - inline - static int - invoke(MemberType &member, - const ViewTypeP &P, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; + if (m == plen) { + if (A.span() > 0) { + using exec_space = MemberType; + using policy_type = Kokkos::RangePolicy; - const ordinal_type - m = A.extent(0), - n = A.extent(1), - plen = P.extent(0); - - if (m == plen) { - if (A.span() > 0) { - using exec_space = MemberType; - using policy_type = Kokkos::RangePolicy; - - const auto exec_instance = member; - const auto policy = policy_type(exec_instance, 0, n); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const ordinal_type &j) { - for (ordinal_type i=0;i struct ApplyPivots { + template + inline static int invoke(MemberType &member, const ViewTypeP &P, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; - template<> - struct ApplyPivots { - template - inline - static int - invoke(MemberType &member, - const ViewTypeP &P, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = A.extent(0), n = A.extent(1), plen = P.extent(0); - const ordinal_type - m = A.extent(0), - n = A.extent(1), - plen = P.extent(0); - - if (m == plen) { - if (A.span() > 0) { - using exec_space = MemberType; - using policy_type = Kokkos::RangePolicy; - - const auto exec_instance = member; - const auto policy = policy_type(exec_instance, 0, n); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const ordinal_type &j) { - for (ordinal_type i=(m-1);i>=0;--i) { + if (m == plen) { + if (A.span() > 0) { + using exec_space = MemberType; + using policy_type = Kokkos::RangePolicy; + + const auto exec_instance = member; + const auto policy = policy_type(exec_instance, 0, n); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &j) { + for (ordinal_type i = (m - 1); i >= 0; --i) { const ordinal_type piv = P(i); if (piv == 0) { /// no pivot } else { /// 1x1 pivot const ordinal_type p = i + piv; - const value_type tmp = A(i,j); - A(i,j) = A(p,j); - A(p,j) = tmp; - } + const value_type tmp = A(i, j); + A(i, j) = A(p, j); + A(p, j) = tmp; + } } - }); - } - } else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A and P dimension does not match"); } - return 0; + } else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A and P dimension does not match"); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.cpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.cpp index f9d8466c119d..cc636f5cde39 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.cpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.cpp @@ -2,1103 +2,835 @@ /// \brief BLAS wrapper /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "Tacho_config.h" #include "Tacho_Blas_External.hpp" +#include "Tacho_config.h" #include "Kokkos_Core.hpp" extern "C" { - /// - /// Gemv - /// - - void F77_BLAS_MANGLE(sgemv,SGEMV)( const char*, - int*, int*, - const float*, - const float*, int*, - const float*, int*, - const float*, - /* */ float*, int* ); - void F77_BLAS_MANGLE(dgemv,DGEMV)( const char*, - int*, int*, - const double*, - const double*, int*, - const double*, int*, - const double*, - /* */ double*, int* ); - void F77_BLAS_MANGLE(cgemv,CGEMV)( const char*, - int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - const Kokkos::complex*, int*, - const Kokkos::complex*, - /* */ Kokkos::complex*, int* ); - void F77_BLAS_MANGLE(zgemv,ZGEMV)( const char*, - int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - const Kokkos::complex*, int*, - const Kokkos::complex*, - /* */ Kokkos::complex*, int* ); - - /// - /// Trsv - /// - - void F77_BLAS_MANGLE(strsv,STRSV)( const char*, const char*, const char*, - int*, - const float*, int*, - /* */ float*, int* ); - void F77_BLAS_MANGLE(dtrsv,DTRSV)( const char*, const char*, const char*, - int*, - const double*, int*, - /* */ double*, int* ); - void F77_BLAS_MANGLE(ctrsv,CTRSV)( const char*, const char*, const char*, - int*, - const Kokkos::complex*, int*, - /* */ Kokkos::complex*, int* ); - void F77_BLAS_MANGLE(ztrsv,ZTRSV)( const char*, const char*, const char*, - int*, - const Kokkos::complex*, int*, - /* */ Kokkos::complex*, int* ); - - /// - /// Gemm - /// - - void F77_BLAS_MANGLE(sgemm,SGEMM)( const char*, const char*, - int*, int*, int*, - const float*, - const float*, int*, - const float*, int*, - const float*, - /* */ float*, int* ); - void F77_BLAS_MANGLE(dgemm,DGEMM)( const char*, const char*, - int*, int*, int*, - const double*, - const double*, int*, - const double*, int*, - const double*, - /* */ double*, int* ); - void F77_BLAS_MANGLE(cgemm,CGEMM)( const char*, const char*, - int*, int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - const Kokkos::complex*, int*, - const Kokkos::complex*, - /* */ Kokkos::complex*, int* ); - void F77_BLAS_MANGLE(zgemm,ZGEMM)( const char*, const char*, - int*, int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - const Kokkos::complex*, int*, - const Kokkos::complex*, - /* */ Kokkos::complex*, int* ); - - /// - /// Herk - /// - - void F77_BLAS_MANGLE(ssyrk,SSYRK)( const char*, const char*, - int*, int*, - const float*, - const float*, int*, - const float*, - /* */ float*, int* ); - void F77_BLAS_MANGLE(dsyrk,DSYRK)( const char*, const char*, - int*, int*, - const double*, - const double*, int*, - const double*, - /* */ double*, int* ); - void F77_BLAS_MANGLE(cherk,CHERK)( const char*, const char*, - int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - const Kokkos::complex*, - /* */ Kokkos::complex*, int* ); - void F77_BLAS_MANGLE(zherk,ZHERK)( const char*, const char*, - int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - const Kokkos::complex*, - /* */ Kokkos::complex*, int* ); - - /// - /// Trsm - /// - - void F77_BLAS_MANGLE(strsm,STRSM)( const char*, const char*, const char*, const char*, - int*, int*, - const float*, - const float*, int*, - /* */ float*, int* ); - void F77_BLAS_MANGLE(dtrsm,DTRSM)( const char*, const char*, const char*, const char*, - int*, int*, - const double*, - const double*, int*, - /* */ double*, int* ); - void F77_BLAS_MANGLE(ctrsm,CTRSM)( const char*, const char*, const char*, const char*, - int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - /* */ Kokkos::complex*, int* ); - void F77_BLAS_MANGLE(ztrsm,ZTRSM)( const char*, const char*, const char*, const char*, - int*, int*, - const Kokkos::complex*, - const Kokkos::complex*, int*, - /* */ Kokkos::complex*, int* ); +/// +/// Gemv +/// + +void F77_BLAS_MANGLE(sgemv, SGEMV)(const char *, int *, int *, const float *, const float *, int *, const float *, + int *, const float *, + /* */ float *, int *); +void F77_BLAS_MANGLE(dgemv, DGEMV)(const char *, int *, int *, const double *, const double *, int *, const double *, + int *, const double *, + /* */ double *, int *); +void F77_BLAS_MANGLE(cgemv, CGEMV)(const char *, int *, int *, const Kokkos::complex *, + const Kokkos::complex *, int *, const Kokkos::complex *, int *, + const Kokkos::complex *, + /* */ Kokkos::complex *, int *); +void F77_BLAS_MANGLE(zgemv, ZGEMV)(const char *, int *, int *, const Kokkos::complex *, + const Kokkos::complex *, int *, const Kokkos::complex *, int *, + const Kokkos::complex *, + /* */ Kokkos::complex *, int *); + +/// +/// Trsv +/// + +void F77_BLAS_MANGLE(strsv, STRSV)(const char *, const char *, const char *, int *, const float *, int *, + /* */ float *, int *); +void F77_BLAS_MANGLE(dtrsv, DTRSV)(const char *, const char *, const char *, int *, const double *, int *, + /* */ double *, int *); +void F77_BLAS_MANGLE(ctrsv, CTRSV)(const char *, const char *, const char *, int *, const Kokkos::complex *, + int *, + /* */ Kokkos::complex *, int *); +void F77_BLAS_MANGLE(ztrsv, ZTRSV)(const char *, const char *, const char *, int *, const Kokkos::complex *, + int *, + /* */ Kokkos::complex *, int *); + +/// +/// Gemm +/// + +void F77_BLAS_MANGLE(sgemm, SGEMM)(const char *, const char *, int *, int *, int *, const float *, const float *, int *, + const float *, int *, const float *, + /* */ float *, int *); +void F77_BLAS_MANGLE(dgemm, DGEMM)(const char *, const char *, int *, int *, int *, const double *, const double *, + int *, const double *, int *, const double *, + /* */ double *, int *); +void F77_BLAS_MANGLE(cgemm, CGEMM)(const char *, const char *, int *, int *, int *, const Kokkos::complex *, + const Kokkos::complex *, int *, const Kokkos::complex *, int *, + const Kokkos::complex *, + /* */ Kokkos::complex *, int *); +void F77_BLAS_MANGLE(zgemm, ZGEMM)(const char *, const char *, int *, int *, int *, const Kokkos::complex *, + const Kokkos::complex *, int *, const Kokkos::complex *, int *, + const Kokkos::complex *, + /* */ Kokkos::complex *, int *); + +/// +/// Herk +/// + +void F77_BLAS_MANGLE(ssyrk, SSYRK)(const char *, const char *, int *, int *, const float *, const float *, int *, + const float *, + /* */ float *, int *); +void F77_BLAS_MANGLE(dsyrk, DSYRK)(const char *, const char *, int *, int *, const double *, const double *, int *, + const double *, + /* */ double *, int *); +void F77_BLAS_MANGLE(cherk, CHERK)(const char *, const char *, int *, int *, const Kokkos::complex *, + const Kokkos::complex *, int *, const Kokkos::complex *, + /* */ Kokkos::complex *, int *); +void F77_BLAS_MANGLE(zherk, ZHERK)(const char *, const char *, int *, int *, const Kokkos::complex *, + const Kokkos::complex *, int *, const Kokkos::complex *, + /* */ Kokkos::complex *, int *); + +/// +/// Trsm +/// + +void F77_BLAS_MANGLE(strsm, STRSM)(const char *, const char *, const char *, const char *, int *, int *, const float *, + const float *, int *, + /* */ float *, int *); +void F77_BLAS_MANGLE(dtrsm, DTRSM)(const char *, const char *, const char *, const char *, int *, int *, const double *, + const double *, int *, + /* */ double *, int *); +void F77_BLAS_MANGLE(ctrsm, CTRSM)(const char *, const char *, const char *, const char *, int *, int *, + const Kokkos::complex *, const Kokkos::complex *, int *, + /* */ Kokkos::complex *, int *); +void F77_BLAS_MANGLE(ztrsm, ZTRSM)(const char *, const char *, const char *, const char *, int *, int *, + const Kokkos::complex *, const Kokkos::complex *, int *, + /* */ Kokkos::complex *, int *); } -#define F77_FUNC_SGEMV F77_BLAS_MANGLE(sgemv,SGEMV) -#define F77_FUNC_DGEMV F77_BLAS_MANGLE(dgemv,DGEMV) -#define F77_FUNC_CGEMV F77_BLAS_MANGLE(cgemv,CGEMV) -#define F77_FUNC_ZGEMV F77_BLAS_MANGLE(zgemv,ZGEMV) +#define F77_FUNC_SGEMV F77_BLAS_MANGLE(sgemv, SGEMV) +#define F77_FUNC_DGEMV F77_BLAS_MANGLE(dgemv, DGEMV) +#define F77_FUNC_CGEMV F77_BLAS_MANGLE(cgemv, CGEMV) +#define F77_FUNC_ZGEMV F77_BLAS_MANGLE(zgemv, ZGEMV) -#define F77_FUNC_STRSV F77_BLAS_MANGLE(strsv,STRSV) -#define F77_FUNC_DTRSV F77_BLAS_MANGLE(dtrsv,DTRSV) -#define F77_FUNC_CTRSV F77_BLAS_MANGLE(ctrsv,CTRSV) -#define F77_FUNC_ZTRSV F77_BLAS_MANGLE(ztrsv,ZTRSV) +#define F77_FUNC_STRSV F77_BLAS_MANGLE(strsv, STRSV) +#define F77_FUNC_DTRSV F77_BLAS_MANGLE(dtrsv, DTRSV) +#define F77_FUNC_CTRSV F77_BLAS_MANGLE(ctrsv, CTRSV) +#define F77_FUNC_ZTRSV F77_BLAS_MANGLE(ztrsv, ZTRSV) -#define F77_FUNC_SGEMM F77_BLAS_MANGLE(sgemm,SGEMM) -#define F77_FUNC_DGEMM F77_BLAS_MANGLE(dgemm,DGEMM) -#define F77_FUNC_CGEMM F77_BLAS_MANGLE(cgemm,CGEMM) -#define F77_FUNC_ZGEMM F77_BLAS_MANGLE(zgemm,ZGEMM) +#define F77_FUNC_SGEMM F77_BLAS_MANGLE(sgemm, SGEMM) +#define F77_FUNC_DGEMM F77_BLAS_MANGLE(dgemm, DGEMM) +#define F77_FUNC_CGEMM F77_BLAS_MANGLE(cgemm, CGEMM) +#define F77_FUNC_ZGEMM F77_BLAS_MANGLE(zgemm, ZGEMM) -#define F77_FUNC_SSYRK F77_BLAS_MANGLE(ssyrk,SSYRK) -#define F77_FUNC_DSYRK F77_BLAS_MANGLE(dsyrk,DSYRK) -#define F77_FUNC_CHERK F77_BLAS_MANGLE(cherk,CHERK) -#define F77_FUNC_ZHERK F77_BLAS_MANGLE(zherk,ZHERK) +#define F77_FUNC_SSYRK F77_BLAS_MANGLE(ssyrk, SSYRK) +#define F77_FUNC_DSYRK F77_BLAS_MANGLE(dsyrk, DSYRK) +#define F77_FUNC_CHERK F77_BLAS_MANGLE(cherk, CHERK) +#define F77_FUNC_ZHERK F77_BLAS_MANGLE(zherk, ZHERK) -#define F77_FUNC_STRSM F77_BLAS_MANGLE(strsm,STRSM) -#define F77_FUNC_DTRSM F77_BLAS_MANGLE(dtrsm,DTRSM) -#define F77_FUNC_CTRSM F77_BLAS_MANGLE(ctrsm,CTRSM) -#define F77_FUNC_ZTRSM F77_BLAS_MANGLE(ztrsm,ZTRSM) +#define F77_FUNC_STRSM F77_BLAS_MANGLE(strsm, STRSM) +#define F77_FUNC_DTRSM F77_BLAS_MANGLE(dtrsm, DTRSM) +#define F77_FUNC_CTRSM F77_BLAS_MANGLE(ctrsm, CTRSM) +#define F77_FUNC_ZTRSM F77_BLAS_MANGLE(ztrsm, ZTRSM) namespace Tacho { - /// - /// float - /// +/// +/// float +/// - template<> - int - Blas::gemv(const char trans, - int m, int n, - const float alpha, - const float *a, int lda, - const float *b, int ldb, - const float beta, - /* */ float *c, int ldc) { - F77_FUNC_SGEMV(&trans, - &m, &n, - &alpha, - a, &lda, - b, &ldb, - &beta, - c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas::gemv(cublasHandle_t handle, - const cublasOperation_t trans, - int m, int n, - const float alpha, - const float *a, int lda, - const float *b, int ldb, - const float beta, - /* */ float *c, int ldc) { - const int r_val = cublasSgemv(handle, - trans, - m, n, - &alpha, - a, lda, - b, ldb, - &beta, - c, ldc); - return r_val; - } +template <> +int Blas::gemv(const char trans, int m, int n, const float alpha, const float *a, int lda, const float *b, + int ldb, const float beta, + /* */ float *c, int ldc) { + F77_FUNC_SGEMV(&trans, &m, &n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::gemv(cublasHandle_t handle, const cublasOperation_t trans, int m, int n, const float alpha, + const float *a, int lda, const float *b, int ldb, const float beta, + /* */ float *c, int ldc) { + const int r_val = cublasSgemv(handle, trans, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} #endif - - template<> - int - Blas::trsv(const char uplo, const char transa, const char diag, - int m, - const float *a, int lda, - /* */ float *b, int ldb) { - F77_FUNC_STRSV(&uplo, &transa, &diag, - &m, - a, &lda, - b, &ldb); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas::trsv(cublasHandle_t handle, - const cublasFillMode_t uplo, - const cublasOperation_t transa, - const cublasDiagType_t diag, - int m, - const float *a, int lda, - /* */ float *b, int ldb) { - const int r_val = cublasStrsv(handle, - uplo, transa, diag, - m, - a, lda, - b, ldb); - return r_val; - } +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::gemv(hipblasHandle_t handle, const hipblasOperation_t trans, int m, int n, const float alpha, + const float *a, int lda, const float *b, int ldb, const float beta, + /* */ float *c, int ldc) { + const int r_val = hipblasSgemv(handle, trans, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} #endif - template<> - int - Blas::gemm(const char transa, const char transb, - int m, int n, int k, - const float alpha, - const float *a, int lda, - const float *b, int ldb, - const float beta, - /* */ float *c, int ldc) { - F77_FUNC_SGEMM(&transa, &transb, - &m, &n, &k, - &alpha, - a, &lda, - b, &ldb, - &beta, - c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas::gemm(cublasHandle_t handle, - const cublasOperation_t transa, - const cublasOperation_t transb, - int m, int n, int k, - const float alpha, - const float *a, int lda, - const float *b, int ldb, - const float beta, - /* */ float *c, int ldc) { - const int r_val = cublasSgemm(handle, - transa, transb, - m, n, k, - &alpha, - a, lda, - b, ldb, - &beta, - c, ldc); - return r_val; - } +template <> +int Blas::trsv(const char uplo, const char transa, const char diag, int m, const float *a, int lda, + /* */ float *b, int ldb) { + F77_FUNC_STRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::trsv(cublasHandle_t handle, const cublasFillMode_t uplo, const cublasOperation_t transa, + const cublasDiagType_t diag, int m, const float *a, int lda, + /* */ float *b, int ldb) { + const int r_val = cublasStrsv(handle, uplo, transa, diag, m, a, lda, b, ldb); + return r_val; +} #endif - - template<> - int - Blas::herk(const char uplo, const char trans, - int n, int k, - const float alpha, - const float *a, int lda, - const float beta, - /* */ float *c, int ldc) { - F77_FUNC_SSYRK(&uplo, &trans, - &n, &k, - &alpha, - a, &lda, - &beta, - c, &ldc); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas::herk(cublasHandle_t handle, - const cublasFillMode_t uplo, const cublasOperation_t trans, - int n, int k, - const float alpha, - const float *a, int lda, - const float beta, - /* */ float *c, int ldc) { - const int r_val = cublasSsyrk(handle, - uplo, trans, - n, k, - &alpha, - a, lda, - &beta, - c, ldc); - return r_val; - } +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::trsv(hipblasHandle_t handle, const hipblasFillMode_t uplo, const hipblasOperation_t transa, + const hipblasDiagType_t diag, int m, const float *a, int lda, + /* */ float *b, int ldb) { + const int r_val = hipblasStrsv(handle, uplo, transa, diag, m, a, lda, b, ldb); + return r_val; +} #endif - template<> - int - Blas::trsm(const char side, const char uplo, const char transa, const char diag, - int m, int n, - const float alpha, - const float *a, int lda, - /* */ float *b, int ldb) { - F77_FUNC_STRSM(&side, &uplo, &transa, &diag, - &m, &n, - &alpha, - a, &lda, - b, &ldb); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas::trsm(cublasHandle_t handle, - const cublasSideMode_t side, const cublasFillMode_t uplo, - const cublasOperation_t transa, const cublasDiagType_t diag, - int m, int n, - const float alpha, - const float *a, int lda, - /* */ float *b, int ldb) { - const int r_val = cublasStrsm(handle, - side, uplo, transa, diag, - m, n, - &alpha, - a, lda, - b, ldb); - return r_val; - } +template <> +int Blas::gemm(const char transa, const char transb, int m, int n, int k, const float alpha, const float *a, + int lda, const float *b, int ldb, const float beta, + /* */ float *c, int ldc) { + F77_FUNC_SGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::gemm(cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, + int n, int k, const float alpha, const float *a, int lda, const float *b, int ldb, + const float beta, + /* */ float *c, int ldc) { + const int r_val = cublasSgemm(handle, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::gemm(hipblasHandle_t handle, const hipblasOperation_t transa, const hipblasOperation_t transb, int m, + int n, int k, const float alpha, const float *a, int lda, const float *b, int ldb, + const float beta, + /* */ float *c, int ldc) { + const int r_val = hipblasSgemm(handle, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} #endif - /// - /// double - /// - - template<> - int - Blas::gemv(const char trans, - int m, int n, - const double alpha, - const double *a, int lda, - const double *b, int ldb, - const double beta, - /* */ double *c, int ldc) { - F77_FUNC_DGEMV(&trans, - &m, &n, - &alpha, - a, &lda, - b, &ldb, - &beta, - c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas::gemv(cublasHandle_t handle, - const cublasOperation_t trans, - int m, int n, - const double alpha, - const double *a, int lda, - const double *b, int ldb, - const double beta, - /* */ double *c, int ldc) { - const int r_val = cublasDgemv(handle, - trans, - m, n, - &alpha, - a, lda, - b, ldb, - &beta, - c, ldc); - return r_val; - } +template <> +int Blas::herk(const char uplo, const char trans, int n, int k, const float alpha, const float *a, int lda, + const float beta, + /* */ float *c, int ldc) { + F77_FUNC_SSYRK(&uplo, &trans, &n, &k, &alpha, a, &lda, &beta, c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::herk(cublasHandle_t handle, const cublasFillMode_t uplo, const cublasOperation_t trans, int n, int k, + const float alpha, const float *a, int lda, const float beta, + /* */ float *c, int ldc) { + const int r_val = cublasSsyrk(handle, uplo, trans, n, k, &alpha, a, lda, &beta, c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::herk(hipblasHandle_t handle, const hipblasFillMode_t uplo, const hipblasOperation_t trans, int n, + int k, const float alpha, const float *a, int lda, const float beta, + /* */ float *c, int ldc) { + const int r_val = hipblasSsyrk(handle, uplo, trans, n, k, &alpha, a, lda, &beta, c, ldc); + return r_val; +} #endif - template<> - int - Blas::trsv(const char uplo, const char transa, const char diag, - int m, - const double *a, int lda, - /* */ double *b, int ldb) { - F77_FUNC_DTRSV(&uplo, &transa, &diag, - &m, - a, &lda, - b, &ldb); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas::trsv(cublasHandle_t handle, - const cublasFillMode_t uplo, - const cublasOperation_t transa, - const cublasDiagType_t diag, - int m, - const double *a, int lda, - /* */ double *b, int ldb) { - const int r_val = cublasDtrsv(handle, - uplo, transa, diag, - m, - a, lda, - b, ldb); - return r_val; - } +template <> +int Blas::trsm(const char side, const char uplo, const char transa, const char diag, int m, int n, + const float alpha, const float *a, int lda, + /* */ float *b, int ldb) { + F77_FUNC_STRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, a, &lda, b, &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::trsm(cublasHandle_t handle, const cublasSideMode_t side, const cublasFillMode_t uplo, + const cublasOperation_t transa, const cublasDiagType_t diag, int m, int n, const float alpha, + const float *a, int lda, + /* */ float *b, int ldb) { + const int r_val = cublasStrsm(handle, side, uplo, transa, diag, m, n, &alpha, a, lda, b, ldb); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::trsm(hipblasHandle_t handle, const hipblasSideMode_t side, const hipblasFillMode_t uplo, + const hipblasOperation_t transa, const hipblasDiagType_t diag, int m, int n, const float alpha, + const float *a, int lda, + /* */ float *b, int ldb) { + const int r_val = hipblasStrsm(handle, side, uplo, transa, diag, m, n, &alpha, a, lda, b, ldb); + return r_val; +} #endif - template<> - int - Blas::gemm(const char transa, const char transb, - int m, int n, int k, - const double alpha, - const double *a, int lda, - const double *b, int ldb, - const double beta, - /* */ double *c, int ldc) { - F77_FUNC_DGEMM(&transa, &transb, - &m, &n, &k, - &alpha, - a, &lda, - b, &ldb, - &beta, - c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas::gemm(cublasHandle_t handle, - const cublasOperation_t transa, - const cublasOperation_t transb, - int m, int n, int k, - const double alpha, - const double *a, int lda, - const double *b, int ldb, - const double beta, - /* */ double *c, int ldc) { - const int r_val = cublasDgemm(handle, - transa, transb, - m, n, k, - &alpha, - a, lda, - b, ldb, - &beta, - c, ldc); - return r_val; - } +/// +/// double +/// + +template <> +int Blas::gemv(const char trans, int m, int n, const double alpha, const double *a, int lda, const double *b, + int ldb, const double beta, + /* */ double *c, int ldc) { + F77_FUNC_DGEMV(&trans, &m, &n, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::gemv(cublasHandle_t handle, const cublasOperation_t trans, int m, int n, const double alpha, + const double *a, int lda, const double *b, int ldb, const double beta, + /* */ double *c, int ldc) { + const int r_val = cublasDgemv(handle, trans, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::gemv(hipblasHandle_t handle, const hipblasOperation_t trans, int m, int n, const double alpha, + const double *a, int lda, const double *b, int ldb, const double beta, + /* */ double *c, int ldc) { + const int r_val = hipblasDgemv(handle, trans, m, n, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} #endif - template<> - int - Blas::herk(const char uplo, const char trans, - int n, int k, - const double alpha, - const double *a, int lda, - const double beta, - /* */ double *c, int ldc) { - F77_FUNC_DSYRK(&uplo, &trans, - &n, &k, - &alpha, - a, &lda, - &beta, - c, &ldc); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas::herk(cublasHandle_t handle, - const cublasFillMode_t uplo, const cublasOperation_t trans, - int n, int k, - const double alpha, - const double *a, int lda, - const double beta, - /* */ double *c, int ldc) { - const int r_val = cublasDsyrk(handle, - uplo, trans, - n, k, - &alpha, - a, lda, - &beta, - c, ldc); - return r_val; - } +template <> +int Blas::trsv(const char uplo, const char transa, const char diag, int m, const double *a, int lda, + /* */ double *b, int ldb) { + F77_FUNC_DTRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::trsv(cublasHandle_t handle, const cublasFillMode_t uplo, const cublasOperation_t transa, + const cublasDiagType_t diag, int m, const double *a, int lda, + /* */ double *b, int ldb) { + const int r_val = cublasDtrsv(handle, uplo, transa, diag, m, a, lda, b, ldb); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::trsv(hipblasHandle_t handle, const hipblasFillMode_t uplo, const hipblasOperation_t transa, + const hipblasDiagType_t diag, int m, const double *a, int lda, + /* */ double *b, int ldb) { + const int r_val = hipblasDtrsv(handle, uplo, transa, diag, m, a, lda, b, ldb); + return r_val; +} #endif - template<> - int - Blas::trsm(const char side, const char uplo, const char transa, const char diag, - int m, int n, - const double alpha, - const double *a, int lda, - /* */ double *b, int ldb) { - F77_FUNC_DTRSM(&side, &uplo, &transa, &diag, - &m, &n, - &alpha, - a, &lda, - b, &ldb); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas::trsm(cublasHandle_t handle, - const cublasSideMode_t side, const cublasFillMode_t uplo, - const cublasOperation_t transa, const cublasDiagType_t diag, - int m, int n, - const double alpha, - const double *a, int lda, - /* */ double *b, int ldb) { - const int r_val = cublasDtrsm(handle, - side, uplo, transa, diag, - m, n, - &alpha, - a, lda, - b, ldb); - return r_val; - } +template <> +int Blas::gemm(const char transa, const char transb, int m, int n, int k, const double alpha, const double *a, + int lda, const double *b, int ldb, const double beta, + /* */ double *c, int ldc) { + F77_FUNC_DGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::gemm(cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, + int n, int k, const double alpha, const double *a, int lda, const double *b, int ldb, + const double beta, + /* */ double *c, int ldc) { + const int r_val = cublasDgemm(handle, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::gemm(hipblasHandle_t handle, const hipblasOperation_t transa, const hipblasOperation_t transb, int m, + int n, int k, const double alpha, const double *a, int lda, const double *b, int ldb, + const double beta, + /* */ double *c, int ldc) { + const int r_val = hipblasDgemm(handle, transa, transb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc); + return r_val; +} #endif - /// - /// Kokkos::complex - /// - - template<> - int - Blas >::gemv(const char trans, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - F77_FUNC_CGEMV(&trans, - &m, &n, - &alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - &beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::gemv(cublasHandle_t handle, - const cublasOperation_t trans, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - const int r_val = cublasCgemv(handle, - trans, - m, n, - (const cuComplex*)&alpha, - (const cuComplex*)a, lda, - (const cuComplex*)b, ldb, - (const cuComplex*)&beta, - (cuComplex*)c, ldc); - return r_val; - } +template <> +int Blas::herk(const char uplo, const char trans, int n, int k, const double alpha, const double *a, int lda, + const double beta, + /* */ double *c, int ldc) { + F77_FUNC_DSYRK(&uplo, &trans, &n, &k, &alpha, a, &lda, &beta, c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::herk(cublasHandle_t handle, const cublasFillMode_t uplo, const cublasOperation_t trans, int n, int k, + const double alpha, const double *a, int lda, const double beta, + /* */ double *c, int ldc) { + const int r_val = cublasDsyrk(handle, uplo, trans, n, k, &alpha, a, lda, &beta, c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::herk(hipblasHandle_t handle, const hipblasFillMode_t uplo, const hipblasOperation_t trans, int n, + int k, const double alpha, const double *a, int lda, const double beta, + /* */ double *c, int ldc) { + const int r_val = hipblasDsyrk(handle, uplo, trans, n, k, &alpha, a, lda, &beta, c, ldc); + return r_val; +} #endif - template<> - int - Blas >::trsv(const char uplo, const char transa, const char diag, - int m, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - F77_FUNC_CTRSV(&uplo, &transa, &diag, - &m, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::trsv(cublasHandle_t handle, - const cublasFillMode_t uplo, - const cublasOperation_t transa, - const cublasDiagType_t diag, - int m, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - const int r_val = cublasCtrsv(handle, - uplo, transa, diag, - m, - (const cuComplex*)a, lda, - (cuComplex*)b, ldb); - return r_val; - } +template <> +int Blas::trsm(const char side, const char uplo, const char transa, const char diag, int m, int n, + const double alpha, const double *a, int lda, + /* */ double *b, int ldb) { + F77_FUNC_DTRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, a, &lda, b, &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas::trsm(cublasHandle_t handle, const cublasSideMode_t side, const cublasFillMode_t uplo, + const cublasOperation_t transa, const cublasDiagType_t diag, int m, int n, const double alpha, + const double *a, int lda, + /* */ double *b, int ldb) { + const int r_val = cublasDtrsm(handle, side, uplo, transa, diag, m, n, &alpha, a, lda, b, ldb); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas::trsm(hipblasHandle_t handle, const hipblasSideMode_t side, const hipblasFillMode_t uplo, + const hipblasOperation_t transa, const hipblasDiagType_t diag, int m, int n, const double alpha, + const double *a, int lda, + /* */ double *b, int ldb) { + const int r_val = hipblasDtrsm(handle, side, uplo, transa, diag, m, n, &alpha, a, lda, b, ldb); + return r_val; +} #endif - template<> - int - Blas >::gemm(const char transa, const char transb, - int m, int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - F77_FUNC_CGEMM(&transa, &transb, - &m, &n, &k, - &alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - &beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::gemm(cublasHandle_t handle, - const cublasOperation_t transa, - const cublasOperation_t transb, - int m, int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - const int r_val = cublasCgemm(handle, - transa, transb, - m, n, k, - (const cuComplex*)&alpha, - (const cuComplex*)a, lda, - (const cuComplex*)b, ldb, - (const cuComplex*)&beta, - (cuComplex*)c, ldc); - return r_val; - } +/// +/// Kokkos::complex +/// + +template <> +int Blas>::gemv(const char trans, int m, int n, const Kokkos::complex alpha, + const Kokkos::complex *a, int lda, const Kokkos::complex *b, + int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + F77_FUNC_CGEMV(&trans, &m, &n, &alpha, (const Kokkos::complex *)a, &lda, (const Kokkos::complex *)b, + &ldb, &beta, (Kokkos::complex *)c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::gemv(cublasHandle_t handle, const cublasOperation_t trans, int m, int n, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = cublasCgemv(handle, trans, m, n, (const cuComplex *)&alpha, (const cuComplex *)a, lda, + (const cuComplex *)b, ldb, (const cuComplex *)&beta, (cuComplex *)c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::gemv(hipblasHandle_t handle, const hipblasOperation_t trans, int m, int n, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = + hipblasCgemv(handle, trans, m, n, (const hipFloatComplex *)&alpha, (const hipFloatComplex *)a, lda, + (const hipFloatComplex *)b, ldb, (const hipFloatComplex *)&beta, (hipFloatComplex *)c, ldc); + return r_val; +} #endif - template<> - int - Blas >::herk(const char uplo, const char trans, - int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - F77_FUNC_CHERK(&uplo, &trans, - &n, &k, - &alpha, - (const Kokkos::complex*)a, &lda, - &beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::herk(cublasHandle_t handle, - const cublasFillMode_t uplo, const cublasOperation_t trans, - int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - const int r_val = cublasCherk(handle, - uplo, trans, - n, k, - (const float*)&alpha, - (const cuComplex*)a, lda, - (const float*)&beta, - (cuComplex*)c, ldc); - return r_val; - } +template <> +int Blas>::trsv(const char uplo, const char transa, const char diag, int m, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + F77_FUNC_CTRSV(&uplo, &transa, &diag, &m, (const Kokkos::complex *)a, &lda, (Kokkos::complex *)b, &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::trsv(cublasHandle_t handle, const cublasFillMode_t uplo, + const cublasOperation_t transa, const cublasDiagType_t diag, int m, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = cublasCtrsv(handle, uplo, transa, diag, m, (const cuComplex *)a, lda, (cuComplex *)b, ldb); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::trsv(hipblasHandle_t handle, const hipblasFillMode_t uplo, + const hipblasOperation_t transa, const hipblasDiagType_t diag, int m, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = + hipblasCtrsv(handle, uplo, transa, diag, m, (const hipFloatComplex *)a, lda, (hipFloatComplex *)b, ldb); + return r_val; +} #endif - template<> - int - Blas >::trsm(const char side, const char uplo, const char transa, const char diag, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - F77_FUNC_CTRSM(&side, &uplo, &transa, &diag, - &m, &n, - &alpha, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::trsm(cublasHandle_t handle, - const cublasSideMode_t side, const cublasFillMode_t uplo, - const cublasOperation_t transa, const cublasDiagType_t diag, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - const int r_val = cublasCtrsm(handle, - side, uplo, transa, diag, - m, n, - (const cuComplex*)&alpha, - (const cuComplex*)a, lda, - (cuComplex*)b, ldb); - return r_val; - } +template <> +int Blas>::gemm(const char transa, const char transb, int m, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + F77_FUNC_CGEMM(&transa, &transb, &m, &n, &k, &alpha, (const Kokkos::complex *)a, &lda, + (const Kokkos::complex *)b, &ldb, &beta, (Kokkos::complex *)c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::gemm(cublasHandle_t handle, const cublasOperation_t transa, + const cublasOperation_t transb, int m, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = cublasCgemm(handle, transa, transb, m, n, k, (const cuComplex *)&alpha, (const cuComplex *)a, lda, + (const cuComplex *)b, ldb, (const cuComplex *)&beta, (cuComplex *)c, ldc); + return r_val; +} #endif - - /// - /// Kokkos::complex - /// - - template<> - int - Blas >::gemv(const char trans, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - F77_FUNC_ZGEMV(&trans, - &m, &n, - &alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - &beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::gemv(cublasHandle_t handle, - const cublasOperation_t trans, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - const int r_val = cublasZgemv(handle, - trans, - m, n, - (const cuDoubleComplex*)&alpha, - (const cuDoubleComplex*)a, lda, - (const cuDoubleComplex*)b, ldb, - (const cuDoubleComplex*)&beta, - (cuDoubleComplex*)c, ldc); - return r_val; - } +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::gemm(hipblasHandle_t handle, const hipblasOperation_t transa, + const hipblasOperation_t transb, int m, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = + hipblasCgemm(handle, transa, transb, m, n, k, (const hipFloatComplex *)&alpha, (const hipFloatComplex *)a, lda, + (const hipFloatComplex *)b, ldb, (const hipFloatComplex *)&beta, (hipFloatComplex *)c, ldc); + return r_val; +} #endif - template<> - int - Blas >::trsv(const char uplo, const char transa, const char diag, - int m, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - F77_FUNC_ZTRSV(&uplo, &transa, &diag, - &m, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::trsv(cublasHandle_t handle, - const cublasFillMode_t uplo, - const cublasOperation_t transa, - const cublasDiagType_t diag, - int m, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - const int r_val = cublasZtrsv(handle, - uplo, transa, diag, - m, - (const cuDoubleComplex*)a, lda, - (cuDoubleComplex*)b, ldb); - return r_val; - } +template <> +int Blas>::herk(const char uplo, const char trans, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + F77_FUNC_CHERK(&uplo, &trans, &n, &k, &alpha, (const Kokkos::complex *)a, &lda, &beta, + (Kokkos::complex *)c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::herk(cublasHandle_t handle, const cublasFillMode_t uplo, + const cublasOperation_t trans, int n, int k, const Kokkos::complex alpha, + const Kokkos::complex *a, int lda, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = cublasCherk(handle, uplo, trans, n, k, (const float *)&alpha, (const cuComplex *)a, lda, + (const float *)&beta, (cuComplex *)c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::herk(hipblasHandle_t handle, const hipblasFillMode_t uplo, + const hipblasOperation_t trans, int n, int k, const Kokkos::complex alpha, + const Kokkos::complex *a, int lda, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = hipblasCherk(handle, uplo, trans, n, k, (const float *)&alpha, (const hipFloatComplex *)a, lda, + (const float *)&beta, (hipFloatComplex *)c, ldc); + return r_val; +} #endif - template<> - int - Blas >::gemm(const char transa, const char transb, - int m, int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - F77_FUNC_ZGEMM(&transa, &transb, - &m, &n, &k, - &alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - &beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } -#if defined (KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::gemm(cublasHandle_t handle, - const cublasOperation_t transa, - const cublasOperation_t transb, - int m, int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex *b, int ldb, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - const int r_val = cublasZgemm(handle, - transa, transb, - m, n, k, - (const cuDoubleComplex*)&alpha, - (const cuDoubleComplex*)a, lda, - (const cuDoubleComplex*)b, ldb, - (const cuDoubleComplex*)&beta, - (cuDoubleComplex*)c, ldc); - return r_val; - } +template <> +int Blas>::trsm(const char side, const char uplo, const char transa, const char diag, int m, + int n, const Kokkos::complex alpha, const Kokkos::complex *a, + int lda, + /* */ Kokkos::complex *b, int ldb) { + F77_FUNC_CTRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const Kokkos::complex *)a, &lda, + (Kokkos::complex *)b, &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::trsm(cublasHandle_t handle, const cublasSideMode_t side, const cublasFillMode_t uplo, + const cublasOperation_t transa, const cublasDiagType_t diag, int m, int n, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = cublasCtrsm(handle, side, uplo, transa, diag, m, n, (const cuComplex *)&alpha, (const cuComplex *)a, + lda, (cuComplex *)b, ldb); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::trsm(hipblasHandle_t handle, const hipblasSideMode_t side, + const hipblasFillMode_t uplo, const hipblasOperation_t transa, + const hipblasDiagType_t diag, int m, int n, const Kokkos::complex alpha, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = hipblasCtrsm(handle, side, uplo, transa, diag, m, n, (const hipFloatComplex *)&alpha, + (const hipFloatComplex *)a, lda, (hipFloatComplex *)b, ldb); + return r_val; +} #endif - template<> - int - Blas >::herk(const char uplo, const char trans, - int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - F77_FUNC_ZHERK(&uplo, &trans, - &n, &k, - &alpha, - (const Kokkos::complex*)a, &lda, - &beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::herk(cublasHandle_t handle, - const cublasFillMode_t uplo, const cublasOperation_t trans, - int n, int k, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - const Kokkos::complex beta, - /* */ Kokkos::complex *c, int ldc) { - const int r_val = cublasZherk(handle, - uplo, trans, - n, k, - (const double*)&alpha, - (const cuDoubleComplex*)a, lda, - (const double*)&beta, - (cuDoubleComplex*)c, ldc); - return r_val; - } +/// +/// Kokkos::complex +/// + +template <> +int Blas>::gemv(const char trans, int m, int n, const Kokkos::complex alpha, + const Kokkos::complex *a, int lda, const Kokkos::complex *b, + int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + F77_FUNC_ZGEMV(&trans, &m, &n, &alpha, (const Kokkos::complex *)a, &lda, (const Kokkos::complex *)b, + &ldb, &beta, (Kokkos::complex *)c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::gemv(cublasHandle_t handle, const cublasOperation_t trans, int m, int n, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = + cublasZgemv(handle, trans, m, n, (const cuDoubleComplex *)&alpha, (const cuDoubleComplex *)a, lda, + (const cuDoubleComplex *)b, ldb, (const cuDoubleComplex *)&beta, (cuDoubleComplex *)c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::gemv(hipblasHandle_t handle, const hipblasOperation_t trans, int m, int n, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = + hipblasZgemv(handle, trans, m, n, (const hipDoubleComplex *)&alpha, (const hipDoubleComplex *)a, lda, + (const hipDoubleComplex *)b, ldb, (const hipDoubleComplex *)&beta, (hipDoubleComplex *)c, ldc); + return r_val; +} #endif - template<> - int - Blas >::trsm(const char side, const char uplo, const char transa, const char diag, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - F77_FUNC_ZTRSM(&side, &uplo, &transa, &diag, - &m, &n, - &alpha, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Blas >::trsm(cublasHandle_t handle, - const cublasSideMode_t side, const cublasFillMode_t uplo, - const cublasOperation_t transa, const cublasDiagType_t diag, - int m, int n, - const Kokkos::complex alpha, - const Kokkos::complex *a, int lda, - /* */ Kokkos::complex *b, int ldb) { - const int r_val = cublasZtrsm(handle, - side, uplo, transa, diag, - m, n, - (const cuDoubleComplex*)&alpha, - (const cuDoubleComplex*)a, lda, - (cuDoubleComplex*)b, ldb); - return r_val; - } +template <> +int Blas>::trsv(const char uplo, const char transa, const char diag, int m, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + F77_FUNC_ZTRSV(&uplo, &transa, &diag, &m, (const Kokkos::complex *)a, &lda, (Kokkos::complex *)b, + &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::trsv(cublasHandle_t handle, const cublasFillMode_t uplo, + const cublasOperation_t transa, const cublasDiagType_t diag, int m, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = + cublasZtrsv(handle, uplo, transa, diag, m, (const cuDoubleComplex *)a, lda, (cuDoubleComplex *)b, ldb); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::trsv(hipblasHandle_t handle, const hipblasFillMode_t uplo, + const hipblasOperation_t transa, const hipblasDiagType_t diag, int m, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = + hipblasZtrsv(handle, uplo, transa, diag, m, (const hipDoubleComplex *)a, lda, (hipDoubleComplex *)b, ldb); + return r_val; +} #endif - /// - /// std::complex - /// +template <> +int Blas>::gemm(const char transa, const char transb, int m, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + F77_FUNC_ZGEMM(&transa, &transb, &m, &n, &k, &alpha, (const Kokkos::complex *)a, &lda, + (const Kokkos::complex *)b, &ldb, &beta, (Kokkos::complex *)c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::gemm(cublasHandle_t handle, const cublasOperation_t transa, + const cublasOperation_t transb, int m, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = + cublasZgemm(handle, transa, transb, m, n, k, (const cuDoubleComplex *)&alpha, (const cuDoubleComplex *)a, lda, + (const cuDoubleComplex *)b, ldb, (const cuDoubleComplex *)&beta, (cuDoubleComplex *)c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::gemm(hipblasHandle_t handle, const hipblasOperation_t transa, + const hipblasOperation_t transb, int m, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex *b, int ldb, const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = + hipblasZgemm(handle, transa, transb, m, n, k, (const hipDoubleComplex *)&alpha, (const hipDoubleComplex *)a, lda, + (const hipDoubleComplex *)b, ldb, (const hipDoubleComplex *)&beta, (hipDoubleComplex *)c, ldc); + return r_val; +} +#endif - template<> - int - Blas >::gemv(const char trans, - int m, int n, - const std::complex alpha, - const std::complex *a, int lda, - const std::complex *b, int ldb, - const std::complex beta, - /* */ std::complex *c, int ldc) { - F77_FUNC_CGEMV(&trans, - &m, &n, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - (const Kokkos::complex*)&beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } - template<> - int - Blas >::trsv(const char uplo, const char transa, const char diag, - int m, - const std::complex *a, int lda, - /* */ std::complex *b, int ldb) { - F77_FUNC_CTRSV(&uplo, &transa, &diag, - &m, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } - template<> - int - Blas >::gemm(const char transa, const char transb, - int m, int n, int k, - const std::complex alpha, - const std::complex *a, int lda, - const std::complex *b, int ldb, - const std::complex beta, - /* */ std::complex *c, int ldc) { - F77_FUNC_CGEMM(&transa, &transb, - &m, &n, &k, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - (const Kokkos::complex*)&beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } - template<> - int - Blas >::herk(const char transa, const char transb, - int n, int k, - const std::complex alpha, - const std::complex *a, int lda, - const std::complex beta, - /* */ std::complex *c, int ldc) { - F77_FUNC_CHERK(&transa, &transb, - &n, &k, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)&beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } - template<> - int - Blas >::trsm(const char side, const char uplo, const char transa, const char diag, - int m, int n, - const std::complex alpha, - const std::complex *a, int lda, - /* */ std::complex *b, int ldb) { - F77_FUNC_CTRSM(&side, &uplo, &transa, &diag, - &m, &n, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } +template <> +int Blas>::herk(const char uplo, const char trans, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + F77_FUNC_ZHERK(&uplo, &trans, &n, &k, &alpha, (const Kokkos::complex *)a, &lda, &beta, + (Kokkos::complex *)c, &ldc); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::herk(cublasHandle_t handle, const cublasFillMode_t uplo, + const cublasOperation_t trans, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = cublasZherk(handle, uplo, trans, n, k, (const double *)&alpha, (const cuDoubleComplex *)a, lda, + (const double *)&beta, (cuDoubleComplex *)c, ldc); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::herk(hipblasHandle_t handle, const hipblasFillMode_t uplo, + const hipblasOperation_t trans, int n, int k, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + const Kokkos::complex beta, + /* */ Kokkos::complex *c, int ldc) { + const int r_val = hipblasZherk(handle, uplo, trans, n, k, (const double *)&alpha, (const hipDoubleComplex *)a, lda, + (const double *)&beta, (hipDoubleComplex *)c, ldc); + return r_val; +} +#endif - /// - /// std::complex - /// - - template<> - int - Blas >::gemv(const char trans, - int m, int n, - const std::complex alpha, - const std::complex *a, int lda, - const std::complex *b, int ldb, - const std::complex beta, - /* */ std::complex *c, int ldc) { - F77_FUNC_ZGEMV(&trans, - &m, &n, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - (const Kokkos::complex*)&beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } - template<> - int - Blas >::trsv(const char uplo, const char transa, const char diag, - int m, - const std::complex *a, int lda, - /* */ std::complex *b, int ldb) { - F77_FUNC_ZTRSV(&uplo, &transa, &diag, - &m, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } - template<> - int - Blas >::gemm(const char transa, const char transb, - int m, int n, int k, - const std::complex alpha, - const std::complex *a, int lda, - const std::complex *b, int ldb, - const std::complex beta, - /* */ std::complex *c, int ldc) { - F77_FUNC_ZGEMM(&transa, &transb, - &m, &n, &k, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)b, &ldb, - (const Kokkos::complex*)&beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } - template<> - int - Blas >::herk(const char transa, const char transb, - int n, int k, - const std::complex alpha, - const std::complex *a, int lda, - const std::complex beta, - /* */ std::complex *c, int ldc) { - F77_FUNC_ZHERK(&transa, &transb, - &n, &k, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - (const Kokkos::complex*)&beta, - ( Kokkos::complex*)c, &ldc); - return 0; - } - template<> - int - Blas >::trsm(const char side, const char uplo, const char transa, const char diag, - int m, int n, - const std::complex alpha, - const std::complex *a, int lda, - /* */ std::complex *b, int ldb) { - F77_FUNC_ZTRSM(&side, &uplo, &transa, &diag, - &m, &n, - (const Kokkos::complex*)&alpha, - (const Kokkos::complex*)a, &lda, - ( Kokkos::complex*)b, &ldb); - return 0; - } +template <> +int Blas>::trsm(const char side, const char uplo, const char transa, const char diag, int m, + int n, const Kokkos::complex alpha, const Kokkos::complex *a, + int lda, + /* */ Kokkos::complex *b, int ldb) { + F77_FUNC_ZTRSM(&side, &uplo, &transa, &diag, &m, &n, &alpha, (const Kokkos::complex *)a, &lda, + (Kokkos::complex *)b, &ldb); + return 0; +} +#if defined(TACHO_ENABLE_CUBLAS) +template <> +int Blas>::trsm(cublasHandle_t handle, const cublasSideMode_t side, const cublasFillMode_t uplo, + const cublasOperation_t transa, const cublasDiagType_t diag, int m, int n, + const Kokkos::complex alpha, const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = cublasZtrsm(handle, side, uplo, transa, diag, m, n, (const cuDoubleComplex *)&alpha, + (const cuDoubleComplex *)a, lda, (cuDoubleComplex *)b, ldb); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPBLAS) +template <> +int Blas>::trsm(hipblasHandle_t handle, const hipblasSideMode_t side, + const hipblasFillMode_t uplo, const hipblasOperation_t transa, + const hipblasDiagType_t diag, int m, int n, const Kokkos::complex alpha, + const Kokkos::complex *a, int lda, + /* */ Kokkos::complex *b, int ldb) { + const int r_val = hipblasZtrsm(handle, side, uplo, transa, diag, m, n, (const hipDoubleComplex *)&alpha, + (const hipDoubleComplex *)a, lda, (hipDoubleComplex *)b, ldb); + return r_val; +} +#endif +/// +/// std::complex +/// + +template <> +int Blas>::gemv(const char trans, int m, int n, const std::complex alpha, + const std::complex *a, int lda, const std::complex *b, int ldb, + const std::complex beta, + /* */ std::complex *c, int ldc) { + F77_FUNC_CGEMV(&trans, &m, &n, (const Kokkos::complex *)&alpha, (const Kokkos::complex *)a, &lda, + (const Kokkos::complex *)b, &ldb, (const Kokkos::complex *)&beta, + (Kokkos::complex *)c, &ldc); + return 0; +} +template <> +int Blas>::trsv(const char uplo, const char transa, const char diag, int m, + const std::complex *a, int lda, + /* */ std::complex *b, int ldb) { + F77_FUNC_CTRSV(&uplo, &transa, &diag, &m, (const Kokkos::complex *)a, &lda, (Kokkos::complex *)b, &ldb); + return 0; +} +template <> +int Blas>::gemm(const char transa, const char transb, int m, int n, int k, + const std::complex alpha, const std::complex *a, int lda, + const std::complex *b, int ldb, const std::complex beta, + /* */ std::complex *c, int ldc) { + F77_FUNC_CGEMM(&transa, &transb, &m, &n, &k, (const Kokkos::complex *)&alpha, + (const Kokkos::complex *)a, &lda, (const Kokkos::complex *)b, &ldb, + (const Kokkos::complex *)&beta, (Kokkos::complex *)c, &ldc); + return 0; +} +template <> +int Blas>::herk(const char transa, const char transb, int n, int k, const std::complex alpha, + const std::complex *a, int lda, const std::complex beta, + /* */ std::complex *c, int ldc) { + F77_FUNC_CHERK(&transa, &transb, &n, &k, (const Kokkos::complex *)&alpha, (const Kokkos::complex *)a, + &lda, (const Kokkos::complex *)&beta, (Kokkos::complex *)c, &ldc); + return 0; +} +template <> +int Blas>::trsm(const char side, const char uplo, const char transa, const char diag, int m, int n, + const std::complex alpha, const std::complex *a, int lda, + /* */ std::complex *b, int ldb) { + F77_FUNC_CTRSM(&side, &uplo, &transa, &diag, &m, &n, (const Kokkos::complex *)&alpha, + (const Kokkos::complex *)a, &lda, (Kokkos::complex *)b, &ldb); + return 0; } +/// +/// std::complex +/// + +template <> +int Blas>::gemv(const char trans, int m, int n, const std::complex alpha, + const std::complex *a, int lda, const std::complex *b, int ldb, + const std::complex beta, + /* */ std::complex *c, int ldc) { + F77_FUNC_ZGEMV(&trans, &m, &n, (const Kokkos::complex *)&alpha, (const Kokkos::complex *)a, &lda, + (const Kokkos::complex *)b, &ldb, (const Kokkos::complex *)&beta, + (Kokkos::complex *)c, &ldc); + return 0; +} +template <> +int Blas>::trsv(const char uplo, const char transa, const char diag, int m, + const std::complex *a, int lda, + /* */ std::complex *b, int ldb) { + F77_FUNC_ZTRSV(&uplo, &transa, &diag, &m, (const Kokkos::complex *)a, &lda, (Kokkos::complex *)b, + &ldb); + return 0; +} +template <> +int Blas>::gemm(const char transa, const char transb, int m, int n, int k, + const std::complex alpha, const std::complex *a, int lda, + const std::complex *b, int ldb, const std::complex beta, + /* */ std::complex *c, int ldc) { + F77_FUNC_ZGEMM(&transa, &transb, &m, &n, &k, (const Kokkos::complex *)&alpha, + (const Kokkos::complex *)a, &lda, (const Kokkos::complex *)b, &ldb, + (const Kokkos::complex *)&beta, (Kokkos::complex *)c, &ldc); + return 0; +} +template <> +int Blas>::herk(const char transa, const char transb, int n, int k, + const std::complex alpha, const std::complex *a, int lda, + const std::complex beta, + /* */ std::complex *c, int ldc) { + F77_FUNC_ZHERK(&transa, &transb, &n, &k, (const Kokkos::complex *)&alpha, (const Kokkos::complex *)a, + &lda, (const Kokkos::complex *)&beta, (Kokkos::complex *)c, &ldc); + return 0; +} +template <> +int Blas>::trsm(const char side, const char uplo, const char transa, const char diag, int m, int n, + const std::complex alpha, const std::complex *a, int lda, + /* */ std::complex *b, int ldb) { + F77_FUNC_ZTRSM(&side, &uplo, &transa, &diag, &m, &n, (const Kokkos::complex *)&alpha, + (const Kokkos::complex *)a, &lda, (Kokkos::complex *)b, &ldb); + return 0; +} +} // namespace Tacho diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.hpp index 3b456d78118c..1b10b48daa16 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_External.hpp @@ -5,106 +5,97 @@ /// \brief BLAS wrapper /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "Kokkos_Core.hpp" // CUDA specialization +#include "Kokkos_Core.hpp" -#if defined(KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) +#define TACHO_ENABLE_CUBLAS +#endif + +#if defined(KOKKOS_ENABLE_HIP) +// todo: enable hipblas interface after checking on AMD machine +//#define TACHO_ENABLE_HIPBLAS +#endif + +#if defined(TACHO_ENABLE_CUBLAS) #include "cublas_v2.h" #endif +#if defined(TACHO_ENABLE_HIPBLAS) +#include "hipblas.h" +#endif namespace Tacho { - template - struct Blas { - static - int gemv(const char trans, - int m, int n, - const T alpha, - const T *a, int lda, - const T *b, int ldb, - const T beta, - /* */ T *c, int ldc); -#if defined (KOKKOS_ENABLE_CUDA) - static - int gemv(cublasHandle_t handle, - const cublasOperation_t trans, - int m, int n, - const T alpha, - const T *a, int lda, - const T *b, int ldb, - const T beta, - /* */ T *c, int ldc); +template struct Blas { + static int gemv(const char trans, int m, int n, const T alpha, const T *a, int lda, const T *b, int ldb, const T beta, + /* */ T *c, int ldc); +#if defined(TACHO_ENABLE_CUBLAS) + static int gemv(cublasHandle_t handle, const cublasOperation_t trans, int m, int n, const T alpha, const T *a, + int lda, const T *b, int ldb, const T beta, + /* */ T *c, int ldc); +#endif +#if defined(TACHO_ENABLE_HIPBLAS) + static int gemv(hipblasHandle_t handle, const hipblasOperation_t trans, int m, int n, const T alpha, const T *a, + int lda, const T *b, int ldb, const T beta, + /* */ T *c, int ldc); #endif - static - int trsv(const char uplo, const char transa, const char diag, - int m, - const T *a, int lda, - /* */ T *b, int ldb); -#if defined (KOKKOS_ENABLE_CUDA) - static - int trsv(cublasHandle_t handle, - const cublasFillMode_t uplo, const cublasOperation_t transa, const cublasDiagType_t diag, - int m, - const T *a, int lda, - /* */ T *b, int ldb); + static int trsv(const char uplo, const char transa, const char diag, int m, const T *a, int lda, + /* */ T *b, int ldb); +#if defined(TACHO_ENABLE_CUBLAS) + static int trsv(cublasHandle_t handle, const cublasFillMode_t uplo, const cublasOperation_t transa, + const cublasDiagType_t diag, int m, const T *a, int lda, + /* */ T *b, int ldb); +#endif +#if defined(TACHO_ENABLE_HIPBLAS) + static int trsv(hipblasHandle_t handle, const hipblasFillMode_t uplo, const hipblasOperation_t transa, + const hipblasDiagType_t diag, int m, const T *a, int lda, + /* */ T *b, int ldb); #endif - static - int gemm(const char transa, const char transb, - int m, int n, int k, - const T alpha, - const T *a, int lda, - const T *b, int ldb, - const T beta, - /* */ T *c, int ldc); -#if defined (KOKKOS_ENABLE_CUDA) - static - int gemm(const cublasHandle_t handle, - const cublasOperation_t transa, const cublasOperation_t transb, - int m, int n, int k, - const T alpha, - const T *a, int lda, - const T *b, int ldb, - const T beta, - /* */ T *c, int ldc); + static int gemm(const char transa, const char transb, int m, int n, int k, const T alpha, const T *a, int lda, + const T *b, int ldb, const T beta, + /* */ T *c, int ldc); +#if defined(TACHO_ENABLE_CUBLAS) + static int gemm(const cublasHandle_t handle, const cublasOperation_t transa, const cublasOperation_t transb, int m, + int n, int k, const T alpha, const T *a, int lda, const T *b, int ldb, const T beta, + /* */ T *c, int ldc); +#endif +#if defined(TACHO_ENABLE_HIPBLAS) + static int gemm(const hipblasHandle_t handle, const hipblasOperation_t transa, const hipblasOperation_t transb, int m, + int n, int k, const T alpha, const T *a, int lda, const T *b, int ldb, const T beta, + /* */ T *c, int ldc); #endif - static - int herk(const char uplo, const char trans, - int n, int k, - const T alpha, - const T *a, int lda, - const T beta, - /* */ T *c, int ldc); -#if defined (KOKKOS_ENABLE_CUDA) - static - int herk(cublasHandle_t handle, - const cublasFillMode_t uplo, const cublasOperation_t trans, - int n, int k, - const T alpha, - const T *a, int lda, - const T beta, - /* */ T *c, int ldc); + static int herk(const char uplo, const char trans, int n, int k, const T alpha, const T *a, int lda, const T beta, + /* */ T *c, int ldc); +#if defined(TACHO_ENABLE_CUBLAS) + static int herk(cublasHandle_t handle, const cublasFillMode_t uplo, const cublasOperation_t trans, int n, int k, + const T alpha, const T *a, int lda, const T beta, + /* */ T *c, int ldc); +#endif +#if defined(TACHO_ENABLE_HIPBLAS) + static int herk(hipblasHandle_t handle, const hipblasFillMode_t uplo, const hipblasOperation_t trans, int n, int k, + const T alpha, const T *a, int lda, const T beta, + /* */ T *c, int ldc); #endif - static - int trsm(const char side, const char uplo, const char transa, const char diag, - int m, int n, - const T alpha, - const T *a, int lda, - /* */ T *b, int ldb); -#if defined (KOKKOS_ENABLE_CUDA) - static - int trsm(cublasHandle_t handle, - const cublasSideMode_t side, const cublasFillMode_t uplo, - const cublasOperation_t trans, const cublasDiagType_t diag, - int m, int n, - const T alpha, - const T *a, int lda, - /* */ T *b, int ldb); + static int trsm(const char side, const char uplo, const char transa, const char diag, int m, int n, const T alpha, + const T *a, int lda, + /* */ T *b, int ldb); +#if defined(TACHO_ENABLE_CUBLAS) + static int trsm(cublasHandle_t handle, const cublasSideMode_t side, const cublasFillMode_t uplo, + const cublasOperation_t trans, const cublasDiagType_t diag, int m, int n, const T alpha, const T *a, + int lda, + /* */ T *b, int ldb); +#endif +#if defined(TACHO_ENABLE_HIPBLAS) + static int trsm(hipblasHandle_t handle, const hipblasSideMode_t side, const hipblasFillMode_t uplo, + const hipblasOperation_t trans, const hipblasDiagType_t diag, int m, int n, const T alpha, const T *a, + int lda, + /* */ T *b, int ldb); #endif - }; +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_Team.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_Team.hpp index c37c949f2ffe..90dffed312f8 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_Team.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_Team.hpp @@ -9,1078 +9,812 @@ namespace Tacho { - template - struct BlasTeam { - struct Impl { - template - static - KOKKOS_INLINE_FUNCTION - void set(MemberType &member, - int m, - const T alpha, - /* */ T *__restrict__ a, int as0) { - Kokkos::parallel_for(Kokkos::TeamVectorRange(member,m),[&](const int &i) { - a[i*as0] = alpha; +template struct BlasTeam { + struct Impl { + template + static KOKKOS_INLINE_FUNCTION void set(MemberType &member, int m, const T alpha, + /* */ T *__restrict__ a, int as0) { + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &i) { a[i * as0] = alpha; }); + } + + template + static KOKKOS_INLINE_FUNCTION void scale(MemberType &member, int m, const T alpha, + /* */ T *__restrict__ a, int as0) { + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &i) { a[i * as0] *= alpha; }); + } + + template + static KOKKOS_INLINE_FUNCTION void set(MemberType &member, int m, int n, const T alpha, + /* */ T *__restrict__ a, int as0, int as1) { + if (as0 == 1 || as0 < as1) + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), + [&](const int &i) { a[i * as0 + j * as1] = alpha; }); + }); + else + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, n), + [&](const int &j) { a[i * as0 + j * as1] = alpha; }); + }); + } + + template + static KOKKOS_INLINE_FUNCTION void scale(MemberType &member, int m, int n, const T alpha, + /* */ T *__restrict__ a, int as0, int as1) { + if (as0 == 1 || as0 < as1) + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), + [&](const int &i) { a[i * as0 + j * as1] *= alpha; }); + }); + else + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, n), + [&](const int &j) { a[i * as0 + j * as1] *= alpha; }); + }); + } + + template + static KOKKOS_INLINE_FUNCTION void set_upper(MemberType &member, int m, int n, int offset, const T alpha, + /* */ T *__restrict__ a, int as0, int as1) { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j + 1 - offset), + [&](const int &i) { a[i * as0 + j * as1] = alpha; }); + }); + } + + template + static KOKKOS_INLINE_FUNCTION void scale_upper(MemberType &member, int m, int n, int offset, const T alpha, + /* */ T *__restrict__ a, int as0, int as1) { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j + 1 - offset), + [&](const int &i) { a[i * as0 + j * as1] *= alpha; }); + }); + } + + template + static KOKKOS_INLINE_FUNCTION void set_lower(MemberType &member, int m, int n, int offset, const T alpha, + /* */ T *__restrict__ a, int as0, int as1) { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + const int jj = j + offset; + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, n - j - offset), + [&, alpha, as0, as1, j, jj](const int &i) { /// compiler bug with c++14 lambda capturing and workaround + a[(i + jj) * as0 + j * as1] = alpha; }); - } - - template - static - KOKKOS_INLINE_FUNCTION - void scale(MemberType &member, - int m, - const T alpha, - /* */ T *__restrict__ a, int as0) { - Kokkos::parallel_for(Kokkos::TeamVectorRange(member,m),[&](const int &i) { - a[i*as0] *= alpha; + }); + } + + template + static KOKKOS_INLINE_FUNCTION void scale_lower(MemberType &member, int m, int n, int offset, const T alpha, + /* */ T *__restrict__ a, int as0, int as1) { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + const int jj = j + offset; + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, n - j - offset), + [&, alpha, as0, as1, j, jj](const int &i) { /// compiler bug with c++14 lambda capturing and workaround + a[(i + jj) * as0 + j * as1] *= alpha; + }); + }); + } + + template + static KOKKOS_INLINE_FUNCTION void gemv(MemberType &member, const ConjType &cj, const int m, const int n, + const T alpha, const T *__restrict__ A, const int as0, const int as1, + const T *__restrict__ x, const int xs0, const T beta, + /* */ T *__restrict__ y, const int ys0) { + const T one(1), zero(0); + + if (beta == zero) + set(member, m, zero, y, ys0); + if (beta != one) + scale(member, m, beta, y, ys0); + + if (alpha != zero) { + if (m <= 0 || n <= 0) + return; + + member.team_barrier(); + { + if (as0 == 1 || as0 < as1) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), [&](const int &i) { + T t(0); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), + [&](const int &j) { t += cj(A[i * as0 + j * as1]) * x[j * xs0]; }); + Kokkos::atomic_add(&y[i * ys0], alpha * t); + }); + } else { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) { + T t(0); + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, n), + [&](const int &j) { t += cj(A[i * as0 + j * as1]) * x[j * xs0]; }); + Kokkos::atomic_add(&y[i * ys0], alpha * t); }); + } } - - template - static - KOKKOS_INLINE_FUNCTION - void set(MemberType &member, - int m, int n, - const T alpha, - /* */ T *__restrict__ a, int as0, int as1) { - if (as0 == 1 || as0 < as1) - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) { - a[i*as0+j*as1] = alpha; - }); - }); - else - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,m),[&](const int &i) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n),[&](const int &j) { - a[i*as0+j*as1] = alpha; - }); - }); + } + } + + template + static KOKKOS_INLINE_FUNCTION void trsv_upper(MemberType &member, const ConjType &cjA, const char diag, const int m, + const T *__restrict__ A, const int as0, const int as1, + /* */ T *__restrict__ b, const int bs0) { + if (m <= 0) + return; + + const bool use_unit_diag = diag == 'U' || diag == 'u'; + T *__restrict__ b0 = b; + for (int p = (m - 1); p >= 0; --p) { + const int iend = p; + + const T *__restrict__ a01 = A + p * as1; + /**/ T *__restrict__ beta1 = b + p * bs0; + + /// make sure the previous iteration update is done + member.team_barrier(); + T local_beta1 = *beta1; + if (!use_unit_diag) { + const T alpha11 = cjA(A[p * as0 + p * as1]); + local_beta1 /= alpha11; + /// before modifying beta1 we need make sure + /// that every local_beta1 has the previous beta1 value + member.team_barrier(); + Kokkos::single(Kokkos::PerTeam(member), [&]() { *beta1 = local_beta1; }); } - - template - static - KOKKOS_INLINE_FUNCTION - void scale(MemberType &member, - int m, int n, - const T alpha, - /* */ T *__restrict__ a, int as0, int as1) { - if (as0 == 1 || as0 < as1) - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) { - a[i*as0+j*as1] *= alpha; - }); - }); - else - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,m),[&](const int &i) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n),[&](const int &j) { - a[i*as0+j*as1] *= alpha; - }); - }); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, iend), + [&](const int &i) { b0[i * bs0] -= cjA(a01[i * as0]) * local_beta1; }); + } + } + + template + static KOKKOS_INLINE_FUNCTION void trsv_lower(MemberType &member, const ConjType &cjA, const char diag, const int m, + const T *__restrict__ A, const int as0, const int as1, + /* */ T *__restrict__ b, const int bs0) { + if (m <= 0) + return; + + const bool use_unit_diag = diag == 'U' || diag == 'u'; + // T *__restrict__ b0 = b; + for (int p = 0; p < m; ++p) { + const int iend = m - p - 1; + + const T *__restrict__ a21 = iend ? A + (p + 1) * as0 + p * as1 : NULL; + + T *__restrict__ beta1 = b + p * bs0, *__restrict__ b2 = iend ? beta1 + bs0 : NULL; + + /// make sure that the previous iteration update is done + member.team_barrier(); + T local_beta1 = *beta1; + if (!use_unit_diag) { + const T alpha11 = A[p * as0 + p * as1]; + local_beta1 /= alpha11; + /// before modifying beta1 we need make sure + /// that every local_beta1 has the previous beta1 value + member.team_barrier(); + Kokkos::single(Kokkos::PerTeam(member), [&]() { *beta1 = local_beta1; }); } - - template - static - KOKKOS_INLINE_FUNCTION - void set_upper(MemberType &member, - int m, int n, int offset, - const T alpha, - /* */ T *__restrict__ a, int as0, int as1) { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,j+1-offset),[&](const int &i) { - a[i*as0+j*as1] = alpha; - }); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, iend), + [&](const int &i) { b2[i * bs0] -= cjA(a21[i * as0]) * local_beta1; }); + } + } + + template + static KOKKOS_INLINE_FUNCTION void gemm(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, const int m, + const int n, const int k, const T alpha, const T *__restrict__ A, + const int as0, const int as1, const T *__restrict__ B, const int bs0, + const int bs1, const T beta, + /* */ T *__restrict__ C, const int cs0, const int cs1) { + const T one(1), zero(0); + + if (beta == zero) + set(member, m, n, zero, C, cs0, cs1); + else if (beta != one) + scale(member, m, n, beta, C, cs0, cs1); + + if (alpha != zero) { + if (m <= 0 || n <= 0 || k <= 0) + return; + + member.team_barrier(); + { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), [&](const int &i) { + const T *__restrict__ pA = A + i * as0, *__restrict__ pB = B + j * bs1; + T c(0); + for (int p = 0; p < k; ++p) + c += cjA(pA[p * as1]) * cjB(pB[p * bs0]); + C[i * cs0 + j * cs1] += alpha * c; }); + }); } - - template - static - KOKKOS_INLINE_FUNCTION - void scale_upper(MemberType &member, - int m, int n, int offset, - const T alpha, - /* */ T *__restrict__ a, int as0, int as1) { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,j+1-offset),[&](const int &i) { - a[i*as0+j*as1] *= alpha; - }); + } + } + + template + static KOKKOS_INLINE_FUNCTION void gemm_upper(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, + const int m, const int n, const int k, const T alpha, + const T *__restrict__ A, const int as0, const int as1, + const T *__restrict__ B, const int bs0, const int bs1, const T beta, + /* */ T *__restrict__ C, const int cs0, const int cs1) { + const T one(1), zero(0); + + if (beta == zero) + set(member, m, n, zero, C, cs0, cs1); + else if (beta != one) + scale(member, m, n, beta, C, cs0, cs1); + + if (alpha != zero) { + if (m <= 0 || n <= 0 || k <= 0) + return; + + member.team_barrier(); + { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j + 1), [&](const int &i) { + const T *__restrict__ pA = A + i * as0, *__restrict__ pB = B + j * bs1; + T c(0); + for (int p = 0; p < k; ++p) + c += cjA(pA[p * as1]) * cjB(pB[p * bs0]); + C[i * cs0 + j * cs1] += alpha * c; }); + }); } - - - template - static - KOKKOS_INLINE_FUNCTION - void set_lower(MemberType &member, - int m, int n, int offset, - const T alpha, - /* */ T *__restrict__ a, int as0, int as1) { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - const int jj = j + offset; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n-j-offset),[&,alpha,as0,as1,j,jj](const int &i) { /// compiler bug with c++14 lambda capturing and workaround - a[(i+jj)*as0+j*as1] = alpha; - }); + } + } + + template + static KOKKOS_INLINE_FUNCTION void herk_upper(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, + const int n, const int k, const T alpha, const T *__restrict__ A, + const int as0, const int as1, const T beta, + /* */ T *__restrict__ C, const int cs0, const int cs1) { + const T one(1), zero(0); + + if (beta == zero) + set_upper(member, n, n, 0, zero, C, cs0, cs1); + else if (beta != one) + scale_upper(member, n, n, 0, beta, C, cs0, cs1); + + if (alpha != zero) { + if (n <= 0 || k <= 0) + return; + + member.team_barrier(); + { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + const T *__restrict__ pA = A + j * as0; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j + 1), [&](const int &i) { + const T *__restrict__ pB = A + i * as0; + T c(0); + for (int p = 0; p < k; ++p) + c += cjA(pA[p * as1]) * cjB(pB[p * as1]); + C[i * cs0 + j * cs1] += alpha * c; }); + }); } - - template - static - KOKKOS_INLINE_FUNCTION - void scale_lower(MemberType &member, - int m, int n, int offset, - const T alpha, - /* */ T *__restrict__ a, int as0, int as1) { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - const int jj = j + offset; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n-j-offset),[&,alpha,as0,as1,j,jj](const int &i) {/// compiler bug with c++14 lambda capturing and workaround - a[(i+jj)*as0+j*as1] *= alpha; - }); + } + } + + template + static KOKKOS_INLINE_FUNCTION void herk_lower(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, + const int n, const int k, const T alpha, const T *__restrict__ A, + const int as0, const int as1, const T beta, + /* */ T *__restrict__ C, const int cs0, const int cs1) { + const T one(1), zero(0); + + if (beta == zero) + set_lower(member, n, n, 0, zero, C, cs0, cs1); + else if (beta != one) + scale_lower(member, n, n, 0, beta, C, cs0, cs1); + + if (alpha != zero) { + if (n <= 0 || k <= 0) + return; + + member.team_barrier(); + { + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, n - j), [&](const int &i) { + const int ii = i + j; + const T *__restrict__ pA = A + j * as0, *__restrict__ pB = A + ii * as0; + T c(0); + for (int p = 0; p < k; ++p) + c += cjA(pA[p * as1]) * cjB(pB[p * as1]); + C[ii * cs0 + j * cs1] += alpha * c; }); + }); } - - template - static - KOKKOS_INLINE_FUNCTION - void gemv(MemberType &member, const ConjType &cj, - const int m, const int n, - const T alpha, - const T *__restrict__ A, const int as0, const int as1, - const T *__restrict__ x, const int xs0, - const T beta, - /* */ T *__restrict__ y, const int ys0) { - const T one(1), zero(0); - - if (beta == zero) set (member, m, zero, y, ys0); - if (beta != one ) scale(member, m, beta, y, ys0); - - if (alpha != zero) { - if (m <=0 || n <=0) return; - - member.team_barrier(); - { - if (as0 == 1 || as0 < as1) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) { - T t(0); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - t += cj(A[i*as0+j*as1])*x[j*xs0]; - }); - Kokkos::atomic_add(&y[i*ys0], alpha*t); - }); - } else { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,m),[&](const int &i) { - T t(0); - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n),[&](const int &j) { - t += cj(A[i*as0+j*as1])*x[j*xs0]; - }); - Kokkos::atomic_add(&y[i*ys0], alpha*t); - }); - } - } - } - } - - template - static - KOKKOS_INLINE_FUNCTION - void trsv_upper(MemberType &member, const ConjType &cjA, - const char diag, - const int m, - const T *__restrict__ A, const int as0, const int as1, - /* */ T *__restrict__ b, const int bs0) { - if (m <= 0) return; - - const bool use_unit_diag = diag == 'U'|| diag == 'u'; - T *__restrict__ b0 = b; - for (int p=(m-1);p>=0;--p) { - const int iend = p; - - const T *__restrict__ a01 = A+p*as1; - /**/ T *__restrict__ beta1 = b+p*bs0; - - /// make sure the previous iteration update is done + } + } + + template + static KOKKOS_INLINE_FUNCTION void trsm_left_lower(MemberType &member, const ConjType &cjA, const char diag, + const int m, const int n, const T alpha, const T *__restrict__ A, + const int as0, const int as1, + /* */ T *__restrict__ B, const int bs0, const int bs1) { + const T one(1), zero(0); + + if (alpha == zero) + set(member, m, n, zero, B, bs0, bs1); + else { + if (alpha != one) + scale(member, m, n, alpha, B, bs0, bs1); + if (m <= 0 || n <= 0) + return; + + const bool use_unit_diag = diag == 'U' || diag == 'u'; + for (int p = 0; p < m; ++p) { + const int iend = m - p - 1, jend = n; + + const T *__restrict__ a21 = iend ? A + (p + 1) * as0 + p * as1 : NULL; + + T *__restrict__ b1t = B + p * bs0, *__restrict__ B2 = iend ? B + (p + 1) * bs0 : NULL; + + member.team_barrier(); + if (!use_unit_diag) { + const T alpha11 = cjA(A[p * as0 + p * as1]); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, jend), [&](const int &j) { b1t[j * bs1] /= alpha11; }); member.team_barrier(); - T local_beta1 = *beta1; - if (!use_unit_diag) { - const T alpha11 = cjA(A[p*as0+p*as1]); - local_beta1 /= alpha11; - /// before modifying beta1 we need make sure - /// that every local_beta1 has the previous beta1 value - member.team_barrier(); - Kokkos::single(Kokkos::PerTeam(member), [&]() { - *beta1 = local_beta1; - }); - } - Kokkos::parallel_for(Kokkos::TeamVectorRange(member,iend),[&](const int &i) { - b0[i*bs0] -= cjA(a01[i*as0]) * local_beta1; - }); } - } - - - template - static - KOKKOS_INLINE_FUNCTION - void trsv_lower(MemberType &member, const ConjType &cjA, - const char diag, - const int m, - const T *__restrict__ A, const int as0, const int as1, - /* */ T *__restrict__ b, const int bs0) { - if (m <= 0) return; - - const bool use_unit_diag = diag == 'U'|| diag == 'u'; - //T *__restrict__ b0 = b; - for (int p=0;p + static KOKKOS_INLINE_FUNCTION void trsm_left_upper(MemberType &member, const ConjType &cjA, const char diag, + const int m, const int n, const T alpha, const T *__restrict__ A, + const int as0, const int as1, + /* */ T *__restrict__ B, const int bs0, const int bs1) { + const T one(1.0), zero(0.0); + + // note that parallel range is different ( m*n vs m-1*n); + if (alpha == zero) + set(member, m, n, zero, B, bs0, bs1); + else { + if (alpha != one) + scale(member, m, n, alpha, B, bs0, bs1); + if (m <= 0 || n <= 0) + return; + + const bool use_unit_diag = diag == 'U' || diag == 'u'; + T *__restrict__ B0 = B; + for (int p = (m - 1); p >= 0; --p) { + const int iend = p, jend = n; + + const T *__restrict__ a01 = A + p * as1; + /**/ T *__restrict__ b1t = B + p * bs0; + + member.team_barrier(); + if (!use_unit_diag) { + const T alpha11 = cjA(A[p * as0 + p * as1]); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, jend), [&](const int &j) { b1t[j * bs1] /= alpha11; }); member.team_barrier(); - T local_beta1 = *beta1; - if (!use_unit_diag) { - const T alpha11 = A[p*as0+p*as1]; - local_beta1 /= alpha11; - /// before modifying beta1 we need make sure - /// that every local_beta1 has the previous beta1 value - member.team_barrier(); - Kokkos::single(Kokkos::PerTeam(member), [&]() { - *beta1 = local_beta1; - }); - } - Kokkos::parallel_for(Kokkos::TeamVectorRange(member,iend),[&](const int &i) { - b2[i*bs0] -= a21[i*as0] * local_beta1; - }); } - } - - template - static - KOKKOS_INLINE_FUNCTION - void gemm(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, - const int m, const int n, const int k, - const T alpha, - const T *__restrict__ A, const int as0, const int as1, - const T *__restrict__ B, const int bs0, const int bs1, - const T beta, - /* */ T *__restrict__ C, const int cs0, const int cs1) { - const T one(1), zero(0); - - if (beta == zero) set (member, m, n, zero, C, cs0, cs1); - else if (beta != one ) scale(member, m, n, beta, C, cs0, cs1); - - if (alpha != zero) { - if (m <= 0 || n <= 0 || k <= 0) return; - - member.team_barrier(); - { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) { - const T - *__restrict__ pA = A+i*as0, - *__restrict__ pB = B+j*bs1; - T c(0); - for (int p=0;p - static - KOKKOS_INLINE_FUNCTION - void gemm_upper(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, - const int m, const int n, const int k, - const T alpha, - const T *__restrict__ A, const int as0, const int as1, - const T *__restrict__ B, const int bs0, const int bs1, - const T beta, - /* */ T *__restrict__ C, const int cs0, const int cs1) { - const T one(1), zero(0); - - if (beta == zero) set (member, m, n, zero, C, cs0, cs1); - else if (beta != one ) scale(member, m, n, beta, C, cs0, cs1); - - if (alpha != zero) { - if (m <= 0 || n <= 0 || k <= 0) return; - - member.team_barrier(); - { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,j+1),[&](const int &i) { - const T - *__restrict__ pA = A+i*as0, - *__restrict__ pB = B+j*bs1; - T c(0); - for (int p=0;p + static KOKKOS_INLINE_FUNCTION void gemv(MemberType &member, const char trans, const int m, const int n, const T alpha, + const T *__restrict__ a, const int lda, const T *__restrict__ x, const int xs, + const T beta, + /* */ T *__restrict__ y, const int ys) { + switch (trans) { + case 'N': + case 'n': { + const NoConjugate cj; + Impl::gemv(member, cj, m, n, alpha, a, 1, lda, x, xs, beta, y, ys); + break; + } + case 'T': + case 't': { + const NoConjugate cj; + Impl::gemv(member, cj, n, m, alpha, a, lda, 1, x, xs, beta, y, ys); + break; + } + case 'C': + case 'c': { + const Conjugate cj; + Impl::gemv(member, cj, n, m, alpha, a, lda, 1, x, xs, beta, y, ys); + break; + } + default: + Kokkos::abort("Invalid trans character"); + } + } + + template + static KOKKOS_INLINE_FUNCTION void trsv(MemberType &member, const char uplo, const char trans, const char diag, + const int m, const T *__restrict__ a, const int lda, + /* */ T *__restrict__ b, const int bs) { + if (uplo == 'U' || uplo == 'u') { + switch (trans) { + case 'N': + case 'n': { + NoConjugate cjA; + Impl::trsv_upper(member, cjA, diag, m, a, 1, lda, b, bs); + break; + } + case 'T': + case 't': { + NoConjugate cjA; + Impl::trsv_lower(member, cjA, diag, m, a, lda, 1, b, bs); + break; + } + case 'C': + case 'c': { + Conjugate cjA; + Impl::trsv_lower(member, cjA, diag, m, a, lda, 1, b, bs); + break; + } + default: + Kokkos::abort("trans is not valid"); + } + } else if (uplo == 'L' || uplo == 'l') { + switch (trans) { + case 'N': + case 'n': { + NoConjugate cjA; + Impl::trsv_lower(member, cjA, diag, m, a, 1, lda, b, bs); + break; + } + case 'T': + case 't': { + NoConjugate cjA; + Impl::trsv_upper(member, cjA, diag, m, a, lda, 1, b, bs); + break; + } + case 'C': + case 'c': { + Conjugate cjA; + Impl::trsv_upper(member, cjA, diag, m, a, lda, 1, b, bs); + break; + } + default: + Kokkos::abort("trans is not valid"); + } + } + } + + template + static KOKKOS_INLINE_FUNCTION void gemm(MemberType &member, const char transa, const char transb, const int m, + const int n, const int k, const T alpha, const T *__restrict__ a, int lda, + const T *__restrict__ b, int ldb, const T beta, + /* */ T *__restrict__ c, int ldc) { + + if (transa == 'N' || transa == 'n') { + const NoConjugate cjA; + switch (transb) { + case 'N': + case 'n': { + const NoConjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, 1, lda, b, 1, ldb, beta, c, 1, ldc); + break; + } + case 'T': + case 't': { + const NoConjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, 1, lda, b, ldb, 1, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const Conjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, 1, lda, b, ldb, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("transa is no trans but transb is not valid"); + } + } else if (transa == 'T' || transa == 't') { + const NoConjugate cjA; + switch (transb) { + case 'N': + case 'n': { + const NoConjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, 1, ldb, beta, c, 1, ldc); + break; + } + case 'T': + case 't': { + const NoConjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const Conjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("transa is trans but transb is not valid"); + } + } else if (transa == 'C' || transa == 'c') { + const Conjugate cjA; + switch (transb) { + case 'N': + case 'n': { + const NoConjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, 1, ldb, beta, c, 1, ldc); + break; + } + case 'T': + case 't': { + const NoConjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const Conjugate cjB; + Impl::gemm(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("transa is conj trans but transb is not valid"); + } + } else { + Kokkos::abort("transa is not valid"); + } + } + + template + static KOKKOS_INLINE_FUNCTION void gemm_upper(MemberType &member, const char transa, const char transb, const int m, + const int n, const int k, const T alpha, const T *__restrict__ a, + int lda, const T *__restrict__ b, int ldb, const T beta, + /* */ T *__restrict__ c, int ldc) { + + if (transa == 'N' || transa == 'n') { + const NoConjugate cjA; + switch (transb) { + case 'N': + case 'n': { + const NoConjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, 1, lda, b, 1, ldb, beta, c, 1, ldc); + break; + } + case 'T': + case 't': { + const NoConjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, 1, lda, b, ldb, 1, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const Conjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, 1, lda, b, ldb, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("transa is no trans but transb is not valid"); + } + } else if (transa == 'T' || transa == 't') { + const NoConjugate cjA; + switch (transb) { + case 'N': + case 'n': { + const NoConjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, 1, ldb, beta, c, 1, ldc); + break; + } + case 'T': + case 't': { + const NoConjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const Conjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("transa is trans but transb is not valid"); + } + } else if (transa == 'C' || transa == 'c') { + const Conjugate cjA; + switch (transb) { + case 'N': + case 'n': { + const NoConjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, 1, ldb, beta, c, 1, ldc); + break; + } + case 'T': + case 't': { + const NoConjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const Conjugate cjB; + Impl::gemm_upper(member, cjA, cjB, m, n, k, alpha, a, lda, 1, b, ldb, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("transa is conj trans but transb is not valid"); + } + } else { + Kokkos::abort("transa is not valid"); + } + } + + template + static KOKKOS_INLINE_FUNCTION void herk(MemberType &member, const char uplo, const char trans, const int n, + const int k, const T alpha, const T *__restrict__ a, const int lda, + const T beta, + /* */ T *__restrict__ c, const int ldc) { + if (uplo == 'U' || uplo == 'u') + switch (trans) { + case 'N': + case 'n': { + const NoConjugate cjA; + const Conjugate cjB; + Impl::herk_upper(member, cjA, cjB, n, k, alpha, a, 1, lda, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const NoConjugate cjA; + const Conjugate cjB; + Impl::herk_upper(member, cjA, cjB, n, k, alpha, a, lda, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("trans is not valid"); + } + else if (uplo == 'L' || uplo == 'l') + switch (trans) { + case 'N': + case 'n': { + const NoConjugate cjA; + const Conjugate cjB; + Impl::herk_lower(member, cjA, cjB, n, k, alpha, a, 1, lda, beta, c, 1, ldc); + break; + } + case 'C': + case 'c': { + const NoConjugate cjA; + const Conjugate cjB; + Impl::herk_lower(member, cjA, cjB, n, k, alpha, a, lda, 1, beta, c, 1, ldc); + break; + } + default: + Kokkos::abort("trans is not valid"); + } + else + Kokkos::abort("uplo is not valid"); + } + + template + static KOKKOS_INLINE_FUNCTION void trsm(MemberType &member, const char side, const char uplo, const char trans, + const char diag, const int m, const int n, const T alpha, + const T *__restrict__ a, const int lda, + /* */ T *__restrict__ b, const int ldb) { + /// + /// side left + /// + if (side == 'L' || side == 'l') { + if (uplo == 'U' || uplo == 'u') { + switch (trans) { + case 'N': + case 'n': { + NoConjugate cjA; + Impl::trsm_left_upper(member, cjA, diag, m, n, alpha, a, 1, lda, b, 1, ldb); + break; } - - - template - static - KOKKOS_INLINE_FUNCTION - void herk_upper(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, - const int n, const int k, - const T alpha, - const T *__restrict__ A, const int as0, const int as1, - const T beta, - /* */ T *__restrict__ C, const int cs0, const int cs1) { - const T one(1), zero(0); - - if (beta == zero) set_upper (member, n, n, 0, zero, C, cs0, cs1); - else if (beta != one ) scale_upper(member, n, n, 0, beta, C, cs0, cs1); - - if (alpha != zero) { - if (n <= 0 || k <= 0) return; - - member.team_barrier(); - { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - const T *__restrict__ pA = A+j*as0; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,j+1),[&](const int &i) { - const T *__restrict__ pB = A+i*as0; - T c(0); - for (int p=0;p - static - KOKKOS_INLINE_FUNCTION - void herk_lower(MemberType &member, const ConjTypeA &cjA, const ConjTypeB &cjB, - const int n, const int k, - const T alpha, - const T *__restrict__ A, const int as0, const int as1, - const T beta, - /* */ T *__restrict__ C, const int cs0, const int cs1) { - const T one(1), zero(0); - - if (beta == zero) set_lower (member, n, n, 0, zero, C, cs0, cs1); - else if (beta != one ) scale_lower(member, n, n, 0, beta, C, cs0, cs1); - - if (alpha != zero) { - if (n <= 0 || k <= 0) return; - - member.team_barrier(); - { - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n-j),[&](const int &i) { - const int ii = i+j; - const T - *__restrict__ pA = A+j*as0, - *__restrict__ pB = A+ii*as0; - T c(0); - for (int p=0;p - static - KOKKOS_INLINE_FUNCTION - void trsm_left_lower(MemberType &member, const ConjType &cjA, - const char diag, - const int m, const int n, - const T alpha, - const T *__restrict__ A, const int as0, const int as1, - /* */ T *__restrict__ B, const int bs0, const int bs1) { - const T one(1), zero(0); - - if (alpha == zero) set (member, m, n, zero, B, bs0, bs1); - else { - if (alpha != one) scale(member, m, n, alpha, B, bs0, bs1); - if (m <= 0 || n <= 0) return; - - const bool use_unit_diag = diag == 'U'|| diag == 'u'; - for (int p=0;p - static - KOKKOS_INLINE_FUNCTION - void trsm_left_upper(MemberType &member, const ConjType &cjA, - const char diag, - const int m, const int n, - const T alpha, - const T *__restrict__ A, const int as0, const int as1, - /* */ T *__restrict__ B, const int bs0, const int bs1) { - const T one(1.0), zero(0.0); - - // note that parallel range is different ( m*n vs m-1*n); - if (alpha == zero) set (member, m, n, zero, B, bs0, bs1); - else { - if (alpha != one) scale(member, m, n, alpha, B, bs0, bs1); - if (m <= 0 || n <= 0) return; - - const bool use_unit_diag = diag == 'U'|| diag == 'u'; - T *__restrict__ B0 = B; - for (int p=(m-1);p>=0;--p) { - const int iend = p, jend = n; - - const T *__restrict__ a01 = A+p*as1; - /**/ T *__restrict__ b1t = B+p*bs0; - - member.team_barrier(); - if (!use_unit_diag) { - const T alpha11 = cjA(A[p*as0+p*as1]); - Kokkos::parallel_for(Kokkos::TeamVectorRange(member,jend),[&](const int &j) { - b1t[j*bs1] /= alpha11; - }); - member.team_barrier(); - } - - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,jend),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,iend),[&](const int &i) { - B0[i*bs0+j*bs1] -= cjA(a01[i*as0]) * b1t[j*bs1]; - }); - }); - } - } + default: + Kokkos::abort("trans is not valid"); } - - }; - - template - static - KOKKOS_INLINE_FUNCTION - void gemv(MemberType &member, - const char trans, - const int m, const int n, - const T alpha, - const T *__restrict__ a, const int lda, - const T *__restrict__ x, const int xs, - const T beta, - /* */ T *__restrict__ y, const int ys) { + } else if (uplo == 'L' || uplo == 'l') { switch (trans) { case 'N': case 'n': { - const NoConjugate cj; - Impl::gemv(member, cj, - m, n, - alpha, - a, 1, lda, - x, xs, - beta, - y, ys); + NoConjugate cjA; + Impl::trsm_left_lower(member, cjA, diag, m, n, alpha, a, 1, lda, b, 1, ldb); break; } case 'T': case 't': { - const NoConjugate cj; - Impl::gemv(member, cj, - n, m, - alpha, - a, lda, 1, - x, xs, - beta, - y, ys); + NoConjugate cjA; + Impl::trsm_left_upper(member, cjA, diag, m, n, alpha, a, lda, 1, b, 1, ldb); break; } case 'C': case 'c': { - const Conjugate cj; - Impl::gemv(member, cj, - n, m, - alpha, - a, lda, 1, - x, xs, - beta, - y, ys); + Conjugate cjA; + Impl::trsm_left_upper(member, cjA, diag, m, n, alpha, a, lda, 1, b, 1, ldb); break; } default: - Kokkos::abort("Invalid trans character"); + Kokkos::abort("trans is not valid"); } } + } - - template - static - KOKKOS_INLINE_FUNCTION - void trsv(MemberType &member, - const char uplo, const char trans, const char diag, - const int m, - const T *__restrict__ a, const int lda, - /* */ T *__restrict__ b, const int bs) { - if (uplo == 'U' || uplo == 'u') { - switch (trans) { - case 'N': - case 'n': { - NoConjugate cjA; - Impl::trsv_upper(member, cjA, diag, - m, - a, 1, lda, - b, bs); - break; - } - case 'T': - case 't': { - NoConjugate cjA; - Impl::trsv_lower(member, cjA, diag, - m, - a, lda, 1, - b, bs); - break; - } - case 'C': - case 'c': { - Conjugate cjA; - Impl::trsv_lower(member, cjA, diag, - m, - a, lda, 1, - b, bs); - break; - } - default: - Kokkos::abort("trans is not valid"); - } - } else if (uplo == 'L' || uplo == 'l') { - switch (trans) { - case 'N': - case 'n': { - NoConjugate cjA; - Impl::trsv_lower(member, cjA, diag, - m, - a, 1, lda, - b, bs); - break; - } - case 'T': - case 't': { - NoConjugate cjA; - Impl::trsv_upper(member, cjA, diag, - m, - a, lda, 1, - b, bs); - break; - } - case 'C': - case 'c': { - Conjugate cjA; - Impl::trsv_upper(member, cjA, diag, - m, - a, lda, 1, - b, bs); - break; - } - default: - Kokkos::abort("trans is not valid"); - } + /// + /// side right + /// + else if (side == 'R' || side == 'r') { + if (uplo == 'U' || uplo == 'u') { + switch (trans) { + case 'N': + case 'n': { + NoConjugate cjA; + Impl::trsm_left_lower(member, cjA, diag, n, m, alpha, a, lda, 1, b, ldb, 1); + break; } - } - - template - static - KOKKOS_INLINE_FUNCTION - void gemm(MemberType &member, - const char transa, const char transb, - const int m, const int n, const int k, - const T alpha, - const T *__restrict__ a, int lda, - const T *__restrict__ b, int ldb, - const T beta, - /* */ T *__restrict__ c, int ldc) { - - if (transa == 'N' || transa == 'n') { - const NoConjugate cjA; - switch (transb) { - case 'N': - case 'n': { - const NoConjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, 1, lda, - b, 1, ldb, - beta, - c, 1, ldc); - break; - } - case 'T': - case 't': { - const NoConjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, 1, lda, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const Conjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, 1, lda, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("transa is no trans but transb is not valid"); - } - } else if (transa == 'T' || transa == 't') { - const NoConjugate cjA; - switch (transb) { - case 'N': - case 'n': { - const NoConjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, 1, ldb, - beta, - c, 1, ldc); - break; - } - case 'T': - case 't': { - const NoConjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const Conjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("transa is trans but transb is not valid"); - } - } else if (transa == 'C' || transa == 'c') { - const Conjugate cjA; - switch (transb) { - case 'N': - case 'n': { - const NoConjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, 1, ldb, - beta, - c, 1, ldc); - break; - } - case 'T': - case 't': { - const NoConjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const Conjugate cjB; - Impl::gemm(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("transa is conj trans but transb is not valid"); - } - } else { - Kokkos::abort("transa is not valid"); + case 'T': + case 't': { + Kokkos::abort("no no no"); + NoConjugate cjA; + Impl::trsm_left_lower(member, cjA, diag, m, n, alpha, a, lda, 1, b, 1, ldb); + break; } - } - - template - static - KOKKOS_INLINE_FUNCTION - void gemm_upper(MemberType &member, - const char transa, const char transb, - const int m, const int n, const int k, - const T alpha, - const T *__restrict__ a, int lda, - const T *__restrict__ b, int ldb, - const T beta, - /* */ T *__restrict__ c, int ldc) { - - if (transa == 'N' || transa == 'n') { - const NoConjugate cjA; - switch (transb) { - case 'N': - case 'n': { - const NoConjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, 1, lda, - b, 1, ldb, - beta, - c, 1, ldc); - break; - } - case 'T': - case 't': { - const NoConjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, 1, lda, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const Conjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, 1, lda, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("transa is no trans but transb is not valid"); - } - } else if (transa == 'T' || transa == 't') { - const NoConjugate cjA; - switch (transb) { - case 'N': - case 'n': { - const NoConjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, 1, ldb, - beta, - c, 1, ldc); - break; - } - case 'T': - case 't': { - const NoConjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const Conjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("transa is trans but transb is not valid"); - } - } else if (transa == 'C' || transa == 'c') { - const Conjugate cjA; - switch (transb) { - case 'N': - case 'n': { - const NoConjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, 1, ldb, - beta, - c, 1, ldc); - break; - } - case 'T': - case 't': { - const NoConjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const Conjugate cjB; - Impl::gemm_upper(member, cjA, cjB, - m, n, k, - alpha, - a, lda, 1, - b, ldb, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("transa is conj trans but transb is not valid"); - } - } else { - Kokkos::abort("transa is not valid"); + case 'C': + case 'c': { + Kokkos::abort("no no no"); + Conjugate cjA; + Impl::trsm_left_lower(member, cjA, diag, m, n, alpha, a, lda, 1, b, 1, ldb); + break; } - } - - - - template - static - KOKKOS_INLINE_FUNCTION - void herk(MemberType &member, - const char uplo, const char trans, - const int n, const int k, - const T alpha, - const T *__restrict__ a, const int lda, - const T beta, - /* */ T *__restrict__ c, const int ldc) { - if (uplo == 'U' || uplo == 'u') - switch (trans) { - case 'N': - case 'n': { - const NoConjugate cjA; - const Conjugate cjB; - Impl::herk_upper(member, cjA, cjB, - n, k, - alpha, - a, 1, lda, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const NoConjugate cjA; - const Conjugate cjB; - Impl::herk_upper(member, cjA, cjB, - n, k, - alpha, - a, lda, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("trans is not valid"); - } - else if (uplo == 'L' || uplo == 'l') - switch (trans) { - case 'N': - case 'n': { - const NoConjugate cjA; - const Conjugate cjB; - Impl::herk_lower(member, cjA, cjB, - n, k, - alpha, - a, 1, lda, - beta, - c, 1, ldc); - break; - } - case 'C': - case 'c': { - const NoConjugate cjA; - const Conjugate cjB; - Impl::herk_lower(member, cjA, cjB, - n, k, - alpha, - a, lda, 1, - beta, - c, 1, ldc); - break; - } - default: - Kokkos::abort("trans is not valid"); - } - else - Kokkos::abort("uplo is not valid"); - } - - - template - static - KOKKOS_INLINE_FUNCTION - void trsm(MemberType &member, - const char side, const char uplo, const char trans, const char diag, - const int m, const int n, - const T alpha, - const T *__restrict__ a, const int lda, - /* */ T *__restrict__ b, const int ldb) { - /// - /// side left - /// - if (side == 'L' || side == 'l') { - if (uplo == 'U' || uplo == 'u') { - switch (trans) { - case 'N': - case 'n': { - NoConjugate cjA; - Impl::trsm_left_upper(member, cjA, - diag, - m, n, - alpha, - a, 1, lda, - b, 1, ldb); - break; - } - case 'T': - case 't': { - NoConjugate cjA; - Impl::trsm_left_lower(member, cjA, - diag, - m, n, - alpha, - a, lda, 1, - b, 1, ldb); - break; - } - case 'C': - case 'c': { - Conjugate cjA; - Impl::trsm_left_lower(member, cjA, - diag, - m, n, - alpha, - a, lda, 1, - b, 1, ldb); - break; - } - default: - Kokkos::abort("trans is not valid"); - } - } else if (uplo == 'L' || uplo == 'l') { - switch (trans) { - case 'N': - case 'n': { - NoConjugate cjA; - Impl::trsm_left_lower(member, cjA, - diag, - m, n, - alpha, - a, 1, lda, - b, 1, ldb); - break; - } - case 'T': - case 't': { - NoConjugate cjA; - Impl::trsm_left_upper(member, cjA, - diag, - m, n, - alpha, - a, lda, 1, - b, 1, ldb); - break; - } - case 'C': - case 'c': { - Conjugate cjA; - Impl::trsm_left_upper(member, cjA, - diag, - m, n, - alpha, - a, lda, 1, - b, 1, ldb); - break; - } - default: - Kokkos::abort("trans is not valid"); - } - } - } - /// - /// side right - /// - else if (side == 'R' || side == 'r') { - Kokkos::abort("right side is not implemented"); + default: + Kokkos::abort("trans is not valid"); + } + } else if (uplo == 'L' || uplo == 'l') { + switch (trans) { + case 'N': + case 'n': { + NoConjugate cjA; + Impl::trsm_left_upper(member, cjA, diag, n, m, alpha, a, lda, 1, b, ldb, 1); + break; + } + case 'T': + case 't': { + Kokkos::abort("no no no"); + NoConjugate cjA; + Impl::trsm_left_upper(member, cjA, diag, m, n, alpha, a, lda, 1, b, 1, ldb); + break; + } + case 'C': + case 'c': { + Kokkos::abort("no no no"); + Conjugate cjA; + Impl::trsm_left_upper(member, cjA, diag, m, n, alpha, a, lda, 1, b, 1, ldb); + break; + } + default: + Kokkos::abort("trans is not valid"); } } - - }; - -} + } + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol.hpp index 23867357867e..194ac4b38929 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol.hpp @@ -8,56 +8,19 @@ #include "Tacho_Util.hpp" namespace Tacho { - - /// - /// Chol: - /// - /// - - /// various implementation for different uplo and algo parameters - template - struct Chol; - - /// task construction for the above chol implementation - // Chol::invoke(_sched, member, _A); - template - struct TaskFunctor_Chol { - public: - typedef SchedulerType scheduler_type; - typedef typename scheduler_type::member_type member_type; - - typedef DenseMatrixViewType dense_block_type; - typedef typename dense_block_type::future_type future_type; - typedef typename future_type::value_type value_type; - - private: - dense_block_type _A; - - public: - KOKKOS_INLINE_FUNCTION - TaskFunctor_Chol() = delete; - - KOKKOS_INLINE_FUNCTION - TaskFunctor_Chol(const dense_block_type &A) - : _A(A) {} - - KOKKOS_INLINE_FUNCTION - void operator()(member_type &member, value_type &r_val) { - const int ierr = Chol - ::invoke(member, _A); - - Kokkos::single(Kokkos::PerTeam(member), - [&, ierr]() { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - _A.set_future(); - r_val = ierr; - }); - } - }; - -} + +/// +/// Chol: +/// +/// + +/// various implementation for different uplo and algo parameters +template struct Chol; + +struct CholAlgorithm { + using type = ActiveAlgorithm::type; +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes.hpp index c1235b5f7463..a092c18244c9 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes.hpp @@ -8,8 +8,7 @@ namespace Tacho { - template - struct CholSupernodes; +template struct CholSupernodes; } diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_Serial.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_Serial.hpp index 3f7a8c250a4d..c7117fecddba 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_Serial.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_Serial.hpp @@ -38,194 +38,179 @@ namespace Tacho { - template<> - struct CholSupernodes { - template - KOKKOS_INLINE_FUNCTION - static int - factorize(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::value_type_matrix &ABR, - const ordinal_type sid) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - // algorithm choice - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type CholAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsmAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type HerkAlgoType; - - // get current supernode - const auto &s = info.supernodes(sid); - - // get panel pointer - value_type *ptr = s.buf; - - // panel (s.m x s.n) is divided into ATL (m x m) and ATR (m x n) - const ordinal_type m = s.m, n = s.n - s.m; - - // m and n are available, then factorize the supernode block - if (m > 0) { - UnmanagedViewType ATL(ptr, m, m); ptr += m*m; - Chol::invoke(member, ATL); - - if (n > 0) { - const value_type one(1), zero(0); - UnmanagedViewType ATR(ptr, m, n); // ptr += m*n; - Trsm - ::invoke(member, Diag::NonUnit(), one, ATL, ATR); - - TACHO_TEST_FOR_ABORT(static_cast(ABR.extent(0)) != n || +template <> struct CholSupernodes { + template + KOKKOS_INLINE_FUNCTION static int factorize(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::value_type_matrix &ABR, + const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + + // algorithm choice + using CholAlgoType = typename CholAlgorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using HerkAlgoType = typename HerkAlgorithm::type; + + // get current supernode + const auto &s = info.supernodes(sid); + + // get panel pointer + value_type *ptr = s.u_buf; + + // panel (s.m x s.n) is divided into ATL (m x m) and ATR (m x n) + const ordinal_type m = s.m, n = s.n - s.m; + + // m and n are available, then factorize the supernode block + if (m > 0) { + UnmanagedViewType ATL(ptr, m, m); + ptr += m * m; + Chol::invoke(member, ATL); + + if (n > 0) { + const value_type one(1), zero(0); + UnmanagedViewType ATR(ptr, m, n); // ptr += m*n; + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ATR); + + TACHO_TEST_FOR_ABORT(static_cast(ABR.extent(0)) != n || static_cast(ABR.extent(1)) != n, - "ABR dimension does not match to supernodes"); - Herk - ::invoke(member, -one, ATR, zero, ABR); - } - } - return 0; + "ABR dimension does not match to supernodes"); + Herk::invoke(member, -one, ATR, zero, ABR); } - - template - KOKKOS_INLINE_FUNCTION - static int - update(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::value_type_matrix &ABR, - const ordinal_type sid, - const size_type bufsize, - /* */ void *buf) { - - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::dense_block_type dense_block_type; - - const auto &cur = info.supernodes(sid); - - const ordinal_type - sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; - - const ordinal_type - srcbeg = info.sid_block_colidx(sbeg).second, - srcend = info.sid_block_colidx(send).second, - srcsize = srcend - srcbeg; - - // short cut to direct update - if ((send - sbeg) == 1) { - const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); - const ordinal_type - tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, - tgtend = info.sid_block_colidx(s.sid_col_end-1).second, - tgtsize = tgtend - tgtbeg; - - if (srcsize == tgtsize) { - /* */ value_type *tgt = s.buf; - const value_type *src = (value_type*)ABR.data(); - -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - // lock - while (Kokkos::atomic_compare_exchange(&s.lock, 0, 1)) KOKKOS_IMPL_PAUSE; - Kokkos::store_fence(); - - for (ordinal_type j=0;j + KOKKOS_INLINE_FUNCTION static int update(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::value_type_matrix &ABR, + const ordinal_type sid, const size_type bufsize, + /* */ void *buf, const bool update_lower = false) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using range_type = typename supernode_info_type::range_type; + + const auto &cur = info.supernodes(sid); + + const ordinal_type sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; + + const ordinal_type srcbeg = info.sid_block_colidx(sbeg).second, srcend = info.sid_block_colidx(send).second, + srcsize = srcend - srcbeg; + + // short cut to direct update + if ((send - sbeg) == 1) { + const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + + if (srcsize == tgtsize) { + /* */ value_type *tgt = s.u_buf; + const value_type *src = (value_type *)ABR.data(); + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + // lock + while (Kokkos::atomic_compare_exchange(&s.lock, 0, 1)) + KOKKOS_IMPL_PAUSE; + Kokkos::store_fence(); + + for (ordinal_type j = 0; j < srcsize; ++j) { + const value_type *__restrict__ ss = src + j * srcsize; + /* */ value_type *__restrict__ tt = tgt + j * srcsize; + const ordinal_type iend = update_lower ? srcsize : j + 1; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (ordinal_type i=0;i<(j+1);++i) - tt[i] += ss[i]; - } + for (ordinal_type i = 0; i < iend; ++i) + tt[i] += ss[i]; + } - // unlock - s.lock = 0; - Kokkos::load_fence(); + // unlock + s.lock = 0; + Kokkos::load_fence(); #else - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, srcsize), [&](const ordinal_type &j) { - const value_type *__restrict__ ss = src + j*srcsize; - /* */ value_type *__restrict__ tt = tgt + j*srcsize; - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, j+1), [&](const ordinal_type &i) { - Kokkos::atomic_add(&tt[i], ss[i]); - }); - }); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, srcsize), [&](const ordinal_type &j) { + const value_type *__restrict__ ss = src + j * srcsize; + /* */ value_type *__restrict__ tt = tgt + j * srcsize; + const ordinal_Type iend = update_lower ? srcsize : j + 1; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, iend), + [&](const ordinal_type &i) { Kokkos::atomic_add(&tt[i], ss[i]); }); + }); #endif - return 0; - } - } - - const ordinal_type *s_colidx = sbeg < send ? &info.gid_colidx(cur.gid_col_begin + srcbeg) : NULL; - - // loop over target -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - ordinal_type *s2t = (ordinal_type*)buf; - const size_type s2tsize = srcsize*sizeof(ordinal_type); - TACHO_TEST_FOR_ABORT(bufsize < s2tsize, "bufsize is smaller than required s2t workspace"); - - for (ordinal_type i=sbeg;i U(s.u_buf, s.m, s.n); + UnmanagedViewType Lp(s.l_buf, s.n, s.m); + const auto L = Kokkos::subview(Lp, range_type(s.m, s.n), Kokkos::ALL()); - // lock - while (Kokkos::atomic_compare_exchange(&s.lock, 0, 1)) KOKKOS_IMPL_PAUSE; - Kokkos::store_fence(); + ordinal_type ijbeg = 0; + for (; s2t[ijbeg] == -1; ++ijbeg) + ; - for (ordinal_type jj=ijbeg;jj= s.m) { + L(col - s.m, row) += ABR(jj, ii); } - - // unlock - s.lock = 0; - Kokkos::load_fence(); + } else + break; } } -#else - // CUDA version - const size_type s2tsize = srcsize*sizeof(ordinal_type)*member.team_size(); - TACHO_TEST_FOR_ABORT(bufsize < s2tsize, "bufsize is smaller than required s2t workspace"); -#if 0 // single version works + + // unlock + s.lock = 0; + Kokkos::load_fence(); + } + } +#else + // CUDA version + const size_type s2tsize = srcsize * sizeof(ordinal_type) * member.team_size(); + TACHO_TEST_FOR_ABORT(bufsize < s2tsize, "bufsize is smaller than required s2t workspace"); +#if 0 // single version for testing only Kokkos::single(Kokkos::PerTeam(member), [&]() { ordinal_type *s2t = (ordinal_type*)buf; for (ordinal_type i=sbeg;i U(s.u_buf, s.m, s.n); + UnmanagedViewType Lp(s.l_buf, s.n, s.m); + const auto L = Kokkos::subview(Lp, range_type(s.m, s.n), Kokkos::ALL()); + ordinal_type ijbeg = 0; for (;s2t[ijbeg] == -1; ++ijbeg) ; - for (ordinal_type jj=ijbeg;jj= s.m) { + Kokkos::atomic_add(&L(col-s.m, row), ABR(jj, ii)); + } + } else break; } + } } } }); #else - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, sbeg, send), [&](const ordinal_type &i) { - ordinal_type *s2t = ((ordinal_type*)(buf)) + member.team_rank()*srcsize; - const auto &s = info.supernodes(info.sid_block_colidx(i).first); - { - const ordinal_type - tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, - tgtend = info.sid_block_colidx(s.sid_col_end-1).second, - tgtsize = tgtend - tgtbeg; - - const ordinal_type *t_colidx = &info.gid_colidx(s.gid_col_begin + tgtbeg); - // for (ordinal_type k=0,l=0;k U(s.u_buf, s.m, s.n); + UnmanagedViewType Lp(s.l_buf, s.n, s.m); + const auto L = Kokkos::subview(Lp, range_type(s.m, s.n), Kokkos::ALL()); + + ordinal_type ijbeg = 0; + for (; s2t[ijbeg] == -1; ++ijbeg) + ; + + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, ijbeg, srcsize), [&](const ordinal_type &ii) { + const ordinal_type row = s2t[ii]; + if (row < s.m) { + for (ordinal_type jj = ijbeg; jj < srcsize; ++jj) { + const ordinal_type col = s2t[jj]; + Kokkos::atomic_add(&U(row, col), ABR(ii, jj)); + if (update_lower && col >= s.m) { + Kokkos::atomic_add(&L(col - s.m, row), ABR(jj, ii)); + } } - }); + } + }); + // Kokkos::parallel_for + // (Kokkos::ThreadVectorRange(member, srcsize-ijbeg), [&](const ordinal_type &jjj) { + // const ordinal_type jj = jjj + ijbeg; + // for (ordinal_type ii=ijbeg;ii + KOKKOS_INLINE_FUNCTION static int solve_lower(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using range_type = typename supernode_info_type::range_type; + + const auto &s = info.supernodes(sid); + + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + // get panel pointer + value_type *ptr = s.u_buf; + + // panel is divided into diagonal and interface block + const ordinal_type m = s.m, n = s.n - s.m, nrhs = info.x.extent(1); + + // m and n are available, then factorize the supernode block + if (m > 0) { + const value_type one(1), zero(0); + const ordinal_type offm = s.row_begin; + UnmanagedViewType AL(ptr, m, m); + ptr += m * m; + auto xT = Kokkos::subview(info.x, range_type(offm, offm + m), Kokkos::ALL()); + + if (nrhs >= ThresholdSolvePhaseUsingBlas3) + Trsm::invoke(member, Diag::NonUnit(), one, AL, xT); + else + Trsv::invoke(member, Diag::NonUnit(), AL, xT); + + if (n > 0) { + UnmanagedViewType AR(ptr, m, n); // ptr += m*n; + if (nrhs >= ThresholdSolvePhaseUsingBlas3) + Gemm::invoke(member, -one, AR, xT, zero, xB); + else + Gemv::invoke(member, -one, AR, xT, zero, xB); } - - template - KOKKOS_INLINE_FUNCTION - static int - solve_lower(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::value_type_matrix &xB, - const ordinal_type sid) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - typedef Kokkos::pair range_type; - - const auto &s = info.supernodes(sid); - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsmAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type GemmAlgoType; - - // get panel pointer - value_type *ptr = s.buf; - - // panel is divided into diagonal and interface block - const ordinal_type m = s.m, n = s.n - s.m, nrhs = info.x.extent(1); - - // m and n are available, then factorize the supernode block - if (m > 0) { - const value_type one(1), zero(0); - const ordinal_type offm = s.row_begin; - UnmanagedViewType AL(ptr, m, m); ptr += m*m; - auto xT = Kokkos::subview(info.x, range_type(offm, offm+m), Kokkos::ALL()); - - if (nrhs >= ThresholdSolvePhaseUsingBlas3) - Trsm - ::invoke(member, Diag::NonUnit(), one, AL, xT); - else - Trsv - ::invoke(member, Diag::NonUnit(), AL, xT); - - if (n > 0) { - UnmanagedViewType AR(ptr, m, n); // ptr += m*n; - if (nrhs >= ThresholdSolvePhaseUsingBlas3) - Gemm - ::invoke(member, -one, AR, xT, zero, xB); - else - Gemv - ::invoke(member, -one, AR, xT, zero, xB); - } - } - return 0; + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int update_solve_lower(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + const auto &cur = info.supernodes(sid); + const ordinal_type sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; + + const ordinal_type m = xB.extent(0), n = xB.extent(1); + TACHO_TEST_FOR_ABORT(m != (cur.n - cur.m), "# of rows in xB does not match to super blocksize in sid"); + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + for (ordinal_type i = sbeg, is = 0; i < send; ++i) { + const ordinal_type tbeg = info.sid_block_colidx(i).second, tend = info.sid_block_colidx(i + 1).second; + + // lock + const auto &s = info.supernodes(info.sid_block_colidx(i).first); + while (Kokkos::atomic_compare_exchange(&s.lock, 0, 1)) + KOKKOS_IMPL_PAUSE; + Kokkos::store_fence(); + + // both src and tgt increase index + for (ordinal_type it = tbeg; it < tend; ++it, ++is) { + const ordinal_type row = info.gid_colidx(cur.gid_col_begin + it); + for (ordinal_type j = 0; j < n; ++j) + info.x(row, j) += xB(is, j); } - template - KOKKOS_INLINE_FUNCTION - static int - update_solve_lower(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::value_type_matrix &xB, - const ordinal_type sid) { - //typedef SupernodeInfoType supernode_info_type; - //typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - const auto &cur = info.supernodes(sid); - const ordinal_type - sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; - - const ordinal_type m = xB.extent(0), n = xB.extent(1); - TACHO_TEST_FOR_ABORT(m != (cur.n-cur.m), "# of rows in xB does not match to super blocksize in sid"); - -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - for (ordinal_type i=sbeg,is=0;i - KOKKOS_INLINE_FUNCTION - static int - solve_upper(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::value_type_matrix &xB, - const ordinal_type sid) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - typedef Kokkos::pair range_type; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type GemmAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsmAlgoType; - - // get current supernode - const auto &s = info.supernodes(sid); - - // get supernode panel pointer - value_type *ptr = s.buf; - - // panel is divided into diagonal and interface block - const ordinal_type m = s.m, n = s.n - s.m, nrhs = info.x.extent(1); - - // m and n are available, then factorize the supernode block - if (m > 0) { - const value_type one(1); - const UnmanagedViewType AL(ptr, m, m); ptr += m*m; - - const ordinal_type offm = s.row_begin; - const auto xT = Kokkos::subview(info.x, range_type(offm, offm+m), Kokkos::ALL()); - - if (n > 0) { - const UnmanagedViewType AR(ptr, m, n); // ptr += m*n; - if (nrhs >= ThresholdSolvePhaseUsingBlas3) - Gemm - ::invoke(member, -one, AR, xB, one, xT); - else - Gemv - ::invoke(member, -one, AR, xB, one, xT); - } - if (nrhs >= ThresholdSolvePhaseUsingBlas3) - Trsm - ::invoke(member, Diag::NonUnit(), one, AL, xT); - else - Trsv - ::invoke(member, Diag::NonUnit(), AL, xT); + Kokkos::single(Kokkos::PerTeam(member), [&]() { + for (ordinal_type i = sbeg, is = 0; i < send; ++i) { + const ordinal_type tbeg = info.sid_block_colidx(i).second, tend = info.sid_block_colidx(i + 1).second; + + for (ordinal_type it = tbeg; it < tend; ++it, ++is) { + const ordinal_type row = info.gid_colidx(cur.gid_col_begin + it); + for (ordinal_type j = 0; j < n; ++j) + Kokkos::atomic_add(&info.x(row, j), xB(is, j)); } - return 0; } - - template - KOKKOS_INLINE_FUNCTION - static int - update_solve_upper(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::value_type_matrix &xB, - const ordinal_type sid) { - - //typedef SupernodeInfoType supernode_info_type; - //typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - const auto &s = info.supernodes(sid); - - const ordinal_type m = xB.extent(0), n = xB.extent(1); - TACHO_TEST_FOR_ABORT(m != (s.n-s.m), "# of rows in xB does not match to super blocksize in sid"); - - const ordinal_type goffset = s.gid_col_begin + s.m; -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - for (ordinal_type j=0;j + KOKKOS_INLINE_FUNCTION static int solve_upper(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using range_type = typename supernode_info_type::range_type; + + using GemmAlgoType = typename GemmAlgorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + + // get current supernode + const auto &s = info.supernodes(sid); + + // get supernode panel pointer + value_type *ptr = s.u_buf; + + // panel is divided into diagonal and interface block + const ordinal_type m = s.m, n = s.n - s.m, nrhs = info.x.extent(1); + + // m and n are available, then factorize the supernode block + if (m > 0) { + const value_type one(1); + const UnmanagedViewType AL(ptr, m, m); + ptr += m * m; + + const ordinal_type offm = s.row_begin; + const auto xT = Kokkos::subview(info.x, range_type(offm, offm + m), Kokkos::ALL()); + + if (n > 0) { + const UnmanagedViewType AR(ptr, m, n); // ptr += m*n; + if (nrhs >= ThresholdSolvePhaseUsingBlas3) + Gemm::invoke(member, -one, AR, xB, one, xT); + else + Gemv::invoke(member, -one, AR, xB, one, xT); + } + if (nrhs >= ThresholdSolvePhaseUsingBlas3) + Trsm::invoke(member, Diag::NonUnit(), one, AL, xT); + else + Trsv::invoke(member, Diag::NonUnit(), AL, xT); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int update_solve_upper(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + + const auto &s = info.supernodes(sid); + + const ordinal_type m = xB.extent(0), n = xB.extent(1); + TACHO_TEST_FOR_ABORT(m != (s.n - s.m), "# of rows in xB does not match to super blocksize in sid"); + + const ordinal_type goffset = s.gid_col_begin + s.m; +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type row = info.gid_colidx(i + goffset); + xB(i, j) = info.x(row, j); + } #else - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, n), [&](const ordinal_type &j) { - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, m), [&](const ordinal_type &i) { - const ordinal_type row = info.gid_colidx(i+goffset); - xB(i,j) = info.x(row,j); - }); - }); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const ordinal_type &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), [&](const ordinal_type &i) { + const ordinal_type row = info.gid_colidx(i + goffset); + xB(i, j) = info.x(row, j); + }); + }); #endif - return 0; - } + return 0; + } - template - KOKKOS_INLINE_FUNCTION - static int - factorize_recursive_serial(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid, - const bool final, - typename SupernodeInfoType::value_type *buf, - const size_type bufsize) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - const auto &s = info.supernodes(sid); - - if (final) { - // serial recursion - for (ordinal_type i=0;i + KOKKOS_INLINE_FUNCTION static int + factorize_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::value_type *buf, const size_type bufsize) { + using supernode_info_type = SupernodeInfoType; + + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; - { - const ordinal_type n = s.n - s.m; -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - const size_type bufsize_required = n*(n+1)*sizeof(value_type); + const auto &s = info.supernodes(sid); + + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + factorize_recursive_serial(member, info, s.children[i], final, buf, bufsize); + } + + { + const ordinal_type n = s.n - s.m; +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + const size_type bufsize_required = n * (n + 1) * sizeof(value_type); #else - const size_type bufsize_required = n*(n+member.team_size())*sizeof(value_type); + const size_type bufsize_required = n * (n + member.team_size()) * sizeof(value_type); #endif - TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, - "bufsize is smaller than required"); + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); - UnmanagedViewType ABR((value_type*)buf, n, n); + UnmanagedViewType ABR((value_type *)buf, n, n); - CholSupernodes - ::factorize(member, info, ABR, sid); + CholSupernodes::factorize(member, info, ABR, sid); - CholSupernodes - ::update(member, info, ABR, sid, - bufsize - ABR.span()*sizeof(value_type), - (void*)((value_type*)buf + ABR.span())); - } - return 0; - } + CholSupernodes::update(member, info, ABR, sid, bufsize - ABR.span() * sizeof(value_type), + (void *)((value_type *)buf + ABR.span())); + } + return 0; + } + template + KOKKOS_INLINE_FUNCTION static int + solve_lower_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::value_type *buf, const size_type bufsize) { + using supernode_info_type = SupernodeInfoType; - template - KOKKOS_INLINE_FUNCTION - static int - solve_lower_recursive_serial(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid, - const bool final, - typename SupernodeInfoType::value_type *buf, - const size_type bufsize) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - const auto &s = info.supernodes(sid); - - if (final) { - // serial recursion - for (ordinal_type i=0;i xB((value_type*)buf, n, nrhs); + { + const ordinal_type n = s.n - s.m; + const ordinal_type nrhs = info.x.extent(1); + const size_type bufsize_required = n * nrhs * sizeof(value_type); - CholSupernodes - ::solve_lower(member, info, xB, sid); + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); - CholSupernodes - ::update_solve_lower(member, info, xB, sid); - } - return 0; - } + UnmanagedViewType xB((value_type *)buf, n, nrhs); + CholSupernodes::solve_lower(member, info, xB, sid); - template - KOKKOS_INLINE_FUNCTION - static int - solve_upper_recursive_serial(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid, - const bool final, - typename SupernodeInfoType::value_type *buf, - const ordinal_type bufsize) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - - const auto &s = info.supernodes(sid); - { - const ordinal_type n = s.n - s.m; - const ordinal_type nrhs = info.x.extent(1); - const ordinal_type bufsize_required = n*nrhs*sizeof(value_type); - - TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, - "bufsize is smaller than required"); - - UnmanagedViewType xB((value_type*)buf, n, nrhs); - - CholSupernodes - ::update_solve_upper(member, info, xB, sid); - - CholSupernodes - ::solve_upper(member, info, xB, sid); - } + CholSupernodes::update_solve_lower(member, info, xB, sid); + } + return 0; + } - if (final) { - // serial recursion - for (ordinal_type i=0;i + KOKKOS_INLINE_FUNCTION static int solve_upper_recursive_serial(MemberType &member, const SupernodeInfoType &info, + const ordinal_type sid, const bool final, + typename SupernodeInfoType::value_type *buf, + const ordinal_type bufsize) { + using supernode_info_type = SupernodeInfoType; + + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + + const auto &s = info.supernodes(sid); + { + const ordinal_type n = s.n - s.m; + const ordinal_type nrhs = info.x.extent(1); + const ordinal_type bufsize_required = n * nrhs * sizeof(value_type); + + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + + UnmanagedViewType xB((value_type *)buf, n, nrhs); + + CholSupernodes::update_solve_upper(member, info, xB, sid); + + CholSupernodes::solve_upper(member, info, xB, sid); + } - }; -} + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + solve_upper_recursive_serial(member, info, s.children[i], final, buf, bufsize); + } + return 0; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_SerialPanel.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_SerialPanel.hpp index 7ea57745c9c6..5750b2e49059 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_SerialPanel.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_CholSupernodes_SerialPanel.hpp @@ -38,377 +38,335 @@ namespace Tacho { - template<> - struct CholSupernodes { - template - KOKKOS_INLINE_FUNCTION - static int - factorize(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - // algorithm choice - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type CholAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsmAlgoType; - - // get current supernode - const auto &s = info.supernodes(sid); - - // get panel pointer - value_type *ptr = s.buf; - - // panel (s.m x s.n) is divided into ATL (m x m) and ATR (m x n) - const ordinal_type m = s.m, n = s.n - s.m; - - // m is available, then factorize the supernode block - if (m > 0) { - UnmanagedViewType ATL(ptr, m, m); ptr += m*m; - Chol::invoke(member, ATL); - - // n is available, then solve interface block - if (n > 0) { - UnmanagedViewType ATR(ptr, m, n); - Trsm - ::invoke(member, Diag::NonUnit(), 1.0, ATL, ATR); - } - } - return 0; +template <> struct CholSupernodes { + template + KOKKOS_INLINE_FUNCTION static int factorize(MemberType &member, const SupernodeInfoType &info, + const ordinal_type sid) { + typedef SupernodeInfoType supernode_info_type; + + typedef typename supernode_info_type::value_type value_type; + typedef typename supernode_info_type::value_type_matrix value_type_matrix; + + // algorithm choice + using CholAlgoType = typename CholAlgorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + + // get current supernode + const auto &s = info.supernodes(sid); + + // get panel pointer + value_type *ptr = s.buf; + + // panel (s.m x s.n) is divided into ATL (m x m) and ATR (m x n) + const ordinal_type m = s.m, n = s.n - s.m; + + // m is available, then factorize the supernode block + if (m > 0) { + UnmanagedViewType ATL(ptr, m, m); + ptr += m * m; + Chol::invoke(member, ATL); + + // n is available, then solve interface block + if (n > 0) { + UnmanagedViewType ATR(ptr, m, n); + Trsm::invoke(member, Diag::NonUnit(), 1.0, ATL, + ATR); } - - template - KOKKOS_INLINE_FUNCTION - static int - update(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type offn, // ATR and ABR panel offset - const ordinal_type np, // ATR and ABR panel width - const ordinal_type sid, - const size_type bufsize, // ABR size + additional - /* */ void *buf) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - typedef typename supernode_info_type::dense_block_type dense_block_type; - - // algorithm choice - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type HerkAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type GemmAlgoType; - -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int update(MemberType &member, const SupernodeInfoType &info, + const ordinal_type offn, // ATR and ABR panel offset + const ordinal_type np, // ATR and ABR panel width + const ordinal_type sid, + const size_type bufsize, // ABR size + additional + /* */ void *buf) { + typedef SupernodeInfoType supernode_info_type; + + typedef typename supernode_info_type::value_type value_type; + typedef typename supernode_info_type::value_type_matrix value_type_matrix; + + // algorithm choice + using HerkAlgoType = typename HerkAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) #else - member.team_barrier(); + member.team_barrier(); #endif - // get current supernode - const auto &cur = info.supernodes(sid); - - // panel (cur.m x cur.n) is divided into ATL (m x m) and ATR (m x n) - const ordinal_type - m = cur.m, n = cur.n - cur.m , - nb = min(np, n - offn), nn = offn + nb; - - // m and n are available, then factorize the supernode block - if (m > 0 && n > 0) { - // ** update - const ordinal_type - sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; - - const ordinal_type - srcbeg = info.sid_block_colidx(sbeg).second, - srcend = info.sid_block_colidx(send).second, - srcsize = srcend - srcbeg; - -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - TACHO_TEST_FOR_ABORT(bufsize < size_type(srcsize*sizeof(ordinal_type) + - nn*nb*sizeof(value_type)), - "bufsize is smaller than required workspace"); -#else - TACHO_TEST_FOR_ABORT(bufsize < size_type(srcsize*sizeof(ordinal_type)*member.team_size() + - nn*nb*sizeof(value_type)), - "bufsize is smaller than required workspace"); + // get current supernode + const auto &cur = info.supernodes(sid); + + // panel (cur.m x cur.n) is divided into ATL (m x m) and ATR (m x n) + const ordinal_type m = cur.m, n = cur.n - cur.m, nb = min(np, n - offn), nn = offn + nb; + + // m and n are available, then factorize the supernode block + if (m > 0 && n > 0) { + // ** update + const ordinal_type sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; + + const ordinal_type srcbeg = info.sid_block_colidx(sbeg).second, srcend = info.sid_block_colidx(send).second, + srcsize = srcend - srcbeg; + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + TACHO_TEST_FOR_ABORT(bufsize < size_type(srcsize * sizeof(ordinal_type) + nn * nb * sizeof(value_type)), + "bufsize is smaller than required workspace"); +#else + TACHO_TEST_FOR_ABORT( + bufsize < size_type(srcsize * sizeof(ordinal_type) * member.team_size() + nn * nb * sizeof(value_type)), + "bufsize is smaller than required workspace"); #endif - - UnmanagedViewType ABL(cur.buf + m*m, m, nn); - UnmanagedViewType ATR(ABL.data() + offn*m, m, nb); - - value_type *ptr = (value_type*)buf; - UnmanagedViewType ABR(ptr, nn, nb); ptr += ABR.span(); - - if (offn == 0 && nb == n) - Herk - ::invoke(member, -1.0, ATR, 0.0, ABR); - else - Gemm - ::invoke(member, -1.0, ABL, ATR, 0.0, ABR); - - // short cut to direct update - if ((send - sbeg) == 1) { - const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); - const ordinal_type - tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, - tgtend = info.sid_block_colidx(s.sid_col_end-1).second, - tgtsize = tgtend - tgtbeg; - - if (srcsize == tgtsize) { - /* */ value_type *tgt = s.buf; - const value_type *src = (value_type*)ABR.data(); - -#if defined (KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - switch (info.front_update_mode) { - case 1: { - for (ordinal_type js=0;js ABL(cur.buf + m * m, m, nn); + UnmanagedViewType ATR(ABL.data() + offn * m, m, nb); + + value_type *ptr = (value_type *)buf; + UnmanagedViewType ABR(ptr, nn, nb); + ptr += ABR.span(); + + if (offn == 0 && nb == n) + Herk::invoke(member, -1.0, ATR, 0.0, ABR); + else + Gemm::invoke(member, -1.0, ABL, ATR, 0.0, ABR); + + // short cut to direct update + if ((send - sbeg) == 1) { + const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + + if (srcsize == tgtsize) { + /* */ value_type *tgt = s.buf; + const value_type *src = (value_type *)ABR.data(); + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + switch (info.front_update_mode) { + case 1: { + for (ordinal_type js = 0; js < nb; ++js) { + const ordinal_type jt = js + offn; + const value_type *__restrict__ ss = src + js * srcsize; + /* */ value_type *__restrict__ tt = tgt + jt * srcsize; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (ordinal_type i=0;i<=jt;++i) - Kokkos::atomic_fetch_add(&tt[i], ss[i]); - } - break; - } - case 0: { - // lock - while (Kokkos::atomic_compare_exchange(&s.lock, 0, 1)) KOKKOS_IMPL_PAUSE; - Kokkos::store_fence(); - - for (ordinal_type js=0;js A(s.u_buf, s.m, s.n); + + ordinal_type ijbeg = 0; + for (; s2t[ijbeg] == -1; ++ijbeg) + ; + + // lock + switch (info.front_update_mode) { + case 1: { + for (ordinal_type jj = max(ijbeg, offn); jj < nn; ++jj) { + const ordinal_type js = jj - offn; #if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) #pragma unroll #endif - for (ordinal_type ii=ijbeg;ii - KOKKOS_INLINE_FUNCTION - static int - factorize_recursive_serial(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid, - const bool final, - typename SupernodeInfoType::value_type *buf, - const size_type bufsize, - const size_type np) { - typedef SupernodeInfoType supernode_info_type; - - typedef typename supernode_info_type::value_type value_type; - - const auto &s = info.supernodes(sid); - - if (final) { - // serial recursion - for (ordinal_type i=0;i A(s.u_buf, s.m, s.n); + + ordinal_type ijbeg = 0; + for (; s2t[ijbeg] == -1; ++ijbeg) + ; + + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, nn - ijbeg), [&](const ordinal_type &iii) { + const ordinal_type ii = ijbeg + iii; + const ordinal_type row = s2t[ii]; + if (row < s.m) + for (ordinal_type jj = max(ijbeg, offn); jj < nn; ++jj) { + const ordinal_type js = jj - offn; + Kokkos::atomic_add(&A(row, s2t[jj]), ABR(ii, js)); + } + }); + // for (ordinal_type jj=max(ijbeg,offn);jj + KOKKOS_INLINE_FUNCTION static int factorize_recursive_serial(MemberType &member, const SupernodeInfoType &info, + const ordinal_type sid, const bool final, + typename SupernodeInfoType::value_type *buf, + const size_type bufsize, const size_type np) { + typedef SupernodeInfoType supernode_info_type; + + typedef typename supernode_info_type::value_type value_type; + + const auto &s = info.supernodes(sid); + + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + factorize_recursive_serial(member, info, s.children[i], final, buf, bufsize, np); + } + + { + const ordinal_type n = s.n - s.m; +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + const size_type bufsize_required = n * (min(np, n) + 1) * sizeof(value_type); #else - const size_type bufsize_required = n*(min(np,n)+member.team_size())*sizeof(value_type); + const size_type bufsize_required = n * (min(np, n) + member.team_size()) * sizeof(value_type); #endif - TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, - "bufsize is smaller than required"); - - CholSupernodes - ::factorize(member, info, sid); - - for (ordinal_type offn=0; offn - ::update(member, info, offn, np, sid, bufsize, (void*)buf); - } - } - return 0; - } + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + + CholSupernodes::factorize(member, info, sid); - }; -} + for (ordinal_type offn = 0; offn < n; offn += np) { + CholSupernodes::update(member, info, offn, np, sid, bufsize, (void *)buf); + } + } + return 0; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_External.hpp index 16d2600ffab9..213743910239 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_External.hpp @@ -9,56 +9,41 @@ namespace Tacho { - /// LAPACK Chol - /// =========== - template - struct Chol { - template - inline - static int - invoke(const ViewTypeA &A) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - - int r_val = 0; - const ordinal_type m = A.extent(0); - if (m > 0) { - Lapack::potrf(ArgUplo::param, - m, - A.data(), A.stride_1(), - &r_val); - TACHO_TEST_FOR_EXCEPTION(r_val, std::runtime_error, - "LAPACK (potrf) returns non-zero error code."); - } - return r_val; +/// LAPACK Chol +/// =========== +template struct Chol { + template inline static int invoke(const ViewTypeA &A) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + + int r_val = 0; + const ordinal_type m = A.extent(0); + if (m > 0) { + Lapack::potrf(ArgUplo::param, m, A.data(), A.stride_1(), &r_val); + TACHO_TEST_FOR_EXCEPTION(r_val, std::runtime_error, "LAPACK (potrf) returns non-zero error code."); + } + return r_val; #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space." ); - return -1; + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); + return -1; #endif - } - - - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - int r_val = 0; - //Kokkos::single(Kokkos::PerTeam(member), [&]() { - r_val = invoke(A); - TACHO_TEST_FOR_EXCEPTION(r_val, std::runtime_error, - "LAPACK (potrf) returns non-zero error code."); - //}); - return r_val; + } + + template inline static int invoke(MemberType &member, const ViewTypeA &A) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + int r_val = 0; + // Kokkos::single(Kokkos::PerTeam(member), [&]() { + r_val = invoke(A); + TACHO_TEST_FOR_EXCEPTION(r_val, std::runtime_error, "LAPACK (potrf) returns non-zero error code."); + //}); + return r_val; #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space." ); - return 0; + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); + return 0; #endif - } - }; -} + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_Internal.hpp index ddce2e2d34c3..88c4df0605ef 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_Internal.hpp @@ -9,31 +9,22 @@ namespace Tacho { - /// LAPACK Chol - /// =========== - template - struct Chol { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); +/// LAPACK Chol +/// =========== +template struct Chol { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; - int r_val = 0; - const ordinal_type m = A.extent(0); - if (m > 0) - LapackTeam::potrf(member, - ArgUplo::param, - m, - A.data(), A.stride_1(), - &r_val); - return r_val; - } - }; -} + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + + int r_val = 0; + const ordinal_type m = A.extent(0); + if (m > 0) + LapackTeam::potrf(member, ArgUplo::param, m, A.data(), A.stride_1(), &r_val); + return r_val; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_OnDevice.hpp index 0a8ac6fb4814..8867c8903177 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_OnDevice.hpp @@ -1,124 +1,87 @@ #ifndef __TACHO_CHOL_ON_DEVICE_HPP__ #define __TACHO_CHOL_ON_DEVICE_HPP__ - /// \file Tacho_Chol_OnDevice.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - template - struct Chol { - template - inline - static int - lapack_invoke(const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = A.extent(0); - - int r_val(0); - if (m > 0) { - Lapack::potrf(ArgUplo::param, - m, - A.data(), A.stride_1(), - &r_val); - } - return r_val; - } +template struct Chol { + template inline static int lapack_invoke(const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = A.extent(0); + + int r_val(0); + if (m > 0) { + Lapack::potrf(ArgUplo::param, m, A.data(), A.stride_1(), &r_val); + } + return r_val; + } #if defined(KOKKOS_ENABLE_CUDA) - template - inline - static int - cusolver_invoke(cusolverDnHandle_t &handle, - const ViewTypeA &A, - const ViewTypeW &W) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeW::non_const_value_type work_value_type; - const ordinal_type - m = A.extent(0); + template + inline static int cusolver_invoke(cusolverDnHandle_t &handle, const ViewTypeA &A, const ViewTypeW &W) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeW::non_const_value_type work_value_type; + const ordinal_type m = A.extent(0); - int r_val(0); - if (m > 0) { - int *devInfo = (int*)W.data(); - value_type *workspace = W.data() + 1; - int lwork = (W.span()-1)*sizeof(work_value_type); - r_val = Lapack::potrf(handle, - ArgUplo::cublas_param, - m, - A.data(), A.stride_1(), - workspace, lwork, - devInfo); - } - return r_val; - } + int r_val(0); + if (m > 0) { + int *devInfo = (int *)W.data(); + value_type *workspace = W.data() + 1; + int lwork = (W.span() - 1) * sizeof(work_value_type); + r_val = Lapack::potrf(handle, ArgUplo::cublas_param, m, A.data(), A.stride_1(), workspace, lwork, + devInfo); + } + return r_val; + } - template - inline - static int - cusolver_buffer_size(cusolverDnHandle_t &handle, - const ViewTypeA &A, - int *lwork) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = A.extent(0); + template + inline static int cusolver_buffer_size(cusolverDnHandle_t &handle, const ViewTypeA &A, int *lwork) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = A.extent(0); - int r_val(0); - if (m > 0) - r_val = Lapack::potrf_buffersize(handle, - ArgUplo::cublas_param, - m, - A.data(), A.stride_1(), - lwork); - return r_val; - } + int r_val(0); + if (m > 0) + r_val = Lapack::potrf_buffersize(handle, ArgUplo::cublas_param, m, A.data(), A.stride_1(), lwork); + return r_val; + } #endif - - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeW &W) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeW::non_const_value_type value_type_w; - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeW::memory_space memory_space_w; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeW::rank == 1,"W is not rank 1 view."); - - static_assert(std::is_same::value, - "A and W do not have the same value type."); + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeW &W) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeW::non_const_value_type value_type_w; - static_assert(std::is_same::value, - "A and W do not have the same memory space."); - int r_val(0); - if (std::is_same::value) { - r_val = lapack_invoke(A); - } + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeW::memory_space memory_space_w; -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) { - if (W.span() == 0) { - int lwork; - r_val = cusolver_buffer_size(member, A, &lwork); - r_val = (lwork+sizeof(value_type_w))/sizeof(value_type_w) + 1; - } else - r_val = cusolver_invoke(member, A, W); - } -#endif - return r_val; - } + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeW::rank == 1, "W is not rank 1 view."); + + static_assert(std::is_same::value, "A and W do not have the same value type."); - }; + static_assert(std::is_same::value, "A and W do not have the same memory space."); + int r_val(0); + if (std::is_same::value) { + r_val = lapack_invoke(A); + } + +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || + std::is_same::value) { + if (W.span() == 0) { + int lwork; + r_val = cusolver_buffer_size(member, A, &lwork); + r_val = (lwork + sizeof(value_type_w)) / sizeof(value_type_w) + 1; + } else + r_val = cusolver_invoke(member, A, W); + } +#endif + return r_val; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy.hpp index e6ff7369e3d6..059df0c4974f 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy.hpp @@ -9,13 +9,12 @@ namespace Tacho { - /// - /// Copy - /// +/// +/// Copy +/// - /// various implementation for different uplo and algo parameters - template - struct Copy; -} +/// various implementation for different uplo and algo parameters +template struct Copy; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_Internal.hpp index e412f40718be..6632016b8fc0 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_Internal.hpp @@ -1,57 +1,98 @@ #ifndef __TACHO_COPY_INTERNAL_HPP__ #define __TACHO_COPY_INTERNAL_HPP__ - /// \file Tacho_Copy_Internal.hpp /// \brief Copy /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - template<> - struct Copy { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - static_assert(std::is_same::value, "A and B does not have the value_type."); +template <> struct Copy { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + static_assert(std::is_same::value, "A and B does not have the value_type."); - /// this should be for contiguous array - //const ordinal_type sA = A.span(), sB = B.span(); - if (A.extent(0) == B.extent(0) && A.extent(0) == B.extent(0)) { - if (A.span() > 0) { + /// this should be for contiguous array + const ordinal_type sA = A.span(), mA = A.extent(0), nA = A.extent(1), as0 = A.stride(0), as1 = A.stride(1); + const ordinal_type /*sB = B.span(), */ mB = B.extent(0), nB = B.extent(1), bs0 = B.stride(0), bs1 = B.stride(1); + if (mA == mB && nA == nB) { + if (sA == (mA * nA) && as0 == bs0 && as1 == bs1) { + /// contiguous array + value_type *ptrA(A.data()); + const value_type *ptrB(B.data()); #if defined(__CUDA_ARCH__) - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, A.extent(1)), - [&](const ordinal_type &j) { - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, A.extent(0)), - [&](const ordinal_type &i) { - A(i,j) = B(i,j); - }); - }); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, sA), + [ptrA, ptrB](const ordinal_type &ij) { ptrA[ij] = ptrB[ij]; }); #else - if (A.span() == (A.extent(0)*A.extent(1)) && - B.span() == (B.extent(0)*B.extent(1))) - memcpy ((void *)A.data(), (const void *)B.data(), A.span()*sizeof(value_type)); - else - for (ordinal_type j=0,jend=A.extent(1);j A and B dimensions are not same\n"); +#if defined(__CUDA_ARCH__) + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, mA * nA), [A, B, mA](const ordinal_type &ij) { + const ordinal_type i = ij % mA, j = ij / mA; + A(i, j) = B(i, j); + }); +#else + for (ordinal_type j = 0; j < nA; ++j) + for (ordinal_type i = 0; i < mA; ++i) + A(i, j) = B(i, j); +#endif + } + } else { + printf("Error: Copy A and B dimensions are not same\n"); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A, const UploType uploB, + const DiagType diagB, const ViewTypeB &B) { + using value_type = typename ViewTypeA::non_const_value_type; + using value_type_b = typename ViewTypeB::non_const_value_type; + static_assert(std::is_same::value, "A and B does not have the value_type."); + + /// this should be for contiguous array + // const ordinal_type sA = A.span(), sB = B.span(); + if (A.extent(0) == B.extent(0) && A.extent(0) == B.extent(0) && A.span() > 0) { + if (uploB.param == 'U') { +#if defined(__CUDA_ARCH__) + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, A.extent(1)), [&](const ordinal_type &j) { + const ordinal_type tmp = diagB.param == 'U' ? j : j + 1; + const ordinal_type iend = tmp < A.extent(0) ? tmp : A.extent(0); + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, iend), + [&](const ordinal_type &i) { A(i, j) = B(i, j); }); + }); +#else + for (ordinal_type j = 0, jend = A.extent(1); j < jend; ++j) { + const ordinal_type tmp = diagB.param == 'U' ? j : j + 1; + const ordinal_type iend = tmp < A.extent(0) ? tmp : A.extent(0); + for (ordinal_type i = 0; i < iend; ++i) + A(i, j) = B(i, j); + } +#endif + } else if (uploB.param == 'L') { +#if defined(__CUDA_ARCH__) + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, A.extent(1)), [&](const ordinal_type &j) { + const ordinal_type ibeg = diagB.param == 'U' ? j + 1 : j; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, ibeg, A.extent(0)), + [&](const ordinal_type &i) { A(i, j) = B(i, j); }); + }); +#else + for (ordinal_type j = 0, jend = A.extent(1); j < jend; ++j) { + const ordinal_type ibeg = diagB.param == 'U' ? j + 1 : j; + for (ordinal_type i = ibeg, iend = A.extent(0); i < iend; ++i) + A(i, j) = B(i, j); + } +#endif } - return 0; + } else { + printf("Error: Copy A and B dimensions are not same\n"); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_OnDevice.hpp index b37cc4ec9b3a..f3b26707f410 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Copy_OnDevice.hpp @@ -1,29 +1,21 @@ #ifndef __TACHO_COPY_ON_DEVICE_HPP__ #define __TACHO_COPY_ON_DEVICE_HPP__ - /// \file Tacho_COPY_OnDevice.hpp /// \brief COPY /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - template<> - struct Copy { - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeB &B) { - const auto exec_instance = member; - Kokkos::deep_copy(exec_instance, A, B); +template <> struct Copy { + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeB &B) { + const auto &exec_instance = member; + Kokkos::deep_copy(exec_instance, A, B); - return 0; - } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_DenseFlopCount.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_DenseFlopCount.hpp index df6feef89d5a..e642afab5c07 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_DenseFlopCount.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_DenseFlopCount.hpp @@ -6,77 +6,76 @@ /// \file Tacho_DenseFlopCount.hpp /// \author Kyungjoo Kim (kyukim@sandia.gov) -// FLOP counting - From LAPACK working note #41 +// FLOP counting - From LAPACK working note #41 namespace Tacho { -#define FLOP_MUL(isComplex) ((isComplex) ? (6.0) : (1.0)) -#define FLOP_ADD(isComplex) ((isComplex) ? (2.0) : (1.0)) +#define FLOP_MUL(isComplex) ((isComplex) ? (6.0) : (1.0)) +#define FLOP_ADD(isComplex) ((isComplex) ? (2.0) : (1.0)) - template - class DenseFlopCount { - public: - static KOKKOS_INLINE_FUNCTION - double Gemm(int mm, int nn, int kk) { - double m = (double)mm; double n = (double)nn; double k = (double)kk; - return (FLOP_MUL(ArithTraits::is_complex)*(m*n*k) + - FLOP_ADD(ArithTraits::is_complex)*(m*n*k)); - } +template class DenseFlopCount { +public: + static KOKKOS_INLINE_FUNCTION double Gemm(int mm, int nn, int kk) { + double m = (double)mm; + double n = (double)nn; + double k = (double)kk; + return (FLOP_MUL(ArithTraits::is_complex) * (m * n * k) + + FLOP_ADD(ArithTraits::is_complex) * (m * n * k)); + } - static KOKKOS_INLINE_FUNCTION - double Syrk(int kk, int nn) { - double k = (double)kk; double n = (double)nn; - return (FLOP_MUL(ArithTraits::is_complex)*(0.5*k*n*(n+1.0)) + - FLOP_ADD(ArithTraits::is_complex)*(0.5*k*n*(n+1.0))); - } + static KOKKOS_INLINE_FUNCTION double Syrk(int kk, int nn) { + double k = (double)kk; + double n = (double)nn; + return (FLOP_MUL(ArithTraits::is_complex) * (0.5 * k * n * (n + 1.0)) + + FLOP_ADD(ArithTraits::is_complex) * (0.5 * k * n * (n + 1.0))); + } - static KOKKOS_INLINE_FUNCTION - double TrsmLower(int mm, int nn) { - double m = (double)mm; double n = (double)nn; - return (FLOP_MUL(ArithTraits::is_complex)*(0.5*n*m*(m+1.0)) + - FLOP_ADD(ArithTraits::is_complex)*(0.5*n*m*(m-1.0))); - } - - static KOKKOS_INLINE_FUNCTION - double TrsmUpper(int mm, int nn) { - double m = (double)mm; double n = (double)nn; - return (FLOP_MUL(ArithTraits::is_complex)*(0.5*m*n*(n+1.0)) + - FLOP_ADD(ArithTraits::is_complex)*(0.5*m*n*(n-1.0))); - } + static KOKKOS_INLINE_FUNCTION double TrsmLower(int mm, int nn) { + double m = (double)mm; + double n = (double)nn; + return (FLOP_MUL(ArithTraits::is_complex) * (0.5 * n * m * (m + 1.0)) + + FLOP_ADD(ArithTraits::is_complex) * (0.5 * n * m * (m - 1.0))); + } - static KOKKOS_INLINE_FUNCTION - double Trsm(int is_lower, int mm, int nn) { - return (is_lower ? - TrsmLower(mm, nn) : - TrsmUpper(mm, nn)); - } - - static KOKKOS_INLINE_FUNCTION - double LU(int mm, int nn) { - double m = (double)mm; double n = (double)nn; - if (m > n) - return (FLOP_MUL(ArithTraits::is_complex)*(0.5*m*n*n-(1.0/6.0)*n*n*n+0.5*m*n-0.5*n*n+(2.0/3.0)*n) + - FLOP_ADD(ArithTraits::is_complex)*(0.5*m*n*n-(1.0/6.0)*n*n*n-0.5*m*n+ (1.0/6.0)*n)); - else - return (FLOP_MUL(ArithTraits::is_complex)*(0.5*n*m*m-(1.0/6.0)*m*m*m+0.5*n*m-0.5*m*m+(2.0/3.0)*m) + - FLOP_ADD(ArithTraits::is_complex)*(0.5*n*m*m-(1.0/6.0)*m*m*m-0.5*n*m+ (1.0/6.0)*m)); - } + static KOKKOS_INLINE_FUNCTION double TrsmUpper(int mm, int nn) { + double m = (double)mm; + double n = (double)nn; + return (FLOP_MUL(ArithTraits::is_complex) * (0.5 * m * n * (n + 1.0)) + + FLOP_ADD(ArithTraits::is_complex) * (0.5 * m * n * (n - 1.0))); + } - static KOKKOS_INLINE_FUNCTION - double Chol(int nn) { - double n = (double)nn; - return (FLOP_MUL(ArithTraits::is_complex)*((1.0/6.0)*n*n*n+0.5*n*n+(1.0/3.0)*n) + - FLOP_ADD(ArithTraits::is_complex)*((1.0/6.0)*n*n*n- (1.0/6.0)*n)); - } + static KOKKOS_INLINE_FUNCTION double Trsm(int is_lower, int mm, int nn) { + return (is_lower ? TrsmLower(mm, nn) : TrsmUpper(mm, nn)); + } - static KOKKOS_INLINE_FUNCTION - double LDL(int nn) { - double n = (double)nn; - return (FLOP_MUL(ArithTraits::is_complex)*((1.0/3.0)*n*n*n + (2.0/3.0)*n) + - FLOP_ADD(ArithTraits::is_complex)*((1.0/3.0)*n*n*n - (1.0/3.0)*n)); - } - }; + static KOKKOS_INLINE_FUNCTION double LU(int mm, int nn) { + double m = (double)mm; + double n = (double)nn; + if (m > n) + return (FLOP_MUL(ArithTraits::is_complex) * + (0.5 * m * n * n - (1.0 / 6.0) * n * n * n + 0.5 * m * n - 0.5 * n * n + (2.0 / 3.0) * n) + + FLOP_ADD(ArithTraits::is_complex) * + (0.5 * m * n * n - (1.0 / 6.0) * n * n * n - 0.5 * m * n + (1.0 / 6.0) * n)); + else + return (FLOP_MUL(ArithTraits::is_complex) * + (0.5 * n * m * m - (1.0 / 6.0) * m * m * m + 0.5 * n * m - 0.5 * m * m + (2.0 / 3.0) * m) + + FLOP_ADD(ArithTraits::is_complex) * + (0.5 * n * m * m - (1.0 / 6.0) * m * m * m - 0.5 * n * m + (1.0 / 6.0) * m)); + } -} + static KOKKOS_INLINE_FUNCTION double Chol(int nn) { + double n = (double)nn; + return (FLOP_MUL(ArithTraits::is_complex) * ((1.0 / 6.0) * n * n * n + 0.5 * n * n + (1.0 / 3.0) * n) + + FLOP_ADD(ArithTraits::is_complex) * ((1.0 / 6.0) * n * n * n - (1.0 / 6.0) * n)); + } + + static KOKKOS_INLINE_FUNCTION double LDL(int nn) { + double n = (double)nn; + return (FLOP_MUL(ArithTraits::is_complex) * ((1.0 / 3.0) * n * n * n + (2.0 / 3.0) * n) + + FLOP_ADD(ArithTraits::is_complex) * ((1.0 / 3.0) * n * n * n - (1.0 / 3.0) * n)); + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_DenseMatrixView.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_DenseMatrixView.hpp deleted file mode 100644 index 6ae759a448a6..000000000000 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_DenseMatrixView.hpp +++ /dev/null @@ -1,410 +0,0 @@ -#ifndef __TACHO_DENSE_MATRIX_VIEW_HPP__ -#define __TACHO_DENSE_MATRIX_VIEW_HPP__ - -#include "Tacho_Util.hpp" - -/// \file Tacho_DenseMatrixView.hpp -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -namespace Tacho { - - template - struct DenseMatrixView { - public: - enum : ordinal_type { rank = 2 }; - - typedef ValueType value_type; - typedef value_type non_const_value_type; - - typedef SchedulerType scheduler_type; - typedef typename UseThisDevice::type device_type; - - typedef Kokkos::BasicFuture future_type; - - private: - ordinal_type _offm, _offn, _m, _n, _rs, _cs; - value_type *_buf; - future_type _future; - - public: - KOKKOS_INLINE_FUNCTION - DenseMatrixView() - : _offm(0), _offn(0), - _m(0), _n(0), - _rs(0), _cs(0), - _buf(NULL), _future() {} - - KOKKOS_INLINE_FUNCTION - DenseMatrixView( value_type *buf, - const ordinal_type m, - const ordinal_type n) - : _offm(0), _offn(0), - _m(m), _n(n), - _rs(1), _cs(m), - _buf(buf), _future() {} - - KOKKOS_INLINE_FUNCTION - DenseMatrixView(const DenseMatrixView &b) - : _offm(b._offm), _offn(b._offn), - _m(b._m), _n(b._n), - _rs(b._rs), _cs(b._cs), - _buf(b._buf), _future() {} - - KOKKOS_INLINE_FUNCTION - value_type& operator[](const ordinal_type k) const { - return _buf[k]; - } - - KOKKOS_INLINE_FUNCTION - value_type& operator()(const ordinal_type i, - const ordinal_type j) const { - return _buf[(i+_offm)*_rs + (j+_offn)*_cs]; - } - - KOKKOS_INLINE_FUNCTION - void set_view(const DenseMatrixView &base, - const ordinal_type offm, const ordinal_type m, - const ordinal_type offn, const ordinal_type n) { - _rs = base._rs; _cs = base._cs; _buf = base._buf; - - _offm = offm; _m = m; - _offn = offn; _n = n; - } - - KOKKOS_INLINE_FUNCTION - void set_view(const ordinal_type offm, const ordinal_type m, - const ordinal_type offn, const ordinal_type n) { - _offm = offm; _m = m; - _offn = offn; _n = n; - } - - KOKKOS_INLINE_FUNCTION - void set_view(const ordinal_type m, - const ordinal_type n) { - _offm = 0; _m = m; - _offn = 0; _n = n; - } - - KOKKOS_INLINE_FUNCTION - void attach_buffer(const ordinal_type rs, - const ordinal_type cs, - const value_type *buf) { - _rs = rs; _cs = cs; _buf = const_cast(buf); - } - - KOKKOS_INLINE_FUNCTION - void set_future(const future_type &f) { _future = f; } - - KOKKOS_INLINE_FUNCTION - void set_future() { _future.clear(); } - - /// get methods - - KOKKOS_INLINE_FUNCTION - ordinal_type offset_0() const { return _offm; } - - KOKKOS_INLINE_FUNCTION - ordinal_type offset_1() const { return _offn; } - - KOKKOS_INLINE_FUNCTION - ordinal_type extent(const ordinal_type r) const { return (r == 0) ? _m : _n; } - - KOKKOS_INLINE_FUNCTION - ordinal_type stride_0() const { return _rs; } - - KOKKOS_INLINE_FUNCTION - ordinal_type stride_1() const { return _cs; } - - KOKKOS_INLINE_FUNCTION - value_type* data() const { return _buf+_offm*_rs+_offn*_cs; } - - KOKKOS_INLINE_FUNCTION - future_type future() const { return _future; } - }; - - template - KOKKOS_INLINE_FUNCTION - void - clearFutureOfBlocks(const MatrixOfBlocksViewType &H) { - const ordinal_type m = H.extent(0); - const ordinal_type n = H.extent(1); - for (ordinal_type j=0;j - KOKKOS_INLINE_FUNCTION - void - clearFutureOfBlocks(/* */ MemberType &member, - const MatrixOfBlocksViewType &H) { - const ordinal_type m = H.extent(0); - const ordinal_type n = H.extent(1); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) { - H(i,j).set_future(); - }); - }); - } - - template - KOKKOS_INLINE_FUNCTION - void - setMatrixOfBlocks(const MatrixOfBlocksViewType &H, - const ordinal_type m, - const ordinal_type n, - const ordinal_type mb, - const ordinal_type nb) { - const ordinal_type bm = H.extent(0); - const ordinal_type bn = H.extent(1); - - for (ordinal_type j=0;j n ? n : jtmp, - jdiff = (jend > jbeg)*(jend - jbeg); - - for (ordinal_type i=0;i m ? m : itmp, - idiff = (iend > ibeg)*(iend - ibeg); - - H(i,j).set_view(ibeg, idiff, - jbeg, jdiff); - } - } - } - - template - KOKKOS_INLINE_FUNCTION - void - setMatrixOfBlocks(const MatrixOfBlocksViewType &H, - const ordinal_type m, - const ordinal_type n, - const ordinal_type mb) { - setMatrixOfBlocks(H, m, n, mb, mb); - } - - - template - KOKKOS_INLINE_FUNCTION - void - setMatrixOfBlocks(/* */ MemberType &member, - const MatrixOfBlocksViewType &H, - const ordinal_type m, - const ordinal_type n, - const ordinal_type mb, - const ordinal_type nb) { - const ordinal_type bm = H.extent(0); - const ordinal_type bn = H.extent(1); - - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,bn),[&](const int &j) { - const ordinal_type - jbeg = j*nb, jtmp = jbeg + nb, - jend = jtmp > n ? n : jtmp, - jdiff = (jend > jbeg)*(jend - jbeg); - - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,bm),[&](const int &i) { - const ordinal_type - ibeg = i*mb, itmp = ibeg + mb, - iend = itmp > m ? m : itmp, - idiff = (iend > ibeg)*(iend - ibeg); - - H(i,j).set_view(ibeg, idiff, - jbeg, jdiff); - }); - }); - } - - template - KOKKOS_INLINE_FUNCTION - void - setMatrixOfBlocks(/* */ MemberType &member, - const MatrixOfBlocksViewType &H, - const ordinal_type m, - const ordinal_type n, - const ordinal_type mb) { - setMatrixOfBlocks(member, H, m, n, mb, mb); - } - - template - KOKKOS_INLINE_FUNCTION - void - attachBaseBuffer(const MatrixOfBlocksViewType &H, - const BaseBufferPtrType ptr, - const ordinal_type rs, - const ordinal_type cs) { - const ordinal_type m = H.extent(0), n = H.extent(1); - for (ordinal_type j=0;j - KOKKOS_INLINE_FUNCTION - void - attachBaseBuffer(/* */ MemberType &member, - const MatrixOfBlocksViewType &H, - const BaseBufferPtrType ptr, - const ordinal_type rs, - const ordinal_type cs) { - const ordinal_type m = H.extent(0); - const ordinal_type n = H.extent(1); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,n),[&](const int &j) { - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,m),[&](const int &i) { - H(i,j).attach_buffer(rs, cs, ptr); - }); - }); - } - - template - KOKKOS_INLINE_FUNCTION - void - allocateStorageByBlocks(const MatrixOfBlocksViewType &H, - const MemoryPoolType &pool) { - typedef typename MatrixOfBlocksViewType::value_type dense_block_type; - typedef typename dense_block_type::value_type value_type; - - const ordinal_type m = H.extent(0); - const ordinal_type n = H.extent(1); - for (ordinal_type j=0;j 0 && nn > 0) { - auto ptr = (value_type*)pool.allocate(mm*nn*sizeof(value_type)); - TACHO_TEST_FOR_ABORT(ptr == NULL, "memory pool allocation fails"); - - H(i,j).set_view(mm, nn); // whatever offsets are defined here, they are gone. - H(i,j).attach_buffer(1, mm, ptr); - } - } - } - - template - KOKKOS_INLINE_FUNCTION - void - deallocateStorageByBlocks(const MatrixOfBlocksViewType &H, - const MemoryPoolType &pool) { - typedef typename MatrixOfBlocksViewType::value_type dense_block_type; - typedef typename dense_block_type::value_type value_type; - - const ordinal_type m = H.extent(0), n = H.extent(1); - for (ordinal_type j=0;j 0 && nn > 0) - pool.deallocate(blk.data(), mm*nn*sizeof(value_type)); - } - } - - template - KOKKOS_INLINE_FUNCTION - void - copyElementwise(const DenseMatrixView &F, - const DenseMatrixView,ExecSpace> &H) { - const ordinal_type - hm = H.extent(0), hn = H.extent(0), - fm = F.extent(0), fn = F.extent(0); - - if (hm > 0 && hn > 0) { - ordinal_type offj = 0; - for (ordinal_type j=0;j - KOKKOS_INLINE_FUNCTION - void - copyElementwise(const DenseMatrixView,ExecSpace> &H, - const DenseMatrixView &F) { - const ordinal_type - hm = H.extent(0), hn = H.extent(0), - fm = F.extent(0), fn = F.extent(0); - - if (hm > 0 && hn > 0) { - ordinal_type offj = 0; - for (ordinal_type j=0;j - inline - void - applyRowPermutationToDenseMatrix(const DenseMatrixViewType &A, - const DenseMatrixViewType &B, - const OrdinalTypeArray &p) { - const ordinal_type m = A.extent(0), n = A.extent(1); - typedef typename DenseMatrixViewType::device_type::execution_space exec_space; - - if (true) { //std::is_same::value) { - // serial copy on host - Kokkos::RangePolicy > policy(0, m); - Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const ordinal_type &i) { - for (ordinal_type j=0;j > policy(m, 1); - // Kokkos::parallel_for - // (policy, KOKKOS_LAMBDA (const typename Kokkos::TeamPolicy::member_type &member) { - // const ordinal_type i = member.league_rank(); - // Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,n),[&](const int &j) { - // Kokkos::single(Kokkos::PerThread(member), [&]() { - // A(i, j) = B(p(i), j); - // }); - // }); - // }); - } - - } - -} - -#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Driver_Impl.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Driver_Impl.hpp index 80a5ed91812c..ddcaafb1df16 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Driver_Impl.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Driver_Impl.hpp @@ -5,242 +5,196 @@ /// \brief temporary solver interface for refactoring /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "Tacho_Internal.hpp" #include "Tacho_Driver.hpp" +#include "Tacho_Internal.hpp" namespace Tacho { - - template - Driver - ::Driver() - : _mode(1), _order_connected_graph_separately(0), - _m(0), _nnz(0), - _ap(), _h_ap(), _aj(), _h_aj(), - _perm(), _h_perm(), _peri(), _h_peri(), - _m_graph(0), _nnz_graph(0), - _h_ap_graph(), _h_aj_graph(), - _h_perm_graph(), _h_peri_graph(), - _nsupernodes(0), - _N(nullptr), - _verbose(0), - _small_problem_thres(1024), - _serial_thres_size(-1), - _mb(-1), - _nb(-1), - _front_update_mode(-1), - _levelset(0), - _device_level_cut(0), - _device_factor_thres(128), - _device_solve_thres(128), - _variant(2), - _nstreams(16), - _max_num_superblocks(-1) {} - /// - /// duplicate the object - /// - template - Driver - Driver - ::duplicate() { - /// input matrix should be given (m and nnz) and analysis is done (nsupernodes is greater than zero) - const bool is_analysis_done = (_m > 0) && (_nnz > 0) && (_nsupernodes > 0); - TACHO_TEST_FOR_EXCEPTION(!is_analysis_done, std::logic_error, "Analysis is not done yet"); - - /// copy constructor of this - Driver r_val(*this); - - /// make sure numeric tool is null pointer - r_val._N = nullptr; - - return r_val; - } +template +Driver::Driver() + : _method(1), _order_connected_graph_separately(0), _m(0), _nnz(0), _ap(), _h_ap(), _aj(), _h_aj(), _perm(), + _h_perm(), _peri(), _h_peri(), _m_graph(0), _nnz_graph(0), _h_ap_graph(), _h_aj_graph(), _h_perm_graph(), + _h_peri_graph(), _nsupernodes(0), _N(nullptr), _verbose(0), _small_problem_thres(1024), _serial_thres_size(-1), + _mb(-1), _nb(-1), _front_update_mode(-1), _levelset(0), _device_level_cut(0), _device_factor_thres(128), + _device_solve_thres(128), _variant(2), _nstreams(16), _max_num_superblocks(-1) {} + +/// +/// duplicate the object +/// +template Driver Driver::duplicate() { + /// input matrix should be given (m and nnz) and analysis is done (nsupernodes is greater than zero) + const bool is_analysis_done = (_m > 0) && (_nnz > 0) && (_nsupernodes > 0); + TACHO_TEST_FOR_EXCEPTION(!is_analysis_done, std::logic_error, "Analysis is not done yet"); + + /// copy constructor of this + Driver r_val(*this); + + /// make sure numeric tool is null pointer + r_val._N = nullptr; + + return r_val; +} - /// - /// common options - /// - template - void - Driver - ::setVerbose(const ordinal_type verbose) { - _verbose = verbose; +/// +/// common options +/// +template void Driver::setVerbose(const ordinal_type verbose) { _verbose = verbose; } + +template +void Driver::setSmallProblemThresholdsize(const ordinal_type small_problem_thres) { + _small_problem_thres = small_problem_thres; +} + +template +void Driver::setMatrixType(const int symmetric, // 0 - unsymmetric, 1 - structure sym, 2 - symmetric + const bool is_positive_definite) { + switch (symmetric) { + case 0: { + _method = LU; + break; } - - template - void - Driver - ::setSmallProblemThresholdsize(const ordinal_type small_problem_thres) { - _small_problem_thres = small_problem_thres; + case 1: { + _method = SymLU; + break; } - - template - void - Driver - ::setMatrixType(const int symmetric, // 0 - unsymmetric, 1 - structure sym, 2 - symmetric - const bool is_positive_definite) { - switch (symmetric) { - case 0: { _mode = LU; break; } - case 1: { _mode = SymLU; break; } - case 2: { - if (is_positive_definite) { - if (std::is_same::value || - std::is_same::value || - std::is_same >::value || - std::is_same >::value) { - // real symmetric posdef - _mode = Cholesky; - } - } else { // real or complex symmetric indef - _mode = LDL; + case 2: { + if (is_positive_definite) { + if (std::is_same::value || std::is_same::value || + std::is_same>::value || + std::is_same>::value) { + // real symmetric posdef + _method = Cholesky; } - break; - } - default: { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "symmetric argument is wrong"); - } + } else { // real or complex symmetric indef + _method = LDL; } + break; } - - template - void - Driver - ::setOrderConnectedGraphSeparately(const ordinal_type order_connected_graph_separately) { - _order_connected_graph_separately = order_connected_graph_separately; + default: { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "symmetric argument is wrong"); } - - /// - /// tasking options - /// - template - void - Driver - ::setSerialThresholdsize(const ordinal_type serial_thres_size) { - _serial_thres_size = serial_thres_size; } +} - template - void - Driver - ::setBlocksize(const ordinal_type mb) { - _mb = mb; +template +void Driver::setSolutionMethod(const int method) { // 1 - Chol, 2 - LDL, 3 - LU + { + std::stringstream ss; + ss << "Error: the given method (" << _method << ") is not supported, 1 - Chol, 2 - LDL, 3 - SymLU"; + TACHO_TEST_FOR_EXCEPTION(method != Cholesky && method != LDL && method != SymLU, std::logic_error, + ss.str().c_str()); } + _method = method; +} - template - void - Driver - ::setPanelsize(const ordinal_type nb) { - _nb = nb; - } +template +void Driver::setOrderConnectedGraphSeparately(const ordinal_type order_connected_graph_separately) { + _order_connected_graph_separately = order_connected_graph_separately; +} - template - void - Driver - ::setFrontUpdateMode(const ordinal_type front_update_mode) { - _front_update_mode = front_update_mode; - } +/// +/// tasking options +/// +template void Driver::setSerialThresholdsize(const ordinal_type serial_thres_size) { + _serial_thres_size = serial_thres_size; +} - template - void - Driver - ::setMaxNumberOfSuperblocks(const ordinal_type max_num_superblocks) { - _max_num_superblocks = max_num_superblocks; - } +template void Driver::setBlocksize(const ordinal_type mb) { _mb = mb; } - /// - /// Level set tools options - /// - template - void - Driver - ::setLevelSetScheduling(const bool levelset) { - _levelset = levelset; - } +template void Driver::setPanelsize(const ordinal_type nb) { _nb = nb; } - template - void - Driver - ::setLevelSetOptionDeviceLevelCut(const ordinal_type device_level_cut) { - _device_level_cut = device_level_cut; - } +template void Driver::setFrontUpdateMode(const ordinal_type front_update_mode) { + _front_update_mode = front_update_mode; +} - template - void - Driver - ::setLevelSetOptionDeviceFunctionThreshold(const ordinal_type device_factor_thres, - const ordinal_type device_solve_thres) { - _device_factor_thres = device_factor_thres; - _device_solve_thres = device_solve_thres; - } +template +void Driver::setMaxNumberOfSuperblocks(const ordinal_type max_num_superblocks) { + _max_num_superblocks = max_num_superblocks; +} - template - void - Driver - ::setLevelSetOptionAlgorithmVariant(const ordinal_type variant) { - if (variant > 2 || variant < 0) { - std::logic_error("levelset algorithm variants range from 0 to 2"); - } - _variant = variant; - } - - template - void - Driver - ::setLevelSetOptionNumStreams(const ordinal_type nstreams) { - _nstreams = nstreams; - } +/// +/// Level set tools options +/// +template void Driver::setLevelSetScheduling(const bool levelset) { + _levelset = levelset; +} - /// - /// get interface - /// - template - ordinal_type - Driver - ::getNumSupernodes() const { - return _nsupernodes; - } - - template - typename Driver::ordinal_type_array - Driver - ::getSupernodes() const { - return _supernodes; - } - - template - typename Driver::ordinal_type_array - Driver - ::getPermutationVector() const { - return _perm; - } - - template - typename Driver::ordinal_type_array - Driver - ::getInversePermutationVector() const { - return _peri; +template +void Driver::setLevelSetOptionDeviceLevelCut(const ordinal_type device_level_cut) { + _device_level_cut = device_level_cut; +} + +template +void Driver::setLevelSetOptionDeviceFunctionThreshold(const ordinal_type device_factor_thres, + const ordinal_type device_solve_thres) { + _device_factor_thres = device_factor_thres; + _device_solve_thres = device_solve_thres; +} + +template void Driver::setLevelSetOptionAlgorithmVariant(const ordinal_type variant) { + if (variant > 2 || variant < 0) { + std::logic_error("levelset algorithm variants range from 0 to 2"); } + _variant = variant; +} + +template void Driver::setLevelSetOptionNumStreams(const ordinal_type nstreams) { + _nstreams = nstreams; +} - // internal only - template - int - Driver - ::analyze() { - int r_val(0); - if (_m < _small_problem_thres) { - /// do nothing - if (_verbose) { - printf("TachoSolver: Analyze\n"); - printf("====================\n"); - printf(" Linear system A\n"); - printf(" number of equations: %10d\n", _m); - printf("\n"); - printf(" A is a small problem ( < %d ) and LAPACK is used\n", _small_problem_thres); - printf("\n"); +/// +/// get interface +/// +template ordinal_type Driver::getNumSupernodes() const { return _nsupernodes; } + +template typename Driver::ordinal_type_array Driver::getSupernodes() const { + return _supernodes; +} + +template +typename Driver::ordinal_type_array Driver::getPermutationVector() const { + return _perm; +} + +template +typename Driver::ordinal_type_array Driver::getInversePermutationVector() const { + return _peri; +} + +// internal only +template int Driver::analyze() { + int r_val(0); + if (_m < _small_problem_thres) { + /// do nothing + if (_verbose) { + printf("TachoSolver: Analyze\n"); + printf("====================\n"); + printf(" Linear system A\n"); + printf(" number of equations: %10d\n", _m); + printf("\n"); + printf(" A is a small problem ( < %d ) and LAPACK is used\n", _small_problem_thres); + printf("\n"); + } + } else { + const bool use_condensed_graph = (_m_graph > 0 && _m_graph < _m); + if (use_condensed_graph) { + Graph graph(_m_graph, _nnz_graph, _h_ap_graph, _h_aj_graph); + graph_tools_type G(graph); +#if defined(TACHO_HAVE_METIS) + if (_order_connected_graph_separately) { + G.setOption(METIS_OPTION_CCORDER, 1); } +#endif + G.reorder(_verbose); + + _h_perm_graph = G.PermVector(); + _h_peri_graph = G.InvPermVector(); + + r_val = analyze_condensed_graph(); } else { - const bool use_condensed_graph = (_m_graph > 0 && _m_graph < _m); - if (use_condensed_graph) { - Graph graph(_m_graph, _nnz_graph, _h_ap_graph, _h_aj_graph); + const bool use_graph_partitioner = (_h_perm.extent(0) == 0 && _h_peri.extent(0) == 0); + if (use_graph_partitioner) { + Graph graph(_m, _nnz, _h_ap, _h_aj); graph_tools_type G(graph); #if defined(TACHO_HAVE_METIS) if (_order_connected_graph_separately) { @@ -248,533 +202,436 @@ namespace Tacho { } #endif G.reorder(_verbose); - - _h_perm_graph = G.PermVector(); - _h_peri_graph = G.InvPermVector(); - - r_val = analyze_condensed_graph(); + + _h_perm = G.PermVector(); + _h_peri = G.InvPermVector(); + + r_val = analyze_linear_system(); } else { - const bool use_graph_partitioner = (_h_perm.extent(0) == 0 && _h_peri.extent(0) == 0); - if (use_graph_partitioner) { - Graph graph(_m, _nnz, _h_ap, _h_aj); - graph_tools_type G(graph); -#if defined(TACHO_HAVE_METIS) - if (_order_connected_graph_separately) { - G.setOption(METIS_OPTION_CCORDER, 1); - } -#endif - G.reorder(_verbose); - - _h_perm = G.PermVector(); - _h_peri = G.InvPermVector(); - - r_val = analyze_linear_system(); - } else { - r_val = analyze_linear_system(); - } + r_val = analyze_linear_system(); } } - return r_val; } + return r_val; +} - template - int - Driver - ::analyze_linear_system() { - if (_verbose) { - printf("TachoSolver: Analyze Linear System\n"); - printf("==================================\n"); - } +template int Driver::analyze_linear_system() { + if (_verbose) { + printf("TachoSolver: Analyze Linear System\n"); + printf("==================================\n"); + } - { - symbolic_tools_type S(_m, _h_ap, _h_aj, _h_perm, _h_peri); - S.symbolicFactorize(_verbose); - - _nsupernodes = S.NumSupernodes(); - _stree_level = S.SupernodesTreeLevel(); - _stree_roots = S.SupernodesTreeRoots(); - - _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); - _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); - _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); - _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); - _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); - _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); - _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); - _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); - _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); - - Kokkos::deep_copy(_supernodes , S.Supernodes()); - Kokkos::deep_copy(_gid_super_panel_ptr , S.gidSuperPanelPtr()); - Kokkos::deep_copy(_gid_super_panel_colidx , S.gidSuperPanelColIdx()); - Kokkos::deep_copy(_sid_super_panel_ptr , S.sidSuperPanelPtr()); - Kokkos::deep_copy(_sid_super_panel_colidx , S.sidSuperPanelColIdx()); - Kokkos::deep_copy(_blk_super_panel_colidx , S.blkSuperPanelColIdx()); - Kokkos::deep_copy(_stree_parent , S.SupernodesTreeParent()); - Kokkos::deep_copy(_stree_ptr , S.SupernodesTreePtr()); - Kokkos::deep_copy(_stree_children , S.SupernodesTreeChildren()); - - // perm and peri is updated during symbolic factorization - _perm = Kokkos::create_mirror_view(exec_memory_space(), _h_perm); - _peri = Kokkos::create_mirror_view(exec_memory_space(), _h_peri); - - Kokkos::deep_copy(_perm, _h_perm); - Kokkos::deep_copy(_peri, _h_peri); - } - return 0; + { + symbolic_tools_type S(_m, _h_ap, _h_aj, _h_perm, _h_peri); + S.symbolicFactorize(_verbose); + + _nsupernodes = S.NumSupernodes(); + _stree_level = S.SupernodesTreeLevel(); + _stree_roots = S.SupernodesTreeRoots(); + + _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); + _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); + _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); + _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); + _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); + _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); + _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); + _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); + _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); + + Kokkos::deep_copy(_supernodes, S.Supernodes()); + Kokkos::deep_copy(_gid_super_panel_ptr, S.gidSuperPanelPtr()); + Kokkos::deep_copy(_gid_super_panel_colidx, S.gidSuperPanelColIdx()); + Kokkos::deep_copy(_sid_super_panel_ptr, S.sidSuperPanelPtr()); + Kokkos::deep_copy(_sid_super_panel_colidx, S.sidSuperPanelColIdx()); + Kokkos::deep_copy(_blk_super_panel_colidx, S.blkSuperPanelColIdx()); + Kokkos::deep_copy(_stree_parent, S.SupernodesTreeParent()); + Kokkos::deep_copy(_stree_ptr, S.SupernodesTreePtr()); + Kokkos::deep_copy(_stree_children, S.SupernodesTreeChildren()); + + // perm and peri is updated during symbolic factorization + _perm = Kokkos::create_mirror_view(exec_memory_space(), _h_perm); + _peri = Kokkos::create_mirror_view(exec_memory_space(), _h_peri); + + Kokkos::deep_copy(_perm, _h_perm); + Kokkos::deep_copy(_peri, _h_peri); } + return 0; +} - template - int - Driver - ::analyze_condensed_graph() { - if (_verbose) { - printf("TachoSolver: Analyze Condensed Graph and Evaporate the Graph\n"); - printf("============================================================\n"); - } +template int Driver::analyze_condensed_graph() { + if (_verbose) { + printf("TachoSolver: Analyze Condensed Graph and Evaporate the Graph\n"); + printf("============================================================\n"); + } - { - symbolic_tools_type S(_m_graph, _h_ap_graph, _h_aj_graph, _h_perm_graph, _h_peri_graph); - S.symbolicFactorize(_verbose); - S.evaporateSymbolicFactors(_h_aw_graph, _verbose); - - _nsupernodes = S.NumSupernodes(); - _stree_level = S.SupernodesTreeLevel(); - _stree_roots = S.SupernodesTreeRoots(); - - _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); - _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); - _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); - _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); - _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); - _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); - _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); - _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); - _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); - _perm = Kokkos::create_mirror_view(exec_memory_space(), S.PermVector()); - _peri = Kokkos::create_mirror_view(exec_memory_space(), S.InvPermVector()); - - Kokkos::deep_copy(_supernodes , S.Supernodes()); - Kokkos::deep_copy(_gid_super_panel_ptr , S.gidSuperPanelPtr()); - Kokkos::deep_copy(_gid_super_panel_colidx , S.gidSuperPanelColIdx()); - Kokkos::deep_copy(_sid_super_panel_ptr , S.sidSuperPanelPtr()); - Kokkos::deep_copy(_sid_super_panel_colidx , S.sidSuperPanelColIdx()); - Kokkos::deep_copy(_blk_super_panel_colidx , S.blkSuperPanelColIdx()); - Kokkos::deep_copy(_stree_parent , S.SupernodesTreeParent()); - Kokkos::deep_copy(_stree_ptr , S.SupernodesTreePtr()); - Kokkos::deep_copy(_stree_children , S.SupernodesTreeChildren()); - Kokkos::deep_copy(_perm , S.PermVector()); - Kokkos::deep_copy(_peri , S.InvPermVector()); - - _h_perm = S.PermVector(); - _h_peri = S.InvPermVector(); - } - return 0; + { + symbolic_tools_type S(_m_graph, _h_ap_graph, _h_aj_graph, _h_perm_graph, _h_peri_graph); + S.symbolicFactorize(_verbose); + S.evaporateSymbolicFactors(_h_aw_graph, _verbose); + + _nsupernodes = S.NumSupernodes(); + _stree_level = S.SupernodesTreeLevel(); + _stree_roots = S.SupernodesTreeRoots(); + + _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); + _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); + _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); + _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); + _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); + _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); + _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); + _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); + _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); + _perm = Kokkos::create_mirror_view(exec_memory_space(), S.PermVector()); + _peri = Kokkos::create_mirror_view(exec_memory_space(), S.InvPermVector()); + + Kokkos::deep_copy(_supernodes, S.Supernodes()); + Kokkos::deep_copy(_gid_super_panel_ptr, S.gidSuperPanelPtr()); + Kokkos::deep_copy(_gid_super_panel_colidx, S.gidSuperPanelColIdx()); + Kokkos::deep_copy(_sid_super_panel_ptr, S.sidSuperPanelPtr()); + Kokkos::deep_copy(_sid_super_panel_colidx, S.sidSuperPanelColIdx()); + Kokkos::deep_copy(_blk_super_panel_colidx, S.blkSuperPanelColIdx()); + Kokkos::deep_copy(_stree_parent, S.SupernodesTreeParent()); + Kokkos::deep_copy(_stree_ptr, S.SupernodesTreePtr()); + Kokkos::deep_copy(_stree_children, S.SupernodesTreeChildren()); + Kokkos::deep_copy(_perm, S.PermVector()); + Kokkos::deep_copy(_peri, S.InvPermVector()); + + _h_perm = S.PermVector(); + _h_peri = S.InvPermVector(); } + return 0; +} - template - int - Driver - ::initialize() { - if (_verbose) { - printf("TachoSolver: Initialize\n"); - printf("=======================\n"); - } +template int Driver::initialize() { + if (_verbose) { + printf("TachoSolver: Initialize\n"); + printf("=======================\n"); + } + /// + /// initialize numeric tools + /// + if (_m < _small_problem_thres) { + /// do nothing + } else { /// - /// initialize numeric tools + /// create numeric tools serial for host space /// - if (_m < _small_problem_thres) { - /// do nothing - } else { - /// - /// create numeric tools serial for host space - /// -#if !defined(__CUDA_ARCH__) - if (std::is_same::value) { - const ordinal_type nthreads = host_space::impl_thread_pool_size(0); - if (nthreads == 1 || true) { - /// single threaded case - if (_N == nullptr) - _N = (numeric_tools_base_type*) ::operator new (sizeof(numeric_tools_serial_type)); - - new (_N) numeric_tools_serial_type(_mode, - _m, _ap, _aj, - _perm, _peri, - _nsupernodes, _supernodes, - _gid_super_panel_ptr, _gid_super_panel_colidx, - _sid_super_panel_ptr, _sid_super_panel_colidx, _blk_super_panel_colidx, - _stree_parent, _stree_ptr, _stree_children, - _stree_level, _stree_roots); - } - //// levelset is not going to be supported on CPUs -#if 0 - else { - /// multi threaded case for test only - if (_levelset || true) { - /// level schedule - -#define TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_name) \ - do { \ - if (_N == nullptr) \ - _N = (numeric_tools_base_type*) ::operator new (sizeof(numeric_tools_levelset_name)); \ - new (_N) numeric_tools_levelset_name(_mode, \ - _m, _ap, _aj, \ - _perm, _peri, \ - _nsupernodes, _supernodes, \ - _gid_super_panel_ptr, _gid_super_panel_colidx, \ - _sid_super_panel_ptr, _sid_super_panel_colidx, _blk_super_panel_colidx, \ - _stree_parent, _stree_ptr, _stree_children, \ - _stree_level, _stree_roots); \ - numeric_tools_levelset_name * N = dynamic_cast(_N); \ - N->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); \ - N->createStream(_nstreams, _verbose); \ - } while (false) - - switch (_variant) { - case 0: { TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_var0_type); break; } - case 1: { TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_var1_type); break; } - case 2: { TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_var2_type); break; } - } -#undef TACHO_CREATE_NUMERICTOOLS_LEVELSET - } else { - /// tasking - } - } -#endif - } -#endif -#if defined(KOKKOS_ENABLE_CUDA) - if (std::is_same::value) { - if (_levelset || true) { - /// level schedule -#define TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_name) \ - do { \ - if (_N == nullptr) \ - _N = (numeric_tools_base_type*) ::operator new (sizeof(numeric_tools_levelset_name)); \ - new (_N) numeric_tools_levelset_name(_mode, \ - _m, _ap, _aj, \ - _perm, _peri, \ - _nsupernodes, _supernodes, \ - _gid_super_panel_ptr, _gid_super_panel_colidx, \ - _sid_super_panel_ptr, _sid_super_panel_colidx, _blk_super_panel_colidx, \ - _stree_parent, _stree_ptr, _stree_children, \ - _stree_level, _stree_roots); \ - numeric_tools_levelset_name * N = dynamic_cast(_N); \ - N->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); \ - N->createStream(_nstreams, _verbose); \ - } while (false) - - switch (_variant) { - case 0: { TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_var0_type); break; } - case 1: { TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_var1_type); break; } - case 2: { TACHO_CREATE_NUMERICTOOLS_LEVELSET(numeric_tools_levelset_var2_type); break; } - } -#undef TACHO_CREATE_NUMERICTOOLS_LEVELSET - } - } -#endif - } - return 0; + NumericToolsFactory factory; + factory.setBaseMember(_method, _m, _ap, _aj, _perm, _peri, _nsupernodes, _supernodes, _gid_super_panel_ptr, + _gid_super_panel_colidx, _sid_super_panel_ptr, _sid_super_panel_colidx, + _blk_super_panel_colidx, _stree_parent, _stree_ptr, _stree_children, _stree_level, + _stree_roots, _verbose); + + factory.setLevelSetMember(_variant, _device_level_cut, _device_factor_thres, _device_solve_thres, _nstreams); + + factory.createObject(_N); } + return 0; +} - template - int - Driver - ::factorize(const value_type_array &ax) { - if (_verbose) { - switch (_mode) { - case Cholesky: { - printf("TachoSolver: Factorize Cholesky\n"); - printf("===============================\n"); - break; - } - case LDL: { - printf("TachoSolver: Factorize LDL\n"); - printf("==========================\n"); - break; - } - } +template int Driver::factorize(const value_type_array &ax) { + if (_verbose) { + switch (_method) { + case Cholesky: { + printf("TachoSolver: Factorize Cholesky\n"); + printf("===============================\n"); + break; + } + case LDL: { + printf("TachoSolver: Factorize LDL\n"); + printf("==========================\n"); + break; + } + case SymLU: { + printf("TachoSolver: Factorize SymLU\n"); + printf("============================\n"); + break; } - - if (_m < _small_problem_thres) { - factorize_small_host(ax); - } else { - _N->factorize(ax, _verbose); } - return 0; } - template - int - Driver - ::factorize_small_host(const value_type_array &ax) { - double t_copy(0), t_factor(0); - { - Kokkos::Timer timer; - - timer.reset(); - _A = value_type_matrix_host("A", _m, _m); - auto h_ax = Kokkos::create_mirror_view_and_copy(host_memory_space(), ax); - for (ordinal_type i=0;i<_m;++i) { - const size_type jbeg = _h_ap(i), jend = _h_ap(i+1); - for (size_type j=jbeg;j= col)); /// lower - if (flag) - _A(i, col) = h_ax(j); - } - } - t_copy = timer.seconds(); - - timer.reset(); - switch (_mode) { - case Cholesky: { - Tacho::Chol::invoke(_A); - break; - } - case LDL: { - _P = ordinal_type_array_host("P", 4*_m); - _D = value_type_matrix_host("D", _m, 2); - auto W = value_type_array_host("W", 32*_m); - Tacho::LDL::invoke(_A, _P, W); - Tacho::LDL::modify(_A, _P, _D); - break; - } - default: { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Solution method is not implemented"); - break; - } + if (_m < _small_problem_thres) { + factorize_small_host(ax); + } else { + _N->factorize(ax, _verbose); + } + return 0; +} + +template int Driver::factorize_small_host(const value_type_array &ax) { + double t_copy(0), t_factor(0); + { + Kokkos::Timer timer; + + timer.reset(); + _A = value_type_matrix_host("A", _m, _m); + auto h_ax = Kokkos::create_mirror_view_and_copy(host_memory_space(), ax); + for (ordinal_type i = 0; i < _m; ++i) { + const size_type jbeg = _h_ap(i), jend = _h_ap(i + 1); + for (size_type j = jbeg; j < jend; ++j) { + const ordinal_type col = _h_aj(j); + const bool flag = ((_method == Cholesky && i <= col) || /// upper + (_method == LDL && i >= col) || /// lower + (_method == SymLU)); /// full matrix + if (flag) + _A(i, col) = h_ax(j); } - t_factor = timer.seconds(); } + t_copy = timer.seconds(); - if (_verbose) { - printf("Summary: NumericTools (SmallDenseFactorization)\n"); - printf("===============================================\n"); - printf(" Time\n"); - printf(" time for copying A into supernodes: %10.6f s\n", t_copy); - printf(" time for numeric factorization: %10.6f s\n", t_factor); - printf(" total time spent: %10.6f s\n", (t_copy+t_factor)); - printf("\n"); + timer.reset(); + switch (_method) { + case Cholesky: { + Tacho::Chol::invoke(_A); + break; + } + case LDL: { + _P = ordinal_type_array_host("P", 4 * _m); + _D = value_type_matrix_host("D", _m, 2); + auto W = value_type_array_host("W", 32 * _m); + Tacho::LDL::invoke(_A, _P, W); + Tacho::LDL::modify(_A, _P, _D); + break; + } + case SymLU: { + _P = ordinal_type_array_host("P", 4 * _m); + Tacho::LU::invoke(_A, _P); + Tacho::LU::modify(_m, _P); + break; + } + default: { + std::stringstream ss; + ss << "Error: the solution method (" << _method << ") is not supported, 1 - Chol, 2 - LDL, 3 - SymLU"; + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, ss.str().c_str()); + break; + } } + t_factor = timer.seconds(); + } - return 0; + if (_verbose) { + printf("Summary: NumericTools (SmallDenseFactorization)\n"); + printf("===============================================\n"); + printf(" Time\n"); + printf(" time for copying A into supernodes: %10.6f s\n", t_copy); + printf(" time for numeric factorization: %10.6f s\n", t_factor); + printf(" total time spent: %10.6f s\n", (t_copy + t_factor)); + printf("\n"); } - template - int - Driver - ::solve(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t) { - if (_verbose) { - switch (_mode) { - case Cholesky: { - printf("TachoSolver: Solve Cholesky\n"); - printf("===========================\n"); - break; - } - case LDL: { - printf("TachoSolver: Solve LDL\n"); - printf("======================\n"); - break; - } - } + return 0; +} + +template +int Driver::solve(const value_type_matrix &x, const value_type_matrix &b, const value_type_matrix &t) { + if (_verbose) { + switch (_method) { + case Cholesky: { + printf("TachoSolver: Solve Cholesky\n"); + printf("===========================\n"); + break; } - - if (_m < _small_problem_thres) { - solve_small_host(x, b, t); - } else { - TACHO_TEST_FOR_EXCEPTION(t.extent(0) < x.extent(0) || - t.extent(1) < x.extent(1), - std::logic_error, - "Temporary rhs vector t is smaller than x"); - auto tt = Kokkos::subview(t, - Kokkos::pair(0, x.extent(0)), - Kokkos::pair(0, x.extent(1))); - _N->solve(x, b, tt, _verbose); + case LDL: { + printf("TachoSolver: Solve LDL\n"); + printf("======================\n"); + break; + } + case SymLU: { + printf("TachoSolver: Solve SymLU\n"); + printf("========================\n"); + break; } - return 0; - } - - template - int - Driver - ::solve_small_host(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t) { - Kokkos::Timer timer; - double t_copy(0), t_solve(0); - { - timer.reset(); - Kokkos::deep_copy(x, b); - t_copy = timer.seconds(); - - timer.reset(); - switch (_mode) { - case Cholesky: { - auto h_x = Kokkos::create_mirror_view_and_copy(host_memory_space(), x); - Trsm - ::invoke(Diag::NonUnit(), 1.0, _A, h_x); - Trsm - ::invoke(Diag::NonUnit(), 1.0, _A, h_x); - Kokkos::deep_copy(x, h_x); - break; - } - case LDL: { - auto perm = ordinal_type_array_host(_P.data()+2*_m, _m); - auto peri = ordinal_type_array_host(_P.data()+3*_m, _m); - auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); - auto h_t = Kokkos::create_mirror_view(host_memory_space(), t); - - ApplyPermutation - ::invoke(h_x, perm, h_t); - Trsm - ::invoke(Diag::Unit(), 1.0, _A, h_t); - Scale2x2_BlockInverseDiagonals - ::invoke(_P, _D, h_t); - Trsm - ::invoke(Diag::Unit(), 1.0, _A, h_t); - ApplyPermutation - ::invoke(h_t, peri, h_x); - Kokkos::deep_copy(x, h_x); - break; - } - } - t_solve = timer.seconds(); } - - if (_verbose) { - printf("Summary: NumericTools (SmallDenseSolve)\n"); - printf("=======================================\n"); - printf(" Time\n"); - printf(" time for extra work e.g.,copy rhs: %10.6f s\n", t_copy); - printf(" time for numeric solve: %10.6f s\n", t_solve); - printf(" total time spent: %10.6f s\n", (t_solve+t_copy)); - printf("\n"); - } - return 0; } - template - double - Driver - ::computeRelativeResidual(const value_type_array &ax, - const value_type_matrix &x, - const value_type_matrix &b) { - CrsMatrixBase A; - A.setExternalMatrix(_m, _m, _nnz, _ap, _aj, ax); - - return Tacho::computeRelativeResidual(A, x, b); + if (_m < _small_problem_thres) { + solve_small_host(x, b, t); + } else { + TACHO_TEST_FOR_EXCEPTION(t.extent(0) < x.extent(0) || t.extent(1) < x.extent(1), std::logic_error, + "Temporary rhs vector t is smaller than x"); + auto tt = Kokkos::subview(t, Kokkos::pair(0, x.extent(0)), + Kokkos::pair(0, x.extent(1))); + _N->solve(x, b, tt, _verbose); } + return 0; +} - template - int - Driver - ::exportFactorsToCrsMatrix(crs_matrix_type &A) { - if (_m < _small_problem_thres) { - typedef ArithTraits ats; - const typename ats::mag_type zero(0); - - /// count nonzero elements in dense U - const ordinal_type m = _m; - size_type_array_host h_ap("h_ap", m+1); - for (ordinal_type i=0;i zero); - - /// serial scan; this is a small problem - h_ap(0) = 0; - for (ordinal_type i=0;i zero) { - h_aj(k) = j; - h_ax(k) = _A(i,j); - ++k; - } - - crs_matrix_type_host h_A; - h_A.setExternalMatrix(m, m, nnz, h_ap, h_aj, h_ax); - ///h_A.showMe(std::cout, true); - A.clear(); - A.createConfTo(h_A); - A.copy(h_A); - } else { - _N->exportFactorsToCrsMatrix(A, false); +template +int Driver::solve_small_host(const value_type_matrix &x, const value_type_matrix &b, + const value_type_matrix &t) { + Kokkos::Timer timer; + double t_copy(0), t_solve(0); + { + timer.reset(); + Kokkos::deep_copy(x, b); + t_copy = timer.seconds(); + + timer.reset(); + switch (_method) { + case Cholesky: { + auto h_x = Kokkos::create_mirror_view_and_copy(host_memory_space(), x); + Trsm::invoke(Diag::NonUnit(), 1.0, _A, h_x); + Trsm::invoke(Diag::NonUnit(), 1.0, _A, h_x); + Kokkos::deep_copy(x, h_x); + break; } - return 0; - } - - template - int - Driver - ::release() { - if (_verbose) { - printf("TachoSolver: Release\n"); - printf("====================\n"); + case LDL: { + auto perm = ordinal_type_array_host(_P.data() + 2 * _m, _m); + auto peri = ordinal_type_array_host(_P.data() + 3 * _m, _m); + auto h_x = Kokkos::create_mirror_view_and_copy(host_memory_space(), x); + auto h_t = Kokkos::create_mirror_view(host_memory_space(), t); + + ApplyPermutation::invoke(h_x, perm, h_t); + Trsm::invoke(Diag::Unit(), 1.0, _A, h_t); + Scale2x2_BlockInverseDiagonals::invoke(_P, _D, h_t); + Trsm::invoke(Diag::Unit(), 1.0, _A, h_t); + ApplyPermutation::invoke(h_t, peri, h_x); + Kokkos::deep_copy(x, h_x); + break; } - - { - if (_N != nullptr) - _N->release(_verbose); - delete _N; _N = nullptr; + case SymLU: { + auto perm = ordinal_type_array_host(_P.data() + 2 * _m, _m); + auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); + auto h_t = Kokkos::create_mirror_view(host_memory_space(), t); + Kokkos::deep_copy(h_t, x); + ApplyPermutation::invoke(h_t, perm, h_x); + Trsm::invoke(Diag::Unit(), 1.0, _A, h_x); + Trsm::invoke(Diag::NonUnit(), 1.0, _A, h_x); + Kokkos::deep_copy(x, h_x); + break; } - { - _mode = 0; - - _m = 0; - _nnz = 0; - - _ap = size_type_array(); _h_ap = size_type_array_host(); - _aj = ordinal_type_array(); _h_aj = ordinal_type_array_host(); - - _perm = ordinal_type_array(); _h_perm = ordinal_type_array_host(); - _peri = ordinal_type_array(); _h_peri = ordinal_type_array_host(); - - _m_graph = 0; - _nnz_graph = 0; - - _h_ap_graph = size_type_array_host(); - _h_aj_graph = ordinal_type_array_host(); - - _h_perm_graph = ordinal_type_array_host(); - _h_peri_graph = ordinal_type_array_host(); - - _nsupernodes = 0; - _supernodes = ordinal_type_array(); - - _gid_super_panel_ptr = size_type_array(); - _gid_super_panel_colidx = ordinal_type_array(); - - _sid_super_panel_ptr = size_type_array(); - - _sid_super_panel_colidx = ordinal_type_array(); - _blk_super_panel_colidx = ordinal_type_array(); - - _stree_ptr = size_type_array(); - _stree_children = ordinal_type_array(); - - _stree_parent = ordinal_type_array(); - _stree_roots = ordinal_type_array_host(); - - _A = value_type_matrix_host(); - - _verbose = 0; - _small_problem_thres = 1024; } - return 0; + t_solve = timer.seconds(); + } + + if (_verbose) { + printf("Summary: NumericTools (SmallDenseSolve)\n"); + printf("=======================================\n"); + printf(" Time\n"); + printf(" time for extra work e.g.,copy rhs: %10.6f s\n", t_copy); + printf(" time for numeric solve: %10.6f s\n", t_solve); + printf(" total time spent: %10.6f s\n", (t_solve + t_copy)); + printf("\n"); + } + return 0; +} + +template +double Driver::computeRelativeResidual(const value_type_array &ax, const value_type_matrix &x, + const value_type_matrix &b) { + CrsMatrixBase A; + A.setExternalMatrix(_m, _m, _nnz, _ap, _aj, ax); + + return Tacho::computeRelativeResidual(A, x, b); +} + +template int Driver::exportFactorsToCrsMatrix(crs_matrix_type &A) { + if (_m < _small_problem_thres) { + typedef ArithTraits ats; + const typename ats::mag_type zero(0); + + /// count nonzero elements in dense U + const ordinal_type m = _m; + size_type_array_host h_ap("h_ap", m + 1); + for (ordinal_type i = 0; i < m; ++i) + for (ordinal_type j = 0; j < m; ++j) + h_ap(i + 1) += (ats::abs(_A(i, j)) > zero); + + /// serial scan; this is a small problem + h_ap(0) = 0; + for (ordinal_type i = 0; i < m; ++i) + h_ap(i + 1) += h_ap(i); + + /// create a host crs matrix + const ordinal_type nnz = h_ap(m); + ordinal_type_array_host h_aj(do_not_initialize_tag("h_aj"), nnz); + value_type_array_host h_ax(do_not_initialize_tag("h_ax"), nnz); + + for (ordinal_type i = 0, k = 0; i < m; ++i) + for (ordinal_type j = i; j < m; ++j) + if (ats::abs(_A(i, j)) > zero) { + h_aj(k) = j; + h_ax(k) = _A(i, j); + ++k; + } + + crs_matrix_type_host h_A; + h_A.setExternalMatrix(m, m, nnz, h_ap, h_aj, h_ax); + /// h_A.showMe(std::cout, true); + A.clear(); + A.createConfTo(h_A); + A.copy(h_A); + } else { + _N->exportFactorsToCrsMatrix(A, false); + } + return 0; +} + +template int Driver::release() { + if (_verbose) { + printf("TachoSolver: Release\n"); + printf("====================\n"); } + { + if (_N != nullptr) + _N->release(_verbose); + delete _N; + _N = nullptr; + } + { + _method = 0; + + _m = 0; + _nnz = 0; + + _ap = size_type_array(); + _h_ap = size_type_array_host(); + _aj = ordinal_type_array(); + _h_aj = ordinal_type_array_host(); + + _perm = ordinal_type_array(); + _h_perm = ordinal_type_array_host(); + _peri = ordinal_type_array(); + _h_peri = ordinal_type_array_host(); + + _m_graph = 0; + _nnz_graph = 0; + + _h_ap_graph = size_type_array_host(); + _h_aj_graph = ordinal_type_array_host(); + + _h_perm_graph = ordinal_type_array_host(); + _h_peri_graph = ordinal_type_array_host(); + + _nsupernodes = 0; + _supernodes = ordinal_type_array(); + + _gid_super_panel_ptr = size_type_array(); + _gid_super_panel_colidx = ordinal_type_array(); + + _sid_super_panel_ptr = size_type_array(); + + _sid_super_panel_colidx = ordinal_type_array(); + _blk_super_panel_colidx = ordinal_type_array(); + + _stree_ptr = size_type_array(); + _stree_children = ordinal_type_array(); + + _stree_parent = ordinal_type_array(); + _stree_roots = ordinal_type_array_host(); + + _A = value_type_matrix_host(); + + _verbose = 0; + _small_problem_thres = 1024; + } + return 0; } +} // namespace Tacho + #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm.hpp index 2828b68c0dc6..601cdf88fc57 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm.hpp @@ -9,66 +9,17 @@ namespace Tacho { - /// - /// Gemm: - /// +/// +/// Gemm: +/// - /// various implementation for different uplo and algo parameters - template - struct Gemm; +/// various implementation for different uplo and algo parameters +template struct Gemm; - /// task construction for the above chol implementation - /// Gemm::invoke(_sched, member, _alpha, _A, _B, _beta, _C); - template - struct TaskFunctor_Gemm { - public: - typedef SchedulerType scheduler_type; - typedef typename scheduler_type::member_type member_type; +struct GemmAlgorithm { + using type = ActiveAlgorithm::type; +}; - typedef ScalarType scalar_type; - - typedef DenseMatrixViewType dense_block_type; - typedef typename dense_block_type::future_type future_type; - typedef typename future_type::value_type value_type; - - private: - scalar_type _alpha, _beta; - dense_block_type _A, _B, _C; - - public: - KOKKOS_INLINE_FUNCTION - TaskFunctor_Gemm() = delete; - - KOKKOS_INLINE_FUNCTION - TaskFunctor_Gemm(const scalar_type alpha, - const dense_block_type &A, - const dense_block_type &B, - const scalar_type beta, - const dense_block_type &C) - : _alpha(alpha), - _beta(beta), - _A(A), - _B(B), - _C(C) {} - - KOKKOS_INLINE_FUNCTION - void operator()(member_type &member, value_type &r_val) { - const int ierr = Gemm - ::invoke(member, _alpha, _A, _B, _beta, _C); - - Kokkos::single(Kokkos::PerTeam(member), - [&, ierr]() { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - _C.set_future(); - r_val = ierr; - }); - } - }; - -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular.hpp index 06788b492760..f6bbcf7eb533 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular.hpp @@ -9,14 +9,13 @@ namespace Tacho { - /// - /// GemmTriangular: - /// - - /// various implementation for different uplo and algo parameters - template - struct GemmTriangular; +/// +/// GemmTriangular: +/// -} +/// various implementation for different uplo and algo parameters +template struct GemmTriangular; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_External.hpp index 335d13263ec6..715743734f6d 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_External.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMM_TRIANGULAR_EXTERNAL_HPP__ #define __TACHO_GEMM_TRIANGULAR_EXTERNAL_HPP__ - /// \file Tacho_Gemm_External.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,87 +9,58 @@ namespace Tacho { - template<> - struct GemmTriangular { - template - inline - static int - invoke(const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); +template <> struct GemmTriangular { + template + inline static int invoke(const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, const ScalarType beta, + const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); - const ordinal_type - m = C.extent(0), - n = C.extent(1), - k = B.extent(0); + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); - if (m > 0 && n > 0 && k > 0) { - if (m == n) { - const ordinal_type b = 32; - value_type * aptr = A.data(), * bptr = B.data(), * cptr = C.data(); - const int as1 = A.stride_1(), bs1 = B.stride_1(), cs1 = C.stride_1(); - for (ordinal_type i=0;i m ? m : m2), nn = mm-i; - value_type * aaptr = aptr, * bbptr = bptr+i*bs1, * ccptr = cptr+i*cs1; - Blas::gemm(Trans::Transpose::param, Trans::NoTranspose::param, - mm, nn, k, - value_type(alpha), - aaptr, as1, - bbptr, bs1, - value_type(beta), - ccptr, cs1); - } - } else { - TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); + const ordinal_type m = C.extent(0), n = C.extent(1), k = B.extent(0); + + if (m > 0 && n > 0 && k > 0) { + if (m == n) { + const ordinal_type b = 32; + value_type *aptr = A.data(), *bptr = B.data(), *cptr = C.data(); + const int as1 = A.stride_1(), bs1 = B.stride_1(), cs1 = C.stride_1(); + for (ordinal_type i = 0; i < m; i += b) { + const ordinal_type m2 = i + b, mm = (m2 > m ? m : m2), nn = mm - i; + value_type *aaptr = aptr, *bbptr = bptr + i * bs1, *ccptr = cptr + i * cs1; + Blas::gemm(Trans::Transpose::param, Trans::NoTranspose::param, mm, nn, k, value_type(alpha), + aaptr, as1, bbptr, bs1, value_type(beta), ccptr, cs1); } + } else { + TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); } -#else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); -#endif - return 0; } - - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - //Kokkos::single(Kokkos::PerTeam(member), [&]() { - invoke(alpha, A, B, beta, C); - //}); #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return 0; - } + return 0; + } - }; + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, + const ScalarType beta, const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + // Kokkos::single(Kokkos::PerTeam(member), [&]() { + invoke(alpha, A, B, beta, C); + //}); +#else + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); +#endif + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_Internal.hpp index bcf4eff5da1b..09560e3ddd13 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_Internal.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMM_TRIANGULAR_INTERNAL_HPP__ #define __TACHO_GEMM_TRIANGULAR_INTERNAL_HPP__ - /// \file Tacho_Gemm_Internal.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,54 +9,37 @@ namespace Tacho { - template<> - struct GemmTriangular { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - const ordinal_type m = C.extent(0); - const ordinal_type n = C.extent(1); - const ordinal_type k = B.extent(0); - - if (m > 0 && n > 0 && k > 0) { - if (m == n) { - BlasTeam::gemm_upper(member, - Trans::Transpose::param, Trans::NoTranspose::param, - m, n, k, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - } else { - TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); - } +template <> struct GemmTriangular { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B, const ScalarType beta, const ViewTypeC &C) { + + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + const ordinal_type m = C.extent(0); + const ordinal_type n = C.extent(1); + const ordinal_type k = B.extent(0); + + if (m > 0 && n > 0 && k > 0) { + if (m == n) { + BlasTeam::gemm_upper(member, Trans::Transpose::param, Trans::NoTranspose::param, m, n, k, + value_type(alpha), A.data(), A.stride_1(), B.data(), B.stride_1(), + value_type(beta), C.data(), C.stride_1()); + } else { + TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); } - return 0; } - }; -} + return 0; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_OnDevice.hpp index f89f6292c48f..60042eff5c92 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GemmTriangular_OnDevice.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMM_TRIANGULAR_ON_DEVICE_HPP__ #define __TACHO_GEMM_TRIANGULAR_ON_DEVICE_HPP__ - /// \file Tacho_GemmTriangulr_OnDevice.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,172 +9,114 @@ namespace Tacho { - template<> - struct GemmTriangular { - template - inline - static int - blas_invoke(const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - const ordinal_type - m = C.extent(0), - n = C.extent(1), - k = B.extent(0); - - if (m > 0 && n > 0 && k > 0) { - if (m == n) { - const ordinal_type b = 32; - value_type * aptr = A.data(), * bptr = B.data(), * cptr = C.data(); - const int as1 = A.stride_1(), bs1 = B.stride_1(), cs1 = C.stride_1(); - for (ordinal_type i=0;i m ? m : m2), nn = mm-i; - value_type * aaptr = aptr, * bbptr = bptr+i*bs1, * ccptr = cptr+i*cs1; - Blas::gemm(Trans::Transpose::param, Trans::NoTranspose::param, - mm, nn, k, - value_type(alpha), - aaptr, as1, - bbptr, bs1, - value_type(beta), - ccptr, cs1); - } - } else { - TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); +template <> struct GemmTriangular { + template + inline static int blas_invoke(const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, const ScalarType beta, + const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + const ordinal_type m = C.extent(0), n = C.extent(1), k = B.extent(0); + + if (m > 0 && n > 0 && k > 0) { + if (m == n) { + const ordinal_type b = 32; + value_type *aptr = A.data(), *bptr = B.data(), *cptr = C.data(); + const int as1 = A.stride_1(), bs1 = B.stride_1(), cs1 = C.stride_1(); + for (ordinal_type i = 0; i < m; i += b) { + const ordinal_type m2 = i + b, mm = (m2 > m ? m : m2), nn = mm - i; + value_type *aaptr = aptr, *bbptr = bptr + i * bs1, *ccptr = cptr + i * cs1; + Blas::gemm(Trans::Transpose::param, Trans::NoTranspose::param, mm, nn, k, value_type(alpha), + aaptr, as1, bbptr, bs1, value_type(beta), ccptr, cs1); } + } else { + TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); } - return 0; } + return 0; + } #if defined(KOKKOS_ENABLE_CUDA) - template - inline - static int - cublas_invoke(cublasHandle_t &handle, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - const ordinal_type - m = C.extent(0), - n = C.extent(1), - k = B.extent(0); - - if (m > 0 && n > 0 && k > 0) { - if (m == n) { - const ordinal_type b = 256; - value_type * aptr = A.data(), * bptr = B.data(), * cptr = C.data(); - const int as1 = A.stride_1(), bs1 = B.stride_1(), cs1 = C.stride_1(); - if (m < 2*b) { - Blas::gemm(handle, - Trans::Transpose::cublas_param, - Trans::NoTranspose::cublas_param, - m, n, k, - value_type(alpha), - aptr, as1, - bptr, bs1, - value_type(beta), - cptr, cs1); - } else { - for (ordinal_type i=0;i m ? m : m2), nn = mm-i; - value_type * aaptr = aptr, * bbptr = bptr+i*bs1, * ccptr = cptr+i*cs1; - Blas::gemm(handle, - Trans::Transpose::cublas_param, - Trans::NoTranspose::cublas_param, - mm, nn, k, - value_type(alpha), - aaptr, as1, - bbptr, bs1, - value_type(beta), - ccptr, cs1); - } - } + template + inline static int cublas_invoke(cublasHandle_t &handle, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B, const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + const ordinal_type m = C.extent(0), n = C.extent(1), k = B.extent(0); + + if (m > 0 && n > 0 && k > 0) { + if (m == n) { + const ordinal_type b = 256; + value_type *aptr = A.data(), *bptr = B.data(), *cptr = C.data(); + const int as1 = A.stride_1(), bs1 = B.stride_1(), cs1 = C.stride_1(); + if (m < 2 * b) { + Blas::gemm(handle, Trans::Transpose::cublas_param, Trans::NoTranspose::cublas_param, m, n, k, + value_type(alpha), aptr, as1, bptr, bs1, value_type(beta), cptr, cs1); } else { - TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); + for (ordinal_type i = 0; i < m; i += b) { + const ordinal_type m2 = i + b, mm = (m2 > m ? m : m2), nn = mm - i; + value_type *aaptr = aptr, *bbptr = bptr + i * bs1, *ccptr = cptr + i * cs1; + Blas::gemm(handle, Trans::Transpose::cublas_param, Trans::NoTranspose::cublas_param, mm, nn, k, + value_type(alpha), aaptr, as1, bbptr, bs1, value_type(beta), ccptr, cs1); + } } + } else { + TACHO_TEST_FOR_ABORT(true, "C is not a square matrix"); } - return 0; } + return 0; + } #endif - - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeB::memory_space memory_space_b; - typedef typename ViewTypeC::memory_space memory_space_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same memory space."); - - int r_val(0); - if (std::is_same::value) - r_val = blas_invoke(alpha, A, B, beta, C); + + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, + const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeB::memory_space memory_space_b; + typedef typename ViewTypeC::memory_space memory_space_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + static_assert(std::is_same::value && + std::is_same::value, + "A, B and C do not have the same memory space."); + + int r_val(0); + if (std::is_same::value) + r_val = blas_invoke(alpha, A, B, beta, C); #if defined(KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) - r_val = cublas_invoke(member, alpha, A, B, beta, C); + if (std::is_same::value || std::is_same::value) + r_val = cublas_invoke(member, alpha, A, B, beta, C); #endif - return r_val; - } - - }; + return r_val; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_External.hpp index fda4cb7f5fc8..4faee3d3aa47 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_External.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMM_EXTERNAL_HPP__ #define __TACHO_GEMM_EXTERNAL_HPP__ - /// \file Tacho_Gemm_External.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,76 +9,48 @@ namespace Tacho { - template - struct Gemm { - template - inline - static int - invoke(const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - const ordinal_type - m = C.extent(0), - n = C.extent(1), - k = (std::is_same::value ? B.extent(0) : B.extent(1)); - - if (m > 0 && n > 0 && k > 0) { - Blas::gemm(ArgTransA::param, ArgTransB::param, - m, n, k, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - } +template struct Gemm { + template + inline static int invoke(const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, const ScalarType beta, + const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + const ordinal_type m = C.extent(0), n = C.extent(1), + k = (std::is_same::value ? B.extent(0) : B.extent(1)); + + if (m > 0 && n > 0 && k > 0) { + Blas::gemm(ArgTransA::param, ArgTransB::param, m, n, k, value_type(alpha), A.data(), A.stride_1(), + B.data(), B.stride_1(), value_type(beta), C.data(), C.stride_1()); + } #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return 0; - } - - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - //Kokkos::single(Kokkos::PerTeam(member), [&]() { - invoke(alpha, A, B, beta, C); - //}); + return 0; + } + + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, + const ScalarType beta, const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + // Kokkos::single(Kokkos::PerTeam(member), [&]() { + invoke(alpha, A, B, beta, C); + //}); #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return 0; - } - - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_Internal.hpp index e763172a253b..779ace316204 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_Internal.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMM_INTERNAL_HPP__ #define __TACHO_GEMM_INTERNAL_HPP__ - /// \file Tacho_Gemm_Internal.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,50 +9,31 @@ namespace Tacho { - template - struct Gemm { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - const ordinal_type m = C.extent(0); - const ordinal_type n = C.extent(1); - const ordinal_type - k = (std::is_same::value ? B.extent(0) : B.extent(1)); - - if (m > 0 && n > 0 && k > 0) - BlasTeam::gemm(member, - ArgTransA::param, ArgTransB::param, - m, n, k, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - return 0; - } - }; -} +template struct Gemm { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B, const ScalarType beta, const ViewTypeC &C) { + + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + const ordinal_type m = C.extent(0); + const ordinal_type n = C.extent(1); + const ordinal_type k = (std::is_same::value ? B.extent(0) : B.extent(1)); + + if (m > 0 && n > 0 && k > 0) + BlasTeam::gemm(member, ArgTransA::param, ArgTransB::param, m, n, k, value_type(alpha), A.data(), + A.stride_1(), B.data(), B.stride_1(), value_type(beta), C.data(), C.stride_1()); + return 0; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_OnDevice.hpp index cd3179423859..ad54b8485795 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_OnDevice.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMM_ON_DEVICE_HPP__ #define __TACHO_GEMM_ON_DEVICE_HPP__ - /// \file Tacho_Gemm_OnDevice.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,115 +9,70 @@ namespace Tacho { - template - struct Gemm { - template - inline - static int - blas_invoke(const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = C.extent(0), - n = C.extent(1), - k = (std::is_same::value ? B.extent(0) : B.extent(1)); +template struct Gemm { + template + inline static int blas_invoke(const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, const ScalarType beta, + const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = C.extent(0), n = C.extent(1), + k = (std::is_same::value ? B.extent(0) : B.extent(1)); - if (m > 0 && n > 0 && k > 0) { - Blas::gemm(ArgTransA::param, ArgTransB::param, - m, n, k, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - } - return 0; + if (m > 0 && n > 0 && k > 0) { + Blas::gemm(ArgTransA::param, ArgTransB::param, m, n, k, value_type(alpha), A.data(), A.stride_1(), + B.data(), B.stride_1(), value_type(beta), C.data(), C.stride_1()); } + return 0; + } #if defined(KOKKOS_ENABLE_CUDA) - template - inline - static int - cublas_invoke(cublasHandle_t &handle, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; + template + inline static int cublas_invoke(cublasHandle_t &handle, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B, const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = C.extent(0), - n = C.extent(1), - k = (std::is_same::value ? B.extent(0) : B.extent(1)); + const ordinal_type m = C.extent(0), n = C.extent(1), + k = (std::is_same::value ? B.extent(0) : B.extent(1)); - int r_val(0); - if (m > 0 && n > 0 && k > 0) { - r_val = Blas::gemm(handle, - ArgTransA::cublas_param, ArgTransB::cublas_param, - m, n, k, - alpha, - A.data(), A.stride_1(), - B.data(), B.stride_1(), - beta, - C.data(), C.stride_1()); - } - return r_val; + int r_val(0); + if (m > 0 && n > 0 && k > 0) { + r_val = Blas::gemm(handle, ArgTransA::cublas_param, ArgTransB::cublas_param, m, n, k, alpha, A.data(), + A.stride_1(), B.data(), B.stride_1(), beta, C.data(), C.stride_1()); } + return r_val; + } #endif - - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeB::memory_space memory_space_b; - typedef typename ViewTypeC::memory_space memory_space_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, + const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeB::memory_space memory_space_b; + typedef typename ViewTypeC::memory_space memory_space_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same memory space."); - int r_val(0); - if (std::is_same::value) - r_val = blas_invoke(alpha, A, B, beta, C); -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) - r_val = cublas_invoke(member, alpha, A, B, beta, C); + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + static_assert(std::is_same::value && + std::is_same::value, + "A, B and C do not have the same memory space."); + int r_val(0); + if (std::is_same::value) + r_val = blas_invoke(alpha, A, B, beta, C); +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || std::is_same::value) + r_val = cublas_invoke(member, alpha, A, B, beta, C); #endif - return r_val; - } - }; + return r_val; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv.hpp index 26b8529c733e..f3f1b7afd432 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv.hpp @@ -9,64 +9,17 @@ namespace Tacho { - /// - /// Gemm: - /// +/// +/// Gemm: +/// - /// various implementation for different uplo and algo parameters - template - struct Gemv; +/// various implementation for different uplo and algo parameters +template struct Gemv; - /// task construction for the above chol implementation - /// Gemm::invoke(_sched, member, _alpha, _A, _B, _beta, _C); - template - struct TaskFunctor_Gemv { - public: - typedef SchedulerType scheduler_type; - typedef typename scheduler_type::member_type member_type; - - typedef ScalarType scalar_type; - - typedef DenseMatrixViewType dense_block_type; - typedef typename dense_block_type::future_type future_type; - typedef typename future_type::value_type value_type; +struct GemvAlgorithm { + using type = ActiveAlgorithm::type; +}; - private: - scalar_type _alpha, _beta; - dense_block_type _A, _B, _C; - - public: - KOKKOS_INLINE_FUNCTION - TaskFunctor_Gemv() = delete; - - KOKKOS_INLINE_FUNCTION - TaskFunctor_Gemv(const scalar_type alpha, - const dense_block_type &A, - const dense_block_type &B, - const scalar_type beta, - const dense_block_type &C) - : _alpha(alpha), - _beta(beta), - _A(A), - _B(B), - _C(C) {} - - KOKKOS_INLINE_FUNCTION - void operator()(member_type &member, value_type &r_val) { - const int ierr = Gemv - ::invoke(member, _alpha, _A, _B, _beta, _C); - - Kokkos::single(Kokkos::PerTeam(member), [&]() { - _C.set_future(); - r_val = ierr; - }); - } - }; - -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_External.hpp index cc427308f4b0..92abcfc26714 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_External.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMV_EXTERNAL_HPP__ #define __TACHO_GEMV_EXTERNAL_HPP__ - /// \file Tacho_Gemv_External.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,89 +9,54 @@ namespace Tacho { - template - struct Gemv { - template - inline - static int - invoke(const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - const ordinal_type - m = C.extent(0), - n = C.extent(1); - - - if (m > 0 && n > 0) { - if (n == 1) { - const int mm = A.extent(0), nn = A.extent(1); - Blas::gemv(ArgTrans::param, - mm, nn, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_0(), - value_type(beta), - C.data(), C.stride_0()); - } else { - const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); - Blas::gemm(ArgTrans::param, - Trans::NoTranspose::param, - mm, nn, kk, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - } - - } +template struct Gemv { + template + inline static int invoke(const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, const ScalarType beta, + const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + const ordinal_type m = C.extent(0), n = C.extent(1); + + if (m > 0 && n > 0) { + if (n == 1) { + const int mm = A.extent(0), nn = A.extent(1); + Blas::gemv(ArgTrans::param, mm, nn, value_type(alpha), A.data(), A.stride_1(), B.data(), + B.stride_0(), value_type(beta), C.data(), C.stride_0()); + } else { + const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); + Blas::gemm(ArgTrans::param, Trans::NoTranspose::param, mm, nn, kk, value_type(alpha), A.data(), + A.stride_1(), B.data(), B.stride_1(), value_type(beta), C.data(), C.stride_1()); + } + } #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return 0; - } - - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - //Kokkos::single(Kokkos::PerTeam(member), [&]() { - invoke(alpha, A, B, beta, C); - //}); + return 0; + } + + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, + const ScalarType beta, const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + // Kokkos::single(Kokkos::PerTeam(member), [&]() { + invoke(alpha, A, B, beta, C); + //}); #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return 0; - } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_Internal.hpp index 1db5b884e9b9..1a0b62f9028b 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_Internal.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMV_INTERNAL_HPP__ #define __TACHO_GEMV_INTERNAL_HPP__ - /// \file Tacho_Gemv_Internal.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,66 +9,39 @@ namespace Tacho { - template - struct Gemv { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); - - const ordinal_type - m = C.extent(0), - n = C.extent(1); - - if (m > 0 && n > 0) { - if (n == 1) { - const int mm = A.extent(0), nn = A.extent(1); - BlasTeam::gemv(member, - ArgTrans::param, - mm, nn, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_0(), - value_type(beta), - C.data(), C.stride_0()); - } else { - const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); - BlasTeam::gemm(member, - ArgTrans::param, - Trans::NoTranspose::param, - mm, nn, kk, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - - } - } - return 0; +template struct Gemv { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B, const ScalarType beta, const ViewTypeC &C) { + + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + const ordinal_type m = C.extent(0), n = C.extent(1); + + if (m > 0 && n > 0) { + if (n == 1) { + const int mm = A.extent(0), nn = A.extent(1); + BlasTeam::gemv(member, ArgTrans::param, mm, nn, value_type(alpha), A.data(), A.stride_1(), B.data(), + B.stride_0(), value_type(beta), C.data(), C.stride_0()); + } else { + const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); + BlasTeam::gemm(member, ArgTrans::param, Trans::NoTranspose::param, mm, nn, kk, value_type(alpha), + A.data(), A.stride_1(), B.data(), B.stride_1(), value_type(beta), C.data(), + C.stride_1()); } - }; + } + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_OnDevice.hpp index 34397353452c..809611fadf5e 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemv_OnDevice.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_GEMV_ON_DEVICE_HPP__ #define __TACHO_GEMV_ON_DEVICE_HPP__ - /// \file Tacho_Gemv_OnDevice.hpp /// \brief BLAS general matrix matrix multiplication /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,139 +9,82 @@ namespace Tacho { - template - struct Gemv { - template - inline - static int - blas_invoke(const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = C.extent(0), - n = C.extent(1); +template struct Gemv { + template + inline static int blas_invoke(const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, const ScalarType beta, + const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = C.extent(0), n = C.extent(1); - if (m > 0 && n > 0) { - if (n == 1) { - const int mm = A.extent(0), nn = A.extent(1); - Blas::gemv(ArgTrans::param, - mm, nn, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_0(), - value_type(beta), - C.data(), C.stride_0()); - } else { - const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); - Blas::gemm(ArgTrans::param, - Trans::NoTranspose::param, - mm, nn, kk, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - } - } - return 0; + if (m > 0 && n > 0) { + if (n == 1) { + const int mm = A.extent(0), nn = A.extent(1); + Blas::gemv(ArgTrans::param, mm, nn, value_type(alpha), A.data(), A.stride_1(), B.data(), + B.stride_0(), value_type(beta), C.data(), C.stride_0()); + } else { + const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); + Blas::gemm(ArgTrans::param, Trans::NoTranspose::param, mm, nn, kk, value_type(alpha), A.data(), + A.stride_1(), B.data(), B.stride_1(), value_type(beta), C.data(), C.stride_1()); } + } + return 0; + } #if defined(KOKKOS_ENABLE_CUDA) - template - inline - static int - cublas_invoke(cublasHandle_t &handle, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; + template + inline static int cublas_invoke(cublasHandle_t &handle, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B, const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = C.extent(0), - n = C.extent(1); + const ordinal_type m = C.extent(0), n = C.extent(1); - int r_val(0); - if (m > 0 && n > 0) { - if (n == 1) { - const int mm = A.extent(0), nn = A.extent(1); - r_val = Blas::gemv(handle, - ArgTrans::cublas_param, - mm, nn, - alpha, - A.data(), A.stride_1(), - B.data(), B.stride_0(), - beta, - C.data(), C.stride_0()); - } else { - const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); - r_val = Blas::gemm(handle, - ArgTrans::cublas_param, - Trans::NoTranspose::cublas_param, - mm, nn, kk, - alpha, - A.data(), A.stride_1(), - B.data(), B.stride_1(), - beta, - C.data(), C.stride_1()); - } - } - return r_val; + int r_val(0); + if (m > 0 && n > 0) { + if (n == 1) { + const int mm = A.extent(0), nn = A.extent(1); + r_val = Blas::gemv(handle, ArgTrans::cublas_param, mm, nn, alpha, A.data(), A.stride_1(), B.data(), + B.stride_0(), beta, C.data(), C.stride_0()); + } else { + const int mm = C.extent(0), nn = C.extent(1), kk = B.extent(0); + r_val = + Blas::gemm(handle, ArgTrans::cublas_param, Trans::NoTranspose::cublas_param, mm, nn, kk, alpha, + A.data(), A.stride_1(), B.data(), B.stride_1(), beta, C.data(), C.stride_1()); } + } + return r_val; + } #endif - - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - typedef typename ViewTypeC::non_const_value_type value_type_c; - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeB::memory_space memory_space_b; - typedef typename ViewTypeC::memory_space memory_space_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same value type."); + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B, + const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeB::memory_space memory_space_b; + typedef typename ViewTypeC::memory_space memory_space_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); - static_assert(std::is_same::value && - std::is_same::value, - "A, B and C do not have the same memory space."); - int r_val(0); - if (std::is_same::value) - r_val = blas_invoke(alpha, A, B, beta, C); -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) - r_val = cublas_invoke(member, alpha, A, B, beta, C); + static_assert(std::is_same::value && std::is_same::value, + "A, B and C do not have the same value type."); + + static_assert(std::is_same::value && + std::is_same::value, + "A, B and C do not have the same memory space."); + int r_val(0); + if (std::is_same::value) + r_val = blas_invoke(alpha, A, B, beta, C); +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || std::is_same::value) + r_val = cublas_invoke(member, alpha, A, B, beta, C); #endif - return r_val; - } - }; + return r_val; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Graph.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Graph.hpp index 195e25d6dd39..5c9dedddb941 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Graph.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Graph.hpp @@ -4,136 +4,113 @@ /// \file Tacho_Graph.hpp /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "Tacho_Util.hpp" #include "Tacho_CrsMatrixBase.hpp" +#include "Tacho_Util.hpp" namespace Tacho { - /// - /// graph reordering happens in host space only - /// - class Graph { - public: - typedef typename UseThisDevice::type host_device_type; - - typedef Kokkos::View ordinal_type_array; - - // scotch use int and long is misinterpreted in pointers - typedef Kokkos::View size_type_array; - - private: - ordinal_type _m; - size_type _nnz; - - size_type_array _rptr; - ordinal_type_array _cidx; - - template - inline - void - init(const ordinal_type m, - const size_type nnz, - const SizeTypeArray &ap, - const OrdinalTypeArray &aj) { - _rptr = size_type_array(do_not_initialize_tag("Graph::rptr"), m + 1); - _cidx = ordinal_type_array(do_not_initialize_tag("Graph::cidx"), nnz); - - _m = m; - _nnz = 0; - for (ordinal_type i=0;i<_m;++i) { - const size_type jbeg = ap(i), jend = ap(i+1); - _rptr(i) = _nnz; - for (size_type j=jbeg;j - Graph(const ordinal_type m, - const size_type nnz, - const SizeTypeArray &ap, - const OrdinalTypeArray &aj) { - init(m, nnz, ap, aj); +/// +/// graph reordering happens in host space only +/// +class Graph { +public: + typedef typename UseThisDevice::type host_device_type; + + typedef Kokkos::View ordinal_type_array; + + // scotch use int and long is misinterpreted in pointers + typedef Kokkos::View size_type_array; + +private: + ordinal_type _m; + size_type _nnz; + + size_type_array _rptr; + ordinal_type_array _cidx; + + template + inline void init(const ordinal_type m, const size_type nnz, const SizeTypeArray &ap, const OrdinalTypeArray &aj) { + _rptr = size_type_array(do_not_initialize_tag("Graph::rptr"), m + 1); + _cidx = ordinal_type_array(do_not_initialize_tag("Graph::cidx"), nnz); + + _m = m; + _nnz = 0; + for (ordinal_type i = 0; i < _m; ++i) { + const size_type jbeg = ap(i), jend = ap(i + 1); + _rptr(i) = _nnz; + for (size_type j = jbeg; j < jend; ++j) { + // skip diagonal + const ordinal_type col = aj(j); + if (i != col) + _cidx(_nnz++) = col; } - - template - inline - Graph(const CrsMatrixBase &A) { - // - // host mirroring - // - CrsMatrixBase AA; - AA.createMirror(A); - AA.copy(A); - - init(AA.NumRows(), AA.NumNonZeros(), AA.RowPtr(), AA.Cols()); - } - - inline - size_type_array RowPtr() const { return _rptr; } - - inline - ordinal_type_array ColIdx() const { return _cidx; } - - inline - ordinal_type NumRows() const { return _m; } - - inline - ordinal_type NumNonZeros() const { return _nnz; } - - inline - void clear() { - _m = 0; - _nnz = 0; - _rptr = size_type_array(); - _cidx = ordinal_type_array(); - } - - std::ostream& showMe(std::ostream &os, const bool detail = false) const { - std::streamsize prec = os.precision(); - os.precision(4); - os << std::scientific; - - os << " -- Graph -- " << std::endl - << " # of Rows = " << _m << std::endl - << " # of NonZeros = " << _nnz << std::endl; - - const int w = 10; - if (detail) { - os << std::setw(w) << "Row" << " " - << std::setw(w) << "Col" - << std::endl; - for (ordinal_type i=0;i<_m;++i) { - const size_type jbeg = _rptr[i], jend = _rptr[i+1]; - for (size_type j=jbeg;j + Graph(const ordinal_type m, const size_type nnz, const SizeTypeArray &ap, const OrdinalTypeArray &aj) { + init(m, nnz, ap, aj); + } + + template inline Graph(const CrsMatrixBase &A) { + // + // host mirroring + // + CrsMatrixBase AA; + AA.createMirror(A); + AA.copy(A); + + init(AA.NumRows(), AA.NumNonZeros(), AA.RowPtr(), AA.Cols()); + } + + inline size_type_array RowPtr() const { return _rptr; } + + inline ordinal_type_array ColIdx() const { return _cidx; } + + inline ordinal_type NumRows() const { return _m; } + + inline ordinal_type NumNonZeros() const { return _nnz; } + + inline void clear() { + _m = 0; + _nnz = 0; + _rptr = size_type_array(); + _cidx = ordinal_type_array(); + } + + std::ostream &showMe(std::ostream &os, const bool detail = false) const { + std::streamsize prec = os.precision(); + os.precision(4); + os << std::scientific; + + os << " -- Graph -- " << std::endl + << " # of Rows = " << _m << std::endl + << " # of NonZeros = " << _nnz << std::endl; + + const int w = 10; + if (detail) { + os << std::setw(w) << "Row" + << " " << std::setw(w) << "Col" << std::endl; + for (ordinal_type i = 0; i < _m; ++i) { + const size_type jbeg = _rptr[i], jend = _rptr[i + 1]; + for (size_type j = jbeg; j < jend; ++j) { + os << std::setw(w) << i << " " << std::setw(w) << _cidx[j] << " " << std::endl; } - - os.unsetf(std::ios::scientific); - os.precision(prec); - - return os; } + } + os.unsetf(std::ios::scientific); + os.precision(prec); - }; - -} + return os; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools.hpp index f82c58b17caa..b4bab697bb87 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools.hpp @@ -4,89 +4,85 @@ /// \file Tacho_GraphTools.hpp /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "Tacho_Util.hpp" #include "Tacho_Graph.hpp" +#include "Tacho_Util.hpp" namespace Tacho { - class GraphTools { - public: - typedef typename UseThisDevice::type host_device_type; - typedef Kokkos::View ordinal_type_array; - - private: - - ordinal_type_array _perm, _peri; - - // status flag - bool _is_ordered, _verbose; - - public: - GraphTools() = default; - GraphTools(const GraphTools &b) = default; - ~GraphTools() = default; - - /// - /// construction of scotch graph - /// - GraphTools(const Graph &g) { - _is_ordered = false; - _verbose = false; - - // input - const ordinal_type m = g.NumRows(); - - // output - _perm = ordinal_type_array(do_not_initialize_tag("PermutationArray"), m); - _peri = ordinal_type_array(do_not_initialize_tag("InvPermutationArray"), m); - } +class GraphTools { +public: + typedef typename UseThisDevice::type host_device_type; + typedef Kokkos::View ordinal_type_array; + +private: + ordinal_type_array _perm, _peri; + + // status flag + bool _is_ordered, _verbose; + +public: + GraphTools() = default; + GraphTools(const GraphTools &b) = default; + ~GraphTools() = default; + + /// + /// construction of scotch graph + /// + GraphTools(const Graph &g) { + _is_ordered = false; + _verbose = false; - /// - /// setup parameters - /// + // input + const ordinal_type m = g.NumRows(); - void setVerbose(const bool verbose) { _verbose = verbose; } + // output + _perm = ordinal_type_array(do_not_initialize_tag("PermutationArray"), m); + _peri = ordinal_type_array(do_not_initialize_tag("InvPermutationArray"), m); + } + /// + /// setup parameters + /// - /// - /// reorder - /// - void reorder(const ordinal_type verbose = 0) { - // do nothing - // later implement AMD - const ordinal_type m = _perm.extent(0); - for (ordinal_type i=0;i - class CAMD; - - template<> - class CAMD { - public: - static void run( int n, int Pe[], int Iw[], - int Len[], int iwlen, int pfree, - int Nv[], int Next[], int Last[], - int Head[], int Elen[], int Degree[], - int W[], - double Control[], - double Info[], - const int C[], - int BucketSet[] ) { - TACHO_SUITESPARSE(camd_2)( n, Pe, Iw, Len, iwlen, pfree, - Nv, Next, Last, Head, Elen, Degree, W, Control, Info, C, BucketSet ); - } - }; - - template<> - class CAMD { - public: - static void run( SuiteSparse_long n, SuiteSparse_long Pe[], SuiteSparse_long Iw[], - SuiteSparse_long Len[], SuiteSparse_long iwlen, SuiteSparse_long pfree, - SuiteSparse_long Nv[], SuiteSparse_long Next[], SuiteSparse_long Last[], - SuiteSparse_long Head[], SuiteSparse_long Elen[], SuiteSparse_long Degree[], - SuiteSparse_long W[], - double Control[], - double Info[], - const SuiteSparse_long C[], - SuiteSparse_long BucketSet[] ) { - TACHO_SUITESPARSE(camd_l2)( n, Pe, Iw, Len, iwlen, pfree, - Nv, Next, Last, Head, Elen, Degree, W, Control, Info, C, BucketSet ); - } - }; - - class GraphTools_CAMD { - public: - typedef Kokkos::DefaultHostExecutionSpace host_exec_space; - typedef Kokkos::View ordinal_type_array; - - private: - // graph input - ordinal_type _m; - size_type _nnz; - ordinal_type_array _rptr, _cidx, _cnst; - - // CAMD output - ordinal_type_array _pe, _nv, _el, _next, _perm, _peri; // perm = last, peri = next - - double _control[TRILINOS_CAMD_CONTROL], _info[TRILINOS_CAMD_INFO]; - - bool _is_ordered; - - public: - - // static assert is necessary to enforce to use host space only - GraphTools_CAMD() = default; - GraphTools_CAMD(const GraphTools_CAMD &b) = default; - - GraphTools_CAMD(const Graph &g) { - _m = g.NumRows(); - _nnz = g.NumNonZeros(); - - _rptr = g.RowPtr(); - _cidx = g.ColIdx(); - _cnst = ordinal_type_array("CAMD::ConstraintArray", _m+1); - - // permutation vector - _pe = ordinal_type_array("CAMD::EliminationArray", _m); - _nv = ordinal_type_array("CAMD::SupernodesArray", _m); - _el = ordinal_type_array("CAMD::DegreeArray", _m); - _next = ordinal_type_array("CAMD::InvPermSupernodesArray", _m); - _perm = ordinal_type_array("CAMD::PermutationArray", _m); - _peri = ordinal_type_array("CAMD::InvPermutationArray", _m); - - } - virtual~GraphTools_CAMD() = default; - - void setConstraint(const ordinal_type nblk, - const ordinal_type_array range, - const ordinal_type_array peri) { - for (ordinal_type i=0;i(_rptr.data()); - ordinal_type *cidx = reinterpret_cast(_cidx.data()); - ordinal_type *cnst = reinterpret_cast(_cnst.data()); - - ordinal_type *next = reinterpret_cast(_next.data()); - ordinal_type *perm = reinterpret_cast(_perm.data()); - - // length array - ordinal_type_array lwork("CAMD::LWorkArray", _m); - ordinal_type *lwork_ptr = reinterpret_cast(lwork.data()); - for (ordinal_type i=0;i<_m;++i) - lwork_ptr[i] = rptr[i+1] - rptr[i]; - - // workspace - const size_type swlen = _nnz + _nnz/5 + 5*(_m+1);; - ordinal_type_array swork("CAMD::SWorkArray", swlen); - ordinal_type *swork_ptr = reinterpret_cast(swork.data()); - - ordinal_type *pe_ptr = reinterpret_cast(_pe.data()); // 1) Pe - size_type pfree = 0; - for (ordinal_type i=0;i<_m;++i) { - pe_ptr[i] = pfree; - pfree += lwork_ptr[i]; - } - TACHO_TEST_FOR_EXCEPTION( _nnz != pfree, - std::logic_error, - ">> nnz in the graph does not match to nnz count (pfree)"); - - ordinal_type *nv_ptr = reinterpret_cast(_nv.data()); // 2) Nv - ordinal_type *hd_ptr = swork_ptr; swork_ptr += (_m+1); // 3) Head - ordinal_type *el_ptr = reinterpret_cast(_el.data()); // 4) Elen - ordinal_type *dg_ptr = swork_ptr; swork_ptr += _m; // 5) Degree - ordinal_type *wk_ptr = swork_ptr; swork_ptr += (_m+1); // 6) W - ordinal_type *bk_ptr = swork_ptr; swork_ptr += _m; // 7) BucketSet - - const size_type iwlen = swlen - (4*_m+2); - ordinal_type *iw_ptr = swork_ptr; swork_ptr += iwlen; // Iw - for (size_type i=0;i::run(_m, pe_ptr, iw_ptr, lwork_ptr, iwlen, pfree, - // output - nv_ptr, next, perm, hd_ptr, el_ptr, dg_ptr, wk_ptr, - _control, _info, cnst, bk_ptr); - t_camd = timer.seconds(); - - TACHO_TEST_FOR_EXCEPTION(_info[TRILINOS_CAMD_STATUS] != TRILINOS_CAMD_OK, - std::runtime_error, - "CAMD fails"); - - for (ordinal_type i=0;i<_m;++i) - _peri[_perm[i]] = i; - - _is_ordered = true; - - if (verbose) { - printf("Summary: GraphTools (CAMD)\n"); - printf("===========================\n"); - - switch (verbose) { - case 1: { - printf(" Time\n"); - printf(" time for reordering: %10.6f s\n", t_camd); - printf("\n"); - } - } - } - } - - - ordinal_type_array PermVector() const { return _perm; } - ordinal_type_array InvPermVector() const { return _peri; } - ordinal_type_array ConstraintVector() const { return _cnst; } - - std::ostream& showMe(std::ostream &os, const bool detail) const { - std::streamsize prec = os.precision(); - os.precision(8); - os << std::scientific; - - os << " -- CAMD input -- " << std::endl - << " # of Rows = " << _m << std::endl - << " # of NonZeros = " << _nnz << std::endl; - - if (_is_ordered) - os << " -- Ordering -- " << std::endl - << " CNST PERM PERI PE NV NEXT ELEN" << std::endl; - - const int w = 6; - for (ordinal_type i=0;i<_m;++i) - os << std::setw(w) << _cnst[i] << " " - << std::setw(w) << _perm[i] << " " - << std::setw(w) << _peri[i] << " " - << std::setw(w) << _pe[i] << " " - << std::setw(w) << _nv[i] << " " - << std::setw(w) << _next[i] << " " - << std::setw(w) << _el[i] << " " - << std::endl; - - os.unsetf(std::ios::scientific); - os.precision(prec); - - return os; - } - - }; - -} - -#endif -#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools_Metis.cpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools_Metis.cpp index 8fb318281fe6..53839aad114e 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools_Metis.cpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools_Metis.cpp @@ -8,119 +8,109 @@ namespace Tacho { - GraphTools_Metis::GraphTools_Metis() = default; - GraphTools_Metis::GraphTools_Metis(const GraphTools_Metis &b) = default; - GraphTools_Metis::GraphTools_Metis(const Graph &g) { - _is_ordered = false; - _verbose = false; - - // input - _nvts = g.NumRows(); - - _xadj = idx_t_array(do_not_initialize_tag("Metis::idx_t_xadj"), g.RowPtr().extent(0)); - _adjncy = idx_t_array(do_not_initialize_tag("Metis::idx_t_adjncy"), g.ColIdx().extent(0)); - _vwgt = idx_t_array(); - - const auto &g_row_ptr = g.RowPtr(); - const auto &g_col_idx = g.ColIdx(); - - for (ordinal_type i=0;i(_xadj.extent(0));++i) - _xadj(i) = g_row_ptr(i); - for (ordinal_type i=0;i(_adjncy.extent(0));++i) - _adjncy(i) = g_col_idx(i); - - METIS_SetDefaultOptions(_options); - _options[METIS_OPTION_NUMBERING] = 0; - - _perm_t = idx_t_array(do_not_initialize_tag("idx_t_perm"), _nvts); - _peri_t = idx_t_array(do_not_initialize_tag("idx_t_peri"), _nvts); - - // output - _perm = ordinal_type_array(do_not_initialize_tag("Metis::PermutationArray"), _nvts); - _peri = ordinal_type_array(do_not_initialize_tag("Metis::InvPermutationArray"), _nvts); - } - GraphTools_Metis::~GraphTools_Metis() {} +GraphTools_Metis::GraphTools_Metis() = default; +GraphTools_Metis::GraphTools_Metis(const GraphTools_Metis &b) = default; +GraphTools_Metis::GraphTools_Metis(const Graph &g) { + _is_ordered = false; + _verbose = false; - void GraphTools_Metis::setVerbose(const bool verbose) { - _verbose = verbose; - } - void GraphTools_Metis::setOption(const int id, const idx_t value) { - _options[id] = value; + // input + _nvts = g.NumRows(); + + _xadj = idx_t_array(do_not_initialize_tag("Metis::idx_t_xadj"), g.RowPtr().extent(0)); + _adjncy = idx_t_array(do_not_initialize_tag("Metis::idx_t_adjncy"), g.ColIdx().extent(0)); + _vwgt = idx_t_array(); + + const auto &g_row_ptr = g.RowPtr(); + const auto &g_col_idx = g.ColIdx(); + + for (ordinal_type i = 0; i < static_cast(_xadj.extent(0)); ++i) + _xadj(i) = g_row_ptr(i); + for (ordinal_type i = 0; i < static_cast(_adjncy.extent(0)); ++i) + _adjncy(i) = g_col_idx(i); + + METIS_SetDefaultOptions(_options); + _options[METIS_OPTION_NUMBERING] = 0; + + _perm_t = idx_t_array(do_not_initialize_tag("idx_t_perm"), _nvts); + _peri_t = idx_t_array(do_not_initialize_tag("idx_t_peri"), _nvts); + + // output + _perm = ordinal_type_array(do_not_initialize_tag("Metis::PermutationArray"), _nvts); + _peri = ordinal_type_array(do_not_initialize_tag("Metis::InvPermutationArray"), _nvts); +} +GraphTools_Metis::~GraphTools_Metis() {} + +void GraphTools_Metis::setVerbose(const bool verbose) { _verbose = verbose; } +void GraphTools_Metis::setOption(const int id, const idx_t value) { _options[id] = value; } + +/// +/// reorder by metis +/// + +void GraphTools_Metis::reorder(const ordinal_type verbose) { + Kokkos::Timer timer; + double t_metis = 0; + + int ierr = 0; + + idx_t *xadj = (idx_t *)_xadj.data(); + idx_t *adjncy = (idx_t *)_adjncy.data(); + idx_t *vwgt = (idx_t *)_vwgt.data(); + + idx_t *perm = (idx_t *)_perm_t.data(); + idx_t *peri = (idx_t *)_peri_t.data(); + + timer.reset(); + ierr = METIS_NodeND(&_nvts, xadj, adjncy, vwgt, _options, perm, peri); + t_metis = timer.seconds(); + + for (idx_t i = 0; i < _nvts; ++i) { + _perm(i) = _perm_t(i); + _peri(i) = _peri_t(i); } - - /// - /// reorder by metis - /// - - void GraphTools_Metis::reorder(const ordinal_type verbose) { - Kokkos::Timer timer; - double t_metis = 0; - - int ierr = 0; - - idx_t *xadj = (idx_t*)_xadj.data(); - idx_t *adjncy = (idx_t*)_adjncy.data(); - idx_t *vwgt = (idx_t*)_vwgt.data(); - - idx_t *perm = (idx_t*)_perm_t.data(); - idx_t *peri = (idx_t*)_peri_t.data(); - - timer.reset(); - ierr = METIS_NodeND(&_nvts, xadj, adjncy, vwgt, _options, - perm, peri); - t_metis = timer.seconds(); - - for (idx_t i=0;i<_nvts;++i) { - _perm(i) = _perm_t(i); - _peri(i) = _peri_t(i); + + TACHO_TEST_FOR_EXCEPTION(ierr != METIS_OK, std::runtime_error, "Failed in METIS_NodeND"); + _is_ordered = true; + + if (verbose) { + printf("Summary: GraphTools (Metis)\n"); + printf("===========================\n"); + + switch (verbose) { + case 1: { + printf(" Time\n"); + printf(" time for reordering: %10.6f s\n", t_metis); + printf("\n"); } - - TACHO_TEST_FOR_EXCEPTION(ierr != METIS_OK, - std::runtime_error, - "Failed in METIS_NodeND"); - _is_ordered = true; - - if (verbose) { - printf("Summary: GraphTools (Metis)\n"); - printf("===========================\n"); - - switch (verbose) { - case 1: { - printf(" Time\n"); - printf(" time for reordering: %10.6f s\n", t_metis); - printf("\n"); - } - } } } - - typename GraphTools_Metis::ordinal_type_array GraphTools_Metis::PermVector() const { return _perm; } - typename GraphTools_Metis::ordinal_type_array GraphTools_Metis::InvPermVector() const { return _peri; } - - std::ostream& GraphTools_Metis::showMe(std::ostream &os, const bool detail) const { - std::streamsize prec = os.precision(); - os.precision(4); - os << std::scientific; - - if (_is_ordered) - os << " -- Metis Ordering -- " << std::endl - << " PERM PERI " << std::endl; - else - os << " -- Not Ordered -- " << std::endl; - - if (detail) { - const ordinal_type w = 6, m = _perm.extent(0); - for (ordinal_type i=0;i::type host_device_type; - - typedef Kokkos::View idx_t_array; - typedef Kokkos::View ordinal_type_array; - - private: - - // metis main data structure - idx_t _nvts; - idx_t_array _xadj, _adjncy, _vwgt; - - idx_t _options[METIS_NOPTIONS]; - - // metis output - idx_t_array _perm_t, _peri_t; - ordinal_type_array _perm, _peri; - - // status flag - bool _is_ordered, _verbose; - - public: - GraphTools_Metis(); - GraphTools_Metis(const GraphTools_Metis &b); - - /// - /// construction of scotch graph - /// - GraphTools_Metis(const Graph &g); - virtual~GraphTools_Metis(); - - /// - /// setup metis parameters - /// - - void setVerbose(const bool verbose); - void setOption(const int id, const idx_t value); - - /// - /// reorder by metis - /// - - void reorder(const ordinal_type verbose = 0); - - ordinal_type_array PermVector() const; - ordinal_type_array InvPermVector() const; - - std::ostream& showMe(std::ostream &os, const bool detail = false) const; - }; - -} +class GraphTools_Metis { +public: + typedef typename UseThisDevice::type host_device_type; + + typedef Kokkos::View idx_t_array; + typedef Kokkos::View ordinal_type_array; + +private: + // metis main data structure + idx_t _nvts; + idx_t_array _xadj, _adjncy, _vwgt; + + idx_t _options[METIS_NOPTIONS]; + + // metis output + idx_t_array _perm_t, _peri_t; + ordinal_type_array _perm, _peri; + + // status flag + bool _is_ordered, _verbose; + +public: + GraphTools_Metis(); + GraphTools_Metis(const GraphTools_Metis &b); + + /// + /// construction of scotch graph + /// + GraphTools_Metis(const Graph &g); + virtual ~GraphTools_Metis(); + + /// + /// setup metis parameters + /// + + void setVerbose(const bool verbose); + void setOption(const int id, const idx_t value); + + /// + /// reorder by metis + /// + + void reorder(const ordinal_type verbose = 0); + + ordinal_type_array PermVector() const; + ordinal_type_array InvPermVector() const; + + std::ostream &showMe(std::ostream &os, const bool detail = false) const; +}; + +} // namespace Tacho #endif #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools_MetisMT.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools_MetisMT.hpp deleted file mode 100644 index 464995db6011..000000000000 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_GraphTools_MetisMT.hpp +++ /dev/null @@ -1,187 +0,0 @@ -#ifndef __TACHO_GRAPH_TOOLS_METIS_MT_HPP__ -#define __TACHO_GRAPH_TOOLS_METIS_MT_HPP__ - -/// \file Tacho_GraphTools_Metis_MT.hpp -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "Tacho_Util.hpp" -#if defined(TACHO_HAVE_METIS_MT) - -#include "Tacho_Graph.hpp" - -#include "mtmetis.h" - -namespace Tacho { - - class GraphTools_MetisMT { - public: - typedef typename UseThisDevice host_device_type; - typedef typename host_device_type::execution_space host_space; - - typedef Kokkos::View mtmetis_vtx_type_array; - typedef Kokkos::View mtmetis_adj_type_array; - typedef Kokkos::View mtmetis_wgt_type_array; - typedef Kokkos::View mtmetis_pid_type_array; - - typedef Kokkos::View ordinal_type_array; - - private: - - // metis main data structure - mtmetis_vtx_type _nvts; - mtmetis_vtx_type_array _xadj; - mtmetis_adj_type_array _adjncy; - mtmetis_wgt_type_array _vwgt; - - double _options[MTMETIS_NOPTIONS]; - - // metis output - mtmetis_pid_type_array _perm_t, _peri_t; - ordinal_type_array _perm, _peri; - - // status flag - bool _is_ordered, _verbose; - - public: - GraphTools_MetisMT() = default; - GraphTools_MetisMT(const GraphTools_MetisMT &b) = default; - - /// - /// construction of scotch graph - /// - GraphTools_MetisMT(const Graph &g) { - _is_ordered = false; - _verbose = false; - - // input - _nvts = g.NumRows(); - - _xadj = mtmetis_vtx_type_array("vtx_type_xadj", g.RowPtr().extent(0)); - _adjncy = mtmetis_adj_type_array("adj_type_adjncy", g.ColIdx().extent(0)); - _vwgt = mtmetis_wgt_type_array(); - - const auto &g_row_ptr = g.RowPtr(); - const auto &g_col_idx = g.ColIdx(); - - for (ordinal_type i=0;i(_xadj.extent(0));++i) - _xadj(i) = g_row_ptr(i); - for (ordinal_type i=0;i(_adjncy.extent(0));++i) - _adjncy(i) = g_col_idx(i); - - // default - for (ordinal_type i=0;i(MTMETIS_NOPTIONS);++i) - _options[i] = MTMETIS_VAL_OFF; - - // by default, metis use - // # of threads : omp_get_max_threads - // seed : (unsigned int)time(NULL) - // internal verbose options are : - // MTMETIS_VERBOSITY_NONE, - // MTMETIS_VERBOSITY_LOW, - // MTMETIS_VERBOSITY_MEDIUM, - // MTMETIS_VERBOSITY_HIGH, - // MTMETIS_VERBOSITY_MAXIMUM - - _options[MTMETIS_OPTION_NTHREADS] = host_space::thread_pool_size(0); // from kokkos - //_options[MTMETIS_OPTION_SEED] = 0; // for testing, use the same seed now - //_options[MTMETIS_OPTION_PTYPE] = MTMETIS_PTYPE_ND; // when explicit interface is used - //_options[MTMETIS_OPTION_VERBOSITY] = MTMETIS_VERBOSITY_NONE; - //_options[MTMETIS_OPTION_METIS] = 1; // flag to use serial metis - - _perm_t = mtmetis_pid_type_array("pid_type_perm", _nvts); - _peri_t = mtmetis_pid_type_array("pid_type_peri", _nvts); - - // output - _perm = ordinal_type_array("MetisMT::PermutationArray", _nvts); - _peri = ordinal_type_array("MetisMT::InvPermutationArray", _nvts); - } - virtual~GraphTools_MetisMT() {} - - /// - /// setup metis parameters - /// - - void setVerbose(const bool verbose) { _verbose = verbose; } - void setOption(const int id, const double value) { - _options[id] = value; - } - - /// - /// reorder by metis - /// - - void reorder(const ordinal_type verbose = 0) { - Kokkos::Timer timer; - double t_metis = 0; - - int ierr = 0; - - mtmetis_vtx_type *xadj = (mtmetis_vtx_type*)_xadj.data(); - mtmetis_adj_type *adjncy = (mtmetis_adj_type*)_adjncy.data(); - mtmetis_wgt_type *vwgt = (mtmetis_wgt_type*)_vwgt.data(); - - mtmetis_pid_type *perm = (mtmetis_pid_type*)_perm_t.data(); - mtmetis_pid_type *peri = (mtmetis_pid_type*)_peri_t.data(); - - timer.reset(); - ierr = MTMETIS_NodeND(&_nvts, xadj, adjncy, vwgt, _options, - perm, peri); - t_metis = timer.seconds(); - - for (mtmetis_vtx_type i=0;i<_nvts;++i) { - _perm(i) = _perm_t(i); - _peri(i) = _peri_t(i); - } - - TACHO_TEST_FOR_EXCEPTION(ierr != MTMETIS_SUCCESS, - std::runtime_error, - "Failed in METIS_NodeND"); - _is_ordered = true; - - if (verbose) { - printf("Summary: GraphTools (MetisMT)\n"); - printf("=============================\n"); - - switch (verbose) { - case 1: { - printf(" Time\n"); - printf(" time for reordering: %10.6f s\n", t_metis); - printf("\n"); - } - } - } - - } - - ordinal_type_array PermVector() const { return _perm; } - ordinal_type_array InvPermVector() const { return _peri; } - - std::ostream& showMe(std::ostream &os, const bool detail = false) const { - std::streamsize prec = os.precision(); - os.precision(4); - os << std::scientific; - - if (_is_ordered) - os << " -- MetisMT Ordering -- " << std::endl - << " PERM PERI " << std::endl; - else - os << " -- Not Ordered -- " << std::endl; - - if (detail) { - const ordinal_type w = 6, m = _perm.extent(0); - for (ordinal_type i=0;i host_device_type; - typedef Kokkos::View ordinal_type_array; - - enum : int { DefaultRandomSeed = -1 }; - - private: - - // scotch main data structure - SCOTCH_Graph _graph; - SCOTCH_Num _strat; - int _level; - - // scotch output - ordinal_type _cblk; - ordinal_type_array _perm,_peri,_range,_tree; - - // status flag - bool _is_ordered, _verbose; - - public: - GraphTools_Scotch() = default; - GraphTools_Scotch(const GraphTools_Scotch &b) = default; - - /// - /// construction of scotch graph - /// - GraphTools_Scotch(const Graph &g) { - _is_ordered = false; - _verbose = false; - - // input - const ordinal_type base = 0; - const ordinal_type m = g.NumRows(); - const size_type nnz = g.NumNonZeros(); - - // scotch control parameter - _strat = 0; - _level = 0; - - // output - _cblk = 0; - _perm = ordinal_type_array("Scotch::PermutationArray", m); - _peri = ordinal_type_array("Scotch::InvPermutationArray", m); - _range = ordinal_type_array("Scotch::RangeArray", m); - _tree = ordinal_type_array("Scotch::TreeArray", m); - - // construct scotch graph - int ierr = 0; - const SCOTCH_Num *rptr_ptr = reinterpret_cast(g.RowPtr().data()); - const SCOTCH_Num *cidx_ptr = reinterpret_cast(g.ColIdx().data()); - - ierr = SCOTCH_graphInit(&_graph); - TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_graphInit"); - - ierr = SCOTCH_graphBuild(&_graph, // scotch graph - base, // base value - m, // # of vertices - rptr_ptr, // column index array pointer begin - rptr_ptr+1, // column index array pointer end - NULL, // weights on vertices (optional) - NULL, // label array on vertices (optional) - nnz, // # of nonzeros - cidx_ptr, // column index array - NULL); // edge load array (optional) - TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_graphBuild"); - - ierr = SCOTCH_graphCheck(&_graph); - TACHO_TEST_FOR_EXCEPTION(ierr, std::runtime_error, "Failed in SCOTCH_graphCheck"); - } - virtual~GraphTools_Scotch() { - SCOTCH_graphFree(&_graph); - } - - /// - /// setup scotch parameters - /// - - void setVerbose(const bool verbose) { _verbose = verbose; } - void setSeed(const int seed = DefaultRandomSeed) { - if (seed != DefaultRandomSeed) { - SCOTCH_randomSeed(seed); - SCOTCH_randomReset(); - } - } - - void setStrategy(const SCOTCH_Num strat = 0) { - // a typical choice - //(SCOTCH_STRATLEVELMAX));// | - //SCOTCH_STRATLEVELMIN | - //SCOTCH_STRATLEAFSIMPLE | - //SCOTCH_STRATSEPASIMPLE); - _strat = strat; - } - - void setTreeLevel(const unsigned int level = 0) { - _level = level; - } - - /// - /// setup scotch parameters - /// - - void reorder(const ordinal_type verbose = 0) { - Kokkos::Timer timer; - double t_scotch = 0; - - _verbose = verbose; - - const int treecut = 0; - int ierr = 0; - - // pointers for global graph ordering - ordinal_type *perm = _perm.data(); - ordinal_type *peri = _peri.data(); - ordinal_type *range = _range.data(); - ordinal_type *tree = _tree.data(); - - timer.reset(); - { - // set desired tree level - if (_strat & SCOTCH_STRATLEVELMAX || - _strat & SCOTCH_STRATLEVELMIN) { - TACHO_TEST_FOR_EXCEPTION(_level == 0, - std::logic_error, - "SCOTCH_STRATLEVEL(MIN/MAX) is used but level is not specified"); - } - const int level = max(1, _level-treecut); - - SCOTCH_Strat stradat; - SCOTCH_Num straval = _strat; - - ierr = SCOTCH_stratInit(&stradat); - TACHO_TEST_FOR_EXCEPTION(ierr, - std::runtime_error, - "Failed in SCOTCH_stratInit"); - - - // if both are zero, do not build strategy - if (_strat || _level) { - ierr = SCOTCH_stratGraphOrderBuild(&stradat, straval, level, 0.2); - TACHO_TEST_FOR_EXCEPTION(ierr, - std::runtime_error, - "Failed in SCOTCH_stratGraphOrderBuild"); - } - ierr = SCOTCH_graphOrder(&_graph, - &stradat, - perm, - peri, - &_cblk, - range, - tree); - TACHO_TEST_FOR_EXCEPTION(ierr, - std::runtime_error, - "Failed in SCOTCH_graphOrder"); - SCOTCH_stratExit(&stradat); - } - t_scotch = timer.seconds(); - _is_ordered = true; - - if (_verbose) { - printf("Summary: GraphTools (Scotch)\n"); - printf("===========================\n"); - printf(" Time\n"); - printf(" time for reordering: %10.6f s\n", t_scotch); - printf("\n"); - if (_strat || _level) { - printf(" User provided strategy ( %d ) and/or level ( %d )\n", _strat, _level); - printf(" strategy & SCOTCH_STRATLEVELMAX: %3d\n", (_strat & SCOTCH_STRATLEVELMAX)); - printf(" strategy & SCOTCH_STRATLEVELMIN: %3d\n", (_strat & SCOTCH_STRATLEVELMIN)); - printf(" strategy & SCOTCH_STRATLEAFSIMPLE: %3d\n", (_strat & SCOTCH_STRATLEAFSIMPLE)); - printf(" strategy & SCOTCH_STRATSEPASIMPLE: %3d\n", (_strat & SCOTCH_STRATSEPASIMPLE)); - printf("\n"); - } - printf(" Partitions\n"); - printf(" number of block partitions: %3d\n", _cblk); - printf("\n"); - } - } - - ordinal_type_array PermVector() const { return _perm; } - ordinal_type_array InvPermVector() const { return _peri; } - - ordinal_type_array RangeVector() const { return _range; } - ordinal_type_array TreeVector() const { return _tree; } - - ordinal_type NumBlocks() const { return _cblk; } - ordinal_type TreeLevel() const { - ordinal_type r_val; - if (_strat & SCOTCH_STRATLEVELMAX || - _strat & SCOTCH_STRATLEVELMIN) - r_val = _level; - else - r_val = 0; - return r_val; - } - - std::ostream& showMe(std::ostream &os, const bool detail = false) const { - std::streamsize prec = os.precision(); - os.precision(4); - os << std::scientific; - - if (_is_ordered) - os << " -- Scotch Ordering -- " << std::endl - << " CBLK = " << _cblk << std::endl - << " PERM PERI RANG TREE" << std::endl; - else - os << " -- Not Ordered -- " << std::endl; - - if (detail) { - const ordinal_type w = 6, m = _perm.extent(0); - for (ordinal_type i=0;i - struct Herk; - - /// task construction for the above chol implementation - /// Herk ::invoke(_sched, member, _alpha, _A, _beta, _C); - template - struct TaskFunctor_Herk { - public: - typedef SchedulerType scheduler_type; - typedef typename scheduler_type::member_type member_type; +/// +/// Herk: +/// - typedef ScalarType scalar_type; +/// various implementation for different uplo and algo parameters +template struct Herk; - typedef DenseMatrixViewType dense_block_type; - typedef typename dense_block_type::future_type future_type; - typedef typename future_type::value_type value_type; +struct HerkAlgorithm { + using type = ActiveAlgorithm::type; +}; - private: - scalar_type _alpha, _beta; - dense_block_type _A, _C; - - public: - KOKKOS_INLINE_FUNCTION - TaskFunctor_Herk() = delete; - - KOKKOS_INLINE_FUNCTION - TaskFunctor_Herk(const scalar_type alpha, - const dense_block_type &A, - const scalar_type beta, - const dense_block_type &C) - : _alpha(alpha), - _beta(beta), - _A(A), - _C(C) {} - - KOKKOS_INLINE_FUNCTION - void operator()(member_type &member, value_type &r_val) { - const int ierr = Herk - ::invoke(member, _alpha, _A, _beta, _C); - - Kokkos::single(Kokkos::PerThread(member), - [&, ierr] () { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - _C.set_future(); - r_val = ierr; - }); - } - }; - -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_External.hpp index 242ddfa45118..50201537e812 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_External.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_HERK_EXTERNAL_HPP__ #define __TACHO_HERK_EXTERNAL_HPP__ - /// \file Tacho_Herk_External.hpp /// \brief BLAS hermitian rank-k update /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,66 +9,43 @@ namespace Tacho { - template - struct Herk { - template - inline - static int - invoke(const ScalarType alpha, - const ViewTypeA &A, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and C do not have the same value type."); - - const ordinal_type - n = C.extent(0), - k = (std::is_same::value ? A.extent(1) : A.extent(0)); - if (n > 0 && k > 0) { - Blas::herk(ArgUplo::param, - ArgTrans::param, - n, k, - value_type(alpha), - A.data(), A.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - } +template struct Herk { + template + inline static int invoke(const ScalarType alpha, const ViewTypeA &A, const ScalarType beta, const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "B is not rank 2 view."); + + static_assert(std::is_same::value, "A and C do not have the same value type."); + + const ordinal_type n = C.extent(0), + k = (std::is_same::value ? A.extent(1) : A.extent(0)); + if (n > 0 && k > 0) { + Blas::herk(ArgUplo::param, ArgTrans::param, n, k, value_type(alpha), A.data(), A.stride_1(), + value_type(beta), C.data(), C.stride_1()); + } #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return 0; - } + return 0; + } - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ScalarType beta, - const ViewTypeC &C) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - //Kokkos::single(Kokkos::PerTeam(member), [&]() { - invoke(alpha, A, beta, C); - //}); + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ScalarType beta, + const ViewTypeC &C) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + // Kokkos::single(Kokkos::PerTeam(member), [&]() { + invoke(alpha, A, beta, C); + //}); #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return 0; - } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_Internal.hpp index 3278d08abf31..ab085db4c9ba 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_Internal.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_HERK_INTERNAL_HPP__ #define __TACHO_HERK_INTERNAL_HPP__ - /// \file Tacho_Herk_Internal.hpp /// \brief BLAS hermitian rank-k update /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,45 +9,28 @@ namespace Tacho { - template - struct Herk { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ScalarType beta, - const ViewTypeC &C) { - - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeC::non_const_value_type value_type_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and C do not have the same value type."); - - const ordinal_type n = C.extent(0); - const ordinal_type - k = (std::is_same::value ? A.extent(1) : A.extent(0)); - - if (n > 0 && k > 0) - BlasTeam::herk(member, - ArgUplo::param, - ArgTrans::param, - n, k, - value_type(alpha), - A.data(), A.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - return 0; - } - }; - -} +template struct Herk { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, + const ScalarType beta, const ViewTypeC &C) { + + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "B is not rank 2 view."); + + static_assert(std::is_same::value, "A and C do not have the same value type."); + + const ordinal_type n = C.extent(0); + const ordinal_type k = (std::is_same::value ? A.extent(1) : A.extent(0)); + + if (n > 0 && k > 0) + BlasTeam::herk(member, ArgUplo::param, ArgTrans::param, n, k, value_type(alpha), A.data(), + A.stride_1(), value_type(beta), C.data(), C.stride_1()); + return 0; + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_OnDevice.hpp index 795ce734c7ab..332126eea202 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_OnDevice.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_HERK_ON_DEVICE_HPP__ #define __TACHO_HERK_ON_DEVICE_HPP__ - /// \file Tacho_Herk_OnDevice.hpp /// \brief BLAS hermitian rank-k update /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -9,114 +8,70 @@ #include "Tacho_Blas_External.hpp" namespace Tacho { - - template - struct Herk { - template - inline - static int - blas_invoke(const ScalarType alpha, - const ViewTypeA &A, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - n = C.extent(0), - k = (std::is_same::value ? A.extent(1) : A.extent(0)); - - if (n > 0 && k > 0) { - Blas::herk(ArgUplo::param, - ArgTrans::param, - n, k, - value_type(alpha), - A.data(), A.stride_1(), - value_type(beta), - C.data(), C.stride_1()); - } - return 0; + +template struct Herk { + template + inline static int blas_invoke(const ScalarType alpha, const ViewTypeA &A, const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type n = C.extent(0), + k = (std::is_same::value ? A.extent(1) : A.extent(0)); + + if (n > 0 && k > 0) { + Blas::herk(ArgUplo::param, ArgTrans::param, n, k, value_type(alpha), A.data(), A.stride_1(), + value_type(beta), C.data(), C.stride_1()); } + return 0; + } #if defined(KOKKOS_ENABLE_CUDA) - template - inline - static int - cublas_invoke(cublasHandle_t &handle, - const ScalarType alpha, - const ViewTypeA &A, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - n = C.extent(0), - k = (std::is_same::value ? A.extent(1) : A.extent(0)); + template + inline static int cublas_invoke(cublasHandle_t &handle, const ScalarType alpha, const ViewTypeA &A, + const ScalarType beta, const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type n = C.extent(0), + k = (std::is_same::value ? A.extent(1) : A.extent(0)); - int r_val(0); - if (n > 0 && k > 0) { - if (std::is_same::value || - std::is_same::value) - r_val = Blas::herk(handle, - ArgUplo::cublas_param, - std::is_same::value ? - Trans::Transpose::cublas_param : ArgTrans::cublas_param, - n, k, - alpha, - A.data(), A.stride_1(), - beta, - C.data(), C.stride_1()); - else if (std::is_same >::value || - std::is_same >::value) - r_val = Blas::herk(handle, - ArgUplo::cublas_param, - ArgTrans::cublas_param, - n, k, - alpha, - A.data(), A.stride_1(), - beta, - C.data(), C.stride_1()); - } - return r_val; + int r_val(0); + if (n > 0 && k > 0) { + if (std::is_same::value || std::is_same::value) + r_val = + Blas::herk(handle, ArgUplo::cublas_param, + std::is_same::value ? Trans::Transpose::cublas_param + : ArgTrans::cublas_param, + n, k, alpha, A.data(), A.stride_1(), beta, C.data(), C.stride_1()); + else if (std::is_same>::value || + std::is_same>::value) + r_val = Blas::herk(handle, ArgUplo::cublas_param, ArgTrans::cublas_param, n, k, alpha, A.data(), + A.stride_1(), beta, C.data(), C.stride_1()); } + return r_val; + } #endif - - template - inline - static int - invoke(MemberType &member, - const ScalarType alpha, - const ViewTypeA &A, - const ScalarType beta, - const ViewTypeC &C) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeC::non_const_value_type value_type_c; - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeC::memory_space memory_space_c; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeC::rank == 2,"C is not rank 2 view."); - - static_assert(std::is_same::value, - "A and C do not have the same value type."); + template + inline static int invoke(MemberType &member, const ScalarType alpha, const ViewTypeA &A, const ScalarType beta, + const ViewTypeC &C) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeC::non_const_value_type value_type_c; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeC::memory_space memory_space_c; - static_assert(std::is_same::value, - "A and C do not have the same memory space."); - int r_val(0); - if (std::is_same::value) - r_val = blas_invoke(alpha, A, beta, C); -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) - r_val = cublas_invoke(member, alpha, A, beta, C); + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeC::rank == 2, "C is not rank 2 view."); + + static_assert(std::is_same::value, "A and C do not have the same value type."); + + static_assert(std::is_same::value, "A and C do not have the same memory space."); + int r_val(0); + if (std::is_same::value) + r_val = blas_invoke(alpha, A, beta, C); +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || std::is_same::value) + r_val = cublas_invoke(member, alpha, A, beta, C); #endif - return r_val; - } - }; + return r_val; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Internal.hpp index d426aebdce21..a0285ee61a8b 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Internal.hpp @@ -8,22 +8,19 @@ #include "Tacho_config.h" #include -#include #include +#include #include -#include "Tacho_Util.hpp" #include "Tacho_Partition.hpp" +#include "Tacho_Util.hpp" #include "Tacho_CrsMatrixBase.hpp" -#include "Tacho_DenseMatrixView.hpp" - -#include "Tacho_MatrixMarket.hpp" +#include "Tacho_MatrixMarket.hpp" #include "Tacho_Graph.hpp" -#include "Tacho_GraphTools.hpp" -#include "Tacho_GraphTools_Metis.hpp" -#include "Tacho_GraphTools_Scotch.hpp" +#include "Tacho_GraphTools.hpp" +#include "Tacho_GraphTools_Metis.hpp" #include "Tacho_SupernodeInfo.hpp" #include "Tacho_SymbolicTools.hpp" @@ -43,6 +40,9 @@ #include "Tacho_SetIdentity.hpp" #include "Tacho_SetIdentity_Internal.hpp" +#include "Tacho_ApplyPivots.hpp" +#include "Tacho_ApplyPivots_Internal.hpp" + #include "Tacho_ApplyPermutation.hpp" #include "Tacho_ApplyPermutation_Internal.hpp" @@ -52,27 +52,32 @@ #include "Tacho_Chol.hpp" #include "Tacho_Chol_External.hpp" #include "Tacho_Chol_Internal.hpp" -#include "Tacho_Chol_ByBlocks.hpp" +// #include "Tacho_Chol_ByBlocks.hpp" #include "Tacho_LDL.hpp" #include "Tacho_LDL_External.hpp" #include "Tacho_LDL_Internal.hpp" //#include "Tacho_LDL_ByBlocks.hpp" +#include "Tacho_LU.hpp" +#include "Tacho_LU_External.hpp" +#include "Tacho_LU_Internal.hpp" +//#include "Tacho_LU_ByBlocks.hpp" + #include "Tacho_Trsm.hpp" #include "Tacho_Trsm_External.hpp" #include "Tacho_Trsm_Internal.hpp" -#include "Tacho_Trsm_ByBlocks.hpp" +// #include "Tacho_Trsm_ByBlocks.hpp" #include "Tacho_Herk.hpp" #include "Tacho_Herk_External.hpp" #include "Tacho_Herk_Internal.hpp" -#include "Tacho_Herk_ByBlocks.hpp" +// #include "Tacho_Herk_ByBlocks.hpp" #include "Tacho_Gemm.hpp" #include "Tacho_Gemm_External.hpp" #include "Tacho_Gemm_Internal.hpp" -#include "Tacho_Gemm_ByBlocks.hpp" +// #include "Tacho_Gemm_ByBlocks.hpp" #include "Tacho_GemmTriangular.hpp" #include "Tacho_GemmTriangular_External.hpp" @@ -91,26 +96,28 @@ #include "Tacho_CholSupernodes_Serial.hpp" #include "Tacho_CholSupernodes_SerialPanel.hpp" -#include "Tacho_TaskFunctor_FactorizeChol.hpp" -#include "Tacho_TaskFunctor_FactorizeCholPanel.hpp" -#include "Tacho_TaskFunctor_FactorizeCholByBlocks.hpp" -#include "Tacho_TaskFunctor_FactorizeCholByBlocksPanel.hpp" +// #include "Tacho_TaskFunctor_FactorizeChol.hpp" +// #include "Tacho_TaskFunctor_FactorizeCholPanel.hpp" +// #include "Tacho_TaskFunctor_FactorizeCholByBlocks.hpp" +// #include "Tacho_TaskFunctor_FactorizeCholByBlocksPanel.hpp" -#include "Tacho_TaskFunctor_SolveLowerChol.hpp" -#include "Tacho_TaskFunctor_SolveUpperChol.hpp" +// #include "Tacho_TaskFunctor_SolveLowerChol.hpp" +// #include "Tacho_TaskFunctor_SolveUpperChol.hpp" -#include "Tacho_NumericTools.hpp" -#include "Tacho_LevelSetTools.hpp" -#include "Tacho_TriSolveTools.hpp" +// #include "Tacho_NumericTools.hpp" +// #include "Tacho_LevelSetTools.hpp" +// #include "Tacho_TriSolveTools.hpp" // refactoring #include "Tacho_NumericTools_Base.hpp" -#include "Tacho_NumericTools_Serial.hpp" #include "Tacho_NumericTools_LevelSet.hpp" +#include "Tacho_NumericTools_Serial.hpp" + +#include "Tacho_NumericTools_Factory.hpp" -// Do not include this. +// Do not include this. // In a gcc (4.9.x), this causes some multiple definition link error with gcc headers. // No idea yet why it happens as the code is guarded by Tacho::Experimental namespace. -//#include "Tacho_CommandLineParser.hpp" +//#include "Tacho_CommandLineParser.hpp" #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL.hpp index 3816b4532121..10e98629e63b 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL.hpp @@ -8,17 +8,19 @@ #include "Tacho_Util.hpp" namespace Tacho { - - /// - /// LDL: - /// - /// - - /// various implementation for different uplo and algo parameters - template - struct LDL; - -} + +/// +/// LDL: +/// +/// + +/// various implementation for different uplo and algo parameters +template struct LDL; + +struct LDL_Algorithm { + using type = ActiveAlgorithm::type; +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_External.hpp index 0d0cb8ece895..9a4b3ea16499 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_External.hpp @@ -9,178 +9,137 @@ namespace Tacho { - /// LAPACK LDL - /// ========== - template<> - struct LDL { - template - inline - static int - invoke(const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeW &W) { - int r_val = 0; -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeP::rank == 1,"P is not rank 1 view."); - static_assert(ViewTypeW::rank == 1,"W is not rank 1 view."); - - TACHO_TEST_FOR_EXCEPTION(P.extent(0) < 4*A.extent(0), std::runtime_error, - "P should be 4*A.extent(0) ."); - - const ordinal_type m = A.extent(0); - if (m > 0) { - /// factorize LDL - Lapack::sytrf('L', - m, - A.data(), A.stride_1(), - P.data(), - W.data(), W.extent(0), - &r_val); - TACHO_TEST_FOR_EXCEPTION(r_val, std::runtime_error, - "LAPACK (sytrf) returns non-zero error code."); - } +/// LAPACK LDL +/// ========== +template <> struct LDL { + template + inline static int invoke(const ViewTypeA &A, const ViewTypeP &P, const ViewTypeW &W) { + int r_val = 0; +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + static_assert(ViewTypeW::rank == 1, "W is not rank 1 view."); + + TACHO_TEST_FOR_EXCEPTION(P.extent(0) < 4 * A.extent(0), std::runtime_error, "P should be 4*A.extent(0) ."); + + const ordinal_type m = A.extent(0); + if (m > 0) { + /// factorize LDL + Lapack::sytrf('L', m, A.data(), A.stride_1(), P.data(), W.data(), W.extent(0), &r_val); + TACHO_TEST_FOR_EXCEPTION(r_val, std::runtime_error, "LAPACK (sytrf) returns non-zero error code."); + } #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space." ); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return r_val; - } - - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeW &W) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - int r_val = 0; - r_val = invoke(A, P, W); - return r_val; + return r_val; + } + + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeP &P, const ViewTypeW &W) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + int r_val = 0; + r_val = invoke(A, P, W); + return r_val; #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space." ); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - } - - template - inline - static int - modify(const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeD &D) { - int r_val = 0; -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeP::rank == 1,"P is not rank 1 view."); - static_assert(ViewTypeD::rank == 2,"D is not rank 2 view."); - - TACHO_TEST_FOR_EXCEPTION(D.extent(0) < A.extent(0), std::runtime_error, - "D extent(0) is smaller than A extent(0)."); - TACHO_TEST_FOR_EXCEPTION(D.extent(1) != 2, std::runtime_error, - "D is supposed to store 2x2 blocks ."); - TACHO_TEST_FOR_EXCEPTION(P.extent(0) < 4*A.extent(0), std::runtime_error, - "P should be 4*A.extent(0) ."); - - const ordinal_type m = A.extent(0); - if (m > 0) { - value_type - *__restrict__ Aptr = A.data(); - ordinal_type - *__restrict__ ipiv = P.data(), - *__restrict__ fpiv = ipiv + m, - *__restrict__ perm = fpiv + m, - *__restrict__ peri = perm + m; - - const value_type one(1), zero(0); - for (ordinal_type i=0;i + inline static int modify(const ViewTypeA &A, const ViewTypeP &P, const ViewTypeD &D) { + int r_val = 0; +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + static_assert(ViewTypeD::rank == 2, "D is not rank 2 view."); + + TACHO_TEST_FOR_EXCEPTION(D.extent(0) < A.extent(0), std::runtime_error, "D extent(0) is smaller than A extent(0)."); + TACHO_TEST_FOR_EXCEPTION(D.extent(1) != 2, std::runtime_error, "D is supposed to store 2x2 blocks ."); + TACHO_TEST_FOR_EXCEPTION(P.extent(0) < 4 * A.extent(0), std::runtime_error, "P should be 4*A.extent(0) ."); + + const ordinal_type m = A.extent(0); + if (m > 0) { + value_type *__restrict__ Aptr = A.data(); + ordinal_type *__restrict__ ipiv = P.data(), *__restrict__ fpiv = ipiv + m, *__restrict__ perm = fpiv + m, + *__restrict__ peri = perm + m; + + const value_type one(1), zero(0); + for (ordinal_type i = 0; i < m; ++i) + perm[i] = i; + for (ordinal_type i = 0; i < m; ++i) { + if (ipiv[i] < 0) { + const bool is_first = (i + 1) < m ? (ipiv[i + 1] == ipiv[i]) : false; + if (is_first) { + ipiv[i] = 0; /// invalidate this pivot + fpiv[i] = 0; + + D(i, 0) = A(i, i); + D(i, 1) = A(i + 1, i); /// symmetric + A(i, i) = one; } else { - const ordinal_type fla_pivot = ipiv[i]-i-1; + const ordinal_type fla_pivot = -ipiv[i] - i - 1; fpiv[i] = fla_pivot; if (fla_pivot) { - value_type * src = Aptr + i; - value_type * tgt = src + fla_pivot; - for (ordinal_type j=0;j> This function is only allowed in host space." ); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - return r_val; - } - - - template - inline - static int - modify(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeD &D) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - int r_val = 0; - r_val = modify(A, P, D); - return r_val; + return r_val; + } + + template + inline static int modify(MemberType &member, const ViewTypeA &A, const ViewTypeP &P, const ViewTypeD &D) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + int r_val = 0; + r_val = modify(A, P, D); + return r_val; #else - TACHO_TEST_FOR_ABORT( true, ">> This function is only allowed in host space." ); + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); #endif - } - - }; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Internal.hpp index 8733a7a94769..2d488d65147f 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Internal.hpp @@ -9,181 +9,140 @@ namespace Tacho { - /// LAPACK LDL - /// ========== - template<> - struct LDL { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeW &W) { - typedef typename ViewTypeA::non_const_value_type value_type; - //typedef typename ViewTypeP::non_const_value_type p_value_type; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeP::rank == 1,"P is not rank 1 view."); - static_assert(ViewTypeW::rank == 1,"W is not rank 1 view."); - - TACHO_TEST_FOR_ABORT(P.extent(0) < 4*A.extent(0), - "P should be 4*A.extent(0) ."); - - int r_val(0); - const ordinal_type m = A.extent(0); - if (m > 0) { - /// factorize LDL - LapackTeam::sytrf(member, - Uplo::Lower::param, - m, - A.data(), A.stride_1(), - P.data(), - W.data(), - &r_val); - } - return r_val; +/// LAPACK LDL +/// ========== +template <> struct LDL { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeP &P, + const ViewTypeW &W) { + typedef typename ViewTypeA::non_const_value_type value_type; + // typedef typename ViewTypeP::non_const_value_type p_value_type; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + static_assert(ViewTypeW::rank == 1, "W is not rank 1 view."); + + TACHO_TEST_FOR_ABORT(P.extent(0) < 4 * A.extent(0), "P should be 4*A.extent(0) ."); + + int r_val(0); + const ordinal_type m = A.extent(0); + if (m > 0) { + /// factorize LDL + LapackTeam::sytrf(member, Uplo::Lower::param, m, A.data(), A.stride_1(), P.data(), W.data(), &r_val); } + return r_val; + } - template - KOKKOS_INLINE_FUNCTION - static int - modify(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeD &D) { - typedef typename ViewTypeA::non_const_value_type value_type; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeP::rank == 1,"P is not rank 1 view."); - static_assert(ViewTypeD::rank == 2,"D is not rank 2 view."); - - TACHO_TEST_FOR_ABORT(D.extent(0) < A.extent(0), - "D extent(0) is smaller than A extent(0)."); - TACHO_TEST_FOR_ABORT(D.extent(1) != 2, - "D is supposed to store 2x2 blocks ."); - TACHO_TEST_FOR_ABORT(P.extent(0) < 4*A.extent(0), - "P should be 4*A.extent(0) ."); - - int r_val = 0; - const ordinal_type m = A.extent(0); - if (m > 0) { - value_type - *__restrict__ Aptr = A.data(); - ordinal_type - *__restrict__ ipiv = P.data(), - *__restrict__ fpiv = ipiv + m, - *__restrict__ perm = fpiv + m, - *__restrict__ peri = perm + m; - const value_type one(1); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m), - [&](const int &i) { - perm[i] = i; - }); - member.team_barrier(); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member,m), - [&](const int &j) { - const bool single = (j == 0); - for (ordinal_type i=0/*,cnt=0*/;i + KOKKOS_INLINE_FUNCTION static int modify(MemberType &member, const ViewTypeA &A, const ViewTypeP &P, + const ViewTypeD &D) { + typedef typename ViewTypeA::non_const_value_type value_type; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + static_assert(ViewTypeD::rank == 2, "D is not rank 2 view."); + + TACHO_TEST_FOR_ABORT(D.extent(0) < A.extent(0), "D extent(0) is smaller than A extent(0)."); + TACHO_TEST_FOR_ABORT(D.extent(1) != 2, "D is supposed to store 2x2 blocks ."); + TACHO_TEST_FOR_ABORT(P.extent(0) < 4 * A.extent(0), "P should be 4*A.extent(0) ."); + + int r_val = 0; + const ordinal_type m = A.extent(0); + if (m > 0) { + value_type *__restrict__ Aptr = A.data(); + ordinal_type *__restrict__ ipiv = P.data(), *__restrict__ fpiv = ipiv + m, *__restrict__ perm = fpiv + m, + *__restrict__ peri = perm + m; + const value_type one(1); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &i) { perm[i] = i; }); + member.team_barrier(); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &j) { + const bool single = (j == (m - 1)); + for (ordinal_type i = 0 /*,cnt=0*/; i < m; ++i) { + // if (ipiv[i] <= 0) { + // if (++cnt%2) { + // if (single) { + // ipiv[i] = 0; /// invalidate this pivot + // fpiv[i] = 0; + + // D(i,0) = A(i, i); + // D(i,1) = A(i+1,i); /// symmetric + // A(i,i) = one; + // } + // } else { + // const ordinal_type fla_pivot = -ipiv[i]-i-1; + // if (single) { + // fpiv[i] = fla_pivot; + // } + // if (fla_pivot) { + // value_type *__restrict__ src = Aptr + i; + // value_type *__restrict__ tgt = src + fla_pivot; + // if (j<(i-1)) { + // const ordinal_type idx = j*m; + // swap(src[idx], tgt[idx]); + // } + // } + + // if (single) { + // D(i,0) = A(i,i-1); + // D(i,1) = A(i,i ); + // A(i,i-1) = zero; A(i,i) = one; + // } + // } + // } else + { + const ordinal_type fla_pivot = ipiv[i] - i - 1; + if (single) { + fpiv[i] = fla_pivot; + } + if (fla_pivot) { + value_type *src = Aptr + i; + value_type *tgt = src + fla_pivot; + if (j < i) { + const ordinal_type idx = j * m; + swap(src[idx], tgt[idx]); } } - }); - member.team_barrier(); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m), - [&](const int &i) { - peri[perm[i]] = i; - }); - - } - - /// no piv version - // if (m > 0) { - // ordinal_type - // *__restrict__ ipiv = P.data(), - // *__restrict__ fpiv = ipiv + m, - // *__restrict__ perm = fpiv + m, - // *__restrict__ peri = perm + m; - // const value_type one(1); - // Kokkos::parallel_for(Kokkos::TeamVectorRange(member,m),[&](const int &i) { - // D(i,0) = A(i,i); - // A(i,i) = one; - // ipiv[i] = i+1; - // fpiv[i] = 0; - // perm[i] = i; - // peri[i] = i; - // }); - // } - return r_val; + + if (single) { + D(i, 0) = A(i, i); + A(i, i) = one; + } + } + + /// apply pivots to perm vector + if (single) { + if (fpiv[i]) { + const ordinal_type pidx = i + fpiv[i]; + swap(perm[i], perm[pidx]); + } + } + } + }); + member.team_barrier(); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &i) { peri[perm[i]] = i; }); } - - }; -} + /// no piv version + // if (m > 0) { + // ordinal_type + // *__restrict__ ipiv = P.data(), + // *__restrict__ fpiv = ipiv + m, + // *__restrict__ perm = fpiv + m, + // *__restrict__ peri = perm + m; + // const value_type one(1); + // Kokkos::parallel_for(Kokkos::TeamVectorRange(member,m),[&](const int &i) { + // D(i,0) = A(i,i); + // A(i,i) = one; + // ipiv[i] = i+1; + // fpiv[i] = 0; + // perm[i] = i; + // peri[i] = i; + // }); + // } + return r_val; + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_OnDevice.hpp index f02d3b6fafdf..e3eb0362a79d 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_OnDevice.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_LDL_ON_DEVICE_HPP__ #define __TACHO_LDL_ON_DEVICE_HPP__ - /// \file Tacho_LDL_OnDevice.hpp /// \brief LDL device solver /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -10,281 +9,209 @@ namespace Tacho { - template<> - struct LDL { - template - inline - static int - lapack_invoke(const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeW &W) { - return LDL::invoke(A, P, W); - } +template <> struct LDL { + template + inline static int lapack_invoke(const ViewTypeA &A, const ViewTypeP &P, const ViewTypeW &W) { + return LDL::invoke(A, P, W); + } #if defined(KOKKOS_ENABLE_CUDA) - template - inline - static int - cusolver_invoke(cusolverDnHandle_t &handle, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeW &W) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeW::non_const_value_type work_value_type; - const ordinal_type - m = A.extent(0); + template + inline static int cusolver_invoke(cusolverDnHandle_t &handle, const ViewTypeA &A, const ViewTypeP &P, + const ViewTypeW &W) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeW::non_const_value_type work_value_type; + const ordinal_type m = A.extent(0); - int r_val(0); - if (m > 0) { - int *devInfo = (int*)W.data(); - value_type *workspace = W.data() + 1; - int lwork = (W.span()-1); - r_val = Lapack::sytrf(handle, - CUBLAS_FILL_MODE_LOWER, - m, - A.data(), A.stride_1(), - P.data(), - workspace, lwork, - devInfo); - } - return r_val; + int r_val(0); + if (m > 0) { + int *devInfo = (int *)W.data(); + value_type *workspace = W.data() + 1; + int lwork = (W.span() - 1); + r_val = Lapack::sytrf(handle, CUBLAS_FILL_MODE_LOWER, m, A.data(), A.stride_1(), P.data(), workspace, + lwork, devInfo); } + return r_val; + } - template - inline - static int - cusolver_buffer_size(cusolverDnHandle_t &handle, - const ViewTypeA &A, - int *lwork) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = A.extent(0); + template + inline static int cusolver_buffer_size(cusolverDnHandle_t &handle, const ViewTypeA &A, int *lwork) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = A.extent(0); - int r_val(0); - if (m > 0) - r_val = Lapack::sytrf_buffersize(handle, - m, - A.data(), A.stride_1(), - lwork); - return r_val; - } + int r_val(0); + if (m > 0) + r_val = Lapack::sytrf_buffersize(handle, m, A.data(), A.stride_1(), lwork); + return r_val; + } #endif - - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeW &W) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeW::non_const_value_type value_type_w; - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeW::memory_space memory_space_w; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeW::rank == 1,"W is not rank 1 view."); - - static_assert(std::is_same::value, - "A and W do not have the same value type."); + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeP &P, const ViewTypeW &W) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeW::non_const_value_type value_type_w; - static_assert(std::is_same::value, - "A and W do not have the same memory space."); - int r_val(0); - if (std::is_same::value) { - if (W.span() == 0) { - int lwork = A.extent(0)*32; - return lwork; - } else { - r_val = lapack_invoke(A, P, W); - } - } - -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) { - if (W.span() == 0) { - int lwork; - r_val = cusolver_buffer_size(member, A, &lwork); - r_val = (lwork+sizeof(value_type_w))/sizeof(value_type_w) + 1; - } else - r_val = cusolver_invoke(member, A, P, W); + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeW::memory_space memory_space_w; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeW::rank == 1, "W is not rank 1 view."); + + static_assert(std::is_same::value, "A and W do not have the same value type."); + + static_assert(std::is_same::value, "A and W do not have the same memory space."); + int r_val(0); + if (std::is_same::value) { + if (W.span() == 0) { + int lwork = A.extent(0) * 32; + return lwork; + } else { + r_val = lapack_invoke(A, P, W); } -#endif - return r_val; } - template - inline - static int - lapack_modify(const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeD &D) { - return LDL::modify(A, P, D); +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || + std::is_same::value) { + if (W.span() == 0) { + int lwork; + r_val = cusolver_buffer_size(member, A, &lwork); + r_val = (lwork + sizeof(value_type_w)) / sizeof(value_type_w) + 1; + } else + r_val = cusolver_invoke(member, A, P, W); } +#endif + return r_val; + } + + template + inline static int lapack_modify(const ViewTypeA &A, const ViewTypeP &P, const ViewTypeD &D) { + return LDL::modify(A, P, D); + } #if defined(KOKKOS_ENABLE_CUDA) - template - inline - static int - cusolver_modify(ExecSpaceType &exec_instance, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeD &D) { - using exec_space = ExecSpaceType; - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type - m = A.extent(0); - - int r_val(0); - if (m > 0) { - value_type - *__restrict__ Aptr = A.data(); - ordinal_type - *__restrict__ ipiv = P.data(), - *__restrict__ fpiv = ipiv + m, - *__restrict__ perm = fpiv + m, - *__restrict__ peri = perm + m; - - const value_type one(1), zero(0); - Kokkos::RangePolicy range_policy(exec_instance, 0, m); - Kokkos::parallel_for - ("PermutationSet", range_policy, - KOKKOS_LAMBDA(const ordinal_type i) { - perm[i] = i; - }); - exec_instance.fence(); - Kokkos::parallel_for - ("ExtractDiagonalsAndPostProcessing", - range_policy, - KOKKOS_LAMBDA(const ordinal_type j) { - const bool single = (j == 0); - for (ordinal_type i=0,cnt=0;i + inline static int cusolver_modify(ExecSpaceType &exec_instance, const ViewTypeA &A, const ViewTypeP &P, + const ViewTypeD &D) { + using exec_space = ExecSpaceType; + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = A.extent(0); + + int r_val(0); + if (m > 0) { + value_type *__restrict__ Aptr = A.data(); + ordinal_type *__restrict__ ipiv = P.data(), *__restrict__ fpiv = ipiv + m, *__restrict__ perm = fpiv + m, + *__restrict__ peri = perm + m; + + const value_type one(1), zero(0); + Kokkos::RangePolicy range_policy(exec_instance, 0, m); + Kokkos::parallel_for( + "PermutationSet", range_policy, KOKKOS_LAMBDA(const ordinal_type i) { perm[i] = i; }); + exec_instance.fence(); + Kokkos::parallel_for( + "ExtractDiagonalsAndPostProcessing", range_policy, KOKKOS_LAMBDA(const ordinal_type j) { + const bool single = (j == (m - 1)); + for (ordinal_type i = 0; i < m; ++i) { + if (ipiv[i] < 0) { + const bool is_first = (i + 1) < m ? (ipiv[i + 1] == ipiv[i]) : false; + if (is_first) { if (single) { ipiv[i] = 0; /// invalidate this pivot fpiv[i] = 0; - - D(i,0) = A(i, i); - D(i,1) = A(i+1,i); /// symmetric - A(i,i) = one; + + D(i, 0) = A(i, i); + D(i, 1) = A(i + 1, i); /// symmetric + A(i, i) = one; } } else { - const ordinal_type fla_pivot = -ipiv[i]-i-1; + const ordinal_type fla_pivot = -ipiv[i] - i - 1; if (single) { fpiv[i] = fla_pivot; } if (fla_pivot) { value_type *__restrict__ src = Aptr + i; value_type *__restrict__ tgt = src + fla_pivot; - if (j<(i-1)) { - const ordinal_type idx = j*m; + if (j < (i - 1)) { + const ordinal_type idx = j * m; swap(src[idx], tgt[idx]); } - } - + } + if (single) { - D(i,0) = A(i,i-1); - D(i,1) = A(i,i ); - A(i,i-1) = zero; A(i,i) = one; + D(i, 0) = A(i, i - 1); + D(i, 1) = A(i, i); + A(i, i - 1) = zero; + A(i, i) = one; } } } else { - const ordinal_type fla_pivot = ipiv[i]-i-1; + const ordinal_type fla_pivot = ipiv[i] - i - 1; if (single) { fpiv[i] = fla_pivot; } if (fla_pivot) { - value_type * src = Aptr + i; - value_type * tgt = src + fla_pivot; - if (j - inline - static int - modify(MemberType &member, - const ViewTypeA &A, - const ViewTypeP &P, - const ViewTypeD &D) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeD::non_const_value_type value_type_d; - - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeP::memory_space memory_space_p; - typedef typename ViewTypeD::memory_space memory_space_d; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeP::rank == 1,"P is not rank 1 view."); - static_assert(ViewTypeD::rank == 2,"D is not rank 2 view."); - - static_assert(std::is_same::value, - "A and D do not have the same value type."); - - static_assert(std::is_same::value, - "A and P do not have the same memory space."); - - static_assert(std::is_same::value, - "A and D do not have the same memory space."); - - int r_val(0); - if (std::is_same::value) { - r_val = lapack_modify(A, P, D); - } - -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) { - r_val = cusolver_modify(member, A, P, D); - } -#endif - return r_val; + template + inline static int modify(MemberType &member, const ViewTypeA &A, const ViewTypeP &P, const ViewTypeD &D) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeD::non_const_value_type value_type_d; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeP::memory_space memory_space_p; + typedef typename ViewTypeD::memory_space memory_space_d; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + static_assert(ViewTypeD::rank == 2, "D is not rank 2 view."); + + static_assert(std::is_same::value, "A and D do not have the same value type."); + + static_assert(std::is_same::value, "A and P do not have the same memory space."); + + static_assert(std::is_same::value, "A and D do not have the same memory space."); + + int r_val(0); + if (std::is_same::value) { + r_val = lapack_modify(A, P, D); } - }; -} + +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || + std::is_same::value) { + r_val = cusolver_modify(member, A, P, D); + } +#endif + return r_val; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes.hpp index ce800df4658a..0ecb8bea77ee 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes.hpp @@ -8,8 +8,7 @@ namespace Tacho { - template - struct LDL_Supernodes; +template struct LDL_Supernodes; } diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes_Serial.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes_Serial.hpp index 0ea58303d04d..bc71f253280f 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes_Serial.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LDL_Supernodes_Serial.hpp @@ -28,346 +28,289 @@ namespace Tacho { - template<> - struct LDL_Supernodes { - template - KOKKOS_INLINE_FUNCTION - static int - factorize(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::ordinal_type_array &P, - const typename SupernodeInfoType::value_type_matrix &D, - const typename SupernodeInfoType::value_type_array &W, - const typename SupernodeInfoType::value_type_matrix &ABR, - const ordinal_type sid) { - using supernode_info_type = SupernodeInfoType; - using value_type = typename supernode_info_type::value_type; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - - // algorithm choice - using MainAlgoType = typename std::conditional - ::value, - Algo::External,Algo::Internal>::type; - - using LDL_AlgoType = MainAlgoType; - using TrsmAlgoType = MainAlgoType; - using GemmAlgoType = MainAlgoType; - - // get current supernode - const auto &s = info.supernodes(sid); - - // get panel pointer - value_type *ptr = s.buf; - - // panel (s.m x s.n) is divided into ATL (m x m) and ATR (m x n) - const ordinal_type m = s.m, n = s.n - s.m; - - // m and n are available, then factorize the supernode block - if (m > 0) { - /// LDL factorize ATL, extract diag, symmetrize ATL with unit diagonals - UnmanagedViewType ATL(ptr, m, m); ptr += m*m; - Symmetrize::invoke(member, ATL); - LDL::invoke(member, ATL, P, W); - LDL::modify(member, ATL, P, D); - - if (n > 0) { - const value_type one(1), zero(0); - UnmanagedViewType ATR(ptr, m, n); ptr += m*n; - UnmanagedViewType STR(W.data(), m, n); - - auto fpiv = ordinal_type_array(P.data()+m, m); - ApplyPivots /// row inter-change - ::invoke(member, fpiv, ATR); - Trsm - ::invoke(member, Diag::Unit(), one, ATL, ATR); - Copy - ::invoke(member, STR, ATR); - Scale2x2_BlockInverseDiagonals /// row scaling - ::invoke(member, P, D, ATR); - - TACHO_TEST_FOR_ABORT(static_cast(ABR.extent(0)) != n || +template <> struct LDL_Supernodes { + template + KOKKOS_INLINE_FUNCTION static int + factorize(MemberType &member, const SupernodeInfoType &info, const typename SupernodeInfoType::ordinal_type_array &P, + const typename SupernodeInfoType::value_type_matrix &D, + const typename SupernodeInfoType::value_type_array &W, + const typename SupernodeInfoType::value_type_matrix &ABR, const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + // algorithm choice + using LDL_AlgoType = typename LDL_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + // get current supernode + const auto &s = info.supernodes(sid); + + // get panel pointer + value_type *ptr = s.u_buf; + + // panel (s.m x s.n) is divided into ATL (m x m) and ATR (m x n) + const ordinal_type m = s.m, n = s.n - s.m; + + // m and n are available, then factorize the supernode block + if (m > 0) { + /// LDL factorize ATL, extract diag, symmetrize ATL with unit diagonals + UnmanagedViewType ATL(ptr, m, m); + ptr += m * m; + + Symmetrize::invoke(member, ATL); + + LDL::invoke(member, ATL, P, W); + LDL::modify(member, ATL, P, D); + + if (n > 0) { + const value_type one(1), zero(0); + UnmanagedViewType ATR(ptr, m, n); + ptr += m * n; + UnmanagedViewType STR(W.data(), m, n); + + auto fpiv = ordinal_type_array(P.data() + m, m); + ApplyPivots /// row inter-change + ::invoke(member, fpiv, ATR); + Trsm::invoke(member, Diag::Unit(), one, ATL, ATR); + Copy::invoke(member, STR, ATR); + Scale2x2_BlockInverseDiagonals /// row scaling + ::invoke(member, P, D, ATR); + + TACHO_TEST_FOR_ABORT(static_cast(ABR.extent(0)) != n || static_cast(ABR.extent(1)) != n, - "ABR dimension does not match to supernodes"); - GemmTriangular - ::invoke(member, -one, ATR, STR, zero, ABR); - } - } - return 0; + "ABR dimension does not match to supernodes"); + GemmTriangular::invoke(member, -one, ATR, STR, + zero, ABR); } - - template - KOKKOS_INLINE_FUNCTION - static int - solve_lower(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::ordinal_type_array &P, - const typename SupernodeInfoType::value_type_matrix &xB, - const ordinal_type sid) { - using supernode_info_type = SupernodeInfoType; - using value_type = typename supernode_info_type::value_type; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - - using range_type = Kokkos::pair; - - const auto &s = info.supernodes(sid); - - using MainAlgoType = typename std::conditional - ::value, - Algo::External,Algo::Internal>::type; - - using TrsvAlgoType = MainAlgoType; - using GemvAlgoType = MainAlgoType; - - // get panel pointer - value_type *ptr = s.buf; - - // panel is divided into diagonal and interface block - const ordinal_type m = s.m, n = s.n - s.m; //, nrhs = info.x.extent(1); - - // m and n are available, then factorize the supernode block - if (m > 0) { - const value_type one(1), zero(0); - const ordinal_type offm = s.row_begin; - UnmanagedViewType AL(ptr, m, m); ptr += m*m; - const auto xT = Kokkos::subview(info.x, range_type(offm, offm+m), Kokkos::ALL()); - const auto fpiv = ordinal_type_array(P.data()+m, m); - - ApplyPivots /// row inter-change - ::invoke(member, fpiv, xT); - - Trsv - ::invoke(member, Diag::Unit(), AL, xT); - - if (n > 0) { - UnmanagedViewType AR(ptr, m, n); // ptr += m*n; - Gemv - ::invoke(member, -one, AR, xT, zero, xB); - } - } - return 0; + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int solve_lower(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::ordinal_type_array &P, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + using range_type = Kokkos::pair; + + const auto &s = info.supernodes(sid); + + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + // get panel pointer + value_type *ptr = s.u_buf; + + // panel is divided into diagonal and interface block + const ordinal_type m = s.m, n = s.n - s.m; //, nrhs = info.x.extent(1); + + // m and n are available, then factorize the supernode block + if (m > 0) { + const value_type one(1), zero(0); + const ordinal_type offm = s.row_begin; + UnmanagedViewType AL(ptr, m, m); + ptr += m * m; + const auto xT = Kokkos::subview(info.x, range_type(offm, offm + m), Kokkos::ALL()); + const auto fpiv = ordinal_type_array(P.data() + m, m); + + ApplyPivots /// row inter-change + ::invoke(member, fpiv, xT); + + Trsv::invoke(member, Diag::Unit(), AL, xT); + + if (n > 0) { + UnmanagedViewType AR(ptr, m, n); // ptr += m*n; + Gemv::invoke(member, -one, AR, xT, zero, xB); } + } + return 0; + } - template - KOKKOS_INLINE_FUNCTION - static int - solve_upper(MemberType &member, - const SupernodeInfoType &info, - const typename SupernodeInfoType::ordinal_type_array &P, - const typename SupernodeInfoType::value_type_matrix &D, - const typename SupernodeInfoType::value_type_matrix &xB, - const ordinal_type sid) { - using supernode_info_type = SupernodeInfoType; - - using value_type = typename supernode_info_type::value_type; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - - using range_type = Kokkos::pair; - - using MainAlgoType = typename std::conditional - ::value, - Algo::External,Algo::Internal>::type; - using GemvAlgoType = MainAlgoType; - using TrsvAlgoType = MainAlgoType; - - // get current supernode - const auto &s = info.supernodes(sid); - - // get supernode panel pointer - value_type *ptr = s.buf; - - // panel is divided into diagonal and interface block - const ordinal_type m = s.m, n = s.n - s.m;//, nrhs = info.x.extent(1); - - // m and n are available, then factorize the supernode block - if (m > 0) { - const value_type one(1); - const UnmanagedViewType AL(ptr, m, m); ptr += m*m; - - const ordinal_type offm = s.row_begin; - const auto xT = Kokkos::subview(info.x, range_type(offm, offm+m), Kokkos::ALL()); - const auto fpiv = ordinal_type_array(P.data()+m, m); - - Scale2x2_BlockInverseDiagonals /// row scaling - ::invoke(member, P, D, xT); - - if (n > 0) { - const UnmanagedViewType AR(ptr, m, n); // ptr += m*n; - Gemv - ::invoke(member, -one, AR, xB, one, xT); - } - Trsv - ::invoke(member, Diag::Unit(), AL, xT); - ApplyPivots /// row inter-change - ::invoke(member, fpiv, xT); - } - return 0; - } + template + KOKKOS_INLINE_FUNCTION static int solve_upper(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::ordinal_type_array &P, + const typename SupernodeInfoType::value_type_matrix &D, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; - template - KOKKOS_INLINE_FUNCTION - static int - factorize_recursive_serial(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid, - const bool final, - typename SupernodeInfoType::ordinal_type_array::pointer_type piv, - typename SupernodeInfoType::value_type_array::pointer_type diag, - typename SupernodeInfoType::value_type_array::pointer_type buf, - const size_type bufsize) { - using supernode_info_type = SupernodeInfoType; - using value_type = typename supernode_info_type::value_type; - using value_type_array = typename supernode_info_type::value_type_array; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - - const auto &s = info.supernodes(sid); - if (final) { - // serial recursion - for (ordinal_type i=0;i ipiv(piv+rbeg*4, 4*m); - UnmanagedViewType dblk(diag+rbeg*2, m, 2); - - const ordinal_type n = s.n - s.m; - - const ordinal_type mm = m < 32 ? m : 32; - const ordinal_type mn = mm > n ? mm : n; - - const size_type bufsize_required = (n*n + m*mn)*sizeof(value_type); - TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, - "bufsize is smaller than required"); - value_type * bufptr = buf; - UnmanagedViewType ABR(bufptr, n, n); bufptr += ABR.span(); - UnmanagedViewType w(bufptr, m*mn); bufptr += w.span(); - - LDL_Supernodes - ::factorize(member, info, ipiv, dblk, w, ABR, sid); - - /// assembly is same - CholSupernodes - ::update(member, info, ABR, sid, - bufsize - ABR.span()*sizeof(value_type), - (void*)(w.data())); - } - return 0; - } + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using range_type = Kokkos::pair; - template - KOKKOS_INLINE_FUNCTION - static int - solve_lower_recursive_serial(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid, - const bool final, - typename SupernodeInfoType::ordinal_type_array::pointer_type piv, - typename SupernodeInfoType::value_type_array::pointer_type buf, - const size_type bufsize) { - using supernode_info_type = SupernodeInfoType; - - using value_type = typename supernode_info_type::value_type; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - - const auto &s = info.supernodes(sid); - - if (final) { - // serial recursion - for (ordinal_type i=0;i ipiv(piv+rbeg*4, 4*m); - - const ordinal_type n = s.n - s.m; - const ordinal_type nrhs = info.x.extent(1); - const size_type bufsize_required = n*nrhs*sizeof(value_type); - - TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, - "bufsize is smaller than required"); - - UnmanagedViewType xB((value_type*)buf, n, nrhs); - - LDL_Supernodes - ::solve_lower(member, info, ipiv, xB, sid); - - CholSupernodes - ::update_solve_lower(member, info, xB, sid); - } - return 0; - } + using GemvAlgoType = typename GemvAlgorithm::type; + using TrsvAlgoType = typename TrsvAlgorithm::type; + // get current supernode + const auto &s = info.supernodes(sid); - template - KOKKOS_INLINE_FUNCTION - static int - solve_upper_recursive_serial(MemberType &member, - const SupernodeInfoType &info, - const ordinal_type sid, - const bool final, - typename SupernodeInfoType::ordinal_type_array::pointer_type piv, - typename SupernodeInfoType::value_type_array::pointer_type diag, - typename SupernodeInfoType::value_type_array::pointer_type buf, - const ordinal_type bufsize) { - using supernode_info_type = SupernodeInfoType; - using value_type = typename supernode_info_type::value_type; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - - const auto &s = info.supernodes(sid); - { - const ordinal_type m = s.m; - const ordinal_type rbeg = s.row_begin; - UnmanagedViewType ipiv(piv+rbeg*4, 4*m); - UnmanagedViewType dblk(diag+rbeg*2, m, 2); - - const ordinal_type n = s.n - s.m; - const ordinal_type nrhs = info.x.extent(1); - const ordinal_type bufsize_required = n*nrhs*sizeof(value_type); - - TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, - "bufsize is smaller than required"); - - UnmanagedViewType xB((value_type*)buf, n, nrhs); - - CholSupernodes - ::update_solve_upper(member, info, xB, sid); - - LDL_Supernodes - ::solve_upper(member, info, ipiv, dblk, xB, sid); - } - - if (final) { - // serial recursion - for (ordinal_type i=0;i 0) { + const value_type one(1); + const UnmanagedViewType AL(ptr, m, m); + ptr += m * m; + + const ordinal_type offm = s.row_begin; + const auto xT = Kokkos::subview(info.x, range_type(offm, offm + m), Kokkos::ALL()); + const auto fpiv = ordinal_type_array(P.data() + m, m); + + Scale2x2_BlockInverseDiagonals /// row scaling + ::invoke(member, P, D, xT); + + if (n > 0) { + const UnmanagedViewType AR(ptr, m, n); // ptr += m*n; + Gemv::invoke(member, -one, AR, xB, one, xT); + } + Trsv::invoke(member, Diag::Unit(), AL, xT); + ApplyPivots /// row inter-change + ::invoke(member, fpiv, xT); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int + factorize_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::ordinal_type_array::pointer_type piv, + typename SupernodeInfoType::value_type_array::pointer_type diag, + typename SupernodeInfoType::value_type_array::pointer_type buf, const size_type bufsize) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_array = typename supernode_info_type::value_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + const auto &s = info.supernodes(sid); + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + factorize_recursive_serial(member, info, s.children[i], final, piv, diag, buf, bufsize); + } + + { + const ordinal_type m = s.m; + const ordinal_type rbeg = s.row_begin; + UnmanagedViewType ipiv(piv + rbeg * 4, 4 * m); + UnmanagedViewType dblk(diag + rbeg * 2, m, 2); + + const ordinal_type n = s.n - s.m; + + const ordinal_type mm = m < 32 ? m : 32; + const ordinal_type mn = mm > n ? mm : n; + + const size_type bufsize_required = (n * n + m * mn) * sizeof(value_type); + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + value_type *bufptr = buf; + UnmanagedViewType ABR(bufptr, n, n); + bufptr += ABR.span(); + UnmanagedViewType w(bufptr, m * mn); + bufptr += w.span(); + + LDL_Supernodes::factorize(member, info, ipiv, dblk, w, ABR, sid); + + /// assembly is same + CholSupernodes::update(member, info, ABR, sid, bufsize - ABR.span() * sizeof(value_type), + (void *)(w.data())); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int + solve_lower_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::ordinal_type_array::pointer_type piv, + typename SupernodeInfoType::value_type_array::pointer_type buf, + const size_type bufsize) { + using supernode_info_type = SupernodeInfoType; + + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + const auto &s = info.supernodes(sid); + + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + solve_lower_recursive_serial(member, info, s.children[i], final, piv, buf, bufsize); + } + + { + const ordinal_type m = s.m; + const ordinal_type rbeg = s.row_begin; + UnmanagedViewType ipiv(piv + rbeg * 4, 4 * m); + + const ordinal_type n = s.n - s.m; + const ordinal_type nrhs = info.x.extent(1); + const size_type bufsize_required = n * nrhs * sizeof(value_type); + + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + + UnmanagedViewType xB((value_type *)buf, n, nrhs); + + LDL_Supernodes::solve_lower(member, info, ipiv, xB, sid); + + CholSupernodes::update_solve_lower(member, info, xB, sid); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int + solve_upper_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::ordinal_type_array::pointer_type piv, + typename SupernodeInfoType::value_type_array::pointer_type diag, + typename SupernodeInfoType::value_type_array::pointer_type buf, + const ordinal_type bufsize) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + const auto &s = info.supernodes(sid); + { + const ordinal_type m = s.m; + const ordinal_type rbeg = s.row_begin; + UnmanagedViewType ipiv(piv + rbeg * 4, 4 * m); + UnmanagedViewType dblk(diag + rbeg * 2, m, 2); + + const ordinal_type n = s.n - s.m; + const ordinal_type nrhs = info.x.extent(1); + const ordinal_type bufsize_required = n * nrhs * sizeof(value_type); + + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + + UnmanagedViewType xB((value_type *)buf, n, nrhs); + + CholSupernodes::update_solve_upper(member, info, xB, sid); + + LDL_Supernodes::solve_upper(member, info, ipiv, dblk, xB, sid); + } + + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + solve_upper_recursive_serial(member, info, s.children[i], final, piv, diag, buf, bufsize); + } + return 0; + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU.hpp new file mode 100644 index 000000000000..83399c67ed27 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU.hpp @@ -0,0 +1,26 @@ +#ifndef __TACHO_LU_HPP__ +#define __TACHO_LU_HPP__ + +/// \file Tacho_LU.hpp +/// \brief Front interface for LU dense factorization +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" + +namespace Tacho { + +/// +/// LU: +/// +/// + +/// various implementation for different uplo and algo parameters +template struct LU; + +struct LU_Algorithm { + using type = ActiveAlgorithm::type; +}; + +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_External.hpp new file mode 100644 index 000000000000..fdb828e03d83 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_External.hpp @@ -0,0 +1,94 @@ +#ifndef __TACHO_LU_EXTERNAL_HPP__ +#define __TACHO_LU_EXTERNAL_HPP__ + +/// \file Tacho_LU_External.hpp +/// \brief LAPACK LU factorization +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Lapack_External.hpp" + +namespace Tacho { + +/// LAPACK LU +/// ========== +template <> struct LU { + template inline static int invoke(const ViewTypeA &A, const ViewTypeP &P) { + int r_val = 0; +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + + TACHO_TEST_FOR_EXCEPTION(P.extent(0) < 4 * A.extent(0), std::runtime_error, "P should be 4*A.extent(0) ."); + + const ordinal_type m = A.extent(0), n = A.extent(1); + if (m > 0 && n > 0) { + /// factorize LU + Lapack::getrf(m, n, A.data(), A.stride_1(), P.data(), &r_val); + TACHO_TEST_FOR_EXCEPTION(r_val, std::runtime_error, "LAPACK (getrf) returns non-zero error code."); + } +#else + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); +#endif + return r_val; + } + + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeP &P) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + int r_val = 0; + r_val = invoke(A, P); + return r_val; +#else + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); +#endif + } + + template inline static int modify(const ordinal_type m, const ViewTypeP &P) { + int r_val = 0; +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + + TACHO_TEST_FOR_EXCEPTION(int(P.extent(0)) < 4 * m, std::runtime_error, "P should be 4*m."); + + if (m > 0) { + ordinal_type *__restrict__ ipiv = P.data(), *__restrict__ fpiv = ipiv + m, *__restrict__ perm = fpiv + m, + *__restrict__ peri = perm + m; + + for (ordinal_type i = 0; i < m; ++i) + perm[i] = i; + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type fla_pivot = ipiv[i] - i - 1; + fpiv[i] = fla_pivot; + + /// apply pivots to perm vector + if (fpiv[i]) { + const ordinal_type pidx = i + fpiv[i]; + swap(perm[i], perm[pidx]); + } + } + for (ordinal_type i = 0; i < m; ++i) + peri[perm[i]] = i; + } +#else + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); +#endif + return r_val; + } + + template + inline static int modify(MemberType &member, ordinal_type m, const ViewTypeP &P) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + int r_val = 0; + r_val = modify(m, P); + return r_val; +#else + TACHO_TEST_FOR_ABORT(true, ">> This function is only allowed in host space."); +#endif + } +}; + +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Internal.hpp new file mode 100644 index 000000000000..aa857b4d850a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Internal.hpp @@ -0,0 +1,68 @@ +#ifndef __TACHO_LU_INTERNAL_HPP__ +#define __TACHO_LU_INTERNAL_HPP__ + +/// \file Tacho_LU_Internal.hpp +/// \brief LU team factorization +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Lapack_Team.hpp" + +namespace Tacho { + +/// LAPACK LU +/// ========== +template <> struct LU { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeP &P) { + typedef typename ViewTypeA::non_const_value_type value_type; + // typedef typename ViewTypeP::non_const_value_type p_value_type; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + + TACHO_TEST_FOR_ABORT(P.extent(0) < 4 * A.extent(0), "P should be 4*A.extent(0) ."); + + int r_val(0); + const ordinal_type m = A.extent(0), n = A.extent(1); + if (m > 0 && n > 0) { + /// factorize LU + LapackTeam::getrf(member, m, n, A.data(), A.stride_1(), P.data(), &r_val); + } + return r_val; + } + + template + KOKKOS_INLINE_FUNCTION static int modify(const MemberType &member, const ordinal_type m, const ViewTypeP &P) { + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + + TACHO_TEST_FOR_ABORT(P.extent(0) < 4 * m, "P should be 4*m ."); + + int r_val = 0; + if (m > 0) { + ordinal_type *__restrict__ ipiv = P.data(), *__restrict__ fpiv = ipiv + m, *__restrict__ perm = fpiv + m, + *__restrict__ peri = perm + m; + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &i) { + perm[i] = i; + fpiv[i] = ipiv[i] - i - 1; + }); + member.team_barrier(); + Kokkos::single(Kokkos::PerTeam(member), [&]() { + for (ordinal_type i = 0; i < m; ++i) { + /// apply pivots to perm vector + if (fpiv[i]) { + const ordinal_type pidx = i + fpiv[i]; + swap(perm[i], perm[pidx]); + } + } + }); + member.team_barrier(); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &i) { peri[perm[i]] = i; }); + } + + return r_val; + } +}; + +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_OnDevice.hpp new file mode 100644 index 000000000000..ea3758511836 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_OnDevice.hpp @@ -0,0 +1,144 @@ +#ifndef __TACHO_LU_ON_DEVICE_HPP__ +#define __TACHO_LU_ON_DEVICE_HPP__ + +/// \file Tacho_LU_OnDevice.hpp +/// \brief LU device solver +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_LU_External.hpp" + +namespace Tacho { + +template <> struct LU { + template + inline static int lapack_invoke(const ViewTypeA &A, const ViewTypeP &P) { + return LU::invoke(A, P); + } + +#if defined(KOKKOS_ENABLE_CUDA) + template + inline static int cusolver_invoke(cusolverDnHandle_t &handle, const ViewTypeA &A, const ViewTypeP &P, + const ViewTypeW &W) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeW::non_const_value_type work_value_type; + const ordinal_type m = A.extent(0), n = A.extent(1); + + int r_val(0); + if (m > 0 && n > 0) { + int *devInfo = (int *)W.data(); + value_type *workspace = W.data() + 1; + // int lwork = (W.span()-1); + r_val = Lapack::getrf(handle, m, n, A.data(), A.stride_1(), workspace, P.data(), devInfo); + } + return r_val; + } + + template + inline static int cusolver_buffer_size(cusolverDnHandle_t &handle, const ViewTypeA &A, int *lwork) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = A.extent(0), n = A.extent(1); + + int r_val(0); + if (m > 0) + r_val = Lapack::getrf_buffersize(handle, m, n, A.data(), A.stride_1(), lwork); + return r_val; + } +#endif + + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ViewTypeP &P, const ViewTypeW &W) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeW::non_const_value_type value_type_w; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeW::memory_space memory_space_w; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + static_assert(ViewTypeW::rank == 1, "W is not rank 1 view."); + + static_assert(std::is_same::value, "A and W do not have the same value type."); + + static_assert(std::is_same::value, "A and W do not have the same memory space."); + int r_val(0); + if (std::is_same::value) { + TACHO_TEST_FOR_EXCEPTION(A.data() == NULL, std::logic_error, "Error: A has null data pointer"); + TACHO_TEST_FOR_EXCEPTION(P.data() == NULL, std::logic_error, "Error: P has null data pointer"); + r_val = lapack_invoke(A, P); + } + +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || + std::is_same::value) { + if (W.span() == 0) { + int lwork; + r_val = cusolver_buffer_size(member, A, &lwork); + r_val = (lwork + sizeof(value_type_w)) / sizeof(value_type_w) + 1; + } else + r_val = cusolver_invoke(member, A, P, W); + } +#endif + return r_val; + } + + template inline static int lapack_modify(const ordinal_type m, const ViewTypeP &P) { + return LU::modify(m, P); + } + +#if defined(KOKKOS_ENABLE_CUDA) + template + inline static int cusolver_modify(ExecSpaceType &exec_instance, const ordinal_type m, const ViewTypeP &P) { + using exec_space = ExecSpaceType; + + int r_val(0); + if (m > 0) { + ordinal_type *__restrict__ ipiv = P.data(), *__restrict__ fpiv = ipiv + m, *__restrict__ perm = fpiv + m, + *__restrict__ peri = perm + m; + + Kokkos::RangePolicy range_policy(exec_instance, 0, m); + Kokkos::RangePolicy single_policy(exec_instance, 0, 1); + Kokkos::parallel_for( + range_policy, KOKKOS_LAMBDA(const ordinal_type i) { + perm[i] = i; + fpiv[i] = ipiv[i] - i - 1; + }); + exec_instance.fence(); + Kokkos::parallel_for( + single_policy, KOKKOS_LAMBDA(const ordinal_type j) { + for (ordinal_type i = 0; i < m; ++i) { + if (fpiv[i]) { + const ordinal_type pidx = i + fpiv[i]; + swap(perm[i], perm[pidx]); + } + } + }); + exec_instance.fence(); + Kokkos::parallel_for( + "PermutationInverse", range_policy, KOKKOS_LAMBDA(const ordinal_type i) { peri[perm[i]] = i; }); + } + return r_val; + } +#endif + + template + inline static int modify(MemberType &member, const ordinal_type m, const ViewTypeP &P) { + typedef typename ViewTypeP::memory_space memory_space; + + static_assert(ViewTypeP::rank == 1, "P is not rank 1 view."); + + int r_val(0); + if (std::is_same::value) { + r_val = lapack_modify(m, P); + } + +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || + std::is_same::value) { + r_val = cusolver_modify(member, m, P); + } +#endif + return r_val; + } +}; +} // namespace Tacho +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Supernodes.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Supernodes.hpp new file mode 100644 index 000000000000..15427345c9f1 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Supernodes.hpp @@ -0,0 +1,15 @@ +#ifndef __TACHO_LU_SUPERNODES_HPP__ +#define __TACHO_LU_SUPERNODES_HPP__ + +/// \file Tacho_LU_Supernodes.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" + +namespace Tacho { + +template struct LU_Supernodes; + +} + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Supernodes_Serial.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Supernodes_Serial.hpp new file mode 100644 index 000000000000..a9e0f21e8fc4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_LU_Supernodes_Serial.hpp @@ -0,0 +1,292 @@ +#ifndef __TACHO_LU_SUPERNODES_SERIAL_HPP__ +#define __TACHO_LU_SUPERNODES_SERIAL_HPP__ + +/// \file Tacho_LU_Supernodes.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" + +#include "Tacho_Lapack_External.hpp" +#include "Tacho_Lapack_Team.hpp" + +#include "Tacho_Blas_External.hpp" +#include "Tacho_Blas_Team.hpp" + +#include "Tacho_ApplyPivots.hpp" +#include "Tacho_ApplyPivots_Internal.hpp" + +#include "Tacho_LU.hpp" +#include "Tacho_LU_External.hpp" +#include "Tacho_LU_Internal.hpp" + +#include "Tacho_Trsm.hpp" +#include "Tacho_Trsm_External.hpp" +#include "Tacho_Trsm_Internal.hpp" + +#include "Tacho_Gemm.hpp" +#include "Tacho_Gemm_External.hpp" +#include "Tacho_Gemm_Internal.hpp" + +#include "Tacho_Trsv.hpp" +#include "Tacho_Trsv_External.hpp" +#include "Tacho_Trsv_Internal.hpp" + +#include "Tacho_Gemv.hpp" +#include "Tacho_Gemv_External.hpp" +#include "Tacho_Gemv_Internal.hpp" + +namespace Tacho { + +template <> struct LU_Supernodes { + template + KOKKOS_INLINE_FUNCTION static int + factorize(MemberType &member, const SupernodeInfoType &info, const typename SupernodeInfoType::ordinal_type_array &P, + const typename SupernodeInfoType::value_type_matrix &ABR, const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using range_type = typename supernode_info_type::range_type; + + using LU_AlgoType = typename LU_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + // get current supernode + const auto &s = info.supernodes(sid); + + // panel (s.m x s.n) is divided into ATL (m x m) and ATR (m x n) + const ordinal_type m = s.m, n = s.n, n_m = s.n - s.m; + + // m and n are available, then factorize the supernode block + if (m > 0) { + /// LU factorize ATL + value_type *uptr = s.u_buf; + UnmanagedViewType AT(uptr, m, n); + + LU::invoke(member, AT, P); + LU::modify(member, m, P); + + if (n > 0) { + const value_type one(1), zero(0); + + UnmanagedViewType ATL(uptr, m, m); + uptr += m * m; + + UnmanagedViewType AL(s.l_buf, n, m); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ABL); + + TACHO_TEST_FOR_ABORT(static_cast(ABR.extent(0)) != n_m || + static_cast(ABR.extent(1)) != n_m, + "ABR dimension does not match to supernodes"); + + UnmanagedViewType ATR(uptr, m, n_m); + Gemm::invoke(member, -one, ABL, ATR, zero, ABR); + } + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int solve_lower(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::ordinal_type_array &P, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + using range_type = Kokkos::pair; + + const auto &s = info.supernodes(sid); + + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + // panel is divided into diagonal and interface block + const ordinal_type m = s.m, n = s.n, n_m = s.n - s.m; //, nrhs = info.x.extent(1); + + // m and n are available, then factorize the supernode block + if (m > 0) { + const value_type one(1), zero(0); + const ordinal_type offm = s.row_begin; + value_type *uptr = s.u_buf; + UnmanagedViewType ATL(uptr, m, m); + const auto xT = Kokkos::subview(info.x, range_type(offm, offm + m), Kokkos::ALL()); + const auto fpiv = ordinal_type_array(P.data() + m, m); + + ApplyPivots /// row inter-change + ::invoke(member, fpiv, xT); + + Trsv::invoke(member, Diag::Unit(), ATL, xT); + + if (n_m > 0) { + UnmanagedViewType AL(s.l_buf, n, m); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + Gemv::invoke(member, -one, ABL, xT, zero, xB); + } + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int solve_upper(MemberType &member, const SupernodeInfoType &info, + const typename SupernodeInfoType::ordinal_type_array &P, + const typename SupernodeInfoType::value_type_matrix &xB, + const ordinal_type sid) { + using supernode_info_type = SupernodeInfoType; + + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + + using range_type = Kokkos::pair; + + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + // get current supernode + const auto &s = info.supernodes(sid); + + // panel is divided into diagonal and interface block + const ordinal_type m = s.m, n = s.n, n_m = s.n - s.m; //, nrhs = info.x.extent(1); + + // m and n are available, then factorize the supernode block + if (m > 0) { + const value_type one(1); + value_type *uptr = s.u_buf; + const UnmanagedViewType ATL(uptr, m, m); + uptr += m * m; + + const ordinal_type offm = s.row_begin; + const auto xT = Kokkos::subview(info.x, range_type(offm, offm + m), Kokkos::ALL()); + + if (n_m > 0) { + const UnmanagedViewType AR(uptr, m, n); // ptr += m*n; + Gemv::invoke(member, -one, AR, xB, one, xT); + } + Trsv::invoke(member, Diag::NonUnit(), ATL, xT); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int + factorize_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::ordinal_type_array::pointer_type piv, + typename SupernodeInfoType::value_type_array::pointer_type buf, const size_type bufsize) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + const auto &s = info.supernodes(sid); + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + factorize_recursive_serial(member, info, s.children[i], final, piv, buf, bufsize); + } + + { + const ordinal_type m = s.m; + const ordinal_type rbeg = s.row_begin; + UnmanagedViewType ipiv(piv + rbeg * 4, 4 * m); + + const ordinal_type n = s.n - s.m; + const size_type bufsize_required = n * n * sizeof(value_type); + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + value_type *bufptr = buf; + UnmanagedViewType ABR(bufptr, n, n); + + LU_Supernodes::factorize(member, info, ipiv, ABR, sid); + + constexpr bool update_lower = true; + CholSupernodes::update(member, info, ABR, sid, bufsize - ABR.span() * sizeof(value_type), + (void *)((value_type *)buf + ABR.span()), update_lower); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int + solve_lower_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::ordinal_type_array::pointer_type piv, + typename SupernodeInfoType::value_type_array::pointer_type buf, + const size_type bufsize) { + using supernode_info_type = SupernodeInfoType; + + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + const auto &s = info.supernodes(sid); + + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + solve_lower_recursive_serial(member, info, s.children[i], final, piv, buf, bufsize); + } + + { + const ordinal_type m = s.m; + const ordinal_type rbeg = s.row_begin; + UnmanagedViewType ipiv(piv + rbeg * 4, 4 * m); + + const ordinal_type n = s.n - s.m; + const ordinal_type nrhs = info.x.extent(1); + const size_type bufsize_required = n * nrhs * sizeof(value_type); + + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + + UnmanagedViewType xB((value_type *)buf, n, nrhs); + + LU_Supernodes::solve_lower(member, info, ipiv, xB, sid); + + CholSupernodes::update_solve_lower(member, info, xB, sid); + } + return 0; + } + + template + KOKKOS_INLINE_FUNCTION static int + solve_upper_recursive_serial(MemberType &member, const SupernodeInfoType &info, const ordinal_type sid, + const bool final, typename SupernodeInfoType::ordinal_type_array::pointer_type piv, + typename SupernodeInfoType::value_type_array::pointer_type buf, + const ordinal_type bufsize) { + using supernode_info_type = SupernodeInfoType; + using value_type = typename supernode_info_type::value_type; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + + const auto &s = info.supernodes(sid); + { + const ordinal_type m = s.m; + const ordinal_type rbeg = s.row_begin; + UnmanagedViewType ipiv(piv + rbeg * 4, 4 * m); + + const ordinal_type n = s.n - s.m; + const ordinal_type nrhs = info.x.extent(1); + const ordinal_type bufsize_required = n * nrhs * sizeof(value_type); + + TACHO_TEST_FOR_ABORT(bufsize < bufsize_required, "bufsize is smaller than required"); + + UnmanagedViewType xB((value_type *)buf, n, nrhs); + + CholSupernodes::update_solve_upper(member, info, xB, sid); + + LU_Supernodes::solve_upper(member, info, ipiv, xB, sid); + } + + if (final) { + // serial recursion + for (ordinal_type i = 0; i < s.nchildren; ++i) + solve_upper_recursive_serial(member, info, s.children[i], final, piv, buf, bufsize); + } + return 0; + } +}; +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.cpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.cpp index 392b26926cf3..b8591821586d 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.cpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.cpp @@ -2,732 +2,535 @@ /// \brief Lapack wrapper /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "Tacho_config.h" #include "Tacho_Lapack_External.hpp" +#include "Tacho_config.h" #include "Kokkos_Core.hpp" extern "C" { - /// - /// Cholesky - /// - void F77_BLAS_MANGLE(spotrf,SPOTRF)( const char*, - const int*, - float*, const int*, - int* ); - void F77_BLAS_MANGLE(dpotrf,DPOTRF)( const char*, - const int*, - double*, const int*, - int*); - void F77_BLAS_MANGLE(cpotrf,CPOTRF)( const char*, - const int*, - Kokkos::complex*, const int*, - int*); - void F77_BLAS_MANGLE(zpotrf,ZPOTRF)( const char*, - const int*, - Kokkos::complex*, const int*, - int*); - - /// - /// LDLt - /// - void F77_BLAS_MANGLE(ssytrf,SSYTRF)( const char*, - const int*, - float*, const int*, - int*, - float*, int*, - int*); - void F77_BLAS_MANGLE(dsytrf,DSYTRF)( const char*, - const int*, - double*, const int*, - int*, - double*, int*, - int*); - void F77_BLAS_MANGLE(csytrf,CSYTRF)( const char*, - const int*, - Kokkos::complex*, const int*, - int*, - Kokkos::complex*, int *, - int*); - void F77_BLAS_MANGLE(zsytrf,ZSYTRF)( const char*, - const int*, - Kokkos::complex*, const int*, - int*, - Kokkos::complex*, int*, - int*); - - /// - /// LU - ///M, N, A, LDA, IPIV, INFO ) - void F77_BLAS_MANGLE(sgetrf,SGETRF)( const int*, const int*, - float*, const int*, - int*, - int*); - void F77_BLAS_MANGLE(dgetrf,DGETRF)( const int*, const int*, - double*, const int*, - int*, - int*); - void F77_BLAS_MANGLE(cgetrf,CGETRF)( const int*, const int*, - Kokkos::complex*, const int*, - int*, - int*); - void F77_BLAS_MANGLE(zgetrf,ZGETRF)( const int*, const int*, - Kokkos::complex*, const int*, - int*, - int*); +/// +/// Cholesky +/// +void F77_BLAS_MANGLE(spotrf, SPOTRF)(const char *, const int *, float *, const int *, int *); +void F77_BLAS_MANGLE(dpotrf, DPOTRF)(const char *, const int *, double *, const int *, int *); +void F77_BLAS_MANGLE(cpotrf, CPOTRF)(const char *, const int *, Kokkos::complex *, const int *, int *); +void F77_BLAS_MANGLE(zpotrf, ZPOTRF)(const char *, const int *, Kokkos::complex *, const int *, int *); + +/// +/// LDLt +/// +void F77_BLAS_MANGLE(ssytrf, SSYTRF)(const char *, const int *, float *, const int *, int *, float *, int *, int *); +void F77_BLAS_MANGLE(dsytrf, DSYTRF)(const char *, const int *, double *, const int *, int *, double *, int *, int *); +void F77_BLAS_MANGLE(csytrf, CSYTRF)(const char *, const int *, Kokkos::complex *, const int *, int *, + Kokkos::complex *, int *, int *); +void F77_BLAS_MANGLE(zsytrf, ZSYTRF)(const char *, const int *, Kokkos::complex *, const int *, int *, + Kokkos::complex *, int *, int *); + +/// +/// LU +/// M, N, A, LDA, IPIV, INFO ) +void F77_BLAS_MANGLE(sgetrf, SGETRF)(const int *, const int *, float *, const int *, int *, int *); +void F77_BLAS_MANGLE(dgetrf, DGETRF)(const int *, const int *, double *, const int *, int *, int *); +void F77_BLAS_MANGLE(cgetrf, CGETRF)(const int *, const int *, Kokkos::complex *, const int *, int *, int *); +void F77_BLAS_MANGLE(zgetrf, ZGETRF)(const int *, const int *, Kokkos::complex *, const int *, int *, int *); } namespace Tacho { -#define F77_FUNC_SPOTRF F77_BLAS_MANGLE(spotrf,SPOTRF) -#define F77_FUNC_DPOTRF F77_BLAS_MANGLE(dpotrf,DPOTRF) -#define F77_FUNC_CPOTRF F77_BLAS_MANGLE(cpotrf,CPOTRF) -#define F77_FUNC_ZPOTRF F77_BLAS_MANGLE(zpotrf,ZPOTRF) - -#define F77_FUNC_SSYTRF F77_BLAS_MANGLE(ssytrf,SSYTRF) -#define F77_FUNC_DSYTRF F77_BLAS_MANGLE(dsytrf,DSYTRF) -#define F77_FUNC_CSYTRF F77_BLAS_MANGLE(csytrf,CSYTRF) -#define F77_FUNC_ZSYTRF F77_BLAS_MANGLE(zsytrf,ZSYTRF) - -#define F77_FUNC_SGETRF F77_BLAS_MANGLE(sgetrf,SGETRF) -#define F77_FUNC_DGETRF F77_BLAS_MANGLE(dgetrf,DGETRF) -#define F77_FUNC_CGETRF F77_BLAS_MANGLE(cgetrf,CGETRF) -#define F77_FUNC_ZGETRF F77_BLAS_MANGLE(zgetrf,ZGETRF) - - template<> - int - Lapack::potrf(const char uplo, - const int m, - float *a, const int lda, +#define F77_FUNC_SPOTRF F77_BLAS_MANGLE(spotrf, SPOTRF) +#define F77_FUNC_DPOTRF F77_BLAS_MANGLE(dpotrf, DPOTRF) +#define F77_FUNC_CPOTRF F77_BLAS_MANGLE(cpotrf, CPOTRF) +#define F77_FUNC_ZPOTRF F77_BLAS_MANGLE(zpotrf, ZPOTRF) + +#define F77_FUNC_SSYTRF F77_BLAS_MANGLE(ssytrf, SSYTRF) +#define F77_FUNC_DSYTRF F77_BLAS_MANGLE(dsytrf, DSYTRF) +#define F77_FUNC_CSYTRF F77_BLAS_MANGLE(csytrf, CSYTRF) +#define F77_FUNC_ZSYTRF F77_BLAS_MANGLE(zsytrf, ZSYTRF) + +#define F77_FUNC_SGETRF F77_BLAS_MANGLE(sgetrf, SGETRF) +#define F77_FUNC_DGETRF F77_BLAS_MANGLE(dgetrf, DGETRF) +#define F77_FUNC_CGETRF F77_BLAS_MANGLE(cgetrf, CGETRF) +#define F77_FUNC_ZGETRF F77_BLAS_MANGLE(zgetrf, ZGETRF) + +template <> int Lapack::potrf(const char uplo, const int m, float *a, const int lda, int *info) { + F77_FUNC_SPOTRF(&uplo, &m, a, &lda, info); + return 0; +} +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack::potrf_buffersize(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, float *a, + const int lda, int *lwork) { + const int r_val = cusolverDnSpotrf_bufferSize(handle, uplo, m, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::potrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, float *a, const int lda, + float *w, const int lwork, int *dev) { + const int r_val = cusolverDnSpotrf(handle, uplo, m, a, lda, w, lwork, dev); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack::potrf_buffersize(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, float *a, + const int lda, int *lwork) { + const int r_val = hipsolverSpotrf_bufferSize(handle, uplo, m, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::potrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, float *a, const int lda, + float *w, const int lwork, int *dev) { + const int r_val = hipsolverSpotrf(handle, uplo, m, a, lda, w, lwork, dev); + return r_val; +} +#endif + +template <> +int Lapack::sytrf(const char uplo, const int m, float *a, const int lda, int *ipiv, float *work, int lwork, int *info) { - F77_FUNC_SPOTRF(&uplo, - &m, - a, &lda, - info); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack::potrf_buffersize(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - float *a, const int lda, - int *lwork) { - const int r_val = cusolverDnSpotrf_bufferSize(handle, - uplo, - m, - a, lda, - lwork); - return r_val; - } - - template<> - int - Lapack::potrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - float *a, const int lda, - float *w, const int lwork, - int *dev) { - const int r_val = cusolverDnSpotrf(handle, - uplo, - m, - a, lda, - w, lwork, - dev); - return r_val; - } + F77_FUNC_SSYTRF(&uplo, &m, a, &lda, ipiv, work, &lwork, info); + return 0; +} + +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack::sytrf_buffersize(cusolverDnHandle_t handle, const int m, float *a, const int lda, int *lwork) { + const int r_val = cusolverDnSsytrf_bufferSize(handle, m, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::sytrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, float *a, const int lda, + int *ipiv, float *w, const int lwork, int *dev) { + const int r_val = cusolverDnSsytrf(handle, uplo, m, a, lda, ipiv, w, lwork, dev); + return r_val; +} #endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack::sytrf_buffersize(hipsolverHandle_t handle, const int m, float *a, const int lda, int *lwork) { + const int r_val = hipsolverSsytrf_bufferSize(handle, m, a, lda, lwork); + return r_val; +} - template<> - int - Lapack::sytrf(const char uplo, - const int m, - float *a, const int lda, - int *ipiv, - float *work, int lwork, - int *info) { - F77_FUNC_SSYTRF(&uplo, - &m, - a, &lda, - ipiv, - work, &lwork, - info); - return 0; - } - -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack::sytrf_buffersize(cusolverDnHandle_t handle, - const int m, - float *a, const int lda, - int *lwork) { - const int r_val = cusolverDnSsytrf_bufferSize(handle, - m, - a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack::sytrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - float *a, const int lda, - int *ipiv, - float *w, const int lwork, - int *dev) { - const int r_val = cusolverDnSsytrf(handle, - uplo, - m, - a, lda, - ipiv, - w, lwork, - dev); - return r_val; - } +template <> +int Lapack::sytrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, float *a, const int lda, + int *ipiv, float *w, const int lwork, int *dev) { + const int r_val = hipsolverSsytrf(handle, uplo, m, a, lda, ipiv, w, lwork, dev); + return r_val; +} #endif - template<> - int - Lapack::getrf(const int m, const int n, - float *a, const int lda, - int *ipiv, - int *info) { - F77_FUNC_SGETRF(&m, &n, - a, &lda, - ipiv, - info); - return 0; - } - -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack::getrf_buffersize(cusolverDnHandle_t handle, - const int m, const int n, - float *a, const int lda, - int *lwork) { - const int r_val = cusolverDnSgetrf_bufferSize(handle, - m, n, - a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack::getrf(cusolverDnHandle_t handle, - const int m, const int n, - float *a, const int lda, - float *w, - int *ipiv, - int *dev) { - const int r_val = cusolverDnSgetrf(handle, - m, n, - a, lda, - w, - ipiv, - dev); - return r_val; - } +template <> int Lapack::getrf(const int m, const int n, float *a, const int lda, int *ipiv, int *info) { + F77_FUNC_SGETRF(&m, &n, a, &lda, ipiv, info); + return 0; +} + +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack::getrf_buffersize(cusolverDnHandle_t handle, const int m, const int n, float *a, const int lda, + int *lwork) { + const int r_val = cusolverDnSgetrf_bufferSize(handle, m, n, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::getrf(cusolverDnHandle_t handle, const int m, const int n, float *a, const int lda, float *w, + int *ipiv, int *dev) { + const int r_val = cusolverDnSgetrf(handle, m, n, a, lda, w, ipiv, dev); + return r_val; +} #endif - - template<> - int - Lapack::potrf(const char uplo, - const int m, - double *a, const int lda, - int *info) { - F77_FUNC_DPOTRF(&uplo, - &m, - a, &lda, - info); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack::potrf_buffersize(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - double *a, const int lda, - int *lwork) { - const int r_val = cusolverDnDpotrf_bufferSize(handle, - uplo, - m, - a, lda, - lwork); - return r_val; - } - - template<> - int - Lapack::potrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - double *a, const int lda, - double *w, const int lwork, - int *dev) { - const int r_val = cusolverDnDpotrf(handle, - uplo, - m, - a, lda, - w, lwork, - dev); - return r_val; - } +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack::getrf_buffersize(hipsolverHandle_t handle, const int m, const int n, float *a, const int lda, + int *lwork) { + const int r_val = hipsolverSgetrf_bufferSize(handle, m, n, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::getrf(hipsolverHandle_t handle, const int m, const int n, float *a, const int lda, float *w, + int *ipiv, int *dev) { + const int r_val = hipsolverSgetrf(handle, m, n, a, lda, w, ipiv, dev); + return r_val; +} #endif - template<> - int - Lapack::sytrf(const char uplo, - const int m, - double *a, const int lda, - int *ipiv, - double *work, int lwork, - int *info) { - F77_FUNC_DSYTRF(&uplo, - &m, - a, &lda, - ipiv, - work, &lwork, - info); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack::sytrf_buffersize(cusolverDnHandle_t handle, - const int m, - double *a, const int lda, - int *lwork) { - const int r_val = cusolverDnDsytrf_bufferSize(handle, - m, - a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack::sytrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - double *a, const int lda, - int *ipiv, - double *w, const int lwork, - int *dev) { - const int r_val = cusolverDnDsytrf(handle, - uplo, - m, - a, lda, - ipiv, - w, lwork, - dev); - return r_val; - } +template <> int Lapack::potrf(const char uplo, const int m, double *a, const int lda, int *info) { + F77_FUNC_DPOTRF(&uplo, &m, a, &lda, info); + return 0; +} +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack::potrf_buffersize(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, double *a, + const int lda, int *lwork) { + const int r_val = cusolverDnDpotrf_bufferSize(handle, uplo, m, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::potrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, double *a, const int lda, + double *w, const int lwork, int *dev) { + const int r_val = cusolverDnDpotrf(handle, uplo, m, a, lda, w, lwork, dev); + return r_val; +} #endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack::potrf_buffersize(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, double *a, + const int lda, int *lwork) { + const int r_val = hipsolverDpotrf_bufferSize(handle, uplo, m, a, lda, lwork); + return r_val; +} - template<> - int - Lapack::getrf(const int m, const int n, - double *a, const int lda, - int *ipiv, - int *info) { - F77_FUNC_DGETRF(&m, &n, - a, &lda, - ipiv, - info); - return 0; - } - -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack::getrf_buffersize(cusolverDnHandle_t handle, - const int m, const int n, - double *a, const int lda, - int *lwork) { - const int r_val = cusolverDnDgetrf_bufferSize(handle, - m, n, - a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack::getrf(cusolverDnHandle_t handle, - const int m, const int n, - double *a, const int lda, - double *w, - int *ipiv, - int *dev) { - const int r_val = cusolverDnDgetrf(handle, - m, n, - a, lda, - w, - ipiv, - dev); - return r_val; - } +template <> +int Lapack::potrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, double *a, + const int lda, double *w, const int lwork, int *dev) { + const int r_val = hipsolverDpotrf(handle, uplo, m, a, lda, w, lwork, dev); + return r_val; +} #endif - template<> - int - Lapack >::potrf(const char uplo, - const int m, - Kokkos::complex *a, const int lda, - int *info) { - F77_FUNC_CPOTRF(&uplo, - &m, - a, &lda, - info); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack >::potrf_buffersize(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - Kokkos::complex *a, const int lda, - int *lwork) { - const int r_val = cusolverDnCpotrf_bufferSize(handle, - uplo, - m, - (cuComplex*)a, lda, - lwork); - return r_val; - } - - template<> - int - Lapack >::potrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - Kokkos::complex *a, const int lda, - Kokkos::complex *w, const int lwork, - int *dev) { - const int r_val = cusolverDnCpotrf(handle, - uplo, - m, - (cuComplex*)a, lda, - (cuComplex*)w, lwork, - dev); - return r_val; - } +template <> +int Lapack::sytrf(const char uplo, const int m, double *a, const int lda, int *ipiv, double *work, int lwork, + int *info) { + F77_FUNC_DSYTRF(&uplo, &m, a, &lda, ipiv, work, &lwork, info); + return 0; +} +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack::sytrf_buffersize(cusolverDnHandle_t handle, const int m, double *a, const int lda, int *lwork) { + const int r_val = cusolverDnDsytrf_bufferSize(handle, m, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::sytrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, double *a, const int lda, + int *ipiv, double *w, const int lwork, int *dev) { + const int r_val = cusolverDnDsytrf(handle, uplo, m, a, lda, ipiv, w, lwork, dev); + return r_val; +} #endif - - template<> - int - Lapack >::sytrf(const char uplo, - const int m, - Kokkos::complex *a, const int lda, - int *ipiv, - Kokkos::complex *work, int lwork, - int *info) { - F77_FUNC_CSYTRF(&uplo, - &m, - a, &lda, - ipiv, - work, &lwork, - info); - return 0; - } - -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack >::sytrf_buffersize(cusolverDnHandle_t handle, - const int m, - Kokkos::complex *a, const int lda, - int *lwork) { - const int r_val = cusolverDnCsytrf_bufferSize(handle, - m, - (cuComplex*)a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack >::sytrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - Kokkos::complex *a, const int lda, - int *ipiv, - Kokkos::complex *w, const int lwork, - int *dev) { - const int r_val = cusolverDnCsytrf(handle, - uplo, - m, - (cuComplex*)a, lda, - ipiv, - (cuComplex*)w, lwork, - dev); - return r_val; - } +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack::sytrf_buffersize(hipsolverHandle_t handle, const int m, double *a, const int lda, int *lwork) { + const int r_val = hipsolverDsytrf_bufferSize(handle, m, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::sytrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, double *a, + const int lda, int *ipiv, double *w, const int lwork, int *dev) { + const int r_val = hipsolverDsytrf(handle, uplo, m, a, lda, ipiv, w, lwork, dev); + return r_val; +} #endif +template <> int Lapack::getrf(const int m, const int n, double *a, const int lda, int *ipiv, int *info) { + F77_FUNC_DGETRF(&m, &n, a, &lda, ipiv, info); + return 0; +} - template<> - int - Lapack >::getrf(const int m, const int n, - Kokkos::complex *a, const int lda, - int *ipiv, - int *info) { - F77_FUNC_CGETRF(&m, &n, - a, &lda, - ipiv, - info); - return 0; - } - -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack >::getrf_buffersize(cusolverDnHandle_t handle, - const int m, const int n, - Kokkos::complex *a, const int lda, - int *lwork) { - const int r_val = cusolverDnCgetrf_bufferSize(handle, - m, n, - (cuComplex*)a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack >::getrf(cusolverDnHandle_t handle, - const int m, const int n, - Kokkos::complex *a, const int lda, - Kokkos::complex *w, - int *ipiv, - int *dev) { - const int r_val = cusolverDnCgetrf(handle, - m, n, - (cuComplex*)a, lda, - (cuComplex*)w, - ipiv, - dev); - return r_val; - } +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack::getrf_buffersize(cusolverDnHandle_t handle, const int m, const int n, double *a, const int lda, + int *lwork) { + const int r_val = cusolverDnDgetrf_bufferSize(handle, m, n, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::getrf(cusolverDnHandle_t handle, const int m, const int n, double *a, const int lda, double *w, + int *ipiv, int *dev) { + const int r_val = cusolverDnDgetrf(handle, m, n, a, lda, w, ipiv, dev); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack::getrf_buffersize(hipsolverHandle_t handle, const int m, const int n, double *a, const int lda, + int *lwork) { + const int r_val = hipsolverDgetrf_bufferSize(handle, m, n, a, lda, lwork); + return r_val; +} + +template <> +int Lapack::getrf(hipsolverHandle_t handle, const int m, const int n, double *a, const int lda, double *w, + int *ipiv, int *dev) { + const int r_val = hipsolverDgetrf(handle, m, n, a, lda, w, ipiv, dev); + return r_val; +} #endif - template<> - int - Lapack >::potrf(const char uplo, - const int m, - Kokkos::complex *a, const int lda, +template <> +int Lapack>::potrf(const char uplo, const int m, Kokkos::complex *a, const int lda, int *info) { - F77_FUNC_ZPOTRF(&uplo, - &m, - a, &lda, - info); - return 0; - } -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack >::potrf_buffersize(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - Kokkos::complex *a, const int lda, + F77_FUNC_CPOTRF(&uplo, &m, a, &lda, info); + return 0; +} +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack>::potrf_buffersize(cusolverDnHandle_t handle, const cublasFillMode_t uplo, + const int m, Kokkos::complex *a, const int lda, int *lwork) { - const int r_val = cusolverDnZpotrf_bufferSize(handle, - uplo, - m, - (cuDoubleComplex*)a, lda, - lwork); - return r_val; - } - - template<> - int - Lapack >::potrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - Kokkos::complex *a, const int lda, - Kokkos::complex *w, const int lwork, - int *dev) { - const int r_val = cusolverDnZpotrf(handle, - uplo, - m, - (cuDoubleComplex*)a, lda, - (cuDoubleComplex*)w, lwork, - dev); - return r_val; - } + const int r_val = cusolverDnCpotrf_bufferSize(handle, uplo, m, (cuComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::potrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, Kokkos::complex *w, + const int lwork, int *dev) { + const int r_val = cusolverDnCpotrf(handle, uplo, m, (cuComplex *)a, lda, (cuComplex *)w, lwork, dev); + return r_val; +} #endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack>::potrf_buffersize(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, + const int m, Kokkos::complex *a, const int lda, + int *lwork) { + const int r_val = hipsolverCpotrf_bufferSize(handle, uplo, m, (cuComplex *)a, lda, lwork); + return r_val; +} - template<> - int - Lapack >::sytrf(const char uplo, - const int m, - Kokkos::complex *a, const int lda, - int *ipiv, - Kokkos::complex* work, int lwork, +template <> +int Lapack>::potrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, Kokkos::complex *w, + const int lwork, int *dev) { + const int r_val = hipsolverCpotrf(handle, uplo, m, (cuComplex *)a, lda, (cuComplex *)w, lwork, dev); + return r_val; +} +#endif + +template <> +int Lapack>::sytrf(const char uplo, const int m, Kokkos::complex *a, const int lda, + int *ipiv, Kokkos::complex *work, int lwork, int *info) { + F77_FUNC_CSYTRF(&uplo, &m, a, &lda, ipiv, work, &lwork, info); + return 0; +} + +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack>::sytrf_buffersize(cusolverDnHandle_t handle, const int m, Kokkos::complex *a, + const int lda, int *lwork) { + const int r_val = cusolverDnCsytrf_bufferSize(handle, m, (cuComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::sytrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, int *ipiv, + Kokkos::complex *w, const int lwork, int *dev) { + const int r_val = cusolverDnCsytrf(handle, uplo, m, (cuComplex *)a, lda, ipiv, (cuComplex *)w, lwork, dev); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack>::sytrf_buffersize(hipsolverHandle_t handle, const int m, Kokkos::complex *a, + const int lda, int *lwork) { + const int r_val = hipsolverCsytrf_bufferSize(handle, m, (cuComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::sytrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, int *ipiv, + Kokkos::complex *w, const int lwork, int *dev) { + const int r_val = hipsolverCsytrf(handle, uplo, m, (cuComplex *)a, lda, ipiv, (cuComplex *)w, lwork, dev); + return r_val; +} +#endif + +template <> +int Lapack>::getrf(const int m, const int n, Kokkos::complex *a, const int lda, int *ipiv, int *info) { - F77_FUNC_ZSYTRF(&uplo, - &m, - a, &lda, - ipiv, - work, &lwork, - info); - return 0; - } - -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack >::sytrf_buffersize(cusolverDnHandle_t handle, - const int m, - Kokkos::complex *a, const int lda, - int *lwork) { - const int r_val = cusolverDnZsytrf_bufferSize(handle, - m, - (cuDoubleComplex*)a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack >::sytrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - Kokkos::complex *a, const int lda, - int *ipiv, - Kokkos::complex *w, const int lwork, - int *dev) { - const int r_val = cusolverDnZsytrf(handle, - uplo, - m, - (cuDoubleComplex*)a, lda, - ipiv, - (cuDoubleComplex*)w, lwork, - dev); - return r_val; - } + F77_FUNC_CGETRF(&m, &n, a, &lda, ipiv, info); + return 0; +} + +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack>::getrf_buffersize(cusolverDnHandle_t handle, const int m, const int n, + Kokkos::complex *a, const int lda, int *lwork) { + const int r_val = cusolverDnCgetrf_bufferSize(handle, m, n, (cuComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::getrf(cusolverDnHandle_t handle, const int m, const int n, + Kokkos::complex *a, const int lda, Kokkos::complex *w, + int *ipiv, int *dev) { + const int r_val = cusolverDnCgetrf(handle, m, n, (cuComplex *)a, lda, (cuComplex *)w, ipiv, dev); + return r_val; +} #endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack>::getrf_buffersize(hipsolverHandle_t handle, const int m, const int n, + Kokkos::complex *a, const int lda, int *lwork) { + const int r_val = hipsolverCgetrf_bufferSize(handle, m, n, (cuComplex *)a, lda, lwork); + return r_val; +} - template<> - int - Lapack >::getrf(const int m, const int n, - Kokkos::complex *a, const int lda, - int *ipiv, - int *info) { - F77_FUNC_ZGETRF(&m, &n, - a, &lda, - ipiv, - info); - return 0; - } - -#if defined(KOKKOS_ENABLE_CUDA) - template<> - int - Lapack >::getrf_buffersize(cusolverDnHandle_t handle, - const int m, const int n, - Kokkos::complex *a, const int lda, - int *lwork) { - const int r_val = cusolverDnZgetrf_bufferSize(handle, - m, n, - (cuDoubleComplex*)a, - lda, - lwork); - return r_val; - } - - template<> - int - Lapack >::getrf(cusolverDnHandle_t handle, - const int m, const int n, - Kokkos::complex *a, const int lda, - Kokkos::complex *w, - int *ipiv, - int *dev) { - const int r_val = cusolverDnZgetrf(handle, - m, n, - (cuDoubleComplex*)a, lda, - (cuDoubleComplex*)w, - ipiv, - dev); - return r_val; - } +template <> +int Lapack>::getrf(hipsolverHandle_t handle, const int m, const int n, Kokkos::complex *a, + const int lda, Kokkos::complex *w, int *ipiv, int *dev) { + const int r_val = hipsolverCgetrf(handle, m, n, (cuComplex *)a, lda, (cuComplex *)w, ipiv, dev); + return r_val; +} #endif +template <> +int Lapack>::potrf(const char uplo, const int m, Kokkos::complex *a, const int lda, + int *info) { + F77_FUNC_ZPOTRF(&uplo, &m, a, &lda, info); + return 0; +} +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack>::potrf_buffersize(cusolverDnHandle_t handle, const cublasFillMode_t uplo, + const int m, Kokkos::complex *a, const int lda, + int *lwork) { + const int r_val = cusolverDnZpotrf_bufferSize(handle, uplo, m, (cuDoubleComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::potrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, Kokkos::complex *w, + const int lwork, int *dev) { + const int r_val = cusolverDnZpotrf(handle, uplo, m, (cuDoubleComplex *)a, lda, (cuDoubleComplex *)w, lwork, dev); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack>::potrf_buffersize(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, + const int m, Kokkos::complex *a, const int lda, + int *lwork) { + const int r_val = hipsolverZpotrf_bufferSize(handle, uplo, m, (cuDoubleComplex *)a, lda, lwork); + return r_val; +} - template<> - int - Lapack >::potrf(const char uplo, - const int m, - std::complex *a, const int lda, - int *info) { - F77_FUNC_CPOTRF(&uplo, - &m, - (Kokkos::complex*)a, &lda, - info); - return 0; - } - - template<> - int - Lapack >::sytrf(const char uplo, - const int m, - std::complex *a, const int lda, - int *ipiv, - std::complex *work, int lwork, - int *info) { - F77_FUNC_CSYTRF(&uplo, - &m, - (Kokkos::complex*)a, &lda, - ipiv, - (Kokkos::complex*)work, &lwork, - info); - return 0; - } - - template<> - int - Lapack >::potrf(const char uplo, - const int m, - std::complex *a, const int lda, - int *info) { - F77_FUNC_ZPOTRF(&uplo, - &m, - (Kokkos::complex*)a, &lda, - info); - return 0; - } - template<> - int - Lapack >::sytrf(const char uplo, - const int m, - std::complex *a, const int lda, - int *ipiv, - std::complex* work, int lwork, - int *info) { - F77_FUNC_ZSYTRF(&uplo, - &m, - (Kokkos::complex*)a, &lda, - ipiv, - (Kokkos::complex*)work, &lwork, - info); - return 0; - } +template <> +int Lapack>::potrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, Kokkos::complex *w, + const int lwork, int *dev) { + const int r_val = hipsolverZpotrf(handle, uplo, m, (cuDoubleComplex *)a, lda, (cuDoubleComplex *)w, lwork, dev); + return r_val; +} +#endif +template <> +int Lapack>::sytrf(const char uplo, const int m, Kokkos::complex *a, const int lda, + int *ipiv, Kokkos::complex *work, int lwork, int *info) { + F77_FUNC_ZSYTRF(&uplo, &m, a, &lda, ipiv, work, &lwork, info); + return 0; } + +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack>::sytrf_buffersize(cusolverDnHandle_t handle, const int m, + Kokkos::complex *a, const int lda, int *lwork) { + const int r_val = cusolverDnZsytrf_bufferSize(handle, m, (cuDoubleComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::sytrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, int *ipiv, + Kokkos::complex *w, const int lwork, int *dev) { + const int r_val = + cusolverDnZsytrf(handle, uplo, m, (cuDoubleComplex *)a, lda, ipiv, (cuDoubleComplex *)w, lwork, dev); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack>::sytrf_buffersize(hipsolverHandle_t handle, const int m, Kokkos::complex *a, + const int lda, int *lwork) { + const int r_val = hipsolverZsytrf_bufferSize(handle, m, (cuDoubleComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::sytrf(hipsolverHandle_t handle, const hipsolverFillMode_t uplo, const int m, + Kokkos::complex *a, const int lda, int *ipiv, + Kokkos::complex *w, const int lwork, int *dev) { + const int r_val = hipsolverZsytrf(handle, uplo, m, (cuDoubleComplex *)a, lda, ipiv, (cuDoubleComplex *)w, lwork, dev); + return r_val; +} +#endif + +template <> +int Lapack>::getrf(const int m, const int n, Kokkos::complex *a, const int lda, + int *ipiv, int *info) { + F77_FUNC_ZGETRF(&m, &n, a, &lda, ipiv, info); + return 0; +} + +#if defined(TACHO_ENABLE_CUSOLVER) +template <> +int Lapack>::getrf_buffersize(cusolverDnHandle_t handle, const int m, const int n, + Kokkos::complex *a, const int lda, int *lwork) { + const int r_val = cusolverDnZgetrf_bufferSize(handle, m, n, (cuDoubleComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::getrf(cusolverDnHandle_t handle, const int m, const int n, + Kokkos::complex *a, const int lda, Kokkos::complex *w, + int *ipiv, int *dev) { + const int r_val = cusolverDnZgetrf(handle, m, n, (cuDoubleComplex *)a, lda, (cuDoubleComplex *)w, ipiv, dev); + return r_val; +} +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) +template <> +int Lapack>::getrf_buffersize(hipsolverHandle_t handle, const int m, const int n, + Kokkos::complex *a, const int lda, int *lwork) { + const int r_val = hipsolverZgetrf_bufferSize(handle, m, n, (cuDoubleComplex *)a, lda, lwork); + return r_val; +} + +template <> +int Lapack>::getrf(hipsolverHandle_t handle, const int m, const int n, + Kokkos::complex *a, const int lda, Kokkos::complex *w, + int *ipiv, int *dev) { + const int r_val = hipsolverZgetrf(handle, m, n, (cuDoubleComplex *)a, lda, (cuDoubleComplex *)w, ipiv, dev); + return r_val; +} +#endif + +template <> +int Lapack>::potrf(const char uplo, const int m, std::complex *a, const int lda, int *info) { + F77_FUNC_CPOTRF(&uplo, &m, (Kokkos::complex *)a, &lda, info); + return 0; +} + +template <> +int Lapack>::sytrf(const char uplo, const int m, std::complex *a, const int lda, int *ipiv, + std::complex *work, int lwork, int *info) { + F77_FUNC_CSYTRF(&uplo, &m, (Kokkos::complex *)a, &lda, ipiv, (Kokkos::complex *)work, &lwork, info); + return 0; +} + +template <> +int Lapack>::potrf(const char uplo, const int m, std::complex *a, const int lda, + int *info) { + F77_FUNC_ZPOTRF(&uplo, &m, (Kokkos::complex *)a, &lda, info); + return 0; +} +template <> +int Lapack>::sytrf(const char uplo, const int m, std::complex *a, const int lda, int *ipiv, + std::complex *work, int lwork, int *info) { + F77_FUNC_ZSYTRF(&uplo, &m, (Kokkos::complex *)a, &lda, ipiv, (Kokkos::complex *)work, &lwork, info); + return 0; +} + +} // namespace Tacho diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.hpp index b381b3821fcf..a6859a4b5614 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_External.hpp @@ -5,90 +5,72 @@ /// \brief BLAS wrapper /// \author Kyungjoo Kim (kyukim@sandia.gov) -#include "Kokkos_Core.hpp" // CUDA specialization +#include "Kokkos_Core.hpp" -#if defined(KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) +#define TACHO_ENABLE_CUSOLVER +#endif + +#if defined(KOKKOS_ENABLE_HIP) +// todo: enable hipblas interface after checking on AMD machine +//#define TACHO_ENABLE_HIPSOLVER +#endif + +#if defined(TACHO_ENABLE_CUSOLVER) #include "cusolverDn.h" #endif +#if defined(TACHO_ENABLE_HIPSOLVER) +#include "hipsolver.h" +#endif namespace Tacho { - template - struct Lapack { - /// - /// Cholesky - /// - static - int potrf(const char uplo, - const int m, - T *a, const int lda, - int *info); -#if defined (KOKKOS_ENABLE_CUDA) - static - int potrf_buffersize(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - T *a, const int lda, - int *lwork); - static - int potrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - T *a, const int lda, - T *W, const int lwork, - int *dev); +template struct Lapack { + /// + /// Cholesky + /// + static int potrf(const char uplo, const int m, T *a, const int lda, int *info); +#if defined(TACHO_ENABLE_CUSOLVER) + static int potrf_buffersize(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, T *a, const int lda, + int *lwork); + static int potrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, T *a, const int lda, T *W, + const int lwork, int *dev); +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) + static int potrf_buffersize(hipsolverHandle_t handle, const hipblasFillMode_t uplo, const int m, T *a, const int lda, + int *lwork); + static int potrf(hipsolverHandle_t handle, const hipblasFillMode_t uplo, const int m, T *a, const int lda, T *W, + const int lwork, int *dev); #endif - /// - /// LDLt - /// - static - int sytrf(const char uplo, - const int m, - T *a, const int lda, - int *ipiv, - T *work, int lwork, - int *info); -#if defined (KOKKOS_ENABLE_CUDA) - static - int sytrf_buffersize(cusolverDnHandle_t handle, - const int m, - T *a, const int lda, - int *lwork); - static - int sytrf(cusolverDnHandle_t handle, - const cublasFillMode_t uplo, - const int m, - T *a, const int lda, - int *ipiv, - T *W, const int lwork, - int *dev); + /// + /// LDLt + /// + static int sytrf(const char uplo, const int m, T *a, const int lda, int *ipiv, T *work, int lwork, int *info); +#if defined(TACHO_ENABLE_CUSOLVER) + static int sytrf_buffersize(cusolverDnHandle_t handle, const int m, T *a, const int lda, int *lwork); + static int sytrf(cusolverDnHandle_t handle, const cublasFillMode_t uplo, const int m, T *a, const int lda, int *ipiv, + T *W, const int lwork, int *dev); +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) + static int sytrf_buffersize(hipsolverHandle_t handle, const int m, T *a, const int lda, int *lwork); + static int sytrf(hipsolverHandle_t handle, const hipblasFillMode_t uplo, const int m, T *a, const int lda, int *ipiv, + T *W, const int lwork, int *dev); #endif - /// - /// LU - /// - static - int getrf(const int m, const int n, - T *a, const int lda, - int *ipiv, - int *info); -#if defined (KOKKOS_ENABLE_CUDA) - static - int getrf_buffersize(cusolverDnHandle_t handle, - const int m, const int n, - T *a, const int lda, - int *lwork); - static - int getrf(cusolverDnHandle_t handle, - const int m, const int n, - T *a, const int lda, - T *w, - int *ipiv, - int *dev); + /// + /// LU + /// + static int getrf(const int m, const int n, T *a, const int lda, int *ipiv, int *info); +#if defined(TACHO_ENABLE_CUSOLVER) + static int getrf_buffersize(cusolverDnHandle_t handle, const int m, const int n, T *a, const int lda, int *lwork); + static int getrf(cusolverDnHandle_t handle, const int m, const int n, T *a, const int lda, T *w, int *ipiv, int *dev); +#endif +#if defined(TACHO_ENABLE_HIPSOLVER) + static int getrf_buffersize(hipsolverHandle_t handle, const int m, const int n, T *a, const int lda, int *lwork); + static int getrf(hipsolverHandle_t handle, const int m, const int n, T *a, const int lda, T *w, int *ipiv, int *dev); #endif - - }; -} +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_Team.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_Team.hpp index 80bab705f370..3e6ed5148835 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_Team.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Lapack_Team.hpp @@ -9,242 +9,255 @@ namespace Tacho { - template - struct LapackTeam { - struct Impl { - template - static - KOKKOS_INLINE_FUNCTION - void potrf_upper(MemberType &member, - const int m, - T *__restrict__ A, const int as0, const int as1, - int *info) { - if (m <= 0) return; - - typedef ArithTraits arith_traits; - for (int p=0;p struct LapackTeam { + struct Impl { + template + static KOKKOS_INLINE_FUNCTION void potrf_upper(const MemberType &member, const int m, T *__restrict__ A, + const int as0, const int as1, int *info) { + if (m <= 0) + return; - template - static - KOKKOS_INLINE_FUNCTION - void sytrf_lower(MemberType &member, - const int m, - T *__restrict__ A, const int as0, const int as1, - int *__restrict__ ipiv, - int *info) { - if (m <= 0) return; - - using arith_traits = ArithTraits; - using mag_type = typename arith_traits::mag_type; - - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m), - [&](const int &i) { - ipiv[i] = i+1; - }); + typedef ArithTraits arith_traits; + for (int p = 0; p < m; ++p) { + const int jend = m - p - 1; - const int as = as0+as1; - const mag_type mu = 0.6404; - for (int p=0;p::value_type; - reducer_value_type value; - Kokkos::MaxLoc reducer_value(value); - Kokkos::parallel_reduce - (Kokkos::TeamVectorRange(member, iend), - [&](const int &i, reducer_value_type &update) { - const mag_type val = arith_traits::abs(a21[i*as0]); - if (val > update.val) { - update.val = val; - update.loc = i; - } - }, reducer_value); - member.team_barrier(); - - lambda1 = value.val; - idx = value.loc; - } - - const mag_type abs_alpha = arith_traits::abs(*alpha11); - if (abs_alpha < mu*lambda1) { - mag_type lambda2(0); - { - using reducer_value_type = typename Kokkos::Max::value_type; - reducer_value_type value; - Kokkos::Max reducer_value(value); - Kokkos::parallel_reduce - (Kokkos::TeamVectorRange(member, iend), - [&](const int &i, reducer_value_type &update) { - mag_type val(0); - if (i < idx) val = arith_traits::abs(a21[idx*as0+i*as1]); - else if (i > idx) val = arith_traits::abs(A22[idx*as+(i-idx)*as0]); - if (val > update) { - update = val; - } - }, reducer_value); - member.team_barrier(); - lambda2 = value; - } - const mag_type abs_alpha_idx = arith_traits::abs(A22[idx*as]); - if (abs_alpha_idx*lambda2 < mu*lambda1*lambda1) { - /// pivot - Kokkos::parallel_for(Kokkos::TeamVectorRange(member,iend),[&](const int &i) { - if (i < idx) swap(a21[i*as0], A22[idx*as0+i*as1]); - else if (i > idx) swap(a21[i*as0], A22[i*as0+idx*as1]); - else { - swap(alpha11[0], A22[idx*as]); - ipiv[p] = ipiv[p+idx+1]; - } - }); - } - } - member.team_barrier(); - const T alpha = *alpha11; - Kokkos::parallel_for(Kokkos::TeamVectorRange(member,iend),[&](const int &i) { - a21[i*as0] /= alpha; - }); - member.team_barrier(); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,iend),[&](const int &i) { - const T aa = a21[i*as0]; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,i+1),[&](const int &j) { - const T bb = a21[j*as0]; - A22[i*as0+j*as1] -= alpha*aa*bb; - }); - }); - member.team_barrier(); - } + T *__restrict__ alpha11 = A + (p)*as0 + (p)*as1, *__restrict__ a12t = A + (p)*as0 + (p + 1) * as1, + *__restrict__ A22 = A + (p + 1) * as0 + (p + 1) * as1; + + Kokkos::single(Kokkos::PerTeam(member), [&]() { *alpha11 = sqrt(arith_traits::real(*alpha11)); }); + member.team_barrier(); + const auto alpha = arith_traits::real(*alpha11); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, jend), [&](const int &j) { a12t[j * as1] /= alpha; }); + member.team_barrier(); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, jend), [&](const int &j) { + const T aa = arith_traits::conj(a12t[j * as1]); + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j + 1), [&](const int &i) { + const T bb = a12t[i * as1]; + A22[i * as0 + j * as1] -= aa * bb; + }); + }); + member.team_barrier(); + } + } + + template + static KOKKOS_INLINE_FUNCTION void sytrf_lower(const MemberType &member, const int m, T *__restrict__ A, + const int as0, const int as1, int *__restrict__ ipiv, int *info) { + if (m <= 0) + return; + + using arith_traits = ArithTraits; + using mag_type = typename arith_traits::mag_type; + + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const int &i) { ipiv[i] = i + 1; }); + + const int as = as0 + as1; + const mag_type mu = 0.6404; + for (int p = 0; p < m; ++p) { + const int iend = m - p - 1; + + T *__restrict__ alpha11 = A + (p)*as0 + (p)*as1, *__restrict__ a21 = A + (p + 1) * as0 + (p)*as1, + *__restrict__ A22 = A + (p + 1) * as0 + (p + 1) * as1; + + mag_type lambda1(0); + int idx(0); + { + using reducer_value_type = typename Kokkos::MaxLoc::value_type; + reducer_value_type value; + Kokkos::MaxLoc reducer_value(value); + Kokkos::parallel_reduce( + Kokkos::TeamVectorRange(member, iend), + [&](const int &i, reducer_value_type &update) { + const mag_type val = arith_traits::abs(a21[i * as0]); + if (val > update.val) { + update.val = val; + update.loc = i; + } + }, + reducer_value); + member.team_barrier(); + + lambda1 = value.val; + idx = value.loc; } - template - static - KOKKOS_INLINE_FUNCTION - void sytrf_lower_nopiv(MemberType &member, - const int m, - T *__restrict__ A, const int as0, const int as1, - int *info) { - if (m <= 0) return; - - //typedef ArithTraits arith_traits; - for (int p=0;p::value_type; + reducer_value_type value; + Kokkos::Max reducer_value(value); + Kokkos::parallel_reduce( + Kokkos::TeamVectorRange(member, iend), + [&](const int &i, reducer_value_type &update) { + mag_type val(0); + if (i < idx) + val = arith_traits::abs(a21[idx * as0 + i * as1]); + else if (i > idx) + val = arith_traits::abs(A22[idx * as + (i - idx) * as0]); + if (val > update) { + update = val; + } + }, + reducer_value); member.team_barrier(); + lambda2 = value; + } + const mag_type abs_alpha_idx = arith_traits::abs(A22[idx * as]); + if (abs_alpha_idx * lambda2 < mu * lambda1 * lambda1) { + /// pivot + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, iend), [&](const int &i) { + if (i < idx) + swap(a21[i * as0], A22[idx * as0 + i * as1]); + else if (i > idx) + swap(a21[i * as0], A22[i * as0 + idx * as1]); + else { + swap(alpha11[0], A22[idx * as]); + ipiv[p] = p + idx + 2; + } + }); } } - }; - - template - static - KOKKOS_INLINE_FUNCTION - void potrf(MemberType &member, - const char uplo, - const int m, - /* */ T *__restrict__ A, const int lda, - int *info) { - switch (uplo) { - case 'U': - case 'u': { - Impl::potrf_upper(member, - m, - A, 1, lda, - info); - break; - } - case 'L': - case 'l': { - Kokkos::abort("not implemented"); - break; - } - default: - Kokkos::abort("Invalid uplo character"); - } + member.team_barrier(); + const T alpha = *alpha11; + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, iend), [&](const int &i) { a21[i * as0] /= alpha; }); + member.team_barrier(); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, iend), [&](const int &i) { + const T aa = a21[i * as0]; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, i + 1), [&](const int &j) { + const T bb = a21[j * as0]; + A22[i * as0 + j * as1] -= alpha * aa * bb; + }); + }); + member.team_barrier(); } + } + template + static KOKKOS_INLINE_FUNCTION void sytrf_lower_nopiv(const MemberType &member, const int m, T *__restrict__ A, + const int as0, const int as1, int *info) { + if (m <= 0) + return; - template - static - KOKKOS_INLINE_FUNCTION - void sytrf(MemberType &member, - const char uplo, - const int m, - /* */ T *__restrict__ A, const int lda, - /* */ int *__restrict__ P, - /* */ T *__restrict__ W, - int *info) { - switch (uplo) { - case 'U': - case 'u': { - Kokkos::abort("not implemented"); - break; - } - case 'L': - case 'l': { - Impl::sytrf_lower(member, - m, - A, 1, lda, - P, - info); - break; - } - default: - Kokkos::abort("Invalid uplo character"); + // typedef ArithTraits arith_traits; + for (int p = 0; p < m; ++p) { + const int iend = m - p - 1; + + T *__restrict__ alpha11 = A + (p)*as0 + (p)*as1, *__restrict__ a21 = A + (p + 1) * as0 + (p)*as1, + *__restrict__ A22 = A + (p + 1) * as0 + (p + 1) * as1; + + const auto alpha = *alpha11; // arith_traits::real(*alpha11); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, iend), [&](const int &i) { a21[i * as0] /= alpha; }); + member.team_barrier(); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, iend), [&](const int &i) { + const T aa = a21[i * as0]; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, i + 1), [&](const int &j) { + const T bb = a21[j * as0]; + A22[i * as0 + j * as1] -= alpha * aa * bb; + }); + }); + member.team_barrier(); + } + } + }; + + template + static KOKKOS_INLINE_FUNCTION void potrf(const MemberType &member, const char uplo, const int m, + /* */ T *__restrict__ A, const int lda, int *info) { + switch (uplo) { + case 'U': + case 'u': { + Impl::potrf_upper(member, m, A, 1, lda, info); + break; + } + case 'L': + case 'l': { + Kokkos::abort("not implemented"); + break; + } + default: + Kokkos::abort("Invalid uplo character"); + } + } + + template + static KOKKOS_INLINE_FUNCTION void sytrf(const MemberType &member, const char uplo, const int m, + /* */ T *__restrict__ A, const int lda, + /* */ int *__restrict__ P, + /* */ T *__restrict__ W, int *info) { + switch (uplo) { + case 'U': + case 'u': { + Kokkos::abort("not implemented"); + break; + } + case 'L': + case 'l': { + Impl::sytrf_lower(member, m, A, 1, lda, P, info); + break; + } + default: + Kokkos::abort("Invalid uplo character"); + } + } + + template + static KOKKOS_INLINE_FUNCTION void getrf(const MemberType &member, const int m, const int n, T *__restrict__ A, + const int as1, int *__restrict__ ipiv, int *info) { + if (m <= 0 || n <= 0) + return; + + using arith_traits = ArithTraits; + using mag_type = typename arith_traits::mag_type; + + const int as0 = 1; + for (int p = 0; p < m; ++p) { + const int iend = m - p - 1, jend = n - p - 1; + T *__restrict__ alpha11 = A + (p)*as0 + (p)*as1, *__restrict__ AB = A + (p)*as0, *__restrict__ ABR = alpha11, + *__restrict__ a21 = A + (p + 1) * as0 + (p)*as1, *__restrict__ a12 = A + (p)*as0 + (p + 1) * as1, + *__restrict__ A22 = A + (p + 1) * as0 + (p + 1) * as1; + + { + int idx(0); + using reducer_value_type = typename Kokkos::MaxLoc::value_type; + reducer_value_type value; + Kokkos::MaxLoc reducer_value(value); + Kokkos::parallel_reduce( + Kokkos::TeamVectorRange(member, 1 + iend), + [&](const int &i, reducer_value_type &update) { + const mag_type val = arith_traits::abs(ABR[i * as0]); + if (val > update.val) { + update.val = val; + update.loc = i; + } + }, + reducer_value); + member.team_barrier(); + idx = value.loc; + + /// pivot + Kokkos::single(Kokkos::PerThread(member), [&]() { ipiv[p] = p + idx + 1; }); + if (idx) { + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, n), + [&](const int &j) { swap(AB[j * as1], AB[idx * as0 + j * as1]); }); + member.team_barrier(); } } - }; + member.team_barrier(); + const T alpha = *alpha11; + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, iend), [&](const int &i) { a21[i * as0] /= alpha; }); + member.team_barrier(); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, jend), [&](const int &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, iend), + [&](const int &i) { A22[i * as0 + j * as1] -= a21[i * as0] * a12[j * as1]; }); + }); + member.team_barrier(); + } + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals.hpp index b3efac17db35..8c59addc2518 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals.hpp @@ -9,14 +9,13 @@ namespace Tacho { - /// - /// ModifyDiagonals - /// +/// +/// ModifyDiagonals +/// - /// various implementation for different uplo and algo parameters - template - struct ModifyDiagonals; +/// various implementation for different uplo and algo parameters +template struct ModifyDiagonals; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_Internal.hpp index 79d54d214ca0..5cce9dd86301 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_Internal.hpp @@ -1,40 +1,25 @@ #ifndef __TACHO_MODIFY_DIAGONALS_INTERNAL_HPP__ #define __TACHO_MODIFY_DIAGONALS_INTERNAL_HPP__ - /// \file Tacho_ModifyDiagonals_Internal.hpp /// \brief Modify diagonals /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - - template<> - struct ModifyDiagonals { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ScalarType &alpha) { - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - const ordinal_type - m = A.extent(0), - n = A.extent(1), - min_mn = m > n ? n : m; +template <> struct ModifyDiagonals { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A, const ScalarType &alpha) { + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + + const ordinal_type m = A.extent(0), n = A.extent(1), min_mn = m > n ? n : m; - if (min_mn > 0) { - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, min_mn), - [&](const ordinal_type &i) { - A(i,i) += alpha; - }); - } - return 0; + if (min_mn > 0) { + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, min_mn), [&](const ordinal_type &i) { A(i, i) += alpha; }); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_OnDevice.hpp index 584d22c05ebf..42ade7d54139 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_ModifyDiagonals_OnDevice.hpp @@ -1,44 +1,32 @@ #ifndef __TACHO_MODIFY_DIAGONALS_ON_DEVICE_HPP__ #define __TACHO_MODIFY_DIAGONALS_ON_DEVICE_HPP__ - /// \file Tacho_ModifyDiagonals_OnDevice.hpp /// \brief Modify diagonals /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - - template<> - struct ModifyDiagonals { - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ScalarType &alpha) { - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - const ordinal_type - m = A.extent(0), - n = A.extent(1), - min_mn = m > n ? n : m;; - - if (min_mn > 0) { - using exec_space = MemberType; - using range_policy_type = Kokkos::RangePolicy; +template <> struct ModifyDiagonals { + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ScalarType &alpha) { + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + + const ordinal_type m = A.extent(0), n = A.extent(1), min_mn = m > n ? n : m; + ; + + if (min_mn > 0) { + using exec_space = MemberType; + using range_policy_type = Kokkos::RangePolicy; - const auto exec_instance = member; - const auto policy = range_policy_type(exec_instance, 0, min_mn); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const ordinal_type &i) { - A(i,i) += alpha; - }); - } - return 0; + const auto exec_instance = member; + const auto policy = range_policy_type(exec_instance, 0, min_mn); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &i) { A(i, i) += alpha; }); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Base.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Base.hpp index ba8f899578c0..21604b5d1b47 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Base.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Base.hpp @@ -7,356 +7,259 @@ #include "Tacho_Util.hpp" #include "Tacho_CrsMatrixBase.hpp" -#include "Tacho_DenseMatrixView.hpp" - #include "Tacho_SupernodeInfo.hpp" namespace Tacho { - template - class NumericToolsBase { - public: - using value_type = ValueType; - using device_type = DeviceType; - using exec_space = typename device_type::execution_space; - using exec_memory_space = typename device_type::memory_space; - - using range_type = Kokkos::pair; - - using scheduler_type = typename UseThisScheduler::type; - using supernode_info_type = SupernodeInfo; - using crs_matrix_type = typename supernode_info_type::crs_matrix_type; - - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - using size_type_array = typename supernode_info_type::size_type_array; - using value_type_array = typename supernode_info_type::value_type_array; - - using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - using supernode_type_array = typename supernode_info_type::supernode_type_array; - - using host_device_type = typename UseThisDevice::type; - using host_space = typename host_device_type::execution_space; - using host_memory_space = typename host_device_type::memory_space; - - using ordinal_type_array_host = typename ordinal_type_array::HostMirror; - using size_type_array_host = typename size_type_array::HostMirror; - using supernode_type_array_host = typename supernode_type_array::HostMirror; - - protected: - - /// - /// supernode data structure memory "managed" - /// this holds all necessary connectivity data - /// - - // solution method - ordinal_type _method; // 1 - cholesky, 2 - LDL - - // matrix input - ordinal_type _m; - size_type_array _ap; - ordinal_type_array _aj; - value_type_array _ax; - - // graph ordering input - ordinal_type_array _perm, _peri; - - // supernodes - ordinal_type _nsupernodes; - supernode_type_array _supernodes; +template class NumericToolsBase { +public: + using value_type = ValueType; + using device_type = DeviceType; + using exec_space = typename device_type::execution_space; + using exec_memory_space = typename device_type::memory_space; + + using range_type = Kokkos::pair; + + using supernode_info_type = SupernodeInfo; + using crs_matrix_type = typename supernode_info_type::crs_matrix_type; + + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using size_type_array = typename supernode_info_type::size_type_array; + using value_type_array = typename supernode_info_type::value_type_array; + + using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + using supernode_type_array = typename supernode_info_type::supernode_type_array; + + using host_device_type = typename UseThisDevice::type; + using host_space = typename host_device_type::execution_space; + using host_memory_space = typename host_device_type::memory_space; + + using ordinal_type_array_host = typename ordinal_type_array::HostMirror; + using size_type_array_host = typename size_type_array::HostMirror; + using supernode_type_array_host = typename supernode_type_array::HostMirror; + +protected: + /// + /// supernode data structure memory "managed" + /// this holds all necessary connectivity data + /// + + // solution method + ordinal_type _method; // 1 - cholesky, 2 - LDL, 3 - LU + + // matrix input + ordinal_type _m; + size_type_array _ap; + ordinal_type_array _aj; + value_type_array _ax; + + // graph ordering input + ordinal_type_array _perm, _peri; + + // supernodes + ordinal_type _nsupernodes; + supernode_type_array _supernodes; + + // dof mapping to sparse matrix + ordinal_type_array _gid_colidx; + + // supernode map and panel size configuration (sid and column blksize) + ordinal_pair_type_array _sid_block_colidx; + + // supernode tree + ordinal_type_array_host _stree_level, _stree_roots; + + // output : factors, pivot, diagonal blocks + value_type_array _superpanel_buf; + ordinal_type_array _piv; + value_type_array _diag; + + /// + /// supernode info: supernode data structure with "unamanged" view + /// this is passed into computation algorithm without reference counting + /// + supernode_info_type _info; + + /// + /// statistics + /// + struct { + double t_init, t_mode_classification, t_copy, t_factor, t_solve, t_extra; + double m_used, m_peak; + } stat; + + inline void track_alloc(const double in) { + stat.m_used += in; + stat.m_peak = std::max(stat.m_peak, stat.m_used); + } + + inline void track_free(const double out) { stat.m_used -= out; } + + inline void reset_stat() { + stat.t_init = 0; + stat.t_mode_classification = 0; + stat.t_factor = 0; + stat.t_solve = 0; + stat.t_copy = 0; + stat.t_extra = 0; + stat.m_used = 0; + stat.m_peak = 0; + } + + virtual void print_stat_init() { + /// nothing + } + + virtual void print_stat_factor() { + const double kilo(1024); + printf(" Time\n"); + printf(" time for copying A into supernodes: %10.6f s\n", stat.t_copy); + printf(" time for numeric factorization: %10.6f s\n", stat.t_factor); + printf(" total time spent: %10.6f s\n", (stat.t_copy + stat.t_factor)); + printf("\n"); + printf(" Memory\n"); + printf(" memory used in factorization: %10.3f MB\n", stat.m_used / kilo / kilo); + printf(" peak memory used in factorization: %10.3f MB\n", stat.m_peak / kilo / kilo); + printf("\n"); + } + + virtual void print_stat_solve() { + const double kilo(1024); + printf(" Time\n"); + printf(" time for extra work e.g.,copy rhs: %10.6f s\n", stat.t_extra); + printf(" time for numeric solve: %10.6f s\n", stat.t_solve); + printf(" total time spent: %10.6f s\n", (stat.t_solve + stat.t_extra)); + printf(" Memory\n"); + printf(" memory used in solve: %10.3f MB\n", stat.m_used / kilo / kilo); + printf("\n"); + } + + inline void print_stat_memory() { + const double kilo(1024); + printf(" Memory\n"); + printf(" memory used now: %10.3f MB\n", stat.m_used / kilo / kilo); + printf(" peak memory used: %10.3f MB\n", stat.m_peak / kilo / kilo); + printf("\n"); + } + +public: + NumericToolsBase() : _method(0), _m(0), stat() {} + + NumericToolsBase(const NumericToolsBase &b) = default; + + /// + /// construction (assume input matrix and symbolic are from host) + /// + NumericToolsBase(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, const size_type_array &gid_ptr, + const ordinal_type_array &gid_colidx, const size_type_array &sid_ptr, + const ordinal_type_array &sid_colidx, const ordinal_type_array &blk_colidx, + const ordinal_type_array &stree_parent, const size_type_array &stree_ptr, + const ordinal_type_array &stree_children, const ordinal_type_array_host &stree_level, + const ordinal_type_array_host &stree_roots) + : _method(method), _m(m), _ap(ap), _aj(aj), _perm(perm), _peri(peri), _nsupernodes(nsupernodes), + _gid_colidx(gid_colidx), _stree_level(stree_level), _stree_roots(stree_roots) { + + reset_stat(); - // dof mapping to sparse matrix - ordinal_type_array _gid_colidx; - - // supernode map and panel size configuration (sid and column blksize) - ordinal_pair_type_array _sid_block_colidx; - - // supernode tree - ordinal_type_array_host _stree_level, _stree_roots; - - // output : factors, pivot, diagonal blocks - value_type_array _superpanel_buf; - ordinal_type_array _piv; - value_type_array _diag; - - /// - /// supernode info: supernode data structure with "unamanged" view - /// this is passed into computation algorithm without reference counting /// - supernode_info_type _info; - + /// symbolic input /// - /// statistics - /// - struct { - double t_init, t_mode_classification, t_copy, t_factor, t_solve, t_extra; - double m_used, m_peak; - } stat; - - inline - void - track_alloc(const double in) { - stat.m_used += in; - stat.m_peak = std::max(stat.m_peak, stat.m_used); - } - - inline - void - track_free(const double out) { - stat.m_used -= out; - } - - inline - void - reset_stat(){ - stat.t_init = 0; - stat.t_mode_classification = 0; - stat.t_factor = 0; - stat.t_solve = 0; - stat.t_copy = 0; - stat.t_extra = 0; - stat.m_used = 0; - stat.m_peak = 0; - } - - virtual void - print_stat_init() { - /// nothing - } - - virtual void - print_stat_factor() { - const double kilo(1024); - printf(" Time\n"); - printf(" time for copying A into supernodes: %10.6f s\n", stat.t_copy); - printf(" time for numeric factorization: %10.6f s\n", stat.t_factor); - printf(" total time spent: %10.6f s\n", (stat.t_copy+stat.t_factor)); - printf("\n"); - printf(" Memory\n"); - printf(" memory used in factorization: %10.3f MB\n", stat.m_used/kilo/kilo); - printf(" peak memory used in factorization: %10.3f MB\n", stat.m_peak/kilo/kilo); - printf("\n"); - } - - virtual void - print_stat_solve() { - const double kilo(1024); - printf(" Time\n"); - printf(" time for extra work e.g.,copy rhs: %10.6f s\n", stat.t_extra); - printf(" time for numeric solve: %10.6f s\n", stat.t_solve); - printf(" total time spent: %10.6f s\n", (stat.t_solve+stat.t_extra)); - printf(" Memory\n"); - printf(" memory used in solve: %10.3f MB\n", stat.m_used/kilo/kilo); - printf("\n"); - } + const bool allocate_l_buf = (_method == 3); /// for LU + _info.initialize(_supernodes, _sid_block_colidx, _superpanel_buf, allocate_l_buf, supernodes, gid_ptr, gid_colidx, + sid_ptr, sid_colidx, blk_colidx, stree_parent, stree_ptr, stree_children); + track_alloc(_superpanel_buf.span() * sizeof(value_type)); - inline - void - print_stat_memory() { - const double kilo(1024); - printf(" Memory\n"); - printf(" memory used now: %10.3f MB\n", stat.m_used/kilo/kilo); - printf(" peak memory used: %10.3f MB\n", stat.m_peak/kilo/kilo); - printf("\n"); - } + _piv = ordinal_type_array("piv", 4 * _m); + track_alloc(_piv.span() * sizeof(ordinal_type)); - public: - NumericToolsBase() - : _method(0), - _m(0), - stat() {} + _diag = value_type_array("diag", 2 * _m); + track_alloc(_diag.span() * sizeof(value_type)); + } - NumericToolsBase(const NumericToolsBase &b) = default; - - /// - /// construction (assume input matrix and symbolic are from host) - /// - NumericToolsBase(const ordinal_type method, - // input matrix A - const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - // input permutation - const ordinal_type_array &perm, - const ordinal_type_array &peri, - // supernodes - const ordinal_type nsupernodes, - const ordinal_type_array &supernodes, - const size_type_array &gid_ptr, - const ordinal_type_array &gid_colidx, - const size_type_array &sid_ptr, - const ordinal_type_array &sid_colidx, - const ordinal_type_array &blk_colidx, - const ordinal_type_array &stree_parent, - const size_type_array &stree_ptr, - const ordinal_type_array &stree_children, - const ordinal_type_array_host &stree_level, - const ordinal_type_array_host &stree_roots) - : _method(method), - _m(m), _ap(ap), _aj(aj), - _perm(perm), _peri(peri), - _nsupernodes(nsupernodes), - _gid_colidx(gid_colidx), - _stree_level(stree_level), - _stree_roots(stree_roots) { - - reset_stat(); - - /// - /// symbolic input - /// - _info.initialize(_supernodes, - _sid_block_colidx, - _superpanel_buf, - supernodes, - gid_ptr, - gid_colidx, - sid_ptr, - sid_colidx, - blk_colidx, - stree_parent, - stree_ptr, - stree_children); - track_alloc(_superpanel_buf.span()*sizeof(value_type)); - - _piv = ordinal_type_array("piv", 4*_m); - track_alloc(_piv.span()*sizeof(ordinal_type)); - - _diag = value_type_array("diag", 2*_m); - track_alloc(_diag.span()*sizeof(value_type)); - } + virtual ~NumericToolsBase() {} - virtual~NumericToolsBase() {} + inline ordinal_type getSolutionMethod() const { return _method; } - inline - ordinal_type getSolutionMethod() const { - return _method; - } + inline ordinal_type getNumRows() const { return _m; } - inline - ordinal_type getNumRows() const { - return _m; - } + inline ordinal_type getNumCols() const { return _m; } - inline - ordinal_type getNumCols() const { - return _m; - } + inline size_type_array getRowPtr() const { return _ap; } - inline - size_type_array getRowPtr() const { - return _ap; - } + inline ordinal_type_array getCols() const { return _aj; } - inline - ordinal_type_array getCols() const { - return _aj; - } - - inline - ordinal_type_array getPermutationVector() const { - return _perm; - } + inline ordinal_type_array getPermutationVector() const { return _perm; } - inline - ordinal_type_array getInversePermutationVector() const { - return _peri; - } + inline ordinal_type_array getInversePermutationVector() const { return _peri; } - inline - ordinal_type_array_host getSupernodesTreeLevel() const { - return _stree_level; - } + inline ordinal_type_array_host getSupernodesTreeLevel() const { return _stree_level; } - inline - supernode_info_type getSupernodesInfo() const { - return _info; - } + inline supernode_info_type getSupernodesInfo() const { return _info; } - inline - virtual void - release(const ordinal_type verbose = 0) { - // release diagonal blocks and pivots - track_free(_piv.span()*sizeof(ordinal_type)); - track_free(_diag.span()*sizeof(value_type)); - - // release supernode buffer - track_free(_superpanel_buf.span()*sizeof(value_type)); - _superpanel_buf = value_type_array(); - - if (verbose) { - printf("Summary: NumericTools (Release)\n"); - printf("===============================\n"); - print_stat_memory(); /// should report zero leak - } - } + inline virtual void release(const ordinal_type verbose = 0) { + // release diagonal blocks and pivots + track_free(_piv.span() * sizeof(ordinal_type)); + track_free(_diag.span() * sizeof(value_type)); - inline - ordinal_type - getMaxSupernodeSize() const { - return _info.max_supernode_size; - } + // release supernode buffer + track_free(_superpanel_buf.span() * sizeof(value_type)); + _superpanel_buf = value_type_array(); - inline - ordinal_type - getMaxSchurSize() const { - return _info.max_schur_size; + if (verbose) { + printf("Summary: NumericTools (Release)\n"); + printf("===============================\n"); + print_stat_memory(); /// should report zero leak } + } - inline - void - printMemoryStat(const ordinal_type verbose = 0) { - if (verbose) { - printf("Summary: NumericTools (Memory)\n"); - printf("==============================\n"); - print_stat_memory(); - } - } + inline ordinal_type getMaxSupernodeSize() const { return _info.max_supernode_size; } - inline - virtual void - factorize(const value_type_array &ax, - const ordinal_type verbose = 0) { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, - "The function should be overriden by derived classes"); - } - - inline - virtual void - solve(const value_type_matrix &x, // solution - const value_type_matrix &b, // right hand side - const value_type_matrix &t, // temporary workspace (store permuted vectors) - const ordinal_type verbose = 0) { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, - "The function should be overriden by derived classes"); - } - - /// - /// Utility on device - /// - inline - void - exportFactorsToCrsMatrix(crs_matrix_type &A, - const bool replace_value_with_one = false) { - _info.createCrsMatrix(A, replace_value_with_one); - } + inline ordinal_type getMaxSchurSize() const { return _info.max_schur_size; } - inline - double - computeRelativeResidual(const value_type_matrix &x, - const value_type_matrix &b) { - crs_matrix_type A; - auto d_last = Kokkos::subview(_ap, _m); - auto h_last = Kokkos::create_mirror_view(host_memory_space(), d_last); - Kokkos::deep_copy(h_last, d_last); - A.setExternalMatrix(_m, _m, h_last(), //_ap(_m), - _ap, _aj, _ax); - - return Tacho::computeRelativeResidual(A, x, b); + inline void printMemoryStat(const ordinal_type verbose = 0) { + if (verbose) { + printf("Summary: NumericTools (Memory)\n"); + printf("==============================\n"); + print_stat_memory(); } - - }; - -} + } + + inline virtual void factorize(const value_type_array &ax, const ordinal_type verbose = 0) { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "The function should be overriden by derived classes"); + } + + inline virtual void solve(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose = 0) { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "The function should be overriden by derived classes"); + } + + /// + /// Utility on device + /// + inline void exportFactorsToCrsMatrix(crs_matrix_type &A, const bool replace_value_with_one = false) { + _info.createCrsMatrix(A, replace_value_with_one); + } + + inline double computeRelativeResidual(const value_type_matrix &x, const value_type_matrix &b) { + crs_matrix_type A; + auto d_last = Kokkos::subview(_ap, _m); + auto h_last = Kokkos::create_mirror_view(host_memory_space(), d_last); + Kokkos::deep_copy(h_last, d_last); + A.setExternalMatrix(_m, _m, h_last(), //_ap(_m), + _ap, _aj, _ax); + + return Tacho::computeRelativeResidual(A, x, b); + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Factory.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Factory.hpp new file mode 100644 index 000000000000..1ee711789c4b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Factory.hpp @@ -0,0 +1,268 @@ +#ifndef __TACHO_NUMERIC_TOOLS_FACTORY_HPP__ +#define __TACHO_NUMERIC_TOOLS_FACTORY_HPP__ + +/// \file Tacho_NumericTools_Serial.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_NumericTools_Base.hpp" +#include "Tacho_NumericTools_LevelSet.hpp" +#include "Tacho_NumericTools_Serial.hpp" + +namespace Tacho { + +/// +/// +/// +template class NumericToolsFactory; + +/// partial specialization +#define TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING \ + using ordinal_type_array = typename numeric_tools_base_type::ordinal_type_array; \ + using size_type_array = typename numeric_tools_base_type::size_type_array; \ + using ordinal_type_array_host = typename numeric_tools_base_type::ordinal_type_array_host; \ + using size_type_array_host = typename numeric_tools_base_type::size_type_array_host + +#define TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER \ + ordinal_type _method; \ + ordinal_type _m; \ + size_type_array _ap; \ + ordinal_type_array _aj; \ + ordinal_type_array _perm; \ + ordinal_type_array _peri; \ + ordinal_type _nsupernodes; \ + ordinal_type_array _supernodes; \ + size_type_array _gid_ptr; \ + ordinal_type_array _gid_colidx; \ + size_type_array _sid_ptr; \ + ordinal_type_array _sid_colidx; \ + ordinal_type_array _blk_colidx; \ + ordinal_type_array _stree_parent; \ + size_type_array _stree_ptr; \ + ordinal_type_array _stree_children; \ + ordinal_type_array_host _stree_level; \ + ordinal_type_array_host _stree_roots; \ + ordinal_type _verbose + +#define TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER \ + do { \ + _method = method; \ + _m = m; \ + _ap = ap; \ + _aj = aj; \ + _perm = perm; \ + _peri = peri; \ + _nsupernodes = nsupernodes; \ + _supernodes = supernodes; \ + _gid_ptr = gid_ptr; \ + _gid_colidx = gid_colidx; \ + _sid_ptr = sid_ptr; \ + _sid_colidx = sid_colidx; \ + _blk_colidx = blk_colidx; \ + _stree_parent = stree_parent; \ + _stree_ptr = stree_ptr; \ + _stree_children = stree_children; \ + _stree_level = stree_level; \ + _stree_roots = stree_roots; \ + _verbose = verbose; \ + } while (false) + +#define TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER \ + ordinal_type _variant; \ + ordinal_type _device_level_cut; \ + ordinal_type _device_factor_thres; \ + ordinal_type _device_solve_thres; \ + ordinal_type _nstreams + +#define TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER \ + do { \ + _variant = variant; \ + _device_level_cut = device_level_cut; \ + _device_factor_thres = device_factor_thres; \ + _device_solve_thres = device_solve_thres; \ + _nstreams = nstreams; \ + } while (false) + +#define TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY \ + do { \ + if (object == nullptr) \ + object = (numeric_tools_base_type *)::operator new(sizeof(numeric_tools_serial_type)); \ + \ + new (object) numeric_tools_serial_type(_method, _m, _ap, _aj, _perm, _peri, _nsupernodes, _supernodes, _gid_ptr, \ + _gid_colidx, _sid_ptr, _sid_colidx, _blk_colidx, _stree_parent, _stree_ptr, \ + _stree_children, _stree_level, _stree_roots); \ + } while (false) + +#define TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_name) \ + do { \ + if (object == nullptr) \ + object = (numeric_tools_base_type *)::operator new(sizeof(numeric_tools_levelset_name)); \ + new (object) numeric_tools_levelset_name(_method, _m, _ap, _aj, _perm, _peri, _nsupernodes, _supernodes, _gid_ptr, \ + _gid_colidx, _sid_ptr, _sid_colidx, _blk_colidx, _stree_parent, \ + _stree_ptr, _stree_children, _stree_level, _stree_roots); \ + numeric_tools_levelset_name *N = dynamic_cast(object); \ + N->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); \ + N->createStream(_nstreams, _verbose); \ + } while (false) + +/// +/// Serial construction +/// +#if defined(KOKKOS_ENABLE_SERIAL) +template class NumericToolsFactory::type> { +public: + using value_type = ValueType; + using device_type = typename UseThisDevice::type; + using numeric_tools_base_type = NumericToolsBase; + using numeric_tools_serial_type = NumericToolsSerial; + using numeric_tools_levelset_var0_type = NumericToolsLevelSet; + using numeric_tools_levelset_var1_type = NumericToolsLevelSet; + using numeric_tools_levelset_var2_type = NumericToolsLevelSet; + + TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING; + TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER; + // TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER; + + void setBaseMember(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, + const size_type_array &gid_ptr, const ordinal_type_array &gid_colidx, + const size_type_array &sid_ptr, const ordinal_type_array &sid_colidx, + const ordinal_type_array &blk_colidx, const ordinal_type_array &stree_parent, + const size_type_array &stree_ptr, const ordinal_type_array &stree_children, + const ordinal_type_array_host &stree_level, const ordinal_type_array_host &stree_roots, + const ordinal_type verbose) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER; + } + + void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut, + const ordinal_type device_factor_thres, const ordinal_type device_solve_thres, + const ordinal_type nstreams) { + // TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER; + } + + void createObject(numeric_tools_base_type *&object) { +#if !defined(__CUDA_ARCH__) + TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY; +#endif + } +}; +#endif + +#if defined(KOKKOS_ENABLE_OPENMP) +template class NumericToolsFactory::type> { +public: + using value_type = ValueType; + using device_type = typename UseThisDevice::type; + using numeric_tools_base_type = NumericToolsBase; + using numeric_tools_serial_type = NumericToolsSerial; + using numeric_tools_levelset_var0_type = NumericToolsLevelSet; + using numeric_tools_levelset_var1_type = NumericToolsLevelSet; + using numeric_tools_levelset_var2_type = NumericToolsLevelSet; + + TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING; + TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER; + // TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER; + + void setBaseMember(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, + const size_type_array &gid_ptr, const ordinal_type_array &gid_colidx, + const size_type_array &sid_ptr, const ordinal_type_array &sid_colidx, + const ordinal_type_array &blk_colidx, const ordinal_type_array &stree_parent, + const size_type_array &stree_ptr, const ordinal_type_array &stree_children, + const ordinal_type_array_host &stree_level, const ordinal_type_array_host &stree_roots, + const ordinal_type verbose) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER; + } + + void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut, + const ordinal_type device_factor_thres, const ordinal_type device_solve_thres, + const ordinal_type nstreams) { + // TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER; + } + + void createObject(numeric_tools_base_type *&object) { +#if !defined(__CUDA_ARCH__) + TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY; +#endif + } +}; +#endif + +#if defined(KOKKOS_ENABLE_CUDA) +template class NumericToolsFactory::type> { +public: + using value_type = ValueType; + using device_type = typename UseThisDevice::type; + using numeric_tools_base_type = NumericToolsBase; + using numeric_tools_serial_type = NumericToolsSerial; + using numeric_tools_levelset_var0_type = NumericToolsLevelSet; + using numeric_tools_levelset_var1_type = NumericToolsLevelSet; + using numeric_tools_levelset_var2_type = NumericToolsLevelSet; + + TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING; + TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER; + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER; + + void setBaseMember(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, + const size_type_array &gid_ptr, const ordinal_type_array &gid_colidx, + const size_type_array &sid_ptr, const ordinal_type_array &sid_colidx, + const ordinal_type_array &blk_colidx, const ordinal_type_array &stree_parent, + const size_type_array &stree_ptr, const ordinal_type_array &stree_children, + const ordinal_type_array_host &stree_level, const ordinal_type_array_host &stree_roots, + const ordinal_type verbose) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER; + } + + void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut, + const ordinal_type device_factor_thres, const ordinal_type device_solve_thres, + const ordinal_type nstreams) { + TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER; + } + + void createObject(numeric_tools_base_type *&object) { + switch (_variant) { + case 0: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var0_type); + break; + } + case 1: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var1_type); + break; + } + case 2: { + TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var2_type); + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Invalid variant input"); + break; + } + } + } +}; +#endif + +#undef TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING +#undef TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER +#undef TACHO_NUMERIC_TOOLS_FACTORY_SET_BASE_MEMBER +#undef TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER +#undef TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER +#undef TACHO_NUMERIC_TOOLS_SERIAL_BODY + +} // namespace Tacho +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp index 975e0edd9b4a..55fdeef7f924 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp @@ -21,33 +21,22 @@ #include "Tacho_ApplyPivots.hpp" #include "Tacho_ApplyPivots_OnDevice.hpp" +#include "Tacho_ApplyPermutation.hpp" +#include "Tacho_ApplyPermutation_OnDevice.hpp" + #include "Tacho_Scale2x2_BlockInverseDiagonals.hpp" #include "Tacho_Scale2x2_BlockInverseDiagonals_OnDevice.hpp" -#include "Tacho_Chol.hpp" #include "Tacho_Chol_OnDevice.hpp" - -#include "Tacho_LDL.hpp" +#include "Tacho_GemmTriangular_OnDevice.hpp" +#include "Tacho_Gemm_OnDevice.hpp" +#include "Tacho_Gemv_OnDevice.hpp" +#include "Tacho_Herk_OnDevice.hpp" #include "Tacho_LDL_OnDevice.hpp" - -#include "Tacho_Trsm.hpp" +#include "Tacho_LU_OnDevice.hpp" #include "Tacho_Trsm_OnDevice.hpp" - -#include "Tacho_Herk.hpp" -#include "Tacho_Herk_OnDevice.hpp" - -#include "Tacho_Trsv.hpp" #include "Tacho_Trsv_OnDevice.hpp" -#include "Tacho_Gemv.hpp" -#include "Tacho_Gemv_OnDevice.hpp" - -#include "Tacho_Gemm.hpp" -#include "Tacho_Gemm_OnDevice.hpp" - -#include "Tacho_GemmTriangular.hpp" -#include "Tacho_GemmTriangular_OnDevice.hpp" - #include "Tacho_SupernodeInfo.hpp" #include "Tacho_TeamFunctor_FactorizeChol.hpp" @@ -58,2047 +47,3327 @@ #include "Tacho_TeamFunctor_SolveLowerLDL.hpp" #include "Tacho_TeamFunctor_SolveUpperLDL.hpp" +#include "Tacho_TeamFunctor_FactorizeLU.hpp" +#include "Tacho_TeamFunctor_SolveLowerLU.hpp" +#include "Tacho_TeamFunctor_SolveUpperLU.hpp" + //#define TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD //#define TACHO_ENABLE_LEVELSET_TOOLS_USE_LIGHT_KERNEL namespace Tacho { - template - class NumericToolsLevelSet : public NumericToolsBase { - public: - enum { variant = Var, - max_factor_team_size = 64 }; - - /// - /// types - /// - using base_type = NumericToolsBase; - using typename base_type::value_type; - using typename base_type::device_type; - using typename base_type::exec_space; - using typename base_type::exec_memory_space; - using typename base_type::host_space; - using typename base_type::host_memory_space; - using typename base_type::range_type; - using typename base_type::supernode_info_type; - using typename base_type::ordinal_type_array; - using typename base_type::size_type_array; - using typename base_type::value_type_array; - using typename base_type::value_type_matrix; - using typename base_type::ordinal_type_array_host; - using typename base_type::size_type_array_host; - using typename base_type::supernode_type_array_host; - - using base_type::base_type; - - private: - using base_type::_m; - using base_type::_ap; - using base_type::_aj; - using base_type::_ax; - using base_type::_perm; - using base_type::_peri; - using base_type::_stree_roots; - using base_type::_superpanel_buf; - using base_type::_piv; - using base_type::_diag; - using base_type::_info; - using base_type::_nsupernodes; - using base_type::_stree_level; - - using base_type::stat; - using base_type::track_alloc; - using base_type::track_free; - using base_type::reset_stat; - using base_type::print_stat_memory; - - // supernode host information for level kernels launching - supernode_type_array_host _h_supernodes; - - /// - /// level set infrastructure and tuning parameters - /// - - // 0: device level function, 1: team policy, 2: team policy recursive - ordinal_type _device_factorize_thres, _device_solve_thres; - ordinal_type _device_level_cut, _team_serial_level_cut; - - ordinal_type_array_host _h_factorize_mode, _h_solve_mode; - ordinal_type_array _factorize_mode, _solve_mode; - - // level details on host - ordinal_type _nlevel; - size_type_array_host _h_level_ptr; - ordinal_type_array_host _h_level_sids; - - // level sids on device - ordinal_type_array _level_sids; - - // buf level pointer - ordinal_type_array_host _h_buf_level_ptr; - - // workspace metadata for factorization; - size_type_array_host _h_buf_factor_ptr; - size_type_array _buf_factor_ptr; - - // workspace meta data for solve - size_type_array_host _h_buf_solve_ptr, _h_buf_solve_nrhs_ptr; - size_type_array _buf_solve_ptr, _buf_solve_nrhs_ptr; - - // workspace - size_type _bufsize_factorize, _bufsize_solve; - value_type_array _buf; - - // common for host and cuda - int _status; - - // cuda stream - int _nstreams; +template +class NumericToolsLevelSet : public NumericToolsBase { +public: + enum { variant = Var, max_factor_team_size = 64 }; + + /// + /// types + /// + using base_type = NumericToolsBase; + using typename base_type::device_type; + using typename base_type::exec_memory_space; + using typename base_type::exec_space; + using typename base_type::host_memory_space; + using typename base_type::host_space; + using typename base_type::ordinal_type_array; + using typename base_type::ordinal_type_array_host; + using typename base_type::range_type; + using typename base_type::size_type_array; + using typename base_type::size_type_array_host; + using typename base_type::supernode_info_type; + using typename base_type::supernode_type_array_host; + using typename base_type::value_type; + using typename base_type::value_type_array; + using typename base_type::value_type_matrix; + + using base_type::base_type; + +private: + using base_type::_aj; + using base_type::_ap; + using base_type::_ax; + using base_type::_diag; + using base_type::_info; + using base_type::_m; + using base_type::_nsupernodes; + using base_type::_peri; + using base_type::_perm; + using base_type::_piv; + using base_type::_stree_level; + using base_type::_stree_roots; + using base_type::_superpanel_buf; + + using base_type::print_stat_memory; + using base_type::reset_stat; + using base_type::stat; + using base_type::track_alloc; + using base_type::track_free; + + // supernode host information for level kernels launching + supernode_type_array_host _h_supernodes; + + /// + /// level set infrastructure and tuning parameters + /// + + // 0: device level function, 1: team policy, 2: team policy recursive + ordinal_type _device_factorize_thres, _device_solve_thres; + ordinal_type _device_level_cut, _team_serial_level_cut; + + ordinal_type_array_host _h_factorize_mode, _h_solve_mode; + ordinal_type_array _factorize_mode, _solve_mode; + + // level details on host + ordinal_type _nlevel; + size_type_array_host _h_level_ptr; + ordinal_type_array_host _h_level_sids; + + // level sids on device + ordinal_type_array _level_sids; + + // buf level pointer + ordinal_type_array_host _h_buf_level_ptr; + + // workspace metadata for factorization; + size_type_array_host _h_buf_factor_ptr; + size_type_array _buf_factor_ptr; + + // workspace meta data for solve + size_type_array_host _h_buf_solve_ptr, _h_buf_solve_nrhs_ptr; + size_type_array _buf_solve_ptr, _buf_solve_nrhs_ptr; + + // workspace + size_type _bufsize_factorize, _bufsize_solve; + value_type_array _buf; + + // common for host and cuda + int _status; + + // cuda stream + int _nstreams; #if defined(KOKKOS_ENABLE_CUDA) - bool _is_cublas_created, _is_cusolver_dn_created; - cublasHandle_t _handle_blas; - cusolverDnHandle_t _handle_lapack; - using cuda_stream_array_host = std::vector; - cuda_stream_array_host _cuda_streams; - - using exec_instance_array_host = std::vector; - exec_instance_array_host _exec_instances; -#else - int _handle_blas, _handle_lapack; // dummy handle for convenience + bool _is_cublas_created, _is_cusolver_dn_created; + cublasHandle_t _handle_blas; + cusolverDnHandle_t _handle_lapack; + using cuda_stream_array_host = std::vector; + cuda_stream_array_host _cuda_streams; + + using exec_instance_array_host = std::vector; + exec_instance_array_host _exec_instances; +#else + int _handle_blas, _handle_lapack; // dummy handle for convenience #endif - /// - /// statistics - /// - struct { - int n_device_factorize, n_team_factorize, n_kernel_launching_factorize; - int n_device_solve, n_team_solve, n_kernel_launching_solve; - int n_kernel_launching; - } stat_level; - - /// - /// error check for cuda things - /// - inline - void - checkStatus(const char *func, const char *lib) { - if (_status != 0) { - printf("Error: %s, %s returns non-zero status %d\n", - lib, func, _status); - std::runtime_error("checkStatus failed"); - } + /// + /// statistics + /// + struct { + int n_device_factorize, n_team_factorize, n_kernel_launching_factorize; + int n_device_solve, n_team_solve, n_kernel_launching_solve; + int n_kernel_launching; + } stat_level; + + /// + /// error check for cuda things + /// + inline void checkStatus(const char *func, const char *lib) { + if (_status != 0) { + printf("Error: %s, %s returns non-zero status %d\n", lib, func, _status); + std::runtime_error("checkStatus failed"); } - inline void checkDeviceLapackStatus(const char *func) { + } + inline void checkDeviceLapackStatus(const char *func) { #if defined(KOKKOS_ENABLE_CUDA) - constexpr bool is_host = std::is_same::value; - checkStatus(func, is_host ? "HostLapack" : "CuSolverDn"); + constexpr bool is_host = std::is_same::value; + checkStatus(func, is_host ? "HostLapack" : "CuSolverDn"); #else - checkStatus(func, "HostLapack"); -#endif - } - inline void checkDeviceBlasStatus(const char *func) { + checkStatus(func, "HostLapack"); +#endif + } + inline void checkDeviceBlasStatus(const char *func) { #if defined(KOKKOS_ENABLE_CUDA) - constexpr bool is_host = std::is_same::value; - checkStatus(func, is_host ? "HostBlas" : "CuBlas"); + constexpr bool is_host = std::is_same::value; + checkStatus(func, is_host ? "HostBlas" : "CuBlas"); #else - checkStatus(func, "HostBlas"); -#endif - } - inline void checkDeviceStatus(const char *func) { + checkStatus(func, "HostBlas"); +#endif + } + inline void checkDeviceStatus(const char *func) { #if defined(KOKKOS_ENABLE_CUDA) - constexpr bool is_host = std::is_same::value; - checkStatus(func, is_host ? "Host" : "Cuda"); + constexpr bool is_host = std::is_same::value; + checkStatus(func, is_host ? "Host" : "Cuda"); #else - checkStatus(func, "Host"); -#endif - } - - inline - void - print_stat_init() override { - base_type::print_stat_init(); - const double kilo(1024); - printf(" Time\n"); - printf(" time for initialization: %10.6f s\n", stat.t_init); - printf(" time for compute mode classification: %10.6f s\n", stat.t_mode_classification); - printf(" total time spent: %10.6f s\n", (stat.t_init+stat.t_mode_classification)); - printf("\n"); - printf(" Memory\n"); - printf(" workspace allocated for solve: %10.3f MB\n", stat.m_used/kilo/kilo); - printf(" peak memory used: %10.3f MB\n", stat.m_peak/kilo/kilo); - printf("\n"); - printf(" Compute Mode in Factorize with a Threshold(%d)\n", _device_factorize_thres); - printf(" # of subproblems using device functions: %6d\n", stat_level.n_device_factorize); - printf(" # of subproblems using team functions: %6d\n", stat_level.n_team_factorize); - printf(" total # of subproblems: %6d\n", (stat_level.n_device_factorize+stat_level.n_team_factorize)); - printf("\n"); - printf(" Compute Mode in Solve with a Threshold(%d)\n", _device_solve_thres); - printf(" # of subproblems using device functions: %6d\n", stat_level.n_device_solve); - printf(" # of subproblems using team functions: %6d\n", stat_level.n_team_solve); - printf(" total # of subproblems: %6d\n", (stat_level.n_device_solve+stat_level.n_team_solve)); - printf("\n"); - } - - inline - void - print_stat_factor() override { - base_type::print_stat_factor(); - double flop = 0; - switch (this->getSolutionMethod()) { - case 1: { - for (ordinal_type sid=0;sid<_nsupernodes;++sid) { - auto &s = _h_supernodes(sid); - const ordinal_type m = s.m, n = s.n - s.m; - flop += DenseFlopCount::Chol(m); - if (variant == 1) { - flop += DenseFlopCount::Trsm(true, m, m); - } - else if (variant == 2) { - flop += DenseFlopCount::Trsm(true, m, m); - flop += DenseFlopCount::Trsm(true, m, n); - } - flop += DenseFlopCount::Trsm(true, m, n); - flop += DenseFlopCount::Syrk(m, n); + checkStatus(func, "Host"); +#endif + } + + inline void print_stat_init() override { + base_type::print_stat_init(); + const double kilo(1024); + printf(" Time\n"); + printf(" time for initialization: %10.6f s\n", stat.t_init); + printf(" time for compute mode classification: %10.6f s\n", stat.t_mode_classification); + printf(" total time spent: %10.6f s\n", + (stat.t_init + stat.t_mode_classification)); + printf("\n"); + printf(" Memory\n"); + printf(" workspace allocated for solve: %10.3f MB\n", stat.m_used / kilo / kilo); + printf(" peak memory used: %10.3f MB\n", stat.m_peak / kilo / kilo); + printf("\n"); + printf(" Compute Mode in Factorize with a Threshold(%d)\n", _device_factorize_thres); + printf(" # of subproblems using device functions: %6d\n", stat_level.n_device_factorize); + printf(" # of subproblems using team functions: %6d\n", stat_level.n_team_factorize); + printf(" total # of subproblems: %6d\n", + (stat_level.n_device_factorize + stat_level.n_team_factorize)); + printf("\n"); + printf(" Compute Mode in Solve with a Threshold(%d)\n", _device_solve_thres); + printf(" # of subproblems using device functions: %6d\n", stat_level.n_device_solve); + printf(" # of subproblems using team functions: %6d\n", stat_level.n_team_solve); + printf(" total # of subproblems: %6d\n", + (stat_level.n_device_solve + stat_level.n_team_solve)); + printf("\n"); + } + + inline void print_stat_factor() override { + base_type::print_stat_factor(); + double flop = 0; + switch (this->getSolutionMethod()) { + case 1: { + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) { + auto &s = _h_supernodes(sid); + const ordinal_type m = s.m, n = s.n - s.m; + flop += DenseFlopCount::Chol(m); + if (variant == 1) { + flop += DenseFlopCount::Trsm(true, m, m); + } else if (variant == 2) { + flop += DenseFlopCount::Trsm(true, m, m); + flop += DenseFlopCount::Trsm(true, m, n); } - break; + flop += DenseFlopCount::Trsm(true, m, n); + flop += DenseFlopCount::Syrk(m, n); } - case 2: { - for (ordinal_type sid=0;sid<_nsupernodes;++sid) { - auto &s = _h_supernodes(sid); - const ordinal_type m = s.m, n = s.n - s.m; - flop += DenseFlopCount::LDL(m); - flop += DenseFlopCount::Trsm(true, m, n); - flop += DenseFlopCount::Syrk(m, n); + break; + } + case 2: { + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) { + auto &s = _h_supernodes(sid); + const ordinal_type m = s.m, n = s.n - s.m; + flop += DenseFlopCount::LDL(m); + if (variant == 1) { + flop += DenseFlopCount::Trsm(true, m, m); + } else if (variant == 2) { + flop += DenseFlopCount::Trsm(true, m, m); + flop += DenseFlopCount::Trsm(true, m, n); } - break; - } - default: { - TACHO_TEST_FOR_EXCEPTION(false, - std::logic_error, - "The solution method is not supported"); + flop += DenseFlopCount::Trsm(true, m, n); + flop += DenseFlopCount::Syrk(m, n); } + break; + } + case 3: { + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) { + auto &s = _h_supernodes(sid); + const ordinal_type m = s.m, n = s.n - s.m; + flop += DenseFlopCount::LU(m, m); + if (variant == 1) { + flop += 2 * DenseFlopCount::Trsm(true, m, m); + } else if (variant == 2) { + flop += 2 * DenseFlopCount::Trsm(true, m, m); + flop += 2 * DenseFlopCount::Trsm(true, m, n); + } + flop += 2 * DenseFlopCount::Trsm(true, m, n); + flop += DenseFlopCount::Gemm(n, n, m); } - const double kilo(1024); - printf(" FLOPs\n"); - printf(" gflop for numeric factorization: %10.3f GFLOP\n", flop/kilo/kilo/kilo); - printf(" gflop/s for numeric factorization: %10.3f GFLOP/s\n", flop/stat.t_factor/kilo/kilo/kilo); - printf("\n"); - printf(" Kernels\n"); - printf(" # of kernels launching: %6d\n", stat_level.n_kernel_launching); - printf("\n"); - } - - inline - void - print_stat_solve() override { - base_type::print_stat_solve(); - printf(" Kernels\n"); - printf(" # of kernels launching: %6d\n", stat_level.n_kernel_launching); - printf("\n"); - } - - public: + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(false, std::logic_error, "The solution method is not supported"); + } + } + const double kilo(1024); + printf(" FLOPs\n"); + printf(" gflop for numeric factorization: %10.3f GFLOP\n", flop / kilo / kilo / kilo); + printf(" gflop/s for numeric factorization: %10.3f GFLOP/s\n", + flop / stat.t_factor / kilo / kilo / kilo); + printf("\n"); + printf(" Kernels\n"); + printf(" # of kernels launching: %6d\n", stat_level.n_kernel_launching); + printf("\n"); + } + + inline void print_stat_solve() override { + base_type::print_stat_solve(); + printf(" Kernels\n"); + printf(" # of kernels launching: %6d\n", stat_level.n_kernel_launching); + printf("\n"); + } + +public: + /// + /// initialization / release + /// + inline void initialize(const ordinal_type device_level_cut, const ordinal_type device_factorize_thres, + const ordinal_type device_solve_thres, const ordinal_type verbose = 0) { + stat_level.n_device_factorize = 0; + stat_level.n_device_solve = 0; + stat_level.n_team_factorize = 0; + stat_level.n_team_solve = 0; + + Kokkos::Timer timer; + + timer.reset(); + /// - /// initialization / release + /// level data structure /// - inline - void - initialize(const ordinal_type device_level_cut, - const ordinal_type device_factorize_thres, - const ordinal_type device_solve_thres, - const ordinal_type verbose = 0) { - stat_level.n_device_factorize = 0; stat_level.n_device_solve = 0; - stat_level.n_team_factorize= 0; stat_level.n_team_solve = 0; - Kokkos::Timer timer; + // # of supernodes + _nsupernodes = _info.supernodes.extent(0); - timer.reset(); + // local host supernodes info + _h_supernodes = Kokkos::create_mirror_view_and_copy(host_memory_space(), _info.supernodes); - /// - /// level data structure - /// - - // # of supernodes - _nsupernodes = _info.supernodes.extent(0); - - // local host supernodes info - _h_supernodes = Kokkos::create_mirror_view_and_copy(host_memory_space(), _info.supernodes); - - // # of levels - _nlevel = 0; - { - for (ordinal_type sid=0;sid<_nsupernodes;++sid) - _nlevel = max(_stree_level(sid), _nlevel); - ++_nlevel; - } + // # of levels + _nlevel = 0; + { + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) + _nlevel = max(_stree_level(sid), _nlevel); + ++_nlevel; + } - // create level ptr - _h_level_ptr = size_type_array_host("h_level_ptr", _nlevel+1); - { - // first count # of supernodes in each level - for (ordinal_type sid=0;sid<_nsupernodes;++sid) - ++_h_level_ptr(_stree_level(sid)+1); + // create level ptr + _h_level_ptr = size_type_array_host("h_level_ptr", _nlevel + 1); + { + // first count # of supernodes in each level + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) + ++_h_level_ptr(_stree_level(sid) + 1); - // scan - for (ordinal_type i=0;i<_nlevel;++i) - _h_level_ptr(i+1) += _h_level_ptr(i); - } + // scan + for (ordinal_type i = 0; i < _nlevel; ++i) + _h_level_ptr(i + 1) += _h_level_ptr(i); + } - // fill sids - _h_level_sids = ordinal_type_array_host(do_not_initialize_tag("h_level_sids"), _nsupernodes); - { - size_type_array_host tmp_level_ptr(do_not_initialize_tag("tmp_level_ptr"), _h_level_ptr.extent(0)); - Kokkos::deep_copy(tmp_level_ptr, _h_level_ptr); - for (ordinal_type sid=0;sid<_nsupernodes;++sid) { - const ordinal_type lvl = _stree_level(sid); - _h_level_sids(tmp_level_ptr(lvl)++) = sid; - } + // fill sids + _h_level_sids = ordinal_type_array_host(do_not_initialize_tag("h_level_sids"), _nsupernodes); + { + size_type_array_host tmp_level_ptr(do_not_initialize_tag("tmp_level_ptr"), _h_level_ptr.extent(0)); + Kokkos::deep_copy(tmp_level_ptr, _h_level_ptr); + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) { + const ordinal_type lvl = _stree_level(sid); + _h_level_sids(tmp_level_ptr(lvl)++) = sid; } - _level_sids = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_level_sids); - track_alloc(_level_sids.span()*sizeof(ordinal_type)); + } + _level_sids = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_level_sids); + track_alloc(_level_sids.span() * sizeof(ordinal_type)); - /// - /// workspace - /// - _h_buf_level_ptr = ordinal_type_array_host(do_not_initialize_tag("h_buf_factor_level_ptr"), _nlevel+1); - { - _h_buf_level_ptr(0) = 0; - for (ordinal_type i=0;i<_nlevel;++i) { - const ordinal_type pbeg = _h_level_ptr(i), pend = _h_level_ptr(i+1); - _h_buf_level_ptr(i+1) = (pend - pbeg + 1) + _h_buf_level_ptr(i); - } + /// + /// workspace + /// + _h_buf_level_ptr = ordinal_type_array_host(do_not_initialize_tag("h_buf_factor_level_ptr"), _nlevel + 1); + { + _h_buf_level_ptr(0) = 0; + for (ordinal_type i = 0; i < _nlevel; ++i) { + const ordinal_type pbeg = _h_level_ptr(i), pend = _h_level_ptr(i + 1); + _h_buf_level_ptr(i + 1) = (pend - pbeg + 1) + _h_buf_level_ptr(i); } + } - // create workspace for factorization / solve - _bufsize_factorize = 0; - _bufsize_solve = 0; - _h_buf_factor_ptr = size_type_array_host(do_not_initialize_tag("h_buf_factor_ptr"), _h_buf_level_ptr(_nlevel)); - _h_buf_solve_ptr = size_type_array_host(do_not_initialize_tag("h_buf_solve_ptr"), _h_buf_level_ptr(_nlevel)); - { - for (ordinal_type i=0;i<_nlevel;++i) { - const ordinal_type lbeg = _h_buf_level_ptr(i); - const ordinal_type pbeg = _h_level_ptr(i), pend = _h_level_ptr(i+1); - - _h_buf_factor_ptr(lbeg) = 0; - _h_buf_solve_ptr(lbeg) = 0; - for (ordinal_type p=pbeg,k=(lbeg+1);pgetSolutionMethod()-1; - const ordinal_type factor_work_size = factor_work_size_variants[index_work_size]; - const ordinal_type solve_work_size = solve_work_size_variants[index_work_size]; - - _h_buf_factor_ptr(k) = factor_work_size + _h_buf_factor_ptr(k-1); - _h_buf_solve_ptr(k) = solve_work_size + _h_buf_solve_ptr(k-1); - } - const ordinal_type last_idx = lbeg+pend-pbeg; - _bufsize_factorize = max(_bufsize_factorize, _h_buf_factor_ptr(last_idx)); - _bufsize_solve = max(_bufsize_solve, _h_buf_solve_ptr(last_idx)); + // create workspace for factorization / solve + _bufsize_factorize = 0; + _bufsize_solve = 0; + _h_buf_factor_ptr = size_type_array_host(do_not_initialize_tag("h_buf_factor_ptr"), _h_buf_level_ptr(_nlevel)); + _h_buf_solve_ptr = size_type_array_host(do_not_initialize_tag("h_buf_solve_ptr"), _h_buf_level_ptr(_nlevel)); + { + for (ordinal_type i = 0; i < _nlevel; ++i) { + const ordinal_type lbeg = _h_buf_level_ptr(i); + const ordinal_type pbeg = _h_level_ptr(i), pend = _h_level_ptr(i + 1); + + _h_buf_factor_ptr(lbeg) = 0; + _h_buf_solve_ptr(lbeg) = 0; + for (ordinal_type p = pbeg, k = (lbeg + 1); p < pend; ++p, ++k) { + const ordinal_type sid = _h_level_sids(p); + const auto s = _h_supernodes(sid); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + const ordinal_type schur_work_size = n_m * (n_m + max_factor_team_size); + const ordinal_type chol_factor_work_size_variants[3] = {schur_work_size, max(m * m, schur_work_size), + m * m + schur_work_size}; + const ordinal_type chol_factor_work_size = chol_factor_work_size_variants[variant]; + const ordinal_type ldl_factor_work_size_variant_0 = chol_factor_work_size_variants[0] + max(32 * m, m * n); + const ordinal_type ldl_factor_work_size_variants[3] = {ldl_factor_work_size_variant_0, + max(m * m, ldl_factor_work_size_variant_0 + m * n_m), + m * m + ldl_factor_work_size_variant_0 + m * n_m}; + const ordinal_type ldl_factor_work_size = ldl_factor_work_size_variants[variant]; + const ordinal_type lu_factor_work_size_variants[3] = {schur_work_size, max(m * m, schur_work_size), + m * m + schur_work_size}; + const ordinal_type lu_factor_work_size = lu_factor_work_size_variants[variant]; + const ordinal_type factor_work_size_variants[3] = {chol_factor_work_size, ldl_factor_work_size, + lu_factor_work_size}; + + const ordinal_type chol_solve_work_size = (variant == 0 ? n_m : n); + const ordinal_type ldl_solve_work_size = chol_solve_work_size; + const ordinal_type lu_solve_work_size = chol_solve_work_size; + const ordinal_type solve_work_size_variants[3] = {chol_solve_work_size, ldl_solve_work_size, + lu_solve_work_size}; + + const ordinal_type index_work_size = this->getSolutionMethod() - 1; + const ordinal_type factor_work_size = factor_work_size_variants[index_work_size]; + const ordinal_type solve_work_size = solve_work_size_variants[index_work_size]; + + _h_buf_factor_ptr(k) = factor_work_size + _h_buf_factor_ptr(k - 1); + _h_buf_solve_ptr(k) = solve_work_size + _h_buf_solve_ptr(k - 1); } + const ordinal_type last_idx = lbeg + pend - pbeg; + _bufsize_factorize = max(_bufsize_factorize, _h_buf_factor_ptr(last_idx)); + _bufsize_solve = max(_bufsize_solve, _h_buf_solve_ptr(last_idx)); } + } - _buf_factor_ptr = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_buf_factor_ptr); - track_alloc(_buf_factor_ptr.span()*sizeof(size_type)); + _buf_factor_ptr = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_buf_factor_ptr); + track_alloc(_buf_factor_ptr.span() * sizeof(size_type)); - _buf_solve_ptr = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_buf_solve_ptr); - track_alloc(_buf_solve_ptr.span()*sizeof(size_type)); + _buf_solve_ptr = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_buf_solve_ptr); + track_alloc(_buf_solve_ptr.span() * sizeof(size_type)); - _h_buf_solve_nrhs_ptr = size_type_array_host(do_not_initialize_tag("h_buf_solve_nrhs_ptr"), _h_buf_solve_ptr.extent(0)); - _buf_solve_nrhs_ptr = Kokkos::create_mirror_view(exec_memory_space(), _h_buf_solve_nrhs_ptr); - track_alloc(_buf_solve_nrhs_ptr.span()*sizeof(size_type)); + _h_buf_solve_nrhs_ptr = + size_type_array_host(do_not_initialize_tag("h_buf_solve_nrhs_ptr"), _h_buf_solve_ptr.extent(0)); + _buf_solve_nrhs_ptr = Kokkos::create_mirror_view(exec_memory_space(), _h_buf_solve_nrhs_ptr); + track_alloc(_buf_solve_nrhs_ptr.span() * sizeof(size_type)); - /// - /// cuda library initialize - /// + /// + /// cuda library initialize + /// #if defined(KOKKOS_ENABLE_CUDA) - if (!_is_cublas_created) { - _status = cublasCreate(&_handle_blas); checkDeviceBlasStatus("cublasCreate"); _is_cublas_created = true; - } - if (!_is_cusolver_dn_created) { - _status = cusolverDnCreate(&_handle_lapack); checkDeviceLapackStatus("cusolverDnCreate"); _is_cusolver_dn_created = true; - } + if (!_is_cublas_created) { + _status = cublasCreate(&_handle_blas); + checkDeviceBlasStatus("cublasCreate"); + _is_cublas_created = true; + } + if (!_is_cusolver_dn_created) { + _status = cusolverDnCreate(&_handle_lapack); + checkDeviceLapackStatus("cusolverDnCreate"); + _is_cusolver_dn_created = true; + } #endif - stat.t_init = timer.seconds(); - - /// - /// classification of problems - /// - timer.reset(); - - _device_level_cut = min(device_level_cut, _nlevel); - _device_factorize_thres = device_factorize_thres; - _device_solve_thres = device_solve_thres; + stat.t_init = timer.seconds(); - _h_factorize_mode = ordinal_type_array_host(do_not_initialize_tag("h_factorize_mode"), _nsupernodes); - Kokkos::deep_copy(_h_factorize_mode, -1); - - _h_solve_mode = ordinal_type_array_host(do_not_initialize_tag("h_solve_mode"), _nsupernodes); - Kokkos::deep_copy(_h_solve_mode, -1); + /// + /// classification of problems + /// + timer.reset(); + + _device_level_cut = min(device_level_cut, _nlevel); + _device_factorize_thres = device_factorize_thres; + _device_solve_thres = device_solve_thres; + + _h_factorize_mode = ordinal_type_array_host(do_not_initialize_tag("h_factorize_mode"), _nsupernodes); + Kokkos::deep_copy(_h_factorize_mode, -1); + + _h_solve_mode = ordinal_type_array_host(do_not_initialize_tag("h_solve_mode"), _nsupernodes); + Kokkos::deep_copy(_h_solve_mode, -1); + + if (_device_level_cut > 0) { + for (ordinal_type lvl = 0; lvl < _device_level_cut; ++lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1); + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + _h_solve_mode(sid) = 0; + _h_factorize_mode(sid) = 0; + ++stat_level.n_device_solve; + ++stat_level.n_device_factorize; + } + } + } - if (_device_level_cut > 0) { - for (ordinal_type lvl=0;lvl<_device_level_cut;++lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1); - for (ordinal_type p=pbeg;p _device_solve_thres) { // || n > _device_solve_thres) { _h_solve_mode(sid) = 0; - _h_factorize_mode(sid) = 0; ++stat_level.n_device_solve; + } else { + _h_solve_mode(sid) = 1; + ++stat_level.n_team_solve; + } + if (m > _device_factorize_thres) { // || n_m > _device_factorize_thres) { + _h_factorize_mode(sid) = 0; ++stat_level.n_device_factorize; + } else { + _h_factorize_mode(sid) = 1; + ++stat_level.n_team_factorize; } } } + } - _team_serial_level_cut = _nlevel; - { - for (ordinal_type lvl=_device_level_cut;lvl<_team_serial_level_cut;++lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1); - for (ordinal_type p=pbeg;p _device_solve_thres) {// || n > _device_solve_thres) { - _h_solve_mode(sid) = 0; - ++stat_level.n_device_solve; - } else { - _h_solve_mode(sid) = 1; - ++stat_level.n_team_solve; - } - if (m > _device_factorize_thres) {// || n_m > _device_factorize_thres) { - _h_factorize_mode(sid) = 0; - ++stat_level.n_device_factorize; - } else { - _h_factorize_mode(sid) = 1; - ++stat_level.n_team_factorize; - } - } - } - } + _factorize_mode = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_factorize_mode); + track_alloc(_factorize_mode.span() * sizeof(ordinal_type)); - _factorize_mode = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_factorize_mode); - track_alloc(_factorize_mode.span()*sizeof(ordinal_type)); - - _solve_mode = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_solve_mode); - track_alloc(_solve_mode.span()*sizeof(ordinal_type)); - - stat.t_mode_classification = timer.seconds(); - if (verbose) { - switch (this->getSolutionMethod()) { - case 1: { - printf("Summary: LevelSetTools-Variant-%d (InitializeCholesky)\n", variant); - printf("======================================================\n"); - break; - } - case 2: { - printf("Summary: LevelSetTools (InitializeLDL)\n"); - printf("======================================================\n"); - break; - } - } - print_stat_init(); - } - } + _solve_mode = Kokkos::create_mirror_view_and_copy(exec_memory_space(), _h_solve_mode); + track_alloc(_solve_mode.span() * sizeof(ordinal_type)); - inline - void - release(const ordinal_type verbose = 0) override { - base_type::release(false); - track_free(_buf_factor_ptr.span()*sizeof(size_type)); - track_free(_buf_solve_ptr.span()*sizeof(size_type)); - track_free(_buf_solve_nrhs_ptr.span()*sizeof(size_type)); - track_free(_buf.span()*sizeof(value_type)); - track_free(_factorize_mode.span()*sizeof(ordinal_type)); - track_free(_solve_mode.span()*sizeof(ordinal_type)); - track_free(_level_sids.span()*sizeof(ordinal_type)); - if (verbose) { - printf("Summary: LevelSetTools-Variant-%d (Release)\n", variant); - printf("============================================\n"); - print_stat_memory(); + stat.t_mode_classification = timer.seconds(); + if (verbose) { + switch (this->getSolutionMethod()) { + case 1: { + printf("Summary: LevelSetTools-Variant-%d (InitializeCholesky)\n", variant); + printf("======================================================\n"); + break; } - } - - NumericToolsLevelSet() : base_type() { - _nlevel = 0; - _bufsize_factorize = 0; - _bufsize_solve = 0; - _nstreams = 0; - stat_level = stat_level(); - } - NumericToolsLevelSet(const NumericToolsLevelSet &b) = default; - - NumericToolsLevelSet(const ordinal_type method, - // input matrix A - const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - // input permutation - const ordinal_type_array &perm, - const ordinal_type_array &peri, - // supernodes - const ordinal_type nsupernodes, - const ordinal_type_array &supernodes, - const size_type_array &gid_ptr, - const ordinal_type_array &gid_colidx, - const size_type_array &sid_ptr, - const ordinal_type_array &sid_colidx, - const ordinal_type_array &blk_colidx, - const ordinal_type_array &stree_parent, - const size_type_array &stree_ptr, - const ordinal_type_array &stree_children, - const ordinal_type_array_host &stree_level, - const ordinal_type_array_host &stree_roots) - : base_type(method, - m, ap, aj, - perm,peri, - nsupernodes, supernodes, - gid_ptr, gid_colidx, - sid_ptr, sid_colidx, blk_colidx, - stree_parent, stree_ptr, stree_children, - stree_level, stree_roots) { - _nstreams = 0; -#if defined(KOKKOS_ENABLE_CUDA) - _is_cublas_created = 0; - _is_cusolver_dn_created = 0; -#endif - } - - virtual~NumericToolsLevelSet() { -#if defined(KOKKOS_ENABLE_CUDA) - // destroy previously created streams - for (ordinal_type i=0;i<_nstreams;++i) { - _status = cudaStreamDestroy(_cuda_streams[i]); checkDeviceStatus("cudaStreamDestroy"); + case 2: { + printf("Summary: LevelSetTools-Variant-%d (InitializeLDL)\n", variant); + printf("=================================================\n"); + break; } - _cuda_streams.clear(); - _exec_instances.clear(); - - if (_is_cublas_created) { - _status = cusolverDnDestroy(_handle_lapack); checkDeviceLapackStatus("cusolverDnDestroy"); + case 3: { + printf("Summary: LevelSetTools-Variant-%d (InitializeLU)\n", variant); + printf("================================================\n"); + break; } - if (_is_cusolver_dn_created) { - _status = cublasDestroy(_handle_blas); checkDeviceBlasStatus("cublasDestroy"); } -#endif + print_stat_init(); + } + } + + inline void release(const ordinal_type verbose = 0) override { + base_type::release(false); + track_free(_buf_factor_ptr.span() * sizeof(size_type)); + track_free(_buf_solve_ptr.span() * sizeof(size_type)); + track_free(_buf_solve_nrhs_ptr.span() * sizeof(size_type)); + track_free(_buf.span() * sizeof(value_type)); + track_free(_factorize_mode.span() * sizeof(ordinal_type)); + track_free(_solve_mode.span() * sizeof(ordinal_type)); + track_free(_level_sids.span() * sizeof(ordinal_type)); + if (verbose) { + printf("Summary: LevelSetTools-Variant-%d (Release)\n", variant); + printf("===========================================\n"); + print_stat_memory(); } + } + + NumericToolsLevelSet() : base_type() { + _nlevel = 0; + _bufsize_factorize = 0; + _bufsize_solve = 0; + _nstreams = 0; + stat_level = stat_level(); + } + NumericToolsLevelSet(const NumericToolsLevelSet &b) = default; + + NumericToolsLevelSet(const ordinal_type method, + // input matrix A + const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + // input permutation + const ordinal_type_array &perm, const ordinal_type_array &peri, + // supernodes + const ordinal_type nsupernodes, const ordinal_type_array &supernodes, + const size_type_array &gid_ptr, const ordinal_type_array &gid_colidx, + const size_type_array &sid_ptr, const ordinal_type_array &sid_colidx, + const ordinal_type_array &blk_colidx, const ordinal_type_array &stree_parent, + const size_type_array &stree_ptr, const ordinal_type_array &stree_children, + const ordinal_type_array_host &stree_level, const ordinal_type_array_host &stree_roots) + : base_type(method, m, ap, aj, perm, peri, nsupernodes, supernodes, gid_ptr, gid_colidx, sid_ptr, sid_colidx, + blk_colidx, stree_parent, stree_ptr, stree_children, stree_level, stree_roots) { + _nstreams = 0; +#if defined(KOKKOS_ENABLE_CUDA) + _is_cublas_created = 0; + _is_cusolver_dn_created = 0; +#endif + } - inline - void - createStream(const ordinal_type nstreams, const ordinal_type verbose = 0) { + virtual ~NumericToolsLevelSet() { #if defined(KOKKOS_ENABLE_CUDA) - // destroy previously created streams - for (ordinal_type i=0;i<_nstreams;++i) { - _status = cudaStreamDestroy(_cuda_streams[i]); checkDeviceStatus("cudaStreamDestroy"); - } - // new streams - _nstreams = nstreams; - //_cuda_streams = cuda_stream_array_host(do_not_initialize_tag("cuda streams"), _nstreams); - _cuda_streams.clear(); - _cuda_streams.resize(_nstreams); - for (ordinal_type i=0;i<_nstreams;++i) { - _status = cudaStreamCreateWithFlags(&_cuda_streams[i], cudaStreamNonBlocking); checkDeviceStatus("cudaStreamCreate"); - } + // destroy previously created streams + for (ordinal_type i = 0; i < _nstreams; ++i) { + _status = cudaStreamDestroy(_cuda_streams[i]); + checkDeviceStatus("cudaStreamDestroy"); + } + _cuda_streams.clear(); + _exec_instances.clear(); - _exec_instances.clear(); - _exec_instances.resize(_nstreams); - for (ordinal_type i=0;i<_nstreams;++i) { - ExecSpaceFactory::createInstance(_cuda_streams[i], _exec_instances[i]); - } - if (verbose) { - printf("Summary: CreateStream : %3d\n", _nstreams); - printf("===========================\n"); - } + if (_is_cublas_created) { + _status = cusolverDnDestroy(_handle_lapack); + checkDeviceLapackStatus("cusolverDnDestroy"); + } + if (_is_cusolver_dn_created) { + _status = cublasDestroy(_handle_blas); + checkDeviceBlasStatus("cublasDestroy"); + } #endif + } + + inline void createStream(const ordinal_type nstreams, const ordinal_type verbose = 0) { +#if defined(KOKKOS_ENABLE_CUDA) + // destroy previously created streams + for (ordinal_type i = 0; i < _nstreams; ++i) { + _status = cudaStreamDestroy(_cuda_streams[i]); + checkDeviceStatus("cudaStreamDestroy"); + } + // new streams + _nstreams = nstreams; + //_cuda_streams = cuda_stream_array_host(do_not_initialize_tag("cuda streams"), _nstreams); + _cuda_streams.clear(); + _cuda_streams.resize(_nstreams); + for (ordinal_type i = 0; i < _nstreams; ++i) { + _status = cudaStreamCreateWithFlags(&_cuda_streams[i], cudaStreamNonBlocking); + checkDeviceStatus("cudaStreamCreate"); } - /// - /// Device level functions - /// - inline - void - factorizeCholeskyOnDeviceVar0(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_factor_ptr, - const value_type_array &work) { - const value_type one(1), minus_one(-1), zero(0); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p::createInstance(_cuda_streams[i], _exec_instances[i]); + } + if (verbose) { + printf("Summary: CreateStream : %3d\n", _nstreams); + printf("===========================\n"); + } +#endif + } + + /// + /// Device level functions + /// + inline void factorizeCholeskyOnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, + const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_factorize_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceBlasStatus("cublasSetStream"); + + _status = cusolverDnSetStream(_handle_lapack, mystream); + checkDeviceLapackStatus("cusolverDnSetStream"); + + exec_instance = _exec_instances[qid]; + + const size_type worksize = work.extent(0) / _nstreams; + value_type_array W(work.data() + worksize * qid, worksize); + ++q; #else - value_type_array W = work; -#endif - const auto &s = _h_supernodes(sid); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - _status = Chol - ::invoke(_handle_lapack, ATL, W); checkDeviceLapackStatus("chol"); - - if (n_m > 0) { - exec_instance.fence(); - UnmanagedViewType ABR(_buf.data()+h_buf_factor_ptr(p-pbeg), n_m, n_m); - UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; - _status = Trsm - ::invoke(_handle_blas, Diag::NonUnit(), one, ATL, ATR); checkDeviceBlasStatus("trsm"); - exec_instance.fence(); - _status = Herk - ::invoke(_handle_blas, minus_one, ATR, zero, ABR); - exec_instance.fence(); - } + value_type_array W = work; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + _status = Chol::invoke(_handle_lapack, ATL, W); + checkDeviceLapackStatus("chol"); + + if (n_m > 0) { + UnmanagedViewType ABR(_buf.data() + h_buf_factor_ptr(p - pbeg), n_m, n_m); + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, ATL, ATR); + checkDeviceBlasStatus("trsm"); + + _status = Herk::invoke(_handle_blas, minus_one, ATR, + zero, ABR); } } } } } + } - inline - void - factorizeCholeskyOnDeviceVar1(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_factor_ptr, - const value_type_array &work) { - const value_type one(1), minus_one(-1), zero(0); + inline void factorizeCholeskyOnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, + const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); #if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); + ordinal_type q(0); #endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - _status = Chol - ::invoke(_handle_lapack, ATL, W); checkDeviceLapackStatus("chol"); - - value_type *bptr = _buf.data()+h_buf_factor_ptr(p-pbeg); - UnmanagedViewType T(bptr, m, m); - _status = SetIdentity::invoke(exec_instance, T, one); checkDeviceBlasStatus("SetIdentity"); - exec_instance.fence(); - _status = Trsm - ::invoke(_handle_blas, Diag::NonUnit(), one, ATL, T); checkDeviceBlasStatus("trsm"); - - if (n_m > 0) { - exec_instance.fence(); - UnmanagedViewType ABR(bptr, n_m, n_m); - UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; - _status = Trsm - ::invoke(_handle_blas, Diag::NonUnit(), one, ATL, ATR); checkDeviceBlasStatus("trsm"); - exec_instance.fence(); - _status = Copy::invoke(exec_instance, ATL, T); checkDeviceBlasStatus("Copy"); - exec_instance.fence(); - _status = Herk - ::invoke(_handle_blas, minus_one, ATR, zero, ABR); - exec_instance.fence(); - } else { - exec_instance.fence(); - _status = Copy::invoke(exec_instance, ATL, T); checkDeviceBlasStatus("Copy"); - } + value_type_array W = work; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + _status = Chol::invoke(_handle_lapack, ATL, W); + checkDeviceLapackStatus("chol"); + + value_type *bptr = _buf.data() + h_buf_factor_ptr(p - pbeg); + UnmanagedViewType T(bptr, m, m); + _status = SetIdentity::invoke(exec_instance, T, one); + checkDeviceBlasStatus("SetIdentity"); + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, ATL, T); + checkDeviceBlasStatus("trsm"); + + if (n_m > 0) { + UnmanagedViewType ABR(bptr, n_m, n_m); + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, ATL, ATR); + checkDeviceBlasStatus("trsm"); + + _status = Copy::invoke(exec_instance, ATL, T); + checkDeviceBlasStatus("Copy"); + + _status = Herk::invoke(_handle_blas, minus_one, ATR, + zero, ABR); + } else { + _status = Copy::invoke(exec_instance, ATL, T); + checkDeviceBlasStatus("Copy"); } } } } } + } - inline - void - factorizeCholeskyOnDeviceVar2(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_factor_ptr, - const value_type_array &work) { - const value_type one(1), minus_one(-1), zero(0); + inline void factorizeCholeskyOnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, + const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); #if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); + ordinal_type q(0); #endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + _status = Chol::invoke(_handle_lapack, ATL, W); + checkDeviceLapackStatus("chol"); + + value_type *bptr = _buf.data() + h_buf_factor_ptr(p - pbeg); + if (n_m > 0) { + UnmanagedViewType ABR(bptr, n_m, n_m); + bptr += ABR.span(); + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, ATL, ATR); + checkDeviceBlasStatus("trsm"); + + _status = Herk::invoke(_handle_blas, minus_one, ATR, + zero, ABR); + + /// additional things + UnmanagedViewType T(bptr, m, m); + _status = Copy::invoke(exec_instance, T, ATL); + checkDeviceBlasStatus("Copy"); - exec_instance = _exec_instances[qid]; + _status = SetIdentity::invoke(exec_instance, ATL, minus_one); + checkDeviceBlasStatus("SetIdentity"); - const size_type worksize = work.extent(0)/_nstreams; - value_type_array W(work.data() + worksize*qid, worksize); - ++q; -#else - value_type_array W = work; -#endif - const auto &s = _h_supernodes(sid); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - _status = Chol - ::invoke(_handle_lapack, ATL, W); checkDeviceLapackStatus("chol"); - - value_type *bptr = _buf.data()+h_buf_factor_ptr(p-pbeg); - if (n_m > 0) { - exec_instance.fence(); - UnmanagedViewType ABR(bptr, n_m, n_m); bptr += ABR.span(); - UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; - - _status = Trsm - ::invoke(_handle_blas, Diag::NonUnit(), one, ATL, ATR); checkDeviceBlasStatus("trsm"); - exec_instance.fence(); - _status = Herk - ::invoke(_handle_blas, minus_one, ATR, zero, ABR); - exec_instance.fence(); - - /// additional things - UnmanagedViewType T(bptr, m, m); - _status = Copy::invoke(exec_instance, T, ATL); checkDeviceBlasStatus("Copy"); - exec_instance.fence(); - _status = SetIdentity::invoke(exec_instance, ATL, minus_one); checkDeviceBlasStatus("SetIdentity"); - exec_instance.fence(); - - UnmanagedViewType AT(ATL.data(), m, n); - _status = Trsm - ::invoke(_handle_blas, Diag::NonUnit(), minus_one, T, AT); checkDeviceBlasStatus("trsm"); - exec_instance.fence(); - } else { - exec_instance.fence(); - /// additional things - UnmanagedViewType T(bptr, m, m); - _status = Copy::invoke(exec_instance, T, ATL); checkDeviceBlasStatus("Copy"); - exec_instance.fence(); - _status = SetIdentity::invoke(exec_instance, ATL, one); checkDeviceBlasStatus("SetIdentity"); - exec_instance.fence(); - _status = Trsm - ::invoke(_handle_blas, Diag::NonUnit(), one, T, ATL); checkDeviceBlasStatus("trsm"); - } + UnmanagedViewType AT(ATL.data(), m, n); + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), minus_one, T, AT); + checkDeviceBlasStatus("trsm"); + } else { + /// additional things + UnmanagedViewType T(bptr, m, m); + _status = Copy::invoke(exec_instance, T, ATL); + checkDeviceBlasStatus("Copy"); + + _status = SetIdentity::invoke(exec_instance, ATL, one); + checkDeviceBlasStatus("SetIdentity"); + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, T, ATL); + checkDeviceBlasStatus("trsm"); } } } } } - - inline - void - factorizeCholeskyOnDevice(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_factor_ptr, - const value_type_array &work) { - if (variant == 0) - factorizeCholeskyOnDeviceVar0(pbeg, pend, h_buf_factor_ptr, work); - else if (variant == 1) - factorizeCholeskyOnDeviceVar1(pbeg, pend, h_buf_factor_ptr, work); - else if (variant == 2) - factorizeCholeskyOnDeviceVar2(pbeg, pend, h_buf_factor_ptr, work); - else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, - "LevelSetTools::factorizeCholeskyOnDevice, algorithm variant is not supported"); - } + } + + inline void factorizeCholeskyOnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + if (variant == 0) + factorizeCholeskyOnDeviceVar0(pbeg, pend, h_buf_factor_ptr, work); + else if (variant == 1) + factorizeCholeskyOnDeviceVar1(pbeg, pend, h_buf_factor_ptr, work); + else if (variant == 2) + factorizeCholeskyOnDeviceVar2(pbeg, pend, h_buf_factor_ptr, work); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::factorizeCholeskyOnDevice, algorithm variant is not supported"); } + } - inline - void - factorizeLDL_OnDevice(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_factor_ptr, - const value_type_array &work) { - const value_type one(1), minus_one(-1), zero(0); + inline void factorizeLDL_OnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); #if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); + ordinal_type q(0); #endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; - exec_instance = _exec_instances[qid]; + _status = Symmetrize::invoke(exec_instance, ATL); - const size_type worksize = work.extent(0)/_nstreams; - value_type_array W(work.data() + worksize*qid, worksize); - ++q; -#else - value_type_array W = work; -#endif - const auto &s = _h_supernodes(sid); - { - const ordinal_type offs = s.row_begin, m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - - _status = Symmetrize::invoke(exec_instance, ATL); - exec_instance.fence(); - - ordinal_type *pivptr = _piv.data() + 4*offs; - UnmanagedViewType P(pivptr, 4*m); - _status = LDL::invoke(_handle_lapack, ATL, P, W); checkDeviceLapackStatus("ldl::invoke"); - exec_instance.fence(); - - value_type * dptr = _diag.data() + 2*offs; - UnmanagedViewType D(dptr, m, 2); - _status = LDL::modify(exec_instance, ATL, P, D); checkDeviceLapackStatus("ldl::modify"); - exec_instance.fence(); - - if (n_m > 0) { - UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; - UnmanagedViewType ABR(_buf.data()+h_buf_factor_ptr(p-pbeg), n_m, n_m); - UnmanagedViewType STR(ABR.data()+ABR.span(), m, n_m); - - auto fpiv = ordinal_type_array(P.data()+m, m); - _status = ApplyPivots - ::invoke(exec_instance, fpiv, ATR); - exec_instance.fence(); - - _status = Trsm - ::invoke(_handle_blas, Diag::Unit(), one, ATL, ATR); checkDeviceBlasStatus("trsm"); - exec_instance.fence(); - - _status = Copy - ::invoke(exec_instance, STR, ATR); - exec_instance.fence(); - - _status = Scale2x2_BlockInverseDiagonals - ::invoke(exec_instance, P, D, ATR); - exec_instance.fence(); - - _status = GemmTriangular - ::invoke(_handle_blas, minus_one, ATR, STR, zero, ABR); - exec_instance.fence(); checkDeviceBlasStatus("gemm"); - } - } - } - } - } - } + ordinal_type *pivptr = _piv.data() + 4 * offs; + UnmanagedViewType P(pivptr, 4 * m); + _status = LDL::invoke(_handle_lapack, ATL, P, W); + checkDeviceLapackStatus("ldl::invoke"); + value_type *dptr = _diag.data() + 2 * offs; + UnmanagedViewType D(dptr, m, 2); + _status = LDL::modify(exec_instance, ATL, P, D); + checkDeviceLapackStatus("ldl::modify"); - /// - /// Level set factorize - /// - inline - void - factorizeCholesky(const value_type_array &ax, - const ordinal_type verbose) { - constexpr bool is_host = std::is_same::value; - Kokkos::Timer timer; + if (n_m > 0) { + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; + UnmanagedViewType ABR(_buf.data() + h_buf_factor_ptr(p - pbeg), n_m, n_m); + UnmanagedViewType STR(ABR.data() + ABR.span(), m, n_m); - timer.reset(); - value_type_array work; - { - _buf = value_type_array(do_not_initialize_tag("buf"), _bufsize_factorize); - track_alloc(_buf.span()*sizeof(value_type)); - -#if defined (KOKKOS_ENABLE_CUDA) - value_type_matrix T(NULL, _info.max_supernode_size, _info.max_supernode_size); - const size_type worksize = Chol - ::invoke(_handle_lapack, T, work); - - work = value_type_array(do_not_initialize_tag("work"), worksize*(_nstreams+1)); - track_alloc(work.span()*sizeof(value_type)); -#endif - } - stat.t_extra = timer.seconds(); + auto fpiv = ordinal_type_array(P.data() + m, m); + _status = ApplyPivots::invoke( + exec_instance, fpiv, ATR); - timer.reset(); - { - _ax = ax; // matrix values - _info.copySparseToSuperpanels(_ap, _aj, _ax, _perm, _peri); - } - stat.t_copy = timer.seconds(); - - stat_level.n_kernel_launching = 0; - timer.reset(); - { - // this should be considered with average problem sizes in levels - const ordinal_type half_level = _nlevel/2; - //const ordinal_type team_size_factor[2] = { 64, 16 }, vector_size_factor[2] = { 8, 8}; - //const ordinal_type team_size_factor[2] = { 16, 16 }, vector_size_factor[2] = { 32, 32}; - const ordinal_type team_size_factor[2] = { 64, 64 }, vector_size_factor[2] = { 8, 4}; - const ordinal_type team_size_update[2] = { 16, 8 }, vector_size_update[2] = { 32, 32}; - { - typedef TeamFunctor_FactorizeChol functor_type; -#if defined(TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD) - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_factorize; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_update; -#else - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::template FactorizeTag > team_policy_factor; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::UpdateTag> team_policy_update; -#endif - - functor_type functor(_info, - _factorize_mode, - _level_sids, - _buf); - - team_policy_factor policy_factor(1,1,1); - team_policy_update policy_update(1,1,1); - - { - for (ordinal_type lvl=(_team_serial_level_cut-1);lvl>=0;--lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1), - pcnt = pend - pbeg; - - const range_type range_buf_factor_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl+1)); - - const auto buf_factor_ptr = Kokkos::subview(_buf_factor_ptr, range_buf_factor_ptr); - functor.setRange(pbeg, pend); - functor.setBufferPtr(buf_factor_ptr); - if (is_host) { - policy_factor = team_policy_factor(pcnt, 1, 1); - policy_update = team_policy_update(pcnt, 1, 1); - } else { - const ordinal_type idx = lvl > half_level; - policy_factor = team_policy_factor(pcnt, team_size_factor[idx], vector_size_factor[idx]); - policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); - } - if (lvl < _device_level_cut) { - // do nothing - //Kokkos::parallel_for("factor lower", policy_factor, functor); - } else { - Kokkos::parallel_for("factor", policy_factor, functor); - ++stat_level.n_kernel_launching; - } - - const auto h_buf_factor_ptr = Kokkos::subview(_h_buf_factor_ptr, range_buf_factor_ptr); - factorizeCholeskyOnDevice(pbeg, pend, h_buf_factor_ptr, work); - Kokkos::fence(); - - Kokkos::parallel_for("update factor", policy_update, functor); - ++stat_level.n_kernel_launching; - exec_space().fence(); //Kokkos::fence(); - } - } - } - } // end of Cholesky - stat.t_factor = timer.seconds(); + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, ATL, ATR); + checkDeviceBlasStatus("trsm"); - timer.reset(); - { -#if defined (KOKKOS_ENABLE_CUDA) - track_free(work.span()*sizeof(value_type)); -#endif - track_free(_buf.span()*sizeof(value_type)); - _buf = value_type_array(); - } - stat.t_extra += timer.seconds(); + _status = Copy::invoke(exec_instance, STR, ATR); - if (verbose) { - printf("Summary: LevelSetTools-Variant-%d (CholeskyFactorize)\n", variant); - printf("=====================================================\n"); - print_stat_factor(); - } - } + _status = Scale2x2_BlockInverseDiagonals::invoke(exec_instance, P, D, ATR); - inline - void - solveCholeskyLowerOnDeviceVar0(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type minus_one(-1), zero(0); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf; - UnmanagedViewType AL(aptr, m, m); aptr += m*m; - - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); - _status = Trsv - ::invoke(_handle_blas, Diag::NonUnit(), AL, tT); checkDeviceBlasStatus("trsv"); - - if (n_m > 0) { - // solve offdiag - value_type *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg); - UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n_m; - UnmanagedViewType bB(bptr, n_m, nrhs); - _status = Gemv - ::invoke(_handle_blas, minus_one, AR, tT, zero, bB); checkDeviceBlasStatus("gemv"); - } + _status = GemmTriangular::invoke( + _handle_blas, minus_one, ATR, STR, zero, ABR); + checkDeviceBlasStatus("gemm"); } } } } } + } + + inline void factorizeLDL_OnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_factorize_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceBlasStatus("cublasSetStream"); + _status = cusolverDnSetStream(_handle_lapack, mystream); + checkDeviceLapackStatus("cusolverDnSetStream"); + + exec_instance = _exec_instances[qid]; + + const size_type worksize = work.extent(0) / _nstreams; + value_type_array W(work.data() + worksize * qid, worksize); + ++q; +#else + value_type_array W = work; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type offs = s.row_begin, m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + + _status = Symmetrize::invoke(exec_instance, ATL); + + ordinal_type *pivptr = _piv.data() + 4 * offs; + UnmanagedViewType P(pivptr, 4 * m); + _status = LDL::invoke(_handle_lapack, ATL, P, W); + checkDeviceLapackStatus("ldl::invoke"); + + value_type *dptr = _diag.data() + 2 * offs; + UnmanagedViewType D(dptr, m, 2); + _status = LDL::modify(exec_instance, ATL, P, D); + checkDeviceLapackStatus("ldl::modify"); + + value_type *bptr = _buf.data() + h_buf_factor_ptr(p - pbeg); + UnmanagedViewType T(bptr, m, m); + + _status = SetIdentity::invoke(exec_instance, T, one); + checkDeviceBlasStatus("SetIdentity"); + + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, ATL, T); + checkDeviceBlasStatus("trsm"); + + if (n_m > 0) { + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; + UnmanagedViewType ABR(bptr, n_m, n_m); + + const ordinal_type used_span = max(ABR.span(), T.span()); + UnmanagedViewType STR(ABR.data() + used_span, m, n_m); + + ConstUnmanagedViewType perm(P.data() + 2 * m, m); + _status = ApplyPermutation::invoke(exec_instance, ATR, + perm, STR); + + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, ATL, STR); + checkDeviceBlasStatus("trsm"); + + _status = Copy::invoke(exec_instance, ATL, T); + _status = Copy::invoke(exec_instance, ATR, STR); + + _status = Scale2x2_BlockInverseDiagonals::invoke(exec_instance, P, D, ATR); - inline - void - solveCholeskyLowerOnDeviceVar1(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type one(1), minus_one(-1), zero(0); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf; - UnmanagedViewType AL(aptr, m, m); aptr += m*m; - - value_type *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg); - UnmanagedViewType bT(bptr, m, nrhs); bptr += m*nrhs; - - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); - - _status = Gemv - ::invoke(_handle_blas, one, AL, tT, zero, bT); checkDeviceBlasStatus("gemv"); - - if (n_m > 0) { - // solve offdiag - UnmanagedViewType AR(aptr, m, n_m); - UnmanagedViewType bB(bptr, n_m, nrhs); - - _status = Gemv - ::invoke(_handle_blas, minus_one, AR, bT, zero, bB); checkDeviceBlasStatus("gemv"); - } + _status = GemmTriangular::invoke( + _handle_blas, minus_one, ATR, STR, zero, ABR); + checkDeviceBlasStatus("gemm"); + } else { + _status = Copy::invoke(exec_instance, ATL, T); } } } } } + } - inline - void - solveCholeskyLowerOnDeviceVar2(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type one(1), zero(0); + inline void factorizeLDL_OnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); #if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); + ordinal_type q(0); #endif - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf; - UnmanagedViewType A(aptr, m, n); + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceBlasStatus("cublasSetStream"); + _status = cusolverDnSetStream(_handle_lapack, mystream); + checkDeviceLapackStatus("cusolverDnSetStream"); + + exec_instance = _exec_instances[qid]; + + const size_type worksize = work.extent(0) / _nstreams; + value_type_array W(work.data() + worksize * qid, worksize); + ++q; +#else + value_type_array W = work; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type offs = s.row_begin, m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; - value_type *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg); - UnmanagedViewType b(bptr, n, nrhs); + _status = Symmetrize::invoke(exec_instance, ATL); - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); + ordinal_type *pivptr = _piv.data() + 4 * offs; + UnmanagedViewType P(pivptr, 4 * m); + _status = LDL::invoke(_handle_lapack, ATL, P, W); + checkDeviceLapackStatus("ldl::invoke"); - _status = Gemv - ::invoke(_handle_blas, one, A, tT, zero, b); checkDeviceBlasStatus("gemv"); - } - } - } - } - } + value_type *dptr = _diag.data() + 2 * offs; + UnmanagedViewType D(dptr, m, 2); + _status = LDL::modify(exec_instance, ATL, P, D); + checkDeviceLapackStatus("ldl::modify"); - inline - void - solveCholeskyLowerOnDevice(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - if (variant == 0) - solveCholeskyLowerOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); - else if (variant == 1) - solveCholeskyLowerOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); - else if (variant == 2) - solveCholeskyLowerOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); - else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, - "LevelSetTools::solveCholeskyLowerOnDevice, algorithm variant is not supported"); - } - } + value_type *bptr = _buf.data() + h_buf_factor_ptr(p - pbeg); - inline - void - solveCholeskyUpperOnDeviceVar0(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type minus_one(-1), one(1); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf, *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg);; - const UnmanagedViewType AL(aptr, m, m); aptr += m*m; - const UnmanagedViewType bB(bptr, n_m, nrhs); - - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); - - if (n_m > 0) { - const UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n; - _status = Gemv - ::invoke(_handle_blas, minus_one, AR, bB, one, tT); checkDeviceBlasStatus("gemv"); - exec_instance.fence(); - } - _status = Trsv - ::invoke(_handle_blas, Diag::NonUnit(), AL, tT); checkDeviceBlasStatus("trsv"); - } - } - } - } - } + if (n_m > 0) { + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; + UnmanagedViewType ABR(bptr, n_m, n_m); + UnmanagedViewType T(bptr + ABR.span(), m, m); - inline - void - solveCholeskyUpperOnDeviceVar1(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type minus_one(-1), one(1), zero(0); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf, *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg);; - const UnmanagedViewType AL(aptr, m, m); aptr += m*m; - const UnmanagedViewType bT(bptr, m, nrhs); bptr += m*nrhs; - - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); - - if (n_m > 0) { - const UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n; - const UnmanagedViewType bB(bptr, n_m, nrhs); - _status = Gemv - ::invoke(_handle_blas, minus_one, AR, bB, one, tT); checkDeviceBlasStatus("gemv"); - exec_instance.fence(); - } - - _status = Gemv - ::invoke(_handle_blas, one, AL, tT, zero, bT); checkDeviceBlasStatus("gemv"); - - exec_instance.fence(); - - _status = Copy::invoke(exec_instance, tT, bT); checkDeviceBlasStatus("Copy"); - } - } - } - } - } + const ordinal_type used_span = ABR.span() + T.span(); + UnmanagedViewType STR(bptr + used_span, m, n_m); + + ConstUnmanagedViewType perm(P.data() + 2 * m, m); + _status = ApplyPermutation::invoke(exec_instance, ATR, + perm, STR); + + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, ATL, STR); + checkDeviceBlasStatus("trsm"); + + _status = Copy::invoke(exec_instance, T, ATL); + _status = Copy::invoke(exec_instance, ATR, STR); - inline - void - solveCholeskyUpperOnDeviceVar2(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type one(1), zero(0); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - for (ordinal_type p=pbeg;p 0 && n > 0) { - value_type *aptr = s.buf, *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg);; - const UnmanagedViewType A(aptr, m, n); - const UnmanagedViewType b(bptr, n, nrhs); - - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); - - _status = Gemv - ::invoke(_handle_blas, one, A, b, zero, tT); checkDeviceBlasStatus("gemv"); + _status = Symmetrize::invoke(exec_instance, T); + _status = SetIdentity::invoke(exec_instance, ATL, minus_one); + _status = Scale2x2_BlockInverseDiagonals::invoke(exec_instance, P, D, ATR); + + _status = GemmTriangular::invoke( + _handle_blas, minus_one, ATR, STR, zero, ABR); + checkDeviceBlasStatus("gemm"); + + UnmanagedViewType AT(ATL.data(), m, n); + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), minus_one, T, AT); + } else { + UnmanagedViewType T(bptr, m, m); + _status = Copy::invoke(exec_instance, T, ATL); + + _status = SetIdentity::invoke(exec_instance, ATL, one); + + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, T, ATL); } } } } } - - inline - void - solveCholeskyUpperOnDevice(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - if (variant == 0) - solveCholeskyUpperOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); - else if (variant == 1) - solveCholeskyUpperOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); - else if (variant == 2) - solveCholeskyUpperOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); - else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, - "LevelSetTools::solveCholeskyUpperOnDevice, algorithm variant is not supported"); - } + } + + inline void factorizeLDL_OnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + if (variant == 0) + factorizeLDL_OnDeviceVar0(pbeg, pend, h_buf_factor_ptr, work); + else if (variant == 1) + factorizeLDL_OnDeviceVar1(pbeg, pend, h_buf_factor_ptr, work); + else if (variant == 2) + factorizeLDL_OnDeviceVar2(pbeg, pend, h_buf_factor_ptr, work); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::factorizeLDL_OnDevice, algorithm variant is not supported"); } + } - inline - void - solveLDL_LowerOnDevice(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type minus_one(-1), zero(0); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf; - UnmanagedViewType AL(aptr, m, m); aptr += m*m; - - const ordinal_type offm = s.row_begin; - - const auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); - const auto fpiv = ordinal_type_array(_piv.data()+4*offm+m, m); - - _status = ApplyPivots /// row inter-change - ::invoke(exec_instance, fpiv, tT); - exec_instance.fence(); - - _status = Trsv - ::invoke(_handle_blas, Diag::Unit(), AL, tT); checkDeviceBlasStatus("trsv"); - exec_instance.fence(); - if (n_m > 0) { - value_type *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg); - UnmanagedViewType AR(aptr, m, n_m); // ptr += m*n_m; - UnmanagedViewType bB(bptr, n_m, nrhs); - _status = Gemv - ::invoke(_handle_blas, minus_one, AR, tT, zero, bB); checkDeviceBlasStatus("gemv"); - } + inline void factorizeLU_OnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_factorize_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceBlasStatus("cublasSetStream"); + _status = cusolverDnSetStream(_handle_lapack, mystream); + checkDeviceLapackStatus("cusolverDnSetStream"); + + exec_instance = _exec_instances[qid]; + + const size_type worksize = work.extent(0) / _nstreams; + value_type_array W(work.data() + worksize * qid, worksize); + ++q; +#else + value_type_array W = work; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type offs = s.row_begin, m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *uptr = s.u_buf; + UnmanagedViewType AT(uptr, m, n); + + ordinal_type *pivptr = _piv.data() + 4 * offs; + UnmanagedViewType P(pivptr, 4 * m); + _status = LU::invoke(_handle_lapack, AT, P, W); + checkDeviceLapackStatus("lu::invoke"); + + _status = LU::modify(exec_instance, m, P); + checkDeviceLapackStatus("lu::modify"); + + if (n_m > 0) { + UnmanagedViewType ATL(uptr, m, m); + uptr += m * m; + UnmanagedViewType ATR(uptr, m, n_m); + + UnmanagedViewType AL(s.l_buf, n, m); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + UnmanagedViewType ABR(_buf.data() + h_buf_factor_ptr(p - pbeg), n_m, n_m); + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, ATL, ABL); + checkDeviceBlasStatus("trsm"); + + _status = Gemm::invoke(_handle_blas, minus_one, + ABL, ATR, zero, ABR); + checkDeviceBlasStatus("gemm"); } } } } } + } + inline void factorizeLU_OnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_factorize_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceBlasStatus("cublasSetStream"); + _status = cusolverDnSetStream(_handle_lapack, mystream); + checkDeviceLapackStatus("cusolverDnSetStream"); + + exec_instance = _exec_instances[qid]; + + const size_type worksize = work.extent(0) / _nstreams; + value_type_array W(work.data() + worksize * qid, worksize); + ++q; +#else + value_type_array W = work; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type offs = s.row_begin, m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType AT(s.u_buf, m, n); - inline - void - solveLDL_UpperOnDevice(const ordinal_type pbeg, - const ordinal_type pend, - const size_type_array_host &h_buf_solve_ptr, - const value_type_matrix &t) { - const ordinal_type nrhs = t.extent(1); - const value_type minus_one(-1), one(1); -#if defined(KOKKOS_ENABLE_CUDA) - ordinal_type q(0); -#endif - exec_space exec_instance; - for (ordinal_type p=pbeg;p 0) { - value_type *aptr = s.buf, *bptr = _buf.data()+h_buf_solve_ptr(p-pbeg);; - const UnmanagedViewType AL(aptr, m, m); aptr += m*m; - const UnmanagedViewType bB(bptr, n_m, nrhs); - - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(t, range_type(offm, offm+m), Kokkos::ALL()); - const auto P = ordinal_type_array(_piv.data()+4*offm, 4*m); - const auto D = value_type_matrix(_diag.data()+2*offm, m, 2); - _status = Scale2x2_BlockInverseDiagonals /// row scaling - ::invoke(exec_instance, P, D, tT); + ordinal_type *pivptr = _piv.data() + 4 * offs; + UnmanagedViewType P(pivptr, 4 * m); + _status = LU::invoke(_handle_lapack, AT, P, W); + checkDeviceLapackStatus("lu::invoke"); - if (n_m > 0) { - const UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n; - _status = Gemv - ::invoke(_handle_blas, minus_one, AR, bB, one, tT); checkDeviceBlasStatus("gemv"); - exec_instance.fence(); - } - _status = Trsv - ::invoke(_handle_blas, Diag::Unit(), AL, tT); checkDeviceBlasStatus("trsv"); - - const auto fpiv = ordinal_type_array(P.data()+m, m); - _status = ApplyPivots /// row inter-change - ::invoke(exec_instance, fpiv, tT); - } + _status = LU::modify(exec_instance, m, P); + checkDeviceLapackStatus("lu::modify"); + + value_type *bptr = _buf.data() + h_buf_factor_ptr(p - pbeg); + UnmanagedViewType T(bptr, m, m); + + if (n_m > 0) { + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + + UnmanagedViewType AL(s.l_buf, n, m); + const auto ATL2 = Kokkos::subview(AL, range_type(0, m), Kokkos::ALL()); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + + UnmanagedViewType ABR(bptr, n_m, n_m); + + _status = Copy::invoke(exec_instance, T, ATL); + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, ATL, ABL); + checkDeviceBlasStatus("trsm"); + + _status = SetIdentity::invoke(exec_instance, ATL, one); + _status = SetIdentity::invoke(exec_instance, ATL2, one); + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, T, ATL); + checkDeviceBlasStatus("trsm"); + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, T, ATL2); + checkDeviceBlasStatus("trsm"); + + _status = Gemm::invoke(_handle_blas, minus_one, + ABL, ATR, zero, ABR); + checkDeviceBlasStatus("gemm"); + } else { + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATL2(s.l_buf, m, m); + + _status = Copy::invoke(exec_instance, T, ATL); + + _status = SetIdentity::invoke(exec_instance, ATL, one); + _status = SetIdentity::invoke(exec_instance, ATL2, one); + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, T, ATL); + checkDeviceBlasStatus("trsm"); + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, T, ATL2); + checkDeviceBlasStatus("trsm"); + } + } + } + } + } + } + + inline void factorizeLU_OnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_factorize_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceBlasStatus("cublasSetStream"); + _status = cusolverDnSetStream(_handle_lapack, mystream); + checkDeviceLapackStatus("cusolverDnSetStream"); + + exec_instance = _exec_instances[qid]; + + const size_type worksize = work.extent(0) / _nstreams; + value_type_array W(work.data() + worksize * qid, worksize); + ++q; +#else + value_type_array W = work; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type offs = s.row_begin, m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType AT(s.u_buf, m, n); + + ordinal_type *pivptr = _piv.data() + 4 * offs; + UnmanagedViewType P(pivptr, 4 * m); + _status = LU::invoke(_handle_lapack, AT, P, W); + checkDeviceLapackStatus("lu::invoke"); + + _status = LU::modify(exec_instance, m, P); + checkDeviceLapackStatus("lu::modify"); + + value_type *bptr = _buf.data() + h_buf_factor_ptr(p - pbeg); + + if (n_m > 0) { + UnmanagedViewType AT(s.u_buf, m, n); + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + + UnmanagedViewType AL(s.l_buf, n, m); + const auto ATL2 = Kokkos::subview(AL, range_type(0, m), Kokkos::ALL()); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + + UnmanagedViewType ABR(bptr, n_m, n_m); + UnmanagedViewType T(bptr + ABR.span(), m, m); + + _status = Copy::invoke(exec_instance, T, ATL); + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, ATL, ABL); + checkDeviceBlasStatus("trsm"); + + _status = SetIdentity::invoke(exec_instance, ATL, minus_one); + _status = SetIdentity::invoke(exec_instance, ATL2, minus_one); + + _status = Gemm::invoke(_handle_blas, minus_one, + ABL, ATR, zero, ABR); + checkDeviceBlasStatus("gemm"); + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), minus_one, T, AT); + checkDeviceBlasStatus("trsm"); + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), minus_one, T, AL); + checkDeviceBlasStatus("trsm"); + + } else { + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATL2(s.l_buf, m, m); + UnmanagedViewType T(bptr, m, m); + + _status = Copy::invoke(exec_instance, T, ATL); + + _status = SetIdentity::invoke(exec_instance, ATL, one); + _status = SetIdentity::invoke(exec_instance, ATL2, one); + + _status = Trsm::invoke( + _handle_blas, Diag::NonUnit(), one, T, ATL); + checkDeviceBlasStatus("trsm"); + _status = Trsm::invoke( + _handle_blas, Diag::Unit(), one, T, ATL2); + checkDeviceBlasStatus("trsm"); + } + } + } + } + } + } + + inline void factorizeLU_OnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_factor_ptr, const value_type_array &work) { + if (variant == 0) + factorizeLU_OnDeviceVar0(pbeg, pend, h_buf_factor_ptr, work); + else if (variant == 1) + factorizeLU_OnDeviceVar1(pbeg, pend, h_buf_factor_ptr, work); + else if (variant == 2) + factorizeLU_OnDeviceVar2(pbeg, pend, h_buf_factor_ptr, work); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::factorizeLU_OnDevice, algorithm variant is not supported"); + } + } + + /// + /// Level set factorize + /// + inline void factorizeCholesky(const value_type_array &ax, const ordinal_type verbose) { + constexpr bool is_host = std::is_same::value; + Kokkos::Timer timer; + + timer.reset(); + value_type_array work; + { + _buf = value_type_array(do_not_initialize_tag("buf"), _bufsize_factorize); + track_alloc(_buf.span() * sizeof(value_type)); + +#if defined(KOKKOS_ENABLE_CUDA) + value_type_matrix T(NULL, _info.max_supernode_size, _info.max_supernode_size); + const size_type worksize = Chol::invoke(_handle_lapack, T, work); + + /// TODO:: why do i plus one ? fix this; + work = value_type_array(do_not_initialize_tag("work"), worksize * (_nstreams + 1)); + track_alloc(work.span() * sizeof(value_type)); +#endif + } + stat.t_extra = timer.seconds(); + + timer.reset(); + { + _ax = ax; // matrix values + constexpr bool copy_to_l_buf(false); + _info.copySparseToSuperpanels(copy_to_l_buf, _ap, _aj, _ax, _perm, _peri); + } + stat.t_copy = timer.seconds(); + + stat_level.n_kernel_launching = 0; + timer.reset(); + { + // this should be considered with average problem sizes in levels + const ordinal_type half_level = _nlevel / 2; + // const ordinal_type team_size_factor[2] = { 64, 16 }, vector_size_factor[2] = { 8, 8}; + // const ordinal_type team_size_factor[2] = { 16, 16 }, vector_size_factor[2] = { 32, 32}; + const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; + const ordinal_type team_size_update[2] = {16, 8}, vector_size_update[2] = {32, 32}; + { + typedef TeamFunctor_FactorizeChol functor_type; +#if defined(TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD) + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_factorize; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; +#else + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template FactorizeTag> + team_policy_factor; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::UpdateTag> + team_policy_update; +#endif + + functor_type functor(_info, _factorize_mode, _level_sids, _buf); + + team_policy_factor policy_factor(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + { + for (ordinal_type lvl = (_team_serial_level_cut - 1); lvl >= 0; --lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_buf_factor_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + + const auto buf_factor_ptr = Kokkos::subview(_buf_factor_ptr, range_buf_factor_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(buf_factor_ptr); + if (is_host) { + policy_factor = team_policy_factor(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_factor = team_policy_factor(pcnt, team_size_factor[idx], vector_size_factor[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("factor lower", policy_factor, functor); + } else { + Kokkos::parallel_for("factor", policy_factor, functor); + ++stat_level.n_kernel_launching; + } + + const auto h_buf_factor_ptr = Kokkos::subview(_h_buf_factor_ptr, range_buf_factor_ptr); + factorizeCholeskyOnDevice(pbeg, pend, h_buf_factor_ptr, work); + Kokkos::fence(); + + Kokkos::parallel_for("update factor", policy_update, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); // Kokkos::fence(); + } + } + } + } // end of Cholesky + stat.t_factor = timer.seconds(); + + timer.reset(); + { +#if defined(KOKKOS_ENABLE_CUDA) + track_free(work.span() * sizeof(value_type)); +#endif + track_free(_buf.span() * sizeof(value_type)); + _buf = value_type_array(); + } + stat.t_extra += timer.seconds(); + + if (verbose) { + printf("Summary: LevelSetTools-Variant-%d (CholeskyFactorize)\n", variant); + printf("=====================================================\n"); + print_stat_factor(); + } + } + + inline void solveCholeskyLowerOnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + _status = cublasSetStream(_handle_blas, _cuda_streams[q % _nstreams]); + checkDeviceStatus("cublasSetStream"); + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += ATL.span(); + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + _status = + Trsv::invoke(_handle_blas, Diag::NonUnit(), ATL, tT); + checkDeviceBlasStatus("trsv"); + + if (n_m > 0) { + // solve offdiag + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n_m; + UnmanagedViewType bB(bptr, n_m, nrhs); + _status = Gemv::invoke(_handle_blas, minus_one, ATR, tT, zero, bB); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + } + + inline void solveCholeskyLowerOnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + _status = cublasSetStream(_handle_blas, _cuda_streams[q % _nstreams]); + checkDeviceStatus("cublasSetStream"); + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType bT(bptr, m, nrhs); + bptr += m * nrhs; + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + _status = Gemv::invoke(_handle_blas, one, ATL, tT, zero, bT); + checkDeviceBlasStatus("gemv"); + + if (n_m > 0) { + // solve offdiag + UnmanagedViewType ATR(aptr, m, n_m); + UnmanagedViewType bB(bptr, n_m, nrhs); + + _status = Gemv::invoke(_handle_blas, minus_one, ATR, bT, zero, bB); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + } + + inline void solveCholeskyLowerOnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + _status = cublasSetStream(_handle_blas, _cuda_streams[q % _nstreams]); + checkDeviceStatus("cublasSetStream"); + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType AT(aptr, m, n); + + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType b(bptr, n, nrhs); + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + _status = Gemv::invoke(_handle_blas, one, AT, tT, zero, b); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + + inline void solveCholeskyLowerOnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + if (variant == 0) + solveCholeskyLowerOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 1) + solveCholeskyLowerOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 2) + solveCholeskyLowerOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::solveCholeskyLowerOnDevice, algorithm variant is not supported"); + } + } + + inline void solveCholeskyUpperOnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), one(1); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + ; + const UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + const UnmanagedViewType bB(bptr, n_m, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + if (n_m > 0) { + const UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + _status = Gemv::invoke(_handle_blas, minus_one, ATR, bB, one, tT); + checkDeviceBlasStatus("gemv"); + } + _status = + Trsv::invoke(_handle_blas, Diag::NonUnit(), ATL, tT); + checkDeviceBlasStatus("trsv"); + } + } + } + } + } + + inline void solveCholeskyUpperOnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + + const UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + const UnmanagedViewType bT(bptr, m, nrhs); + bptr += m * nrhs; + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + if (n_m > 0) { + const UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + const UnmanagedViewType bB(bptr, n_m, nrhs); + _status = Gemv::invoke(_handle_blas, minus_one, ATR, bB, one, tT); + checkDeviceBlasStatus("gemv"); + } + + _status = Gemv::invoke(_handle_blas, one, ATL, tT, zero, bT); + checkDeviceBlasStatus("gemv"); + + _status = Copy::invoke(exec_instance, tT, bT); + checkDeviceBlasStatus("Copy"); + } + } + } + } + } + + inline void solveCholeskyUpperOnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0 && n > 0) { + value_type *aptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + ; + const UnmanagedViewType AT(aptr, m, n); + const UnmanagedViewType b(bptr, n, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + _status = Gemv::invoke(_handle_blas, one, AT, b, zero, tT); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + + inline void solveCholeskyUpperOnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + if (variant == 0) + solveCholeskyUpperOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 1) + solveCholeskyUpperOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 2) + solveCholeskyUpperOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::solveCholeskyUpperOnDevice, algorithm variant is not supported"); + } + } + + inline void solveLDL_LowerOnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + + const ordinal_type offm = s.row_begin; + + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + if (!s.do_not_apply_pivots) { + const auto fpiv = ordinal_type_array(_piv.data() + 4 * offm + m, m); + _status = ApplyPivots /// row inter-change + ::invoke(exec_instance, fpiv, tT); + } + + _status = + Trsv::invoke(_handle_blas, Diag::Unit(), ATL, tT); + checkDeviceBlasStatus("trsv"); + if (n_m > 0) { + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType ATR(aptr, m, n_m); // ptr += m*n_m; + UnmanagedViewType bB(bptr, n_m, nrhs); + _status = Gemv::invoke(_handle_blas, minus_one, ATR, tT, zero, bB); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + } + + inline void solveLDL_LowerOnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType bT(bptr, m, nrhs); + bptr += m * nrhs; + + const ordinal_type offm = s.row_begin; + + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + if (s.do_not_apply_pivots) { + _status = Copy::invoke(exec_instance, bT, tT); + } else { + ConstUnmanagedViewType perm(_piv.data() + 4 * offm + 2 * m, m); + _status = + ApplyPermutation::invoke(exec_instance, tT, perm, bT); + } + + _status = Gemv::invoke(_handle_blas, one, ATL, bT, zero, tT); + checkDeviceBlasStatus("gemv"); + + if (n_m > 0) { + UnmanagedViewType ATR(aptr, m, n_m); + UnmanagedViewType bB(bptr, n_m, nrhs); + + _status = Gemv::invoke(_handle_blas, minus_one, ATR, tT, zero, bB); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + } + + inline void solveLDL_LowerOnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0) { + value_type *aptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType AT(aptr, m, n); + UnmanagedViewType b(bptr, n, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + if (!s.do_not_apply_pivots) { + UnmanagedViewType bT(bptr, m, nrhs); + ConstUnmanagedViewType perm(_piv.data() + 4 * offm + 2 * m, m); + _status = Copy::invoke(exec_instance, bT, tT); + + _status = + ApplyPermutation::invoke(exec_instance, bT, perm, tT); + } + + _status = Gemv::invoke(_handle_blas, one, AT, tT, zero, b); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + + inline void solveLDL_LowerOnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + if (variant == 0) + solveLDL_LowerOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 1) + solveLDL_LowerOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 2) + solveLDL_LowerOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::solveLDL_LowerOnDevice, algorithm variant is not supported"); + } + } + + inline void solveLDL_UpperOnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), one(1); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + ; + const UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + const UnmanagedViewType bB(bptr, n_m, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + const auto P = ordinal_type_array(_piv.data() + 4 * offm, 4 * m); + const auto D = value_type_matrix(_diag.data() + 2 * offm, m, 2); + _status = Scale2x2_BlockInverseDiagonals /// row scaling + ::invoke(exec_instance, P, D, tT); + + if (n_m > 0) { + const UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + _status = Gemv::invoke(_handle_blas, minus_one, ATR, bB, one, tT); + checkDeviceBlasStatus("gemv"); + } + _status = Trsv::invoke(_handle_blas, Diag::Unit(), ATL, tT); + checkDeviceBlasStatus("trsv"); + + if (!s.do_not_apply_pivots) { + const auto fpiv = ordinal_type_array(P.data() + m, m); + _status = ApplyPivots /// row inter-change + ::invoke(exec_instance, fpiv, tT); + } + } + } + } + } + } + + inline void solveLDL_UpperOnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + + const UnmanagedViewType ATL(aptr, m, m); + aptr += ATL.span(); + const UnmanagedViewType bT(bptr, m, nrhs); + bptr += bT.span(); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + const auto P = ordinal_type_array(_piv.data() + 4 * offm, 4 * m); + const auto D = value_type_matrix(_diag.data() + 2 * offm, m, 2); + _status = Scale2x2_BlockInverseDiagonals /// row scaling + ::invoke(exec_instance, P, D, tT); + + if (n_m > 0) { + const UnmanagedViewType bB(bptr, n_m, nrhs); + const UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + _status = Gemv::invoke(_handle_blas, minus_one, ATR, bB, one, tT); + checkDeviceBlasStatus("gemv"); + } + + _status = Gemv::invoke(_handle_blas, one, ATL, tT, zero, bT); + checkDeviceBlasStatus("gemv"); + + if (s.do_not_apply_pivots) { + _status = Copy::invoke(exec_instance, tT, bT); + } else { + ConstUnmanagedViewType peri(P.data() + 3 * m, m); + _status = + ApplyPermutation::invoke(exec_instance, bT, peri, tT); + } + } + } + } + } + } + + inline void solveLDL_UpperOnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0 && n > 0) { + value_type *aptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + + const UnmanagedViewType AT(aptr, m, n); + const UnmanagedViewType b(bptr, n, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + const UnmanagedViewType bT(bptr, m, nrhs); + + ConstUnmanagedViewType P(_piv.data() + offm * 4, m * 4); + ConstUnmanagedViewType D(_diag.data() + offm * 2, m, 2); + + _status = Scale2x2_BlockInverseDiagonals /// row scaling + ::invoke(exec_instance, P, D, bT); + + _status = Gemv::invoke(_handle_blas, one, AT, b, zero, tT); + + if (!s.do_not_apply_pivots) { + _status = Copy::invoke(exec_instance, bT, tT); + + ConstUnmanagedViewType peri(P.data() + 3 * m, m); + _status = + ApplyPermutation::invoke(exec_instance, bT, peri, tT); + } + } + } + } + } + } + + inline void solveLDL_UpperOnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + if (variant == 0) + solveLDL_UpperOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 1) + solveLDL_UpperOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 2) + solveLDL_UpperOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::solveLDL_UpperOnDevice, algorithm variant is not supported"); + } + } + + inline void solveLU_LowerOnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType ATL(s.u_buf, m, m); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + const auto fpiv = ordinal_type_array(_piv.data() + 4 * offm + m, m); + + _status = ApplyPivots /// row inter-change + ::invoke(exec_instance, fpiv, tT); + + _status = + Trsv::invoke(_handle_blas, Diag::Unit(), ATL, tT); + checkDeviceBlasStatus("trsv"); + + if (n_m > 0) { + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType AL(s.l_buf, n, m); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + UnmanagedViewType bB(bptr, n_m, nrhs); + _status = Gemv::invoke(_handle_blas, minus_one, ABL, tT, zero, bB); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + } + + inline void solveLU_LowerOnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), minus_one(-1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType AL(s.l_buf, n, m); + const auto ATL = Kokkos::subview(AL, range_type(0, m), Kokkos::ALL()); + + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType bT(bptr, m, nrhs); + + const ordinal_type offm = s.row_begin; + + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + ConstUnmanagedViewType perm(_piv.data() + 4 * offm + 2 * m, m); + + if (s.do_not_apply_pivots) { + _status = Copy::invoke(exec_instance, bT, tT); + } else { + _status = + ApplyPermutation::invoke(exec_instance, tT, perm, bT); + } + + _status = Gemv::invoke(_handle_blas, one, ATL, bT, zero, tT); + checkDeviceBlasStatus("gemv"); + + if (n_m > 0) { + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + UnmanagedViewType bB(bptr + bT.span(), n_m, nrhs); + _status = Gemv::invoke(_handle_blas, minus_one, ABL, tT, zero, bB); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + } + + inline void solveLU_LowerOnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0) { + UnmanagedViewType AL(s.l_buf, n, m); + + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + UnmanagedViewType b(bptr, n, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + if (!s.do_not_apply_pivots) { + UnmanagedViewType bT(bptr, m, nrhs); + ConstUnmanagedViewType perm(_piv.data() + 4 * offm + 2 * m, m); + _status = Copy::invoke(exec_instance, bT, tT); + + _status = + ApplyPermutation::invoke(exec_instance, bT, perm, tT); + } + + _status = Gemv::invoke(_handle_blas, one, AL, tT, zero, b); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } + + inline void solveLU_LowerOnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + if (variant == 0) + solveLU_LowerOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 1) + solveLU_LowerOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 2) + solveLU_LowerOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::solveLU_LowerOnDevice, algorithm variant is not supported"); + } + } + + inline void solveLU_UpperOnDeviceVar0(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), one(1); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *uptr = s.u_buf, *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + ; + const UnmanagedViewType ATL(uptr, m, m); + uptr += m * m; + const UnmanagedViewType bB(bptr, n_m, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + if (n_m > 0) { + const UnmanagedViewType ATR(uptr, m, n_m); // uptr += m*n; + _status = Gemv::invoke(_handle_blas, minus_one, ATR, bB, one, tT); + checkDeviceBlasStatus("gemv"); + } + _status = + Trsv::invoke(_handle_blas, Diag::NonUnit(), ATL, tT); + checkDeviceBlasStatus("trsv"); } } } } + } + inline void solveLU_UpperOnDeviceVar1(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type minus_one(-1), one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + + const UnmanagedViewType ATL(s.u_buf, m, m); + const UnmanagedViewType bT(bptr, m, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + _status = Copy::invoke(exec_instance, bT, tT); + if (n_m > 0) { + const UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + const UnmanagedViewType bB(bptr + bT.span(), n_m, nrhs); + _status = Gemv::invoke(_handle_blas, minus_one, ATR, bB, one, bT); + checkDeviceBlasStatus("gemv"); + } + + _status = Gemv::invoke(_handle_blas, one, ATL, bT, zero, tT); + checkDeviceBlasStatus("gemv"); + } + } + } + } + } - inline - void - allocateWorkspaceSolve(const ordinal_type nrhs) { - const size_type buf_extent = _bufsize_solve*nrhs; - const size_type buf_span = _buf.span(); - - if (buf_extent != buf_span) { - _buf = value_type_array(do_not_initialize_tag("buf"), buf_extent); - track_free(buf_span*sizeof(value_type)); - track_alloc(_buf.span()*sizeof(value_type)); + inline void solveLU_UpperOnDeviceVar2(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + const ordinal_type nrhs = t.extent(1); + const value_type one(1), zero(0); +#if defined(KOKKOS_ENABLE_CUDA) + ordinal_type q(0); +#endif + exec_space exec_instance; + for (ordinal_type p = pbeg; p < pend; ++p) { + const ordinal_type sid = _h_level_sids(p); + if (_h_solve_mode(sid) == 0) { +#if defined(KOKKOS_ENABLE_CUDA) + const ordinal_type qid = q % _nstreams; + const auto mystream = _cuda_streams[qid]; + _status = cublasSetStream(_handle_blas, mystream); + checkDeviceStatus("cublasSetStream"); + exec_instance = _exec_instances[qid]; + ++q; +#endif + const auto &s = _h_supernodes(sid); { - const Kokkos::RangePolicy policy(0,_buf_solve_ptr.extent(0)); - const auto buf_solve_nrhs_ptr = _buf_solve_nrhs_ptr; - const auto buf_solve_ptr = _buf_solve_ptr; - Kokkos::parallel_for(policy, KOKKOS_LAMBDA(const ordinal_type &i) { - buf_solve_nrhs_ptr(i) = nrhs*buf_solve_ptr(i); - }); + const ordinal_type m = s.m, n = s.n; + if (m > 0) { + value_type *bptr = _buf.data() + h_buf_solve_ptr(p - pbeg); + + const UnmanagedViewType AT(s.u_buf, m, n); + const UnmanagedViewType b(bptr, n, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL()); + + _status = Gemv::invoke(_handle_blas, one, AT, b, zero, tT); + checkDeviceBlasStatus("gemv"); + } } - Kokkos::deep_copy(_h_buf_solve_nrhs_ptr, _buf_solve_nrhs_ptr); } } + } + + inline void solveLU_UpperOnDevice(const ordinal_type pbeg, const ordinal_type pend, + const size_type_array_host &h_buf_solve_ptr, const value_type_matrix &t) { + if (variant == 0) + solveLU_UpperOnDeviceVar0(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 1) + solveLU_UpperOnDeviceVar1(pbeg, pend, h_buf_solve_ptr, t); + else if (variant == 2) + solveLU_UpperOnDeviceVar2(pbeg, pend, h_buf_solve_ptr, t); + else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, + "LevelSetTools::solveLU_UpperOnDevice, algorithm variant is not supported"); + } + } + + inline void allocateWorkspaceSolve(const ordinal_type nrhs) { + const size_type buf_extent = _bufsize_solve * nrhs; + const size_type buf_span = _buf.span(); + if (buf_extent != buf_span) { + _buf = value_type_array(do_not_initialize_tag("buf"), buf_extent); + track_free(buf_span * sizeof(value_type)); + track_alloc(_buf.span() * sizeof(value_type)); + { + const Kokkos::RangePolicy policy(0, _buf_solve_ptr.extent(0)); + const auto buf_solve_nrhs_ptr = _buf_solve_nrhs_ptr; + const auto buf_solve_ptr = _buf_solve_ptr; + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &i) { buf_solve_nrhs_ptr(i) = nrhs * buf_solve_ptr(i); }); + } + Kokkos::deep_copy(_h_buf_solve_nrhs_ptr, _buf_solve_nrhs_ptr); + } + } - inline - void - solveCholesky(const value_type_matrix &x, // solution - const value_type_matrix &b, // right hand side - const value_type_matrix &t, - const ordinal_type verbose) { // temporary workspace (store permuted vectors) - TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || - x.extent(1) != b.extent(1) || - x.extent(0) != t.extent(0) || - x.extent(1) != t.extent(1), std::logic_error, - "x, b, t, and w dimensions do not match"); + inline void solveCholesky(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, + const ordinal_type verbose) { // temporary workspace (store permuted vectors) + TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || x.extent(1) != b.extent(1) || x.extent(0) != t.extent(0) || + x.extent(1) != t.extent(1), + std::logic_error, "x, b, t, and w dimensions do not match"); - TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || - x.data() == t.data(), std::logic_error, - "x, b, t, and w have the same data pointer"); - constexpr bool is_host = std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || x.data() == t.data(), std::logic_error, + "x, b, t, and w have the same data pointer"); + constexpr bool is_host = std::is_same::value; - // solve U^{H} (U x) = b - const ordinal_type nrhs = x.extent(1); - Kokkos::Timer timer; + // solve U^{H} (U x) = b + const ordinal_type nrhs = x.extent(1); + Kokkos::Timer timer; - stat_level.n_kernel_launching = 0; + stat_level.n_kernel_launching = 0; - // one-time operation when nrhs is changed - timer.reset(); - allocateWorkspaceSolve(nrhs); + // one-time operation when nrhs is changed + timer.reset(); + allocateWorkspaceSolve(nrhs); - // 0. permute and copy b -> t - applyRowPermutationToDenseMatrix(t, b, _perm); - stat.t_extra = timer.seconds(); + // 0. permute and copy b -> t + const auto exec_instance = exec_space(); + ApplyPermutation::invoke(exec_instance, b, _perm, t); + stat.t_extra = timer.seconds(); - timer.reset(); - { + timer.reset(); + { #if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) - const auto work_item_property = Kokkos::Experimental::WorkItemProperty::HintLightWeight; + const auto work_item_property = Kokkos::Experimental::WorkItemProperty::HintLightWeight; #endif - // this should be considered with average problem sizes in levels - const ordinal_type half_level = _nlevel/2; - const ordinal_type team_size_solve[2] = { 64, 16 }, vector_size_solve[2] = { 8, 8}; - const ordinal_type team_size_update[2] = { 128, 32}, vector_size_update[2] = { 1, 1}; - { - typedef TeamFunctor_SolveLowerChol functor_type; + // this should be considered with average problem sizes in levels + const ordinal_type half_level = _nlevel / 2; + const ordinal_type team_size_solve[2] = {64, 16}, vector_size_solve[2] = {8, 8}; + const ordinal_type team_size_update[2] = {128, 32}, vector_size_update[2] = {1, 1}; + { + typedef TeamFunctor_SolveLowerChol functor_type; #if defined(TACHO_TEST_SOLVE_CHOLESKY_KERNEL_OVERHEAD) - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_update; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; #else - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::template SolveTag > team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::template UpdateTag > team_policy_update; -#endif - functor_type functor(_info, - _solve_mode, - _level_sids, - t, - _buf); - - team_policy_solve policy_solve(1,1,1); - team_policy_update policy_update(1,1,1); - - // 1. U^{H} w = t - { - for (ordinal_type lvl=(_team_serial_level_cut-1);lvl>=0;--lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1), - pcnt = pend - pbeg; - - const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl+1)); - - const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); - functor.setRange(pbeg, pend); - functor.setBufferPtr(solve_buf_ptr); - if (is_host) { - policy_solve = team_policy_solve(pcnt, 1, 1); - policy_update = team_policy_update(pcnt, 1, 1); - } else { - const ordinal_type idx = lvl > half_level; - policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); - policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); - } + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template SolveTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _solve_mode, _level_sids, t, _buf); + + team_policy_solve policy_solve(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + // 1. U^{H} w = t + { + for (ordinal_type lvl = (_team_serial_level_cut - 1); lvl >= 0; --lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + + const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(solve_buf_ptr); + if (is_host) { + policy_solve = team_policy_solve(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } #if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) - const auto policy_solve_with_work_property = Kokkos::Experimental::require(policy_solve, work_item_property); - const auto policy_update_with_work_property = Kokkos::Experimental::require(policy_update, work_item_property); + const auto policy_solve_with_work_property = + Kokkos::Experimental::require(policy_solve, work_item_property); + const auto policy_update_with_work_property = + Kokkos::Experimental::require(policy_update, work_item_property); #else - const auto policy_solve_with_work_property = policy_solve; - const auto policy_update_with_work_property = policy_update; -#endif - if (lvl < _device_level_cut) { - // do nothing - //Kokkos::parallel_for("solve lower", policy_solve, functor); - } else { - Kokkos::parallel_for("solve lower", - policy_solve_with_work_property, - functor); - ++stat_level.n_kernel_launching; - } - const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); - solveCholeskyLowerOnDevice(pbeg, pend, h_buf_solve_ptr, t); - Kokkos::fence(); - - Kokkos::parallel_for("update lower", - policy_update_with_work_property, - functor); + const auto policy_solve_with_work_property = policy_solve; + const auto policy_update_with_work_property = policy_update; +#endif + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("solve lower", policy_solve, functor); + } else { + Kokkos::parallel_for("solve lower", policy_solve_with_work_property, functor); ++stat_level.n_kernel_launching; - exec_space().fence(); //Kokkos::fence(); } + const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); + solveCholeskyLowerOnDevice(pbeg, pend, h_buf_solve_ptr, t); + Kokkos::fence(); + + Kokkos::parallel_for("update lower", policy_update_with_work_property, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); } - } // end of lower tri solve - - { - typedef TeamFunctor_SolveUpperChol functor_type; + } + } // end of lower tri solve + + { + typedef TeamFunctor_SolveUpperChol functor_type; #if defined(TACHO_TEST_SOLVE_CHOLESKY_KERNEL_OVERHEAD) - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_update; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; #else - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::template SolveTag > team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::template UpdateTag > team_policy_update; -#endif - functor_type functor(_info, - _solve_mode, - _level_sids, - t, - _buf); - - team_policy_solve policy_solve(1,1,1); - team_policy_update policy_update(1,1,1); - - // 2. U t = w; - { - for (ordinal_type lvl=0;lvl<_team_serial_level_cut;++lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1), - pcnt = pend - pbeg; - - const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl+1)); - const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); - functor.setRange(pbeg, pend); - functor.setBufferPtr(solve_buf_ptr); - if (is_host) { - policy_solve = team_policy_solve(pcnt, 1, 1); - policy_update = team_policy_update(pcnt, 1, 1); - } else { - const ordinal_type idx = lvl > half_level; - policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); - policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); - } + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template SolveTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _solve_mode, _level_sids, t, _buf); + + team_policy_solve policy_solve(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + // 2. U t = w; + { + for (ordinal_type lvl = 0; lvl < _team_serial_level_cut; ++lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(solve_buf_ptr); + if (is_host) { + policy_solve = team_policy_solve(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } #if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) - const auto policy_solve_with_work_property = Kokkos::Experimental::require(policy_solve, work_item_property); - const auto policy_update_with_work_property = Kokkos::Experimental::require(policy_update, work_item_property); + const auto policy_solve_with_work_property = + Kokkos::Experimental::require(policy_solve, work_item_property); + const auto policy_update_with_work_property = + Kokkos::Experimental::require(policy_update, work_item_property); #else - const auto policy_solve_with_work_property = policy_solve; - const auto policy_update_with_work_property = policy_update; + const auto policy_solve_with_work_property = policy_solve; + const auto policy_update_with_work_property = policy_update; #endif - Kokkos::parallel_for("update upper", - policy_update_with_work_property, - functor); + Kokkos::parallel_for("update upper", policy_update_with_work_property, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); + + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("solve upper", policy_solve, functor); + } else { + Kokkos::parallel_for("solve upper", policy_solve_with_work_property, functor); ++stat_level.n_kernel_launching; - exec_space().fence(); //Kokkos::fence(); - - if (lvl < _device_level_cut) { - // do nothing - //Kokkos::parallel_for("solve upper", policy_solve, functor); - } else { - Kokkos::parallel_for("solve upper", - policy_solve_with_work_property, - functor); - ++stat_level.n_kernel_launching; - } - - const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); - solveCholeskyUpperOnDevice(pbeg, pend, h_buf_solve_ptr, t); - Kokkos::fence(); } + + const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); + solveCholeskyUpperOnDevice(pbeg, pend, h_buf_solve_ptr, t); + Kokkos::fence(); } - }/// end of upper tri solve + } + } /// end of upper tri solve - } // end of solve - stat.t_solve = timer.seconds(); + } // end of solve + stat.t_solve = timer.seconds(); - // permute and copy t -> x - timer.reset(); - applyRowPermutationToDenseMatrix(x, t, _peri); - stat.t_extra += timer.seconds(); + // permute and copy t -> x + timer.reset(); + ApplyPermutation::invoke(exec_instance, t, _peri, x); + stat.t_extra += timer.seconds(); - if (verbose) { - printf("Summary: LevelSetTools-Variant-%d (Cholesky Solve: %3d)\n", variant, nrhs); - printf("=======================================================\n"); - print_stat_solve(); - } + if (verbose) { + printf("Summary: LevelSetTools-Variant-%d (Cholesky Solve: %3d)\n", variant, nrhs); + printf("=======================================================\n"); + print_stat_solve(); } + } - inline - void - factorizeLDL(const value_type_array &ax, - const ordinal_type verbose) { - constexpr bool is_host = std::is_same::value; - Kokkos::Timer timer; - - timer.reset(); - value_type_array work; - { - _buf = value_type_array(do_not_initialize_tag("buf"), _bufsize_factorize); - track_alloc(_buf.span()*sizeof(value_type)); - -#if defined (KOKKOS_ENABLE_CUDA) - value_type_matrix T(NULL, _info.max_supernode_size, _info.max_supernode_size); - ordinal_type_array P(NULL, _info.max_supernode_size); - const size_type worksize = LDL - ::invoke(_handle_lapack, T, P, work); - work = value_type_array(do_not_initialize_tag("work"), worksize*(_nstreams+1)*max(8, _nstreams)); + inline void factorizeLDL(const value_type_array &ax, const ordinal_type verbose) { + constexpr bool is_host = std::is_same::value; + Kokkos::Timer timer; + + timer.reset(); + value_type_array work; + { + _buf = value_type_array(do_not_initialize_tag("buf"), _bufsize_factorize); + track_alloc(_buf.span() * sizeof(value_type)); + +#if defined(KOKKOS_ENABLE_CUDA) + value_type_matrix T(NULL, _info.max_supernode_size, _info.max_supernode_size); + ordinal_type_array P(NULL, _info.max_supernode_size); + const size_type worksize = LDL::invoke(_handle_lapack, T, P, work); + /// TODO:: why do i multiply additional max(8, _nstreams); fix this + work = value_type_array(do_not_initialize_tag("work"), worksize * (_nstreams + 1) * max(8, _nstreams)); #else - const size_type worksize = 32*_info.max_supernode_size; - work = value_type_array(do_not_initialize_tag("work"), worksize); + const size_type worksize = 32 * _info.max_supernode_size; + work = value_type_array(do_not_initialize_tag("work"), worksize); #endif - track_alloc(work.span()*sizeof(value_type)); - } - stat.t_extra = timer.seconds(); - - timer.reset(); - { - _ax = ax; // matrix values - _info.copySparseToSuperpanels(_ap, _aj, _ax, _perm, _peri); - } - stat.t_copy = timer.seconds(); - - stat_level.n_kernel_launching = 0; - timer.reset(); - { - // this should be considered with average problem sizes in levels - const ordinal_type half_level = _nlevel/2; - //const ordinal_type team_size_factor[2] = { 64, 16 }, vector_size_factor[2] = { 8, 8}; - //const ordinal_type team_size_factor[2] = { 16, 16 }, vector_size_factor[2] = { 32, 32}; -#if defined (CUDA_VERSION) + track_alloc(work.span() * sizeof(value_type)); + } + stat.t_extra = timer.seconds(); + + timer.reset(); + { + _ax = ax; // matrix values + constexpr bool copy_to_l_buf(false); + _info.copySparseToSuperpanels(copy_to_l_buf, _ap, _aj, _ax, _perm, _peri); + } + stat.t_copy = timer.seconds(); + + stat_level.n_kernel_launching = 0; + timer.reset(); + { + // this should be considered with average problem sizes in levels + const ordinal_type half_level = _nlevel / 2; + // const ordinal_type team_size_factor[2] = { 64, 16 }, vector_size_factor[2] = { 8, 8}; + // const ordinal_type team_size_factor[2] = { 16, 16 }, vector_size_factor[2] = { 32, 32}; +#if defined(CUDA_VERSION) #if (11000 > CUDA_VERSION) - /// cuda 11.1 below - const ordinal_type team_size_factor[2] = { 32, 64 }, vector_size_factor[2] = { 8, 4}; -#else - /// cuda 11.1 and higher - const ordinal_type team_size_factor[2] = { 64, 64 }, vector_size_factor[2] = { 8, 4}; + /// cuda 11.1 below + const ordinal_type team_size_factor[2] = {32, 64}, vector_size_factor[2] = {8, 4}; +#else + /// cuda 11.1 and higher + const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; #endif #else - /// not cuda ... whatever.. - const ordinal_type team_size_factor[2] = { 64, 64 }, vector_size_factor[2] = { 8, 4}; + /// not cuda ... whatever.. + const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; #endif - const ordinal_type team_size_update[2] = { 16, 8 }, vector_size_update[2] = { 32, 32}; - { - typedef TeamFunctor_FactorizeLDL functor_type; + const ordinal_type team_size_update[2] = {16, 8}, vector_size_update[2] = {32, 32}; + { + typedef TeamFunctor_FactorizeLDL functor_type; #if defined(TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD) - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_factorize; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_update; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_factorize; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; #else - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::FactorizeTag> team_policy_factor; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::UpdateTag> team_policy_update; -#endif - functor_type functor(_info, - _factorize_mode, - _level_sids, - _piv, - _diag, - _buf); - - team_policy_factor policy_factor(1,1,1); - team_policy_update policy_update(1,1,1); - - { - for (ordinal_type lvl=(_team_serial_level_cut-1);lvl>=0;--lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1), - pcnt = pend - pbeg; - - const range_type range_buf_factor_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl+1)); - - const auto buf_factor_ptr = Kokkos::subview(_buf_factor_ptr, range_buf_factor_ptr); - functor.setRange(pbeg, pend); - functor.setBufferPtr(buf_factor_ptr); - if (is_host) { - policy_factor = team_policy_factor(pcnt, 1, 1); - policy_update = team_policy_update(pcnt, 1, 1); - } else { - const ordinal_type idx = lvl > half_level; - policy_factor = team_policy_factor(pcnt, team_size_factor[idx], vector_size_factor[idx]); - policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); - } - if (lvl < _device_level_cut) { - // do nothing - //Kokkos::parallel_for("factor lower", policy_factor, functor); - } else { - Kokkos::parallel_for("factor", policy_factor, functor); - ++stat_level.n_kernel_launching; - } - - const auto h_buf_factor_ptr = Kokkos::subview(_h_buf_factor_ptr, range_buf_factor_ptr); - - factorizeLDL_OnDevice(pbeg, pend, h_buf_factor_ptr, work); - Kokkos::fence(); - - Kokkos::parallel_for("update factor", policy_update, functor); + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template FactorizeTag> + team_policy_factor; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _factorize_mode, _level_sids, _piv, _diag, _buf); + + team_policy_factor policy_factor(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + { + for (ordinal_type lvl = (_team_serial_level_cut - 1); lvl >= 0; --lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_buf_factor_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + + const auto buf_factor_ptr = Kokkos::subview(_buf_factor_ptr, range_buf_factor_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(buf_factor_ptr); + if (is_host) { + policy_factor = team_policy_factor(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_factor = team_policy_factor(pcnt, team_size_factor[idx], vector_size_factor[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("factor lower", policy_factor, functor); + } else { + Kokkos::parallel_for("factor", policy_factor, functor); ++stat_level.n_kernel_launching; - exec_space().fence(); //Kokkos::fence(); } + + const auto h_buf_factor_ptr = Kokkos::subview(_h_buf_factor_ptr, range_buf_factor_ptr); + + factorizeLDL_OnDevice(pbeg, pend, h_buf_factor_ptr, work); + Kokkos::fence(); + + Kokkos::parallel_for("update factor", policy_update, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); } + const auto exec_instance = exec_space(); + Kokkos::deep_copy(exec_instance, _h_supernodes, _info.supernodes); } - } // end of LDL - stat.t_factor = timer.seconds(); + } + } // end of LDL + stat.t_factor = timer.seconds(); - timer.reset(); - { -#if defined (KOKKOS_ENABLE_CUDA) - track_free(work.span()*sizeof(value_type)); + timer.reset(); + { +#if defined(KOKKOS_ENABLE_CUDA) + track_free(work.span() * sizeof(value_type)); #endif - track_free(_buf.span()*sizeof(value_type)); - _buf = value_type_array(); - } - stat.t_extra += timer.seconds(); + track_free(_buf.span() * sizeof(value_type)); + _buf = value_type_array(); + } + stat.t_extra += timer.seconds(); - if (verbose) { - printf("Summary: LevelSetTools (LDL Factorize)\n"); - printf("======================================\n"); - print_stat_factor(); - } + if (verbose) { + printf("Summary: LevelSetTools-Variant-%d (LDL Factorize)\n", variant); + printf("=================================================\n"); + print_stat_factor(); } + } - inline - void - solveLDL(const value_type_matrix &x, // solution - const value_type_matrix &b, // right hand side - const value_type_matrix &t, // temporary workspace (store permuted vectors) - const ordinal_type verbose) { - TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || - x.extent(1) != b.extent(1) || - x.extent(0) != t.extent(0) || - x.extent(1) != t.extent(1), std::logic_error, - "x, b, t, and w dimensions do not match"); + inline void solveLDL(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose) { + TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || x.extent(1) != b.extent(1) || x.extent(0) != t.extent(0) || + x.extent(1) != t.extent(1), + std::logic_error, "x, b, t, and w dimensions do not match"); - TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || - x.data() == t.data(), std::logic_error, - "x, b, t, and w have the same data pointer"); - constexpr bool is_host = std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || x.data() == t.data(), std::logic_error, + "x, b, t, and w have the same data pointer"); + constexpr bool is_host = std::is_same::value; - // solve L D L^{H} x = b - const ordinal_type nrhs = x.extent(1); - Kokkos::Timer timer; + // solve L D L^{H} x = b + const ordinal_type nrhs = x.extent(1); + Kokkos::Timer timer; - stat_level.n_kernel_launching = 0; + stat_level.n_kernel_launching = 0; - // one-time operation when nrhs is changed - timer.reset(); - allocateWorkspaceSolve(nrhs); + // one-time operation when nrhs is changed + timer.reset(); + allocateWorkspaceSolve(nrhs); - // 0. permute and copy b -> t - applyRowPermutationToDenseMatrix(t, b, _perm); - stat.t_extra = timer.seconds(); + // 0. permute and copy b -> t + const auto exec_instance = exec_space(); + ApplyPermutation::invoke(exec_instance, b, _perm, t); + stat.t_extra = timer.seconds(); - timer.reset(); - { + timer.reset(); + { #if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) - const auto work_item_property = Kokkos::Experimental::WorkItemProperty::HintLightWeight; + const auto work_item_property = Kokkos::Experimental::WorkItemProperty::HintLightWeight; #endif - // this should be considered with average problem sizes in levels - const ordinal_type half_level = _nlevel/2; -#if defined (CUDA_VERSION) + // this should be considered with average problem sizes in levels + const ordinal_type half_level = _nlevel / 2; +#if defined(CUDA_VERSION) #if (11000 > CUDA_VERSION) - /// cuda 11.1 below - const ordinal_type team_size_solve[2] = { 32, 16 }, vector_size_solve[2] = { 8, 8}; + /// cuda 11.1 below + const ordinal_type team_size_solve[2] = {32, 16}, vector_size_solve[2] = {8, 8}; #else - /// cuda 11.1 and higher - const ordinal_type team_size_solve[2] = { 32, 16 }, vector_size_solve[2] = { 8, 8}; + /// cuda 11.1 and higher + const ordinal_type team_size_solve[2] = {32, 16}, vector_size_solve[2] = {8, 8}; #endif #else - /// not cuda whatever... - const ordinal_type team_size_solve[2] = { 64, 16 }, vector_size_solve[2] = { 8, 8}; + /// not cuda whatever... + const ordinal_type team_size_solve[2] = {64, 16}, vector_size_solve[2] = {8, 8}; #endif - const ordinal_type team_size_update[2] = { 128, 32}, vector_size_update[2] = { 1, 1}; - { - typedef TeamFunctor_SolveLowerLDL functor_type; + const ordinal_type team_size_update[2] = {128, 32}, vector_size_update[2] = {1, 1}; + { + typedef TeamFunctor_SolveLowerLDL functor_type; #if defined(TACHO_TEST_SOLVE_CHOLESKY_KERNEL_OVERHEAD) - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_update; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; #else - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::SolveTag> team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::UpdateTag> team_policy_update; -#endif - functor_type functor(_info, - _solve_mode, - _level_sids, - _piv, - t, - _buf); - - team_policy_solve policy_solve(1,1,1); - team_policy_update policy_update(1,1,1); - - // 1. L w = t - { - for (ordinal_type lvl=(_team_serial_level_cut-1);lvl>=0;--lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1), - pcnt = pend - pbeg; - - const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl+1)); - - const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); - functor.setRange(pbeg, pend); - functor.setBufferPtr(solve_buf_ptr); - if (is_host) { - policy_solve = team_policy_solve(pcnt, 1, 1); - policy_update = team_policy_update(pcnt, 1, 1); - } else { - const ordinal_type idx = lvl > half_level; - policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); - policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); - } + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template SolveTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _solve_mode, _level_sids, _piv, t, _buf); + + team_policy_solve policy_solve(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + // 1. L w = t + { + for (ordinal_type lvl = (_team_serial_level_cut - 1); lvl >= 0; --lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + + const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(solve_buf_ptr); + if (is_host) { + policy_solve = team_policy_solve(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } #if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) - const auto policy_solve_with_work_property = Kokkos::Experimental::require(policy_solve, work_item_property); - const auto policy_update_with_work_property = Kokkos::Experimental::require(policy_update, work_item_property); + const auto policy_solve_with_work_property = + Kokkos::Experimental::require(policy_solve, work_item_property); + const auto policy_update_with_work_property = + Kokkos::Experimental::require(policy_update, work_item_property); #else - const auto policy_solve_with_work_property = policy_solve; - const auto policy_update_with_work_property = policy_update; -#endif - if (lvl < _device_level_cut) { - // do nothing - //Kokkos::parallel_for("solve lower", policy_solve, functor); - } else { - Kokkos::parallel_for("solve lower", - policy_solve_with_work_property, - functor); - ++stat_level.n_kernel_launching; - } - const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); - solveLDL_LowerOnDevice(pbeg, pend, h_buf_solve_ptr, t); - Kokkos::fence(); - - Kokkos::parallel_for("update lower", - policy_update_with_work_property, - functor); + const auto policy_solve_with_work_property = policy_solve; + const auto policy_update_with_work_property = policy_update; +#endif + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("solve lower", policy_solve, functor); + } else { + Kokkos::parallel_for("solve lower", policy_solve_with_work_property, functor); ++stat_level.n_kernel_launching; - exec_space().fence(); //Kokkos::fence(); } + const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); + solveLDL_LowerOnDevice(pbeg, pend, h_buf_solve_ptr, t); + Kokkos::fence(); + + Kokkos::parallel_for("update lower", policy_update_with_work_property, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); } - } // end of lower tri solve - - { - typedef TeamFunctor_SolveUpperLDL functor_type; + } + } // end of lower tri solve + + { + typedef TeamFunctor_SolveUpperLDL functor_type; #if defined(TACHO_TEST_SOLVE_CHOLESKY_KERNEL_OVERHEAD) - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::DummyTag> team_policy_update; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; #else - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::SolveTag> team_policy_solve; - typedef Kokkos::TeamPolicy,exec_space, - typename functor_type::UpdateTag> team_policy_update; -#endif - functor_type functor(_info, - _solve_mode, - _level_sids, - _piv, - _diag, - t, - _buf); - - team_policy_solve policy_solve(1,1,1); - team_policy_update policy_update(1,1,1); - - // 2. U t = w; - { - for (ordinal_type lvl=0;lvl<_team_serial_level_cut;++lvl) { - const ordinal_type - pbeg = _h_level_ptr(lvl), - pend = _h_level_ptr(lvl+1), - pcnt = pend - pbeg; - - const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl+1)); - const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); - functor.setRange(pbeg, pend); - functor.setBufferPtr(solve_buf_ptr); - if (is_host) { - policy_solve = team_policy_solve(pcnt, 1, 1); - policy_update = team_policy_update(pcnt, 1, 1); - } else { - const ordinal_type idx = lvl > half_level; - policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); - policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); - } + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template SolveTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _solve_mode, _level_sids, _piv, _diag, t, _buf); + + team_policy_solve policy_solve(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + // 2. U t = w; + { + for (ordinal_type lvl = 0; lvl < _team_serial_level_cut; ++lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(solve_buf_ptr); + if (is_host) { + policy_solve = team_policy_solve(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } #if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) - const auto policy_solve_with_work_property = Kokkos::Experimental::require(policy_solve, work_item_property); - const auto policy_update_with_work_property = Kokkos::Experimental::require(policy_update, work_item_property); + const auto policy_solve_with_work_property = + Kokkos::Experimental::require(policy_solve, work_item_property); + const auto policy_update_with_work_property = + Kokkos::Experimental::require(policy_update, work_item_property); #else - const auto policy_solve_with_work_property = policy_solve; - const auto policy_update_with_work_property = policy_update; + const auto policy_solve_with_work_property = policy_solve; + const auto policy_update_with_work_property = policy_update; #endif - Kokkos::parallel_for("update upper", - policy_update_with_work_property, - functor); + Kokkos::parallel_for("update upper", policy_update_with_work_property, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); + + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("solve upper", policy_solve, functor); + } else { + Kokkos::parallel_for("solve upper", policy_solve_with_work_property, functor); ++stat_level.n_kernel_launching; - exec_space().fence(); //Kokkos::fence(); - - if (lvl < _device_level_cut) { - // do nothing - //Kokkos::parallel_for("solve upper", policy_solve, functor); - } else { - Kokkos::parallel_for("solve upper", - policy_solve_with_work_property, - functor); - ++stat_level.n_kernel_launching; - } - - const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); - solveLDL_UpperOnDevice(pbeg, pend, h_buf_solve_ptr, t); - Kokkos::fence(); } + + const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); + solveLDL_UpperOnDevice(pbeg, pend, h_buf_solve_ptr, t); + Kokkos::fence(); } - }/// end of upper tri solve + } + } /// end of upper tri solve - } // end of solve - stat.t_solve = timer.seconds(); + } // end of solve + stat.t_solve = timer.seconds(); - // permute and copy t -> x - timer.reset(); - applyRowPermutationToDenseMatrix(x, t, _peri); - stat.t_extra += timer.seconds(); + // permute and copy t -> x + timer.reset(); + ApplyPermutation::invoke(exec_instance, t, _peri, x); + stat.t_extra += timer.seconds(); - if (verbose) { - printf("Summary: LevelSetTools (LDL Solve: %3d)\n", nrhs); - printf("=======================================\n"); - print_stat_solve(); - } + if (verbose) { + printf("Summary: LevelSetTools-Variant-%d (LDL Solve: %3d)\n", variant, nrhs); + printf("==================================================\n"); + print_stat_solve(); } + } - inline - void - factorize(const value_type_array &ax, - const ordinal_type verbose = 0) override { - Kokkos::deep_copy(_superpanel_buf, value_type(0)); - switch (this->getSolutionMethod()) { - case 1: { /// Cholesky - factorizeCholesky(ax, verbose); - break; + inline void factorizeLU(const value_type_array &ax, const ordinal_type verbose) { + constexpr bool is_host = std::is_same::value; + Kokkos::Timer timer; + + timer.reset(); + value_type_array work; + { + _buf = value_type_array(do_not_initialize_tag("buf"), _bufsize_factorize); + track_alloc(_buf.span() * sizeof(value_type)); + +#if defined(KOKKOS_ENABLE_CUDA) + value_type_matrix T(NULL, _info.max_supernode_size, _info.max_supernode_size); + ordinal_type_array P(NULL, _info.max_supernode_size); + const size_type worksize = LU::invoke(_handle_lapack, T, P, work); + /// TODO:: why do i do this way ? + work = value_type_array(do_not_initialize_tag("work"), worksize * (_nstreams + 1)); + // work = value_type_array(do_not_initialize_tag("work"), worksize*_nstreams); +#endif + track_alloc(work.span() * sizeof(value_type)); + } + stat.t_extra = timer.seconds(); + + timer.reset(); + { + _ax = ax; // matrix values + constexpr bool copy_to_l_buf(true); + _info.copySparseToSuperpanels(copy_to_l_buf, _ap, _aj, _ax, _perm, _peri); + } + stat.t_copy = timer.seconds(); + + stat_level.n_kernel_launching = 0; + timer.reset(); + { + // this should be considered with average problem sizes in levels + const ordinal_type half_level = _nlevel / 2; + // const ordinal_type team_size_factor[2] = { 64, 16 }, vector_size_factor[2] = { 8, 8}; + // const ordinal_type team_size_factor[2] = { 16, 16 }, vector_size_factor[2] = { 32, 32}; +#if defined(CUDA_VERSION) +#if (11000 > CUDA_VERSION) + /// cuda 11.1 below + const ordinal_type team_size_factor[2] = {32, 64}, vector_size_factor[2] = {8, 4}; +#else + /// cuda 11.1 and higher + const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; +#endif +#else + /// not cuda ... whatever.. + const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; +#endif + const ordinal_type team_size_update[2] = {16, 8}, vector_size_update[2] = {32, 32}; + { + typedef TeamFunctor_FactorizeLU functor_type; +#if defined(TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD) + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_factorize; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; +#else + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template FactorizeTag> + team_policy_factor; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _factorize_mode, _level_sids, _piv, _buf); + + team_policy_factor policy_factor(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + { + for (ordinal_type lvl = (_team_serial_level_cut - 1); lvl >= 0; --lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_buf_factor_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + + const auto buf_factor_ptr = Kokkos::subview(_buf_factor_ptr, range_buf_factor_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(buf_factor_ptr); + if (is_host) { + policy_factor = team_policy_factor(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_factor = team_policy_factor(pcnt, team_size_factor[idx], vector_size_factor[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("factor lower", policy_factor, functor); + } else { + Kokkos::parallel_for("factor", policy_factor, functor); + ++stat_level.n_kernel_launching; + } + + const auto h_buf_factor_ptr = Kokkos::subview(_h_buf_factor_ptr, range_buf_factor_ptr); + + factorizeLU_OnDevice(pbeg, pend, h_buf_factor_ptr, work); + Kokkos::fence(); + + Kokkos::parallel_for("update factor", policy_update, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); + } + const auto exec_instance = exec_space(); + Kokkos::deep_copy(exec_instance, _h_supernodes, _info.supernodes); + } } - case 2: { /// LDL + } // end of LU + stat.t_factor = timer.seconds(); + + timer.reset(); + { +#if defined(KOKKOS_ENABLE_CUDA) + track_free(work.span() * sizeof(value_type)); +#endif + track_free(_buf.span() * sizeof(value_type)); + _buf = value_type_array(); + } + stat.t_extra += timer.seconds(); + + if (verbose) { + printf("Summary: LevelSetTools-Variant-%d (LU Factorize)\n", variant); + printf("================================================\n"); + print_stat_factor(); + } + } + + inline void solveLU(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose) { + TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || x.extent(1) != b.extent(1) || x.extent(0) != t.extent(0) || + x.extent(1) != t.extent(1), + std::logic_error, "x, b, t, and w dimensions do not match"); + + TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || x.data() == t.data(), std::logic_error, + "x, b, t, and w have the same data pointer"); + constexpr bool is_host = std::is_same::value; + + // solve LU x = b + const ordinal_type nrhs = x.extent(1); + Kokkos::Timer timer; + + stat_level.n_kernel_launching = 0; + + // one-time operation when nrhs is changed + timer.reset(); + allocateWorkspaceSolve(nrhs); + + // 0. permute and copy b -> t + const auto exec_instance = exec_space(); + ApplyPermutation::invoke(exec_instance, b, _perm, t); + stat.t_extra = timer.seconds(); + + timer.reset(); + { +#if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) + const auto work_item_property = Kokkos::Experimental::WorkItemProperty::HintLightWeight; +#endif + // this should be considered with average problem sizes in levels + const ordinal_type half_level = _nlevel / 2; +#if defined(CUDA_VERSION) +#if (11000 > CUDA_VERSION) + /// cuda 11.1 below + const ordinal_type team_size_solve[2] = {32, 16}, vector_size_solve[2] = {8, 8}; +#else + /// cuda 11.1 and higher + const ordinal_type team_size_solve[2] = {32, 16}, vector_size_solve[2] = {8, 8}; +#endif +#else + /// not cuda whatever... + const ordinal_type team_size_solve[2] = {64, 16}, vector_size_solve[2] = {8, 8}; +#endif + const ordinal_type team_size_update[2] = {128, 32}, vector_size_update[2] = {1, 1}; + { + typedef TeamFunctor_SolveLowerLU functor_type; +#if defined(TACHO_TEST_SOLVE_CHOLESKY_KERNEL_OVERHEAD) + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; +#else + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template SolveTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _solve_mode, _level_sids, _piv, t, _buf); + + team_policy_solve policy_solve(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + // 1. L w = t { - const ordinal_type rlen = 4*_m, plen = _piv.span(); - if (plen < rlen) { - track_free(_piv.span()*sizeof(ordinal_type)); - _piv = ordinal_type_array("piv", rlen); - track_alloc(_piv.span()*sizeof(ordinal_type)); + for (ordinal_type lvl = (_team_serial_level_cut - 1); lvl >= 0; --lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + + const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(solve_buf_ptr); + if (is_host) { + policy_solve = team_policy_solve(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } +#if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) + const auto policy_solve_with_work_property = + Kokkos::Experimental::require(policy_solve, work_item_property); + const auto policy_update_with_work_property = + Kokkos::Experimental::require(policy_update, work_item_property); +#else + const auto policy_solve_with_work_property = policy_solve; + const auto policy_update_with_work_property = policy_update; +#endif + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("solve lower", policy_solve, functor); + } else { + Kokkos::parallel_for("solve lower", policy_solve_with_work_property, functor); + ++stat_level.n_kernel_launching; + } + const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); + solveLU_LowerOnDevice(pbeg, pend, h_buf_solve_ptr, t); + Kokkos::fence(); + + Kokkos::parallel_for("update lower", policy_update_with_work_property, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); } } + } // end of lower tri solve + + { + typedef TeamFunctor_SolveUpperLU functor_type; +#if defined(TACHO_TEST_SOLVE_CHOLESKY_KERNEL_OVERHEAD) + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> + team_policy_update; +#else + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template SolveTag> + team_policy_solve; + typedef Kokkos::TeamPolicy, exec_space, + typename functor_type::template UpdateTag> + team_policy_update; +#endif + functor_type functor(_info, _solve_mode, _level_sids, t, _buf); + + team_policy_solve policy_solve(1, 1, 1); + team_policy_update policy_update(1, 1, 1); + + // 2. U t = w; { - const ordinal_type rlen = 2*_m, dlen = _diag.span(); - if (dlen < rlen) { - track_free(_diag.span()*sizeof(value_type)); - _diag = value_type_array("diag", rlen); - track_alloc(_diag.span()*sizeof(value_type)); + for (ordinal_type lvl = 0; lvl < _team_serial_level_cut; ++lvl) { + const ordinal_type pbeg = _h_level_ptr(lvl), pend = _h_level_ptr(lvl + 1), pcnt = pend - pbeg; + + const range_type range_solve_buf_ptr(_h_buf_level_ptr(lvl), _h_buf_level_ptr(lvl + 1)); + const auto solve_buf_ptr = Kokkos::subview(_buf_solve_nrhs_ptr, range_solve_buf_ptr); + functor.setRange(pbeg, pend); + functor.setBufferPtr(solve_buf_ptr); + if (is_host) { + policy_solve = team_policy_solve(pcnt, 1, 1); + policy_update = team_policy_update(pcnt, 1, 1); + } else { + const ordinal_type idx = lvl > half_level; + policy_solve = team_policy_solve(pcnt, team_size_solve[idx], vector_size_solve[idx]); + policy_update = team_policy_update(pcnt, team_size_update[idx], vector_size_update[idx]); + } +#if defined(TACHO_ENABLE_SOLVE_CHOLESKY_USE_LIGHT_KERNEL) + const auto policy_solve_with_work_property = + Kokkos::Experimental::require(policy_solve, work_item_property); + const auto policy_update_with_work_property = + Kokkos::Experimental::require(policy_update, work_item_property); +#else + const auto policy_solve_with_work_property = policy_solve; + const auto policy_update_with_work_property = policy_update; +#endif + Kokkos::parallel_for("update upper", policy_update_with_work_property, functor); + ++stat_level.n_kernel_launching; + exec_space().fence(); + + if (lvl < _device_level_cut) { + // do nothing + // Kokkos::parallel_for("solve upper", policy_solve, functor); + } else { + Kokkos::parallel_for("solve upper", policy_solve_with_work_property, functor); + ++stat_level.n_kernel_launching; + } + + const auto h_buf_solve_ptr = Kokkos::subview(_h_buf_solve_nrhs_ptr, range_solve_buf_ptr); + solveLU_UpperOnDevice(pbeg, pend, h_buf_solve_ptr, t); + Kokkos::fence(); } } - factorizeLDL(ax, verbose); - break; - } - default: { - TACHO_TEST_FOR_EXCEPTION(false, - std::logic_error, - "The solution method is not supported"); - break; - } - } - } + } /// end of upper tri solve - inline - void - solve(const value_type_matrix &x, // solution - const value_type_matrix &b, // right hand side - const value_type_matrix &t, // temporary workspace (store permuted vectors) - const ordinal_type verbose = 0) override { - switch (this->getSolutionMethod()) { - case 1: { /// Cholesky - solveCholesky(x, b, t, verbose); - break; - } - case 2: { /// LDL - solveLDL(x, b, t, verbose); - break; + } // end of solve + stat.t_solve = timer.seconds(); + + // permute and copy t -> x + timer.reset(); + ApplyPermutation::invoke(exec_instance, t, _peri, x); + stat.t_extra += timer.seconds(); + + if (verbose) { + printf("Summary: LevelSetTools-Variant-%d (LU Solve: %3d)\n", variant, nrhs); + printf("=================================================\n"); + print_stat_solve(); + } + } + + inline void factorize(const value_type_array &ax, const ordinal_type verbose = 0) override { + Kokkos::deep_copy(_superpanel_buf, value_type(0)); + switch (this->getSolutionMethod()) { + case 1: { /// Cholesky + factorizeCholesky(ax, verbose); + break; + } + case 2: { /// LDL + { + const ordinal_type rlen = 4 * _m, plen = _piv.span(); + if (plen < rlen) { + track_free(_piv.span() * sizeof(ordinal_type)); + _piv = ordinal_type_array("piv", rlen); + track_alloc(_piv.span() * sizeof(ordinal_type)); + } } - default: { - TACHO_TEST_FOR_EXCEPTION(false, - std::logic_error, - "The solution method is not supported"); - break; + { + const ordinal_type rlen = 2 * _m, dlen = _diag.span(); + if (dlen < rlen) { + track_free(_diag.span() * sizeof(value_type)); + _diag = value_type_array("diag", rlen); + track_alloc(_diag.span() * sizeof(value_type)); + } } + factorizeLDL(ax, verbose); + break; + } + case 3: { /// LU + { + const ordinal_type rlen = 4 * _m, plen = _piv.span(); + if (plen < rlen) { + track_free(_piv.span() * sizeof(ordinal_type)); + _piv = ordinal_type_array("piv", rlen); + track_alloc(_piv.span() * sizeof(ordinal_type)); + } } + factorizeLU(ax, verbose); + break; } - - }; + default: { + TACHO_TEST_FOR_EXCEPTION(false, std::logic_error, "The solution method is not supported"); + break; + } + } + } + + inline void solve(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose = 0) override { + switch (this->getSolutionMethod()) { + case 1: { /// Cholesky + solveCholesky(x, b, t, verbose); + break; + } + case 2: { /// LDL + solveLDL(x, b, t, verbose); + break; + } + case 3: { /// LU + solveLU(x, b, t, verbose); + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(false, std::logic_error, "The solution method is not supported"); + break; + } + } + } +}; -} +} // namespace Tacho #endif - diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Serial.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Serial.hpp index 1cc43da13be9..92c81907c088 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Serial.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_Serial.hpp @@ -16,444 +16,543 @@ #include "Tacho_LDL_Supernodes.hpp" #include "Tacho_LDL_Supernodes_Serial.hpp" +#include "Tacho_LU_Supernodes.hpp" +#include "Tacho_LU_Supernodes_Serial.hpp" + namespace Tacho { - template - class NumericToolsSerial : public NumericToolsBase { - public: - /// - /// types - /// - using base_type = NumericToolsBase; - using typename base_type::value_type; - using typename base_type::device_type; - using typename base_type::exec_space; - using typename base_type::exec_memory_space; - using typename base_type::host_memory_space; - using typename base_type::ordinal_type_array; - using typename base_type::size_type_array; - using typename base_type::value_type_array; - using typename base_type::value_type_matrix; - using typename base_type::ordinal_type_array_host; - - using base_type::base_type; - - private: - using base_type::_m; - using base_type::_ap; - using base_type::_aj; - using base_type::_ax; - using base_type::_perm; - using base_type::_peri; - using base_type::_nsupernodes; - using base_type::_supernodes; - using base_type::_stree_roots; - using base_type::_superpanel_buf; - using base_type::_piv; - using base_type::_diag; - using base_type::_info; - - using base_type::stat; - using base_type::track_alloc; - using base_type::track_free; - using base_type::reset_stat; - //using base_type::print_stat_factor; - using base_type::print_stat_solve; - using base_type::print_stat_memory; - - public: - - inline - void - print_stat_factor() override { - base_type::print_stat_factor(); - double flop = 0; - auto h_supernodes = Kokkos::create_mirror_view_and_copy(host_memory_space(), _supernodes); - switch (this->getSolutionMethod()) { - case 1: { - for (ordinal_type sid=0;sid<_nsupernodes;++sid) { - auto &s = h_supernodes(sid); - const ordinal_type m = s.m, n = s.n - s.m; - flop += DenseFlopCount::Chol(m); - flop += DenseFlopCount::Trsm(true, m, n); - flop += DenseFlopCount::Syrk(m, n); - } - break; - } - case 2: { - for (ordinal_type sid=0;sid<_nsupernodes;++sid) { - auto &s = h_supernodes(sid); - const ordinal_type m = s.m, n = s.n - s.m; - flop += DenseFlopCount::LDL(m); - flop += DenseFlopCount::Trsm(true, m, n); - flop += DenseFlopCount::Syrk(m, n); - } - break; - } - default: { - TACHO_TEST_FOR_EXCEPTION(false, - std::logic_error, - "The solution method is not supported"); +template +class NumericToolsSerial : public NumericToolsBase { +public: + /// + /// types + /// + using base_type = NumericToolsBase; + using typename base_type::device_type; + using typename base_type::exec_memory_space; + using typename base_type::exec_space; + using typename base_type::host_memory_space; + using typename base_type::ordinal_type_array; + using typename base_type::ordinal_type_array_host; + using typename base_type::size_type_array; + using typename base_type::value_type; + using typename base_type::value_type_array; + using typename base_type::value_type_matrix; + + using base_type::base_type; + +private: + using base_type::_aj; + using base_type::_ap; + using base_type::_ax; + using base_type::_diag; + using base_type::_info; + using base_type::_m; + using base_type::_nsupernodes; + using base_type::_peri; + using base_type::_perm; + using base_type::_piv; + using base_type::_stree_roots; + using base_type::_supernodes; + using base_type::_superpanel_buf; + + using base_type::reset_stat; + using base_type::stat; + using base_type::track_alloc; + using base_type::track_free; + // using base_type::print_stat_factor; + using base_type::print_stat_memory; + using base_type::print_stat_solve; + +public: + inline void print_stat_factor() override { + base_type::print_stat_factor(); + double flop = 0; + auto h_supernodes = Kokkos::create_mirror_view_and_copy(host_memory_space(), _supernodes); + switch (this->getSolutionMethod()) { + case 1: { + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) { + auto &s = h_supernodes(sid); + const ordinal_type m = s.m, n = s.n - s.m; + flop += DenseFlopCount::Chol(m); + flop += DenseFlopCount::Trsm(true, m, n); + flop += DenseFlopCount::Syrk(m, n); } + break; + } + case 2: { + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) { + auto &s = h_supernodes(sid); + const ordinal_type m = s.m, n = s.n - s.m; + flop += DenseFlopCount::LDL(m); + flop += DenseFlopCount::Trsm(true, m, n); + flop += DenseFlopCount::Syrk(m, n); } - const double kilo(1024); - printf(" FLOPs\n"); - printf(" gflop for numeric factorization: %10.3f GFLOP\n", flop/kilo/kilo/kilo); - printf(" gflop/s for numeric factorization: %10.3f GFLOP/s\n", flop/stat.t_factor/kilo/kilo/kilo); - printf("\n"); - } - - /// - /// Choleksy - /// - inline - void - factorizeCholesky(const value_type_array &ax, - const ordinal_type verbose) { - Kokkos::Timer timer; - { - timer.reset(); - { - /// matrix values - _ax = ax; - - /// copy the input matrix into super panels - _info.copySparseToSuperpanels(_ap, _aj, _ax, _perm, _peri); - } - stat.t_copy = timer.seconds(); + break; + } + case 3: { + for (ordinal_type sid = 0; sid < _nsupernodes; ++sid) { + auto &s = h_supernodes(sid); + const ordinal_type m = s.m, n = s.n - s.m; + flop += DenseFlopCount::LU(m, m); + flop += 2 * DenseFlopCount::Trsm(true, m, n); + flop += DenseFlopCount::Gemm(n, n, m); } - + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(false, std::logic_error, "The solution method is not supported"); + } + } + const double kilo(1024); + printf(" FLOPs\n"); + printf(" gflop for numeric factorization: %10.3f GFLOP\n", flop / kilo / kilo / kilo); + printf(" gflop/s for numeric factorization: %10.3f GFLOP/s\n", + flop / stat.t_factor / kilo / kilo / kilo); + printf("\n"); + } + + /// + /// Choleksy + /// + inline void factorizeCholesky(const value_type_array &ax, const ordinal_type verbose) { + Kokkos::Timer timer; + { timer.reset(); { - /// valgrind reports the following buf array as uninitialized even if it is initialized - /// while the task is executed. to remove the valgrind error, we initialize the array with zero. - /// value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*(_info.max_schur_size + 1)); - value_type_array buf("buf", _info.max_schur_size*(_info.max_schur_size + 1)); - const size_t bufsize = buf.span()*sizeof(value_type); - track_alloc(bufsize); - - /// recursive tree traversal - const ordinal_type member = 0, nroots = _stree_roots.extent(0); - for (ordinal_type i=0;i - ::factorize_recursive_serial(member, _info, _stree_roots(i), true, buf.data(), bufsize); - - track_free(bufsize); - } - stat.t_factor = timer.seconds(); - - if (verbose) { - printf("Summary: NumericTools, Cholesky (SerialFactorization)\n"); - printf("=====================================================\n"); - - print_stat_factor(); + /// matrix values + _ax = ax; + + /// copy the input matrix into super panels + const bool copy_to_l_buf(false); + _info.copySparseToSuperpanels(copy_to_l_buf, _ap, _aj, _ax, _perm, _peri); } + stat.t_copy = timer.seconds(); } - inline - void - factorizeCholesky(const value_type_array &ax, - const ordinal_type panelsize, - const ordinal_type verbose) { - Kokkos::Timer timer; - { - timer.reset(); - { - /// matrix values - _ax = ax; - - /// copy the input matrix into super panels - _info.copySparseToSuperpanels(_ap, _aj, _ax, _perm, _peri); - } - stat.t_copy = timer.seconds(); - } + timer.reset(); + { + /// valgrind reports the following buf array as uninitialized even if it is initialized + /// while the task is executed. to remove the valgrind error, we initialize the array with zero. + /// value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*(_info.max_schur_size + 1)); + value_type_array buf("buf", _info.max_schur_size * (_info.max_schur_size + 1)); + const size_t bufsize = buf.span() * sizeof(value_type); + track_alloc(bufsize); + + /// recursive tree traversal + const ordinal_type member = 0, nroots = _stree_roots.extent(0); + for (ordinal_type i = 0; i < nroots; ++i) + CholSupernodes::factorize_recursive_serial(member, _info, _stree_roots(i), true, + buf.data(), bufsize); + + track_free(bufsize); + } + stat.t_factor = timer.seconds(); + + if (verbose) { + printf("Summary: NumericTools, Cholesky (SerialFactorization)\n"); + printf("=====================================================\n"); + + print_stat_factor(); + } + } - const ordinal_type nb = panelsize > 0 ? panelsize : _info.max_schur_size; + inline void factorizeCholesky(const value_type_array &ax, const ordinal_type panelsize, const ordinal_type verbose) { + Kokkos::Timer timer; + { timer.reset(); { - /// valgrind reports the following buf array as uninitialized even if it is initialized - /// while the task is executed. to remove the valgrind error, we initialize the array with zero. - /// value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*(nb + 1)); - value_type_array buf("buf", _info.max_schur_size*(nb + 1)); - const size_t bufsize = buf.span()*sizeof(value_type); - track_alloc(bufsize); - - /// recursive tree traversal - const ordinal_type member = 0, nroots = _stree_roots.extent(0); - for (ordinal_type i=0;i - ::factorize_recursive_serial(member, - _info, _stree_roots(i), - true, buf.data(), bufsize, nb); - - track_free(bufsize); - } - stat.t_factor = timer.seconds(); - - if (verbose) { - printf("Summary: NumericTools, Cholesky (SerialPanelFactorization: %3d)\n", nb); - printf("===============================================================\n"); + /// matrix values + _ax = ax; - print_stat_factor(); + /// copy the input matrix into super panels + const bool copy_to_l_buf(false); + _info.copySparseToSuperpanels(copy_to_l_buf, _ap, _aj, _ax, _perm, _peri); } + stat.t_copy = timer.seconds(); + } + + const ordinal_type nb = panelsize > 0 ? panelsize : _info.max_schur_size; + timer.reset(); + { + /// valgrind reports the following buf array as uninitialized even if it is initialized + /// while the task is executed. to remove the valgrind error, we initialize the array with zero. + /// value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*(nb + 1)); + value_type_array buf("buf", _info.max_schur_size * (nb + 1)); + const size_t bufsize = buf.span() * sizeof(value_type); + track_alloc(bufsize); + + /// recursive tree traversal + const ordinal_type member = 0, nroots = _stree_roots.extent(0); + for (ordinal_type i = 0; i < nroots; ++i) + CholSupernodes::factorize_recursive_serial(member, _info, _stree_roots(i), true, + buf.data(), bufsize, nb); + + track_free(bufsize); } + stat.t_factor = timer.seconds(); - inline - void - solveCholesky(const value_type_matrix &x, // solution - const value_type_matrix &b, // right hand side - const value_type_matrix &t, // temporary workspace (store permuted vectors) - const ordinal_type verbose) { - Kokkos::Timer timer; + if (verbose) { + printf("Summary: NumericTools, Cholesky (SerialPanelFactorization: %3d)\n", nb); + printf("===============================================================\n"); - _info.x = t; + print_stat_factor(); + } + } + + inline void solveCholesky(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose) { + Kokkos::Timer timer; + + _info.x = t; + + // copy b -> t + timer.reset(); + const auto exec_instance = exec_space(); + ApplyPermutation::invoke(exec_instance, b, _perm, t); + stat.t_extra = timer.seconds(); + + timer.reset(); + { + value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size * x.extent(1)); + const size_t bufsize = buf.span() * sizeof(value_type); + track_alloc(bufsize); + + /// recursive tree traversal + const ordinal_type member = 0, nroots = _stree_roots.extent(0); + for (ordinal_type i = 0; i < nroots; ++i) + CholSupernodes::solve_lower_recursive_serial(member, _info, _stree_roots(i), true, + buf.data(), bufsize); + for (ordinal_type i = 0; i < nroots; ++i) + CholSupernodes::solve_upper_recursive_serial(member, _info, _stree_roots(i), true, + buf.data(), bufsize); + + track_free(bufsize); + } + stat.t_solve = timer.seconds(); - // copy b -> t - timer.reset(); - applyRowPermutationToDenseMatrix(t, b, _perm); - stat.t_extra = timer.seconds(); - + // copy t -> x + timer.reset(); + ApplyPermutation::invoke(exec_instance, t, _peri, x); + stat.t_extra += timer.seconds(); + + if (verbose) { + printf("Summary: NumericTools, Cholesky (SerialSolve: %3d)\n", ordinal_type(x.extent(1))); + printf("==================================================\n"); + + print_stat_solve(); + } + } + + /// + /// LDL + /// + inline void factorizeLDL(const value_type_array &ax, const ordinal_type verbose) { + { + const bool test = !std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(test, std::logic_error, "Serial interface works on host device only"); + } + + Kokkos::Timer timer; + { timer.reset(); { - value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*x.extent(1)); - const size_t bufsize = buf.span()*sizeof(value_type); - track_alloc(bufsize); - - /// recursive tree traversal - const ordinal_type member = 0, nroots = _stree_roots.extent(0); - for (ordinal_type i=0;i - ::solve_lower_recursive_serial(member, _info, _stree_roots(i), true, buf.data(), bufsize); - for (ordinal_type i=0;i - ::solve_upper_recursive_serial(member, _info, _stree_roots(i), true, buf.data(), bufsize); - - track_free(bufsize); + /// matrix values + _ax = ax; + + /// copy the input matrix into super panels + const bool copy_to_l_buf(false); + _info.copySparseToSuperpanels(copy_to_l_buf, _ap, _aj, _ax, _perm, _peri); } - stat.t_solve = timer.seconds(); - - // copy t -> x - timer.reset(); - applyRowPermutationToDenseMatrix(x, t, _peri); - stat.t_extra += timer.seconds(); + stat.t_copy = timer.seconds(); + } - if (verbose) { - printf("Summary: NumericTools, Cholesky (SerialSolve: %3d)\n", ordinal_type(x.extent(1))); - printf("==================================================\n"); + timer.reset(); + { + /// valgrind reports the following buf array as uninitialized even if it is initialized + /// while the task is executed. to remove the valgrind error, we initialize the array with zero. + /// value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*(_info.max_schur_size + 1)); + value_type_array buf("buf", + _info.max_schur_size * (_info.max_schur_size + 1) + // ABR + _info.max_supernode_size * + std::max(32, _info.max_schur_size)); // ATR copy and workspace for LDL + const size_t bufsize = buf.span() * sizeof(value_type); + track_alloc(bufsize); + + /// recursive tree traversal + const ordinal_type member = 0, nroots = _stree_roots.extent(0); + for (ordinal_type i = 0; i < nroots; ++i) + LDL_Supernodes::factorize_recursive_serial( + member, _info, _stree_roots(i), true, _piv.data(), _diag.data(), buf.data(), bufsize); + + track_free(bufsize); + } + stat.t_factor = timer.seconds(); - print_stat_solve(); - } + if (verbose) { + printf("Summary: NumericTools, LDL (SerialFactorization)\n"); + printf("================================================\n"); + + print_stat_factor(); + } + } + + inline void solveLDL(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose) { + { + const bool test = !std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(test, std::logic_error, "Serial interface works on host device only"); + TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || x.extent(1) != b.extent(1) || x.extent(0) != t.extent(0) || + x.extent(1) != t.extent(1), + std::logic_error, "supernode data structure is not allocated"); + TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || x.data() == t.data() || t.data() == b.data(), std::logic_error, + "x, b and t have the same data pointer"); } - inline - void - factorizeLDL(const value_type_array &ax, - const ordinal_type verbose) { - { - const bool test = !std::is_same::value; - TACHO_TEST_FOR_EXCEPTION(test, - std::logic_error, - "Serial interface works on host device only"); - } + Kokkos::Timer timer; - Kokkos::Timer timer; - { - timer.reset(); - { - /// matrix values - _ax = ax; - - /// copy the input matrix into super panels - _info.copySparseToSuperpanels(_ap, _aj, _ax, _perm, _peri); - } - stat.t_copy = timer.seconds(); - } - - timer.reset(); - { - /// valgrind reports the following buf array as uninitialized even if it is initialized - /// while the task is executed. to remove the valgrind error, we initialize the array with zero. - /// value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*(_info.max_schur_size + 1)); - value_type_array - buf("buf", - _info.max_schur_size*(_info.max_schur_size + 1) + // ABR - _info.max_supernode_size*std::max(32,_info.max_schur_size)); // ATR copy and workspace for LDL - const size_t bufsize = buf.span()*sizeof(value_type); - track_alloc(bufsize); - - /// recursive tree traversal - const ordinal_type member = 0, nroots = _stree_roots.extent(0); - for (ordinal_type i=0;i - ::factorize_recursive_serial(member, _info, _stree_roots(i), true, - _piv.data(), _diag.data(), buf.data(), bufsize); - - track_free(bufsize); - } - stat.t_factor = timer.seconds(); - - if (verbose) { - printf("Summary: NumericTools, LDL (SerialFactorization)\n"); - printf("================================================\n"); - - print_stat_factor(); - } + _info.x = t; + + // copy b -> t + timer.reset(); + const auto exec_instance = exec_space(); + ApplyPermutation::invoke(exec_instance, b, _perm, t); + stat.t_extra = timer.seconds(); + + timer.reset(); + { + value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size * x.extent(1)); + const size_t bufsize = buf.span() * sizeof(value_type); + track_alloc(bufsize); + + /// recursive tree traversal + const ordinal_type member = 0, nroots = _stree_roots.extent(0); + for (ordinal_type i = 0; i < nroots; ++i) + LDL_Supernodes::solve_lower_recursive_serial(member, _info, _stree_roots(i), true, + _piv.data(), buf.data(), bufsize); + for (ordinal_type i = 0; i < nroots; ++i) + LDL_Supernodes::solve_upper_recursive_serial( + member, _info, _stree_roots(i), true, _piv.data(), _diag.data(), buf.data(), bufsize); + + track_free(bufsize); } + stat.t_solve = timer.seconds(); - inline - void - solveLDL(const value_type_matrix &x, // solution - const value_type_matrix &b, // right hand side - const value_type_matrix &t, // temporary workspace (store permuted vectors) - const ordinal_type verbose) { - { - const bool test = !std::is_same::value; - TACHO_TEST_FOR_EXCEPTION(test, - std::logic_error, - "Serial interface works on host device only"); - TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || - x.extent(1) != b.extent(1) || - x.extent(0) != t.extent(0) || - x.extent(1) != t.extent(1), std::logic_error, - "supernode data structure is not allocated"); - TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || - x.data() == t.data() || - t.data() == b.data(), std::logic_error, - "x, b and t have the same data pointer"); - } + // copy t -> x + timer.reset(); + ApplyPermutation::invoke(exec_instance, t, _peri, x); + stat.t_extra += timer.seconds(); - Kokkos::Timer timer; + if (verbose) { + printf("Summary: NumericTools, LDL (SerialSolve: %3d)\n", ordinal_type(x.extent(1))); + printf("=============================================\n"); - _info.x = t; + print_stat_solve(); + } + } + + /// + /// LU + /// + inline void factorizeLU(const value_type_array &ax, const ordinal_type verbose) { + { + const bool test = !std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(test, std::logic_error, "Serial interface works on host device only"); + } - // copy b -> t - timer.reset(); - applyRowPermutationToDenseMatrix(t, b, _perm); - stat.t_extra = timer.seconds(); - + Kokkos::Timer timer; + { timer.reset(); { - value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*x.extent(1)); - const size_t bufsize = buf.span()*sizeof(value_type); - track_alloc(bufsize); - - /// recursive tree traversal - const ordinal_type member = 0, nroots = _stree_roots.extent(0); - for (ordinal_type i=0;i - ::solve_lower_recursive_serial(member, _info, _stree_roots(i), true, _piv.data(), buf.data(), bufsize); - for (ordinal_type i=0;i - ::solve_upper_recursive_serial(member, _info, _stree_roots(i), true, _piv.data(), _diag.data(), buf.data(), bufsize); - - track_free(bufsize); + /// matrix values + _ax = ax; + + /// copy the input matrix into super panels + const bool copy_to_l_buf(true); + _info.copySparseToSuperpanels(copy_to_l_buf, _ap, _aj, _ax, _perm, _peri); } - stat.t_solve = timer.seconds(); - - // copy t -> x - timer.reset(); - applyRowPermutationToDenseMatrix(x, t, _peri); - stat.t_extra += timer.seconds(); + stat.t_copy = timer.seconds(); + } - if (verbose) { - printf("Summary: NumericTools, LDL (SerialSolve: %3d)\n", ordinal_type(x.extent(1))); - printf("=============================================\n"); + timer.reset(); + { + /// valgrind reports the following buf array as uninitialized even if it is initialized + /// while the task is executed. to remove the valgrind error, we initialize the array with zero. + /// value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size*(_info.max_schur_size + 1)); + value_type_array buf("buf", + _info.max_schur_size * (_info.max_schur_size + 1)); // ABR + const size_t bufsize = buf.span() * sizeof(value_type); + track_alloc(bufsize); + + /// recursive tree traversal + const ordinal_type member = 0, nroots = _stree_roots.extent(0); + for (ordinal_type i = 0; i < nroots; ++i) + LU_Supernodes::factorize_recursive_serial(member, _info, _stree_roots(i), true, + _piv.data(), buf.data(), bufsize); + + track_free(bufsize); + } + stat.t_factor = timer.seconds(); - print_stat_solve(); - } + if (verbose) { + printf("Summary: NumericTools, LU (SerialFactorization)\n"); + printf("===============================================\n"); + + print_stat_factor(); + } + } + + inline void solveLU(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose) { + { + const bool test = !std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(test, std::logic_error, "Serial interface works on host device only"); + TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || x.extent(1) != b.extent(1) || x.extent(0) != t.extent(0) || + x.extent(1) != t.extent(1), + std::logic_error, "supernode data structure is not allocated"); + TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || x.data() == t.data() || t.data() == b.data(), std::logic_error, + "x, b and t have the same data pointer"); + } + + Kokkos::Timer timer; + + _info.x = t; + + // copy b -> t + timer.reset(); + const auto exec_instance = exec_space(); + ApplyPermutation::invoke(exec_instance, b, _perm, t); + stat.t_extra = timer.seconds(); + + timer.reset(); + { + value_type_array buf(do_not_initialize_tag("buf"), _info.max_schur_size * x.extent(1)); + const size_t bufsize = buf.span() * sizeof(value_type); + track_alloc(bufsize); + + /// recursive tree traversal + const ordinal_type member = 0, nroots = _stree_roots.extent(0); + for (ordinal_type i = 0; i < nroots; ++i) + LU_Supernodes::solve_lower_recursive_serial(member, _info, _stree_roots(i), true, + _piv.data(), buf.data(), bufsize); + for (ordinal_type i = 0; i < nroots; ++i) + LU_Supernodes::solve_upper_recursive_serial(member, _info, _stree_roots(i), true, + _piv.data(), buf.data(), bufsize); + + track_free(bufsize); } + stat.t_solve = timer.seconds(); - /// - /// Choleksy main interface - /// - inline - void - factorize(const value_type_array &ax, - const ordinal_type verbose = 0) override { + // copy t -> x + timer.reset(); + ApplyPermutation::invoke(exec_instance, t, _peri, x); + stat.t_extra += timer.seconds(); + + if (verbose) { + printf("Summary: NumericTools, LU (SerialSolve: %3d)\n", ordinal_type(x.extent(1))); + printf("============================================\n"); + + print_stat_solve(); + } + } + + /// + /// main interface + /// + inline void factorize(const value_type_array &ax, const ordinal_type verbose = 0) override { + { + const bool test = !std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(test, std::logic_error, "Serial interface works on host device only"); + } + /// reset the supernode buffer for potential reuse cases + Kokkos::deep_copy(_superpanel_buf, value_type(0)); + switch (this->getSolutionMethod()) { + case 1: { /// Cholesky + // if (_nb > 0) { + // factorizeCholesky(ax, _nb, verbose); + // } else { + // factorizeCholesky(ax, verbose); + // } + factorizeCholesky(ax, verbose); + break; + } + case 2: { /// LDL { - const bool test = !std::is_same::value; - TACHO_TEST_FOR_EXCEPTION(test, - std::logic_error, - "Serial interface works on host device only"); - } - /// reset the supernode buffer for potential reuse cases - Kokkos::deep_copy(_superpanel_buf, value_type(0)); - switch (this->getSolutionMethod()) { - case 1: { /// Cholesky - // if (_nb > 0) { - // factorizeCholesky(ax, _nb, verbose); - // } else { - // factorizeCholesky(ax, verbose); - // } - factorizeCholesky(ax, verbose); - break; - } - case 2: { /// LDL - { - const ordinal_type rlen = 4*_m, plen = _piv.span(); - if (plen < rlen) { - track_free(this->_piv.span()*sizeof(ordinal_type)); - this->_piv = ordinal_type_array("piv", rlen); - track_alloc(this->_piv.span()*sizeof(ordinal_type)); - } + const ordinal_type rlen = 4 * _m, plen = _piv.span(); + if (plen < rlen) { + track_free(this->_piv.span() * sizeof(ordinal_type)); + this->_piv = ordinal_type_array("piv", rlen); + track_alloc(this->_piv.span() * sizeof(ordinal_type)); } - { - const ordinal_type rlen = 2*_m, dlen = _diag.span(); - if (dlen < rlen) { - track_free(this->_diag.span()*sizeof(value_type)); - this->_diag = value_type_array("diag", rlen); - track_alloc(this->_diag.span()*sizeof(value_type)); - } - } - factorizeLDL(ax, verbose); - break; - } - default: { - TACHO_TEST_FOR_EXCEPTION(false, - std::logic_error, - "The solution method is not supported"); - break; } + { + const ordinal_type rlen = 2 * _m, dlen = _diag.span(); + if (dlen < rlen) { + track_free(this->_diag.span() * sizeof(value_type)); + this->_diag = value_type_array("diag", rlen); + track_alloc(this->_diag.span() * sizeof(value_type)); + } } + factorizeLDL(ax, verbose); + break; } - - inline - void - solve(const value_type_matrix &x, // solution - const value_type_matrix &b, // right hand side - const value_type_matrix &t, // temporary workspace (store permuted vectors) - const ordinal_type verbose = 0) override { + case 3: { /// LU { - const bool test = !std::is_same::value; - TACHO_TEST_FOR_EXCEPTION(test, - std::logic_error, - "Serial interface works on host device only"); - TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || - x.extent(1) != b.extent(1) || - x.extent(0) != t.extent(0) || - x.extent(1) != t.extent(1), std::logic_error, - "Input x, b and t dimensions are not compatible"); - TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || - x.data() == t.data() || - t.data() == b.data(), std::logic_error, - "Input x, b and t have the same data pointer"); - } - - switch (this->getSolutionMethod()) { - case 1: { - solveCholesky(x, b, t, verbose); - break; - } - case 2: { - solveLDL(x, b, t, verbose); - break; - } - default: { - } + const ordinal_type rlen = 4 * _m, plen = _piv.span(); + if (plen < rlen) { + track_free(this->_piv.span() * sizeof(ordinal_type)); + this->_piv = ordinal_type_array("piv", rlen); + track_alloc(this->_piv.span() * sizeof(ordinal_type)); + } } + factorizeLU(ax, verbose); + break; + } + default: { + TACHO_TEST_FOR_EXCEPTION(false, std::logic_error, "The solution method is not supported"); + break; + } + } + } + + inline void solve(const value_type_matrix &x, // solution + const value_type_matrix &b, // right hand side + const value_type_matrix &t, // temporary workspace (store permuted vectors) + const ordinal_type verbose = 0) override { + { + const bool test = !std::is_same::value; + TACHO_TEST_FOR_EXCEPTION(test, std::logic_error, "Serial interface works on host device only"); + TACHO_TEST_FOR_EXCEPTION(x.extent(0) != b.extent(0) || x.extent(1) != b.extent(1) || x.extent(0) != t.extent(0) || + x.extent(1) != t.extent(1), + std::logic_error, "Input x, b and t dimensions are not compatible"); + TACHO_TEST_FOR_EXCEPTION(x.data() == b.data() || x.data() == t.data() || t.data() == b.data(), std::logic_error, + "Input x, b and t have the same data pointer"); } - }; + switch (this->getSolutionMethod()) { + case 1: { + solveCholesky(x, b, t, verbose); + break; + } + case 2: { + solveLDL(x, b, t, verbose); + break; + } + case 3: { + solveLU(x, b, t, verbose); + break; + } + default: { + } + } + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Partition.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Partition.hpp index 06e2801cac84..204c4abe8fb4 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Partition.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Partition.hpp @@ -7,374 +7,265 @@ #include "Tacho_Util.hpp" -namespace Tacho { - - template - KOKKOS_INLINE_FUNCTION - void - Part_2x2(const MatView A, MatView &ATL, MatView &ATR, - /**************/ MatView &ABL, MatView &ABR, - const ordinal_type bm, - const ordinal_type bn, - const int quadrant) { - ordinal_type bmm, bnn; - - switch (quadrant) { - case Partition::TopLeft: - bmm = min(bm, A.extent(0)); - bnn = min(bn, A.extent(1)); - - ATL.set_view(A, - A.offset_0(), bmm, - A.offset_1(), bnn); - break; - case Partition::TopRight: - case Partition::BottomLeft: - TACHO_TEST_FOR_ABORT(true, MSG_NOT_IMPLEMENTED); - break; - case Partition::BottomRight: - bmm = A.extent(0) - min(bm, A.extent(0)); - bnn = A.extent(1) - min(bn, A.extent(1)); - - ATL.set_view(A, - A.offset_0(), bmm, - A.offset_1(), bnn); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - - ATR.set_view(A, - A.offset_0(), ATL.extent(0), - A.offset_1() + ATL.extent(1), A.extent(1) - ATL.extent(1)); - - ABL.set_view(A, - A.offset_0() + ATL.extent(0), A.extent(0) - ATL.extent(0), - A.offset_1(), ATL.extent(1)); - - ABR.set_view(A, - A.offset_0() + ATL.extent(0), A.extent(0) - ATL.extent(0), - A.offset_1() + ATL.extent(1), A.extent(1) - ATL.extent(1)); - } - - template - KOKKOS_INLINE_FUNCTION - void - Part_1x2(const MatView A, MatView &AL, MatView &AR, - const ordinal_type bn, - const int side) { - ordinal_type bmm, bnn; - - switch (side) { - case Partition::Left: - bmm = A.extent(0); - bnn = min(bn, A.extent(1)); - - AL.set_view(A, - A.offset_0(), bmm, - A.offset_1(), bnn); - break; - case Partition::Right: - bmm = A.extent(0); - bnn = A.extent(1) - min(bn, A.extent(1)); - - AL.set_view(A, - A.offset_0(), bmm, - A.offset_1(), bnn); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - - AR.set_view(A, - A.offset_0(), A.extent(0), - A.offset_1() + AL.extent(1), A.extent(1) - AL.extent(1)); - } - - template - KOKKOS_INLINE_FUNCTION - void - Part_2x1(const MatView A, MatView &AT, - /*************/ MatView &AB, - const ordinal_type bm, - const int side) { - ordinal_type bmm, bnn; - - switch (side) { - case Partition::Top: - bmm = min(bm, A.extent(0)); - bnn = A.extent(1); - - AT.set_view(A, - A.offset_0(), bmm, - A.offset_1(), bnn); - break; - case Partition::Bottom: - bmm = A.extent(0) - min(bm, A.extent(0)); - bnn = A.extent(1); - - AT.set_view(A, - A.offset_0(), bmm, - A.offset_1(), bnn); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - - AB.set_view(A, - A.offset_0() + AT.extent(0), A.extent(0) - AT.extent(0), - A.offset_1(), A.extent(1)); - } - - template - KOKKOS_INLINE_FUNCTION - void - Part_2x2_to_3x3(const MatView ATL, const MatView ATR, MatView &A00, MatView &A01, MatView &A02, - /***********************************/ MatView &A10, MatView &A11, MatView &A12, - const MatView ABL, const MatView ABR, MatView &A20, MatView &A21, MatView &A22, - const ordinal_type bm, - const ordinal_type bn, - const int quadrant) { - switch (quadrant) { - case Partition::TopLeft: - Part_2x2(ATL, A00, A01, - /**/ A10, A11, - bm, bn, Partition::BottomRight); - - Part_2x1(ATR, A02, - /**/ A12, - bm, Partition::Bottom); - - Part_1x2(ABL, A20, A21, - bn, Partition::Right); - - A22.set_view(ABR, - ABR.offset_0(), ABR.extent(0), - ABR.offset_1(), ABR.extent(1)); - break; - case Partition::TopRight: - case Partition::BottomLeft: - TACHO_TEST_FOR_ABORT(true, MSG_NOT_IMPLEMENTED); - break; - case Partition::BottomRight: - A00.set_view(ATL, - ATL.offset_0(), ATL.extent(0), - ATL.offset_1(), ATL.extent(1)); - - Part_1x2(ATR, A01, A02, - bn, Partition::Left); - - Part_2x1(ABL, A10, - /**/ A20, - bm, Partition::Top); - - Part_2x2(ABR, A11, A12, - /**/ A21, A22, - bm, bn, Partition::TopLeft); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - } - - template - KOKKOS_INLINE_FUNCTION - void - Part_2x1_to_3x1(const MatView AT, MatView &A0, - /***************/ MatView &A1, - const MatView AB, MatView &A2, - const ordinal_type bm, - const int side) { - switch (side) { - case Partition::Top: - Part_2x1(AT, A0, - /**/ A1, - bm, Partition::Bottom); - - A2.set_view(AB, - AB.offset_0(), AB.extent(0), - AB.offset_1(), AB.extent(1)); - break; - case Partition::Bottom: - A0.set_view(AT, - AT.offset_0(), AT.extent(0), - AT.offset_1(), AT.extent(1)); - - Part_2x1(AB, A1, - /**/ A2, - bm, Partition::Top); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - } - - template - KOKKOS_INLINE_FUNCTION - void - Part_1x2_to_1x3(const MatView AL, const MatView AR, - MatView &A0, MatView &A1, MatView &A2, - const ordinal_type bn, - const int side) { - switch (side) { - case Partition::Left: - Part_1x2(AL, A0, A1, - bn, Partition::Right); - - A2.set_view(AR.BaseObaject(), - AR.offset_0(), AR.extent(0), - AR.offset_1(), AR.extent(1)); - break; - case Partition::Right: - A0.set_view(AL, - AL.offset_0(), AL.extent(0), - AL.offset_1(), AL.extent(1)); - - Part_1x2(AR, A1, A2, - bn, Partition::Left); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - } - - template - KOKKOS_INLINE_FUNCTION - void - Merge_2x2(const MatView ATL, const MatView ATR, - const MatView ABL, const MatView ABR, MatView &A) { - A.set_view(ATL, - ATL.offset_0(), ATL.extent(0) + ABR.extent(0), - ATL.offset_1(), ATL.extent(1) + ABR.extent(1)); - } - - template - KOKKOS_INLINE_FUNCTION - void - Merge_1x2(const MatView AL, const MatView AR, MatView &A) { - A.set_view(AL, - AL.offset_0(), AL.extent(0), - AL.offset_1(), AL.extent(1) + AR.extent(1)); - } - - template - KOKKOS_INLINE_FUNCTION - void - Merge_2x1(const MatView AT, - const MatView AB, MatView &A) { - A.set_view(AT, - AT.offset_0(), AT.extent(0) + AB.extent(0), - AT.offset_1(), AT.extent(1)); - } - - template - KOKKOS_INLINE_FUNCTION - void - Merge_3x3_to_2x2(const MatView A00, const MatView A01, const MatView A02, MatView &ATL, MatView &ATR, - const MatView A10, const MatView A11, const MatView A12, - const MatView A20, const MatView A21, const MatView A22, MatView &ABL, MatView &ABR, - const int quadrant) { - switch (quadrant) { - case Partition::TopLeft: - Merge_2x2(A00, A01, - A10, A11, ATL); - - Merge_2x1(A02, - A12, ATR); - - Merge_1x2(A20, A21, ABL); - - ABR.set_view(A22, - A22.offset_0(), A22.extent(0), - A22.offset_1(), A22.extent(1)); - break; - case Partition::TopRight: - case Partition::BottomLeft: - TACHO_TEST_FOR_ABORT(true, MSG_NOT_IMPLEMENTED); - break; - case Partition::BottomRight: - ATL.set_view(A00, - A00.offset_0(), A00.extent(0), - A00.offset_1(), A00.extent(1)); - - Merge_1x2(A01, A02, ATR); - - Merge_2x1(A10, - A20, ABL); - - Merge_2x2(A11, A12, - A21, A22, ABR); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - } - - template - KOKKOS_INLINE_FUNCTION - void - Merge_3x1_to_2x1(const MatView A0, MatView &AT, - const MatView A1, - const MatView A2, MatView &AB, - const int side) { - switch (side) { - case Partition::Top: - Merge_2x1(A0, - A1, AT); - - AB.set_view(A2, - A2.offset_0(), A2.extent(0), - A2.offset_1(), A2.extent(1)); - break; - case Partition::Bottom: - AT.set_view(A0, - A0.offset_0(), A0.extent(0), - A0.offset_1(), A0.extent(1)); - - Merge_2x1(A1, - A2, AB); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - } - - template - KOKKOS_INLINE_FUNCTION - void - Merge_1x3_to_1x2(const MatView A0, const MatView A1, const MatView A2, - MatView &AL, MatView &AR, - const int side) { - switch (side) { - case Partition::Left: - Merge_1x2(A0, A1, AL); - - AR.set_view(A2, - A2.offset_0(), A2.extent(0), - A2.offset_1(), A2.extent(1)); - break; - case Partition::Right: - AL.set_view(A0, - A0.offset_0(), A0.extent(0), - A0.offset_1(), A0.extent(1)); - - Merge_1x2(A1, A2, AR); - break; - default: - TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); - break; - } - } +namespace Tacho { + +template +KOKKOS_INLINE_FUNCTION void Part_2x2(const MatView A, MatView &ATL, MatView &ATR, + /**************/ MatView &ABL, MatView &ABR, const ordinal_type bm, + const ordinal_type bn, const int quadrant) { + ordinal_type bmm, bnn; + + switch (quadrant) { + case Partition::TopLeft: + bmm = min(bm, A.extent(0)); + bnn = min(bn, A.extent(1)); + + ATL.set_view(A, A.offset_0(), bmm, A.offset_1(), bnn); + break; + case Partition::TopRight: + case Partition::BottomLeft: + TACHO_TEST_FOR_ABORT(true, MSG_NOT_IMPLEMENTED); + break; + case Partition::BottomRight: + bmm = A.extent(0) - min(bm, A.extent(0)); + bnn = A.extent(1) - min(bn, A.extent(1)); + + ATL.set_view(A, A.offset_0(), bmm, A.offset_1(), bnn); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } + + ATR.set_view(A, A.offset_0(), ATL.extent(0), A.offset_1() + ATL.extent(1), A.extent(1) - ATL.extent(1)); + + ABL.set_view(A, A.offset_0() + ATL.extent(0), A.extent(0) - ATL.extent(0), A.offset_1(), ATL.extent(1)); + + ABR.set_view(A, A.offset_0() + ATL.extent(0), A.extent(0) - ATL.extent(0), A.offset_1() + ATL.extent(1), + A.extent(1) - ATL.extent(1)); +} + +template +KOKKOS_INLINE_FUNCTION void Part_1x2(const MatView A, MatView &AL, MatView &AR, const ordinal_type bn, const int side) { + ordinal_type bmm, bnn; + + switch (side) { + case Partition::Left: + bmm = A.extent(0); + bnn = min(bn, A.extent(1)); + + AL.set_view(A, A.offset_0(), bmm, A.offset_1(), bnn); + break; + case Partition::Right: + bmm = A.extent(0); + bnn = A.extent(1) - min(bn, A.extent(1)); + + AL.set_view(A, A.offset_0(), bmm, A.offset_1(), bnn); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } + + AR.set_view(A, A.offset_0(), A.extent(0), A.offset_1() + AL.extent(1), A.extent(1) - AL.extent(1)); +} +template +KOKKOS_INLINE_FUNCTION void Part_2x1(const MatView A, MatView &AT, + /*************/ MatView &AB, const ordinal_type bm, const int side) { + ordinal_type bmm, bnn; + + switch (side) { + case Partition::Top: + bmm = min(bm, A.extent(0)); + bnn = A.extent(1); + + AT.set_view(A, A.offset_0(), bmm, A.offset_1(), bnn); + break; + case Partition::Bottom: + bmm = A.extent(0) - min(bm, A.extent(0)); + bnn = A.extent(1); + + AT.set_view(A, A.offset_0(), bmm, A.offset_1(), bnn); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } + + AB.set_view(A, A.offset_0() + AT.extent(0), A.extent(0) - AT.extent(0), A.offset_1(), A.extent(1)); +} + +template +KOKKOS_INLINE_FUNCTION void +Part_2x2_to_3x3(const MatView ATL, const MatView ATR, MatView &A00, MatView &A01, MatView &A02, + /***********************************/ MatView &A10, MatView &A11, MatView &A12, const MatView ABL, + const MatView ABR, MatView &A20, MatView &A21, MatView &A22, const ordinal_type bm, + const ordinal_type bn, const int quadrant) { + switch (quadrant) { + case Partition::TopLeft: + Part_2x2(ATL, A00, A01, + /**/ A10, A11, bm, bn, Partition::BottomRight); + + Part_2x1(ATR, A02, + /**/ A12, bm, Partition::Bottom); + + Part_1x2(ABL, A20, A21, bn, Partition::Right); + + A22.set_view(ABR, ABR.offset_0(), ABR.extent(0), ABR.offset_1(), ABR.extent(1)); + break; + case Partition::TopRight: + case Partition::BottomLeft: + TACHO_TEST_FOR_ABORT(true, MSG_NOT_IMPLEMENTED); + break; + case Partition::BottomRight: + A00.set_view(ATL, ATL.offset_0(), ATL.extent(0), ATL.offset_1(), ATL.extent(1)); + + Part_1x2(ATR, A01, A02, bn, Partition::Left); + + Part_2x1(ABL, A10, + /**/ A20, bm, Partition::Top); + + Part_2x2(ABR, A11, A12, + /**/ A21, A22, bm, bn, Partition::TopLeft); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } +} +template +KOKKOS_INLINE_FUNCTION void Part_2x1_to_3x1(const MatView AT, MatView &A0, + /***************/ MatView &A1, const MatView AB, MatView &A2, + const ordinal_type bm, const int side) { + switch (side) { + case Partition::Top: + Part_2x1(AT, A0, + /**/ A1, bm, Partition::Bottom); + + A2.set_view(AB, AB.offset_0(), AB.extent(0), AB.offset_1(), AB.extent(1)); + break; + case Partition::Bottom: + A0.set_view(AT, AT.offset_0(), AT.extent(0), AT.offset_1(), AT.extent(1)); + + Part_2x1(AB, A1, + /**/ A2, bm, Partition::Top); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } } +template +KOKKOS_INLINE_FUNCTION void Part_1x2_to_1x3(const MatView AL, const MatView AR, MatView &A0, MatView &A1, MatView &A2, + const ordinal_type bn, const int side) { + switch (side) { + case Partition::Left: + Part_1x2(AL, A0, A1, bn, Partition::Right); + + A2.set_view(AR.BaseObaject(), AR.offset_0(), AR.extent(0), AR.offset_1(), AR.extent(1)); + break; + case Partition::Right: + A0.set_view(AL, AL.offset_0(), AL.extent(0), AL.offset_1(), AL.extent(1)); + + Part_1x2(AR, A1, A2, bn, Partition::Left); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } +} + +template +KOKKOS_INLINE_FUNCTION void Merge_2x2(const MatView ATL, const MatView ATR, const MatView ABL, const MatView ABR, + MatView &A) { + A.set_view(ATL, ATL.offset_0(), ATL.extent(0) + ABR.extent(0), ATL.offset_1(), ATL.extent(1) + ABR.extent(1)); +} + +template KOKKOS_INLINE_FUNCTION void Merge_1x2(const MatView AL, const MatView AR, MatView &A) { + A.set_view(AL, AL.offset_0(), AL.extent(0), AL.offset_1(), AL.extent(1) + AR.extent(1)); +} + +template KOKKOS_INLINE_FUNCTION void Merge_2x1(const MatView AT, const MatView AB, MatView &A) { + A.set_view(AT, AT.offset_0(), AT.extent(0) + AB.extent(0), AT.offset_1(), AT.extent(1)); +} + +template +KOKKOS_INLINE_FUNCTION void Merge_3x3_to_2x2(const MatView A00, const MatView A01, const MatView A02, MatView &ATL, + MatView &ATR, const MatView A10, const MatView A11, const MatView A12, + const MatView A20, const MatView A21, const MatView A22, MatView &ABL, + MatView &ABR, const int quadrant) { + switch (quadrant) { + case Partition::TopLeft: + Merge_2x2(A00, A01, A10, A11, ATL); + + Merge_2x1(A02, A12, ATR); + + Merge_1x2(A20, A21, ABL); + + ABR.set_view(A22, A22.offset_0(), A22.extent(0), A22.offset_1(), A22.extent(1)); + break; + case Partition::TopRight: + case Partition::BottomLeft: + TACHO_TEST_FOR_ABORT(true, MSG_NOT_IMPLEMENTED); + break; + case Partition::BottomRight: + ATL.set_view(A00, A00.offset_0(), A00.extent(0), A00.offset_1(), A00.extent(1)); + + Merge_1x2(A01, A02, ATR); + + Merge_2x1(A10, A20, ABL); + + Merge_2x2(A11, A12, A21, A22, ABR); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } +} + +template +KOKKOS_INLINE_FUNCTION void Merge_3x1_to_2x1(const MatView A0, MatView &AT, const MatView A1, const MatView A2, + MatView &AB, const int side) { + switch (side) { + case Partition::Top: + Merge_2x1(A0, A1, AT); + + AB.set_view(A2, A2.offset_0(), A2.extent(0), A2.offset_1(), A2.extent(1)); + break; + case Partition::Bottom: + AT.set_view(A0, A0.offset_0(), A0.extent(0), A0.offset_1(), A0.extent(1)); + + Merge_2x1(A1, A2, AB); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } +} + +template +KOKKOS_INLINE_FUNCTION void Merge_1x3_to_1x2(const MatView A0, const MatView A1, const MatView A2, MatView &AL, + MatView &AR, const int side) { + switch (side) { + case Partition::Left: + Merge_1x2(A0, A1, AL); + + AR.set_view(A2, A2.offset_0(), A2.extent(0), A2.offset_1(), A2.extent(1)); + break; + case Partition::Right: + AL.set_view(A0, A0.offset_0(), A0.extent(0), A0.offset_1(), A0.extent(1)); + + Merge_1x2(A1, A2, AR); + break; + default: + TACHO_TEST_FOR_ABORT(true, MSG_INVALID_INPUT); + break; + } +} + +} // namespace Tacho + #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals.hpp index 297ebb5b6199..bbc28463df34 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals.hpp @@ -9,13 +9,12 @@ namespace Tacho { - /// - /// Scale2x2_BlockInverseDiagonals - /// +/// +/// Scale2x2_BlockInverseDiagonals +/// - /// various implementation for different uplo and algo parameters - template - struct Scale2x2_BlockInverseDiagonals; -} +/// various implementation for different uplo and algo parameters +template struct Scale2x2_BlockInverseDiagonals; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_Internal.hpp index 3990a88b232f..39269a69ba9a 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_Internal.hpp @@ -1,180 +1,125 @@ #ifndef __TACHO_SCALE_2X2_BLOCK_INVERSE_DIAGONALS_INTERNAL_HPP #define __TACHO_SCALE_2X2_BLOCK_INVERSE_DIAGONALS_INTERNAL_HPP - /// \file Tacho_Scale2x2_BlockInverseDiagonals_Internal.hpp /// \brief Inverse scale /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - /// row exchange - template<> - struct Scale2x2_BlockInverseDiagonals { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(const ViewTypeP &P, - const ViewTypeD &D, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; - - if (A.extent(0) == D.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - if (n == 1) { - for (ordinal_type i=0;i struct Scale2x2_BlockInverseDiagonals { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ViewTypeP &P, const ViewTypeD &D, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; + + if (A.extent(0) == D.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + if (n == 1) { + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type pval = P(i); + if (pval == 0) { + /// do nothing + } else if (pval < 0) { + /// take 2x2 block to D + const value_type a00 = D(i - 1, 0), a01 = D(i - 1, 1), a10 = D(i, 0), a11 = D(i, 1); + const value_type det = a00 * a11 - a10 * a01; + const value_type x0 = A(i - 1, 0), x1 = A(i, 0); + + A(i - 1, 0) = (a11 * x0 - a10 * x1) / det; + A(i, 0) = (-a10 * x0 + a00 * x1) / det; + } else { + const value_type a00 = D(i, 0); + A(i, 0) /= a00; } - } else { - for (ordinal_type i=0;i A is not square\n"); } - return 0; + } else { + printf("Error: Scale2x2_BlockInverseDiagonals A is not square\n"); } + return 0; + } - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeP &P, - const ViewTypeD &D, - const ViewTypeA &A) { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeP &P, const ViewTypeD &D, + const ViewTypeA &A) { #if defined(__CUDA_ARCH__) - typedef typename ViewTypeA::non_const_value_type value_type; - if (A.extent(0) == D.extent(0)) { - if (A.span() > 0) { - const ordinal_type m = A.extent(0), n = A.extent(1); - if (n == 1) { - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m), - [&](const ordinal_type &i) { - const ordinal_type pval = P(i); - if (pval == 0) { - /// do nothing - } else if (pval < 0) { - /// take 2x2 block to D - const value_type - a00 = D(i-1, 0), a01 = D(i-1, 1), - a10 = D(i , 0), a11 = D(i , 1); - const value_type - det = a00*a11-a10*a01; - const value_type - x0 = A(i-1,0), - x1 = A(i,0); - - A(i-1,0) = ( a11*x0 - a10*x1)/det; - A(i ,0) = (-a10*x0 + a00*x1)/det; - } else { - const value_type - a00 = D(i,0); - A(i,0) /= a00; - } + typedef typename ViewTypeA::non_const_value_type value_type; + if (A.extent(0) == D.extent(0)) { + if (A.span() > 0) { + const ordinal_type m = A.extent(0), n = A.extent(1); + if (n == 1) { + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m), [&](const ordinal_type &i) { + const ordinal_type pval = P(i); + if (pval == 0) { + /// do nothing + } else if (pval < 0) { + /// take 2x2 block to D + const value_type a00 = D(i - 1, 0), a01 = D(i - 1, 1), a10 = D(i, 0), a11 = D(i, 1); + const value_type det = a00 * a11 - a10 * a01; + const value_type x0 = A(i - 1, 0), x1 = A(i, 0); + + A(i - 1, 0) = (a11 * x0 - a10 * x1) / det; + A(i, 0) = (-a10 * x0 + a00 * x1) / det; + } else { + const value_type a00 = D(i, 0); + A(i, 0) /= a00; + } + }); + } else { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), [&](const ordinal_type &i) { + const ordinal_type pval = P(i); + if (pval == 0) { + /// do nothing + } else if (pval < 0) { + /// take 2x2 block to D + const value_type a00 = D(i - 1, 0), a01 = D(i - 1, 1), a10 = D(i, 0), a11 = D(i, 1); + const value_type det = a00 * a11 - a10 * a01; + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const ordinal_type &j) { + const value_type x0 = A(i - 1, j), x1 = A(i, j); + A(i - 1, j) = (a11 * x0 - a10 * x1) / det; + A(i, j) = (-a10 * x0 + a00 * x1) / det; }); - } else { - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, m), - [&](const ordinal_type &i) { - const ordinal_type pval = P(i); - if (pval == 0) { - /// do nothing - } else if (pval < 0) { - /// take 2x2 block to D - const value_type - a00 = D(i-1, 0), a01 = D(i-1, 1), - a10 = D(i , 0), a11 = D(i , 1); - const value_type - det = a00*a11-a10*a01; - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, n), - [&](const ordinal_type &j) { - const value_type - x0 = A(i-1,j), - x1 = A(i,j); - A(i-1,j) = ( a11*x0 - a10*x1)/det; - A(i ,j) = (-a10*x0 + a00*x1)/det; - }); - } else { - const value_type - a00 = D(i,0); - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, n), - [&](const ordinal_type &j) { - A(i,j) /= a00; - }); - } - }); - } + } else { + const value_type a00 = D(i, 0); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const ordinal_type &j) { A(i, j) /= a00; }); + } + }); } - } else { - printf("Error: Scale2x2_BlockInverseDiagonals A is not square\n"); } + } else { + printf("Error: Scale2x2_BlockInverseDiagonals A is not square\n"); + } #else - invoke(P, D, A); + invoke(P, D, A); #endif - return 0; - } - - - - + return 0; + } +}; - }; - - -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_OnDevice.hpp index 0dbe9a8c9370..de51a4221c69 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Scale2x2_BlockInverseDiagonals_OnDevice.hpp @@ -1,106 +1,76 @@ #ifndef __TACHO_SCALE_2X2_BLOCK_INVERSE_DIAGONALS_ON_DEVICE_HPP #define __TACHO_SCALE_2X2_BLOCK_INVERSE_DIAGONALS_ON_DEVICE_HPP - /// \file Tacho_Scale2x2_BlockInverseDiagonals_OnDevice.hpp /// \brief inverse scale /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - /// row exchange - template<> - struct Scale2x2_BlockInverseDiagonals { - template - inline - static int - invoke(MemberType &member, - const ViewTypeP &P, - const ViewTypeD &D, - const ViewTypeA &A) { - typedef typename ViewTypeA::non_const_value_type value_type; +/// row exchange +template <> struct Scale2x2_BlockInverseDiagonals { + template + inline static int invoke(MemberType &member, const ViewTypeP &P, const ViewTypeD &D, const ViewTypeA &A) { + typedef typename ViewTypeA::non_const_value_type value_type; + + const ordinal_type m = A.extent(0), n = A.extent(1); + if (A.extent(0) == D.extent(0)) { + if (A.span() > 0) { + using exec_space = MemberType; + const auto exec_instance = member; - const ordinal_type m = A.extent(0), n = A.extent(1); - if (A.extent(0) == D.extent(0)) { - if (A.span() > 0) { - using exec_space = MemberType; - const auto exec_instance = member; - - if (n == 1) { - Kokkos::RangePolicy policy(exec_instance, 0, m); - Kokkos::parallel_for - (policy, - KOKKOS_LAMBDA(const ordinal_type &i) { + if (n == 1) { + Kokkos::RangePolicy policy(exec_instance, 0, m); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &i) { const ordinal_type pval = P(i); if (pval == 0) { /// do nothing } else if (pval < 0) { /// take 2x2 block to D - const value_type - a00 = D(i-1, 0), a01 = D(i-1, 1), - a10 = D(i , 0), a11 = D(i , 1); - const value_type - det = a00*a11-a10*a01; - const value_type - x0 = A(i-1,0), - x1 = A(i,0); - - A(i-1,0) = ( a11*x0 - a10*x1)/det; - A(i ,0) = (-a10*x0 + a00*x1)/det; + const value_type a00 = D(i - 1, 0), a01 = D(i - 1, 1), a10 = D(i, 0), a11 = D(i, 1); + const value_type det = a00 * a11 - a10 * a01; + const value_type x0 = A(i - 1, 0), x1 = A(i, 0); + + A(i - 1, 0) = (a11 * x0 - a10 * x1) / det; + A(i, 0) = (-a10 * x0 + a00 * x1) / det; } else { - const value_type - a00 = D(i,0); - A(i,0) /= a00; + const value_type a00 = D(i, 0); + A(i, 0) /= a00; } }); - } else { - using policy_type= Kokkos::TeamPolicy; - policy_type policy(exec_instance, m, Kokkos::AUTO); - Kokkos::parallel_for - (policy, - KOKKOS_LAMBDA(const typename policy_type::member_type &member) { + } else { + using policy_type = Kokkos::TeamPolicy; + policy_type policy(exec_instance, m, Kokkos::AUTO); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const typename policy_type::member_type &member) { const ordinal_type i = member.league_rank(); const ordinal_type pval = P(i); if (pval == 0) { /// do nothing } else if (pval < 0) { /// take 2x2 block to D - const value_type - a00 = D(i-1, 0), a01 = D(i-1, 1), - a10 = D(i , 0), a11 = D(i , 1); - const value_type - det = a00*a11-a10*a01; - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n), - [=](const ordinal_type &j) { - const value_type - x0 = A(i-1,j), - x1 = A(i,j); - A(i-1,j) = ( a11*x0 - a10*x1)/det; - A(i ,j) = (-a10*x0 + a00*x1)/det; - }); + const value_type a00 = D(i - 1, 0), a01 = D(i - 1, 1), a10 = D(i, 0), a11 = D(i, 1); + const value_type det = a00 * a11 - a10 * a01; + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, n), [=](const ordinal_type &j) { + const value_type x0 = A(i - 1, j), x1 = A(i, j); + A(i - 1, j) = (a11 * x0 - a10 * x1) / det; + A(i, j) = (-a10 * x0 + a00 * x1) / det; + }); } else { - const value_type - a00 = D(i,0); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n), - [=](const ordinal_type &j) { - A(i,j) /= a00; - }); + const value_type a00 = D(i, 0); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, n), + [=](const ordinal_type &j) { A(i, j) /= a00; }); } }); - } } - } else { - printf("Error: Scale2x2_BlockInverseDiagonals A is not square\n"); } - return 0; + } else { + printf("Error: Scale2x2_BlockInverseDiagonals A is not square\n"); } - }; - - -} + return 0; + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity.hpp index 01fc371a1c30..39f60b537363 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity.hpp @@ -9,14 +9,13 @@ namespace Tacho { - /// - /// SetIdentity - /// +/// +/// SetIdentity +/// - /// various implementation for different uplo and algo parameters - template - struct SetIdentity; +/// various implementation for different uplo and algo parameters +template struct SetIdentity; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_Internal.hpp index e4e4467a8919..3dee7b5e689f 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_Internal.hpp @@ -1,45 +1,30 @@ #ifndef __TACHO_SET_IDENTITY_INTERNAL_HPP__ #define __TACHO_SET_IDENTITY_INTERNAL_HPP__ - /// \file Tacho_SetIdentity_Internal.hpp /// \brief Set an identity matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - - template<> - struct SetIdentity { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ScalarType &alpha) { - typedef typename ViewTypeA::non_const_value_type value_type; - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - const ordinal_type - m = A.extent(0), - n = A.extent(1); +template <> struct SetIdentity { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A, const ScalarType &alpha) { + typedef typename ViewTypeA::non_const_value_type value_type; + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + + const ordinal_type m = A.extent(0), n = A.extent(1); - if (m > 0 && n > 0) { - const value_type diag(alpha), zero(0); - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, n), - [&](const ordinal_type &j) { - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, m), - [&](const ordinal_type &i) { - A(i,j) = i==j ? diag : zero; - }); - }); - } - return 0; + if (m > 0 && n > 0) { + const value_type diag(alpha), zero(0); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const ordinal_type &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, m), + [&](const ordinal_type &i) { A(i, j) = i == j ? diag : zero; }); + }); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_OnDevice.hpp index 8ca7525057f0..7578995ce9ae 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SetIdentity_OnDevice.hpp @@ -1,52 +1,43 @@ #ifndef __TACHO_SET_IDENTITY_ON_DEVICE_HPP__ #define __TACHO_SET_IDENTITY_ON_DEVICE_HPP__ - /// \file Tacho_SetIdentity_OnDevice.hpp /// \brief Set an identity matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - - template<> - struct SetIdentity { - template - inline - static int - invoke(MemberType &member, - const ViewTypeA &A, - const ScalarType &alpha) { - - typedef typename ViewTypeA::non_const_value_type value_type; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - - const ordinal_type - m = A.extent(0), - n = A.extent(1); - - if (m > 0 && n > 0) { - const value_type diag(alpha), zero(0); - using exec_space = MemberType; - using team_policy_type = Kokkos::TeamPolicy; - - const auto exec_instance = member; - const auto policy = team_policy_type(exec_instance, n, Kokkos::AUTO); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const typename team_policy_type::member_type &member) { + +template <> struct SetIdentity { + template + inline static int invoke(MemberType &member, const ViewTypeA &A, const ScalarType &alpha) { + + typedef typename ViewTypeA::non_const_value_type value_type; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + + const ordinal_type m = A.extent(0), n = A.extent(1); + + if (m > 0 && n > 0) { + const value_type diag(alpha), zero(0); + using exec_space = MemberType; + using team_policy_type = Kokkos::TeamPolicy; + + const auto exec_instance = member; + const auto policy = team_policy_type(exec_instance, n, Kokkos::AUTO); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const typename team_policy_type::member_type &member) { const ordinal_type j = member.league_rank(); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m), - [&, diag, zero, A, j](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - A(i,j) = i==j ? diag : zero; - }); + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, m), + [&, diag, zero, A, + j](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + A(i, j) = i == j ? diag : zero; + }); }); - } - return 0; } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Solver_Impl.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Solver_Impl.hpp deleted file mode 100644 index d6f308b3db4a..000000000000 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Solver_Impl.hpp +++ /dev/null @@ -1,919 +0,0 @@ -#ifndef __TACHO_SOLVER_IMPL_HPP__ -#define __TACHO_SOLVER_IMPL_HPP__ - -/// \file Tacho_Solver_Impl.hpp -/// \brief solver interface -/// \author Kyungjoo Kim (kyukim@sandia.gov) - -#include "Tacho_Internal.hpp" -#include "Tacho_Solver.hpp" - -namespace Tacho { - - template - Solver - ::Solver() - : _transpose(0), _mode(0), _order_connected_graph_separately(0), - _m(0), _nnz(0), - _ap(), _h_ap(), _aj(), _h_aj(), - _perm(), _h_perm(), _peri(), _h_peri(), - _m_graph(0), _nnz_graph(0), - _h_ap_graph(), _h_aj_graph(), - _h_perm_graph(), _h_peri_graph(), - _N(nullptr), - _L0(nullptr), - _L1(nullptr), - _L2(nullptr), - _verbose(0), - _small_problem_thres(1024), - _serial_thres_size(-1), - _mb(-1), - _nb(-1), - _front_update_mode(-1), - _levelset(0), - _device_level_cut(0), - _device_factor_thres(64), - _device_solve_thres(128), - _variant(2), - _nstreams(16), - _max_num_superblocks(-1) {} - - /// deleted - // template - // Solver - // ::Solver(const Solver &b) = default; - - /// - /// common options - /// - template - void - Solver - ::setVerbose(const ordinal_type verbose) { - _verbose = verbose; - } - - template - void - Solver - ::setSmallProblemThresholdsize(const ordinal_type small_problem_thres) { - _small_problem_thres = small_problem_thres; - } - - // template - // void - // Solver - // ::setTransposeSolve(const bool transpose) { - // _transpose = transpose; // this option is not used yet - // } - - template - void - Solver - ::setMatrixType(const int symmetric, // 0 - unsymmetric, 1 - structure sym, 2 - symmetric - const bool is_positive_definite) { - switch (symmetric) { - case 0: { _mode = LU; break; } - case 1: { _mode = SymLU; break; } - case 2: { - if (is_positive_definite) { - if (std::is_same::value || - std::is_same::value || - std::is_same >::value || - std::is_same >::value) { - // real symmetric posdef - _mode = Cholesky; - } - } else { // real or complex symmetric indef - _mode = LDL; - } - break; - } - default: { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "symmetric argument is wrong"); - } - } - } - - template - void - Solver - ::setOrderConnectedGraphSeparately(const ordinal_type order_connected_graph_separately) { - _order_connected_graph_separately = order_connected_graph_separately; - } - - /// - /// tasking options - /// - template - void - Solver - ::setSerialThresholdsize(const ordinal_type serial_thres_size) { - _serial_thres_size = serial_thres_size; - } - - template - void - Solver - ::setBlocksize(const ordinal_type mb) { - _mb = mb; - } - - template - void - Solver - ::setPanelsize(const ordinal_type nb) { - _nb = nb; - } - - template - void - Solver - ::setFrontUpdateMode(const ordinal_type front_update_mode) { - _front_update_mode = front_update_mode; - } - - template - void - Solver - ::setMaxNumberOfSuperblocks(const ordinal_type max_num_superblocks) { - _max_num_superblocks = max_num_superblocks; - } - - /// - /// Level set tools options - /// - template - void - Solver - ::setLevelSetScheduling(const bool levelset) { - _levelset = levelset; - } - - template - void - Solver - ::setLevelSetOptionDeviceLevelCut(const ordinal_type device_level_cut) { - _device_level_cut = device_level_cut; - } - - template - void - Solver - ::setLevelSetOptionDeviceFunctionThreshold(const ordinal_type device_factor_thres, - const ordinal_type device_solve_thres) { - _device_factor_thres = device_factor_thres; - _device_solve_thres = device_solve_thres; - } - - template - void - Solver - ::setLevelSetOptionAlgorithmVariant(const ordinal_type variant) { - if (variant > 2 || variant < 0) { - std::logic_error("levelset algorithm variants range from 0 to 2"); - } - _variant = variant; - } - - template - void - Solver - ::setLevelSetOptionNumStreams(const ordinal_type nstreams) { - _nstreams = nstreams; - } - - /// - /// get interface - /// - template - ordinal_type - Solver - ::getNumSupernodes() const { - return _nsupernodes; - } - - template - typename Solver::ordinal_type_array - Solver - ::getSupernodes() const { - return _supernodes; - } - - template - typename Solver::ordinal_type_array - Solver - ::getPermutationVector() const { - return _perm; - } - - template - typename Solver::ordinal_type_array - Solver - ::getInversePermutationVector() const { - return _peri; - } - - - // internal only - template - int - Solver - ::analyze() { - int r_val(0); - if (_m < _small_problem_thres) { - /// do nothing - if (_verbose) { - printf("TachoSolver: Analyze\n"); - printf("====================\n"); - printf(" Linear system A\n"); - printf(" number of equations: %10d\n", _m); - printf("\n"); - printf(" A is a small problem ( < %d ) and LAPACK is used\n", _small_problem_thres); - printf("\n"); - } - } else { - const bool use_condensed_graph = (_m_graph > 0 && _m_graph < _m); - if (use_condensed_graph) { - Graph graph(_m_graph, _nnz_graph, _h_ap_graph, _h_aj_graph); - graph_tools_type G(graph); -#if defined(TACHO_HAVE_METIS) - if (_order_connected_graph_separately) { - G.setOption(METIS_OPTION_CCORDER, 1); - } -#endif - G.reorder(_verbose); - - _h_perm_graph = G.PermVector(); - _h_peri_graph = G.InvPermVector(); - - r_val = analyze_condensed_graph(); - } else { - const bool use_graph_partitioner = (_h_perm.extent(0) == 0 && _h_peri.extent(0) == 0); - if (use_graph_partitioner) { - Graph graph(_m, _nnz, _h_ap, _h_aj); - graph_tools_type G(graph); -#if defined(TACHO_HAVE_METIS) - if (_order_connected_graph_separately) { - G.setOption(METIS_OPTION_CCORDER, 1); - } -#endif - G.reorder(_verbose); - - _h_perm = G.PermVector(); - _h_peri = G.InvPermVector(); - - r_val = analyze_linear_system(); - } else { - r_val = analyze_linear_system(); - } - } - } - return r_val; - } - - template - int - Solver - ::analyze_linear_system() { - if (_verbose) { - printf("TachoSolver: Analyze Linear System\n"); - printf("==================================\n"); - } - - { - symbolic_tools_type S(_m, _h_ap, _h_aj, _h_perm, _h_peri); - S.symbolicFactorize(_verbose); - - _nsupernodes = S.NumSupernodes(); - _stree_level = S.SupernodesTreeLevel(); - _stree_roots = S.SupernodesTreeRoots(); - - _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); - _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); - _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); - _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); - _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); - _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); - _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); - _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); - _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); - - Kokkos::deep_copy(_supernodes , S.Supernodes()); - Kokkos::deep_copy(_gid_super_panel_ptr , S.gidSuperPanelPtr()); - Kokkos::deep_copy(_gid_super_panel_colidx , S.gidSuperPanelColIdx()); - Kokkos::deep_copy(_sid_super_panel_ptr , S.sidSuperPanelPtr()); - Kokkos::deep_copy(_sid_super_panel_colidx , S.sidSuperPanelColIdx()); - Kokkos::deep_copy(_blk_super_panel_colidx , S.blkSuperPanelColIdx()); - Kokkos::deep_copy(_stree_parent , S.SupernodesTreeParent()); - Kokkos::deep_copy(_stree_ptr , S.SupernodesTreePtr()); - Kokkos::deep_copy(_stree_children , S.SupernodesTreeChildren()); - - // perm and peri is updated during symbolic factorization - _perm = Kokkos::create_mirror_view(exec_memory_space(), _h_perm); - _peri = Kokkos::create_mirror_view(exec_memory_space(), _h_peri); - - Kokkos::deep_copy(_perm, _h_perm); - Kokkos::deep_copy(_peri, _h_peri); - } - return 0; - } - - template - int - Solver - ::analyze_condensed_graph() { - if (_verbose) { - printf("TachoSolver: Analyze Condensed Graph and Evaporate the Graph\n"); - printf("============================================================\n"); - } - - { - symbolic_tools_type S(_m_graph, _h_ap_graph, _h_aj_graph, _h_perm_graph, _h_peri_graph); - S.symbolicFactorize(_verbose); - S.evaporateSymbolicFactors(_h_aw_graph, _verbose); - - _nsupernodes = S.NumSupernodes(); - _stree_level = S.SupernodesTreeLevel(); - _stree_roots = S.SupernodesTreeRoots(); - - _supernodes = Kokkos::create_mirror_view(exec_memory_space(), S.Supernodes()); - _gid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelPtr()); - _gid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.gidSuperPanelColIdx()); - _sid_super_panel_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelPtr()); - _sid_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.sidSuperPanelColIdx()); - _blk_super_panel_colidx = Kokkos::create_mirror_view(exec_memory_space(), S.blkSuperPanelColIdx()); - _stree_parent = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeParent()); - _stree_ptr = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreePtr()); - _stree_children = Kokkos::create_mirror_view(exec_memory_space(), S.SupernodesTreeChildren()); - _perm = Kokkos::create_mirror_view(exec_memory_space(), S.PermVector()); - _peri = Kokkos::create_mirror_view(exec_memory_space(), S.InvPermVector()); - - Kokkos::deep_copy(_supernodes , S.Supernodes()); - Kokkos::deep_copy(_gid_super_panel_ptr , S.gidSuperPanelPtr()); - Kokkos::deep_copy(_gid_super_panel_colidx , S.gidSuperPanelColIdx()); - Kokkos::deep_copy(_sid_super_panel_ptr , S.sidSuperPanelPtr()); - Kokkos::deep_copy(_sid_super_panel_colidx , S.sidSuperPanelColIdx()); - Kokkos::deep_copy(_blk_super_panel_colidx , S.blkSuperPanelColIdx()); - Kokkos::deep_copy(_stree_parent , S.SupernodesTreeParent()); - Kokkos::deep_copy(_stree_ptr , S.SupernodesTreePtr()); - Kokkos::deep_copy(_stree_children , S.SupernodesTreeChildren()); - Kokkos::deep_copy(_perm , S.PermVector()); - Kokkos::deep_copy(_peri , S.InvPermVector()); - - _h_perm = S.PermVector(); - _h_peri = S.InvPermVector(); - } - return 0; - } - - template - int - Solver - ::initialize() { - if (_verbose) { - printf("TachoSolver: Initialize\n"); - printf("=======================\n"); - } - - /// - /// initialize numeric tools - /// - if (_m < _small_problem_thres) { - //_A = value_type_matrix_host("A", _m, _m); - } else { - if (_N == nullptr) - _N = (numeric_tools_type*) ::operator new (sizeof(numeric_tools_type)); - - new (_N) numeric_tools_type(_m, _ap, _aj, - _perm, _peri, - _nsupernodes, _supernodes, - _gid_super_panel_ptr, _gid_super_panel_colidx, - _sid_super_panel_ptr, _sid_super_panel_colidx, _blk_super_panel_colidx, - _stree_parent, _stree_ptr, _stree_children, - _stree_level, _stree_roots); - - if (_serial_thres_size < 0) { // set default values - _serial_thres_size = 64; - } - _N->setSerialThresholdSize(_serial_thres_size); - - if (_max_num_superblocks < 0) { // set default values - _max_num_superblocks = 16; - } - _N->setMaxNumberOfSuperblocks(_max_num_superblocks); - - if (_front_update_mode < 0) { // set default values - _front_update_mode = 1; // atomic is default - } - _N->setFrontUpdateMode(_front_update_mode); - _N->printMemoryStat(_verbose); - - /// - /// initialize levelset tools - /// - if (_levelset) { - if (_variant == 0) { - if (_L0 == nullptr) - _L0 = (levelset_tools_var0_type*) ::operator new (sizeof(levelset_tools_var0_type)); - new (_L0) levelset_tools_var0_type(*_N); - _L0->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); - _L0->createStream(_nstreams); - } else if (_variant == 1) { - if (_L1 == nullptr) - _L1 = (levelset_tools_var1_type*) ::operator new (sizeof(levelset_tools_var1_type)); - new (_L1) levelset_tools_var1_type(*_N); - _L1->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); - _L1->createStream(_nstreams); - } else if (_variant == 2) { - if (_L2 == nullptr) - _L2 = (levelset_tools_var2_type*) ::operator new (sizeof(levelset_tools_var2_type)); - new (_L2) levelset_tools_var1_type(*_N); - _L2->initialize(_device_level_cut, _device_factor_thres, _device_solve_thres, _verbose); - _L2->createStream(_nstreams); - } - } - } - return 0; - } - - template - int - Solver - ::factorize(const value_type_array &ax) { - switch (_mode) { - case Cholesky: factorize_chol(ax); break; - case LDL: factorize_ldl(ax); break; - default: - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); - } - return 0; - } - - template - int - Solver - ::factorize_chol(const value_type_array &ax) { - if (_verbose) { - printf("TachoSolver: Factorize\n"); - printf("======================\n"); - } - if (_m < _small_problem_thres) { - Kokkos::Timer timer; - - timer.reset(); - _A = value_type_matrix_host("A", _m, _m); - auto h_ax = Kokkos::create_mirror_view(host_memory_space(), ax); Kokkos::deep_copy(h_ax, ax); - for (ordinal_type i=0;i<_m;++i) { - const size_type jbeg = _h_ap(i), jend = _h_ap(i+1); - for (size_type j=jbeg;j::invoke(_A); - const double t_factor = timer.seconds(); - - if (_verbose) { - printf("Summary: NumericTools (SmallDenseFactorization)\n"); - printf("===============================================\n"); - printf(" Time\n"); - printf(" time for copying A into U: %10.6f s\n", t_copy); - printf(" time for numeric factorization: %10.6f s\n", t_factor); - printf(" total time spent: %10.6f s\n", (t_copy+t_factor)); - printf("\n"); - } - } else { - -#if !defined (KOKKOS_ENABLE_CUDA) - const ordinal_type nthreads = host_space::impl_thread_pool_size(0); -#endif - if (_levelset) { - if (_variant == 0) _L0->factorizeCholesky(ax, _verbose); - else if (_variant == 1) _L1->factorizeCholesky(ax, _verbose); - else if (_variant == 2) _L2->factorizeCholesky(ax, _verbose); - } -#if !defined (KOKKOS_ENABLE_CUDA) - else if (nthreads == 1) { - if (_nb < 0) - _N->factorizeCholesky_Serial(ax, _verbose); - else - _N->factorizeCholesky_SerialPanel(ax, _nb, _verbose); - } -#endif - else { - const ordinal_type max_dense_size = max(_N->getMaxSupernodeSize(),_N->getMaxSchurSize()); - if (std::is_same::value) { - if (_nb < 0) { - _nb = 64; - // if (max_dense_size < 256) _nb = -1; - // else if (max_dense_size < 512) _nb = 64; - // else if (max_dense_size < 1024) _nb = 128; - // else if (max_dense_size < 8192) _nb = 256; - // else _nb = 256; - } - if (_mb < 0) { - if (max_dense_size < 256) _mb = -1; - else if (max_dense_size < 512) _mb = 64; - else if (max_dense_size < 1024) _mb = 128; - else if (max_dense_size < 8192) _mb = 256; - else _mb = 256; - } - } else { - if (_nb < 0) { - _nb = 40; - // if (max_dense_size < 256) _nb = -1; - // else if (max_dense_size < 512) _nb = 64; - // else if (max_dense_size < 1024) _nb = 128; - // else if (max_dense_size < 8192) _nb = 256; - // else _nb = 256; - } - if (_mb < 0) { - if (max_dense_size < 256) _mb = -1; - else if (max_dense_size < 512) _mb = 80; - else if (max_dense_size < 1024) _mb = 120; - else if (max_dense_size < 8192) _mb = 160; - else _mb = 160; - } - } - - if (_nb <= 0) - if (_mb > 0) - _N->factorizeCholesky_ParallelByBlocks(ax, _mb, _verbose); - else - _N->factorizeCholesky_Parallel(ax, _verbose); - else - if (_mb > 0) - _N->factorizeCholesky_ParallelByBlocksPanel(ax, _mb, _nb, _verbose); - else - _N->factorizeCholesky_ParallelPanel(ax, _nb, _verbose); - } - } - return 0; - } - - template - int - Solver - ::factorize_ldl(const value_type_array &ax) { - if (_verbose) { - printf("TachoSolver: Factorize\n"); - printf("======================\n"); - } - if (_m < _small_problem_thres) { - Kokkos::Timer timer; - - timer.reset(); - _A = value_type_matrix_host("A", _m, _m); - auto h_ax = Kokkos::create_mirror_view(host_memory_space(), ax); Kokkos::deep_copy(h_ax, ax); - for (ordinal_type i=0;i<_m;++i) { - const size_type jbeg = _h_ap(i), jend = _h_ap(i+1); - for (size_type j=jbeg;j= col) - _A(i, col) = h_ax(j); - } - } - const double t_copy = timer.seconds(); - - timer.reset(); - _P = ordinal_type_array_host("P", 4*_m); - _D = value_type_matrix_host("D", _m, 2); - auto W = value_type_array_host("W", 32*_m); - Tacho::LDL::invoke(_A, _P, W); - Tacho::LDL::modify(_A, _P, _D); - const double t_factor = timer.seconds(); - - if (_verbose) { - printf("Summary: NumericTools (SmallDenseFactorization)\n"); - printf("===============================================\n"); - printf(" Time\n"); - printf(" time for copying A into L: %10.6f s\n", t_copy); - printf(" time for numeric factorization: %10.6f s\n", t_factor); - printf(" total time spent: %10.6f s\n", (t_copy+t_factor)); - printf("\n"); - } - } else { -#if !defined (KOKKOS_ENABLE_CUDA) - const ordinal_type nthreads = host_space::impl_thread_pool_size(0); -#endif - if (_levelset) { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); - // if (_variant == 0) _L0->factorizeLDL(ax, _verbose); - // else if (_variant == 1) _L1->factorizeLDL(ax, _verbose); - // else if (_variant == 2) _L2->factorizeLDL(ax, _verbose); - } -#if !defined (KOKKOS_ENABLE_CUDA) - else if (nthreads == 1) { - _N->factorizeLDL_Serial(ax, _verbose); - // if (_nb < 0) - // _N->factorizeLDL_Serial(ax, _verbose); - // else - // _N->factorizeLDL_SerialPanel(ax, _nb, _verbose); - } -#endif - else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); - } - } - return 0; - } - - template - int - Solver - ::solve(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t) { - switch (_mode) { - case Cholesky: solve_chol(x, b, t); break; - case LDL: solve_ldl(x, b, t); break; - default: - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); - } - return 0; - } - - template - int - Solver - ::solve_chol(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t) { - if (_verbose) { - printf("TachoSolver: Solve\n"); - printf("==================\n"); - } - - if (_m < _small_problem_thres) { - Kokkos::Timer timer; - - timer.reset(); - Kokkos::deep_copy(x, b); - const double t_copy = timer.seconds(); - - timer.reset(); - auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); - Kokkos::deep_copy(h_x, x); - Trsm - ::invoke(Diag::NonUnit(), 1.0, _A, h_x); - Trsm - ::invoke(Diag::NonUnit(), 1.0, _A, h_x); - Kokkos::deep_copy(x, h_x); - const double t_solve = timer.seconds(); - - if (_verbose) { - printf("Summary: NumericTools (SmallDenseSolve)\n"); - printf("=======================================\n"); - printf(" Time\n"); - printf(" time for extra work e.g.,copy rhs: %10.6f s\n", t_copy); - printf(" time for numeric solve: %10.6f s\n", t_solve); - printf(" total time spent: %10.6f s\n", (t_solve+t_copy)); - printf("\n"); - } - } else { -#if !defined (KOKKOS_ENABLE_CUDA) - const ordinal_type nthreads = host_space::impl_thread_pool_size(0); -#endif - TACHO_TEST_FOR_EXCEPTION(t.extent(0) < x.extent(0) || - t.extent(1) < x.extent(1), std::logic_error, "Temporary rhs vector t is smaller than x"); - auto tt = Kokkos::subview(t, - Kokkos::pair(0, x.extent(0)), - Kokkos::pair(0, x.extent(1))); - if (_levelset) { - if (_variant == 0) _L0->solveCholesky(x, b, tt, _verbose); - else if (_variant == 1) _L1->solveCholesky(x, b, tt, _verbose); - else if (_variant == 2) _L2->solveCholesky(x, b, tt, _verbose); - } -#if !defined (KOKKOS_ENABLE_CUDA) - else if (nthreads == 1) { - _N->solveCholesky_Serial(x, b, tt, _verbose); - } -#endif - else { - _N->solveCholesky_Parallel(x, b, tt, _verbose); - } - } - return 0; - } - - template - int - Solver - ::solve_ldl(const value_type_matrix &x, - const value_type_matrix &b, - const value_type_matrix &t) { - if (_verbose) { - printf("TachoSolver: Solve\n"); - printf("==================\n"); - } - - if (_m < _small_problem_thres) { - Kokkos::Timer timer; - - timer.reset(); - Kokkos::deep_copy(x, b); - const double t_copy = timer.seconds(); - - timer.reset(); - auto perm = ordinal_type_array_host(_P.data()+2*_m, _m); - auto peri = ordinal_type_array_host(_P.data()+3*_m, _m); - auto h_x = Kokkos::create_mirror_view(host_memory_space(), x); - auto h_t = Kokkos::create_mirror_view(host_memory_space(), t); - - ApplyPermutation - ::invoke(h_x, perm, h_t); - Trsm - ::invoke(Diag::Unit(), 1.0, _A, h_t); - Scale2x2_BlockInverseDiagonals - ::invoke(_P, _D, h_t); - Trsm - ::invoke(Diag::Unit(), 1.0, _A, h_t); - ApplyPermutation - ::invoke(h_t, peri, h_x); - Kokkos::deep_copy(x, h_x); - const double t_solve = timer.seconds(); - - if (_verbose) { - printf("Summary: NumericTools (SmallDenseSolve)\n"); - printf("=======================================\n"); - printf(" Time\n"); - printf(" time for extra work e.g.,copy rhs: %10.6f s\n", t_copy); - printf(" time for numeric solve: %10.6f s\n", t_solve); - printf(" total time spent: %10.6f s\n", (t_solve+t_copy)); - printf("\n"); - } - } else { -#if !defined (KOKKOS_ENABLE_CUDA) - const ordinal_type nthreads = host_space::impl_thread_pool_size(0); -#endif - TACHO_TEST_FOR_EXCEPTION(t.extent(0) < x.extent(0) || - t.extent(1) < x.extent(1), std::logic_error, "Temporary rhs vector t is smaller than x"); - auto tt = Kokkos::subview(t, - Kokkos::pair(0, x.extent(0)), - Kokkos::pair(0, x.extent(1))); - if (_levelset) { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Not implemented yet"); - // if (_variant == 0) _L0->solveCholesky(x, b, tt, _verbose); - // else if (_variant == 1) _L1->solveCholesky(x, b, tt, _verbose); - // else if (_variant == 2) _L2->solveCholesky(x, b, tt, _verbose); - } -#if !defined (KOKKOS_ENABLE_CUDA) - else if (nthreads == 1) { - _N->solveLDL_Serial(x, b, tt, _verbose); - } -#endif - else { - //_N->solveCholesky_Parallel(x, b, tt, _verbose); - } - } - return 0; - } - - template - double - Solver - ::computeRelativeResidual(const value_type_array &ax, - const value_type_matrix &x, - const value_type_matrix &b) { - CrsMatrixBase A; - A.setExternalMatrix(_m, _m, _nnz, _ap, _aj, ax); - - return Tacho::computeRelativeResidual(A, x, b); - } - - template - int - Solver - ::exportFactorsToCrsMatrix(crs_matrix_type &A) { - if (_m < _small_problem_thres) { - typedef ArithTraits ats; - const typename ats::mag_type zero(0); - - /// count nonzero elements in dense U - const ordinal_type m = _m; - size_type_array_host h_ap("h_ap", m+1); - for (ordinal_type i=0;i zero); - - /// serial scan; this is a small problem - h_ap(0) = 0; - for (ordinal_type i=0;i zero) { - h_aj(k) = j; - h_ax(k) = _A(i,j); - ++k; - } - - crs_matrix_type_host h_A; - h_A.setExternalMatrix(m, m, nnz, h_ap, h_aj, h_ax); - ///h_A.showMe(std::cout, true); - A.clear(); - A.createConfTo(h_A); - A.copy(h_A); - } else { - _N->exportFactorsToCrsMatrix(A, false); - } - return 0; - } - - template - int - Solver - ::release() { - if (_verbose) { - printf("TachoSolver: Release\n"); - printf("====================\n"); - } - - if (_levelset) { - if (_variant == 0) { - if (_L0 != nullptr) - _L0->release(_verbose); - delete _L0; _L0 = nullptr; - } else if (_variant == 1) { - if (_L1 != nullptr) - _L1->release(_verbose); - delete _L1; _L1 = nullptr; - } else if (_variant == 2) { - if (_L2 != nullptr) - _L2->release(_verbose); - delete _L2; _L2 = nullptr; - } - } - - { - if (_N != nullptr) - _N->release(_verbose); - delete _N; _N = nullptr; - } - { - _transpose = false; - _mode = 0; - - _m = 0; - _nnz = 0; - - _ap = size_type_array(); _h_ap = size_type_array_host(); - _aj = ordinal_type_array(); _h_aj = ordinal_type_array_host(); - - _perm = ordinal_type_array(); _h_perm = ordinal_type_array_host(); - _peri = ordinal_type_array(); _h_peri = ordinal_type_array_host(); - - _m_graph = 0; - _nnz_graph = 0; - - _h_ap_graph = size_type_array_host(); - _h_aj_graph = ordinal_type_array_host(); - - _h_perm_graph = ordinal_type_array_host(); - _h_peri_graph = ordinal_type_array_host(); - - _nsupernodes = 0; - _supernodes = ordinal_type_array(); - - _gid_super_panel_ptr = size_type_array(); - _gid_super_panel_colidx = ordinal_type_array(); - - _sid_super_panel_ptr = size_type_array(); - - _sid_super_panel_colidx = ordinal_type_array(); - _blk_super_panel_colidx = ordinal_type_array(); - - _stree_ptr = size_type_array(); - _stree_children = ordinal_type_array(); - - _stree_parent = ordinal_type_array(); - _stree_roots = ordinal_type_array_host(); - - _A = value_type_matrix_host(); - - _verbose = 0; - _small_problem_thres = 1024; - _serial_thres_size = -1; - _mb = -1; - _nb = -1; - _front_update_mode = -1; - - _max_num_superblocks = -1; - } - return 0; - } - -} - -#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SupernodeInfo.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SupernodeInfo.hpp index 7bdf2b8aab29..402d3a63fbec 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_SupernodeInfo.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_SupernodeInfo.hpp @@ -8,526 +8,451 @@ namespace Tacho { - struct SuperNodeInfoInitReducer { - typedef SuperNodeInfoInitReducer reducer; - struct ValueType { - ordinal_type max_nchildren; - ordinal_type max_supernode_size; - ordinal_type max_schur_size; - size_type nnz; - }; - typedef struct ValueType value_type; - - typedef Kokkos::View > result_view_type; - - value_type *value; - - KOKKOS_INLINE_FUNCTION SuperNodeInfoInitReducer() : value() {} //= default; - KOKKOS_INLINE_FUNCTION SuperNodeInfoInitReducer(const SuperNodeInfoInitReducer &b) : value(b.value) {} //= default; - KOKKOS_INLINE_FUNCTION SuperNodeInfoInitReducer(value_type &val) : value(&val) {} - - KOKKOS_INLINE_FUNCTION void join(value_type &dst, value_type &src) const { - dst.max_nchildren = ( src.max_nchildren > dst.max_nchildren ? - src.max_nchildren : dst.max_nchildren ); - dst.max_supernode_size = ( src.max_supernode_size > dst.max_supernode_size ? - src.max_supernode_size : dst.max_supernode_size ); - dst.max_schur_size = ( src.max_schur_size > dst.max_schur_size ? - src.max_schur_size : dst.max_schur_size ); - dst.nnz += src.nnz; - } - - KOKKOS_INLINE_FUNCTION void join(volatile value_type &dst, const volatile value_type &src) const { - dst.max_nchildren = ( src.max_nchildren > dst.max_nchildren ? - src.max_nchildren : dst.max_nchildren ); - dst.max_supernode_size = ( src.max_supernode_size > dst.max_supernode_size ? - src.max_supernode_size : dst.max_supernode_size ); - dst.max_schur_size = ( src.max_schur_size > dst.max_schur_size ? - src.max_schur_size : dst.max_schur_size ); - dst.nnz += src.nnz; - } - - KOKKOS_INLINE_FUNCTION void init(value_type &val) const { - val.max_nchildren = Kokkos::reduction_identity::max(); - val.max_supernode_size = Kokkos::reduction_identity::max(); - val.max_schur_size = Kokkos::reduction_identity::max(); - val.nnz = Kokkos::reduction_identity::sum(); - } - - KOKKOS_INLINE_FUNCTION - value_type& reference() { - return *value; - } - - KOKKOS_INLINE_FUNCTION - result_view_type view() const { - return result_view_type(value); - } - }; - - template - struct SupernodeInfo { - typedef ValueType value_type; - typedef SchedulerType scheduler_type; - - typedef typename UseThisDevice::type device_type; - typedef typename device_type::execution_space exec_space; - typedef typename device_type::memory_space exec_memory_space; - - typedef typename UseThisDevice::type host_device_type; - typedef typename host_device_type::execution_space host_space; - typedef typename host_device_type::memory_space host_memory_space; - - typedef CrsMatrixBase crs_matrix_type; - - typedef Kokkos::View ordinal_type_array; - typedef Kokkos::View size_type_array; - typedef Kokkos::View value_type_array; - - typedef Kokkos::pair ordinal_pair_type; - typedef Kokkos::View ordinal_pair_type_array; - typedef Kokkos::View value_type_matrix; - - typedef DenseMatrixView dense_block_type; - typedef DenseMatrixView dense_matrix_of_blocks_type; - - typedef Kokkos::BasicFuture future_type; - - struct Supernode { - mutable int32_t lock; - - ordinal_type row_begin; // beginning row - ordinal_type m, n; // panel dimension - - // column connectivity (gid - dof, sid - supernode) - ordinal_type gid_col_begin, gid_col_end, sid_col_begin, sid_col_end; - ordinal_type nchildren, *children; //children[MaxDependenceSize]; // hierarchy - - ordinal_type max_decendant_schur_size; // workspace - ordinal_type max_decendant_supernode_size; // workspace - - value_type *buf; - - KOKKOS_INLINE_FUNCTION - Supernode() - : lock(0), row_begin(0), m(0), n(0), - gid_col_begin(0), gid_col_end(0), - sid_col_begin(0), sid_col_end(0), - nchildren(0), children(NULL), - max_decendant_schur_size(0), - max_decendant_supernode_size(0), - buf(NULL) { - //for (ordinal_type i=0;i supernode_type_array; - - /// - /// info for symbolic - /// - //ConstUnmanagedViewType supernodes; - UnmanagedViewType supernodes; - - /// dof mapping to sparse matrix - UnmanagedViewType gid_colidx; - - /// supernode map and panel size configuration - /// first - sid, second - blk , blk_superpanel_colidx; - /// the last sid is dummy but last blk is ending point of the block - UnmanagedViewType sid_block_colidx; - - /// - /// max parameter - /// - ordinal_type max_nchildren, max_supernode_size, max_schur_size; - - /// - /// frontal matrix subassembly mode and serialization parameter - /// - short front_update_mode, serial_thres_size; // 0 - lock, 1 - atomic - - /// - /// info for solve (rhs multivector) - UnmanagedViewType x; - - KOKKOS_INLINE_FUNCTION - SupernodeInfo() - : - supernodes(), gid_colidx(), sid_block_colidx(), - max_nchildren(), max_supernode_size(), max_schur_size(), - front_update_mode(), serial_thres_size(), - x() {} - //= default; - - KOKKOS_INLINE_FUNCTION - SupernodeInfo(const SupernodeInfo &b) - : - supernodes(b.supernodes), gid_colidx(b.gid_colidx), sid_block_colidx(b.sid_block_colidx), +struct SuperNodeInfoInitReducer { + using reducer = SuperNodeInfoInitReducer; + struct ValueType { + ordinal_type max_nchildren; + ordinal_type max_supernode_size; + ordinal_type max_schur_size; + size_type nnz; + }; + using value_type = struct ValueType; + + using result_view_type = Kokkos::View>; + + value_type *value; + + KOKKOS_INLINE_FUNCTION SuperNodeInfoInitReducer() : value() {} //= default; + KOKKOS_INLINE_FUNCTION SuperNodeInfoInitReducer(const SuperNodeInfoInitReducer &b) : value(b.value) {} //= default; + KOKKOS_INLINE_FUNCTION SuperNodeInfoInitReducer(value_type &val) : value(&val) {} + + KOKKOS_INLINE_FUNCTION void join(value_type &dst, value_type &src) const { + dst.max_nchildren = (src.max_nchildren > dst.max_nchildren ? src.max_nchildren : dst.max_nchildren); + dst.max_supernode_size = + (src.max_supernode_size > dst.max_supernode_size ? src.max_supernode_size : dst.max_supernode_size); + dst.max_schur_size = (src.max_schur_size > dst.max_schur_size ? src.max_schur_size : dst.max_schur_size); + dst.nnz += src.nnz; + } + + KOKKOS_INLINE_FUNCTION void join(volatile value_type &dst, const volatile value_type &src) const { + dst.max_nchildren = (src.max_nchildren > dst.max_nchildren ? src.max_nchildren : dst.max_nchildren); + dst.max_supernode_size = + (src.max_supernode_size > dst.max_supernode_size ? src.max_supernode_size : dst.max_supernode_size); + dst.max_schur_size = (src.max_schur_size > dst.max_schur_size ? src.max_schur_size : dst.max_schur_size); + dst.nnz += src.nnz; + } + + KOKKOS_INLINE_FUNCTION void init(value_type &val) const { + val.max_nchildren = Kokkos::reduction_identity::max(); + val.max_supernode_size = Kokkos::reduction_identity::max(); + val.max_schur_size = Kokkos::reduction_identity::max(); + val.nnz = Kokkos::reduction_identity::sum(); + } + + KOKKOS_INLINE_FUNCTION + value_type &reference() { return *value; } + + KOKKOS_INLINE_FUNCTION + result_view_type view() const { return result_view_type(value); } +}; + +template struct SupernodeInfo { + using value_type = ValueType; + + using device_type = DeviceType; + using exec_space = typename device_type::execution_space; + + using host_device_type = typename UseThisDevice::type; + using host_memory_space = typename host_device_type::memory_space; + + using crs_matrix_type = CrsMatrixBase; + + using ordinal_type_array = Kokkos::View; + using size_type_array = Kokkos::View; + using value_type_array = Kokkos::View; + + using ordinal_pair_type = Kokkos::pair; + using ordinal_pair_type_array = Kokkos::View; + using value_type_matrix = Kokkos::View; + + using range_type = Kokkos::pair; + + struct Supernode { + mutable int32_t lock; + + ordinal_type row_begin; // beginning row + ordinal_type m, n; // panel dimension + + // column connectivity (gid - dof, sid - supernode) + ordinal_type gid_col_begin, gid_col_end, sid_col_begin, sid_col_end; + ordinal_type nchildren, *children; // children[MaxDependenceSize]; // hierarchy + + ordinal_type max_decendant_schur_size; // workspace + ordinal_type max_decendant_supernode_size; // workspace + + value_type *l_buf, *u_buf; + + bool do_not_apply_pivots; + + KOKKOS_INLINE_FUNCTION + Supernode() + : lock(0), row_begin(0), m(0), n(0), gid_col_begin(0), gid_col_end(0), sid_col_begin(0), sid_col_end(0), + nchildren(0), children(NULL), max_decendant_schur_size(0), max_decendant_supernode_size(0), l_buf(NULL), + u_buf(NULL), do_not_apply_pivots(false) { + // for (ordinal_type i=0;i; + + /// + /// info for symbolic + /// + // ConstUnmanagedViewType supernodes; + UnmanagedViewType supernodes; + + /// dof mapping to sparse matrix + UnmanagedViewType gid_colidx; + + /// supernode map and panel size configuration + /// first - sid, second - blk , blk_superpanel_colidx; + /// the last sid is dummy but last blk is ending point of the block + UnmanagedViewType sid_block_colidx; + + /// + /// max parameter + /// + ordinal_type max_nchildren, max_supernode_size, max_schur_size; + + /// + /// frontal matrix subassembly mode and serialization parameter + /// + short front_update_mode, serial_thres_size; // 0 - lock, 1 - atomic + + /// + /// info for solve (rhs multivector) + UnmanagedViewType x; + + KOKKOS_INLINE_FUNCTION + SupernodeInfo() + : supernodes(), gid_colidx(), sid_block_colidx(), max_nchildren(), max_supernode_size(), max_schur_size(), + front_update_mode(), serial_thres_size(), x() {} + //= default; + + KOKKOS_INLINE_FUNCTION + SupernodeInfo(const SupernodeInfo &b) + : supernodes(b.supernodes), gid_colidx(b.gid_colidx), sid_block_colidx(b.sid_block_colidx), max_nchildren(b.max_nchildren), max_supernode_size(b.max_supernode_size), max_schur_size(b.max_schur_size), - front_update_mode(b.front_update_mode), serial_thres_size(b.serial_thres_size), - x(b.x) {} - //= default; - - static - inline - void - initialize(/* */ SupernodeInfo &self, - /* */ supernode_type_array &supernodes_, - /* */ ordinal_pair_type_array &sid_block_colidx_, - /* */ value_type_array &superpanel_buf_, - // symbolic input - const ordinal_type_array &snodes_, - const size_type_array &gid_ptr_, - const ordinal_type_array &gid_colidx_, - const size_type_array &sid_ptr_, - const ordinal_type_array &sid_colidx_, - const ordinal_type_array &blk_colidx_, - // tree hierarchy - const ordinal_type_array &stree_parent_, - const size_type_array &stree_ptr_, - const ordinal_type_array &stree_children_) { - const ordinal_type nsupernodes = snodes_.extent(0) - 1; - - /// allocate and assign supernodes - supernodes_ = supernode_type_array("supernodes", nsupernodes); // managed view - - sid_block_colidx_ = ordinal_pair_type_array("sid_block_colidx", sid_colidx_.span()); - - // by default, update mode is atomic: 0 - mutex lock, 1 - atomic - self.front_update_mode = 1; - self.serial_thres_size = 0; - - /// workspace parameter initialization - self.max_nchildren = 0; - self.max_supernode_size = 0; - self.max_schur_size = 0; - - Kokkos::RangePolicy supernodes_range_policy(0,nsupernodes); - SuperNodeInfoInitReducer::value_type init_reduce_val; - Kokkos::parallel_reduce - (supernodes_range_policy, - KOKKOS_LAMBDA(const ordinal_type &sid, SuperNodeInfoInitReducer::value_type &update) { - auto &s = supernodes_(sid); - - s.row_begin = snodes_(sid); - s.m = snodes_(sid+1) - snodes_(sid); - s.n = blk_colidx_(sid_ptr_(sid+1)-1); - - s.gid_col_begin = gid_ptr_(sid); s.gid_col_end = gid_ptr_(sid+1); - s.sid_col_begin = sid_ptr_(sid); s.sid_col_end = sid_ptr_(sid+1); - - for (ordinal_type i=s.sid_col_begin;i supernodes_range_policy(0, nsupernodes); + SuperNodeInfoInitReducer::value_type init_reduce_val; + Kokkos::parallel_reduce( + supernodes_range_policy, + KOKKOS_LAMBDA(const ordinal_type &sid, SuperNodeInfoInitReducer::value_type &update) { + auto &s = supernodes_(sid); + + s.row_begin = snodes_(sid); + s.m = snodes_(sid + 1) - snodes_(sid); + s.n = blk_colidx_(sid_ptr_(sid + 1) - 1); + + s.gid_col_begin = gid_ptr_(sid); + s.gid_col_end = gid_ptr_(sid + 1); + s.sid_col_begin = sid_ptr_(sid); + s.sid_col_end = sid_ptr_(sid + 1); + + for (ordinal_type i = s.sid_col_begin; i < s.sid_col_end; ++i) { + sid_block_colidx_(i).first = sid_colidx_(i); + sid_block_colidx_(i).second = blk_colidx_(i); } - Kokkos::deep_copy(supernodes_, h_supernodes); - } - // need to iterate parallel for .. let's not do this way - // Kokkos::parallel_for - // (supernodes_range_policy, KOKKOS_LAMBDA(const ordinal_type &sid) { - // auto &s = supernodes_(sid); - // const ordinal_type sidpar = stree_parent_(sid); - // if (sidpar != -1) { - // auto &spar = supernodes_(sidpar); - // spar.max_decendant_supernode_size = max(s.max_decendant_supernode_size, - // spar.max_decendant_supernode_size); - // spar.max_decendant_schur_size = max(s.max_decendant_schur_size, - // spar.max_decendant_schur_size); - // } - // }); - - self.max_nchildren = init_reduce_val.max_nchildren; - self.max_supernode_size = init_reduce_val.max_supernode_size; - self.max_schur_size = init_reduce_val.max_schur_size; - - // supernodal factor array; data is held outside with a managed view - // supernode does not include this view - // for the case that the same data structure is reused, the buffer will - // be zero'ed for each numeric factorization. - superpanel_buf_ = value_type_array(do_not_initialize_tag("superpanel_buf"), init_reduce_val.nnz); - Kokkos::parallel_scan - (supernodes_range_policy, KOKKOS_LAMBDA(const ordinal_type &sid, size_type &update, const bool &final) { - auto &s = supernodes_(sid); - if (final) - s.buf = &superpanel_buf_(update); - update += s.m * s.n; - }); - - self.supernodes = supernodes_; // unmanaged view, data is held outside - self.gid_colidx = gid_colidx_; - self.sid_block_colidx = sid_block_colidx_; - } - inline - void - initialize(/* */ supernode_type_array &supernodes_, - /* */ ordinal_pair_type_array &sid_block_colidx_, - /* */ value_type_array &superpanel_buf_, - // symbolic input - const ordinal_type_array &snodes_, - const size_type_array &gid_ptr_, - const ordinal_type_array &gid_colidx_, - const size_type_array &sid_ptr_, - const ordinal_type_array &sid_colidx_, - const ordinal_type_array &blk_colidx_, - // tree hierarchy - const ordinal_type_array &stree_parent_, - const size_type_array &stree_ptr_, - const ordinal_type_array &stree_children_) { - initialize(*this, - supernodes_, - sid_block_colidx_, - superpanel_buf_, - snodes_, - gid_ptr_, - gid_colidx_, - sid_ptr_, - sid_colidx_, - blk_colidx_, - stree_parent_, - stree_ptr_, - stree_children_); + s.nchildren = stree_ptr_(sid + 1) - stree_ptr_(sid); + s.children = &stree_children_(stree_ptr_(sid)); + // const ordinal_type offset = stree_ptr_(sid); + // for (ordinal_type i=0;i > - policy(nsupernodes, Kokkos::AUTO()); // team and vector sizes are AUTO selected. - - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA (const typename Kokkos::TeamPolicy::member_type &member) { - const ordinal_type sid = member.league_rank(); - const auto s = self.supernodes(sid); - dense_block_type tgt(s.buf, s.m, s.n);; - - // row major access to sparse src - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, s.m), [&](const ordinal_type &i) { - const ordinal_type - ii = i + s.row_begin, // row in U - row = perm(ii), kbeg = ap(row), kend = ap(row+1); // row in A - const ordinal_type kcnt = kend - kbeg; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, kcnt), - [&, kbeg, ii](const ordinal_type &kk) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - const ordinal_type k = kk + kbeg; - const ordinal_type jj = peri(aj(k) /* col in A */); // col in U - if (ii <= jj) { - ordinal_type *first = self.gid_colidx.data() + s.gid_col_begin; - ordinal_type *last = self.gid_colidx.data() + s.gid_col_end; - ordinal_type *loc = lower_bound(first, last, jj, - [](ordinal_type left, ordinal_type right) { - return left < right; }); - TACHO_TEST_FOR_ABORT(*loc != jj, " copy is wrong" ); - tgt(i, loc-first) = ax(k); - } - }); - }); + Kokkos::deep_copy(supernodes_, h_supernodes); + } + + self.max_nchildren = init_reduce_val.max_nchildren; + self.max_supernode_size = init_reduce_val.max_supernode_size; + self.max_schur_size = init_reduce_val.max_schur_size; + + // supernodal factor array; data is held outside with a managed view + // supernode does not include this view + // for the case that the same data structure is reused, the buffer will + // be zero'ed for each numeric factorization. + superpanel_buf_ = value_type_array(do_not_initialize_tag("superpanel_buf"), init_reduce_val.nnz); + Kokkos::parallel_scan( + supernodes_range_policy, KOKKOS_LAMBDA(const ordinal_type &sid, size_type &update, const bool &final) { + auto &s = supernodes_(sid); + const ordinal_type u_buf_size = s.m * s.n; + const ordinal_type l_buf_size = allocate_l_buf_ ? (s.m * s.n) : 0; + if (final) { + s.u_buf = &superpanel_buf_(update); + s.l_buf = (allocate_l_buf_ ? s.u_buf + u_buf_size : NULL); + } + update += (u_buf_size + l_buf_size); + }); + + self.supernodes = supernodes_; // unmanaged view, data is held outside + self.gid_colidx = gid_colidx_; + self.sid_block_colidx = sid_block_colidx_; + } + + inline void initialize(/* */ supernode_type_array &supernodes_, + /* */ ordinal_pair_type_array &sid_block_colidx_, + /* */ value_type_array &superpanel_buf_, + /// control + const bool allocate_l_buf_, + // symbolic input + const ordinal_type_array &snodes_, const size_type_array &gid_ptr_, + const ordinal_type_array &gid_colidx_, const size_type_array &sid_ptr_, + const ordinal_type_array &sid_colidx_, const ordinal_type_array &blk_colidx_, + // tree hierarchy + const ordinal_type_array &stree_parent_, const size_type_array &stree_ptr_, + const ordinal_type_array &stree_children_) { + initialize(*this, + /// output + supernodes_, sid_block_colidx_, superpanel_buf_, + /// control + allocate_l_buf_, + /// super node input + snodes_, gid_ptr_, gid_colidx_, sid_ptr_, sid_colidx_, blk_colidx_, stree_parent_, stree_ptr_, + stree_children_); + } + + static inline void copySparseToSuperpanels(SupernodeInfo &self, + /// control + const bool copy_to_l_buf, + /// input from sparse matrix + const size_type_array &ap, const ordinal_type_array &aj, + const value_type_array &ax, const ordinal_type_array &perm, + const ordinal_type_array &peri) { + const ordinal_type nsupernodes = self.supernodes.extent(0); + using policy_type = Kokkos::TeamPolicy>; + + value_type_array axt; + if (copy_to_l_buf) { + axt = value_type_array("axt", ax.extent(0)); + policy_type policy(ap.extent(0) - 1, Kokkos::AUTO()); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const typename policy_type::member_type &member) { + const ordinal_type row = member.league_rank(); + const ordinal_type kbeg = ap(row), kend = ap(row + 1); + + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, kbeg, kend), [&](const ordinal_type &k) { + const ordinal_type i = aj(k), j = row; + { + const ordinal_type lbeg = ap(i), lend = ap(i + 1); + ordinal_type *first = &aj(lbeg); + ordinal_type *last = &aj(lend); + ordinal_type *loc = + lower_bound(first, last, j, [](ordinal_type left, ordinal_type right) { return left < right; }); + TACHO_TEST_FOR_ABORT(*loc != j, "transpose fail"); + axt(lbeg + loc - first) = ax(k); + } + }); }); -#else - const ordinal_type nsupernodes = self.supernodes.extent(0); - const ordinal_type m = ap.extent(0) - 1; - Kokkos::TeamPolicy > - policy(nsupernodes, Kokkos::AUTO()); // team and vector sizes are AUTO selected. - - typedef typename exec_space::scratch_memory_space shmem_space; - typedef Kokkos::View team_shared_memory_view_type; - const ordinal_type lvl = 0, per_team_scratch = team_shared_memory_view_type::shmem_size(m); - - Kokkos::parallel_for - (policy.set_scratch_size(lvl, Kokkos::PerTeam(per_team_scratch)), - KOKKOS_LAMBDA ( const typename Kokkos::TeamPolicy::member_type &member) { - team_shared_memory_view_type work(member.team_shmem(), m); + } + + { + policy_type policy(nsupernodes, Kokkos::AUTO()); // team and vector sizes are AUTO selected. + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const typename policy_type::member_type &member) { const ordinal_type sid = member.league_rank(); const auto s = self.supernodes(sid); - dense_block_type tgt(s.buf, s.m, s.n);; - - // local to global map - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, s.n), [&](const ordinal_type &j) { - Kokkos::single(Kokkos::PerThread(member), [&]() { - work[self.gid_colidx(j+s.gid_col_begin) /* = col */] = j; - }); + /// copy to upper triangular + { + UnmanagedViewType tgt_u(s.u_buf, s.m, s.n); + UnmanagedViewType tgt_lp(s.l_buf, s.n, s.m); + const auto tgt_l = Kokkos::subview(tgt_lp, range_type(s.m, s.n), Kokkos::ALL()); + + // row major access to sparse src + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, s.m), [&](const ordinal_type &i) { + const ordinal_type ii = i + s.row_begin, // row in U + row = perm(ii), kbeg = ap(row), kend = ap(row + 1); // row in A + + const ordinal_type jjbeg = (copy_to_l_buf ? s.row_begin : ii); + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, kbeg, kend), + [&, jjbeg, ii](const ordinal_type &k) { // Value capture is a workaround for cuda + + // gcc-7.2 compiler bug w/c++14 + const ordinal_type jj = peri(aj(k) /* col in A */); // col in U + if (jjbeg <= jj) { + ordinal_type *first = self.gid_colidx.data() + s.gid_col_begin; + ordinal_type *last = self.gid_colidx.data() + s.gid_col_end; + ordinal_type *loc = + lower_bound(first, last, jj, [](ordinal_type left, ordinal_type right) { + return left < right; + }); + TACHO_TEST_FOR_ABORT(*loc != jj, "copy is wrong"); + const ordinal_type j = loc - first; + tgt_u(i, j) = ax(k); + if (j >= s.m && copy_to_l_buf) { + tgt_l(j - s.m, i) = axt(k); + } + } + }); }); - - // row major access to sparse src - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, 0, s.m), [&](const ordinal_type &i) { - const ordinal_type - ii = i + s.row_begin, // row in U - row = perm(ii), kbeg = ap(row), kend = ap(row+1); // row in A - const ordinal_type kcnt = kend - kbeg; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, kcnt), [&](const ordinal_type &kk) { - const ordinal_type k = kk + kbeg; - const ordinal_type jj = peri(aj(k) /* col in A */); // col in U - if (ii <= jj) - tgt(i, work[jj]) = ax(k); - }); - }); - }); -#endif - } - - inline - void - copySparseToSuperpanels(// input from sparse matrix - const size_type_array &ap, - const ordinal_type_array &aj, - const value_type_array &ax, - const ordinal_type_array &perm, - const ordinal_type_array &peri) { - copySparseToSuperpanels(*this, - ap, - aj, - ax, - perm, - peri); - } - - static - inline - void - createCrsMatrix(SupernodeInfo &self, - crs_matrix_type &A, - const bool replace_value_with_one = false) { - // count m, n, nnz - const ordinal_type nsupernodes = self.supernodes.extent(0); - - auto d_last = Kokkos::subview(self.supernodes, nsupernodes - 1); - auto h_last = Kokkos::create_mirror_view(host_memory_space(), d_last); - Kokkos::deep_copy(h_last, d_last); - auto last = h_last(); - //auto &last = supernodes(nsupernodes - 1); - - const ordinal_type mm = last.row_begin + last.m, nn = mm; - - Kokkos::RangePolicy supernodes_range_policy(0,nsupernodes); - - // parallel for/scan version - size_type_array ap_tmp("ap", mm+1); - Kokkos::parallel_for - (supernodes_range_policy, KOKKOS_LAMBDA(const ordinal_type &sid) { - // row major access to sparse src - const auto &s = self.supernodes(sid); - const ordinal_type soffset = s.row_begin; - for (ordinal_type i=0;i(0,mm+1), - KOKKOS_LAMBDA(const ordinal_type &i, size_type &update, const bool &final) { - if (final) - ap(i) = update; - update += ap_tmp(i); - }); - - // fill the matrix - auto d_nnz = Kokkos::subview(ap, mm); - auto h_nnz = Kokkos::create_mirror_view(host_memory_space(), d_nnz); - Kokkos::deep_copy(h_nnz, d_nnz); - - const auto nnz = h_nnz(); - ordinal_type_array aj("aj", nnz); - value_type_array ax("ax", nnz); - - Kokkos::parallel_for - (supernodes_range_policy, KOKKOS_LAMBDA(const ordinal_type &sid) { - const auto &s = self.supernodes(sid); - - dense_block_type src; - src.set_view(s.m, s.n); - src.attach_buffer(1, s.m, s.buf); - - // row major access to sparse src - const ordinal_type - soffset = s.row_begin, - goffset = s.gid_col_begin; - - for (ordinal_type i=0;i supernodes_range_policy(0,nsupernodes); + + // // parallel for/scan version + // size_type_array ap_tmp("ap", mm+1); + // Kokkos::parallel_for + // (supernodes_range_policy, KOKKOS_LAMBDA(const ordinal_type &sid) { + // // row major access to sparse src + // const auto &s = self.supernodes(sid); + // const ordinal_type soffset = s.row_begin; + // for (ordinal_type i=0;i(0,mm+1), + // KOKKOS_LAMBDA(const ordinal_type &i, size_type &update, const bool &final) { + // if (final) + // ap(i) = update; + // update += ap_tmp(i); + // }); + + // // fill the matrix + // auto d_nnz = Kokkos::subview(ap, mm); + // auto h_nnz = Kokkos::create_mirror_view(host_memory_space(), d_nnz); + // Kokkos::deep_copy(h_nnz, d_nnz); + + // const auto nnz = h_nnz(); + // ordinal_type_array aj("aj", nnz); + // value_type_array ax("ax", nnz); + + // Kokkos::parallel_for + // (supernodes_range_policy, KOKKOS_LAMBDA(const ordinal_type &sid) { + // const auto &s = self.supernodes(sid); + + // UnmanagedViewType src(s.u_buf, s.m, s.n); + + // // row major access to sparse src + // const ordinal_type + // soffset = s.row_begin, + // goffset = s.gid_col_begin; + + // for (ordinal_type i=0;i= 0) { - const ordinal_type p = stack(top); - const ordinal_type i = head(p); - if (i == -1) { - --top; - post(k++) = p; - } else { - head(p) = next(i); - stack(++top) = i; - } +} + +ordinal_type SymbolicTools::TreeDepthFirstSearch(const ordinal_type j, const ordinal_type c, + const ordinal_type_array &head, const ordinal_type_array &next, + const ordinal_type_array &post, const ordinal_type_array &stack) { + ordinal_type top = 0, k = c; + stack(top) = j; + while (top >= 0) { + const ordinal_type p = stack(top); + const ordinal_type i = head(p); + if (i == -1) { + --top; + post(k++) = p; + } else { + head(p) = next(i); + stack(++top) = i; } - return k; } - - void SymbolicTools:: - computePostOrdering(const ordinal_type m, - const ordinal_type_array &parent, - const ordinal_type_array &post, - const ordinal_type_array &work) { - auto head = Kokkos::subview(work, range_type(0*m, 1*m)); - auto next = Kokkos::subview(work, range_type(1*m, 2*m)); - auto stack = Kokkos::subview(work, range_type(2*m, 3*m)); - - for (ordinal_type i=0;i=0;--i) { - const ordinal_type p = parent(i); - if (p != -1) { - next(i) = head(p); - head(p) = i; - } + return k; +} + +void SymbolicTools::computePostOrdering(const ordinal_type m, const ordinal_type_array &parent, + const ordinal_type_array &post, const ordinal_type_array &work) { + auto head = Kokkos::subview(work, range_type(0 * m, 1 * m)); + auto next = Kokkos::subview(work, range_type(1 * m, 2 * m)); + auto stack = Kokkos::subview(work, range_type(2 * m, 3 * m)); + + for (ordinal_type i = 0; i < m; ++i) + head(i) = -1; + + for (ordinal_type i = m - 1; i >= 0; --i) { + const ordinal_type p = parent(i); + if (p != -1) { + next(i) = head(p); + head(p) = i; } - ordinal_type k = 0; - for (ordinal_type i=0;i= 0) ++count(parent(i)); - - // parent has more than a child, it becomes a supernode candidate - // roots are supernodes - for (ordinal_type i=0;i 1 || parent(i) < 0) flag(i) = true; - - // accumulate subtree sizes in count. - for (ordinal_type i=0;i= 0) count(parent(i)) += count(i); - - // tree leaves are also supernode candidate (not easy to understand this) - for (ordinal_type i=0;i= 0) + ++count(parent(i)); + + // parent has more than a child, it becomes a supernode candidate + // roots are supernodes + for (ordinal_type i = 0; i < m; ++i) + if (count(i) > 1 || parent(i) < 0) + flag(i) = true; + + // accumulate subtree sizes in count. + for (ordinal_type i = 0; i < m; ++i) + count(i) = 1; + for (ordinal_type i = 0; i < m; ++i) + if (parent(i) >= 0) + count(parent(i)) += count(i); + + // tree leaves are also supernode candidate (not easy to understand this) + for (ordinal_type i = 0; i < m; ++i) { + const ordinal_type ii = perm(i); + for (size_type p = ap(ii); p < ap(ii + 1); ++p) { + const ordinal_type j = peri(aj(p)); + if (i < j) { + const ordinal_type k = prev(j); + if (k < (i - count(i) + 1)) + flag(i) = true; + prev(j) = i; } } - - // count # of supernodes - { - ordinal_type k = 0; - flag(k) = true; // supernodes begin - - for (ordinal_type i=0;i ordinal_type { - // # of columns accessed by this super node - ordinal_type cnt = 0; - - // loop over super node cols (diagonal block) - for (ordinal_type col=sbeg;col ordinal_type { - ordinal_type cnt = 0; - for (ordinal_type k=0;k ordinal_type { + // # of columns accessed by this super node + ordinal_type cnt = 0; + + // loop over super node cols (diagonal block) + for (ordinal_type col = sbeg; col < send; ++col) { + flag(col) = true; // visitation flag on + cid(cnt++) = col; // record the column indicies + } + + // visit each super node row (off diagonal block) + for (ordinal_type i = sbeg; i < send; ++i) { + for (size_type j = ap_(i); j < ap_(i + 1); ++j) { + const ordinal_type col = aj_(j); + if (flag(col)) { + // already visited; pass on + } else { + flag(col) = true; + cid(cnt++) = col; } } - return cnt; - }; - - auto flag = Kokkos::subview(work, range_type(0*m, 1*m)); - auto cid = Kokkos::subview(work, range_type(1*m, 2*m)); - auto rid = Kokkos::subview(work, range_type(2*m, 3*m)); - auto tmp = Kokkos::subview(work, range_type(3*m, 4*m)); - - // zeros - Kokkos::deep_copy(work, 0); - - /// - /// super color in rows - /// - for (ordinal_type sid=0;sid ordinal_type { + ordinal_type cnt = 0; + for (ordinal_type k = 0; k < ndofs; ++k) { + const ordinal_type sid = sid_colored_in_rows(cid(k)); + if (count(sid) == 0) + sids_connected_to_this_row(cnt++) = sid; + ++count(sid); } - } - - void SymbolicTools:: - computeSupernodesAssemblyTree(const ordinal_type_array &parent, - const ordinal_type_array &supernodes, - /* */ ordinal_type_array &stree_level, - /* */ ordinal_type_array &stree_parent, - /* */ size_type_array &stree_ptr, - /* */ ordinal_type_array &stree_children, - /* */ ordinal_type_array &stree_roots, - const ordinal_type_array &work) { - const ordinal_type numSupernodes = supernodes.extent(0) - 1; - const ordinal_type m = supernodes(numSupernodes); - - stree_parent = ordinal_type_array("stree_parent", numSupernodes); - auto flag = Kokkos::subview(work, range_type(0*m, 1*m)); - - // color flag with supernodes (for the ease to detect supernode id from dofs) - for (ordinal_type i=0;i= 0) { - const ordinal_type sidpar = flag(parent(i)); - if (sidpar != sid) stree_parent(sid) = sidpar; - } + if (blks_connected_to_this_row.data() != NULL) { + for (ordinal_type k = 0; k < cnt; ++k) { + const ordinal_type sid = sids_connected_to_this_row(k); + blks_connected_to_this_row(k) = count(sid); } } - - auto clear_array = [](const ordinal_type cnt, - const ordinal_type_array &a) { - memset(a.data(), 0, cnt*sizeof(typename ordinal_type_array::value_type)); - }; - - // construct parent - child relations - { - clear_array(m, flag); - ordinal_type cnt = 0; - for (ordinal_type sid=0;sid= 0) { + const ordinal_type sidpar = flag(parent(i)); + if (sidpar != sid) + stree_parent(sid) = sidpar; } - stree_roots = ordinal_type_array(do_not_initialize_tag("stree_roots"), cnt); - } - - // prefix scan - { - stree_ptr = size_type_array(do_not_initialize_tag("stree_ptr"), numSupernodes + 1); stree_ptr(0) = size_type(); - for (ordinal_type sid=0;sid::type host_device_type; - typedef typename host_device_type::execution_space host_space; - typedef typename host_device_type::memory_space host_memory_space; - - typedef Kokkos::View ordinal_type_array; - typedef Kokkos::View size_type_array; - - typedef Kokkos::pair range_type; - - /// - /// supernode tools - /// - - // Tim Davis, Direct Methods for Sparse Linear Systems, Siam, p 42. - static void - computeEliminationTree(const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - const ordinal_type_array &perm, - const ordinal_type_array &peri, - const ordinal_type_array &parent, - const ordinal_type_array &ancestor); - - // Tim Davis, Direct Methods for Sparse Linear Systems, Siam, p 45. - static ordinal_type - TreeDepthFirstSearch(const ordinal_type j, - const ordinal_type c, - const ordinal_type_array &head, - const ordinal_type_array &next, - const ordinal_type_array &post, - const ordinal_type_array &stack); - - // Tim Davis, Direct Methods for Sparse Linear Systems, Siam, p 45. - static void - computePostOrdering(const ordinal_type m, - const ordinal_type_array &parent, - const ordinal_type_array &post, - const ordinal_type_array &work); - - // Tim Davis, Algorithm 849: A Concise Sparse Cholesky Factorization Package - // ACM TOMS Vol 31 No. 4 pp 587--591. - static void - computeFillPatternUpper(const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - const ordinal_type_array &perm, - const ordinal_type_array &peri, - /* */ size_type_array &up, - /* */ ordinal_type_array &uj, - const ordinal_type_array &work); - - // Joseph, W. H. Liu, Esmond G. Ng, and Barry W. Peyton, - // "On Finding Supernodes for Sparse Matrix Computations," - // SIAM J. Matrix Anal. Appl., Vol. 14, No. 1, pp. 242-252. - static void - computeSupernodes(const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - const ordinal_type_array &perm, - const ordinal_type_array &peri, - const ordinal_type_array &parent, - /* */ ordinal_type_array &supernodes, - const ordinal_type_array &work); - - /// Based on the symbolic factors, allocate pannels - static void - allocateSupernodes(const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - const ordinal_type_array &supernodes, - const ordinal_type_array &work, - /* */ size_type_array &gid_super_panel_ptr, - /* */ ordinal_type_array &gid_super_panel_colidx, - /* */ size_type_array &sid_super_panel_ptr, - /* */ ordinal_type_array &sid_super_panel_colidx, - /* */ ordinal_type_array &blk_super_panel_colidx); - - /// construct tree explicitly - static void - computeSupernodesAssemblyTree(const ordinal_type_array &parent, - const ordinal_type_array &supernodes, - /* */ ordinal_type_array &stree_level, - /* */ ordinal_type_array &stree_parent, - /* */ size_type_array &stree_ptr, - /* */ ordinal_type_array &stree_children, - /* */ ordinal_type_array &stree_roots, - const ordinal_type_array &work); - - /// - /// evaporation tools - /// - static void - scanWeights(const ordinal_type m, - const ordinal_type_array &aw, - const ordinal_type_array &perm, - /* */ size_type_array &as, - /* */ size_type_array &aq); - - static void - evaporateGraph(const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - const size_type_array &as, - /* */ size_type_array &ap_eva, - /* */ ordinal_type_array &aj_eva); - - static void - evaporatePermutationVectors(const ordinal_type m, - const ordinal_type_array &perm, - const ordinal_type m_eva, - const ordinal_type_array &aw, - const size_type_array &aq, - /* */ ordinal_type_array &perm_eva, - /* */ ordinal_type_array &peri_eva); - - static void - evaporateSupernodes(const ordinal_type_array &supernodes, - const size_type_array &sid_super_panel_ptr, - const size_type_array &gid_super_panel_ptr, - const ordinal_type_array &gid_super_panel_colidx, - const ordinal_type_array &blk_super_panel_colidx, - const ordinal_type_array &perm, - const size_type_array &as, - const ordinal_type_array &peri_eva, - /* */ ordinal_type_array &supernodes_eva, - /* */ size_type_array &gid_super_panel_ptr_eva, - /* */ ordinal_type_array &gid_super_panel_colidx_eva, - /* */ ordinal_type_array &blk_super_panel_colidx_eva); - - private: - // matrix input - ordinal_type _m; - size_type_array _ap; - ordinal_type_array _aj; - - // graph ordering input - ordinal_type_array _perm, _peri; - - // supernodes output - ordinal_type_array _supernodes; - - // dof mapping to sparse matrix - size_type_array _gid_super_panel_ptr; - ordinal_type_array _gid_super_panel_colidx; - - // supernode map and panel size configuration - size_type_array _sid_super_panel_ptr; - ordinal_type_array _sid_super_panel_colidx, _blk_super_panel_colidx; - - // supernode elimination tree (parent - children) - size_type_array _stree_ptr; - ordinal_type_array _stree_children, _stree_roots; - - // supernode elimination tree (child - parent) - ordinal_type_array _stree_parent; - - // level information of supernodes - ordinal_type_array _stree_level; - - // stat - struct { - ordinal_type nrows, nroots; - size_type nnz_a, nnz_u; - ordinal_type nsupernodes, max_nchildren, largest_supernode, largest_schur; - ordinal_type nleaves, height; // tree - } stat; - - public: - SymbolicTools(); - SymbolicTools(const SymbolicTools &b); - - /// - /// construction - /// - SymbolicTools(const ordinal_type m, - const size_type_array &ap, - const ordinal_type_array &aj, - const ordinal_type_array &perm, - const ordinal_type_array &peri); - - template - SymbolicTools(CrsMatBaseType &A, - GraphToolType &G) { - _m = A.NumRows(); - - _ap = Kokkos::create_mirror_view(host_memory_space(), A.RowPtr()); - _aj = Kokkos::create_mirror_view(host_memory_space(), A.Cols()); - _perm = Kokkos::create_mirror_view(host_memory_space(), G.PermVector()); - _peri = Kokkos::create_mirror_view(host_memory_space(), G.InvPermVector()); - - Kokkos::deep_copy(_ap, A.RowPtr()); - Kokkos::deep_copy(_aj, A.Cols()); - Kokkos::deep_copy(_perm, G.PermVector()); - Kokkos::deep_copy(_peri, G.InvPermVector()); - } - - ordinal_type NumSupernodes() const; - ordinal_type_array Supernodes() const; - size_type_array gidSuperPanelPtr() const; - ordinal_type_array gidSuperPanelColIdx() const; - size_type_array sidSuperPanelPtr() const; - ordinal_type_array sidSuperPanelColIdx() const; - ordinal_type_array blkSuperPanelColIdx() const; - ordinal_type_array SupernodesTreeParent() const; - size_type_array SupernodesTreePtr() const; - ordinal_type_array SupernodesTreeChildren() const; - ordinal_type_array SupernodesTreeRoots() const; - ordinal_type_array SupernodesTreeLevel() const; - ordinal_type_array PermVector() const; - ordinal_type_array InvPermVector() const; - - void symbolicFactorize(const ordinal_type verbose = 0); - void evaporateSymbolicFactors(const ordinal_type_array &aw, - const ordinal_type verbose = 0); - }; - -} +class SymbolicTools { +public: + typedef typename UseThisDevice::type host_device_type; + typedef typename host_device_type::execution_space host_space; + typedef typename host_device_type::memory_space host_memory_space; + + typedef Kokkos::View ordinal_type_array; + typedef Kokkos::View size_type_array; + + typedef Kokkos::pair range_type; + + /// + /// supernode tools + /// + + // Tim Davis, Direct Methods for Sparse Linear Systems, Siam, p 42. + static void computeEliminationTree(const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + const ordinal_type_array &perm, const ordinal_type_array &peri, + const ordinal_type_array &parent, const ordinal_type_array &ancestor); + + // Tim Davis, Direct Methods for Sparse Linear Systems, Siam, p 45. + static ordinal_type TreeDepthFirstSearch(const ordinal_type j, const ordinal_type c, const ordinal_type_array &head, + const ordinal_type_array &next, const ordinal_type_array &post, + const ordinal_type_array &stack); + + // Tim Davis, Direct Methods for Sparse Linear Systems, Siam, p 45. + static void computePostOrdering(const ordinal_type m, const ordinal_type_array &parent, + const ordinal_type_array &post, const ordinal_type_array &work); + + // Tim Davis, Algorithm 849: A Concise Sparse Cholesky Factorization Package + // ACM TOMS Vol 31 No. 4 pp 587--591. + static void computeFillPatternUpper(const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + const ordinal_type_array &perm, const ordinal_type_array &peri, + /* */ size_type_array &up, + /* */ ordinal_type_array &uj, const ordinal_type_array &work); + + // Joseph, W. H. Liu, Esmond G. Ng, and Barry W. Peyton, + // "On Finding Supernodes for Sparse Matrix Computations," + // SIAM J. Matrix Anal. Appl., Vol. 14, No. 1, pp. 242-252. + static void computeSupernodes(const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + const ordinal_type_array &perm, const ordinal_type_array &peri, + const ordinal_type_array &parent, + /* */ ordinal_type_array &supernodes, const ordinal_type_array &work); + + /// Based on the symbolic factors, allocate pannels + static void allocateSupernodes(const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + const ordinal_type_array &supernodes, const ordinal_type_array &work, + /* */ size_type_array &gid_super_panel_ptr, + /* */ ordinal_type_array &gid_super_panel_colidx, + /* */ size_type_array &sid_super_panel_ptr, + /* */ ordinal_type_array &sid_super_panel_colidx, + /* */ ordinal_type_array &blk_super_panel_colidx); + + /// construct tree explicitly + static void computeSupernodesAssemblyTree(const ordinal_type_array &parent, const ordinal_type_array &supernodes, + /* */ ordinal_type_array &stree_level, + /* */ ordinal_type_array &stree_parent, + /* */ size_type_array &stree_ptr, + /* */ ordinal_type_array &stree_children, + /* */ ordinal_type_array &stree_roots, const ordinal_type_array &work); + + /// + /// evaporation tools + /// + static void scanWeights(const ordinal_type m, const ordinal_type_array &aw, const ordinal_type_array &perm, + /* */ size_type_array &as, + /* */ size_type_array &aq); + + static void evaporateGraph(const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + const size_type_array &as, + /* */ size_type_array &ap_eva, + /* */ ordinal_type_array &aj_eva); + + static void evaporatePermutationVectors(const ordinal_type m, const ordinal_type_array &perm, + const ordinal_type m_eva, const ordinal_type_array &aw, + const size_type_array &aq, + /* */ ordinal_type_array &perm_eva, + /* */ ordinal_type_array &peri_eva); + + static void evaporateSupernodes(const ordinal_type_array &supernodes, const size_type_array &sid_super_panel_ptr, + const size_type_array &gid_super_panel_ptr, + const ordinal_type_array &gid_super_panel_colidx, + const ordinal_type_array &blk_super_panel_colidx, const ordinal_type_array &perm, + const size_type_array &as, const ordinal_type_array &peri_eva, + /* */ ordinal_type_array &supernodes_eva, + /* */ size_type_array &gid_super_panel_ptr_eva, + /* */ ordinal_type_array &gid_super_panel_colidx_eva, + /* */ ordinal_type_array &blk_super_panel_colidx_eva); + +private: + // matrix input + ordinal_type _m; + size_type_array _ap; + ordinal_type_array _aj; + + // graph ordering input + ordinal_type_array _perm, _peri; + + // supernodes output + ordinal_type_array _supernodes; + + // dof mapping to sparse matrix + size_type_array _gid_super_panel_ptr; + ordinal_type_array _gid_super_panel_colidx; + + // supernode map and panel size configuration + size_type_array _sid_super_panel_ptr; + ordinal_type_array _sid_super_panel_colidx, _blk_super_panel_colidx; + + // supernode elimination tree (parent - children) + size_type_array _stree_ptr; + ordinal_type_array _stree_children, _stree_roots; + + // supernode elimination tree (child - parent) + ordinal_type_array _stree_parent; + + // level information of supernodes + ordinal_type_array _stree_level; + + // stat + struct { + ordinal_type nrows, nroots; + size_type nnz_a, nnz_u; + ordinal_type nsupernodes, max_nchildren, largest_supernode, largest_schur; + ordinal_type nleaves, height; // tree + } stat; + +public: + SymbolicTools(); + SymbolicTools(const SymbolicTools &b); + + /// + /// construction + /// + SymbolicTools(const ordinal_type m, const size_type_array &ap, const ordinal_type_array &aj, + const ordinal_type_array &perm, const ordinal_type_array &peri); + + template SymbolicTools(CrsMatBaseType &A, GraphToolType &G) { + _m = A.NumRows(); + + _ap = Kokkos::create_mirror_view(host_memory_space(), A.RowPtr()); + _aj = Kokkos::create_mirror_view(host_memory_space(), A.Cols()); + _perm = Kokkos::create_mirror_view(host_memory_space(), G.PermVector()); + _peri = Kokkos::create_mirror_view(host_memory_space(), G.InvPermVector()); + + Kokkos::deep_copy(_ap, A.RowPtr()); + Kokkos::deep_copy(_aj, A.Cols()); + Kokkos::deep_copy(_perm, G.PermVector()); + Kokkos::deep_copy(_peri, G.InvPermVector()); + } + + ordinal_type NumSupernodes() const; + ordinal_type_array Supernodes() const; + size_type_array gidSuperPanelPtr() const; + ordinal_type_array gidSuperPanelColIdx() const; + size_type_array sidSuperPanelPtr() const; + ordinal_type_array sidSuperPanelColIdx() const; + ordinal_type_array blkSuperPanelColIdx() const; + ordinal_type_array SupernodesTreeParent() const; + size_type_array SupernodesTreePtr() const; + ordinal_type_array SupernodesTreeChildren() const; + ordinal_type_array SupernodesTreeRoots() const; + ordinal_type_array SupernodesTreeLevel() const; + ordinal_type_array PermVector() const; + ordinal_type_array InvPermVector() const; + + void symbolicFactorize(const ordinal_type verbose = 0); + void evaporateSymbolicFactors(const ordinal_type_array &aw, const ordinal_type verbose = 0); +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize.hpp index c26e93767c0a..121a47f1b1a6 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize.hpp @@ -9,13 +9,12 @@ namespace Tacho { - /// - /// Symmetrize - /// +/// +/// Symmetrize +/// - /// various implementation for different uplo and algo parameters - template - struct Symmetrize; -} +/// various implementation for different uplo and algo parameters +template struct Symmetrize; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_Internal.hpp index 5b56bd15d860..d8193e266da3 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_Internal.hpp @@ -1,49 +1,59 @@ #ifndef __TACHO_SYMMETRIZE_INTERNAL_HPP__ #define __TACHO_SYMMETRIZE_INTERNAL_HPP__ - /// \file Tacho_Symmetrize_Internal.hpp /// \brief Symmetrize a square block matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - template<> - struct Symmetrize { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const ViewTypeA &A) { - const ordinal_type - m = A.extent(0), - n = A.extent(1); - - if (m == n) { - if (A.span() > 0) { +template <> struct Symmetrize { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A) { + const ordinal_type m = A.extent(0), n = A.extent(1); + + if (m == n) { + if (A.span() > 0) { +#if defined(__CUDA_ARCH__) + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const ordinal_type &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j), [&](const ordinal_type &i) { A(j, i) = A(i, j); }); + }); +#else + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < j; ++i) + A(j, i) = A(i, j); +#endif + } + } else { + printf("Error: Symmetrize A is not square\n"); + } + return 0; + } +}; + +template <> struct Symmetrize { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const ViewTypeA &A) { + const ordinal_type m = A.extent(0), n = A.extent(1); + + if (m == n) { + if (A.span() > 0) { #if defined(__CUDA_ARCH__) - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, n), - [&](const ordinal_type &j) { - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, j), - [&](const ordinal_type &i) { - A(j,i) = A(i,j); - }); - }); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, n), [&](const ordinal_type &j) { + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j), [&](const ordinal_type &i) { A(i, j) = A(j, i); }); + }); #else - for (ordinal_type j=0;j A is not square\n"); } - return 0; + } else { + printf("Error: Symmetrize A is not square\n"); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_OnDevice.hpp index 96cb736557bd..23cee7e36a32 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Symmetrize_OnDevice.hpp @@ -1,68 +1,99 @@ #ifndef __TACHO_SYMMETRIZE_ON_DEVICE_HPP__ #define __TACHO_SYMMETRIZE_ON_DEVICE_HPP__ - /// \file Tacho_Symmetrize_OnDevice.hpp /// \brief Symmetrize a matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - template<> - struct Symmetrize { - template - inline - static int - invoke(const ViewTypeA &A) { - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(std::is_same::value, - "A is not accessible from host"); +template <> struct Symmetrize { + template inline static int invoke(const ViewTypeA &A) { + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(std::is_same::value, + "A is not accessible from host"); + + const ordinal_type m = A.extent(0), n = A.extent(1); + + if (m == n) { + if (A.span() > 0) { + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < j; ++i) + A(j, i) = A(i, j); + } + } else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); + } + return 0; + } + + template + inline static int invoke(ExecSpaceType &exec_instance, const ViewTypeA &A) { + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + const ordinal_type m = A.extent(0), n = A.extent(1); + + if (m == n) { + if (A.span() > 0) { + using exec_space = ExecSpaceType; + const Kokkos::RangePolicy policy(exec_instance, 0, m * m); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &ij) { + const ordinal_type i = ij % m; + const ordinal_type j = ij / m; + if (i < j) + A(j, i) = A(i, j); + }); + } + } else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); + } + return 0; + } +}; + +template <> struct Symmetrize { + template inline static int invoke(const ViewTypeA &A) { + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(std::is_same::value, + "A is not accessible from host"); - const ordinal_type - m = A.extent(0), - n = A.extent(1); + const ordinal_type m = A.extent(0), n = A.extent(1); - if (m == n) { - if (A.span() > 0) { - for (ordinal_type j=0;j 0) { + for (ordinal_type j = 0; j < n; ++j) + for (ordinal_type i = 0; i < j; ++i) + A(i, j) = A(j, i); } - return 0; + } else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); } + return 0; + } - template - inline - static int - invoke(ExecSpaceType &exec_instance, - const ViewTypeA &A) { - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - const ordinal_type - m = A.extent(0), - n = A.extent(1); + template + inline static int invoke(ExecSpaceType &exec_instance, const ViewTypeA &A) { + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + const ordinal_type m = A.extent(0), n = A.extent(1); - if (m == n) { - if (A.span() > 0) { - using exec_space = ExecSpaceType; - const Kokkos::RangePolicy policy(exec_instance, 0, m*m); - Kokkos::parallel_for - (policy, KOKKOS_LAMBDA(const ordinal_type &ij) { - const ordinal_type i = ij%m; - const ordinal_type j = ij/m; - if (i < j) - A(j,i) = A(i,j); + if (m == n) { + if (A.span() > 0) { + using exec_space = ExecSpaceType; + const Kokkos::RangePolicy policy(exec_instance, 0, m * m); + Kokkos::parallel_for( + policy, KOKKOS_LAMBDA(const ordinal_type &ij) { + const ordinal_type i = ij % m; + const ordinal_type j = ij / m; + if (i < j) + A(i, j) = A(j, i); }); - } - } else { - TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); } - return 0; + } else { + TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "A is not a square matrix"); } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeChol.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeChol.hpp index e62b403664af..7c3892eb9b6d 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeChol.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeChol.hpp @@ -10,339 +10,309 @@ namespace Tacho { - template - struct TeamFunctor_FactorizeChol { - public: - typedef Kokkos::pair range_type; - - typedef SupernodeInfoType supernode_info_type; - typedef typename supernode_info_type::supernode_type supernode_type; - - typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; - typedef typename supernode_info_type::size_type_array size_type_array; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_array value_type_array; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - typedef typename supernode_info_type::dense_block_type dense_block_type; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type CholAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsmAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type HerkAlgoType; - - private: - supernode_info_type _info; - ordinal_type_array _compute_mode, _level_sids; - ordinal_type _pbeg, _pend; - - size_type_array _buf_ptr; - value_type_array _buf; - - public: - KOKKOS_INLINE_FUNCTION - TeamFunctor_FactorizeChol() = delete; - - KOKKOS_INLINE_FUNCTION - TeamFunctor_FactorizeChol(const supernode_info_type &info, - const ordinal_type_array &compute_mode, - const ordinal_type_array &level_sids, - const value_type_array buf) - : - _info(info), - _compute_mode(compute_mode), - _level_sids(level_sids), - _buf(buf) - {} - - inline - void setRange(const ordinal_type pbeg, - const ordinal_type pend) { - _pbeg = pbeg; _pend = pend; - } +template struct TeamFunctor_FactorizeChol { +public: + typedef Kokkos::pair range_type; - inline - void setBufferPtr(const size_type_array &buf_ptr) { - _buf_ptr = buf_ptr; - } + typedef SupernodeInfoType supernode_info_type; + typedef typename supernode_info_type::supernode_type supernode_type; + + typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; + typedef typename supernode_info_type::size_type_array size_type_array; + + typedef typename supernode_info_type::value_type value_type; + typedef typename supernode_info_type::value_type_array value_type_array; + typedef typename supernode_info_type::value_type_matrix value_type_matrix; + +private: + supernode_info_type _info; + ordinal_type_array _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + size_type_array _buf_ptr; + value_type_array _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_FactorizeChol() = delete; - /// - /// Main functions - /// - template - KOKKOS_INLINE_FUNCTION - void factorize_var0(MemberType &member, - const supernode_type &s, - const value_type_matrix &ABR) const { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - Chol::invoke(member, ATL); - - if (n_m > 0) { - member.team_barrier(); - const value_type one(1), minus_one(-1), zero(0); - UnmanagedViewType ATR(aptr, m, n_m); - Trsm - ::invoke(member, Diag::NonUnit(), one, ATL, ATR); - member.team_barrier(); - Herk - ::invoke(member, minus_one, ATR, zero, ABR); - } + KOKKOS_INLINE_FUNCTION + TeamFunctor_FactorizeChol(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const value_type_array buf) + : _info(info), _compute_mode(compute_mode), _level_sids(level_sids), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Main functions + /// + template + KOKKOS_INLINE_FUNCTION void factorize_var0(MemberType &member, const supernode_type &s, + const value_type_matrix &ABR) const { + using CholAlgoType = typename CholAlgorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using HerkAlgoType = typename HerkAlgorithm::type; + + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + Chol::invoke(member, ATL); + + if (n_m > 0) { + // member.team_barrier(); + const value_type one(1), minus_one(-1), zero(0); + UnmanagedViewType ATR(aptr, m, n_m); + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ATR); + member.team_barrier(); + Herk::invoke(member, minus_one, ATR, zero, ABR); } } + } + + template + KOKKOS_INLINE_FUNCTION void factorize_var1(MemberType &member, const supernode_type &s, const value_type_matrix &T, + const value_type_matrix &ABR) const { + using CholAlgoType = typename CholAlgorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using HerkAlgoType = typename HerkAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + Chol::invoke(member, ATL); - template - KOKKOS_INLINE_FUNCTION - void factorize_var1(MemberType &member, - const supernode_type &s, - const value_type_matrix &T, - const value_type_matrix &ABR) const { - const value_type one(1), minus_one(-1), zero(0); - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - Chol::invoke(member, ATL); - - if (n_m > 0) { - member.team_barrier(); - UnmanagedViewType ATR(aptr, m, n_m); - Trsm - ::invoke(member, Diag::NonUnit(), one, ATL, ATR); - Copy - ::invoke(member, T, ATL); - SetIdentity - ::invoke(member, ATL, one); - Trsm - ::invoke(member, Diag::NonUnit(), one, T, ATL); - member.team_barrier(); - Herk - ::invoke(member, minus_one, ATR, zero, ABR); - } else { - member.team_barrier(); - Copy - ::invoke(member, T, ATL); - SetIdentity - ::invoke(member, ATL, one); - Trsm - ::invoke(member, Diag::NonUnit(), one, T, ATL); - } + if (n_m > 0) { + // member.team_barrier(); + UnmanagedViewType ATR(aptr, m, n_m); + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ATR); + Copy::invoke(member, T, ATL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, one); + Trsm::invoke(member, Diag::NonUnit(), one, T, ATL); + member.team_barrier(); + + Herk::invoke(member, minus_one, ATR, zero, ABR); + } else { + // member.team_barrier(); + Copy::invoke(member, T, ATL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, one); + member.team_barrier(); + + Trsm::invoke(member, Diag::NonUnit(), one, T, ATL); } } + } + + template + KOKKOS_INLINE_FUNCTION void factorize_var2(MemberType &member, const supernode_type &s, const value_type_matrix &T, + const value_type_matrix &ABR) const { + using CholAlgoType = typename CholAlgorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using HerkAlgoType = typename HerkAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + Chol::invoke(member, ATL); + + if (n_m > 0) { + // member.team_barrier(); + UnmanagedViewType ATR(aptr, m, n_m); + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ATR); + member.team_barrier(); + + Herk::invoke(member, minus_one, ATR, zero, ABR); + member.team_barrier(); + + Copy::invoke(member, T, ATL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, minus_one); + member.team_barrier(); + + UnmanagedViewType AT(ATL.data(), m, n); + Trsm::invoke(member, Diag::NonUnit(), minus_one, T, + AT); + } else { + // member.team_barrier(); + Copy::invoke(member, T, ATL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, one); + member.team_barrier(); - template - KOKKOS_INLINE_FUNCTION - void factorize_var2(MemberType &member, - const supernode_type &s, - const value_type_matrix &T, - const value_type_matrix &ABR) const { - const value_type one(1), minus_one(-1), zero(0); - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - Chol::invoke(member, ATL); - - if (n_m > 0) { - member.team_barrier(); - UnmanagedViewType ATR(aptr, m, n_m); - Trsm - ::invoke(member, Diag::NonUnit(), one, ATL, ATR); - member.team_barrier(); - Herk - ::invoke(member, minus_one, ATR, zero, ABR); - member.team_barrier(); - /// additional things - Copy - ::invoke(member, T, ATL); - member.team_barrier(); - SetIdentity::invoke(member, ATL, minus_one); - member.team_barrier(); - UnmanagedViewType AT(ATL.data(), m, n); - Trsm - ::invoke(member, Diag::NonUnit(), minus_one, T, AT); - } else { - /// additional things - Copy - ::invoke(member, T, ATL); - member.team_barrier(); - SetIdentity::invoke(member, ATL, one); - member.team_barrier(); - Trsm - ::invoke(member, Diag::NonUnit(), one, T, ATL); - } + Trsm::invoke(member, Diag::NonUnit(), one, T, ATL); } } + } - template - KOKKOS_INLINE_FUNCTION - void update(MemberType &member, - const supernode_type &cur, - const value_type_matrix &ABR) const { - const auto info = _info; - value_type *buf = ABR.data() + ABR.span(); - const ordinal_type - sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; - - const ordinal_type - srcbeg = info.sid_block_colidx(sbeg).second, - srcend = info.sid_block_colidx(send).second, - srcsize = srcend - srcbeg; - - // short cut to direct update - if ((send - sbeg) == 1) { - const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); - const ordinal_type - tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, - tgtend = info.sid_block_colidx(s.sid_col_end-1).second, - tgtsize = tgtend - tgtbeg; - - if (srcsize == tgtsize) { - /* */ value_type *tgt = s.buf; - const value_type *src = (value_type*)ABR.data(); - - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, srcsize), - [&, srcsize, src, tgt](const ordinal_type &j) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - const value_type *__restrict__ ss = src + j*srcsize; - /* */ value_type *__restrict__ tt = tgt + j*srcsize; - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, j+1), [&](const ordinal_type &i) { - Kokkos::atomic_add(&tt[i], ss[i]); - }); + template + KOKKOS_INLINE_FUNCTION void update(MemberType &member, const supernode_type &cur, + const value_type_matrix &ABR) const { + const auto info = _info; + value_type *buf = ABR.data() + ABR.span(); + const ordinal_type sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; + + const ordinal_type srcbeg = info.sid_block_colidx(sbeg).second, srcend = info.sid_block_colidx(send).second, + srcsize = srcend - srcbeg; + + // short cut to direct update + if ((send - sbeg) == 1) { + const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + + if (srcsize == tgtsize) { + /* */ value_type *tgt = s.u_buf; + const value_type *src = (value_type *)ABR.data(); + + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, srcsize), + [&, srcsize, src, + tgt](const ordinal_type &j) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const value_type *__restrict__ ss = src + j * srcsize; + /* */ value_type *__restrict__ tt = tgt + j * srcsize; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j + 1), + [&](const ordinal_type &i) { Kokkos::atomic_add(&tt[i], ss[i]); }); }); - return; - } - } - - const ordinal_type *s_colidx = sbeg < send ? &info.gid_colidx(cur.gid_col_begin + srcbeg) : NULL; - - // loop over target - //const size_type s2tsize = srcsize*sizeof(ordinal_type)*member.team_size(); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, sbeg, send), - [&, buf, srcsize](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - ordinal_type *s2t = ((ordinal_type*)(buf)) + member.team_rank()*srcsize; + return; + } + } + + const ordinal_type *s_colidx = sbeg < send ? &info.gid_colidx(cur.gid_col_begin + srcbeg) : NULL; + + // loop over target + // const size_type s2tsize = srcsize*sizeof(ordinal_type)*member.team_size(); + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, sbeg, send), + [&, buf, + srcsize](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + ordinal_type *s2t = ((ordinal_type *)(buf)) + member.team_rank() * srcsize; const auto &s = info.supernodes(info.sid_block_colidx(i).first); { - const ordinal_type - tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, - tgtend = info.sid_block_colidx(s.sid_col_end-1).second, - tgtsize = tgtend - tgtbeg; - + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + const ordinal_type *t_colidx = &info.gid_colidx(s.gid_col_begin + tgtbeg); - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, srcsize), - [&, t_colidx, s_colidx, tgtsize](const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - s2t[k] = -1; - auto found = lower_bound(&t_colidx[0], &t_colidx[tgtsize-1], s_colidx[k], - [](ordinal_type left, ordinal_type right) { - return left < right; - }); - if (s_colidx[k] == *found) { - s2t[k] = found - t_colidx; - } - }); + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, srcsize), + [&, t_colidx, s_colidx, tgtsize]( + const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + s2t[k] = -1; + auto found = lower_bound(&t_colidx[0], &t_colidx[tgtsize - 1], s_colidx[k], + [](ordinal_type left, ordinal_type right) { return left < right; }); + if (s_colidx[k] == *found) { + s2t[k] = found - t_colidx; + } + }); } { - dense_block_type A; - A.set_view(s.m, s.n); - A.attach_buffer(1, s.m, s.buf); - - ordinal_type ijbeg = 0; for (;s2t[ijbeg] == -1; ++ijbeg) ; - -#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - for (ordinal_type iii=0;iii<(srcsize-ijbeg);++iii) { + UnmanagedViewType A(s.u_buf, s.m, s.n); + + ordinal_type ijbeg = 0; + for (; s2t[ijbeg] == -1; ++ijbeg) + ; + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + for (ordinal_type iii = 0; iii < (srcsize - ijbeg); ++iii) { const ordinal_type ii = ijbeg + iii; const ordinal_type row = s2t[ii]; if (row < s.m) { - for (ordinal_type jj=ijbeg;jj struct FactorizeTag { enum {variant = Var }; }; - struct UpdateTag {}; - struct DummyTag {}; + template struct FactorizeTag { + enum { variant = Var }; + }; + struct UpdateTag {}; + struct DummyTag {}; - template - KOKKOS_INLINE_FUNCTION - void operator()(const FactorizeTag &, const MemberType &member) const { - const ordinal_type lid = member.league_rank(); - const ordinal_type p = _pbeg + lid; - const ordinal_type sid = _level_sids(p); - const ordinal_type mode = _compute_mode(sid); - if (p < _pend && mode == 1) { - using factorize_tag_type = FactorizeTag; - - const auto &s = _info.supernodes(sid); - const ordinal_type m = s.m, n = s.n, n_m = n-m; - const auto bufptr = _buf.data()+_buf_ptr(lid); - if (factorize_tag_type::variant == 0) { - UnmanagedViewType ABR(bufptr, n_m, n_m); - factorize_var0(member, s, ABR); - } else if (factorize_tag_type::variant == 1) { - UnmanagedViewType ABR(bufptr, n_m, n_m); - UnmanagedViewType T(bufptr, m, m); - factorize_var1(member, s, T, ABR); - } else if (factorize_tag_type::variant == 2) { - UnmanagedViewType ABR(bufptr, n_m, n_m); - UnmanagedViewType T(bufptr+ABR.span(), m, m); - factorize_var2(member, s, T, ABR); - } - } else if (mode == -1) { - printf("Error: TeamFunctorFactorizeChol, computing mode is not determined\n"); - } else { - // skip - } - } + template + KOKKOS_INLINE_FUNCTION void operator()(const FactorizeTag &, const MemberType &member) const { + const ordinal_type lid = member.league_rank(); + const ordinal_type p = _pbeg + lid; + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + using factorize_tag_type = FactorizeTag; - template - KOKKOS_INLINE_FUNCTION - void operator()(const UpdateTag &, const MemberType &member) const { - const ordinal_type lid = member.league_rank(); - const ordinal_type p = _pbeg + lid; - if (p < _pend) { - const ordinal_type sid = _level_sids(p); - const auto &s = _info.supernodes(sid); - const ordinal_type n_m = s.n-s.m; - UnmanagedViewType ABR(_buf.data()+_buf_ptr(lid), n_m, n_m); - update(member, s, ABR); - } else { - // skip + const auto &s = _info.supernodes(sid); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + const auto bufptr = _buf.data() + _buf_ptr(lid); + if (factorize_tag_type::variant == 0) { + UnmanagedViewType ABR(bufptr, n_m, n_m); + factorize_var0(member, s, ABR); + } else if (factorize_tag_type::variant == 1) { + UnmanagedViewType ABR(bufptr, n_m, n_m); + UnmanagedViewType T(bufptr, m, m); + factorize_var1(member, s, T, ABR); + } else if (factorize_tag_type::variant == 2) { + UnmanagedViewType ABR(bufptr, n_m, n_m); + UnmanagedViewType T(bufptr + ABR.span(), m, m); + factorize_var2(member, s, T, ABR); } + } else if (mode == -1) { + printf("Error: TeamFunctorFactorizeChol, computing mode is not determined\n"); + } else { + // skip } + } - template - KOKKOS_INLINE_FUNCTION - void operator()(const DummyTag &, const MemberType &member) const { - // do nothing + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type lid = member.league_rank(); + const ordinal_type p = _pbeg + lid; + if (p < _pend) { + const ordinal_type sid = _level_sids(p); + const auto &s = _info.supernodes(sid); + const ordinal_type n_m = s.n - s.m; + UnmanagedViewType ABR(_buf.data() + _buf_ptr(lid), n_m, n_m); + update(member, s, ABR); + } else { + // skip } + } - }; -} + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const { + // do nothing + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeLDL.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeLDL.hpp index c964086be162..64d16cf8669e 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeLDL.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeLDL.hpp @@ -10,276 +10,394 @@ namespace Tacho { - template - struct TeamFunctor_FactorizeLDL { - public: - using range_type = Kokkos::pair; - - using supernode_info_type = SupernodeInfoType; - using supernode_type = typename supernode_info_type::supernode_type; - - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - using size_type_array = typename supernode_info_type::size_type_array; - - using value_type = typename supernode_info_type::value_type; - using value_type_array = typename supernode_info_type::value_type_array; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - - using dense_block_type = typename supernode_info_type::dense_block_type; - - using MainAlgoType = typename std::conditional - ::value, - Algo::External,Algo::Internal>::type; - /// testing purpose - /// using MainAlgoType = Algo::Internal; - using LDL_AlgoType = MainAlgoType; - using TrsmAlgoType = MainAlgoType; - using GemmAlgoType = MainAlgoType; - - private: - supernode_info_type _info; - ordinal_type_array _compute_mode, _level_sids; - ordinal_type _pbeg, _pend; - - ordinal_type_array _piv; - value_type_array _diag; - - size_type_array _buf_ptr; - value_type_array _buf; - - public: - KOKKOS_INLINE_FUNCTION - TeamFunctor_FactorizeLDL() = delete; - - KOKKOS_INLINE_FUNCTION - TeamFunctor_FactorizeLDL(const supernode_info_type &info, - const ordinal_type_array &compute_mode, - const ordinal_type_array &level_sids, - const ordinal_type_array &piv, - const value_type_array &diag, - const value_type_array buf) - : - _info(info), - _compute_mode(compute_mode), - _level_sids(level_sids), - _piv(piv), _diag(diag), - _buf(buf) - {} - - inline - void setRange(const ordinal_type pbeg, - const ordinal_type pend) { - _pbeg = pbeg; _pend = pend; +template struct TeamFunctor_FactorizeLDL { +public: + using range_type = Kokkos::pair; + + using supernode_info_type = SupernodeInfoType; + using supernode_type = typename supernode_info_type::supernode_type; + + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using size_type_array = typename supernode_info_type::size_type_array; + + using value_type = typename supernode_info_type::value_type; + using value_type_array = typename supernode_info_type::value_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + +private: + supernode_info_type _info; + ordinal_type_array _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + ordinal_type_array _piv; + value_type_array _diag; + + size_type_array _buf_ptr; + value_type_array _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_FactorizeLDL() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_FactorizeLDL(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const ordinal_type_array &piv, + const value_type_array &diag, const value_type_array buf) + : _info(info), _compute_mode(compute_mode), _level_sids(level_sids), _piv(piv), _diag(diag), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Main functions + /// + template + KOKKOS_INLINE_FUNCTION void factorize_var0(MemberType &member, const supernode_type &s, const ordinal_type_array &P, + const value_type_matrix &D, + const value_type_array &W, /// STR and workspace for LDL + const value_type_matrix &ABR) const { + using LDL_AlgoType = typename LDL_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType ATL(s.u_buf, m, m); + Symmetrize::invoke(member, ATL); + member.team_barrier(); + LDL::invoke(member, ATL, P, W); + member.team_barrier(); + LDL::modify(member, ATL, P, D); + member.team_barrier(); + + if (n_m > 0) { + const value_type one(1), minus_one(-1), zero(0); + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + UnmanagedViewType STR(W.data(), m, n_m); + + ConstUnmanagedViewType fpiv(P.data() + m, m); + ApplyPivots::invoke(member, fpiv, ATR); + member.team_barrier(); + Trsm::invoke(member, Diag::Unit(), one, ATL, ATR); + member.team_barrier(); + Copy::invoke(member, STR, ATR); + member.team_barrier(); + Scale2x2_BlockInverseDiagonals /// row scaling + ::invoke(member, P, D, ATR); + member.team_barrier(); + GemmTriangular::invoke(member, minus_one, ATR, + STR, zero, ABR); + } } + } + + template + KOKKOS_INLINE_FUNCTION void factorize_var1(MemberType &member, const supernode_type &s, const ordinal_type_array &P, + const value_type_matrix &D, + const value_type_array &W, /// STR and workspace for LDL + const value_type_matrix &T, const value_type_matrix &ABR) const { + using LDL_AlgoType = typename LDL_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType ATL(s.u_buf, m, m); + + Symmetrize::invoke(member, ATL); + member.team_barrier(); + + LDL::invoke(member, ATL, P, W); + member.team_barrier(); - inline - void setBufferPtr(const size_type_array &buf_ptr) { - _buf_ptr = buf_ptr; + LDL::modify(member, ATL, P, D); + member.team_barrier(); + + SetIdentity::invoke(member, T, one); + member.team_barrier(); + + Trsm::invoke(member, Diag::Unit(), one, ATL, T); + + if (n_m > 0) { + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + UnmanagedViewType STR(W.data(), m, n_m); + + ConstUnmanagedViewType perm(P.data() + 2 * m, m); + ApplyPermutation::invoke(member, ATR, perm, STR); + member.team_barrier(); + + Trsm::invoke(member, Diag::Unit(), one, ATL, STR); + member.team_barrier(); + + Copy::invoke(member, ATL, T); + Copy::invoke(member, ATR, STR); + member.team_barrier(); + + Scale2x2_BlockInverseDiagonals /// row scaling + ::invoke(member, P, D, ATR); + member.team_barrier(); + + GemmTriangular::invoke(member, minus_one, ATR, + STR, zero, ABR); + } else { + member.team_barrier(); + Copy::invoke(member, ATL, T); + } } + } + + template + KOKKOS_INLINE_FUNCTION void factorize_var2(MemberType &member, const supernode_type &s, const ordinal_type_array &P, + const value_type_matrix &D, + const value_type_array &W, /// STR and workspace for LDL + const value_type_matrix &T, const value_type_matrix &ABR) const { + using LDL_AlgoType = typename LDL_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType ATL(s.u_buf, m, m); - /// - /// Main functions - /// - template - KOKKOS_INLINE_FUNCTION - void factorize(MemberType &member, - const supernode_type &s, - const ordinal_type_array &P, - const value_type_matrix &D, - const value_type_array &W, /// STR and workspace for LDL - const value_type_matrix &ABR) const { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - UnmanagedViewType ATL(aptr, m, m); aptr += m*m; - Symmetrize::invoke(member, ATL); + Symmetrize::invoke(member, ATL); + member.team_barrier(); + + LDL::invoke(member, ATL, P, W); + member.team_barrier(); + + LDL::modify(member, ATL, P, D); + + if (n_m > 0) { member.team_barrier(); - LDL::invoke(member, ATL, P, W); + + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + UnmanagedViewType STR(W.data(), m, n_m); + + ConstUnmanagedViewType perm(P.data() + 2 * m, m); + ApplyPermutation::invoke(member, ATR, perm, STR); member.team_barrier(); - LDL::modify(member, ATL, P, D); + + Trsm::invoke(member, Diag::Unit(), one, ATL, STR); + member.team_barrier(); + + Copy::invoke(member, T, ATL); + Copy::invoke(member, ATR, STR); member.team_barrier(); - if (n_m > 0) { - const value_type one(1), minus_one(-1), zero(0); - UnmanagedViewType ATR(aptr, m, n_m); - UnmanagedViewType STR(W.data(), m, n_m); - - auto fpiv = ordinal_type_array(P.data()+m, m); - ApplyPivots - ::invoke(member, fpiv, ATR); - member.team_barrier(); - Trsm - ::invoke(member, Diag::Unit(), one, ATL, ATR); - member.team_barrier(); - Copy - ::invoke(member, STR, ATR); - member.team_barrier(); - Scale2x2_BlockInverseDiagonals /// row scaling + Symmetrize::invoke(member, T); + SetIdentity::invoke(member, ATL, minus_one); + Scale2x2_BlockInverseDiagonals /// row scaling ::invoke(member, P, D, ATR); - member.team_barrier(); - GemmTriangular - ::invoke(member, minus_one, ATR, STR, zero, ABR); - } + member.team_barrier(); + + GemmTriangular::invoke(member, minus_one, ATR, + STR, zero, ABR); + member.team_barrier(); + + UnmanagedViewType AT(ATL.data(), m, n); + Trsm::invoke(member, Diag::Unit(), minus_one, T, AT); + } else { + member.team_barrier(); + Copy::invoke(member, T, ATL); + member.team_barrier(); + SetIdentity::invoke(member, ATL, one); + member.team_barrier(); + Trsm::invoke(member, Diag::Unit(), one, T, ATL); } } + } - template - KOKKOS_INLINE_FUNCTION - void update(MemberType &member, - const supernode_type &cur, - const value_type_matrix &ABR) const { - const auto info = _info; - value_type *buf = ABR.data() + ABR.span(); - const ordinal_type - sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; - - const ordinal_type - srcbeg = info.sid_block_colidx(sbeg).second, - srcend = info.sid_block_colidx(send).second, - srcsize = srcend - srcbeg; - - // short cut to direct update - if ((send - sbeg) == 1) { - const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); - const ordinal_type - tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, - tgtend = info.sid_block_colidx(s.sid_col_end-1).second, - tgtsize = tgtend - tgtbeg; - - if (srcsize == tgtsize) { - /* */ value_type *tgt = s.buf; - const value_type *src = (value_type*)ABR.data(); - - Kokkos::parallel_for - (Kokkos::TeamThreadRange(member, srcsize), - [&, srcsize, src, tgt](const ordinal_type &j) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - const value_type *__restrict__ ss = src + j*srcsize; - /* */ value_type *__restrict__ tt = tgt + j*srcsize; - Kokkos::parallel_for - (Kokkos::ThreadVectorRange(member, j+1), [&](const ordinal_type &i) { - Kokkos::atomic_add(&tt[i], ss[i]); - }); + template + KOKKOS_INLINE_FUNCTION void update(MemberType &member, const supernode_type &cur, + const value_type_matrix &ABR) const { + const auto info = _info; + value_type *buf = ABR.data() + ABR.span(); + const ordinal_type sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; + + const ordinal_type srcbeg = info.sid_block_colidx(sbeg).second, srcend = info.sid_block_colidx(send).second, + srcsize = srcend - srcbeg; + + // short cut to direct update + if ((send - sbeg) == 1) { + const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + + if (srcsize == tgtsize) { + /* */ value_type *tgt = s.u_buf; + const value_type *src = (value_type *)ABR.data(); + + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, srcsize), + [&, srcsize, src, + tgt](const ordinal_type &j) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const value_type *__restrict__ ss = src + j * srcsize; + /* */ value_type *__restrict__ tt = tgt + j * srcsize; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, j + 1), + [&](const ordinal_type &i) { Kokkos::atomic_add(&tt[i], ss[i]); }); }); - return; - } - } - - const ordinal_type *s_colidx = sbeg < send ? &info.gid_colidx(cur.gid_col_begin + srcbeg) : NULL; - - // loop over target - //const size_type s2tsize = srcsize*sizeof(ordinal_type)*member.team_size(); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member, sbeg, send), - [&, buf, srcsize](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - ordinal_type *s2t = ((ordinal_type*)(buf)) + member.team_rank()*srcsize; + return; + } + } + + const ordinal_type *s_colidx = sbeg < send ? &info.gid_colidx(cur.gid_col_begin + srcbeg) : NULL; + + // loop over target + // const size_type s2tsize = srcsize*sizeof(ordinal_type)*member.team_size(); + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, sbeg, send), + [&, buf, + srcsize](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + ordinal_type *s2t = ((ordinal_type *)(buf)) + member.team_rank() * srcsize; const auto &s = info.supernodes(info.sid_block_colidx(i).first); { - const ordinal_type - tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, - tgtend = info.sid_block_colidx(s.sid_col_end-1).second, - tgtsize = tgtend - tgtbeg; - + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + const ordinal_type *t_colidx = &info.gid_colidx(s.gid_col_begin + tgtbeg); - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, srcsize), - [&, t_colidx, s_colidx, tgtsize](const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - s2t[k] = -1; - auto found = lower_bound(&t_colidx[0], &t_colidx[tgtsize-1], s_colidx[k], - [](ordinal_type left, ordinal_type right) { - return left < right; - }); - if (s_colidx[k] == *found) { - s2t[k] = found - t_colidx; - } - }); + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, srcsize), + [&, t_colidx, s_colidx, tgtsize]( + const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + s2t[k] = -1; + auto found = lower_bound(&t_colidx[0], &t_colidx[tgtsize - 1], s_colidx[k], + [](ordinal_type left, ordinal_type right) { return left < right; }); + if (s_colidx[k] == *found) { + s2t[k] = found - t_colidx; + } + }); } { - dense_block_type A; - A.set_view(s.m, s.n); - A.attach_buffer(1, s.m, s.buf); - - ordinal_type ijbeg = 0; for (;s2t[ijbeg] == -1; ++ijbeg) ; - -#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) - for (ordinal_type iii=0;iii<(srcsize-ijbeg);++iii) { + UnmanagedViewType A(s.u_buf, s.m, s.n); + + ordinal_type ijbeg = 0; + for (; s2t[ijbeg] == -1; ++ijbeg) + ; + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + for (ordinal_type iii = 0; iii < (srcsize - ijbeg); ++iii) { const ordinal_type ii = ijbeg + iii; const ordinal_type row = s2t[ii]; if (row < s.m) { - for (ordinal_type jj=ijbeg;jj + KOKKOS_INLINE_FUNCTION void check(MemberType &member, supernode_type &s, const ordinal_type_array &fpiv) const { + ordinal_type val(0); + Kokkos::parallel_reduce( + Kokkos::TeamVectorRange(member, s.m), + [&](const ordinal_type &i, ordinal_type &update) { + const ordinal_type fpiv_at_i = fpiv(i); + update += (fpiv_at_i < 0 ? -fpiv_at_i : fpiv_at_i); + }, + val); + member.team_barrier(); + Kokkos::single(Kokkos::PerTeam(member), [&]() { s.do_not_apply_pivots = (val == 0); }); + return; + } - template - KOKKOS_INLINE_FUNCTION - void operator()(const FactorizeTag &, const MemberType &member) const { - const ordinal_type lid = member.league_rank(); - const ordinal_type p = _pbeg + lid; - const ordinal_type sid = _level_sids(p); - const ordinal_type mode = _compute_mode(sid); - if (p < _pend && mode == 1) { - const auto &s = _info.supernodes(sid); - const ordinal_type m = s.m, n = s.n, n_m = n-m; - const ordinal_type offm = s.row_begin; - - UnmanagedViewType P(_piv.data()+offm*4, m*4); - UnmanagedViewType D(_diag.data()+offm*2, m, 2); - - const int bufbeg = _buf_ptr(lid), bufend = _buf_ptr(lid+1); - value_type * bufptr = _buf.data()+bufbeg; - UnmanagedViewType ABR(bufptr, n_m, n_m); bufptr += ABR.span(); - UnmanagedViewType W(bufptr, int(bufend-bufbeg-ABR.span())); - - /// check the span does not go more than buf_ptr(lid+1) - factorize(member, s, P, D, W, ABR); - } else if (mode == -1) { - printf("Error: TeamFunctorFactorizeChol, computing mode is not determined\n"); - } else { - // skip - } - } + template struct FactorizeTag { + enum { variant = Var }; + }; + struct UpdateTag {}; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const FactorizeTag &, const MemberType &member) const { + const ordinal_type lid = member.league_rank(); + const ordinal_type p = _pbeg + lid; + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + using factorize_tag_type = FactorizeTag; + + const auto &s = _info.supernodes(sid); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + const ordinal_type offm = s.row_begin; - template - KOKKOS_INLINE_FUNCTION - void operator()(const UpdateTag &, const MemberType &member) const { - const ordinal_type lid = member.league_rank(); - const ordinal_type p = _pbeg + lid; - if (p < _pend) { - const ordinal_type sid = _level_sids(p); - const auto &s = _info.supernodes(sid); - const ordinal_type n_m = s.n-s.m; - value_type * bufptr = _buf.data()+_buf_ptr(lid); + UnmanagedViewType P(_piv.data() + offm * 4, m * 4); + UnmanagedViewType D(_diag.data() + offm * 2, m, 2); + + if (factorize_tag_type::variant == 0) { + const ordinal_type bufbeg = _buf_ptr(lid), bufend = _buf_ptr(lid + 1); + auto bufptr = _buf.data() + bufbeg; UnmanagedViewType ABR(bufptr, n_m, n_m); - update(member, s, ABR); - } else { - // skip + + const ordinal_type used_span = ABR.span(); + UnmanagedViewType W(bufptr + used_span, int(bufend - bufbeg - used_span)); + factorize_var0(member, s, P, D, W, ABR); + } else if (factorize_tag_type::variant == 1) { + const ordinal_type bufbeg = _buf_ptr(lid), bufend = _buf_ptr(lid + 1); + auto bufptr = _buf.data() + bufbeg; + UnmanagedViewType ABR(bufptr, n_m, n_m); + UnmanagedViewType T(bufptr, m, m); + const ordinal_type used_span = max(ABR.span(), T.span()); + UnmanagedViewType W(bufptr + used_span, int(bufend - bufbeg - used_span)); + factorize_var1(member, s, P, D, W, T, ABR); + } else if (factorize_tag_type::variant == 2) { + const ordinal_type bufbeg = _buf_ptr(lid), bufend = _buf_ptr(lid + 1); + auto bufptr = _buf.data() + bufbeg; + UnmanagedViewType ABR(bufptr, n_m, n_m); + UnmanagedViewType T(bufptr + ABR.span(), m, m); + const ordinal_type used_span = ABR.span() + T.span(); + UnmanagedViewType W(bufptr + used_span, int(bufend - bufbeg - used_span)); + factorize_var2(member, s, P, D, W, T, ABR); } + } else if (mode == -1) { + printf("Error: TeamFunctorFactorizeChol, computing mode is not determined\n"); + } else { + // skip } + } - template - KOKKOS_INLINE_FUNCTION - void operator()(const DummyTag &, const MemberType &member) const { - // do nothing + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type lid = member.league_rank(); + const ordinal_type p = _pbeg + lid; + if (p < _pend) { + const ordinal_type sid = _level_sids(p); + auto &s = _info.supernodes(sid); + const ordinal_type n_m = s.n - s.m; + value_type *bufptr = _buf.data() + _buf_ptr(lid); + UnmanagedViewType ABR(bufptr, n_m, n_m); + update(member, s, ABR); + + const ordinal_type offm = s.row_begin; + UnmanagedViewType fpiv(_piv.data() + offm * 4 + s.m, s.m); + check(member, s, fpiv); + } else { + // skip } + } - }; -} + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const { + // do nothing + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeLU.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeLU.hpp new file mode 100644 index 000000000000..539c5355adc5 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_FactorizeLU.hpp @@ -0,0 +1,388 @@ +#ifndef __TACHO_TEAMFUNCTOR_FACTORIZE_LU_HPP__ +#define __TACHO_TEAMFUNCTOR_FACTORIZE_LU_HPP__ + +/// \file Tacho_TeamFunctor_FactorizeLU.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" + +#include "Tacho_SupernodeInfo.hpp" + +namespace Tacho { + +template struct TeamFunctor_FactorizeLU { +public: + using range_type = Kokkos::pair; + + using supernode_info_type = SupernodeInfoType; + using supernode_type = typename supernode_info_type::supernode_type; + + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using size_type_array = typename supernode_info_type::size_type_array; + + using value_type = typename supernode_info_type::value_type; + using value_type_array = typename supernode_info_type::value_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + +private: + supernode_info_type _info; + ordinal_type_array _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + ordinal_type_array _piv; + + size_type_array _buf_ptr; + value_type_array _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_FactorizeLU() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_FactorizeLU(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const ordinal_type_array &piv, + const value_type_array buf) + : _info(info), _compute_mode(compute_mode), _level_sids(level_sids), _piv(piv), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Main functions + /// + template + KOKKOS_INLINE_FUNCTION void factorize_var0(MemberType &member, const supernode_type &s, const ordinal_type_array &P, + const value_type_matrix &ABR) const { + using LU_AlgoType = typename LU_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType AT(s.u_buf, m, n); + + LU::invoke(member, AT, P); + member.team_barrier(); + + LU::modify(member, m, P); + member.team_barrier(); + + if (n_m > 0) { + const value_type one(1), minus_one(-1), zero(0); + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + UnmanagedViewType AL(s.l_buf, n, m); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ABL); + member.team_barrier(); + + Gemm::invoke(member, minus_one, ABL, ATR, zero, ABR); + } + } + } + template + KOKKOS_INLINE_FUNCTION void factorize_var1(MemberType &member, const supernode_type &s, const ordinal_type_array &P, + const value_type_matrix &T, const value_type_matrix &ABR) const { + using LU_AlgoType = typename LU_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType AT(s.u_buf, m, n); + + LU::invoke(member, AT, P); + member.team_barrier(); + + LU::modify(member, m, P); + member.team_barrier(); + + if (n_m > 0) { + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + UnmanagedViewType AL(s.l_buf, n, m); + const auto ATL2 = Kokkos::subview(AL, range_type(0, m), Kokkos::ALL()); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + + Copy::invoke(member, T, ATL); + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ABL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, one); + SetIdentity::invoke(member, ATL2, one); + member.team_barrier(); + + Trsm::invoke(member, Diag::NonUnit(), one, T, ATL); + Trsm::invoke(member, Diag::Unit(), one, T, ATL2); + member.team_barrier(); + + Gemm::invoke(member, minus_one, ABL, ATR, zero, ABR); + } else { + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATL2(s.l_buf, m, m); + + Copy::invoke(member, T, ATL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, one); + SetIdentity::invoke(member, ATL2, one); + member.team_barrier(); + + Trsm::invoke(member, Diag::NonUnit(), one, T, ATL); + Trsm::invoke(member, Diag::Unit(), one, T, ATL2); + } + } + } + + template + KOKKOS_INLINE_FUNCTION void factorize_var2(MemberType &member, const supernode_type &s, const ordinal_type_array &P, + const value_type_matrix &T, const value_type_matrix &ABR) const { + using LU_AlgoType = typename LU_Algorithm::type; + using TrsmAlgoType = typename TrsmAlgorithm::type; + using GemmAlgoType = typename GemmAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType AT(s.u_buf, m, n); + + LU::invoke(member, AT, P); + member.team_barrier(); + + LU::modify(member, m, P); + member.team_barrier(); + + if (n_m > 0) { + UnmanagedViewType AT(s.u_buf, m, n); + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + UnmanagedViewType AL(s.l_buf, n, m); + const auto ATL2 = Kokkos::subview(AL, range_type(0, m), Kokkos::ALL()); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + + Copy::invoke(member, T, ATL); + Trsm::invoke(member, Diag::NonUnit(), one, ATL, + ABL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, minus_one); + SetIdentity::invoke(member, ATL2, minus_one); + member.team_barrier(); + + Gemm::invoke(member, minus_one, ABL, ATR, zero, ABR); + member.team_barrier(); + + Trsm::invoke(member, Diag::NonUnit(), minus_one, T, + AT); + Trsm::invoke(member, Diag::Unit(), minus_one, T, + AL); + } else { + UnmanagedViewType ATL(s.u_buf, m, m); + UnmanagedViewType ATL2(s.l_buf, m, m); + + Copy::invoke(member, T, ATL); + member.team_barrier(); + + SetIdentity::invoke(member, ATL, one); + SetIdentity::invoke(member, ATL2, one); + member.team_barrier(); + + Trsm::invoke(member, Diag::NonUnit(), one, T, ATL); + Trsm::invoke(member, Diag::Unit(), one, T, ATL2); + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update(MemberType &member, const supernode_type &cur, + const value_type_matrix &ABR) const { + const auto info = _info; + value_type *buf = ABR.data() + ABR.span(); + const ordinal_type sbeg = cur.sid_col_begin + 1, send = cur.sid_col_end - 1; + + const ordinal_type srcbeg = info.sid_block_colidx(sbeg).second, srcend = info.sid_block_colidx(send).second, + srcsize = srcend - srcbeg; + + // short cut to direct update + if ((send - sbeg) == 1) { + const auto &s = info.supernodes(info.sid_block_colidx(sbeg).first); + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + + if (srcsize == tgtsize) { + /* */ value_type *tgt = s.u_buf; + const value_type *src = (value_type *)ABR.data(); + + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, srcsize), + [&, srcsize, src, + tgt](const ordinal_type &j) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const value_type *__restrict__ ss = src + j * srcsize; + /* */ value_type *__restrict__ tt = tgt + j * srcsize; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, srcsize), + [&](const ordinal_type &i) { Kokkos::atomic_add(&tt[i], ss[i]); }); + }); + return; + } + } + + const ordinal_type *s_colidx = sbeg < send ? &info.gid_colidx(cur.gid_col_begin + srcbeg) : NULL; + + // loop over target + // const size_type s2tsize = srcsize*sizeof(ordinal_type)*member.team_size(); + Kokkos::parallel_for( + Kokkos::TeamThreadRange(member, sbeg, send), + [&, buf, + srcsize](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + ordinal_type *s2t = ((ordinal_type *)(buf)) + member.team_rank() * srcsize; + const auto &s = info.supernodes(info.sid_block_colidx(i).first); + { + const ordinal_type tgtbeg = info.sid_block_colidx(s.sid_col_begin).second, + tgtend = info.sid_block_colidx(s.sid_col_end - 1).second, tgtsize = tgtend - tgtbeg; + + const ordinal_type *t_colidx = &info.gid_colidx(s.gid_col_begin + tgtbeg); + Kokkos::parallel_for( + Kokkos::ThreadVectorRange(member, srcsize), + [&, t_colidx, s_colidx, tgtsize]( + const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + s2t[k] = -1; + auto found = lower_bound(&t_colidx[0], &t_colidx[tgtsize - 1], s_colidx[k], + [](ordinal_type left, ordinal_type right) { return left < right; }); + if (s_colidx[k] == *found) { + s2t[k] = found - t_colidx; + } + }); + } + { + UnmanagedViewType U(s.u_buf, s.m, s.n); + UnmanagedViewType Lp(s.l_buf, s.n, s.m); + const auto L = Kokkos::subview(Lp, range_type(s.m, s.n), Kokkos::ALL()); + + ordinal_type ijbeg = 0; + for (; s2t[ijbeg] == -1; ++ijbeg) + ; + +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + for (ordinal_type ii = ijbeg; ii < srcsize; ++ii) { + const ordinal_type row = s2t[ii]; + if (row < s.m) { + for (ordinal_type jj = ijbeg; jj < srcsize; ++jj) { + const ordinal_type col = s2t[jj]; + Kokkos::atomic_add(&U(row, col), ABR(ii, jj)); + if (col >= s.m) { + Kokkos::atomic_add(&L(col - s.m, row), ABR(jj, ii)); + } + } + } + } +#else + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, ijbeg, srcsize), [&](const ordinal_type &ii) { + const ordinal_type row = s2t[ii]; + if (row < s.m) { + for (ordinal_type jj = ijbeg; jj < srcsize; ++jj) { + const ordinal_type col = s2t[jj]; + Kokkos::atomic_add(&U(row, col), ABR(ii, jj)); + if (col >= s.m) { + Kokkos::atomic_add(&L(col - s.m, row), ABR(jj, ii)); + } + } + } + }); +#endif + } + }); + return; + } + + template + KOKKOS_INLINE_FUNCTION void check(MemberType &member, supernode_type &s, const ordinal_type_array &fpiv) const { + ordinal_type val(0); + Kokkos::parallel_reduce( + Kokkos::TeamVectorRange(member, s.m), + [&](const ordinal_type &i, ordinal_type &update) { + const ordinal_type fpiv_at_i = fpiv(i); + update += (fpiv_at_i < 0 ? -fpiv_at_i : fpiv_at_i); + }, + val); + member.team_barrier(); + Kokkos::single(Kokkos::PerTeam(member), [&]() { s.do_not_apply_pivots = (val == 0); }); + return; + } + + template struct FactorizeTag { + enum { variant = Var }; + }; + struct UpdateTag {}; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const FactorizeTag &, const MemberType &member) const { + const ordinal_type lid = member.league_rank(); + const ordinal_type p = _pbeg + lid; + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + using factorize_tag_type = FactorizeTag; + + const auto &s = _info.supernodes(sid); + const ordinal_type m = s.m, n = s.n, n_m = n - m; + const ordinal_type offm = s.row_begin; + + UnmanagedViewType P(_piv.data() + offm * 4, m * 4); + + const auto bufptr = _buf.data() + _buf_ptr(lid); + if (factorize_tag_type::variant == 0) { // temporary to push to trilinos + UnmanagedViewType ABR(bufptr, n_m, n_m); + factorize_var0(member, s, P, ABR); + } else if (factorize_tag_type::variant == 1) { + UnmanagedViewType ABR(bufptr, n_m, n_m); + UnmanagedViewType T(bufptr, m, m); + factorize_var1(member, s, P, T, ABR); + } else if (factorize_tag_type::variant == 2) { + UnmanagedViewType ABR(bufptr, n_m, n_m); + UnmanagedViewType T(bufptr + ABR.span(), m, m); + factorize_var2(member, s, P, T, ABR); + } + } else if (mode == -1) { + printf("Error: TeamFunctorFactorizeChol, computing mode is not determined\n"); + } else { + // skip + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type lid = member.league_rank(); + const ordinal_type p = _pbeg + lid; + if (p < _pend) { + const ordinal_type sid = _level_sids(p); + auto &s = _info.supernodes(sid); + const ordinal_type n_m = s.n - s.m; + value_type *bufptr = _buf.data() + _buf_ptr(lid); + UnmanagedViewType ABR(bufptr, n_m, n_m); + update(member, s, ABR); + + const ordinal_type offm = s.row_begin; + UnmanagedViewType fpiv(_piv.data() + offm * 4 + s.m, s.m); + check(member, s, fpiv); + } else { + // skip + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const { + // do nothing + } +}; +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_InvertPanel.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_InvertPanel.hpp index f4ca348b1569..da639417d16f 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_InvertPanel.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_InvertPanel.hpp @@ -9,123 +9,98 @@ #include "Tacho_SupernodeInfo.hpp" namespace Tacho { - - template - struct TeamFunctor_InvertPanel { - public: - typedef Kokkos::pair range_type; - - typedef SupernodeInfoType supernode_info_type; - typedef typename supernode_info_type::supernode_type supernode_type; - - typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; - typedef typename supernode_info_type::size_type_array size_type_array; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_array value_type_array; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsmAlgoType; - - private: - supernode_info_type _info; - ordinal_type_array _prepare_mode; - ordinal_type _scratch_level; - - public: - KOKKOS_INLINE_FUNCTION - TeamFunctor_InvertPanel() = delete; - - KOKKOS_INLINE_FUNCTION - TeamFunctor_InvertPanel(const supernode_info_type &info, - const ordinal_type_array &prepare_mode, - const ordinal_type scratch_level) - : - _info(info), - _prepare_mode(prepare_mode), - _scratch_level(scratch_level) - {} - - template - KOKKOS_INLINE_FUNCTION - void copyAndSetIdentity(MemberType &member, - const value_type use_this_one, - value_type_matrix A, - value_type_matrix P) const { - const value_type zero(0); - const ordinal_type m = A.extent(0); - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m*m), - [&](const ordinal_type &k) { - const ordinal_type i = k%m; - const ordinal_type j = k/m; - A(i,j) = i <= j ? P(i,j) : zero; - P(i,j) = i == j ? use_this_one : zero; - }); - } - - template - KOKKOS_INLINE_FUNCTION - void invert(MemberType &member, - const value_type use_this_one, - const value_type_matrix &A, - const value_type_matrix &P) const { - Trsm - ::invoke(member, Diag::NonUnit(), use_this_one, A, P); - } - - template struct VariantTag {}; - - template - KOKKOS_INLINE_FUNCTION - void operator()(const VariantTag<0> &, const MemberType &member) const { - // dummy - } - template - KOKKOS_INLINE_FUNCTION - void operator()(const VariantTag<1> &, const MemberType &member) const { - const ordinal_type sid = member.league_rank(); - const ordinal_type mode = _prepare_mode(sid); - if (mode == 1) { - const auto s = _info.supernodes(sid); - const ordinal_type m = s.m; - UnmanagedViewType A(member.team_scratch(_scratch_level), m, m); - UnmanagedViewType P(s.buf, m, m); - const value_type one(1); - copyAndSetIdentity(member, one, A, P); - invert(member, one, A, P); - } else { - // skip - } +template struct TeamFunctor_InvertPanel { +public: + typedef Kokkos::pair range_type; + + typedef SupernodeInfoType supernode_info_type; + typedef typename supernode_info_type::supernode_type supernode_type; + + typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; + typedef typename supernode_info_type::size_type_array size_type_array; + + typedef typename supernode_info_type::value_type value_type; + typedef typename supernode_info_type::value_type_array value_type_array; + typedef typename supernode_info_type::value_type_matrix value_type_matrix; + +private: + supernode_info_type _info; + ordinal_type_array _prepare_mode; + ordinal_type _scratch_level; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_InvertPanel() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_InvertPanel(const supernode_info_type &info, const ordinal_type_array &prepare_mode, + const ordinal_type scratch_level) + : _info(info), _prepare_mode(prepare_mode), _scratch_level(scratch_level) {} + + template + KOKKOS_INLINE_FUNCTION void copyAndSetIdentity(MemberType &member, const value_type use_this_one, value_type_matrix A, + value_type_matrix P) const { + const value_type zero(0); + const ordinal_type m = A.extent(0); + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m * m), [&](const ordinal_type &k) { + const ordinal_type i = k % m; + const ordinal_type j = k / m; + A(i, j) = i <= j ? P(i, j) : zero; + P(i, j) = i == j ? use_this_one : zero; + }); + } + + template + KOKKOS_INLINE_FUNCTION void invert(MemberType &member, const value_type use_this_one, const value_type_matrix &A, + const value_type_matrix &P) const { + using TrsmAlgoType = typename TrsmAlgorithm::type; + Trsm::invoke(member, Diag::NonUnit(), use_this_one, A, + P); + } + + template struct VariantTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const VariantTag<0> &, const MemberType &member) const { + // dummy + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const VariantTag<1> &, const MemberType &member) const { + const ordinal_type sid = member.league_rank(); + const ordinal_type mode = _prepare_mode(sid); + if (mode == 1) { + const auto s = _info.supernodes(sid); + const ordinal_type m = s.m; + UnmanagedViewType A(member.team_scratch(_scratch_level), m, m); + UnmanagedViewType P(s.buf, m, m); + const value_type one(1); + copyAndSetIdentity(member, one, A, P); + invert(member, one, A, P); + } else { + // skip } - - template - KOKKOS_INLINE_FUNCTION - void operator()(const VariantTag<2> &, const MemberType &member) const { - const ordinal_type sid = member.league_rank(); - const ordinal_type mode = _prepare_mode(sid); - if (mode == 1) { - const auto s = _info.supernodes(sid); - const ordinal_type m = s.m, n = s.n; - UnmanagedViewType A(member.team_scratch(_scratch_level), m, m); - UnmanagedViewType P(s.buf, m, n); - const value_type minus_one(-1); - copyAndSetIdentity(member, minus_one, A, P); - invert(member, minus_one, A, P); - } else { - // skip - } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const VariantTag<2> &, const MemberType &member) const { + const ordinal_type sid = member.league_rank(); + const ordinal_type mode = _prepare_mode(sid); + if (mode == 1) { + const auto s = _info.supernodes(sid); + const ordinal_type m = s.m, n = s.n; + UnmanagedViewType A(member.team_scratch(_scratch_level), m, m); + UnmanagedViewType P(s.buf, m, n); + const value_type minus_one(-1); + copyAndSetIdentity(member, minus_one, A, P); + invert(member, minus_one, A, P); + } else { + // skip } + } +}; - }; - -} +} // namespace Tacho #endif - - - - diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerChol.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerChol.hpp index edd64a8f801f..95022a951d93 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerChol.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerChol.hpp @@ -10,344 +10,321 @@ namespace Tacho { - template - struct TeamFunctor_SolveLowerChol { - public: - typedef Kokkos::pair range_type; - - typedef SupernodeInfoType supernode_info_type; - typedef typename supernode_info_type::supernode_type supernode_type; - typedef typename supernode_info_type::supernode_type_array supernode_type_array; - - typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; - typedef typename supernode_info_type::size_type_array size_type_array; - - typedef typename supernode_info_type::ordinal_pair_type_array ordinal_pair_type_array; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_array value_type_array; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsvAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type GemvAlgoType; - - private: - ConstUnmanagedViewType _supernodes; - ConstUnmanagedViewType _sid_block_colidx; - ConstUnmanagedViewType _gid_colidx; - - ConstUnmanagedViewType _compute_mode, _level_sids; - ordinal_type _pbeg, _pend; - - UnmanagedViewType _t; - ordinal_type _nrhs; - - UnmanagedViewType _buf_ptr; - UnmanagedViewType _buf; - - public: - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveLowerChol() = delete; - - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveLowerChol(const supernode_info_type &info, - const ordinal_type_array &compute_mode, - const ordinal_type_array &level_sids, - const value_type_matrix t, - const value_type_array buf) - : - _supernodes(info.supernodes), - _sid_block_colidx(info.sid_block_colidx), - _gid_colidx(info.gid_colidx), - _compute_mode(compute_mode), - _level_sids(level_sids), - _t(t), - _nrhs(t.extent(1)), - _buf(buf) - {} - - inline - void setRange(const ordinal_type pbeg, - const ordinal_type pend) { - _pbeg = pbeg; _pend = pend; - } - - inline - void setBufferPtr(const size_type_array &buf_ptr) { - _buf_ptr = buf_ptr; - } +template struct TeamFunctor_SolveLowerChol { +public: + typedef Kokkos::pair range_type; + + typedef SupernodeInfoType supernode_info_type; + typedef typename supernode_info_type::supernode_type supernode_type; + typedef typename supernode_info_type::supernode_type_array supernode_type_array; + + typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; + typedef typename supernode_info_type::size_type_array size_type_array; + + typedef typename supernode_info_type::ordinal_pair_type_array ordinal_pair_type_array; + + typedef typename supernode_info_type::value_type value_type; + typedef typename supernode_info_type::value_type_array value_type_array; + typedef typename supernode_info_type::value_type_matrix value_type_matrix; + +private: + ConstUnmanagedViewType _supernodes; + ConstUnmanagedViewType _sid_block_colidx; + ConstUnmanagedViewType _gid_colidx; + + ConstUnmanagedViewType _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + UnmanagedViewType _t; + ordinal_type _nrhs; + + UnmanagedViewType _buf_ptr; + UnmanagedViewType _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveLowerChol() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveLowerChol(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const value_type_matrix t, + const value_type_array buf) + : _supernodes(info.supernodes), _sid_block_colidx(info.sid_block_colidx), _gid_colidx(info.gid_colidx), + _compute_mode(compute_mode), _level_sids(level_sids), _t(t), _nrhs(t.extent(1)), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Algorithm Variant 0: trsv - gemv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + // solve + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + Trsv::invoke(member, Diag::NonUnit(), ATL, tT); - /// - /// Algorithm Variant 0: trsv - gemv - /// - template - KOKKOS_INLINE_FUNCTION - void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type minus_one(-1), zero(0); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - // solve - UnmanagedViewType AL(aptr, m, m); aptr += m*m; - - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - Trsv - ::invoke(member, Diag::NonUnit(), AL, tT); - - if (n_m > 0) { - // update - member.team_barrier(); - UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n; - UnmanagedViewType bB(bptr, n_m, _nrhs); - Gemv - ::invoke(member, minus_one, AR, tT, zero, bB); - } + if (n_m > 0) { + // update + member.team_barrier(); + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + UnmanagedViewType bB(bptr, n_m, _nrhs); + Gemv::invoke(member, minus_one, ATR, tT, zero, bB); } } } - - template - KOKKOS_INLINE_FUNCTION - void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (n_m > 0) { - UnmanagedViewType bB(bptr, n_m, _nrhs); - // update - const ordinal_type - sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; - for (ordinal_type i=sbeg,ip=0/*is=0*/;i + KOKKOS_INLINE_FUNCTION void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + UnmanagedViewType bB(bptr, n_m, _nrhs); + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, tbeg, + ip](const ordinal_type &ii) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it - KOKKOS_INLINE_FUNCTION - void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type minus_one(-1), one(1), zero(0); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - - UnmanagedViewType AL(aptr, m, m); aptr += m*m; - UnmanagedViewType bT(bptr, m, _nrhs); bptr += m*_nrhs; - - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - Gemv - ::invoke(member, one, AL, tT, zero, bT); - - if (n_m > 0) { - // solve offdiag - member.team_barrier(); - UnmanagedViewType AR(aptr, m, n_m); - UnmanagedViewType bB(bptr, n_m, _nrhs); - - Gemv - ::invoke(member, minus_one, AR, bT, zero, bB); - } + /// + /// Algorithm Variant 1: gemv - gemv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + UnmanagedViewType bT(bptr, m, _nrhs); + bptr += m * _nrhs; + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + Gemv::invoke(member, one, ATL, tT, zero, bT); + + if (n_m > 0) { + // solve offdiag + member.team_barrier(); + UnmanagedViewType ATR(aptr, m, n_m); + UnmanagedViewType bB(bptr, n_m, _nrhs); + + Gemv::invoke(member, minus_one, ATR, bT, zero, bB); } } } - - template - KOKKOS_INLINE_FUNCTION - void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - UnmanagedViewType bT(bptr, m, _nrhs); bptr += m*_nrhs; - - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - // copy to t - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m*_nrhs), - [&, m](const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - const ordinal_type i = k%m, j = k/m; - tT(i,j) = bT(i,j); + } + + template + KOKKOS_INLINE_FUNCTION void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType bT(bptr, m, _nrhs); + bptr += m * _nrhs; + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + // copy to t + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, m * _nrhs), + [&, m](const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const ordinal_type i = k % m, j = k / m; + tT(i, j) = bT(i, j); }); - if (n_m > 0) { - UnmanagedViewType bB(bptr, n_m, _nrhs); - - // update - const ordinal_type - sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; - for (ordinal_type i=sbeg,ip=0/*is=0*/;i 0) { + UnmanagedViewType bB(bptr, n_m, _nrhs); + + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, tbeg]( + const ordinal_type &ii) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it - KOKKOS_INLINE_FUNCTION - void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type one(1), zero(0); - { - const ordinal_type m = s.m, n = s.n; - if (m > 0 && n > 0) { - value_type *aptr = s.buf; - - UnmanagedViewType A(aptr, m, n); - UnmanagedViewType b(bptr, n, _nrhs); - - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - Gemv - ::invoke(member, one, A, tT, zero, b); - } + /// + /// Algorithm Variant 2: gemv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0 && n > 0) { + value_type *aptr = s.u_buf; + + UnmanagedViewType AT(aptr, m, n); + UnmanagedViewType b(bptr, n, _nrhs); + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + Gemv::invoke(member, one, AT, tT, zero, b); } } + } + + template + KOKKOS_INLINE_FUNCTION void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + UnmanagedViewType b(bptr, n, _nrhs); + if (m > 0) { + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + // copy to t + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m * _nrhs), + [&, m](const ordinal_type &k) { /// compiler bug with c++14 lambda capturing and workaround + const ordinal_type i = k % m, j = k / m; + tT(i, j) = b(i, j); + }); - template - KOKKOS_INLINE_FUNCTION - void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - UnmanagedViewType b(bptr, n, _nrhs); - if (m > 0) { - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - // copy to t - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m*_nrhs), - [&,m](const ordinal_type &k) { /// compiler bug with c++14 lambda capturing and workaround - const ordinal_type i = k%m, j = k/m; - tT(i,j) = b(i,j); - }); - - if (n_m > 0) { - // update - const ordinal_type - sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; - for (ordinal_type i=sbeg,ip=0/*is=0*/;i 0) { + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, ip, m, tbeg, + tcnt](const ordinal_type &ii) { /// compiler bug with c++14 lambda capturing and workaround + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it struct SolveTag { enum { variant = Var }; }; - template struct UpdateTag { enum { variant = Var }; }; - struct DummyTag {}; - - template - KOKKOS_INLINE_FUNCTION - void operator()(const SolveTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - const ordinal_type sid = _level_sids(p); - const ordinal_type mode = _compute_mode(sid); - if (p < _pend && mode == 1) { - typedef SolveTag solve_tag_type; - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - if (solve_tag_type::variant == 0) solve_var0(member, s, bptr); - else if (solve_tag_type::variant == 1) solve_var1(member, s, bptr); - else if (solve_tag_type::variant == 2) solve_var2(member, s, bptr); - else - printf("Error: TeamFunctorSolveLowerChol::SolveTag, algorithm variant is not supported\n"); - } if (mode == -1) { - printf("Error: TeamFunctorSolveLowerChol::SolveTag, computing mode is not determined\n"); - } else { - // skip - } + template struct SolveTag { + enum { variant = Var }; + }; + template struct UpdateTag { + enum { variant = Var }; + }; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const SolveTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + typedef SolveTag solve_tag_type; + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (solve_tag_type::variant == 0) + solve_var0(member, s, bptr); + else if (solve_tag_type::variant == 1) + solve_var1(member, s, bptr); + else if (solve_tag_type::variant == 2) + solve_var2(member, s, bptr); + else + printf("Error: TeamFunctorSolveLowerChol::SolveTag, algorithm variant is not supported\n"); } - - template - KOKKOS_INLINE_FUNCTION - void operator()(const UpdateTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - if (p < _pend) { - const ordinal_type sid = _level_sids(p); - typedef UpdateTag update_tag_type; - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - if (update_tag_type::variant == 0) update_var0(member, s, bptr); - else if (update_tag_type::variant == 1) update_var1(member, s, bptr); - else if (update_tag_type::variant == 2) update_var2(member, s, bptr); - else - printf("Error: TeamFunctorSolveLowerChol::UpdateTag, algorithm variant is not supported\n"); - } else { - // skip - } + if (mode == -1) { + printf("Error: TeamFunctorSolveLowerChol::SolveTag, computing mode is not determined\n"); + } else { + // skip } + } - template - KOKKOS_INLINE_FUNCTION - void operator()(const DummyTag &, const MemberType &member) const { - // do nothing + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + if (p < _pend) { + const ordinal_type sid = _level_sids(p); + typedef UpdateTag update_tag_type; + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (update_tag_type::variant == 0) + update_var0(member, s, bptr); + else if (update_tag_type::variant == 1) + update_var1(member, s, bptr); + else if (update_tag_type::variant == 2) + update_var2(member, s, bptr); + else + printf("Error: TeamFunctorSolveLowerChol::UpdateTag, algorithm variant is not supported\n"); + } else { + // skip } - - }; -} + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const { + // do nothing + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerLDL.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerLDL.hpp index 5774d747b2f0..eb328e811d4e 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerLDL.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerLDL.hpp @@ -10,189 +10,332 @@ namespace Tacho { - template - struct TeamFunctor_SolveLowerLDL { - public: - using range_type = Kokkos::pair; - - using supernode_info_type = SupernodeInfoType; - using supernode_type = typename supernode_info_type::supernode_type; - using supernode_type_array = typename supernode_info_type::supernode_type_array; - - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - using size_type_array = typename supernode_info_type::size_type_array; - - using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; - - using value_type = typename supernode_info_type::value_type; - using value_type_array = typename supernode_info_type::value_type_array; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - - using MainAlgoType = typename std::conditional - ::value, - Algo::External,Algo::Internal>::type; - using TrsvAlgoType = MainAlgoType; - using GemvAlgoType = MainAlgoType; - - private: - ConstUnmanagedViewType _supernodes; - ConstUnmanagedViewType _sid_block_colidx; - ConstUnmanagedViewType _gid_colidx; - - ConstUnmanagedViewType _compute_mode, _level_sids; - ordinal_type _pbeg, _pend; - - ConstUnmanagedViewType _piv; - - UnmanagedViewType _t; - ordinal_type _nrhs; - - UnmanagedViewType _buf_ptr; - UnmanagedViewType _buf; - - public: - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveLowerLDL() = delete; - - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveLowerLDL(const supernode_info_type &info, - const ordinal_type_array &compute_mode, - const ordinal_type_array &level_sids, - const ordinal_type_array &piv, - const value_type_matrix t, - const value_type_array buf) - : - _supernodes(info.supernodes), - _sid_block_colidx(info.sid_block_colidx), - _gid_colidx(info.gid_colidx), - _compute_mode(compute_mode), - _level_sids(level_sids), - _piv(piv), - _t(t), - _nrhs(t.extent(1)), - _buf(buf) - {} - - inline - void setRange(const ordinal_type pbeg, - const ordinal_type pend) { - _pbeg = pbeg; _pend = pend; +template struct TeamFunctor_SolveLowerLDL { +public: + using range_type = Kokkos::pair; + + using supernode_info_type = SupernodeInfoType; + using supernode_type = typename supernode_info_type::supernode_type; + using supernode_type_array = typename supernode_info_type::supernode_type_array; + + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using size_type_array = typename supernode_info_type::size_type_array; + + using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; + + using value_type = typename supernode_info_type::value_type; + using value_type_array = typename supernode_info_type::value_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + +private: + ConstUnmanagedViewType _supernodes; + ConstUnmanagedViewType _sid_block_colidx; + ConstUnmanagedViewType _gid_colidx; + + ConstUnmanagedViewType _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + ConstUnmanagedViewType _piv; + + UnmanagedViewType _t; + ordinal_type _nrhs; + + UnmanagedViewType _buf_ptr; + UnmanagedViewType _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveLowerLDL() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveLowerLDL(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const ordinal_type_array &piv, + const value_type_matrix t, const value_type_array buf) + : _supernodes(info.supernodes), _sid_block_colidx(info.sid_block_colidx), _gid_colidx(info.gid_colidx), + _compute_mode(compute_mode), _level_sids(level_sids), _piv(piv), _t(t), _nrhs(t.extent(1)), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Algorithm Variant 0: trsv - gemv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + // solve + UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + auto fpiv = ConstUnmanagedViewType(_piv.data() + 4 * offm + m, m); + + if (!s.do_not_apply_pivots) { + ApplyPivots /// row inter-change + ::invoke(member, fpiv, tT); + member.team_barrier(); + } + Trsv::invoke(member, Diag::Unit(), ATL, tT); + + if (n_m > 0) { + // update + member.team_barrier(); + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + UnmanagedViewType bB(bptr, n_m, _nrhs); + Gemv::invoke(member, minus_one, ATR, tT, zero, bB); + } + } } + } + + template + KOKKOS_INLINE_FUNCTION void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + UnmanagedViewType bB(bptr, n_m, _nrhs); + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; - inline - void setBufferPtr(const size_type_array &buf_ptr) { - _buf_ptr = buf_ptr; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, tbeg, + ip](const ordinal_type &ii) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it - KOKKOS_INLINE_FUNCTION - void solve(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type minus_one(-1), zero(0); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - // solve - UnmanagedViewType AL(aptr, m, m); aptr += m*m; - - const ordinal_type offm = s.row_begin; - auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - auto fpiv = ConstUnmanagedViewType(_piv.data()+4*offm+m, m); - - ApplyPivots /// row inter-change - ::invoke(member, fpiv, tT); - Trsv - ::invoke(member, Diag::Unit(), AL, tT); - - if (n_m > 0) { - // update - member.team_barrier(); - UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n; - UnmanagedViewType bB(bptr, n_m, _nrhs); - Gemv - ::invoke(member, minus_one, AR, tT, zero, bB); - } + template + KOKKOS_INLINE_FUNCTION void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + // solve + UnmanagedViewType ATL(aptr, m, m); + aptr += ATL.span(); + UnmanagedViewType bT(bptr, m, _nrhs); + bptr += bT.span(); + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + auto perm = ConstUnmanagedViewType(_piv.data() + 4 * offm + 2 * m, m); + + if (s.do_not_apply_pivots) { + Copy::invoke(member, bT, tT); + } else { + ApplyPermutation::invoke(member, tT, perm, bT); + } + member.team_barrier(); + + Gemv::invoke(member, one, ATL, bT, zero, tT); + + if (n_m > 0) { + // update + member.team_barrier(); + UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + UnmanagedViewType bB(bptr, n_m, _nrhs); + Gemv::invoke(member, minus_one, ATR, tT, zero, bB); } } } + } + + template + KOKKOS_INLINE_FUNCTION void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType bT(bptr, m, _nrhs); + bptr += m * _nrhs; + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); - template - KOKKOS_INLINE_FUNCTION - void update(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; if (n_m > 0) { UnmanagedViewType bB(bptr, n_m, _nrhs); + // update - const ordinal_type - sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; - for (ordinal_type i=sbeg,ip=0/*is=0*/;i + KOKKOS_INLINE_FUNCTION void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; - template - KOKKOS_INLINE_FUNCTION - void operator()(const SolveTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - const ordinal_type sid = _level_sids(p); - const ordinal_type mode = _compute_mode(sid); - if (p < _pend && mode == 1) { - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - solve(member, s, bptr); - } if (mode == -1) { - printf("Error: TeamFunctorSolveLowerChol::SolveTag, computing mode is not determined\n"); - } else { - // skip + const value_type one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0 && n > 0) { + value_type *aptr = s.u_buf; + + UnmanagedViewType AT(aptr, m, n); + UnmanagedViewType b(bptr, n, _nrhs); + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + UnmanagedViewType bT(bptr, m, _nrhs); + + if (!s.do_not_apply_pivots) { + ConstUnmanagedViewType perm(_piv.data() + 4 * offm + 2 * m, m); + Copy::invoke(member, bT, tT); + member.team_barrier(); + + ApplyPermutation::invoke(member, bT, perm, tT); + member.team_barrier(); + } + Gemv::invoke(member, one, AT, tT, zero, b); } } + } + + template + KOKKOS_INLINE_FUNCTION void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + UnmanagedViewType b(bptr, n, _nrhs); + if (m > 0) { + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); - template - KOKKOS_INLINE_FUNCTION - void operator()(const UpdateTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - if (p < _pend) { - const ordinal_type sid = _level_sids(p); - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - update(member, s, bptr); - } else { - // skip + // copy to t + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m * _nrhs), + [&, m](const ordinal_type &k) { /// compiler bug with c++14 lambda capturing and workaround + const ordinal_type i = k % m, j = k / m; + tT(i, j) = b(i, j); + }); + + if (n_m > 0) { + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, ip, m, tbeg, + tcnt](const ordinal_type &ii) { /// compiler bug with c++14 lambda capturing and workaround + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it struct SolveTag { + enum { variant = Var }; + }; + template struct UpdateTag { + enum { variant = Var }; + }; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const SolveTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + using solve_tag_type = SolveTag; - template - KOKKOS_INLINE_FUNCTION - void operator()(const DummyTag &, const MemberType &member) const { - // do nothing + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (solve_tag_type::variant == 0) { + solve_var0(member, s, bptr); + } else if (solve_tag_type::variant == 1) { + solve_var1(member, s, bptr); + } else if (solve_tag_type::variant == 2) { + solve_var2(member, s, bptr); + } + } + if (mode == -1) { + printf("Error: TeamFunctorSolveLowerChol::SolveTag, computing mode is not determined\n"); + } else { + // skip } + } - }; -} + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + if (p < _pend) { + using update_tag_type = UpdateTag; + + const ordinal_type sid = _level_sids(p); + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (update_tag_type::variant == 0) { + update_var0(member, s, bptr); + } else if (update_tag_type::variant == 1) { + update_var1(member, s, bptr); + } else if (update_tag_type::variant == 2) { + update_var2(member, s, bptr); + } + } else { + // skip + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const { + // do nothing + } +}; +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerLU.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerLU.hpp new file mode 100644 index 000000000000..6df798f2cfa1 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveLowerLU.hpp @@ -0,0 +1,340 @@ +#ifndef __TACHO_TEAMFUNCTOR_SOLVE_LOWER_LU_HPP__ +#define __TACHO_TEAMFUNCTOR_SOLVE_LOWER_LU_HPP__ + +/// \file Tacho_TeamFunctor_SolveLowerLU.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" + +#include "Tacho_SupernodeInfo.hpp" + +namespace Tacho { + +template struct TeamFunctor_SolveLowerLU { +public: + using range_type = Kokkos::pair; + + using supernode_info_type = SupernodeInfoType; + using supernode_type = typename supernode_info_type::supernode_type; + using supernode_type_array = typename supernode_info_type::supernode_type_array; + + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using size_type_array = typename supernode_info_type::size_type_array; + + using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; + + using value_type = typename supernode_info_type::value_type; + using value_type_array = typename supernode_info_type::value_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + +private: + ConstUnmanagedViewType _supernodes; + ConstUnmanagedViewType _sid_block_colidx; + ConstUnmanagedViewType _gid_colidx; + + ConstUnmanagedViewType _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + ConstUnmanagedViewType _piv; + + UnmanagedViewType _t; + ordinal_type _nrhs; + + UnmanagedViewType _buf_ptr; + UnmanagedViewType _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveLowerLU() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveLowerLU(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const ordinal_type_array &piv, + const value_type_matrix t, const value_type_array buf) + : _supernodes(info.supernodes), _sid_block_colidx(info.sid_block_colidx), _gid_colidx(info.gid_colidx), + _compute_mode(compute_mode), _level_sids(level_sids), _piv(piv), _t(t), _nrhs(t.extent(1)), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Algorithm Variant 0: trsv - gemv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + // solve + UnmanagedViewType ATL(s.u_buf, m, m); + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + auto fpiv = ConstUnmanagedViewType(_piv.data() + 4 * offm + m, m); + + if (!s.do_not_apply_pivots) { + ApplyPivots /// row inter-change + ::invoke(member, fpiv, tT); + } + Trsv::invoke(member, Diag::Unit(), ATL, tT); + + if (n_m > 0) { + // update + member.team_barrier(); + UnmanagedViewType AL(s.l_buf, n, m); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + UnmanagedViewType bB(bptr, n_m, _nrhs); + Gemv::invoke(member, minus_one, ABL, tT, zero, bB); + } + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + UnmanagedViewType bB(bptr, n_m, _nrhs); + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, tbeg, + ip](const ordinal_type &ii) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it + KOKKOS_INLINE_FUNCTION void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type one(1), minus_one(-1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + // solve + UnmanagedViewType AL(s.l_buf, n, m); + const auto ATL = Kokkos::subview(AL, range_type(0, m), Kokkos::ALL()); + UnmanagedViewType bT(bptr, m, _nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + const auto perm = ConstUnmanagedViewType(_piv.data() + 4 * offm + 2 * m, m); + + if (s.do_not_apply_pivots) { + Copy::invoke(member, bT, tT); + } else { + ApplyPermutation::invoke(member, tT, perm, bT); + } + member.team_barrier(); + + Gemv::invoke(member, one, ATL, bT, zero, tT); + + if (n_m > 0) { + // update + member.team_barrier(); + const auto ABL = Kokkos::subview(AL, range_type(m, n), Kokkos::ALL()); + UnmanagedViewType bB(bptr + bT.span(), n_m, _nrhs); + Gemv::invoke(member, minus_one, ABL, tT, zero, bB); + } + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + UnmanagedViewType bT(bptr, m, _nrhs); + bptr += m * _nrhs; + + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + if (n_m > 0) { + UnmanagedViewType bB(bptr, n_m, _nrhs); + + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, tbeg]( + const ordinal_type &ii) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it + KOKKOS_INLINE_FUNCTION void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0) { + // solve + UnmanagedViewType AL(s.l_buf, n, m); + UnmanagedViewType b(bptr, n, _nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + if (!s.do_not_apply_pivots) { + UnmanagedViewType bT(bptr, m, _nrhs); + ConstUnmanagedViewType perm(_piv.data() + 4 * offm + 2 * m, m); + Copy::invoke(member, bT, tT); + ApplyPermutation::invoke(member, bT, perm, tT); + member.team_barrier(); + } + Gemv::invoke(member, one, AL, tT, zero, b); + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + UnmanagedViewType b(bptr, n, _nrhs); + if (m > 0) { + const ordinal_type offm = s.row_begin; + auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + // copy to t + Kokkos::parallel_for(Kokkos::TeamVectorRange(member, m * _nrhs), + [&, m](const ordinal_type &k) { /// compiler bug with c++14 lambda capturing and workaround + const ordinal_type i = k % m, j = k / m; + tT(i, j) = b(i, j); + }); + + if (n_m > 0) { + // update + const ordinal_type sbeg = s.sid_col_begin + 1, send = s.sid_col_end - 1; + for (ordinal_type i = sbeg, ip = 0 /*is=0*/; i < send; ++i) { + const ordinal_type tbeg = _sid_block_colidx(i).second, tend = _sid_block_colidx(i + 1).second, + tcnt = tend - tbeg; + + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, tcnt), + [&, ip, m, tbeg, + tcnt](const ordinal_type &ii) { /// compiler bug with c++14 lambda capturing and workaround + const ordinal_type it = tbeg + ii; + const ordinal_type is = ip + ii; + // for (ordinal_type it=tbeg;it struct SolveTag { + enum { variant = Var }; + }; + template struct UpdateTag { + enum { variant = Var }; + }; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const SolveTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + using solve_tag_type = SolveTag; + + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (solve_tag_type::variant == 0) { + solve_var0(member, s, bptr); + } else if (solve_tag_type::variant == 1) { + solve_var1(member, s, bptr); + } else if (solve_tag_type::variant == 2) { + solve_var2(member, s, bptr); + } + } + if (mode == -1) { + printf("Error: TeamFunctorSolveLowerChol::SolveTag, computing mode is not determined\n"); + } else { + // skip + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + if (p < _pend) { + using update_tag_type = UpdateTag; + + const ordinal_type sid = _level_sids(p); + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (update_tag_type::variant == 0) { + update_var0(member, s, bptr); + } else if (update_tag_type::variant == 1) { + update_var1(member, s, bptr); + } else if (update_tag_type::variant == 2) { + update_var2(member, s, bptr); + } + } else { + // skip + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const { + // do nothing + } +}; +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperChol.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperChol.hpp index 83e6c0a0dd39..d61287e173d0 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperChol.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperChol.hpp @@ -10,292 +10,276 @@ namespace Tacho { - template - struct TeamFunctor_SolveUpperChol { - public: - typedef Kokkos::pair range_type; - - typedef SupernodeInfoType supernode_info_type; - typedef typename supernode_info_type::supernode_type supernode_type; - typedef typename supernode_info_type::supernode_type_array supernode_type_array; - - typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; - typedef typename supernode_info_type::size_type_array size_type_array; - - typedef typename supernode_info_type::ordinal_pair_type_array ordinal_pair_type_array; - - typedef typename supernode_info_type::value_type value_type; - typedef typename supernode_info_type::value_type_array value_type_array; - typedef typename supernode_info_type::value_type_matrix value_type_matrix; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type TrsvAlgoType; - - typedef typename std::conditional - ::value, - Algo::External,Algo::Internal>::type GemvAlgoType; - - private: - ConstUnmanagedViewType _supernodes; - ConstUnmanagedViewType _gid_colidx; - - ConstUnmanagedViewType _compute_mode, _level_sids; - ordinal_type _pbeg, _pend; - - UnmanagedViewType _t; - ordinal_type _nrhs; - - UnmanagedViewType _buf_ptr; - UnmanagedViewType _buf; - - public: - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveUpperChol() = delete; - - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveUpperChol(const supernode_info_type &info, - const ordinal_type_array &compute_mode, - const ordinal_type_array &level_sids, - const value_type_matrix t, - const value_type_array buf) - : - _supernodes(info.supernodes), - _gid_colidx(info.gid_colidx), - _compute_mode(compute_mode), - _level_sids(level_sids), - _t(t), - _nrhs(t.extent(1)), - _buf(buf) - {} - - inline - void setRange(const ordinal_type pbeg, - const ordinal_type pend) { - _pbeg = pbeg; _pend = pend; - } - - inline - void setBufferPtr(const size_type_array &buf_ptr) { - _buf_ptr = buf_ptr; - } +template struct TeamFunctor_SolveUpperChol { +public: + typedef Kokkos::pair range_type; + + typedef SupernodeInfoType supernode_info_type; + typedef typename supernode_info_type::supernode_type supernode_type; + typedef typename supernode_info_type::supernode_type_array supernode_type_array; + + typedef typename supernode_info_type::ordinal_type_array ordinal_type_array; + typedef typename supernode_info_type::size_type_array size_type_array; + + typedef typename supernode_info_type::ordinal_pair_type_array ordinal_pair_type_array; + + typedef typename supernode_info_type::value_type value_type; + typedef typename supernode_info_type::value_type_array value_type_array; + typedef typename supernode_info_type::value_type_matrix value_type_matrix; + +private: + ConstUnmanagedViewType _supernodes; + ConstUnmanagedViewType _gid_colidx; + + ConstUnmanagedViewType _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + UnmanagedViewType _t; + ordinal_type _nrhs; + + UnmanagedViewType _buf_ptr; + UnmanagedViewType _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveUpperChol() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveUpperChol(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const value_type_matrix t, + const value_type_array buf) + : _supernodes(info.supernodes), _gid_colidx(info.gid_colidx), _compute_mode(compute_mode), + _level_sids(level_sids), _t(t), _nrhs(t.extent(1)), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Algorithm Variant 0: gemv - trsv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + using TrsvAlgoType = typename TrsvAlgorithm::type; + + const value_type minus_one(-1), one(1); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + // solve + const UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); - /// - /// Algorithm Variant 0: gemv - trsv - /// - template - KOKKOS_INLINE_FUNCTION - void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type minus_one(-1), one(1); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - // solve - const UnmanagedViewType AL(aptr, m, m); aptr += m*m; - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - if (n_m > 0) { - // update - const UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n; - const UnmanagedViewType bB(bptr, n_m, _nrhs); - Gemv - ::invoke(member, minus_one, AR, bB, one, tT); - member.team_barrier(); - } - Trsv - ::invoke(member, Diag::NonUnit(), AL, tT); - } - } - } - - template - KOKKOS_INLINE_FUNCTION - void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; if (n_m > 0) { // update + const UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; const UnmanagedViewType bB(bptr, n_m, _nrhs); - const ordinal_type goffset = s.gid_col_begin + s.m; - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n_m), - [&, goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - //for (ordinal_type i=0;i::invoke(member, minus_one, ATR, bB, one, tT); + member.team_barrier(); } + Trsv::invoke(member, Diag::NonUnit(), ATL, tT); } } - - /// - /// Algorithm Variant 1: gemv - gemv - /// - template - KOKKOS_INLINE_FUNCTION - void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type minus_one(-1), one(1), zero(0); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - - const UnmanagedViewType AL(aptr, m, m); aptr += m*m; - const UnmanagedViewType bT(bptr, m, _nrhs); bptr += m*_nrhs; - - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - if (n_m > 0) { - // update - const UnmanagedViewType AR(aptr, m, n_m); - const UnmanagedViewType bB(bptr, n_m, _nrhs); - - Gemv - ::invoke(member, minus_one, AR, bB, one, tT); - member.team_barrier(); - } - Gemv - ::invoke(member, one, AL, tT, zero, bT); - member.team_barrier(); - // copy to t - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, m*_nrhs), - [&, m](const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - const ordinal_type i = k%m, j = k/m; - tT(i,j) = bT(i,j); + } + + template + KOKKOS_INLINE_FUNCTION void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + // update + const UnmanagedViewType bB(bptr, n_m, _nrhs); + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n_m), + [&, + goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + // for (ordinal_type i=0;i - KOKKOS_INLINE_FUNCTION - void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (n_m > 0) { - UnmanagedViewType bB(bptr+m*_nrhs, n_m, _nrhs); + /// + /// Algorithm Variant 1: gemv - gemv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + const UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + const UnmanagedViewType bT(bptr, m, _nrhs); + bptr += m * _nrhs; + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + if (n_m > 0) { // update - const ordinal_type goffset = s.gid_col_begin + s.m; - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n_m), - [&, goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - //for (ordinal_type i=0;i ATR(aptr, m, n_m); + const UnmanagedViewType bB(bptr, n_m, _nrhs); + + Gemv::invoke(member, minus_one, ATR, bB, one, tT); + member.team_barrier(); } + Gemv::invoke(member, one, ATL, tT, zero, bT); + member.team_barrier(); + // copy to t + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, m * _nrhs), + [&, m](const ordinal_type &k) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + const ordinal_type i = k % m, j = k / m; + tT(i, j) = bT(i, j); + }); } } - - /// - /// Algorithm Variant 2: gemv - /// - template - KOKKOS_INLINE_FUNCTION - void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type one(1), zero(0); - { - const ordinal_type m = s.m, n = s.n; - if (m > 0 && n > 0) { - value_type *aptr = s.buf; - - const UnmanagedViewType A(aptr, m, n); - const UnmanagedViewType b(bptr, n, _nrhs); - - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - Gemv - ::invoke(member, one, A, b, zero, tT); - } + } + + template + KOKKOS_INLINE_FUNCTION void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + UnmanagedViewType bB(bptr + m * _nrhs, n_m, _nrhs); + + // update + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n_m), + [&, + goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + // for (ordinal_type i=0;i + KOKKOS_INLINE_FUNCTION void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0 && n > 0) { + value_type *aptr = s.u_buf; + + const UnmanagedViewType AT(aptr, m, n); + const UnmanagedViewType b(bptr, n, _nrhs); - template - KOKKOS_INLINE_FUNCTION - void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n; - UnmanagedViewType b(bptr, n, _nrhs); - if (n > 0) { - const ordinal_type offm = s.row_begin; - const ordinal_type goffset = s.gid_col_begin + s.m; - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n), - [&,m,goffset,offm](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - for (ordinal_type j=0;j<_nrhs;++j) { + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + Gemv::invoke(member, one, AT, b, zero, tT); + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n; + UnmanagedViewType b(bptr, n, _nrhs); + if (n > 0) { + const ordinal_type offm = s.row_begin; + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n), + [&, m, goffset, + offm](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + for (ordinal_type j = 0; j < _nrhs; ++j) { if (i < m) { - b(i,j) = _t(offm+i,j); + b(i, j) = _t(offm + i, j); } else { - const ordinal_type row = _gid_colidx(i-m+goffset); - b(i,j) = _t(row,j); + const ordinal_type row = _gid_colidx(i - m + goffset); + b(i, j) = _t(row, j); } } }); - } } } + } - template struct SolveTag { enum { variant = Var }; }; - template struct UpdateTag { enum { variant = Var }; }; - struct DummyTag {}; - - template - KOKKOS_INLINE_FUNCTION - void operator()(const SolveTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - const ordinal_type sid = _level_sids(p); - const ordinal_type mode = _compute_mode(sid); - if (p < _pend && mode == 1) { - typedef SolveTag solve_tag_type; - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - if (solve_tag_type::variant == 0) solve_var0(member, s, bptr); - else if (solve_tag_type::variant == 1) solve_var1(member, s, bptr); - else if (solve_tag_type::variant == 2) solve_var2(member, s, bptr); - else - printf("Error: TeamFunctorSolveUpperChol::SolveTag, algorithm variant is not supported\n"); - } else if (mode == -1) { - printf("Error: TeamFunctorSolveUpperChol::SolveTag, computing mode is not determined\n"); - } else { - // skip - } - } - - template - KOKKOS_INLINE_FUNCTION - void operator()(const UpdateTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - if (p < _pend) { - const ordinal_type sid = _level_sids(p); - typedef UpdateTag update_tag_type; - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - if (update_tag_type::variant == 0) update_var0(member, s, bptr); - else if (update_tag_type::variant == 1) update_var1(member, s, bptr); - else if (update_tag_type::variant == 2) update_var2(member, s, bptr); - else - printf("Error: TeamFunctorUpdateUpperChol::SolveTag, algorithm variant is not supported\n"); - } else { - // skip - } + template struct SolveTag { + enum { variant = Var }; + }; + template struct UpdateTag { + enum { variant = Var }; + }; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const SolveTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + typedef SolveTag solve_tag_type; + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (solve_tag_type::variant == 0) + solve_var0(member, s, bptr); + else if (solve_tag_type::variant == 1) + solve_var1(member, s, bptr); + else if (solve_tag_type::variant == 2) + solve_var2(member, s, bptr); + else + printf("Error: TeamFunctorSolveUpperChol::SolveTag, algorithm variant is not supported\n"); + } else if (mode == -1) { + printf("Error: TeamFunctorSolveUpperChol::SolveTag, computing mode is not determined\n"); + } else { + // skip } + } - template - KOKKOS_INLINE_FUNCTION - void operator()(const DummyTag &, const MemberType &member) const { - + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + if (p < _pend) { + const ordinal_type sid = _level_sids(p); + typedef UpdateTag update_tag_type; + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (update_tag_type::variant == 0) + update_var0(member, s, bptr); + else if (update_tag_type::variant == 1) + update_var1(member, s, bptr); + else if (update_tag_type::variant == 2) + update_var2(member, s, bptr); + else + printf("Error: TeamFunctorUpdateUpperChol::SolveTag, algorithm variant is not supported\n"); + } else { + // skip } + } - }; + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const {} +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperLDL.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperLDL.hpp index 723c5b9bd934..b4b6f99f6c15 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperLDL.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperLDL.hpp @@ -10,184 +10,305 @@ namespace Tacho { - template - struct TeamFunctor_SolveUpperLDL { - public: - using range_type = Kokkos::pair; - - using supernode_info_type = SupernodeInfoType; - using supernode_type = typename supernode_info_type::supernode_type; - using supernode_type_array = typename supernode_info_type::supernode_type_array; - - using ordinal_type_array = typename supernode_info_type::ordinal_type_array; - using size_type_array = typename supernode_info_type::size_type_array; - - using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; - - using value_type = typename supernode_info_type::value_type; - using value_type_array = typename supernode_info_type::value_type_array; - using value_type_matrix = typename supernode_info_type::value_type_matrix; - - using MainAlgoType = typename std::conditional - ::value, - Algo::External,Algo::Internal>::type; - using TrsvAlgoType = MainAlgoType; - using GemvAlgoType = MainAlgoType; - - private: - ConstUnmanagedViewType _supernodes; - ConstUnmanagedViewType _gid_colidx; - - ConstUnmanagedViewType _compute_mode, _level_sids; - ordinal_type _pbeg, _pend; - - ConstUnmanagedViewType _piv; - ConstUnmanagedViewType _diag; - UnmanagedViewType _t; - ordinal_type _nrhs; - - UnmanagedViewType _buf_ptr; - UnmanagedViewType _buf; - - public: - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveUpperLDL() = delete; - - KOKKOS_INLINE_FUNCTION - TeamFunctor_SolveUpperLDL(const supernode_info_type &info, - const ordinal_type_array &compute_mode, - const ordinal_type_array &level_sids, - const ordinal_type_array &piv, - const value_type_array &diag, - const value_type_matrix &t, - const value_type_array &buf) - : - _supernodes(info.supernodes), - _gid_colidx(info.gid_colidx), - _compute_mode(compute_mode), - _level_sids(level_sids), - _piv(piv), - _diag(diag), - _t(t), - _nrhs(t.extent(1)), - _buf(buf) - {} - - inline - void setRange(const ordinal_type pbeg, - const ordinal_type pend) { - _pbeg = pbeg; _pend = pend; - } +template struct TeamFunctor_SolveUpperLDL { +public: + using range_type = Kokkos::pair; + + using supernode_info_type = SupernodeInfoType; + using supernode_type = typename supernode_info_type::supernode_type; + using supernode_type_array = typename supernode_info_type::supernode_type_array; + + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using size_type_array = typename supernode_info_type::size_type_array; + + using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; + + using value_type = typename supernode_info_type::value_type; + using value_type_array = typename supernode_info_type::value_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + +private: + ConstUnmanagedViewType _supernodes; + ConstUnmanagedViewType _gid_colidx; + + ConstUnmanagedViewType _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + ConstUnmanagedViewType _piv; + ConstUnmanagedViewType _diag; + UnmanagedViewType _t; + ordinal_type _nrhs; + + UnmanagedViewType _buf_ptr; + UnmanagedViewType _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveUpperLDL() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveUpperLDL(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const ordinal_type_array &piv, + const value_type_array &diag, const value_type_matrix &t, const value_type_array &buf) + : _supernodes(info.supernodes), _gid_colidx(info.gid_colidx), _compute_mode(compute_mode), + _level_sids(level_sids), _piv(piv), _diag(diag), _t(t), _nrhs(t.extent(1)), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Algorithm Variant 0: gemv - trsv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), one(1); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + // solve + const UnmanagedViewType ATL(aptr, m, m); + aptr += m * m; + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + ConstUnmanagedViewType P(_piv.data() + offm * 4, m * 4); + ConstUnmanagedViewType D(_diag.data() + offm * 2, m, 2); + + Scale2x2_BlockInverseDiagonals::invoke(member, P, D, tT); - inline - void setBufferPtr(const size_type_array &buf_ptr) { - _buf_ptr = buf_ptr; - } + if (n_m > 0) { + // update + const UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; + const UnmanagedViewType bB(bptr, n_m, _nrhs); + Gemv::invoke(member, minus_one, ATR, bB, one, tT); + member.team_barrier(); + } + Trsv::invoke(member, Diag::Unit(), ATL, tT); - /// - /// Algorithm Variant 0: gemv - trsv - /// - template - KOKKOS_INLINE_FUNCTION - void solve(MemberType &member, const supernode_type &s, value_type *bptr) const { - const value_type minus_one(-1), one(1); - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; - if (m > 0) { - value_type *aptr = s.buf; - // solve - const UnmanagedViewType AL(aptr, m, m); aptr += m*m; - const ordinal_type offm = s.row_begin; - const auto tT = Kokkos::subview(_t, range_type(offm, offm+m), Kokkos::ALL()); - - ConstUnmanagedViewType P(_piv.data()+offm*4, m*4); - ConstUnmanagedViewType D(_diag.data()+offm*2, m, 2); - - Scale2x2_BlockInverseDiagonals - ::invoke(member, P, D, tT); - - if (n_m > 0) { - // update - const UnmanagedViewType AR(aptr, m, n_m); // aptr += m*n; - const UnmanagedViewType bB(bptr, n_m, _nrhs); - Gemv - ::invoke(member, minus_one, AR, bB, one, tT); - member.team_barrier(); - } - Trsv - ::invoke(member, Diag::Unit(), AL, tT); - - ConstUnmanagedViewType fpiv(P.data()+m, m); - ApplyPivots /// row inter-change - ::invoke(member, fpiv, tT); + if (!s.do_not_apply_pivots) { + ConstUnmanagedViewType fpiv(P.data() + m, m); + ApplyPivots /// row inter-change + ::invoke(member, fpiv, tT); } } } + } + + template + KOKKOS_INLINE_FUNCTION void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + // update + const UnmanagedViewType bB(bptr, n_m, _nrhs); + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n_m), + [&, + goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + // for (ordinal_type i=0;i + KOKKOS_INLINE_FUNCTION void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *aptr = s.u_buf; + // solve + const UnmanagedViewType ATL(aptr, m, m); + aptr += ATL.span(); + const UnmanagedViewType bT(bptr, m, _nrhs); + bptr += bT.span(); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + ConstUnmanagedViewType P(_piv.data() + offm * 4, m * 4); + ConstUnmanagedViewType D(_diag.data() + offm * 2, m, 2); + + Scale2x2_BlockInverseDiagonals::invoke(member, P, D, tT); - template - KOKKOS_INLINE_FUNCTION - void update(MemberType &member, const supernode_type &s, value_type *bptr) const { - { - const ordinal_type m = s.m, n = s.n, n_m = n-m; if (n_m > 0) { // update + const UnmanagedViewType ATR(aptr, m, n_m); // aptr += m*n; const UnmanagedViewType bB(bptr, n_m, _nrhs); - const ordinal_type goffset = s.gid_col_begin + s.m; - Kokkos::parallel_for - (Kokkos::TeamVectorRange(member, n_m), - [&, goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - //for (ordinal_type i=0;i::invoke(member, minus_one, ATR, bB, one, tT); + member.team_barrier(); + } + Gemv::invoke(member, one, ATL, tT, zero, bT); + member.team_barrier(); + + if (s.do_not_apply_pivots) { + Copy::invoke(member, tT, bT); + } else { + ConstUnmanagedViewType peri(P.data() + 3 * m, m); + ApplyPermutation::invoke(member, bT, peri, tT); } } } + } + + template + KOKKOS_INLINE_FUNCTION void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + UnmanagedViewType bB(bptr + m * _nrhs, n_m, _nrhs); + + // update + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n_m), + [&, + goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + // for (ordinal_type i=0;i + KOKKOS_INLINE_FUNCTION void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; - template - KOKKOS_INLINE_FUNCTION - void operator()(const SolveTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - const ordinal_type sid = _level_sids(p); - const ordinal_type mode = _compute_mode(sid); - if (p < _pend && mode == 1) { - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - solve(member, s, bptr); - } else if (mode == -1) { - printf("Error: TeamFunctorSolveUpperChol::SolveTag, computing mode is not determined\n"); - } else { - // skip + const ordinal_type nrhs = _t.extent(1); + const value_type one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0 && n > 0) { + value_type *aptr = s.u_buf; + + const UnmanagedViewType AT(aptr, m, n); + const UnmanagedViewType b(bptr, n, nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + const UnmanagedViewType bT(bptr, m, nrhs); + + ConstUnmanagedViewType P(_piv.data() + offm * 4, m * 4); + ConstUnmanagedViewType D(_diag.data() + offm * 2, m, 2); + + Scale2x2_BlockInverseDiagonals::invoke(member, P, D, bT); + member.team_barrier(); + + Gemv::invoke(member, one, AT, b, zero, tT); + member.team_barrier(); + + if (!s.do_not_apply_pivots) { + Copy::invoke(member, bT, tT); + member.team_barrier(); + + ConstUnmanagedViewType peri(P.data() + 3 * m, m); + ApplyPermutation::invoke(member, bT, peri, tT); + } } } + } + + template + KOKKOS_INLINE_FUNCTION void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n; + UnmanagedViewType b(bptr, n, _nrhs); + if (n > 0) { + const ordinal_type offm = s.row_begin; + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n), + [&, m, goffset, + offm](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + for (ordinal_type j = 0; j < _nrhs; ++j) { + if (i < m) { + b(i, j) = _t(offm + i, j); + } else { + const ordinal_type row = _gid_colidx(i - m + goffset); + b(i, j) = _t(row, j); + } + } + }); + } + } + } - template - KOKKOS_INLINE_FUNCTION - void operator()(const UpdateTag &, const MemberType &member) const { - const ordinal_type p = _pbeg + member.league_rank(); - if (p < _pend) { - const ordinal_type sid = _level_sids(p); - const supernode_type &s = _supernodes(sid); - value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); - update(member, s, bptr); - } else { - // skip + template struct SolveTag { + enum { variant = Var }; + }; + template struct UpdateTag { + enum { variant = Var }; + }; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const SolveTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + using solve_tag_type = SolveTag; + + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (solve_tag_type::variant == 0) { + solve_var0(member, s, bptr); + } else if (solve_tag_type::variant == 1) { + solve_var1(member, s, bptr); + } else if (solve_tag_type::variant == 2) { + solve_var2(member, s, bptr); } + } else if (mode == -1) { + printf("Error: TeamFunctorSolveUpperChol::SolveTag, computing mode is not determined\n"); + } else { + // skip } + } - template - KOKKOS_INLINE_FUNCTION - void operator()(const DummyTag &, const MemberType &member) const { + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + if (p < _pend) { + using update_tag_type = UpdateTag; + const ordinal_type sid = _level_sids(p); + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + + if (update_tag_type::variant == 0) { + update_var0(member, s, bptr); + } else if (update_tag_type::variant == 1) { + update_var1(member, s, bptr); + } else if (update_tag_type::variant == 2) { + update_var2(member, s, bptr); + } + } else { + // skip } + } - }; + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const {} +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperLU.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperLU.hpp new file mode 100644 index 000000000000..5696a9ed1f7d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_TeamFunctor_SolveUpperLU.hpp @@ -0,0 +1,275 @@ +#ifndef __TACHO_TEAMFUNCTOR_SOLVE_UPPER_LU_HPP__ +#define __TACHO_TEAMFUNCTOR_SOLVE_UPPER_LU_HPP__ + +/// \file Tacho_TeamFunctor_SolveUpperLU.hpp +/// \author Kyungjoo Kim (kyukim@sandia.gov) + +#include "Tacho_Util.hpp" + +#include "Tacho_SupernodeInfo.hpp" + +namespace Tacho { + +template struct TeamFunctor_SolveUpperLU { +public: + using range_type = Kokkos::pair; + + using supernode_info_type = SupernodeInfoType; + using supernode_type = typename supernode_info_type::supernode_type; + using supernode_type_array = typename supernode_info_type::supernode_type_array; + + using ordinal_type_array = typename supernode_info_type::ordinal_type_array; + using size_type_array = typename supernode_info_type::size_type_array; + + using ordinal_pair_type_array = typename supernode_info_type::ordinal_pair_type_array; + + using value_type = typename supernode_info_type::value_type; + using value_type_array = typename supernode_info_type::value_type_array; + using value_type_matrix = typename supernode_info_type::value_type_matrix; + +private: + ConstUnmanagedViewType _supernodes; + ConstUnmanagedViewType _gid_colidx; + + ConstUnmanagedViewType _compute_mode, _level_sids; + ordinal_type _pbeg, _pend; + + UnmanagedViewType _t; + ordinal_type _nrhs; + + UnmanagedViewType _buf_ptr; + UnmanagedViewType _buf; + +public: + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveUpperLU() = delete; + + KOKKOS_INLINE_FUNCTION + TeamFunctor_SolveUpperLU(const supernode_info_type &info, const ordinal_type_array &compute_mode, + const ordinal_type_array &level_sids, const value_type_matrix &t, + const value_type_array &buf) + : _supernodes(info.supernodes), _gid_colidx(info.gid_colidx), _compute_mode(compute_mode), + _level_sids(level_sids), _t(t), _nrhs(t.extent(1)), _buf(buf) {} + + inline void setRange(const ordinal_type pbeg, const ordinal_type pend) { + _pbeg = pbeg; + _pend = pend; + } + + inline void setBufferPtr(const size_type_array &buf_ptr) { _buf_ptr = buf_ptr; } + + /// + /// Algorithm Variant 0: gemv - trsv + /// + template + KOKKOS_INLINE_FUNCTION void solve_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + using TrsvAlgoType = typename TrsvAlgorithm::type; + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), one(1); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + value_type *uptr = s.u_buf; + // solve + const UnmanagedViewType ATL(uptr, m, m); + uptr += m * m; + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + if (n_m > 0) { + // update + const UnmanagedViewType ATR(uptr, m, n_m); // uptr += m*n; + const UnmanagedViewType bB(bptr, n_m, _nrhs); + Gemv::invoke(member, minus_one, ATR, bB, one, tT); + member.team_barrier(); + } + Trsv::invoke(member, Diag::NonUnit(), ATL, tT); + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update_var0(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + // update + const UnmanagedViewType bB(bptr, n_m, _nrhs); + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n_m), + [&, + goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + // for (ordinal_type i=0;i + KOKKOS_INLINE_FUNCTION void solve_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type minus_one(-1), one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (m > 0) { + // solve + const UnmanagedViewType ATL(s.u_buf, m, m); + const UnmanagedViewType bT(bptr, m, _nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + Copy::invoke(member, bT, tT); + member.team_barrier(); + + if (n_m > 0) { + // update + const UnmanagedViewType ATR(s.u_buf + ATL.span(), m, n_m); + const UnmanagedViewType bB(bptr + bT.span(), n_m, _nrhs); + Gemv::invoke(member, minus_one, ATR, bB, one, bT); + member.team_barrier(); + } + Gemv::invoke(member, one, ATL, bT, zero, tT); + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update_var1(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n, n_m = n - m; + if (n_m > 0) { + UnmanagedViewType bB(bptr + m * _nrhs, n_m, _nrhs); + + // update + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n_m), + [&, + goffset](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + // for (ordinal_type i=0;i + KOKKOS_INLINE_FUNCTION void solve_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + using GemvAlgoType = typename GemvAlgorithm::type; + + const value_type one(1), zero(0); + { + const ordinal_type m = s.m, n = s.n; + if (m > 0) { + // solve + const UnmanagedViewType AT(s.u_buf, m, n); + const UnmanagedViewType b(bptr, n, _nrhs); + + const ordinal_type offm = s.row_begin; + const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL()); + + Gemv::invoke(member, one, AT, b, zero, tT); + } + } + } + + template + KOKKOS_INLINE_FUNCTION void update_var2(MemberType &member, const supernode_type &s, value_type *bptr) const { + { + const ordinal_type m = s.m, n = s.n; + UnmanagedViewType b(bptr, n, _nrhs); + if (n > 0) { + const ordinal_type offm = s.row_begin; + const ordinal_type goffset = s.gid_col_begin + s.m; + Kokkos::parallel_for( + Kokkos::TeamVectorRange(member, n), + [&, m, goffset, + offm](const ordinal_type &i) { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 + for (ordinal_type j = 0; j < _nrhs; ++j) { + if (i < m) { + b(i, j) = _t(offm + i, j); + } else { + const ordinal_type row = _gid_colidx(i - m + goffset); + b(i, j) = _t(row, j); + } + } + }); + } + } + } + + template struct SolveTag { + enum { variant = Var }; + }; + template struct UpdateTag { + enum { variant = Var }; + }; + struct DummyTag {}; + + template + KOKKOS_INLINE_FUNCTION void operator()(const SolveTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + const ordinal_type sid = _level_sids(p); + const ordinal_type mode = _compute_mode(sid); + if (p < _pend && mode == 1) { + using solve_tag_type = SolveTag; + + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (solve_tag_type::variant == 0) { + solve_var0(member, s, bptr); + } else if (solve_tag_type::variant == 1) { + solve_var1(member, s, bptr); + } else if (solve_tag_type::variant == 2) { + solve_var2(member, s, bptr); + } + } else if (mode == -1) { + printf("Error: TeamFunctorSolveUpperChol::SolveTag, computing mode is not determined\n"); + } else { + // skip + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const UpdateTag &, const MemberType &member) const { + const ordinal_type p = _pbeg + member.league_rank(); + if (p < _pend) { + using update_tag_type = UpdateTag; + + const ordinal_type sid = _level_sids(p); + const supernode_type &s = _supernodes(sid); + value_type *bptr = _buf.data() + _buf_ptr(member.league_rank()); + if (update_tag_type::variant == 0) { + update_var0(member, s, bptr); + } else if (update_tag_type::variant == 1) { + update_var1(member, s, bptr); + } else if (update_tag_type::variant == 2) { + update_var2(member, s, bptr); + } + } else { + // skip + } + } + + template + KOKKOS_INLINE_FUNCTION void operator()(const DummyTag &, const MemberType &member) const {} +}; + +} // namespace Tacho + +#endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm.hpp index f2efa8acb1d8..3d817d7c8d0f 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm.hpp @@ -7,68 +7,18 @@ #include "Tacho_Util.hpp" namespace Tacho { - - /// - /// Trsm: - /// - /// various implementation for different uplo and algo parameters - template - struct Trsm; +/// +/// Trsm: +/// - /// task construction for the above chol implementation - /// Trsm::invoke(_sched, member, ArgDiag(), _alpha, _A, _B); - template - struct TaskFunctor_Trsm { - public: - typedef SchedulerType scheduler_type; - typedef typename scheduler_type::member_type member_type; +/// various implementation for different uplo and algo parameters +template struct Trsm; - typedef ScalarType scalar_type; +struct TrsmAlgorithm { + using type = ActiveAlgorithm::type; +}; - typedef DenseMatrixViewType dense_block_type; - typedef typename dense_block_type::future_type future_type; - typedef typename future_type::value_type value_type; - - private: - scalar_type _alpha; - dense_block_type _A, _B; - - public: - KOKKOS_INLINE_FUNCTION - TaskFunctor_Trsm() = delete; - - KOKKOS_INLINE_FUNCTION - TaskFunctor_Trsm(const scalar_type alpha, - const dense_block_type &A, - const dense_block_type &B) - : _alpha(alpha), - _A(A), - _B(B) {} - - KOKKOS_INLINE_FUNCTION - void operator()(member_type &member, value_type &r_val) { - const int ierr = Trsm - ::invoke(member, ArgDiag(), _alpha, _A, _B); - - Kokkos::single(Kokkos::PerTeam(member), - [&, ierr] () { // Value capture is a workaround for cuda + gcc-7.2 compiler bug w/c++14 - _B.set_future(); - r_val = ierr; - }); - } - }; - -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_External.hpp index 088721e0f4c4..405568b402e8 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_External.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_TRSM_EXTERNAL_HPP__ #define __TACHO_TRSM_EXTERNAL_HPP__ - /// \file Tacho_Trsm_External.hpp /// \brief BLAS triangular solve matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -9,70 +8,46 @@ #include "Tacho_Blas_External.hpp" namespace Tacho { - - template - struct Trsm { - template - inline - static int - invoke(const DiagType diagA, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and B do not have the same value type."); - - const ordinal_type m = B.extent(0); - const ordinal_type n = B.extent(1); - - if (m > 0 && n > 0) - Blas::trsm(ArgSide::param, - ArgUplo::param, - ArgTransA::param, - diagA.param, - m, n, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1()); +template +struct Trsm { + + template + inline static int invoke(const DiagType diagA, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + + static_assert(std::is_same::value, "A and B do not have the same value type."); + + const ordinal_type m = B.extent(0); + const ordinal_type n = B.extent(1); + + if (m > 0 && n > 0) + Blas::trsm(ArgSide::param, ArgUplo::param, ArgTransA::param, diagA.param, m, n, value_type(alpha), + A.data(), A.stride_1(), B.data(), B.stride_1()); #else - TACHO_TEST_FOR_ABORT( true, "This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, "This function is only allowed in host space."); #endif - return 0; - } - - template - inline - static int - invoke(MemberType &member, - const DiagType diagA, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - //Kokkos::single(Kokkos::PerTeam(member), [&]() { - invoke(diagA, alpha, A, B); - //}); + return 0; + } + + template + inline static int invoke(MemberType &member, const DiagType diagA, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + // Kokkos::single(Kokkos::PerTeam(member), [&]() { + invoke(diagA, alpha, A, B); + //}); #else - TACHO_TEST_FOR_ABORT( true, "This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, "This function is only allowed in host space."); #endif - return 0; - } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_Internal.hpp index f50cb1fc3f98..90f425f87434 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_Internal.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_TRSM_INTERNAL_HPP__ #define __TACHO_TRSM_INTERNAL_HPP__ - /// \file Tacho_Trsm_Internal.hpp /// \brief BLAS triangular solve matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -9,46 +8,29 @@ #include "Tacho_Blas_Team.hpp" namespace Tacho { - - template - struct Trsm { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const DiagType diagA, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and B do not have the same value type."); - - const ordinal_type m = B.extent(0); - const ordinal_type n = B.extent(1); - - if (m > 0 && n > 0) - BlasTeam::trsm(member, - ArgSide::param, - ArgUplo::param, - ArgTransA::param, - diagA.param, - m, n, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1()); - return 0; - } - }; - -} + +template +struct Trsm { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const DiagType diagA, const ScalarType alpha, + const ViewTypeA &A, const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + + static_assert(std::is_same::value, "A and B do not have the same value type."); + + const ordinal_type m = B.extent(0); + const ordinal_type n = B.extent(1); + + if (m > 0 && n > 0) + BlasTeam::trsm(member, ArgSide::param, ArgUplo::param, ArgTransA::param, diagA.param, m, n, + value_type(alpha), A.data(), A.stride_1(), B.data(), B.stride_1()); + return 0; + } +}; + +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_OnDevice.hpp index 8d50ffc27da6..f3ce523438d5 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_OnDevice.hpp @@ -1,110 +1,68 @@ #ifndef __TACHO_TRSM_ON_DEVICE_HPP__ #define __TACHO_TRSM_ON_DEVICE_HPP__ - /// \file Tacho_Trsm_OnDevice.hpp /// \brief BLAS triangular solve matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) namespace Tacho { - - template - struct Trsm { - template - inline - static int - blas_invoke(const DiagType diagA, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type m = B.extent(0); - const ordinal_type n = B.extent(1); - - if (m > 0 && n > 0) - Blas::trsm(ArgSide::param, - ArgUplo::param, - ArgTransA::param, - diagA.param, - m, n, - value_type(alpha), - A.data(), A.stride_1(), - B.data(), B.stride_1()); - return 0; - } - -#if defined (KOKKOS_ENABLE_CUDA) - template - inline - static int - cublas_invoke(cublasHandle_t &handle, - const DiagType diagA, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type m = B.extent(0); - const ordinal_type n = B.extent(1); - - int r_val(0); - if (m > 0 && n > 0) - Blas::trsm(handle, - ArgSide::cublas_param, - ArgUplo::cublas_param, - ArgTransA::cublas_param, - diagA.cublas_param, - m, n, - alpha, - A.data(), A.stride_1(), - B.data(), B.stride_1()); - return r_val; - } -#endif - template - inline - static int - invoke(MemberType &member, - const DiagType diagA, - const ScalarType alpha, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeB::memory_space memory_space_b; +template +struct Trsm { + + template + inline static int blas_invoke(const DiagType diagA, const ScalarType alpha, const ViewTypeA &A, const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = B.extent(0); + const ordinal_type n = B.extent(1); + + if (m > 0 && n > 0) + Blas::trsm(ArgSide::param, ArgUplo::param, ArgTransA::param, diagA.param, m, n, value_type(alpha), + A.data(), A.stride_1(), B.data(), B.stride_1()); + return 0; + } + +#if defined(KOKKOS_ENABLE_CUDA) + template + inline static int cublas_invoke(cublasHandle_t &handle, const DiagType diagA, const ScalarType alpha, + const ViewTypeA &A, const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = B.extent(0); + const ordinal_type n = B.extent(1); + + int r_val(0); + if (m > 0 && n > 0) + Blas::trsm(handle, ArgSide::cublas_param, ArgUplo::cublas_param, ArgTransA::cublas_param, + diagA.cublas_param, m, n, alpha, A.data(), A.stride_1(), B.data(), B.stride_1()); + return r_val; + } +#endif + template + inline static int invoke(MemberType &member, const DiagType diagA, const ScalarType alpha, const ViewTypeA &A, + const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeB::memory_space memory_space_b; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and B do not have the same value type."); + static_assert(std::is_same::value, "A and B do not have the same value type."); - static_assert(std::is_same::value, - "A and B do not have the same memory space."); + static_assert(std::is_same::value, "A and B do not have the same memory space."); - int r_val(0); - if (std::is_same::value) - r_val = blas_invoke(diagA, alpha, A, B); -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) - r_val = cublas_invoke(member, diagA, alpha, A, B); + int r_val(0); + if (std::is_same::value) + r_val = blas_invoke(diagA, alpha, A, B); +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || std::is_same::value) + r_val = cublas_invoke(member, diagA, alpha, A, B); #endif - return r_val; - } - }; + return r_val; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv.hpp index a85486b2ff35..7b49c1cd066c 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv.hpp @@ -7,62 +7,18 @@ #include "Tacho_Util.hpp" namespace Tacho { - - /// - /// Trsv: - /// - /// various implementation for different uplo and algo parameters - template - struct Trsv; +/// +/// Trsv: +/// - /// task construction for the above chol implementation - /// Trsv::invoke(_sched, member, ArgDiag(), _alpha, _A, _B); - template - struct TaskFunctor_Trsv { - public: - typedef SchedulerType scheduler_type; - typedef typename scheduler_type::member_type member_type; +/// various implementation for different uplo and algo parameters +template struct Trsv; - typedef ScalarType scalar_type; +struct TrsvAlgorithm { + using type = ActiveAlgorithm::type; +}; - typedef DenseMatrixViewType dense_block_type; - typedef typename dense_block_type::future_type future_type; - typedef typename future_type::value_type value_type; - - private: - dense_block_type _A, _B; - - public: - KOKKOS_INLINE_FUNCTION - TaskFunctor_Trsv() = delete; - - KOKKOS_INLINE_FUNCTION - TaskFunctor_Trsv(const dense_block_type &A, - const dense_block_type &B) - : _A(A), - _B(B) {} - - KOKKOS_INLINE_FUNCTION - void operator()(member_type &member, value_type &r_val) { - const int ierr = Trsv - ::invoke(member, ArgDiag(),_A, _B); - - Kokkos::single(Kokkos::PerTeam(member), [&] () { - _B.set_future(); - r_val = ierr; - }); - } - }; - -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_External.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_External.hpp index 999c7f0cff10..e1b92411921a 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_External.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_External.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_TRSV_EXTERNAL_HPP__ #define __TACHO_TRSV_EXTERNAL_HPP__ - /// \file Tacho_Trsv_External.hpp /// \brief BLAS triangular solve matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -9,71 +8,48 @@ #include "Tacho_Blas_External.hpp" namespace Tacho { - - template - struct Trsv { - template - inline - static int - invoke(const DiagType diagA, - const ViewTypeA &A, - const ViewTypeB &B) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and B do not have the same value type."); - - const ordinal_type m = B.extent(0), n = B.extent(1); - - if (m > 0 && n > 0) { - if (n == 1) { - Blas::trsv(ArgUplo::param, ArgTransA::param, - diagA.param, - m, - A.data(), A.stride_1(), - B.data(), B.stride_0()); - } else { - Blas::trsm(Side::Left::param, ArgUplo::param, ArgTransA::param, - diagA.param, - m, n, - value_type(1), - A.data(), A.stride_1(), - B.data(), B.stride_1()); - } - } + +template struct Trsv { + template + inline static int invoke(const DiagType diagA, const ViewTypeA &A, const ViewTypeB &B) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + + static_assert(std::is_same::value, "A and B do not have the same value type."); + + const ordinal_type m = B.extent(0), n = B.extent(1); + + if (m > 0 && n > 0) { + if (n == 1) { + Blas::trsv(ArgUplo::param, ArgTransA::param, diagA.param, m, A.data(), A.stride_1(), B.data(), + B.stride_0()); + } else { + Blas::trsm(Side::Left::param, ArgUplo::param, ArgTransA::param, diagA.param, m, n, value_type(1), + A.data(), A.stride_1(), B.data(), B.stride_1()); + } + } #else - TACHO_TEST_FOR_ABORT( true, "This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, "This function is only allowed in host space."); #endif - return 0; - } + return 0; + } - template - inline - static int - invoke(MemberType &member, - const DiagType diagA, - const ViewTypeA &A, - const ViewTypeB &B) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - //Kokkos::single(Kokkos::PerTeam(member), [&]() { - invoke(diagA, A, B); - //}); + template + inline static int invoke(MemberType &member, const DiagType diagA, const ViewTypeA &A, const ViewTypeB &B) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + // Kokkos::single(Kokkos::PerTeam(member), [&]() { + invoke(diagA, A, B); + //}); #else - TACHO_TEST_FOR_ABORT( true, "This function is only allowed in host space."); + TACHO_TEST_FOR_ABORT(true, "This function is only allowed in host space."); #endif - return 0; - } - }; + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_Internal.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_Internal.hpp index bbbc23ca2b24..785ddf30d10f 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_Internal.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_Internal.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_TRSV_INTERNAL_HPP__ #define __TACHO_TRSV_INTERNAL_HPP__ - /// \file Tacho_Trsv_Internal.hpp /// \brief BLAS triangular solve matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -9,53 +8,34 @@ #include "Tacho_Blas_Team.hpp" namespace Tacho { - - template - struct Trsv { - template - KOKKOS_INLINE_FUNCTION - static int - invoke(MemberType &member, - const DiagType diagA, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and B do not have the same value type."); - - const ordinal_type m = B.extent(0); - const ordinal_type n = B.extent(1); - - if (m > 0 && n > 0) { - if (n == 1) { - BlasTeam::trsv(member, - ArgUplo::param, ArgTransA::param, - diagA.param, - m, - A.data(), A.stride_1(), - B.data(), B.stride_0()); - } else { - BlasTeam::trsm(member, - Side::Left::param, ArgUplo::param, ArgTransA::param, - diagA.param, - m, n, - value_type(1), - A.data(), A.stride_1(), - B.data(), B.stride_1()); - } - - } - return 0; + +template struct Trsv { + template + KOKKOS_INLINE_FUNCTION static int invoke(MemberType &member, const DiagType diagA, const ViewTypeA &A, + const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); + + static_assert(std::is_same::value, "A and B do not have the same value type."); + + const ordinal_type m = B.extent(0); + const ordinal_type n = B.extent(1); + + if (m > 0 && n > 0) { + if (n == 1) { + BlasTeam::trsv(member, ArgUplo::param, ArgTransA::param, diagA.param, m, A.data(), A.stride_1(), + B.data(), B.stride_0()); + } else { + BlasTeam::trsm(member, Side::Left::param, ArgUplo::param, ArgTransA::param, diagA.param, m, n, + value_type(1), A.data(), A.stride_1(), B.data(), B.stride_1()); } - }; + } + return 0; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_OnDevice.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_OnDevice.hpp index c0d187a59d09..04642b60dc33 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_OnDevice.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsv_OnDevice.hpp @@ -1,7 +1,6 @@ #ifndef __TACHO_TRSV_ON_DEVICE_HPP__ #define __TACHO_TRSV_ON_DEVICE_HPP__ - /// \file Tacho_Trsv_OnDevice.hpp /// \brief BLAS triangular solve matrix /// \author Kyungjoo Kim (kyukim@sandia.gov) @@ -9,113 +8,69 @@ #include "Tacho_Blas_External.hpp" namespace Tacho { - - template - struct Trsv { - template - inline - static int - blas_invoke(const DiagType diagA, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type m = B.extent(0), n = B.extent(1); - if (m > 0 && n > 0) { - if (n == 1) { - Blas::trsv(ArgUplo::param, ArgTransA::param, - diagA.param, - m, - A.data(), A.stride_1(), - B.data(), B.stride_0()); - } else { - Blas::trsm(Side::Left::param, - ArgUplo::param, - ArgTransA::param, - diagA.param, - m, n, - value_type(1), - A.data(), A.stride_1(), - B.data(), B.stride_1()); - } - } - return 0; + +template struct Trsv { + template + inline static int blas_invoke(const DiagType diagA, const ViewTypeA &A, const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = B.extent(0), n = B.extent(1); + if (m > 0 && n > 0) { + if (n == 1) { + Blas::trsv(ArgUplo::param, ArgTransA::param, diagA.param, m, A.data(), A.stride_1(), B.data(), + B.stride_0()); + } else { + Blas::trsm(Side::Left::param, ArgUplo::param, ArgTransA::param, diagA.param, m, n, value_type(1), + A.data(), A.stride_1(), B.data(), B.stride_1()); } + } + return 0; + } -#if defined (KOKKOS_ENABLE_CUDA) - template - inline - static int - cublas_invoke(cublasHandle_t &handle, - const DiagType diagA, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - const ordinal_type m = B.extent(0), n = B.extent(1); - int r_val(0); - if (m > 0 && n > 0) { - if (n == 1) { - r_val = Blas::trsv(handle, - ArgUplo::cublas_param, - ArgTransA::cublas_param, - diagA.cublas_param, - m, - A.data(), A.stride_1(), - B.data(), B.stride_0()); - } else { - r_val = Blas::trsm(handle, - Side::Left::cublas_param, - ArgUplo::cublas_param, - ArgTransA::cublas_param, - diagA.cublas_param, - m, n, - value_type(1), - A.data(), A.stride_1(), - B.data(), B.stride_1()); - } - } - return r_val; +#if defined(KOKKOS_ENABLE_CUDA) + template + inline static int cublas_invoke(cublasHandle_t &handle, const DiagType diagA, const ViewTypeA &A, + const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + const ordinal_type m = B.extent(0), n = B.extent(1); + int r_val(0); + if (m > 0 && n > 0) { + if (n == 1) { + r_val = Blas::trsv(handle, ArgUplo::cublas_param, ArgTransA::cublas_param, diagA.cublas_param, m, + A.data(), A.stride_1(), B.data(), B.stride_0()); + } else { + r_val = Blas::trsm(handle, Side::Left::cublas_param, ArgUplo::cublas_param, ArgTransA::cublas_param, + diagA.cublas_param, m, n, value_type(1), A.data(), A.stride_1(), B.data(), + B.stride_1()); } + } + return r_val; + } #endif - template - inline - static int - invoke(MemberType &member, - const DiagType diagA, - const ViewTypeA &A, - const ViewTypeB &B) { - typedef typename ViewTypeA::non_const_value_type value_type; - typedef typename ViewTypeB::non_const_value_type value_type_b; + template + inline static int invoke(MemberType &member, const DiagType diagA, const ViewTypeA &A, const ViewTypeB &B) { + typedef typename ViewTypeA::non_const_value_type value_type; + typedef typename ViewTypeB::non_const_value_type value_type_b; + + typedef typename ViewTypeA::memory_space memory_space; + typedef typename ViewTypeB::memory_space memory_space_b; - typedef typename ViewTypeA::memory_space memory_space; - typedef typename ViewTypeB::memory_space memory_space_b; - - static_assert(ViewTypeA::rank == 2,"A is not rank 2 view."); - static_assert(ViewTypeB::rank == 2,"B is not rank 2 view."); - - static_assert(std::is_same::value, - "A and B do not have the same value type."); + static_assert(ViewTypeA::rank == 2, "A is not rank 2 view."); + static_assert(ViewTypeB::rank == 2, "B is not rank 2 view."); - static_assert(std::is_same::value, - "A and B do not have the same memory space."); - int r_val(0); - if (std::is_same::value) - r_val = blas_invoke(diagA, A, B); -#if defined (KOKKOS_ENABLE_CUDA) - if (std::is_same::value || - std::is_same::value) - r_val = cublas_invoke(member, diagA, A, B); + static_assert(std::is_same::value, "A and B do not have the same value type."); + + static_assert(std::is_same::value, "A and B do not have the same memory space."); + int r_val(0); + if (std::is_same::value) + r_val = blas_invoke(diagA, A, B); +#if defined(KOKKOS_ENABLE_CUDA) + if (std::is_same::value || std::is_same::value) + r_val = cublas_invoke(member, diagA, A, B); #endif - return r_val; - } - }; + return r_val; + } +}; -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.cpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.cpp index 7c9612006581..23dd3b5a8361 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.cpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.cpp @@ -3,10 +3,6 @@ namespace Tacho { - const char* Version() { - return "Tacho:: Trilinos Git"; - } - -} - +const char *Version() { return "Tacho:: Trilinos Git"; } +} // namespace Tacho diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.hpp index 29fed790df77..8f018b0351fb 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_Util.hpp @@ -7,16 +7,16 @@ // "std" includes #include -#include -#include -#include #include -#include -#include +#include +#include #include #include -#include #include +#include +#include +#include +#include #include #include @@ -26,11 +26,11 @@ #include "Kokkos_Core.hpp" #include "Kokkos_Timer.hpp" -#if defined (__INTEL_MKL__) +#if defined(__INTEL_MKL__) #include "mkl.h" #endif -#if defined(KOKKOS_ENABLE_CUDA) +#if defined(KOKKOS_ENABLE_CUDA) #include "cublas_v2.h" #include "cusolverDn.h" #endif @@ -43,453 +43,425 @@ namespace Tacho { - const char* Version(); +const char *Version(); - /// - /// error macros - /// - // #define MSG_NOT_YET_IMPLEMENTED(what) "Not yet implemented: " #what - // #define MSG_INVALID_INPUT(what) "Invaid input argument: " #what +/// +/// error macros +/// +// #define MSG_NOT_YET_IMPLEMENTED(what) "Not yet implemented: " #what +// #define MSG_INVALID_INPUT(what) "Invaid input argument: " #what #define MSG_NOT_HAVE_PACKAGE(what) "Tacho does not have a package or library: " #what #define MSG_INVALID_TEMPLATE_ARGS "Invaid template arguments" #define MSG_INVALID_INPUT "Invaid input arguments" #define MSG_NOT_IMPLEMENTED "Not yet implemented" -#define TACHO_TEST_FOR_ABORT(ierr, msg) \ - if ((ierr) != 0) { \ - printf(">> Error in file %s, line %d, error %d \n %s\n",__FILE__,__LINE__,ierr,msg); \ - Kokkos::abort(">> Tacho abort\n"); \ +#define TACHO_TEST_FOR_ABORT(ierr, msg) \ + if ((ierr) != 0) { \ + printf(">> Error in file %s, line %d, error %d \n %s\n", __FILE__, __LINE__, ierr, msg); \ + Kokkos::abort(">> Tacho abort\n"); \ } -#define TACHO_TEST_FOR_EXCEPTION(ierr, x, msg) \ - if ((ierr) != 0) { \ - fprintf(stderr, ">> Error in file %s, line %d, error %d \n",__FILE__,__LINE__,ierr); \ - fprintf(stderr, " %s\n", msg); \ - throw x(msg); \ +#define TACHO_TEST_FOR_EXCEPTION(ierr, x, msg) \ + if ((ierr) != 0) { \ + fprintf(stderr, ">> Error in file %s, line %d, error %d \n", __FILE__, __LINE__, ierr); \ + fprintf(stderr, " %s\n", msg); \ + throw x(msg); \ } -#if defined( KOKKOS_ENABLE_ASM ) -#if defined( __amd64 ) || defined( __amd64__ ) || \ - defined( __x86_64 ) || defined( __x86_64__ ) -#if !defined( _WIN32 ) /* IS NOT Microsoft Windows */ -#define KOKKOS_IMPL_PAUSE asm volatile( "pause\n":::"memory" ); +#if defined(KOKKOS_ENABLE_ASM) +#if defined(__amd64) || defined(__amd64__) || defined(__x86_64) || defined(__x86_64__) +#if !defined(_WIN32) /* IS NOT Microsoft Windows */ +#define KOKKOS_IMPL_PAUSE asm volatile("pause\n" ::: "memory"); #else -#define KOKKOS_IMPL_PAUSE __asm__ __volatile__( "pause\n":::"memory" ); +#define KOKKOS_IMPL_PAUSE __asm__ __volatile__("pause\n" ::: "memory"); #endif #elif defined(__PPC64__) -#define KOKKOS_IMPL_PAUSE asm volatile( "or 27, 27, 27" ::: "memory" ); +#define KOKKOS_IMPL_PAUSE asm volatile("or 27, 27, 27" ::: "memory"); #endif #else #define KOKKOS_IMPL_PAUSE #endif - /// - /// label size used to identify object name - /// - enum : int { - MaxDependenceSize = 4, - ThresholdSolvePhaseUsingBlas3 = 12, - CudaVectorSize = 4 - }; +/// +/// label size used to identify object name +/// +enum : int { MaxDependenceSize = 4, ThresholdSolvePhaseUsingBlas3 = 12, CudaVectorSize = 4 }; - /// - /// util - /// - template - KOKKOS_FORCEINLINE_FUNCTION - static Ta min(const Ta a, const Tb b) { - return (a < static_cast(b) ? a : static_cast(b)); - } +/// +/// util +/// +template KOKKOS_FORCEINLINE_FUNCTION static Ta min(const Ta a, const Tb b) { + return (a < static_cast(b) ? a : static_cast(b)); +} - template - KOKKOS_FORCEINLINE_FUNCTION - static Ta max(const Ta a, const Tb b) { - return (a > static_cast(b) ? a : static_cast(b)); - } +template KOKKOS_FORCEINLINE_FUNCTION static Ta max(const Ta a, const Tb b) { + return (a > static_cast(b) ? a : static_cast(b)); +} - template - KOKKOS_FORCEINLINE_FUNCTION - static void swap(Ta &a, Tb &b) { - Ta c(a); a = static_cast(b); b = static_cast(c); - } +template KOKKOS_FORCEINLINE_FUNCTION static void swap(Ta &a, Tb &b) { + Ta c(a); + a = static_cast(b); + b = static_cast(c); +} - KOKKOS_FORCEINLINE_FUNCTION - static void clear(char *buf, size_type bufsize) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - memset(buf, 0, bufsize); +KOKKOS_FORCEINLINE_FUNCTION +static void clear(char *buf, size_type bufsize) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + memset(buf, 0, bufsize); #else - for (size_type i=0;i - KOKKOS_FORCEINLINE_FUNCTION - static void clear(MemberType &member, char *buf, size_type bufsize) { -#if defined( KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST ) - memset(buf, 0, bufsize); +} + +template +KOKKOS_FORCEINLINE_FUNCTION static void clear(MemberType &member, char *buf, size_type bufsize) { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_HOST) + memset(buf, 0, bufsize); #else - const ordinal_type team_index_range = (bufsize/CudaVectorSize) + (bufsize%CudaVectorSize > 0); - Kokkos::parallel_for(Kokkos::TeamThreadRange(member,team_index_range),[&](const int &idx) { - const int ioff = idx * CudaVectorSize; - const int itmp = bufsize - ioff; - const int icnt = itmp > CudaVectorSize ? CudaVectorSize : itmp; - Kokkos::parallel_for(Kokkos::ThreadVectorRange(member,icnt),[&](const int &ii) { - const int i = ioff + ii; - buf[i] = 0; - }); - }); + const ordinal_type team_index_range = (bufsize / CudaVectorSize) + (bufsize % CudaVectorSize > 0); + Kokkos::parallel_for(Kokkos::TeamThreadRange(member, team_index_range), [&](const int &idx) { + const int ioff = idx * CudaVectorSize; + const int itmp = bufsize - ioff; + const int icnt = itmp > CudaVectorSize ? CudaVectorSize : itmp; + Kokkos::parallel_for(Kokkos::ThreadVectorRange(member, icnt), [&](const int &ii) { + const int i = ioff + ii; + buf[i] = 0; + }); + }); #endif - } +} - template - KOKKOS_INLINE_FUNCTION - static T1* lower_bound(T1* first, T1* last, const T2& val, - CompareType compare) { - T1 *it; - ordinal_type step = 0, count = last - first; - while (count > 0) { - it = first; step = count/2; it += step; - if (compare(*it,val)) { - first = ++it; - count -= step + 1; - } else { - count = step; - } +template +KOKKOS_INLINE_FUNCTION static T1 *lower_bound(T1 *first, T1 *last, const T2 &val, CompareType compare) { + T1 *it; + ordinal_type step = 0, count = last - first; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (compare(*it, val)) { + first = ++it; + count -= step + 1; + } else { + count = step; } - return first; } - - template - struct Flush { - typedef double value_type; - - // flush a large host buffer - Kokkos::View _buf; - Flush() : _buf("Flush::buf", BufSize/sizeof(double)) { - Kokkos::deep_copy(_buf, 1); - } + return first; +} - KOKKOS_INLINE_FUNCTION - void init(value_type &update) { - update = 0; - } - - KOKKOS_INLINE_FUNCTION - void join(volatile value_type &update, - const volatile value_type &input) { - update += input; - } - - KOKKOS_INLINE_FUNCTION - void operator()(const int i, value_type &update) const { - update += _buf[i]; - } - - void run() { - double sum = 0; - Kokkos::parallel_reduce(Kokkos::RangePolicy(0,BufSize/sizeof(double)), *this, sum); - SpaceType().fence(); - FILE *fp = fopen("/dev/null", "w"); - fprintf(fp, "%f\n", sum); - fclose(fp); - } +template struct Flush { + typedef double value_type; - }; + // flush a large host buffer + Kokkos::View _buf; + Flush() : _buf("Flush::buf", BufSize / sizeof(double)) { Kokkos::deep_copy(_buf, 1); } - template - struct Random; + KOKKOS_INLINE_FUNCTION + void init(value_type &update) { update = 0; } - template<> - struct Random { - Random(const unsigned int seed = 0) { srand(seed); } - double value() { return rand()/((double) RAND_MAX + 1.0); } - }; + KOKKOS_INLINE_FUNCTION + void join(volatile value_type &update, const volatile value_type &input) { update += input; } - template<> - struct Random > { - Random(const unsigned int seed = 0) { srand(seed); } - std::complex value() { - return std::complex(rand()/((double) RAND_MAX + 1.0), - rand()/((double) RAND_MAX + 1.0)); - } - }; + KOKKOS_INLINE_FUNCTION + void operator()(const int i, value_type &update) const { update += _buf[i]; } + + void run() { + double sum = 0; + Kokkos::parallel_reduce(Kokkos::RangePolicy(0, BufSize / sizeof(double)), *this, sum); + SpaceType().fence(); + FILE *fp = fopen("/dev/null", "w"); + fprintf(fp, "%f\n", sum); + fclose(fp); + } +}; - template<> - struct Random > { - Random(const unsigned int seed = 0) { srand(seed); } - Kokkos::complex value() { - return Kokkos::complex(rand()/((double) RAND_MAX + 1.0), - rand()/((double) RAND_MAX + 1.0)); - } - }; +template struct Random; +template <> struct Random { + Random(const unsigned int seed = 0) { srand(seed); } + double value() { return rand() / ((double)RAND_MAX + 1.0); } +}; - /// - /// Tag struct - /// - struct NullTag { enum : int { tag = 0 }; }; - struct PivotMode { - struct Flame {}; /// 0 base relative pivot index - struct Lapack {}; /// 1 base index - }; +template <> struct Random> { + Random(const unsigned int seed = 0) { srand(seed); } + std::complex value() { + return std::complex(rand() / ((double)RAND_MAX + 1.0), rand() / ((double)RAND_MAX + 1.0)); + } +}; - struct Partition { - enum : int { Top = 101, - Bottom, - Left = 201, - Right, - TopLeft = 301, - TopRight, - BottomLeft, - BottomRight }; +template <> struct Random> { + Random(const unsigned int seed = 0) { srand(seed); } + Kokkos::complex value() { + return Kokkos::complex(rand() / ((double)RAND_MAX + 1.0), rand() / ((double)RAND_MAX + 1.0)); + } +}; + +/// +/// Tag struct +/// +struct NullTag { + enum : int { tag = 0 }; +}; +struct PivotMode { + struct Flame {}; /// 0 base relative pivot index + struct Lapack {}; /// 1 base index +}; + +struct Partition { + enum : int { Top = 101, Bottom, Left = 201, Right, TopLeft = 301, TopRight, BottomLeft, BottomRight }; +}; +template struct is_valid_partition_tag { + enum : bool { + value = + (T == Partition::TopLeft || T == Partition::Top || T == Partition::TopRight || T == Partition::Left || + T == Partition::Right || T == Partition::BottomLeft || T == Partition::Bottom || T == Partition::BottomRight) }; - - struct Uplo { - enum : int { tag = 400 }; - struct Upper { - enum : int { tag = 401 }; - static constexpr char param = 'U'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_UPLO mkl_param = CblasUpper; +}; + +struct Uplo { + enum : int { tag = 400 }; + struct Upper { + enum : int { tag = 401 }; + static constexpr char param = 'U'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_UPLO mkl_param = CblasUpper; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasFillMode_t cublas_param = CUBLAS_FILL_MODE_UPPER; + static constexpr cublasFillMode_t cublas_param = CUBLAS_FILL_MODE_UPPER; #endif - }; - struct Lower { - enum : int { tag = 402 }; - static constexpr char param = 'L'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_UPLO mkl_param = CblasLower; + }; + struct Lower { + enum : int { tag = 402 }; + static constexpr char param = 'L'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_UPLO mkl_param = CblasLower; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasFillMode_t cublas_param = CUBLAS_FILL_MODE_LOWER; + static constexpr cublasFillMode_t cublas_param = CUBLAS_FILL_MODE_LOWER; #endif - }; }; - template - struct is_valid_uplo_tag { - enum : bool { value = (std::is_same::value || - std::is_same::value ) - }; - }; - template struct transpose_uplo_tag; - template<> struct transpose_uplo_tag { typedef Uplo::Upper type; }; - template<> struct transpose_uplo_tag { typedef Uplo::Lower type; }; - - - struct Side { - enum : int { tag = 500 }; - struct Left { - enum : int { tag = 501 }; - static constexpr char param = 'L'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_SIDE mkl_param = CblasLeft; +}; +template struct is_valid_uplo_tag { + enum : bool { value = (std::is_same::value || std::is_same::value) }; +}; +template struct transpose_uplo_tag; +template <> struct transpose_uplo_tag { typedef Uplo::Upper type; }; +template <> struct transpose_uplo_tag { typedef Uplo::Lower type; }; + +struct Side { + enum : int { tag = 500 }; + struct Left { + enum : int { tag = 501 }; + static constexpr char param = 'L'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_SIDE mkl_param = CblasLeft; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasSideMode_t cublas_param = CUBLAS_SIDE_LEFT; + static constexpr cublasSideMode_t cublas_param = CUBLAS_SIDE_LEFT; #endif - }; - struct Right { - enum : int { tag = 502 }; - static constexpr char param = 'R'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_SIDE mkl_param = CblasRight; + }; + struct Right { + enum : int { tag = 502 }; + static constexpr char param = 'R'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_SIDE mkl_param = CblasRight; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasSideMode_t cublas_param = CUBLAS_SIDE_RIGHT; + static constexpr cublasSideMode_t cublas_param = CUBLAS_SIDE_RIGHT; #endif - }; - }; - template - struct is_valid_side_tag { - enum : bool { value = (std::is_same::value || - std::is_same::value ) - }; }; - template struct flip_side_tag; - template<> struct flip_side_tag { typedef Side::Right type; }; - template<> struct flip_side_tag { typedef Side::Left type; }; - - struct Diag { - enum : int { tag = 600 }; - struct Unit { - enum : int { tag = 601 }; - static constexpr char param = 'U'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_DIAG mkl_param = CblasUnit; +}; +template struct is_valid_side_tag { + enum : bool { value = (std::is_same::value || std::is_same::value) }; +}; +template struct flip_side_tag; +template <> struct flip_side_tag { typedef Side::Right type; }; +template <> struct flip_side_tag { typedef Side::Left type; }; + +struct Diag { + enum : int { tag = 600 }; + struct Unit { + enum : int { tag = 601 }; + static constexpr char param = 'U'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_DIAG mkl_param = CblasUnit; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasDiagType_t cublas_param = CUBLAS_DIAG_UNIT; + static constexpr cublasDiagType_t cublas_param = CUBLAS_DIAG_UNIT; #endif - }; - struct NonUnit { - enum : int { tag = 602 }; - static constexpr char param = 'N'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_DIAG mkl_param = CblasNonUnit; + }; + struct NonUnit { + enum : int { tag = 602 }; + static constexpr char param = 'N'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_DIAG mkl_param = CblasNonUnit; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasDiagType_t cublas_param = CUBLAS_DIAG_NON_UNIT; + static constexpr cublasDiagType_t cublas_param = CUBLAS_DIAG_NON_UNIT; #endif - }; - }; - template - struct is_valid_diag_tag { - enum : bool { value = (std::is_same::value || - std::is_same::value ) - }; }; - - struct Trans { - enum : int { tag = 700 }; - struct Transpose { - enum : int { tag = 701 }; - static constexpr char param = 'T'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_TRANSPOSE mkl_param = CblasTrans; +}; +template struct is_valid_diag_tag { + enum : bool { value = (std::is_same::value || std::is_same::value) }; +}; + +struct Trans { + enum : int { tag = 700 }; + struct Transpose { + enum : int { tag = 701 }; + static constexpr char param = 'T'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_TRANSPOSE mkl_param = CblasTrans; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasOperation_t cublas_param = CUBLAS_OP_T; + static constexpr cublasOperation_t cublas_param = CUBLAS_OP_T; #endif - }; - struct ConjTranspose { - enum : int { tag = 702 }; - static constexpr char param = 'C'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_TRANSPOSE mkl_param = CblasConjTrans; + }; + struct ConjTranspose { + enum : int { tag = 702 }; + static constexpr char param = 'C'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_TRANSPOSE mkl_param = CblasConjTrans; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasOperation_t cublas_param = CUBLAS_OP_C; + static constexpr cublasOperation_t cublas_param = CUBLAS_OP_C; #endif - }; - struct NoTranspose { - enum : int { tag = 703 }; - static constexpr char param = 'N'; -#if defined(__INTEL_MKL__) - static constexpr CBLAS_TRANSPOSE mkl_param = CblasNoTrans; + }; + struct NoTranspose { + enum : int { tag = 703 }; + static constexpr char param = 'N'; +#if defined(__INTEL_MKL__) + static constexpr CBLAS_TRANSPOSE mkl_param = CblasNoTrans; #endif #if defined(CUBLAS_VERSION) - static constexpr cublasOperation_t cublas_param = CUBLAS_OP_N; + static constexpr cublasOperation_t cublas_param = CUBLAS_OP_N; #endif - }; }; - template - struct is_valid_trans_tag { - enum : bool { value = (std::is_same::value || - std::is_same::value || - std::is_same::value) - }; - }; - template struct transpose_trans_tag; - template struct conj_transpose_trans_tag; - - template<> struct transpose_trans_tag { typedef Trans::NoTranspose type; }; - template<> struct transpose_trans_tag { typedef Trans::NoTranspose type; }; - template<> struct transpose_trans_tag { typedef Trans::Transpose type; }; - - template<> struct conj_transpose_trans_tag { typedef Trans::NoTranspose type; }; - template<> struct conj_transpose_trans_tag { typedef Trans::NoTranspose type; }; - template<> struct conj_transpose_trans_tag { typedef Trans::ConjTranspose type; }; - - struct Direct { - enum : int { tag = 800 }; - struct Forward { - enum : int { tag = 801 }; - }; - struct Backward { - enum : int { tag = 802 }; - }; +}; +template struct is_valid_trans_tag { + enum : bool { + value = (std::is_same::value || std::is_same::value || + std::is_same::value) }; +}; +template struct transpose_trans_tag; +template struct conj_transpose_trans_tag; + +template <> struct transpose_trans_tag { typedef Trans::NoTranspose type; }; +template <> struct transpose_trans_tag { typedef Trans::NoTranspose type; }; +template <> struct transpose_trans_tag { typedef Trans::Transpose type; }; + +template <> struct conj_transpose_trans_tag { typedef Trans::NoTranspose type; }; +template <> struct conj_transpose_trans_tag { typedef Trans::NoTranspose type; }; +template <> struct conj_transpose_trans_tag { typedef Trans::ConjTranspose type; }; - /// - /// helper functions - /// - struct Conjugate { +struct Direct { + enum : int { tag = 800 }; + struct Forward { enum : int { tag = 801 }; - - KOKKOS_FORCEINLINE_FUNCTION Conjugate() {} - KOKKOS_FORCEINLINE_FUNCTION Conjugate(const Conjugate &b) {} - - KOKKOS_FORCEINLINE_FUNCTION float operator()(const float &v) const { return v; } - KOKKOS_FORCEINLINE_FUNCTION double operator()(const double &v) const { return v; } - inline std::complex operator()(const std::complex &v) const { return std::conj(v); } - inline std::complex operator()(const std::complex &v) const { return std::conj(v); } - KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { return Kokkos::conj(v); } - KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { return Kokkos::conj(v); } }; - - struct NoConjugate { - enum : int {tag = 802 }; - - KOKKOS_FORCEINLINE_FUNCTION NoConjugate() {} - KOKKOS_FORCEINLINE_FUNCTION NoConjugate(const NoConjugate &b) {} - - KOKKOS_FORCEINLINE_FUNCTION float operator()(const float &v) const { return v; } - KOKKOS_FORCEINLINE_FUNCTION double operator()(const double &v) const { return v; } - inline std::complex operator()(const std::complex &v) const { return v; } - inline std::complex operator()(const std::complex &v) const { return v; } - KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { return v; } - KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { return v; } + struct Backward { + enum : int { tag = 802 }; }; +}; + +/// +/// helper functions +/// +struct Conjugate { + enum : int { tag = 801 }; + + KOKKOS_FORCEINLINE_FUNCTION Conjugate() {} + KOKKOS_FORCEINLINE_FUNCTION Conjugate(const Conjugate &b) {} + + KOKKOS_FORCEINLINE_FUNCTION float operator()(const float &v) const { return v; } + KOKKOS_FORCEINLINE_FUNCTION double operator()(const double &v) const { return v; } + inline std::complex operator()(const std::complex &v) const { return std::conj(v); } + inline std::complex operator()(const std::complex &v) const { return std::conj(v); } + KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { + return Kokkos::conj(v); + } + KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { + return Kokkos::conj(v); + } +}; + +struct NoConjugate { + enum : int { tag = 802 }; - struct Algo { - struct External { enum : int { tag = 1001 }; }; - struct Internal { enum : int { tag = 1002 }; }; - struct ByBlocks { enum : int { tag = 1003 }; }; - struct OnDevice { enum : int { tag = 1004 }; }; + KOKKOS_FORCEINLINE_FUNCTION NoConjugate() {} + KOKKOS_FORCEINLINE_FUNCTION NoConjugate(const NoConjugate &b) {} + + KOKKOS_FORCEINLINE_FUNCTION float operator()(const float &v) const { return v; } + KOKKOS_FORCEINLINE_FUNCTION double operator()(const double &v) const { return v; } + inline std::complex operator()(const std::complex &v) const { return v; } + inline std::complex operator()(const std::complex &v) const { return v; } + KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { return v; } + KOKKOS_FORCEINLINE_FUNCTION Kokkos::complex operator()(const Kokkos::complex &v) const { return v; } +}; + +struct Algo { + struct External { + enum : int { tag = 1001 }; + }; + struct Internal { + enum : int { tag = 1002 }; + }; + struct ByBlocks { + enum : int { tag = 1003 }; + }; + struct OnDevice { + enum : int { tag = 1004 }; + }; - struct Workflow { - struct Serial { enum : int { tag = 2001 }; }; - struct SerialPanel { enum : int { tag = 2002 }; }; + struct Workflow { + struct Serial { + enum : int { tag = 2001 }; + }; + struct SerialPanel { + enum : int { tag = 2002 }; }; }; +}; - template - using MemoryTraits = Kokkos::MemoryTraits; - - template - using UnmanagedViewType = Kokkos::View >; - template - using ConstViewType = Kokkos::View; - template - using ConstUnmanagedViewType = ConstViewType >; - - using do_not_initialize_tag = Kokkos::ViewAllocateWithoutInitializing; - - template - struct ExecSpaceFactory { - static void createInstance(T &exec_instance) { - exec_instance = T(); - } +struct ActiveAlgorithm { +#if defined(KOKKOS_ACTIVE_EXECUTION_MEMORY_SPACE_CUDA) + using type = Algo::Internal; +#else + using type = Algo::External; +#endif +}; + +template +using MemoryTraits = Kokkos::MemoryTraits; + +template +using UnmanagedViewType = + Kokkos::View>; +template +using ConstViewType = Kokkos::View; +template using ConstUnmanagedViewType = ConstViewType>; + +using do_not_initialize_tag = Kokkos::ViewAllocateWithoutInitializing; + +template struct ExecSpaceFactory { + static void createInstance(T &exec_instance) { exec_instance = T(); } #if defined(KOKKOS_ENABLE_CUDA) - static void createInstance(const cudaStream_t &s, T &exec_instance) { - exec_instance = T(); - } + static void createInstance(const cudaStream_t &s, T &exec_instance) { exec_instance = T(); } #endif - }; +}; #if defined(KOKKOS_ENABLE_CUDA) - template<> - struct ExecSpaceFactory { - static void createInstance(Kokkos::Cuda &exec_instance) { - exec_instance = Kokkos::Cuda(); - } - static void createInstance(const cudaStream_t &s, Kokkos::Cuda &exec_instance) { - exec_instance = Kokkos::Cuda(s); - } - }; +template <> struct ExecSpaceFactory { + static void createInstance(Kokkos::Cuda &exec_instance) { exec_instance = Kokkos::Cuda(); } + static void createInstance(const cudaStream_t &s, Kokkos::Cuda &exec_instance) { exec_instance = Kokkos::Cuda(s); } +}; #endif -} +} // namespace Tacho #endif diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_ByBlocks.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Chol_ByBlocks.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_Chol_ByBlocks.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Chol_ByBlocks.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_ByBlocks.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Gemm_ByBlocks.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_Gemm_ByBlocks.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Gemm_ByBlocks.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_ByBlocks.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Herk_ByBlocks.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_Herk_ByBlocks.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Herk_ByBlocks.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LevelSetTools.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_LevelSetTools.hpp similarity index 99% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_LevelSetTools.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_LevelSetTools.hpp index 702e4d5ebf39..4999df8d85f6 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_LevelSetTools.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_LevelSetTools.hpp @@ -13,19 +13,19 @@ #include "Tacho_SetIdentity.hpp" #include "Tacho_SetIdentity_OnDevice.hpp" -#include "Tacho_Chol.hpp" +//#include "Tacho_Chol.hpp" #include "Tacho_Chol_OnDevice.hpp" -#include "Tacho_Trsm.hpp" +//#include "Tacho_Trsm.hpp" #include "Tacho_Trsm_OnDevice.hpp" -#include "Tacho_Herk.hpp" +//#include "Tacho_Herk.hpp" #include "Tacho_Herk_OnDevice.hpp" -#include "Tacho_Trsv.hpp" +//#include "Tacho_Trsv.hpp" #include "Tacho_Trsv_OnDevice.hpp" -#include "Tacho_Gemv.hpp" +//#include "Tacho_Gemv.hpp" #include "Tacho_Gemv_OnDevice.hpp" #include "Tacho_SupernodeInfo.hpp" diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_NumericTools.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_NumericTools.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeChol.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeChol.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeChol.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeChol.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeCholByBlocks.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeCholByBlocks.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeCholByBlocks.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeCholByBlocks.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeCholByBlocksPanel.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeCholByBlocksPanel.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeCholByBlocksPanel.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeCholByBlocksPanel.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeCholPanel.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeCholPanel.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_FactorizeCholPanel.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_FactorizeCholPanel.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_MemoryPool.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_MemoryPool.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_MemoryPool.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_MemoryPool.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_SolveLowerChol.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_SolveLowerChol.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_SolveLowerChol.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_SolveLowerChol.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_SolveUpperChol.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_SolveUpperChol.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TaskFunctor_SolveUpperChol.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TaskFunctor_SolveUpperChol.hpp diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TriSolveTools.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TriSolveTools.hpp similarity index 99% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_TriSolveTools.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TriSolveTools.hpp index 5b3157d00b38..0584da3f2477 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_TriSolveTools.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_TriSolveTools.hpp @@ -6,13 +6,13 @@ #include "Tacho_Util.hpp" -#include "Tacho_Trsm.hpp" +//#include "Tacho_Trsm.hpp" #include "Tacho_Trsm_OnDevice.hpp" -#include "Tacho_Trsv.hpp" +//#include "Tacho_Trsv.hpp" #include "Tacho_Trsv_OnDevice.hpp" -#include "Tacho_Gemv.hpp" +//#include "Tacho_Gemv.hpp" #include "Tacho_Gemv_OnDevice.hpp" #include "Tacho_SupernodeInfo.hpp" diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_ByBlocks.hpp b/packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Trsm_ByBlocks.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/src/impl/Tacho_Trsm_ByBlocks.hpp rename to packages/shylu/shylu_node/tacho/src/impl/later/Tacho_Trsm_ByBlocks.hpp diff --git a/packages/shylu/shylu_node/tacho/unit-test/CMakeLists.txt b/packages/shylu/shylu_node/tacho/unit-test/CMakeLists.txt index 58664936863a..b4b338acafc2 100644 --- a/packages/shylu/shylu_node/tacho/unit-test/CMakeLists.txt +++ b/packages/shylu/shylu_node/tacho/unit-test/CMakeLists.txt @@ -1,82 +1,108 @@ INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) INCLUDE_DIRECTORIES(${CMAKE_CURRENT_BINARY_DIR}) -TRIBITS_COPY_FILES_TO_BINARY_DIR(Tacho_UnitTest_SparseMatrixFile +TRIBITS_COPY_FILES_TO_BINARY_DIR( + Tacho_UnitTest_SparseMatrixFile SOURCE_FILES test_double.mtx test_dcomplex.mtx SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR} DEST_DIR ${CMAKE_CURRENT_BINARY_DIR} ) -SET(SOURCES "") +# +# gtest library +# +INCLUDE_DIRECTORIES("${CMAKE_CURRENT_SOURCE_DIR}/googletest") +INCLUDE_DIRECTORIES("${CMAKE_CURRENT_SOURCE_DIR}/googletest/include") -FILE(GLOB SOURCES *.cpp) +TRIBITS_ADD_LIBRARY( + tacho-gtest + SOURCES googletest/src/gtest-all.cc + TESTONLY + NO_INSTALL_LIB_OR_HEADERS +) -SET(LIBRARIES tacho) +TRIBITS_ADD_EXECUTABLE_AND_TEST( + Tacho_Test_Util.x + NOEXESUFFIX + NOEXEPREFIX + SOURCES Tacho_TestUtil.cpp + TESTONLYLIBS tacho-gtest + ARGS PrintItAll + NUM_MPI_PROCS 1 + FAIL_REGULAR_EXPRESSION " FAILED " +) -IF (TACHO_HAVE_KOKKOS_TASK) +SET(TACHO_ETI_FILE "Tacho_Test") +SET(TACHO_ETI_DEVICE_NAME "") +SET(TACHO_ETI_DEVICE_TYPE "") +SET(TACHO_ETI_WITH_TASK "") - TRIBITS_ADD_EXECUTABLE_AND_TEST( - Tacho_TestUtil - NOEXEPREFIX - SOURCES Tacho_TestUtil.cpp - ARGS PrintItAll - NUM_MPI_PROCS 1 - FAIL_REGULAR_EXPRESSION " FAILED " - ) - IF(Kokkos_ENABLE_SERIAL) - TRIBITS_ADD_EXECUTABLE_AND_TEST( - Tacho_TestSerialDouble - NOEXEPREFIX - SOURCES Tacho_TestSerial_double.cpp - ARGS PrintItAll - NUM_MPI_PROCS 1 - FAIL_REGULAR_EXPRESSION " FAILED " - ) - TRIBITS_ADD_EXECUTABLE_AND_TEST( - Tacho_TestSerialDoubleComplex - NOEXEPREFIX - SOURCES Tacho_TestSerial_dcomplex.cpp - ARGS PrintItAll - NUM_MPI_PROCS 1 - FAIL_REGULAR_EXPRESSION " FAILED " - ) - ENDIF() - - IF(Kokkos_ENABLE_OPENMP) - TRIBITS_ADD_EXECUTABLE_AND_TEST( - Tacho_TestOpenMPDouble - NOEXEPREFIX - SOURCES Tacho_TestOpenMP_double.cpp - ARGS PrintItAll - NUM_MPI_PROCS 1 - FAIL_REGULAR_EXPRESSION " FAILED " - ) - TRIBITS_ADD_EXECUTABLE_AND_TEST( - Tacho_TestOpenMPDoubleComplex - NOEXEPREFIX - SOURCES Tacho_TestOpenMP_dcomplex.cpp - ARGS PrintItAll - NUM_MPI_PROCS 1 - FAIL_REGULAR_EXPRESSION " FAILED " - ) - ENDIF() +IF (Kokkos_ENABLE_SERIAL) + LIST(APPEND TACHO_ETI_DEVICE_NAME "Serial") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "0") +ENDIF() +IF (Kokkos_ENABLE_OPENMP) + LIST(APPEND TACHO_ETI_DEVICE_NAME "OpenMP") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "1") +ENDIF() +IF (Kokkos_ENABLE_CUDA) + LIST(APPEND TACHO_ETI_DEVICE_NAME "CUDA") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "0") +ENDIF() +IF (Kokkos_ENABLE_HIP) + LIST(APPEND TACHO_ETI_DEVICE_NAME "HIP") + LIST(APPEND TACHO_ETI_DEVICE_TYPE "Kokkos::Device") + LIST(APPEND TACHO_ETI_WITH_TASK "0") +ENDIF() + +LIST(LENGTH TACHO_ETI_DEVICE_NAME ETI_DEVICE_COUNT) +MATH(EXPR ETI_DEVICE_COUNT "${ETI_DEVICE_COUNT}-1") + +SET(TACHO_ETI_VALUE_NAME "") +SET(TACHO_ETI_VALUE_TYPE "") + +LIST(APPEND TACHO_ETI_VALUE_NAME "double") +LIST(APPEND TACHO_ETI_VALUE_TYPE "double") + +LIST(APPEND TACHO_ETI_VALUE_NAME "float") +LIST(APPEND TACHO_ETI_VALUE_TYPE "float") + +LIST(APPEND TACHO_ETI_VALUE_NAME "complex_double") +LIST(APPEND TACHO_ETI_VALUE_TYPE "Kokkos::complex") + +LIST(APPEND TACHO_ETI_VALUE_NAME "complex_float") +LIST(APPEND TACHO_ETI_VALUE_TYPE "Kokkos::complex") + +LIST(LENGTH TACHO_ETI_VALUE_NAME ETI_VALUE_COUNT) +MATH(EXPR ETI_VALUE_COUNT "${ETI_VALUE_COUNT}-1") + +FOREACH(I RANGE ${ETI_DEVICE_COUNT}) + LIST(GET TACHO_ETI_DEVICE_NAME ${I} ETI_DEVICE_NAME) + LIST(GET TACHO_ETI_DEVICE_TYPE ${I} ETI_DEVICE_TYPE) + LIST(GET TACHO_ETI_WITH_TASK ${I} ETI_WITH_TASK) + + FOREACH(J RANGE ${ETI_VALUE_COUNT}) + LIST(GET TACHO_ETI_VALUE_NAME ${J} ETI_VALUE_NAME) + LIST(GET TACHO_ETI_VALUE_TYPE ${J} ETI_VALUE_TYPE) + + FOREACH(ETI_FILE IN LISTS TACHO_ETI_FILE) + SET(ETI_NAME "${ETI_FILE}_ETI_${ETI_VALUE_NAME}_${ETI_DEVICE_NAME}") + MESSAGE(STATUS "Generating ETI: ${ETI_NAME}.cpp") + CONFIGURE_FILE(eti/${ETI_FILE}_ETI.in eti/${ETI_NAME}.cpp) - IF(Kokkos_ENABLE_CUDA) - TRIBITS_ADD_EXECUTABLE_AND_TEST( - Tacho_TestCudaDouble - NOEXEPREFIX - SOURCES Tacho_TestCuda_double.cpp - ARGS PrintItAll - NUM_MPI_PROCS 1 - FAIL_REGULAR_EXPRESSION " FAILED " - ) TRIBITS_ADD_EXECUTABLE_AND_TEST( - Tacho_TestCudaDoubleComplex + ${ETI_NAME}.x + NOEXESUFFIX NOEXEPREFIX - SOURCES Tacho_TestCuda_dcomplex.cpp + SOURCES eti/${ETI_NAME}.cpp + TESTONLYLIBS tacho-gtest ARGS PrintItAll NUM_MPI_PROCS 1 FAIL_REGULAR_EXPRESSION " FAILED " ) - ENDIF() -ENDIF() + ENDFOREACH() + ENDFOREACH() +ENDFOREACH() diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_Test.hpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_Test.hpp index eb42bfc044dd..5c48c5c616a8 100644 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_Test.hpp +++ b/packages/shylu/shylu_node/tacho/unit-test/Tacho_Test.hpp @@ -1,17 +1,162 @@ #ifndef __TACHO_TEST_HPP__ #define __TACHO_TEST_HPP__ -template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; +#include "Tacho_CrsMatrixBase.hpp" -#include "Tacho_TestCrsMatrixBase.hpp" -#include "Tacho_TestGraph.hpp" -#include "Tacho_TestSymbolic.hpp" -//#include "Tacho_TestNumeric.hpp" -//#include "Tacho_TestTaskFunctor.hpp" +namespace Test { + +/// std cout capture +/// testing::internal::CaptureStdout(); +/// std::string output = testing::internal::GetCapturedStdout(); +/// printf("%s\n", output.c_str()); + +using namespace Tacho; + +using atsv = ArithTraits; +using atsm = ArithTraits; + +using crs_matrix_base_type_host = CrsMatrixBase; +using crs_matrix_base_type = CrsMatrixBase; + +using ordinal_type_array_type_host = Kokkos::View; +using size_type_array_type_host = Kokkos::View; -#include "Tacho_TestDenseMatrixView.hpp" -#include "Tacho_TestDenseByBlocks.hpp" +using value_type_array_type_host = Kokkos::View; +using value_type_matrix_type_host = Kokkos::View; +void fill_spd_tridiag_matrix(const value_type_matrix_type_host &A) { + const int m = A.extent(0), n = A.extent(1); + EXPECT_TRUE(m == n); + { + for (int i = 0; i < m; ++i) + A(i, i) = 4.0; + + for (int i = 0; i < (m - 1); ++i) { + A(i, i + 1) = -1.0; + A(i + 1, i) = -1.0; + } + } +} +void symmetrize_with_upper(const value_type_matrix_type_host &A) { + const int m = A.extent(0), n = A.extent(1); + EXPECT_TRUE(m == n); + for (int i = 0; i < m; ++i) + for (int j = i; i < m; ++j) + A(j, i) = A(i, j); +} +void fill_random_matrix(const value_type_matrix_type_host &A) { + // const int m = A.extent(0), n = A.extent(1); + value_type one; + atsv::set_real(one, 1); + atsv::set_imag(one, 1); + + Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::fill_random(A, random, one); +} +void fill_random_symmetric_matrix(const value_type_matrix_type_host &A) { + fill_random_matrix(A); + symmetrize_with_upper(A); +} +void copy_matrix(const value_type_matrix_type_host A, const value_type_matrix_type_host B) { Kokkos::deep_copy(A, B); } +void copy_lower_triangular(const value_type_matrix_type_host A, const bool transL, const value_type diag, + const value_type_matrix_type_host L) { + const int m = L.extent(0), n = L.extent(1); + { + const int mA = A.extent(0), nA = A.extent(1); + EXPECT_TRUE(m == mA); + EXPECT_TRUE(n == nA); + } + const bool replace_diag = diag != value_type(0); + for (int j = 0; j < n; ++j) + for (int i = j; i < m; ++i) { + if (i == j && replace_diag) + A(i, j) = diag; + else + A(i, j) = (transL ? L(j, i) : L(i, j)); + } +} +void copy_upper_triangular(const value_type_matrix_type_host A, const bool transU, const value_type diag, + const value_type_matrix_type_host U) { + const int m = U.extent(0), n = U.extent(1); + { + const int mA = A.extent(0), nA = A.extent(1); + EXPECT_TRUE(m == mA); + EXPECT_TRUE(n == nA); + } + const bool replace_diag = diag != value_type(0); + for (int i = 0; i < m; ++i) + for (int j = i; j < n; ++j) { + if (i == j && replace_diag) + A(i, j) = diag; + else + A(i, j) = (transU ? U(j, i) : U(i, j)); + } +} +void show_matrix(std::string label, const value_type_matrix_type_host A) { + const int m = A.extent(0), n = A.extent(1); + std::cout << label << "(" << m << " x " << n << ")\n"; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) + std::cout << A(i, j) << " "; + std::cout << "\n"; + } + std::cout << "\n"; +} +void compute_A(const value_type_matrix_type_host A, const value_type_matrix_type_host L, + const value_type_matrix_type_host U) { + const int m = A.extent(0), n = A.extent(1), k = L.extent(1); + { + const int mL = L.extent(0), nU = U.extent(1), kU = U.extent(0); + EXPECT_TRUE(m == mL); + EXPECT_TRUE(n == nU); + EXPECT_TRUE(k == kU); + } + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) { + A(i, j) = 0; + for (int l = 0; l < k; ++l) + A(i, j) += L(i, l) * U(l, j); + } +} +void apply_lapack_pivots_left_no_trans(const value_type_matrix_type_host A, const ordinal_type_array_type_host p) { + const int m = A.extent(0), n = A.extent(1); + { + const int mp = p.extent(0); + EXPECT_TRUE(m <= mp); + } + for (int i = 0; i < m; ++i) { + const int idx = p(i) - 1; + if (idx != i) + for (int j = 0; j < n; ++j) + std::swap(A(i, j), A(idx, j)); + } +} +void check_same_matrix(const value_type_matrix_type_host A, const value_type_matrix_type_host B) { + const int m = A.extent(0), n = A.extent(1); + { + const int mB = B.extent(0), nB = B.extent(1); + EXPECT_TRUE(m == mB); + EXPECT_TRUE(n == nB); + } + const magnitude_type eps = atsv::epsilon() * 100; + for (int i = 0; i < m; ++i) + for (int j = 0; j < m; ++j) { + EXPECT_NEAR(atsv::real(A(i, j)), atsv::real(B(i, j)), eps); + EXPECT_NEAR(atsv::imag(A(i, j)), atsv::imag(B(i, j)), eps); + } +} +} // namespace Test + +#include "Tacho_TestCrsMatrixBase.hpp" #include "Tacho_TestDenseLinearAlgebra.hpp" +#include "Tacho_TestGraphTools.hpp" +#include "Tacho_TestSymbolicTools.hpp" +// //#include "Tacho_TestNumeric.hpp" +// //#include "Tacho_TestTaskFunctor.hpp" + +// #include "Tacho_TestDenseMatrixView.hpp" +// #include "Tacho_TestDenseByBlocks.hpp" + +// #include "Tacho_TestDenseLinearAlgebra.hpp" #endif diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestCrsMatrixBase.hpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestCrsMatrixBase.hpp index 6aebc10573af..f43cd2d7d49d 100644 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestCrsMatrixBase.hpp +++ b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestCrsMatrixBase.hpp @@ -1,213 +1,166 @@ #ifndef __TACHO_TEST_CRS_MATRIX_BASE_HPP__ #define __TACHO_TEST_CRS_MATRIX_BASE_HPP__ -#include - -#include -#include - #include "Tacho_Util.hpp" -#include "Tacho_CrsMatrixBase.hpp" #include "Tacho_MatrixMarket.hpp" -using namespace Tacho; - -typedef CrsMatrixBase CrsMatrixBaseHostType; -typedef CrsMatrixBase CrsMatrixBaseDeviceType; +namespace Test { -TEST( CrsMatrixBase, coo ) { - TEST_BEGIN; - { - auto a = Coo(); - EXPECT_EQ(a.i, 0); - EXPECT_EQ(a.j, 0); - EXPECT_EQ(a.j, 0.0); - } - { - auto a = Coo(1,3, 3.0); - auto b = Coo(1,3,10.0); - auto c = Coo(2,3, 3.0); - auto d = Coo(1,1, 3.0); + TEST( CrsMatrixBase, coo ) { + { + auto a = Coo(); + EXPECT_EQ(a.i, 0); + EXPECT_EQ(a.j, 0); + EXPECT_EQ(a.j, 0.0); + } + { + auto a = Coo(1,3, 3.0); + auto b = Coo(1,3,10.0); + auto c = Coo(2,3, 3.0); + auto d = Coo(1,1, 3.0); - EXPECT_TRUE(a == b); - EXPECT_TRUE(a != c); - EXPECT_TRUE(a < c); - EXPECT_FALSE(a < d); + EXPECT_TRUE(a == b); + EXPECT_TRUE(a != c); + EXPECT_TRUE(a < c); + EXPECT_FALSE(a < d); + } } - TEST_END; -} -TEST( CrsMatrixBase, constructor ) { - TEST_BEGIN; - /// - /// host space crs matrix base - /// - const ordinal_type - m = 4, - n = 4, - nnz = 16; - - CrsMatrixBaseHostType Ah(m, n, nnz); - EXPECT_EQ(Ah.NumRows(), m); - EXPECT_EQ(Ah.NumCols(), n); - EXPECT_EQ(size_t(Ah.NumNonZeros()), size_t(nnz)); - - ordinal_type cnt = 0; - for (ordinal_type i=0;i::read(inputfilename, Ah); - -// std::ofstream out(outputfilename); -// MatrixMarket::write(out, Ah); - -// CrsMatrixBaseHostType Bh; -// MatrixMarket::read(outputfilename, Bh); - -// /// -// /// read and write the matrix and read again, -// /// then check if they are same -// /// -// EXPECT_EQ(Ah.NumRows(), Bh.NumRows()); -// EXPECT_EQ(Ah.NumCols(), Bh.NumCols()); -// EXPECT_EQ(Ah.NumNonZeros(), Bh.NumNonZeros()); - -// const ordinal_type m = Ah.NumRows(); -// for (ordinal_type i=0;i::read(inputfilename, Bh); - Ah.createConfTo(Bh); - - /// - /// device crs matrix - /// - CrsMatrixBaseDeviceType Ad, Bd; - Ad.createMirror(Ah); - Bd.createMirror(Bh); - Bd.copy(Bh); - - /// - /// random permutation vector - /// - const ordinal_type m = Ad.NumRows(); - typedef Kokkos::View ordinal_type_array_host; - ordinal_type_array_host perm("perm", m), peri("peri", m); - - for (int i=0;i::value ? "test_double.mtx" : "test_dcomplex.mtx"; + + /// + /// host crs matrix read from matrix market + /// + crs_matrix_base_type_host Ah, Bh; + MatrixMarket::read(filename, Bh); + Ah.createConfTo(Bh); + + /// + /// device crs matrix + /// + crs_matrix_base_type Ad, Bd; + Ad.createMirror(Ah); + Bd.createMirror(Bh); + Bd.copy(Bh); + + /// + /// random permutation vector + /// + const int m = Ad.NumRows(); + ordinal_type_array_type_host perm("perm", m), peri("peri", m); + + for (int i=0;i - -#include -#include - -static const std::string MM_TEST_FILE="test_dcomplex"; - -#define TEST_BEGIN -#define TEST_END -//#define TEST_BEGIN Kokkos::initialize() -//#define TEST_END Kokkos::finalize() - -#define __TACHO_TEST_CUDA__ -#include "Tacho_config.h" -#include "Tacho_Util.hpp" - -typedef typename Tacho::UseThisDevice::type HostDeviceType; -typedef typename Tacho::UseThisDevice::type DeviceType; - -typedef Kokkos::complex ValueType; -typedef double MagnitudeType; - -#include "Tacho_Test.hpp" - -using namespace Tacho; - -int main (int argc, char *argv[]) { - - Kokkos::initialize(argc, argv); - - TEST_BEGIN; - - const bool detail = false; - printExecSpaceConfiguration("DeviceSpace", detail); - printExecSpaceConfiguration("HostSpace", detail); - - TEST_END; - - ::testing::InitGoogleTest(&argc, argv); - const int r_val = RUN_ALL_TESTS(); - - Kokkos::finalize(); - - return r_val; -} diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestCuda_double.cpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestCuda_double.cpp deleted file mode 100644 index 75659731f2de..000000000000 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestCuda_double.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include - -#include -#include - -static const std::string MM_TEST_FILE="test_double"; - -#define TEST_BEGIN -#define TEST_END -//#define TEST_BEGIN Kokkos::initialize() -//#define TEST_END Kokkos::finalize() - -#define __TACHO_TEST_CUDA__ -#include "Tacho_config.h" -#include "Tacho_Util.hpp" - -typedef typename Tacho::UseThisDevice::type HostDeviceType; -typedef typename Tacho::UseThisDevice::type DeviceType; - -typedef double ValueType; -typedef double MagnitudeType; - -#include "Tacho_Test.hpp" - -using namespace Tacho; - -int main (int argc, char *argv[]) { - - Kokkos::initialize(argc, argv); - - TEST_BEGIN; - - const bool detail = false; - printExecSpaceConfiguration("DeviceSpace", detail); - printExecSpaceConfiguration("HostSpace", detail); - - TEST_END; - - ::testing::InitGoogleTest(&argc, argv); - const int r_val = RUN_ALL_TESTS(); - - Kokkos::finalize(); - - return r_val; -} diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestDenseLinearAlgebra.hpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestDenseLinearAlgebra.hpp index 5f266548edad..acdb7cd56cb0 100644 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestDenseLinearAlgebra.hpp +++ b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestDenseLinearAlgebra.hpp @@ -9,70 +9,54 @@ #include #include -#include "Tacho_Util.hpp" -#include "Tacho_Blas_Team.hpp" #include "Tacho_Blas_External.hpp" +#include "Tacho_Blas_Team.hpp" +#include "Tacho_Util.hpp" -#include "Tacho_Lapack_Team.hpp" #include "Tacho_Lapack_External.hpp" +#include "Tacho_Lapack_Team.hpp" using namespace Tacho; -typedef Kokkos::DualView matrix_type; +typedef Kokkos::DualView matrix_type; namespace Test { - struct Functor_TeamGemm { - char _transa, _transb; - int _m, _n, _k; - matrix_type _A, _B, _C; - ValueType _alpha, _beta; - - Functor_TeamGemm(const char transa, const char transb, - const int m, const int n, const int k, - const ValueType alpha, - const matrix_type &A, - const matrix_type &B, - const ValueType beta, - const matrix_type &C) - : _transa(transa), _transb(transb), - _m(m), _n(n), _k(k), - _A(A), _B(B), _C(C), - _alpha(alpha), _beta(beta) {} - - template - KOKKOS_INLINE_FUNCTION - void operator()(const MemberType &member) const { - ::BlasTeam::gemm(member, - _transa, _transb, - _m, _n, _k, - _alpha, - (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), - (const ValueType*)_B.d_view.data(), (int)_B.d_view.stride_1(), - _beta, - ( ValueType*)_C.d_view.data(), (int)_C.d_view.stride_1()); - } - - inline - void run() { - _A.sync_device(); - _B.sync_device(); - - _C.sync_device(); - _C.modify_device(); - - Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); - - _C.sync_host(); - } - }; -} +struct Functor_TeamGemm { + char _transa, _transb; + int _m, _n, _k; + matrix_type _A, _B, _C; + value_type _alpha, _beta; + + Functor_TeamGemm(const char transa, const char transb, const int m, const int n, const int k, const value_type alpha, + const matrix_type &A, const matrix_type &B, const value_type beta, const matrix_type &C) + : _transa(transa), _transb(transb), _m(m), _n(n), _k(k), _A(A), _B(B), _C(C), _alpha(alpha), _beta(beta) {} + + template KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const { + ::BlasTeam::gemm(member, _transa, _transb, _m, _n, _k, _alpha, (const value_type *)_A.d_view.data(), + (int)_A.d_view.stride_1(), (const value_type *)_B.d_view.data(), + (int)_B.d_view.stride_1(), _beta, (value_type *)_C.d_view.data(), + (int)_C.d_view.stride_1()); + } + + inline void run() { + _A.sync_device(); + _B.sync_device(); + + _C.sync_device(); + _C.modify_device(); + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _C.sync_host(); + } +}; +} // namespace Test + +TEST(DenseLinearAlgebra, team_gemm_nn) { -TEST( DenseLinearAlgebra, team_gemm_nn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'N', transb = 'N'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; // test problem setup matrix_type A1("A1", m, k), B1("B1", k, n), C1("C1", m, n); @@ -82,44 +66,35 @@ TEST( DenseLinearAlgebra, team_gemm_nn ) { B1.modify_device(); C1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - + // tacho test - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - // reference test - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_nt) { -TEST( DenseLinearAlgebra, team_gemm_nt ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'N', transb = 'T'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", m, k), B1("B1", n, k), C1("C1", m, n); matrix_type A2("A2", m, k), B2("B2", n, k), C2("C2", m, n); @@ -129,124 +104,97 @@ TEST( DenseLinearAlgebra, team_gemm_nt ) { B1.modify_device(); C1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - // tacho test - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + // tacho test + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - // reference test - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_nc) { -TEST( DenseLinearAlgebra, team_gemm_nc ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'N', transb = 'C'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", m, k), B1("B1", n, k), C1("C1", m, n); - matrix_type A2("A2", m, k), B2("B2", n, k), C2("C2", m, n); - - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + matrix_type A2("A2", m, k), B2("B2", n, k), C2("C2", m, n); + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_tn) { -TEST( DenseLinearAlgebra, team_gemm_tn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'T', transb = 'N'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", k, m), B1("B1", k, n), C1("C1", m, n); matrix_type A2("A2", k, m), B2("B2", k, n), C2("C2", m, n); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_tt) { -TEST( DenseLinearAlgebra, team_gemm_tt ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'T', transb = 'T'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", k, m), B1("B1", n, k), C1("C1", m, n); matrix_type A2("A2", k, m), B2("B2", n, k), C2("C2", m, n); @@ -255,42 +203,33 @@ TEST( DenseLinearAlgebra, team_gemm_tt ) { B1.modify_device(); C1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_tc) { -TEST( DenseLinearAlgebra, team_gemm_tc ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'T', transb = 'C'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", k, m), B1("B1", n, k), C1("C1", m, n); matrix_type A2("A2", k, m), B2("B2", n, k), C2("C2", m, n); @@ -298,44 +237,35 @@ TEST( DenseLinearAlgebra, team_gemm_tc ) { A1.modify_device(); B1.modify_device(); C1.modify_device(); - - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_cn) { -TEST( DenseLinearAlgebra, team_gemm_cn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'C', transb = 'N'; - const ValueType alpha = 1.3, beta = 2.5; - + const value_type alpha = 1.3, beta = 2.5; + matrix_type A1("A1", k, m), B1("B1", k, n), C1("C1", m, n); matrix_type A2("A2", k, m), B2("B2", k, n), C2("C2", m, n); @@ -343,42 +273,33 @@ TEST( DenseLinearAlgebra, team_gemm_cn ) { B1.modify_device(); C1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_ct) { -TEST( DenseLinearAlgebra, team_gemm_ct ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'C', transb = 'T'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", k, m), B1("B1", n, k), C1("C1", m, n); matrix_type A2("A2", k, m), B2("B2", n, k), C2("C2", m, n); @@ -387,42 +308,33 @@ TEST( DenseLinearAlgebra, team_gemm_ct ) { B1.modify_device(); C1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_gemm_cc) { -TEST( DenseLinearAlgebra, team_gemm_cc ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10, k = 15; const char transa = 'C', transb = 'C'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", k, m), B1("B1", n, k), C1("C1", m, n); matrix_type A2("A2", k, m), B2("B2", n, k), C2("C2", m, n); @@ -431,89 +343,65 @@ TEST( DenseLinearAlgebra, team_gemm_cc ) { B1.modify_device(); C1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(B1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(B1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(B2.h_view, B1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - - ::Test::Functor_TeamGemm test(transa, transb, - m, n, k, - alpha, A1, B1, beta, C1); + + ::Test::Functor_TeamGemm test(transa, transb, m, n, k, alpha, A1, B1, beta, C1); test.run(); - - Blas::gemm(transa, transb, - m, n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - B2.h_view.data(), B2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemm(transa, transb, m, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), B2.h_view.data(), + B2.h_view.stride_1(), beta, C2.h_view.data(), C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } namespace Test { - struct Functor_TeamGemv { - char _trans; - int _m, _n; - matrix_type _A, _x, _y; - ValueType _alpha, _beta; - - Functor_TeamGemv(const char trans, - const int m, const int n, - const ValueType alpha, - const matrix_type &A, - const matrix_type &x, - const ValueType beta, - const matrix_type &y) - : _trans(trans), - _m(m), _n(n), - _A(A), _x(x), _y(y), - _alpha(alpha), _beta(beta) {} - - template - KOKKOS_INLINE_FUNCTION - void operator()(const MemberType &member) const { - ::BlasTeam::gemv(member, - _trans, - _m, _n, - _alpha, - (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), - (const ValueType*)_x.d_view.data(), (int)_x.d_view.stride_0(), - _beta, - ( ValueType*)_y.d_view.data(), (int)_y.d_view.stride_0()); - } - - inline - void run() { - _A.sync_device(); - _x.sync_device(); - - _y.sync_device(); - _y.modify_device(); - - Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); - - _y.sync_host(); - } - }; -} +struct Functor_TeamGemv { + char _trans; + int _m, _n; + matrix_type _A, _x, _y; + value_type _alpha, _beta; + + Functor_TeamGemv(const char trans, const int m, const int n, const value_type alpha, const matrix_type &A, + const matrix_type &x, const value_type beta, const matrix_type &y) + : _trans(trans), _m(m), _n(n), _A(A), _x(x), _y(y), _alpha(alpha), _beta(beta) {} + + template KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const { + ::BlasTeam::gemv(member, _trans, _m, _n, _alpha, (const value_type *)_A.d_view.data(), + (int)_A.d_view.stride_1(), (const value_type *)_x.d_view.data(), + (int)_x.d_view.stride_0(), _beta, (value_type *)_y.d_view.data(), + (int)_y.d_view.stride_0()); + } + + inline void run() { + _A.sync_device(); + _x.sync_device(); + + _y.sync_device(); + _y.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + _y.sync_host(); + } +}; +} // namespace Test + +TEST(DenseLinearAlgebra, team_gemv_n) { -TEST( DenseLinearAlgebra, team_gemv_n ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; const char trans = 'N'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", m, n), x1("B1", n, 1), y1("C1", m, 1); matrix_type A2("A2", m, n), x2("B2", n, 1), y2("C2", m, 1); @@ -522,41 +410,32 @@ TEST( DenseLinearAlgebra, team_gemv_n ) { x1.modify_device(); y1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(x1.d_view, random, ValueType(1)); - Kokkos::fill_random(y1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(x1.d_view, random, value_type(1)); + Kokkos::fill_random(y1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(x2.h_view, x1.d_view); Kokkos::deep_copy(y2.h_view, y1.d_view); - - ::Test::Functor_TeamGemv test(trans, - m, n, - alpha, A1, x1, beta, y1); + + ::Test::Functor_TeamGemv test(trans, m, n, alpha, A1, x1, beta, y1); test.run(); - - Blas::gemv(trans, - m, n, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - x2.h_view.data(), x2.h_view.stride_0(), - beta, - y2.h_view.data(), y2.h_view.stride_0()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemv(trans, m, n, alpha, A2.h_view.data(), A2.h_view.stride_1(), x2.h_view.data(), + x2.h_view.stride_0(), beta, y2.h_view.data(), y2.h_view.stride_0()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + EXPECT_NEAR(ats::abs(y1.h_view(i, 0)), ats::abs(y2.h_view(i, 0)), eps); } +TEST(DenseLinearAlgebra, team_gemv_t) { -TEST( DenseLinearAlgebra, team_gemv_t ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; const char trans = 'T'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", m, n), x1("x1", m, 1), y1("y1", n, 1); matrix_type A2("A2", m, n), x2("x2", m, 1), y2("y2", n, 1); @@ -565,41 +444,32 @@ TEST( DenseLinearAlgebra, team_gemv_t ) { x1.modify_device(); y1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(x1.d_view, random, ValueType(1)); - Kokkos::fill_random(y1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(x1.d_view, random, value_type(1)); + Kokkos::fill_random(y1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(x2.h_view, x1.d_view); Kokkos::deep_copy(y2.h_view, y1.d_view); - ::Test::Functor_TeamGemv test(trans, - m, n, - alpha, A1, x1, beta, y1); + ::Test::Functor_TeamGemv test(trans, m, n, alpha, A1, x1, beta, y1); test.run(); - - Blas::gemv(trans, - m, n, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - x2.h_view.data(), x2.h_view.stride_0(), - beta, - y2.h_view.data(), y2.h_view.stride_0()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemv(trans, m, n, alpha, A2.h_view.data(), A2.h_view.stride_1(), x2.h_view.data(), + x2.h_view.stride_0(), beta, y2.h_view.data(), y2.h_view.stride_0()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < n; ++i) + EXPECT_NEAR(ats::abs(y1.h_view(i, 0)), ats::abs(y2.h_view(i, 0)), eps); } +TEST(DenseLinearAlgebra, team_gemv_c) { -TEST( DenseLinearAlgebra, team_gemv_c ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; const char trans = 'C'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", m, n), x1("x1", m, 1), y1("y1", n, 1); matrix_type A2("A2", m, n), x2("x2", m, 1), y2("y2", n, 1); @@ -608,774 +478,623 @@ TEST( DenseLinearAlgebra, team_gemv_c ) { x1.modify_device(); y1.modify_device(); - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(x1.d_view, random, ValueType(1)); - Kokkos::fill_random(y1.d_view, random, ValueType(1)); - + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(x1.d_view, random, value_type(1)); + Kokkos::fill_random(y1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(x2.h_view, x1.d_view); Kokkos::deep_copy(y2.h_view, y1.d_view); - ::Test::Functor_TeamGemv test(trans, - m, n, - alpha, A1, x1, beta, y1); + ::Test::Functor_TeamGemv test(trans, m, n, alpha, A1, x1, beta, y1); test.run(); - Blas::gemv(trans, - m, n, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - x2.h_view.data(), x2.h_view.stride_0(), - beta, - y2.h_view.data(), y2.h_view.stride_0()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::gemv(trans, m, n, alpha, A2.h_view.data(), A2.h_view.stride_1(), x2.h_view.data(), + x2.h_view.stride_0(), beta, y2.h_view.data(), y2.h_view.stride_0()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < n; ++i) + EXPECT_NEAR(ats::abs(y1.h_view(i, 0)), ats::abs(y2.h_view(i, 0)), eps); } namespace Test { - struct Functor_TeamHerk { - char _uplo, _trans; - int _n, _k; - matrix_type _A, _C; - ValueType _alpha, _beta; - - Functor_TeamHerk(const char uplo, const char trans, - const int n, const int k, - const ValueType alpha, - const matrix_type &A, - const ValueType beta, - const matrix_type &C) - : _uplo(uplo), _trans(trans), - _n(n), _k(k), - _A(A), _C(C), - _alpha(alpha), _beta(beta) {} - - template - KOKKOS_INLINE_FUNCTION - void operator()(const MemberType &member) const { - ::BlasTeam::herk(member, - _uplo, _trans, - _n, _k, - _alpha, - (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), - _beta, - ( ValueType*)_C.d_view.data(), (int)_C.d_view.stride_1()); - } - - inline - void run() { - _A.sync_device(); - - _C.sync_device(); - _C.modify_device(); - - Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); - - _C.sync_host(); - } - }; -} +struct Functor_TeamHerk { + char _uplo, _trans; + int _n, _k; + matrix_type _A, _C; + value_type _alpha, _beta; + + Functor_TeamHerk(const char uplo, const char trans, const int n, const int k, const value_type alpha, + const matrix_type &A, const value_type beta, const matrix_type &C) + : _uplo(uplo), _trans(trans), _n(n), _k(k), _A(A), _C(C), _alpha(alpha), _beta(beta) {} + + template KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const { + ::BlasTeam::herk(member, _uplo, _trans, _n, _k, _alpha, (const value_type *)_A.d_view.data(), + (int)_A.d_view.stride_1(), _beta, (value_type *)_C.d_view.data(), + (int)_C.d_view.stride_1()); + } + + inline void run() { + _A.sync_device(); + + _C.sync_device(); + _C.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + _C.sync_host(); + } +}; +} // namespace Test + +TEST(DenseLinearAlgebra, team_herk_un) { -TEST( DenseLinearAlgebra, team_herk_un ) { - TEST_BEGIN; const ordinal_type n = 20, k = 10; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'N'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", n, k), C1("C1", n, n); matrix_type A2("A2", n, k), C2("C2", n, n); A1.modify_device(); C1.modify_device(); - - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - ::Test::Functor_TeamHerk test(uplo, trans, - n, k, - alpha, A1, beta, C1); + ::Test::Functor_TeamHerk test(uplo, trans, n, k, alpha, A1, beta, C1); test.run(); - Blas::herk(uplo, trans, - n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::herk(uplo, trans, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), beta, C2.h_view.data(), + C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < n; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_herk_uc) { -TEST( DenseLinearAlgebra, team_herk_uc ) { - TEST_BEGIN; const ordinal_type n = 20, k = 10; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'C'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", k, n), C1("C1", n, n); matrix_type A2("A2", k, n), C2("C2", n, n); A1.modify_device(); C1.modify_device(); - - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - ::Test::Functor_TeamHerk test(uplo, trans, - n, k, - alpha, A1, beta, C1); + ::Test::Functor_TeamHerk test(uplo, trans, n, k, alpha, A1, beta, C1); test.run(); - Blas::herk(uplo, trans, - n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::herk(uplo, trans, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), beta, C2.h_view.data(), + C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < n; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_herk_ln) { -TEST( DenseLinearAlgebra, team_herk_ln ) { - TEST_BEGIN; const ordinal_type n = 20, k = 10; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'N'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", n, k), C1("C1", n, n); matrix_type A2("A2", n, k), C2("C2", n, n); A1.modify_device(); C1.modify_device(); - - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - ::Test::Functor_TeamHerk test(uplo, trans, - n, k, - alpha, A1, beta, C1); + ::Test::Functor_TeamHerk test(uplo, trans, n, k, alpha, A1, beta, C1); test.run(); - Blas::herk(uplo, trans, - n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::herk(uplo, trans, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), beta, C2.h_view.data(), + C2.h_view.stride_1()); + + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < n; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); } +TEST(DenseLinearAlgebra, team_herk_lc) { -TEST( DenseLinearAlgebra, team_herk_lc ) { - TEST_BEGIN; const ordinal_type n = 20, k = 10; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'C'; - const ValueType alpha = 1.3, beta = 2.5; + const value_type alpha = 1.3, beta = 2.5; matrix_type A1("A1", k, n), C1("C1", n, n); matrix_type A2("A2", k, n), C2("C2", n, n); A1.modify_device(); C1.modify_device(); - - Kokkos::Random_XorShift64_Pool random(13718); - - Kokkos::fill_random(A1.d_view, random, ValueType(1)); - Kokkos::fill_random(C1.d_view, random, ValueType(1)); - + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, value_type(1)); + Kokkos::fill_random(C1.d_view, random, value_type(1)); + Kokkos::deep_copy(A2.h_view, A1.d_view); Kokkos::deep_copy(C2.h_view, C1.d_view); - ::Test::Functor_TeamHerk test(uplo, trans, - n, k, - alpha, A1, beta, C1); + ::Test::Functor_TeamHerk test(uplo, trans, n, k, alpha, A1, beta, C1); test.run(); - Blas::herk(uplo, trans, - n, k, - alpha, - A2.h_view.data(), A2.h_view.stride_1(), - beta, - C2.h_view.data(), C2.h_view.stride_1()); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::herk(uplo, trans, n, k, alpha, A2.h_view.data(), A2.h_view.stride_1(), beta, C2.h_view.data(), + C2.h_view.stride_1()); + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < n; ++i) + for (int j = 0; j < n; ++j) + EXPECT_NEAR(ats::abs(C1.h_view(i, j)), ats::abs(C2.h_view(i, j)), eps); +} namespace Test { - struct Functor_TeamTrsv { - char _uplo, _trans, _diag; - int _m; - matrix_type _A, _b; - - Functor_TeamTrsv(const char uplo, const char trans, const char diag, - const int m, - const matrix_type &A, - const matrix_type &b) - : _uplo(uplo), _trans(trans), _diag(diag), - _m(m), - _A(A), _b(b) {} - - template - KOKKOS_INLINE_FUNCTION - void operator()(const MemberType &member) const { - ::BlasTeam::trsv(member, - _uplo, _trans, _diag, - _m, - (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), - ( ValueType*)_b.d_view.data(), (int)_b.d_view.stride_0()); - } - - inline - void run() { - _A.sync_device(); - - _b.sync_device(); - _b.modify_device(); - - Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); - - _b.sync_host(); - } - }; -} +struct Functor_TeamTrsv { + char _uplo, _trans, _diag; + int _m; + matrix_type _A, _b; + + Functor_TeamTrsv(const char uplo, const char trans, const char diag, const int m, const matrix_type &A, + const matrix_type &b) + : _uplo(uplo), _trans(trans), _diag(diag), _m(m), _A(A), _b(b) {} + + template KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const { + ::BlasTeam::trsv(member, _uplo, _trans, _diag, _m, (const value_type *)_A.d_view.data(), + (int)_A.d_view.stride_1(), (value_type *)_b.d_view.data(), (int)_b.d_view.stride_0()); + } + + inline void run() { + _A.sync_device(); + + _b.sync_device(); + _b.modify_device(); + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); -#define TEAM_TRSV_TEST_BODY { \ - matrix_type A1("A1", m, m), b1("b1", m, 1); \ - matrix_type A2("A2", m, m), b2("b2", m, 1); \ - \ - A1.modify_device(); \ - b1.modify_device(); \ - \ - Kokkos::Random_XorShift64_Pool random(13718); \ - \ - Kokkos::fill_random(A1.d_view, random, ValueType(1)); \ - Kokkos::fill_random(b1.d_view, random, ValueType(1)); \ - \ - Kokkos::deep_copy(A2.h_view, A1.d_view); \ - Kokkos::deep_copy(b2.h_view, b1.d_view); \ - \ - ::Test::Functor_TeamTrsv test(uplo, trans, diag, \ - m, \ - A1, b1); \ - test.run(); \ - \ - Blas::trsv(uplo, trans, diag, \ - m, \ - A2.h_view.data(), A2.h_view.stride_1(), \ - b2.h_view.data(), b2.h_view.stride_0()); \ - \ - const MagnitudeType eps = std::numeric_limits::epsilon() * 10000; \ - for (int i=0;i random(13718); \ + \ + Kokkos::fill_random(A1.d_view, random, value_type(1)); \ + Kokkos::fill_random(b1.d_view, random, value_type(1)); \ + \ + Kokkos::deep_copy(A2.h_view, A1.d_view); \ + Kokkos::deep_copy(b2.h_view, b1.d_view); \ + \ + ::Test::Functor_TeamTrsv test(uplo, trans, diag, m, A1, b1); \ + test.run(); \ + \ + Blas::trsv(uplo, trans, diag, m, A2.h_view.data(), A2.h_view.stride_1(), b2.h_view.data(), \ + b2.h_view.stride_0()); \ + \ + const magnitude_type eps = std::numeric_limits::epsilon() * 10000; \ + for (int i = 0; i < m; ++i) \ + EXPECT_NEAR(ats::abs(b1.h_view(i, 0)), ats::abs(b2.h_view(i, 0)), eps *ats::abs(b2.h_view(i, 0))); \ + } + +TEST(DenseLinearAlgebra, team_trsv_unu) { -TEST( DenseLinearAlgebra, team_trsv_unu ) { - TEST_BEGIN; const ordinal_type m = 4; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'N'; - const char diag = 'U'; + const char diag = 'U'; TEAM_TRSV_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsv_unn) { -TEST( DenseLinearAlgebra, team_trsv_unn ) { - TEST_BEGIN; const ordinal_type m = 20; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'N'; - const char diag = 'N'; + const char diag = 'N'; TEAM_TRSV_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsv_utu) { -TEST( DenseLinearAlgebra, team_trsv_utu ) { - TEST_BEGIN; const ordinal_type m = 20; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'T'; - const char diag = 'U'; + const char diag = 'U'; TEAM_TRSV_TEST_BODY; - TEST_END; } -TEST( DenseLinearAlgebra, team_trsv_utn ) { - TEST_BEGIN; +TEST(DenseLinearAlgebra, team_trsv_utn) { + const ordinal_type m = 20; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'T'; - const char diag = 'N'; + const char diag = 'N'; TEAM_TRSV_TEST_BODY; - TEST_END; } -TEST( DenseLinearAlgebra, team_trsv_ucu ) { - TEST_BEGIN; +TEST(DenseLinearAlgebra, team_trsv_ucu) { + const ordinal_type m = 20; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'C'; - const char diag = 'U'; + const char diag = 'U'; TEAM_TRSV_TEST_BODY; - TEST_END; } -TEST( DenseLinearAlgebra, team_trsv_ucn ) { - TEST_BEGIN; +TEST(DenseLinearAlgebra, team_trsv_ucn) { + const ordinal_type m = 20; - const char uplo = 'U'; + const char uplo = 'U'; const char trans = 'C'; - const char diag = 'N'; + const char diag = 'N'; TEAM_TRSV_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsv_lnu) { - - - -TEST( DenseLinearAlgebra, team_trsv_lnu ) { - TEST_BEGIN; const ordinal_type m = 20; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'N'; - const char diag = 'U'; + const char diag = 'U'; TEAM_TRSV_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsv_lnn) { -TEST( DenseLinearAlgebra, team_trsv_lnn ) { - TEST_BEGIN; const ordinal_type m = 20; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'N'; - const char diag = 'N'; + const char diag = 'N'; TEAM_TRSV_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsv_ltu) { -TEST( DenseLinearAlgebra, team_trsv_ltu ) { - TEST_BEGIN; const ordinal_type m = 20; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'T'; - const char diag = 'U'; + const char diag = 'U'; TEAM_TRSV_TEST_BODY; - TEST_END; } -TEST( DenseLinearAlgebra, team_trsv_ltn ) { - TEST_BEGIN; +TEST(DenseLinearAlgebra, team_trsv_ltn) { + const ordinal_type m = 20; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'T'; - const char diag = 'N'; + const char diag = 'N'; TEAM_TRSV_TEST_BODY; - TEST_END; } -TEST( DenseLinearAlgebra, team_trsv_lcu ) { - TEST_BEGIN; +TEST(DenseLinearAlgebra, team_trsv_lcu) { + const ordinal_type m = 20; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'C'; - const char diag = 'U'; + const char diag = 'U'; TEAM_TRSV_TEST_BODY; - TEST_END; } -TEST( DenseLinearAlgebra, team_trsv_lcn ) { - TEST_BEGIN; +TEST(DenseLinearAlgebra, team_trsv_lcn) { + const ordinal_type m = 20; - const char uplo = 'L'; + const char uplo = 'L'; const char trans = 'C'; - const char diag = 'N'; + const char diag = 'N'; TEAM_TRSV_TEST_BODY; - TEST_END; } #undef TEAM_TRSV_TEST_BODY - namespace Test { - struct Functor_TeamTrsm { - char _side, _uplo, _trans, _diag; - int _m, _n; - matrix_type _A, _B; - ValueType _alpha; - - Functor_TeamTrsm(const char side, const char uplo, const char trans, const char diag, - const int m, const int n, - const ValueType alpha, - const matrix_type &A, - const matrix_type &B) - : _side(side), _uplo(uplo), _trans(trans), _diag(diag), - _m(m), _n(n), - _A(A), _B(B), - _alpha(alpha) {} - - template - KOKKOS_INLINE_FUNCTION - void operator()(const MemberType &member) const { - ::BlasTeam::trsm(member, - _side, _uplo, _trans, _diag, - _m, _n, - _alpha, - (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), - ( ValueType*)_B.d_view.data(), (int)_B.d_view.stride_1()); - } - - inline - void run() { - _A.sync_device(); - - _B.sync_device(); - _B.modify_device(); - - Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); - - _B.sync_host(); - } - }; -} +struct Functor_TeamTrsm { + char _side, _uplo, _trans, _diag; + int _m, _n; + matrix_type _A, _B; + value_type _alpha; + + Functor_TeamTrsm(const char side, const char uplo, const char trans, const char diag, const int m, const int n, + const value_type alpha, const matrix_type &A, const matrix_type &B) + : _side(side), _uplo(uplo), _trans(trans), _diag(diag), _m(m), _n(n), _A(A), _B(B), _alpha(alpha) {} + + template KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const { + ::BlasTeam::trsm(member, _side, _uplo, _trans, _diag, _m, _n, _alpha, + (const value_type *)_A.d_view.data(), (int)_A.d_view.stride_1(), + (value_type *)_B.d_view.data(), (int)_B.d_view.stride_1()); + } + + inline void run() { + _A.sync_device(); + + _B.sync_device(); + _B.modify_device(); -#define TEAM_TRSM_TEST_BODY { \ - matrix_type A1("A1", m, m), B1("b1", m, n); \ - matrix_type A2("A2", m, m), B2("b2", m, n); \ - \ - A1.modify_device(); \ - B1.modify_device(); \ - \ - Kokkos::Random_XorShift64_Pool random(13718); \ - \ - Kokkos::fill_random(A1.d_view, random, ValueType(1)); \ - Kokkos::fill_random(B1.d_view, random, ValueType(1)); \ - \ - Kokkos::deep_copy(A2.h_view, A1.d_view); \ - Kokkos::deep_copy(B2.h_view, B1.d_view); \ - \ - ::Test::Functor_TeamTrsm test(side, uplo, trans, diag, \ - m, n, \ - alpha, \ - A1, B1); \ - test.run(); \ - \ - Blas::trsm(side, uplo, trans, diag, \ - m, n, \ - alpha, \ - A2.h_view.data(), A2.h_view.stride_1(), \ - B2.h_view.data(), B2.h_view.stride_1()); \ - \ - const MagnitudeType eps = std::numeric_limits::epsilon() * 100000; \ - for (int i=0;i(1, Kokkos::AUTO), *this); + + _B.sync_host(); + } +}; +} // namespace Test + +#define TEAM_TRSM_TEST_BODY \ + { \ + matrix_type A1("A1", m, m), B1("b1", m, n); \ + matrix_type A2("A2", m, m), B2("b2", m, n); \ + \ + A1.modify_device(); \ + B1.modify_device(); \ + \ + Kokkos::Random_XorShift64_Pool random(13718); \ + \ + Kokkos::fill_random(A1.d_view, random, value_type(1)); \ + Kokkos::fill_random(B1.d_view, random, value_type(1)); \ + \ + Kokkos::deep_copy(A2.h_view, A1.d_view); \ + Kokkos::deep_copy(B2.h_view, B1.d_view); \ + \ + ::Test::Functor_TeamTrsm test(side, uplo, trans, diag, m, n, alpha, A1, B1); \ + test.run(); \ + \ + Blas::trsm(side, uplo, trans, diag, m, n, alpha, A2.h_view.data(), A2.h_view.stride_1(), \ + B2.h_view.data(), B2.h_view.stride_1()); \ + \ + const magnitude_type eps = std::numeric_limits::epsilon() * 100000; \ + for (int i = 0; i < m; ++i) \ + for (int j = 0; j < n; ++j) \ + EXPECT_NEAR(ats::abs(B1.h_view(i, j)), ats::abs(B2.h_view(i, j)), eps *ats::abs(B2.h_view(i, j))); \ } -TEST( DenseLinearAlgebra, team_trsm_lunu ) { - TEST_BEGIN; +TEST(DenseLinearAlgebra, team_trsm_lunu) { + const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'U'; + const char side = 'L'; + const char uplo = 'U'; const char trans = 'N'; - const char diag = 'U'; - const ValueType alpha = 1.2; + const char diag = 'U'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_lunn) { -TEST( DenseLinearAlgebra, team_trsm_lunn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'U'; + const char side = 'L'; + const char uplo = 'U'; const char trans = 'N'; - const char diag = 'N'; - const ValueType alpha = 1.2; + const char diag = 'N'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_lutu) { -TEST( DenseLinearAlgebra, team_trsm_lutu ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'U'; + const char side = 'L'; + const char uplo = 'U'; const char trans = 'T'; - const char diag = 'U'; - const ValueType alpha = 1.2; + const char diag = 'U'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_lutn) { -TEST( DenseLinearAlgebra, team_trsm_lutn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'U'; + const char side = 'L'; + const char uplo = 'U'; const char trans = 'T'; - const char diag = 'N'; - const ValueType alpha = 1.2; + const char diag = 'N'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_lucu) { -TEST( DenseLinearAlgebra, team_trsm_lucu ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'U'; + const char side = 'L'; + const char uplo = 'U'; const char trans = 'C'; - const char diag = 'U'; - const ValueType alpha = 1.2; + const char diag = 'U'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_lucn) { -TEST( DenseLinearAlgebra, team_trsm_lucn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'U'; + const char side = 'L'; + const char uplo = 'U'; const char trans = 'C'; - const char diag = 'N'; - const ValueType alpha = 1.2; + const char diag = 'N'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_llnu) { -TEST( DenseLinearAlgebra, team_trsm_llnu ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'L'; + const char side = 'L'; + const char uplo = 'L'; const char trans = 'N'; - const char diag = 'U'; - const ValueType alpha = 1.2; + const char diag = 'U'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_llnn) { -TEST( DenseLinearAlgebra, team_trsm_llnn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'L'; + const char side = 'L'; + const char uplo = 'L'; const char trans = 'N'; - const char diag = 'N'; - const ValueType alpha = 1.2; + const char diag = 'N'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_lltu) { -TEST( DenseLinearAlgebra, team_trsm_lltu ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'L'; + const char side = 'L'; + const char uplo = 'L'; const char trans = 'T'; - const char diag = 'U'; - const ValueType alpha = 1.2; + const char diag = 'U'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_lltn) { -TEST( DenseLinearAlgebra, team_trsm_lltn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'L'; + const char side = 'L'; + const char uplo = 'L'; const char trans = 'T'; - const char diag = 'N'; - const ValueType alpha = 1.2; + const char diag = 'N'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_llcu) { -TEST( DenseLinearAlgebra, team_trsm_llcu ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'L'; + const char side = 'L'; + const char uplo = 'L'; const char trans = 'C'; - const char diag = 'U'; - const ValueType alpha = 1.2; + const char diag = 'U'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } +TEST(DenseLinearAlgebra, team_trsm_llcn) { -TEST( DenseLinearAlgebra, team_trsm_llcn ) { - TEST_BEGIN; const ordinal_type m = 20, n = 10; - const char side = 'L'; - const char uplo = 'L'; + const char side = 'L'; + const char uplo = 'L'; const char trans = 'C'; - const char diag = 'N'; - const ValueType alpha = 1.2; + const char diag = 'N'; + const value_type alpha = 1.2; TEAM_TRSM_TEST_BODY; - TEST_END; } #undef TEAM_TRSM_TEST_BODY namespace Test { - struct Functor_TeamChol { - char _uplo; - int _m; - matrix_type _A; - - Functor_TeamChol(const char uplo, - const int m, - const matrix_type &A) - : _uplo(uplo), - _m(m), - _A(A) {} - - template - KOKKOS_INLINE_FUNCTION - void operator()(const MemberType &member) const { - int r_val = 0; - ::LapackTeam::potrf(member, - _uplo, - _m, - (ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), - &r_val); - } - - inline - void run() { - _A.sync_device(); - _A.modify_device(); - - Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); - - _A.sync_host(); - } - }; -} +struct Functor_TeamChol { + char _uplo; + int _m; + matrix_type _A; + + Functor_TeamChol(const char uplo, const int m, const matrix_type &A) : _uplo(uplo), _m(m), _A(A) {} + + template KOKKOS_INLINE_FUNCTION void operator()(const MemberType &member) const { + int r_val = 0; + ::LapackTeam::potrf(member, _uplo, _m, (value_type *)_A.d_view.data(), (int)_A.d_view.stride_1(), + &r_val); + } + + inline void run() { + _A.sync_device(); + _A.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _A.sync_host(); + } +}; +} // namespace Test +TEST(DenseLinearAlgebra, team_chol_u) { -TEST( DenseLinearAlgebra, team_chol_u ) { - TEST_BEGIN; const ordinal_type m = 20; const char uplo = 'U'; - + matrix_type A1("A1", m, m); matrix_type A2("A2", m, m); A2.modify_host(); - - for (int i=0;i::potrf(uplo, - m, - (ValueType*)A2.h_view.data(), (int)A2.h_view.stride_1(), - &r_val); - - const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; - for (int i=0;i::potrf(uplo, m, (value_type *)A2.h_view.data(), (int)A2.h_view.stride_1(), &r_val); + const magnitude_type eps = std::numeric_limits::epsilon() * 1000; + for (int i = 0; i < m; ++i) + for (int j = 0; j < m; ++j) + EXPECT_NEAR(ats::abs(A1.h_view(i, j)), ats::abs(A2.h_view(i, j)), eps); +} #endif diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestGraph.hpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestGraph.hpp deleted file mode 100644 index 9e05e32bccce..000000000000 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestGraph.hpp +++ /dev/null @@ -1,127 +0,0 @@ -#ifndef __TACHO_TEST_GRAPH_HPP__ -#define __TACHO_TEST_GRAPH_HPP__ - -#include - -#include -#include - -#include "Tacho_Util.hpp" -#include "Tacho_CrsMatrixBase.hpp" -#include "Tacho_MatrixMarket.hpp" - -#include "Tacho_Graph.hpp" - -#if defined(TACHO_HAVE_SCOTCH) -#include "Tacho_GraphTools_Scotch.hpp" -#endif - -#if defined(TACHO_HAVE_METIS) -#include "Tacho_GraphTools_Metis.hpp" -#endif - -using namespace Tacho; - -typedef CrsMatrixBase CrsMatrixBaseHostType; -//typedef CrsMatrixBase CrsMatrixBaseDeviceType; - -TEST( Graph, constructor ) { - TEST_BEGIN; - /// - /// host space crs matrix base - /// - const ordinal_type - m = 4, - n = 4, - nnz = 16; - - CrsMatrixBaseHostType Ah(m, n, nnz); - - ordinal_type cnt = 0; - for (ordinal_type i=0;i::read(inputfilename, Ah); - - Graph G(Ah); - GraphTools_Scotch S(G); - - const ordinal_type m = G.NumRows(); - S.setTreeLevel(log2(m)); - S.setStrategy( SCOTCH_STRATSPEED - | SCOTCH_STRATSPEED - | SCOTCH_STRATLEVELMAX - | SCOTCH_STRATLEVELMIN - | SCOTCH_STRATLEAFSIMPLE - | SCOTCH_STRATSEPASIMPLE - ); - S.reorder(); - - /// - /// perm and invperm should be properly setup - /// - const auto perm = S.PermVector(); - const auto peri = S.InvPermVector(); - - for (ordinal_type i=0;i::read(inputfilename, Ah); - - Graph G(Ah); - GraphTools_Metis M(G); - - M.reorder(); - - /// - /// perm and invperm should be properly setup - /// - const auto perm = M.PermVector(); - const auto peri = M.InvPermVector(); - - const ordinal_type m = G.NumRows(); - for (ordinal_type i=0;i + +#include +#include + +#include "Tacho_Util.hpp" +#include "Tacho_CrsMatrixBase.hpp" +#include "Tacho_MatrixMarket.hpp" + +#include "Tacho_Graph.hpp" + +#if defined(TACHO_HAVE_METIS) +#include "Tacho_GraphTools_Metis.hpp" +#endif + +namespace Test { + + TEST( Graph, constructor ) { + + /// + /// host space crs matrix base + /// + const ordinal_type + m = 4, + n = 4, + nnz = 16; + + crs_matrix_base_type_host Ah(m, n, nnz); + + ordinal_type cnt = 0; + for (ordinal_type i=0;i::read(filename, Ah); + + Graph G(Ah); + GraphTools_Metis M(G); + + M.reorder(); + + /// + /// perm and invperm should be properly setup + /// + const auto perm = M.PermVector(); + const auto peri = M.InvPermVector(); + + const ordinal_type m = G.NumRows(); + for (ordinal_type i=0;i - -#include -#include - -static const std::string MM_TEST_FILE="test_dcomplex"; - -#define TEST_BEGIN -#define TEST_END - -#define __TACHO_TEST_OPENMP__ -#include "Tacho_config.h" -#include "Tacho_Util.hpp" - -typedef typename Tacho::UseThisDevice::type HostDeviceType; -typedef typename Tacho::UseThisDevice::type DeviceType; - -typedef Kokkos::complex ValueType; -typedef double MagnitudeType; - -#include "Tacho_Test.hpp" - -using namespace Tacho; - -int main (int argc, char *argv[]) { - - Kokkos::initialize(argc, argv); - - TEST_BEGIN; - - const bool detail = false; - printExecSpaceConfiguration("DeviceSpace", detail); - printExecSpaceConfiguration("HostSpace", detail); - - TEST_END; - - ::testing::InitGoogleTest(&argc, argv); - const int r_val = RUN_ALL_TESTS(); - - Kokkos::finalize(); - - return r_val; -} - diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestOpenMP_double.cpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestOpenMP_double.cpp deleted file mode 100644 index 2762908c5fef..000000000000 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestOpenMP_double.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include - -#include -#include - -static const std::string MM_TEST_FILE="test_double"; - -#define TEST_BEGIN -#define TEST_END - -#define __TACHO_TEST_OPENMP__ -#include "Tacho_config.h" -#include "Tacho_Util.hpp" - -typedef typename Tacho::UseThisDevice::type HostDeviceType; -typedef typename Tacho::UseThisDevice::type DeviceType; - -typedef double ValueType; -typedef double MagnitudeType; - -#include "Tacho_Test.hpp" - -using namespace Tacho; - -int main (int argc, char *argv[]) { - - Kokkos::initialize(argc, argv); - - TEST_BEGIN; - - const bool detail = false; - printExecSpaceConfiguration("DeviceSpace", detail); - printExecSpaceConfiguration("HostSpace", detail); - - TEST_END; - - ::testing::InitGoogleTest(&argc, argv); - const int r_val = RUN_ALL_TESTS(); - - Kokkos::finalize(); - - return r_val; -} diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSerial_dcomplex.cpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSerial_dcomplex.cpp deleted file mode 100644 index a4802f1661e0..000000000000 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSerial_dcomplex.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include - -#include -#include - -static const std::string MM_TEST_FILE="test_dcomplex"; - -#define TEST_BEGIN -#define TEST_END - -#define __TACHO_TEST_SERIAL__ -#include "Tacho_config.h" -#include "Tacho_Util.hpp" - -typedef typename Tacho::UseThisDevice::type HostDeviceType; -typedef typename Tacho::UseThisDevice::type DeviceType; - -typedef Kokkos::complex ValueType; -typedef double MagnitudeType; - -#include "Tacho_Test.hpp" - -using namespace Tacho; - -int main (int argc, char *argv[]) { - - Kokkos::initialize(argc, argv); - - TEST_BEGIN; - - const bool detail = false; - printExecSpaceConfiguration("DeviceSpace", detail); - printExecSpaceConfiguration("HostSpace", detail); - - TEST_END; - - ::testing::InitGoogleTest(&argc, argv); - const int r_val = RUN_ALL_TESTS(); - - Kokkos::finalize(); - - return r_val; -} - diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSerial_double.cpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSerial_double.cpp deleted file mode 100644 index 9735bfd956c7..000000000000 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSerial_double.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include - -#include -#include - -static const std::string MM_TEST_FILE="test_double"; - -#define TEST_BEGIN -#define TEST_END - -#define __TACHO_TEST_SERIAL__ -#include "Tacho_config.h" -#include "Tacho_Util.hpp" - -typedef typename Tacho::UseThisDevice::type HostDeviceType; -typedef typename Tacho::UseThisDevice::type DeviceType; - -typedef double ValueType; -typedef double MagnitudeType; - -#include "Tacho_Test.hpp" - -using namespace Tacho; - -int main (int argc, char *argv[]) { - - Kokkos::initialize(argc, argv); - - TEST_BEGIN; - - const bool detail = false; - printExecSpaceConfiguration("DeviceSpace", detail); - printExecSpaceConfiguration("HostSpace", detail); - - TEST_END; - - ::testing::InitGoogleTest(&argc, argv); - const int r_val = RUN_ALL_TESTS(); - - Kokkos::finalize(); - - return r_val; -} diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSymbolic.hpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSymbolic.hpp deleted file mode 100644 index 2f6952193674..000000000000 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSymbolic.hpp +++ /dev/null @@ -1,186 +0,0 @@ -#ifndef __TACHO_TEST_SYMBOLIC_HPP__ -#define __TACHO_TEST_SYMBOLIC_HPP__ - -#include - -#include -#include - -#include "Tacho_Util.hpp" -#include "Tacho_CrsMatrixBase.hpp" -#include "Tacho_MatrixMarket.hpp" - -#include "Tacho_Graph.hpp" -#include "Tacho_SymbolicTools.hpp" - -#include "Tacho_GraphTools.hpp" - -#if defined(TACHO_HAVE_SCOTCH) -#include "Tacho_GraphTools_Scotch.hpp" -#endif - -#if defined(TACHO_HAVE_METIS) -#include "Tacho_GraphTools_Metis.hpp" -#endif - -using namespace Tacho; - -typedef CrsMatrixBase CrsMatrixBaseHostType; - -// we do not test for device space -//typedef CrsMatrixBase CrsMatrixBaseDeviceType; - -TEST( Symbolic, constructor ) { - TEST_BEGIN; - const ordinal_type - m = 4, - n = 4, - nnz = 16; - - CrsMatrixBaseHostType A(m, n, nnz); - - ordinal_type cnt = 0; - for (ordinal_type i=0;i ordinal_type_array; - - ordinal_type_array idx("idx", m); - for (ordinal_type i=0;i::read(inputfilename, A); - - Graph G(A); - -#if defined(TACHO_HAVE_METIS) - GraphTools_Metis T(G); -#elif defined(TACHO_HAVE_SCOTCH) - GraphTools_Scotch T(G); -#else - GraphTools T(G); -#endif - T.reorder(); - - typedef Kokkos::View ordinal_type_array; - typedef Kokkos::View size_type_array; - - ordinal_type m = A.NumRows(); - size_type_array ap = A.RowPtr(); - ordinal_type_array - aj = A.Cols(), - perm = T.PermVector(), - peri = T.InvPermVector(), - parent("parent", m), - ancestor("ancestor", m); - - SymbolicTools::computeEliminationTree(m, ap, aj, perm, peri, parent, ancestor); - - ordinal_type_array work("work", m*4); - - typedef Kokkos::pair range_type; - auto post = Kokkos::subview(work, range_type(0*m, 1*m)); - auto w = Kokkos::subview(work, range_type(1*m, 4*m)); - SymbolicTools::computePostOrdering(m, parent, post, w); - - size_type_array up; - ordinal_type_array uj; - SymbolicTools::computeFillPatternUpper(m, ap, aj, perm, peri, up, uj, work); - - ordinal_type_array supernodes; - SymbolicTools::computeSupernodes(m, ap, aj, perm, peri, parent, supernodes, work); - - // allocate supernodes - size_type_array gid_super_panel_ptr, sid_super_panel_ptr; - ordinal_type_array gid_super_panel_colidx, sid_super_panel_colidx, blk_super_panel_colidx; - SymbolicTools::allocateSupernodes(m, up, uj, supernodes, work, - gid_super_panel_ptr, - gid_super_panel_colidx, - sid_super_panel_ptr, - sid_super_panel_colidx, - blk_super_panel_colidx); - - size_type_array stree_ptr; - ordinal_type_array stree_level, stree_parent, stree_children, stree_roots; - SymbolicTools::computeSupernodesAssemblyTree(parent, - supernodes, - stree_level, - stree_parent, - stree_ptr, - stree_children, - stree_roots, - work); - - // const size_type numSupernodes = supernodes.extent(0) - 1; - // printf("supernodes = \n"); - // for (size_type i=0;i::read(inputfilename, A); - - Graph G(A); - -#if defined(TACHO_HAVE_METIS) - GraphTools_Metis T(G); -#elif defined(TACHO_HAVE_SCOTCH) - GraphTools_Scotch T(G); -#else - GraphTools T(G); -#endif - T.reorder(); - - { - SymbolicTools S(A.NumRows(), - A.RowPtr(), - A.Cols(), - T.PermVector(), - T.InvPermVector()); - - S.symbolicFactorize(); - } - { - SymbolicTools S(A, T); - S.symbolicFactorize(); - } - TEST_END; -} - - -#endif diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSymbolicTools.hpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSymbolicTools.hpp new file mode 100644 index 000000000000..96c139db612a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestSymbolicTools.hpp @@ -0,0 +1,163 @@ +#ifndef __TACHO_TEST_SYMBOLIC_TOOLS_HPP__ +#define __TACHO_TEST_SYMBOLIC_TOOLS_HPP__ + +#include + +#include +#include + +#include "Tacho_Util.hpp" +#include "Tacho_CrsMatrixBase.hpp" +#include "Tacho_MatrixMarket.hpp" + +#include "Tacho_Graph.hpp" +#include "Tacho_SymbolicTools.hpp" + +#include "Tacho_GraphTools.hpp" + +#if defined(TACHO_HAVE_METIS) +#include "Tacho_GraphTools_Metis.hpp" +#endif + +namespace Test { + + TEST( Symbolic, constructor ) { + const ordinal_type + m = 4, + n = 4, + nnz = 16; + + crs_matrix_base_type_host A(m, n, nnz); + + ordinal_type cnt = 0; + for (ordinal_type i=0;i; + auto post = Kokkos::subview(work, range_type(0*m, 1*m)); + auto w = Kokkos::subview(work, range_type(1*m, 4*m)); + SymbolicTools::computePostOrdering(m, parent, post, w); + + size_type_array_type_host up; + ordinal_type_array_type_host uj; + SymbolicTools::computeFillPatternUpper(m, ap, aj, perm, peri, up, uj, work); + + ordinal_type_array_type_host supernodes; + SymbolicTools::computeSupernodes(m, ap, aj, perm, peri, parent, supernodes, work); + + // allocate supernodes + size_type_array_type_host gid_super_panel_ptr, sid_super_panel_ptr; + ordinal_type_array_type_host gid_super_panel_colidx, sid_super_panel_colidx, blk_super_panel_colidx; + SymbolicTools::allocateSupernodes(m, up, uj, supernodes, work, + gid_super_panel_ptr, + gid_super_panel_colidx, + sid_super_panel_ptr, + sid_super_panel_colidx, + blk_super_panel_colidx); + + size_type_array_type_host stree_ptr; + ordinal_type_array_type_host stree_level, stree_parent, stree_children, stree_roots; + SymbolicTools::computeSupernodesAssemblyTree(parent, + supernodes, + stree_level, + stree_parent, + stree_ptr, + stree_children, + stree_roots, + work); + + // const size_type numSupernodes = supernodes.extent(0) - 1; + // printf("supernodes = \n"); + // for (size_type i=0;i::read(filename, A); + + Graph G(A); + +#if defined(TACHO_HAVE_METIS) + GraphTools_Metis T(G); +#else + GraphTools T(G); +#endif + T.reorder(); + + { + SymbolicTools S(A.NumRows(), + A.RowPtr(), + A.Cols(), + T.PermVector(), + T.InvPermVector()); + + S.symbolicFactorize(); + } + { + SymbolicTools S(A, T); + S.symbolicFactorize(); + } + } + +} +#endif diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestUtil.cpp b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestUtil.cpp index e485a8d79924..de03c86d0ce7 100644 --- a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestUtil.cpp +++ b/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestUtil.cpp @@ -2,97 +2,68 @@ #include #include "Tacho.hpp" -#include "Tacho_MatrixMarket.hpp" #include "Tacho_Util.hpp" +#include "Tacho_MatrixMarket.hpp" -using namespace Tacho; - -typedef typename UseThisDevice::type HostDeviceType; -typedef typename UseThisDevice::type DeviceType; - -#define TEST_BEGIN -#define TEST_END - -template using TaskSchedulerType = Kokkos::TaskSchedulerMultiple; -static const char * scheduler_name = "TaskSchedulerMultiple"; +using host_device_type = typename Tacho::UseThisDevice::type; +using device_type = typename Tacho::UseThisDevice::type; TEST( Util, is_complex_type ) { - TEST_BEGIN; - EXPECT_FALSE(int(ArithTraits::is_complex)); - EXPECT_TRUE(int(ArithTraits >::is_complex)); - EXPECT_TRUE(int(ArithTraits >::is_complex)); - TEST_END; + EXPECT_FALSE(int(Tacho::ArithTraits::is_complex)); + EXPECT_TRUE(int(Tacho::ArithTraits >::is_complex)); + EXPECT_TRUE(int(Tacho::ArithTraits >::is_complex)); } -TEST( util, tag ) { - TEST_BEGIN; - // EXPECT_TRUE(is_valid_partition_tag::value); - // EXPECT_TRUE(is_valid_partition_tag::value); - // EXPECT_TRUE(is_valid_partition_tag::value); - // EXPECT_TRUE(is_valid_partition_tag::value); - // EXPECT_TRUE(is_valid_partition_tag::value); - // EXPECT_TRUE(is_valid_partition_tag::value); - // EXPECT_TRUE(is_valid_partition_tag::value); - // EXPECT_TRUE(is_valid_partition_tag::value); - - EXPECT_TRUE(int(is_valid_uplo_tag::value)); - EXPECT_TRUE(int(is_valid_uplo_tag::value)); - - EXPECT_TRUE(int(is_valid_side_tag::value)); - EXPECT_TRUE(int(is_valid_side_tag::value)); - - EXPECT_TRUE(int(is_valid_diag_tag::value)); - EXPECT_TRUE(int(is_valid_diag_tag::value)); - - EXPECT_TRUE(int(is_valid_trans_tag::value)); - EXPECT_TRUE(int(is_valid_trans_tag::value)); - EXPECT_TRUE(int(is_valid_trans_tag::value)); - - // EXPECT_FALSE(is_valid_partition_tag::value); - EXPECT_FALSE(int(is_valid_uplo_tag::value)); - EXPECT_FALSE(int(is_valid_side_tag::value)); - EXPECT_FALSE(int(is_valid_diag_tag::value)); - EXPECT_FALSE(int(is_valid_trans_tag::value)); - TEST_END; +TEST( Util, tag ) { + using Tacho::NullTag; + using Tacho::Partition; + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_partition_tag::value)); + + using Tacho::Uplo; + EXPECT_TRUE(int(Tacho::is_valid_uplo_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_uplo_tag::value)); + EXPECT_FALSE(int(Tacho::is_valid_uplo_tag::value)); + + using Tacho::Side; + EXPECT_TRUE(int(Tacho::is_valid_side_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_side_tag::value)); + EXPECT_FALSE(int(Tacho::is_valid_side_tag::value)); + + using Tacho::Diag; + EXPECT_TRUE(int(Tacho::is_valid_diag_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_diag_tag::value)); + EXPECT_FALSE(int(Tacho::is_valid_diag_tag::value)); + + using Tacho::Trans; + EXPECT_TRUE(int(Tacho::is_valid_trans_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_trans_tag::value)); + EXPECT_TRUE(int(Tacho::is_valid_trans_tag::value)); + EXPECT_FALSE(int(Tacho::is_valid_trans_tag::value)); } -TEST( util, task_scheduler ) { - TEST_BEGIN; - - size_t span = 1024*8; // 100 - unsigned int min_block_size = 8; // 10 - unsigned int max_block_size = 1024; // 10 - unsigned int superblock_size = 1024; // 10 - - typedef TaskSchedulerType host_scheduler_type; - host_scheduler_type host_sched(typename host_scheduler_type::memory_space(), - span, - min_block_size, - max_block_size, - superblock_size); - typedef TaskSchedulerType device_scheduler_type; - device_scheduler_type device_sched(typename device_scheduler_type::memory_space(), - span, - min_block_size, - max_block_size, - superblock_size); - TEST_END; -} +/// TODO:: add task scheduler instanciation test int main (int argc, char *argv[]) { + int r_val(0); Kokkos::initialize(argc, argv); - - const bool detail = false; - printExecSpaceConfiguration("DeviceSpace", detail); - printExecSpaceConfiguration("HostSpace", detail); + { + const bool detail = false; + Tacho::printExecSpaceConfiguration("DeviceSpace", detail); + Tacho::printExecSpaceConfiguration("HostSpace", detail); - printf("Scheduler Type = %s\n", scheduler_name); - - ::testing::InitGoogleTest(&argc, argv); + ::testing::InitGoogleTest(&argc, argv); - int result = RUN_ALL_TESTS(); + r_val = RUN_ALL_TESTS(); + } Kokkos::finalize(); - return result; + return r_val; } diff --git a/packages/shylu/shylu_node/tacho/unit-test/Tacho_TestDenseByBlocks.hpp b/packages/shylu/shylu_node/tacho/unit-test/do-not-test-yet/Tacho_TestDenseByBlocks.hpp similarity index 100% rename from packages/shylu/shylu_node/tacho/unit-test/Tacho_TestDenseByBlocks.hpp rename to packages/shylu/shylu_node/tacho/unit-test/do-not-test-yet/Tacho_TestDenseByBlocks.hpp diff --git a/packages/shylu/shylu_node/tacho/unit-test/do-not-test-yet/Tacho_TestDenseLinearAlgebra.hpp b/packages/shylu/shylu_node/tacho/unit-test/do-not-test-yet/Tacho_TestDenseLinearAlgebra.hpp new file mode 100644 index 000000000000..73b8670cd621 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/do-not-test-yet/Tacho_TestDenseLinearAlgebra.hpp @@ -0,0 +1,1383 @@ +#ifndef __TACHO_TEST_DENSE_LINEAR_ALGEBRA_HPP__ +#define __TACHO_TEST_DENSE_LINEAR_ALGEBRA_HPP__ + +#include + +#include +#include + +#include +#include + +#include "Tacho_Util.hpp" +#include "Tacho_Blas_Team.hpp" +#include "Tacho_Blas_External.hpp" + +#include "Tacho_Lapack_Team.hpp" +#include "Tacho_Lapack_External.hpp" + +using namespace Tacho; +using std::abs; +using Kokkos::abs; + +typedef Kokkos::DualView matrix_type; + +namespace Test { + struct Functor_TeamGemm { + char _transa, _transb; + int _m, _n, _k; + matrix_type _A, _B, _C; + ValueType _alpha, _beta; + + Functor_TeamGemm(const char transa, const char transb, + const int m, const int n, const int k, + const ValueType alpha, + const matrix_type &A, + const matrix_type &B, + const ValueType beta, + const matrix_type &C) + : _transa(transa), _transb(transb), + _m(m), _n(n), _k(k), + _A(A), _B(B), _C(C), + _alpha(alpha), _beta(beta) {} + + template + KOKKOS_INLINE_FUNCTION + void operator()(const MemberType &member) const { + ::BlasTeam::gemm(member, + _transa, _transb, + _m, _n, _k, + _alpha, + (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), + (const ValueType*)_B.d_view.data(), (int)_B.d_view.stride_1(), + _beta, + ( ValueType*)_C.d_view.data(), (int)_C.d_view.stride_1()); + } + + inline + void run() { + _A.sync_device(); + _B.sync_device(); + + _C.sync_device(); + _C.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _C.sync_host(); + } + }; +} + + +TEST( DenseLinearAlgebra, team_gemm_nn ) { + TEST_BEGIN; + const ordinal_type m = 20, n = 10, k = 15; + const char transa = 'N', transb = 'N'; + const ValueType alpha = 1.3, beta = 2.5; + + // test problem setup + matrix_type A1("A1", m, k), B1("B1", k, n), C1("C1", m, n); + matrix_type A2("A2", m, k), B2("B2", k, n), C2("C2", m, n); + + A1.modify_device(); + B1.modify_device(); + C1.modify_device(); + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + // tacho test + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + // reference test + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + // tacho test + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + // reference test + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(B1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(B2.h_view, B1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamGemm test(transa, transb, + m, n, k, + alpha, A1, B1, beta, C1); + test.run(); + + Blas::gemm(transa, transb, + m, n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + B2.h_view.data(), B2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i + KOKKOS_INLINE_FUNCTION + void operator()(const MemberType &member) const { + ::BlasTeam::gemv(member, + _trans, + _m, _n, + _alpha, + (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), + (const ValueType*)_x.d_view.data(), (int)_x.d_view.stride_0(), + _beta, + ( ValueType*)_y.d_view.data(), (int)_y.d_view.stride_0()); + } + + inline + void run() { + _A.sync_device(); + _x.sync_device(); + + _y.sync_device(); + _y.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _y.sync_host(); + } + }; +} + + +TEST( DenseLinearAlgebra, team_gemv_n ) { + TEST_BEGIN; + const ordinal_type m = 20, n = 10; + const char trans = 'N'; + const ValueType alpha = 1.3, beta = 2.5; + + matrix_type A1("A1", m, n), x1("B1", n, 1), y1("C1", m, 1); + matrix_type A2("A2", m, n), x2("B2", n, 1), y2("C2", m, 1); + + A1.modify_device(); + x1.modify_device(); + y1.modify_device(); + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(x1.d_view, random, ValueType(1)); + Kokkos::fill_random(y1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(x2.h_view, x1.d_view); + Kokkos::deep_copy(y2.h_view, y1.d_view); + + ::Test::Functor_TeamGemv test(trans, + m, n, + alpha, A1, x1, beta, y1); + test.run(); + + Blas::gemv(trans, + m, n, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + x2.h_view.data(), x2.h_view.stride_0(), + beta, + y2.h_view.data(), y2.h_view.stride_0()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(x1.d_view, random, ValueType(1)); + Kokkos::fill_random(y1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(x2.h_view, x1.d_view); + Kokkos::deep_copy(y2.h_view, y1.d_view); + + ::Test::Functor_TeamGemv test(trans, + m, n, + alpha, A1, x1, beta, y1); + test.run(); + + Blas::gemv(trans, + m, n, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + x2.h_view.data(), x2.h_view.stride_0(), + beta, + y2.h_view.data(), y2.h_view.stride_0()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(x1.d_view, random, ValueType(1)); + Kokkos::fill_random(y1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(x2.h_view, x1.d_view); + Kokkos::deep_copy(y2.h_view, y1.d_view); + + ::Test::Functor_TeamGemv test(trans, + m, n, + alpha, A1, x1, beta, y1); + test.run(); + + Blas::gemv(trans, + m, n, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + x2.h_view.data(), x2.h_view.stride_0(), + beta, + y2.h_view.data(), y2.h_view.stride_0()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i + KOKKOS_INLINE_FUNCTION + void operator()(const MemberType &member) const { + ::BlasTeam::herk(member, + _uplo, _trans, + _n, _k, + _alpha, + (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), + _beta, + ( ValueType*)_C.d_view.data(), (int)_C.d_view.stride_1()); + } + + inline + void run() { + _A.sync_device(); + + _C.sync_device(); + _C.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _C.sync_host(); + } + }; +} + + +TEST( DenseLinearAlgebra, team_herk_un ) { + TEST_BEGIN; + const ordinal_type n = 20, k = 10; + const char uplo = 'U'; + const char trans = 'N'; + const ValueType alpha = 1.3, beta = 2.5; + + matrix_type A1("A1", n, k), C1("C1", n, n); + matrix_type A2("A2", n, k), C2("C2", n, n); + + A1.modify_device(); + C1.modify_device(); + + Kokkos::Random_XorShift64_Pool random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamHerk test(uplo, trans, + n, k, + alpha, A1, beta, C1); + test.run(); + + Blas::herk(uplo, trans, + n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamHerk test(uplo, trans, + n, k, + alpha, A1, beta, C1); + test.run(); + + Blas::herk(uplo, trans, + n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamHerk test(uplo, trans, + n, k, + alpha, A1, beta, C1); + test.run(); + + Blas::herk(uplo, trans, + n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i random(13718); + + Kokkos::fill_random(A1.d_view, random, ValueType(1)); + Kokkos::fill_random(C1.d_view, random, ValueType(1)); + + Kokkos::deep_copy(A2.h_view, A1.d_view); + Kokkos::deep_copy(C2.h_view, C1.d_view); + + ::Test::Functor_TeamHerk test(uplo, trans, + n, k, + alpha, A1, beta, C1); + test.run(); + + Blas::herk(uplo, trans, + n, k, + alpha, + A2.h_view.data(), A2.h_view.stride_1(), + beta, + C2.h_view.data(), C2.h_view.stride_1()); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i + KOKKOS_INLINE_FUNCTION + void operator()(const MemberType &member) const { + ::BlasTeam::trsv(member, + _uplo, _trans, _diag, + _m, + (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), + ( ValueType*)_b.d_view.data(), (int)_b.d_view.stride_0()); + } + + inline + void run() { + _A.sync_device(); + + _b.sync_device(); + _b.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _b.sync_host(); + } + }; +} + + +#define TEAM_TRSV_TEST_BODY { \ + matrix_type A1("A1", m, m), b1("b1", m, 1); \ + matrix_type A2("A2", m, m), b2("b2", m, 1); \ + \ + A1.modify_device(); \ + b1.modify_device(); \ + \ + Kokkos::Random_XorShift64_Pool random(13718); \ + \ + Kokkos::fill_random(A1.d_view, random, ValueType(1)); \ + Kokkos::fill_random(b1.d_view, random, ValueType(1)); \ + \ + Kokkos::deep_copy(A2.h_view, A1.d_view); \ + Kokkos::deep_copy(b2.h_view, b1.d_view); \ + \ + ::Test::Functor_TeamTrsv test(uplo, trans, diag, \ + m, \ + A1, b1); \ + test.run(); \ + \ + Blas::trsv(uplo, trans, diag, \ + m, \ + A2.h_view.data(), A2.h_view.stride_1(), \ + b2.h_view.data(), b2.h_view.stride_0()); \ + \ + const MagnitudeType eps = std::numeric_limits::epsilon() * 10000; \ + for (int i=0;i + KOKKOS_INLINE_FUNCTION + void operator()(const MemberType &member) const { + ::BlasTeam::trsm(member, + _side, _uplo, _trans, _diag, + _m, _n, + _alpha, + (const ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), + ( ValueType*)_B.d_view.data(), (int)_B.d_view.stride_1()); + } + + inline + void run() { + _A.sync_device(); + + _B.sync_device(); + _B.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _B.sync_host(); + } + }; +} + +#define TEAM_TRSM_TEST_BODY { \ + matrix_type A1("A1", m, m), B1("b1", m, n); \ + matrix_type A2("A2", m, m), B2("b2", m, n); \ + \ + A1.modify_device(); \ + B1.modify_device(); \ + \ + Kokkos::Random_XorShift64_Pool random(13718); \ + \ + Kokkos::fill_random(A1.d_view, random, ValueType(1)); \ + Kokkos::fill_random(B1.d_view, random, ValueType(1)); \ + \ + Kokkos::deep_copy(A2.h_view, A1.d_view); \ + Kokkos::deep_copy(B2.h_view, B1.d_view); \ + \ + ::Test::Functor_TeamTrsm test(side, uplo, trans, diag, \ + m, n, \ + alpha, \ + A1, B1); \ + test.run(); \ + \ + Blas::trsm(side, uplo, trans, diag, \ + m, n, \ + alpha, \ + A2.h_view.data(), A2.h_view.stride_1(), \ + B2.h_view.data(), B2.h_view.stride_1()); \ + \ + const MagnitudeType eps = std::numeric_limits::epsilon() * 100000; \ + for (int i=0;i + KOKKOS_INLINE_FUNCTION + void operator()(const MemberType &member) const { + int r_val = 0; + ::LapackTeam::potrf(member, + _uplo, + _m, + (ValueType*)_A.d_view.data(), (int)_A.d_view.stride_1(), + &r_val); + } + + inline + void run() { + _A.sync_device(); + _A.modify_device(); + + Kokkos::parallel_for(Kokkos::TeamPolicy(1, Kokkos::AUTO), *this); + + _A.sync_host(); + } + }; +} + + +TEST( DenseLinearAlgebra, team_chol_u ) { + TEST_BEGIN; + const ordinal_type m = 20; + const char uplo = 'U'; + + matrix_type A1("A1", m, m); + matrix_type A2("A2", m, m); + + A2.modify_host(); + + for (int i=0;i::potrf(uplo, + m, + (ValueType*)A2.h_view.data(), (int)A2.h_view.stride_1(), + &r_val); + + const MagnitudeType eps = std::numeric_limits::epsilon() * 1000; + for (int i=0;i +#include +#include + +#include "Tacho.hpp" +#include "gtest/gtest.h" + +static const std::string ETI_TEST_NAME="@ETI_NAME@"; +using host_device_type = typename Tacho::UseThisDevice::type; +using device_type = typename Tacho::UseThisDevice::type; +using value_type = @ETI_VALUE_TYPE@; +using ats = Tacho::ArithTraits; +using magnitude_type = typename ats::mag_type; + +static const std::string MM_TEST_FILE= std::is_same::value ? "test_double.mtx" : "test_dcomplex.mtx"; + +// #if defined(KOKKOS_ENABLE_SERIAL) +// static Kokkos::Impl::HostThreadTeamMember +// host_serial_member() { +// auto& data = Kokkos::Serial().impl_internal_space_instance()->m_thread_team_data; +// return Kokkos::Impl::HostThreadTeamMember(data); +// } +// #else +// int host_serial_member() { +// throw std::logic_error("Error: host serial member is used while serial execution space is not enabled"); +// return -1; +// } +// #endif + +#include "Tacho_Test.hpp" + +int main (int argc, char *argv[]) { + int r_val(0); + Kokkos::initialize(); + { + const bool detail = false; + Tacho::printExecSpaceConfiguration("Device space", detail); + Tacho::printExecSpaceConfiguration("Host space", detail); + + ::testing::InitGoogleTest(&argc, argv); + r_val = RUN_ALL_TESTS(); + } + Kokkos::finalize(); + + return r_val; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/CMakeLists.txt b/packages/shylu/shylu_node/tacho/unit-test/googletest/CMakeLists.txt new file mode 100644 index 000000000000..eb03bfaf3e0f --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/CMakeLists.txt @@ -0,0 +1,321 @@ +######################################################################## +# Note: CMake support is community-based. The maintainers do not use CMake +# internally. +# +# CMake build script for Google Test. +# +# To run the tests for Google Test itself on Linux, use 'make test' or +# ctest. You can select which tests to run using 'ctest -R regex'. +# For more options, run 'ctest --help'. + +# When other libraries are using a shared version of runtime libraries, +# Google Test also has to use one. +option( + gtest_force_shared_crt + "Use shared (DLL) run-time lib even when Google Test is built as static lib." + OFF) + +option(gtest_build_tests "Build all of gtest's own tests." OFF) + +option(gtest_build_samples "Build gtest's sample programs." OFF) + +option(gtest_disable_pthreads "Disable uses of pthreads in gtest." OFF) + +option( + gtest_hide_internal_symbols + "Build gtest with internal symbols hidden in shared libraries." + OFF) + +# Defines pre_project_set_up_hermetic_build() and set_up_hermetic_build(). +include(cmake/hermetic_build.cmake OPTIONAL) + +if (COMMAND pre_project_set_up_hermetic_build) + pre_project_set_up_hermetic_build() +endif() + +######################################################################## +# +# Project-wide settings + +# Name of the project. +# +# CMake files in this project can refer to the root source directory +# as ${gtest_SOURCE_DIR} and to the root binary directory as +# ${gtest_BINARY_DIR}. +# Language "C" is required for find_package(Threads). + +# Project version: + +cmake_minimum_required(VERSION 3.5) +cmake_policy(SET CMP0048 NEW) +project(gtest VERSION ${GOOGLETEST_VERSION} LANGUAGES CXX C) + +if (POLICY CMP0063) # Visibility + cmake_policy(SET CMP0063 NEW) +endif (POLICY CMP0063) + +if (COMMAND set_up_hermetic_build) + set_up_hermetic_build() +endif() + +# These commands only run if this is the main project +if(CMAKE_PROJECT_NAME STREQUAL "gtest" OR CMAKE_PROJECT_NAME STREQUAL "googletest-distribution") + + # BUILD_SHARED_LIBS is a standard CMake variable, but we declare it here to + # make it prominent in the GUI. + option(BUILD_SHARED_LIBS "Build shared libraries (DLLs)." OFF) + +else() + + mark_as_advanced( + gtest_force_shared_crt + gtest_build_tests + gtest_build_samples + gtest_disable_pthreads + gtest_hide_internal_symbols) + +endif() + + +if (gtest_hide_internal_symbols) + set(CMAKE_CXX_VISIBILITY_PRESET hidden) + set(CMAKE_VISIBILITY_INLINES_HIDDEN 1) +endif() + +# Define helper functions and macros used by Google Test. +include(cmake/internal_utils.cmake) + +config_compiler_and_linker() # Defined in internal_utils.cmake. + +# Needed to set the namespace for both the export targets and the +# alias libraries +set(cmake_package_name GTest CACHE INTERNAL "") + +# Create the CMake package file descriptors. +if (INSTALL_GTEST) + include(CMakePackageConfigHelpers) + set(targets_export_name ${cmake_package_name}Targets CACHE INTERNAL "") + set(generated_dir "${CMAKE_CURRENT_BINARY_DIR}/generated" CACHE INTERNAL "") + set(cmake_files_install_dir "${CMAKE_INSTALL_LIBDIR}/cmake/${cmake_package_name}") + set(version_file "${generated_dir}/${cmake_package_name}ConfigVersion.cmake") + write_basic_package_version_file(${version_file} VERSION ${GOOGLETEST_VERSION} COMPATIBILITY AnyNewerVersion) + install(EXPORT ${targets_export_name} + NAMESPACE ${cmake_package_name}:: + DESTINATION ${cmake_files_install_dir}) + set(config_file "${generated_dir}/${cmake_package_name}Config.cmake") + configure_package_config_file("${gtest_SOURCE_DIR}/cmake/Config.cmake.in" + "${config_file}" INSTALL_DESTINATION ${cmake_files_install_dir}) + install(FILES ${version_file} ${config_file} + DESTINATION ${cmake_files_install_dir}) +endif() + +# Where Google Test's .h files can be found. +set(gtest_build_include_dirs + "${gtest_SOURCE_DIR}/include" + "${gtest_SOURCE_DIR}") +include_directories(${gtest_build_include_dirs}) + +######################################################################## +# +# Defines the gtest & gtest_main libraries. User tests should link +# with one of them. + +# Google Test libraries. We build them using more strict warnings than what +# are used for other targets, to ensure that gtest can be compiled by a user +# aggressive about warnings. +cxx_library(gtest "${cxx_strict}" src/gtest-all.cc) +set_target_properties(gtest PROPERTIES VERSION ${GOOGLETEST_VERSION}) +cxx_library(gtest_main "${cxx_strict}" src/gtest_main.cc) +set_target_properties(gtest_main PROPERTIES VERSION ${GOOGLETEST_VERSION}) +# If the CMake version supports it, attach header directory information +# to the targets for when we are part of a parent build (ie being pulled +# in via add_subdirectory() rather than being a standalone build). +if (DEFINED CMAKE_VERSION AND NOT "${CMAKE_VERSION}" VERSION_LESS "2.8.11") + target_include_directories(gtest SYSTEM INTERFACE + "$" + "$/${CMAKE_INSTALL_INCLUDEDIR}>") + target_include_directories(gtest_main SYSTEM INTERFACE + "$" + "$/${CMAKE_INSTALL_INCLUDEDIR}>") +endif() +if(CMAKE_SYSTEM_NAME MATCHES "QNX") + target_link_libraries(gtest PUBLIC regex) +endif() +target_link_libraries(gtest_main PUBLIC gtest) + +######################################################################## +# +# Install rules +install_project(gtest gtest_main) + +######################################################################## +# +# Samples on how to link user tests with gtest or gtest_main. +# +# They are not built by default. To build them, set the +# gtest_build_samples option to ON. You can do it by running ccmake +# or specifying the -Dgtest_build_samples=ON flag when running cmake. + +if (gtest_build_samples) + cxx_executable(sample1_unittest samples gtest_main samples/sample1.cc) + cxx_executable(sample2_unittest samples gtest_main samples/sample2.cc) + cxx_executable(sample3_unittest samples gtest_main) + cxx_executable(sample4_unittest samples gtest_main samples/sample4.cc) + cxx_executable(sample5_unittest samples gtest_main samples/sample1.cc) + cxx_executable(sample6_unittest samples gtest_main) + cxx_executable(sample7_unittest samples gtest_main) + cxx_executable(sample8_unittest samples gtest_main) + cxx_executable(sample9_unittest samples gtest) + cxx_executable(sample10_unittest samples gtest) +endif() + +######################################################################## +# +# Google Test's own tests. +# +# You can skip this section if you aren't interested in testing +# Google Test itself. +# +# The tests are not built by default. To build them, set the +# gtest_build_tests option to ON. You can do it by running ccmake +# or specifying the -Dgtest_build_tests=ON flag when running cmake. + +if (gtest_build_tests) + # This must be set in the root directory for the tests to be run by + # 'make test' or ctest. + enable_testing() + + ############################################################ + # C++ tests built with standard compiler flags. + + cxx_test(googletest-death-test-test gtest_main) + cxx_test(gtest_environment_test gtest) + cxx_test(googletest-filepath-test gtest_main) + cxx_test(googletest-listener-test gtest_main) + cxx_test(gtest_main_unittest gtest_main) + cxx_test(googletest-message-test gtest_main) + cxx_test(gtest_no_test_unittest gtest) + cxx_test(googletest-options-test gtest_main) + cxx_test(googletest-param-test-test gtest + test/googletest-param-test2-test.cc) + cxx_test(googletest-port-test gtest_main) + cxx_test(gtest_pred_impl_unittest gtest_main) + cxx_test(gtest_premature_exit_test gtest + test/gtest_premature_exit_test.cc) + cxx_test(googletest-printers-test gtest_main) + cxx_test(gtest_prod_test gtest_main + test/production.cc) + cxx_test(gtest_repeat_test gtest) + cxx_test(gtest_sole_header_test gtest_main) + cxx_test(gtest_stress_test gtest) + cxx_test(googletest-test-part-test gtest_main) + cxx_test(gtest_throw_on_failure_ex_test gtest) + cxx_test(gtest-typed-test_test gtest_main + test/gtest-typed-test2_test.cc) + cxx_test(gtest_unittest gtest_main) + cxx_test(gtest-unittest-api_test gtest) + cxx_test(gtest_skip_in_environment_setup_test gtest_main) + cxx_test(gtest_skip_test gtest_main) + + ############################################################ + # C++ tests built with non-standard compiler flags. + + # MSVC 7.1 does not support STL with exceptions disabled. + if (NOT MSVC OR MSVC_VERSION GREATER 1310) + cxx_library(gtest_no_exception "${cxx_no_exception}" + src/gtest-all.cc) + cxx_library(gtest_main_no_exception "${cxx_no_exception}" + src/gtest-all.cc src/gtest_main.cc) + endif() + cxx_library(gtest_main_no_rtti "${cxx_no_rtti}" + src/gtest-all.cc src/gtest_main.cc) + + cxx_test_with_flags(gtest-death-test_ex_nocatch_test + "${cxx_exception} -DGTEST_ENABLE_CATCH_EXCEPTIONS_=0" + gtest test/googletest-death-test_ex_test.cc) + cxx_test_with_flags(gtest-death-test_ex_catch_test + "${cxx_exception} -DGTEST_ENABLE_CATCH_EXCEPTIONS_=1" + gtest test/googletest-death-test_ex_test.cc) + + cxx_test_with_flags(gtest_no_rtti_unittest "${cxx_no_rtti}" + gtest_main_no_rtti test/gtest_unittest.cc) + + cxx_shared_library(gtest_dll "${cxx_default}" + src/gtest-all.cc src/gtest_main.cc) + + cxx_executable_with_flags(gtest_dll_test_ "${cxx_default}" + gtest_dll test/gtest_all_test.cc) + set_target_properties(gtest_dll_test_ + PROPERTIES + COMPILE_DEFINITIONS "GTEST_LINKED_AS_SHARED_LIBRARY=1") + + ############################################################ + # Python tests. + + cxx_executable(googletest-break-on-failure-unittest_ test gtest) + py_test(googletest-break-on-failure-unittest) + + py_test(gtest_skip_check_output_test) + py_test(gtest_skip_environment_check_output_test) + + # Visual Studio .NET 2003 does not support STL with exceptions disabled. + if (NOT MSVC OR MSVC_VERSION GREATER 1310) # 1310 is Visual Studio .NET 2003 + cxx_executable_with_flags( + googletest-catch-exceptions-no-ex-test_ + "${cxx_no_exception}" + gtest_main_no_exception + test/googletest-catch-exceptions-test_.cc) + endif() + + cxx_executable_with_flags( + googletest-catch-exceptions-ex-test_ + "${cxx_exception}" + gtest_main + test/googletest-catch-exceptions-test_.cc) + py_test(googletest-catch-exceptions-test) + + cxx_executable(googletest-color-test_ test gtest) + py_test(googletest-color-test) + + cxx_executable(googletest-env-var-test_ test gtest) + py_test(googletest-env-var-test) + + cxx_executable(googletest-filter-unittest_ test gtest) + py_test(googletest-filter-unittest) + + cxx_executable(gtest_help_test_ test gtest_main) + py_test(gtest_help_test) + + cxx_executable(googletest-list-tests-unittest_ test gtest) + py_test(googletest-list-tests-unittest) + + cxx_executable(googletest-output-test_ test gtest) + py_test(googletest-output-test --no_stacktrace_support) + + cxx_executable(googletest-shuffle-test_ test gtest) + py_test(googletest-shuffle-test) + + # MSVC 7.1 does not support STL with exceptions disabled. + if (NOT MSVC OR MSVC_VERSION GREATER 1310) + cxx_executable(googletest-throw-on-failure-test_ test gtest_no_exception) + set_target_properties(googletest-throw-on-failure-test_ + PROPERTIES + COMPILE_FLAGS "${cxx_no_exception}") + py_test(googletest-throw-on-failure-test) + endif() + + cxx_executable(googletest-uninitialized-test_ test gtest) + py_test(googletest-uninitialized-test) + + cxx_executable(gtest_list_output_unittest_ test gtest) + py_test(gtest_list_output_unittest) + + cxx_executable(gtest_xml_outfile1_test_ test gtest_main) + cxx_executable(gtest_xml_outfile2_test_ test gtest_main) + py_test(gtest_xml_outfiles_test) + py_test(googletest-json-outfiles-test) + + cxx_executable(gtest_xml_output_unittest_ test gtest) + py_test(gtest_xml_output_unittest --no_stacktrace_support) + py_test(googletest-json-output-unittest --no_stacktrace_support) +endif() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/LICENSE b/packages/shylu/shylu_node/tacho/unit-test/googletest/LICENSE new file mode 100644 index 000000000000..1941a11f8ce9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/LICENSE @@ -0,0 +1,28 @@ +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/README.md b/packages/shylu/shylu_node/tacho/unit-test/googletest/README.md new file mode 100644 index 000000000000..d26b309ed0d1 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/README.md @@ -0,0 +1,217 @@ +### Generic Build Instructions + +#### Setup + +To build GoogleTest and your tests that use it, you need to tell your build +system where to find its headers and source files. The exact way to do it +depends on which build system you use, and is usually straightforward. + +### Build with CMake + +GoogleTest comes with a CMake build script +([CMakeLists.txt](https://github.com/google/googletest/blob/master/CMakeLists.txt)) +that can be used on a wide range of platforms ("C" stands for cross-platform.). +If you don't have CMake installed already, you can download it for free from +. + +CMake works by generating native makefiles or build projects that can be used in +the compiler environment of your choice. You can either build GoogleTest as a +standalone project or it can be incorporated into an existing CMake build for +another project. + +#### Standalone CMake Project + +When building GoogleTest as a standalone project, the typical workflow starts +with + +``` +git clone https://github.com/google/googletest.git -b release-1.11.0 +cd googletest # Main directory of the cloned repository. +mkdir build # Create a directory to hold the build output. +cd build +cmake .. # Generate native build scripts for GoogleTest. +``` + +The above command also includes GoogleMock by default. And so, if you want to +build only GoogleTest, you should replace the last command with + +``` +cmake .. -DBUILD_GMOCK=OFF +``` + +If you are on a \*nix system, you should now see a Makefile in the current +directory. Just type `make` to build GoogleTest. And then you can simply install +GoogleTest if you are a system administrator. + +``` +make +sudo make install # Install in /usr/local/ by default +``` + +If you use Windows and have Visual Studio installed, a `gtest.sln` file and +several `.vcproj` files will be created. You can then build them using Visual +Studio. + +On Mac OS X with Xcode installed, a `.xcodeproj` file will be generated. + +#### Incorporating Into An Existing CMake Project + +If you want to use GoogleTest in a project which already uses CMake, the easiest +way is to get installed libraries and headers. + +* Import GoogleTest by using `find_package` (or `pkg_check_modules`). For + example, if `find_package(GTest CONFIG REQUIRED)` succeeds, you can use the + libraries as `GTest::gtest`, `GTest::gmock`. + +And a more robust and flexible approach is to build GoogleTest as part of that +project directly. This is done by making the GoogleTest source code available to +the main build and adding it using CMake's `add_subdirectory()` command. This +has the significant advantage that the same compiler and linker settings are +used between GoogleTest and the rest of your project, so issues associated with +using incompatible libraries (eg debug/release), etc. are avoided. This is +particularly useful on Windows. Making GoogleTest's source code available to the +main build can be done a few different ways: + +* Download the GoogleTest source code manually and place it at a known + location. This is the least flexible approach and can make it more difficult + to use with continuous integration systems, etc. +* Embed the GoogleTest source code as a direct copy in the main project's + source tree. This is often the simplest approach, but is also the hardest to + keep up to date. Some organizations may not permit this method. +* Add GoogleTest as a git submodule or equivalent. This may not always be + possible or appropriate. Git submodules, for example, have their own set of + advantages and drawbacks. +* Use CMake to download GoogleTest as part of the build's configure step. This + approach doesn't have the limitations of the other methods. + +The last of the above methods is implemented with a small piece of CMake code +that downloads and pulls the GoogleTest code into the main build. + +Just add to your `CMakeLists.txt`: + +```cmake +include(FetchContent) +FetchContent_Declare( + googletest + # Specify the commit you depend on and update it regularly. + URL https://github.com/google/googletest/archive/e2239ee6043f73722e7aa812a459f54a28552929.zip +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +# Now simply link against gtest or gtest_main as needed. Eg +add_executable(example example.cpp) +target_link_libraries(example gtest_main) +add_test(NAME example_test COMMAND example) +``` + +Note that this approach requires CMake 3.14 or later due to its use of the +`FetchContent_MakeAvailable()` command. + +##### Visual Studio Dynamic vs Static Runtimes + +By default, new Visual Studio projects link the C runtimes dynamically but +GoogleTest links them statically. This will generate an error that looks +something like the following: gtest.lib(gtest-all.obj) : error LNK2038: mismatch +detected for 'RuntimeLibrary': value 'MTd_StaticDebug' doesn't match value +'MDd_DynamicDebug' in main.obj + +GoogleTest already has a CMake option for this: `gtest_force_shared_crt` + +Enabling this option will make gtest link the runtimes dynamically too, and +match the project in which it is included. + +#### C++ Standard Version + +An environment that supports C++11 is required in order to successfully build +GoogleTest. One way to ensure this is to specify the standard in the top-level +project, for example by using the `set(CMAKE_CXX_STANDARD 11)` command. If this +is not feasible, for example in a C project using GoogleTest for validation, +then it can be specified by adding it to the options for cmake via the +`DCMAKE_CXX_FLAGS` option. + +### Tweaking GoogleTest + +GoogleTest can be used in diverse environments. The default configuration may +not work (or may not work well) out of the box in some environments. However, +you can easily tweak GoogleTest by defining control macros on the compiler +command line. Generally, these macros are named like `GTEST_XYZ` and you define +them to either 1 or 0 to enable or disable a certain feature. + +We list the most frequently used macros below. For a complete list, see file +[include/gtest/internal/gtest-port.h](https://github.com/google/googletest/blob/master/googletest/include/gtest/internal/gtest-port.h). + +### Multi-threaded Tests + +GoogleTest is thread-safe where the pthread library is available. After +`#include "gtest/gtest.h"`, you can check the +`GTEST_IS_THREADSAFE` macro to see whether this is the case (yes if the macro is +`#defined` to 1, no if it's undefined.). + +If GoogleTest doesn't correctly detect whether pthread is available in your +environment, you can force it with + + -DGTEST_HAS_PTHREAD=1 + +or + + -DGTEST_HAS_PTHREAD=0 + +When GoogleTest uses pthread, you may need to add flags to your compiler and/or +linker to select the pthread library, or you'll get link errors. If you use the +CMake script, this is taken care of for you. If you use your own build script, +you'll need to read your compiler and linker's manual to figure out what flags +to add. + +### As a Shared Library (DLL) + +GoogleTest is compact, so most users can build and link it as a static library +for the simplicity. You can choose to use GoogleTest as a shared library (known +as a DLL on Windows) if you prefer. + +To compile *gtest* as a shared library, add + + -DGTEST_CREATE_SHARED_LIBRARY=1 + +to the compiler flags. You'll also need to tell the linker to produce a shared +library instead - consult your linker's manual for how to do it. + +To compile your *tests* that use the gtest shared library, add + + -DGTEST_LINKED_AS_SHARED_LIBRARY=1 + +to the compiler flags. + +Note: while the above steps aren't technically necessary today when using some +compilers (e.g. GCC), they may become necessary in the future, if we decide to +improve the speed of loading the library (see + for details). Therefore you are recommended +to always add the above flags when using GoogleTest as a shared library. +Otherwise a future release of GoogleTest may break your build script. + +### Avoiding Macro Name Clashes + +In C++, macros don't obey namespaces. Therefore two libraries that both define a +macro of the same name will clash if you `#include` both definitions. In case a +GoogleTest macro clashes with another library, you can force GoogleTest to +rename its macro to avoid the conflict. + +Specifically, if both GoogleTest and some other code define macro FOO, you can +add + + -DGTEST_DONT_DEFINE_FOO=1 + +to the compiler flags to tell GoogleTest to change the macro's name from `FOO` +to `GTEST_FOO`. Currently `FOO` can be `ASSERT_EQ`, `ASSERT_FALSE`, `ASSERT_GE`, +`ASSERT_GT`, `ASSERT_LE`, `ASSERT_LT`, `ASSERT_NE`, `ASSERT_TRUE`, +`EXPECT_FALSE`, `EXPECT_TRUE`, `FAIL`, `SUCCEED`, `TEST`, or `TEST_F`. For +example, with `-DGTEST_DONT_DEFINE_TEST=1`, you'll need to write + + GTEST_TEST(SomeTest, DoesThis) { ... } + +instead of + + TEST(SomeTest, DoesThis) { ... } + +in order to define a test. diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/Config.cmake.in b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/Config.cmake.in new file mode 100644 index 000000000000..12be4498b1a0 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/Config.cmake.in @@ -0,0 +1,9 @@ +@PACKAGE_INIT@ +include(CMakeFindDependencyMacro) +if (@GTEST_HAS_PTHREAD@) + set(THREADS_PREFER_PTHREAD_FLAG @THREADS_PREFER_PTHREAD_FLAG@) + find_dependency(Threads) +endif() + +include("${CMAKE_CURRENT_LIST_DIR}/@targets_export_name@.cmake") +check_required_components("@project_name@") diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/gtest.pc.in b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/gtest.pc.in new file mode 100644 index 000000000000..b4148fae42b1 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/gtest.pc.in @@ -0,0 +1,9 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gtest +Description: GoogleTest (without main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgtest @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/gtest_main.pc.in b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/gtest_main.pc.in new file mode 100644 index 000000000000..38c88c54d538 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/gtest_main.pc.in @@ -0,0 +1,10 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gtest_main +Description: GoogleTest (with main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Requires: gtest = @PROJECT_VERSION@ +Libs: -L${libdir} -lgtest_main @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/internal_utils.cmake b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/internal_utils.cmake new file mode 100644 index 000000000000..5a34c07a1b99 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/internal_utils.cmake @@ -0,0 +1,342 @@ +# Defines functions and macros useful for building Google Test and +# Google Mock. +# +# Note: +# +# - This file will be run twice when building Google Mock (once via +# Google Test's CMakeLists.txt, and once via Google Mock's). +# Therefore it shouldn't have any side effects other than defining +# the functions and macros. +# +# - The functions/macros defined in this file may depend on Google +# Test and Google Mock's option() definitions, and thus must be +# called *after* the options have been defined. + +if (POLICY CMP0054) + cmake_policy(SET CMP0054 NEW) +endif (POLICY CMP0054) + +# Tweaks CMake's default compiler/linker settings to suit Google Test's needs. +# +# This must be a macro(), as inside a function string() can only +# update variables in the function scope. +macro(fix_default_compiler_settings_) + if (MSVC) + # For MSVC, CMake sets certain flags to defaults we want to override. + # This replacement code is taken from sample in the CMake Wiki at + # https://gitlab.kitware.com/cmake/community/wikis/FAQ#dynamic-replace. + foreach (flag_var + CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if (NOT BUILD_SHARED_LIBS AND NOT gtest_force_shared_crt) + # When Google Test is built as a shared library, it should also use + # shared runtime libraries. Otherwise, it may end up with multiple + # copies of runtime library data in different modules, resulting in + # hard-to-find crashes. When it is built as a static library, it is + # preferable to use CRT as static libraries, as we don't have to rely + # on CRT DLLs being available. CMake always defaults to using shared + # CRT libraries, so we override that default here. + string(REPLACE "/MD" "-MT" ${flag_var} "${${flag_var}}") + endif() + + # We prefer more strict warning checking for building Google Test. + # Replaces /W3 with /W4 in defaults. + string(REPLACE "/W3" "/W4" ${flag_var} "${${flag_var}}") + + # Prevent D9025 warning for targets that have exception handling + # turned off (/EHs-c- flag). Where required, exceptions are explicitly + # re-enabled using the cxx_exception_flags variable. + string(REPLACE "/EHsc" "" ${flag_var} "${${flag_var}}") + endforeach() + endif() +endmacro() + +# Defines the compiler/linker flags used to build Google Test and +# Google Mock. You can tweak these definitions to suit your need. A +# variable's value is empty before it's explicitly assigned to. +macro(config_compiler_and_linker) + # Note: pthreads on MinGW is not supported, even if available + # instead, we use windows threading primitives + unset(GTEST_HAS_PTHREAD) + if (NOT gtest_disable_pthreads AND NOT MINGW) + # Defines CMAKE_USE_PTHREADS_INIT and CMAKE_THREAD_LIBS_INIT. + find_package(Threads) + if (CMAKE_USE_PTHREADS_INIT) + set(GTEST_HAS_PTHREAD ON) + endif() + endif() + + fix_default_compiler_settings_() + if (MSVC) + # Newlines inside flags variables break CMake's NMake generator. + # TODO(vladl@google.com): Add -RTCs and -RTCu to debug builds. + set(cxx_base_flags "-GS -W4 -WX -wd4251 -wd4275 -nologo -J") + set(cxx_base_flags "${cxx_base_flags} -D_UNICODE -DUNICODE -DWIN32 -D_WIN32") + set(cxx_base_flags "${cxx_base_flags} -DSTRICT -DWIN32_LEAN_AND_MEAN") + set(cxx_exception_flags "-EHsc -D_HAS_EXCEPTIONS=1") + set(cxx_no_exception_flags "-EHs-c- -D_HAS_EXCEPTIONS=0") + set(cxx_no_rtti_flags "-GR-") + # Suppress "unreachable code" warning + # http://stackoverflow.com/questions/3232669 explains the issue. + set(cxx_base_flags "${cxx_base_flags} -wd4702") + # Ensure MSVC treats source files as UTF-8 encoded. + set(cxx_base_flags "${cxx_base_flags} -utf-8") + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(cxx_base_flags "-Wall -Wshadow -Wconversion") + set(cxx_exception_flags "-fexceptions") + set(cxx_no_exception_flags "-fno-exceptions") + set(cxx_strict_flags "-W -Wpointer-arith -Wreturn-type -Wcast-qual -Wwrite-strings -Wswitch -Wunused-parameter -Wcast-align -Wchar-subscripts -Winline -Wredundant-decls") + set(cxx_no_rtti_flags "-fno-rtti") + elseif (CMAKE_COMPILER_IS_GNUCXX) + set(cxx_base_flags "-Wall -Wshadow") + if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0) + set(cxx_base_flags "${cxx_base_flags} -Wno-error=dangling-else") + endif() + set(cxx_exception_flags "-fexceptions") + set(cxx_no_exception_flags "-fno-exceptions") + # Until version 4.3.2, GCC doesn't define a macro to indicate + # whether RTTI is enabled. Therefore we define GTEST_HAS_RTTI + # explicitly. + set(cxx_no_rtti_flags "-fno-rtti -DGTEST_HAS_RTTI=0") + set(cxx_strict_flags + "-Wextra -Wno-unused-parameter -Wno-missing-field-initializers") + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "SunPro") + set(cxx_exception_flags "-features=except") + # Sun Pro doesn't provide macros to indicate whether exceptions and + # RTTI are enabled, so we define GTEST_HAS_* explicitly. + set(cxx_no_exception_flags "-features=no%except -DGTEST_HAS_EXCEPTIONS=0") + set(cxx_no_rtti_flags "-features=no%rtti -DGTEST_HAS_RTTI=0") + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "VisualAge" OR + CMAKE_CXX_COMPILER_ID STREQUAL "XL") + # CMake 2.8 changes Visual Age's compiler ID to "XL". + set(cxx_exception_flags "-qeh") + set(cxx_no_exception_flags "-qnoeh") + # Until version 9.0, Visual Age doesn't define a macro to indicate + # whether RTTI is enabled. Therefore we define GTEST_HAS_RTTI + # explicitly. + set(cxx_no_rtti_flags "-qnortti -DGTEST_HAS_RTTI=0") + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "HP") + set(cxx_base_flags "-AA -mt") + set(cxx_exception_flags "-DGTEST_HAS_EXCEPTIONS=1") + set(cxx_no_exception_flags "+noeh -DGTEST_HAS_EXCEPTIONS=0") + # RTTI can not be disabled in HP aCC compiler. + set(cxx_no_rtti_flags "") + endif() + + # The pthreads library is available and allowed? + if (DEFINED GTEST_HAS_PTHREAD) + set(GTEST_HAS_PTHREAD_MACRO "-DGTEST_HAS_PTHREAD=1") + else() + set(GTEST_HAS_PTHREAD_MACRO "-DGTEST_HAS_PTHREAD=0") + endif() + set(cxx_base_flags "${cxx_base_flags} ${GTEST_HAS_PTHREAD_MACRO}") + + # For building gtest's own tests and samples. + set(cxx_exception "${cxx_base_flags} ${cxx_exception_flags}") + set(cxx_no_exception + "${CMAKE_CXX_FLAGS} ${cxx_base_flags} ${cxx_no_exception_flags}") + set(cxx_default "${cxx_exception}") + set(cxx_no_rtti "${cxx_default} ${cxx_no_rtti_flags}") + + # For building the gtest libraries. + set(cxx_strict "${cxx_default} ${cxx_strict_flags}") +endmacro() + +# Defines the gtest & gtest_main libraries. User tests should link +# with one of them. +function(cxx_library_with_type name type cxx_flags) + # type can be either STATIC or SHARED to denote a static or shared library. + # ARGN refers to additional arguments after 'cxx_flags'. + add_library(${name} ${type} ${ARGN}) + add_library(${cmake_package_name}::${name} ALIAS ${name}) + set_target_properties(${name} + PROPERTIES + COMPILE_FLAGS "${cxx_flags}") + # Set the output directory for build artifacts + set_target_properties(${name} + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") + # make PDBs match library name + get_target_property(pdb_debug_postfix ${name} DEBUG_POSTFIX) + set_target_properties(${name} + PROPERTIES + PDB_NAME "${name}" + PDB_NAME_DEBUG "${name}${pdb_debug_postfix}" + COMPILE_PDB_NAME "${name}" + COMPILE_PDB_NAME_DEBUG "${name}${pdb_debug_postfix}") + + if (BUILD_SHARED_LIBS OR type STREQUAL "SHARED") + set_target_properties(${name} + PROPERTIES + COMPILE_DEFINITIONS "GTEST_CREATE_SHARED_LIBRARY=1") + if (NOT "${CMAKE_VERSION}" VERSION_LESS "2.8.11") + target_compile_definitions(${name} INTERFACE + $) + endif() + endif() + if (DEFINED GTEST_HAS_PTHREAD) + if ("${CMAKE_VERSION}" VERSION_LESS "3.1.0") + set(threads_spec ${CMAKE_THREAD_LIBS_INIT}) + else() + set(threads_spec Threads::Threads) + endif() + target_link_libraries(${name} PUBLIC ${threads_spec}) + endif() + + if (NOT "${CMAKE_VERSION}" VERSION_LESS "3.8") + target_compile_features(${name} PUBLIC cxx_std_11) + endif() +endfunction() + +######################################################################## +# +# Helper functions for creating build targets. + +function(cxx_shared_library name cxx_flags) + cxx_library_with_type(${name} SHARED "${cxx_flags}" ${ARGN}) +endfunction() + +function(cxx_library name cxx_flags) + cxx_library_with_type(${name} "" "${cxx_flags}" ${ARGN}) +endfunction() + +# cxx_executable_with_flags(name cxx_flags libs srcs...) +# +# creates a named C++ executable that depends on the given libraries and +# is built from the given source files with the given compiler flags. +function(cxx_executable_with_flags name cxx_flags libs) + add_executable(${name} ${ARGN}) + if (MSVC) + # BigObj required for tests. + set(cxx_flags "${cxx_flags} -bigobj") + endif() + if (cxx_flags) + set_target_properties(${name} + PROPERTIES + COMPILE_FLAGS "${cxx_flags}") + endif() + if (BUILD_SHARED_LIBS) + set_target_properties(${name} + PROPERTIES + COMPILE_DEFINITIONS "GTEST_LINKED_AS_SHARED_LIBRARY=1") + endif() + # To support mixing linking in static and dynamic libraries, link each + # library in with an extra call to target_link_libraries. + foreach (lib "${libs}") + target_link_libraries(${name} ${lib}) + endforeach() +endfunction() + +# cxx_executable(name dir lib srcs...) +# +# creates a named target that depends on the given libs and is built +# from the given source files. dir/name.cc is implicitly included in +# the source file list. +function(cxx_executable name dir libs) + cxx_executable_with_flags( + ${name} "${cxx_default}" "${libs}" "${dir}/${name}.cc" ${ARGN}) +endfunction() + +# Sets PYTHONINTERP_FOUND and PYTHON_EXECUTABLE. +if ("${CMAKE_VERSION}" VERSION_LESS "3.12.0") + find_package(PythonInterp) +else() + find_package(Python COMPONENTS Interpreter) + set(PYTHONINTERP_FOUND ${Python_Interpreter_FOUND}) + set(PYTHON_EXECUTABLE ${Python_EXECUTABLE}) +endif() + +# cxx_test_with_flags(name cxx_flags libs srcs...) +# +# creates a named C++ test that depends on the given libs and is built +# from the given source files with the given compiler flags. +function(cxx_test_with_flags name cxx_flags libs) + cxx_executable_with_flags(${name} "${cxx_flags}" "${libs}" ${ARGN}) + add_test(NAME ${name} COMMAND "$") +endfunction() + +# cxx_test(name libs srcs...) +# +# creates a named test target that depends on the given libs and is +# built from the given source files. Unlike cxx_test_with_flags, +# test/name.cc is already implicitly included in the source file list. +function(cxx_test name libs) + cxx_test_with_flags("${name}" "${cxx_default}" "${libs}" + "test/${name}.cc" ${ARGN}) +endfunction() + +# py_test(name) +# +# creates a Python test with the given name whose main module is in +# test/name.py. It does nothing if Python is not installed. +function(py_test name) + if (PYTHONINTERP_FOUND) + if ("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" VERSION_GREATER 3.1) + if (CMAKE_CONFIGURATION_TYPES) + # Multi-configuration build generators as for Visual Studio save + # output in a subdirectory of CMAKE_CURRENT_BINARY_DIR (Debug, + # Release etc.), so we have to provide it here. + add_test(NAME ${name} + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/${name}.py + --build_dir=${CMAKE_CURRENT_BINARY_DIR}/$ ${ARGN}) + else (CMAKE_CONFIGURATION_TYPES) + # Single-configuration build generators like Makefile generators + # don't have subdirs below CMAKE_CURRENT_BINARY_DIR. + add_test(NAME ${name} + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/${name}.py + --build_dir=${CMAKE_CURRENT_BINARY_DIR} ${ARGN}) + endif (CMAKE_CONFIGURATION_TYPES) + else() + # ${CMAKE_CURRENT_BINARY_DIR} is known at configuration time, so we can + # directly bind it from cmake. ${CTEST_CONFIGURATION_TYPE} is known + # only at ctest runtime (by calling ctest -c ), so + # we have to escape $ to delay variable substitution here. + add_test(NAME ${name} + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/test/${name}.py + --build_dir=${CMAKE_CURRENT_BINARY_DIR}/\${CTEST_CONFIGURATION_TYPE} ${ARGN}) + endif() + # Make the Python import path consistent between Bazel and CMake. + set_tests_properties(${name} PROPERTIES ENVIRONMENT PYTHONPATH=${CMAKE_SOURCE_DIR}) + endif(PYTHONINTERP_FOUND) +endfunction() + +# install_project(targets...) +# +# Installs the specified targets and configures the associated pkgconfig files. +function(install_project) + if(INSTALL_GTEST) + install(DIRECTORY "${PROJECT_SOURCE_DIR}/include/" + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") + # Install the project targets. + install(TARGETS ${ARGN} + EXPORT ${targets_export_name} + RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}") + if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + # Install PDBs + foreach(t ${ARGN}) + get_target_property(t_pdb_name ${t} COMPILE_PDB_NAME) + get_target_property(t_pdb_name_debug ${t} COMPILE_PDB_NAME_DEBUG) + get_target_property(t_pdb_output_directory ${t} PDB_OUTPUT_DIRECTORY) + install(FILES + "${t_pdb_output_directory}/\${CMAKE_INSTALL_CONFIG_NAME}/$<$:${t_pdb_name_debug}>$<$>:${t_pdb_name}>.pdb" + DESTINATION ${CMAKE_INSTALL_LIBDIR} + OPTIONAL) + endforeach() + endif() + # Configure and install pkgconfig files. + foreach(t ${ARGN}) + set(configured_pc "${generated_dir}/${t}.pc") + configure_file("${PROJECT_SOURCE_DIR}/cmake/${t}.pc.in" + "${configured_pc}" @ONLY) + install(FILES "${configured_pc}" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + endforeach() + endif() +endfunction() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/libgtest.la.in b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/libgtest.la.in new file mode 100644 index 000000000000..840c83885f98 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/cmake/libgtest.la.in @@ -0,0 +1,21 @@ +# libgtest.la - a libtool library file +# Generated by libtool (GNU libtool) 2.4.6 + +# Please DO NOT delete this file! +# It is necessary for linking the library. + +# Names of this library. +library_names='libgtest.so' + +# Is this an already installed library? +installed=yes + +# Should we warn about portability when linking against -modules? +shouldnotlink=no + +# Files to dlopen/dlpreopen +dlopen='' +dlpreopen='' + +# Directory that this library needs to be installed in: +libdir='@CMAKE_INSTALL_FULL_LIBDIR@' diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/docs/README.md b/packages/shylu/shylu_node/tacho/unit-test/googletest/docs/README.md new file mode 100644 index 000000000000..1bc57b799cce --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/docs/README.md @@ -0,0 +1,4 @@ +# Content Moved + +We are working on updates to the GoogleTest documentation, which has moved to +the top-level [docs](../../docs) directory. diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-assertion-result.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-assertion-result.h new file mode 100644 index 000000000000..e020c48943f4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-assertion-result.h @@ -0,0 +1,232 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This file implements the AssertionResult type. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_ASSERTION_RESULT_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_ASSERTION_RESULT_H_ + +#include +#include +#include +#include + +#include "gtest/gtest-message.h" +#include "gtest/internal/gtest-port.h" + +namespace testing { + +// A class for indicating whether an assertion was successful. When +// the assertion wasn't successful, the AssertionResult object +// remembers a non-empty message that describes how it failed. +// +// To create an instance of this class, use one of the factory functions +// (AssertionSuccess() and AssertionFailure()). +// +// This class is useful for two purposes: +// 1. Defining predicate functions to be used with Boolean test assertions +// EXPECT_TRUE/EXPECT_FALSE and their ASSERT_ counterparts +// 2. Defining predicate-format functions to be +// used with predicate assertions (ASSERT_PRED_FORMAT*, etc). +// +// For example, if you define IsEven predicate: +// +// testing::AssertionResult IsEven(int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess(); +// else +// return testing::AssertionFailure() << n << " is odd"; +// } +// +// Then the failed expectation EXPECT_TRUE(IsEven(Fib(5))) +// will print the message +// +// Value of: IsEven(Fib(5)) +// Actual: false (5 is odd) +// Expected: true +// +// instead of a more opaque +// +// Value of: IsEven(Fib(5)) +// Actual: false +// Expected: true +// +// in case IsEven is a simple Boolean predicate. +// +// If you expect your predicate to be reused and want to support informative +// messages in EXPECT_FALSE and ASSERT_FALSE (negative assertions show up +// about half as often as positive ones in our tests), supply messages for +// both success and failure cases: +// +// testing::AssertionResult IsEven(int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess() << n << " is even"; +// else +// return testing::AssertionFailure() << n << " is odd"; +// } +// +// Then a statement EXPECT_FALSE(IsEven(Fib(6))) will print +// +// Value of: IsEven(Fib(6)) +// Actual: true (8 is even) +// Expected: false +// +// NB: Predicates that support negative Boolean assertions have reduced +// performance in positive ones so be careful not to use them in tests +// that have lots (tens of thousands) of positive Boolean assertions. +// +// To use this class with EXPECT_PRED_FORMAT assertions such as: +// +// // Verifies that Foo() returns an even number. +// EXPECT_PRED_FORMAT1(IsEven, Foo()); +// +// you need to define: +// +// testing::AssertionResult IsEven(const char* expr, int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess(); +// else +// return testing::AssertionFailure() +// << "Expected: " << expr << " is even\n Actual: it's " << n; +// } +// +// If Foo() returns 5, you will see the following message: +// +// Expected: Foo() is even +// Actual: it's 5 +// +class GTEST_API_ AssertionResult { + public: + // Copy constructor. + // Used in EXPECT_TRUE/FALSE(assertion_result). + AssertionResult(const AssertionResult& other); + +// C4800 is a level 3 warning in Visual Studio 2015 and earlier. +// This warning is not emitted in Visual Studio 2017. +// This warning is off by default starting in Visual Studio 2019 but can be +// enabled with command-line options. +#if defined(_MSC_VER) && (_MSC_VER < 1910 || _MSC_VER >= 1920) + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4800 /* forcing value to bool */) +#endif + + // Used in the EXPECT_TRUE/FALSE(bool_expression). + // + // T must be contextually convertible to bool. + // + // The second parameter prevents this overload from being considered if + // the argument is implicitly convertible to AssertionResult. In that case + // we want AssertionResult's copy constructor to be used. + template + explicit AssertionResult( + const T& success, + typename std::enable_if< + !std::is_convertible::value>::type* + /*enabler*/ + = nullptr) + : success_(success) {} + +#if defined(_MSC_VER) && (_MSC_VER < 1910 || _MSC_VER >= 1920) + GTEST_DISABLE_MSC_WARNINGS_POP_() +#endif + + // Assignment operator. + AssertionResult& operator=(AssertionResult other) { + swap(other); + return *this; + } + + // Returns true if and only if the assertion succeeded. + operator bool() const { return success_; } // NOLINT + + // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. + AssertionResult operator!() const; + + // Returns the text streamed into this AssertionResult. Test assertions + // use it when they fail (i.e., the predicate's outcome doesn't match the + // assertion's expectation). When nothing has been streamed into the + // object, returns an empty string. + const char* message() const { + return message_.get() != nullptr ? message_->c_str() : ""; + } + // Deprecated; please use message() instead. + const char* failure_message() const { return message(); } + + // Streams a custom failure message into this object. + template + AssertionResult& operator<<(const T& value) { + AppendMessage(Message() << value); + return *this; + } + + // Allows streaming basic output manipulators such as endl or flush into + // this object. + AssertionResult& operator<<( + ::std::ostream& (*basic_manipulator)(::std::ostream& stream)) { + AppendMessage(Message() << basic_manipulator); + return *this; + } + + private: + // Appends the contents of message to message_. + void AppendMessage(const Message& a_message) { + if (message_.get() == nullptr) message_.reset(new ::std::string); + message_->append(a_message.GetString().c_str()); + } + + // Swap the contents of this AssertionResult with other. + void swap(AssertionResult& other); + + // Stores result of the assertion predicate. + bool success_; + // Stores the message describing the condition in case the expectation + // construct is not satisfied with the predicate's outcome. + // Referenced via a pointer to avoid taking too much stack frame space + // with test assertions. + std::unique_ptr< ::std::string> message_; +}; + +// Makes a successful assertion result. +GTEST_API_ AssertionResult AssertionSuccess(); + +// Makes a failed assertion result. +GTEST_API_ AssertionResult AssertionFailure(); + +// Makes a failed assertion result with the given failure message. +// Deprecated; use AssertionFailure() << msg. +GTEST_API_ AssertionResult AssertionFailure(const Message& msg); + +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_ASSERTION_RESULT_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-death-test.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-death-test.h new file mode 100644 index 000000000000..cd34e1f2e852 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-death-test.h @@ -0,0 +1,346 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file defines the public API for death tests. It is +// #included by gtest.h so a user doesn't need to include this +// directly. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ + +#include "gtest/internal/gtest-death-test-internal.h" + +// This flag controls the style of death tests. Valid values are "threadsafe", +// meaning that the death test child process will re-execute the test binary +// from the start, running only a single death test, or "fast", +// meaning that the child process will execute the test logic immediately +// after forking. +GTEST_DECLARE_string_(death_test_style); + +namespace testing { + +#if GTEST_HAS_DEATH_TEST + +namespace internal { + +// Returns a Boolean value indicating whether the caller is currently +// executing in the context of the death test child process. Tools such as +// Valgrind heap checkers may need this to modify their behavior in death +// tests. IMPORTANT: This is an internal utility. Using it may break the +// implementation of death tests. User code MUST NOT use it. +GTEST_API_ bool InDeathTestChild(); + +} // namespace internal + +// The following macros are useful for writing death tests. + +// Here's what happens when an ASSERT_DEATH* or EXPECT_DEATH* is +// executed: +// +// 1. It generates a warning if there is more than one active +// thread. This is because it's safe to fork() or clone() only +// when there is a single thread. +// +// 2. The parent process clone()s a sub-process and runs the death +// test in it; the sub-process exits with code 0 at the end of the +// death test, if it hasn't exited already. +// +// 3. The parent process waits for the sub-process to terminate. +// +// 4. The parent process checks the exit code and error message of +// the sub-process. +// +// Examples: +// +// ASSERT_DEATH(server.SendMessage(56, "Hello"), "Invalid port number"); +// for (int i = 0; i < 5; i++) { +// EXPECT_DEATH(server.ProcessRequest(i), +// "Invalid request .* in ProcessRequest()") +// << "Failed to die on request " << i; +// } +// +// ASSERT_EXIT(server.ExitNow(), ::testing::ExitedWithCode(0), "Exiting"); +// +// bool KilledBySIGHUP(int exit_code) { +// return WIFSIGNALED(exit_code) && WTERMSIG(exit_code) == SIGHUP; +// } +// +// ASSERT_EXIT(client.HangUpServer(), KilledBySIGHUP, "Hanging up!"); +// +// The final parameter to each of these macros is a matcher applied to any data +// the sub-process wrote to stderr. For compatibility with existing tests, a +// bare string is interpreted as a regular expression matcher. +// +// On the regular expressions used in death tests: +// +// On POSIX-compliant systems (*nix), we use the library, +// which uses the POSIX extended regex syntax. +// +// On other platforms (e.g. Windows or Mac), we only support a simple regex +// syntax implemented as part of Google Test. This limited +// implementation should be enough most of the time when writing +// death tests; though it lacks many features you can find in PCRE +// or POSIX extended regex syntax. For example, we don't support +// union ("x|y"), grouping ("(xy)"), brackets ("[xy]"), and +// repetition count ("x{5,7}"), among others. +// +// Below is the syntax that we do support. We chose it to be a +// subset of both PCRE and POSIX extended regex, so it's easy to +// learn wherever you come from. In the following: 'A' denotes a +// literal character, period (.), or a single \\ escape sequence; +// 'x' and 'y' denote regular expressions; 'm' and 'n' are for +// natural numbers. +// +// c matches any literal character c +// \\d matches any decimal digit +// \\D matches any character that's not a decimal digit +// \\f matches \f +// \\n matches \n +// \\r matches \r +// \\s matches any ASCII whitespace, including \n +// \\S matches any character that's not a whitespace +// \\t matches \t +// \\v matches \v +// \\w matches any letter, _, or decimal digit +// \\W matches any character that \\w doesn't match +// \\c matches any literal character c, which must be a punctuation +// . matches any single character except \n +// A? matches 0 or 1 occurrences of A +// A* matches 0 or many occurrences of A +// A+ matches 1 or many occurrences of A +// ^ matches the beginning of a string (not that of each line) +// $ matches the end of a string (not that of each line) +// xy matches x followed by y +// +// If you accidentally use PCRE or POSIX extended regex features +// not implemented by us, you will get a run-time failure. In that +// case, please try to rewrite your regular expression within the +// above syntax. +// +// This implementation is *not* meant to be as highly tuned or robust +// as a compiled regex library, but should perform well enough for a +// death test, which already incurs significant overhead by launching +// a child process. +// +// Known caveats: +// +// A "threadsafe" style death test obtains the path to the test +// program from argv[0] and re-executes it in the sub-process. For +// simplicity, the current implementation doesn't search the PATH +// when launching the sub-process. This means that the user must +// invoke the test program via a path that contains at least one +// path separator (e.g. path/to/foo_test and +// /absolute/path/to/bar_test are fine, but foo_test is not). This +// is rarely a problem as people usually don't put the test binary +// directory in PATH. +// + +// Asserts that a given `statement` causes the program to exit, with an +// integer exit status that satisfies `predicate`, and emitting error output +// that matches `matcher`. +# define ASSERT_EXIT(statement, predicate, matcher) \ + GTEST_DEATH_TEST_(statement, predicate, matcher, GTEST_FATAL_FAILURE_) + +// Like `ASSERT_EXIT`, but continues on to successive tests in the +// test suite, if any: +# define EXPECT_EXIT(statement, predicate, matcher) \ + GTEST_DEATH_TEST_(statement, predicate, matcher, GTEST_NONFATAL_FAILURE_) + +// Asserts that a given `statement` causes the program to exit, either by +// explicitly exiting with a nonzero exit code or being killed by a +// signal, and emitting error output that matches `matcher`. +# define ASSERT_DEATH(statement, matcher) \ + ASSERT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, matcher) + +// Like `ASSERT_DEATH`, but continues on to successive tests in the +// test suite, if any: +# define EXPECT_DEATH(statement, matcher) \ + EXPECT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, matcher) + +// Two predicate classes that can be used in {ASSERT,EXPECT}_EXIT*: + +// Tests that an exit code describes a normal exit with a given exit code. +class GTEST_API_ ExitedWithCode { + public: + explicit ExitedWithCode(int exit_code); + ExitedWithCode(const ExitedWithCode&) = default; + void operator=(const ExitedWithCode& other) = delete; + bool operator()(int exit_status) const; + private: + const int exit_code_; +}; + +# if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA +// Tests that an exit code describes an exit due to termination by a +// given signal. +class GTEST_API_ KilledBySignal { + public: + explicit KilledBySignal(int signum); + bool operator()(int exit_status) const; + private: + const int signum_; +}; +# endif // !GTEST_OS_WINDOWS + +// EXPECT_DEBUG_DEATH asserts that the given statements die in debug mode. +// The death testing framework causes this to have interesting semantics, +// since the sideeffects of the call are only visible in opt mode, and not +// in debug mode. +// +// In practice, this can be used to test functions that utilize the +// LOG(DFATAL) macro using the following style: +// +// int DieInDebugOr12(int* sideeffect) { +// if (sideeffect) { +// *sideeffect = 12; +// } +// LOG(DFATAL) << "death"; +// return 12; +// } +// +// TEST(TestSuite, TestDieOr12WorksInDgbAndOpt) { +// int sideeffect = 0; +// // Only asserts in dbg. +// EXPECT_DEBUG_DEATH(DieInDebugOr12(&sideeffect), "death"); +// +// #ifdef NDEBUG +// // opt-mode has sideeffect visible. +// EXPECT_EQ(12, sideeffect); +// #else +// // dbg-mode no visible sideeffect. +// EXPECT_EQ(0, sideeffect); +// #endif +// } +// +// This will assert that DieInDebugReturn12InOpt() crashes in debug +// mode, usually due to a DCHECK or LOG(DFATAL), but returns the +// appropriate fallback value (12 in this case) in opt mode. If you +// need to test that a function has appropriate side-effects in opt +// mode, include assertions against the side-effects. A general +// pattern for this is: +// +// EXPECT_DEBUG_DEATH({ +// // Side-effects here will have an effect after this statement in +// // opt mode, but none in debug mode. +// EXPECT_EQ(12, DieInDebugOr12(&sideeffect)); +// }, "death"); +// +# ifdef NDEBUG + +# define EXPECT_DEBUG_DEATH(statement, regex) \ + GTEST_EXECUTE_STATEMENT_(statement, regex) + +# define ASSERT_DEBUG_DEATH(statement, regex) \ + GTEST_EXECUTE_STATEMENT_(statement, regex) + +# else + +# define EXPECT_DEBUG_DEATH(statement, regex) \ + EXPECT_DEATH(statement, regex) + +# define ASSERT_DEBUG_DEATH(statement, regex) \ + ASSERT_DEATH(statement, regex) + +# endif // NDEBUG for EXPECT_DEBUG_DEATH +#endif // GTEST_HAS_DEATH_TEST + +// This macro is used for implementing macros such as +// EXPECT_DEATH_IF_SUPPORTED and ASSERT_DEATH_IF_SUPPORTED on systems where +// death tests are not supported. Those macros must compile on such systems +// if and only if EXPECT_DEATH and ASSERT_DEATH compile with the same parameters +// on systems that support death tests. This allows one to write such a macro on +// a system that does not support death tests and be sure that it will compile +// on a death-test supporting system. It is exposed publicly so that systems +// that have death-tests with stricter requirements than GTEST_HAS_DEATH_TEST +// can write their own equivalent of EXPECT_DEATH_IF_SUPPORTED and +// ASSERT_DEATH_IF_SUPPORTED. +// +// Parameters: +// statement - A statement that a macro such as EXPECT_DEATH would test +// for program termination. This macro has to make sure this +// statement is compiled but not executed, to ensure that +// EXPECT_DEATH_IF_SUPPORTED compiles with a certain +// parameter if and only if EXPECT_DEATH compiles with it. +// regex - A regex that a macro such as EXPECT_DEATH would use to test +// the output of statement. This parameter has to be +// compiled but not evaluated by this macro, to ensure that +// this macro only accepts expressions that a macro such as +// EXPECT_DEATH would accept. +// terminator - Must be an empty statement for EXPECT_DEATH_IF_SUPPORTED +// and a return statement for ASSERT_DEATH_IF_SUPPORTED. +// This ensures that ASSERT_DEATH_IF_SUPPORTED will not +// compile inside functions where ASSERT_DEATH doesn't +// compile. +// +// The branch that has an always false condition is used to ensure that +// statement and regex are compiled (and thus syntactically correct) but +// never executed. The unreachable code macro protects the terminator +// statement from generating an 'unreachable code' warning in case +// statement unconditionally returns or throws. The Message constructor at +// the end allows the syntax of streaming additional messages into the +// macro, for compilational compatibility with EXPECT_DEATH/ASSERT_DEATH. +# define GTEST_UNSUPPORTED_DEATH_TEST(statement, regex, terminator) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + GTEST_LOG_(WARNING) \ + << "Death tests are not supported on this platform.\n" \ + << "Statement '" #statement "' cannot be verified."; \ + } else if (::testing::internal::AlwaysFalse()) { \ + ::testing::internal::RE::PartialMatch(".*", (regex)); \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + terminator; \ + } else \ + ::testing::Message() + +// EXPECT_DEATH_IF_SUPPORTED(statement, regex) and +// ASSERT_DEATH_IF_SUPPORTED(statement, regex) expand to real death tests if +// death tests are supported; otherwise they just issue a warning. This is +// useful when you are combining death test assertions with normal test +// assertions in one test. +#if GTEST_HAS_DEATH_TEST +# define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ + EXPECT_DEATH(statement, regex) +# define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ + ASSERT_DEATH(statement, regex) +#else +# define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ + GTEST_UNSUPPORTED_DEATH_TEST(statement, regex, ) +# define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ + GTEST_UNSUPPORTED_DEATH_TEST(statement, regex, return) +#endif + +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-matchers.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-matchers.h new file mode 100644 index 000000000000..3472db7e17ec --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-matchers.h @@ -0,0 +1,934 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This file implements just enough of the matcher interface to allow +// EXPECT_DEATH and friends to accept a matcher argument. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_MATCHERS_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_MATCHERS_H_ + +#include +#include +#include +#include +#include + +#include "gtest/gtest-printers.h" +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-port.h" + +// MSVC warning C5046 is new as of VS2017 version 15.8. +#if defined(_MSC_VER) && _MSC_VER >= 1915 +#define GTEST_MAYBE_5046_ 5046 +#else +#define GTEST_MAYBE_5046_ +#endif + +GTEST_DISABLE_MSC_WARNINGS_PUSH_( + 4251 GTEST_MAYBE_5046_ /* class A needs to have dll-interface to be used by + clients of class B */ + /* Symbol involving type with internal linkage not defined */) + +namespace testing { + +// To implement a matcher Foo for type T, define: +// 1. a class FooMatcherMatcher that implements the matcher interface: +// using is_gtest_matcher = void; +// bool MatchAndExplain(const T&, std::ostream*); +// (MatchResultListener* can also be used instead of std::ostream*) +// void DescribeTo(std::ostream*); +// void DescribeNegationTo(std::ostream*); +// +// 2. a factory function that creates a Matcher object from a +// FooMatcherMatcher. + +class MatchResultListener { + public: + // Creates a listener object with the given underlying ostream. The + // listener does not own the ostream, and does not dereference it + // in the constructor or destructor. + explicit MatchResultListener(::std::ostream* os) : stream_(os) {} + virtual ~MatchResultListener() = 0; // Makes this class abstract. + + // Streams x to the underlying ostream; does nothing if the ostream + // is NULL. + template + MatchResultListener& operator<<(const T& x) { + if (stream_ != nullptr) *stream_ << x; + return *this; + } + + // Returns the underlying ostream. + ::std::ostream* stream() { return stream_; } + + // Returns true if and only if the listener is interested in an explanation + // of the match result. A matcher's MatchAndExplain() method can use + // this information to avoid generating the explanation when no one + // intends to hear it. + bool IsInterested() const { return stream_ != nullptr; } + + private: + ::std::ostream* const stream_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(MatchResultListener); +}; + +inline MatchResultListener::~MatchResultListener() { +} + +// An instance of a subclass of this knows how to describe itself as a +// matcher. +class GTEST_API_ MatcherDescriberInterface { + public: + virtual ~MatcherDescriberInterface() {} + + // Describes this matcher to an ostream. The function should print + // a verb phrase that describes the property a value matching this + // matcher should have. The subject of the verb phrase is the value + // being matched. For example, the DescribeTo() method of the Gt(7) + // matcher prints "is greater than 7". + virtual void DescribeTo(::std::ostream* os) const = 0; + + // Describes the negation of this matcher to an ostream. For + // example, if the description of this matcher is "is greater than + // 7", the negated description could be "is not greater than 7". + // You are not required to override this when implementing + // MatcherInterface, but it is highly advised so that your matcher + // can produce good error messages. + virtual void DescribeNegationTo(::std::ostream* os) const { + *os << "not ("; + DescribeTo(os); + *os << ")"; + } +}; + +// The implementation of a matcher. +template +class MatcherInterface : public MatcherDescriberInterface { + public: + // Returns true if and only if the matcher matches x; also explains the + // match result to 'listener' if necessary (see the next paragraph), in + // the form of a non-restrictive relative clause ("which ...", + // "whose ...", etc) that describes x. For example, the + // MatchAndExplain() method of the Pointee(...) matcher should + // generate an explanation like "which points to ...". + // + // Implementations of MatchAndExplain() should add an explanation of + // the match result *if and only if* they can provide additional + // information that's not already present (or not obvious) in the + // print-out of x and the matcher's description. Whether the match + // succeeds is not a factor in deciding whether an explanation is + // needed, as sometimes the caller needs to print a failure message + // when the match succeeds (e.g. when the matcher is used inside + // Not()). + // + // For example, a "has at least 10 elements" matcher should explain + // what the actual element count is, regardless of the match result, + // as it is useful information to the reader; on the other hand, an + // "is empty" matcher probably only needs to explain what the actual + // size is when the match fails, as it's redundant to say that the + // size is 0 when the value is already known to be empty. + // + // You should override this method when defining a new matcher. + // + // It's the responsibility of the caller (Google Test) to guarantee + // that 'listener' is not NULL. This helps to simplify a matcher's + // implementation when it doesn't care about the performance, as it + // can talk to 'listener' without checking its validity first. + // However, in order to implement dummy listeners efficiently, + // listener->stream() may be NULL. + virtual bool MatchAndExplain(T x, MatchResultListener* listener) const = 0; + + // Inherits these methods from MatcherDescriberInterface: + // virtual void DescribeTo(::std::ostream* os) const = 0; + // virtual void DescribeNegationTo(::std::ostream* os) const; +}; + +namespace internal { + +struct AnyEq { + template + bool operator()(const A& a, const B& b) const { return a == b; } +}; +struct AnyNe { + template + bool operator()(const A& a, const B& b) const { return a != b; } +}; +struct AnyLt { + template + bool operator()(const A& a, const B& b) const { return a < b; } +}; +struct AnyGt { + template + bool operator()(const A& a, const B& b) const { return a > b; } +}; +struct AnyLe { + template + bool operator()(const A& a, const B& b) const { return a <= b; } +}; +struct AnyGe { + template + bool operator()(const A& a, const B& b) const { return a >= b; } +}; + +// A match result listener that ignores the explanation. +class DummyMatchResultListener : public MatchResultListener { + public: + DummyMatchResultListener() : MatchResultListener(nullptr) {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(DummyMatchResultListener); +}; + +// A match result listener that forwards the explanation to a given +// ostream. The difference between this and MatchResultListener is +// that the former is concrete. +class StreamMatchResultListener : public MatchResultListener { + public: + explicit StreamMatchResultListener(::std::ostream* os) + : MatchResultListener(os) {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(StreamMatchResultListener); +}; + +struct SharedPayloadBase { + std::atomic ref{1}; + void Ref() { ref.fetch_add(1, std::memory_order_relaxed); } + bool Unref() { return ref.fetch_sub(1, std::memory_order_acq_rel) == 1; } +}; + +template +struct SharedPayload : SharedPayloadBase { + explicit SharedPayload(const T& v) : value(v) {} + explicit SharedPayload(T&& v) : value(std::move(v)) {} + + static void Destroy(SharedPayloadBase* shared) { + delete static_cast(shared); + } + + T value; +}; + +// An internal class for implementing Matcher, which will derive +// from it. We put functionalities common to all Matcher +// specializations here to avoid code duplication. +template +class MatcherBase : private MatcherDescriberInterface { + public: + // Returns true if and only if the matcher matches x; also explains the + // match result to 'listener'. + bool MatchAndExplain(const T& x, MatchResultListener* listener) const { + GTEST_CHECK_(vtable_ != nullptr); + return vtable_->match_and_explain(*this, x, listener); + } + + // Returns true if and only if this matcher matches x. + bool Matches(const T& x) const { + DummyMatchResultListener dummy; + return MatchAndExplain(x, &dummy); + } + + // Describes this matcher to an ostream. + void DescribeTo(::std::ostream* os) const final { + GTEST_CHECK_(vtable_ != nullptr); + vtable_->describe(*this, os, false); + } + + // Describes the negation of this matcher to an ostream. + void DescribeNegationTo(::std::ostream* os) const final { + GTEST_CHECK_(vtable_ != nullptr); + vtable_->describe(*this, os, true); + } + + // Explains why x matches, or doesn't match, the matcher. + void ExplainMatchResultTo(const T& x, ::std::ostream* os) const { + StreamMatchResultListener listener(os); + MatchAndExplain(x, &listener); + } + + // Returns the describer for this matcher object; retains ownership + // of the describer, which is only guaranteed to be alive when + // this matcher object is alive. + const MatcherDescriberInterface* GetDescriber() const { + if (vtable_ == nullptr) return nullptr; + return vtable_->get_describer(*this); + } + + protected: + MatcherBase() : vtable_(nullptr) {} + + // Constructs a matcher from its implementation. + template + explicit MatcherBase(const MatcherInterface* impl) { + Init(impl); + } + + template ::type::is_gtest_matcher> + MatcherBase(M&& m) { // NOLINT + Init(std::forward(m)); + } + + MatcherBase(const MatcherBase& other) + : vtable_(other.vtable_), buffer_(other.buffer_) { + if (IsShared()) buffer_.shared->Ref(); + } + + MatcherBase& operator=(const MatcherBase& other) { + if (this == &other) return *this; + Destroy(); + vtable_ = other.vtable_; + buffer_ = other.buffer_; + if (IsShared()) buffer_.shared->Ref(); + return *this; + } + + MatcherBase(MatcherBase&& other) + : vtable_(other.vtable_), buffer_(other.buffer_) { + other.vtable_ = nullptr; + } + + MatcherBase& operator=(MatcherBase&& other) { + if (this == &other) return *this; + Destroy(); + vtable_ = other.vtable_; + buffer_ = other.buffer_; + other.vtable_ = nullptr; + return *this; + } + + ~MatcherBase() override { Destroy(); } + + private: + struct VTable { + bool (*match_and_explain)(const MatcherBase&, const T&, + MatchResultListener*); + void (*describe)(const MatcherBase&, std::ostream*, bool negation); + // Returns the captured object if it implements the interface, otherwise + // returns the MatcherBase itself. + const MatcherDescriberInterface* (*get_describer)(const MatcherBase&); + // Called on shared instances when the reference count reaches 0. + void (*shared_destroy)(SharedPayloadBase*); + }; + + bool IsShared() const { + return vtable_ != nullptr && vtable_->shared_destroy != nullptr; + } + + // If the implementation uses a listener, call that. + template + static auto MatchAndExplainImpl(const MatcherBase& m, const T& value, + MatchResultListener* listener) + -> decltype(P::Get(m).MatchAndExplain(value, listener->stream())) { + return P::Get(m).MatchAndExplain(value, listener->stream()); + } + + template + static auto MatchAndExplainImpl(const MatcherBase& m, const T& value, + MatchResultListener* listener) + -> decltype(P::Get(m).MatchAndExplain(value, listener)) { + return P::Get(m).MatchAndExplain(value, listener); + } + + template + static void DescribeImpl(const MatcherBase& m, std::ostream* os, + bool negation) { + if (negation) { + P::Get(m).DescribeNegationTo(os); + } else { + P::Get(m).DescribeTo(os); + } + } + + template + static const MatcherDescriberInterface* GetDescriberImpl( + const MatcherBase& m) { + // If the impl is a MatcherDescriberInterface, then return it. + // Otherwise use MatcherBase itself. + // This allows us to implement the GetDescriber() function without support + // from the impl, but some users really want to get their impl back when + // they call GetDescriber(). + // We use std::get on a tuple as a workaround of not having `if constexpr`. + return std::get<( + std::is_convertible::value + ? 1 + : 0)>(std::make_tuple(&m, &P::Get(m))); + } + + template + const VTable* GetVTable() { + static constexpr VTable kVTable = {&MatchAndExplainImpl

, + &DescribeImpl

, &GetDescriberImpl

, + P::shared_destroy}; + return &kVTable; + } + + union Buffer { + // Add some types to give Buffer some common alignment/size use cases. + void* ptr; + double d; + int64_t i; + // And add one for the out-of-line cases. + SharedPayloadBase* shared; + }; + + void Destroy() { + if (IsShared() && buffer_.shared->Unref()) { + vtable_->shared_destroy(buffer_.shared); + } + } + + template + static constexpr bool IsInlined() { + return sizeof(M) <= sizeof(Buffer) && alignof(M) <= alignof(Buffer) && + std::is_trivially_copy_constructible::value && + std::is_trivially_destructible::value; + } + + template ()> + struct ValuePolicy { + static const M& Get(const MatcherBase& m) { + // When inlined along with Init, need to be explicit to avoid violating + // strict aliasing rules. + const M *ptr = static_cast( + static_cast(&m.buffer_)); + return *ptr; + } + static void Init(MatcherBase& m, M impl) { + ::new (static_cast(&m.buffer_)) M(impl); + } + static constexpr auto shared_destroy = nullptr; + }; + + template + struct ValuePolicy { + using Shared = SharedPayload; + static const M& Get(const MatcherBase& m) { + return static_cast(m.buffer_.shared)->value; + } + template + static void Init(MatcherBase& m, Arg&& arg) { + m.buffer_.shared = new Shared(std::forward(arg)); + } + static constexpr auto shared_destroy = &Shared::Destroy; + }; + + template + struct ValuePolicy*, B> { + using M = const MatcherInterface; + using Shared = SharedPayload>; + static const M& Get(const MatcherBase& m) { + return *static_cast(m.buffer_.shared)->value; + } + static void Init(MatcherBase& m, M* impl) { + m.buffer_.shared = new Shared(std::unique_ptr(impl)); + } + + static constexpr auto shared_destroy = &Shared::Destroy; + }; + + template + void Init(M&& m) { + using MM = typename std::decay::type; + using Policy = ValuePolicy; + vtable_ = GetVTable(); + Policy::Init(*this, std::forward(m)); + } + + const VTable* vtable_; + Buffer buffer_; +}; + +} // namespace internal + +// A Matcher is a copyable and IMMUTABLE (except by assignment) +// object that can check whether a value of type T matches. The +// implementation of Matcher is just a std::shared_ptr to const +// MatcherInterface. Don't inherit from Matcher! +template +class Matcher : public internal::MatcherBase { + public: + // Constructs a null matcher. Needed for storing Matcher objects in STL + // containers. A default-constructed matcher is not yet initialized. You + // cannot use it until a valid value has been assigned to it. + explicit Matcher() {} // NOLINT + + // Constructs a matcher from its implementation. + explicit Matcher(const MatcherInterface* impl) + : internal::MatcherBase(impl) {} + + template + explicit Matcher( + const MatcherInterface* impl, + typename std::enable_if::value>::type* = + nullptr) + : internal::MatcherBase(impl) {} + + template ::type::is_gtest_matcher> + Matcher(M&& m) : internal::MatcherBase(std::forward(m)) {} // NOLINT + + // Implicit constructor here allows people to write + // EXPECT_CALL(foo, Bar(5)) instead of EXPECT_CALL(foo, Bar(Eq(5))) sometimes + Matcher(T value); // NOLINT +}; + +// The following two specializations allow the user to write str +// instead of Eq(str) and "foo" instead of Eq("foo") when a std::string +// matcher is expected. +template <> +class GTEST_API_ Matcher + : public internal::MatcherBase { + public: + Matcher() {} + + explicit Matcher(const MatcherInterface* impl) + : internal::MatcherBase(impl) {} + + template ::type::is_gtest_matcher> + Matcher(M&& m) // NOLINT + : internal::MatcherBase(std::forward(m)) {} + + // Allows the user to write str instead of Eq(str) sometimes, where + // str is a std::string object. + Matcher(const std::string& s); // NOLINT + + // Allows the user to write "foo" instead of Eq("foo") sometimes. + Matcher(const char* s); // NOLINT +}; + +template <> +class GTEST_API_ Matcher + : public internal::MatcherBase { + public: + Matcher() {} + + explicit Matcher(const MatcherInterface* impl) + : internal::MatcherBase(impl) {} + explicit Matcher(const MatcherInterface* impl) + : internal::MatcherBase(impl) {} + + template ::type::is_gtest_matcher> + Matcher(M&& m) // NOLINT + : internal::MatcherBase(std::forward(m)) {} + + // Allows the user to write str instead of Eq(str) sometimes, where + // str is a string object. + Matcher(const std::string& s); // NOLINT + + // Allows the user to write "foo" instead of Eq("foo") sometimes. + Matcher(const char* s); // NOLINT +}; + +#if GTEST_INTERNAL_HAS_STRING_VIEW +// The following two specializations allow the user to write str +// instead of Eq(str) and "foo" instead of Eq("foo") when a absl::string_view +// matcher is expected. +template <> +class GTEST_API_ Matcher + : public internal::MatcherBase { + public: + Matcher() {} + + explicit Matcher(const MatcherInterface* impl) + : internal::MatcherBase(impl) {} + + template ::type::is_gtest_matcher> + Matcher(M&& m) // NOLINT + : internal::MatcherBase(std::forward(m)) { + } + + // Allows the user to write str instead of Eq(str) sometimes, where + // str is a std::string object. + Matcher(const std::string& s); // NOLINT + + // Allows the user to write "foo" instead of Eq("foo") sometimes. + Matcher(const char* s); // NOLINT + + // Allows the user to pass absl::string_views or std::string_views directly. + Matcher(internal::StringView s); // NOLINT +}; + +template <> +class GTEST_API_ Matcher + : public internal::MatcherBase { + public: + Matcher() {} + + explicit Matcher(const MatcherInterface* impl) + : internal::MatcherBase(impl) {} + explicit Matcher(const MatcherInterface* impl) + : internal::MatcherBase(impl) {} + + template ::type::is_gtest_matcher> + Matcher(M&& m) // NOLINT + : internal::MatcherBase(std::forward(m)) {} + + // Allows the user to write str instead of Eq(str) sometimes, where + // str is a std::string object. + Matcher(const std::string& s); // NOLINT + + // Allows the user to write "foo" instead of Eq("foo") sometimes. + Matcher(const char* s); // NOLINT + + // Allows the user to pass absl::string_views or std::string_views directly. + Matcher(internal::StringView s); // NOLINT +}; +#endif // GTEST_INTERNAL_HAS_STRING_VIEW + +// Prints a matcher in a human-readable format. +template +std::ostream& operator<<(std::ostream& os, const Matcher& matcher) { + matcher.DescribeTo(&os); + return os; +} + +// The PolymorphicMatcher class template makes it easy to implement a +// polymorphic matcher (i.e. a matcher that can match values of more +// than one type, e.g. Eq(n) and NotNull()). +// +// To define a polymorphic matcher, a user should provide an Impl +// class that has a DescribeTo() method and a DescribeNegationTo() +// method, and define a member function (or member function template) +// +// bool MatchAndExplain(const Value& value, +// MatchResultListener* listener) const; +// +// See the definition of NotNull() for a complete example. +template +class PolymorphicMatcher { + public: + explicit PolymorphicMatcher(const Impl& an_impl) : impl_(an_impl) {} + + // Returns a mutable reference to the underlying matcher + // implementation object. + Impl& mutable_impl() { return impl_; } + + // Returns an immutable reference to the underlying matcher + // implementation object. + const Impl& impl() const { return impl_; } + + template + operator Matcher() const { + return Matcher(new MonomorphicImpl(impl_)); + } + + private: + template + class MonomorphicImpl : public MatcherInterface { + public: + explicit MonomorphicImpl(const Impl& impl) : impl_(impl) {} + + void DescribeTo(::std::ostream* os) const override { impl_.DescribeTo(os); } + + void DescribeNegationTo(::std::ostream* os) const override { + impl_.DescribeNegationTo(os); + } + + bool MatchAndExplain(T x, MatchResultListener* listener) const override { + return impl_.MatchAndExplain(x, listener); + } + + private: + const Impl impl_; + }; + + Impl impl_; +}; + +// Creates a matcher from its implementation. +// DEPRECATED: Especially in the generic code, prefer: +// Matcher(new MyMatcherImpl(...)); +// +// MakeMatcher may create a Matcher that accepts its argument by value, which +// leads to unnecessary copies & lack of support for non-copyable types. +template +inline Matcher MakeMatcher(const MatcherInterface* impl) { + return Matcher(impl); +} + +// Creates a polymorphic matcher from its implementation. This is +// easier to use than the PolymorphicMatcher constructor as it +// doesn't require you to explicitly write the template argument, e.g. +// +// MakePolymorphicMatcher(foo); +// vs +// PolymorphicMatcher(foo); +template +inline PolymorphicMatcher MakePolymorphicMatcher(const Impl& impl) { + return PolymorphicMatcher(impl); +} + +namespace internal { +// Implements a matcher that compares a given value with a +// pre-supplied value using one of the ==, <=, <, etc, operators. The +// two values being compared don't have to have the same type. +// +// The matcher defined here is polymorphic (for example, Eq(5) can be +// used to match an int, a short, a double, etc). Therefore we use +// a template type conversion operator in the implementation. +// +// The following template definition assumes that the Rhs parameter is +// a "bare" type (i.e. neither 'const T' nor 'T&'). +template +class ComparisonBase { + public: + explicit ComparisonBase(const Rhs& rhs) : rhs_(rhs) {} + + using is_gtest_matcher = void; + + template + bool MatchAndExplain(const Lhs& lhs, std::ostream*) const { + return Op()(lhs, Unwrap(rhs_)); + } + void DescribeTo(std::ostream* os) const { + *os << D::Desc() << " "; + UniversalPrint(Unwrap(rhs_), os); + } + void DescribeNegationTo(std::ostream* os) const { + *os << D::NegatedDesc() << " "; + UniversalPrint(Unwrap(rhs_), os); + } + + private: + template + static const T& Unwrap(const T& v) { + return v; + } + template + static const T& Unwrap(std::reference_wrapper v) { + return v; + } + + Rhs rhs_; +}; + +template +class EqMatcher : public ComparisonBase, Rhs, AnyEq> { + public: + explicit EqMatcher(const Rhs& rhs) + : ComparisonBase, Rhs, AnyEq>(rhs) { } + static const char* Desc() { return "is equal to"; } + static const char* NegatedDesc() { return "isn't equal to"; } +}; +template +class NeMatcher : public ComparisonBase, Rhs, AnyNe> { + public: + explicit NeMatcher(const Rhs& rhs) + : ComparisonBase, Rhs, AnyNe>(rhs) { } + static const char* Desc() { return "isn't equal to"; } + static const char* NegatedDesc() { return "is equal to"; } +}; +template +class LtMatcher : public ComparisonBase, Rhs, AnyLt> { + public: + explicit LtMatcher(const Rhs& rhs) + : ComparisonBase, Rhs, AnyLt>(rhs) { } + static const char* Desc() { return "is <"; } + static const char* NegatedDesc() { return "isn't <"; } +}; +template +class GtMatcher : public ComparisonBase, Rhs, AnyGt> { + public: + explicit GtMatcher(const Rhs& rhs) + : ComparisonBase, Rhs, AnyGt>(rhs) { } + static const char* Desc() { return "is >"; } + static const char* NegatedDesc() { return "isn't >"; } +}; +template +class LeMatcher : public ComparisonBase, Rhs, AnyLe> { + public: + explicit LeMatcher(const Rhs& rhs) + : ComparisonBase, Rhs, AnyLe>(rhs) { } + static const char* Desc() { return "is <="; } + static const char* NegatedDesc() { return "isn't <="; } +}; +template +class GeMatcher : public ComparisonBase, Rhs, AnyGe> { + public: + explicit GeMatcher(const Rhs& rhs) + : ComparisonBase, Rhs, AnyGe>(rhs) { } + static const char* Desc() { return "is >="; } + static const char* NegatedDesc() { return "isn't >="; } +}; + +template ::value>::type> +using StringLike = T; + +// Implements polymorphic matchers MatchesRegex(regex) and +// ContainsRegex(regex), which can be used as a Matcher as long as +// T can be converted to a string. +class MatchesRegexMatcher { + public: + MatchesRegexMatcher(const RE* regex, bool full_match) + : regex_(regex), full_match_(full_match) {} + +#if GTEST_INTERNAL_HAS_STRING_VIEW + bool MatchAndExplain(const internal::StringView& s, + MatchResultListener* listener) const { + return MatchAndExplain(std::string(s), listener); + } +#endif // GTEST_INTERNAL_HAS_STRING_VIEW + + // Accepts pointer types, particularly: + // const char* + // char* + // const wchar_t* + // wchar_t* + template + bool MatchAndExplain(CharType* s, MatchResultListener* listener) const { + return s != nullptr && MatchAndExplain(std::string(s), listener); + } + + // Matches anything that can convert to std::string. + // + // This is a template, not just a plain function with const std::string&, + // because absl::string_view has some interfering non-explicit constructors. + template + bool MatchAndExplain(const MatcheeStringType& s, + MatchResultListener* /* listener */) const { + const std::string& s2(s); + return full_match_ ? RE::FullMatch(s2, *regex_) + : RE::PartialMatch(s2, *regex_); + } + + void DescribeTo(::std::ostream* os) const { + *os << (full_match_ ? "matches" : "contains") << " regular expression "; + UniversalPrinter::Print(regex_->pattern(), os); + } + + void DescribeNegationTo(::std::ostream* os) const { + *os << "doesn't " << (full_match_ ? "match" : "contain") + << " regular expression "; + UniversalPrinter::Print(regex_->pattern(), os); + } + + private: + const std::shared_ptr regex_; + const bool full_match_; +}; +} // namespace internal + +// Matches a string that fully matches regular expression 'regex'. +// The matcher takes ownership of 'regex'. +inline PolymorphicMatcher MatchesRegex( + const internal::RE* regex) { + return MakePolymorphicMatcher(internal::MatchesRegexMatcher(regex, true)); +} +template +PolymorphicMatcher MatchesRegex( + const internal::StringLike& regex) { + return MatchesRegex(new internal::RE(std::string(regex))); +} + +// Matches a string that contains regular expression 'regex'. +// The matcher takes ownership of 'regex'. +inline PolymorphicMatcher ContainsRegex( + const internal::RE* regex) { + return MakePolymorphicMatcher(internal::MatchesRegexMatcher(regex, false)); +} +template +PolymorphicMatcher ContainsRegex( + const internal::StringLike& regex) { + return ContainsRegex(new internal::RE(std::string(regex))); +} + +// Creates a polymorphic matcher that matches anything equal to x. +// Note: if the parameter of Eq() were declared as const T&, Eq("foo") +// wouldn't compile. +template +inline internal::EqMatcher Eq(T x) { return internal::EqMatcher(x); } + +// Constructs a Matcher from a 'value' of type T. The constructed +// matcher matches any value that's equal to 'value'. +template +Matcher::Matcher(T value) { *this = Eq(value); } + +// Creates a monomorphic matcher that matches anything with type Lhs +// and equal to rhs. A user may need to use this instead of Eq(...) +// in order to resolve an overloading ambiguity. +// +// TypedEq(x) is just a convenient short-hand for Matcher(Eq(x)) +// or Matcher(x), but more readable than the latter. +// +// We could define similar monomorphic matchers for other comparison +// operations (e.g. TypedLt, TypedGe, and etc), but decided not to do +// it yet as those are used much less than Eq() in practice. A user +// can always write Matcher(Lt(5)) to be explicit about the type, +// for example. +template +inline Matcher TypedEq(const Rhs& rhs) { return Eq(rhs); } + +// Creates a polymorphic matcher that matches anything >= x. +template +inline internal::GeMatcher Ge(Rhs x) { + return internal::GeMatcher(x); +} + +// Creates a polymorphic matcher that matches anything > x. +template +inline internal::GtMatcher Gt(Rhs x) { + return internal::GtMatcher(x); +} + +// Creates a polymorphic matcher that matches anything <= x. +template +inline internal::LeMatcher Le(Rhs x) { + return internal::LeMatcher(x); +} + +// Creates a polymorphic matcher that matches anything < x. +template +inline internal::LtMatcher Lt(Rhs x) { + return internal::LtMatcher(x); +} + +// Creates a polymorphic matcher that matches anything != x. +template +inline internal::NeMatcher Ne(Rhs x) { + return internal::NeMatcher(x); +} +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 5046 + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_MATCHERS_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-message.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-message.h new file mode 100644 index 000000000000..9419229ffa07 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-message.h @@ -0,0 +1,220 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file defines the Message class. +// +// IMPORTANT NOTE: Due to limitation of the C++ language, we have to +// leave some internal implementation details in this header file. +// They are clearly marked by comments like this: +// +// // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +// +// Such code is NOT meant to be used by a user directly, and is subject +// to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user +// program! + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ + +#include +#include +#include + +#include "gtest/internal/gtest-port.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +// Ensures that there is at least one operator<< in the global namespace. +// See Message& operator<<(...) below for why. +void operator<<(const testing::internal::Secret&, int); + +namespace testing { + +// The Message class works like an ostream repeater. +// +// Typical usage: +// +// 1. You stream a bunch of values to a Message object. +// It will remember the text in a stringstream. +// 2. Then you stream the Message object to an ostream. +// This causes the text in the Message to be streamed +// to the ostream. +// +// For example; +// +// testing::Message foo; +// foo << 1 << " != " << 2; +// std::cout << foo; +// +// will print "1 != 2". +// +// Message is not intended to be inherited from. In particular, its +// destructor is not virtual. +// +// Note that stringstream behaves differently in gcc and in MSVC. You +// can stream a NULL char pointer to it in the former, but not in the +// latter (it causes an access violation if you do). The Message +// class hides this difference by treating a NULL char pointer as +// "(null)". +class GTEST_API_ Message { + private: + // The type of basic IO manipulators (endl, ends, and flush) for + // narrow streams. + typedef std::ostream& (*BasicNarrowIoManip)(std::ostream&); + + public: + // Constructs an empty Message. + Message(); + + // Copy constructor. + Message(const Message& msg) : ss_(new ::std::stringstream) { // NOLINT + *ss_ << msg.GetString(); + } + + // Constructs a Message from a C-string. + explicit Message(const char* str) : ss_(new ::std::stringstream) { + *ss_ << str; + } + + // Streams a non-pointer value to this object. + template + inline Message& operator <<(const T& val) { + // Some libraries overload << for STL containers. These + // overloads are defined in the global namespace instead of ::std. + // + // C++'s symbol lookup rule (i.e. Koenig lookup) says that these + // overloads are visible in either the std namespace or the global + // namespace, but not other namespaces, including the testing + // namespace which Google Test's Message class is in. + // + // To allow STL containers (and other types that has a << operator + // defined in the global namespace) to be used in Google Test + // assertions, testing::Message must access the custom << operator + // from the global namespace. With this using declaration, + // overloads of << defined in the global namespace and those + // visible via Koenig lookup are both exposed in this function. + using ::operator <<; + *ss_ << val; + return *this; + } + + // Streams a pointer value to this object. + // + // This function is an overload of the previous one. When you + // stream a pointer to a Message, this definition will be used as it + // is more specialized. (The C++ Standard, section + // [temp.func.order].) If you stream a non-pointer, then the + // previous definition will be used. + // + // The reason for this overload is that streaming a NULL pointer to + // ostream is undefined behavior. Depending on the compiler, you + // may get "0", "(nil)", "(null)", or an access violation. To + // ensure consistent result across compilers, we always treat NULL + // as "(null)". + template + inline Message& operator <<(T* const& pointer) { // NOLINT + if (pointer == nullptr) { + *ss_ << "(null)"; + } else { + *ss_ << pointer; + } + return *this; + } + + // Since the basic IO manipulators are overloaded for both narrow + // and wide streams, we have to provide this specialized definition + // of operator <<, even though its body is the same as the + // templatized version above. Without this definition, streaming + // endl or other basic IO manipulators to Message will confuse the + // compiler. + Message& operator <<(BasicNarrowIoManip val) { + *ss_ << val; + return *this; + } + + // Instead of 1/0, we want to see true/false for bool values. + Message& operator <<(bool b) { + return *this << (b ? "true" : "false"); + } + + // These two overloads allow streaming a wide C string to a Message + // using the UTF-8 encoding. + Message& operator <<(const wchar_t* wide_c_str); + Message& operator <<(wchar_t* wide_c_str); + +#if GTEST_HAS_STD_WSTRING + // Converts the given wide string to a narrow string using the UTF-8 + // encoding, and streams the result to this Message object. + Message& operator <<(const ::std::wstring& wstr); +#endif // GTEST_HAS_STD_WSTRING + + // Gets the text streamed to this object so far as an std::string. + // Each '\0' character in the buffer is replaced with "\\0". + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + std::string GetString() const; + + private: + // We'll hold the text streamed to this object here. + const std::unique_ptr< ::std::stringstream> ss_; + + // We declare (but don't implement) this to prevent the compiler + // from implementing the assignment operator. + void operator=(const Message&); +}; + +// Streams a Message to an ostream. +inline std::ostream& operator <<(std::ostream& os, const Message& sb) { + return os << sb.GetString(); +} + +namespace internal { + +// Converts a streamable value to an std::string. A NULL pointer is +// converted to "(null)". When the input value is a ::string, +// ::std::string, ::wstring, or ::std::wstring object, each NUL +// character in it is replaced with "\\0". +template +std::string StreamableToString(const T& streamable) { + return (Message() << streamable).GetString(); +} + +} // namespace internal +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-param-test.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-param-test.h new file mode 100644 index 000000000000..96c1c72254f2 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-param-test.h @@ -0,0 +1,510 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Macros and functions for implementing parameterized tests +// in Google C++ Testing and Mocking Framework (Google Test) + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ + +// Value-parameterized tests allow you to test your code with different +// parameters without writing multiple copies of the same test. +// +// Here is how you use value-parameterized tests: + +#if 0 + +// To write value-parameterized tests, first you should define a fixture +// class. It is usually derived from testing::TestWithParam (see below for +// another inheritance scheme that's sometimes useful in more complicated +// class hierarchies), where the type of your parameter values. +// TestWithParam is itself derived from testing::Test. T can be any +// copyable type. If it's a raw pointer, you are responsible for managing the +// lifespan of the pointed values. + +class FooTest : public ::testing::TestWithParam { + // You can implement all the usual class fixture members here. +}; + +// Then, use the TEST_P macro to define as many parameterized tests +// for this fixture as you want. The _P suffix is for "parameterized" +// or "pattern", whichever you prefer to think. + +TEST_P(FooTest, DoesBlah) { + // Inside a test, access the test parameter with the GetParam() method + // of the TestWithParam class: + EXPECT_TRUE(foo.Blah(GetParam())); + ... +} + +TEST_P(FooTest, HasBlahBlah) { + ... +} + +// Finally, you can use INSTANTIATE_TEST_SUITE_P to instantiate the test +// case with any set of parameters you want. Google Test defines a number +// of functions for generating test parameters. They return what we call +// (surprise!) parameter generators. Here is a summary of them, which +// are all in the testing namespace: +// +// +// Range(begin, end [, step]) - Yields values {begin, begin+step, +// begin+step+step, ...}. The values do not +// include end. step defaults to 1. +// Values(v1, v2, ..., vN) - Yields values {v1, v2, ..., vN}. +// ValuesIn(container) - Yields values from a C-style array, an STL +// ValuesIn(begin,end) container, or an iterator range [begin, end). +// Bool() - Yields sequence {false, true}. +// Combine(g1, g2, ..., gN) - Yields all combinations (the Cartesian product +// for the math savvy) of the values generated +// by the N generators. +// +// For more details, see comments at the definitions of these functions below +// in this file. +// +// The following statement will instantiate tests from the FooTest test suite +// each with parameter values "meeny", "miny", and "moe". + +INSTANTIATE_TEST_SUITE_P(InstantiationName, + FooTest, + Values("meeny", "miny", "moe")); + +// To distinguish different instances of the pattern, (yes, you +// can instantiate it more than once) the first argument to the +// INSTANTIATE_TEST_SUITE_P macro is a prefix that will be added to the +// actual test suite name. Remember to pick unique prefixes for different +// instantiations. The tests from the instantiation above will have +// these names: +// +// * InstantiationName/FooTest.DoesBlah/0 for "meeny" +// * InstantiationName/FooTest.DoesBlah/1 for "miny" +// * InstantiationName/FooTest.DoesBlah/2 for "moe" +// * InstantiationName/FooTest.HasBlahBlah/0 for "meeny" +// * InstantiationName/FooTest.HasBlahBlah/1 for "miny" +// * InstantiationName/FooTest.HasBlahBlah/2 for "moe" +// +// You can use these names in --gtest_filter. +// +// This statement will instantiate all tests from FooTest again, each +// with parameter values "cat" and "dog": + +const char* pets[] = {"cat", "dog"}; +INSTANTIATE_TEST_SUITE_P(AnotherInstantiationName, FooTest, ValuesIn(pets)); + +// The tests from the instantiation above will have these names: +// +// * AnotherInstantiationName/FooTest.DoesBlah/0 for "cat" +// * AnotherInstantiationName/FooTest.DoesBlah/1 for "dog" +// * AnotherInstantiationName/FooTest.HasBlahBlah/0 for "cat" +// * AnotherInstantiationName/FooTest.HasBlahBlah/1 for "dog" +// +// Please note that INSTANTIATE_TEST_SUITE_P will instantiate all tests +// in the given test suite, whether their definitions come before or +// AFTER the INSTANTIATE_TEST_SUITE_P statement. +// +// Please also note that generator expressions (including parameters to the +// generators) are evaluated in InitGoogleTest(), after main() has started. +// This allows the user on one hand, to adjust generator parameters in order +// to dynamically determine a set of tests to run and on the other hand, +// give the user a chance to inspect the generated tests with Google Test +// reflection API before RUN_ALL_TESTS() is executed. +// +// You can see samples/sample7_unittest.cc and samples/sample8_unittest.cc +// for more examples. +// +// In the future, we plan to publish the API for defining new parameter +// generators. But for now this interface remains part of the internal +// implementation and is subject to change. +// +// +// A parameterized test fixture must be derived from testing::Test and from +// testing::WithParamInterface, where T is the type of the parameter +// values. Inheriting from TestWithParam satisfies that requirement because +// TestWithParam inherits from both Test and WithParamInterface. In more +// complicated hierarchies, however, it is occasionally useful to inherit +// separately from Test and WithParamInterface. For example: + +class BaseTest : public ::testing::Test { + // You can inherit all the usual members for a non-parameterized test + // fixture here. +}; + +class DerivedTest : public BaseTest, public ::testing::WithParamInterface { + // The usual test fixture members go here too. +}; + +TEST_F(BaseTest, HasFoo) { + // This is an ordinary non-parameterized test. +} + +TEST_P(DerivedTest, DoesBlah) { + // GetParam works just the same here as if you inherit from TestWithParam. + EXPECT_TRUE(foo.Blah(GetParam())); +} + +#endif // 0 + +#include +#include + +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-param-util.h" +#include "gtest/internal/gtest-port.h" + +namespace testing { + +// Functions producing parameter generators. +// +// Google Test uses these generators to produce parameters for value- +// parameterized tests. When a parameterized test suite is instantiated +// with a particular generator, Google Test creates and runs tests +// for each element in the sequence produced by the generator. +// +// In the following sample, tests from test suite FooTest are instantiated +// each three times with parameter values 3, 5, and 8: +// +// class FooTest : public TestWithParam { ... }; +// +// TEST_P(FooTest, TestThis) { +// } +// TEST_P(FooTest, TestThat) { +// } +// INSTANTIATE_TEST_SUITE_P(TestSequence, FooTest, Values(3, 5, 8)); +// + +// Range() returns generators providing sequences of values in a range. +// +// Synopsis: +// Range(start, end) +// - returns a generator producing a sequence of values {start, start+1, +// start+2, ..., }. +// Range(start, end, step) +// - returns a generator producing a sequence of values {start, start+step, +// start+step+step, ..., }. +// Notes: +// * The generated sequences never include end. For example, Range(1, 5) +// returns a generator producing a sequence {1, 2, 3, 4}. Range(1, 9, 2) +// returns a generator producing {1, 3, 5, 7}. +// * start and end must have the same type. That type may be any integral or +// floating-point type or a user defined type satisfying these conditions: +// * It must be assignable (have operator=() defined). +// * It must have operator+() (operator+(int-compatible type) for +// two-operand version). +// * It must have operator<() defined. +// Elements in the resulting sequences will also have that type. +// * Condition start < end must be satisfied in order for resulting sequences +// to contain any elements. +// +template +internal::ParamGenerator Range(T start, T end, IncrementT step) { + return internal::ParamGenerator( + new internal::RangeGenerator(start, end, step)); +} + +template +internal::ParamGenerator Range(T start, T end) { + return Range(start, end, 1); +} + +// ValuesIn() function allows generation of tests with parameters coming from +// a container. +// +// Synopsis: +// ValuesIn(const T (&array)[N]) +// - returns a generator producing sequences with elements from +// a C-style array. +// ValuesIn(const Container& container) +// - returns a generator producing sequences with elements from +// an STL-style container. +// ValuesIn(Iterator begin, Iterator end) +// - returns a generator producing sequences with elements from +// a range [begin, end) defined by a pair of STL-style iterators. These +// iterators can also be plain C pointers. +// +// Please note that ValuesIn copies the values from the containers +// passed in and keeps them to generate tests in RUN_ALL_TESTS(). +// +// Examples: +// +// This instantiates tests from test suite StringTest +// each with C-string values of "foo", "bar", and "baz": +// +// const char* strings[] = {"foo", "bar", "baz"}; +// INSTANTIATE_TEST_SUITE_P(StringSequence, StringTest, ValuesIn(strings)); +// +// This instantiates tests from test suite StlStringTest +// each with STL strings with values "a" and "b": +// +// ::std::vector< ::std::string> GetParameterStrings() { +// ::std::vector< ::std::string> v; +// v.push_back("a"); +// v.push_back("b"); +// return v; +// } +// +// INSTANTIATE_TEST_SUITE_P(CharSequence, +// StlStringTest, +// ValuesIn(GetParameterStrings())); +// +// +// This will also instantiate tests from CharTest +// each with parameter values 'a' and 'b': +// +// ::std::list GetParameterChars() { +// ::std::list list; +// list.push_back('a'); +// list.push_back('b'); +// return list; +// } +// ::std::list l = GetParameterChars(); +// INSTANTIATE_TEST_SUITE_P(CharSequence2, +// CharTest, +// ValuesIn(l.begin(), l.end())); +// +template +internal::ParamGenerator< + typename std::iterator_traits::value_type> +ValuesIn(ForwardIterator begin, ForwardIterator end) { + typedef typename std::iterator_traits::value_type ParamType; + return internal::ParamGenerator( + new internal::ValuesInIteratorRangeGenerator(begin, end)); +} + +template +internal::ParamGenerator ValuesIn(const T (&array)[N]) { + return ValuesIn(array, array + N); +} + +template +internal::ParamGenerator ValuesIn( + const Container& container) { + return ValuesIn(container.begin(), container.end()); +} + +// Values() allows generating tests from explicitly specified list of +// parameters. +// +// Synopsis: +// Values(T v1, T v2, ..., T vN) +// - returns a generator producing sequences with elements v1, v2, ..., vN. +// +// For example, this instantiates tests from test suite BarTest each +// with values "one", "two", and "three": +// +// INSTANTIATE_TEST_SUITE_P(NumSequence, +// BarTest, +// Values("one", "two", "three")); +// +// This instantiates tests from test suite BazTest each with values 1, 2, 3.5. +// The exact type of values will depend on the type of parameter in BazTest. +// +// INSTANTIATE_TEST_SUITE_P(FloatingNumbers, BazTest, Values(1, 2, 3.5)); +// +// +template +internal::ValueArray Values(T... v) { + return internal::ValueArray(std::move(v)...); +} + +// Bool() allows generating tests with parameters in a set of (false, true). +// +// Synopsis: +// Bool() +// - returns a generator producing sequences with elements {false, true}. +// +// It is useful when testing code that depends on Boolean flags. Combinations +// of multiple flags can be tested when several Bool()'s are combined using +// Combine() function. +// +// In the following example all tests in the test suite FlagDependentTest +// will be instantiated twice with parameters false and true. +// +// class FlagDependentTest : public testing::TestWithParam { +// virtual void SetUp() { +// external_flag = GetParam(); +// } +// } +// INSTANTIATE_TEST_SUITE_P(BoolSequence, FlagDependentTest, Bool()); +// +inline internal::ParamGenerator Bool() { + return Values(false, true); +} + +// Combine() allows the user to combine two or more sequences to produce +// values of a Cartesian product of those sequences' elements. +// +// Synopsis: +// Combine(gen1, gen2, ..., genN) +// - returns a generator producing sequences with elements coming from +// the Cartesian product of elements from the sequences generated by +// gen1, gen2, ..., genN. The sequence elements will have a type of +// std::tuple where T1, T2, ..., TN are the types +// of elements from sequences produces by gen1, gen2, ..., genN. +// +// Example: +// +// This will instantiate tests in test suite AnimalTest each one with +// the parameter values tuple("cat", BLACK), tuple("cat", WHITE), +// tuple("dog", BLACK), and tuple("dog", WHITE): +// +// enum Color { BLACK, GRAY, WHITE }; +// class AnimalTest +// : public testing::TestWithParam > {...}; +// +// TEST_P(AnimalTest, AnimalLooksNice) {...} +// +// INSTANTIATE_TEST_SUITE_P(AnimalVariations, AnimalTest, +// Combine(Values("cat", "dog"), +// Values(BLACK, WHITE))); +// +// This will instantiate tests in FlagDependentTest with all variations of two +// Boolean flags: +// +// class FlagDependentTest +// : public testing::TestWithParam > { +// virtual void SetUp() { +// // Assigns external_flag_1 and external_flag_2 values from the tuple. +// std::tie(external_flag_1, external_flag_2) = GetParam(); +// } +// }; +// +// TEST_P(FlagDependentTest, TestFeature1) { +// // Test your code using external_flag_1 and external_flag_2 here. +// } +// INSTANTIATE_TEST_SUITE_P(TwoBoolSequence, FlagDependentTest, +// Combine(Bool(), Bool())); +// +template +internal::CartesianProductHolder Combine(const Generator&... g) { + return internal::CartesianProductHolder(g...); +} + +#define TEST_P(test_suite_name, test_name) \ + class GTEST_TEST_CLASS_NAME_(test_suite_name, test_name) \ + : public test_suite_name { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)() {} \ + void TestBody() override; \ + \ + private: \ + static int AddToRegistry() { \ + ::testing::UnitTest::GetInstance() \ + ->parameterized_test_registry() \ + .GetTestSuitePatternHolder( \ + GTEST_STRINGIFY_(test_suite_name), \ + ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ + ->AddTestPattern( \ + GTEST_STRINGIFY_(test_suite_name), GTEST_STRINGIFY_(test_name), \ + new ::testing::internal::TestMetaFactory(), \ + ::testing::internal::CodeLocation(__FILE__, __LINE__)); \ + return 0; \ + } \ + static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_suite_name, \ + test_name)); \ + }; \ + int GTEST_TEST_CLASS_NAME_(test_suite_name, \ + test_name)::gtest_registering_dummy_ = \ + GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::AddToRegistry(); \ + void GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::TestBody() + +// The last argument to INSTANTIATE_TEST_SUITE_P allows the user to specify +// generator and an optional function or functor that generates custom test name +// suffixes based on the test parameters. Such a function or functor should +// accept one argument of type testing::TestParamInfo, and +// return std::string. +// +// testing::PrintToStringParamName is a builtin test suffix generator that +// returns the value of testing::PrintToString(GetParam()). +// +// Note: test names must be non-empty, unique, and may only contain ASCII +// alphanumeric characters or underscore. Because PrintToString adds quotes +// to std::string and C strings, it won't work for these types. + +#define GTEST_EXPAND_(arg) arg +#define GTEST_GET_FIRST_(first, ...) first +#define GTEST_GET_SECOND_(first, second, ...) second + +#define INSTANTIATE_TEST_SUITE_P(prefix, test_suite_name, ...) \ + static ::testing::internal::ParamGenerator \ + gtest_##prefix##test_suite_name##_EvalGenerator_() { \ + return GTEST_EXPAND_(GTEST_GET_FIRST_(__VA_ARGS__, DUMMY_PARAM_)); \ + } \ + static ::std::string gtest_##prefix##test_suite_name##_EvalGenerateName_( \ + const ::testing::TestParamInfo& info) { \ + if (::testing::internal::AlwaysFalse()) { \ + ::testing::internal::TestNotEmpty(GTEST_EXPAND_(GTEST_GET_SECOND_( \ + __VA_ARGS__, \ + ::testing::internal::DefaultParamName, \ + DUMMY_PARAM_))); \ + auto t = std::make_tuple(__VA_ARGS__); \ + static_assert(std::tuple_size::value <= 2, \ + "Too Many Args!"); \ + } \ + return ((GTEST_EXPAND_(GTEST_GET_SECOND_( \ + __VA_ARGS__, \ + ::testing::internal::DefaultParamName, \ + DUMMY_PARAM_))))(info); \ + } \ + static int gtest_##prefix##test_suite_name##_dummy_ \ + GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::UnitTest::GetInstance() \ + ->parameterized_test_registry() \ + .GetTestSuitePatternHolder( \ + GTEST_STRINGIFY_(test_suite_name), \ + ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ + ->AddTestSuiteInstantiation( \ + GTEST_STRINGIFY_(prefix), \ + >est_##prefix##test_suite_name##_EvalGenerator_, \ + >est_##prefix##test_suite_name##_EvalGenerateName_, \ + __FILE__, __LINE__) + + +// Allow Marking a Parameterized test class as not needing to be instantiated. +#define GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(T) \ + namespace gtest_do_not_use_outside_namespace_scope {} \ + static const ::testing::internal::MarkAsIgnored gtest_allow_ignore_##T( \ + GTEST_STRINGIFY_(T)) + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +#define INSTANTIATE_TEST_CASE_P \ + static_assert(::testing::internal::InstantiateTestCase_P_IsDeprecated(), \ + ""); \ + INSTANTIATE_TEST_SUITE_P +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-printers.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-printers.h new file mode 100644 index 000000000000..b097e9886df6 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-printers.h @@ -0,0 +1,1050 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Test - The Google C++ Testing and Mocking Framework +// +// This file implements a universal value printer that can print a +// value of any type T: +// +// void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); +// +// A user can teach this function how to print a class type T by +// defining either operator<<() or PrintTo() in the namespace that +// defines T. More specifically, the FIRST defined function in the +// following list will be used (assuming T is defined in namespace +// foo): +// +// 1. foo::PrintTo(const T&, ostream*) +// 2. operator<<(ostream&, const T&) defined in either foo or the +// global namespace. +// +// However if T is an STL-style container then it is printed element-wise +// unless foo::PrintTo(const T&, ostream*) is defined. Note that +// operator<<() is ignored for container types. +// +// If none of the above is defined, it will print the debug string of +// the value if it is a protocol buffer, or print the raw bytes in the +// value otherwise. +// +// To aid debugging: when T is a reference type, the address of the +// value is also printed; when T is a (const) char pointer, both the +// pointer value and the NUL-terminated string it points to are +// printed. +// +// We also provide some convenient wrappers: +// +// // Prints a value to a string. For a (const or not) char +// // pointer, the NUL-terminated string (but not the pointer) is +// // printed. +// std::string ::testing::PrintToString(const T& value); +// +// // Prints a value tersely: for a reference type, the referenced +// // value (but not the address) is printed; for a (const or not) char +// // pointer, the NUL-terminated string (but not the pointer) is +// // printed. +// void ::testing::internal::UniversalTersePrint(const T& value, ostream*); +// +// // Prints value using the type inferred by the compiler. The difference +// // from UniversalTersePrint() is that this function prints both the +// // pointer and the NUL-terminated string for a (const or not) char pointer. +// void ::testing::internal::UniversalPrint(const T& value, ostream*); +// +// // Prints the fields of a tuple tersely to a string vector, one +// // element for each field. Tuple support must be enabled in +// // gtest-port.h. +// std::vector UniversalTersePrintTupleFieldsToStrings( +// const Tuple& value); +// +// Known limitation: +// +// The print primitives print the elements of an STL-style container +// using the compiler-inferred type of *iter where iter is a +// const_iterator of the container. When const_iterator is an input +// iterator but not a forward iterator, this inferred type may not +// match value_type, and the print output may be incorrect. In +// practice, this is rarely a problem as for most containers +// const_iterator is a forward iterator. We'll fix this if there's an +// actual need for it. Note that this fix cannot rely on value_type +// being defined as many user-defined container types don't have +// value_type. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ + +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include + +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-port.h" + +namespace testing { + +// Definitions in the internal* namespaces are subject to change without notice. +// DO NOT USE THEM IN USER CODE! +namespace internal { + +template +void UniversalPrint(const T& value, ::std::ostream* os); + +// Used to print an STL-style container when the user doesn't define +// a PrintTo() for it. +struct ContainerPrinter { + template (0)) == sizeof(IsContainer)) && + !IsRecursiveContainer::value>::type> + static void PrintValue(const T& container, std::ostream* os) { + const size_t kMaxCount = 32; // The maximum number of elements to print. + *os << '{'; + size_t count = 0; + for (auto&& elem : container) { + if (count > 0) { + *os << ','; + if (count == kMaxCount) { // Enough has been printed. + *os << " ..."; + break; + } + } + *os << ' '; + // We cannot call PrintTo(elem, os) here as PrintTo() doesn't + // handle `elem` being a native array. + internal::UniversalPrint(elem, os); + ++count; + } + + if (count > 0) { + *os << ' '; + } + *os << '}'; + } +}; + +// Used to print a pointer that is neither a char pointer nor a member +// pointer, when the user doesn't define PrintTo() for it. (A member +// variable pointer or member function pointer doesn't really point to +// a location in the address space. Their representation is +// implementation-defined. Therefore they will be printed as raw +// bytes.) +struct FunctionPointerPrinter { + template ::value>::type> + static void PrintValue(T* p, ::std::ostream* os) { + if (p == nullptr) { + *os << "NULL"; + } else { + // T is a function type, so '*os << p' doesn't do what we want + // (it just prints p as bool). We want to print p as a const + // void*. + *os << reinterpret_cast(p); + } + } +}; + +struct PointerPrinter { + template + static void PrintValue(T* p, ::std::ostream* os) { + if (p == nullptr) { + *os << "NULL"; + } else { + // T is not a function type. We just call << to print p, + // relying on ADL to pick up user-defined << for their pointer + // types, if any. + *os << p; + } + } +}; + +namespace internal_stream_operator_without_lexical_name_lookup { + +// The presence of an operator<< here will terminate lexical scope lookup +// straight away (even though it cannot be a match because of its argument +// types). Thus, the two operator<< calls in StreamPrinter will find only ADL +// candidates. +struct LookupBlocker {}; +void operator<<(LookupBlocker, LookupBlocker); + +struct StreamPrinter { + template ::value>::type, + // Only accept types for which we can find a streaming operator via + // ADL (possibly involving implicit conversions). + typename = decltype(std::declval() + << std::declval())> + static void PrintValue(const T& value, ::std::ostream* os) { + // Call streaming operator found by ADL, possibly with implicit conversions + // of the arguments. + *os << value; + } +}; + +} // namespace internal_stream_operator_without_lexical_name_lookup + +struct ProtobufPrinter { + // We print a protobuf using its ShortDebugString() when the string + // doesn't exceed this many characters; otherwise we print it using + // DebugString() for better readability. + static const size_t kProtobufOneLinerMaxLength = 50; + + template ::value>::type> + static void PrintValue(const T& value, ::std::ostream* os) { + std::string pretty_str = value.ShortDebugString(); + if (pretty_str.length() > kProtobufOneLinerMaxLength) { + pretty_str = "\n" + value.DebugString(); + } + *os << ("<" + pretty_str + ">"); + } +}; + +struct ConvertibleToIntegerPrinter { + // Since T has no << operator or PrintTo() but can be implicitly + // converted to BiggestInt, we print it as a BiggestInt. + // + // Most likely T is an enum type (either named or unnamed), in which + // case printing it as an integer is the desired behavior. In case + // T is not an enum, printing it as an integer is the best we can do + // given that it has no user-defined printer. + static void PrintValue(internal::BiggestInt value, ::std::ostream* os) { + *os << value; + } +}; + +struct ConvertibleToStringViewPrinter { +#if GTEST_INTERNAL_HAS_STRING_VIEW + static void PrintValue(internal::StringView value, ::std::ostream* os) { + internal::UniversalPrint(value, os); + } +#endif +}; + + +// Prints the given number of bytes in the given object to the given +// ostream. +GTEST_API_ void PrintBytesInObjectTo(const unsigned char* obj_bytes, + size_t count, + ::std::ostream* os); +struct RawBytesPrinter { + // SFINAE on `sizeof` to make sure we have a complete type. + template + static void PrintValue(const T& value, ::std::ostream* os) { + PrintBytesInObjectTo( + static_cast( + // Load bearing cast to void* to support iOS + reinterpret_cast(std::addressof(value))), + sizeof(value), os); + } +}; + +struct FallbackPrinter { + template + static void PrintValue(const T&, ::std::ostream* os) { + *os << "(incomplete type)"; + } +}; + +// Try every printer in order and return the first one that works. +template +struct FindFirstPrinter : FindFirstPrinter {}; + +template +struct FindFirstPrinter< + T, decltype(Printer::PrintValue(std::declval(), nullptr)), + Printer, Printers...> { + using type = Printer; +}; + +// Select the best printer in the following order: +// - Print containers (they have begin/end/etc). +// - Print function pointers. +// - Print object pointers. +// - Use the stream operator, if available. +// - Print protocol buffers. +// - Print types convertible to BiggestInt. +// - Print types convertible to StringView, if available. +// - Fallback to printing the raw bytes of the object. +template +void PrintWithFallback(const T& value, ::std::ostream* os) { + using Printer = typename FindFirstPrinter< + T, void, ContainerPrinter, FunctionPointerPrinter, PointerPrinter, + internal_stream_operator_without_lexical_name_lookup::StreamPrinter, + ProtobufPrinter, ConvertibleToIntegerPrinter, + ConvertibleToStringViewPrinter, RawBytesPrinter, FallbackPrinter>::type; + Printer::PrintValue(value, os); +} + +// FormatForComparison::Format(value) formats a +// value of type ToPrint that is an operand of a comparison assertion +// (e.g. ASSERT_EQ). OtherOperand is the type of the other operand in +// the comparison, and is used to help determine the best way to +// format the value. In particular, when the value is a C string +// (char pointer) and the other operand is an STL string object, we +// want to format the C string as a string, since we know it is +// compared by value with the string object. If the value is a char +// pointer but the other operand is not an STL string object, we don't +// know whether the pointer is supposed to point to a NUL-terminated +// string, and thus want to print it as a pointer to be safe. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + +// The default case. +template +class FormatForComparison { + public: + static ::std::string Format(const ToPrint& value) { + return ::testing::PrintToString(value); + } +}; + +// Array. +template +class FormatForComparison { + public: + static ::std::string Format(const ToPrint* value) { + return FormatForComparison::Format(value); + } +}; + +// By default, print C string as pointers to be safe, as we don't know +// whether they actually point to a NUL-terminated string. + +#define GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(CharType) \ + template \ + class FormatForComparison { \ + public: \ + static ::std::string Format(CharType* value) { \ + return ::testing::PrintToString(static_cast(value)); \ + } \ + } + +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(char); +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(const char); +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(wchar_t); +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(const wchar_t); +#ifdef __cpp_lib_char8_t +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(char8_t); +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(const char8_t); +#endif +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(char16_t); +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(const char16_t); +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(char32_t); +GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_(const char32_t); + +#undef GTEST_IMPL_FORMAT_C_STRING_AS_POINTER_ + +// If a C string is compared with an STL string object, we know it's meant +// to point to a NUL-terminated string, and thus can print it as a string. + +#define GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(CharType, OtherStringType) \ + template <> \ + class FormatForComparison { \ + public: \ + static ::std::string Format(CharType* value) { \ + return ::testing::PrintToString(value); \ + } \ + } + +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(char, ::std::string); +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(const char, ::std::string); +#ifdef __cpp_char8_t +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(char8_t, ::std::u8string); +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(const char8_t, ::std::u8string); +#endif +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(char16_t, ::std::u16string); +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(const char16_t, ::std::u16string); +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(char32_t, ::std::u32string); +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(const char32_t, ::std::u32string); + +#if GTEST_HAS_STD_WSTRING +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(wchar_t, ::std::wstring); +GTEST_IMPL_FORMAT_C_STRING_AS_STRING_(const wchar_t, ::std::wstring); +#endif + +#undef GTEST_IMPL_FORMAT_C_STRING_AS_STRING_ + +// Formats a comparison assertion (e.g. ASSERT_EQ, EXPECT_LT, and etc) +// operand to be used in a failure message. The type (but not value) +// of the other operand may affect the format. This allows us to +// print a char* as a raw pointer when it is compared against another +// char* or void*, and print it as a C string when it is compared +// against an std::string object, for example. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +template +std::string FormatForComparisonFailureMessage( + const T1& value, const T2& /* other_operand */) { + return FormatForComparison::Format(value); +} + +// UniversalPrinter::Print(value, ostream_ptr) prints the given +// value to the given ostream. The caller must ensure that +// 'ostream_ptr' is not NULL, or the behavior is undefined. +// +// We define UniversalPrinter as a class template (as opposed to a +// function template), as we need to partially specialize it for +// reference types, which cannot be done with function templates. +template +class UniversalPrinter; + +// Prints the given value using the << operator if it has one; +// otherwise prints the bytes in it. This is what +// UniversalPrinter::Print() does when PrintTo() is not specialized +// or overloaded for type T. +// +// A user can override this behavior for a class type Foo by defining +// an overload of PrintTo() in the namespace where Foo is defined. We +// give the user this option as sometimes defining a << operator for +// Foo is not desirable (e.g. the coding style may prevent doing it, +// or there is already a << operator but it doesn't do what the user +// wants). +template +void PrintTo(const T& value, ::std::ostream* os) { + internal::PrintWithFallback(value, os); +} + +// The following list of PrintTo() overloads tells +// UniversalPrinter::Print() how to print standard types (built-in +// types, strings, plain arrays, and pointers). + +// Overloads for various char types. +GTEST_API_ void PrintTo(unsigned char c, ::std::ostream* os); +GTEST_API_ void PrintTo(signed char c, ::std::ostream* os); +inline void PrintTo(char c, ::std::ostream* os) { + // When printing a plain char, we always treat it as unsigned. This + // way, the output won't be affected by whether the compiler thinks + // char is signed or not. + PrintTo(static_cast(c), os); +} + +// Overloads for other simple built-in types. +inline void PrintTo(bool x, ::std::ostream* os) { + *os << (x ? "true" : "false"); +} + +// Overload for wchar_t type. +// Prints a wchar_t as a symbol if it is printable or as its internal +// code otherwise and also as its decimal code (except for L'\0'). +// The L'\0' char is printed as "L'\\0'". The decimal code is printed +// as signed integer when wchar_t is implemented by the compiler +// as a signed type and is printed as an unsigned integer when wchar_t +// is implemented as an unsigned type. +GTEST_API_ void PrintTo(wchar_t wc, ::std::ostream* os); + +GTEST_API_ void PrintTo(char32_t c, ::std::ostream* os); +inline void PrintTo(char16_t c, ::std::ostream* os) { + PrintTo(ImplicitCast_(c), os); +} +#ifdef __cpp_char8_t +inline void PrintTo(char8_t c, ::std::ostream* os) { + PrintTo(ImplicitCast_(c), os); +} +#endif + +// gcc/clang __{u,}int128_t +#if defined(__SIZEOF_INT128__) +GTEST_API_ void PrintTo(__uint128_t v, ::std::ostream* os); +GTEST_API_ void PrintTo(__int128_t v, ::std::ostream* os); +#endif // __SIZEOF_INT128__ + +// Overloads for C strings. +GTEST_API_ void PrintTo(const char* s, ::std::ostream* os); +inline void PrintTo(char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} + +// signed/unsigned char is often used for representing binary data, so +// we print pointers to it as void* to be safe. +inline void PrintTo(const signed char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(signed char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(const unsigned char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(unsigned char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +#ifdef __cpp_char8_t +// Overloads for u8 strings. +GTEST_API_ void PrintTo(const char8_t* s, ::std::ostream* os); +inline void PrintTo(char8_t* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +#endif +// Overloads for u16 strings. +GTEST_API_ void PrintTo(const char16_t* s, ::std::ostream* os); +inline void PrintTo(char16_t* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +// Overloads for u32 strings. +GTEST_API_ void PrintTo(const char32_t* s, ::std::ostream* os); +inline void PrintTo(char32_t* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} + +// MSVC can be configured to define wchar_t as a typedef of unsigned +// short. It defines _NATIVE_WCHAR_T_DEFINED when wchar_t is a native +// type. When wchar_t is a typedef, defining an overload for const +// wchar_t* would cause unsigned short* be printed as a wide string, +// possibly causing invalid memory accesses. +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) +// Overloads for wide C strings +GTEST_API_ void PrintTo(const wchar_t* s, ::std::ostream* os); +inline void PrintTo(wchar_t* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +#endif + +// Overload for C arrays. Multi-dimensional arrays are printed +// properly. + +// Prints the given number of elements in an array, without printing +// the curly braces. +template +void PrintRawArrayTo(const T a[], size_t count, ::std::ostream* os) { + UniversalPrint(a[0], os); + for (size_t i = 1; i != count; i++) { + *os << ", "; + UniversalPrint(a[i], os); + } +} + +// Overloads for ::std::string. +GTEST_API_ void PrintStringTo(const ::std::string&s, ::std::ostream* os); +inline void PrintTo(const ::std::string& s, ::std::ostream* os) { + PrintStringTo(s, os); +} + +// Overloads for ::std::u8string +#ifdef __cpp_char8_t +GTEST_API_ void PrintU8StringTo(const ::std::u8string& s, ::std::ostream* os); +inline void PrintTo(const ::std::u8string& s, ::std::ostream* os) { + PrintU8StringTo(s, os); +} +#endif + +// Overloads for ::std::u16string +GTEST_API_ void PrintU16StringTo(const ::std::u16string& s, ::std::ostream* os); +inline void PrintTo(const ::std::u16string& s, ::std::ostream* os) { + PrintU16StringTo(s, os); +} + +// Overloads for ::std::u32string +GTEST_API_ void PrintU32StringTo(const ::std::u32string& s, ::std::ostream* os); +inline void PrintTo(const ::std::u32string& s, ::std::ostream* os) { + PrintU32StringTo(s, os); +} + +// Overloads for ::std::wstring. +#if GTEST_HAS_STD_WSTRING +GTEST_API_ void PrintWideStringTo(const ::std::wstring&s, ::std::ostream* os); +inline void PrintTo(const ::std::wstring& s, ::std::ostream* os) { + PrintWideStringTo(s, os); +} +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_INTERNAL_HAS_STRING_VIEW +// Overload for internal::StringView. +inline void PrintTo(internal::StringView sp, ::std::ostream* os) { + PrintTo(::std::string(sp), os); +} +#endif // GTEST_INTERNAL_HAS_STRING_VIEW + +inline void PrintTo(std::nullptr_t, ::std::ostream* os) { *os << "(nullptr)"; } + +#if GTEST_HAS_RTTI +inline void PrintTo(const std::type_info& info, std::ostream* os) { + *os << internal::GetTypeName(info); +} +#endif // GTEST_HAS_RTTI + +template +void PrintTo(std::reference_wrapper ref, ::std::ostream* os) { + UniversalPrinter::Print(ref.get(), os); +} + +inline const void* VoidifyPointer(const void* p) { return p; } +inline const void* VoidifyPointer(volatile const void* p) { + return const_cast(p); +} + +template +void PrintSmartPointer(const Ptr& ptr, std::ostream* os, char) { + if (ptr == nullptr) { + *os << "(nullptr)"; + } else { + // We can't print the value. Just print the pointer.. + *os << "(" << (VoidifyPointer)(ptr.get()) << ")"; + } +} +template ::value && + !std::is_array::value>::type> +void PrintSmartPointer(const Ptr& ptr, std::ostream* os, int) { + if (ptr == nullptr) { + *os << "(nullptr)"; + } else { + *os << "(ptr = " << (VoidifyPointer)(ptr.get()) << ", value = "; + UniversalPrinter::Print(*ptr, os); + *os << ")"; + } +} + +template +void PrintTo(const std::unique_ptr& ptr, std::ostream* os) { + (PrintSmartPointer)(ptr, os, 0); +} + +template +void PrintTo(const std::shared_ptr& ptr, std::ostream* os) { + (PrintSmartPointer)(ptr, os, 0); +} + +// Helper function for printing a tuple. T must be instantiated with +// a tuple type. +template +void PrintTupleTo(const T&, std::integral_constant, + ::std::ostream*) {} + +template +void PrintTupleTo(const T& t, std::integral_constant, + ::std::ostream* os) { + PrintTupleTo(t, std::integral_constant(), os); + GTEST_INTENTIONAL_CONST_COND_PUSH_() + if (I > 1) { + GTEST_INTENTIONAL_CONST_COND_POP_() + *os << ", "; + } + UniversalPrinter::type>::Print( + std::get(t), os); +} + +template +void PrintTo(const ::std::tuple& t, ::std::ostream* os) { + *os << "("; + PrintTupleTo(t, std::integral_constant(), os); + *os << ")"; +} + +// Overload for std::pair. +template +void PrintTo(const ::std::pair& value, ::std::ostream* os) { + *os << '('; + // We cannot use UniversalPrint(value.first, os) here, as T1 may be + // a reference type. The same for printing value.second. + UniversalPrinter::Print(value.first, os); + *os << ", "; + UniversalPrinter::Print(value.second, os); + *os << ')'; +} + +// Implements printing a non-reference type T by letting the compiler +// pick the right overload of PrintTo() for T. +template +class UniversalPrinter { + public: + // MSVC warns about adding const to a function type, so we want to + // disable the warning. + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4180) + + // Note: we deliberately don't call this PrintTo(), as that name + // conflicts with ::testing::internal::PrintTo in the body of the + // function. + static void Print(const T& value, ::std::ostream* os) { + // By default, ::testing::internal::PrintTo() is used for printing + // the value. + // + // Thanks to Koenig look-up, if T is a class and has its own + // PrintTo() function defined in its namespace, that function will + // be visible here. Since it is more specific than the generic ones + // in ::testing::internal, it will be picked by the compiler in the + // following statement - exactly what we want. + PrintTo(value, os); + } + + GTEST_DISABLE_MSC_WARNINGS_POP_() +}; + +// Remove any const-qualifiers before passing a type to UniversalPrinter. +template +class UniversalPrinter : public UniversalPrinter {}; + +#if GTEST_INTERNAL_HAS_ANY + +// Printer for std::any / absl::any + +template <> +class UniversalPrinter { + public: + static void Print(const Any& value, ::std::ostream* os) { + if (value.has_value()) { + *os << "value of type " << GetTypeName(value); + } else { + *os << "no value"; + } + } + + private: + static std::string GetTypeName(const Any& value) { +#if GTEST_HAS_RTTI + return internal::GetTypeName(value.type()); +#else + static_cast(value); // possibly unused + return ""; +#endif // GTEST_HAS_RTTI + } +}; + +#endif // GTEST_INTERNAL_HAS_ANY + +#if GTEST_INTERNAL_HAS_OPTIONAL + +// Printer for std::optional / absl::optional + +template +class UniversalPrinter> { + public: + static void Print(const Optional& value, ::std::ostream* os) { + *os << '('; + if (!value) { + *os << "nullopt"; + } else { + UniversalPrint(*value, os); + } + *os << ')'; + } +}; + +template <> +class UniversalPrinter { + public: + static void Print(decltype(Nullopt()), ::std::ostream* os) { + *os << "(nullopt)"; + } +}; + +#endif // GTEST_INTERNAL_HAS_OPTIONAL + +#if GTEST_INTERNAL_HAS_VARIANT + +// Printer for std::variant / absl::variant + +template +class UniversalPrinter> { + public: + static void Print(const Variant& value, ::std::ostream* os) { + *os << '('; +#if GTEST_HAS_ABSL + absl::visit(Visitor{os, value.index()}, value); +#else + std::visit(Visitor{os, value.index()}, value); +#endif // GTEST_HAS_ABSL + *os << ')'; + } + + private: + struct Visitor { + template + void operator()(const U& u) const { + *os << "'" << GetTypeName() << "(index = " << index + << ")' with value "; + UniversalPrint(u, os); + } + ::std::ostream* os; + std::size_t index; + }; +}; + +#endif // GTEST_INTERNAL_HAS_VARIANT + +// UniversalPrintArray(begin, len, os) prints an array of 'len' +// elements, starting at address 'begin'. +template +void UniversalPrintArray(const T* begin, size_t len, ::std::ostream* os) { + if (len == 0) { + *os << "{}"; + } else { + *os << "{ "; + const size_t kThreshold = 18; + const size_t kChunkSize = 8; + // If the array has more than kThreshold elements, we'll have to + // omit some details by printing only the first and the last + // kChunkSize elements. + if (len <= kThreshold) { + PrintRawArrayTo(begin, len, os); + } else { + PrintRawArrayTo(begin, kChunkSize, os); + *os << ", ..., "; + PrintRawArrayTo(begin + len - kChunkSize, kChunkSize, os); + } + *os << " }"; + } +} +// This overload prints a (const) char array compactly. +GTEST_API_ void UniversalPrintArray( + const char* begin, size_t len, ::std::ostream* os); + +#ifdef __cpp_char8_t +// This overload prints a (const) char8_t array compactly. +GTEST_API_ void UniversalPrintArray(const char8_t* begin, size_t len, + ::std::ostream* os); +#endif + +// This overload prints a (const) char16_t array compactly. +GTEST_API_ void UniversalPrintArray(const char16_t* begin, size_t len, + ::std::ostream* os); + +// This overload prints a (const) char32_t array compactly. +GTEST_API_ void UniversalPrintArray(const char32_t* begin, size_t len, + ::std::ostream* os); + +// This overload prints a (const) wchar_t array compactly. +GTEST_API_ void UniversalPrintArray( + const wchar_t* begin, size_t len, ::std::ostream* os); + +// Implements printing an array type T[N]. +template +class UniversalPrinter { + public: + // Prints the given array, omitting some elements when there are too + // many. + static void Print(const T (&a)[N], ::std::ostream* os) { + UniversalPrintArray(a, N, os); + } +}; + +// Implements printing a reference type T&. +template +class UniversalPrinter { + public: + // MSVC warns about adding const to a function type, so we want to + // disable the warning. + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4180) + + static void Print(const T& value, ::std::ostream* os) { + // Prints the address of the value. We use reinterpret_cast here + // as static_cast doesn't compile when T is a function type. + *os << "@" << reinterpret_cast(&value) << " "; + + // Then prints the value itself. + UniversalPrint(value, os); + } + + GTEST_DISABLE_MSC_WARNINGS_POP_() +}; + +// Prints a value tersely: for a reference type, the referenced value +// (but not the address) is printed; for a (const) char pointer, the +// NUL-terminated string (but not the pointer) is printed. + +template +class UniversalTersePrinter { + public: + static void Print(const T& value, ::std::ostream* os) { + UniversalPrint(value, os); + } +}; +template +class UniversalTersePrinter { + public: + static void Print(const T& value, ::std::ostream* os) { + UniversalPrint(value, os); + } +}; +template +class UniversalTersePrinter { + public: + static void Print(const T (&value)[N], ::std::ostream* os) { + UniversalPrinter::Print(value, os); + } +}; +template <> +class UniversalTersePrinter { + public: + static void Print(const char* str, ::std::ostream* os) { + if (str == nullptr) { + *os << "NULL"; + } else { + UniversalPrint(std::string(str), os); + } + } +}; +template <> +class UniversalTersePrinter : public UniversalTersePrinter { +}; + +#ifdef __cpp_char8_t +template <> +class UniversalTersePrinter { + public: + static void Print(const char8_t* str, ::std::ostream* os) { + if (str == nullptr) { + *os << "NULL"; + } else { + UniversalPrint(::std::u8string(str), os); + } + } +}; +template <> +class UniversalTersePrinter + : public UniversalTersePrinter {}; +#endif + +template <> +class UniversalTersePrinter { + public: + static void Print(const char16_t* str, ::std::ostream* os) { + if (str == nullptr) { + *os << "NULL"; + } else { + UniversalPrint(::std::u16string(str), os); + } + } +}; +template <> +class UniversalTersePrinter + : public UniversalTersePrinter {}; + +template <> +class UniversalTersePrinter { + public: + static void Print(const char32_t* str, ::std::ostream* os) { + if (str == nullptr) { + *os << "NULL"; + } else { + UniversalPrint(::std::u32string(str), os); + } + } +}; +template <> +class UniversalTersePrinter + : public UniversalTersePrinter {}; + +#if GTEST_HAS_STD_WSTRING +template <> +class UniversalTersePrinter { + public: + static void Print(const wchar_t* str, ::std::ostream* os) { + if (str == nullptr) { + *os << "NULL"; + } else { + UniversalPrint(::std::wstring(str), os); + } + } +}; +#endif + +template <> +class UniversalTersePrinter { + public: + static void Print(wchar_t* str, ::std::ostream* os) { + UniversalTersePrinter::Print(str, os); + } +}; + +template +void UniversalTersePrint(const T& value, ::std::ostream* os) { + UniversalTersePrinter::Print(value, os); +} + +// Prints a value using the type inferred by the compiler. The +// difference between this and UniversalTersePrint() is that for a +// (const) char pointer, this prints both the pointer and the +// NUL-terminated string. +template +void UniversalPrint(const T& value, ::std::ostream* os) { + // A workarond for the bug in VC++ 7.1 that prevents us from instantiating + // UniversalPrinter with T directly. + typedef T T1; + UniversalPrinter::Print(value, os); +} + +typedef ::std::vector< ::std::string> Strings; + + // Tersely prints the first N fields of a tuple to a string vector, + // one element for each field. +template +void TersePrintPrefixToStrings(const Tuple&, std::integral_constant, + Strings*) {} +template +void TersePrintPrefixToStrings(const Tuple& t, + std::integral_constant, + Strings* strings) { + TersePrintPrefixToStrings(t, std::integral_constant(), + strings); + ::std::stringstream ss; + UniversalTersePrint(std::get(t), &ss); + strings->push_back(ss.str()); +} + +// Prints the fields of a tuple tersely to a string vector, one +// element for each field. See the comment before +// UniversalTersePrint() for how we define "tersely". +template +Strings UniversalTersePrintTupleFieldsToStrings(const Tuple& value) { + Strings result; + TersePrintPrefixToStrings( + value, std::integral_constant::value>(), + &result); + return result; +} + +} // namespace internal + +template +::std::string PrintToString(const T& value) { + ::std::stringstream ss; + internal::UniversalTersePrinter::Print(value, &ss); + return ss.str(); +} + +} // namespace testing + +// Include any custom printer added by the local installation. +// We must include this header at the end to make sure it can use the +// declarations from this file. +#include "gtest/internal/custom/gtest-printers.h" + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-spi.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-spi.h new file mode 100644 index 000000000000..12c94e4a28da --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-spi.h @@ -0,0 +1,235 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Utilities for testing Google Test itself and code that uses Google Test +// (e.g. frameworks built on top of Google Test). + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_SPI_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_SPI_H_ + +#include "gtest/gtest.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +namespace testing { + +// This helper class can be used to mock out Google Test failure reporting +// so that we can test Google Test or code that builds on Google Test. +// +// An object of this class appends a TestPartResult object to the +// TestPartResultArray object given in the constructor whenever a Google Test +// failure is reported. It can either intercept only failures that are +// generated in the same thread that created this object or it can intercept +// all generated failures. The scope of this mock object can be controlled with +// the second argument to the two arguments constructor. +class GTEST_API_ ScopedFakeTestPartResultReporter + : public TestPartResultReporterInterface { + public: + // The two possible mocking modes of this object. + enum InterceptMode { + INTERCEPT_ONLY_CURRENT_THREAD, // Intercepts only thread local failures. + INTERCEPT_ALL_THREADS // Intercepts all failures. + }; + + // The c'tor sets this object as the test part result reporter used + // by Google Test. The 'result' parameter specifies where to report the + // results. This reporter will only catch failures generated in the current + // thread. DEPRECATED + explicit ScopedFakeTestPartResultReporter(TestPartResultArray* result); + + // Same as above, but you can choose the interception scope of this object. + ScopedFakeTestPartResultReporter(InterceptMode intercept_mode, + TestPartResultArray* result); + + // The d'tor restores the previous test part result reporter. + ~ScopedFakeTestPartResultReporter() override; + + // Appends the TestPartResult object to the TestPartResultArray + // received in the constructor. + // + // This method is from the TestPartResultReporterInterface + // interface. + void ReportTestPartResult(const TestPartResult& result) override; + + private: + void Init(); + + const InterceptMode intercept_mode_; + TestPartResultReporterInterface* old_reporter_; + TestPartResultArray* const result_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedFakeTestPartResultReporter); +}; + +namespace internal { + +// A helper class for implementing EXPECT_FATAL_FAILURE() and +// EXPECT_NONFATAL_FAILURE(). Its destructor verifies that the given +// TestPartResultArray contains exactly one failure that has the given +// type and contains the given substring. If that's not the case, a +// non-fatal failure will be generated. +class GTEST_API_ SingleFailureChecker { + public: + // The constructor remembers the arguments. + SingleFailureChecker(const TestPartResultArray* results, + TestPartResult::Type type, const std::string& substr); + ~SingleFailureChecker(); + private: + const TestPartResultArray* const results_; + const TestPartResult::Type type_; + const std::string substr_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(SingleFailureChecker); +}; + +} // namespace internal + +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +// A set of macros for testing Google Test assertions or code that's expected +// to generate Google Test fatal failures. It verifies that the given +// statement will cause exactly one fatal Google Test failure with 'substr' +// being part of the failure message. +// +// There are two different versions of this macro. EXPECT_FATAL_FAILURE only +// affects and considers failures generated in the current thread and +// EXPECT_FATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. +// +// The verification of the assertion is done correctly even when the statement +// throws an exception or aborts the current function. +// +// Known restrictions: +// - 'statement' cannot reference local non-static variables or +// non-static members of the current object. +// - 'statement' cannot return a value. +// - You cannot stream a failure message to this macro. +// +// Note that even though the implementations of the following two +// macros are much alike, we cannot refactor them to use a common +// helper macro, due to some peculiarity in how the preprocessor +// works. The AcceptsMacroThatExpandsToUnprotectedComma test in +// gtest_unittest.cc will fail to compile if we do that. +#define EXPECT_FATAL_FAILURE(statement, substr) \ + do { \ + class GTestExpectFatalFailureHelper {\ + public:\ + static void Execute() { statement; }\ + };\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ + GTestExpectFatalFailureHelper::Execute();\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#define EXPECT_FATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ + do { \ + class GTestExpectFatalFailureHelper {\ + public:\ + static void Execute() { statement; }\ + };\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ALL_THREADS, >est_failures);\ + GTestExpectFatalFailureHelper::Execute();\ + }\ + } while (::testing::internal::AlwaysFalse()) + +// A macro for testing Google Test assertions or code that's expected to +// generate Google Test non-fatal failures. It asserts that the given +// statement will cause exactly one non-fatal Google Test failure with 'substr' +// being part of the failure message. +// +// There are two different versions of this macro. EXPECT_NONFATAL_FAILURE only +// affects and considers failures generated in the current thread and +// EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. +// +// 'statement' is allowed to reference local variables and members of +// the current object. +// +// The verification of the assertion is done correctly even when the statement +// throws an exception or aborts the current function. +// +// Known restrictions: +// - You cannot stream a failure message to this macro. +// +// Note that even though the implementations of the following two +// macros are much alike, we cannot refactor them to use a common +// helper macro, due to some peculiarity in how the preprocessor +// works. If we do that, the code won't compile when the user gives +// EXPECT_NONFATAL_FAILURE() a statement that contains a macro that +// expands to code containing an unprotected comma. The +// AcceptsMacroThatExpandsToUnprotectedComma test in gtest_unittest.cc +// catches that. +// +// For the same reason, we have to write +// if (::testing::internal::AlwaysTrue()) { statement; } +// instead of +// GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) +// to avoid an MSVC warning on unreachable code. +#define EXPECT_NONFATAL_FAILURE(statement, substr) \ + do {\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ + (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ + if (::testing::internal::AlwaysTrue()) { statement; }\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#define EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ + do {\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ + (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter::INTERCEPT_ALL_THREADS, \ + >est_failures);\ + if (::testing::internal::AlwaysTrue()) { statement; }\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_SPI_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-test-part.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-test-part.h new file mode 100644 index 000000000000..39393b212cfa --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-test-part.h @@ -0,0 +1,186 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ + +#include +#include +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-string.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +namespace testing { + +// A copyable object representing the result of a test part (i.e. an +// assertion or an explicit FAIL(), ADD_FAILURE(), or SUCCESS()). +// +// Don't inherit from TestPartResult as its destructor is not virtual. +class GTEST_API_ TestPartResult { + public: + // The possible outcomes of a test part (i.e. an assertion or an + // explicit SUCCEED(), FAIL(), or ADD_FAILURE()). + enum Type { + kSuccess, // Succeeded. + kNonFatalFailure, // Failed but the test can continue. + kFatalFailure, // Failed and the test should be terminated. + kSkip // Skipped. + }; + + // C'tor. TestPartResult does NOT have a default constructor. + // Always use this constructor (with parameters) to create a + // TestPartResult object. + TestPartResult(Type a_type, const char* a_file_name, int a_line_number, + const char* a_message) + : type_(a_type), + file_name_(a_file_name == nullptr ? "" : a_file_name), + line_number_(a_line_number), + summary_(ExtractSummary(a_message)), + message_(a_message) {} + + // Gets the outcome of the test part. + Type type() const { return type_; } + + // Gets the name of the source file where the test part took place, or + // NULL if it's unknown. + const char* file_name() const { + return file_name_.empty() ? nullptr : file_name_.c_str(); + } + + // Gets the line in the source file where the test part took place, + // or -1 if it's unknown. + int line_number() const { return line_number_; } + + // Gets the summary of the failure message. + const char* summary() const { return summary_.c_str(); } + + // Gets the message associated with the test part. + const char* message() const { return message_.c_str(); } + + // Returns true if and only if the test part was skipped. + bool skipped() const { return type_ == kSkip; } + + // Returns true if and only if the test part passed. + bool passed() const { return type_ == kSuccess; } + + // Returns true if and only if the test part non-fatally failed. + bool nonfatally_failed() const { return type_ == kNonFatalFailure; } + + // Returns true if and only if the test part fatally failed. + bool fatally_failed() const { return type_ == kFatalFailure; } + + // Returns true if and only if the test part failed. + bool failed() const { return fatally_failed() || nonfatally_failed(); } + + private: + Type type_; + + // Gets the summary of the failure message by omitting the stack + // trace in it. + static std::string ExtractSummary(const char* message); + + // The name of the source file where the test part took place, or + // "" if the source file is unknown. + std::string file_name_; + // The line in the source file where the test part took place, or -1 + // if the line number is unknown. + int line_number_; + std::string summary_; // The test failure summary. + std::string message_; // The test failure message. +}; + +// Prints a TestPartResult object. +std::ostream& operator<<(std::ostream& os, const TestPartResult& result); + +// An array of TestPartResult objects. +// +// Don't inherit from TestPartResultArray as its destructor is not +// virtual. +class GTEST_API_ TestPartResultArray { + public: + TestPartResultArray() {} + + // Appends the given TestPartResult to the array. + void Append(const TestPartResult& result); + + // Returns the TestPartResult at the given index (0-based). + const TestPartResult& GetTestPartResult(int index) const; + + // Returns the number of TestPartResult objects in the array. + int size() const; + + private: + std::vector array_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestPartResultArray); +}; + +// This interface knows how to report a test part result. +class GTEST_API_ TestPartResultReporterInterface { + public: + virtual ~TestPartResultReporterInterface() {} + + virtual void ReportTestPartResult(const TestPartResult& result) = 0; +}; + +namespace internal { + +// This helper class is used by {ASSERT|EXPECT}_NO_FATAL_FAILURE to check if a +// statement generates new fatal failures. To do so it registers itself as the +// current test part result reporter. Besides checking if fatal failures were +// reported, it only delegates the reporting to the former result reporter. +// The original result reporter is restored in the destructor. +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +class GTEST_API_ HasNewFatalFailureHelper + : public TestPartResultReporterInterface { + public: + HasNewFatalFailureHelper(); + ~HasNewFatalFailureHelper() override; + void ReportTestPartResult(const TestPartResult& result) override; + bool has_new_fatal_failure() const { return has_new_fatal_failure_; } + private: + bool has_new_fatal_failure_; + TestPartResultReporterInterface* original_reporter_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(HasNewFatalFailureHelper); +}; + +} // namespace internal + +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-typed-test.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-typed-test.h new file mode 100644 index 000000000000..343bf6fe98c1 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest-typed-test.h @@ -0,0 +1,331 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ + +// This header implements typed tests and type-parameterized tests. + +// Typed (aka type-driven) tests repeat the same test for types in a +// list. You must know which types you want to test with when writing +// typed tests. Here's how you do it: + +#if 0 + +// First, define a fixture class template. It should be parameterized +// by a type. Remember to derive it from testing::Test. +template +class FooTest : public testing::Test { + public: + ... + typedef std::list List; + static T shared_; + T value_; +}; + +// Next, associate a list of types with the test suite, which will be +// repeated for each type in the list. The typedef is necessary for +// the macro to parse correctly. +typedef testing::Types MyTypes; +TYPED_TEST_SUITE(FooTest, MyTypes); + +// If the type list contains only one type, you can write that type +// directly without Types<...>: +// TYPED_TEST_SUITE(FooTest, int); + +// Then, use TYPED_TEST() instead of TEST_F() to define as many typed +// tests for this test suite as you want. +TYPED_TEST(FooTest, DoesBlah) { + // Inside a test, refer to the special name TypeParam to get the type + // parameter. Since we are inside a derived class template, C++ requires + // us to visit the members of FooTest via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the TestFixture:: + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the "typename + // TestFixture::" prefix. + typename TestFixture::List values; + values.push_back(n); + ... +} + +TYPED_TEST(FooTest, HasPropertyA) { ... } + +// TYPED_TEST_SUITE takes an optional third argument which allows to specify a +// class that generates custom test name suffixes based on the type. This should +// be a class which has a static template function GetName(int index) returning +// a string for each type. The provided integer index equals the index of the +// type in the provided type list. In many cases the index can be ignored. +// +// For example: +// class MyTypeNames { +// public: +// template +// static std::string GetName(int) { +// if (std::is_same()) return "char"; +// if (std::is_same()) return "int"; +// if (std::is_same()) return "unsignedInt"; +// } +// }; +// TYPED_TEST_SUITE(FooTest, MyTypes, MyTypeNames); + +#endif // 0 + +// Type-parameterized tests are abstract test patterns parameterized +// by a type. Compared with typed tests, type-parameterized tests +// allow you to define the test pattern without knowing what the type +// parameters are. The defined pattern can be instantiated with +// different types any number of times, in any number of translation +// units. +// +// If you are designing an interface or concept, you can define a +// suite of type-parameterized tests to verify properties that any +// valid implementation of the interface/concept should have. Then, +// each implementation can easily instantiate the test suite to verify +// that it conforms to the requirements, without having to write +// similar tests repeatedly. Here's an example: + +#if 0 + +// First, define a fixture class template. It should be parameterized +// by a type. Remember to derive it from testing::Test. +template +class FooTest : public testing::Test { + ... +}; + +// Next, declare that you will define a type-parameterized test suite +// (the _P suffix is for "parameterized" or "pattern", whichever you +// prefer): +TYPED_TEST_SUITE_P(FooTest); + +// Then, use TYPED_TEST_P() to define as many type-parameterized tests +// for this type-parameterized test suite as you want. +TYPED_TEST_P(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + TypeParam n = 0; + ... +} + +TYPED_TEST_P(FooTest, HasPropertyA) { ... } + +// Now the tricky part: you need to register all test patterns before +// you can instantiate them. The first argument of the macro is the +// test suite name; the rest are the names of the tests in this test +// case. +REGISTER_TYPED_TEST_SUITE_P(FooTest, + DoesBlah, HasPropertyA); + +// Finally, you are free to instantiate the pattern with the types you +// want. If you put the above code in a header file, you can #include +// it in multiple C++ source files and instantiate it multiple times. +// +// To distinguish different instances of the pattern, the first +// argument to the INSTANTIATE_* macro is a prefix that will be added +// to the actual test suite name. Remember to pick unique prefixes for +// different instances. +typedef testing::Types MyTypes; +INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, MyTypes); + +// If the type list contains only one type, you can write that type +// directly without Types<...>: +// INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, int); +// +// Similar to the optional argument of TYPED_TEST_SUITE above, +// INSTANTIATE_TEST_SUITE_P takes an optional fourth argument which allows to +// generate custom names. +// INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, MyTypes, MyTypeNames); + +#endif // 0 + +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-port.h" +#include "gtest/internal/gtest-type-util.h" + +// Implements typed tests. + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the name of the typedef for the type parameters of the +// given test suite. +#define GTEST_TYPE_PARAMS_(TestSuiteName) gtest_type_params_##TestSuiteName##_ + +// Expands to the name of the typedef for the NameGenerator, responsible for +// creating the suffixes of the name. +#define GTEST_NAME_GENERATOR_(TestSuiteName) \ + gtest_type_params_##TestSuiteName##_NameGenerator + +#define TYPED_TEST_SUITE(CaseName, Types, ...) \ + typedef ::testing::internal::GenerateTypeList::type \ + GTEST_TYPE_PARAMS_(CaseName); \ + typedef ::testing::internal::NameGeneratorSelector<__VA_ARGS__>::type \ + GTEST_NAME_GENERATOR_(CaseName) + +#define TYPED_TEST(CaseName, TestName) \ + static_assert(sizeof(GTEST_STRINGIFY_(TestName)) > 1, \ + "test-name must not be empty"); \ + template \ + class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ + : public CaseName { \ + private: \ + typedef CaseName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + void TestBody() override; \ + }; \ + static bool gtest_##CaseName##_##TestName##_registered_ \ + GTEST_ATTRIBUTE_UNUSED_ = ::testing::internal::TypeParameterizedTest< \ + CaseName, \ + ::testing::internal::TemplateSel, \ + GTEST_TYPE_PARAMS_( \ + CaseName)>::Register("", \ + ::testing::internal::CodeLocation( \ + __FILE__, __LINE__), \ + GTEST_STRINGIFY_(CaseName), \ + GTEST_STRINGIFY_(TestName), 0, \ + ::testing::internal::GenerateNames< \ + GTEST_NAME_GENERATOR_(CaseName), \ + GTEST_TYPE_PARAMS_(CaseName)>()); \ + template \ + void GTEST_TEST_CLASS_NAME_(CaseName, \ + TestName)::TestBody() + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +#define TYPED_TEST_CASE \ + static_assert(::testing::internal::TypedTestCaseIsDeprecated(), ""); \ + TYPED_TEST_SUITE +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +// Implements type-parameterized tests. + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the namespace name that the type-parameterized tests for +// the given type-parameterized test suite are defined in. The exact +// name of the namespace is subject to change without notice. +#define GTEST_SUITE_NAMESPACE_(TestSuiteName) gtest_suite_##TestSuiteName##_ + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the name of the variable used to remember the names of +// the defined tests in the given test suite. +#define GTEST_TYPED_TEST_SUITE_P_STATE_(TestSuiteName) \ + gtest_typed_test_suite_p_state_##TestSuiteName##_ + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE DIRECTLY. +// +// Expands to the name of the variable used to remember the names of +// the registered tests in the given test suite. +#define GTEST_REGISTERED_TEST_NAMES_(TestSuiteName) \ + gtest_registered_test_names_##TestSuiteName##_ + +// The variables defined in the type-parameterized test macros are +// static as typically these macros are used in a .h file that can be +// #included in multiple translation units linked together. +#define TYPED_TEST_SUITE_P(SuiteName) \ + static ::testing::internal::TypedTestSuitePState \ + GTEST_TYPED_TEST_SUITE_P_STATE_(SuiteName) + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +#define TYPED_TEST_CASE_P \ + static_assert(::testing::internal::TypedTestCase_P_IsDeprecated(), ""); \ + TYPED_TEST_SUITE_P +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +#define TYPED_TEST_P(SuiteName, TestName) \ + namespace GTEST_SUITE_NAMESPACE_(SuiteName) { \ + template \ + class TestName : public SuiteName { \ + private: \ + typedef SuiteName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + void TestBody() override; \ + }; \ + static bool gtest_##TestName##_defined_ GTEST_ATTRIBUTE_UNUSED_ = \ + GTEST_TYPED_TEST_SUITE_P_STATE_(SuiteName).AddTestName( \ + __FILE__, __LINE__, GTEST_STRINGIFY_(SuiteName), \ + GTEST_STRINGIFY_(TestName)); \ + } \ + template \ + void GTEST_SUITE_NAMESPACE_( \ + SuiteName)::TestName::TestBody() + +// Note: this won't work correctly if the trailing arguments are macros. +#define REGISTER_TYPED_TEST_SUITE_P(SuiteName, ...) \ + namespace GTEST_SUITE_NAMESPACE_(SuiteName) { \ + typedef ::testing::internal::Templates<__VA_ARGS__> gtest_AllTests_; \ + } \ + static const char* const GTEST_REGISTERED_TEST_NAMES_( \ + SuiteName) GTEST_ATTRIBUTE_UNUSED_ = \ + GTEST_TYPED_TEST_SUITE_P_STATE_(SuiteName).VerifyRegisteredTestNames( \ + GTEST_STRINGIFY_(SuiteName), __FILE__, __LINE__, #__VA_ARGS__) + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +#define REGISTER_TYPED_TEST_CASE_P \ + static_assert(::testing::internal::RegisterTypedTestCase_P_IsDeprecated(), \ + ""); \ + REGISTER_TYPED_TEST_SUITE_P +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +#define INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, SuiteName, Types, ...) \ + static_assert(sizeof(GTEST_STRINGIFY_(Prefix)) > 1, \ + "test-suit-prefix must not be empty"); \ + static bool gtest_##Prefix##_##SuiteName GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::internal::TypeParameterizedTestSuite< \ + SuiteName, GTEST_SUITE_NAMESPACE_(SuiteName)::gtest_AllTests_, \ + ::testing::internal::GenerateTypeList::type>:: \ + Register(GTEST_STRINGIFY_(Prefix), \ + ::testing::internal::CodeLocation(__FILE__, __LINE__), \ + >EST_TYPED_TEST_SUITE_P_STATE_(SuiteName), \ + GTEST_STRINGIFY_(SuiteName), \ + GTEST_REGISTERED_TEST_NAMES_(SuiteName), \ + ::testing::internal::GenerateNames< \ + ::testing::internal::NameGeneratorSelector< \ + __VA_ARGS__>::type, \ + ::testing::internal::GenerateTypeList::type>()) + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +#define INSTANTIATE_TYPED_TEST_CASE_P \ + static_assert( \ + ::testing::internal::InstantiateTypedTestCase_P_IsDeprecated(), ""); \ + INSTANTIATE_TYPED_TEST_SUITE_P +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest.h new file mode 100644 index 000000000000..a4174cd4e535 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest.h @@ -0,0 +1,2316 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file defines the public API for Google Test. It should be +// included by any test program that uses Google Test. +// +// IMPORTANT NOTE: Due to limitation of the C++ language, we have to +// leave some internal implementation details in this header file. +// They are clearly marked by comments like this: +// +// // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +// +// Such code is NOT meant to be used by a user directly, and is subject +// to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user +// program! +// +// Acknowledgment: Google Test borrowed the idea of automatic test +// registration from Barthelemy Dagenais' (barthelemy@prologique.com) +// easyUnit framework. + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_H_ + +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest-assertion-result.h" +#include "gtest/gtest-death-test.h" +#include "gtest/gtest-matchers.h" +#include "gtest/gtest-message.h" +#include "gtest/gtest-param-test.h" +#include "gtest/gtest-printers.h" +#include "gtest/gtest-test-part.h" +#include "gtest/gtest-typed-test.h" +#include "gtest/gtest_pred_impl.h" +#include "gtest/gtest_prod.h" +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-string.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +// Declares the flags. + +// This flag temporary enables the disabled tests. +GTEST_DECLARE_bool_(also_run_disabled_tests); + +// This flag brings the debugger on an assertion failure. +GTEST_DECLARE_bool_(break_on_failure); + +// This flag controls whether Google Test catches all test-thrown exceptions +// and logs them as failures. +GTEST_DECLARE_bool_(catch_exceptions); + +// This flag enables using colors in terminal output. Available values are +// "yes" to enable colors, "no" (disable colors), or "auto" (the default) +// to let Google Test decide. +GTEST_DECLARE_string_(color); + +// This flag controls whether the test runner should continue execution past +// first failure. +GTEST_DECLARE_bool_(fail_fast); + +// This flag sets up the filter to select by name using a glob pattern +// the tests to run. If the filter is not given all tests are executed. +GTEST_DECLARE_string_(filter); + +// This flag controls whether Google Test installs a signal handler that dumps +// debugging information when fatal signals are raised. +GTEST_DECLARE_bool_(install_failure_signal_handler); + +// This flag causes the Google Test to list tests. None of the tests listed +// are actually run if the flag is provided. +GTEST_DECLARE_bool_(list_tests); + +// This flag controls whether Google Test emits a detailed XML report to a file +// in addition to its normal textual output. +GTEST_DECLARE_string_(output); + +// This flags control whether Google Test prints only test failures. +GTEST_DECLARE_bool_(brief); + +// This flags control whether Google Test prints the elapsed time for each +// test. +GTEST_DECLARE_bool_(print_time); + +// This flags control whether Google Test prints UTF8 characters as text. +GTEST_DECLARE_bool_(print_utf8); + +// This flag specifies the random number seed. +GTEST_DECLARE_int32_(random_seed); + +// This flag sets how many times the tests are repeated. The default value +// is 1. If the value is -1 the tests are repeating forever. +GTEST_DECLARE_int32_(repeat); + +// This flag controls whether Google Test Environments are recreated for each +// repeat of the tests. The default value is true. If set to false the global +// test Environment objects are only set up once, for the first iteration, and +// only torn down once, for the last. +GTEST_DECLARE_bool_(recreate_environments_when_repeating); + +// This flag controls whether Google Test includes Google Test internal +// stack frames in failure stack traces. +GTEST_DECLARE_bool_(show_internal_stack_frames); + +// When this flag is specified, tests' order is randomized on every iteration. +GTEST_DECLARE_bool_(shuffle); + +// This flag specifies the maximum number of stack frames to be +// printed in a failure message. +GTEST_DECLARE_int32_(stack_trace_depth); + +// When this flag is specified, a failed assertion will throw an +// exception if exceptions are enabled, or exit the program with a +// non-zero code otherwise. For use with an external test framework. +GTEST_DECLARE_bool_(throw_on_failure); + +// When this flag is set with a "host:port" string, on supported +// platforms test results are streamed to the specified port on +// the specified host machine. +GTEST_DECLARE_string_(stream_result_to); + +#if GTEST_USE_OWN_FLAGFILE_FLAG_ +GTEST_DECLARE_string_(flagfile); +#endif // GTEST_USE_OWN_FLAGFILE_FLAG_ + +namespace testing { + +// Silence C4100 (unreferenced formal parameter) and 4805 +// unsafe mix of type 'const int' and type 'const bool' +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4805) +#pragma warning(disable : 4100) +#endif + +// The upper limit for valid stack trace depths. +const int kMaxStackTraceDepth = 100; + +namespace internal { + +class AssertHelper; +class DefaultGlobalTestPartResultReporter; +class ExecDeathTest; +class NoExecDeathTest; +class FinalSuccessChecker; +class GTestFlagSaver; +class StreamingListenerTest; +class TestResultAccessor; +class TestEventListenersAccessor; +class TestEventRepeater; +class UnitTestRecordPropertyTestHelper; +class WindowsDeathTest; +class FuchsiaDeathTest; +class UnitTestImpl* GetUnitTestImpl(); +void ReportFailureInUnknownLocation(TestPartResult::Type result_type, + const std::string& message); +std::set* GetIgnoredParameterizedTestSuites(); + +} // namespace internal + +// The friend relationship of some of these classes is cyclic. +// If we don't forward declare them the compiler might confuse the classes +// in friendship clauses with same named classes on the scope. +class Test; +class TestSuite; + +// Old API is still available but deprecated +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +using TestCase = TestSuite; +#endif +class TestInfo; +class UnitTest; + +// The abstract class that all tests inherit from. +// +// In Google Test, a unit test program contains one or many TestSuites, and +// each TestSuite contains one or many Tests. +// +// When you define a test using the TEST macro, you don't need to +// explicitly derive from Test - the TEST macro automatically does +// this for you. +// +// The only time you derive from Test is when defining a test fixture +// to be used in a TEST_F. For example: +// +// class FooTest : public testing::Test { +// protected: +// void SetUp() override { ... } +// void TearDown() override { ... } +// ... +// }; +// +// TEST_F(FooTest, Bar) { ... } +// TEST_F(FooTest, Baz) { ... } +// +// Test is not copyable. +class GTEST_API_ Test { + public: + friend class TestInfo; + + // The d'tor is virtual as we intend to inherit from Test. + virtual ~Test(); + + // Sets up the stuff shared by all tests in this test suite. + // + // Google Test will call Foo::SetUpTestSuite() before running the first + // test in test suite Foo. Hence a sub-class can define its own + // SetUpTestSuite() method to shadow the one defined in the super + // class. + static void SetUpTestSuite() {} + + // Tears down the stuff shared by all tests in this test suite. + // + // Google Test will call Foo::TearDownTestSuite() after running the last + // test in test suite Foo. Hence a sub-class can define its own + // TearDownTestSuite() method to shadow the one defined in the super + // class. + static void TearDownTestSuite() {} + + // Legacy API is deprecated but still available. Use SetUpTestSuite and + // TearDownTestSuite instead. +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + static void TearDownTestCase() {} + static void SetUpTestCase() {} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Returns true if and only if the current test has a fatal failure. + static bool HasFatalFailure(); + + // Returns true if and only if the current test has a non-fatal failure. + static bool HasNonfatalFailure(); + + // Returns true if and only if the current test was skipped. + static bool IsSkipped(); + + // Returns true if and only if the current test has a (either fatal or + // non-fatal) failure. + static bool HasFailure() { return HasFatalFailure() || HasNonfatalFailure(); } + + // Logs a property for the current test, test suite, or for the entire + // invocation of the test program when used outside of the context of a + // test suite. Only the last value for a given key is remembered. These + // are public static so they can be called from utility functions that are + // not members of the test fixture. Calls to RecordProperty made during + // lifespan of the test (from the moment its constructor starts to the + // moment its destructor finishes) will be output in XML as attributes of + // the element. Properties recorded from fixture's + // SetUpTestSuite or TearDownTestSuite are logged as attributes of the + // corresponding element. Calls to RecordProperty made in the + // global context (before or after invocation of RUN_ALL_TESTS and from + // SetUp/TearDown method of Environment objects registered with Google + // Test) will be output as attributes of the element. + static void RecordProperty(const std::string& key, const std::string& value); + static void RecordProperty(const std::string& key, int value); + + protected: + // Creates a Test object. + Test(); + + // Sets up the test fixture. + virtual void SetUp(); + + // Tears down the test fixture. + virtual void TearDown(); + + private: + // Returns true if and only if the current test has the same fixture class + // as the first test in the current test suite. + static bool HasSameFixtureClass(); + + // Runs the test after the test fixture has been set up. + // + // A sub-class must implement this to define the test logic. + // + // DO NOT OVERRIDE THIS FUNCTION DIRECTLY IN A USER PROGRAM. + // Instead, use the TEST or TEST_F macro. + virtual void TestBody() = 0; + + // Sets up, executes, and tears down the test. + void Run(); + + // Deletes self. We deliberately pick an unusual name for this + // internal method to avoid clashing with names used in user TESTs. + void DeleteSelf_() { delete this; } + + const std::unique_ptr gtest_flag_saver_; + + // Often a user misspells SetUp() as Setup() and spends a long time + // wondering why it is never called by Google Test. The declaration of + // the following method is solely for catching such an error at + // compile time: + // + // - The return type is deliberately chosen to be not void, so it + // will be a conflict if void Setup() is declared in the user's + // test fixture. + // + // - This method is private, so it will be another compiler error + // if the method is called from the user's test fixture. + // + // DO NOT OVERRIDE THIS FUNCTION. + // + // If you see an error about overriding the following function or + // about it being private, you have mis-spelled SetUp() as Setup(). + struct Setup_should_be_spelled_SetUp {}; + virtual Setup_should_be_spelled_SetUp* Setup() { return nullptr; } + + // We disallow copying Tests. + GTEST_DISALLOW_COPY_AND_ASSIGN_(Test); +}; + +typedef internal::TimeInMillis TimeInMillis; + +// A copyable object representing a user specified test property which can be +// output as a key/value string pair. +// +// Don't inherit from TestProperty as its destructor is not virtual. +class TestProperty { + public: + // C'tor. TestProperty does NOT have a default constructor. + // Always use this constructor (with parameters) to create a + // TestProperty object. + TestProperty(const std::string& a_key, const std::string& a_value) : + key_(a_key), value_(a_value) { + } + + // Gets the user supplied key. + const char* key() const { + return key_.c_str(); + } + + // Gets the user supplied value. + const char* value() const { + return value_.c_str(); + } + + // Sets a new value, overriding the one supplied in the constructor. + void SetValue(const std::string& new_value) { + value_ = new_value; + } + + private: + // The key supplied by the user. + std::string key_; + // The value supplied by the user. + std::string value_; +}; + +// The result of a single Test. This includes a list of +// TestPartResults, a list of TestProperties, a count of how many +// death tests there are in the Test, and how much time it took to run +// the Test. +// +// TestResult is not copyable. +class GTEST_API_ TestResult { + public: + // Creates an empty TestResult. + TestResult(); + + // D'tor. Do not inherit from TestResult. + ~TestResult(); + + // Gets the number of all test parts. This is the sum of the number + // of successful test parts and the number of failed test parts. + int total_part_count() const; + + // Returns the number of the test properties. + int test_property_count() const; + + // Returns true if and only if the test passed (i.e. no test part failed). + bool Passed() const { return !Skipped() && !Failed(); } + + // Returns true if and only if the test was skipped. + bool Skipped() const; + + // Returns true if and only if the test failed. + bool Failed() const; + + // Returns true if and only if the test fatally failed. + bool HasFatalFailure() const; + + // Returns true if and only if the test has a non-fatal failure. + bool HasNonfatalFailure() const; + + // Returns the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Gets the time of the test case start, in ms from the start of the + // UNIX epoch. + TimeInMillis start_timestamp() const { return start_timestamp_; } + + // Returns the i-th test part result among all the results. i can range from 0 + // to total_part_count() - 1. If i is not in that range, aborts the program. + const TestPartResult& GetTestPartResult(int i) const; + + // Returns the i-th test property. i can range from 0 to + // test_property_count() - 1. If i is not in that range, aborts the + // program. + const TestProperty& GetTestProperty(int i) const; + + private: + friend class TestInfo; + friend class TestSuite; + friend class UnitTest; + friend class internal::DefaultGlobalTestPartResultReporter; + friend class internal::ExecDeathTest; + friend class internal::TestResultAccessor; + friend class internal::UnitTestImpl; + friend class internal::WindowsDeathTest; + friend class internal::FuchsiaDeathTest; + + // Gets the vector of TestPartResults. + const std::vector& test_part_results() const { + return test_part_results_; + } + + // Gets the vector of TestProperties. + const std::vector& test_properties() const { + return test_properties_; + } + + // Sets the start time. + void set_start_timestamp(TimeInMillis start) { start_timestamp_ = start; } + + // Sets the elapsed time. + void set_elapsed_time(TimeInMillis elapsed) { elapsed_time_ = elapsed; } + + // Adds a test property to the list. The property is validated and may add + // a non-fatal failure if invalid (e.g., if it conflicts with reserved + // key names). If a property is already recorded for the same key, the + // value will be updated, rather than storing multiple values for the same + // key. xml_element specifies the element for which the property is being + // recorded and is used for validation. + void RecordProperty(const std::string& xml_element, + const TestProperty& test_property); + + // Adds a failure if the key is a reserved attribute of Google Test + // testsuite tags. Returns true if the property is valid. + // FIXME: Validate attribute names are legal and human readable. + static bool ValidateTestProperty(const std::string& xml_element, + const TestProperty& test_property); + + // Adds a test part result to the list. + void AddTestPartResult(const TestPartResult& test_part_result); + + // Returns the death test count. + int death_test_count() const { return death_test_count_; } + + // Increments the death test count, returning the new count. + int increment_death_test_count() { return ++death_test_count_; } + + // Clears the test part results. + void ClearTestPartResults(); + + // Clears the object. + void Clear(); + + // Protects mutable state of the property vector and of owned + // properties, whose values may be updated. + internal::Mutex test_properties_mutex_; + + // The vector of TestPartResults + std::vector test_part_results_; + // The vector of TestProperties + std::vector test_properties_; + // Running count of death tests. + int death_test_count_; + // The start time, in milliseconds since UNIX Epoch. + TimeInMillis start_timestamp_; + // The elapsed time, in milliseconds. + TimeInMillis elapsed_time_; + + // We disallow copying TestResult. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestResult); +}; // class TestResult + +// A TestInfo object stores the following information about a test: +// +// Test suite name +// Test name +// Whether the test should be run +// A function pointer that creates the test object when invoked +// Test result +// +// The constructor of TestInfo registers itself with the UnitTest +// singleton such that the RUN_ALL_TESTS() macro knows which tests to +// run. +class GTEST_API_ TestInfo { + public: + // Destructs a TestInfo object. This function is not virtual, so + // don't inherit from TestInfo. + ~TestInfo(); + + // Returns the test suite name. + const char* test_suite_name() const { return test_suite_name_.c_str(); } + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + const char* test_case_name() const { return test_suite_name(); } +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Returns the test name. + const char* name() const { return name_.c_str(); } + + // Returns the name of the parameter type, or NULL if this is not a typed + // or a type-parameterized test. + const char* type_param() const { + if (type_param_.get() != nullptr) return type_param_->c_str(); + return nullptr; + } + + // Returns the text representation of the value parameter, or NULL if this + // is not a value-parameterized test. + const char* value_param() const { + if (value_param_.get() != nullptr) return value_param_->c_str(); + return nullptr; + } + + // Returns the file name where this test is defined. + const char* file() const { return location_.file.c_str(); } + + // Returns the line where this test is defined. + int line() const { return location_.line; } + + // Return true if this test should not be run because it's in another shard. + bool is_in_another_shard() const { return is_in_another_shard_; } + + // Returns true if this test should run, that is if the test is not + // disabled (or it is disabled but the also_run_disabled_tests flag has + // been specified) and its full name matches the user-specified filter. + // + // Google Test allows the user to filter the tests by their full names. + // The full name of a test Bar in test suite Foo is defined as + // "Foo.Bar". Only the tests that match the filter will run. + // + // A filter is a colon-separated list of glob (not regex) patterns, + // optionally followed by a '-' and a colon-separated list of + // negative patterns (tests to exclude). A test is run if it + // matches one of the positive patterns and does not match any of + // the negative patterns. + // + // For example, *A*:Foo.* is a filter that matches any string that + // contains the character 'A' or starts with "Foo.". + bool should_run() const { return should_run_; } + + // Returns true if and only if this test will appear in the XML report. + bool is_reportable() const { + // The XML report includes tests matching the filter, excluding those + // run in other shards. + return matches_filter_ && !is_in_another_shard_; + } + + // Returns the result of the test. + const TestResult* result() const { return &result_; } + + private: +#if GTEST_HAS_DEATH_TEST + friend class internal::DefaultDeathTestFactory; +#endif // GTEST_HAS_DEATH_TEST + friend class Test; + friend class TestSuite; + friend class internal::UnitTestImpl; + friend class internal::StreamingListenerTest; + friend TestInfo* internal::MakeAndRegisterTestInfo( + const char* test_suite_name, const char* name, const char* type_param, + const char* value_param, internal::CodeLocation code_location, + internal::TypeId fixture_class_id, internal::SetUpTestSuiteFunc set_up_tc, + internal::TearDownTestSuiteFunc tear_down_tc, + internal::TestFactoryBase* factory); + + // Constructs a TestInfo object. The newly constructed instance assumes + // ownership of the factory object. + TestInfo(const std::string& test_suite_name, const std::string& name, + const char* a_type_param, // NULL if not a type-parameterized test + const char* a_value_param, // NULL if not a value-parameterized test + internal::CodeLocation a_code_location, + internal::TypeId fixture_class_id, + internal::TestFactoryBase* factory); + + // Increments the number of death tests encountered in this test so + // far. + int increment_death_test_count() { + return result_.increment_death_test_count(); + } + + // Creates the test object, runs it, records its result, and then + // deletes it. + void Run(); + + // Skip and records the test result for this object. + void Skip(); + + static void ClearTestResult(TestInfo* test_info) { + test_info->result_.Clear(); + } + + // These fields are immutable properties of the test. + const std::string test_suite_name_; // test suite name + const std::string name_; // Test name + // Name of the parameter type, or NULL if this is not a typed or a + // type-parameterized test. + const std::unique_ptr type_param_; + // Text representation of the value parameter, or NULL if this is not a + // value-parameterized test. + const std::unique_ptr value_param_; + internal::CodeLocation location_; + const internal::TypeId fixture_class_id_; // ID of the test fixture class + bool should_run_; // True if and only if this test should run + bool is_disabled_; // True if and only if this test is disabled + bool matches_filter_; // True if this test matches the + // user-specified filter. + bool is_in_another_shard_; // Will be run in another shard. + internal::TestFactoryBase* const factory_; // The factory that creates + // the test object + + // This field is mutable and needs to be reset before running the + // test for the second time. + TestResult result_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestInfo); +}; + +// A test suite, which consists of a vector of TestInfos. +// +// TestSuite is not copyable. +class GTEST_API_ TestSuite { + public: + // Creates a TestSuite with the given name. + // + // TestSuite does NOT have a default constructor. Always use this + // constructor to create a TestSuite object. + // + // Arguments: + // + // name: name of the test suite + // a_type_param: the name of the test's type parameter, or NULL if + // this is not a type-parameterized test. + // set_up_tc: pointer to the function that sets up the test suite + // tear_down_tc: pointer to the function that tears down the test suite + TestSuite(const char* name, const char* a_type_param, + internal::SetUpTestSuiteFunc set_up_tc, + internal::TearDownTestSuiteFunc tear_down_tc); + + // Destructor of TestSuite. + virtual ~TestSuite(); + + // Gets the name of the TestSuite. + const char* name() const { return name_.c_str(); } + + // Returns the name of the parameter type, or NULL if this is not a + // type-parameterized test suite. + const char* type_param() const { + if (type_param_.get() != nullptr) return type_param_->c_str(); + return nullptr; + } + + // Returns true if any test in this test suite should run. + bool should_run() const { return should_run_; } + + // Gets the number of successful tests in this test suite. + int successful_test_count() const; + + // Gets the number of skipped tests in this test suite. + int skipped_test_count() const; + + // Gets the number of failed tests in this test suite. + int failed_test_count() const; + + // Gets the number of disabled tests that will be reported in the XML report. + int reportable_disabled_test_count() const; + + // Gets the number of disabled tests in this test suite. + int disabled_test_count() const; + + // Gets the number of tests to be printed in the XML report. + int reportable_test_count() const; + + // Get the number of tests in this test suite that should run. + int test_to_run_count() const; + + // Gets the number of all tests in this test suite. + int total_test_count() const; + + // Returns true if and only if the test suite passed. + bool Passed() const { return !Failed(); } + + // Returns true if and only if the test suite failed. + bool Failed() const { + return failed_test_count() > 0 || ad_hoc_test_result().Failed(); + } + + // Returns the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Gets the time of the test suite start, in ms from the start of the + // UNIX epoch. + TimeInMillis start_timestamp() const { return start_timestamp_; } + + // Returns the i-th test among all the tests. i can range from 0 to + // total_test_count() - 1. If i is not in that range, returns NULL. + const TestInfo* GetTestInfo(int i) const; + + // Returns the TestResult that holds test properties recorded during + // execution of SetUpTestSuite and TearDownTestSuite. + const TestResult& ad_hoc_test_result() const { return ad_hoc_test_result_; } + + private: + friend class Test; + friend class internal::UnitTestImpl; + + // Gets the (mutable) vector of TestInfos in this TestSuite. + std::vector& test_info_list() { return test_info_list_; } + + // Gets the (immutable) vector of TestInfos in this TestSuite. + const std::vector& test_info_list() const { + return test_info_list_; + } + + // Returns the i-th test among all the tests. i can range from 0 to + // total_test_count() - 1. If i is not in that range, returns NULL. + TestInfo* GetMutableTestInfo(int i); + + // Sets the should_run member. + void set_should_run(bool should) { should_run_ = should; } + + // Adds a TestInfo to this test suite. Will delete the TestInfo upon + // destruction of the TestSuite object. + void AddTestInfo(TestInfo * test_info); + + // Clears the results of all tests in this test suite. + void ClearResult(); + + // Clears the results of all tests in the given test suite. + static void ClearTestSuiteResult(TestSuite* test_suite) { + test_suite->ClearResult(); + } + + // Runs every test in this TestSuite. + void Run(); + + // Skips the execution of tests under this TestSuite + void Skip(); + + // Runs SetUpTestSuite() for this TestSuite. This wrapper is needed + // for catching exceptions thrown from SetUpTestSuite(). + void RunSetUpTestSuite() { + if (set_up_tc_ != nullptr) { + (*set_up_tc_)(); + } + } + + // Runs TearDownTestSuite() for this TestSuite. This wrapper is + // needed for catching exceptions thrown from TearDownTestSuite(). + void RunTearDownTestSuite() { + if (tear_down_tc_ != nullptr) { + (*tear_down_tc_)(); + } + } + + // Returns true if and only if test passed. + static bool TestPassed(const TestInfo* test_info) { + return test_info->should_run() && test_info->result()->Passed(); + } + + // Returns true if and only if test skipped. + static bool TestSkipped(const TestInfo* test_info) { + return test_info->should_run() && test_info->result()->Skipped(); + } + + // Returns true if and only if test failed. + static bool TestFailed(const TestInfo* test_info) { + return test_info->should_run() && test_info->result()->Failed(); + } + + // Returns true if and only if the test is disabled and will be reported in + // the XML report. + static bool TestReportableDisabled(const TestInfo* test_info) { + return test_info->is_reportable() && test_info->is_disabled_; + } + + // Returns true if and only if test is disabled. + static bool TestDisabled(const TestInfo* test_info) { + return test_info->is_disabled_; + } + + // Returns true if and only if this test will appear in the XML report. + static bool TestReportable(const TestInfo* test_info) { + return test_info->is_reportable(); + } + + // Returns true if the given test should run. + static bool ShouldRunTest(const TestInfo* test_info) { + return test_info->should_run(); + } + + // Shuffles the tests in this test suite. + void ShuffleTests(internal::Random* random); + + // Restores the test order to before the first shuffle. + void UnshuffleTests(); + + // Name of the test suite. + std::string name_; + // Name of the parameter type, or NULL if this is not a typed or a + // type-parameterized test. + const std::unique_ptr type_param_; + // The vector of TestInfos in their original order. It owns the + // elements in the vector. + std::vector test_info_list_; + // Provides a level of indirection for the test list to allow easy + // shuffling and restoring the test order. The i-th element in this + // vector is the index of the i-th test in the shuffled test list. + std::vector test_indices_; + // Pointer to the function that sets up the test suite. + internal::SetUpTestSuiteFunc set_up_tc_; + // Pointer to the function that tears down the test suite. + internal::TearDownTestSuiteFunc tear_down_tc_; + // True if and only if any test in this test suite should run. + bool should_run_; + // The start time, in milliseconds since UNIX Epoch. + TimeInMillis start_timestamp_; + // Elapsed time, in milliseconds. + TimeInMillis elapsed_time_; + // Holds test properties recorded during execution of SetUpTestSuite and + // TearDownTestSuite. + TestResult ad_hoc_test_result_; + + // We disallow copying TestSuites. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestSuite); +}; + +// An Environment object is capable of setting up and tearing down an +// environment. You should subclass this to define your own +// environment(s). +// +// An Environment object does the set-up and tear-down in virtual +// methods SetUp() and TearDown() instead of the constructor and the +// destructor, as: +// +// 1. You cannot safely throw from a destructor. This is a problem +// as in some cases Google Test is used where exceptions are enabled, and +// we may want to implement ASSERT_* using exceptions where they are +// available. +// 2. You cannot use ASSERT_* directly in a constructor or +// destructor. +class Environment { + public: + // The d'tor is virtual as we need to subclass Environment. + virtual ~Environment() {} + + // Override this to define how to set up the environment. + virtual void SetUp() {} + + // Override this to define how to tear down the environment. + virtual void TearDown() {} + private: + // If you see an error about overriding the following function or + // about it being private, you have mis-spelled SetUp() as Setup(). + struct Setup_should_be_spelled_SetUp {}; + virtual Setup_should_be_spelled_SetUp* Setup() { return nullptr; } +}; + +#if GTEST_HAS_EXCEPTIONS + +// Exception which can be thrown from TestEventListener::OnTestPartResult. +class GTEST_API_ AssertionException + : public internal::GoogleTestFailureException { + public: + explicit AssertionException(const TestPartResult& result) + : GoogleTestFailureException(result) {} +}; + +#endif // GTEST_HAS_EXCEPTIONS + +// The interface for tracing execution of tests. The methods are organized in +// the order the corresponding events are fired. +class TestEventListener { + public: + virtual ~TestEventListener() {} + + // Fired before any test activity starts. + virtual void OnTestProgramStart(const UnitTest& unit_test) = 0; + + // Fired before each iteration of tests starts. There may be more than + // one iteration if GTEST_FLAG(repeat) is set. iteration is the iteration + // index, starting from 0. + virtual void OnTestIterationStart(const UnitTest& unit_test, + int iteration) = 0; + + // Fired before environment set-up for each iteration of tests starts. + virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test) = 0; + + // Fired after environment set-up for each iteration of tests ends. + virtual void OnEnvironmentsSetUpEnd(const UnitTest& unit_test) = 0; + + // Fired before the test suite starts. + virtual void OnTestSuiteStart(const TestSuite& /*test_suite*/) {} + + // Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + virtual void OnTestCaseStart(const TestCase& /*test_case*/) {} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Fired before the test starts. + virtual void OnTestStart(const TestInfo& test_info) = 0; + + // Fired when a test is disabled + virtual void OnTestDisabled(const TestInfo& /*test_info*/) {} + + // Fired after a failed assertion or a SUCCEED() invocation. + // If you want to throw an exception from this function to skip to the next + // TEST, it must be AssertionException defined above, or inherited from it. + virtual void OnTestPartResult(const TestPartResult& test_part_result) = 0; + + // Fired after the test ends. + virtual void OnTestEnd(const TestInfo& test_info) = 0; + + // Fired after the test suite ends. + virtual void OnTestSuiteEnd(const TestSuite& /*test_suite*/) {} + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + virtual void OnTestCaseEnd(const TestCase& /*test_case*/) {} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Fired before environment tear-down for each iteration of tests starts. + virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test) = 0; + + // Fired after environment tear-down for each iteration of tests ends. + virtual void OnEnvironmentsTearDownEnd(const UnitTest& unit_test) = 0; + + // Fired after each iteration of tests finishes. + virtual void OnTestIterationEnd(const UnitTest& unit_test, + int iteration) = 0; + + // Fired after all test activities have ended. + virtual void OnTestProgramEnd(const UnitTest& unit_test) = 0; +}; + +// The convenience class for users who need to override just one or two +// methods and are not concerned that a possible change to a signature of +// the methods they override will not be caught during the build. For +// comments about each method please see the definition of TestEventListener +// above. +class EmptyTestEventListener : public TestEventListener { + public: + void OnTestProgramStart(const UnitTest& /*unit_test*/) override {} + void OnTestIterationStart(const UnitTest& /*unit_test*/, + int /*iteration*/) override {} + void OnEnvironmentsSetUpStart(const UnitTest& /*unit_test*/) override {} + void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) override {} + void OnTestSuiteStart(const TestSuite& /*test_suite*/) override {} +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseStart(const TestCase& /*test_case*/) override {} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + void OnTestStart(const TestInfo& /*test_info*/) override {} + void OnTestDisabled(const TestInfo& /*test_info*/) override {} + void OnTestPartResult(const TestPartResult& /*test_part_result*/) override {} + void OnTestEnd(const TestInfo& /*test_info*/) override {} + void OnTestSuiteEnd(const TestSuite& /*test_suite*/) override {} +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseEnd(const TestCase& /*test_case*/) override {} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + void OnEnvironmentsTearDownStart(const UnitTest& /*unit_test*/) override {} + void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) override {} + void OnTestIterationEnd(const UnitTest& /*unit_test*/, + int /*iteration*/) override {} + void OnTestProgramEnd(const UnitTest& /*unit_test*/) override {} +}; + +// TestEventListeners lets users add listeners to track events in Google Test. +class GTEST_API_ TestEventListeners { + public: + TestEventListeners(); + ~TestEventListeners(); + + // Appends an event listener to the end of the list. Google Test assumes + // the ownership of the listener (i.e. it will delete the listener when + // the test program finishes). + void Append(TestEventListener* listener); + + // Removes the given event listener from the list and returns it. It then + // becomes the caller's responsibility to delete the listener. Returns + // NULL if the listener is not found in the list. + TestEventListener* Release(TestEventListener* listener); + + // Returns the standard listener responsible for the default console + // output. Can be removed from the listeners list to shut down default + // console output. Note that removing this object from the listener list + // with Release transfers its ownership to the caller and makes this + // function return NULL the next time. + TestEventListener* default_result_printer() const { + return default_result_printer_; + } + + // Returns the standard listener responsible for the default XML output + // controlled by the --gtest_output=xml flag. Can be removed from the + // listeners list by users who want to shut down the default XML output + // controlled by this flag and substitute it with custom one. Note that + // removing this object from the listener list with Release transfers its + // ownership to the caller and makes this function return NULL the next + // time. + TestEventListener* default_xml_generator() const { + return default_xml_generator_; + } + + private: + friend class TestSuite; + friend class TestInfo; + friend class internal::DefaultGlobalTestPartResultReporter; + friend class internal::NoExecDeathTest; + friend class internal::TestEventListenersAccessor; + friend class internal::UnitTestImpl; + + // Returns repeater that broadcasts the TestEventListener events to all + // subscribers. + TestEventListener* repeater(); + + // Sets the default_result_printer attribute to the provided listener. + // The listener is also added to the listener list and previous + // default_result_printer is removed from it and deleted. The listener can + // also be NULL in which case it will not be added to the list. Does + // nothing if the previous and the current listener objects are the same. + void SetDefaultResultPrinter(TestEventListener* listener); + + // Sets the default_xml_generator attribute to the provided listener. The + // listener is also added to the listener list and previous + // default_xml_generator is removed from it and deleted. The listener can + // also be NULL in which case it will not be added to the list. Does + // nothing if the previous and the current listener objects are the same. + void SetDefaultXmlGenerator(TestEventListener* listener); + + // Controls whether events will be forwarded by the repeater to the + // listeners in the list. + bool EventForwardingEnabled() const; + void SuppressEventForwarding(); + + // The actual list of listeners. + internal::TestEventRepeater* repeater_; + // Listener responsible for the standard result output. + TestEventListener* default_result_printer_; + // Listener responsible for the creation of the XML output file. + TestEventListener* default_xml_generator_; + + // We disallow copying TestEventListeners. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventListeners); +}; + +// A UnitTest consists of a vector of TestSuites. +// +// This is a singleton class. The only instance of UnitTest is +// created when UnitTest::GetInstance() is first called. This +// instance is never deleted. +// +// UnitTest is not copyable. +// +// This class is thread-safe as long as the methods are called +// according to their specification. +class GTEST_API_ UnitTest { + public: + // Gets the singleton UnitTest object. The first time this method + // is called, a UnitTest object is constructed and returned. + // Consecutive calls will return the same object. + static UnitTest* GetInstance(); + + // Runs all tests in this UnitTest object and prints the result. + // Returns 0 if successful, or 1 otherwise. + // + // This method can only be called from the main thread. + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + int Run() GTEST_MUST_USE_RESULT_; + + // Returns the working directory when the first TEST() or TEST_F() + // was executed. The UnitTest object owns the string. + const char* original_working_dir() const; + + // Returns the TestSuite object for the test that's currently running, + // or NULL if no test is running. + const TestSuite* current_test_suite() const GTEST_LOCK_EXCLUDED_(mutex_); + +// Legacy API is still available but deprecated +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + const TestCase* current_test_case() const GTEST_LOCK_EXCLUDED_(mutex_); +#endif + + // Returns the TestInfo object for the test that's currently running, + // or NULL if no test is running. + const TestInfo* current_test_info() const + GTEST_LOCK_EXCLUDED_(mutex_); + + // Returns the random seed used at the start of the current test run. + int random_seed() const; + + // Returns the ParameterizedTestSuiteRegistry object used to keep track of + // value-parameterized tests and instantiate and register them. + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + internal::ParameterizedTestSuiteRegistry& parameterized_test_registry() + GTEST_LOCK_EXCLUDED_(mutex_); + + // Gets the number of successful test suites. + int successful_test_suite_count() const; + + // Gets the number of failed test suites. + int failed_test_suite_count() const; + + // Gets the number of all test suites. + int total_test_suite_count() const; + + // Gets the number of all test suites that contain at least one test + // that should run. + int test_suite_to_run_count() const; + + // Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + int successful_test_case_count() const; + int failed_test_case_count() const; + int total_test_case_count() const; + int test_case_to_run_count() const; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Gets the number of successful tests. + int successful_test_count() const; + + // Gets the number of skipped tests. + int skipped_test_count() const; + + // Gets the number of failed tests. + int failed_test_count() const; + + // Gets the number of disabled tests that will be reported in the XML report. + int reportable_disabled_test_count() const; + + // Gets the number of disabled tests. + int disabled_test_count() const; + + // Gets the number of tests to be printed in the XML report. + int reportable_test_count() const; + + // Gets the number of all tests. + int total_test_count() const; + + // Gets the number of tests that should run. + int test_to_run_count() const; + + // Gets the time of the test program start, in ms from the start of the + // UNIX epoch. + TimeInMillis start_timestamp() const; + + // Gets the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const; + + // Returns true if and only if the unit test passed (i.e. all test suites + // passed). + bool Passed() const; + + // Returns true if and only if the unit test failed (i.e. some test suite + // failed or something outside of all tests failed). + bool Failed() const; + + // Gets the i-th test suite among all the test suites. i can range from 0 to + // total_test_suite_count() - 1. If i is not in that range, returns NULL. + const TestSuite* GetTestSuite(int i) const; + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + const TestCase* GetTestCase(int i) const; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Returns the TestResult containing information on test failures and + // properties logged outside of individual test suites. + const TestResult& ad_hoc_test_result() const; + + // Returns the list of event listeners that can be used to track events + // inside Google Test. + TestEventListeners& listeners(); + + private: + // Registers and returns a global test environment. When a test + // program is run, all global test environments will be set-up in + // the order they were registered. After all tests in the program + // have finished, all global test environments will be torn-down in + // the *reverse* order they were registered. + // + // The UnitTest object takes ownership of the given environment. + // + // This method can only be called from the main thread. + Environment* AddEnvironment(Environment* env); + + // Adds a TestPartResult to the current TestResult object. All + // Google Test assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) + // eventually call this to report their results. The user code + // should use the assertion macros instead of calling this directly. + void AddTestPartResult(TestPartResult::Type result_type, + const char* file_name, + int line_number, + const std::string& message, + const std::string& os_stack_trace) + GTEST_LOCK_EXCLUDED_(mutex_); + + // Adds a TestProperty to the current TestResult object when invoked from + // inside a test, to current TestSuite's ad_hoc_test_result_ when invoked + // from SetUpTestSuite or TearDownTestSuite, or to the global property set + // when invoked elsewhere. If the result already contains a property with + // the same key, the value will be updated. + void RecordProperty(const std::string& key, const std::string& value); + + // Gets the i-th test suite among all the test suites. i can range from 0 to + // total_test_suite_count() - 1. If i is not in that range, returns NULL. + TestSuite* GetMutableTestSuite(int i); + + // Accessors for the implementation object. + internal::UnitTestImpl* impl() { return impl_; } + const internal::UnitTestImpl* impl() const { return impl_; } + + // These classes and functions are friends as they need to access private + // members of UnitTest. + friend class ScopedTrace; + friend class Test; + friend class internal::AssertHelper; + friend class internal::StreamingListenerTest; + friend class internal::UnitTestRecordPropertyTestHelper; + friend Environment* AddGlobalTestEnvironment(Environment* env); + friend std::set* internal::GetIgnoredParameterizedTestSuites(); + friend internal::UnitTestImpl* internal::GetUnitTestImpl(); + friend void internal::ReportFailureInUnknownLocation( + TestPartResult::Type result_type, + const std::string& message); + + // Creates an empty UnitTest. + UnitTest(); + + // D'tor + virtual ~UnitTest(); + + // Pushes a trace defined by SCOPED_TRACE() on to the per-thread + // Google Test trace stack. + void PushGTestTrace(const internal::TraceInfo& trace) + GTEST_LOCK_EXCLUDED_(mutex_); + + // Pops a trace from the per-thread Google Test trace stack. + void PopGTestTrace() + GTEST_LOCK_EXCLUDED_(mutex_); + + // Protects mutable state in *impl_. This is mutable as some const + // methods need to lock it too. + mutable internal::Mutex mutex_; + + // Opaque implementation object. This field is never changed once + // the object is constructed. We don't mark it as const here, as + // doing so will cause a warning in the constructor of UnitTest. + // Mutable state in *impl_ is protected by mutex_. + internal::UnitTestImpl* impl_; + + // We disallow copying UnitTest. + GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTest); +}; + +// A convenient wrapper for adding an environment for the test +// program. +// +// You should call this before RUN_ALL_TESTS() is called, probably in +// main(). If you use gtest_main, you need to call this before main() +// starts for it to take effect. For example, you can define a global +// variable like this: +// +// testing::Environment* const foo_env = +// testing::AddGlobalTestEnvironment(new FooEnvironment); +// +// However, we strongly recommend you to write your own main() and +// call AddGlobalTestEnvironment() there, as relying on initialization +// of global variables makes the code harder to read and may cause +// problems when you register multiple environments from different +// translation units and the environments have dependencies among them +// (remember that the compiler doesn't guarantee the order in which +// global variables from different translation units are initialized). +inline Environment* AddGlobalTestEnvironment(Environment* env) { + return UnitTest::GetInstance()->AddEnvironment(env); +} + +// Initializes Google Test. This must be called before calling +// RUN_ALL_TESTS(). In particular, it parses a command line for the +// flags that Google Test recognizes. Whenever a Google Test flag is +// seen, it is removed from argv, and *argc is decremented. +// +// No value is returned. Instead, the Google Test flag variables are +// updated. +// +// Calling the function for the second time has no user-visible effect. +GTEST_API_ void InitGoogleTest(int* argc, char** argv); + +// This overloaded version can be used in Windows programs compiled in +// UNICODE mode. +GTEST_API_ void InitGoogleTest(int* argc, wchar_t** argv); + +// This overloaded version can be used on Arduino/embedded platforms where +// there is no argc/argv. +GTEST_API_ void InitGoogleTest(); + +namespace internal { + +// Separate the error generating code from the code path to reduce the stack +// frame size of CmpHelperEQ. This helps reduce the overhead of some sanitizers +// when calling EXPECT_* in a tight loop. +template +AssertionResult CmpHelperEQFailure(const char* lhs_expression, + const char* rhs_expression, + const T1& lhs, const T2& rhs) { + return EqFailure(lhs_expression, + rhs_expression, + FormatForComparisonFailureMessage(lhs, rhs), + FormatForComparisonFailureMessage(rhs, lhs), + false); +} + +// This block of code defines operator==/!= +// to block lexical scope lookup. +// It prevents using invalid operator==/!= defined at namespace scope. +struct faketype {}; +inline bool operator==(faketype, faketype) { return true; } +inline bool operator!=(faketype, faketype) { return false; } + +// The helper function for {ASSERT|EXPECT}_EQ. +template +AssertionResult CmpHelperEQ(const char* lhs_expression, + const char* rhs_expression, + const T1& lhs, + const T2& rhs) { + if (lhs == rhs) { + return AssertionSuccess(); + } + + return CmpHelperEQFailure(lhs_expression, rhs_expression, lhs, rhs); +} + +class EqHelper { + public: + // This templatized version is for the general case. + template < + typename T1, typename T2, + // Disable this overload for cases where one argument is a pointer + // and the other is the null pointer constant. + typename std::enable_if::value || + !std::is_pointer::value>::type* = nullptr> + static AssertionResult Compare(const char* lhs_expression, + const char* rhs_expression, const T1& lhs, + const T2& rhs) { + return CmpHelperEQ(lhs_expression, rhs_expression, lhs, rhs); + } + + // With this overloaded version, we allow anonymous enums to be used + // in {ASSERT|EXPECT}_EQ when compiled with gcc 4, as anonymous + // enums can be implicitly cast to BiggestInt. + // + // Even though its body looks the same as the above version, we + // cannot merge the two, as it will make anonymous enums unhappy. + static AssertionResult Compare(const char* lhs_expression, + const char* rhs_expression, + BiggestInt lhs, + BiggestInt rhs) { + return CmpHelperEQ(lhs_expression, rhs_expression, lhs, rhs); + } + + template + static AssertionResult Compare( + const char* lhs_expression, const char* rhs_expression, + // Handle cases where '0' is used as a null pointer literal. + std::nullptr_t /* lhs */, T* rhs) { + // We already know that 'lhs' is a null pointer. + return CmpHelperEQ(lhs_expression, rhs_expression, static_cast(nullptr), + rhs); + } +}; + +// Separate the error generating code from the code path to reduce the stack +// frame size of CmpHelperOP. This helps reduce the overhead of some sanitizers +// when calling EXPECT_OP in a tight loop. +template +AssertionResult CmpHelperOpFailure(const char* expr1, const char* expr2, + const T1& val1, const T2& val2, + const char* op) { + return AssertionFailure() + << "Expected: (" << expr1 << ") " << op << " (" << expr2 + << "), actual: " << FormatForComparisonFailureMessage(val1, val2) + << " vs " << FormatForComparisonFailureMessage(val2, val1); +} + +// A macro for implementing the helper functions needed to implement +// ASSERT_?? and EXPECT_??. It is here just to avoid copy-and-paste +// of similar code. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + +#define GTEST_IMPL_CMP_HELPER_(op_name, op)\ +template \ +AssertionResult CmpHelper##op_name(const char* expr1, const char* expr2, \ + const T1& val1, const T2& val2) {\ + if (val1 op val2) {\ + return AssertionSuccess();\ + } else {\ + return CmpHelperOpFailure(expr1, expr2, val1, val2, #op);\ + }\ +} + +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + +// Implements the helper function for {ASSERT|EXPECT}_NE +GTEST_IMPL_CMP_HELPER_(NE, !=) +// Implements the helper function for {ASSERT|EXPECT}_LE +GTEST_IMPL_CMP_HELPER_(LE, <=) +// Implements the helper function for {ASSERT|EXPECT}_LT +GTEST_IMPL_CMP_HELPER_(LT, <) +// Implements the helper function for {ASSERT|EXPECT}_GE +GTEST_IMPL_CMP_HELPER_(GE, >=) +// Implements the helper function for {ASSERT|EXPECT}_GT +GTEST_IMPL_CMP_HELPER_(GT, >) + +#undef GTEST_IMPL_CMP_HELPER_ + +// The helper function for {ASSERT|EXPECT}_STREQ. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTREQ(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + +// The helper function for {ASSERT|EXPECT}_STRCASEEQ. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRCASEEQ(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + +// The helper function for {ASSERT|EXPECT}_STRNE. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + +// The helper function for {ASSERT|EXPECT}_STRCASENE. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRCASENE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + + +// Helper function for *_STREQ on wide strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTREQ(const char* s1_expression, + const char* s2_expression, + const wchar_t* s1, + const wchar_t* s2); + +// Helper function for *_STRNE on wide strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const wchar_t* s1, + const wchar_t* s2); + +} // namespace internal + +// IsSubstring() and IsNotSubstring() are intended to be used as the +// first argument to {EXPECT,ASSERT}_PRED_FORMAT2(), not by +// themselves. They check whether needle is a substring of haystack +// (NULL is considered a substring of itself only), and return an +// appropriate error message when they fail. +// +// The {needle,haystack}_expr arguments are the stringified +// expressions that generated the two real arguments. +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack); +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack); +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack); + +#if GTEST_HAS_STD_WSTRING +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack); +#endif // GTEST_HAS_STD_WSTRING + +namespace internal { + +// Helper template function for comparing floating-points. +// +// Template parameter: +// +// RawType: the raw floating-point type (either float or double) +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +template +AssertionResult CmpHelperFloatingPointEQ(const char* lhs_expression, + const char* rhs_expression, + RawType lhs_value, + RawType rhs_value) { + const FloatingPoint lhs(lhs_value), rhs(rhs_value); + + if (lhs.AlmostEquals(rhs)) { + return AssertionSuccess(); + } + + ::std::stringstream lhs_ss; + lhs_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << lhs_value; + + ::std::stringstream rhs_ss; + rhs_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << rhs_value; + + return EqFailure(lhs_expression, + rhs_expression, + StringStreamToString(&lhs_ss), + StringStreamToString(&rhs_ss), + false); +} + +// Helper function for implementing ASSERT_NEAR. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult DoubleNearPredFormat(const char* expr1, + const char* expr2, + const char* abs_error_expr, + double val1, + double val2, + double abs_error); + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// A class that enables one to stream messages to assertion macros +class GTEST_API_ AssertHelper { + public: + // Constructor. + AssertHelper(TestPartResult::Type type, + const char* file, + int line, + const char* message); + ~AssertHelper(); + + // Message assignment is a semantic trick to enable assertion + // streaming; see the GTEST_MESSAGE_ macro below. + void operator=(const Message& message) const; + + private: + // We put our data in a struct so that the size of the AssertHelper class can + // be as small as possible. This is important because gcc is incapable of + // re-using stack space even for temporary variables, so every EXPECT_EQ + // reserves stack space for another AssertHelper. + struct AssertHelperData { + AssertHelperData(TestPartResult::Type t, + const char* srcfile, + int line_num, + const char* msg) + : type(t), file(srcfile), line(line_num), message(msg) { } + + TestPartResult::Type const type; + const char* const file; + int const line; + std::string const message; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelperData); + }; + + AssertHelperData* const data_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelper); +}; + +} // namespace internal + +// The pure interface class that all value-parameterized tests inherit from. +// A value-parameterized class must inherit from both ::testing::Test and +// ::testing::WithParamInterface. In most cases that just means inheriting +// from ::testing::TestWithParam, but more complicated test hierarchies +// may need to inherit from Test and WithParamInterface at different levels. +// +// This interface has support for accessing the test parameter value via +// the GetParam() method. +// +// Use it with one of the parameter generator defining functions, like Range(), +// Values(), ValuesIn(), Bool(), and Combine(). +// +// class FooTest : public ::testing::TestWithParam { +// protected: +// FooTest() { +// // Can use GetParam() here. +// } +// ~FooTest() override { +// // Can use GetParam() here. +// } +// void SetUp() override { +// // Can use GetParam() here. +// } +// void TearDown override { +// // Can use GetParam() here. +// } +// }; +// TEST_P(FooTest, DoesBar) { +// // Can use GetParam() method here. +// Foo foo; +// ASSERT_TRUE(foo.DoesBar(GetParam())); +// } +// INSTANTIATE_TEST_SUITE_P(OneToTenRange, FooTest, ::testing::Range(1, 10)); + +template +class WithParamInterface { + public: + typedef T ParamType; + virtual ~WithParamInterface() {} + + // The current parameter value. Is also available in the test fixture's + // constructor. + static const ParamType& GetParam() { + GTEST_CHECK_(parameter_ != nullptr) + << "GetParam() can only be called inside a value-parameterized test " + << "-- did you intend to write TEST_P instead of TEST_F?"; + return *parameter_; + } + + private: + // Sets parameter value. The caller is responsible for making sure the value + // remains alive and unchanged throughout the current test. + static void SetParam(const ParamType* parameter) { + parameter_ = parameter; + } + + // Static value used for accessing parameter during a test lifetime. + static const ParamType* parameter_; + + // TestClass must be a subclass of WithParamInterface and Test. + template friend class internal::ParameterizedTestFactory; +}; + +template +const T* WithParamInterface::parameter_ = nullptr; + +// Most value-parameterized classes can ignore the existence of +// WithParamInterface, and can just inherit from ::testing::TestWithParam. + +template +class TestWithParam : public Test, public WithParamInterface { +}; + +// Macros for indicating success/failure in test code. + +// Skips test in runtime. +// Skipping test aborts current function. +// Skipped tests are neither successful nor failed. +#define GTEST_SKIP() GTEST_SKIP_("") + +// ADD_FAILURE unconditionally adds a failure to the current test. +// SUCCEED generates a success - it doesn't automatically make the +// current test successful, as a test is only successful when it has +// no failure. +// +// EXPECT_* verifies that a certain condition is satisfied. If not, +// it behaves like ADD_FAILURE. In particular: +// +// EXPECT_TRUE verifies that a Boolean condition is true. +// EXPECT_FALSE verifies that a Boolean condition is false. +// +// FAIL and ASSERT_* are similar to ADD_FAILURE and EXPECT_*, except +// that they will also abort the current function on failure. People +// usually want the fail-fast behavior of FAIL and ASSERT_*, but those +// writing data-driven tests often find themselves using ADD_FAILURE +// and EXPECT_* more. + +// Generates a nonfatal failure with a generic message. +#define ADD_FAILURE() GTEST_NONFATAL_FAILURE_("Failed") + +// Generates a nonfatal failure at the given source file location with +// a generic message. +#define ADD_FAILURE_AT(file, line) \ + GTEST_MESSAGE_AT_(file, line, "Failed", \ + ::testing::TestPartResult::kNonFatalFailure) + +// Generates a fatal failure with a generic message. +#define GTEST_FAIL() GTEST_FATAL_FAILURE_("Failed") + +// Like GTEST_FAIL(), but at the given source file location. +#define GTEST_FAIL_AT(file, line) \ + GTEST_MESSAGE_AT_(file, line, "Failed", \ + ::testing::TestPartResult::kFatalFailure) + +// Define this macro to 1 to omit the definition of FAIL(), which is a +// generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_FAIL +# define FAIL() GTEST_FAIL() +#endif + +// Generates a success with a generic message. +#define GTEST_SUCCEED() GTEST_SUCCESS_("Succeeded") + +// Define this macro to 1 to omit the definition of SUCCEED(), which +// is a generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_SUCCEED +# define SUCCEED() GTEST_SUCCEED() +#endif + +// Macros for testing exceptions. +// +// * {ASSERT|EXPECT}_THROW(statement, expected_exception): +// Tests that the statement throws the expected exception. +// * {ASSERT|EXPECT}_NO_THROW(statement): +// Tests that the statement doesn't throw any exception. +// * {ASSERT|EXPECT}_ANY_THROW(statement): +// Tests that the statement throws an exception. + +#define EXPECT_THROW(statement, expected_exception) \ + GTEST_TEST_THROW_(statement, expected_exception, GTEST_NONFATAL_FAILURE_) +#define EXPECT_NO_THROW(statement) \ + GTEST_TEST_NO_THROW_(statement, GTEST_NONFATAL_FAILURE_) +#define EXPECT_ANY_THROW(statement) \ + GTEST_TEST_ANY_THROW_(statement, GTEST_NONFATAL_FAILURE_) +#define ASSERT_THROW(statement, expected_exception) \ + GTEST_TEST_THROW_(statement, expected_exception, GTEST_FATAL_FAILURE_) +#define ASSERT_NO_THROW(statement) \ + GTEST_TEST_NO_THROW_(statement, GTEST_FATAL_FAILURE_) +#define ASSERT_ANY_THROW(statement) \ + GTEST_TEST_ANY_THROW_(statement, GTEST_FATAL_FAILURE_) + +// Boolean assertions. Condition can be either a Boolean expression or an +// AssertionResult. For more information on how to use AssertionResult with +// these macros see comments on that class. +#define GTEST_EXPECT_TRUE(condition) \ + GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ + GTEST_NONFATAL_FAILURE_) +#define GTEST_EXPECT_FALSE(condition) \ + GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ + GTEST_NONFATAL_FAILURE_) +#define GTEST_ASSERT_TRUE(condition) \ + GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ + GTEST_FATAL_FAILURE_) +#define GTEST_ASSERT_FALSE(condition) \ + GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ + GTEST_FATAL_FAILURE_) + +// Define these macros to 1 to omit the definition of the corresponding +// EXPECT or ASSERT, which clashes with some users' own code. + +#if !GTEST_DONT_DEFINE_EXPECT_TRUE +#define EXPECT_TRUE(condition) GTEST_EXPECT_TRUE(condition) +#endif + +#if !GTEST_DONT_DEFINE_EXPECT_FALSE +#define EXPECT_FALSE(condition) GTEST_EXPECT_FALSE(condition) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_TRUE +#define ASSERT_TRUE(condition) GTEST_ASSERT_TRUE(condition) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_FALSE +#define ASSERT_FALSE(condition) GTEST_ASSERT_FALSE(condition) +#endif + +// Macros for testing equalities and inequalities. +// +// * {ASSERT|EXPECT}_EQ(v1, v2): Tests that v1 == v2 +// * {ASSERT|EXPECT}_NE(v1, v2): Tests that v1 != v2 +// * {ASSERT|EXPECT}_LT(v1, v2): Tests that v1 < v2 +// * {ASSERT|EXPECT}_LE(v1, v2): Tests that v1 <= v2 +// * {ASSERT|EXPECT}_GT(v1, v2): Tests that v1 > v2 +// * {ASSERT|EXPECT}_GE(v1, v2): Tests that v1 >= v2 +// +// When they are not, Google Test prints both the tested expressions and +// their actual values. The values must be compatible built-in types, +// or you will get a compiler error. By "compatible" we mean that the +// values can be compared by the respective operator. +// +// Note: +// +// 1. It is possible to make a user-defined type work with +// {ASSERT|EXPECT}_??(), but that requires overloading the +// comparison operators and is thus discouraged by the Google C++ +// Usage Guide. Therefore, you are advised to use the +// {ASSERT|EXPECT}_TRUE() macro to assert that two objects are +// equal. +// +// 2. The {ASSERT|EXPECT}_??() macros do pointer comparisons on +// pointers (in particular, C strings). Therefore, if you use it +// with two C strings, you are testing how their locations in memory +// are related, not how their content is related. To compare two C +// strings by content, use {ASSERT|EXPECT}_STR*(). +// +// 3. {ASSERT|EXPECT}_EQ(v1, v2) is preferred to +// {ASSERT|EXPECT}_TRUE(v1 == v2), as the former tells you +// what the actual value is when it fails, and similarly for the +// other comparisons. +// +// 4. Do not depend on the order in which {ASSERT|EXPECT}_??() +// evaluate their arguments, which is undefined. +// +// 5. These macros evaluate their arguments exactly once. +// +// Examples: +// +// EXPECT_NE(Foo(), 5); +// EXPECT_EQ(a_pointer, NULL); +// ASSERT_LT(i, array_size); +// ASSERT_GT(records.size(), 0) << "There is no record left."; + +#define EXPECT_EQ(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::EqHelper::Compare, val1, val2) +#define EXPECT_NE(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperNE, val1, val2) +#define EXPECT_LE(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) +#define EXPECT_LT(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) +#define EXPECT_GE(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) +#define EXPECT_GT(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) + +#define GTEST_ASSERT_EQ(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::EqHelper::Compare, val1, val2) +#define GTEST_ASSERT_NE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperNE, val1, val2) +#define GTEST_ASSERT_LE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) +#define GTEST_ASSERT_LT(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) +#define GTEST_ASSERT_GE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) +#define GTEST_ASSERT_GT(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) + +// Define macro GTEST_DONT_DEFINE_ASSERT_XY to 1 to omit the definition of +// ASSERT_XY(), which clashes with some users' own code. + +#if !GTEST_DONT_DEFINE_ASSERT_EQ +# define ASSERT_EQ(val1, val2) GTEST_ASSERT_EQ(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_NE +# define ASSERT_NE(val1, val2) GTEST_ASSERT_NE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_LE +# define ASSERT_LE(val1, val2) GTEST_ASSERT_LE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_LT +# define ASSERT_LT(val1, val2) GTEST_ASSERT_LT(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_GE +# define ASSERT_GE(val1, val2) GTEST_ASSERT_GE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_GT +# define ASSERT_GT(val1, val2) GTEST_ASSERT_GT(val1, val2) +#endif + +// C-string Comparisons. All tests treat NULL and any non-NULL string +// as different. Two NULLs are equal. +// +// * {ASSERT|EXPECT}_STREQ(s1, s2): Tests that s1 == s2 +// * {ASSERT|EXPECT}_STRNE(s1, s2): Tests that s1 != s2 +// * {ASSERT|EXPECT}_STRCASEEQ(s1, s2): Tests that s1 == s2, ignoring case +// * {ASSERT|EXPECT}_STRCASENE(s1, s2): Tests that s1 != s2, ignoring case +// +// For wide or narrow string objects, you can use the +// {ASSERT|EXPECT}_??() macros. +// +// Don't depend on the order in which the arguments are evaluated, +// which is undefined. +// +// These macros evaluate their arguments exactly once. + +#define EXPECT_STREQ(s1, s2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, s1, s2) +#define EXPECT_STRNE(s1, s2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) +#define EXPECT_STRCASEEQ(s1, s2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, s1, s2) +#define EXPECT_STRCASENE(s1, s2)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) + +#define ASSERT_STREQ(s1, s2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, s1, s2) +#define ASSERT_STRNE(s1, s2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) +#define ASSERT_STRCASEEQ(s1, s2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, s1, s2) +#define ASSERT_STRCASENE(s1, s2)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) + +// Macros for comparing floating-point numbers. +// +// * {ASSERT|EXPECT}_FLOAT_EQ(val1, val2): +// Tests that two float values are almost equal. +// * {ASSERT|EXPECT}_DOUBLE_EQ(val1, val2): +// Tests that two double values are almost equal. +// * {ASSERT|EXPECT}_NEAR(v1, v2, abs_error): +// Tests that v1 and v2 are within the given distance to each other. +// +// Google Test uses ULP-based comparison to automatically pick a default +// error bound that is appropriate for the operands. See the +// FloatingPoint template class in gtest-internal.h if you are +// interested in the implementation details. + +#define EXPECT_FLOAT_EQ(val1, val2)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + val1, val2) + +#define EXPECT_DOUBLE_EQ(val1, val2)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + val1, val2) + +#define ASSERT_FLOAT_EQ(val1, val2)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + val1, val2) + +#define ASSERT_DOUBLE_EQ(val1, val2)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + val1, val2) + +#define EXPECT_NEAR(val1, val2, abs_error)\ + EXPECT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ + val1, val2, abs_error) + +#define ASSERT_NEAR(val1, val2, abs_error)\ + ASSERT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ + val1, val2, abs_error) + +// These predicate format functions work on floating-point values, and +// can be used in {ASSERT|EXPECT}_PRED_FORMAT2*(), e.g. +// +// EXPECT_PRED_FORMAT2(testing::DoubleLE, Foo(), 5.0); + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +GTEST_API_ AssertionResult FloatLE(const char* expr1, const char* expr2, + float val1, float val2); +GTEST_API_ AssertionResult DoubleLE(const char* expr1, const char* expr2, + double val1, double val2); + + +#if GTEST_OS_WINDOWS + +// Macros that test for HRESULT failure and success, these are only useful +// on Windows, and rely on Windows SDK macros and APIs to compile. +// +// * {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED}(expr) +// +// When expr unexpectedly fails or succeeds, Google Test prints the +// expected result and the actual result with both a human-readable +// string representation of the error, if available, as well as the +// hex result code. +# define EXPECT_HRESULT_SUCCEEDED(expr) \ + EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) + +# define ASSERT_HRESULT_SUCCEEDED(expr) \ + ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) + +# define EXPECT_HRESULT_FAILED(expr) \ + EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) + +# define ASSERT_HRESULT_FAILED(expr) \ + ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) + +#endif // GTEST_OS_WINDOWS + +// Macros that execute statement and check that it doesn't generate new fatal +// failures in the current thread. +// +// * {ASSERT|EXPECT}_NO_FATAL_FAILURE(statement); +// +// Examples: +// +// EXPECT_NO_FATAL_FAILURE(Process()); +// ASSERT_NO_FATAL_FAILURE(Process()) << "Process() failed"; +// +#define ASSERT_NO_FATAL_FAILURE(statement) \ + GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_FATAL_FAILURE_) +#define EXPECT_NO_FATAL_FAILURE(statement) \ + GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_NONFATAL_FAILURE_) + +// Causes a trace (including the given source file path and line number, +// and the given message) to be included in every test failure message generated +// by code in the scope of the lifetime of an instance of this class. The effect +// is undone with the destruction of the instance. +// +// The message argument can be anything streamable to std::ostream. +// +// Example: +// testing::ScopedTrace trace("file.cc", 123, "message"); +// +class GTEST_API_ ScopedTrace { + public: + // The c'tor pushes the given source file location and message onto + // a trace stack maintained by Google Test. + + // Template version. Uses Message() to convert the values into strings. + // Slow, but flexible. + template + ScopedTrace(const char* file, int line, const T& message) { + PushTrace(file, line, (Message() << message).GetString()); + } + + // Optimize for some known types. + ScopedTrace(const char* file, int line, const char* message) { + PushTrace(file, line, message ? message : "(null)"); + } + + ScopedTrace(const char* file, int line, const std::string& message) { + PushTrace(file, line, message); + } + + // The d'tor pops the info pushed by the c'tor. + // + // Note that the d'tor is not virtual in order to be efficient. + // Don't inherit from ScopedTrace! + ~ScopedTrace(); + + private: + void PushTrace(const char* file, int line, std::string message); + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedTrace); +} GTEST_ATTRIBUTE_UNUSED_; // A ScopedTrace object does its job in its + // c'tor and d'tor. Therefore it doesn't + // need to be used otherwise. + +// Causes a trace (including the source file path, the current line +// number, and the given message) to be included in every test failure +// message generated by code in the current scope. The effect is +// undone when the control leaves the current scope. +// +// The message argument can be anything streamable to std::ostream. +// +// In the implementation, we include the current line number as part +// of the dummy variable name, thus allowing multiple SCOPED_TRACE()s +// to appear in the same block - as long as they are on different +// lines. +// +// Assuming that each thread maintains its own stack of traces. +// Therefore, a SCOPED_TRACE() would (correctly) only affect the +// assertions in its own thread. +#define SCOPED_TRACE(message) \ + ::testing::ScopedTrace GTEST_CONCAT_TOKEN_(gtest_trace_, __LINE__)(\ + __FILE__, __LINE__, (message)) + +// Compile-time assertion for type equality. +// StaticAssertTypeEq() compiles if and only if type1 and type2 +// are the same type. The value it returns is not interesting. +// +// Instead of making StaticAssertTypeEq a class template, we make it a +// function template that invokes a helper class template. This +// prevents a user from misusing StaticAssertTypeEq by +// defining objects of that type. +// +// CAVEAT: +// +// When used inside a method of a class template, +// StaticAssertTypeEq() is effective ONLY IF the method is +// instantiated. For example, given: +// +// template class Foo { +// public: +// void Bar() { testing::StaticAssertTypeEq(); } +// }; +// +// the code: +// +// void Test1() { Foo foo; } +// +// will NOT generate a compiler error, as Foo::Bar() is never +// actually instantiated. Instead, you need: +// +// void Test2() { Foo foo; foo.Bar(); } +// +// to cause a compiler error. +template +constexpr bool StaticAssertTypeEq() noexcept { + static_assert(std::is_same::value, "T1 and T2 are not the same type"); + return true; +} + +// Defines a test. +// +// The first parameter is the name of the test suite, and the second +// parameter is the name of the test within the test suite. +// +// The convention is to end the test suite name with "Test". For +// example, a test suite for the Foo class can be named FooTest. +// +// Test code should appear between braces after an invocation of +// this macro. Example: +// +// TEST(FooTest, InitializesCorrectly) { +// Foo foo; +// EXPECT_TRUE(foo.StatusIsOK()); +// } + +// Note that we call GetTestTypeId() instead of GetTypeId< +// ::testing::Test>() here to get the type ID of testing::Test. This +// is to work around a suspected linker bug when using Google Test as +// a framework on Mac OS X. The bug causes GetTypeId< +// ::testing::Test>() to return different values depending on whether +// the call is from the Google Test framework itself or from user test +// code. GetTestTypeId() is guaranteed to always return the same +// value, as it always calls GetTypeId<>() from the Google Test +// framework. +#define GTEST_TEST(test_suite_name, test_name) \ + GTEST_TEST_(test_suite_name, test_name, ::testing::Test, \ + ::testing::internal::GetTestTypeId()) + +// Define this macro to 1 to omit the definition of TEST(), which +// is a generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_TEST +#define TEST(test_suite_name, test_name) GTEST_TEST(test_suite_name, test_name) +#endif + +// Defines a test that uses a test fixture. +// +// The first parameter is the name of the test fixture class, which +// also doubles as the test suite name. The second parameter is the +// name of the test within the test suite. +// +// A test fixture class must be declared earlier. The user should put +// the test code between braces after using this macro. Example: +// +// class FooTest : public testing::Test { +// protected: +// void SetUp() override { b_.AddElement(3); } +// +// Foo a_; +// Foo b_; +// }; +// +// TEST_F(FooTest, InitializesCorrectly) { +// EXPECT_TRUE(a_.StatusIsOK()); +// } +// +// TEST_F(FooTest, ReturnsElementCountCorrectly) { +// EXPECT_EQ(a_.size(), 0); +// EXPECT_EQ(b_.size(), 1); +// } +#define GTEST_TEST_F(test_fixture, test_name)\ + GTEST_TEST_(test_fixture, test_name, test_fixture, \ + ::testing::internal::GetTypeId()) +#if !GTEST_DONT_DEFINE_TEST_F +#define TEST_F(test_fixture, test_name) GTEST_TEST_F(test_fixture, test_name) +#endif + +// Returns a path to temporary directory. +// Tries to determine an appropriate directory for the platform. +GTEST_API_ std::string TempDir(); + +#ifdef _MSC_VER +# pragma warning(pop) +#endif + +// Dynamically registers a test with the framework. +// +// This is an advanced API only to be used when the `TEST` macros are +// insufficient. The macros should be preferred when possible, as they avoid +// most of the complexity of calling this function. +// +// The `factory` argument is a factory callable (move-constructible) object or +// function pointer that creates a new instance of the Test object. It +// handles ownership to the caller. The signature of the callable is +// `Fixture*()`, where `Fixture` is the test fixture class for the test. All +// tests registered with the same `test_suite_name` must return the same +// fixture type. This is checked at runtime. +// +// The framework will infer the fixture class from the factory and will call +// the `SetUpTestSuite` and `TearDownTestSuite` for it. +// +// Must be called before `RUN_ALL_TESTS()` is invoked, otherwise behavior is +// undefined. +// +// Use case example: +// +// class MyFixture : public ::testing::Test { +// public: +// // All of these optional, just like in regular macro usage. +// static void SetUpTestSuite() { ... } +// static void TearDownTestSuite() { ... } +// void SetUp() override { ... } +// void TearDown() override { ... } +// }; +// +// class MyTest : public MyFixture { +// public: +// explicit MyTest(int data) : data_(data) {} +// void TestBody() override { ... } +// +// private: +// int data_; +// }; +// +// void RegisterMyTests(const std::vector& values) { +// for (int v : values) { +// ::testing::RegisterTest( +// "MyFixture", ("Test" + std::to_string(v)).c_str(), nullptr, +// std::to_string(v).c_str(), +// __FILE__, __LINE__, +// // Important to use the fixture type as the return type here. +// [=]() -> MyFixture* { return new MyTest(v); }); +// } +// } +// ... +// int main(int argc, char** argv) { +// ::testing::InitGoogleTest(&argc, argv); +// std::vector values_to_test = LoadValuesFromConfig(); +// RegisterMyTests(values_to_test); +// ... +// return RUN_ALL_TESTS(); +// } +// +template +TestInfo* RegisterTest(const char* test_suite_name, const char* test_name, + const char* type_param, const char* value_param, + const char* file, int line, Factory factory) { + using TestT = typename std::remove_pointer::type; + + class FactoryImpl : public internal::TestFactoryBase { + public: + explicit FactoryImpl(Factory f) : factory_(std::move(f)) {} + Test* CreateTest() override { return factory_(); } + + private: + Factory factory_; + }; + + return internal::MakeAndRegisterTestInfo( + test_suite_name, test_name, type_param, value_param, + internal::CodeLocation(file, line), internal::GetTypeId(), + internal::SuiteApiResolver::GetSetUpCaseOrSuite(file, line), + internal::SuiteApiResolver::GetTearDownCaseOrSuite(file, line), + new FactoryImpl{std::move(factory)}); +} + +} // namespace testing + +// Use this function in main() to run all tests. It returns 0 if all +// tests are successful, or 1 otherwise. +// +// RUN_ALL_TESTS() should be invoked after the command line has been +// parsed by InitGoogleTest(). +// +// This function was formerly a macro; thus, it is in the global +// namespace and has an all-caps name. +int RUN_ALL_TESTS() GTEST_MUST_USE_RESULT_; + +inline int RUN_ALL_TESTS() { + return ::testing::UnitTest::GetInstance()->Run(); +} + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest_pred_impl.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest_pred_impl.h new file mode 100644 index 000000000000..96b36fabf5ad --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest_pred_impl.h @@ -0,0 +1,364 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file is AUTOMATICALLY GENERATED on 07/21/2021 by command +// 'gen_gtest_pred_impl.py 5'. DO NOT EDIT BY HAND! +// +// Implements a family of generic predicate assertion macros. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ + +#include "gtest/gtest-assertion-result.h" +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-port.h" + +namespace testing { + +// This header implements a family of generic predicate assertion +// macros: +// +// ASSERT_PRED_FORMAT1(pred_format, v1) +// ASSERT_PRED_FORMAT2(pred_format, v1, v2) +// ... +// +// where pred_format is a function or functor that takes n (in the +// case of ASSERT_PRED_FORMATn) values and their source expression +// text, and returns a testing::AssertionResult. See the definition +// of ASSERT_EQ in gtest.h for an example. +// +// If you don't care about formatting, you can use the more +// restrictive version: +// +// ASSERT_PRED1(pred, v1) +// ASSERT_PRED2(pred, v1, v2) +// ... +// +// where pred is an n-ary function or functor that returns bool, +// and the values v1, v2, ..., must support the << operator for +// streaming to std::ostream. +// +// We also define the EXPECT_* variations. +// +// For now we only support predicates whose arity is at most 5. +// Please email googletestframework@googlegroups.com if you need +// support for higher arities. + +// GTEST_ASSERT_ is the basic statement to which all of the assertions +// in this file reduce. Don't use this in your code. + +#define GTEST_ASSERT_(expression, on_failure) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (const ::testing::AssertionResult gtest_ar = (expression)) \ + ; \ + else \ + on_failure(gtest_ar.failure_message()) + + +// Helper function for implementing {EXPECT|ASSERT}_PRED1. Don't use +// this in your code. +template +AssertionResult AssertPred1Helper(const char* pred_text, + const char* e1, + Pred pred, + const T1& v1) { + if (pred(v1)) return AssertionSuccess(); + + return AssertionFailure() + << pred_text << "(" << e1 << ") evaluates to false, where" + << "\n" + << e1 << " evaluates to " << ::testing::PrintToString(v1); +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT1. +// Don't use this in your code. +#define GTEST_PRED_FORMAT1_(pred_format, v1, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, v1), \ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED1. Don't use +// this in your code. +#define GTEST_PRED1_(pred, v1, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred1Helper(#pred, \ + #v1, \ + pred, \ + v1), on_failure) + +// Unary predicate assertion macros. +#define EXPECT_PRED_FORMAT1(pred_format, v1) \ + GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED1(pred, v1) \ + GTEST_PRED1_(pred, v1, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT1(pred_format, v1) \ + GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED1(pred, v1) \ + GTEST_PRED1_(pred, v1, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED2. Don't use +// this in your code. +template +AssertionResult AssertPred2Helper(const char* pred_text, + const char* e1, + const char* e2, + Pred pred, + const T1& v1, + const T2& v2) { + if (pred(v1, v2)) return AssertionSuccess(); + + return AssertionFailure() + << pred_text << "(" << e1 << ", " << e2 + << ") evaluates to false, where" + << "\n" + << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" + << e2 << " evaluates to " << ::testing::PrintToString(v2); +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT2. +// Don't use this in your code. +#define GTEST_PRED_FORMAT2_(pred_format, v1, v2, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, v1, v2), \ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED2. Don't use +// this in your code. +#define GTEST_PRED2_(pred, v1, v2, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred2Helper(#pred, \ + #v1, \ + #v2, \ + pred, \ + v1, \ + v2), on_failure) + +// Binary predicate assertion macros. +#define EXPECT_PRED_FORMAT2(pred_format, v1, v2) \ + GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED2(pred, v1, v2) \ + GTEST_PRED2_(pred, v1, v2, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT2(pred_format, v1, v2) \ + GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED2(pred, v1, v2) \ + GTEST_PRED2_(pred, v1, v2, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED3. Don't use +// this in your code. +template +AssertionResult AssertPred3Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3) { + if (pred(v1, v2, v3)) return AssertionSuccess(); + + return AssertionFailure() + << pred_text << "(" << e1 << ", " << e2 << ", " << e3 + << ") evaluates to false, where" + << "\n" + << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" + << e2 << " evaluates to " << ::testing::PrintToString(v2) << "\n" + << e3 << " evaluates to " << ::testing::PrintToString(v3); +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT3. +// Don't use this in your code. +#define GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, v1, v2, v3), \ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED3. Don't use +// this in your code. +#define GTEST_PRED3_(pred, v1, v2, v3, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred3Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + pred, \ + v1, \ + v2, \ + v3), on_failure) + +// Ternary predicate assertion macros. +#define EXPECT_PRED_FORMAT3(pred_format, v1, v2, v3) \ + GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED3(pred, v1, v2, v3) \ + GTEST_PRED3_(pred, v1, v2, v3, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT3(pred_format, v1, v2, v3) \ + GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED3(pred, v1, v2, v3) \ + GTEST_PRED3_(pred, v1, v2, v3, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED4. Don't use +// this in your code. +template +AssertionResult AssertPred4Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + const char* e4, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4) { + if (pred(v1, v2, v3, v4)) return AssertionSuccess(); + + return AssertionFailure() + << pred_text << "(" << e1 << ", " << e2 << ", " << e3 << ", " << e4 + << ") evaluates to false, where" + << "\n" + << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" + << e2 << " evaluates to " << ::testing::PrintToString(v2) << "\n" + << e3 << " evaluates to " << ::testing::PrintToString(v3) << "\n" + << e4 << " evaluates to " << ::testing::PrintToString(v4); +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT4. +// Don't use this in your code. +#define GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, v1, v2, v3, v4), \ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED4. Don't use +// this in your code. +#define GTEST_PRED4_(pred, v1, v2, v3, v4, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred4Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + #v4, \ + pred, \ + v1, \ + v2, \ + v3, \ + v4), on_failure) + +// 4-ary predicate assertion macros. +#define EXPECT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ + GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED4(pred, v1, v2, v3, v4) \ + GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ + GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED4(pred, v1, v2, v3, v4) \ + GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED5. Don't use +// this in your code. +template +AssertionResult AssertPred5Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + const char* e4, + const char* e5, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4, + const T5& v5) { + if (pred(v1, v2, v3, v4, v5)) return AssertionSuccess(); + + return AssertionFailure() + << pred_text << "(" << e1 << ", " << e2 << ", " << e3 << ", " << e4 + << ", " << e5 << ") evaluates to false, where" + << "\n" + << e1 << " evaluates to " << ::testing::PrintToString(v1) << "\n" + << e2 << " evaluates to " << ::testing::PrintToString(v2) << "\n" + << e3 << " evaluates to " << ::testing::PrintToString(v3) << "\n" + << e4 << " evaluates to " << ::testing::PrintToString(v4) << "\n" + << e5 << " evaluates to " << ::testing::PrintToString(v5); +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT5. +// Don't use this in your code. +#define GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, #v5, v1, v2, v3, v4, v5), \ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED5. Don't use +// this in your code. +#define GTEST_PRED5_(pred, v1, v2, v3, v4, v5, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred5Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + #v4, \ + #v5, \ + pred, \ + v1, \ + v2, \ + v3, \ + v4, \ + v5), on_failure) + +// 5-ary predicate assertion macros. +#define EXPECT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ + GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED5(pred, v1, v2, v3, v4, v5) \ + GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ + GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED5(pred, v1, v2, v3, v4, v5) \ + GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) + + + +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest_prod.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest_prod.h new file mode 100644 index 000000000000..b22030a8a8f3 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/gtest_prod.h @@ -0,0 +1,60 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google C++ Testing and Mocking Framework definitions useful in production +// code. + +#ifndef GOOGLETEST_INCLUDE_GTEST_GTEST_PROD_H_ +#define GOOGLETEST_INCLUDE_GTEST_GTEST_PROD_H_ + +// When you need to test the private or protected members of a class, +// use the FRIEND_TEST macro to declare your tests as friends of the +// class. For example: +// +// class MyClass { +// private: +// void PrivateMethod(); +// FRIEND_TEST(MyClassTest, PrivateMethodWorks); +// }; +// +// class MyClassTest : public testing::Test { +// // ... +// }; +// +// TEST_F(MyClassTest, PrivateMethodWorks) { +// // Can call MyClass::PrivateMethod() here. +// } +// +// Note: The test class must be in the same namespace as the class being tested. +// For example, putting MyClassTest in an anonymous namespace will not work. + +#define FRIEND_TEST(test_case_name, test_name)\ +friend class test_case_name##_##test_name##_Test + +#endif // GOOGLETEST_INCLUDE_GTEST_GTEST_PROD_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/README.md b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/README.md new file mode 100644 index 000000000000..0af3539abf11 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/README.md @@ -0,0 +1,58 @@ +# Customization Points + +The custom directory is an injection point for custom user configurations. + +## Header `gtest.h` + +### The following macros can be defined: + +* `GTEST_OS_STACK_TRACE_GETTER_` - The name of an implementation of + `OsStackTraceGetterInterface`. +* `GTEST_CUSTOM_TEMPDIR_FUNCTION_` - An override for `testing::TempDir()`. See + `testing::TempDir` for semantics and signature. + +## Header `gtest-port.h` + +The following macros can be defined: + +### Flag related macros: + +* `GTEST_FLAG(flag_name)` +* `GTEST_USE_OWN_FLAGFILE_FLAG_` - Define to 0 when the system provides its + own flagfile flag parsing. +* `GTEST_DECLARE_bool_(name)` +* `GTEST_DECLARE_int32_(name)` +* `GTEST_DECLARE_string_(name)` +* `GTEST_DEFINE_bool_(name, default_val, doc)` +* `GTEST_DEFINE_int32_(name, default_val, doc)` +* `GTEST_DEFINE_string_(name, default_val, doc)` +* `GTEST_FLAG_GET(flag_name)` +* `GTEST_FLAG_SET(flag_name, value)` + +### Logging: + +* `GTEST_LOG_(severity)` +* `GTEST_CHECK_(condition)` +* Functions `LogToStderr()` and `FlushInfoLog()` have to be provided too. + +### Threading: + +* `GTEST_HAS_NOTIFICATION_` - Enabled if Notification is already provided. +* `GTEST_HAS_MUTEX_AND_THREAD_LOCAL_` - Enabled if `Mutex` and `ThreadLocal` + are already provided. Must also provide `GTEST_DECLARE_STATIC_MUTEX_(mutex)` + and `GTEST_DEFINE_STATIC_MUTEX_(mutex)` +* `GTEST_EXCLUSIVE_LOCK_REQUIRED_(locks)` +* `GTEST_LOCK_EXCLUDED_(locks)` + +### Underlying library support features + +* `GTEST_HAS_CXXABI_H_` + +### Exporting API symbols: + +* `GTEST_API_` - Specifier for exported symbols. + +## Header `gtest-printers.h` + +* See documentation at `gtest/gtest-printers.h` for details on how to define a + custom printer. diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest-port.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest-port.h new file mode 100644 index 000000000000..db02881c0c89 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest-port.h @@ -0,0 +1,37 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Injection point for custom user configurations. See README for details +// +// ** Custom implementation starts here ** + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PORT_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest-printers.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest-printers.h new file mode 100644 index 000000000000..b9495d83783b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest-printers.h @@ -0,0 +1,42 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// This file provides an injection point for custom printers in a local +// installation of gTest. +// It will be included from gtest-printers.h and the overrides in this file +// will be visible to everyone. +// +// Injection point for custom user configurations. See README for details +// +// ** Custom implementation starts here ** + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_PRINTERS_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest.h new file mode 100644 index 000000000000..afaaf17ba28e --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/custom/gtest.h @@ -0,0 +1,37 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Injection point for custom user configurations. See README for details +// +// ** Custom implementation starts here ** + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_H_ + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_CUSTOM_GTEST_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-death-test-internal.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-death-test-internal.h new file mode 100644 index 000000000000..128e0f4c28c4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-death-test-internal.h @@ -0,0 +1,305 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file defines internal utilities needed for implementing +// death tests. They are subject to change without notice. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ + +#include "gtest/gtest-matchers.h" +#include "gtest/internal/gtest-internal.h" + +#include +#include + +GTEST_DECLARE_string_(internal_run_death_test); + +namespace testing { +namespace internal { + +// Names of the flags (needed for parsing Google Test flags). +const char kDeathTestStyleFlag[] = "death_test_style"; +const char kDeathTestUseFork[] = "death_test_use_fork"; +const char kInternalRunDeathTestFlag[] = "internal_run_death_test"; + +#if GTEST_HAS_DEATH_TEST + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +// DeathTest is a class that hides much of the complexity of the +// GTEST_DEATH_TEST_ macro. It is abstract; its static Create method +// returns a concrete class that depends on the prevailing death test +// style, as defined by the --gtest_death_test_style and/or +// --gtest_internal_run_death_test flags. + +// In describing the results of death tests, these terms are used with +// the corresponding definitions: +// +// exit status: The integer exit information in the format specified +// by wait(2) +// exit code: The integer code passed to exit(3), _exit(2), or +// returned from main() +class GTEST_API_ DeathTest { + public: + // Create returns false if there was an error determining the + // appropriate action to take for the current death test; for example, + // if the gtest_death_test_style flag is set to an invalid value. + // The LastMessage method will return a more detailed message in that + // case. Otherwise, the DeathTest pointer pointed to by the "test" + // argument is set. If the death test should be skipped, the pointer + // is set to NULL; otherwise, it is set to the address of a new concrete + // DeathTest object that controls the execution of the current test. + static bool Create(const char* statement, Matcher matcher, + const char* file, int line, DeathTest** test); + DeathTest(); + virtual ~DeathTest() { } + + // A helper class that aborts a death test when it's deleted. + class ReturnSentinel { + public: + explicit ReturnSentinel(DeathTest* test) : test_(test) { } + ~ReturnSentinel() { test_->Abort(TEST_ENCOUNTERED_RETURN_STATEMENT); } + private: + DeathTest* const test_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(ReturnSentinel); + } GTEST_ATTRIBUTE_UNUSED_; + + // An enumeration of possible roles that may be taken when a death + // test is encountered. EXECUTE means that the death test logic should + // be executed immediately. OVERSEE means that the program should prepare + // the appropriate environment for a child process to execute the death + // test, then wait for it to complete. + enum TestRole { OVERSEE_TEST, EXECUTE_TEST }; + + // An enumeration of the three reasons that a test might be aborted. + enum AbortReason { + TEST_ENCOUNTERED_RETURN_STATEMENT, + TEST_THREW_EXCEPTION, + TEST_DID_NOT_DIE + }; + + // Assumes one of the above roles. + virtual TestRole AssumeRole() = 0; + + // Waits for the death test to finish and returns its status. + virtual int Wait() = 0; + + // Returns true if the death test passed; that is, the test process + // exited during the test, its exit status matches a user-supplied + // predicate, and its stderr output matches a user-supplied regular + // expression. + // The user-supplied predicate may be a macro expression rather + // than a function pointer or functor, or else Wait and Passed could + // be combined. + virtual bool Passed(bool exit_status_ok) = 0; + + // Signals that the death test did not die as expected. + virtual void Abort(AbortReason reason) = 0; + + // Returns a human-readable outcome message regarding the outcome of + // the last death test. + static const char* LastMessage(); + + static void set_last_death_test_message(const std::string& message); + + private: + // A string containing a description of the outcome of the last death test. + static std::string last_death_test_message_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DeathTest); +}; + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +// Factory interface for death tests. May be mocked out for testing. +class DeathTestFactory { + public: + virtual ~DeathTestFactory() { } + virtual bool Create(const char* statement, + Matcher matcher, const char* file, + int line, DeathTest** test) = 0; +}; + +// A concrete DeathTestFactory implementation for normal use. +class DefaultDeathTestFactory : public DeathTestFactory { + public: + bool Create(const char* statement, Matcher matcher, + const char* file, int line, DeathTest** test) override; +}; + +// Returns true if exit_status describes a process that was terminated +// by a signal, or exited normally with a nonzero exit code. +GTEST_API_ bool ExitedUnsuccessfully(int exit_status); + +// A string passed to EXPECT_DEATH (etc.) is caught by one of these overloads +// and interpreted as a regex (rather than an Eq matcher) for legacy +// compatibility. +inline Matcher MakeDeathTestMatcher( + ::testing::internal::RE regex) { + return ContainsRegex(regex.pattern()); +} +inline Matcher MakeDeathTestMatcher(const char* regex) { + return ContainsRegex(regex); +} +inline Matcher MakeDeathTestMatcher( + const ::std::string& regex) { + return ContainsRegex(regex); +} + +// If a Matcher is passed to EXPECT_DEATH (etc.), it's +// used directly. +inline Matcher MakeDeathTestMatcher( + Matcher matcher) { + return matcher; +} + +// Traps C++ exceptions escaping statement and reports them as test +// failures. Note that trapping SEH exceptions is not implemented here. +# if GTEST_HAS_EXCEPTIONS +# define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } catch (const ::std::exception& gtest_exception) { \ + fprintf(\ + stderr, \ + "\n%s: Caught std::exception-derived exception escaping the " \ + "death test statement. Exception message: %s\n", \ + ::testing::internal::FormatFileLocation(__FILE__, __LINE__).c_str(), \ + gtest_exception.what()); \ + fflush(stderr); \ + death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ + } catch (...) { \ + death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ + } + +# else +# define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) + +# endif + +// This macro is for implementing ASSERT_DEATH*, EXPECT_DEATH*, +// ASSERT_EXIT*, and EXPECT_EXIT*. +#define GTEST_DEATH_TEST_(statement, predicate, regex_or_matcher, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + ::testing::internal::DeathTest* gtest_dt; \ + if (!::testing::internal::DeathTest::Create( \ + #statement, \ + ::testing::internal::MakeDeathTestMatcher(regex_or_matcher), \ + __FILE__, __LINE__, >est_dt)) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ + } \ + if (gtest_dt != nullptr) { \ + std::unique_ptr< ::testing::internal::DeathTest> gtest_dt_ptr(gtest_dt); \ + switch (gtest_dt->AssumeRole()) { \ + case ::testing::internal::DeathTest::OVERSEE_TEST: \ + if (!gtest_dt->Passed(predicate(gtest_dt->Wait()))) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ + } \ + break; \ + case ::testing::internal::DeathTest::EXECUTE_TEST: { \ + ::testing::internal::DeathTest::ReturnSentinel gtest_sentinel( \ + gtest_dt); \ + GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, gtest_dt); \ + gtest_dt->Abort(::testing::internal::DeathTest::TEST_DID_NOT_DIE); \ + break; \ + } \ + } \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__) \ + : fail(::testing::internal::DeathTest::LastMessage()) +// The symbol "fail" here expands to something into which a message +// can be streamed. + +// This macro is for implementing ASSERT/EXPECT_DEBUG_DEATH when compiled in +// NDEBUG mode. In this case we need the statements to be executed and the macro +// must accept a streamed message even though the message is never printed. +// The regex object is not evaluated, but it is used to prevent "unused" +// warnings and to avoid an expression that doesn't compile in debug mode. +#define GTEST_EXECUTE_STATEMENT_(statement, regex_or_matcher) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } else if (!::testing::internal::AlwaysTrue()) { \ + ::testing::internal::MakeDeathTestMatcher(regex_or_matcher); \ + } else \ + ::testing::Message() + +// A class representing the parsed contents of the +// --gtest_internal_run_death_test flag, as it existed when +// RUN_ALL_TESTS was called. +class InternalRunDeathTestFlag { + public: + InternalRunDeathTestFlag(const std::string& a_file, + int a_line, + int an_index, + int a_write_fd) + : file_(a_file), line_(a_line), index_(an_index), + write_fd_(a_write_fd) {} + + ~InternalRunDeathTestFlag() { + if (write_fd_ >= 0) + posix::Close(write_fd_); + } + + const std::string& file() const { return file_; } + int line() const { return line_; } + int index() const { return index_; } + int write_fd() const { return write_fd_; } + + private: + std::string file_; + int line_; + int index_; + int write_fd_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(InternalRunDeathTestFlag); +}; + +// Returns a newly created InternalRunDeathTestFlag object with fields +// initialized from the GTEST_FLAG(internal_run_death_test) flag if +// the flag is specified; otherwise returns NULL. +InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag(); + +#endif // GTEST_HAS_DEATH_TEST + +} // namespace internal +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-filepath.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-filepath.h new file mode 100644 index 000000000000..4dfe2e22272d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-filepath.h @@ -0,0 +1,213 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Test filepath utilities +// +// This header file declares classes and functions used internally by +// Google Test. They are subject to change without notice. +// +// This file is #included in gtest/internal/gtest-internal.h. +// Do not include this header file separately! + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ + +#include "gtest/internal/gtest-string.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +namespace testing { +namespace internal { + +// FilePath - a class for file and directory pathname manipulation which +// handles platform-specific conventions (like the pathname separator). +// Used for helper functions for naming files in a directory for xml output. +// Except for Set methods, all methods are const or static, which provides an +// "immutable value object" -- useful for peace of mind. +// A FilePath with a value ending in a path separator ("like/this/") represents +// a directory, otherwise it is assumed to represent a file. In either case, +// it may or may not represent an actual file or directory in the file system. +// Names are NOT checked for syntax correctness -- no checking for illegal +// characters, malformed paths, etc. + +class GTEST_API_ FilePath { + public: + FilePath() : pathname_("") { } + FilePath(const FilePath& rhs) : pathname_(rhs.pathname_) { } + + explicit FilePath(const std::string& pathname) : pathname_(pathname) { + Normalize(); + } + + FilePath& operator=(const FilePath& rhs) { + Set(rhs); + return *this; + } + + void Set(const FilePath& rhs) { + pathname_ = rhs.pathname_; + } + + const std::string& string() const { return pathname_; } + const char* c_str() const { return pathname_.c_str(); } + + // Returns the current working directory, or "" if unsuccessful. + static FilePath GetCurrentDir(); + + // Given directory = "dir", base_name = "test", number = 0, + // extension = "xml", returns "dir/test.xml". If number is greater + // than zero (e.g., 12), returns "dir/test_12.xml". + // On Windows platform, uses \ as the separator rather than /. + static FilePath MakeFileName(const FilePath& directory, + const FilePath& base_name, + int number, + const char* extension); + + // Given directory = "dir", relative_path = "test.xml", + // returns "dir/test.xml". + // On Windows, uses \ as the separator rather than /. + static FilePath ConcatPaths(const FilePath& directory, + const FilePath& relative_path); + + // Returns a pathname for a file that does not currently exist. The pathname + // will be directory/base_name.extension or + // directory/base_name_.extension if directory/base_name.extension + // already exists. The number will be incremented until a pathname is found + // that does not already exist. + // Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. + // There could be a race condition if two or more processes are calling this + // function at the same time -- they could both pick the same filename. + static FilePath GenerateUniqueFileName(const FilePath& directory, + const FilePath& base_name, + const char* extension); + + // Returns true if and only if the path is "". + bool IsEmpty() const { return pathname_.empty(); } + + // If input name has a trailing separator character, removes it and returns + // the name, otherwise return the name string unmodified. + // On Windows platform, uses \ as the separator, other platforms use /. + FilePath RemoveTrailingPathSeparator() const; + + // Returns a copy of the FilePath with the directory part removed. + // Example: FilePath("path/to/file").RemoveDirectoryName() returns + // FilePath("file"). If there is no directory part ("just_a_file"), it returns + // the FilePath unmodified. If there is no file part ("just_a_dir/") it + // returns an empty FilePath (""). + // On Windows platform, '\' is the path separator, otherwise it is '/'. + FilePath RemoveDirectoryName() const; + + // RemoveFileName returns the directory path with the filename removed. + // Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". + // If the FilePath is "a_file" or "/a_file", RemoveFileName returns + // FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does + // not have a file, like "just/a/dir/", it returns the FilePath unmodified. + // On Windows platform, '\' is the path separator, otherwise it is '/'. + FilePath RemoveFileName() const; + + // Returns a copy of the FilePath with the case-insensitive extension removed. + // Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns + // FilePath("dir/file"). If a case-insensitive extension is not + // found, returns a copy of the original FilePath. + FilePath RemoveExtension(const char* extension) const; + + // Creates directories so that path exists. Returns true if successful or if + // the directories already exist; returns false if unable to create + // directories for any reason. Will also return false if the FilePath does + // not represent a directory (that is, it doesn't end with a path separator). + bool CreateDirectoriesRecursively() const; + + // Create the directory so that path exists. Returns true if successful or + // if the directory already exists; returns false if unable to create the + // directory for any reason, including if the parent directory does not + // exist. Not named "CreateDirectory" because that's a macro on Windows. + bool CreateFolder() const; + + // Returns true if FilePath describes something in the file-system, + // either a file, directory, or whatever, and that something exists. + bool FileOrDirectoryExists() const; + + // Returns true if pathname describes a directory in the file-system + // that exists. + bool DirectoryExists() const; + + // Returns true if FilePath ends with a path separator, which indicates that + // it is intended to represent a directory. Returns false otherwise. + // This does NOT check that a directory (or file) actually exists. + bool IsDirectory() const; + + // Returns true if pathname describes a root directory. (Windows has one + // root directory per disk drive.) + bool IsRootDirectory() const; + + // Returns true if pathname describes an absolute path. + bool IsAbsolutePath() const; + + private: + // Replaces multiple consecutive separators with a single separator. + // For example, "bar///foo" becomes "bar/foo". Does not eliminate other + // redundancies that might be in a pathname involving "." or "..". + // + // A pathname with multiple consecutive separators may occur either through + // user error or as a result of some scripts or APIs that generate a pathname + // with a trailing separator. On other platforms the same API or script + // may NOT generate a pathname with a trailing "/". Then elsewhere that + // pathname may have another "/" and pathname components added to it, + // without checking for the separator already being there. + // The script language and operating system may allow paths like "foo//bar" + // but some of the functions in FilePath will not handle that correctly. In + // particular, RemoveTrailingPathSeparator() only removes one separator, and + // it is called in CreateDirectoriesRecursively() assuming that it will change + // a pathname from directory syntax (trailing separator) to filename syntax. + // + // On Windows this method also replaces the alternate path separator '/' with + // the primary path separator '\\', so that for example "bar\\/\\foo" becomes + // "bar\\foo". + + void Normalize(); + + // Returns a pointer to the last occurrence of a valid path separator in + // the FilePath. On Windows, for example, both '/' and '\' are valid path + // separators. Returns NULL if no path separator was found. + const char* FindLastPathSeparator() const; + + std::string pathname_; +}; // class FilePath + +} // namespace internal +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-internal.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-internal.h new file mode 100644 index 000000000000..fc15e94785d4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-internal.h @@ -0,0 +1,1562 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file declares functions and macros used internally by +// Google Test. They are subject to change without notice. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ + +#include "gtest/internal/gtest-port.h" + +#if GTEST_OS_LINUX +# include +# include +# include +# include +#endif // GTEST_OS_LINUX + +#if GTEST_HAS_EXCEPTIONS +# include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest-message.h" +#include "gtest/internal/gtest-filepath.h" +#include "gtest/internal/gtest-string.h" +#include "gtest/internal/gtest-type-util.h" + +// Due to C++ preprocessor weirdness, we need double indirection to +// concatenate two tokens when one of them is __LINE__. Writing +// +// foo ## __LINE__ +// +// will result in the token foo__LINE__, instead of foo followed by +// the current line number. For more details, see +// http://www.parashift.com/c++-faq-lite/misc-technical-issues.html#faq-39.6 +#define GTEST_CONCAT_TOKEN_(foo, bar) GTEST_CONCAT_TOKEN_IMPL_(foo, bar) +#define GTEST_CONCAT_TOKEN_IMPL_(foo, bar) foo ## bar + +// Stringifies its argument. +// Work around a bug in visual studio which doesn't accept code like this: +// +// #define GTEST_STRINGIFY_(name) #name +// #define MACRO(a, b, c) ... GTEST_STRINGIFY_(a) ... +// MACRO(, x, y) +// +// Complaining about the argument to GTEST_STRINGIFY_ being empty. +// This is allowed by the spec. +#define GTEST_STRINGIFY_HELPER_(name, ...) #name +#define GTEST_STRINGIFY_(...) GTEST_STRINGIFY_HELPER_(__VA_ARGS__, ) + +namespace proto2 { +class MessageLite; +} + +namespace testing { + +// Forward declarations. + +class AssertionResult; // Result of an assertion. +class Message; // Represents a failure message. +class Test; // Represents a test. +class TestInfo; // Information about a test. +class TestPartResult; // Result of a test part. +class UnitTest; // A collection of test suites. + +template +::std::string PrintToString(const T& value); + +namespace internal { + +struct TraceInfo; // Information about a trace point. +class TestInfoImpl; // Opaque implementation of TestInfo +class UnitTestImpl; // Opaque implementation of UnitTest + +// The text used in failure messages to indicate the start of the +// stack trace. +GTEST_API_ extern const char kStackTraceMarker[]; + +// An IgnoredValue object can be implicitly constructed from ANY value. +class IgnoredValue { + struct Sink {}; + public: + // This constructor template allows any value to be implicitly + // converted to IgnoredValue. The object has no data member and + // doesn't try to remember anything about the argument. We + // deliberately omit the 'explicit' keyword in order to allow the + // conversion to be implicit. + // Disable the conversion if T already has a magical conversion operator. + // Otherwise we get ambiguity. + template ::value, + int>::type = 0> + IgnoredValue(const T& /* ignored */) {} // NOLINT(runtime/explicit) +}; + +// Appends the user-supplied message to the Google-Test-generated message. +GTEST_API_ std::string AppendUserMessage( + const std::string& gtest_msg, const Message& user_msg); + +#if GTEST_HAS_EXCEPTIONS + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4275 \ +/* an exported class was derived from a class that was not exported */) + +// This exception is thrown by (and only by) a failed Google Test +// assertion when GTEST_FLAG(throw_on_failure) is true (if exceptions +// are enabled). We derive it from std::runtime_error, which is for +// errors presumably detectable only at run time. Since +// std::runtime_error inherits from std::exception, many testing +// frameworks know how to extract and print the message inside it. +class GTEST_API_ GoogleTestFailureException : public ::std::runtime_error { + public: + explicit GoogleTestFailureException(const TestPartResult& failure); +}; + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4275 + +#endif // GTEST_HAS_EXCEPTIONS + +namespace edit_distance { +// Returns the optimal edits to go from 'left' to 'right'. +// All edits cost the same, with replace having lower priority than +// add/remove. +// Simple implementation of the Wagner-Fischer algorithm. +// See http://en.wikipedia.org/wiki/Wagner-Fischer_algorithm +enum EditType { kMatch, kAdd, kRemove, kReplace }; +GTEST_API_ std::vector CalculateOptimalEdits( + const std::vector& left, const std::vector& right); + +// Same as above, but the input is represented as strings. +GTEST_API_ std::vector CalculateOptimalEdits( + const std::vector& left, + const std::vector& right); + +// Create a diff of the input strings in Unified diff format. +GTEST_API_ std::string CreateUnifiedDiff(const std::vector& left, + const std::vector& right, + size_t context = 2); + +} // namespace edit_distance + +// Calculate the diff between 'left' and 'right' and return it in unified diff +// format. +// If not null, stores in 'total_line_count' the total number of lines found +// in left + right. +GTEST_API_ std::string DiffStrings(const std::string& left, + const std::string& right, + size_t* total_line_count); + +// Constructs and returns the message for an equality assertion +// (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. +// +// The first four parameters are the expressions used in the assertion +// and their values, as strings. For example, for ASSERT_EQ(foo, bar) +// where foo is 5 and bar is 6, we have: +// +// expected_expression: "foo" +// actual_expression: "bar" +// expected_value: "5" +// actual_value: "6" +// +// The ignoring_case parameter is true if and only if the assertion is a +// *_STRCASEEQ*. When it's true, the string " (ignoring case)" will +// be inserted into the message. +GTEST_API_ AssertionResult EqFailure(const char* expected_expression, + const char* actual_expression, + const std::string& expected_value, + const std::string& actual_value, + bool ignoring_case); + +// Constructs a failure message for Boolean assertions such as EXPECT_TRUE. +GTEST_API_ std::string GetBoolAssertionFailureMessage( + const AssertionResult& assertion_result, + const char* expression_text, + const char* actual_predicate_value, + const char* expected_predicate_value); + +// This template class represents an IEEE floating-point number +// (either single-precision or double-precision, depending on the +// template parameters). +// +// The purpose of this class is to do more sophisticated number +// comparison. (Due to round-off error, etc, it's very unlikely that +// two floating-points will be equal exactly. Hence a naive +// comparison by the == operation often doesn't work.) +// +// Format of IEEE floating-point: +// +// The most-significant bit being the leftmost, an IEEE +// floating-point looks like +// +// sign_bit exponent_bits fraction_bits +// +// Here, sign_bit is a single bit that designates the sign of the +// number. +// +// For float, there are 8 exponent bits and 23 fraction bits. +// +// For double, there are 11 exponent bits and 52 fraction bits. +// +// More details can be found at +// http://en.wikipedia.org/wiki/IEEE_floating-point_standard. +// +// Template parameter: +// +// RawType: the raw floating-point type (either float or double) +template +class FloatingPoint { + public: + // Defines the unsigned integer type that has the same size as the + // floating point number. + typedef typename TypeWithSize::UInt Bits; + + // Constants. + + // # of bits in a number. + static const size_t kBitCount = 8*sizeof(RawType); + + // # of fraction bits in a number. + static const size_t kFractionBitCount = + std::numeric_limits::digits - 1; + + // # of exponent bits in a number. + static const size_t kExponentBitCount = kBitCount - 1 - kFractionBitCount; + + // The mask for the sign bit. + static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); + + // The mask for the fraction bits. + static const Bits kFractionBitMask = + ~static_cast(0) >> (kExponentBitCount + 1); + + // The mask for the exponent bits. + static const Bits kExponentBitMask = ~(kSignBitMask | kFractionBitMask); + + // How many ULP's (Units in the Last Place) we want to tolerate when + // comparing two numbers. The larger the value, the more error we + // allow. A 0 value means that two numbers must be exactly the same + // to be considered equal. + // + // The maximum error of a single floating-point operation is 0.5 + // units in the last place. On Intel CPU's, all floating-point + // calculations are done with 80-bit precision, while double has 64 + // bits. Therefore, 4 should be enough for ordinary use. + // + // See the following article for more details on ULP: + // http://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ + static const uint32_t kMaxUlps = 4; + + // Constructs a FloatingPoint from a raw floating-point number. + // + // On an Intel CPU, passing a non-normalized NAN (Not a Number) + // around may change its bits, although the new value is guaranteed + // to be also a NAN. Therefore, don't expect this constructor to + // preserve the bits in x when x is a NAN. + explicit FloatingPoint(const RawType& x) { u_.value_ = x; } + + // Static methods + + // Reinterprets a bit pattern as a floating-point number. + // + // This function is needed to test the AlmostEquals() method. + static RawType ReinterpretBits(const Bits bits) { + FloatingPoint fp(0); + fp.u_.bits_ = bits; + return fp.u_.value_; + } + + // Returns the floating-point number that represent positive infinity. + static RawType Infinity() { + return ReinterpretBits(kExponentBitMask); + } + + // Returns the maximum representable finite floating-point number. + static RawType Max(); + + // Non-static methods + + // Returns the bits that represents this number. + const Bits &bits() const { return u_.bits_; } + + // Returns the exponent bits of this number. + Bits exponent_bits() const { return kExponentBitMask & u_.bits_; } + + // Returns the fraction bits of this number. + Bits fraction_bits() const { return kFractionBitMask & u_.bits_; } + + // Returns the sign bit of this number. + Bits sign_bit() const { return kSignBitMask & u_.bits_; } + + // Returns true if and only if this is NAN (not a number). + bool is_nan() const { + // It's a NAN if the exponent bits are all ones and the fraction + // bits are not entirely zeros. + return (exponent_bits() == kExponentBitMask) && (fraction_bits() != 0); + } + + // Returns true if and only if this number is at most kMaxUlps ULP's away + // from rhs. In particular, this function: + // + // - returns false if either number is (or both are) NAN. + // - treats really large numbers as almost equal to infinity. + // - thinks +0.0 and -0.0 are 0 DLP's apart. + bool AlmostEquals(const FloatingPoint& rhs) const { + // The IEEE standard says that any comparison operation involving + // a NAN must return false. + if (is_nan() || rhs.is_nan()) return false; + + return DistanceBetweenSignAndMagnitudeNumbers(u_.bits_, rhs.u_.bits_) + <= kMaxUlps; + } + + private: + // The data type used to store the actual floating-point number. + union FloatingPointUnion { + RawType value_; // The raw floating-point number. + Bits bits_; // The bits that represent the number. + }; + + // Converts an integer from the sign-and-magnitude representation to + // the biased representation. More precisely, let N be 2 to the + // power of (kBitCount - 1), an integer x is represented by the + // unsigned number x + N. + // + // For instance, + // + // -N + 1 (the most negative number representable using + // sign-and-magnitude) is represented by 1; + // 0 is represented by N; and + // N - 1 (the biggest number representable using + // sign-and-magnitude) is represented by 2N - 1. + // + // Read http://en.wikipedia.org/wiki/Signed_number_representations + // for more details on signed number representations. + static Bits SignAndMagnitudeToBiased(const Bits &sam) { + if (kSignBitMask & sam) { + // sam represents a negative number. + return ~sam + 1; + } else { + // sam represents a positive number. + return kSignBitMask | sam; + } + } + + // Given two numbers in the sign-and-magnitude representation, + // returns the distance between them as an unsigned number. + static Bits DistanceBetweenSignAndMagnitudeNumbers(const Bits &sam1, + const Bits &sam2) { + const Bits biased1 = SignAndMagnitudeToBiased(sam1); + const Bits biased2 = SignAndMagnitudeToBiased(sam2); + return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); + } + + FloatingPointUnion u_; +}; + +// We cannot use std::numeric_limits::max() as it clashes with the max() +// macro defined by . +template <> +inline float FloatingPoint::Max() { return FLT_MAX; } +template <> +inline double FloatingPoint::Max() { return DBL_MAX; } + +// Typedefs the instances of the FloatingPoint template class that we +// care to use. +typedef FloatingPoint Float; +typedef FloatingPoint Double; + +// In order to catch the mistake of putting tests that use different +// test fixture classes in the same test suite, we need to assign +// unique IDs to fixture classes and compare them. The TypeId type is +// used to hold such IDs. The user should treat TypeId as an opaque +// type: the only operation allowed on TypeId values is to compare +// them for equality using the == operator. +typedef const void* TypeId; + +template +class TypeIdHelper { + public: + // dummy_ must not have a const type. Otherwise an overly eager + // compiler (e.g. MSVC 7.1 & 8.0) may try to merge + // TypeIdHelper::dummy_ for different Ts as an "optimization". + static bool dummy_; +}; + +template +bool TypeIdHelper::dummy_ = false; + +// GetTypeId() returns the ID of type T. Different values will be +// returned for different types. Calling the function twice with the +// same type argument is guaranteed to return the same ID. +template +TypeId GetTypeId() { + // The compiler is required to allocate a different + // TypeIdHelper::dummy_ variable for each T used to instantiate + // the template. Therefore, the address of dummy_ is guaranteed to + // be unique. + return &(TypeIdHelper::dummy_); +} + +// Returns the type ID of ::testing::Test. Always call this instead +// of GetTypeId< ::testing::Test>() to get the type ID of +// ::testing::Test, as the latter may give the wrong result due to a +// suspected linker bug when compiling Google Test as a Mac OS X +// framework. +GTEST_API_ TypeId GetTestTypeId(); + +// Defines the abstract factory interface that creates instances +// of a Test object. +class TestFactoryBase { + public: + virtual ~TestFactoryBase() {} + + // Creates a test instance to run. The instance is both created and destroyed + // within TestInfoImpl::Run() + virtual Test* CreateTest() = 0; + + protected: + TestFactoryBase() {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestFactoryBase); +}; + +// This class provides implementation of TeastFactoryBase interface. +// It is used in TEST and TEST_F macros. +template +class TestFactoryImpl : public TestFactoryBase { + public: + Test* CreateTest() override { return new TestClass; } +}; + +#if GTEST_OS_WINDOWS + +// Predicate-formatters for implementing the HRESULT checking macros +// {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED} +// We pass a long instead of HRESULT to avoid causing an +// include dependency for the HRESULT type. +GTEST_API_ AssertionResult IsHRESULTSuccess(const char* expr, + long hr); // NOLINT +GTEST_API_ AssertionResult IsHRESULTFailure(const char* expr, + long hr); // NOLINT + +#endif // GTEST_OS_WINDOWS + +// Types of SetUpTestSuite() and TearDownTestSuite() functions. +using SetUpTestSuiteFunc = void (*)(); +using TearDownTestSuiteFunc = void (*)(); + +struct CodeLocation { + CodeLocation(const std::string& a_file, int a_line) + : file(a_file), line(a_line) {} + + std::string file; + int line; +}; + +// Helper to identify which setup function for TestCase / TestSuite to call. +// Only one function is allowed, either TestCase or TestSute but not both. + +// Utility functions to help SuiteApiResolver +using SetUpTearDownSuiteFuncType = void (*)(); + +inline SetUpTearDownSuiteFuncType GetNotDefaultOrNull( + SetUpTearDownSuiteFuncType a, SetUpTearDownSuiteFuncType def) { + return a == def ? nullptr : a; +} + +template +// Note that SuiteApiResolver inherits from T because +// SetUpTestSuite()/TearDownTestSuite() could be protected. This way +// SuiteApiResolver can access them. +struct SuiteApiResolver : T { + // testing::Test is only forward declared at this point. So we make it a + // dependent class for the compiler to be OK with it. + using Test = + typename std::conditional::type; + + static SetUpTearDownSuiteFuncType GetSetUpCaseOrSuite(const char* filename, + int line_num) { +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + SetUpTearDownSuiteFuncType test_case_fp = + GetNotDefaultOrNull(&T::SetUpTestCase, &Test::SetUpTestCase); + SetUpTearDownSuiteFuncType test_suite_fp = + GetNotDefaultOrNull(&T::SetUpTestSuite, &Test::SetUpTestSuite); + + GTEST_CHECK_(!test_case_fp || !test_suite_fp) + << "Test can not provide both SetUpTestSuite and SetUpTestCase, please " + "make sure there is only one present at " + << filename << ":" << line_num; + + return test_case_fp != nullptr ? test_case_fp : test_suite_fp; +#else + (void)(filename); + (void)(line_num); + return &T::SetUpTestSuite; +#endif + } + + static SetUpTearDownSuiteFuncType GetTearDownCaseOrSuite(const char* filename, + int line_num) { +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + SetUpTearDownSuiteFuncType test_case_fp = + GetNotDefaultOrNull(&T::TearDownTestCase, &Test::TearDownTestCase); + SetUpTearDownSuiteFuncType test_suite_fp = + GetNotDefaultOrNull(&T::TearDownTestSuite, &Test::TearDownTestSuite); + + GTEST_CHECK_(!test_case_fp || !test_suite_fp) + << "Test can not provide both TearDownTestSuite and TearDownTestCase," + " please make sure there is only one present at" + << filename << ":" << line_num; + + return test_case_fp != nullptr ? test_case_fp : test_suite_fp; +#else + (void)(filename); + (void)(line_num); + return &T::TearDownTestSuite; +#endif + } +}; + +// Creates a new TestInfo object and registers it with Google Test; +// returns the created object. +// +// Arguments: +// +// test_suite_name: name of the test suite +// name: name of the test +// type_param: the name of the test's type parameter, or NULL if +// this is not a typed or a type-parameterized test. +// value_param: text representation of the test's value parameter, +// or NULL if this is not a type-parameterized test. +// code_location: code location where the test is defined +// fixture_class_id: ID of the test fixture class +// set_up_tc: pointer to the function that sets up the test suite +// tear_down_tc: pointer to the function that tears down the test suite +// factory: pointer to the factory that creates a test object. +// The newly created TestInfo instance will assume +// ownership of the factory object. +GTEST_API_ TestInfo* MakeAndRegisterTestInfo( + const char* test_suite_name, const char* name, const char* type_param, + const char* value_param, CodeLocation code_location, + TypeId fixture_class_id, SetUpTestSuiteFunc set_up_tc, + TearDownTestSuiteFunc tear_down_tc, TestFactoryBase* factory); + +// If *pstr starts with the given prefix, modifies *pstr to be right +// past the prefix and returns true; otherwise leaves *pstr unchanged +// and returns false. None of pstr, *pstr, and prefix can be NULL. +GTEST_API_ bool SkipPrefix(const char* prefix, const char** pstr); + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +// State of the definition of a type-parameterized test suite. +class GTEST_API_ TypedTestSuitePState { + public: + TypedTestSuitePState() : registered_(false) {} + + // Adds the given test name to defined_test_names_ and return true + // if the test suite hasn't been registered; otherwise aborts the + // program. + bool AddTestName(const char* file, int line, const char* case_name, + const char* test_name) { + if (registered_) { + fprintf(stderr, + "%s Test %s must be defined before " + "REGISTER_TYPED_TEST_SUITE_P(%s, ...).\n", + FormatFileLocation(file, line).c_str(), test_name, case_name); + fflush(stderr); + posix::Abort(); + } + registered_tests_.insert( + ::std::make_pair(test_name, CodeLocation(file, line))); + return true; + } + + bool TestExists(const std::string& test_name) const { + return registered_tests_.count(test_name) > 0; + } + + const CodeLocation& GetCodeLocation(const std::string& test_name) const { + RegisteredTestsMap::const_iterator it = registered_tests_.find(test_name); + GTEST_CHECK_(it != registered_tests_.end()); + return it->second; + } + + // Verifies that registered_tests match the test names in + // defined_test_names_; returns registered_tests if successful, or + // aborts the program otherwise. + const char* VerifyRegisteredTestNames(const char* test_suite_name, + const char* file, int line, + const char* registered_tests); + + private: + typedef ::std::map RegisteredTestsMap; + + bool registered_; + RegisteredTestsMap registered_tests_; +}; + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +using TypedTestCasePState = TypedTestSuitePState; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +// Skips to the first non-space char after the first comma in 'str'; +// returns NULL if no comma is found in 'str'. +inline const char* SkipComma(const char* str) { + const char* comma = strchr(str, ','); + if (comma == nullptr) { + return nullptr; + } + while (IsSpace(*(++comma))) {} + return comma; +} + +// Returns the prefix of 'str' before the first comma in it; returns +// the entire string if it contains no comma. +inline std::string GetPrefixUntilComma(const char* str) { + const char* comma = strchr(str, ','); + return comma == nullptr ? str : std::string(str, comma); +} + +// Splits a given string on a given delimiter, populating a given +// vector with the fields. +void SplitString(const ::std::string& str, char delimiter, + ::std::vector< ::std::string>* dest); + +// The default argument to the template below for the case when the user does +// not provide a name generator. +struct DefaultNameGenerator { + template + static std::string GetName(int i) { + return StreamableToString(i); + } +}; + +template +struct NameGeneratorSelector { + typedef Provided type; +}; + +template +void GenerateNamesRecursively(internal::None, std::vector*, int) {} + +template +void GenerateNamesRecursively(Types, std::vector* result, int i) { + result->push_back(NameGenerator::template GetName(i)); + GenerateNamesRecursively(typename Types::Tail(), result, + i + 1); +} + +template +std::vector GenerateNames() { + std::vector result; + GenerateNamesRecursively(Types(), &result, 0); + return result; +} + +// TypeParameterizedTest::Register() +// registers a list of type-parameterized tests with Google Test. The +// return value is insignificant - we just need to return something +// such that we can call this function in a namespace scope. +// +// Implementation note: The GTEST_TEMPLATE_ macro declares a template +// template parameter. It's defined in gtest-type-util.h. +template +class TypeParameterizedTest { + public: + // 'index' is the index of the test in the type list 'Types' + // specified in INSTANTIATE_TYPED_TEST_SUITE_P(Prefix, TestSuite, + // Types). Valid values for 'index' are [0, N - 1] where N is the + // length of Types. + static bool Register(const char* prefix, const CodeLocation& code_location, + const char* case_name, const char* test_names, int index, + const std::vector& type_names = + GenerateNames()) { + typedef typename Types::Head Type; + typedef Fixture FixtureClass; + typedef typename GTEST_BIND_(TestSel, Type) TestClass; + + // First, registers the first type-parameterized test in the type + // list. + MakeAndRegisterTestInfo( + (std::string(prefix) + (prefix[0] == '\0' ? "" : "/") + case_name + + "/" + type_names[static_cast(index)]) + .c_str(), + StripTrailingSpaces(GetPrefixUntilComma(test_names)).c_str(), + GetTypeName().c_str(), + nullptr, // No value parameter. + code_location, GetTypeId(), + SuiteApiResolver::GetSetUpCaseOrSuite( + code_location.file.c_str(), code_location.line), + SuiteApiResolver::GetTearDownCaseOrSuite( + code_location.file.c_str(), code_location.line), + new TestFactoryImpl); + + // Next, recurses (at compile time) with the tail of the type list. + return TypeParameterizedTest::Register(prefix, + code_location, + case_name, + test_names, + index + 1, + type_names); + } +}; + +// The base case for the compile time recursion. +template +class TypeParameterizedTest { + public: + static bool Register(const char* /*prefix*/, const CodeLocation&, + const char* /*case_name*/, const char* /*test_names*/, + int /*index*/, + const std::vector& = + std::vector() /*type_names*/) { + return true; + } +}; + +GTEST_API_ void RegisterTypeParameterizedTestSuite(const char* test_suite_name, + CodeLocation code_location); +GTEST_API_ void RegisterTypeParameterizedTestSuiteInstantiation( + const char* case_name); + +// TypeParameterizedTestSuite::Register() +// registers *all combinations* of 'Tests' and 'Types' with Google +// Test. The return value is insignificant - we just need to return +// something such that we can call this function in a namespace scope. +template +class TypeParameterizedTestSuite { + public: + static bool Register(const char* prefix, CodeLocation code_location, + const TypedTestSuitePState* state, const char* case_name, + const char* test_names, + const std::vector& type_names = + GenerateNames()) { + RegisterTypeParameterizedTestSuiteInstantiation(case_name); + std::string test_name = StripTrailingSpaces( + GetPrefixUntilComma(test_names)); + if (!state->TestExists(test_name)) { + fprintf(stderr, "Failed to get code location for test %s.%s at %s.", + case_name, test_name.c_str(), + FormatFileLocation(code_location.file.c_str(), + code_location.line).c_str()); + fflush(stderr); + posix::Abort(); + } + const CodeLocation& test_location = state->GetCodeLocation(test_name); + + typedef typename Tests::Head Head; + + // First, register the first test in 'Test' for each type in 'Types'. + TypeParameterizedTest::Register( + prefix, test_location, case_name, test_names, 0, type_names); + + // Next, recurses (at compile time) with the tail of the test list. + return TypeParameterizedTestSuite::Register(prefix, code_location, + state, case_name, + SkipComma(test_names), + type_names); + } +}; + +// The base case for the compile time recursion. +template +class TypeParameterizedTestSuite { + public: + static bool Register(const char* /*prefix*/, const CodeLocation&, + const TypedTestSuitePState* /*state*/, + const char* /*case_name*/, const char* /*test_names*/, + const std::vector& = + std::vector() /*type_names*/) { + return true; + } +}; + +// Returns the current OS stack trace as an std::string. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in +// the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. +GTEST_API_ std::string GetCurrentOsStackTraceExceptTop( + UnitTest* unit_test, int skip_count); + +// Helpers for suppressing warnings on unreachable code or constant +// condition. + +// Always returns true. +GTEST_API_ bool AlwaysTrue(); + +// Always returns false. +inline bool AlwaysFalse() { return !AlwaysTrue(); } + +// Helper for suppressing false warning from Clang on a const char* +// variable declared in a conditional expression always being NULL in +// the else branch. +struct GTEST_API_ ConstCharPtr { + ConstCharPtr(const char* str) : value(str) {} + operator bool() const { return true; } + const char* value; +}; + +// Helper for declaring std::string within 'if' statement +// in pre C++17 build environment. +struct TrueWithString { + TrueWithString() = default; + explicit TrueWithString(const char* str) : value(str) {} + explicit TrueWithString(const std::string& str) : value(str) {} + explicit operator bool() const { return true; } + std::string value; +}; + +// A simple Linear Congruential Generator for generating random +// numbers with a uniform distribution. Unlike rand() and srand(), it +// doesn't use global state (and therefore can't interfere with user +// code). Unlike rand_r(), it's portable. An LCG isn't very random, +// but it's good enough for our purposes. +class GTEST_API_ Random { + public: + static const uint32_t kMaxRange = 1u << 31; + + explicit Random(uint32_t seed) : state_(seed) {} + + void Reseed(uint32_t seed) { state_ = seed; } + + // Generates a random number from [0, range). Crashes if 'range' is + // 0 or greater than kMaxRange. + uint32_t Generate(uint32_t range); + + private: + uint32_t state_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(Random); +}; + +// Turns const U&, U&, const U, and U all into U. +#define GTEST_REMOVE_REFERENCE_AND_CONST_(T) \ + typename std::remove_const::type>::type + +// HasDebugStringAndShortDebugString::value is a compile-time bool constant +// that's true if and only if T has methods DebugString() and ShortDebugString() +// that return std::string. +template +class HasDebugStringAndShortDebugString { + private: + template + static auto CheckDebugString(C*) -> typename std::is_same< + std::string, decltype(std::declval().DebugString())>::type; + template + static std::false_type CheckDebugString(...); + + template + static auto CheckShortDebugString(C*) -> typename std::is_same< + std::string, decltype(std::declval().ShortDebugString())>::type; + template + static std::false_type CheckShortDebugString(...); + + using HasDebugStringType = decltype(CheckDebugString(nullptr)); + using HasShortDebugStringType = decltype(CheckShortDebugString(nullptr)); + + public: + static constexpr bool value = + HasDebugStringType::value && HasShortDebugStringType::value; +}; + +template +constexpr bool HasDebugStringAndShortDebugString::value; + +// When the compiler sees expression IsContainerTest(0), if C is an +// STL-style container class, the first overload of IsContainerTest +// will be viable (since both C::iterator* and C::const_iterator* are +// valid types and NULL can be implicitly converted to them). It will +// be picked over the second overload as 'int' is a perfect match for +// the type of argument 0. If C::iterator or C::const_iterator is not +// a valid type, the first overload is not viable, and the second +// overload will be picked. Therefore, we can determine whether C is +// a container class by checking the type of IsContainerTest(0). +// The value of the expression is insignificant. +// +// In C++11 mode we check the existence of a const_iterator and that an +// iterator is properly implemented for the container. +// +// For pre-C++11 that we look for both C::iterator and C::const_iterator. +// The reason is that C++ injects the name of a class as a member of the +// class itself (e.g. you can refer to class iterator as either +// 'iterator' or 'iterator::iterator'). If we look for C::iterator +// only, for example, we would mistakenly think that a class named +// iterator is an STL container. +// +// Also note that the simpler approach of overloading +// IsContainerTest(typename C::const_iterator*) and +// IsContainerTest(...) doesn't work with Visual Age C++ and Sun C++. +typedef int IsContainer; +template ().begin()), + class = decltype(::std::declval().end()), + class = decltype(++::std::declval()), + class = decltype(*::std::declval()), + class = typename C::const_iterator> +IsContainer IsContainerTest(int /* dummy */) { + return 0; +} + +typedef char IsNotContainer; +template +IsNotContainer IsContainerTest(long /* dummy */) { return '\0'; } + +// Trait to detect whether a type T is a hash table. +// The heuristic used is that the type contains an inner type `hasher` and does +// not contain an inner type `reverse_iterator`. +// If the container is iterable in reverse, then order might actually matter. +template +struct IsHashTable { + private: + template + static char test(typename U::hasher*, typename U::reverse_iterator*); + template + static int test(typename U::hasher*, ...); + template + static char test(...); + + public: + static const bool value = sizeof(test(nullptr, nullptr)) == sizeof(int); +}; + +template +const bool IsHashTable::value; + +template (0)) == sizeof(IsContainer)> +struct IsRecursiveContainerImpl; + +template +struct IsRecursiveContainerImpl : public std::false_type {}; + +// Since the IsRecursiveContainerImpl depends on the IsContainerTest we need to +// obey the same inconsistencies as the IsContainerTest, namely check if +// something is a container is relying on only const_iterator in C++11 and +// is relying on both const_iterator and iterator otherwise +template +struct IsRecursiveContainerImpl { + using value_type = decltype(*std::declval()); + using type = + std::is_same::type>::type, + C>; +}; + +// IsRecursiveContainer is a unary compile-time predicate that +// evaluates whether C is a recursive container type. A recursive container +// type is a container type whose value_type is equal to the container type +// itself. An example for a recursive container type is +// boost::filesystem::path, whose iterator has a value_type that is equal to +// boost::filesystem::path. +template +struct IsRecursiveContainer : public IsRecursiveContainerImpl::type {}; + +// Utilities for native arrays. + +// ArrayEq() compares two k-dimensional native arrays using the +// elements' operator==, where k can be any integer >= 0. When k is +// 0, ArrayEq() degenerates into comparing a single pair of values. + +template +bool ArrayEq(const T* lhs, size_t size, const U* rhs); + +// This generic version is used when k is 0. +template +inline bool ArrayEq(const T& lhs, const U& rhs) { return lhs == rhs; } + +// This overload is used when k >= 1. +template +inline bool ArrayEq(const T(&lhs)[N], const U(&rhs)[N]) { + return internal::ArrayEq(lhs, N, rhs); +} + +// This helper reduces code bloat. If we instead put its logic inside +// the previous ArrayEq() function, arrays with different sizes would +// lead to different copies of the template code. +template +bool ArrayEq(const T* lhs, size_t size, const U* rhs) { + for (size_t i = 0; i != size; i++) { + if (!internal::ArrayEq(lhs[i], rhs[i])) + return false; + } + return true; +} + +// Finds the first element in the iterator range [begin, end) that +// equals elem. Element may be a native array type itself. +template +Iter ArrayAwareFind(Iter begin, Iter end, const Element& elem) { + for (Iter it = begin; it != end; ++it) { + if (internal::ArrayEq(*it, elem)) + return it; + } + return end; +} + +// CopyArray() copies a k-dimensional native array using the elements' +// operator=, where k can be any integer >= 0. When k is 0, +// CopyArray() degenerates into copying a single value. + +template +void CopyArray(const T* from, size_t size, U* to); + +// This generic version is used when k is 0. +template +inline void CopyArray(const T& from, U* to) { *to = from; } + +// This overload is used when k >= 1. +template +inline void CopyArray(const T(&from)[N], U(*to)[N]) { + internal::CopyArray(from, N, *to); +} + +// This helper reduces code bloat. If we instead put its logic inside +// the previous CopyArray() function, arrays with different sizes +// would lead to different copies of the template code. +template +void CopyArray(const T* from, size_t size, U* to) { + for (size_t i = 0; i != size; i++) { + internal::CopyArray(from[i], to + i); + } +} + +// The relation between an NativeArray object (see below) and the +// native array it represents. +// We use 2 different structs to allow non-copyable types to be used, as long +// as RelationToSourceReference() is passed. +struct RelationToSourceReference {}; +struct RelationToSourceCopy {}; + +// Adapts a native array to a read-only STL-style container. Instead +// of the complete STL container concept, this adaptor only implements +// members useful for Google Mock's container matchers. New members +// should be added as needed. To simplify the implementation, we only +// support Element being a raw type (i.e. having no top-level const or +// reference modifier). It's the client's responsibility to satisfy +// this requirement. Element can be an array type itself (hence +// multi-dimensional arrays are supported). +template +class NativeArray { + public: + // STL-style container typedefs. + typedef Element value_type; + typedef Element* iterator; + typedef const Element* const_iterator; + + // Constructs from a native array. References the source. + NativeArray(const Element* array, size_t count, RelationToSourceReference) { + InitRef(array, count); + } + + // Constructs from a native array. Copies the source. + NativeArray(const Element* array, size_t count, RelationToSourceCopy) { + InitCopy(array, count); + } + + // Copy constructor. + NativeArray(const NativeArray& rhs) { + (this->*rhs.clone_)(rhs.array_, rhs.size_); + } + + ~NativeArray() { + if (clone_ != &NativeArray::InitRef) + delete[] array_; + } + + // STL-style container methods. + size_t size() const { return size_; } + const_iterator begin() const { return array_; } + const_iterator end() const { return array_ + size_; } + bool operator==(const NativeArray& rhs) const { + return size() == rhs.size() && + ArrayEq(begin(), size(), rhs.begin()); + } + + private: + static_assert(!std::is_const::value, "Type must not be const"); + static_assert(!std::is_reference::value, + "Type must not be a reference"); + + // Initializes this object with a copy of the input. + void InitCopy(const Element* array, size_t a_size) { + Element* const copy = new Element[a_size]; + CopyArray(array, a_size, copy); + array_ = copy; + size_ = a_size; + clone_ = &NativeArray::InitCopy; + } + + // Initializes this object with a reference of the input. + void InitRef(const Element* array, size_t a_size) { + array_ = array; + size_ = a_size; + clone_ = &NativeArray::InitRef; + } + + const Element* array_; + size_t size_; + void (NativeArray::*clone_)(const Element*, size_t); +}; + +// Backport of std::index_sequence. +template +struct IndexSequence { + using type = IndexSequence; +}; + +// Double the IndexSequence, and one if plus_one is true. +template +struct DoubleSequence; +template +struct DoubleSequence, sizeofT> { + using type = IndexSequence; +}; +template +struct DoubleSequence, sizeofT> { + using type = IndexSequence; +}; + +// Backport of std::make_index_sequence. +// It uses O(ln(N)) instantiation depth. +template +struct MakeIndexSequenceImpl + : DoubleSequence::type, + N / 2>::type {}; + +template <> +struct MakeIndexSequenceImpl<0> : IndexSequence<> {}; + +template +using MakeIndexSequence = typename MakeIndexSequenceImpl::type; + +template +using IndexSequenceFor = typename MakeIndexSequence::type; + +template +struct Ignore { + Ignore(...); // NOLINT +}; + +template +struct ElemFromListImpl; +template +struct ElemFromListImpl> { + // We make Ignore a template to solve a problem with MSVC. + // A non-template Ignore would work fine with `decltype(Ignore(I))...`, but + // MSVC doesn't understand how to deal with that pack expansion. + // Use `0 * I` to have a single instantiation of Ignore. + template + static R Apply(Ignore<0 * I>..., R (*)(), ...); +}; + +template +struct ElemFromList { + using type = + decltype(ElemFromListImpl::type>::Apply( + static_cast(nullptr)...)); +}; + +struct FlatTupleConstructTag {}; + +template +class FlatTuple; + +template +struct FlatTupleElemBase; + +template +struct FlatTupleElemBase, I> { + using value_type = typename ElemFromList::type; + FlatTupleElemBase() = default; + template + explicit FlatTupleElemBase(FlatTupleConstructTag, Arg&& t) + : value(std::forward(t)) {} + value_type value; +}; + +template +struct FlatTupleBase; + +template +struct FlatTupleBase, IndexSequence> + : FlatTupleElemBase, Idx>... { + using Indices = IndexSequence; + FlatTupleBase() = default; + template + explicit FlatTupleBase(FlatTupleConstructTag, Args&&... args) + : FlatTupleElemBase, Idx>(FlatTupleConstructTag{}, + std::forward(args))... {} + + template + const typename ElemFromList::type& Get() const { + return FlatTupleElemBase, I>::value; + } + + template + typename ElemFromList::type& Get() { + return FlatTupleElemBase, I>::value; + } + + template + auto Apply(F&& f) -> decltype(std::forward(f)(this->Get()...)) { + return std::forward(f)(Get()...); + } + + template + auto Apply(F&& f) const -> decltype(std::forward(f)(this->Get()...)) { + return std::forward(f)(Get()...); + } +}; + +// Analog to std::tuple but with different tradeoffs. +// This class minimizes the template instantiation depth, thus allowing more +// elements than std::tuple would. std::tuple has been seen to require an +// instantiation depth of more than 10x the number of elements in some +// implementations. +// FlatTuple and ElemFromList are not recursive and have a fixed depth +// regardless of T... +// MakeIndexSequence, on the other hand, it is recursive but with an +// instantiation depth of O(ln(N)). +template +class FlatTuple + : private FlatTupleBase, + typename MakeIndexSequence::type> { + using Indices = typename FlatTupleBase< + FlatTuple, typename MakeIndexSequence::type>::Indices; + + public: + FlatTuple() = default; + template + explicit FlatTuple(FlatTupleConstructTag tag, Args&&... args) + : FlatTuple::FlatTupleBase(tag, std::forward(args)...) {} + + using FlatTuple::FlatTupleBase::Apply; + using FlatTuple::FlatTupleBase::Get; +}; + +// Utility functions to be called with static_assert to induce deprecation +// warnings. +GTEST_INTERNAL_DEPRECATED( + "INSTANTIATE_TEST_CASE_P is deprecated, please use " + "INSTANTIATE_TEST_SUITE_P") +constexpr bool InstantiateTestCase_P_IsDeprecated() { return true; } + +GTEST_INTERNAL_DEPRECATED( + "TYPED_TEST_CASE_P is deprecated, please use " + "TYPED_TEST_SUITE_P") +constexpr bool TypedTestCase_P_IsDeprecated() { return true; } + +GTEST_INTERNAL_DEPRECATED( + "TYPED_TEST_CASE is deprecated, please use " + "TYPED_TEST_SUITE") +constexpr bool TypedTestCaseIsDeprecated() { return true; } + +GTEST_INTERNAL_DEPRECATED( + "REGISTER_TYPED_TEST_CASE_P is deprecated, please use " + "REGISTER_TYPED_TEST_SUITE_P") +constexpr bool RegisterTypedTestCase_P_IsDeprecated() { return true; } + +GTEST_INTERNAL_DEPRECATED( + "INSTANTIATE_TYPED_TEST_CASE_P is deprecated, please use " + "INSTANTIATE_TYPED_TEST_SUITE_P") +constexpr bool InstantiateTypedTestCase_P_IsDeprecated() { return true; } + +} // namespace internal +} // namespace testing + +namespace std { +// Some standard library implementations use `struct tuple_size` and some use +// `class tuple_size`. Clang warns about the mismatch. +// https://reviews.llvm.org/D55466 +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wmismatched-tags" +#endif +template +struct tuple_size> + : std::integral_constant {}; +#ifdef __clang__ +#pragma clang diagnostic pop +#endif +} // namespace std + +#define GTEST_MESSAGE_AT_(file, line, message, result_type) \ + ::testing::internal::AssertHelper(result_type, file, line, message) \ + = ::testing::Message() + +#define GTEST_MESSAGE_(message, result_type) \ + GTEST_MESSAGE_AT_(__FILE__, __LINE__, message, result_type) + +#define GTEST_FATAL_FAILURE_(message) \ + return GTEST_MESSAGE_(message, ::testing::TestPartResult::kFatalFailure) + +#define GTEST_NONFATAL_FAILURE_(message) \ + GTEST_MESSAGE_(message, ::testing::TestPartResult::kNonFatalFailure) + +#define GTEST_SUCCESS_(message) \ + GTEST_MESSAGE_(message, ::testing::TestPartResult::kSuccess) + +#define GTEST_SKIP_(message) \ + return GTEST_MESSAGE_(message, ::testing::TestPartResult::kSkip) + +// Suppress MSVC warning 4072 (unreachable code) for the code following +// statement if it returns or throws (or doesn't return or throw in some +// situations). +// NOTE: The "else" is important to keep this expansion to prevent a top-level +// "else" from attaching to our "if". +#define GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) \ + if (::testing::internal::AlwaysTrue()) { \ + statement; \ + } else /* NOLINT */ \ + static_assert(true, "") // User must have a semicolon after expansion. + +#if GTEST_HAS_EXCEPTIONS + +namespace testing { +namespace internal { + +class NeverThrown { + public: + const char* what() const noexcept { + return "this exception should never be thrown"; + } +}; + +} // namespace internal +} // namespace testing + +#if GTEST_HAS_RTTI + +#define GTEST_EXCEPTION_TYPE_(e) ::testing::internal::GetTypeName(typeid(e)) + +#else // GTEST_HAS_RTTI + +#define GTEST_EXCEPTION_TYPE_(e) \ + std::string { "an std::exception-derived error" } + +#endif // GTEST_HAS_RTTI + +#define GTEST_TEST_THROW_CATCH_STD_EXCEPTION_(statement, expected_exception) \ + catch (typename std::conditional< \ + std::is_same::type>::type, \ + std::exception>::value, \ + const ::testing::internal::NeverThrown&, const std::exception&>::type \ + e) { \ + gtest_msg.value = "Expected: " #statement \ + " throws an exception of type " #expected_exception \ + ".\n Actual: it throws "; \ + gtest_msg.value += GTEST_EXCEPTION_TYPE_(e); \ + gtest_msg.value += " with description \""; \ + gtest_msg.value += e.what(); \ + gtest_msg.value += "\"."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ + } + +#else // GTEST_HAS_EXCEPTIONS + +#define GTEST_TEST_THROW_CATCH_STD_EXCEPTION_(statement, expected_exception) + +#endif // GTEST_HAS_EXCEPTIONS + +#define GTEST_TEST_THROW_(statement, expected_exception, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::TrueWithString gtest_msg{}) { \ + bool gtest_caught_expected = false; \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } catch (expected_exception const&) { \ + gtest_caught_expected = true; \ + } \ + GTEST_TEST_THROW_CATCH_STD_EXCEPTION_(statement, expected_exception) \ + catch (...) { \ + gtest_msg.value = "Expected: " #statement \ + " throws an exception of type " #expected_exception \ + ".\n Actual: it throws a different type."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ + } \ + if (!gtest_caught_expected) { \ + gtest_msg.value = "Expected: " #statement \ + " throws an exception of type " #expected_exception \ + ".\n Actual: it throws nothing."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ + } \ + } else /*NOLINT*/ \ + GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__) \ + : fail(gtest_msg.value.c_str()) + +#if GTEST_HAS_EXCEPTIONS + +#define GTEST_TEST_NO_THROW_CATCH_STD_EXCEPTION_() \ + catch (std::exception const& e) { \ + gtest_msg.value = "it throws "; \ + gtest_msg.value += GTEST_EXCEPTION_TYPE_(e); \ + gtest_msg.value += " with description \""; \ + gtest_msg.value += e.what(); \ + gtest_msg.value += "\"."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__); \ + } + +#else // GTEST_HAS_EXCEPTIONS + +#define GTEST_TEST_NO_THROW_CATCH_STD_EXCEPTION_() + +#endif // GTEST_HAS_EXCEPTIONS + +#define GTEST_TEST_NO_THROW_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::TrueWithString gtest_msg{}) { \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + GTEST_TEST_NO_THROW_CATCH_STD_EXCEPTION_() \ + catch (...) { \ + gtest_msg.value = "it throws."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__): \ + fail(("Expected: " #statement " doesn't throw an exception.\n" \ + " Actual: " + gtest_msg.value).c_str()) + +#define GTEST_TEST_ANY_THROW_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + bool gtest_caught_any = false; \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + catch (...) { \ + gtest_caught_any = true; \ + } \ + if (!gtest_caught_any) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__): \ + fail("Expected: " #statement " throws an exception.\n" \ + " Actual: it doesn't.") + + +// Implements Boolean test assertions such as EXPECT_TRUE. expression can be +// either a boolean expression or an AssertionResult. text is a textual +// representation of expression as it was passed into the EXPECT_TRUE. +#define GTEST_TEST_BOOLEAN_(expression, text, actual, expected, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (const ::testing::AssertionResult gtest_ar_ = \ + ::testing::AssertionResult(expression)) \ + ; \ + else \ + fail(::testing::internal::GetBoolAssertionFailureMessage(\ + gtest_ar_, text, #actual, #expected).c_str()) + +#define GTEST_TEST_NO_FATAL_FAILURE_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + ::testing::internal::HasNewFatalFailureHelper gtest_fatal_failure_checker; \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + if (gtest_fatal_failure_checker.has_new_fatal_failure()) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__): \ + fail("Expected: " #statement " doesn't generate new fatal " \ + "failures in the current thread.\n" \ + " Actual: it does.") + +// Expands to the name of the class that implements the given test. +#define GTEST_TEST_CLASS_NAME_(test_suite_name, test_name) \ + test_suite_name##_##test_name##_Test + +// Helper macro for defining tests. +#define GTEST_TEST_(test_suite_name, test_name, parent_class, parent_id) \ + static_assert(sizeof(GTEST_STRINGIFY_(test_suite_name)) > 1, \ + "test_suite_name must not be empty"); \ + static_assert(sizeof(GTEST_STRINGIFY_(test_name)) > 1, \ + "test_name must not be empty"); \ + class GTEST_TEST_CLASS_NAME_(test_suite_name, test_name) \ + : public parent_class { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)() = default; \ + ~GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)() override = default; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_suite_name, \ + test_name)); \ + GTEST_DISALLOW_MOVE_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_suite_name, \ + test_name)); \ + \ + private: \ + void TestBody() override; \ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ + }; \ + \ + ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_suite_name, \ + test_name)::test_info_ = \ + ::testing::internal::MakeAndRegisterTestInfo( \ + #test_suite_name, #test_name, nullptr, nullptr, \ + ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ + ::testing::internal::SuiteApiResolver< \ + parent_class>::GetSetUpCaseOrSuite(__FILE__, __LINE__), \ + ::testing::internal::SuiteApiResolver< \ + parent_class>::GetTearDownCaseOrSuite(__FILE__, __LINE__), \ + new ::testing::internal::TestFactoryImpl); \ + void GTEST_TEST_CLASS_NAME_(test_suite_name, test_name)::TestBody() + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-param-util.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-param-util.h new file mode 100644 index 000000000000..ff25d9950aaf --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-param-util.h @@ -0,0 +1,948 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Type and function utilities for implementing parameterized tests. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-port.h" +#include "gtest/gtest-printers.h" +#include "gtest/gtest-test-part.h" + +namespace testing { +// Input to a parameterized test name generator, describing a test parameter. +// Consists of the parameter value and the integer parameter index. +template +struct TestParamInfo { + TestParamInfo(const ParamType& a_param, size_t an_index) : + param(a_param), + index(an_index) {} + ParamType param; + size_t index; +}; + +// A builtin parameterized test name generator which returns the result of +// testing::PrintToString. +struct PrintToStringParamName { + template + std::string operator()(const TestParamInfo& info) const { + return PrintToString(info.param); + } +}; + +namespace internal { + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// Utility Functions + +// Outputs a message explaining invalid registration of different +// fixture class for the same test suite. This may happen when +// TEST_P macro is used to define two tests with the same name +// but in different namespaces. +GTEST_API_ void ReportInvalidTestSuiteType(const char* test_suite_name, + CodeLocation code_location); + +template class ParamGeneratorInterface; +template class ParamGenerator; + +// Interface for iterating over elements provided by an implementation +// of ParamGeneratorInterface. +template +class ParamIteratorInterface { + public: + virtual ~ParamIteratorInterface() {} + // A pointer to the base generator instance. + // Used only for the purposes of iterator comparison + // to make sure that two iterators belong to the same generator. + virtual const ParamGeneratorInterface* BaseGenerator() const = 0; + // Advances iterator to point to the next element + // provided by the generator. The caller is responsible + // for not calling Advance() on an iterator equal to + // BaseGenerator()->End(). + virtual void Advance() = 0; + // Clones the iterator object. Used for implementing copy semantics + // of ParamIterator. + virtual ParamIteratorInterface* Clone() const = 0; + // Dereferences the current iterator and provides (read-only) access + // to the pointed value. It is the caller's responsibility not to call + // Current() on an iterator equal to BaseGenerator()->End(). + // Used for implementing ParamGenerator::operator*(). + virtual const T* Current() const = 0; + // Determines whether the given iterator and other point to the same + // element in the sequence generated by the generator. + // Used for implementing ParamGenerator::operator==(). + virtual bool Equals(const ParamIteratorInterface& other) const = 0; +}; + +// Class iterating over elements provided by an implementation of +// ParamGeneratorInterface. It wraps ParamIteratorInterface +// and implements the const forward iterator concept. +template +class ParamIterator { + public: + typedef T value_type; + typedef const T& reference; + typedef ptrdiff_t difference_type; + + // ParamIterator assumes ownership of the impl_ pointer. + ParamIterator(const ParamIterator& other) : impl_(other.impl_->Clone()) {} + ParamIterator& operator=(const ParamIterator& other) { + if (this != &other) + impl_.reset(other.impl_->Clone()); + return *this; + } + + const T& operator*() const { return *impl_->Current(); } + const T* operator->() const { return impl_->Current(); } + // Prefix version of operator++. + ParamIterator& operator++() { + impl_->Advance(); + return *this; + } + // Postfix version of operator++. + ParamIterator operator++(int /*unused*/) { + ParamIteratorInterface* clone = impl_->Clone(); + impl_->Advance(); + return ParamIterator(clone); + } + bool operator==(const ParamIterator& other) const { + return impl_.get() == other.impl_.get() || impl_->Equals(*other.impl_); + } + bool operator!=(const ParamIterator& other) const { + return !(*this == other); + } + + private: + friend class ParamGenerator; + explicit ParamIterator(ParamIteratorInterface* impl) : impl_(impl) {} + std::unique_ptr > impl_; +}; + +// ParamGeneratorInterface is the binary interface to access generators +// defined in other translation units. +template +class ParamGeneratorInterface { + public: + typedef T ParamType; + + virtual ~ParamGeneratorInterface() {} + + // Generator interface definition + virtual ParamIteratorInterface* Begin() const = 0; + virtual ParamIteratorInterface* End() const = 0; +}; + +// Wraps ParamGeneratorInterface and provides general generator syntax +// compatible with the STL Container concept. +// This class implements copy initialization semantics and the contained +// ParamGeneratorInterface instance is shared among all copies +// of the original object. This is possible because that instance is immutable. +template +class ParamGenerator { + public: + typedef ParamIterator iterator; + + explicit ParamGenerator(ParamGeneratorInterface* impl) : impl_(impl) {} + ParamGenerator(const ParamGenerator& other) : impl_(other.impl_) {} + + ParamGenerator& operator=(const ParamGenerator& other) { + impl_ = other.impl_; + return *this; + } + + iterator begin() const { return iterator(impl_->Begin()); } + iterator end() const { return iterator(impl_->End()); } + + private: + std::shared_ptr > impl_; +}; + +// Generates values from a range of two comparable values. Can be used to +// generate sequences of user-defined types that implement operator+() and +// operator<(). +// This class is used in the Range() function. +template +class RangeGenerator : public ParamGeneratorInterface { + public: + RangeGenerator(T begin, T end, IncrementT step) + : begin_(begin), end_(end), + step_(step), end_index_(CalculateEndIndex(begin, end, step)) {} + ~RangeGenerator() override {} + + ParamIteratorInterface* Begin() const override { + return new Iterator(this, begin_, 0, step_); + } + ParamIteratorInterface* End() const override { + return new Iterator(this, end_, end_index_, step_); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, T value, int index, + IncrementT step) + : base_(base), value_(value), index_(index), step_(step) {} + ~Iterator() override {} + + const ParamGeneratorInterface* BaseGenerator() const override { + return base_; + } + void Advance() override { + value_ = static_cast(value_ + step_); + index_++; + } + ParamIteratorInterface* Clone() const override { + return new Iterator(*this); + } + const T* Current() const override { return &value_; } + bool Equals(const ParamIteratorInterface& other) const override { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const int other_index = + CheckedDowncastToActualType(&other)->index_; + return index_ == other_index; + } + + private: + Iterator(const Iterator& other) + : ParamIteratorInterface(), + base_(other.base_), value_(other.value_), index_(other.index_), + step_(other.step_) {} + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + T value_; + int index_; + const IncrementT step_; + }; // class RangeGenerator::Iterator + + static int CalculateEndIndex(const T& begin, + const T& end, + const IncrementT& step) { + int end_index = 0; + for (T i = begin; i < end; i = static_cast(i + step)) + end_index++; + return end_index; + } + + // No implementation - assignment is unsupported. + void operator=(const RangeGenerator& other); + + const T begin_; + const T end_; + const IncrementT step_; + // The index for the end() iterator. All the elements in the generated + // sequence are indexed (0-based) to aid iterator comparison. + const int end_index_; +}; // class RangeGenerator + + +// Generates values from a pair of STL-style iterators. Used in the +// ValuesIn() function. The elements are copied from the source range +// since the source can be located on the stack, and the generator +// is likely to persist beyond that stack frame. +template +class ValuesInIteratorRangeGenerator : public ParamGeneratorInterface { + public: + template + ValuesInIteratorRangeGenerator(ForwardIterator begin, ForwardIterator end) + : container_(begin, end) {} + ~ValuesInIteratorRangeGenerator() override {} + + ParamIteratorInterface* Begin() const override { + return new Iterator(this, container_.begin()); + } + ParamIteratorInterface* End() const override { + return new Iterator(this, container_.end()); + } + + private: + typedef typename ::std::vector ContainerType; + + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + typename ContainerType::const_iterator iterator) + : base_(base), iterator_(iterator) {} + ~Iterator() override {} + + const ParamGeneratorInterface* BaseGenerator() const override { + return base_; + } + void Advance() override { + ++iterator_; + value_.reset(); + } + ParamIteratorInterface* Clone() const override { + return new Iterator(*this); + } + // We need to use cached value referenced by iterator_ because *iterator_ + // can return a temporary object (and of type other then T), so just + // having "return &*iterator_;" doesn't work. + // value_ is updated here and not in Advance() because Advance() + // can advance iterator_ beyond the end of the range, and we cannot + // detect that fact. The client code, on the other hand, is + // responsible for not calling Current() on an out-of-range iterator. + const T* Current() const override { + if (value_.get() == nullptr) value_.reset(new T(*iterator_)); + return value_.get(); + } + bool Equals(const ParamIteratorInterface& other) const override { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + return iterator_ == + CheckedDowncastToActualType(&other)->iterator_; + } + + private: + Iterator(const Iterator& other) + // The explicit constructor call suppresses a false warning + // emitted by gcc when supplied with the -Wextra option. + : ParamIteratorInterface(), + base_(other.base_), + iterator_(other.iterator_) {} + + const ParamGeneratorInterface* const base_; + typename ContainerType::const_iterator iterator_; + // A cached value of *iterator_. We keep it here to allow access by + // pointer in the wrapping iterator's operator->(). + // value_ needs to be mutable to be accessed in Current(). + // Use of std::unique_ptr helps manage cached value's lifetime, + // which is bound by the lifespan of the iterator itself. + mutable std::unique_ptr value_; + }; // class ValuesInIteratorRangeGenerator::Iterator + + // No implementation - assignment is unsupported. + void operator=(const ValuesInIteratorRangeGenerator& other); + + const ContainerType container_; +}; // class ValuesInIteratorRangeGenerator + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Default parameterized test name generator, returns a string containing the +// integer test parameter index. +template +std::string DefaultParamName(const TestParamInfo& info) { + Message name_stream; + name_stream << info.index; + return name_stream.GetString(); +} + +template +void TestNotEmpty() { + static_assert(sizeof(T) == 0, "Empty arguments are not allowed."); +} +template +void TestNotEmpty(const T&) {} + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Stores a parameter value and later creates tests parameterized with that +// value. +template +class ParameterizedTestFactory : public TestFactoryBase { + public: + typedef typename TestClass::ParamType ParamType; + explicit ParameterizedTestFactory(ParamType parameter) : + parameter_(parameter) {} + Test* CreateTest() override { + TestClass::SetParam(¶meter_); + return new TestClass(); + } + + private: + const ParamType parameter_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestFactory); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// TestMetaFactoryBase is a base class for meta-factories that create +// test factories for passing into MakeAndRegisterTestInfo function. +template +class TestMetaFactoryBase { + public: + virtual ~TestMetaFactoryBase() {} + + virtual TestFactoryBase* CreateTestFactory(ParamType parameter) = 0; +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// TestMetaFactory creates test factories for passing into +// MakeAndRegisterTestInfo function. Since MakeAndRegisterTestInfo receives +// ownership of test factory pointer, same factory object cannot be passed +// into that method twice. But ParameterizedTestSuiteInfo is going to call +// it for each Test/Parameter value combination. Thus it needs meta factory +// creator class. +template +class TestMetaFactory + : public TestMetaFactoryBase { + public: + using ParamType = typename TestSuite::ParamType; + + TestMetaFactory() {} + + TestFactoryBase* CreateTestFactory(ParamType parameter) override { + return new ParameterizedTestFactory(parameter); + } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestMetaFactory); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestSuiteInfoBase is a generic interface +// to ParameterizedTestSuiteInfo classes. ParameterizedTestSuiteInfoBase +// accumulates test information provided by TEST_P macro invocations +// and generators provided by INSTANTIATE_TEST_SUITE_P macro invocations +// and uses that information to register all resulting test instances +// in RegisterTests method. The ParameterizeTestSuiteRegistry class holds +// a collection of pointers to the ParameterizedTestSuiteInfo objects +// and calls RegisterTests() on each of them when asked. +class ParameterizedTestSuiteInfoBase { + public: + virtual ~ParameterizedTestSuiteInfoBase() {} + + // Base part of test suite name for display purposes. + virtual const std::string& GetTestSuiteName() const = 0; + // Test suite id to verify identity. + virtual TypeId GetTestSuiteTypeId() const = 0; + // UnitTest class invokes this method to register tests in this + // test suite right before running them in RUN_ALL_TESTS macro. + // This method should not be called more than once on any single + // instance of a ParameterizedTestSuiteInfoBase derived class. + virtual void RegisterTests() = 0; + + protected: + ParameterizedTestSuiteInfoBase() {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestSuiteInfoBase); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Report a the name of a test_suit as safe to ignore +// as the side effect of construction of this type. +struct GTEST_API_ MarkAsIgnored { + explicit MarkAsIgnored(const char* test_suite); +}; + +GTEST_API_ void InsertSyntheticTestCase(const std::string& name, + CodeLocation location, bool has_test_p); + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestSuiteInfo accumulates tests obtained from TEST_P +// macro invocations for a particular test suite and generators +// obtained from INSTANTIATE_TEST_SUITE_P macro invocations for that +// test suite. It registers tests with all values generated by all +// generators when asked. +template +class ParameterizedTestSuiteInfo : public ParameterizedTestSuiteInfoBase { + public: + // ParamType and GeneratorCreationFunc are private types but are required + // for declarations of public methods AddTestPattern() and + // AddTestSuiteInstantiation(). + using ParamType = typename TestSuite::ParamType; + // A function that returns an instance of appropriate generator type. + typedef ParamGenerator(GeneratorCreationFunc)(); + using ParamNameGeneratorFunc = std::string(const TestParamInfo&); + + explicit ParameterizedTestSuiteInfo(const char* name, + CodeLocation code_location) + : test_suite_name_(name), code_location_(code_location) {} + + // Test suite base name for display purposes. + const std::string& GetTestSuiteName() const override { + return test_suite_name_; + } + // Test suite id to verify identity. + TypeId GetTestSuiteTypeId() const override { return GetTypeId(); } + // TEST_P macro uses AddTestPattern() to record information + // about a single test in a LocalTestInfo structure. + // test_suite_name is the base name of the test suite (without invocation + // prefix). test_base_name is the name of an individual test without + // parameter index. For the test SequenceA/FooTest.DoBar/1 FooTest is + // test suite base name and DoBar is test base name. + void AddTestPattern(const char* test_suite_name, const char* test_base_name, + TestMetaFactoryBase* meta_factory, + CodeLocation code_location) { + tests_.push_back(std::shared_ptr(new TestInfo( + test_suite_name, test_base_name, meta_factory, code_location))); + } + // INSTANTIATE_TEST_SUITE_P macro uses AddGenerator() to record information + // about a generator. + int AddTestSuiteInstantiation(const std::string& instantiation_name, + GeneratorCreationFunc* func, + ParamNameGeneratorFunc* name_func, + const char* file, int line) { + instantiations_.push_back( + InstantiationInfo(instantiation_name, func, name_func, file, line)); + return 0; // Return value used only to run this method in namespace scope. + } + // UnitTest class invokes this method to register tests in this test suite + // right before running tests in RUN_ALL_TESTS macro. + // This method should not be called more than once on any single + // instance of a ParameterizedTestSuiteInfoBase derived class. + // UnitTest has a guard to prevent from calling this method more than once. + void RegisterTests() override { + bool generated_instantiations = false; + + for (typename TestInfoContainer::iterator test_it = tests_.begin(); + test_it != tests_.end(); ++test_it) { + std::shared_ptr test_info = *test_it; + for (typename InstantiationContainer::iterator gen_it = + instantiations_.begin(); gen_it != instantiations_.end(); + ++gen_it) { + const std::string& instantiation_name = gen_it->name; + ParamGenerator generator((*gen_it->generator)()); + ParamNameGeneratorFunc* name_func = gen_it->name_func; + const char* file = gen_it->file; + int line = gen_it->line; + + std::string test_suite_name; + if ( !instantiation_name.empty() ) + test_suite_name = instantiation_name + "/"; + test_suite_name += test_info->test_suite_base_name; + + size_t i = 0; + std::set test_param_names; + for (typename ParamGenerator::iterator param_it = + generator.begin(); + param_it != generator.end(); ++param_it, ++i) { + generated_instantiations = true; + + Message test_name_stream; + + std::string param_name = name_func( + TestParamInfo(*param_it, i)); + + GTEST_CHECK_(IsValidParamName(param_name)) + << "Parameterized test name '" << param_name + << "' is invalid, in " << file + << " line " << line << std::endl; + + GTEST_CHECK_(test_param_names.count(param_name) == 0) + << "Duplicate parameterized test name '" << param_name + << "', in " << file << " line " << line << std::endl; + + test_param_names.insert(param_name); + + if (!test_info->test_base_name.empty()) { + test_name_stream << test_info->test_base_name << "/"; + } + test_name_stream << param_name; + MakeAndRegisterTestInfo( + test_suite_name.c_str(), test_name_stream.GetString().c_str(), + nullptr, // No type parameter. + PrintToString(*param_it).c_str(), test_info->code_location, + GetTestSuiteTypeId(), + SuiteApiResolver::GetSetUpCaseOrSuite(file, line), + SuiteApiResolver::GetTearDownCaseOrSuite(file, line), + test_info->test_meta_factory->CreateTestFactory(*param_it)); + } // for param_it + } // for gen_it + } // for test_it + + if (!generated_instantiations) { + // There are no generaotrs, or they all generate nothing ... + InsertSyntheticTestCase(GetTestSuiteName(), code_location_, + !tests_.empty()); + } + } // RegisterTests + + private: + // LocalTestInfo structure keeps information about a single test registered + // with TEST_P macro. + struct TestInfo { + TestInfo(const char* a_test_suite_base_name, const char* a_test_base_name, + TestMetaFactoryBase* a_test_meta_factory, + CodeLocation a_code_location) + : test_suite_base_name(a_test_suite_base_name), + test_base_name(a_test_base_name), + test_meta_factory(a_test_meta_factory), + code_location(a_code_location) {} + + const std::string test_suite_base_name; + const std::string test_base_name; + const std::unique_ptr > test_meta_factory; + const CodeLocation code_location; + }; + using TestInfoContainer = ::std::vector >; + // Records data received from INSTANTIATE_TEST_SUITE_P macros: + // + struct InstantiationInfo { + InstantiationInfo(const std::string &name_in, + GeneratorCreationFunc* generator_in, + ParamNameGeneratorFunc* name_func_in, + const char* file_in, + int line_in) + : name(name_in), + generator(generator_in), + name_func(name_func_in), + file(file_in), + line(line_in) {} + + std::string name; + GeneratorCreationFunc* generator; + ParamNameGeneratorFunc* name_func; + const char* file; + int line; + }; + typedef ::std::vector InstantiationContainer; + + static bool IsValidParamName(const std::string& name) { + // Check for empty string + if (name.empty()) + return false; + + // Check for invalid characters + for (std::string::size_type index = 0; index < name.size(); ++index) { + if (!IsAlNum(name[index]) && name[index] != '_') + return false; + } + + return true; + } + + const std::string test_suite_name_; + CodeLocation code_location_; + TestInfoContainer tests_; + InstantiationContainer instantiations_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestSuiteInfo); +}; // class ParameterizedTestSuiteInfo + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +template +using ParameterizedTestCaseInfo = ParameterizedTestSuiteInfo; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestSuiteRegistry contains a map of +// ParameterizedTestSuiteInfoBase classes accessed by test suite names. TEST_P +// and INSTANTIATE_TEST_SUITE_P macros use it to locate their corresponding +// ParameterizedTestSuiteInfo descriptors. +class ParameterizedTestSuiteRegistry { + public: + ParameterizedTestSuiteRegistry() {} + ~ParameterizedTestSuiteRegistry() { + for (auto& test_suite_info : test_suite_infos_) { + delete test_suite_info; + } + } + + // Looks up or creates and returns a structure containing information about + // tests and instantiations of a particular test suite. + template + ParameterizedTestSuiteInfo* GetTestSuitePatternHolder( + const char* test_suite_name, CodeLocation code_location) { + ParameterizedTestSuiteInfo* typed_test_info = nullptr; + for (auto& test_suite_info : test_suite_infos_) { + if (test_suite_info->GetTestSuiteName() == test_suite_name) { + if (test_suite_info->GetTestSuiteTypeId() != GetTypeId()) { + // Complain about incorrect usage of Google Test facilities + // and terminate the program since we cannot guaranty correct + // test suite setup and tear-down in this case. + ReportInvalidTestSuiteType(test_suite_name, code_location); + posix::Abort(); + } else { + // At this point we are sure that the object we found is of the same + // type we are looking for, so we downcast it to that type + // without further checks. + typed_test_info = CheckedDowncastToActualType< + ParameterizedTestSuiteInfo >(test_suite_info); + } + break; + } + } + if (typed_test_info == nullptr) { + typed_test_info = new ParameterizedTestSuiteInfo( + test_suite_name, code_location); + test_suite_infos_.push_back(typed_test_info); + } + return typed_test_info; + } + void RegisterTests() { + for (auto& test_suite_info : test_suite_infos_) { + test_suite_info->RegisterTests(); + } + } +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + template + ParameterizedTestCaseInfo* GetTestCasePatternHolder( + const char* test_case_name, CodeLocation code_location) { + return GetTestSuitePatternHolder(test_case_name, code_location); + } + +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + private: + using TestSuiteInfoContainer = ::std::vector; + + TestSuiteInfoContainer test_suite_infos_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestSuiteRegistry); +}; + +// Keep track of what type-parameterized test suite are defined and +// where as well as which are intatiated. This allows susequently +// identifying suits that are defined but never used. +class TypeParameterizedTestSuiteRegistry { + public: + // Add a suite definition + void RegisterTestSuite(const char* test_suite_name, + CodeLocation code_location); + + // Add an instantiation of a suit. + void RegisterInstantiation(const char* test_suite_name); + + // For each suit repored as defined but not reported as instantiation, + // emit a test that reports that fact (configurably, as an error). + void CheckForInstantiations(); + + private: + struct TypeParameterizedTestSuiteInfo { + explicit TypeParameterizedTestSuiteInfo(CodeLocation c) + : code_location(c), instantiated(false) {} + + CodeLocation code_location; + bool instantiated; + }; + + std::map suites_; +}; + +} // namespace internal + +// Forward declarations of ValuesIn(), which is implemented in +// include/gtest/gtest-param-test.h. +template +internal::ParamGenerator ValuesIn( + const Container& container); + +namespace internal { +// Used in the Values() function to provide polymorphic capabilities. + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + +template +class ValueArray { + public: + explicit ValueArray(Ts... v) : v_(FlatTupleConstructTag{}, std::move(v)...) {} + + template + operator ParamGenerator() const { // NOLINT + return ValuesIn(MakeVector(MakeIndexSequence())); + } + + private: + template + std::vector MakeVector(IndexSequence) const { + return std::vector{static_cast(v_.template Get())...}; + } + + FlatTuple v_; +}; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +template +class CartesianProductGenerator + : public ParamGeneratorInterface<::std::tuple> { + public: + typedef ::std::tuple ParamType; + + CartesianProductGenerator(const std::tuple...>& g) + : generators_(g) {} + ~CartesianProductGenerator() override {} + + ParamIteratorInterface* Begin() const override { + return new Iterator(this, generators_, false); + } + ParamIteratorInterface* End() const override { + return new Iterator(this, generators_, true); + } + + private: + template + class IteratorImpl; + template + class IteratorImpl> + : public ParamIteratorInterface { + public: + IteratorImpl(const ParamGeneratorInterface* base, + const std::tuple...>& generators, bool is_end) + : base_(base), + begin_(std::get(generators).begin()...), + end_(std::get(generators).end()...), + current_(is_end ? end_ : begin_) { + ComputeCurrentValue(); + } + ~IteratorImpl() override {} + + const ParamGeneratorInterface* BaseGenerator() const override { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + void Advance() override { + assert(!AtEnd()); + // Advance the last iterator. + ++std::get(current_); + // if that reaches end, propagate that up. + AdvanceIfEnd(); + ComputeCurrentValue(); + } + ParamIteratorInterface* Clone() const override { + return new IteratorImpl(*this); + } + + const ParamType* Current() const override { return current_value_.get(); } + + bool Equals(const ParamIteratorInterface& other) const override { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const IteratorImpl* typed_other = + CheckedDowncastToActualType(&other); + + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + if (AtEnd() && typed_other->AtEnd()) return true; + + bool same = true; + bool dummy[] = { + (same = same && std::get(current_) == + std::get(typed_other->current_))...}; + (void)dummy; + return same; + } + + private: + template + void AdvanceIfEnd() { + if (std::get(current_) != std::get(end_)) return; + + bool last = ThisI == 0; + if (last) { + // We are done. Nothing else to propagate. + return; + } + + constexpr size_t NextI = ThisI - (ThisI != 0); + std::get(current_) = std::get(begin_); + ++std::get(current_); + AdvanceIfEnd(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = std::make_shared(*std::get(current_)...); + } + bool AtEnd() const { + bool at_end = false; + bool dummy[] = { + (at_end = at_end || std::get(current_) == std::get(end_))...}; + (void)dummy; + return at_end; + } + + const ParamGeneratorInterface* const base_; + std::tuple::iterator...> begin_; + std::tuple::iterator...> end_; + std::tuple::iterator...> current_; + std::shared_ptr current_value_; + }; + + using Iterator = IteratorImpl::type>; + + std::tuple...> generators_; +}; + +template +class CartesianProductHolder { + public: + CartesianProductHolder(const Gen&... g) : generators_(g...) {} + template + operator ParamGenerator<::std::tuple>() const { + return ParamGenerator<::std::tuple>( + new CartesianProductGenerator(generators_)); + } + + private: + std::tuple generators_; +}; + +} // namespace internal +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-port-arch.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-port-arch.h new file mode 100644 index 000000000000..22bbad97eb6b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-port-arch.h @@ -0,0 +1,116 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file defines the GTEST_OS_* macro. +// It is separate from gtest-port.h so that custom/gtest-port.h can include it. + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_ARCH_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_ARCH_H_ + +// Determines the platform on which Google Test is compiled. +#ifdef __CYGWIN__ +# define GTEST_OS_CYGWIN 1 +# elif defined(__MINGW__) || defined(__MINGW32__) || defined(__MINGW64__) +# define GTEST_OS_WINDOWS_MINGW 1 +# define GTEST_OS_WINDOWS 1 +#elif defined _WIN32 +# define GTEST_OS_WINDOWS 1 +# ifdef _WIN32_WCE +# define GTEST_OS_WINDOWS_MOBILE 1 +# elif defined(WINAPI_FAMILY) +# include +# if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) +# define GTEST_OS_WINDOWS_DESKTOP 1 +# elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_PHONE_APP) +# define GTEST_OS_WINDOWS_PHONE 1 +# elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP) +# define GTEST_OS_WINDOWS_RT 1 +# elif WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_TV_TITLE) +# define GTEST_OS_WINDOWS_PHONE 1 +# define GTEST_OS_WINDOWS_TV_TITLE 1 +# else + // WINAPI_FAMILY defined but no known partition matched. + // Default to desktop. +# define GTEST_OS_WINDOWS_DESKTOP 1 +# endif +# else +# define GTEST_OS_WINDOWS_DESKTOP 1 +# endif // _WIN32_WCE +#elif defined __OS2__ +# define GTEST_OS_OS2 1 +#elif defined __APPLE__ +# define GTEST_OS_MAC 1 +# include +# if TARGET_OS_IPHONE +# define GTEST_OS_IOS 1 +# endif +#elif defined __DragonFly__ +# define GTEST_OS_DRAGONFLY 1 +#elif defined __FreeBSD__ +# define GTEST_OS_FREEBSD 1 +#elif defined __Fuchsia__ +# define GTEST_OS_FUCHSIA 1 +#elif defined(__GNU__) +# define GTEST_OS_GNU_HURD 1 +#elif defined(__GLIBC__) && defined(__FreeBSD_kernel__) +# define GTEST_OS_GNU_KFREEBSD 1 +#elif defined __linux__ +# define GTEST_OS_LINUX 1 +# if defined __ANDROID__ +# define GTEST_OS_LINUX_ANDROID 1 +# endif +#elif defined __MVS__ +# define GTEST_OS_ZOS 1 +#elif defined(__sun) && defined(__SVR4) +# define GTEST_OS_SOLARIS 1 +#elif defined(_AIX) +# define GTEST_OS_AIX 1 +#elif defined(__hpux) +# define GTEST_OS_HPUX 1 +#elif defined __native_client__ +# define GTEST_OS_NACL 1 +#elif defined __NetBSD__ +# define GTEST_OS_NETBSD 1 +#elif defined __OpenBSD__ +# define GTEST_OS_OPENBSD 1 +#elif defined __QNX__ +# define GTEST_OS_QNX 1 +#elif defined(__HAIKU__) +#define GTEST_OS_HAIKU 1 +#elif defined ESP8266 +#define GTEST_OS_ESP8266 1 +#elif defined ESP32 +#define GTEST_OS_ESP32 1 +#elif defined(__XTENSA__) +#define GTEST_OS_XTENSA 1 +#endif // __CYGWIN__ + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_ARCH_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-port.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-port.h new file mode 100644 index 000000000000..929b7090cab3 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-port.h @@ -0,0 +1,2381 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Low-level types and utilities for porting Google Test to various +// platforms. All macros ending with _ and symbols defined in an +// internal namespace are subject to change without notice. Code +// outside Google Test MUST NOT USE THEM DIRECTLY. Macros that don't +// end with _ are part of Google Test's public API and can be used by +// code outside Google Test. +// +// This file is fundamental to Google Test. All other Google Test source +// files are expected to #include this. Therefore, it cannot #include +// any other Google Test header. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ + +// Environment-describing macros +// ----------------------------- +// +// Google Test can be used in many different environments. Macros in +// this section tell Google Test what kind of environment it is being +// used in, such that Google Test can provide environment-specific +// features and implementations. +// +// Google Test tries to automatically detect the properties of its +// environment, so users usually don't need to worry about these +// macros. However, the automatic detection is not perfect. +// Sometimes it's necessary for a user to define some of the following +// macros in the build script to override Google Test's decisions. +// +// If the user doesn't define a macro in the list, Google Test will +// provide a default definition. After this header is #included, all +// macros in this list will be defined to either 1 or 0. +// +// Notes to maintainers: +// - Each macro here is a user-tweakable knob; do not grow the list +// lightly. +// - Use #if to key off these macros. Don't use #ifdef or "#if +// defined(...)", which will not work as these macros are ALWAYS +// defined. +// +// GTEST_HAS_CLONE - Define it to 1/0 to indicate that clone(2) +// is/isn't available. +// GTEST_HAS_EXCEPTIONS - Define it to 1/0 to indicate that exceptions +// are enabled. +// GTEST_HAS_POSIX_RE - Define it to 1/0 to indicate that POSIX regular +// expressions are/aren't available. +// GTEST_HAS_PTHREAD - Define it to 1/0 to indicate that +// is/isn't available. +// GTEST_HAS_RTTI - Define it to 1/0 to indicate that RTTI is/isn't +// enabled. +// GTEST_HAS_STD_WSTRING - Define it to 1/0 to indicate that +// std::wstring does/doesn't work (Google Test can +// be used where std::wstring is unavailable). +// GTEST_HAS_SEH - Define it to 1/0 to indicate whether the +// compiler supports Microsoft's "Structured +// Exception Handling". +// GTEST_HAS_STREAM_REDIRECTION +// - Define it to 1/0 to indicate whether the +// platform supports I/O stream redirection using +// dup() and dup2(). +// GTEST_LINKED_AS_SHARED_LIBRARY +// - Define to 1 when compiling tests that use +// Google Test as a shared library (known as +// DLL on Windows). +// GTEST_CREATE_SHARED_LIBRARY +// - Define to 1 when compiling Google Test itself +// as a shared library. +// GTEST_DEFAULT_DEATH_TEST_STYLE +// - The default value of --gtest_death_test_style. +// The legacy default has been "fast" in the open +// source version since 2008. The recommended value +// is "threadsafe", and can be set in +// custom/gtest-port.h. + +// Platform-indicating macros +// -------------------------- +// +// Macros indicating the platform on which Google Test is being used +// (a macro is defined to 1 if compiled on the given platform; +// otherwise UNDEFINED -- it's never defined to 0.). Google Test +// defines these macros automatically. Code outside Google Test MUST +// NOT define them. +// +// GTEST_OS_AIX - IBM AIX +// GTEST_OS_CYGWIN - Cygwin +// GTEST_OS_DRAGONFLY - DragonFlyBSD +// GTEST_OS_FREEBSD - FreeBSD +// GTEST_OS_FUCHSIA - Fuchsia +// GTEST_OS_GNU_HURD - GNU/Hurd +// GTEST_OS_GNU_KFREEBSD - GNU/kFreeBSD +// GTEST_OS_HAIKU - Haiku +// GTEST_OS_HPUX - HP-UX +// GTEST_OS_LINUX - Linux +// GTEST_OS_LINUX_ANDROID - Google Android +// GTEST_OS_MAC - Mac OS X +// GTEST_OS_IOS - iOS +// GTEST_OS_NACL - Google Native Client (NaCl) +// GTEST_OS_NETBSD - NetBSD +// GTEST_OS_OPENBSD - OpenBSD +// GTEST_OS_OS2 - OS/2 +// GTEST_OS_QNX - QNX +// GTEST_OS_SOLARIS - Sun Solaris +// GTEST_OS_WINDOWS - Windows (Desktop, MinGW, or Mobile) +// GTEST_OS_WINDOWS_DESKTOP - Windows Desktop +// GTEST_OS_WINDOWS_MINGW - MinGW +// GTEST_OS_WINDOWS_MOBILE - Windows Mobile +// GTEST_OS_WINDOWS_PHONE - Windows Phone +// GTEST_OS_WINDOWS_RT - Windows Store App/WinRT +// GTEST_OS_ZOS - z/OS +// +// Among the platforms, Cygwin, Linux, Mac OS X, and Windows have the +// most stable support. Since core members of the Google Test project +// don't have access to other platforms, support for them may be less +// stable. If you notice any problems on your platform, please notify +// googletestframework@googlegroups.com (patches for fixing them are +// even more welcome!). +// +// It is possible that none of the GTEST_OS_* macros are defined. + +// Feature-indicating macros +// ------------------------- +// +// Macros indicating which Google Test features are available (a macro +// is defined to 1 if the corresponding feature is supported; +// otherwise UNDEFINED -- it's never defined to 0.). Google Test +// defines these macros automatically. Code outside Google Test MUST +// NOT define them. +// +// These macros are public so that portable tests can be written. +// Such tests typically surround code using a feature with an #if +// which controls that code. For example: +// +// #if GTEST_HAS_DEATH_TEST +// EXPECT_DEATH(DoSomethingDeadly()); +// #endif +// +// GTEST_HAS_DEATH_TEST - death tests +// GTEST_HAS_TYPED_TEST - typed tests +// GTEST_HAS_TYPED_TEST_P - type-parameterized tests +// GTEST_IS_THREADSAFE - Google Test is thread-safe. +// GTEST_USES_POSIX_RE - enhanced POSIX regex is used. Do not confuse with +// GTEST_HAS_POSIX_RE (see above) which users can +// define themselves. +// GTEST_USES_SIMPLE_RE - our own simple regex is used; +// the above RE\b(s) are mutually exclusive. + +// Misc public macros +// ------------------ +// +// GTEST_FLAG(flag_name) - references the variable corresponding to +// the given Google Test flag. + +// Internal utilities +// ------------------ +// +// The following macros and utilities are for Google Test's INTERNAL +// use only. Code outside Google Test MUST NOT USE THEM DIRECTLY. +// +// Macros for basic C++ coding: +// GTEST_AMBIGUOUS_ELSE_BLOCKER_ - for disabling a gcc warning. +// GTEST_ATTRIBUTE_UNUSED_ - declares that a class' instances or a +// variable don't have to be used. +// GTEST_DISALLOW_COPY_AND_ASSIGN_ - disables copy ctor and operator=. +// GTEST_DISALLOW_MOVE_AND_ASSIGN_ - disables move ctor and operator=. +// GTEST_MUST_USE_RESULT_ - declares that a function's result must be used. +// GTEST_INTENTIONAL_CONST_COND_PUSH_ - start code section where MSVC C4127 is +// suppressed (constant conditional). +// GTEST_INTENTIONAL_CONST_COND_POP_ - finish code section where MSVC C4127 +// is suppressed. +// GTEST_INTERNAL_HAS_ANY - for enabling UniversalPrinter or +// UniversalPrinter specializations. +// GTEST_INTERNAL_HAS_OPTIONAL - for enabling UniversalPrinter +// or +// UniversalPrinter +// specializations. +// GTEST_INTERNAL_HAS_STRING_VIEW - for enabling Matcher or +// Matcher +// specializations. +// GTEST_INTERNAL_HAS_VARIANT - for enabling UniversalPrinter or +// UniversalPrinter +// specializations. +// +// Synchronization: +// Mutex, MutexLock, ThreadLocal, GetThreadCount() +// - synchronization primitives. +// +// Regular expressions: +// RE - a simple regular expression class using the POSIX +// Extended Regular Expression syntax on UNIX-like platforms +// or a reduced regular exception syntax on other +// platforms, including Windows. +// Logging: +// GTEST_LOG_() - logs messages at the specified severity level. +// LogToStderr() - directs all log messages to stderr. +// FlushInfoLog() - flushes informational log messages. +// +// Stdout and stderr capturing: +// CaptureStdout() - starts capturing stdout. +// GetCapturedStdout() - stops capturing stdout and returns the captured +// string. +// CaptureStderr() - starts capturing stderr. +// GetCapturedStderr() - stops capturing stderr and returns the captured +// string. +// +// Integer types: +// TypeWithSize - maps an integer to a int type. +// TimeInMillis - integers of known sizes. +// BiggestInt - the biggest signed integer type. +// +// Command-line utilities: +// GTEST_DECLARE_*() - declares a flag. +// GTEST_DEFINE_*() - defines a flag. +// GetInjectableArgvs() - returns the command line as a vector of strings. +// +// Environment variable utilities: +// GetEnv() - gets the value of an environment variable. +// BoolFromGTestEnv() - parses a bool environment variable. +// Int32FromGTestEnv() - parses an int32_t environment variable. +// StringFromGTestEnv() - parses a string environment variable. +// +// Deprecation warnings: +// GTEST_INTERNAL_DEPRECATED(message) - attribute marking a function as +// deprecated; calling a marked function +// should generate a compiler warning + +#include // for isspace, etc +#include // for ptrdiff_t +#include +#include +#include + +#include +// #include // Guarded by GTEST_IS_THREADSAFE below +#include +#include +#include +#include +#include +#include +// #include // Guarded by GTEST_IS_THREADSAFE below +#include +#include +#include + +#ifndef _WIN32_WCE +# include +# include +#endif // !_WIN32_WCE + +#if defined __APPLE__ +# include +# include +#endif + +#include "gtest/internal/custom/gtest-port.h" +#include "gtest/internal/gtest-port-arch.h" + +#if !defined(GTEST_DEV_EMAIL_) +# define GTEST_DEV_EMAIL_ "googletestframework@@googlegroups.com" +# define GTEST_FLAG_PREFIX_ "gtest_" +# define GTEST_FLAG_PREFIX_DASH_ "gtest-" +# define GTEST_FLAG_PREFIX_UPPER_ "GTEST_" +# define GTEST_NAME_ "Google Test" +# define GTEST_PROJECT_URL_ "https://github.com/google/googletest/" +#endif // !defined(GTEST_DEV_EMAIL_) + +#if !defined(GTEST_INIT_GOOGLE_TEST_NAME_) +# define GTEST_INIT_GOOGLE_TEST_NAME_ "testing::InitGoogleTest" +#endif // !defined(GTEST_INIT_GOOGLE_TEST_NAME_) + +// Determines the version of gcc that is used to compile this. +#ifdef __GNUC__ +// 40302 means version 4.3.2. +# define GTEST_GCC_VER_ \ + (__GNUC__*10000 + __GNUC_MINOR__*100 + __GNUC_PATCHLEVEL__) +#endif // __GNUC__ + +// Macros for disabling Microsoft Visual C++ warnings. +// +// GTEST_DISABLE_MSC_WARNINGS_PUSH_(4800 4385) +// /* code that triggers warnings C4800 and C4385 */ +// GTEST_DISABLE_MSC_WARNINGS_POP_() +#if defined(_MSC_VER) +# define GTEST_DISABLE_MSC_WARNINGS_PUSH_(warnings) \ + __pragma(warning(push)) \ + __pragma(warning(disable: warnings)) +# define GTEST_DISABLE_MSC_WARNINGS_POP_() \ + __pragma(warning(pop)) +#else +// Not all compilers are MSVC +# define GTEST_DISABLE_MSC_WARNINGS_PUSH_(warnings) +# define GTEST_DISABLE_MSC_WARNINGS_POP_() +#endif + +// Clang on Windows does not understand MSVC's pragma warning. +// We need clang-specific way to disable function deprecation warning. +#ifdef __clang__ +# define GTEST_DISABLE_MSC_DEPRECATED_PUSH_() \ + _Pragma("clang diagnostic push") \ + _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") \ + _Pragma("clang diagnostic ignored \"-Wdeprecated-implementations\"") +#define GTEST_DISABLE_MSC_DEPRECATED_POP_() \ + _Pragma("clang diagnostic pop") +#else +# define GTEST_DISABLE_MSC_DEPRECATED_PUSH_() \ + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4996) +# define GTEST_DISABLE_MSC_DEPRECATED_POP_() \ + GTEST_DISABLE_MSC_WARNINGS_POP_() +#endif + +// Brings in definitions for functions used in the testing::internal::posix +// namespace (read, write, close, chdir, isatty, stat). We do not currently +// use them on Windows Mobile. +#if GTEST_OS_WINDOWS +# if !GTEST_OS_WINDOWS_MOBILE +# include +# include +# endif +// In order to avoid having to include , use forward declaration +#if GTEST_OS_WINDOWS_MINGW && !defined(__MINGW64_VERSION_MAJOR) +// MinGW defined _CRITICAL_SECTION and _RTL_CRITICAL_SECTION as two +// separate (equivalent) structs, instead of using typedef +typedef struct _CRITICAL_SECTION GTEST_CRITICAL_SECTION; +#else +// Assume CRITICAL_SECTION is a typedef of _RTL_CRITICAL_SECTION. +// This assumption is verified by +// WindowsTypesTest.CRITICAL_SECTIONIs_RTL_CRITICAL_SECTION. +typedef struct _RTL_CRITICAL_SECTION GTEST_CRITICAL_SECTION; +#endif +#elif GTEST_OS_XTENSA +#include +// Xtensa toolchains define strcasecmp in the string.h header instead of +// strings.h. string.h is already included. +#else +// This assumes that non-Windows OSes provide unistd.h. For OSes where this +// is not the case, we need to include headers that provide the functions +// mentioned above. +# include +# include +#endif // GTEST_OS_WINDOWS + +#if GTEST_OS_LINUX_ANDROID +// Used to define __ANDROID_API__ matching the target NDK API level. +# include // NOLINT +#endif + +// Defines this to true if and only if Google Test can use POSIX regular +// expressions. +#ifndef GTEST_HAS_POSIX_RE +# if GTEST_OS_LINUX_ANDROID +// On Android, is only available starting with Gingerbread. +# define GTEST_HAS_POSIX_RE (__ANDROID_API__ >= 9) +# else +#define GTEST_HAS_POSIX_RE (!GTEST_OS_WINDOWS && !GTEST_OS_XTENSA) +# endif +#endif + +#if GTEST_USES_PCRE +// The appropriate headers have already been included. + +#elif GTEST_HAS_POSIX_RE + +// On some platforms, needs someone to define size_t, and +// won't compile otherwise. We can #include it here as we already +// included , which is guaranteed to define size_t through +// . +# include // NOLINT + +# define GTEST_USES_POSIX_RE 1 + +#elif GTEST_OS_WINDOWS + +// is not available on Windows. Use our own simple regex +// implementation instead. +# define GTEST_USES_SIMPLE_RE 1 + +#else + +// may not be available on this platform. Use our own +// simple regex implementation instead. +# define GTEST_USES_SIMPLE_RE 1 + +#endif // GTEST_USES_PCRE + +#ifndef GTEST_HAS_EXCEPTIONS +// The user didn't tell us whether exceptions are enabled, so we need +// to figure it out. +# if defined(_MSC_VER) && defined(_CPPUNWIND) +// MSVC defines _CPPUNWIND to 1 if and only if exceptions are enabled. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__BORLANDC__) +// C++Builder's implementation of the STL uses the _HAS_EXCEPTIONS +// macro to enable exceptions, so we'll do the same. +// Assumes that exceptions are enabled by default. +# ifndef _HAS_EXCEPTIONS +# define _HAS_EXCEPTIONS 1 +# endif // _HAS_EXCEPTIONS +# define GTEST_HAS_EXCEPTIONS _HAS_EXCEPTIONS +# elif defined(__clang__) +// clang defines __EXCEPTIONS if and only if exceptions are enabled before clang +// 220714, but if and only if cleanups are enabled after that. In Obj-C++ files, +// there can be cleanups for ObjC exceptions which also need cleanups, even if +// C++ exceptions are disabled. clang has __has_feature(cxx_exceptions) which +// checks for C++ exceptions starting at clang r206352, but which checked for +// cleanups prior to that. To reliably check for C++ exception availability with +// clang, check for +// __EXCEPTIONS && __has_feature(cxx_exceptions). +# define GTEST_HAS_EXCEPTIONS (__EXCEPTIONS && __has_feature(cxx_exceptions)) +# elif defined(__GNUC__) && __EXCEPTIONS +// gcc defines __EXCEPTIONS to 1 if and only if exceptions are enabled. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__SUNPRO_CC) +// Sun Pro CC supports exceptions. However, there is no compile-time way of +// detecting whether they are enabled or not. Therefore, we assume that +// they are enabled unless the user tells us otherwise. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__IBMCPP__) && __EXCEPTIONS +// xlC defines __EXCEPTIONS to 1 if and only if exceptions are enabled. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__HP_aCC) +// Exception handling is in effect by default in HP aCC compiler. It has to +// be turned of by +noeh compiler option if desired. +# define GTEST_HAS_EXCEPTIONS 1 +# else +// For other compilers, we assume exceptions are disabled to be +// conservative. +# define GTEST_HAS_EXCEPTIONS 0 +# endif // defined(_MSC_VER) || defined(__BORLANDC__) +#endif // GTEST_HAS_EXCEPTIONS + +#ifndef GTEST_HAS_STD_WSTRING +// The user didn't tell us whether ::std::wstring is available, so we need +// to figure it out. +// Cygwin 1.7 and below doesn't support ::std::wstring. +// Solaris' libc++ doesn't support it either. Android has +// no support for it at least as recent as Froyo (2.2). +#define GTEST_HAS_STD_WSTRING \ + (!(GTEST_OS_LINUX_ANDROID || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS || \ + GTEST_OS_HAIKU || GTEST_OS_ESP32 || GTEST_OS_ESP8266 || GTEST_OS_XTENSA)) + +#endif // GTEST_HAS_STD_WSTRING + +// Determines whether RTTI is available. +#ifndef GTEST_HAS_RTTI +// The user didn't tell us whether RTTI is enabled, so we need to +// figure it out. + +# ifdef _MSC_VER + +#ifdef _CPPRTTI // MSVC defines this macro if and only if RTTI is enabled. +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif + +// Starting with version 4.3.2, gcc defines __GXX_RTTI if and only if RTTI is +// enabled. +# elif defined(__GNUC__) + +# ifdef __GXX_RTTI +// When building against STLport with the Android NDK and with +// -frtti -fno-exceptions, the build fails at link time with undefined +// references to __cxa_bad_typeid. Note sure if STL or toolchain bug, +// so disable RTTI when detected. +# if GTEST_OS_LINUX_ANDROID && defined(_STLPORT_MAJOR) && \ + !defined(__EXCEPTIONS) +# define GTEST_HAS_RTTI 0 +# else +# define GTEST_HAS_RTTI 1 +# endif // GTEST_OS_LINUX_ANDROID && __STLPORT_MAJOR && !__EXCEPTIONS +# else +# define GTEST_HAS_RTTI 0 +# endif // __GXX_RTTI + +// Clang defines __GXX_RTTI starting with version 3.0, but its manual recommends +// using has_feature instead. has_feature(cxx_rtti) is supported since 2.7, the +// first version with C++ support. +# elif defined(__clang__) + +# define GTEST_HAS_RTTI __has_feature(cxx_rtti) + +// Starting with version 9.0 IBM Visual Age defines __RTTI_ALL__ to 1 if +// both the typeid and dynamic_cast features are present. +# elif defined(__IBMCPP__) && (__IBMCPP__ >= 900) + +# ifdef __RTTI_ALL__ +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif + +# else + +// For all other compilers, we assume RTTI is enabled. +# define GTEST_HAS_RTTI 1 + +# endif // _MSC_VER + +#endif // GTEST_HAS_RTTI + +// It's this header's responsibility to #include when RTTI +// is enabled. +#if GTEST_HAS_RTTI +# include +#endif + +// Determines whether Google Test can use the pthreads library. +#ifndef GTEST_HAS_PTHREAD +// The user didn't tell us explicitly, so we make reasonable assumptions about +// which platforms have pthreads support. +// +// To disable threading support in Google Test, add -DGTEST_HAS_PTHREAD=0 +// to your compiler flags. +#define GTEST_HAS_PTHREAD \ + (GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_HPUX || GTEST_OS_QNX || \ + GTEST_OS_FREEBSD || GTEST_OS_NACL || GTEST_OS_NETBSD || GTEST_OS_FUCHSIA || \ + GTEST_OS_DRAGONFLY || GTEST_OS_GNU_KFREEBSD || GTEST_OS_OPENBSD || \ + GTEST_OS_HAIKU || GTEST_OS_GNU_HURD) +#endif // GTEST_HAS_PTHREAD + +#if GTEST_HAS_PTHREAD +// gtest-port.h guarantees to #include when GTEST_HAS_PTHREAD is +// true. +# include // NOLINT + +// For timespec and nanosleep, used below. +# include // NOLINT +#endif + +// Determines whether clone(2) is supported. +// Usually it will only be available on Linux, excluding +// Linux on the Itanium architecture. +// Also see http://linux.die.net/man/2/clone. +#ifndef GTEST_HAS_CLONE +// The user didn't tell us, so we need to figure it out. + +# if GTEST_OS_LINUX && !defined(__ia64__) +# if GTEST_OS_LINUX_ANDROID +// On Android, clone() became available at different API levels for each 32-bit +// architecture. +# if defined(__LP64__) || \ + (defined(__arm__) && __ANDROID_API__ >= 9) || \ + (defined(__mips__) && __ANDROID_API__ >= 12) || \ + (defined(__i386__) && __ANDROID_API__ >= 17) +# define GTEST_HAS_CLONE 1 +# else +# define GTEST_HAS_CLONE 0 +# endif +# else +# define GTEST_HAS_CLONE 1 +# endif +# else +# define GTEST_HAS_CLONE 0 +# endif // GTEST_OS_LINUX && !defined(__ia64__) + +#endif // GTEST_HAS_CLONE + +// Determines whether to support stream redirection. This is used to test +// output correctness and to implement death tests. +#ifndef GTEST_HAS_STREAM_REDIRECTION +// By default, we assume that stream redirection is supported on all +// platforms except known mobile ones. +#if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_PHONE || \ + GTEST_OS_WINDOWS_RT || GTEST_OS_ESP8266 || GTEST_OS_XTENSA +# define GTEST_HAS_STREAM_REDIRECTION 0 +# else +# define GTEST_HAS_STREAM_REDIRECTION 1 +# endif // !GTEST_OS_WINDOWS_MOBILE +#endif // GTEST_HAS_STREAM_REDIRECTION + +// Determines whether to support death tests. +// pops up a dialog window that cannot be suppressed programmatically. +#if (GTEST_OS_LINUX || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS || \ + (GTEST_OS_MAC && !GTEST_OS_IOS) || \ + (GTEST_OS_WINDOWS_DESKTOP && _MSC_VER) || GTEST_OS_WINDOWS_MINGW || \ + GTEST_OS_AIX || GTEST_OS_HPUX || GTEST_OS_OPENBSD || GTEST_OS_QNX || \ + GTEST_OS_FREEBSD || GTEST_OS_NETBSD || GTEST_OS_FUCHSIA || \ + GTEST_OS_DRAGONFLY || GTEST_OS_GNU_KFREEBSD || GTEST_OS_HAIKU || \ + GTEST_OS_GNU_HURD) +# define GTEST_HAS_DEATH_TEST 1 +#endif + +// Determines whether to support type-driven tests. + +// Typed tests need and variadic macros, which GCC, VC++ 8.0, +// Sun Pro CC, IBM Visual Age, and HP aCC support. +#if defined(__GNUC__) || defined(_MSC_VER) || defined(__SUNPRO_CC) || \ + defined(__IBMCPP__) || defined(__HP_aCC) +# define GTEST_HAS_TYPED_TEST 1 +# define GTEST_HAS_TYPED_TEST_P 1 +#endif + +// Determines whether the system compiler uses UTF-16 for encoding wide strings. +#define GTEST_WIDE_STRING_USES_UTF16_ \ + (GTEST_OS_WINDOWS || GTEST_OS_CYGWIN || GTEST_OS_AIX || GTEST_OS_OS2) + +// Determines whether test results can be streamed to a socket. +#if GTEST_OS_LINUX || GTEST_OS_GNU_KFREEBSD || GTEST_OS_DRAGONFLY || \ + GTEST_OS_FREEBSD || GTEST_OS_NETBSD || GTEST_OS_OPENBSD || \ + GTEST_OS_GNU_HURD +# define GTEST_CAN_STREAM_RESULTS_ 1 +#endif + +// Defines some utility macros. + +// The GNU compiler emits a warning if nested "if" statements are followed by +// an "else" statement and braces are not used to explicitly disambiguate the +// "else" binding. This leads to problems with code like: +// +// if (gate) +// ASSERT_*(condition) << "Some message"; +// +// The "switch (0) case 0:" idiom is used to suppress this. +#ifdef __INTEL_COMPILER +# define GTEST_AMBIGUOUS_ELSE_BLOCKER_ +#else +# define GTEST_AMBIGUOUS_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT +#endif + +// Use this annotation at the end of a struct/class definition to +// prevent the compiler from optimizing away instances that are never +// used. This is useful when all interesting logic happens inside the +// c'tor and / or d'tor. Example: +// +// struct Foo { +// Foo() { ... } +// } GTEST_ATTRIBUTE_UNUSED_; +// +// Also use it after a variable or parameter declaration to tell the +// compiler the variable/parameter does not have to be used. +#if defined(__GNUC__) && !defined(COMPILER_ICC) +# define GTEST_ATTRIBUTE_UNUSED_ __attribute__ ((unused)) +#elif defined(__clang__) +# if __has_attribute(unused) +# define GTEST_ATTRIBUTE_UNUSED_ __attribute__ ((unused)) +# endif +#endif +#ifndef GTEST_ATTRIBUTE_UNUSED_ +# define GTEST_ATTRIBUTE_UNUSED_ +#endif + +// Use this annotation before a function that takes a printf format string. +#if (defined(__GNUC__) || defined(__clang__)) && !defined(COMPILER_ICC) +# if defined(__MINGW_PRINTF_FORMAT) +// MinGW has two different printf implementations. Ensure the format macro +// matches the selected implementation. See +// https://sourceforge.net/p/mingw-w64/wiki2/gnu%20printf/. +# define GTEST_ATTRIBUTE_PRINTF_(string_index, first_to_check) \ + __attribute__((__format__(__MINGW_PRINTF_FORMAT, string_index, \ + first_to_check))) +# else +# define GTEST_ATTRIBUTE_PRINTF_(string_index, first_to_check) \ + __attribute__((__format__(__printf__, string_index, first_to_check))) +# endif +#else +# define GTEST_ATTRIBUTE_PRINTF_(string_index, first_to_check) +#endif + +// A macro to disallow copy constructor and operator= +// This should be used in the private: declarations for a class. +// NOLINT is for modernize-use-trailing-return-type in macro uses. +#define GTEST_DISALLOW_COPY_AND_ASSIGN_(type) \ + type(type const&) = delete; \ + type& operator=(type const&) = delete /* NOLINT */ + +// A macro to disallow move constructor and operator= +// This should be used in the private: declarations for a class. +// NOLINT is for modernize-use-trailing-return-type in macro uses. +#define GTEST_DISALLOW_MOVE_AND_ASSIGN_(type) \ + type(type&&) noexcept = delete; \ + type& operator=(type&&) noexcept = delete /* NOLINT */ + +// Tell the compiler to warn about unused return values for functions declared +// with this macro. The macro should be used on function declarations +// following the argument list: +// +// Sprocket* AllocateSprocket() GTEST_MUST_USE_RESULT_; +#if defined(__GNUC__) && !defined(COMPILER_ICC) +# define GTEST_MUST_USE_RESULT_ __attribute__ ((warn_unused_result)) +#else +# define GTEST_MUST_USE_RESULT_ +#endif // __GNUC__ && !COMPILER_ICC + +// MS C++ compiler emits warning when a conditional expression is compile time +// constant. In some contexts this warning is false positive and needs to be +// suppressed. Use the following two macros in such cases: +// +// GTEST_INTENTIONAL_CONST_COND_PUSH_() +// while (true) { +// GTEST_INTENTIONAL_CONST_COND_POP_() +// } +# define GTEST_INTENTIONAL_CONST_COND_PUSH_() \ + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4127) +# define GTEST_INTENTIONAL_CONST_COND_POP_() \ + GTEST_DISABLE_MSC_WARNINGS_POP_() + +// Determine whether the compiler supports Microsoft's Structured Exception +// Handling. This is supported by several Windows compilers but generally +// does not exist on any other system. +#ifndef GTEST_HAS_SEH +// The user didn't tell us, so we need to figure it out. + +# if defined(_MSC_VER) || defined(__BORLANDC__) +// These two compilers are known to support SEH. +# define GTEST_HAS_SEH 1 +# else +// Assume no SEH. +# define GTEST_HAS_SEH 0 +# endif + +#endif // GTEST_HAS_SEH + +#ifndef GTEST_IS_THREADSAFE + +#define GTEST_IS_THREADSAFE \ + (GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ || \ + (GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT) || \ + GTEST_HAS_PTHREAD) + +#endif // GTEST_IS_THREADSAFE + +#if GTEST_IS_THREADSAFE +// Some platforms don't support including these threading related headers. +#include // NOLINT +#include // NOLINT +#endif // GTEST_IS_THREADSAFE + +// GTEST_API_ qualifies all symbols that must be exported. The definitions below +// are guarded by #ifndef to give embedders a chance to define GTEST_API_ in +// gtest/internal/custom/gtest-port.h +#ifndef GTEST_API_ + +#ifdef _MSC_VER +# if GTEST_LINKED_AS_SHARED_LIBRARY +# define GTEST_API_ __declspec(dllimport) +# elif GTEST_CREATE_SHARED_LIBRARY +# define GTEST_API_ __declspec(dllexport) +# endif +#elif __GNUC__ >= 4 || defined(__clang__) +# define GTEST_API_ __attribute__((visibility ("default"))) +#endif // _MSC_VER + +#endif // GTEST_API_ + +#ifndef GTEST_API_ +# define GTEST_API_ +#endif // GTEST_API_ + +#ifndef GTEST_DEFAULT_DEATH_TEST_STYLE +# define GTEST_DEFAULT_DEATH_TEST_STYLE "fast" +#endif // GTEST_DEFAULT_DEATH_TEST_STYLE + +#ifdef __GNUC__ +// Ask the compiler to never inline a given function. +# define GTEST_NO_INLINE_ __attribute__((noinline)) +#else +# define GTEST_NO_INLINE_ +#endif + +#if defined(__clang__) +// Nested ifs to avoid triggering MSVC warning. +#if __has_attribute(disable_tail_calls) +// Ask the compiler not to perform tail call optimization inside +// the marked function. +#define GTEST_NO_TAIL_CALL_ __attribute__((disable_tail_calls)) +#endif +#elif __GNUC__ +#define GTEST_NO_TAIL_CALL_ \ + __attribute__((optimize("no-optimize-sibling-calls"))) +#else +#define GTEST_NO_TAIL_CALL_ +#endif + +// _LIBCPP_VERSION is defined by the libc++ library from the LLVM project. +#if !defined(GTEST_HAS_CXXABI_H_) +# if defined(__GLIBCXX__) || (defined(_LIBCPP_VERSION) && !defined(_MSC_VER)) +# define GTEST_HAS_CXXABI_H_ 1 +# else +# define GTEST_HAS_CXXABI_H_ 0 +# endif +#endif + +// A function level attribute to disable checking for use of uninitialized +// memory when built with MemorySanitizer. +#if defined(__clang__) +# if __has_feature(memory_sanitizer) +# define GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ \ + __attribute__((no_sanitize_memory)) +# else +# define GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ +# endif // __has_feature(memory_sanitizer) +#else +# define GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ +#endif // __clang__ + +// A function level attribute to disable AddressSanitizer instrumentation. +#if defined(__clang__) +# if __has_feature(address_sanitizer) +# define GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ \ + __attribute__((no_sanitize_address)) +# else +# define GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ +# endif // __has_feature(address_sanitizer) +#else +# define GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ +#endif // __clang__ + +// A function level attribute to disable HWAddressSanitizer instrumentation. +#if defined(__clang__) +# if __has_feature(hwaddress_sanitizer) +# define GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ \ + __attribute__((no_sanitize("hwaddress"))) +# else +# define GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ +# endif // __has_feature(hwaddress_sanitizer) +#else +# define GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ +#endif // __clang__ + +// A function level attribute to disable ThreadSanitizer instrumentation. +#if defined(__clang__) +# if __has_feature(thread_sanitizer) +# define GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ \ + __attribute__((no_sanitize_thread)) +# else +# define GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ +# endif // __has_feature(thread_sanitizer) +#else +# define GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ +#endif // __clang__ + +namespace testing { + +class Message; + +// Legacy imports for backwards compatibility. +// New code should use std:: names directly. +using std::get; +using std::make_tuple; +using std::tuple; +using std::tuple_element; +using std::tuple_size; + +namespace internal { + +// A secret type that Google Test users don't know about. It has no +// definition on purpose. Therefore it's impossible to create a +// Secret object, which is what we want. +class Secret; + +// The GTEST_COMPILE_ASSERT_ is a legacy macro used to verify that a compile +// time expression is true (in new code, use static_assert instead). For +// example, you could use it to verify the size of a static array: +// +// GTEST_COMPILE_ASSERT_(GTEST_ARRAY_SIZE_(names) == NUM_NAMES, +// names_incorrect_size); +// +// The second argument to the macro must be a valid C++ identifier. If the +// expression is false, compiler will issue an error containing this identifier. +#define GTEST_COMPILE_ASSERT_(expr, msg) static_assert(expr, #msg) + +// A helper for suppressing warnings on constant condition. It just +// returns 'condition'. +GTEST_API_ bool IsTrue(bool condition); + +// Defines RE. + +#if GTEST_USES_PCRE +// if used, PCRE is injected by custom/gtest-port.h +#elif GTEST_USES_POSIX_RE || GTEST_USES_SIMPLE_RE + +// A simple C++ wrapper for . It uses the POSIX Extended +// Regular Expression syntax. +class GTEST_API_ RE { + public: + // A copy constructor is required by the Standard to initialize object + // references from r-values. + RE(const RE& other) { Init(other.pattern()); } + + // Constructs an RE from a string. + RE(const ::std::string& regex) { Init(regex.c_str()); } // NOLINT + + RE(const char* regex) { Init(regex); } // NOLINT + ~RE(); + + // Returns the string representation of the regex. + const char* pattern() const { return pattern_; } + + // FullMatch(str, re) returns true if and only if regular expression re + // matches the entire str. + // PartialMatch(str, re) returns true if and only if regular expression re + // matches a substring of str (including str itself). + static bool FullMatch(const ::std::string& str, const RE& re) { + return FullMatch(str.c_str(), re); + } + static bool PartialMatch(const ::std::string& str, const RE& re) { + return PartialMatch(str.c_str(), re); + } + + static bool FullMatch(const char* str, const RE& re); + static bool PartialMatch(const char* str, const RE& re); + + private: + void Init(const char* regex); + const char* pattern_; + bool is_valid_; + +# if GTEST_USES_POSIX_RE + + regex_t full_regex_; // For FullMatch(). + regex_t partial_regex_; // For PartialMatch(). + +# else // GTEST_USES_SIMPLE_RE + + const char* full_pattern_; // For FullMatch(); + +# endif +}; + +#endif // GTEST_USES_PCRE + +// Formats a source file path and a line number as they would appear +// in an error message from the compiler used to compile this code. +GTEST_API_ ::std::string FormatFileLocation(const char* file, int line); + +// Formats a file location for compiler-independent XML output. +// Although this function is not platform dependent, we put it next to +// FormatFileLocation in order to contrast the two functions. +GTEST_API_ ::std::string FormatCompilerIndependentFileLocation(const char* file, + int line); + +// Defines logging utilities: +// GTEST_LOG_(severity) - logs messages at the specified severity level. The +// message itself is streamed into the macro. +// LogToStderr() - directs all log messages to stderr. +// FlushInfoLog() - flushes informational log messages. + +enum GTestLogSeverity { + GTEST_INFO, + GTEST_WARNING, + GTEST_ERROR, + GTEST_FATAL +}; + +// Formats log entry severity, provides a stream object for streaming the +// log message, and terminates the message with a newline when going out of +// scope. +class GTEST_API_ GTestLog { + public: + GTestLog(GTestLogSeverity severity, const char* file, int line); + + // Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. + ~GTestLog(); + + ::std::ostream& GetStream() { return ::std::cerr; } + + private: + const GTestLogSeverity severity_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestLog); +}; + +#if !defined(GTEST_LOG_) + +# define GTEST_LOG_(severity) \ + ::testing::internal::GTestLog(::testing::internal::GTEST_##severity, \ + __FILE__, __LINE__).GetStream() + +inline void LogToStderr() {} +inline void FlushInfoLog() { fflush(nullptr); } + +#endif // !defined(GTEST_LOG_) + +#if !defined(GTEST_CHECK_) +// INTERNAL IMPLEMENTATION - DO NOT USE. +// +// GTEST_CHECK_ is an all-mode assert. It aborts the program if the condition +// is not satisfied. +// Synopsis: +// GTEST_CHECK_(boolean_condition); +// or +// GTEST_CHECK_(boolean_condition) << "Additional message"; +// +// This checks the condition and if the condition is not satisfied +// it prints message about the condition violation, including the +// condition itself, plus additional message streamed into it, if any, +// and then it aborts the program. It aborts the program irrespective of +// whether it is built in the debug mode or not. +# define GTEST_CHECK_(condition) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::IsTrue(condition)) \ + ; \ + else \ + GTEST_LOG_(FATAL) << "Condition " #condition " failed. " +#endif // !defined(GTEST_CHECK_) + +// An all-mode assert to verify that the given POSIX-style function +// call returns 0 (indicating success). Known limitation: this +// doesn't expand to a balanced 'if' statement, so enclose the macro +// in {} if you need to use it as the only statement in an 'if' +// branch. +#define GTEST_CHECK_POSIX_SUCCESS_(posix_call) \ + if (const int gtest_error = (posix_call)) \ + GTEST_LOG_(FATAL) << #posix_call << "failed with error " \ + << gtest_error + +// Transforms "T" into "const T&" according to standard reference collapsing +// rules (this is only needed as a backport for C++98 compilers that do not +// support reference collapsing). Specifically, it transforms: +// +// char ==> const char& +// const char ==> const char& +// char& ==> char& +// const char& ==> const char& +// +// Note that the non-const reference will not have "const" added. This is +// standard, and necessary so that "T" can always bind to "const T&". +template +struct ConstRef { typedef const T& type; }; +template +struct ConstRef { typedef T& type; }; + +// The argument T must depend on some template parameters. +#define GTEST_REFERENCE_TO_CONST_(T) \ + typename ::testing::internal::ConstRef::type + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Use ImplicitCast_ as a safe version of static_cast for upcasting in +// the type hierarchy (e.g. casting a Foo* to a SuperclassOfFoo* or a +// const Foo*). When you use ImplicitCast_, the compiler checks that +// the cast is safe. Such explicit ImplicitCast_s are necessary in +// surprisingly many situations where C++ demands an exact type match +// instead of an argument type convertible to a target type. +// +// The syntax for using ImplicitCast_ is the same as for static_cast: +// +// ImplicitCast_(expr) +// +// ImplicitCast_ would have been part of the C++ standard library, +// but the proposal was submitted too late. It will probably make +// its way into the language in the future. +// +// This relatively ugly name is intentional. It prevents clashes with +// similar functions users may have (e.g., implicit_cast). The internal +// namespace alone is not enough because the function can be found by ADL. +template +inline To ImplicitCast_(To x) { return x; } + +// When you upcast (that is, cast a pointer from type Foo to type +// SuperclassOfFoo), it's fine to use ImplicitCast_<>, since upcasts +// always succeed. When you downcast (that is, cast a pointer from +// type Foo to type SubclassOfFoo), static_cast<> isn't safe, because +// how do you know the pointer is really of type SubclassOfFoo? It +// could be a bare Foo, or of type DifferentSubclassOfFoo. Thus, +// when you downcast, you should use this macro. In debug mode, we +// use dynamic_cast<> to double-check the downcast is legal (we die +// if it's not). In normal mode, we do the efficient static_cast<> +// instead. Thus, it's important to test in debug mode to make sure +// the cast is legal! +// This is the only place in the code we should use dynamic_cast<>. +// In particular, you SHOULDN'T be using dynamic_cast<> in order to +// do RTTI (eg code like this: +// if (dynamic_cast(foo)) HandleASubclass1Object(foo); +// if (dynamic_cast(foo)) HandleASubclass2Object(foo); +// You should design the code some other way not to need this. +// +// This relatively ugly name is intentional. It prevents clashes with +// similar functions users may have (e.g., down_cast). The internal +// namespace alone is not enough because the function can be found by ADL. +template // use like this: DownCast_(foo); +inline To DownCast_(From* f) { // so we only accept pointers + // Ensures that To is a sub-type of From *. This test is here only + // for compile-time type checking, and has no overhead in an + // optimized build at run-time, as it will be optimized away + // completely. + GTEST_INTENTIONAL_CONST_COND_PUSH_() + if (false) { + GTEST_INTENTIONAL_CONST_COND_POP_() + const To to = nullptr; + ::testing::internal::ImplicitCast_(to); + } + +#if GTEST_HAS_RTTI + // RTTI: debug mode only! + GTEST_CHECK_(f == nullptr || dynamic_cast(f) != nullptr); +#endif + return static_cast(f); +} + +// Downcasts the pointer of type Base to Derived. +// Derived must be a subclass of Base. The parameter MUST +// point to a class of type Derived, not any subclass of it. +// When RTTI is available, the function performs a runtime +// check to enforce this. +template +Derived* CheckedDowncastToActualType(Base* base) { +#if GTEST_HAS_RTTI + GTEST_CHECK_(typeid(*base) == typeid(Derived)); +#endif + +#if GTEST_HAS_DOWNCAST_ + return ::down_cast(base); +#elif GTEST_HAS_RTTI + return dynamic_cast(base); // NOLINT +#else + return static_cast(base); // Poor man's downcast. +#endif +} + +#if GTEST_HAS_STREAM_REDIRECTION + +// Defines the stderr capturer: +// CaptureStdout - starts capturing stdout. +// GetCapturedStdout - stops capturing stdout and returns the captured string. +// CaptureStderr - starts capturing stderr. +// GetCapturedStderr - stops capturing stderr and returns the captured string. +// +GTEST_API_ void CaptureStdout(); +GTEST_API_ std::string GetCapturedStdout(); +GTEST_API_ void CaptureStderr(); +GTEST_API_ std::string GetCapturedStderr(); + +#endif // GTEST_HAS_STREAM_REDIRECTION +// Returns the size (in bytes) of a file. +GTEST_API_ size_t GetFileSize(FILE* file); + +// Reads the entire content of a file as a string. +GTEST_API_ std::string ReadEntireFile(FILE* file); + +// All command line arguments. +GTEST_API_ std::vector GetArgvs(); + +#if GTEST_HAS_DEATH_TEST + +std::vector GetInjectableArgvs(); +// Deprecated: pass the args vector by value instead. +void SetInjectableArgvs(const std::vector* new_argvs); +void SetInjectableArgvs(const std::vector& new_argvs); +void ClearInjectableArgvs(); + +#endif // GTEST_HAS_DEATH_TEST + +// Defines synchronization primitives. +#if GTEST_IS_THREADSAFE + +# if GTEST_OS_WINDOWS +// Provides leak-safe Windows kernel handle ownership. +// Used in death tests and in threading support. +class GTEST_API_ AutoHandle { + public: + // Assume that Win32 HANDLE type is equivalent to void*. Doing so allows us to + // avoid including in this header file. Including is + // undesirable because it defines a lot of symbols and macros that tend to + // conflict with client code. This assumption is verified by + // WindowsTypesTest.HANDLEIsVoidStar. + typedef void* Handle; + AutoHandle(); + explicit AutoHandle(Handle handle); + + ~AutoHandle(); + + Handle Get() const; + void Reset(); + void Reset(Handle handle); + + private: + // Returns true if and only if the handle is a valid handle object that can be + // closed. + bool IsCloseable() const; + + Handle handle_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(AutoHandle); +}; +# endif + +# if GTEST_HAS_NOTIFICATION_ +// Notification has already been imported into the namespace. +// Nothing to do here. + +# else +// Allows a controller thread to pause execution of newly created +// threads until notified. Instances of this class must be created +// and destroyed in the controller thread. +// +// This class is only for testing Google Test's own constructs. Do not +// use it in user tests, either directly or indirectly. +// TODO(b/203539622): Replace unconditionally with absl::Notification. +class GTEST_API_ Notification { + public: + Notification() : notified_(false) {} + Notification(const Notification&) = delete; + Notification& operator=(const Notification&) = delete; + + // Notifies all threads created with this notification to start. Must + // be called from the controller thread. + void Notify() { + std::lock_guard lock(mu_); + notified_ = true; + cv_.notify_all(); + } + + // Blocks until the controller thread notifies. Must be called from a test + // thread. + void WaitForNotification() { + std::unique_lock lock(mu_); + cv_.wait(lock, [this]() { return notified_; }); + } + + private: + std::mutex mu_; + std::condition_variable cv_; + bool notified_; +}; +# endif // GTEST_HAS_NOTIFICATION_ + +// On MinGW, we can have both GTEST_OS_WINDOWS and GTEST_HAS_PTHREAD +// defined, but we don't want to use MinGW's pthreads implementation, which +// has conformance problems with some versions of the POSIX standard. +# if GTEST_HAS_PTHREAD && !GTEST_OS_WINDOWS_MINGW + +// As a C-function, ThreadFuncWithCLinkage cannot be templated itself. +// Consequently, it cannot select a correct instantiation of ThreadWithParam +// in order to call its Run(). Introducing ThreadWithParamBase as a +// non-templated base class for ThreadWithParam allows us to bypass this +// problem. +class ThreadWithParamBase { + public: + virtual ~ThreadWithParamBase() {} + virtual void Run() = 0; +}; + +// pthread_create() accepts a pointer to a function type with the C linkage. +// According to the Standard (7.5/1), function types with different linkages +// are different even if they are otherwise identical. Some compilers (for +// example, SunStudio) treat them as different types. Since class methods +// cannot be defined with C-linkage we need to define a free C-function to +// pass into pthread_create(). +extern "C" inline void* ThreadFuncWithCLinkage(void* thread) { + static_cast(thread)->Run(); + return nullptr; +} + +// Helper class for testing Google Test's multi-threading constructs. +// To use it, write: +// +// void ThreadFunc(int param) { /* Do things with param */ } +// Notification thread_can_start; +// ... +// // The thread_can_start parameter is optional; you can supply NULL. +// ThreadWithParam thread(&ThreadFunc, 5, &thread_can_start); +// thread_can_start.Notify(); +// +// These classes are only for testing Google Test's own constructs. Do +// not use them in user tests, either directly or indirectly. +template +class ThreadWithParam : public ThreadWithParamBase { + public: + typedef void UserThreadFunc(T); + + ThreadWithParam(UserThreadFunc* func, T param, Notification* thread_can_start) + : func_(func), + param_(param), + thread_can_start_(thread_can_start), + finished_(false) { + ThreadWithParamBase* const base = this; + // The thread can be created only after all fields except thread_ + // have been initialized. + GTEST_CHECK_POSIX_SUCCESS_( + pthread_create(&thread_, nullptr, &ThreadFuncWithCLinkage, base)); + } + ~ThreadWithParam() override { Join(); } + + void Join() { + if (!finished_) { + GTEST_CHECK_POSIX_SUCCESS_(pthread_join(thread_, nullptr)); + finished_ = true; + } + } + + void Run() override { + if (thread_can_start_ != nullptr) thread_can_start_->WaitForNotification(); + func_(param_); + } + + private: + UserThreadFunc* const func_; // User-supplied thread function. + const T param_; // User-supplied parameter to the thread function. + // When non-NULL, used to block execution until the controller thread + // notifies. + Notification* const thread_can_start_; + bool finished_; // true if and only if we know that the thread function has + // finished. + pthread_t thread_; // The native thread object. + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParam); +}; +# endif // !GTEST_OS_WINDOWS && GTEST_HAS_PTHREAD || + // GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ + +# if GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ +// Mutex and ThreadLocal have already been imported into the namespace. +// Nothing to do here. + +# elif GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT + +// Mutex implements mutex on Windows platforms. It is used in conjunction +// with class MutexLock: +// +// Mutex mutex; +// ... +// MutexLock lock(&mutex); // Acquires the mutex and releases it at the +// // end of the current scope. +// +// A static Mutex *must* be defined or declared using one of the following +// macros: +// GTEST_DEFINE_STATIC_MUTEX_(g_some_mutex); +// GTEST_DECLARE_STATIC_MUTEX_(g_some_mutex); +// +// (A non-static Mutex is defined/declared in the usual way). +class GTEST_API_ Mutex { + public: + enum MutexType { kStatic = 0, kDynamic = 1 }; + // We rely on kStaticMutex being 0 as it is to what the linker initializes + // type_ in static mutexes. critical_section_ will be initialized lazily + // in ThreadSafeLazyInit(). + enum StaticConstructorSelector { kStaticMutex = 0 }; + + // This constructor intentionally does nothing. It relies on type_ being + // statically initialized to 0 (effectively setting it to kStatic) and on + // ThreadSafeLazyInit() to lazily initialize the rest of the members. + explicit Mutex(StaticConstructorSelector /*dummy*/) {} + + Mutex(); + ~Mutex(); + + void Lock(); + + void Unlock(); + + // Does nothing if the current thread holds the mutex. Otherwise, crashes + // with high probability. + void AssertHeld(); + + private: + // Initializes owner_thread_id_ and critical_section_ in static mutexes. + void ThreadSafeLazyInit(); + + // Per https://blogs.msdn.microsoft.com/oldnewthing/20040223-00/?p=40503, + // we assume that 0 is an invalid value for thread IDs. + unsigned int owner_thread_id_; + + // For static mutexes, we rely on these members being initialized to zeros + // by the linker. + MutexType type_; + long critical_section_init_phase_; // NOLINT + GTEST_CRITICAL_SECTION* critical_section_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(Mutex); +}; + +# define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ + extern ::testing::internal::Mutex mutex + +# define GTEST_DEFINE_STATIC_MUTEX_(mutex) \ + ::testing::internal::Mutex mutex(::testing::internal::Mutex::kStaticMutex) + +// We cannot name this class MutexLock because the ctor declaration would +// conflict with a macro named MutexLock, which is defined on some +// platforms. That macro is used as a defensive measure to prevent against +// inadvertent misuses of MutexLock like "MutexLock(&mu)" rather than +// "MutexLock l(&mu)". Hence the typedef trick below. +class GTestMutexLock { + public: + explicit GTestMutexLock(Mutex* mutex) + : mutex_(mutex) { mutex_->Lock(); } + + ~GTestMutexLock() { mutex_->Unlock(); } + + private: + Mutex* const mutex_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestMutexLock); +}; + +typedef GTestMutexLock MutexLock; + +// Base class for ValueHolder. Allows a caller to hold and delete a value +// without knowing its type. +class ThreadLocalValueHolderBase { + public: + virtual ~ThreadLocalValueHolderBase() {} +}; + +// Provides a way for a thread to send notifications to a ThreadLocal +// regardless of its parameter type. +class ThreadLocalBase { + public: + // Creates a new ValueHolder object holding a default value passed to + // this ThreadLocal's constructor and returns it. It is the caller's + // responsibility not to call this when the ThreadLocal instance already + // has a value on the current thread. + virtual ThreadLocalValueHolderBase* NewValueForCurrentThread() const = 0; + + protected: + ThreadLocalBase() {} + virtual ~ThreadLocalBase() {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocalBase); +}; + +// Maps a thread to a set of ThreadLocals that have values instantiated on that +// thread and notifies them when the thread exits. A ThreadLocal instance is +// expected to persist until all threads it has values on have terminated. +class GTEST_API_ ThreadLocalRegistry { + public: + // Registers thread_local_instance as having value on the current thread. + // Returns a value that can be used to identify the thread from other threads. + static ThreadLocalValueHolderBase* GetValueOnCurrentThread( + const ThreadLocalBase* thread_local_instance); + + // Invoked when a ThreadLocal instance is destroyed. + static void OnThreadLocalDestroyed( + const ThreadLocalBase* thread_local_instance); +}; + +class GTEST_API_ ThreadWithParamBase { + public: + void Join(); + + protected: + class Runnable { + public: + virtual ~Runnable() {} + virtual void Run() = 0; + }; + + ThreadWithParamBase(Runnable *runnable, Notification* thread_can_start); + virtual ~ThreadWithParamBase(); + + private: + AutoHandle thread_; +}; + +// Helper class for testing Google Test's multi-threading constructs. +template +class ThreadWithParam : public ThreadWithParamBase { + public: + typedef void UserThreadFunc(T); + + ThreadWithParam(UserThreadFunc* func, T param, Notification* thread_can_start) + : ThreadWithParamBase(new RunnableImpl(func, param), thread_can_start) { + } + virtual ~ThreadWithParam() {} + + private: + class RunnableImpl : public Runnable { + public: + RunnableImpl(UserThreadFunc* func, T param) + : func_(func), + param_(param) { + } + virtual ~RunnableImpl() {} + virtual void Run() { + func_(param_); + } + + private: + UserThreadFunc* const func_; + const T param_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(RunnableImpl); + }; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParam); +}; + +// Implements thread-local storage on Windows systems. +// +// // Thread 1 +// ThreadLocal tl(100); // 100 is the default value for each thread. +// +// // Thread 2 +// tl.set(150); // Changes the value for thread 2 only. +// EXPECT_EQ(150, tl.get()); +// +// // Thread 1 +// EXPECT_EQ(100, tl.get()); // In thread 1, tl has the original value. +// tl.set(200); +// EXPECT_EQ(200, tl.get()); +// +// The template type argument T must have a public copy constructor. +// In addition, the default ThreadLocal constructor requires T to have +// a public default constructor. +// +// The users of a TheadLocal instance have to make sure that all but one +// threads (including the main one) using that instance have exited before +// destroying it. Otherwise, the per-thread objects managed for them by the +// ThreadLocal instance are not guaranteed to be destroyed on all platforms. +// +// Google Test only uses global ThreadLocal objects. That means they +// will die after main() has returned. Therefore, no per-thread +// object managed by Google Test will be leaked as long as all threads +// using Google Test have exited when main() returns. +template +class ThreadLocal : public ThreadLocalBase { + public: + ThreadLocal() : default_factory_(new DefaultValueHolderFactory()) {} + explicit ThreadLocal(const T& value) + : default_factory_(new InstanceValueHolderFactory(value)) {} + + ~ThreadLocal() { ThreadLocalRegistry::OnThreadLocalDestroyed(this); } + + T* pointer() { return GetOrCreateValue(); } + const T* pointer() const { return GetOrCreateValue(); } + const T& get() const { return *pointer(); } + void set(const T& value) { *pointer() = value; } + + private: + // Holds a value of T. Can be deleted via its base class without the caller + // knowing the type of T. + class ValueHolder : public ThreadLocalValueHolderBase { + public: + ValueHolder() : value_() {} + explicit ValueHolder(const T& value) : value_(value) {} + + T* pointer() { return &value_; } + + private: + T value_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolder); + }; + + + T* GetOrCreateValue() const { + return static_cast( + ThreadLocalRegistry::GetValueOnCurrentThread(this))->pointer(); + } + + virtual ThreadLocalValueHolderBase* NewValueForCurrentThread() const { + return default_factory_->MakeNewHolder(); + } + + class ValueHolderFactory { + public: + ValueHolderFactory() {} + virtual ~ValueHolderFactory() {} + virtual ValueHolder* MakeNewHolder() const = 0; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolderFactory); + }; + + class DefaultValueHolderFactory : public ValueHolderFactory { + public: + DefaultValueHolderFactory() {} + ValueHolder* MakeNewHolder() const override { return new ValueHolder(); } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultValueHolderFactory); + }; + + class InstanceValueHolderFactory : public ValueHolderFactory { + public: + explicit InstanceValueHolderFactory(const T& value) : value_(value) {} + ValueHolder* MakeNewHolder() const override { + return new ValueHolder(value_); + } + + private: + const T value_; // The value for each thread. + + GTEST_DISALLOW_COPY_AND_ASSIGN_(InstanceValueHolderFactory); + }; + + std::unique_ptr default_factory_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocal); +}; + +# elif GTEST_HAS_PTHREAD + +// MutexBase and Mutex implement mutex on pthreads-based platforms. +class MutexBase { + public: + // Acquires this mutex. + void Lock() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_lock(&mutex_)); + owner_ = pthread_self(); + has_owner_ = true; + } + + // Releases this mutex. + void Unlock() { + // Since the lock is being released the owner_ field should no longer be + // considered valid. We don't protect writing to has_owner_ here, as it's + // the caller's responsibility to ensure that the current thread holds the + // mutex when this is called. + has_owner_ = false; + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_unlock(&mutex_)); + } + + // Does nothing if the current thread holds the mutex. Otherwise, crashes + // with high probability. + void AssertHeld() const { + GTEST_CHECK_(has_owner_ && pthread_equal(owner_, pthread_self())) + << "The current thread is not holding the mutex @" << this; + } + + // A static mutex may be used before main() is entered. It may even + // be used before the dynamic initialization stage. Therefore we + // must be able to initialize a static mutex object at link time. + // This means MutexBase has to be a POD and its member variables + // have to be public. + public: + pthread_mutex_t mutex_; // The underlying pthread mutex. + // has_owner_ indicates whether the owner_ field below contains a valid thread + // ID and is therefore safe to inspect (e.g., to use in pthread_equal()). All + // accesses to the owner_ field should be protected by a check of this field. + // An alternative might be to memset() owner_ to all zeros, but there's no + // guarantee that a zero'd pthread_t is necessarily invalid or even different + // from pthread_self(). + bool has_owner_; + pthread_t owner_; // The thread holding the mutex. +}; + +// Forward-declares a static mutex. +# define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ + extern ::testing::internal::MutexBase mutex + +// Defines and statically (i.e. at link time) initializes a static mutex. +// The initialization list here does not explicitly initialize each field, +// instead relying on default initialization for the unspecified fields. In +// particular, the owner_ field (a pthread_t) is not explicitly initialized. +// This allows initialization to work whether pthread_t is a scalar or struct. +// The flag -Wmissing-field-initializers must not be specified for this to work. +#define GTEST_DEFINE_STATIC_MUTEX_(mutex) \ + ::testing::internal::MutexBase mutex = {PTHREAD_MUTEX_INITIALIZER, false, 0} + +// The Mutex class can only be used for mutexes created at runtime. It +// shares its API with MutexBase otherwise. +class Mutex : public MutexBase { + public: + Mutex() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_init(&mutex_, nullptr)); + has_owner_ = false; + } + ~Mutex() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_destroy(&mutex_)); + } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(Mutex); +}; + +// We cannot name this class MutexLock because the ctor declaration would +// conflict with a macro named MutexLock, which is defined on some +// platforms. That macro is used as a defensive measure to prevent against +// inadvertent misuses of MutexLock like "MutexLock(&mu)" rather than +// "MutexLock l(&mu)". Hence the typedef trick below. +class GTestMutexLock { + public: + explicit GTestMutexLock(MutexBase* mutex) + : mutex_(mutex) { mutex_->Lock(); } + + ~GTestMutexLock() { mutex_->Unlock(); } + + private: + MutexBase* const mutex_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestMutexLock); +}; + +typedef GTestMutexLock MutexLock; + +// Helpers for ThreadLocal. + +// pthread_key_create() requires DeleteThreadLocalValue() to have +// C-linkage. Therefore it cannot be templatized to access +// ThreadLocal. Hence the need for class +// ThreadLocalValueHolderBase. +class ThreadLocalValueHolderBase { + public: + virtual ~ThreadLocalValueHolderBase() {} +}; + +// Called by pthread to delete thread-local data stored by +// pthread_setspecific(). +extern "C" inline void DeleteThreadLocalValue(void* value_holder) { + delete static_cast(value_holder); +} + +// Implements thread-local storage on pthreads-based systems. +template +class GTEST_API_ ThreadLocal { + public: + ThreadLocal() + : key_(CreateKey()), default_factory_(new DefaultValueHolderFactory()) {} + explicit ThreadLocal(const T& value) + : key_(CreateKey()), + default_factory_(new InstanceValueHolderFactory(value)) {} + + ~ThreadLocal() { + // Destroys the managed object for the current thread, if any. + DeleteThreadLocalValue(pthread_getspecific(key_)); + + // Releases resources associated with the key. This will *not* + // delete managed objects for other threads. + GTEST_CHECK_POSIX_SUCCESS_(pthread_key_delete(key_)); + } + + T* pointer() { return GetOrCreateValue(); } + const T* pointer() const { return GetOrCreateValue(); } + const T& get() const { return *pointer(); } + void set(const T& value) { *pointer() = value; } + + private: + // Holds a value of type T. + class ValueHolder : public ThreadLocalValueHolderBase { + public: + ValueHolder() : value_() {} + explicit ValueHolder(const T& value) : value_(value) {} + + T* pointer() { return &value_; } + + private: + T value_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolder); + }; + + static pthread_key_t CreateKey() { + pthread_key_t key; + // When a thread exits, DeleteThreadLocalValue() will be called on + // the object managed for that thread. + GTEST_CHECK_POSIX_SUCCESS_( + pthread_key_create(&key, &DeleteThreadLocalValue)); + return key; + } + + T* GetOrCreateValue() const { + ThreadLocalValueHolderBase* const holder = + static_cast(pthread_getspecific(key_)); + if (holder != nullptr) { + return CheckedDowncastToActualType(holder)->pointer(); + } + + ValueHolder* const new_holder = default_factory_->MakeNewHolder(); + ThreadLocalValueHolderBase* const holder_base = new_holder; + GTEST_CHECK_POSIX_SUCCESS_(pthread_setspecific(key_, holder_base)); + return new_holder->pointer(); + } + + class ValueHolderFactory { + public: + ValueHolderFactory() {} + virtual ~ValueHolderFactory() {} + virtual ValueHolder* MakeNewHolder() const = 0; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolderFactory); + }; + + class DefaultValueHolderFactory : public ValueHolderFactory { + public: + DefaultValueHolderFactory() {} + ValueHolder* MakeNewHolder() const override { return new ValueHolder(); } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultValueHolderFactory); + }; + + class InstanceValueHolderFactory : public ValueHolderFactory { + public: + explicit InstanceValueHolderFactory(const T& value) : value_(value) {} + ValueHolder* MakeNewHolder() const override { + return new ValueHolder(value_); + } + + private: + const T value_; // The value for each thread. + + GTEST_DISALLOW_COPY_AND_ASSIGN_(InstanceValueHolderFactory); + }; + + // A key pthreads uses for looking up per-thread values. + const pthread_key_t key_; + std::unique_ptr default_factory_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocal); +}; + +# endif // GTEST_HAS_MUTEX_AND_THREAD_LOCAL_ + +#else // GTEST_IS_THREADSAFE + +// A dummy implementation of synchronization primitives (mutex, lock, +// and thread-local variable). Necessary for compiling Google Test where +// mutex is not supported - using Google Test in multiple threads is not +// supported on such platforms. + +class Mutex { + public: + Mutex() {} + void Lock() {} + void Unlock() {} + void AssertHeld() const {} +}; + +# define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ + extern ::testing::internal::Mutex mutex + +# define GTEST_DEFINE_STATIC_MUTEX_(mutex) ::testing::internal::Mutex mutex + +// We cannot name this class MutexLock because the ctor declaration would +// conflict with a macro named MutexLock, which is defined on some +// platforms. That macro is used as a defensive measure to prevent against +// inadvertent misuses of MutexLock like "MutexLock(&mu)" rather than +// "MutexLock l(&mu)". Hence the typedef trick below. +class GTestMutexLock { + public: + explicit GTestMutexLock(Mutex*) {} // NOLINT +}; + +typedef GTestMutexLock MutexLock; + +template +class GTEST_API_ ThreadLocal { + public: + ThreadLocal() : value_() {} + explicit ThreadLocal(const T& value) : value_(value) {} + T* pointer() { return &value_; } + const T* pointer() const { return &value_; } + const T& get() const { return value_; } + void set(const T& value) { value_ = value; } + private: + T value_; +}; + +#endif // GTEST_IS_THREADSAFE + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +GTEST_API_ size_t GetThreadCount(); + +#if GTEST_OS_WINDOWS +# define GTEST_PATH_SEP_ "\\" +# define GTEST_HAS_ALT_PATH_SEP_ 1 +#else +# define GTEST_PATH_SEP_ "/" +# define GTEST_HAS_ALT_PATH_SEP_ 0 +#endif // GTEST_OS_WINDOWS + +// Utilities for char. + +// isspace(int ch) and friends accept an unsigned char or EOF. char +// may be signed, depending on the compiler (or compiler flags). +// Therefore we need to cast a char to unsigned char before calling +// isspace(), etc. + +inline bool IsAlpha(char ch) { + return isalpha(static_cast(ch)) != 0; +} +inline bool IsAlNum(char ch) { + return isalnum(static_cast(ch)) != 0; +} +inline bool IsDigit(char ch) { + return isdigit(static_cast(ch)) != 0; +} +inline bool IsLower(char ch) { + return islower(static_cast(ch)) != 0; +} +inline bool IsSpace(char ch) { + return isspace(static_cast(ch)) != 0; +} +inline bool IsUpper(char ch) { + return isupper(static_cast(ch)) != 0; +} +inline bool IsXDigit(char ch) { + return isxdigit(static_cast(ch)) != 0; +} +#ifdef __cpp_char8_t +inline bool IsXDigit(char8_t ch) { + return isxdigit(static_cast(ch)) != 0; +} +#endif +inline bool IsXDigit(char16_t ch) { + const unsigned char low_byte = static_cast(ch); + return ch == low_byte && isxdigit(low_byte) != 0; +} +inline bool IsXDigit(char32_t ch) { + const unsigned char low_byte = static_cast(ch); + return ch == low_byte && isxdigit(low_byte) != 0; +} +inline bool IsXDigit(wchar_t ch) { + const unsigned char low_byte = static_cast(ch); + return ch == low_byte && isxdigit(low_byte) != 0; +} + +inline char ToLower(char ch) { + return static_cast(tolower(static_cast(ch))); +} +inline char ToUpper(char ch) { + return static_cast(toupper(static_cast(ch))); +} + +inline std::string StripTrailingSpaces(std::string str) { + std::string::iterator it = str.end(); + while (it != str.begin() && IsSpace(*--it)) + it = str.erase(it); + return str; +} + +// The testing::internal::posix namespace holds wrappers for common +// POSIX functions. These wrappers hide the differences between +// Windows/MSVC and POSIX systems. Since some compilers define these +// standard functions as macros, the wrapper cannot have the same name +// as the wrapped function. + +namespace posix { + +// Functions with a different name on Windows. + +#if GTEST_OS_WINDOWS + +typedef struct _stat StatStruct; + +# ifdef __BORLANDC__ +inline int DoIsATTY(int fd) { return isatty(fd); } +inline int StrCaseCmp(const char* s1, const char* s2) { + return stricmp(s1, s2); +} +inline char* StrDup(const char* src) { return strdup(src); } +# else // !__BORLANDC__ +# if GTEST_OS_WINDOWS_MOBILE +inline int DoIsATTY(int /* fd */) { return 0; } +# else +inline int DoIsATTY(int fd) { return _isatty(fd); } +# endif // GTEST_OS_WINDOWS_MOBILE +inline int StrCaseCmp(const char* s1, const char* s2) { + return _stricmp(s1, s2); +} +inline char* StrDup(const char* src) { return _strdup(src); } +# endif // __BORLANDC__ + +# if GTEST_OS_WINDOWS_MOBILE +inline int FileNo(FILE* file) { return reinterpret_cast(_fileno(file)); } +// Stat(), RmDir(), and IsDir() are not needed on Windows CE at this +// time and thus not defined there. +# else +inline int FileNo(FILE* file) { return _fileno(file); } +inline int Stat(const char* path, StatStruct* buf) { return _stat(path, buf); } +inline int RmDir(const char* dir) { return _rmdir(dir); } +inline bool IsDir(const StatStruct& st) { + return (_S_IFDIR & st.st_mode) != 0; +} +# endif // GTEST_OS_WINDOWS_MOBILE + +#elif GTEST_OS_ESP8266 +typedef struct stat StatStruct; + +inline int FileNo(FILE* file) { return fileno(file); } +inline int DoIsATTY(int fd) { return isatty(fd); } +inline int Stat(const char* path, StatStruct* buf) { + // stat function not implemented on ESP8266 + return 0; +} +inline int StrCaseCmp(const char* s1, const char* s2) { + return strcasecmp(s1, s2); +} +inline char* StrDup(const char* src) { return strdup(src); } +inline int RmDir(const char* dir) { return rmdir(dir); } +inline bool IsDir(const StatStruct& st) { return S_ISDIR(st.st_mode); } + +#else + +typedef struct stat StatStruct; + +inline int FileNo(FILE* file) { return fileno(file); } +inline int DoIsATTY(int fd) { return isatty(fd); } +inline int Stat(const char* path, StatStruct* buf) { return stat(path, buf); } +inline int StrCaseCmp(const char* s1, const char* s2) { + return strcasecmp(s1, s2); +} +inline char* StrDup(const char* src) { return strdup(src); } +inline int RmDir(const char* dir) { return rmdir(dir); } +inline bool IsDir(const StatStruct& st) { return S_ISDIR(st.st_mode); } + +#endif // GTEST_OS_WINDOWS + +inline int IsATTY(int fd) { + // DoIsATTY might change errno (for example ENOTTY in case you redirect stdout + // to a file on Linux), which is unexpected, so save the previous value, and + // restore it after the call. + int savedErrno = errno; + int isAttyValue = DoIsATTY(fd); + errno = savedErrno; + + return isAttyValue; +} + +// Functions deprecated by MSVC 8.0. + +GTEST_DISABLE_MSC_DEPRECATED_PUSH_() + +// ChDir(), FReopen(), FDOpen(), Read(), Write(), Close(), and +// StrError() aren't needed on Windows CE at this time and thus not +// defined there. + +#if !GTEST_OS_WINDOWS_MOBILE && !GTEST_OS_WINDOWS_PHONE && \ + !GTEST_OS_WINDOWS_RT && !GTEST_OS_ESP8266 && !GTEST_OS_XTENSA +inline int ChDir(const char* dir) { return chdir(dir); } +#endif +inline FILE* FOpen(const char* path, const char* mode) { +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MINGW + struct wchar_codecvt : public std::codecvt {}; + std::wstring_convert converter; + std::wstring wide_path = converter.from_bytes(path); + std::wstring wide_mode = converter.from_bytes(mode); + return _wfopen(wide_path.c_str(), wide_mode.c_str()); +#else // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MINGW + return fopen(path, mode); +#endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MINGW +} +#if !GTEST_OS_WINDOWS_MOBILE +inline FILE *FReopen(const char* path, const char* mode, FILE* stream) { + return freopen(path, mode, stream); +} +inline FILE* FDOpen(int fd, const char* mode) { return fdopen(fd, mode); } +#endif +inline int FClose(FILE* fp) { return fclose(fp); } +#if !GTEST_OS_WINDOWS_MOBILE +inline int Read(int fd, void* buf, unsigned int count) { + return static_cast(read(fd, buf, count)); +} +inline int Write(int fd, const void* buf, unsigned int count) { + return static_cast(write(fd, buf, count)); +} +inline int Close(int fd) { return close(fd); } +inline const char* StrError(int errnum) { return strerror(errnum); } +#endif +inline const char* GetEnv(const char* name) { +#if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_PHONE || \ + GTEST_OS_WINDOWS_RT || GTEST_OS_ESP8266 || GTEST_OS_XTENSA + // We are on an embedded platform, which has no environment variables. + static_cast(name); // To prevent 'unused argument' warning. + return nullptr; +#elif defined(__BORLANDC__) || defined(__SunOS_5_8) || defined(__SunOS_5_9) + // Environment variables which we programmatically clear will be set to the + // empty string rather than unset (NULL). Handle that case. + const char* const env = getenv(name); + return (env != nullptr && env[0] != '\0') ? env : nullptr; +#else + return getenv(name); +#endif +} + +GTEST_DISABLE_MSC_DEPRECATED_POP_() + +#if GTEST_OS_WINDOWS_MOBILE +// Windows CE has no C library. The abort() function is used in +// several places in Google Test. This implementation provides a reasonable +// imitation of standard behaviour. +[[noreturn]] void Abort(); +#else +[[noreturn]] inline void Abort() { abort(); } +#endif // GTEST_OS_WINDOWS_MOBILE + +} // namespace posix + +// MSVC "deprecates" snprintf and issues warnings wherever it is used. In +// order to avoid these warnings, we need to use _snprintf or _snprintf_s on +// MSVC-based platforms. We map the GTEST_SNPRINTF_ macro to the appropriate +// function in order to achieve that. We use macro definition here because +// snprintf is a variadic function. +#if _MSC_VER && !GTEST_OS_WINDOWS_MOBILE +// MSVC 2005 and above support variadic macros. +# define GTEST_SNPRINTF_(buffer, size, format, ...) \ + _snprintf_s(buffer, size, size, format, __VA_ARGS__) +#elif defined(_MSC_VER) +// Windows CE does not define _snprintf_s +# define GTEST_SNPRINTF_ _snprintf +#else +# define GTEST_SNPRINTF_ snprintf +#endif + +// The biggest signed integer type the compiler supports. +// +// long long is guaranteed to be at least 64-bits in C++11. +using BiggestInt = long long; // NOLINT + +// The maximum number a BiggestInt can represent. +constexpr BiggestInt kMaxBiggestInt = (std::numeric_limits::max)(); + +// This template class serves as a compile-time function from size to +// type. It maps a size in bytes to a primitive type with that +// size. e.g. +// +// TypeWithSize<4>::UInt +// +// is typedef-ed to be unsigned int (unsigned integer made up of 4 +// bytes). +// +// Such functionality should belong to STL, but I cannot find it +// there. +// +// Google Test uses this class in the implementation of floating-point +// comparison. +// +// For now it only handles UInt (unsigned int) as that's all Google Test +// needs. Other types can be easily added in the future if need +// arises. +template +class TypeWithSize { + public: + // This prevents the user from using TypeWithSize with incorrect + // values of N. + using UInt = void; +}; + +// The specialization for size 4. +template <> +class TypeWithSize<4> { + public: + using Int = std::int32_t; + using UInt = std::uint32_t; +}; + +// The specialization for size 8. +template <> +class TypeWithSize<8> { + public: + using Int = std::int64_t; + using UInt = std::uint64_t; +}; + +// Integer types of known sizes. +using TimeInMillis = int64_t; // Represents time in milliseconds. + +// Utilities for command line flags and environment variables. + +// Macro for referencing flags. +#if !defined(GTEST_FLAG) +# define GTEST_FLAG(name) FLAGS_gtest_##name +#endif // !defined(GTEST_FLAG) + +#if !defined(GTEST_USE_OWN_FLAGFILE_FLAG_) +# define GTEST_USE_OWN_FLAGFILE_FLAG_ 1 +#endif // !defined(GTEST_USE_OWN_FLAGFILE_FLAG_) + +#if !defined(GTEST_DECLARE_bool_) +# define GTEST_FLAG_SAVER_ ::testing::internal::GTestFlagSaver + +// Macros for declaring flags. +#define GTEST_DECLARE_bool_(name) \ + namespace testing { \ + GTEST_API_ extern bool GTEST_FLAG(name); \ + } static_assert(true, "no-op to require trailing semicolon") +#define GTEST_DECLARE_int32_(name) \ + namespace testing { \ + GTEST_API_ extern std::int32_t GTEST_FLAG(name); \ + } static_assert(true, "no-op to require trailing semicolon") +#define GTEST_DECLARE_string_(name) \ + namespace testing { \ + GTEST_API_ extern ::std::string GTEST_FLAG(name); \ + } static_assert(true, "no-op to require trailing semicolon") + +// Macros for defining flags. +#define GTEST_DEFINE_bool_(name, default_val, doc) \ + namespace testing { \ + GTEST_API_ bool GTEST_FLAG(name) = (default_val); \ + } static_assert(true, "no-op to require trailing semicolon") +#define GTEST_DEFINE_int32_(name, default_val, doc) \ + namespace testing { \ + GTEST_API_ std::int32_t GTEST_FLAG(name) = (default_val); \ + } static_assert(true, "no-op to require trailing semicolon") +#define GTEST_DEFINE_string_(name, default_val, doc) \ + namespace testing { \ + GTEST_API_ ::std::string GTEST_FLAG(name) = (default_val); \ + } static_assert(true, "no-op to require trailing semicolon") + +#endif // !defined(GTEST_DECLARE_bool_) + +#if !defined(GTEST_FLAG_GET) +#define GTEST_FLAG_GET(name) ::testing::GTEST_FLAG(name) +#define GTEST_FLAG_SET(name, value) (void)(::testing::GTEST_FLAG(name) = value) +#endif // !defined(GTEST_FLAG_GET) + +// Thread annotations +#if !defined(GTEST_EXCLUSIVE_LOCK_REQUIRED_) +# define GTEST_EXCLUSIVE_LOCK_REQUIRED_(locks) +# define GTEST_LOCK_EXCLUDED_(locks) +#endif // !defined(GTEST_EXCLUSIVE_LOCK_REQUIRED_) + +// Parses 'str' for a 32-bit signed integer. If successful, writes the result +// to *value and returns true; otherwise leaves *value unchanged and returns +// false. +GTEST_API_ bool ParseInt32(const Message& src_text, const char* str, + int32_t* value); + +// Parses a bool/int32_t/string from the environment variable +// corresponding to the given Google Test flag. +bool BoolFromGTestEnv(const char* flag, bool default_val); +GTEST_API_ int32_t Int32FromGTestEnv(const char* flag, int32_t default_val); +std::string OutputFlagAlsoCheckEnvVar(); +const char* StringFromGTestEnv(const char* flag, const char* default_val); + +} // namespace internal +} // namespace testing + +#if !defined(GTEST_INTERNAL_DEPRECATED) + +// Internal Macro to mark an API deprecated, for googletest usage only +// Usage: class GTEST_INTERNAL_DEPRECATED(message) MyClass or +// GTEST_INTERNAL_DEPRECATED(message) myFunction(); Every usage of +// a deprecated entity will trigger a warning when compiled with +// `-Wdeprecated-declarations` option (clang, gcc, any __GNUC__ compiler). +// For msvc /W3 option will need to be used +// Note that for 'other' compilers this macro evaluates to nothing to prevent +// compilations errors. +#if defined(_MSC_VER) +#define GTEST_INTERNAL_DEPRECATED(message) __declspec(deprecated(message)) +#elif defined(__GNUC__) +#define GTEST_INTERNAL_DEPRECATED(message) __attribute__((deprecated(message))) +#else +#define GTEST_INTERNAL_DEPRECATED(message) +#endif + +#endif // !defined(GTEST_INTERNAL_DEPRECATED) + +#if GTEST_HAS_ABSL +// Always use absl::any for UniversalPrinter<> specializations if googletest +// is built with absl support. +#define GTEST_INTERNAL_HAS_ANY 1 +#include "absl/types/any.h" +namespace testing { +namespace internal { +using Any = ::absl::any; +} // namespace internal +} // namespace testing +#else +#ifdef __has_include +#if __has_include() && __cplusplus >= 201703L +// Otherwise for C++17 and higher use std::any for UniversalPrinter<> +// specializations. +#define GTEST_INTERNAL_HAS_ANY 1 +#include +namespace testing { +namespace internal { +using Any = ::std::any; +} // namespace internal +} // namespace testing +// The case where absl is configured NOT to alias std::any is not +// supported. +#endif // __has_include() && __cplusplus >= 201703L +#endif // __has_include +#endif // GTEST_HAS_ABSL + +#if GTEST_HAS_ABSL +// Always use absl::optional for UniversalPrinter<> specializations if +// googletest is built with absl support. +#define GTEST_INTERNAL_HAS_OPTIONAL 1 +#include "absl/types/optional.h" +namespace testing { +namespace internal { +template +using Optional = ::absl::optional; +inline ::absl::nullopt_t Nullopt() { return ::absl::nullopt; } +} // namespace internal +} // namespace testing +#else +#ifdef __has_include +#if __has_include() && __cplusplus >= 201703L +// Otherwise for C++17 and higher use std::optional for UniversalPrinter<> +// specializations. +#define GTEST_INTERNAL_HAS_OPTIONAL 1 +#include +namespace testing { +namespace internal { +template +using Optional = ::std::optional; +inline ::std::nullopt_t Nullopt() { return ::std::nullopt; } +} // namespace internal +} // namespace testing +// The case where absl is configured NOT to alias std::optional is not +// supported. +#endif // __has_include() && __cplusplus >= 201703L +#endif // __has_include +#endif // GTEST_HAS_ABSL + +#if GTEST_HAS_ABSL +// Always use absl::string_view for Matcher<> specializations if googletest +// is built with absl support. +# define GTEST_INTERNAL_HAS_STRING_VIEW 1 +#include "absl/strings/string_view.h" +namespace testing { +namespace internal { +using StringView = ::absl::string_view; +} // namespace internal +} // namespace testing +#else +# ifdef __has_include +# if __has_include() && __cplusplus >= 201703L +// Otherwise for C++17 and higher use std::string_view for Matcher<> +// specializations. +# define GTEST_INTERNAL_HAS_STRING_VIEW 1 +#include +namespace testing { +namespace internal { +using StringView = ::std::string_view; +} // namespace internal +} // namespace testing +// The case where absl is configured NOT to alias std::string_view is not +// supported. +# endif // __has_include() && __cplusplus >= 201703L +# endif // __has_include +#endif // GTEST_HAS_ABSL + +#if GTEST_HAS_ABSL +// Always use absl::variant for UniversalPrinter<> specializations if googletest +// is built with absl support. +#define GTEST_INTERNAL_HAS_VARIANT 1 +#include "absl/types/variant.h" +namespace testing { +namespace internal { +template +using Variant = ::absl::variant; +} // namespace internal +} // namespace testing +#else +#ifdef __has_include +#if __has_include() && __cplusplus >= 201703L +// Otherwise for C++17 and higher use std::variant for UniversalPrinter<> +// specializations. +#define GTEST_INTERNAL_HAS_VARIANT 1 +#include +namespace testing { +namespace internal { +template +using Variant = ::std::variant; +} // namespace internal +} // namespace testing +// The case where absl is configured NOT to alias std::variant is not supported. +#endif // __has_include() && __cplusplus >= 201703L +#endif // __has_include +#endif // GTEST_HAS_ABSL + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-string.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-string.h new file mode 100644 index 000000000000..4cb8e07cf9f4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-string.h @@ -0,0 +1,177 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file declares the String class and functions used internally by +// Google Test. They are subject to change without notice. They should not used +// by code external to Google Test. +// +// This header file is #included by gtest-internal.h. +// It should not be #included by other files. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ + +#ifdef __BORLANDC__ +// string.h is not guaranteed to provide strcpy on C++ Builder. +# include +#endif + +#include +#include +#include + +#include "gtest/internal/gtest-port.h" + +namespace testing { +namespace internal { + +// String - an abstract class holding static string utilities. +class GTEST_API_ String { + public: + // Static utility methods + + // Clones a 0-terminated C string, allocating memory using new. The + // caller is responsible for deleting the return value using + // delete[]. Returns the cloned string, or NULL if the input is + // NULL. + // + // This is different from strdup() in string.h, which allocates + // memory using malloc(). + static const char* CloneCString(const char* c_str); + +#if GTEST_OS_WINDOWS_MOBILE + // Windows CE does not have the 'ANSI' versions of Win32 APIs. To be + // able to pass strings to Win32 APIs on CE we need to convert them + // to 'Unicode', UTF-16. + + // Creates a UTF-16 wide string from the given ANSI string, allocating + // memory using new. The caller is responsible for deleting the return + // value using delete[]. Returns the wide string, or NULL if the + // input is NULL. + // + // The wide string is created using the ANSI codepage (CP_ACP) to + // match the behaviour of the ANSI versions of Win32 calls and the + // C runtime. + static LPCWSTR AnsiToUtf16(const char* c_str); + + // Creates an ANSI string from the given wide string, allocating + // memory using new. The caller is responsible for deleting the return + // value using delete[]. Returns the ANSI string, or NULL if the + // input is NULL. + // + // The returned string is created using the ANSI codepage (CP_ACP) to + // match the behaviour of the ANSI versions of Win32 calls and the + // C runtime. + static const char* Utf16ToAnsi(LPCWSTR utf16_str); +#endif + + // Compares two C strings. Returns true if and only if they have the same + // content. + // + // Unlike strcmp(), this function can handle NULL argument(s). A + // NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool CStringEquals(const char* lhs, const char* rhs); + + // Converts a wide C string to a String using the UTF-8 encoding. + // NULL will be converted to "(null)". If an error occurred during + // the conversion, "(failed to convert from wide string)" is + // returned. + static std::string ShowWideCString(const wchar_t* wide_c_str); + + // Compares two wide C strings. Returns true if and only if they have the + // same content. + // + // Unlike wcscmp(), this function can handle NULL argument(s). A + // NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool WideCStringEquals(const wchar_t* lhs, const wchar_t* rhs); + + // Compares two C strings, ignoring case. Returns true if and only if + // they have the same content. + // + // Unlike strcasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool CaseInsensitiveCStringEquals(const char* lhs, + const char* rhs); + + // Compares two wide C strings, ignoring case. Returns true if and only if + // they have the same content. + // + // Unlike wcscasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL wide C string, + // including the empty string. + // NB: The implementations on different platforms slightly differ. + // On windows, this method uses _wcsicmp which compares according to LC_CTYPE + // environment variable. On GNU platform this method uses wcscasecmp + // which compares according to LC_CTYPE category of the current locale. + // On MacOS X, it uses towlower, which also uses LC_CTYPE category of the + // current locale. + static bool CaseInsensitiveWideCStringEquals(const wchar_t* lhs, + const wchar_t* rhs); + + // Returns true if and only if the given string ends with the given suffix, + // ignoring case. Any string is considered to end with an empty suffix. + static bool EndsWithCaseInsensitive( + const std::string& str, const std::string& suffix); + + // Formats an int value as "%02d". + static std::string FormatIntWidth2(int value); // "%02d" for width == 2 + + // Formats an int value to given width with leading zeros. + static std::string FormatIntWidthN(int value, int width); + + // Formats an int value as "%X". + static std::string FormatHexInt(int value); + + // Formats an int value as "%X". + static std::string FormatHexUInt32(uint32_t value); + + // Formats a byte as "%02X". + static std::string FormatByte(unsigned char value); + + private: + String(); // Not meant to be instantiated. +}; // class String + +// Gets the content of the stringstream's buffer as an std::string. Each '\0' +// character in the buffer is replaced with "\\0". +GTEST_API_ std::string StringStreamToString(::std::stringstream* stream); + +} // namespace internal +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-type-util.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-type-util.h new file mode 100644 index 000000000000..665564a97ac9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/include/gtest/internal/gtest-type-util.h @@ -0,0 +1,185 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Type utilities needed for implementing typed and type-parameterized +// tests. + +// IWYU pragma: private, include "gtest/gtest.h" +// IWYU pragma: friend gtest/.* +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ +#define GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ + +#include "gtest/internal/gtest-port.h" + +// #ifdef __GNUC__ is too general here. It is possible to use gcc without using +// libstdc++ (which is where cxxabi.h comes from). +# if GTEST_HAS_CXXABI_H_ +# include +# elif defined(__HP_aCC) +# include +# endif // GTEST_HASH_CXXABI_H_ + +namespace testing { +namespace internal { + +// Canonicalizes a given name with respect to the Standard C++ Library. +// This handles removing the inline namespace within `std` that is +// used by various standard libraries (e.g., `std::__1`). Names outside +// of namespace std are returned unmodified. +inline std::string CanonicalizeForStdLibVersioning(std::string s) { + static const char prefix[] = "std::__"; + if (s.compare(0, strlen(prefix), prefix) == 0) { + std::string::size_type end = s.find("::", strlen(prefix)); + if (end != s.npos) { + // Erase everything between the initial `std` and the second `::`. + s.erase(strlen("std"), end - strlen("std")); + } + } + return s; +} + +#if GTEST_HAS_RTTI +// GetTypeName(const std::type_info&) returns a human-readable name of type T. +inline std::string GetTypeName(const std::type_info& type) { + const char* const name = type.name(); +#if GTEST_HAS_CXXABI_H_ || defined(__HP_aCC) + int status = 0; + // gcc's implementation of typeid(T).name() mangles the type name, + // so we have to demangle it. +#if GTEST_HAS_CXXABI_H_ + using abi::__cxa_demangle; +#endif // GTEST_HAS_CXXABI_H_ + char* const readable_name = __cxa_demangle(name, nullptr, nullptr, &status); + const std::string name_str(status == 0 ? readable_name : name); + free(readable_name); + return CanonicalizeForStdLibVersioning(name_str); +#else + return name; +#endif // GTEST_HAS_CXXABI_H_ || __HP_aCC +} +#endif // GTEST_HAS_RTTI + +// GetTypeName() returns a human-readable name of type T if and only if +// RTTI is enabled, otherwise it returns a dummy type name. +// NB: This function is also used in Google Mock, so don't move it inside of +// the typed-test-only section below. +template +std::string GetTypeName() { +#if GTEST_HAS_RTTI + return GetTypeName(typeid(T)); +#else + return ""; +#endif // GTEST_HAS_RTTI +} + +// A unique type indicating an empty node +struct None {}; + +# define GTEST_TEMPLATE_ template class + +// The template "selector" struct TemplateSel is used to +// represent Tmpl, which must be a class template with one type +// parameter, as a type. TemplateSel::Bind::type is defined +// as the type Tmpl. This allows us to actually instantiate the +// template "selected" by TemplateSel. +// +// This trick is necessary for simulating typedef for class templates, +// which C++ doesn't support directly. +template +struct TemplateSel { + template + struct Bind { + typedef Tmpl type; + }; +}; + +# define GTEST_BIND_(TmplSel, T) \ + TmplSel::template Bind::type + +template +struct Templates { + using Head = TemplateSel; + using Tail = Templates; +}; + +template +struct Templates { + using Head = TemplateSel; + using Tail = None; +}; + +// Tuple-like type lists +template +struct Types { + using Head = Head_; + using Tail = Types; +}; + +template +struct Types { + using Head = Head_; + using Tail = None; +}; + +// Helper metafunctions to tell apart a single type from types +// generated by ::testing::Types +template +struct ProxyTypeList { + using type = Types; +}; + +template +struct is_proxy_type_list : std::false_type {}; + +template +struct is_proxy_type_list> : std::true_type {}; + +// Generator which conditionally creates type lists. +// It recognizes if a requested type list should be created +// and prevents creating a new type list nested within another one. +template +struct GenerateTypeList { + private: + using proxy = typename std::conditional::value, T, + ProxyTypeList>::type; + + public: + using type = typename proxy::type; +}; + +} // namespace internal + +template +using Types = internal::ProxyTypeList; + +} // namespace testing + +#endif // GOOGLETEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/prime_tables.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/prime_tables.h new file mode 100644 index 000000000000..3a10352baae5 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/prime_tables.h @@ -0,0 +1,126 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +// This provides interface PrimeTable that determines whether a number is a +// prime and determines a next prime number. This interface is used +// in Google Test samples demonstrating use of parameterized tests. + +#ifndef GOOGLETEST_SAMPLES_PRIME_TABLES_H_ +#define GOOGLETEST_SAMPLES_PRIME_TABLES_H_ + +#include + +// The prime table interface. +class PrimeTable { + public: + virtual ~PrimeTable() {} + + // Returns true if and only if n is a prime number. + virtual bool IsPrime(int n) const = 0; + + // Returns the smallest prime number greater than p; or returns -1 + // if the next prime is beyond the capacity of the table. + virtual int GetNextPrime(int p) const = 0; +}; + +// Implementation #1 calculates the primes on-the-fly. +class OnTheFlyPrimeTable : public PrimeTable { + public: + bool IsPrime(int n) const override { + if (n <= 1) return false; + + for (int i = 2; i*i <= n; i++) { + // n is divisible by an integer other than 1 and itself. + if ((n % i) == 0) return false; + } + + return true; + } + + int GetNextPrime(int p) const override { + if (p < 0) return -1; + + for (int n = p + 1;; n++) { + if (IsPrime(n)) return n; + } + } +}; + +// Implementation #2 pre-calculates the primes and stores the result +// in an array. +class PreCalculatedPrimeTable : public PrimeTable { + public: + // 'max' specifies the maximum number the prime table holds. + explicit PreCalculatedPrimeTable(int max) + : is_prime_size_(max + 1), is_prime_(new bool[max + 1]) { + CalculatePrimesUpTo(max); + } + ~PreCalculatedPrimeTable() override { delete[] is_prime_; } + + bool IsPrime(int n) const override { + return 0 <= n && n < is_prime_size_ && is_prime_[n]; + } + + int GetNextPrime(int p) const override { + for (int n = p + 1; n < is_prime_size_; n++) { + if (is_prime_[n]) return n; + } + + return -1; + } + + private: + void CalculatePrimesUpTo(int max) { + ::std::fill(is_prime_, is_prime_ + is_prime_size_, true); + is_prime_[0] = is_prime_[1] = false; + + // Checks every candidate for prime number (we know that 2 is the only even + // prime). + for (int i = 2; i*i <= max; i += i%2+1) { + if (!is_prime_[i]) continue; + + // Marks all multiples of i (except i itself) as non-prime. + // We are starting here from i-th multiplier, because all smaller + // complex numbers were already marked. + for (int j = i*i; j <= max; j += i) { + is_prime_[j] = false; + } + } + } + + const int is_prime_size_; + bool* const is_prime_; + + // Disables compiler warning "assignment operator could not be generated." + void operator=(const PreCalculatedPrimeTable& rhs); +}; + +#endif // GOOGLETEST_SAMPLES_PRIME_TABLES_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1.cc new file mode 100644 index 000000000000..1d4275979ff6 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1.cc @@ -0,0 +1,66 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +#include "sample1.h" + +// Returns n! (the factorial of n). For negative n, n! is defined to be 1. +int Factorial(int n) { + int result = 1; + for (int i = 1; i <= n; i++) { + result *= i; + } + + return result; +} + +// Returns true if and only if n is a prime number. +bool IsPrime(int n) { + // Trivial case 1: small numbers + if (n <= 1) return false; + + // Trivial case 2: even numbers + if (n % 2 == 0) return n == 2; + + // Now, we have that n is odd and n >= 3. + + // Try to divide n by every odd number i, starting from 3 + for (int i = 3; ; i += 2) { + // We only have to try i up to the square root of n + if (i > n/i) break; + + // Now, we have i <= n/i < n. + // If n is divisible by i, n is not prime. + if (n % i == 0) return false; + } + + // n has no integer factor in the range (1, n), and thus is prime. + return true; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1.h new file mode 100644 index 000000000000..ba392cfbd266 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1.h @@ -0,0 +1,41 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +#ifndef GOOGLETEST_SAMPLES_SAMPLE1_H_ +#define GOOGLETEST_SAMPLES_SAMPLE1_H_ + +// Returns n! (the factorial of n). For negative n, n! is defined to be 1. +int Factorial(int n); + +// Returns true if and only if n is a prime number. +bool IsPrime(int n); + +#endif // GOOGLETEST_SAMPLES_SAMPLE1_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample10_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample10_unittest.cc new file mode 100644 index 000000000000..36cdac2279ae --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample10_unittest.cc @@ -0,0 +1,139 @@ +// Copyright 2009 Google Inc. All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This sample shows how to use Google Test listener API to implement +// a primitive leak checker. + +#include +#include + +#include "gtest/gtest.h" +using ::testing::EmptyTestEventListener; +using ::testing::InitGoogleTest; +using ::testing::Test; +using ::testing::TestEventListeners; +using ::testing::TestInfo; +using ::testing::TestPartResult; +using ::testing::UnitTest; + +namespace { +// We will track memory used by this class. +class Water { + public: + // Normal Water declarations go here. + + // operator new and operator delete help us control water allocation. + void* operator new(size_t allocation_size) { + allocated_++; + return malloc(allocation_size); + } + + void operator delete(void* block, size_t /* allocation_size */) { + allocated_--; + free(block); + } + + static int allocated() { return allocated_; } + + private: + static int allocated_; +}; + +int Water::allocated_ = 0; + +// This event listener monitors how many Water objects are created and +// destroyed by each test, and reports a failure if a test leaks some Water +// objects. It does this by comparing the number of live Water objects at +// the beginning of a test and at the end of a test. +class LeakChecker : public EmptyTestEventListener { + private: + // Called before a test starts. + void OnTestStart(const TestInfo& /* test_info */) override { + initially_allocated_ = Water::allocated(); + } + + // Called after a test ends. + void OnTestEnd(const TestInfo& /* test_info */) override { + int difference = Water::allocated() - initially_allocated_; + + // You can generate a failure in any event handler except + // OnTestPartResult. Just use an appropriate Google Test assertion to do + // it. + EXPECT_LE(difference, 0) << "Leaked " << difference << " unit(s) of Water!"; + } + + int initially_allocated_; +}; + +TEST(ListenersTest, DoesNotLeak) { + Water* water = new Water; + delete water; +} + +// This should fail when the --check_for_leaks command line flag is +// specified. +TEST(ListenersTest, LeaksWater) { + Water* water = new Water; + EXPECT_TRUE(water != nullptr); +} +} // namespace + +int main(int argc, char **argv) { + InitGoogleTest(&argc, argv); + + bool check_for_leaks = false; + if (argc > 1 && strcmp(argv[1], "--check_for_leaks") == 0 ) + check_for_leaks = true; + else + printf("%s\n", "Run this program with --check_for_leaks to enable " + "custom leak checking in the tests."); + + // If we are given the --check_for_leaks command line flag, installs the + // leak checker. + if (check_for_leaks) { + TestEventListeners& listeners = UnitTest::GetInstance()->listeners(); + + // Adds the leak checker to the end of the test event listener list, + // after the default text output printer and the default XML report + // generator. + // + // The order is important - it ensures that failures generated in the + // leak checker's OnTestEnd() method are processed by the text and XML + // printers *before* their OnTestEnd() methods are called, such that + // they are attributed to the right test. Remember that a listener + // receives an OnXyzStart event *after* listeners preceding it in the + // list received that event, and receives an OnXyzEnd event *before* + // listeners preceding it. + // + // We don't need to worry about deleting the new listener later, as + // Google Test will do it. + listeners.Append(new LeakChecker); + } + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1_unittest.cc new file mode 100644 index 000000000000..cb08b61a59a9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample1_unittest.cc @@ -0,0 +1,151 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +// This sample shows how to write a simple unit test for a function, +// using Google C++ testing framework. +// +// Writing a unit test using Google C++ testing framework is easy as 1-2-3: + + +// Step 1. Include necessary header files such that the stuff your +// test logic needs is declared. +// +// Don't forget gtest.h, which declares the testing framework. + +#include +#include "sample1.h" +#include "gtest/gtest.h" +namespace { + +// Step 2. Use the TEST macro to define your tests. +// +// TEST has two parameters: the test case name and the test name. +// After using the macro, you should define your test logic between a +// pair of braces. You can use a bunch of macros to indicate the +// success or failure of a test. EXPECT_TRUE and EXPECT_EQ are +// examples of such macros. For a complete list, see gtest.h. +// +// +// +// In Google Test, tests are grouped into test cases. This is how we +// keep test code organized. You should put logically related tests +// into the same test case. +// +// The test case name and the test name should both be valid C++ +// identifiers. And you should not use underscore (_) in the names. +// +// Google Test guarantees that each test you define is run exactly +// once, but it makes no guarantee on the order the tests are +// executed. Therefore, you should write your tests in such a way +// that their results don't depend on their order. +// +// + + +// Tests Factorial(). + +// Tests factorial of negative numbers. +TEST(FactorialTest, Negative) { + // This test is named "Negative", and belongs to the "FactorialTest" + // test case. + EXPECT_EQ(1, Factorial(-5)); + EXPECT_EQ(1, Factorial(-1)); + EXPECT_GT(Factorial(-10), 0); + + // + // + // EXPECT_EQ(expected, actual) is the same as + // + // EXPECT_TRUE((expected) == (actual)) + // + // except that it will print both the expected value and the actual + // value when the assertion fails. This is very helpful for + // debugging. Therefore in this case EXPECT_EQ is preferred. + // + // On the other hand, EXPECT_TRUE accepts any Boolean expression, + // and is thus more general. + // + // +} + +// Tests factorial of 0. +TEST(FactorialTest, Zero) { + EXPECT_EQ(1, Factorial(0)); +} + +// Tests factorial of positive numbers. +TEST(FactorialTest, Positive) { + EXPECT_EQ(1, Factorial(1)); + EXPECT_EQ(2, Factorial(2)); + EXPECT_EQ(6, Factorial(3)); + EXPECT_EQ(40320, Factorial(8)); +} + + +// Tests IsPrime() + +// Tests negative input. +TEST(IsPrimeTest, Negative) { + // This test belongs to the IsPrimeTest test case. + + EXPECT_FALSE(IsPrime(-1)); + EXPECT_FALSE(IsPrime(-2)); + EXPECT_FALSE(IsPrime(INT_MIN)); +} + +// Tests some trivial cases. +TEST(IsPrimeTest, Trivial) { + EXPECT_FALSE(IsPrime(0)); + EXPECT_FALSE(IsPrime(1)); + EXPECT_TRUE(IsPrime(2)); + EXPECT_TRUE(IsPrime(3)); +} + +// Tests positive input. +TEST(IsPrimeTest, Positive) { + EXPECT_FALSE(IsPrime(4)); + EXPECT_TRUE(IsPrime(5)); + EXPECT_FALSE(IsPrime(6)); + EXPECT_TRUE(IsPrime(23)); +} +} // namespace + +// Step 3. Call RUN_ALL_TESTS() in main(). +// +// We do this by linking in src/gtest_main.cc file, which consists of +// a main() function which calls RUN_ALL_TESTS() for us. +// +// This runs all the tests you've defined, prints the result, and +// returns 0 if successful, or 1 otherwise. +// +// Did you notice that we didn't register the tests? The +// RUN_ALL_TESTS() macro magically knows about all the tests we +// defined. Isn't this convenient? diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2.cc new file mode 100644 index 000000000000..d8e8723965a1 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2.cc @@ -0,0 +1,54 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +#include "sample2.h" + +#include + +// Clones a 0-terminated C string, allocating memory using new. +const char* MyString::CloneCString(const char* a_c_string) { + if (a_c_string == nullptr) return nullptr; + + const size_t len = strlen(a_c_string); + char* const clone = new char[ len + 1 ]; + memcpy(clone, a_c_string, len + 1); + + return clone; +} + +// Sets the 0-terminated C string this MyString object +// represents. +void MyString::Set(const char* a_c_string) { + // Makes sure this works when c_string == c_string_ + const char* const temp = MyString::CloneCString(a_c_string); + delete[] c_string_; + c_string_ = temp; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2.h new file mode 100644 index 000000000000..0f9868959d21 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2.h @@ -0,0 +1,80 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +#ifndef GOOGLETEST_SAMPLES_SAMPLE2_H_ +#define GOOGLETEST_SAMPLES_SAMPLE2_H_ + +#include + + +// A simple string class. +class MyString { + private: + const char* c_string_; + const MyString& operator=(const MyString& rhs); + + public: + // Clones a 0-terminated C string, allocating memory using new. + static const char* CloneCString(const char* a_c_string); + + //////////////////////////////////////////////////////////// + // + // C'tors + + // The default c'tor constructs a NULL string. + MyString() : c_string_(nullptr) {} + + // Constructs a MyString by cloning a 0-terminated C string. + explicit MyString(const char* a_c_string) : c_string_(nullptr) { + Set(a_c_string); + } + + // Copy c'tor + MyString(const MyString& string) : c_string_(nullptr) { + Set(string.c_string_); + } + + //////////////////////////////////////////////////////////// + // + // D'tor. MyString is intended to be a final class, so the d'tor + // doesn't need to be virtual. + ~MyString() { delete[] c_string_; } + + // Gets the 0-terminated C string this MyString object represents. + const char* c_string() const { return c_string_; } + + size_t Length() const { return c_string_ == nullptr ? 0 : strlen(c_string_); } + + // Sets the 0-terminated C string this MyString object represents. + void Set(const char* c_string); +}; + +#endif // GOOGLETEST_SAMPLES_SAMPLE2_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2_unittest.cc new file mode 100644 index 000000000000..41e31c1767de --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample2_unittest.cc @@ -0,0 +1,107 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +// This sample shows how to write a more complex unit test for a class +// that has multiple member functions. +// +// Usually, it's a good idea to have one test for each method in your +// class. You don't have to do that exactly, but it helps to keep +// your tests organized. You may also throw in additional tests as +// needed. + +#include "sample2.h" +#include "gtest/gtest.h" +namespace { +// In this example, we test the MyString class (a simple string). + +// Tests the default c'tor. +TEST(MyString, DefaultConstructor) { + const MyString s; + + // Asserts that s.c_string() returns NULL. + // + // + // + // If we write NULL instead of + // + // static_cast(NULL) + // + // in this assertion, it will generate a warning on gcc 3.4. The + // reason is that EXPECT_EQ needs to know the types of its + // arguments in order to print them when it fails. Since NULL is + // #defined as 0, the compiler will use the formatter function for + // int to print it. However, gcc thinks that NULL should be used as + // a pointer, not an int, and therefore complains. + // + // The root of the problem is C++'s lack of distinction between the + // integer number 0 and the null pointer constant. Unfortunately, + // we have to live with this fact. + // + // + EXPECT_STREQ(nullptr, s.c_string()); + + EXPECT_EQ(0u, s.Length()); +} + +const char kHelloString[] = "Hello, world!"; + +// Tests the c'tor that accepts a C string. +TEST(MyString, ConstructorFromCString) { + const MyString s(kHelloString); + EXPECT_EQ(0, strcmp(s.c_string(), kHelloString)); + EXPECT_EQ(sizeof(kHelloString)/sizeof(kHelloString[0]) - 1, + s.Length()); +} + +// Tests the copy c'tor. +TEST(MyString, CopyConstructor) { + const MyString s1(kHelloString); + const MyString s2 = s1; + EXPECT_EQ(0, strcmp(s2.c_string(), kHelloString)); +} + +// Tests the Set method. +TEST(MyString, Set) { + MyString s; + + s.Set(kHelloString); + EXPECT_EQ(0, strcmp(s.c_string(), kHelloString)); + + // Set should work when the input pointer is the same as the one + // already in the MyString object. + s.Set(s.c_string()); + EXPECT_EQ(0, strcmp(s.c_string(), kHelloString)); + + // Can we set the MyString to NULL? + s.Set(nullptr); + EXPECT_STREQ(nullptr, s.c_string()); +} +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample3-inl.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample3-inl.h new file mode 100644 index 000000000000..659e0f0bb5df --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample3-inl.h @@ -0,0 +1,172 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +#ifndef GOOGLETEST_SAMPLES_SAMPLE3_INL_H_ +#define GOOGLETEST_SAMPLES_SAMPLE3_INL_H_ + +#include + + +// Queue is a simple queue implemented as a singled-linked list. +// +// The element type must support copy constructor. +template // E is the element type +class Queue; + +// QueueNode is a node in a Queue, which consists of an element of +// type E and a pointer to the next node. +template // E is the element type +class QueueNode { + friend class Queue; + + public: + // Gets the element in this node. + const E& element() const { return element_; } + + // Gets the next node in the queue. + QueueNode* next() { return next_; } + const QueueNode* next() const { return next_; } + + private: + // Creates a node with a given element value. The next pointer is + // set to NULL. + explicit QueueNode(const E& an_element) + : element_(an_element), next_(nullptr) {} + + // We disable the default assignment operator and copy c'tor. + const QueueNode& operator = (const QueueNode&); + QueueNode(const QueueNode&); + + E element_; + QueueNode* next_; +}; + +template // E is the element type. +class Queue { + public: + // Creates an empty queue. + Queue() : head_(nullptr), last_(nullptr), size_(0) {} + + // D'tor. Clears the queue. + ~Queue() { Clear(); } + + // Clears the queue. + void Clear() { + if (size_ > 0) { + // 1. Deletes every node. + QueueNode* node = head_; + QueueNode* next = node->next(); + for (; ;) { + delete node; + node = next; + if (node == nullptr) break; + next = node->next(); + } + + // 2. Resets the member variables. + head_ = last_ = nullptr; + size_ = 0; + } + } + + // Gets the number of elements. + size_t Size() const { return size_; } + + // Gets the first element of the queue, or NULL if the queue is empty. + QueueNode* Head() { return head_; } + const QueueNode* Head() const { return head_; } + + // Gets the last element of the queue, or NULL if the queue is empty. + QueueNode* Last() { return last_; } + const QueueNode* Last() const { return last_; } + + // Adds an element to the end of the queue. A copy of the element is + // created using the copy constructor, and then stored in the queue. + // Changes made to the element in the queue doesn't affect the source + // object, and vice versa. + void Enqueue(const E& element) { + QueueNode* new_node = new QueueNode(element); + + if (size_ == 0) { + head_ = last_ = new_node; + size_ = 1; + } else { + last_->next_ = new_node; + last_ = new_node; + size_++; + } + } + + // Removes the head of the queue and returns it. Returns NULL if + // the queue is empty. + E* Dequeue() { + if (size_ == 0) { + return nullptr; + } + + const QueueNode* const old_head = head_; + head_ = head_->next_; + size_--; + if (size_ == 0) { + last_ = nullptr; + } + + E* element = new E(old_head->element()); + delete old_head; + + return element; + } + + // Applies a function/functor on each element of the queue, and + // returns the result in a new queue. The original queue is not + // affected. + template + Queue* Map(F function) const { + Queue* new_queue = new Queue(); + for (const QueueNode* node = head_; node != nullptr; + node = node->next_) { + new_queue->Enqueue(function(node->element())); + } + + return new_queue; + } + + private: + QueueNode* head_; // The first node of the queue. + QueueNode* last_; // The last node of the queue. + size_t size_; // The number of elements in the queue. + + // We disallow copying a queue. + Queue(const Queue&); + const Queue& operator = (const Queue&); +}; + +#endif // GOOGLETEST_SAMPLES_SAMPLE3_INL_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample3_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample3_unittest.cc new file mode 100644 index 000000000000..b19416d53c95 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample3_unittest.cc @@ -0,0 +1,149 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +// In this example, we use a more advanced feature of Google Test called +// test fixture. +// +// A test fixture is a place to hold objects and functions shared by +// all tests in a test case. Using a test fixture avoids duplicating +// the test code necessary to initialize and cleanup those common +// objects for each test. It is also useful for defining sub-routines +// that your tests need to invoke a lot. +// +// +// +// The tests share the test fixture in the sense of code sharing, not +// data sharing. Each test is given its own fresh copy of the +// fixture. You cannot expect the data modified by one test to be +// passed on to another test, which is a bad idea. +// +// The reason for this design is that tests should be independent and +// repeatable. In particular, a test should not fail as the result of +// another test's failure. If one test depends on info produced by +// another test, then the two tests should really be one big test. +// +// The macros for indicating the success/failure of a test +// (EXPECT_TRUE, FAIL, etc) need to know what the current test is +// (when Google Test prints the test result, it tells you which test +// each failure belongs to). Technically, these macros invoke a +// member function of the Test class. Therefore, you cannot use them +// in a global function. That's why you should put test sub-routines +// in a test fixture. +// +// + +#include "sample3-inl.h" +#include "gtest/gtest.h" +namespace { +// To use a test fixture, derive a class from testing::Test. +class QueueTestSmpl3 : public testing::Test { + protected: // You should make the members protected s.t. they can be + // accessed from sub-classes. + + // virtual void SetUp() will be called before each test is run. You + // should define it if you need to initialize the variables. + // Otherwise, this can be skipped. + void SetUp() override { + q1_.Enqueue(1); + q2_.Enqueue(2); + q2_.Enqueue(3); + } + + // virtual void TearDown() will be called after each test is run. + // You should define it if there is cleanup work to do. Otherwise, + // you don't have to provide it. + // + // virtual void TearDown() { + // } + + // A helper function that some test uses. + static int Double(int n) { + return 2*n; + } + + // A helper function for testing Queue::Map(). + void MapTester(const Queue * q) { + // Creates a new queue, where each element is twice as big as the + // corresponding one in q. + const Queue * const new_q = q->Map(Double); + + // Verifies that the new queue has the same size as q. + ASSERT_EQ(q->Size(), new_q->Size()); + + // Verifies the relationship between the elements of the two queues. + for (const QueueNode*n1 = q->Head(), *n2 = new_q->Head(); + n1 != nullptr; n1 = n1->next(), n2 = n2->next()) { + EXPECT_EQ(2 * n1->element(), n2->element()); + } + + delete new_q; + } + + // Declares the variables your tests want to use. + Queue q0_; + Queue q1_; + Queue q2_; +}; + +// When you have a test fixture, you define a test using TEST_F +// instead of TEST. + +// Tests the default c'tor. +TEST_F(QueueTestSmpl3, DefaultConstructor) { + // You can access data in the test fixture here. + EXPECT_EQ(0u, q0_.Size()); +} + +// Tests Dequeue(). +TEST_F(QueueTestSmpl3, Dequeue) { + int * n = q0_.Dequeue(); + EXPECT_TRUE(n == nullptr); + + n = q1_.Dequeue(); + ASSERT_TRUE(n != nullptr); + EXPECT_EQ(1, *n); + EXPECT_EQ(0u, q1_.Size()); + delete n; + + n = q2_.Dequeue(); + ASSERT_TRUE(n != nullptr); + EXPECT_EQ(2, *n); + EXPECT_EQ(1u, q2_.Size()); + delete n; +} + +// Tests the Queue::Map() function. +TEST_F(QueueTestSmpl3, Map) { + MapTester(&q0_); + MapTester(&q1_); + MapTester(&q2_); +} +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4.cc new file mode 100644 index 000000000000..b0ee6093b4a2 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4.cc @@ -0,0 +1,54 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. + +#include + +#include "sample4.h" + +// Returns the current counter value, and increments it. +int Counter::Increment() { + return counter_++; +} + +// Returns the current counter value, and decrements it. +// counter can not be less than 0, return 0 in this case +int Counter::Decrement() { + if (counter_ == 0) { + return counter_; + } else { + return counter_--; + } +} + +// Prints the current counter value to STDOUT. +void Counter::Print() const { + printf("%d", counter_); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4.h new file mode 100644 index 000000000000..0c4ed92e738b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4.h @@ -0,0 +1,53 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// A sample program demonstrating using Google C++ testing framework. +#ifndef GOOGLETEST_SAMPLES_SAMPLE4_H_ +#define GOOGLETEST_SAMPLES_SAMPLE4_H_ + +// A simple monotonic counter. +class Counter { + private: + int counter_; + + public: + // Creates a counter that starts at 0. + Counter() : counter_(0) {} + + // Returns the current counter value, and increments it. + int Increment(); + + // Returns the current counter value, and decrements it. + int Decrement(); + + // Prints the current counter value to STDOUT. + void Print() const; +}; + +#endif // GOOGLETEST_SAMPLES_SAMPLE4_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4_unittest.cc new file mode 100644 index 000000000000..d5144c0d00f5 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample4_unittest.cc @@ -0,0 +1,53 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "sample4.h" +#include "gtest/gtest.h" + +namespace { +// Tests the Increment() method. + +TEST(Counter, Increment) { + Counter c; + + // Test that counter 0 returns 0 + EXPECT_EQ(0, c.Decrement()); + + // EXPECT_EQ() evaluates its arguments exactly once, so they + // can have side effects. + + EXPECT_EQ(0, c.Increment()); + EXPECT_EQ(1, c.Increment()); + EXPECT_EQ(2, c.Increment()); + + EXPECT_EQ(3, c.Decrement()); +} + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample5_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample5_unittest.cc new file mode 100644 index 000000000000..0a21dd215770 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample5_unittest.cc @@ -0,0 +1,196 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This sample teaches how to reuse a test fixture in multiple test +// cases by deriving sub-fixtures from it. +// +// When you define a test fixture, you specify the name of the test +// case that will use this fixture. Therefore, a test fixture can +// be used by only one test case. +// +// Sometimes, more than one test cases may want to use the same or +// slightly different test fixtures. For example, you may want to +// make sure that all tests for a GUI library don't leak important +// system resources like fonts and brushes. In Google Test, you do +// this by putting the shared logic in a super (as in "super class") +// test fixture, and then have each test case use a fixture derived +// from this super fixture. + +#include +#include +#include "gtest/gtest.h" +#include "sample1.h" +#include "sample3-inl.h" +namespace { +// In this sample, we want to ensure that every test finishes within +// ~5 seconds. If a test takes longer to run, we consider it a +// failure. +// +// We put the code for timing a test in a test fixture called +// "QuickTest". QuickTest is intended to be the super fixture that +// other fixtures derive from, therefore there is no test case with +// the name "QuickTest". This is OK. +// +// Later, we will derive multiple test fixtures from QuickTest. +class QuickTest : public testing::Test { + protected: + // Remember that SetUp() is run immediately before a test starts. + // This is a good place to record the start time. + void SetUp() override { start_time_ = time(nullptr); } + + // TearDown() is invoked immediately after a test finishes. Here we + // check if the test was too slow. + void TearDown() override { + // Gets the time when the test finishes + const time_t end_time = time(nullptr); + + // Asserts that the test took no more than ~5 seconds. Did you + // know that you can use assertions in SetUp() and TearDown() as + // well? + EXPECT_TRUE(end_time - start_time_ <= 5) << "The test took too long."; + } + + // The UTC time (in seconds) when the test starts + time_t start_time_; +}; + + +// We derive a fixture named IntegerFunctionTest from the QuickTest +// fixture. All tests using this fixture will be automatically +// required to be quick. +class IntegerFunctionTest : public QuickTest { + // We don't need any more logic than already in the QuickTest fixture. + // Therefore the body is empty. +}; + + +// Now we can write tests in the IntegerFunctionTest test case. + +// Tests Factorial() +TEST_F(IntegerFunctionTest, Factorial) { + // Tests factorial of negative numbers. + EXPECT_EQ(1, Factorial(-5)); + EXPECT_EQ(1, Factorial(-1)); + EXPECT_GT(Factorial(-10), 0); + + // Tests factorial of 0. + EXPECT_EQ(1, Factorial(0)); + + // Tests factorial of positive numbers. + EXPECT_EQ(1, Factorial(1)); + EXPECT_EQ(2, Factorial(2)); + EXPECT_EQ(6, Factorial(3)); + EXPECT_EQ(40320, Factorial(8)); +} + + +// Tests IsPrime() +TEST_F(IntegerFunctionTest, IsPrime) { + // Tests negative input. + EXPECT_FALSE(IsPrime(-1)); + EXPECT_FALSE(IsPrime(-2)); + EXPECT_FALSE(IsPrime(INT_MIN)); + + // Tests some trivial cases. + EXPECT_FALSE(IsPrime(0)); + EXPECT_FALSE(IsPrime(1)); + EXPECT_TRUE(IsPrime(2)); + EXPECT_TRUE(IsPrime(3)); + + // Tests positive input. + EXPECT_FALSE(IsPrime(4)); + EXPECT_TRUE(IsPrime(5)); + EXPECT_FALSE(IsPrime(6)); + EXPECT_TRUE(IsPrime(23)); +} + + +// The next test case (named "QueueTest") also needs to be quick, so +// we derive another fixture from QuickTest. +// +// The QueueTest test fixture has some logic and shared objects in +// addition to what's in QuickTest already. We define the additional +// stuff inside the body of the test fixture, as usual. +class QueueTest : public QuickTest { + protected: + void SetUp() override { + // First, we need to set up the super fixture (QuickTest). + QuickTest::SetUp(); + + // Second, some additional setup for this fixture. + q1_.Enqueue(1); + q2_.Enqueue(2); + q2_.Enqueue(3); + } + + // By default, TearDown() inherits the behavior of + // QuickTest::TearDown(). As we have no additional cleaning work + // for QueueTest, we omit it here. + // + // virtual void TearDown() { + // QuickTest::TearDown(); + // } + + Queue q0_; + Queue q1_; + Queue q2_; +}; + + +// Now, let's write tests using the QueueTest fixture. + +// Tests the default constructor. +TEST_F(QueueTest, DefaultConstructor) { + EXPECT_EQ(0u, q0_.Size()); +} + +// Tests Dequeue(). +TEST_F(QueueTest, Dequeue) { + int* n = q0_.Dequeue(); + EXPECT_TRUE(n == nullptr); + + n = q1_.Dequeue(); + EXPECT_TRUE(n != nullptr); + EXPECT_EQ(1, *n); + EXPECT_EQ(0u, q1_.Size()); + delete n; + + n = q2_.Dequeue(); + EXPECT_TRUE(n != nullptr); + EXPECT_EQ(2, *n); + EXPECT_EQ(1u, q2_.Size()); + delete n; +} +} // namespace +// If necessary, you can derive further test fixtures from a derived +// fixture itself. For example, you can derive another fixture from +// QueueTest. Google Test imposes no limit on how deep the hierarchy +// can be. In practice, however, you probably don't want it to be too +// deep as to be confusing. diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample6_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample6_unittest.cc new file mode 100644 index 000000000000..da317eed5a0d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample6_unittest.cc @@ -0,0 +1,217 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This sample shows how to test common properties of multiple +// implementations of the same interface (aka interface tests). + +// The interface and its implementations are in this header. +#include "prime_tables.h" + +#include "gtest/gtest.h" +namespace { +// First, we define some factory functions for creating instances of +// the implementations. You may be able to skip this step if all your +// implementations can be constructed the same way. + +template +PrimeTable* CreatePrimeTable(); + +template <> +PrimeTable* CreatePrimeTable() { + return new OnTheFlyPrimeTable; +} + +template <> +PrimeTable* CreatePrimeTable() { + return new PreCalculatedPrimeTable(10000); +} + +// Then we define a test fixture class template. +template +class PrimeTableTest : public testing::Test { + protected: + // The ctor calls the factory function to create a prime table + // implemented by T. + PrimeTableTest() : table_(CreatePrimeTable()) {} + + ~PrimeTableTest() override { delete table_; } + + // Note that we test an implementation via the base interface + // instead of the actual implementation class. This is important + // for keeping the tests close to the real world scenario, where the + // implementation is invoked via the base interface. It avoids + // got-yas where the implementation class has a method that shadows + // a method with the same name (but slightly different argument + // types) in the base interface, for example. + PrimeTable* const table_; +}; + +using testing::Types; + +// Google Test offers two ways for reusing tests for different types. +// The first is called "typed tests". You should use it if you +// already know *all* the types you are gonna exercise when you write +// the tests. + +// To write a typed test case, first use +// +// TYPED_TEST_SUITE(TestCaseName, TypeList); +// +// to declare it and specify the type parameters. As with TEST_F, +// TestCaseName must match the test fixture name. + +// The list of types we want to test. +typedef Types Implementations; + +TYPED_TEST_SUITE(PrimeTableTest, Implementations); + +// Then use TYPED_TEST(TestCaseName, TestName) to define a typed test, +// similar to TEST_F. +TYPED_TEST(PrimeTableTest, ReturnsFalseForNonPrimes) { + // Inside the test body, you can refer to the type parameter by + // TypeParam, and refer to the fixture class by TestFixture. We + // don't need them in this example. + + // Since we are in the template world, C++ requires explicitly + // writing 'this->' when referring to members of the fixture class. + // This is something you have to learn to live with. + EXPECT_FALSE(this->table_->IsPrime(-5)); + EXPECT_FALSE(this->table_->IsPrime(0)); + EXPECT_FALSE(this->table_->IsPrime(1)); + EXPECT_FALSE(this->table_->IsPrime(4)); + EXPECT_FALSE(this->table_->IsPrime(6)); + EXPECT_FALSE(this->table_->IsPrime(100)); +} + +TYPED_TEST(PrimeTableTest, ReturnsTrueForPrimes) { + EXPECT_TRUE(this->table_->IsPrime(2)); + EXPECT_TRUE(this->table_->IsPrime(3)); + EXPECT_TRUE(this->table_->IsPrime(5)); + EXPECT_TRUE(this->table_->IsPrime(7)); + EXPECT_TRUE(this->table_->IsPrime(11)); + EXPECT_TRUE(this->table_->IsPrime(131)); +} + +TYPED_TEST(PrimeTableTest, CanGetNextPrime) { + EXPECT_EQ(2, this->table_->GetNextPrime(0)); + EXPECT_EQ(3, this->table_->GetNextPrime(2)); + EXPECT_EQ(5, this->table_->GetNextPrime(3)); + EXPECT_EQ(7, this->table_->GetNextPrime(5)); + EXPECT_EQ(11, this->table_->GetNextPrime(7)); + EXPECT_EQ(131, this->table_->GetNextPrime(128)); +} + +// That's it! Google Test will repeat each TYPED_TEST for each type +// in the type list specified in TYPED_TEST_SUITE. Sit back and be +// happy that you don't have to define them multiple times. + +using testing::Types; + +// Sometimes, however, you don't yet know all the types that you want +// to test when you write the tests. For example, if you are the +// author of an interface and expect other people to implement it, you +// might want to write a set of tests to make sure each implementation +// conforms to some basic requirements, but you don't know what +// implementations will be written in the future. +// +// How can you write the tests without committing to the type +// parameters? That's what "type-parameterized tests" can do for you. +// It is a bit more involved than typed tests, but in return you get a +// test pattern that can be reused in many contexts, which is a big +// win. Here's how you do it: + +// First, define a test fixture class template. Here we just reuse +// the PrimeTableTest fixture defined earlier: + +template +class PrimeTableTest2 : public PrimeTableTest { +}; + +// Then, declare the test case. The argument is the name of the test +// fixture, and also the name of the test case (as usual). The _P +// suffix is for "parameterized" or "pattern". +TYPED_TEST_SUITE_P(PrimeTableTest2); + +// Next, use TYPED_TEST_P(TestCaseName, TestName) to define a test, +// similar to what you do with TEST_F. +TYPED_TEST_P(PrimeTableTest2, ReturnsFalseForNonPrimes) { + EXPECT_FALSE(this->table_->IsPrime(-5)); + EXPECT_FALSE(this->table_->IsPrime(0)); + EXPECT_FALSE(this->table_->IsPrime(1)); + EXPECT_FALSE(this->table_->IsPrime(4)); + EXPECT_FALSE(this->table_->IsPrime(6)); + EXPECT_FALSE(this->table_->IsPrime(100)); +} + +TYPED_TEST_P(PrimeTableTest2, ReturnsTrueForPrimes) { + EXPECT_TRUE(this->table_->IsPrime(2)); + EXPECT_TRUE(this->table_->IsPrime(3)); + EXPECT_TRUE(this->table_->IsPrime(5)); + EXPECT_TRUE(this->table_->IsPrime(7)); + EXPECT_TRUE(this->table_->IsPrime(11)); + EXPECT_TRUE(this->table_->IsPrime(131)); +} + +TYPED_TEST_P(PrimeTableTest2, CanGetNextPrime) { + EXPECT_EQ(2, this->table_->GetNextPrime(0)); + EXPECT_EQ(3, this->table_->GetNextPrime(2)); + EXPECT_EQ(5, this->table_->GetNextPrime(3)); + EXPECT_EQ(7, this->table_->GetNextPrime(5)); + EXPECT_EQ(11, this->table_->GetNextPrime(7)); + EXPECT_EQ(131, this->table_->GetNextPrime(128)); +} + +// Type-parameterized tests involve one extra step: you have to +// enumerate the tests you defined: +REGISTER_TYPED_TEST_SUITE_P( + PrimeTableTest2, // The first argument is the test case name. + // The rest of the arguments are the test names. + ReturnsFalseForNonPrimes, ReturnsTrueForPrimes, CanGetNextPrime); + +// At this point the test pattern is done. However, you don't have +// any real test yet as you haven't said which types you want to run +// the tests with. + +// To turn the abstract test pattern into real tests, you instantiate +// it with a list of types. Usually the test pattern will be defined +// in a .h file, and anyone can #include and instantiate it. You can +// even instantiate it more than once in the same program. To tell +// different instances apart, you give each of them a name, which will +// become part of the test case name and can be used in test filters. + +// The list of types we want to test. Note that it doesn't have to be +// defined at the time we write the TYPED_TEST_P()s. +typedef Types + PrimeTableImplementations; +INSTANTIATE_TYPED_TEST_SUITE_P(OnTheFlyAndPreCalculated, // Instance name + PrimeTableTest2, // Test case name + PrimeTableImplementations); // Type list + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample7_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample7_unittest.cc new file mode 100644 index 000000000000..e0efc29e4a2a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample7_unittest.cc @@ -0,0 +1,117 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This sample shows how to test common properties of multiple +// implementations of an interface (aka interface tests) using +// value-parameterized tests. Each test in the test case has +// a parameter that is an interface pointer to an implementation +// tested. + +// The interface and its implementations are in this header. +#include "prime_tables.h" + +#include "gtest/gtest.h" +namespace { + +using ::testing::TestWithParam; +using ::testing::Values; + +// As a general rule, to prevent a test from affecting the tests that come +// after it, you should create and destroy the tested objects for each test +// instead of reusing them. In this sample we will define a simple factory +// function for PrimeTable objects. We will instantiate objects in test's +// SetUp() method and delete them in TearDown() method. +typedef PrimeTable* CreatePrimeTableFunc(); + +PrimeTable* CreateOnTheFlyPrimeTable() { + return new OnTheFlyPrimeTable(); +} + +template +PrimeTable* CreatePreCalculatedPrimeTable() { + return new PreCalculatedPrimeTable(max_precalculated); +} + +// Inside the test body, fixture constructor, SetUp(), and TearDown() you +// can refer to the test parameter by GetParam(). In this case, the test +// parameter is a factory function which we call in fixture's SetUp() to +// create and store an instance of PrimeTable. +class PrimeTableTestSmpl7 : public TestWithParam { + public: + ~PrimeTableTestSmpl7() override { delete table_; } + void SetUp() override { table_ = (*GetParam())(); } + void TearDown() override { + delete table_; + table_ = nullptr; + } + + protected: + PrimeTable* table_; +}; + +TEST_P(PrimeTableTestSmpl7, ReturnsFalseForNonPrimes) { + EXPECT_FALSE(table_->IsPrime(-5)); + EXPECT_FALSE(table_->IsPrime(0)); + EXPECT_FALSE(table_->IsPrime(1)); + EXPECT_FALSE(table_->IsPrime(4)); + EXPECT_FALSE(table_->IsPrime(6)); + EXPECT_FALSE(table_->IsPrime(100)); +} + +TEST_P(PrimeTableTestSmpl7, ReturnsTrueForPrimes) { + EXPECT_TRUE(table_->IsPrime(2)); + EXPECT_TRUE(table_->IsPrime(3)); + EXPECT_TRUE(table_->IsPrime(5)); + EXPECT_TRUE(table_->IsPrime(7)); + EXPECT_TRUE(table_->IsPrime(11)); + EXPECT_TRUE(table_->IsPrime(131)); +} + +TEST_P(PrimeTableTestSmpl7, CanGetNextPrime) { + EXPECT_EQ(2, table_->GetNextPrime(0)); + EXPECT_EQ(3, table_->GetNextPrime(2)); + EXPECT_EQ(5, table_->GetNextPrime(3)); + EXPECT_EQ(7, table_->GetNextPrime(5)); + EXPECT_EQ(11, table_->GetNextPrime(7)); + EXPECT_EQ(131, table_->GetNextPrime(128)); +} + +// In order to run value-parameterized tests, you need to instantiate them, +// or bind them to a list of values which will be used as test parameters. +// You can instantiate them in a different translation module, or even +// instantiate them several times. +// +// Here, we instantiate our tests with a list of two PrimeTable object +// factory functions: +INSTANTIATE_TEST_SUITE_P(OnTheFlyAndPreCalculated, PrimeTableTestSmpl7, + Values(&CreateOnTheFlyPrimeTable, + &CreatePreCalculatedPrimeTable<1000>)); + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample8_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample8_unittest.cc new file mode 100644 index 000000000000..10488b0ea450 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample8_unittest.cc @@ -0,0 +1,154 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This sample shows how to test code relying on some global flag variables. +// Combine() helps with generating all possible combinations of such flags, +// and each test is given one combination as a parameter. + +// Use class definitions to test from this header. +#include "prime_tables.h" + +#include "gtest/gtest.h" +namespace { + +// Suppose we want to introduce a new, improved implementation of PrimeTable +// which combines speed of PrecalcPrimeTable and versatility of +// OnTheFlyPrimeTable (see prime_tables.h). Inside it instantiates both +// PrecalcPrimeTable and OnTheFlyPrimeTable and uses the one that is more +// appropriate under the circumstances. But in low memory conditions, it can be +// told to instantiate without PrecalcPrimeTable instance at all and use only +// OnTheFlyPrimeTable. +class HybridPrimeTable : public PrimeTable { + public: + HybridPrimeTable(bool force_on_the_fly, int max_precalculated) + : on_the_fly_impl_(new OnTheFlyPrimeTable), + precalc_impl_(force_on_the_fly + ? nullptr + : new PreCalculatedPrimeTable(max_precalculated)), + max_precalculated_(max_precalculated) {} + ~HybridPrimeTable() override { + delete on_the_fly_impl_; + delete precalc_impl_; + } + + bool IsPrime(int n) const override { + if (precalc_impl_ != nullptr && n < max_precalculated_) + return precalc_impl_->IsPrime(n); + else + return on_the_fly_impl_->IsPrime(n); + } + + int GetNextPrime(int p) const override { + int next_prime = -1; + if (precalc_impl_ != nullptr && p < max_precalculated_) + next_prime = precalc_impl_->GetNextPrime(p); + + return next_prime != -1 ? next_prime : on_the_fly_impl_->GetNextPrime(p); + } + + private: + OnTheFlyPrimeTable* on_the_fly_impl_; + PreCalculatedPrimeTable* precalc_impl_; + int max_precalculated_; +}; + +using ::testing::TestWithParam; +using ::testing::Bool; +using ::testing::Values; +using ::testing::Combine; + +// To test all code paths for HybridPrimeTable we must test it with numbers +// both within and outside PreCalculatedPrimeTable's capacity and also with +// PreCalculatedPrimeTable disabled. We do this by defining fixture which will +// accept different combinations of parameters for instantiating a +// HybridPrimeTable instance. +class PrimeTableTest : public TestWithParam< ::std::tuple > { + protected: + void SetUp() override { + bool force_on_the_fly; + int max_precalculated; + std::tie(force_on_the_fly, max_precalculated) = GetParam(); + table_ = new HybridPrimeTable(force_on_the_fly, max_precalculated); + } + void TearDown() override { + delete table_; + table_ = nullptr; + } + HybridPrimeTable* table_; +}; + +TEST_P(PrimeTableTest, ReturnsFalseForNonPrimes) { + // Inside the test body, you can refer to the test parameter by GetParam(). + // In this case, the test parameter is a PrimeTable interface pointer which + // we can use directly. + // Please note that you can also save it in the fixture's SetUp() method + // or constructor and use saved copy in the tests. + + EXPECT_FALSE(table_->IsPrime(-5)); + EXPECT_FALSE(table_->IsPrime(0)); + EXPECT_FALSE(table_->IsPrime(1)); + EXPECT_FALSE(table_->IsPrime(4)); + EXPECT_FALSE(table_->IsPrime(6)); + EXPECT_FALSE(table_->IsPrime(100)); +} + +TEST_P(PrimeTableTest, ReturnsTrueForPrimes) { + EXPECT_TRUE(table_->IsPrime(2)); + EXPECT_TRUE(table_->IsPrime(3)); + EXPECT_TRUE(table_->IsPrime(5)); + EXPECT_TRUE(table_->IsPrime(7)); + EXPECT_TRUE(table_->IsPrime(11)); + EXPECT_TRUE(table_->IsPrime(131)); +} + +TEST_P(PrimeTableTest, CanGetNextPrime) { + EXPECT_EQ(2, table_->GetNextPrime(0)); + EXPECT_EQ(3, table_->GetNextPrime(2)); + EXPECT_EQ(5, table_->GetNextPrime(3)); + EXPECT_EQ(7, table_->GetNextPrime(5)); + EXPECT_EQ(11, table_->GetNextPrime(7)); + EXPECT_EQ(131, table_->GetNextPrime(128)); +} + +// In order to run value-parameterized tests, you need to instantiate them, +// or bind them to a list of values which will be used as test parameters. +// You can instantiate them in a different translation module, or even +// instantiate them several times. +// +// Here, we instantiate our tests with a list of parameters. We must combine +// all variations of the boolean flag suppressing PrecalcPrimeTable and some +// meaningful values for tests. We choose a small value (1), and a value that +// will put some of the tested numbers beyond the capability of the +// PrecalcPrimeTable instance and some inside it (10). Combine will produce all +// possible combinations. +INSTANTIATE_TEST_SUITE_P(MeaningfulTestParameters, PrimeTableTest, + Combine(Bool(), Values(1, 10))); + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample9_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample9_unittest.cc new file mode 100644 index 000000000000..0245b531dcf5 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/samples/sample9_unittest.cc @@ -0,0 +1,156 @@ +// Copyright 2009 Google Inc. All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This sample shows how to use Google Test listener API to implement +// an alternative console output and how to use the UnitTest reflection API +// to enumerate test suites and tests and to inspect their results. + +#include + +#include "gtest/gtest.h" + +using ::testing::EmptyTestEventListener; +using ::testing::InitGoogleTest; +using ::testing::Test; +using ::testing::TestSuite; +using ::testing::TestEventListeners; +using ::testing::TestInfo; +using ::testing::TestPartResult; +using ::testing::UnitTest; +namespace { +// Provides alternative output mode which produces minimal amount of +// information about tests. +class TersePrinter : public EmptyTestEventListener { + private: + // Called before any test activity starts. + void OnTestProgramStart(const UnitTest& /* unit_test */) override {} + + // Called after all test activities have ended. + void OnTestProgramEnd(const UnitTest& unit_test) override { + fprintf(stdout, "TEST %s\n", unit_test.Passed() ? "PASSED" : "FAILED"); + fflush(stdout); + } + + // Called before a test starts. + void OnTestStart(const TestInfo& test_info) override { + fprintf(stdout, + "*** Test %s.%s starting.\n", + test_info.test_suite_name(), + test_info.name()); + fflush(stdout); + } + + // Called after a failed assertion or a SUCCEED() invocation. + void OnTestPartResult(const TestPartResult& test_part_result) override { + fprintf(stdout, + "%s in %s:%d\n%s\n", + test_part_result.failed() ? "*** Failure" : "Success", + test_part_result.file_name(), + test_part_result.line_number(), + test_part_result.summary()); + fflush(stdout); + } + + // Called after a test ends. + void OnTestEnd(const TestInfo& test_info) override { + fprintf(stdout, + "*** Test %s.%s ending.\n", + test_info.test_suite_name(), + test_info.name()); + fflush(stdout); + } +}; // class TersePrinter + +TEST(CustomOutputTest, PrintsMessage) { + printf("Printing something from the test body...\n"); +} + +TEST(CustomOutputTest, Succeeds) { + SUCCEED() << "SUCCEED() has been invoked from here"; +} + +TEST(CustomOutputTest, Fails) { + EXPECT_EQ(1, 2) + << "This test fails in order to demonstrate alternative failure messages"; +} +} // namespace + +int main(int argc, char **argv) { + InitGoogleTest(&argc, argv); + + bool terse_output = false; + if (argc > 1 && strcmp(argv[1], "--terse_output") == 0 ) + terse_output = true; + else + printf("%s\n", "Run this program with --terse_output to change the way " + "it prints its output."); + + UnitTest& unit_test = *UnitTest::GetInstance(); + + // If we are given the --terse_output command line flag, suppresses the + // standard output and attaches own result printer. + if (terse_output) { + TestEventListeners& listeners = unit_test.listeners(); + + // Removes the default console output listener from the list so it will + // not receive events from Google Test and won't print any output. Since + // this operation transfers ownership of the listener to the caller we + // have to delete it as well. + delete listeners.Release(listeners.default_result_printer()); + + // Adds the custom output listener to the list. It will now receive + // events from Google Test and print the alternative output. We don't + // have to worry about deleting it since Google Test assumes ownership + // over it after adding it to the list. + listeners.Append(new TersePrinter); + } + int ret_val = RUN_ALL_TESTS(); + + // This is an example of using the UnitTest reflection API to inspect test + // results. Here we discount failures from the tests we expected to fail. + int unexpectedly_failed_tests = 0; + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { + const testing::TestSuite& test_suite = *unit_test.GetTestSuite(i); + for (int j = 0; j < test_suite.total_test_count(); ++j) { + const TestInfo& test_info = *test_suite.GetTestInfo(j); + // Counts failed tests that were not meant to fail (those without + // 'Fails' in the name). + if (test_info.result()->Failed() && + strcmp(test_info.name(), "Fails") != 0) { + unexpectedly_failed_tests++; + } + } + } + + // Test that were meant to fail should not affect the test program outcome. + if (unexpectedly_failed_tests == 0) + ret_val = 0; + + return ret_val; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-all.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-all.cc new file mode 100644 index 000000000000..29eba165e4d4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-all.cc @@ -0,0 +1,49 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Google C++ Testing and Mocking Framework (Google Test) +// +// Sometimes it's desirable to build Google Test by compiling a single file. +// This file serves this purpose. + +// This line ensures that gtest.h can be compiled on its own, even +// when it's fused. +#include "gtest/gtest.h" + +// The following lines pull in the real gtest *.cc files. +#include "src/gtest.cc" +#include "src/gtest-assertion-result.cc" +#include "src/gtest-death-test.cc" +#include "src/gtest-filepath.cc" +#include "src/gtest-matchers.cc" +#include "src/gtest-port.cc" +#include "src/gtest-printers.cc" +#include "src/gtest-test-part.cc" +#include "src/gtest-typed-test.cc" diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-assertion-result.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-assertion-result.cc new file mode 100644 index 000000000000..9f90e8729849 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-assertion-result.cc @@ -0,0 +1,81 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This file defines the AssertionResult type. + +#include "gtest/gtest-assertion-result.h" + +#include +#include + +#include "gtest/gtest-message.h" + +namespace testing { + +// AssertionResult constructors. +// Used in EXPECT_TRUE/FALSE(assertion_result). +AssertionResult::AssertionResult(const AssertionResult& other) + : success_(other.success_), + message_(other.message_.get() != nullptr + ? new ::std::string(*other.message_) + : static_cast< ::std::string*>(nullptr)) {} + +// Swaps two AssertionResults. +void AssertionResult::swap(AssertionResult& other) { + using std::swap; + swap(success_, other.success_); + swap(message_, other.message_); +} + +// Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. +AssertionResult AssertionResult::operator!() const { + AssertionResult negation(!success_); + if (message_.get() != nullptr) negation << *message_; + return negation; +} + +// Makes a successful assertion result. +AssertionResult AssertionSuccess() { + return AssertionResult(true); +} + +// Makes a failed assertion result. +AssertionResult AssertionFailure() { + return AssertionResult(false); +} + +// Makes a failed assertion result with the given failure message. +// Deprecated; use AssertionFailure() << message. +AssertionResult AssertionFailure(const Message& message) { + return AssertionFailure() << message; +} + +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-death-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-death-test.cc new file mode 100644 index 000000000000..87d5e9b924cc --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-death-test.cc @@ -0,0 +1,1647 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// This file implements death tests. + +#include "gtest/gtest-death-test.h" + +#include +#include + +#include "gtest/internal/gtest-port.h" +#include "gtest/internal/custom/gtest.h" + +#if GTEST_HAS_DEATH_TEST + +# if GTEST_OS_MAC +# include +# endif // GTEST_OS_MAC + +# include +# include +# include + +# if GTEST_OS_LINUX +# include +# endif // GTEST_OS_LINUX + +# include + +# if GTEST_OS_WINDOWS +# include +# else +# include +# include +# endif // GTEST_OS_WINDOWS + +# if GTEST_OS_QNX +# include +# endif // GTEST_OS_QNX + +# if GTEST_OS_FUCHSIA +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# endif // GTEST_OS_FUCHSIA + +#endif // GTEST_HAS_DEATH_TEST + +#include "gtest/gtest-message.h" +#include "gtest/internal/gtest-string.h" +#include "src/gtest-internal-inl.h" + +namespace testing { + +// Constants. + +// The default death test style. +// +// This is defined in internal/gtest-port.h as "fast", but can be overridden by +// a definition in internal/custom/gtest-port.h. The recommended value, which is +// used internally at Google, is "threadsafe". +static const char kDefaultDeathTestStyle[] = GTEST_DEFAULT_DEATH_TEST_STYLE; + +} // namespace testing + +GTEST_DEFINE_string_( + death_test_style, + testing::internal::StringFromGTestEnv("death_test_style", + testing::kDefaultDeathTestStyle), + "Indicates how to run a death test in a forked child process: " + "\"threadsafe\" (child process re-executes the test binary " + "from the beginning, running only the specific death test) or " + "\"fast\" (child process runs the death test immediately " + "after forking)."); + +GTEST_DEFINE_bool_( + death_test_use_fork, + testing::internal::BoolFromGTestEnv("death_test_use_fork", false), + "Instructs to use fork()/_exit() instead of clone() in death tests. " + "Ignored and always uses fork() on POSIX systems where clone() is not " + "implemented. Useful when running under valgrind or similar tools if " + "those do not support clone(). Valgrind 3.3.1 will just fail if " + "it sees an unsupported combination of clone() flags. " + "It is not recommended to use this flag w/o valgrind though it will " + "work in 99% of the cases. Once valgrind is fixed, this flag will " + "most likely be removed."); + +GTEST_DEFINE_string_( + internal_run_death_test, "", + "Indicates the file, line number, temporal index of " + "the single death test to run, and a file descriptor to " + "which a success code may be sent, all separated by " + "the '|' characters. This flag is specified if and only if the " + "current process is a sub-process launched for running a thread-safe " + "death test. FOR INTERNAL USE ONLY."); + +namespace testing { + +#if GTEST_HAS_DEATH_TEST + +namespace internal { + +// Valid only for fast death tests. Indicates the code is running in the +// child process of a fast style death test. +# if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA +static bool g_in_fast_death_test_child = false; +# endif + +// Returns a Boolean value indicating whether the caller is currently +// executing in the context of the death test child process. Tools such as +// Valgrind heap checkers may need this to modify their behavior in death +// tests. IMPORTANT: This is an internal utility. Using it may break the +// implementation of death tests. User code MUST NOT use it. +bool InDeathTestChild() { +# if GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA + + // On Windows and Fuchsia, death tests are thread-safe regardless of the value + // of the death_test_style flag. + return !GTEST_FLAG_GET(internal_run_death_test).empty(); + +# else + + if (GTEST_FLAG_GET(death_test_style) == "threadsafe") + return !GTEST_FLAG_GET(internal_run_death_test).empty(); + else + return g_in_fast_death_test_child; +#endif +} + +} // namespace internal + +// ExitedWithCode constructor. +ExitedWithCode::ExitedWithCode(int exit_code) : exit_code_(exit_code) { +} + +// ExitedWithCode function-call operator. +bool ExitedWithCode::operator()(int exit_status) const { +# if GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA + + return exit_status == exit_code_; + +# else + + return WIFEXITED(exit_status) && WEXITSTATUS(exit_status) == exit_code_; + +# endif // GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA +} + +# if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA +// KilledBySignal constructor. +KilledBySignal::KilledBySignal(int signum) : signum_(signum) { +} + +// KilledBySignal function-call operator. +bool KilledBySignal::operator()(int exit_status) const { +# if defined(GTEST_KILLED_BY_SIGNAL_OVERRIDE_) + { + bool result; + if (GTEST_KILLED_BY_SIGNAL_OVERRIDE_(signum_, exit_status, &result)) { + return result; + } + } +# endif // defined(GTEST_KILLED_BY_SIGNAL_OVERRIDE_) + return WIFSIGNALED(exit_status) && WTERMSIG(exit_status) == signum_; +} +# endif // !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA + +namespace internal { + +// Utilities needed for death tests. + +// Generates a textual description of a given exit code, in the format +// specified by wait(2). +static std::string ExitSummary(int exit_code) { + Message m; + +# if GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA + + m << "Exited with exit status " << exit_code; + +# else + + if (WIFEXITED(exit_code)) { + m << "Exited with exit status " << WEXITSTATUS(exit_code); + } else if (WIFSIGNALED(exit_code)) { + m << "Terminated by signal " << WTERMSIG(exit_code); + } +# ifdef WCOREDUMP + if (WCOREDUMP(exit_code)) { + m << " (core dumped)"; + } +# endif +# endif // GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA + + return m.GetString(); +} + +// Returns true if exit_status describes a process that was terminated +// by a signal, or exited normally with a nonzero exit code. +bool ExitedUnsuccessfully(int exit_status) { + return !ExitedWithCode(0)(exit_status); +} + +# if !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA +// Generates a textual failure message when a death test finds more than +// one thread running, or cannot determine the number of threads, prior +// to executing the given statement. It is the responsibility of the +// caller not to pass a thread_count of 1. +static std::string DeathTestThreadWarning(size_t thread_count) { + Message msg; + msg << "Death tests use fork(), which is unsafe particularly" + << " in a threaded context. For this test, " << GTEST_NAME_ << " "; + if (thread_count == 0) { + msg << "couldn't detect the number of threads."; + } else { + msg << "detected " << thread_count << " threads."; + } + msg << " See " + "https://github.com/google/googletest/blob/master/docs/" + "advanced.md#death-tests-and-threads" + << " for more explanation and suggested solutions, especially if" + << " this is the last message you see before your test times out."; + return msg.GetString(); +} +# endif // !GTEST_OS_WINDOWS && !GTEST_OS_FUCHSIA + +// Flag characters for reporting a death test that did not die. +static const char kDeathTestLived = 'L'; +static const char kDeathTestReturned = 'R'; +static const char kDeathTestThrew = 'T'; +static const char kDeathTestInternalError = 'I'; + +#if GTEST_OS_FUCHSIA + +// File descriptor used for the pipe in the child process. +static const int kFuchsiaReadPipeFd = 3; + +#endif + +// An enumeration describing all of the possible ways that a death test can +// conclude. DIED means that the process died while executing the test +// code; LIVED means that process lived beyond the end of the test code; +// RETURNED means that the test statement attempted to execute a return +// statement, which is not allowed; THREW means that the test statement +// returned control by throwing an exception. IN_PROGRESS means the test +// has not yet concluded. +enum DeathTestOutcome { IN_PROGRESS, DIED, LIVED, RETURNED, THREW }; + +// Routine for aborting the program which is safe to call from an +// exec-style death test child process, in which case the error +// message is propagated back to the parent process. Otherwise, the +// message is simply printed to stderr. In either case, the program +// then exits with status 1. +static void DeathTestAbort(const std::string& message) { + // On a POSIX system, this function may be called from a threadsafe-style + // death test child process, which operates on a very small stack. Use + // the heap for any additional non-minuscule memory requirements. + const InternalRunDeathTestFlag* const flag = + GetUnitTestImpl()->internal_run_death_test_flag(); + if (flag != nullptr) { + FILE* parent = posix::FDOpen(flag->write_fd(), "w"); + fputc(kDeathTestInternalError, parent); + fprintf(parent, "%s", message.c_str()); + fflush(parent); + _exit(1); + } else { + fprintf(stderr, "%s", message.c_str()); + fflush(stderr); + posix::Abort(); + } +} + +// A replacement for CHECK that calls DeathTestAbort if the assertion +// fails. +# define GTEST_DEATH_TEST_CHECK_(expression) \ + do { \ + if (!::testing::internal::IsTrue(expression)) { \ + DeathTestAbort( \ + ::std::string("CHECK failed: File ") + __FILE__ + ", line " \ + + ::testing::internal::StreamableToString(__LINE__) + ": " \ + + #expression); \ + } \ + } while (::testing::internal::AlwaysFalse()) + +// This macro is similar to GTEST_DEATH_TEST_CHECK_, but it is meant for +// evaluating any system call that fulfills two conditions: it must return +// -1 on failure, and set errno to EINTR when it is interrupted and +// should be tried again. The macro expands to a loop that repeatedly +// evaluates the expression as long as it evaluates to -1 and sets +// errno to EINTR. If the expression evaluates to -1 but errno is +// something other than EINTR, DeathTestAbort is called. +# define GTEST_DEATH_TEST_CHECK_SYSCALL_(expression) \ + do { \ + int gtest_retval; \ + do { \ + gtest_retval = (expression); \ + } while (gtest_retval == -1 && errno == EINTR); \ + if (gtest_retval == -1) { \ + DeathTestAbort( \ + ::std::string("CHECK failed: File ") + __FILE__ + ", line " \ + + ::testing::internal::StreamableToString(__LINE__) + ": " \ + + #expression + " != -1"); \ + } \ + } while (::testing::internal::AlwaysFalse()) + +// Returns the message describing the last system error in errno. +std::string GetLastErrnoDescription() { + return errno == 0 ? "" : posix::StrError(errno); +} + +// This is called from a death test parent process to read a failure +// message from the death test child process and log it with the FATAL +// severity. On Windows, the message is read from a pipe handle. On other +// platforms, it is read from a file descriptor. +static void FailFromInternalError(int fd) { + Message error; + char buffer[256]; + int num_read; + + do { + while ((num_read = posix::Read(fd, buffer, 255)) > 0) { + buffer[num_read] = '\0'; + error << buffer; + } + } while (num_read == -1 && errno == EINTR); + + if (num_read == 0) { + GTEST_LOG_(FATAL) << error.GetString(); + } else { + const int last_error = errno; + GTEST_LOG_(FATAL) << "Error while reading death test internal: " + << GetLastErrnoDescription() << " [" << last_error << "]"; + } +} + +// Death test constructor. Increments the running death test count +// for the current test. +DeathTest::DeathTest() { + TestInfo* const info = GetUnitTestImpl()->current_test_info(); + if (info == nullptr) { + DeathTestAbort("Cannot run a death test outside of a TEST or " + "TEST_F construct"); + } +} + +// Creates and returns a death test by dispatching to the current +// death test factory. +bool DeathTest::Create(const char* statement, + Matcher matcher, const char* file, + int line, DeathTest** test) { + return GetUnitTestImpl()->death_test_factory()->Create( + statement, std::move(matcher), file, line, test); +} + +const char* DeathTest::LastMessage() { + return last_death_test_message_.c_str(); +} + +void DeathTest::set_last_death_test_message(const std::string& message) { + last_death_test_message_ = message; +} + +std::string DeathTest::last_death_test_message_; + +// Provides cross platform implementation for some death functionality. +class DeathTestImpl : public DeathTest { + protected: + DeathTestImpl(const char* a_statement, Matcher matcher) + : statement_(a_statement), + matcher_(std::move(matcher)), + spawned_(false), + status_(-1), + outcome_(IN_PROGRESS), + read_fd_(-1), + write_fd_(-1) {} + + // read_fd_ is expected to be closed and cleared by a derived class. + ~DeathTestImpl() override { GTEST_DEATH_TEST_CHECK_(read_fd_ == -1); } + + void Abort(AbortReason reason) override; + bool Passed(bool status_ok) override; + + const char* statement() const { return statement_; } + bool spawned() const { return spawned_; } + void set_spawned(bool is_spawned) { spawned_ = is_spawned; } + int status() const { return status_; } + void set_status(int a_status) { status_ = a_status; } + DeathTestOutcome outcome() const { return outcome_; } + void set_outcome(DeathTestOutcome an_outcome) { outcome_ = an_outcome; } + int read_fd() const { return read_fd_; } + void set_read_fd(int fd) { read_fd_ = fd; } + int write_fd() const { return write_fd_; } + void set_write_fd(int fd) { write_fd_ = fd; } + + // Called in the parent process only. Reads the result code of the death + // test child process via a pipe, interprets it to set the outcome_ + // member, and closes read_fd_. Outputs diagnostics and terminates in + // case of unexpected codes. + void ReadAndInterpretStatusByte(); + + // Returns stderr output from the child process. + virtual std::string GetErrorLogs(); + + private: + // The textual content of the code this object is testing. This class + // doesn't own this string and should not attempt to delete it. + const char* const statement_; + // A matcher that's expected to match the stderr output by the child process. + Matcher matcher_; + // True if the death test child process has been successfully spawned. + bool spawned_; + // The exit status of the child process. + int status_; + // How the death test concluded. + DeathTestOutcome outcome_; + // Descriptor to the read end of the pipe to the child process. It is + // always -1 in the child process. The child keeps its write end of the + // pipe in write_fd_. + int read_fd_; + // Descriptor to the child's write end of the pipe to the parent process. + // It is always -1 in the parent process. The parent keeps its end of the + // pipe in read_fd_. + int write_fd_; +}; + +// Called in the parent process only. Reads the result code of the death +// test child process via a pipe, interprets it to set the outcome_ +// member, and closes read_fd_. Outputs diagnostics and terminates in +// case of unexpected codes. +void DeathTestImpl::ReadAndInterpretStatusByte() { + char flag; + int bytes_read; + + // The read() here blocks until data is available (signifying the + // failure of the death test) or until the pipe is closed (signifying + // its success), so it's okay to call this in the parent before + // the child process has exited. + do { + bytes_read = posix::Read(read_fd(), &flag, 1); + } while (bytes_read == -1 && errno == EINTR); + + if (bytes_read == 0) { + set_outcome(DIED); + } else if (bytes_read == 1) { + switch (flag) { + case kDeathTestReturned: + set_outcome(RETURNED); + break; + case kDeathTestThrew: + set_outcome(THREW); + break; + case kDeathTestLived: + set_outcome(LIVED); + break; + case kDeathTestInternalError: + FailFromInternalError(read_fd()); // Does not return. + break; + default: + GTEST_LOG_(FATAL) << "Death test child process reported " + << "unexpected status byte (" + << static_cast(flag) << ")"; + } + } else { + GTEST_LOG_(FATAL) << "Read from death test child process failed: " + << GetLastErrnoDescription(); + } + GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Close(read_fd())); + set_read_fd(-1); +} + +std::string DeathTestImpl::GetErrorLogs() { + return GetCapturedStderr(); +} + +// Signals that the death test code which should have exited, didn't. +// Should be called only in a death test child process. +// Writes a status byte to the child's status file descriptor, then +// calls _exit(1). +void DeathTestImpl::Abort(AbortReason reason) { + // The parent process considers the death test to be a failure if + // it finds any data in our pipe. So, here we write a single flag byte + // to the pipe, then exit. + const char status_ch = + reason == TEST_DID_NOT_DIE ? kDeathTestLived : + reason == TEST_THREW_EXCEPTION ? kDeathTestThrew : kDeathTestReturned; + + GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Write(write_fd(), &status_ch, 1)); + // We are leaking the descriptor here because on some platforms (i.e., + // when built as Windows DLL), destructors of global objects will still + // run after calling _exit(). On such systems, write_fd_ will be + // indirectly closed from the destructor of UnitTestImpl, causing double + // close if it is also closed here. On debug configurations, double close + // may assert. As there are no in-process buffers to flush here, we are + // relying on the OS to close the descriptor after the process terminates + // when the destructors are not run. + _exit(1); // Exits w/o any normal exit hooks (we were supposed to crash) +} + +// Returns an indented copy of stderr output for a death test. +// This makes distinguishing death test output lines from regular log lines +// much easier. +static ::std::string FormatDeathTestOutput(const ::std::string& output) { + ::std::string ret; + for (size_t at = 0; ; ) { + const size_t line_end = output.find('\n', at); + ret += "[ DEATH ] "; + if (line_end == ::std::string::npos) { + ret += output.substr(at); + break; + } + ret += output.substr(at, line_end + 1 - at); + at = line_end + 1; + } + return ret; +} + +// Assesses the success or failure of a death test, using both private +// members which have previously been set, and one argument: +// +// Private data members: +// outcome: An enumeration describing how the death test +// concluded: DIED, LIVED, THREW, or RETURNED. The death test +// fails in the latter three cases. +// status: The exit status of the child process. On *nix, it is in the +// in the format specified by wait(2). On Windows, this is the +// value supplied to the ExitProcess() API or a numeric code +// of the exception that terminated the program. +// matcher_: A matcher that's expected to match the stderr output by the child +// process. +// +// Argument: +// status_ok: true if exit_status is acceptable in the context of +// this particular death test, which fails if it is false +// +// Returns true if and only if all of the above conditions are met. Otherwise, +// the first failing condition, in the order given above, is the one that is +// reported. Also sets the last death test message string. +bool DeathTestImpl::Passed(bool status_ok) { + if (!spawned()) + return false; + + const std::string error_message = GetErrorLogs(); + + bool success = false; + Message buffer; + + buffer << "Death test: " << statement() << "\n"; + switch (outcome()) { + case LIVED: + buffer << " Result: failed to die.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case THREW: + buffer << " Result: threw an exception.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case RETURNED: + buffer << " Result: illegal return in test statement.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case DIED: + if (status_ok) { + if (matcher_.Matches(error_message)) { + success = true; + } else { + std::ostringstream stream; + matcher_.DescribeTo(&stream); + buffer << " Result: died but not with expected error.\n" + << " Expected: " << stream.str() << "\n" + << "Actual msg:\n" + << FormatDeathTestOutput(error_message); + } + } else { + buffer << " Result: died but not with expected exit code:\n" + << " " << ExitSummary(status()) << "\n" + << "Actual msg:\n" << FormatDeathTestOutput(error_message); + } + break; + case IN_PROGRESS: + default: + GTEST_LOG_(FATAL) + << "DeathTest::Passed somehow called before conclusion of test"; + } + + DeathTest::set_last_death_test_message(buffer.GetString()); + return success; +} + +# if GTEST_OS_WINDOWS +// WindowsDeathTest implements death tests on Windows. Due to the +// specifics of starting new processes on Windows, death tests there are +// always threadsafe, and Google Test considers the +// --gtest_death_test_style=fast setting to be equivalent to +// --gtest_death_test_style=threadsafe there. +// +// A few implementation notes: Like the Linux version, the Windows +// implementation uses pipes for child-to-parent communication. But due to +// the specifics of pipes on Windows, some extra steps are required: +// +// 1. The parent creates a communication pipe and stores handles to both +// ends of it. +// 2. The parent starts the child and provides it with the information +// necessary to acquire the handle to the write end of the pipe. +// 3. The child acquires the write end of the pipe and signals the parent +// using a Windows event. +// 4. Now the parent can release the write end of the pipe on its side. If +// this is done before step 3, the object's reference count goes down to +// 0 and it is destroyed, preventing the child from acquiring it. The +// parent now has to release it, or read operations on the read end of +// the pipe will not return when the child terminates. +// 5. The parent reads child's output through the pipe (outcome code and +// any possible error messages) from the pipe, and its stderr and then +// determines whether to fail the test. +// +// Note: to distinguish Win32 API calls from the local method and function +// calls, the former are explicitly resolved in the global namespace. +// +class WindowsDeathTest : public DeathTestImpl { + public: + WindowsDeathTest(const char* a_statement, Matcher matcher, + const char* file, int line) + : DeathTestImpl(a_statement, std::move(matcher)), + file_(file), + line_(line) {} + + // All of these virtual functions are inherited from DeathTest. + virtual int Wait(); + virtual TestRole AssumeRole(); + + private: + // The name of the file in which the death test is located. + const char* const file_; + // The line number on which the death test is located. + const int line_; + // Handle to the write end of the pipe to the child process. + AutoHandle write_handle_; + // Child process handle. + AutoHandle child_handle_; + // Event the child process uses to signal the parent that it has + // acquired the handle to the write end of the pipe. After seeing this + // event the parent can release its own handles to make sure its + // ReadFile() calls return when the child terminates. + AutoHandle event_handle_; +}; + +// Waits for the child in a death test to exit, returning its exit +// status, or 0 if no child process exists. As a side effect, sets the +// outcome data member. +int WindowsDeathTest::Wait() { + if (!spawned()) + return 0; + + // Wait until the child either signals that it has acquired the write end + // of the pipe or it dies. + const HANDLE wait_handles[2] = { child_handle_.Get(), event_handle_.Get() }; + switch (::WaitForMultipleObjects(2, + wait_handles, + FALSE, // Waits for any of the handles. + INFINITE)) { + case WAIT_OBJECT_0: + case WAIT_OBJECT_0 + 1: + break; + default: + GTEST_DEATH_TEST_CHECK_(false); // Should not get here. + } + + // The child has acquired the write end of the pipe or exited. + // We release the handle on our side and continue. + write_handle_.Reset(); + event_handle_.Reset(); + + ReadAndInterpretStatusByte(); + + // Waits for the child process to exit if it haven't already. This + // returns immediately if the child has already exited, regardless of + // whether previous calls to WaitForMultipleObjects synchronized on this + // handle or not. + GTEST_DEATH_TEST_CHECK_( + WAIT_OBJECT_0 == ::WaitForSingleObject(child_handle_.Get(), + INFINITE)); + DWORD status_code; + GTEST_DEATH_TEST_CHECK_( + ::GetExitCodeProcess(child_handle_.Get(), &status_code) != FALSE); + child_handle_.Reset(); + set_status(static_cast(status_code)); + return status(); +} + +// The AssumeRole process for a Windows death test. It creates a child +// process with the same executable as the current process to run the +// death test. The child process is given the --gtest_filter and +// --gtest_internal_run_death_test flags such that it knows to run the +// current death test only. +DeathTest::TestRole WindowsDeathTest::AssumeRole() { + const UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const TestInfo* const info = impl->current_test_info(); + const int death_test_index = info->result()->death_test_count(); + + if (flag != nullptr) { + // ParseInternalRunDeathTestFlag() has performed all the necessary + // processing. + set_write_fd(flag->write_fd()); + return EXECUTE_TEST; + } + + // WindowsDeathTest uses an anonymous pipe to communicate results of + // a death test. + SECURITY_ATTRIBUTES handles_are_inheritable = {sizeof(SECURITY_ATTRIBUTES), + nullptr, TRUE}; + HANDLE read_handle, write_handle; + GTEST_DEATH_TEST_CHECK_( + ::CreatePipe(&read_handle, &write_handle, &handles_are_inheritable, + 0) // Default buffer size. + != FALSE); + set_read_fd(::_open_osfhandle(reinterpret_cast(read_handle), + O_RDONLY)); + write_handle_.Reset(write_handle); + event_handle_.Reset(::CreateEvent( + &handles_are_inheritable, + TRUE, // The event will automatically reset to non-signaled state. + FALSE, // The initial state is non-signalled. + nullptr)); // The even is unnamed. + GTEST_DEATH_TEST_CHECK_(event_handle_.Get() != nullptr); + const std::string filter_flag = std::string("--") + GTEST_FLAG_PREFIX_ + + "filter=" + info->test_suite_name() + "." + + info->name(); + const std::string internal_flag = + std::string("--") + GTEST_FLAG_PREFIX_ + + "internal_run_death_test=" + file_ + "|" + StreamableToString(line_) + + "|" + StreamableToString(death_test_index) + "|" + + StreamableToString(static_cast(::GetCurrentProcessId())) + + // size_t has the same width as pointers on both 32-bit and 64-bit + // Windows platforms. + // See http://msdn.microsoft.com/en-us/library/tcxf1dw6.aspx. + "|" + StreamableToString(reinterpret_cast(write_handle)) + "|" + + StreamableToString(reinterpret_cast(event_handle_.Get())); + + char executable_path[_MAX_PATH + 1]; // NOLINT + GTEST_DEATH_TEST_CHECK_(_MAX_PATH + 1 != ::GetModuleFileNameA(nullptr, + executable_path, + _MAX_PATH)); + + std::string command_line = + std::string(::GetCommandLineA()) + " " + filter_flag + " \"" + + internal_flag + "\""; + + DeathTest::set_last_death_test_message(""); + + CaptureStderr(); + // Flush the log buffers since the log streams are shared with the child. + FlushInfoLog(); + + // The child process will share the standard handles with the parent. + STARTUPINFOA startup_info; + memset(&startup_info, 0, sizeof(STARTUPINFO)); + startup_info.dwFlags = STARTF_USESTDHANDLES; + startup_info.hStdInput = ::GetStdHandle(STD_INPUT_HANDLE); + startup_info.hStdOutput = ::GetStdHandle(STD_OUTPUT_HANDLE); + startup_info.hStdError = ::GetStdHandle(STD_ERROR_HANDLE); + + PROCESS_INFORMATION process_info; + GTEST_DEATH_TEST_CHECK_( + ::CreateProcessA( + executable_path, const_cast(command_line.c_str()), + nullptr, // Returned process handle is not inheritable. + nullptr, // Returned thread handle is not inheritable. + TRUE, // Child inherits all inheritable handles (for write_handle_). + 0x0, // Default creation flags. + nullptr, // Inherit the parent's environment. + UnitTest::GetInstance()->original_working_dir(), &startup_info, + &process_info) != FALSE); + child_handle_.Reset(process_info.hProcess); + ::CloseHandle(process_info.hThread); + set_spawned(true); + return OVERSEE_TEST; +} + +# elif GTEST_OS_FUCHSIA + +class FuchsiaDeathTest : public DeathTestImpl { + public: + FuchsiaDeathTest(const char* a_statement, Matcher matcher, + const char* file, int line) + : DeathTestImpl(a_statement, std::move(matcher)), + file_(file), + line_(line) {} + + // All of these virtual functions are inherited from DeathTest. + int Wait() override; + TestRole AssumeRole() override; + std::string GetErrorLogs() override; + + private: + // The name of the file in which the death test is located. + const char* const file_; + // The line number on which the death test is located. + const int line_; + // The stderr data captured by the child process. + std::string captured_stderr_; + + zx::process child_process_; + zx::channel exception_channel_; + zx::socket stderr_socket_; +}; + +// Utility class for accumulating command-line arguments. +class Arguments { + public: + Arguments() { args_.push_back(nullptr); } + + ~Arguments() { + for (std::vector::iterator i = args_.begin(); i != args_.end(); + ++i) { + free(*i); + } + } + void AddArgument(const char* argument) { + args_.insert(args_.end() - 1, posix::StrDup(argument)); + } + + template + void AddArguments(const ::std::vector& arguments) { + for (typename ::std::vector::const_iterator i = arguments.begin(); + i != arguments.end(); + ++i) { + args_.insert(args_.end() - 1, posix::StrDup(i->c_str())); + } + } + char* const* Argv() { + return &args_[0]; + } + + int size() { + return static_cast(args_.size()) - 1; + } + + private: + std::vector args_; +}; + +// Waits for the child in a death test to exit, returning its exit +// status, or 0 if no child process exists. As a side effect, sets the +// outcome data member. +int FuchsiaDeathTest::Wait() { + const int kProcessKey = 0; + const int kSocketKey = 1; + const int kExceptionKey = 2; + + if (!spawned()) + return 0; + + // Create a port to wait for socket/task/exception events. + zx_status_t status_zx; + zx::port port; + status_zx = zx::port::create(0, &port); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + + // Register to wait for the child process to terminate. + status_zx = child_process_.wait_async( + port, kProcessKey, ZX_PROCESS_TERMINATED, 0); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + + // Register to wait for the socket to be readable or closed. + status_zx = stderr_socket_.wait_async( + port, kSocketKey, ZX_SOCKET_READABLE | ZX_SOCKET_PEER_CLOSED, 0); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + + // Register to wait for an exception. + status_zx = exception_channel_.wait_async( + port, kExceptionKey, ZX_CHANNEL_READABLE, 0); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + + bool process_terminated = false; + bool socket_closed = false; + do { + zx_port_packet_t packet = {}; + status_zx = port.wait(zx::time::infinite(), &packet); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + + if (packet.key == kExceptionKey) { + // Process encountered an exception. Kill it directly rather than + // letting other handlers process the event. We will get a kProcessKey + // event when the process actually terminates. + status_zx = child_process_.kill(); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + } else if (packet.key == kProcessKey) { + // Process terminated. + GTEST_DEATH_TEST_CHECK_(ZX_PKT_IS_SIGNAL_ONE(packet.type)); + GTEST_DEATH_TEST_CHECK_(packet.signal.observed & ZX_PROCESS_TERMINATED); + process_terminated = true; + } else if (packet.key == kSocketKey) { + GTEST_DEATH_TEST_CHECK_(ZX_PKT_IS_SIGNAL_ONE(packet.type)); + if (packet.signal.observed & ZX_SOCKET_READABLE) { + // Read data from the socket. + constexpr size_t kBufferSize = 1024; + do { + size_t old_length = captured_stderr_.length(); + size_t bytes_read = 0; + captured_stderr_.resize(old_length + kBufferSize); + status_zx = stderr_socket_.read( + 0, &captured_stderr_.front() + old_length, kBufferSize, + &bytes_read); + captured_stderr_.resize(old_length + bytes_read); + } while (status_zx == ZX_OK); + if (status_zx == ZX_ERR_PEER_CLOSED) { + socket_closed = true; + } else { + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_ERR_SHOULD_WAIT); + status_zx = stderr_socket_.wait_async( + port, kSocketKey, ZX_SOCKET_READABLE | ZX_SOCKET_PEER_CLOSED, 0); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + } + } else { + GTEST_DEATH_TEST_CHECK_(packet.signal.observed & ZX_SOCKET_PEER_CLOSED); + socket_closed = true; + } + } + } while (!process_terminated && !socket_closed); + + ReadAndInterpretStatusByte(); + + zx_info_process_t buffer; + status_zx = child_process_.get_info(ZX_INFO_PROCESS, &buffer, sizeof(buffer), + nullptr, nullptr); + GTEST_DEATH_TEST_CHECK_(status_zx == ZX_OK); + + GTEST_DEATH_TEST_CHECK_(buffer.flags & ZX_INFO_PROCESS_FLAG_EXITED); + set_status(static_cast(buffer.return_code)); + return status(); +} + +// The AssumeRole process for a Fuchsia death test. It creates a child +// process with the same executable as the current process to run the +// death test. The child process is given the --gtest_filter and +// --gtest_internal_run_death_test flags such that it knows to run the +// current death test only. +DeathTest::TestRole FuchsiaDeathTest::AssumeRole() { + const UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const TestInfo* const info = impl->current_test_info(); + const int death_test_index = info->result()->death_test_count(); + + if (flag != nullptr) { + // ParseInternalRunDeathTestFlag() has performed all the necessary + // processing. + set_write_fd(kFuchsiaReadPipeFd); + return EXECUTE_TEST; + } + + // Flush the log buffers since the log streams are shared with the child. + FlushInfoLog(); + + // Build the child process command line. + const std::string filter_flag = std::string("--") + GTEST_FLAG_PREFIX_ + + "filter=" + info->test_suite_name() + "." + + info->name(); + const std::string internal_flag = + std::string("--") + GTEST_FLAG_PREFIX_ + kInternalRunDeathTestFlag + "=" + + file_ + "|" + + StreamableToString(line_) + "|" + + StreamableToString(death_test_index); + Arguments args; + args.AddArguments(GetInjectableArgvs()); + args.AddArgument(filter_flag.c_str()); + args.AddArgument(internal_flag.c_str()); + + // Build the pipe for communication with the child. + zx_status_t status; + zx_handle_t child_pipe_handle; + int child_pipe_fd; + status = fdio_pipe_half(&child_pipe_fd, &child_pipe_handle); + GTEST_DEATH_TEST_CHECK_(status == ZX_OK); + set_read_fd(child_pipe_fd); + + // Set the pipe handle for the child. + fdio_spawn_action_t spawn_actions[2] = {}; + fdio_spawn_action_t* add_handle_action = &spawn_actions[0]; + add_handle_action->action = FDIO_SPAWN_ACTION_ADD_HANDLE; + add_handle_action->h.id = PA_HND(PA_FD, kFuchsiaReadPipeFd); + add_handle_action->h.handle = child_pipe_handle; + + // Create a socket pair will be used to receive the child process' stderr. + zx::socket stderr_producer_socket; + status = + zx::socket::create(0, &stderr_producer_socket, &stderr_socket_); + GTEST_DEATH_TEST_CHECK_(status >= 0); + int stderr_producer_fd = -1; + status = + fdio_fd_create(stderr_producer_socket.release(), &stderr_producer_fd); + GTEST_DEATH_TEST_CHECK_(status >= 0); + + // Make the stderr socket nonblocking. + GTEST_DEATH_TEST_CHECK_(fcntl(stderr_producer_fd, F_SETFL, 0) == 0); + + fdio_spawn_action_t* add_stderr_action = &spawn_actions[1]; + add_stderr_action->action = FDIO_SPAWN_ACTION_CLONE_FD; + add_stderr_action->fd.local_fd = stderr_producer_fd; + add_stderr_action->fd.target_fd = STDERR_FILENO; + + // Create a child job. + zx_handle_t child_job = ZX_HANDLE_INVALID; + status = zx_job_create(zx_job_default(), 0, & child_job); + GTEST_DEATH_TEST_CHECK_(status == ZX_OK); + zx_policy_basic_t policy; + policy.condition = ZX_POL_NEW_ANY; + policy.policy = ZX_POL_ACTION_ALLOW; + status = zx_job_set_policy( + child_job, ZX_JOB_POL_RELATIVE, ZX_JOB_POL_BASIC, &policy, 1); + GTEST_DEATH_TEST_CHECK_(status == ZX_OK); + + // Create an exception channel attached to the |child_job|, to allow + // us to suppress the system default exception handler from firing. + status = + zx_task_create_exception_channel( + child_job, 0, exception_channel_.reset_and_get_address()); + GTEST_DEATH_TEST_CHECK_(status == ZX_OK); + + // Spawn the child process. + status = fdio_spawn_etc( + child_job, FDIO_SPAWN_CLONE_ALL, args.Argv()[0], args.Argv(), nullptr, + 2, spawn_actions, child_process_.reset_and_get_address(), nullptr); + GTEST_DEATH_TEST_CHECK_(status == ZX_OK); + + set_spawned(true); + return OVERSEE_TEST; +} + +std::string FuchsiaDeathTest::GetErrorLogs() { + return captured_stderr_; +} + +#else // We are neither on Windows, nor on Fuchsia. + +// ForkingDeathTest provides implementations for most of the abstract +// methods of the DeathTest interface. Only the AssumeRole method is +// left undefined. +class ForkingDeathTest : public DeathTestImpl { + public: + ForkingDeathTest(const char* statement, Matcher matcher); + + // All of these virtual functions are inherited from DeathTest. + int Wait() override; + + protected: + void set_child_pid(pid_t child_pid) { child_pid_ = child_pid; } + + private: + // PID of child process during death test; 0 in the child process itself. + pid_t child_pid_; +}; + +// Constructs a ForkingDeathTest. +ForkingDeathTest::ForkingDeathTest(const char* a_statement, + Matcher matcher) + : DeathTestImpl(a_statement, std::move(matcher)), child_pid_(-1) {} + +// Waits for the child in a death test to exit, returning its exit +// status, or 0 if no child process exists. As a side effect, sets the +// outcome data member. +int ForkingDeathTest::Wait() { + if (!spawned()) + return 0; + + ReadAndInterpretStatusByte(); + + int status_value; + GTEST_DEATH_TEST_CHECK_SYSCALL_(waitpid(child_pid_, &status_value, 0)); + set_status(status_value); + return status_value; +} + +// A concrete death test class that forks, then immediately runs the test +// in the child process. +class NoExecDeathTest : public ForkingDeathTest { + public: + NoExecDeathTest(const char* a_statement, Matcher matcher) + : ForkingDeathTest(a_statement, std::move(matcher)) {} + TestRole AssumeRole() override; +}; + +// The AssumeRole process for a fork-and-run death test. It implements a +// straightforward fork, with a simple pipe to transmit the status byte. +DeathTest::TestRole NoExecDeathTest::AssumeRole() { + const size_t thread_count = GetThreadCount(); + if (thread_count != 1) { + GTEST_LOG_(WARNING) << DeathTestThreadWarning(thread_count); + } + + int pipe_fd[2]; + GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); + + DeathTest::set_last_death_test_message(""); + CaptureStderr(); + // When we fork the process below, the log file buffers are copied, but the + // file descriptors are shared. We flush all log files here so that closing + // the file descriptors in the child process doesn't throw off the + // synchronization between descriptors and buffers in the parent process. + // This is as close to the fork as possible to avoid a race condition in case + // there are multiple threads running before the death test, and another + // thread writes to the log file. + FlushInfoLog(); + + const pid_t child_pid = fork(); + GTEST_DEATH_TEST_CHECK_(child_pid != -1); + set_child_pid(child_pid); + if (child_pid == 0) { + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[0])); + set_write_fd(pipe_fd[1]); + // Redirects all logging to stderr in the child process to prevent + // concurrent writes to the log files. We capture stderr in the parent + // process and append the child process' output to a log. + LogToStderr(); + // Event forwarding to the listeners of event listener API mush be shut + // down in death test subprocesses. + GetUnitTestImpl()->listeners()->SuppressEventForwarding(); + g_in_fast_death_test_child = true; + return EXECUTE_TEST; + } else { + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); + set_read_fd(pipe_fd[0]); + set_spawned(true); + return OVERSEE_TEST; + } +} + +// A concrete death test class that forks and re-executes the main +// program from the beginning, with command-line flags set that cause +// only this specific death test to be run. +class ExecDeathTest : public ForkingDeathTest { + public: + ExecDeathTest(const char* a_statement, Matcher matcher, + const char* file, int line) + : ForkingDeathTest(a_statement, std::move(matcher)), + file_(file), + line_(line) {} + TestRole AssumeRole() override; + + private: + static ::std::vector GetArgvsForDeathTestChildProcess() { + ::std::vector args = GetInjectableArgvs(); +# if defined(GTEST_EXTRA_DEATH_TEST_COMMAND_LINE_ARGS_) + ::std::vector extra_args = + GTEST_EXTRA_DEATH_TEST_COMMAND_LINE_ARGS_(); + args.insert(args.end(), extra_args.begin(), extra_args.end()); +# endif // defined(GTEST_EXTRA_DEATH_TEST_COMMAND_LINE_ARGS_) + return args; + } + // The name of the file in which the death test is located. + const char* const file_; + // The line number on which the death test is located. + const int line_; +}; + +// Utility class for accumulating command-line arguments. +class Arguments { + public: + Arguments() { args_.push_back(nullptr); } + + ~Arguments() { + for (std::vector::iterator i = args_.begin(); i != args_.end(); + ++i) { + free(*i); + } + } + void AddArgument(const char* argument) { + args_.insert(args_.end() - 1, posix::StrDup(argument)); + } + + template + void AddArguments(const ::std::vector& arguments) { + for (typename ::std::vector::const_iterator i = arguments.begin(); + i != arguments.end(); + ++i) { + args_.insert(args_.end() - 1, posix::StrDup(i->c_str())); + } + } + char* const* Argv() { + return &args_[0]; + } + + private: + std::vector args_; +}; + +// A struct that encompasses the arguments to the child process of a +// threadsafe-style death test process. +struct ExecDeathTestArgs { + char* const* argv; // Command-line arguments for the child's call to exec + int close_fd; // File descriptor to close; the read end of a pipe +}; + +# if GTEST_OS_QNX +extern "C" char** environ; +# else // GTEST_OS_QNX +// The main function for a threadsafe-style death test child process. +// This function is called in a clone()-ed process and thus must avoid +// any potentially unsafe operations like malloc or libc functions. +static int ExecDeathTestChildMain(void* child_arg) { + ExecDeathTestArgs* const args = static_cast(child_arg); + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(args->close_fd)); + + // We need to execute the test program in the same environment where + // it was originally invoked. Therefore we change to the original + // working directory first. + const char* const original_dir = + UnitTest::GetInstance()->original_working_dir(); + // We can safely call chdir() as it's a direct system call. + if (chdir(original_dir) != 0) { + DeathTestAbort(std::string("chdir(\"") + original_dir + "\") failed: " + + GetLastErrnoDescription()); + return EXIT_FAILURE; + } + + // We can safely call execv() as it's almost a direct system call. We + // cannot use execvp() as it's a libc function and thus potentially + // unsafe. Since execv() doesn't search the PATH, the user must + // invoke the test program via a valid path that contains at least + // one path separator. + execv(args->argv[0], args->argv); + DeathTestAbort(std::string("execv(") + args->argv[0] + ", ...) in " + + original_dir + " failed: " + + GetLastErrnoDescription()); + return EXIT_FAILURE; +} +# endif // GTEST_OS_QNX + +# if GTEST_HAS_CLONE +// Two utility routines that together determine the direction the stack +// grows. +// This could be accomplished more elegantly by a single recursive +// function, but we want to guard against the unlikely possibility of +// a smart compiler optimizing the recursion away. +// +// GTEST_NO_INLINE_ is required to prevent GCC 4.6 from inlining +// StackLowerThanAddress into StackGrowsDown, which then doesn't give +// correct answer. +static void StackLowerThanAddress(const void* ptr, + bool* result) GTEST_NO_INLINE_; +// Make sure sanitizers do not tamper with the stack here. +// Ideally, we want to use `__builtin_frame_address` instead of a local variable +// address with sanitizer disabled, but it does not work when the +// compiler optimizes the stack frame out, which happens on PowerPC targets. +// HWAddressSanitizer add a random tag to the MSB of the local variable address, +// making comparison result unpredictable. +GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ +static void StackLowerThanAddress(const void* ptr, bool* result) { + int dummy = 0; + *result = std::less()(&dummy, ptr); +} + +// Make sure AddressSanitizer does not tamper with the stack here. +GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ +static bool StackGrowsDown() { + int dummy = 0; + bool result; + StackLowerThanAddress(&dummy, &result); + return result; +} +# endif // GTEST_HAS_CLONE + +// Spawns a child process with the same executable as the current process in +// a thread-safe manner and instructs it to run the death test. The +// implementation uses fork(2) + exec. On systems where clone(2) is +// available, it is used instead, being slightly more thread-safe. On QNX, +// fork supports only single-threaded environments, so this function uses +// spawn(2) there instead. The function dies with an error message if +// anything goes wrong. +static pid_t ExecDeathTestSpawnChild(char* const* argv, int close_fd) { + ExecDeathTestArgs args = { argv, close_fd }; + pid_t child_pid = -1; + +# if GTEST_OS_QNX + // Obtains the current directory and sets it to be closed in the child + // process. + const int cwd_fd = open(".", O_RDONLY); + GTEST_DEATH_TEST_CHECK_(cwd_fd != -1); + GTEST_DEATH_TEST_CHECK_SYSCALL_(fcntl(cwd_fd, F_SETFD, FD_CLOEXEC)); + // We need to execute the test program in the same environment where + // it was originally invoked. Therefore we change to the original + // working directory first. + const char* const original_dir = + UnitTest::GetInstance()->original_working_dir(); + // We can safely call chdir() as it's a direct system call. + if (chdir(original_dir) != 0) { + DeathTestAbort(std::string("chdir(\"") + original_dir + "\") failed: " + + GetLastErrnoDescription()); + return EXIT_FAILURE; + } + + int fd_flags; + // Set close_fd to be closed after spawn. + GTEST_DEATH_TEST_CHECK_SYSCALL_(fd_flags = fcntl(close_fd, F_GETFD)); + GTEST_DEATH_TEST_CHECK_SYSCALL_(fcntl(close_fd, F_SETFD, + fd_flags | FD_CLOEXEC)); + struct inheritance inherit = {0}; + // spawn is a system call. + child_pid = spawn(args.argv[0], 0, nullptr, &inherit, args.argv, environ); + // Restores the current working directory. + GTEST_DEATH_TEST_CHECK_(fchdir(cwd_fd) != -1); + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(cwd_fd)); + +# else // GTEST_OS_QNX +# if GTEST_OS_LINUX + // When a SIGPROF signal is received while fork() or clone() are executing, + // the process may hang. To avoid this, we ignore SIGPROF here and re-enable + // it after the call to fork()/clone() is complete. + struct sigaction saved_sigprof_action; + struct sigaction ignore_sigprof_action; + memset(&ignore_sigprof_action, 0, sizeof(ignore_sigprof_action)); + sigemptyset(&ignore_sigprof_action.sa_mask); + ignore_sigprof_action.sa_handler = SIG_IGN; + GTEST_DEATH_TEST_CHECK_SYSCALL_(sigaction( + SIGPROF, &ignore_sigprof_action, &saved_sigprof_action)); +# endif // GTEST_OS_LINUX + +# if GTEST_HAS_CLONE + const bool use_fork = GTEST_FLAG_GET(death_test_use_fork); + + if (!use_fork) { + static const bool stack_grows_down = StackGrowsDown(); + const auto stack_size = static_cast(getpagesize() * 2); + // MMAP_ANONYMOUS is not defined on Mac, so we use MAP_ANON instead. + void* const stack = mmap(nullptr, stack_size, PROT_READ | PROT_WRITE, + MAP_ANON | MAP_PRIVATE, -1, 0); + GTEST_DEATH_TEST_CHECK_(stack != MAP_FAILED); + + // Maximum stack alignment in bytes: For a downward-growing stack, this + // amount is subtracted from size of the stack space to get an address + // that is within the stack space and is aligned on all systems we care + // about. As far as I know there is no ABI with stack alignment greater + // than 64. We assume stack and stack_size already have alignment of + // kMaxStackAlignment. + const size_t kMaxStackAlignment = 64; + void* const stack_top = + static_cast(stack) + + (stack_grows_down ? stack_size - kMaxStackAlignment : 0); + GTEST_DEATH_TEST_CHECK_( + static_cast(stack_size) > kMaxStackAlignment && + reinterpret_cast(stack_top) % kMaxStackAlignment == 0); + + child_pid = clone(&ExecDeathTestChildMain, stack_top, SIGCHLD, &args); + + GTEST_DEATH_TEST_CHECK_(munmap(stack, stack_size) != -1); + } +# else + const bool use_fork = true; +# endif // GTEST_HAS_CLONE + + if (use_fork && (child_pid = fork()) == 0) { + ExecDeathTestChildMain(&args); + _exit(0); + } +# endif // GTEST_OS_QNX +# if GTEST_OS_LINUX + GTEST_DEATH_TEST_CHECK_SYSCALL_( + sigaction(SIGPROF, &saved_sigprof_action, nullptr)); +# endif // GTEST_OS_LINUX + + GTEST_DEATH_TEST_CHECK_(child_pid != -1); + return child_pid; +} + +// The AssumeRole process for a fork-and-exec death test. It re-executes the +// main program from the beginning, setting the --gtest_filter +// and --gtest_internal_run_death_test flags to cause only the current +// death test to be re-run. +DeathTest::TestRole ExecDeathTest::AssumeRole() { + const UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const TestInfo* const info = impl->current_test_info(); + const int death_test_index = info->result()->death_test_count(); + + if (flag != nullptr) { + set_write_fd(flag->write_fd()); + return EXECUTE_TEST; + } + + int pipe_fd[2]; + GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); + // Clear the close-on-exec flag on the write end of the pipe, lest + // it be closed when the child process does an exec: + GTEST_DEATH_TEST_CHECK_(fcntl(pipe_fd[1], F_SETFD, 0) != -1); + + const std::string filter_flag = std::string("--") + GTEST_FLAG_PREFIX_ + + "filter=" + info->test_suite_name() + "." + + info->name(); + const std::string internal_flag = std::string("--") + GTEST_FLAG_PREFIX_ + + "internal_run_death_test=" + file_ + "|" + + StreamableToString(line_) + "|" + + StreamableToString(death_test_index) + "|" + + StreamableToString(pipe_fd[1]); + Arguments args; + args.AddArguments(GetArgvsForDeathTestChildProcess()); + args.AddArgument(filter_flag.c_str()); + args.AddArgument(internal_flag.c_str()); + + DeathTest::set_last_death_test_message(""); + + CaptureStderr(); + // See the comment in NoExecDeathTest::AssumeRole for why the next line + // is necessary. + FlushInfoLog(); + + const pid_t child_pid = ExecDeathTestSpawnChild(args.Argv(), pipe_fd[0]); + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); + set_child_pid(child_pid); + set_read_fd(pipe_fd[0]); + set_spawned(true); + return OVERSEE_TEST; +} + +# endif // !GTEST_OS_WINDOWS + +// Creates a concrete DeathTest-derived class that depends on the +// --gtest_death_test_style flag, and sets the pointer pointed to +// by the "test" argument to its address. If the test should be +// skipped, sets that pointer to NULL. Returns true, unless the +// flag is set to an invalid value. +bool DefaultDeathTestFactory::Create(const char* statement, + Matcher matcher, + const char* file, int line, + DeathTest** test) { + UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const int death_test_index = impl->current_test_info() + ->increment_death_test_count(); + + if (flag != nullptr) { + if (death_test_index > flag->index()) { + DeathTest::set_last_death_test_message( + "Death test count (" + StreamableToString(death_test_index) + + ") somehow exceeded expected maximum (" + + StreamableToString(flag->index()) + ")"); + return false; + } + + if (!(flag->file() == file && flag->line() == line && + flag->index() == death_test_index)) { + *test = nullptr; + return true; + } + } + +# if GTEST_OS_WINDOWS + + if (GTEST_FLAG_GET(death_test_style) == "threadsafe" || + GTEST_FLAG_GET(death_test_style) == "fast") { + *test = new WindowsDeathTest(statement, std::move(matcher), file, line); + } + +# elif GTEST_OS_FUCHSIA + + if (GTEST_FLAG_GET(death_test_style) == "threadsafe" || + GTEST_FLAG_GET(death_test_style) == "fast") { + *test = new FuchsiaDeathTest(statement, std::move(matcher), file, line); + } + +# else + + if (GTEST_FLAG_GET(death_test_style) == "threadsafe") { + *test = new ExecDeathTest(statement, std::move(matcher), file, line); + } else if (GTEST_FLAG_GET(death_test_style) == "fast") { + *test = new NoExecDeathTest(statement, std::move(matcher)); + } + +# endif // GTEST_OS_WINDOWS + + else { // NOLINT - this is more readable than unbalanced brackets inside #if. + DeathTest::set_last_death_test_message("Unknown death test style \"" + + GTEST_FLAG_GET(death_test_style) + + "\" encountered"); + return false; + } + + return true; +} + +# if GTEST_OS_WINDOWS +// Recreates the pipe and event handles from the provided parameters, +// signals the event, and returns a file descriptor wrapped around the pipe +// handle. This function is called in the child process only. +static int GetStatusFileDescriptor(unsigned int parent_process_id, + size_t write_handle_as_size_t, + size_t event_handle_as_size_t) { + AutoHandle parent_process_handle(::OpenProcess(PROCESS_DUP_HANDLE, + FALSE, // Non-inheritable. + parent_process_id)); + if (parent_process_handle.Get() == INVALID_HANDLE_VALUE) { + DeathTestAbort("Unable to open parent process " + + StreamableToString(parent_process_id)); + } + + GTEST_CHECK_(sizeof(HANDLE) <= sizeof(size_t)); + + const HANDLE write_handle = + reinterpret_cast(write_handle_as_size_t); + HANDLE dup_write_handle; + + // The newly initialized handle is accessible only in the parent + // process. To obtain one accessible within the child, we need to use + // DuplicateHandle. + if (!::DuplicateHandle(parent_process_handle.Get(), write_handle, + ::GetCurrentProcess(), &dup_write_handle, + 0x0, // Requested privileges ignored since + // DUPLICATE_SAME_ACCESS is used. + FALSE, // Request non-inheritable handler. + DUPLICATE_SAME_ACCESS)) { + DeathTestAbort("Unable to duplicate the pipe handle " + + StreamableToString(write_handle_as_size_t) + + " from the parent process " + + StreamableToString(parent_process_id)); + } + + const HANDLE event_handle = reinterpret_cast(event_handle_as_size_t); + HANDLE dup_event_handle; + + if (!::DuplicateHandle(parent_process_handle.Get(), event_handle, + ::GetCurrentProcess(), &dup_event_handle, + 0x0, + FALSE, + DUPLICATE_SAME_ACCESS)) { + DeathTestAbort("Unable to duplicate the event handle " + + StreamableToString(event_handle_as_size_t) + + " from the parent process " + + StreamableToString(parent_process_id)); + } + + const int write_fd = + ::_open_osfhandle(reinterpret_cast(dup_write_handle), O_APPEND); + if (write_fd == -1) { + DeathTestAbort("Unable to convert pipe handle " + + StreamableToString(write_handle_as_size_t) + + " to a file descriptor"); + } + + // Signals the parent that the write end of the pipe has been acquired + // so the parent can release its own write end. + ::SetEvent(dup_event_handle); + + return write_fd; +} +# endif // GTEST_OS_WINDOWS + +// Returns a newly created InternalRunDeathTestFlag object with fields +// initialized from the GTEST_FLAG(internal_run_death_test) flag if +// the flag is specified; otherwise returns NULL. +InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag() { + if (GTEST_FLAG_GET(internal_run_death_test) == "") return nullptr; + + // GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we + // can use it here. + int line = -1; + int index = -1; + ::std::vector< ::std::string> fields; + SplitString(GTEST_FLAG_GET(internal_run_death_test), '|', &fields); + int write_fd = -1; + +# if GTEST_OS_WINDOWS + + unsigned int parent_process_id = 0; + size_t write_handle_as_size_t = 0; + size_t event_handle_as_size_t = 0; + + if (fields.size() != 6 + || !ParseNaturalNumber(fields[1], &line) + || !ParseNaturalNumber(fields[2], &index) + || !ParseNaturalNumber(fields[3], &parent_process_id) + || !ParseNaturalNumber(fields[4], &write_handle_as_size_t) + || !ParseNaturalNumber(fields[5], &event_handle_as_size_t)) { + DeathTestAbort("Bad --gtest_internal_run_death_test flag: " + + GTEST_FLAG_GET(internal_run_death_test)); + } + write_fd = GetStatusFileDescriptor(parent_process_id, + write_handle_as_size_t, + event_handle_as_size_t); + +# elif GTEST_OS_FUCHSIA + + if (fields.size() != 3 + || !ParseNaturalNumber(fields[1], &line) + || !ParseNaturalNumber(fields[2], &index)) { + DeathTestAbort("Bad --gtest_internal_run_death_test flag: " + + GTEST_FLAG_GET(internal_run_death_test)); + } + +# else + + if (fields.size() != 4 + || !ParseNaturalNumber(fields[1], &line) + || !ParseNaturalNumber(fields[2], &index) + || !ParseNaturalNumber(fields[3], &write_fd)) { + DeathTestAbort("Bad --gtest_internal_run_death_test flag: " + + GTEST_FLAG_GET(internal_run_death_test)); + } + +# endif // GTEST_OS_WINDOWS + + return new InternalRunDeathTestFlag(fields[0], line, index, write_fd); +} + +} // namespace internal + +#endif // GTEST_HAS_DEATH_TEST + +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-filepath.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-filepath.cc new file mode 100644 index 000000000000..0b5629401b5a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-filepath.cc @@ -0,0 +1,369 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "gtest/internal/gtest-filepath.h" + +#include +#include "gtest/internal/gtest-port.h" +#include "gtest/gtest-message.h" + +#if GTEST_OS_WINDOWS_MOBILE +# include +#elif GTEST_OS_WINDOWS +# include +# include +#else +# include +# include // Some Linux distributions define PATH_MAX here. +#endif // GTEST_OS_WINDOWS_MOBILE + +#include "gtest/internal/gtest-string.h" + +#if GTEST_OS_WINDOWS +# define GTEST_PATH_MAX_ _MAX_PATH +#elif defined(PATH_MAX) +# define GTEST_PATH_MAX_ PATH_MAX +#elif defined(_XOPEN_PATH_MAX) +# define GTEST_PATH_MAX_ _XOPEN_PATH_MAX +#else +# define GTEST_PATH_MAX_ _POSIX_PATH_MAX +#endif // GTEST_OS_WINDOWS + +namespace testing { +namespace internal { + +#if GTEST_OS_WINDOWS +// On Windows, '\\' is the standard path separator, but many tools and the +// Windows API also accept '/' as an alternate path separator. Unless otherwise +// noted, a file path can contain either kind of path separators, or a mixture +// of them. +const char kPathSeparator = '\\'; +const char kAlternatePathSeparator = '/'; +const char kAlternatePathSeparatorString[] = "/"; +# if GTEST_OS_WINDOWS_MOBILE +// Windows CE doesn't have a current directory. You should not use +// the current directory in tests on Windows CE, but this at least +// provides a reasonable fallback. +const char kCurrentDirectoryString[] = "\\"; +// Windows CE doesn't define INVALID_FILE_ATTRIBUTES +const DWORD kInvalidFileAttributes = 0xffffffff; +# else +const char kCurrentDirectoryString[] = ".\\"; +# endif // GTEST_OS_WINDOWS_MOBILE +#else +const char kPathSeparator = '/'; +const char kCurrentDirectoryString[] = "./"; +#endif // GTEST_OS_WINDOWS + +// Returns whether the given character is a valid path separator. +static bool IsPathSeparator(char c) { +#if GTEST_HAS_ALT_PATH_SEP_ + return (c == kPathSeparator) || (c == kAlternatePathSeparator); +#else + return c == kPathSeparator; +#endif +} + +// Returns the current working directory, or "" if unsuccessful. +FilePath FilePath::GetCurrentDir() { +#if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_PHONE || \ + GTEST_OS_WINDOWS_RT || GTEST_OS_ESP8266 || GTEST_OS_ESP32 || \ + GTEST_OS_XTENSA + // These platforms do not have a current directory, so we just return + // something reasonable. + return FilePath(kCurrentDirectoryString); +#elif GTEST_OS_WINDOWS + char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; + return FilePath(_getcwd(cwd, sizeof(cwd)) == nullptr ? "" : cwd); +#else + char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; + char* result = getcwd(cwd, sizeof(cwd)); +# if GTEST_OS_NACL + // getcwd will likely fail in NaCl due to the sandbox, so return something + // reasonable. The user may have provided a shim implementation for getcwd, + // however, so fallback only when failure is detected. + return FilePath(result == nullptr ? kCurrentDirectoryString : cwd); +# endif // GTEST_OS_NACL + return FilePath(result == nullptr ? "" : cwd); +#endif // GTEST_OS_WINDOWS_MOBILE +} + +// Returns a copy of the FilePath with the case-insensitive extension removed. +// Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns +// FilePath("dir/file"). If a case-insensitive extension is not +// found, returns a copy of the original FilePath. +FilePath FilePath::RemoveExtension(const char* extension) const { + const std::string dot_extension = std::string(".") + extension; + if (String::EndsWithCaseInsensitive(pathname_, dot_extension)) { + return FilePath(pathname_.substr( + 0, pathname_.length() - dot_extension.length())); + } + return *this; +} + +// Returns a pointer to the last occurrence of a valid path separator in +// the FilePath. On Windows, for example, both '/' and '\' are valid path +// separators. Returns NULL if no path separator was found. +const char* FilePath::FindLastPathSeparator() const { + const char* const last_sep = strrchr(c_str(), kPathSeparator); +#if GTEST_HAS_ALT_PATH_SEP_ + const char* const last_alt_sep = strrchr(c_str(), kAlternatePathSeparator); + // Comparing two pointers of which only one is NULL is undefined. + if (last_alt_sep != nullptr && + (last_sep == nullptr || last_alt_sep > last_sep)) { + return last_alt_sep; + } +#endif + return last_sep; +} + +// Returns a copy of the FilePath with the directory part removed. +// Example: FilePath("path/to/file").RemoveDirectoryName() returns +// FilePath("file"). If there is no directory part ("just_a_file"), it returns +// the FilePath unmodified. If there is no file part ("just_a_dir/") it +// returns an empty FilePath (""). +// On Windows platform, '\' is the path separator, otherwise it is '/'. +FilePath FilePath::RemoveDirectoryName() const { + const char* const last_sep = FindLastPathSeparator(); + return last_sep ? FilePath(last_sep + 1) : *this; +} + +// RemoveFileName returns the directory path with the filename removed. +// Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". +// If the FilePath is "a_file" or "/a_file", RemoveFileName returns +// FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does +// not have a file, like "just/a/dir/", it returns the FilePath unmodified. +// On Windows platform, '\' is the path separator, otherwise it is '/'. +FilePath FilePath::RemoveFileName() const { + const char* const last_sep = FindLastPathSeparator(); + std::string dir; + if (last_sep) { + dir = std::string(c_str(), static_cast(last_sep + 1 - c_str())); + } else { + dir = kCurrentDirectoryString; + } + return FilePath(dir); +} + +// Helper functions for naming files in a directory for xml output. + +// Given directory = "dir", base_name = "test", number = 0, +// extension = "xml", returns "dir/test.xml". If number is greater +// than zero (e.g., 12), returns "dir/test_12.xml". +// On Windows platform, uses \ as the separator rather than /. +FilePath FilePath::MakeFileName(const FilePath& directory, + const FilePath& base_name, + int number, + const char* extension) { + std::string file; + if (number == 0) { + file = base_name.string() + "." + extension; + } else { + file = base_name.string() + "_" + StreamableToString(number) + + "." + extension; + } + return ConcatPaths(directory, FilePath(file)); +} + +// Given directory = "dir", relative_path = "test.xml", returns "dir/test.xml". +// On Windows, uses \ as the separator rather than /. +FilePath FilePath::ConcatPaths(const FilePath& directory, + const FilePath& relative_path) { + if (directory.IsEmpty()) + return relative_path; + const FilePath dir(directory.RemoveTrailingPathSeparator()); + return FilePath(dir.string() + kPathSeparator + relative_path.string()); +} + +// Returns true if pathname describes something findable in the file-system, +// either a file, directory, or whatever. +bool FilePath::FileOrDirectoryExists() const { +#if GTEST_OS_WINDOWS_MOBILE + LPCWSTR unicode = String::AnsiToUtf16(pathname_.c_str()); + const DWORD attributes = GetFileAttributes(unicode); + delete [] unicode; + return attributes != kInvalidFileAttributes; +#else + posix::StatStruct file_stat{}; + return posix::Stat(pathname_.c_str(), &file_stat) == 0; +#endif // GTEST_OS_WINDOWS_MOBILE +} + +// Returns true if pathname describes a directory in the file-system +// that exists. +bool FilePath::DirectoryExists() const { + bool result = false; +#if GTEST_OS_WINDOWS + // Don't strip off trailing separator if path is a root directory on + // Windows (like "C:\\"). + const FilePath& path(IsRootDirectory() ? *this : + RemoveTrailingPathSeparator()); +#else + const FilePath& path(*this); +#endif + +#if GTEST_OS_WINDOWS_MOBILE + LPCWSTR unicode = String::AnsiToUtf16(path.c_str()); + const DWORD attributes = GetFileAttributes(unicode); + delete [] unicode; + if ((attributes != kInvalidFileAttributes) && + (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + result = true; + } +#else + posix::StatStruct file_stat{}; + result = posix::Stat(path.c_str(), &file_stat) == 0 && + posix::IsDir(file_stat); +#endif // GTEST_OS_WINDOWS_MOBILE + + return result; +} + +// Returns true if pathname describes a root directory. (Windows has one +// root directory per disk drive.) +bool FilePath::IsRootDirectory() const { +#if GTEST_OS_WINDOWS + return pathname_.length() == 3 && IsAbsolutePath(); +#else + return pathname_.length() == 1 && IsPathSeparator(pathname_.c_str()[0]); +#endif +} + +// Returns true if pathname describes an absolute path. +bool FilePath::IsAbsolutePath() const { + const char* const name = pathname_.c_str(); +#if GTEST_OS_WINDOWS + return pathname_.length() >= 3 && + ((name[0] >= 'a' && name[0] <= 'z') || + (name[0] >= 'A' && name[0] <= 'Z')) && + name[1] == ':' && + IsPathSeparator(name[2]); +#else + return IsPathSeparator(name[0]); +#endif +} + +// Returns a pathname for a file that does not currently exist. The pathname +// will be directory/base_name.extension or +// directory/base_name_.extension if directory/base_name.extension +// already exists. The number will be incremented until a pathname is found +// that does not already exist. +// Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. +// There could be a race condition if two or more processes are calling this +// function at the same time -- they could both pick the same filename. +FilePath FilePath::GenerateUniqueFileName(const FilePath& directory, + const FilePath& base_name, + const char* extension) { + FilePath full_pathname; + int number = 0; + do { + full_pathname.Set(MakeFileName(directory, base_name, number++, extension)); + } while (full_pathname.FileOrDirectoryExists()); + return full_pathname; +} + +// Returns true if FilePath ends with a path separator, which indicates that +// it is intended to represent a directory. Returns false otherwise. +// This does NOT check that a directory (or file) actually exists. +bool FilePath::IsDirectory() const { + return !pathname_.empty() && + IsPathSeparator(pathname_.c_str()[pathname_.length() - 1]); +} + +// Create directories so that path exists. Returns true if successful or if +// the directories already exist; returns false if unable to create directories +// for any reason. +bool FilePath::CreateDirectoriesRecursively() const { + if (!this->IsDirectory()) { + return false; + } + + if (pathname_.length() == 0 || this->DirectoryExists()) { + return true; + } + + const FilePath parent(this->RemoveTrailingPathSeparator().RemoveFileName()); + return parent.CreateDirectoriesRecursively() && this->CreateFolder(); +} + +// Create the directory so that path exists. Returns true if successful or +// if the directory already exists; returns false if unable to create the +// directory for any reason, including if the parent directory does not +// exist. Not named "CreateDirectory" because that's a macro on Windows. +bool FilePath::CreateFolder() const { +#if GTEST_OS_WINDOWS_MOBILE + FilePath removed_sep(this->RemoveTrailingPathSeparator()); + LPCWSTR unicode = String::AnsiToUtf16(removed_sep.c_str()); + int result = CreateDirectory(unicode, nullptr) ? 0 : -1; + delete [] unicode; +#elif GTEST_OS_WINDOWS + int result = _mkdir(pathname_.c_str()); +#elif GTEST_OS_ESP8266 || GTEST_OS_XTENSA + // do nothing + int result = 0; +#else + int result = mkdir(pathname_.c_str(), 0777); +#endif // GTEST_OS_WINDOWS_MOBILE + + if (result == -1) { + return this->DirectoryExists(); // An error is OK if the directory exists. + } + return true; // No error. +} + +// If input name has a trailing separator character, remove it and return the +// name, otherwise return the name string unmodified. +// On Windows platform, uses \ as the separator, other platforms use /. +FilePath FilePath::RemoveTrailingPathSeparator() const { + return IsDirectory() + ? FilePath(pathname_.substr(0, pathname_.length() - 1)) + : *this; +} + +// Removes any redundant separators that might be in the pathname. +// For example, "bar///foo" becomes "bar/foo". Does not eliminate other +// redundancies that might be in a pathname involving "." or "..". +void FilePath::Normalize() { + auto out = pathname_.begin(); + + for (const char character : pathname_) { + if (!IsPathSeparator(character)) { + *(out++) = character; + } else if (out == pathname_.begin() || *std::prev(out) != kPathSeparator) { + *(out++) = kPathSeparator; + } else { + continue; + } + } + + pathname_.erase(out, pathname_.end()); +} + +} // namespace internal +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-internal-inl.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-internal-inl.h new file mode 100644 index 000000000000..3f5551d18c1a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-internal-inl.h @@ -0,0 +1,1205 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Utility functions and classes used by the Google C++ testing framework.// +// This file contains purely Google Test's internal implementation. Please +// DO NOT #INCLUDE IT IN A USER PROGRAM. + +#ifndef GOOGLETEST_SRC_GTEST_INTERNAL_INL_H_ +#define GOOGLETEST_SRC_GTEST_INTERNAL_INL_H_ + +#ifndef _WIN32_WCE +# include +#endif // !_WIN32_WCE +#include +#include // For strtoll/_strtoul64/malloc/free. +#include // For memmove. + +#include +#include +#include +#include +#include + +#include "gtest/internal/gtest-port.h" + +#if GTEST_CAN_STREAM_RESULTS_ +# include // NOLINT +# include // NOLINT +#endif + +#if GTEST_OS_WINDOWS +# include // NOLINT +#endif // GTEST_OS_WINDOWS + +#include "gtest/gtest.h" +#include "gtest/gtest-spi.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +// Declares the flags. +// +// We don't want the users to modify this flag in the code, but want +// Google Test's own unit tests to be able to access it. Therefore we +// declare it here as opposed to in gtest.h. +GTEST_DECLARE_bool_(death_test_use_fork); + +namespace testing { +namespace internal { + +// The value of GetTestTypeId() as seen from within the Google Test +// library. This is solely for testing GetTestTypeId(). +GTEST_API_ extern const TypeId kTestTypeIdInGoogleTest; + +// A valid random seed must be in [1, kMaxRandomSeed]. +const int kMaxRandomSeed = 99999; + +// g_help_flag is true if and only if the --help flag or an equivalent form +// is specified on the command line. +GTEST_API_ extern bool g_help_flag; + +// Returns the current time in milliseconds. +GTEST_API_ TimeInMillis GetTimeInMillis(); + +// Returns true if and only if Google Test should use colors in the output. +GTEST_API_ bool ShouldUseColor(bool stdout_is_tty); + +// Formats the given time in milliseconds as seconds. +GTEST_API_ std::string FormatTimeInMillisAsSeconds(TimeInMillis ms); + +// Converts the given time in milliseconds to a date string in the ISO 8601 +// format, without the timezone information. N.B.: due to the use the +// non-reentrant localtime() function, this function is not thread safe. Do +// not use it in any code that can be called from multiple threads. +GTEST_API_ std::string FormatEpochTimeInMillisAsIso8601(TimeInMillis ms); + +// Parses a string for an Int32 flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +GTEST_API_ bool ParseFlag(const char* str, const char* flag, int32_t* value); + +// Returns a random seed in range [1, kMaxRandomSeed] based on the +// given --gtest_random_seed flag value. +inline int GetRandomSeedFromFlag(int32_t random_seed_flag) { + const unsigned int raw_seed = (random_seed_flag == 0) ? + static_cast(GetTimeInMillis()) : + static_cast(random_seed_flag); + + // Normalizes the actual seed to range [1, kMaxRandomSeed] such that + // it's easy to type. + const int normalized_seed = + static_cast((raw_seed - 1U) % + static_cast(kMaxRandomSeed)) + 1; + return normalized_seed; +} + +// Returns the first valid random seed after 'seed'. The behavior is +// undefined if 'seed' is invalid. The seed after kMaxRandomSeed is +// considered to be 1. +inline int GetNextRandomSeed(int seed) { + GTEST_CHECK_(1 <= seed && seed <= kMaxRandomSeed) + << "Invalid random seed " << seed << " - must be in [1, " + << kMaxRandomSeed << "]."; + const int next_seed = seed + 1; + return (next_seed > kMaxRandomSeed) ? 1 : next_seed; +} + +// This class saves the values of all Google Test flags in its c'tor, and +// restores them in its d'tor. +class GTestFlagSaver { + public: + // The c'tor. + GTestFlagSaver() { + also_run_disabled_tests_ = GTEST_FLAG_GET(also_run_disabled_tests); + break_on_failure_ = GTEST_FLAG_GET(break_on_failure); + catch_exceptions_ = GTEST_FLAG_GET(catch_exceptions); + color_ = GTEST_FLAG_GET(color); + death_test_style_ = GTEST_FLAG_GET(death_test_style); + death_test_use_fork_ = GTEST_FLAG_GET(death_test_use_fork); + fail_fast_ = GTEST_FLAG_GET(fail_fast); + filter_ = GTEST_FLAG_GET(filter); + internal_run_death_test_ = GTEST_FLAG_GET(internal_run_death_test); + list_tests_ = GTEST_FLAG_GET(list_tests); + output_ = GTEST_FLAG_GET(output); + brief_ = GTEST_FLAG_GET(brief); + print_time_ = GTEST_FLAG_GET(print_time); + print_utf8_ = GTEST_FLAG_GET(print_utf8); + random_seed_ = GTEST_FLAG_GET(random_seed); + repeat_ = GTEST_FLAG_GET(repeat); + recreate_environments_when_repeating_ = + GTEST_FLAG_GET(recreate_environments_when_repeating); + shuffle_ = GTEST_FLAG_GET(shuffle); + stack_trace_depth_ = GTEST_FLAG_GET(stack_trace_depth); + stream_result_to_ = GTEST_FLAG_GET(stream_result_to); + throw_on_failure_ = GTEST_FLAG_GET(throw_on_failure); + } + + // The d'tor is not virtual. DO NOT INHERIT FROM THIS CLASS. + ~GTestFlagSaver() { + GTEST_FLAG_SET(also_run_disabled_tests, also_run_disabled_tests_); + GTEST_FLAG_SET(break_on_failure, break_on_failure_); + GTEST_FLAG_SET(catch_exceptions, catch_exceptions_); + GTEST_FLAG_SET(color, color_); + GTEST_FLAG_SET(death_test_style, death_test_style_); + GTEST_FLAG_SET(death_test_use_fork, death_test_use_fork_); + GTEST_FLAG_SET(filter, filter_); + GTEST_FLAG_SET(fail_fast, fail_fast_); + GTEST_FLAG_SET(internal_run_death_test, internal_run_death_test_); + GTEST_FLAG_SET(list_tests, list_tests_); + GTEST_FLAG_SET(output, output_); + GTEST_FLAG_SET(brief, brief_); + GTEST_FLAG_SET(print_time, print_time_); + GTEST_FLAG_SET(print_utf8, print_utf8_); + GTEST_FLAG_SET(random_seed, random_seed_); + GTEST_FLAG_SET(repeat, repeat_); + GTEST_FLAG_SET(recreate_environments_when_repeating, + recreate_environments_when_repeating_); + GTEST_FLAG_SET(shuffle, shuffle_); + GTEST_FLAG_SET(stack_trace_depth, stack_trace_depth_); + GTEST_FLAG_SET(stream_result_to, stream_result_to_); + GTEST_FLAG_SET(throw_on_failure, throw_on_failure_); + } + + private: + // Fields for saving the original values of flags. + bool also_run_disabled_tests_; + bool break_on_failure_; + bool catch_exceptions_; + std::string color_; + std::string death_test_style_; + bool death_test_use_fork_; + bool fail_fast_; + std::string filter_; + std::string internal_run_death_test_; + bool list_tests_; + std::string output_; + bool brief_; + bool print_time_; + bool print_utf8_; + int32_t random_seed_; + int32_t repeat_; + bool recreate_environments_when_repeating_; + bool shuffle_; + int32_t stack_trace_depth_; + std::string stream_result_to_; + bool throw_on_failure_; +} GTEST_ATTRIBUTE_UNUSED_; + +// Converts a Unicode code point to a narrow string in UTF-8 encoding. +// code_point parameter is of type UInt32 because wchar_t may not be +// wide enough to contain a code point. +// If the code_point is not a valid Unicode code point +// (i.e. outside of Unicode range U+0 to U+10FFFF) it will be converted +// to "(Invalid Unicode 0xXXXXXXXX)". +GTEST_API_ std::string CodePointToUtf8(uint32_t code_point); + +// Converts a wide string to a narrow string in UTF-8 encoding. +// The wide string is assumed to have the following encoding: +// UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin) +// UTF-32 if sizeof(wchar_t) == 4 (on Linux) +// Parameter str points to a null-terminated wide string. +// Parameter num_chars may additionally limit the number +// of wchar_t characters processed. -1 is used when the entire string +// should be processed. +// If the string contains code points that are not valid Unicode code points +// (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding +// and contains invalid UTF-16 surrogate pairs, values in those pairs +// will be encoded as individual Unicode characters from Basic Normal Plane. +GTEST_API_ std::string WideStringToUtf8(const wchar_t* str, int num_chars); + +// Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file +// if the variable is present. If a file already exists at this location, this +// function will write over it. If the variable is present, but the file cannot +// be created, prints an error and exits. +void WriteToShardStatusFileIfNeeded(); + +// Checks whether sharding is enabled by examining the relevant +// environment variable values. If the variables are present, +// but inconsistent (e.g., shard_index >= total_shards), prints +// an error and exits. If in_subprocess_for_death_test, sharding is +// disabled because it must only be applied to the original test +// process. Otherwise, we could filter out death tests we intended to execute. +GTEST_API_ bool ShouldShard(const char* total_shards_str, + const char* shard_index_str, + bool in_subprocess_for_death_test); + +// Parses the environment variable var as a 32-bit integer. If it is unset, +// returns default_val. If it is not a 32-bit integer, prints an error and +// and aborts. +GTEST_API_ int32_t Int32FromEnvOrDie(const char* env_var, int32_t default_val); + +// Given the total number of shards, the shard index, and the test id, +// returns true if and only if the test should be run on this shard. The test id +// is some arbitrary but unique non-negative integer assigned to each test +// method. Assumes that 0 <= shard_index < total_shards. +GTEST_API_ bool ShouldRunTestOnShard( + int total_shards, int shard_index, int test_id); + +// STL container utilities. + +// Returns the number of elements in the given container that satisfy +// the given predicate. +template +inline int CountIf(const Container& c, Predicate predicate) { + // Implemented as an explicit loop since std::count_if() in libCstd on + // Solaris has a non-standard signature. + int count = 0; + for (auto it = c.begin(); it != c.end(); ++it) { + if (predicate(*it)) + ++count; + } + return count; +} + +// Applies a function/functor to each element in the container. +template +void ForEach(const Container& c, Functor functor) { + std::for_each(c.begin(), c.end(), functor); +} + +// Returns the i-th element of the vector, or default_value if i is not +// in range [0, v.size()). +template +inline E GetElementOr(const std::vector& v, int i, E default_value) { + return (i < 0 || i >= static_cast(v.size())) ? default_value + : v[static_cast(i)]; +} + +// Performs an in-place shuffle of a range of the vector's elements. +// 'begin' and 'end' are element indices as an STL-style range; +// i.e. [begin, end) are shuffled, where 'end' == size() means to +// shuffle to the end of the vector. +template +void ShuffleRange(internal::Random* random, int begin, int end, + std::vector* v) { + const int size = static_cast(v->size()); + GTEST_CHECK_(0 <= begin && begin <= size) + << "Invalid shuffle range start " << begin << ": must be in range [0, " + << size << "]."; + GTEST_CHECK_(begin <= end && end <= size) + << "Invalid shuffle range finish " << end << ": must be in range [" + << begin << ", " << size << "]."; + + // Fisher-Yates shuffle, from + // http://en.wikipedia.org/wiki/Fisher-Yates_shuffle + for (int range_width = end - begin; range_width >= 2; range_width--) { + const int last_in_range = begin + range_width - 1; + const int selected = + begin + + static_cast(random->Generate(static_cast(range_width))); + std::swap((*v)[static_cast(selected)], + (*v)[static_cast(last_in_range)]); + } +} + +// Performs an in-place shuffle of the vector's elements. +template +inline void Shuffle(internal::Random* random, std::vector* v) { + ShuffleRange(random, 0, static_cast(v->size()), v); +} + +// A function for deleting an object. Handy for being used as a +// functor. +template +static void Delete(T* x) { + delete x; +} + +// A predicate that checks the key of a TestProperty against a known key. +// +// TestPropertyKeyIs is copyable. +class TestPropertyKeyIs { + public: + // Constructor. + // + // TestPropertyKeyIs has NO default constructor. + explicit TestPropertyKeyIs(const std::string& key) : key_(key) {} + + // Returns true if and only if the test name of test property matches on key_. + bool operator()(const TestProperty& test_property) const { + return test_property.key() == key_; + } + + private: + std::string key_; +}; + +// Class UnitTestOptions. +// +// This class contains functions for processing options the user +// specifies when running the tests. It has only static members. +// +// In most cases, the user can specify an option using either an +// environment variable or a command line flag. E.g. you can set the +// test filter using either GTEST_FILTER or --gtest_filter. If both +// the variable and the flag are present, the latter overrides the +// former. +class GTEST_API_ UnitTestOptions { + public: + // Functions for processing the gtest_output flag. + + // Returns the output format, or "" for normal printed output. + static std::string GetOutputFormat(); + + // Returns the absolute path of the requested output file, or the + // default (test_detail.xml in the original working directory) if + // none was explicitly specified. + static std::string GetAbsolutePathToOutputFile(); + + // Functions for processing the gtest_filter flag. + + // Returns true if and only if the user-specified filter matches the test + // suite name and the test name. + static bool FilterMatchesTest(const std::string& test_suite_name, + const std::string& test_name); + +#if GTEST_OS_WINDOWS + // Function for supporting the gtest_catch_exception flag. + + // Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the + // given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. + // This function is useful as an __except condition. + static int GTestShouldProcessSEH(DWORD exception_code); +#endif // GTEST_OS_WINDOWS + + // Returns true if "name" matches the ':' separated list of glob-style + // filters in "filter". + static bool MatchesFilter(const std::string& name, const char* filter); +}; + +// Returns the current application's name, removing directory path if that +// is present. Used by UnitTestOptions::GetOutputFile. +GTEST_API_ FilePath GetCurrentExecutableName(); + +// The role interface for getting the OS stack trace as a string. +class OsStackTraceGetterInterface { + public: + OsStackTraceGetterInterface() {} + virtual ~OsStackTraceGetterInterface() {} + + // Returns the current OS stack trace as an std::string. Parameters: + // + // max_depth - the maximum number of stack frames to be included + // in the trace. + // skip_count - the number of top frames to be skipped; doesn't count + // against max_depth. + virtual std::string CurrentStackTrace(int max_depth, int skip_count) = 0; + + // UponLeavingGTest() should be called immediately before Google Test calls + // user code. It saves some information about the current stack that + // CurrentStackTrace() will use to find and hide Google Test stack frames. + virtual void UponLeavingGTest() = 0; + + // This string is inserted in place of stack frames that are part of + // Google Test's implementation. + static const char* const kElidedFramesMarker; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetterInterface); +}; + +// A working implementation of the OsStackTraceGetterInterface interface. +class OsStackTraceGetter : public OsStackTraceGetterInterface { + public: + OsStackTraceGetter() {} + + std::string CurrentStackTrace(int max_depth, int skip_count) override; + void UponLeavingGTest() override; + + private: +#if GTEST_HAS_ABSL + Mutex mutex_; // Protects all internal state. + + // We save the stack frame below the frame that calls user code. + // We do this because the address of the frame immediately below + // the user code changes between the call to UponLeavingGTest() + // and any calls to the stack trace code from within the user code. + void* caller_frame_ = nullptr; +#endif // GTEST_HAS_ABSL + + GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetter); +}; + +// Information about a Google Test trace point. +struct TraceInfo { + const char* file; + int line; + std::string message; +}; + +// This is the default global test part result reporter used in UnitTestImpl. +// This class should only be used by UnitTestImpl. +class DefaultGlobalTestPartResultReporter + : public TestPartResultReporterInterface { + public: + explicit DefaultGlobalTestPartResultReporter(UnitTestImpl* unit_test); + // Implements the TestPartResultReporterInterface. Reports the test part + // result in the current test. + void ReportTestPartResult(const TestPartResult& result) override; + + private: + UnitTestImpl* const unit_test_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultGlobalTestPartResultReporter); +}; + +// This is the default per thread test part result reporter used in +// UnitTestImpl. This class should only be used by UnitTestImpl. +class DefaultPerThreadTestPartResultReporter + : public TestPartResultReporterInterface { + public: + explicit DefaultPerThreadTestPartResultReporter(UnitTestImpl* unit_test); + // Implements the TestPartResultReporterInterface. The implementation just + // delegates to the current global test part result reporter of *unit_test_. + void ReportTestPartResult(const TestPartResult& result) override; + + private: + UnitTestImpl* const unit_test_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultPerThreadTestPartResultReporter); +}; + +// The private implementation of the UnitTest class. We don't protect +// the methods under a mutex, as this class is not accessible by a +// user and the UnitTest class that delegates work to this class does +// proper locking. +class GTEST_API_ UnitTestImpl { + public: + explicit UnitTestImpl(UnitTest* parent); + virtual ~UnitTestImpl(); + + // There are two different ways to register your own TestPartResultReporter. + // You can register your own repoter to listen either only for test results + // from the current thread or for results from all threads. + // By default, each per-thread test result repoter just passes a new + // TestPartResult to the global test result reporter, which registers the + // test part result for the currently running test. + + // Returns the global test part result reporter. + TestPartResultReporterInterface* GetGlobalTestPartResultReporter(); + + // Sets the global test part result reporter. + void SetGlobalTestPartResultReporter( + TestPartResultReporterInterface* reporter); + + // Returns the test part result reporter for the current thread. + TestPartResultReporterInterface* GetTestPartResultReporterForCurrentThread(); + + // Sets the test part result reporter for the current thread. + void SetTestPartResultReporterForCurrentThread( + TestPartResultReporterInterface* reporter); + + // Gets the number of successful test suites. + int successful_test_suite_count() const; + + // Gets the number of failed test suites. + int failed_test_suite_count() const; + + // Gets the number of all test suites. + int total_test_suite_count() const; + + // Gets the number of all test suites that contain at least one test + // that should run. + int test_suite_to_run_count() const; + + // Gets the number of successful tests. + int successful_test_count() const; + + // Gets the number of skipped tests. + int skipped_test_count() const; + + // Gets the number of failed tests. + int failed_test_count() const; + + // Gets the number of disabled tests that will be reported in the XML report. + int reportable_disabled_test_count() const; + + // Gets the number of disabled tests. + int disabled_test_count() const; + + // Gets the number of tests to be printed in the XML report. + int reportable_test_count() const; + + // Gets the number of all tests. + int total_test_count() const; + + // Gets the number of tests that should run. + int test_to_run_count() const; + + // Gets the time of the test program start, in ms from the start of the + // UNIX epoch. + TimeInMillis start_timestamp() const { return start_timestamp_; } + + // Gets the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Returns true if and only if the unit test passed (i.e. all test suites + // passed). + bool Passed() const { return !Failed(); } + + // Returns true if and only if the unit test failed (i.e. some test suite + // failed or something outside of all tests failed). + bool Failed() const { + return failed_test_suite_count() > 0 || ad_hoc_test_result()->Failed(); + } + + // Gets the i-th test suite among all the test suites. i can range from 0 to + // total_test_suite_count() - 1. If i is not in that range, returns NULL. + const TestSuite* GetTestSuite(int i) const { + const int index = GetElementOr(test_suite_indices_, i, -1); + return index < 0 ? nullptr : test_suites_[static_cast(i)]; + } + + // Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + const TestCase* GetTestCase(int i) const { return GetTestSuite(i); } +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Gets the i-th test suite among all the test suites. i can range from 0 to + // total_test_suite_count() - 1. If i is not in that range, returns NULL. + TestSuite* GetMutableSuiteCase(int i) { + const int index = GetElementOr(test_suite_indices_, i, -1); + return index < 0 ? nullptr : test_suites_[static_cast(index)]; + } + + // Provides access to the event listener list. + TestEventListeners* listeners() { return &listeners_; } + + // Returns the TestResult for the test that's currently running, or + // the TestResult for the ad hoc test if no test is running. + TestResult* current_test_result(); + + // Returns the TestResult for the ad hoc test. + const TestResult* ad_hoc_test_result() const { return &ad_hoc_test_result_; } + + // Sets the OS stack trace getter. + // + // Does nothing if the input and the current OS stack trace getter + // are the same; otherwise, deletes the old getter and makes the + // input the current getter. + void set_os_stack_trace_getter(OsStackTraceGetterInterface* getter); + + // Returns the current OS stack trace getter if it is not NULL; + // otherwise, creates an OsStackTraceGetter, makes it the current + // getter, and returns it. + OsStackTraceGetterInterface* os_stack_trace_getter(); + + // Returns the current OS stack trace as an std::string. + // + // The maximum number of stack frames to be included is specified by + // the gtest_stack_trace_depth flag. The skip_count parameter + // specifies the number of top frames to be skipped, which doesn't + // count against the number of frames to be included. + // + // For example, if Foo() calls Bar(), which in turn calls + // CurrentOsStackTraceExceptTop(1), Foo() will be included in the + // trace but Bar() and CurrentOsStackTraceExceptTop() won't. + std::string CurrentOsStackTraceExceptTop(int skip_count) + GTEST_NO_INLINE_ GTEST_NO_TAIL_CALL_; + + // Finds and returns a TestSuite with the given name. If one doesn't + // exist, creates one and returns it. + // + // Arguments: + // + // test_suite_name: name of the test suite + // type_param: the name of the test's type parameter, or NULL if + // this is not a typed or a type-parameterized test. + // set_up_tc: pointer to the function that sets up the test suite + // tear_down_tc: pointer to the function that tears down the test suite + TestSuite* GetTestSuite(const char* test_suite_name, const char* type_param, + internal::SetUpTestSuiteFunc set_up_tc, + internal::TearDownTestSuiteFunc tear_down_tc); + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + TestCase* GetTestCase(const char* test_case_name, const char* type_param, + internal::SetUpTestSuiteFunc set_up_tc, + internal::TearDownTestSuiteFunc tear_down_tc) { + return GetTestSuite(test_case_name, type_param, set_up_tc, tear_down_tc); + } +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // Adds a TestInfo to the unit test. + // + // Arguments: + // + // set_up_tc: pointer to the function that sets up the test suite + // tear_down_tc: pointer to the function that tears down the test suite + // test_info: the TestInfo object + void AddTestInfo(internal::SetUpTestSuiteFunc set_up_tc, + internal::TearDownTestSuiteFunc tear_down_tc, + TestInfo* test_info) { +#if GTEST_HAS_DEATH_TEST + // In order to support thread-safe death tests, we need to + // remember the original working directory when the test program + // was first invoked. We cannot do this in RUN_ALL_TESTS(), as + // the user may have changed the current directory before calling + // RUN_ALL_TESTS(). Therefore we capture the current directory in + // AddTestInfo(), which is called to register a TEST or TEST_F + // before main() is reached. + if (original_working_dir_.IsEmpty()) { + original_working_dir_.Set(FilePath::GetCurrentDir()); + GTEST_CHECK_(!original_working_dir_.IsEmpty()) + << "Failed to get the current working directory."; + } +#endif // GTEST_HAS_DEATH_TEST + + GetTestSuite(test_info->test_suite_name(), test_info->type_param(), + set_up_tc, tear_down_tc) + ->AddTestInfo(test_info); + } + + // Returns ParameterizedTestSuiteRegistry object used to keep track of + // value-parameterized tests and instantiate and register them. + internal::ParameterizedTestSuiteRegistry& parameterized_test_registry() { + return parameterized_test_registry_; + } + + std::set* ignored_parameterized_test_suites() { + return &ignored_parameterized_test_suites_; + } + + // Returns TypeParameterizedTestSuiteRegistry object used to keep track of + // type-parameterized tests and instantiations of them. + internal::TypeParameterizedTestSuiteRegistry& + type_parameterized_test_registry() { + return type_parameterized_test_registry_; + } + + // Sets the TestSuite object for the test that's currently running. + void set_current_test_suite(TestSuite* a_current_test_suite) { + current_test_suite_ = a_current_test_suite; + } + + // Sets the TestInfo object for the test that's currently running. If + // current_test_info is NULL, the assertion results will be stored in + // ad_hoc_test_result_. + void set_current_test_info(TestInfo* a_current_test_info) { + current_test_info_ = a_current_test_info; + } + + // Registers all parameterized tests defined using TEST_P and + // INSTANTIATE_TEST_SUITE_P, creating regular tests for each test/parameter + // combination. This method can be called more then once; it has guards + // protecting from registering the tests more then once. If + // value-parameterized tests are disabled, RegisterParameterizedTests is + // present but does nothing. + void RegisterParameterizedTests(); + + // Runs all tests in this UnitTest object, prints the result, and + // returns true if all tests are successful. If any exception is + // thrown during a test, this test is considered to be failed, but + // the rest of the tests will still be run. + bool RunAllTests(); + + // Clears the results of all tests, except the ad hoc tests. + void ClearNonAdHocTestResult() { + ForEach(test_suites_, TestSuite::ClearTestSuiteResult); + } + + // Clears the results of ad-hoc test assertions. + void ClearAdHocTestResult() { + ad_hoc_test_result_.Clear(); + } + + // Adds a TestProperty to the current TestResult object when invoked in a + // context of a test or a test suite, or to the global property set. If the + // result already contains a property with the same key, the value will be + // updated. + void RecordProperty(const TestProperty& test_property); + + enum ReactionToSharding { + HONOR_SHARDING_PROTOCOL, + IGNORE_SHARDING_PROTOCOL + }; + + // Matches the full name of each test against the user-specified + // filter to decide whether the test should run, then records the + // result in each TestSuite and TestInfo object. + // If shard_tests == HONOR_SHARDING_PROTOCOL, further filters tests + // based on sharding variables in the environment. + // Returns the number of tests that should run. + int FilterTests(ReactionToSharding shard_tests); + + // Prints the names of the tests matching the user-specified filter flag. + void ListTestsMatchingFilter(); + + const TestSuite* current_test_suite() const { return current_test_suite_; } + TestInfo* current_test_info() { return current_test_info_; } + const TestInfo* current_test_info() const { return current_test_info_; } + + // Returns the vector of environments that need to be set-up/torn-down + // before/after the tests are run. + std::vector& environments() { return environments_; } + + // Getters for the per-thread Google Test trace stack. + std::vector& gtest_trace_stack() { + return *(gtest_trace_stack_.pointer()); + } + const std::vector& gtest_trace_stack() const { + return gtest_trace_stack_.get(); + } + +#if GTEST_HAS_DEATH_TEST + void InitDeathTestSubprocessControlInfo() { + internal_run_death_test_flag_.reset(ParseInternalRunDeathTestFlag()); + } + // Returns a pointer to the parsed --gtest_internal_run_death_test + // flag, or NULL if that flag was not specified. + // This information is useful only in a death test child process. + // Must not be called before a call to InitGoogleTest. + const InternalRunDeathTestFlag* internal_run_death_test_flag() const { + return internal_run_death_test_flag_.get(); + } + + // Returns a pointer to the current death test factory. + internal::DeathTestFactory* death_test_factory() { + return death_test_factory_.get(); + } + + void SuppressTestEventsIfInSubprocess(); + + friend class ReplaceDeathTestFactory; +#endif // GTEST_HAS_DEATH_TEST + + // Initializes the event listener performing XML output as specified by + // UnitTestOptions. Must not be called before InitGoogleTest. + void ConfigureXmlOutput(); + +#if GTEST_CAN_STREAM_RESULTS_ + // Initializes the event listener for streaming test results to a socket. + // Must not be called before InitGoogleTest. + void ConfigureStreamingOutput(); +#endif + + // Performs initialization dependent upon flag values obtained in + // ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to + // ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest + // this function is also called from RunAllTests. Since this function can be + // called more than once, it has to be idempotent. + void PostFlagParsingInit(); + + // Gets the random seed used at the start of the current test iteration. + int random_seed() const { return random_seed_; } + + // Gets the random number generator. + internal::Random* random() { return &random_; } + + // Shuffles all test suites, and the tests within each test suite, + // making sure that death tests are still run first. + void ShuffleTests(); + + // Restores the test suites and tests to their order before the first shuffle. + void UnshuffleTests(); + + // Returns the value of GTEST_FLAG(catch_exceptions) at the moment + // UnitTest::Run() starts. + bool catch_exceptions() const { return catch_exceptions_; } + + private: + friend class ::testing::UnitTest; + + // Used by UnitTest::Run() to capture the state of + // GTEST_FLAG(catch_exceptions) at the moment it starts. + void set_catch_exceptions(bool value) { catch_exceptions_ = value; } + + // The UnitTest object that owns this implementation object. + UnitTest* const parent_; + + // The working directory when the first TEST() or TEST_F() was + // executed. + internal::FilePath original_working_dir_; + + // The default test part result reporters. + DefaultGlobalTestPartResultReporter default_global_test_part_result_reporter_; + DefaultPerThreadTestPartResultReporter + default_per_thread_test_part_result_reporter_; + + // Points to (but doesn't own) the global test part result reporter. + TestPartResultReporterInterface* global_test_part_result_repoter_; + + // Protects read and write access to global_test_part_result_reporter_. + internal::Mutex global_test_part_result_reporter_mutex_; + + // Points to (but doesn't own) the per-thread test part result reporter. + internal::ThreadLocal + per_thread_test_part_result_reporter_; + + // The vector of environments that need to be set-up/torn-down + // before/after the tests are run. + std::vector environments_; + + // The vector of TestSuites in their original order. It owns the + // elements in the vector. + std::vector test_suites_; + + // Provides a level of indirection for the test suite list to allow + // easy shuffling and restoring the test suite order. The i-th + // element of this vector is the index of the i-th test suite in the + // shuffled order. + std::vector test_suite_indices_; + + // ParameterizedTestRegistry object used to register value-parameterized + // tests. + internal::ParameterizedTestSuiteRegistry parameterized_test_registry_; + internal::TypeParameterizedTestSuiteRegistry + type_parameterized_test_registry_; + + // The set holding the name of parameterized + // test suites that may go uninstantiated. + std::set ignored_parameterized_test_suites_; + + // Indicates whether RegisterParameterizedTests() has been called already. + bool parameterized_tests_registered_; + + // Index of the last death test suite registered. Initially -1. + int last_death_test_suite_; + + // This points to the TestSuite for the currently running test. It + // changes as Google Test goes through one test suite after another. + // When no test is running, this is set to NULL and Google Test + // stores assertion results in ad_hoc_test_result_. Initially NULL. + TestSuite* current_test_suite_; + + // This points to the TestInfo for the currently running test. It + // changes as Google Test goes through one test after another. When + // no test is running, this is set to NULL and Google Test stores + // assertion results in ad_hoc_test_result_. Initially NULL. + TestInfo* current_test_info_; + + // Normally, a user only writes assertions inside a TEST or TEST_F, + // or inside a function called by a TEST or TEST_F. Since Google + // Test keeps track of which test is current running, it can + // associate such an assertion with the test it belongs to. + // + // If an assertion is encountered when no TEST or TEST_F is running, + // Google Test attributes the assertion result to an imaginary "ad hoc" + // test, and records the result in ad_hoc_test_result_. + TestResult ad_hoc_test_result_; + + // The list of event listeners that can be used to track events inside + // Google Test. + TestEventListeners listeners_; + + // The OS stack trace getter. Will be deleted when the UnitTest + // object is destructed. By default, an OsStackTraceGetter is used, + // but the user can set this field to use a custom getter if that is + // desired. + OsStackTraceGetterInterface* os_stack_trace_getter_; + + // True if and only if PostFlagParsingInit() has been called. + bool post_flag_parse_init_performed_; + + // The random number seed used at the beginning of the test run. + int random_seed_; + + // Our random number generator. + internal::Random random_; + + // The time of the test program start, in ms from the start of the + // UNIX epoch. + TimeInMillis start_timestamp_; + + // How long the test took to run, in milliseconds. + TimeInMillis elapsed_time_; + +#if GTEST_HAS_DEATH_TEST + // The decomposed components of the gtest_internal_run_death_test flag, + // parsed when RUN_ALL_TESTS is called. + std::unique_ptr internal_run_death_test_flag_; + std::unique_ptr death_test_factory_; +#endif // GTEST_HAS_DEATH_TEST + + // A per-thread stack of traces created by the SCOPED_TRACE() macro. + internal::ThreadLocal > gtest_trace_stack_; + + // The value of GTEST_FLAG(catch_exceptions) at the moment RunAllTests() + // starts. + bool catch_exceptions_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTestImpl); +}; // class UnitTestImpl + +// Convenience function for accessing the global UnitTest +// implementation object. +inline UnitTestImpl* GetUnitTestImpl() { + return UnitTest::GetInstance()->impl(); +} + +#if GTEST_USES_SIMPLE_RE + +// Internal helper functions for implementing the simple regular +// expression matcher. +GTEST_API_ bool IsInSet(char ch, const char* str); +GTEST_API_ bool IsAsciiDigit(char ch); +GTEST_API_ bool IsAsciiPunct(char ch); +GTEST_API_ bool IsRepeat(char ch); +GTEST_API_ bool IsAsciiWhiteSpace(char ch); +GTEST_API_ bool IsAsciiWordChar(char ch); +GTEST_API_ bool IsValidEscape(char ch); +GTEST_API_ bool AtomMatchesChar(bool escaped, char pattern, char ch); +GTEST_API_ bool ValidateRegex(const char* regex); +GTEST_API_ bool MatchRegexAtHead(const char* regex, const char* str); +GTEST_API_ bool MatchRepetitionAndRegexAtHead( + bool escaped, char ch, char repeat, const char* regex, const char* str); +GTEST_API_ bool MatchRegexAnywhere(const char* regex, const char* str); + +#endif // GTEST_USES_SIMPLE_RE + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. +GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, char** argv); +GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv); + +#if GTEST_HAS_DEATH_TEST + +// Returns the message describing the last system error, regardless of the +// platform. +GTEST_API_ std::string GetLastErrnoDescription(); + +// Attempts to parse a string into a positive integer pointed to by the +// number parameter. Returns true if that is possible. +// GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we can use +// it here. +template +bool ParseNaturalNumber(const ::std::string& str, Integer* number) { + // Fail fast if the given string does not begin with a digit; + // this bypasses strtoXXX's "optional leading whitespace and plus + // or minus sign" semantics, which are undesirable here. + if (str.empty() || !IsDigit(str[0])) { + return false; + } + errno = 0; + + char* end; + // BiggestConvertible is the largest integer type that system-provided + // string-to-number conversion routines can return. + using BiggestConvertible = unsigned long long; // NOLINT + + const BiggestConvertible parsed = strtoull(str.c_str(), &end, 10); // NOLINT + const bool parse_success = *end == '\0' && errno == 0; + + GTEST_CHECK_(sizeof(Integer) <= sizeof(parsed)); + + const Integer result = static_cast(parsed); + if (parse_success && static_cast(result) == parsed) { + *number = result; + return true; + } + return false; +} +#endif // GTEST_HAS_DEATH_TEST + +// TestResult contains some private methods that should be hidden from +// Google Test user but are required for testing. This class allow our tests +// to access them. +// +// This class is supplied only for the purpose of testing Google Test's own +// constructs. Do not use it in user tests, either directly or indirectly. +class TestResultAccessor { + public: + static void RecordProperty(TestResult* test_result, + const std::string& xml_element, + const TestProperty& property) { + test_result->RecordProperty(xml_element, property); + } + + static void ClearTestPartResults(TestResult* test_result) { + test_result->ClearTestPartResults(); + } + + static const std::vector& test_part_results( + const TestResult& test_result) { + return test_result.test_part_results(); + } +}; + +#if GTEST_CAN_STREAM_RESULTS_ + +// Streams test results to the given port on the given host machine. +class StreamingListener : public EmptyTestEventListener { + public: + // Abstract base class for writing strings to a socket. + class AbstractSocketWriter { + public: + virtual ~AbstractSocketWriter() {} + + // Sends a string to the socket. + virtual void Send(const std::string& message) = 0; + + // Closes the socket. + virtual void CloseConnection() {} + + // Sends a string and a newline to the socket. + void SendLn(const std::string& message) { Send(message + "\n"); } + }; + + // Concrete class for actually writing strings to a socket. + class SocketWriter : public AbstractSocketWriter { + public: + SocketWriter(const std::string& host, const std::string& port) + : sockfd_(-1), host_name_(host), port_num_(port) { + MakeConnection(); + } + + ~SocketWriter() override { + if (sockfd_ != -1) + CloseConnection(); + } + + // Sends a string to the socket. + void Send(const std::string& message) override { + GTEST_CHECK_(sockfd_ != -1) + << "Send() can be called only when there is a connection."; + + const auto len = static_cast(message.length()); + if (write(sockfd_, message.c_str(), len) != static_cast(len)) { + GTEST_LOG_(WARNING) + << "stream_result_to: failed to stream to " + << host_name_ << ":" << port_num_; + } + } + + private: + // Creates a client socket and connects to the server. + void MakeConnection(); + + // Closes the socket. + void CloseConnection() override { + GTEST_CHECK_(sockfd_ != -1) + << "CloseConnection() can be called only when there is a connection."; + + close(sockfd_); + sockfd_ = -1; + } + + int sockfd_; // socket file descriptor + const std::string host_name_; + const std::string port_num_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(SocketWriter); + }; // class SocketWriter + + // Escapes '=', '&', '%', and '\n' characters in str as "%xx". + static std::string UrlEncode(const char* str); + + StreamingListener(const std::string& host, const std::string& port) + : socket_writer_(new SocketWriter(host, port)) { + Start(); + } + + explicit StreamingListener(AbstractSocketWriter* socket_writer) + : socket_writer_(socket_writer) { Start(); } + + void OnTestProgramStart(const UnitTest& /* unit_test */) override { + SendLn("event=TestProgramStart"); + } + + void OnTestProgramEnd(const UnitTest& unit_test) override { + // Note that Google Test current only report elapsed time for each + // test iteration, not for the entire test program. + SendLn("event=TestProgramEnd&passed=" + FormatBool(unit_test.Passed())); + + // Notify the streaming server to stop. + socket_writer_->CloseConnection(); + } + + void OnTestIterationStart(const UnitTest& /* unit_test */, + int iteration) override { + SendLn("event=TestIterationStart&iteration=" + + StreamableToString(iteration)); + } + + void OnTestIterationEnd(const UnitTest& unit_test, + int /* iteration */) override { + SendLn("event=TestIterationEnd&passed=" + + FormatBool(unit_test.Passed()) + "&elapsed_time=" + + StreamableToString(unit_test.elapsed_time()) + "ms"); + } + + // Note that "event=TestCaseStart" is a wire format and has to remain + // "case" for compatibility + void OnTestSuiteStart(const TestSuite& test_suite) override { + SendLn(std::string("event=TestCaseStart&name=") + test_suite.name()); + } + + // Note that "event=TestCaseEnd" is a wire format and has to remain + // "case" for compatibility + void OnTestSuiteEnd(const TestSuite& test_suite) override { + SendLn("event=TestCaseEnd&passed=" + FormatBool(test_suite.Passed()) + + "&elapsed_time=" + StreamableToString(test_suite.elapsed_time()) + + "ms"); + } + + void OnTestStart(const TestInfo& test_info) override { + SendLn(std::string("event=TestStart&name=") + test_info.name()); + } + + void OnTestEnd(const TestInfo& test_info) override { + SendLn("event=TestEnd&passed=" + + FormatBool((test_info.result())->Passed()) + + "&elapsed_time=" + + StreamableToString((test_info.result())->elapsed_time()) + "ms"); + } + + void OnTestPartResult(const TestPartResult& test_part_result) override { + const char* file_name = test_part_result.file_name(); + if (file_name == nullptr) file_name = ""; + SendLn("event=TestPartResult&file=" + UrlEncode(file_name) + + "&line=" + StreamableToString(test_part_result.line_number()) + + "&message=" + UrlEncode(test_part_result.message())); + } + + private: + // Sends the given message and a newline to the socket. + void SendLn(const std::string& message) { socket_writer_->SendLn(message); } + + // Called at the start of streaming to notify the receiver what + // protocol we are using. + void Start() { SendLn("gtest_streaming_protocol_version=1.0"); } + + std::string FormatBool(bool value) { return value ? "1" : "0"; } + + const std::unique_ptr socket_writer_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(StreamingListener); +}; // class StreamingListener + +#endif // GTEST_CAN_STREAM_RESULTS_ + +} // namespace internal +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLETEST_SRC_GTEST_INTERNAL_INL_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-matchers.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-matchers.cc new file mode 100644 index 000000000000..65104ebab1ba --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-matchers.cc @@ -0,0 +1,97 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This file implements just enough of the matcher interface to allow +// EXPECT_DEATH and friends to accept a matcher argument. + +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-port.h" +#include "gtest/gtest-matchers.h" + +#include + +namespace testing { + +// Constructs a matcher that matches a const std::string& whose value is +// equal to s. +Matcher::Matcher(const std::string& s) { *this = Eq(s); } + +// Constructs a matcher that matches a const std::string& whose value is +// equal to s. +Matcher::Matcher(const char* s) { + *this = Eq(std::string(s)); +} + +// Constructs a matcher that matches a std::string whose value is equal to +// s. +Matcher::Matcher(const std::string& s) { *this = Eq(s); } + +// Constructs a matcher that matches a std::string whose value is equal to +// s. +Matcher::Matcher(const char* s) { *this = Eq(std::string(s)); } + +#if GTEST_INTERNAL_HAS_STRING_VIEW +// Constructs a matcher that matches a const StringView& whose value is +// equal to s. +Matcher::Matcher(const std::string& s) { + *this = Eq(s); +} + +// Constructs a matcher that matches a const StringView& whose value is +// equal to s. +Matcher::Matcher(const char* s) { + *this = Eq(std::string(s)); +} + +// Constructs a matcher that matches a const StringView& whose value is +// equal to s. +Matcher::Matcher(internal::StringView s) { + *this = Eq(std::string(s)); +} + +// Constructs a matcher that matches a StringView whose value is equal to +// s. +Matcher::Matcher(const std::string& s) { *this = Eq(s); } + +// Constructs a matcher that matches a StringView whose value is equal to +// s. +Matcher::Matcher(const char* s) { + *this = Eq(std::string(s)); +} + +// Constructs a matcher that matches a StringView whose value is equal to +// s. +Matcher::Matcher(internal::StringView s) { + *this = Eq(std::string(s)); +} +#endif // GTEST_INTERNAL_HAS_STRING_VIEW + +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-port.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-port.cc new file mode 100644 index 000000000000..d47550aecfd3 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-port.cc @@ -0,0 +1,1413 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "gtest/internal/gtest-port.h" + +#include +#include +#include +#include +#include +#include +#include + +#if GTEST_OS_WINDOWS +# include +# include +# include +# include // Used in ThreadLocal. +# ifdef _MSC_VER +# include +# endif // _MSC_VER +#else +# include +#endif // GTEST_OS_WINDOWS + +#if GTEST_OS_MAC +# include +# include +# include +#endif // GTEST_OS_MAC + +#if GTEST_OS_DRAGONFLY || GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD || \ + GTEST_OS_NETBSD || GTEST_OS_OPENBSD +# include +# if GTEST_OS_DRAGONFLY || GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD +# include +# endif +#endif + +#if GTEST_OS_QNX +# include +# include +# include +#endif // GTEST_OS_QNX + +#if GTEST_OS_AIX +# include +# include +#endif // GTEST_OS_AIX + +#if GTEST_OS_FUCHSIA +# include +# include +#endif // GTEST_OS_FUCHSIA + +#include "gtest/gtest-spi.h" +#include "gtest/gtest-message.h" +#include "gtest/internal/gtest-internal.h" +#include "gtest/internal/gtest-string.h" +#include "src/gtest-internal-inl.h" + +namespace testing { +namespace internal { + +#if defined(_MSC_VER) || defined(__BORLANDC__) +// MSVC and C++Builder do not provide a definition of STDERR_FILENO. +const int kStdOutFileno = 1; +const int kStdErrFileno = 2; +#else +const int kStdOutFileno = STDOUT_FILENO; +const int kStdErrFileno = STDERR_FILENO; +#endif // _MSC_VER + +#if GTEST_OS_LINUX || GTEST_OS_GNU_HURD + +namespace { +template +T ReadProcFileField(const std::string& filename, int field) { + std::string dummy; + std::ifstream file(filename.c_str()); + while (field-- > 0) { + file >> dummy; + } + T output = 0; + file >> output; + return output; +} +} // namespace + +// Returns the number of active threads, or 0 when there is an error. +size_t GetThreadCount() { + const std::string filename = + (Message() << "/proc/" << getpid() << "/stat").GetString(); + return ReadProcFileField(filename, 19); +} + +#elif GTEST_OS_MAC + +size_t GetThreadCount() { + const task_t task = mach_task_self(); + mach_msg_type_number_t thread_count; + thread_act_array_t thread_list; + const kern_return_t status = task_threads(task, &thread_list, &thread_count); + if (status == KERN_SUCCESS) { + // task_threads allocates resources in thread_list and we need to free them + // to avoid leaks. + vm_deallocate(task, + reinterpret_cast(thread_list), + sizeof(thread_t) * thread_count); + return static_cast(thread_count); + } else { + return 0; + } +} + +#elif GTEST_OS_DRAGONFLY || GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD || \ + GTEST_OS_NETBSD + +#if GTEST_OS_NETBSD +#undef KERN_PROC +#define KERN_PROC KERN_PROC2 +#define kinfo_proc kinfo_proc2 +#endif + +#if GTEST_OS_DRAGONFLY +#define KP_NLWP(kp) (kp.kp_nthreads) +#elif GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD +#define KP_NLWP(kp) (kp.ki_numthreads) +#elif GTEST_OS_NETBSD +#define KP_NLWP(kp) (kp.p_nlwps) +#endif + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +size_t GetThreadCount() { + int mib[] = { + CTL_KERN, + KERN_PROC, + KERN_PROC_PID, + getpid(), +#if GTEST_OS_NETBSD + sizeof(struct kinfo_proc), + 1, +#endif + }; + u_int miblen = sizeof(mib) / sizeof(mib[0]); + struct kinfo_proc info; + size_t size = sizeof(info); + if (sysctl(mib, miblen, &info, &size, NULL, 0)) { + return 0; + } + return static_cast(KP_NLWP(info)); +} +#elif GTEST_OS_OPENBSD + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +size_t GetThreadCount() { + int mib[] = { + CTL_KERN, + KERN_PROC, + KERN_PROC_PID | KERN_PROC_SHOW_THREADS, + getpid(), + sizeof(struct kinfo_proc), + 0, + }; + u_int miblen = sizeof(mib) / sizeof(mib[0]); + + // get number of structs + size_t size; + if (sysctl(mib, miblen, NULL, &size, NULL, 0)) { + return 0; + } + + mib[5] = static_cast(size / static_cast(mib[4])); + + // populate array of structs + struct kinfo_proc info[mib[5]]; + if (sysctl(mib, miblen, &info, &size, NULL, 0)) { + return 0; + } + + // exclude empty members + size_t nthreads = 0; + for (size_t i = 0; i < size / static_cast(mib[4]); i++) { + if (info[i].p_tid != -1) + nthreads++; + } + return nthreads; +} + +#elif GTEST_OS_QNX + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +size_t GetThreadCount() { + const int fd = open("/proc/self/as", O_RDONLY); + if (fd < 0) { + return 0; + } + procfs_info process_info; + const int status = + devctl(fd, DCMD_PROC_INFO, &process_info, sizeof(process_info), nullptr); + close(fd); + if (status == EOK) { + return static_cast(process_info.num_threads); + } else { + return 0; + } +} + +#elif GTEST_OS_AIX + +size_t GetThreadCount() { + struct procentry64 entry; + pid_t pid = getpid(); + int status = getprocs64(&entry, sizeof(entry), nullptr, 0, &pid, 1); + if (status == 1) { + return entry.pi_thcount; + } else { + return 0; + } +} + +#elif GTEST_OS_FUCHSIA + +size_t GetThreadCount() { + int dummy_buffer; + size_t avail; + zx_status_t status = zx_object_get_info( + zx_process_self(), + ZX_INFO_PROCESS_THREADS, + &dummy_buffer, + 0, + nullptr, + &avail); + if (status == ZX_OK) { + return avail; + } else { + return 0; + } +} + +#else + +size_t GetThreadCount() { + // There's no portable way to detect the number of threads, so we just + // return 0 to indicate that we cannot detect it. + return 0; +} + +#endif // GTEST_OS_LINUX + +#if GTEST_IS_THREADSAFE && GTEST_OS_WINDOWS + +AutoHandle::AutoHandle() + : handle_(INVALID_HANDLE_VALUE) {} + +AutoHandle::AutoHandle(Handle handle) + : handle_(handle) {} + +AutoHandle::~AutoHandle() { + Reset(); +} + +AutoHandle::Handle AutoHandle::Get() const { + return handle_; +} + +void AutoHandle::Reset() { + Reset(INVALID_HANDLE_VALUE); +} + +void AutoHandle::Reset(HANDLE handle) { + // Resetting with the same handle we already own is invalid. + if (handle_ != handle) { + if (IsCloseable()) { + ::CloseHandle(handle_); + } + handle_ = handle; + } else { + GTEST_CHECK_(!IsCloseable()) + << "Resetting a valid handle to itself is likely a programmer error " + "and thus not allowed."; + } +} + +bool AutoHandle::IsCloseable() const { + // Different Windows APIs may use either of these values to represent an + // invalid handle. + return handle_ != nullptr && handle_ != INVALID_HANDLE_VALUE; +} + +Mutex::Mutex() + : owner_thread_id_(0), + type_(kDynamic), + critical_section_init_phase_(0), + critical_section_(new CRITICAL_SECTION) { + ::InitializeCriticalSection(critical_section_); +} + +Mutex::~Mutex() { + // Static mutexes are leaked intentionally. It is not thread-safe to try + // to clean them up. + if (type_ == kDynamic) { + ::DeleteCriticalSection(critical_section_); + delete critical_section_; + critical_section_ = nullptr; + } +} + +void Mutex::Lock() { + ThreadSafeLazyInit(); + ::EnterCriticalSection(critical_section_); + owner_thread_id_ = ::GetCurrentThreadId(); +} + +void Mutex::Unlock() { + ThreadSafeLazyInit(); + // We don't protect writing to owner_thread_id_ here, as it's the + // caller's responsibility to ensure that the current thread holds the + // mutex when this is called. + owner_thread_id_ = 0; + ::LeaveCriticalSection(critical_section_); +} + +// Does nothing if the current thread holds the mutex. Otherwise, crashes +// with high probability. +void Mutex::AssertHeld() { + ThreadSafeLazyInit(); + GTEST_CHECK_(owner_thread_id_ == ::GetCurrentThreadId()) + << "The current thread is not holding the mutex @" << this; +} + +namespace { + +#ifdef _MSC_VER +// Use the RAII idiom to flag mem allocs that are intentionally never +// deallocated. The motivation is to silence the false positive mem leaks +// that are reported by the debug version of MS's CRT which can only detect +// if an alloc is missing a matching deallocation. +// Example: +// MemoryIsNotDeallocated memory_is_not_deallocated; +// critical_section_ = new CRITICAL_SECTION; +// +class MemoryIsNotDeallocated +{ + public: + MemoryIsNotDeallocated() : old_crtdbg_flag_(0) { + old_crtdbg_flag_ = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG); + // Set heap allocation block type to _IGNORE_BLOCK so that MS debug CRT + // doesn't report mem leak if there's no matching deallocation. + (void)_CrtSetDbgFlag(old_crtdbg_flag_ & ~_CRTDBG_ALLOC_MEM_DF); + } + + ~MemoryIsNotDeallocated() { + // Restore the original _CRTDBG_ALLOC_MEM_DF flag + (void)_CrtSetDbgFlag(old_crtdbg_flag_); + } + + private: + int old_crtdbg_flag_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(MemoryIsNotDeallocated); +}; +#endif // _MSC_VER + +} // namespace + +// Initializes owner_thread_id_ and critical_section_ in static mutexes. +void Mutex::ThreadSafeLazyInit() { + // Dynamic mutexes are initialized in the constructor. + if (type_ == kStatic) { + switch ( + ::InterlockedCompareExchange(&critical_section_init_phase_, 1L, 0L)) { + case 0: + // If critical_section_init_phase_ was 0 before the exchange, we + // are the first to test it and need to perform the initialization. + owner_thread_id_ = 0; + { + // Use RAII to flag that following mem alloc is never deallocated. +#ifdef _MSC_VER + MemoryIsNotDeallocated memory_is_not_deallocated; +#endif // _MSC_VER + critical_section_ = new CRITICAL_SECTION; + } + ::InitializeCriticalSection(critical_section_); + // Updates the critical_section_init_phase_ to 2 to signal + // initialization complete. + GTEST_CHECK_(::InterlockedCompareExchange( + &critical_section_init_phase_, 2L, 1L) == + 1L); + break; + case 1: + // Somebody else is already initializing the mutex; spin until they + // are done. + while (::InterlockedCompareExchange(&critical_section_init_phase_, + 2L, + 2L) != 2L) { + // Possibly yields the rest of the thread's time slice to other + // threads. + ::Sleep(0); + } + break; + + case 2: + break; // The mutex is already initialized and ready for use. + + default: + GTEST_CHECK_(false) + << "Unexpected value of critical_section_init_phase_ " + << "while initializing a static mutex."; + } + } +} + +namespace { + +class ThreadWithParamSupport : public ThreadWithParamBase { + public: + static HANDLE CreateThread(Runnable* runnable, + Notification* thread_can_start) { + ThreadMainParam* param = new ThreadMainParam(runnable, thread_can_start); + DWORD thread_id; + HANDLE thread_handle = ::CreateThread( + nullptr, // Default security. + 0, // Default stack size. + &ThreadWithParamSupport::ThreadMain, + param, // Parameter to ThreadMainStatic + 0x0, // Default creation flags. + &thread_id); // Need a valid pointer for the call to work under Win98. + GTEST_CHECK_(thread_handle != nullptr) + << "CreateThread failed with error " << ::GetLastError() << "."; + if (thread_handle == nullptr) { + delete param; + } + return thread_handle; + } + + private: + struct ThreadMainParam { + ThreadMainParam(Runnable* runnable, Notification* thread_can_start) + : runnable_(runnable), + thread_can_start_(thread_can_start) { + } + std::unique_ptr runnable_; + // Does not own. + Notification* thread_can_start_; + }; + + static DWORD WINAPI ThreadMain(void* ptr) { + // Transfers ownership. + std::unique_ptr param(static_cast(ptr)); + if (param->thread_can_start_ != nullptr) + param->thread_can_start_->WaitForNotification(); + param->runnable_->Run(); + return 0; + } + + // Prohibit instantiation. + ThreadWithParamSupport(); + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParamSupport); +}; + +} // namespace + +ThreadWithParamBase::ThreadWithParamBase(Runnable *runnable, + Notification* thread_can_start) + : thread_(ThreadWithParamSupport::CreateThread(runnable, + thread_can_start)) { +} + +ThreadWithParamBase::~ThreadWithParamBase() { + Join(); +} + +void ThreadWithParamBase::Join() { + GTEST_CHECK_(::WaitForSingleObject(thread_.Get(), INFINITE) == WAIT_OBJECT_0) + << "Failed to join the thread with error " << ::GetLastError() << "."; +} + +// Maps a thread to a set of ThreadIdToThreadLocals that have values +// instantiated on that thread and notifies them when the thread exits. A +// ThreadLocal instance is expected to persist until all threads it has +// values on have terminated. +class ThreadLocalRegistryImpl { + public: + // Registers thread_local_instance as having value on the current thread. + // Returns a value that can be used to identify the thread from other threads. + static ThreadLocalValueHolderBase* GetValueOnCurrentThread( + const ThreadLocalBase* thread_local_instance) { +#ifdef _MSC_VER + MemoryIsNotDeallocated memory_is_not_deallocated; +#endif // _MSC_VER + DWORD current_thread = ::GetCurrentThreadId(); + MutexLock lock(&mutex_); + ThreadIdToThreadLocals* const thread_to_thread_locals = + GetThreadLocalsMapLocked(); + ThreadIdToThreadLocals::iterator thread_local_pos = + thread_to_thread_locals->find(current_thread); + if (thread_local_pos == thread_to_thread_locals->end()) { + thread_local_pos = thread_to_thread_locals->insert( + std::make_pair(current_thread, ThreadLocalValues())).first; + StartWatcherThreadFor(current_thread); + } + ThreadLocalValues& thread_local_values = thread_local_pos->second; + ThreadLocalValues::iterator value_pos = + thread_local_values.find(thread_local_instance); + if (value_pos == thread_local_values.end()) { + value_pos = + thread_local_values + .insert(std::make_pair( + thread_local_instance, + std::shared_ptr( + thread_local_instance->NewValueForCurrentThread()))) + .first; + } + return value_pos->second.get(); + } + + static void OnThreadLocalDestroyed( + const ThreadLocalBase* thread_local_instance) { + std::vector > value_holders; + // Clean up the ThreadLocalValues data structure while holding the lock, but + // defer the destruction of the ThreadLocalValueHolderBases. + { + MutexLock lock(&mutex_); + ThreadIdToThreadLocals* const thread_to_thread_locals = + GetThreadLocalsMapLocked(); + for (ThreadIdToThreadLocals::iterator it = + thread_to_thread_locals->begin(); + it != thread_to_thread_locals->end(); + ++it) { + ThreadLocalValues& thread_local_values = it->second; + ThreadLocalValues::iterator value_pos = + thread_local_values.find(thread_local_instance); + if (value_pos != thread_local_values.end()) { + value_holders.push_back(value_pos->second); + thread_local_values.erase(value_pos); + // This 'if' can only be successful at most once, so theoretically we + // could break out of the loop here, but we don't bother doing so. + } + } + } + // Outside the lock, let the destructor for 'value_holders' deallocate the + // ThreadLocalValueHolderBases. + } + + static void OnThreadExit(DWORD thread_id) { + GTEST_CHECK_(thread_id != 0) << ::GetLastError(); + std::vector > value_holders; + // Clean up the ThreadIdToThreadLocals data structure while holding the + // lock, but defer the destruction of the ThreadLocalValueHolderBases. + { + MutexLock lock(&mutex_); + ThreadIdToThreadLocals* const thread_to_thread_locals = + GetThreadLocalsMapLocked(); + ThreadIdToThreadLocals::iterator thread_local_pos = + thread_to_thread_locals->find(thread_id); + if (thread_local_pos != thread_to_thread_locals->end()) { + ThreadLocalValues& thread_local_values = thread_local_pos->second; + for (ThreadLocalValues::iterator value_pos = + thread_local_values.begin(); + value_pos != thread_local_values.end(); + ++value_pos) { + value_holders.push_back(value_pos->second); + } + thread_to_thread_locals->erase(thread_local_pos); + } + } + // Outside the lock, let the destructor for 'value_holders' deallocate the + // ThreadLocalValueHolderBases. + } + + private: + // In a particular thread, maps a ThreadLocal object to its value. + typedef std::map > + ThreadLocalValues; + // Stores all ThreadIdToThreadLocals having values in a thread, indexed by + // thread's ID. + typedef std::map ThreadIdToThreadLocals; + + // Holds the thread id and thread handle that we pass from + // StartWatcherThreadFor to WatcherThreadFunc. + typedef std::pair ThreadIdAndHandle; + + static void StartWatcherThreadFor(DWORD thread_id) { + // The returned handle will be kept in thread_map and closed by + // watcher_thread in WatcherThreadFunc. + HANDLE thread = ::OpenThread(SYNCHRONIZE | THREAD_QUERY_INFORMATION, + FALSE, + thread_id); + GTEST_CHECK_(thread != nullptr); + // We need to pass a valid thread ID pointer into CreateThread for it + // to work correctly under Win98. + DWORD watcher_thread_id; + HANDLE watcher_thread = ::CreateThread( + nullptr, // Default security. + 0, // Default stack size + &ThreadLocalRegistryImpl::WatcherThreadFunc, + reinterpret_cast(new ThreadIdAndHandle(thread_id, thread)), + CREATE_SUSPENDED, &watcher_thread_id); + GTEST_CHECK_(watcher_thread != nullptr) + << "CreateThread failed with error " << ::GetLastError() << "."; + // Give the watcher thread the same priority as ours to avoid being + // blocked by it. + ::SetThreadPriority(watcher_thread, + ::GetThreadPriority(::GetCurrentThread())); + ::ResumeThread(watcher_thread); + ::CloseHandle(watcher_thread); + } + + // Monitors exit from a given thread and notifies those + // ThreadIdToThreadLocals about thread termination. + static DWORD WINAPI WatcherThreadFunc(LPVOID param) { + const ThreadIdAndHandle* tah = + reinterpret_cast(param); + GTEST_CHECK_( + ::WaitForSingleObject(tah->second, INFINITE) == WAIT_OBJECT_0); + OnThreadExit(tah->first); + ::CloseHandle(tah->second); + delete tah; + return 0; + } + + // Returns map of thread local instances. + static ThreadIdToThreadLocals* GetThreadLocalsMapLocked() { + mutex_.AssertHeld(); +#ifdef _MSC_VER + MemoryIsNotDeallocated memory_is_not_deallocated; +#endif // _MSC_VER + static ThreadIdToThreadLocals* map = new ThreadIdToThreadLocals(); + return map; + } + + // Protects access to GetThreadLocalsMapLocked() and its return value. + static Mutex mutex_; + // Protects access to GetThreadMapLocked() and its return value. + static Mutex thread_map_mutex_; +}; + +Mutex ThreadLocalRegistryImpl::mutex_(Mutex::kStaticMutex); // NOLINT +Mutex ThreadLocalRegistryImpl::thread_map_mutex_(Mutex::kStaticMutex); // NOLINT + +ThreadLocalValueHolderBase* ThreadLocalRegistry::GetValueOnCurrentThread( + const ThreadLocalBase* thread_local_instance) { + return ThreadLocalRegistryImpl::GetValueOnCurrentThread( + thread_local_instance); +} + +void ThreadLocalRegistry::OnThreadLocalDestroyed( + const ThreadLocalBase* thread_local_instance) { + ThreadLocalRegistryImpl::OnThreadLocalDestroyed(thread_local_instance); +} + +#endif // GTEST_IS_THREADSAFE && GTEST_OS_WINDOWS + +#if GTEST_USES_POSIX_RE + +// Implements RE. Currently only needed for death tests. + +RE::~RE() { + if (is_valid_) { + // regfree'ing an invalid regex might crash because the content + // of the regex is undefined. Since the regex's are essentially + // the same, one cannot be valid (or invalid) without the other + // being so too. + regfree(&partial_regex_); + regfree(&full_regex_); + } + free(const_cast(pattern_)); +} + +// Returns true if and only if regular expression re matches the entire str. +bool RE::FullMatch(const char* str, const RE& re) { + if (!re.is_valid_) return false; + + regmatch_t match; + return regexec(&re.full_regex_, str, 1, &match, 0) == 0; +} + +// Returns true if and only if regular expression re matches a substring of +// str (including str itself). +bool RE::PartialMatch(const char* str, const RE& re) { + if (!re.is_valid_) return false; + + regmatch_t match; + return regexec(&re.partial_regex_, str, 1, &match, 0) == 0; +} + +// Initializes an RE from its string representation. +void RE::Init(const char* regex) { + pattern_ = posix::StrDup(regex); + + // Reserves enough bytes to hold the regular expression used for a + // full match. + const size_t full_regex_len = strlen(regex) + 10; + char* const full_pattern = new char[full_regex_len]; + + snprintf(full_pattern, full_regex_len, "^(%s)$", regex); + is_valid_ = regcomp(&full_regex_, full_pattern, REG_EXTENDED) == 0; + // We want to call regcomp(&partial_regex_, ...) even if the + // previous expression returns false. Otherwise partial_regex_ may + // not be properly initialized can may cause trouble when it's + // freed. + // + // Some implementation of POSIX regex (e.g. on at least some + // versions of Cygwin) doesn't accept the empty string as a valid + // regex. We change it to an equivalent form "()" to be safe. + if (is_valid_) { + const char* const partial_regex = (*regex == '\0') ? "()" : regex; + is_valid_ = regcomp(&partial_regex_, partial_regex, REG_EXTENDED) == 0; + } + EXPECT_TRUE(is_valid_) + << "Regular expression \"" << regex + << "\" is not a valid POSIX Extended regular expression."; + + delete[] full_pattern; +} + +#elif GTEST_USES_SIMPLE_RE + +// Returns true if and only if ch appears anywhere in str (excluding the +// terminating '\0' character). +bool IsInSet(char ch, const char* str) { + return ch != '\0' && strchr(str, ch) != nullptr; +} + +// Returns true if and only if ch belongs to the given classification. +// Unlike similar functions in , these aren't affected by the +// current locale. +bool IsAsciiDigit(char ch) { return '0' <= ch && ch <= '9'; } +bool IsAsciiPunct(char ch) { + return IsInSet(ch, "^-!\"#$%&'()*+,./:;<=>?@[\\]_`{|}~"); +} +bool IsRepeat(char ch) { return IsInSet(ch, "?*+"); } +bool IsAsciiWhiteSpace(char ch) { return IsInSet(ch, " \f\n\r\t\v"); } +bool IsAsciiWordChar(char ch) { + return ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || + ('0' <= ch && ch <= '9') || ch == '_'; +} + +// Returns true if and only if "\\c" is a supported escape sequence. +bool IsValidEscape(char c) { + return (IsAsciiPunct(c) || IsInSet(c, "dDfnrsStvwW")); +} + +// Returns true if and only if the given atom (specified by escaped and +// pattern) matches ch. The result is undefined if the atom is invalid. +bool AtomMatchesChar(bool escaped, char pattern_char, char ch) { + if (escaped) { // "\\p" where p is pattern_char. + switch (pattern_char) { + case 'd': return IsAsciiDigit(ch); + case 'D': return !IsAsciiDigit(ch); + case 'f': return ch == '\f'; + case 'n': return ch == '\n'; + case 'r': return ch == '\r'; + case 's': return IsAsciiWhiteSpace(ch); + case 'S': return !IsAsciiWhiteSpace(ch); + case 't': return ch == '\t'; + case 'v': return ch == '\v'; + case 'w': return IsAsciiWordChar(ch); + case 'W': return !IsAsciiWordChar(ch); + } + return IsAsciiPunct(pattern_char) && pattern_char == ch; + } + + return (pattern_char == '.' && ch != '\n') || pattern_char == ch; +} + +// Helper function used by ValidateRegex() to format error messages. +static std::string FormatRegexSyntaxError(const char* regex, int index) { + return (Message() << "Syntax error at index " << index + << " in simple regular expression \"" << regex << "\": ").GetString(); +} + +// Generates non-fatal failures and returns false if regex is invalid; +// otherwise returns true. +bool ValidateRegex(const char* regex) { + if (regex == nullptr) { + ADD_FAILURE() << "NULL is not a valid simple regular expression."; + return false; + } + + bool is_valid = true; + + // True if and only if ?, *, or + can follow the previous atom. + bool prev_repeatable = false; + for (int i = 0; regex[i]; i++) { + if (regex[i] == '\\') { // An escape sequence + i++; + if (regex[i] == '\0') { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) + << "'\\' cannot appear at the end."; + return false; + } + + if (!IsValidEscape(regex[i])) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) + << "invalid escape sequence \"\\" << regex[i] << "\"."; + is_valid = false; + } + prev_repeatable = true; + } else { // Not an escape sequence. + const char ch = regex[i]; + + if (ch == '^' && i > 0) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'^' can only appear at the beginning."; + is_valid = false; + } else if (ch == '$' && regex[i + 1] != '\0') { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'$' can only appear at the end."; + is_valid = false; + } else if (IsInSet(ch, "()[]{}|")) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'" << ch << "' is unsupported."; + is_valid = false; + } else if (IsRepeat(ch) && !prev_repeatable) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'" << ch << "' can only follow a repeatable token."; + is_valid = false; + } + + prev_repeatable = !IsInSet(ch, "^$?*+"); + } + } + + return is_valid; +} + +// Matches a repeated regex atom followed by a valid simple regular +// expression. The regex atom is defined as c if escaped is false, +// or \c otherwise. repeat is the repetition meta character (?, *, +// or +). The behavior is undefined if str contains too many +// characters to be indexable by size_t, in which case the test will +// probably time out anyway. We are fine with this limitation as +// std::string has it too. +bool MatchRepetitionAndRegexAtHead( + bool escaped, char c, char repeat, const char* regex, + const char* str) { + const size_t min_count = (repeat == '+') ? 1 : 0; + const size_t max_count = (repeat == '?') ? 1 : + static_cast(-1) - 1; + // We cannot call numeric_limits::max() as it conflicts with the + // max() macro on Windows. + + for (size_t i = 0; i <= max_count; ++i) { + // We know that the atom matches each of the first i characters in str. + if (i >= min_count && MatchRegexAtHead(regex, str + i)) { + // We have enough matches at the head, and the tail matches too. + // Since we only care about *whether* the pattern matches str + // (as opposed to *how* it matches), there is no need to find a + // greedy match. + return true; + } + if (str[i] == '\0' || !AtomMatchesChar(escaped, c, str[i])) + return false; + } + return false; +} + +// Returns true if and only if regex matches a prefix of str. regex must +// be a valid simple regular expression and not start with "^", or the +// result is undefined. +bool MatchRegexAtHead(const char* regex, const char* str) { + if (*regex == '\0') // An empty regex matches a prefix of anything. + return true; + + // "$" only matches the end of a string. Note that regex being + // valid guarantees that there's nothing after "$" in it. + if (*regex == '$') + return *str == '\0'; + + // Is the first thing in regex an escape sequence? + const bool escaped = *regex == '\\'; + if (escaped) + ++regex; + if (IsRepeat(regex[1])) { + // MatchRepetitionAndRegexAtHead() calls MatchRegexAtHead(), so + // here's an indirect recursion. It terminates as the regex gets + // shorter in each recursion. + return MatchRepetitionAndRegexAtHead( + escaped, regex[0], regex[1], regex + 2, str); + } else { + // regex isn't empty, isn't "$", and doesn't start with a + // repetition. We match the first atom of regex with the first + // character of str and recurse. + return (*str != '\0') && AtomMatchesChar(escaped, *regex, *str) && + MatchRegexAtHead(regex + 1, str + 1); + } +} + +// Returns true if and only if regex matches any substring of str. regex must +// be a valid simple regular expression, or the result is undefined. +// +// The algorithm is recursive, but the recursion depth doesn't exceed +// the regex length, so we won't need to worry about running out of +// stack space normally. In rare cases the time complexity can be +// exponential with respect to the regex length + the string length, +// but usually it's must faster (often close to linear). +bool MatchRegexAnywhere(const char* regex, const char* str) { + if (regex == nullptr || str == nullptr) return false; + + if (*regex == '^') + return MatchRegexAtHead(regex + 1, str); + + // A successful match can be anywhere in str. + do { + if (MatchRegexAtHead(regex, str)) + return true; + } while (*str++ != '\0'); + return false; +} + +// Implements the RE class. + +RE::~RE() { + free(const_cast(pattern_)); + free(const_cast(full_pattern_)); +} + +// Returns true if and only if regular expression re matches the entire str. +bool RE::FullMatch(const char* str, const RE& re) { + return re.is_valid_ && MatchRegexAnywhere(re.full_pattern_, str); +} + +// Returns true if and only if regular expression re matches a substring of +// str (including str itself). +bool RE::PartialMatch(const char* str, const RE& re) { + return re.is_valid_ && MatchRegexAnywhere(re.pattern_, str); +} + +// Initializes an RE from its string representation. +void RE::Init(const char* regex) { + pattern_ = full_pattern_ = nullptr; + if (regex != nullptr) { + pattern_ = posix::StrDup(regex); + } + + is_valid_ = ValidateRegex(regex); + if (!is_valid_) { + // No need to calculate the full pattern when the regex is invalid. + return; + } + + const size_t len = strlen(regex); + // Reserves enough bytes to hold the regular expression used for a + // full match: we need space to prepend a '^', append a '$', and + // terminate the string with '\0'. + char* buffer = static_cast(malloc(len + 3)); + full_pattern_ = buffer; + + if (*regex != '^') + *buffer++ = '^'; // Makes sure full_pattern_ starts with '^'. + + // We don't use snprintf or strncpy, as they trigger a warning when + // compiled with VC++ 8.0. + memcpy(buffer, regex, len); + buffer += len; + + if (len == 0 || regex[len - 1] != '$') + *buffer++ = '$'; // Makes sure full_pattern_ ends with '$'. + + *buffer = '\0'; +} + +#endif // GTEST_USES_POSIX_RE + +const char kUnknownFile[] = "unknown file"; + +// Formats a source file path and a line number as they would appear +// in an error message from the compiler used to compile this code. +GTEST_API_ ::std::string FormatFileLocation(const char* file, int line) { + const std::string file_name(file == nullptr ? kUnknownFile : file); + + if (line < 0) { + return file_name + ":"; + } +#ifdef _MSC_VER + return file_name + "(" + StreamableToString(line) + "):"; +#else + return file_name + ":" + StreamableToString(line) + ":"; +#endif // _MSC_VER +} + +// Formats a file location for compiler-independent XML output. +// Although this function is not platform dependent, we put it next to +// FormatFileLocation in order to contrast the two functions. +// Note that FormatCompilerIndependentFileLocation() does NOT append colon +// to the file location it produces, unlike FormatFileLocation(). +GTEST_API_ ::std::string FormatCompilerIndependentFileLocation( + const char* file, int line) { + const std::string file_name(file == nullptr ? kUnknownFile : file); + + if (line < 0) + return file_name; + else + return file_name + ":" + StreamableToString(line); +} + +GTestLog::GTestLog(GTestLogSeverity severity, const char* file, int line) + : severity_(severity) { + const char* const marker = + severity == GTEST_INFO ? "[ INFO ]" : + severity == GTEST_WARNING ? "[WARNING]" : + severity == GTEST_ERROR ? "[ ERROR ]" : "[ FATAL ]"; + GetStream() << ::std::endl << marker << " " + << FormatFileLocation(file, line).c_str() << ": "; +} + +// Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. +GTestLog::~GTestLog() { + GetStream() << ::std::endl; + if (severity_ == GTEST_FATAL) { + fflush(stderr); + posix::Abort(); + } +} + +// Disable Microsoft deprecation warnings for POSIX functions called from +// this class (creat, dup, dup2, and close) +GTEST_DISABLE_MSC_DEPRECATED_PUSH_() + +#if GTEST_HAS_STREAM_REDIRECTION + +// Object that captures an output stream (stdout/stderr). +class CapturedStream { + public: + // The ctor redirects the stream to a temporary file. + explicit CapturedStream(int fd) : fd_(fd), uncaptured_fd_(dup(fd)) { +# if GTEST_OS_WINDOWS + char temp_dir_path[MAX_PATH + 1] = { '\0' }; // NOLINT + char temp_file_path[MAX_PATH + 1] = { '\0' }; // NOLINT + + ::GetTempPathA(sizeof(temp_dir_path), temp_dir_path); + const UINT success = ::GetTempFileNameA(temp_dir_path, + "gtest_redir", + 0, // Generate unique file name. + temp_file_path); + GTEST_CHECK_(success != 0) + << "Unable to create a temporary file in " << temp_dir_path; + const int captured_fd = creat(temp_file_path, _S_IREAD | _S_IWRITE); + GTEST_CHECK_(captured_fd != -1) << "Unable to open temporary file " + << temp_file_path; + filename_ = temp_file_path; +# else + // There's no guarantee that a test has write access to the current + // directory, so we create the temporary file in a temporary directory. + std::string name_template; + +# if GTEST_OS_LINUX_ANDROID + // Note: Android applications are expected to call the framework's + // Context.getExternalStorageDirectory() method through JNI to get + // the location of the world-writable SD Card directory. However, + // this requires a Context handle, which cannot be retrieved + // globally from native code. Doing so also precludes running the + // code as part of a regular standalone executable, which doesn't + // run in a Dalvik process (e.g. when running it through 'adb shell'). + // + // The location /data/local/tmp is directly accessible from native code. + // '/sdcard' and other variants cannot be relied on, as they are not + // guaranteed to be mounted, or may have a delay in mounting. + name_template = "/data/local/tmp/"; +# elif GTEST_OS_IOS + char user_temp_dir[PATH_MAX + 1]; + + // Documented alternative to NSTemporaryDirectory() (for obtaining creating + // a temporary directory) at + // https://developer.apple.com/library/archive/documentation/Security/Conceptual/SecureCodingGuide/Articles/RaceConditions.html#//apple_ref/doc/uid/TP40002585-SW10 + // + // _CS_DARWIN_USER_TEMP_DIR (as well as _CS_DARWIN_USER_CACHE_DIR) is not + // documented in the confstr() man page at + // https://developer.apple.com/library/archive/documentation/System/Conceptual/ManPages_iPhoneOS/man3/confstr.3.html#//apple_ref/doc/man/3/confstr + // but are still available, according to the WebKit patches at + // https://trac.webkit.org/changeset/262004/webkit + // https://trac.webkit.org/changeset/263705/webkit + // + // The confstr() implementation falls back to getenv("TMPDIR"). See + // https://opensource.apple.com/source/Libc/Libc-1439.100.3/gen/confstr.c.auto.html + ::confstr(_CS_DARWIN_USER_TEMP_DIR, user_temp_dir, sizeof(user_temp_dir)); + + name_template = user_temp_dir; + if (name_template.back() != GTEST_PATH_SEP_[0]) + name_template.push_back(GTEST_PATH_SEP_[0]); +# else + name_template = "/tmp/"; +# endif + name_template.append("gtest_captured_stream.XXXXXX"); + + // mkstemp() modifies the string bytes in place, and does not go beyond the + // string's length. This results in well-defined behavior in C++17. + // + // The const_cast is needed below C++17. The constraints on std::string + // implementations in C++11 and above make assumption behind the const_cast + // fairly safe. + const int captured_fd = ::mkstemp(const_cast(name_template.data())); + if (captured_fd == -1) { + GTEST_LOG_(WARNING) + << "Failed to create tmp file " << name_template + << " for test; does the test have access to the /tmp directory?"; + } + filename_ = std::move(name_template); +# endif // GTEST_OS_WINDOWS + fflush(nullptr); + dup2(captured_fd, fd_); + close(captured_fd); + } + + ~CapturedStream() { + remove(filename_.c_str()); + } + + std::string GetCapturedString() { + if (uncaptured_fd_ != -1) { + // Restores the original stream. + fflush(nullptr); + dup2(uncaptured_fd_, fd_); + close(uncaptured_fd_); + uncaptured_fd_ = -1; + } + + FILE* const file = posix::FOpen(filename_.c_str(), "r"); + if (file == nullptr) { + GTEST_LOG_(FATAL) << "Failed to open tmp file " << filename_ + << " for capturing stream."; + } + const std::string content = ReadEntireFile(file); + posix::FClose(file); + return content; + } + + private: + const int fd_; // A stream to capture. + int uncaptured_fd_; + // Name of the temporary file holding the stderr output. + ::std::string filename_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(CapturedStream); +}; + +GTEST_DISABLE_MSC_DEPRECATED_POP_() + +static CapturedStream* g_captured_stderr = nullptr; +static CapturedStream* g_captured_stdout = nullptr; + +// Starts capturing an output stream (stdout/stderr). +static void CaptureStream(int fd, const char* stream_name, + CapturedStream** stream) { + if (*stream != nullptr) { + GTEST_LOG_(FATAL) << "Only one " << stream_name + << " capturer can exist at a time."; + } + *stream = new CapturedStream(fd); +} + +// Stops capturing the output stream and returns the captured string. +static std::string GetCapturedStream(CapturedStream** captured_stream) { + const std::string content = (*captured_stream)->GetCapturedString(); + + delete *captured_stream; + *captured_stream = nullptr; + + return content; +} + +// Starts capturing stdout. +void CaptureStdout() { + CaptureStream(kStdOutFileno, "stdout", &g_captured_stdout); +} + +// Starts capturing stderr. +void CaptureStderr() { + CaptureStream(kStdErrFileno, "stderr", &g_captured_stderr); +} + +// Stops capturing stdout and returns the captured string. +std::string GetCapturedStdout() { + return GetCapturedStream(&g_captured_stdout); +} + +// Stops capturing stderr and returns the captured string. +std::string GetCapturedStderr() { + return GetCapturedStream(&g_captured_stderr); +} + +#endif // GTEST_HAS_STREAM_REDIRECTION + + + + + +size_t GetFileSize(FILE* file) { + fseek(file, 0, SEEK_END); + return static_cast(ftell(file)); +} + +std::string ReadEntireFile(FILE* file) { + const size_t file_size = GetFileSize(file); + char* const buffer = new char[file_size]; + + size_t bytes_last_read = 0; // # of bytes read in the last fread() + size_t bytes_read = 0; // # of bytes read so far + + fseek(file, 0, SEEK_SET); + + // Keeps reading the file until we cannot read further or the + // pre-determined file size is reached. + do { + bytes_last_read = fread(buffer+bytes_read, 1, file_size-bytes_read, file); + bytes_read += bytes_last_read; + } while (bytes_last_read > 0 && bytes_read < file_size); + + const std::string content(buffer, bytes_read); + delete[] buffer; + + return content; +} + +#if GTEST_HAS_DEATH_TEST +static const std::vector* g_injected_test_argvs = + nullptr; // Owned. + +std::vector GetInjectableArgvs() { + if (g_injected_test_argvs != nullptr) { + return *g_injected_test_argvs; + } + return GetArgvs(); +} + +void SetInjectableArgvs(const std::vector* new_argvs) { + if (g_injected_test_argvs != new_argvs) delete g_injected_test_argvs; + g_injected_test_argvs = new_argvs; +} + +void SetInjectableArgvs(const std::vector& new_argvs) { + SetInjectableArgvs( + new std::vector(new_argvs.begin(), new_argvs.end())); +} + +void ClearInjectableArgvs() { + delete g_injected_test_argvs; + g_injected_test_argvs = nullptr; +} +#endif // GTEST_HAS_DEATH_TEST + +#if GTEST_OS_WINDOWS_MOBILE +namespace posix { +void Abort() { + DebugBreak(); + TerminateProcess(GetCurrentProcess(), 1); +} +} // namespace posix +#endif // GTEST_OS_WINDOWS_MOBILE + +// Returns the name of the environment variable corresponding to the +// given flag. For example, FlagToEnvVar("foo") will return +// "GTEST_FOO" in the open-source version. +static std::string FlagToEnvVar(const char* flag) { + const std::string full_flag = + (Message() << GTEST_FLAG_PREFIX_ << flag).GetString(); + + Message env_var; + for (size_t i = 0; i != full_flag.length(); i++) { + env_var << ToUpper(full_flag.c_str()[i]); + } + + return env_var.GetString(); +} + +// Parses 'str' for a 32-bit signed integer. If successful, writes +// the result to *value and returns true; otherwise leaves *value +// unchanged and returns false. +bool ParseInt32(const Message& src_text, const char* str, int32_t* value) { + // Parses the environment variable as a decimal integer. + char* end = nullptr; + const long long_value = strtol(str, &end, 10); // NOLINT + + // Has strtol() consumed all characters in the string? + if (*end != '\0') { + // No - an invalid character was encountered. + Message msg; + msg << "WARNING: " << src_text + << " is expected to be a 32-bit integer, but actually" + << " has value \"" << str << "\".\n"; + printf("%s", msg.GetString().c_str()); + fflush(stdout); + return false; + } + + // Is the parsed value in the range of an int32_t? + const auto result = static_cast(long_value); + if (long_value == LONG_MAX || long_value == LONG_MIN || + // The parsed value overflows as a long. (strtol() returns + // LONG_MAX or LONG_MIN when the input overflows.) + result != long_value + // The parsed value overflows as an int32_t. + ) { + Message msg; + msg << "WARNING: " << src_text + << " is expected to be a 32-bit integer, but actually" + << " has value " << str << ", which overflows.\n"; + printf("%s", msg.GetString().c_str()); + fflush(stdout); + return false; + } + + *value = result; + return true; +} + +// Reads and returns the Boolean environment variable corresponding to +// the given flag; if it's not set, returns default_value. +// +// The value is considered true if and only if it's not "0". +bool BoolFromGTestEnv(const char* flag, bool default_value) { +#if defined(GTEST_GET_BOOL_FROM_ENV_) + return GTEST_GET_BOOL_FROM_ENV_(flag, default_value); +#else + const std::string env_var = FlagToEnvVar(flag); + const char* const string_value = posix::GetEnv(env_var.c_str()); + return string_value == nullptr ? default_value + : strcmp(string_value, "0") != 0; +#endif // defined(GTEST_GET_BOOL_FROM_ENV_) +} + +// Reads and returns a 32-bit integer stored in the environment +// variable corresponding to the given flag; if it isn't set or +// doesn't represent a valid 32-bit integer, returns default_value. +int32_t Int32FromGTestEnv(const char* flag, int32_t default_value) { +#if defined(GTEST_GET_INT32_FROM_ENV_) + return GTEST_GET_INT32_FROM_ENV_(flag, default_value); +#else + const std::string env_var = FlagToEnvVar(flag); + const char* const string_value = posix::GetEnv(env_var.c_str()); + if (string_value == nullptr) { + // The environment variable is not set. + return default_value; + } + + int32_t result = default_value; + if (!ParseInt32(Message() << "Environment variable " << env_var, + string_value, &result)) { + printf("The default value %s is used.\n", + (Message() << default_value).GetString().c_str()); + fflush(stdout); + return default_value; + } + + return result; +#endif // defined(GTEST_GET_INT32_FROM_ENV_) +} + +// As a special case for the 'output' flag, if GTEST_OUTPUT is not +// set, we look for XML_OUTPUT_FILE, which is set by the Bazel build +// system. The value of XML_OUTPUT_FILE is a filename without the +// "xml:" prefix of GTEST_OUTPUT. +// Note that this is meant to be called at the call site so it does +// not check that the flag is 'output' +// In essence this checks an env variable called XML_OUTPUT_FILE +// and if it is set we prepend "xml:" to its value, if it not set we return "" +std::string OutputFlagAlsoCheckEnvVar(){ + std::string default_value_for_output_flag = ""; + const char* xml_output_file_env = posix::GetEnv("XML_OUTPUT_FILE"); + if (nullptr != xml_output_file_env) { + default_value_for_output_flag = std::string("xml:") + xml_output_file_env; + } + return default_value_for_output_flag; +} + +// Reads and returns the string environment variable corresponding to +// the given flag; if it's not set, returns default_value. +const char* StringFromGTestEnv(const char* flag, const char* default_value) { +#if defined(GTEST_GET_STRING_FROM_ENV_) + return GTEST_GET_STRING_FROM_ENV_(flag, default_value); +#else + const std::string env_var = FlagToEnvVar(flag); + const char* const value = posix::GetEnv(env_var.c_str()); + return value == nullptr ? default_value : value; +#endif // defined(GTEST_GET_STRING_FROM_ENV_) +} + +} // namespace internal +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-printers.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-printers.cc new file mode 100644 index 000000000000..0c80ab7c1a1a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-printers.cc @@ -0,0 +1,578 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Google Test - The Google C++ Testing and Mocking Framework +// +// This file implements a universal value printer that can print a +// value of any type T: +// +// void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); +// +// It uses the << operator when possible, and prints the bytes in the +// object otherwise. A user can override its behavior for a class +// type Foo by defining either operator<<(::std::ostream&, const Foo&) +// or void PrintTo(const Foo&, ::std::ostream*) in the namespace that +// defines Foo. + +#include "gtest/gtest-printers.h" + +#include + +#include +#include +#include +#include // NOLINT +#include +#include + +#include "gtest/internal/gtest-port.h" +#include "src/gtest-internal-inl.h" + +namespace testing { + +namespace { + +using ::std::ostream; + +// Prints a segment of bytes in the given object. +GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ +GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ +void PrintByteSegmentInObjectTo(const unsigned char* obj_bytes, size_t start, + size_t count, ostream* os) { + char text[5] = ""; + for (size_t i = 0; i != count; i++) { + const size_t j = start + i; + if (i != 0) { + // Organizes the bytes into groups of 2 for easy parsing by + // human. + if ((j % 2) == 0) + *os << ' '; + else + *os << '-'; + } + GTEST_SNPRINTF_(text, sizeof(text), "%02X", obj_bytes[j]); + *os << text; + } +} + +// Prints the bytes in the given value to the given ostream. +void PrintBytesInObjectToImpl(const unsigned char* obj_bytes, size_t count, + ostream* os) { + // Tells the user how big the object is. + *os << count << "-byte object <"; + + const size_t kThreshold = 132; + const size_t kChunkSize = 64; + // If the object size is bigger than kThreshold, we'll have to omit + // some details by printing only the first and the last kChunkSize + // bytes. + if (count < kThreshold) { + PrintByteSegmentInObjectTo(obj_bytes, 0, count, os); + } else { + PrintByteSegmentInObjectTo(obj_bytes, 0, kChunkSize, os); + *os << " ... "; + // Rounds up to 2-byte boundary. + const size_t resume_pos = (count - kChunkSize + 1)/2*2; + PrintByteSegmentInObjectTo(obj_bytes, resume_pos, count - resume_pos, os); + } + *os << ">"; +} + +// Helpers for widening a character to char32_t. Since the standard does not +// specify if char / wchar_t is signed or unsigned, it is important to first +// convert it to the unsigned type of the same width before widening it to +// char32_t. +template +char32_t ToChar32(CharType in) { + return static_cast( + static_cast::type>(in)); +} + +} // namespace + +namespace internal { + +// Delegates to PrintBytesInObjectToImpl() to print the bytes in the +// given object. The delegation simplifies the implementation, which +// uses the << operator and thus is easier done outside of the +// ::testing::internal namespace, which contains a << operator that +// sometimes conflicts with the one in STL. +void PrintBytesInObjectTo(const unsigned char* obj_bytes, size_t count, + ostream* os) { + PrintBytesInObjectToImpl(obj_bytes, count, os); +} + +// Depending on the value of a char (or wchar_t), we print it in one +// of three formats: +// - as is if it's a printable ASCII (e.g. 'a', '2', ' '), +// - as a hexadecimal escape sequence (e.g. '\x7F'), or +// - as a special escape sequence (e.g. '\r', '\n'). +enum CharFormat { + kAsIs, + kHexEscape, + kSpecialEscape +}; + +// Returns true if c is a printable ASCII character. We test the +// value of c directly instead of calling isprint(), which is buggy on +// Windows Mobile. +inline bool IsPrintableAscii(char32_t c) { return 0x20 <= c && c <= 0x7E; } + +// Prints c (of type char, char8_t, char16_t, char32_t, or wchar_t) as a +// character literal without the quotes, escaping it when necessary; returns how +// c was formatted. +template +static CharFormat PrintAsCharLiteralTo(Char c, ostream* os) { + const char32_t u_c = ToChar32(c); + switch (u_c) { + case L'\0': + *os << "\\0"; + break; + case L'\'': + *os << "\\'"; + break; + case L'\\': + *os << "\\\\"; + break; + case L'\a': + *os << "\\a"; + break; + case L'\b': + *os << "\\b"; + break; + case L'\f': + *os << "\\f"; + break; + case L'\n': + *os << "\\n"; + break; + case L'\r': + *os << "\\r"; + break; + case L'\t': + *os << "\\t"; + break; + case L'\v': + *os << "\\v"; + break; + default: + if (IsPrintableAscii(u_c)) { + *os << static_cast(c); + return kAsIs; + } else { + ostream::fmtflags flags = os->flags(); + *os << "\\x" << std::hex << std::uppercase << static_cast(u_c); + os->flags(flags); + return kHexEscape; + } + } + return kSpecialEscape; +} + +// Prints a char32_t c as if it's part of a string literal, escaping it when +// necessary; returns how c was formatted. +static CharFormat PrintAsStringLiteralTo(char32_t c, ostream* os) { + switch (c) { + case L'\'': + *os << "'"; + return kAsIs; + case L'"': + *os << "\\\""; + return kSpecialEscape; + default: + return PrintAsCharLiteralTo(c, os); + } +} + +static const char* GetCharWidthPrefix(char) { + return ""; +} + +static const char* GetCharWidthPrefix(signed char) { + return ""; +} + +static const char* GetCharWidthPrefix(unsigned char) { + return ""; +} + +#ifdef __cpp_char8_t +static const char* GetCharWidthPrefix(char8_t) { + return "u8"; +} +#endif + +static const char* GetCharWidthPrefix(char16_t) { + return "u"; +} + +static const char* GetCharWidthPrefix(char32_t) { + return "U"; +} + +static const char* GetCharWidthPrefix(wchar_t) { + return "L"; +} + +// Prints a char c as if it's part of a string literal, escaping it when +// necessary; returns how c was formatted. +static CharFormat PrintAsStringLiteralTo(char c, ostream* os) { + return PrintAsStringLiteralTo(ToChar32(c), os); +} + +#ifdef __cpp_char8_t +static CharFormat PrintAsStringLiteralTo(char8_t c, ostream* os) { + return PrintAsStringLiteralTo(ToChar32(c), os); +} +#endif + +static CharFormat PrintAsStringLiteralTo(char16_t c, ostream* os) { + return PrintAsStringLiteralTo(ToChar32(c), os); +} + +static CharFormat PrintAsStringLiteralTo(wchar_t c, ostream* os) { + return PrintAsStringLiteralTo(ToChar32(c), os); +} + +// Prints a character c (of type char, char8_t, char16_t, char32_t, or wchar_t) +// and its code. '\0' is printed as "'\\0'", other unprintable characters are +// also properly escaped using the standard C++ escape sequence. +template +void PrintCharAndCodeTo(Char c, ostream* os) { + // First, print c as a literal in the most readable form we can find. + *os << GetCharWidthPrefix(c) << "'"; + const CharFormat format = PrintAsCharLiteralTo(c, os); + *os << "'"; + + // To aid user debugging, we also print c's code in decimal, unless + // it's 0 (in which case c was printed as '\\0', making the code + // obvious). + if (c == 0) + return; + *os << " (" << static_cast(c); + + // For more convenience, we print c's code again in hexadecimal, + // unless c was already printed in the form '\x##' or the code is in + // [1, 9]. + if (format == kHexEscape || (1 <= c && c <= 9)) { + // Do nothing. + } else { + *os << ", 0x" << String::FormatHexInt(static_cast(c)); + } + *os << ")"; +} + +void PrintTo(unsigned char c, ::std::ostream* os) { PrintCharAndCodeTo(c, os); } +void PrintTo(signed char c, ::std::ostream* os) { PrintCharAndCodeTo(c, os); } + +// Prints a wchar_t as a symbol if it is printable or as its internal +// code otherwise and also as its code. L'\0' is printed as "L'\\0'". +void PrintTo(wchar_t wc, ostream* os) { PrintCharAndCodeTo(wc, os); } + +// TODO(dcheng): Consider making this delegate to PrintCharAndCodeTo() as well. +void PrintTo(char32_t c, ::std::ostream* os) { + *os << std::hex << "U+" << std::uppercase << std::setfill('0') << std::setw(4) + << static_cast(c); +} + +// gcc/clang __{u,}int128_t +#if defined(__SIZEOF_INT128__) +void PrintTo(__uint128_t v, ::std::ostream* os) { + if (v == 0) { + *os << "0"; + return; + } + + // Buffer large enough for ceil(log10(2^128))==39 and the null terminator + char buf[40]; + char* p = buf + sizeof(buf); + + // Some configurations have a __uint128_t, but no support for built in + // division. Do manual long division instead. + + uint64_t high = static_cast(v >> 64); + uint64_t low = static_cast(v); + + *--p = 0; + while (high != 0 || low != 0) { + uint64_t high_mod = high % 10; + high = high / 10; + // This is the long division algorithm specialized for a divisor of 10 and + // only two elements. + // Notable values: + // 2^64 / 10 == 1844674407370955161 + // 2^64 % 10 == 6 + const uint64_t carry = 6 * high_mod + low % 10; + low = low / 10 + high_mod * 1844674407370955161 + carry / 10; + + char digit = static_cast(carry % 10); + *--p = '0' + digit; + } + *os << p; +} +void PrintTo(__int128_t v, ::std::ostream* os) { + __uint128_t uv = static_cast<__uint128_t>(v); + if (v < 0) { + *os << "-"; + uv = -uv; + } + PrintTo(uv, os); +} +#endif // __SIZEOF_INT128__ + +// Prints the given array of characters to the ostream. CharType must be either +// char, char8_t, char16_t, char32_t, or wchar_t. +// The array starts at begin, the length is len, it may include '\0' characters +// and may not be NUL-terminated. +template +GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ +GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ +static CharFormat PrintCharsAsStringTo( + const CharType* begin, size_t len, ostream* os) { + const char* const quote_prefix = GetCharWidthPrefix(*begin); + *os << quote_prefix << "\""; + bool is_previous_hex = false; + CharFormat print_format = kAsIs; + for (size_t index = 0; index < len; ++index) { + const CharType cur = begin[index]; + if (is_previous_hex && IsXDigit(cur)) { + // Previous character is of '\x..' form and this character can be + // interpreted as another hexadecimal digit in its number. Break string to + // disambiguate. + *os << "\" " << quote_prefix << "\""; + } + is_previous_hex = PrintAsStringLiteralTo(cur, os) == kHexEscape; + // Remember if any characters required hex escaping. + if (is_previous_hex) { + print_format = kHexEscape; + } + } + *os << "\""; + return print_format; +} + +// Prints a (const) char/wchar_t array of 'len' elements, starting at address +// 'begin'. CharType must be either char or wchar_t. +template +GTEST_ATTRIBUTE_NO_SANITIZE_MEMORY_ +GTEST_ATTRIBUTE_NO_SANITIZE_ADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_HWADDRESS_ +GTEST_ATTRIBUTE_NO_SANITIZE_THREAD_ +static void UniversalPrintCharArray( + const CharType* begin, size_t len, ostream* os) { + // The code + // const char kFoo[] = "foo"; + // generates an array of 4, not 3, elements, with the last one being '\0'. + // + // Therefore when printing a char array, we don't print the last element if + // it's '\0', such that the output matches the string literal as it's + // written in the source code. + if (len > 0 && begin[len - 1] == '\0') { + PrintCharsAsStringTo(begin, len - 1, os); + return; + } + + // If, however, the last element in the array is not '\0', e.g. + // const char kFoo[] = { 'f', 'o', 'o' }; + // we must print the entire array. We also print a message to indicate + // that the array is not NUL-terminated. + PrintCharsAsStringTo(begin, len, os); + *os << " (no terminating NUL)"; +} + +// Prints a (const) char array of 'len' elements, starting at address 'begin'. +void UniversalPrintArray(const char* begin, size_t len, ostream* os) { + UniversalPrintCharArray(begin, len, os); +} + +#ifdef __cpp_char8_t +// Prints a (const) char8_t array of 'len' elements, starting at address +// 'begin'. +void UniversalPrintArray(const char8_t* begin, size_t len, ostream* os) { + UniversalPrintCharArray(begin, len, os); +} +#endif + +// Prints a (const) char16_t array of 'len' elements, starting at address +// 'begin'. +void UniversalPrintArray(const char16_t* begin, size_t len, ostream* os) { + UniversalPrintCharArray(begin, len, os); +} + +// Prints a (const) char32_t array of 'len' elements, starting at address +// 'begin'. +void UniversalPrintArray(const char32_t* begin, size_t len, ostream* os) { + UniversalPrintCharArray(begin, len, os); +} + +// Prints a (const) wchar_t array of 'len' elements, starting at address +// 'begin'. +void UniversalPrintArray(const wchar_t* begin, size_t len, ostream* os) { + UniversalPrintCharArray(begin, len, os); +} + +namespace { + +// Prints a null-terminated C-style string to the ostream. +template +void PrintCStringTo(const Char* s, ostream* os) { + if (s == nullptr) { + *os << "NULL"; + } else { + *os << ImplicitCast_(s) << " pointing to "; + PrintCharsAsStringTo(s, std::char_traits::length(s), os); + } +} + +} // anonymous namespace + +void PrintTo(const char* s, ostream* os) { PrintCStringTo(s, os); } + +#ifdef __cpp_char8_t +void PrintTo(const char8_t* s, ostream* os) { PrintCStringTo(s, os); } +#endif + +void PrintTo(const char16_t* s, ostream* os) { PrintCStringTo(s, os); } + +void PrintTo(const char32_t* s, ostream* os) { PrintCStringTo(s, os); } + +// MSVC compiler can be configured to define whar_t as a typedef +// of unsigned short. Defining an overload for const wchar_t* in that case +// would cause pointers to unsigned shorts be printed as wide strings, +// possibly accessing more memory than intended and causing invalid +// memory accesses. MSVC defines _NATIVE_WCHAR_T_DEFINED symbol when +// wchar_t is implemented as a native type. +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) +// Prints the given wide C string to the ostream. +void PrintTo(const wchar_t* s, ostream* os) { PrintCStringTo(s, os); } +#endif // wchar_t is native + +namespace { + +bool ContainsUnprintableControlCodes(const char* str, size_t length) { + const unsigned char *s = reinterpret_cast(str); + + for (size_t i = 0; i < length; i++) { + unsigned char ch = *s++; + if (std::iscntrl(ch)) { + switch (ch) { + case '\t': + case '\n': + case '\r': + break; + default: + return true; + } + } + } + return false; +} + +bool IsUTF8TrailByte(unsigned char t) { return 0x80 <= t && t<= 0xbf; } + +bool IsValidUTF8(const char* str, size_t length) { + const unsigned char *s = reinterpret_cast(str); + + for (size_t i = 0; i < length;) { + unsigned char lead = s[i++]; + + if (lead <= 0x7f) { + continue; // single-byte character (ASCII) 0..7F + } + if (lead < 0xc2) { + return false; // trail byte or non-shortest form + } else if (lead <= 0xdf && (i + 1) <= length && IsUTF8TrailByte(s[i])) { + ++i; // 2-byte character + } else if (0xe0 <= lead && lead <= 0xef && (i + 2) <= length && + IsUTF8TrailByte(s[i]) && + IsUTF8TrailByte(s[i + 1]) && + // check for non-shortest form and surrogate + (lead != 0xe0 || s[i] >= 0xa0) && + (lead != 0xed || s[i] < 0xa0)) { + i += 2; // 3-byte character + } else if (0xf0 <= lead && lead <= 0xf4 && (i + 3) <= length && + IsUTF8TrailByte(s[i]) && + IsUTF8TrailByte(s[i + 1]) && + IsUTF8TrailByte(s[i + 2]) && + // check for non-shortest form + (lead != 0xf0 || s[i] >= 0x90) && + (lead != 0xf4 || s[i] < 0x90)) { + i += 3; // 4-byte character + } else { + return false; + } + } + return true; +} + +void ConditionalPrintAsText(const char* str, size_t length, ostream* os) { + if (!ContainsUnprintableControlCodes(str, length) && + IsValidUTF8(str, length)) { + *os << "\n As Text: \"" << str << "\""; + } +} + +} // anonymous namespace + +void PrintStringTo(const ::std::string& s, ostream* os) { + if (PrintCharsAsStringTo(s.data(), s.size(), os) == kHexEscape) { + if (GTEST_FLAG_GET(print_utf8)) { + ConditionalPrintAsText(s.data(), s.size(), os); + } + } +} + +#ifdef __cpp_char8_t +void PrintU8StringTo(const ::std::u8string& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} +#endif + +void PrintU16StringTo(const ::std::u16string& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} + +void PrintU32StringTo(const ::std::u32string& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} + +#if GTEST_HAS_STD_WSTRING +void PrintWideStringTo(const ::std::wstring& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} +#endif // GTEST_HAS_STD_WSTRING + +} // namespace internal + +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-test-part.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-test-part.cc new file mode 100644 index 000000000000..a938683ceded --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-test-part.cc @@ -0,0 +1,108 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// The Google C++ Testing and Mocking Framework (Google Test) + +#include "gtest/gtest-test-part.h" + +#include "gtest/internal/gtest-port.h" +#include "src/gtest-internal-inl.h" + +namespace testing { + +using internal::GetUnitTestImpl; + +// Gets the summary of the failure message by omitting the stack trace +// in it. +std::string TestPartResult::ExtractSummary(const char* message) { + const char* const stack_trace = strstr(message, internal::kStackTraceMarker); + return stack_trace == nullptr ? message : std::string(message, stack_trace); +} + +// Prints a TestPartResult object. +std::ostream& operator<<(std::ostream& os, const TestPartResult& result) { + return os << internal::FormatFileLocation(result.file_name(), + result.line_number()) + << " " + << (result.type() == TestPartResult::kSuccess + ? "Success" + : result.type() == TestPartResult::kSkip + ? "Skipped" + : result.type() == TestPartResult::kFatalFailure + ? "Fatal failure" + : "Non-fatal failure") + << ":\n" + << result.message() << std::endl; +} + +// Appends a TestPartResult to the array. +void TestPartResultArray::Append(const TestPartResult& result) { + array_.push_back(result); +} + +// Returns the TestPartResult at the given index (0-based). +const TestPartResult& TestPartResultArray::GetTestPartResult(int index) const { + if (index < 0 || index >= size()) { + printf("\nInvalid index (%d) into TestPartResultArray.\n", index); + internal::posix::Abort(); + } + + return array_[static_cast(index)]; +} + +// Returns the number of TestPartResult objects in the array. +int TestPartResultArray::size() const { + return static_cast(array_.size()); +} + +namespace internal { + +HasNewFatalFailureHelper::HasNewFatalFailureHelper() + : has_new_fatal_failure_(false), + original_reporter_(GetUnitTestImpl()-> + GetTestPartResultReporterForCurrentThread()) { + GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread(this); +} + +HasNewFatalFailureHelper::~HasNewFatalFailureHelper() { + GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread( + original_reporter_); +} + +void HasNewFatalFailureHelper::ReportTestPartResult( + const TestPartResult& result) { + if (result.fatally_failed()) + has_new_fatal_failure_ = true; + original_reporter_->ReportTestPartResult(result); +} + +} // namespace internal + +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-typed-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-typed-test.cc new file mode 100644 index 000000000000..c02c3df65995 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest-typed-test.cc @@ -0,0 +1,107 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "gtest/gtest-typed-test.h" + +#include "gtest/gtest.h" + +namespace testing { +namespace internal { + +// Skips to the first non-space char in str. Returns an empty string if str +// contains only whitespace characters. +static const char* SkipSpaces(const char* str) { + while (IsSpace(*str)) + str++; + return str; +} + +static std::vector SplitIntoTestNames(const char* src) { + std::vector name_vec; + src = SkipSpaces(src); + for (; src != nullptr; src = SkipComma(src)) { + name_vec.push_back(StripTrailingSpaces(GetPrefixUntilComma(src))); + } + return name_vec; +} + +// Verifies that registered_tests match the test names in +// registered_tests_; returns registered_tests if successful, or +// aborts the program otherwise. +const char* TypedTestSuitePState::VerifyRegisteredTestNames( + const char* test_suite_name, const char* file, int line, + const char* registered_tests) { + RegisterTypeParameterizedTestSuite(test_suite_name, CodeLocation(file, line)); + + typedef RegisteredTestsMap::const_iterator RegisteredTestIter; + registered_ = true; + + std::vector name_vec = SplitIntoTestNames(registered_tests); + + Message errors; + + std::set tests; + for (std::vector::const_iterator name_it = name_vec.begin(); + name_it != name_vec.end(); ++name_it) { + const std::string& name = *name_it; + if (tests.count(name) != 0) { + errors << "Test " << name << " is listed more than once.\n"; + continue; + } + + if (registered_tests_.count(name) != 0) { + tests.insert(name); + } else { + errors << "No test named " << name + << " can be found in this test suite.\n"; + } + } + + for (RegisteredTestIter it = registered_tests_.begin(); + it != registered_tests_.end(); + ++it) { + if (tests.count(it->first) == 0) { + errors << "You forgot to list test " << it->first << ".\n"; + } + } + + const std::string& errors_str = errors.GetString(); + if (errors_str != "") { + fprintf(stderr, "%s %s", FormatFileLocation(file, line).c_str(), + errors_str.c_str()); + fflush(stderr); + posix::Abort(); + } + + return registered_tests; +} + +} // namespace internal +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest.cc new file mode 100644 index 000000000000..46c3e7f7bd29 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest.cc @@ -0,0 +1,6824 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// The Google C++ Testing and Mocking Framework (Google Test) + +#include "gtest/gtest.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include // NOLINT +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include + +#include "gtest/gtest-assertion-result.h" +#include "gtest/gtest-spi.h" +#include "gtest/internal/custom/gtest.h" + +#if GTEST_OS_LINUX + +# include // NOLINT +# include // NOLINT +# include // NOLINT +// Declares vsnprintf(). This header is not available on Windows. +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include + +#elif GTEST_OS_ZOS +# include // NOLINT + +// On z/OS we additionally need strings.h for strcasecmp. +# include // NOLINT + +#elif GTEST_OS_WINDOWS_MOBILE // We are on Windows CE. + +# include // NOLINT +# undef min + +#elif GTEST_OS_WINDOWS // We are on Windows proper. + +# include // NOLINT +# undef min + +#ifdef _MSC_VER +# include // NOLINT +#endif + +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include // NOLINT + +# if GTEST_OS_WINDOWS_MINGW +# include // NOLINT +# endif // GTEST_OS_WINDOWS_MINGW + +#else + +// cpplint thinks that the header is already included, so we want to +// silence it. +# include // NOLINT +# include // NOLINT + +#endif // GTEST_OS_LINUX + +#if GTEST_HAS_EXCEPTIONS +# include +#endif + +#if GTEST_CAN_STREAM_RESULTS_ +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include // NOLINT +#endif + +#include "src/gtest-internal-inl.h" + +#if GTEST_OS_WINDOWS +# define vsnprintf _vsnprintf +#endif // GTEST_OS_WINDOWS + +#if GTEST_OS_MAC +#ifndef GTEST_OS_IOS +#include +#endif +#endif + +#if GTEST_HAS_ABSL +#include "absl/debugging/failure_signal_handler.h" +#include "absl/debugging/stacktrace.h" +#include "absl/debugging/symbolize.h" +#include "absl/strings/str_cat.h" +#endif // GTEST_HAS_ABSL + +namespace testing { + +using internal::CountIf; +using internal::ForEach; +using internal::GetElementOr; +using internal::Shuffle; + +// Constants. + +// A test whose test suite name or test name matches this filter is +// disabled and not run. +static const char kDisableTestFilter[] = "DISABLED_*:*/DISABLED_*"; + +// A test suite whose name matches this filter is considered a death +// test suite and will be run before test suites whose name doesn't +// match this filter. +static const char kDeathTestSuiteFilter[] = "*DeathTest:*DeathTest/*"; + +// A test filter that matches everything. +static const char kUniversalFilter[] = "*"; + +// The default output format. +static const char kDefaultOutputFormat[] = "xml"; +// The default output file. +static const char kDefaultOutputFile[] = "test_detail"; + +// The environment variable name for the test shard index. +static const char kTestShardIndex[] = "GTEST_SHARD_INDEX"; +// The environment variable name for the total number of test shards. +static const char kTestTotalShards[] = "GTEST_TOTAL_SHARDS"; +// The environment variable name for the test shard status file. +static const char kTestShardStatusFile[] = "GTEST_SHARD_STATUS_FILE"; + +namespace internal { + +// The text used in failure messages to indicate the start of the +// stack trace. +const char kStackTraceMarker[] = "\nStack trace:\n"; + +// g_help_flag is true if and only if the --help flag or an equivalent form +// is specified on the command line. +bool g_help_flag = false; + +// Utility function to Open File for Writing +static FILE* OpenFileForWriting(const std::string& output_file) { + FILE* fileout = nullptr; + FilePath output_file_path(output_file); + FilePath output_dir(output_file_path.RemoveFileName()); + + if (output_dir.CreateDirectoriesRecursively()) { + fileout = posix::FOpen(output_file.c_str(), "w"); + } + if (fileout == nullptr) { + GTEST_LOG_(FATAL) << "Unable to open file \"" << output_file << "\""; + } + return fileout; +} + +} // namespace internal + +// Bazel passes in the argument to '--test_filter' via the TESTBRIDGE_TEST_ONLY +// environment variable. +static const char* GetDefaultFilter() { + const char* const testbridge_test_only = + internal::posix::GetEnv("TESTBRIDGE_TEST_ONLY"); + if (testbridge_test_only != nullptr) { + return testbridge_test_only; + } + return kUniversalFilter; +} + +// Bazel passes in the argument to '--test_runner_fail_fast' via the +// TESTBRIDGE_TEST_RUNNER_FAIL_FAST environment variable. +static bool GetDefaultFailFast() { + const char* const testbridge_test_runner_fail_fast = + internal::posix::GetEnv("TESTBRIDGE_TEST_RUNNER_FAIL_FAST"); + if (testbridge_test_runner_fail_fast != nullptr) { + return strcmp(testbridge_test_runner_fail_fast, "1") == 0; + } + return false; +} + +} // namespace testing + +GTEST_DEFINE_bool_( + fail_fast, + testing::internal::BoolFromGTestEnv("fail_fast", + testing::GetDefaultFailFast()), + "True if and only if a test failure should stop further test execution."); + +GTEST_DEFINE_bool_( + also_run_disabled_tests, + testing::internal::BoolFromGTestEnv("also_run_disabled_tests", false), + "Run disabled tests too, in addition to the tests normally being run."); + +GTEST_DEFINE_bool_( + break_on_failure, + testing::internal::BoolFromGTestEnv("break_on_failure", false), + "True if and only if a failed assertion should be a debugger " + "break-point."); + +GTEST_DEFINE_bool_(catch_exceptions, + testing::internal::BoolFromGTestEnv("catch_exceptions", + true), + "True if and only if " GTEST_NAME_ + " should catch exceptions and treat them as test failures."); + +GTEST_DEFINE_string_( + color, testing::internal::StringFromGTestEnv("color", "auto"), + "Whether to use colors in the output. Valid values: yes, no, " + "and auto. 'auto' means to use colors if the output is " + "being sent to a terminal and the TERM environment variable " + "is set to a terminal type that supports colors."); + +GTEST_DEFINE_string_( + filter, + testing::internal::StringFromGTestEnv("filter", + testing::GetDefaultFilter()), + "A colon-separated list of glob (not regex) patterns " + "for filtering the tests to run, optionally followed by a " + "'-' and a : separated list of negative patterns (tests to " + "exclude). A test is run if it matches one of the positive " + "patterns and does not match any of the negative patterns."); + +GTEST_DEFINE_bool_( + install_failure_signal_handler, + testing::internal::BoolFromGTestEnv("install_failure_signal_handler", + false), + "If true and supported on the current platform, " GTEST_NAME_ + " should " + "install a signal handler that dumps debugging information when fatal " + "signals are raised."); + +GTEST_DEFINE_bool_(list_tests, false, + "List all tests without running them."); + +// The net priority order after flag processing is thus: +// --gtest_output command line flag +// GTEST_OUTPUT environment variable +// XML_OUTPUT_FILE environment variable +// '' +GTEST_DEFINE_string_( + output, + testing::internal::StringFromGTestEnv( + "output", testing::internal::OutputFlagAlsoCheckEnvVar().c_str()), + "A format (defaults to \"xml\" but can be specified to be \"json\"), " + "optionally followed by a colon and an output file name or directory. " + "A directory is indicated by a trailing pathname separator. " + "Examples: \"xml:filename.xml\", \"xml::directoryname/\". " + "If a directory is specified, output files will be created " + "within that directory, with file-names based on the test " + "executable's name and, if necessary, made unique by adding " + "digits."); + +GTEST_DEFINE_bool_( + brief, testing::internal::BoolFromGTestEnv("brief", false), + "True if only test failures should be displayed in text output."); + +GTEST_DEFINE_bool_(print_time, + testing::internal::BoolFromGTestEnv("print_time", true), + "True if and only if " GTEST_NAME_ + " should display elapsed time in text output."); + +GTEST_DEFINE_bool_(print_utf8, + testing::internal::BoolFromGTestEnv("print_utf8", true), + "True if and only if " GTEST_NAME_ + " prints UTF8 characters as text."); + +GTEST_DEFINE_int32_( + random_seed, testing::internal::Int32FromGTestEnv("random_seed", 0), + "Random number seed to use when shuffling test orders. Must be in range " + "[1, 99999], or 0 to use a seed based on the current time."); + +GTEST_DEFINE_int32_( + repeat, testing::internal::Int32FromGTestEnv("repeat", 1), + "How many times to repeat each test. Specify a negative number " + "for repeating forever. Useful for shaking out flaky tests."); + +GTEST_DEFINE_bool_( + recreate_environments_when_repeating, + testing::internal::BoolFromGTestEnv("recreate_environments_when_repeating", + false), + "Controls whether global test environments are recreated for each repeat " + "of the tests. If set to false the global test environments are only set " + "up once, for the first iteration, and only torn down once, for the last. " + "Useful for shaking out flaky tests with stable, expensive test " + "environments. If --gtest_repeat is set to a negative number, meaning " + "there is no last run, the environments will always be recreated to avoid " + "leaks."); + +GTEST_DEFINE_bool_(show_internal_stack_frames, false, + "True if and only if " GTEST_NAME_ + " should include internal stack frames when " + "printing test failure stack traces."); + +GTEST_DEFINE_bool_(shuffle, + testing::internal::BoolFromGTestEnv("shuffle", false), + "True if and only if " GTEST_NAME_ + " should randomize tests' order on every run."); + +GTEST_DEFINE_int32_( + stack_trace_depth, + testing::internal::Int32FromGTestEnv("stack_trace_depth", + testing::kMaxStackTraceDepth), + "The maximum number of stack frames to print when an " + "assertion fails. The valid range is 0 through 100, inclusive."); + +GTEST_DEFINE_string_( + stream_result_to, + testing::internal::StringFromGTestEnv("stream_result_to", ""), + "This flag specifies the host name and the port number on which to stream " + "test results. Example: \"localhost:555\". The flag is effective only on " + "Linux."); + +GTEST_DEFINE_bool_( + throw_on_failure, + testing::internal::BoolFromGTestEnv("throw_on_failure", false), + "When this flag is specified, a failed assertion will throw an exception " + "if exceptions are enabled or exit the program with a non-zero code " + "otherwise. For use with an external test framework."); + +#if GTEST_USE_OWN_FLAGFILE_FLAG_ +GTEST_DEFINE_string_( + flagfile, testing::internal::StringFromGTestEnv("flagfile", ""), + "This flag specifies the flagfile to read command-line flags from."); +#endif // GTEST_USE_OWN_FLAGFILE_FLAG_ + +namespace testing { +namespace internal { + +// Generates a random number from [0, range), using a Linear +// Congruential Generator (LCG). Crashes if 'range' is 0 or greater +// than kMaxRange. +uint32_t Random::Generate(uint32_t range) { + // These constants are the same as are used in glibc's rand(3). + // Use wider types than necessary to prevent unsigned overflow diagnostics. + state_ = static_cast(1103515245ULL*state_ + 12345U) % kMaxRange; + + GTEST_CHECK_(range > 0) + << "Cannot generate a number in the range [0, 0)."; + GTEST_CHECK_(range <= kMaxRange) + << "Generation of a number in [0, " << range << ") was requested, " + << "but this can only generate numbers in [0, " << kMaxRange << ")."; + + // Converting via modulus introduces a bit of downward bias, but + // it's simple, and a linear congruential generator isn't too good + // to begin with. + return state_ % range; +} + +// GTestIsInitialized() returns true if and only if the user has initialized +// Google Test. Useful for catching the user mistake of not initializing +// Google Test before calling RUN_ALL_TESTS(). +static bool GTestIsInitialized() { return GetArgvs().size() > 0; } + +// Iterates over a vector of TestSuites, keeping a running sum of the +// results of calling a given int-returning method on each. +// Returns the sum. +static int SumOverTestSuiteList(const std::vector& case_list, + int (TestSuite::*method)() const) { + int sum = 0; + for (size_t i = 0; i < case_list.size(); i++) { + sum += (case_list[i]->*method)(); + } + return sum; +} + +// Returns true if and only if the test suite passed. +static bool TestSuitePassed(const TestSuite* test_suite) { + return test_suite->should_run() && test_suite->Passed(); +} + +// Returns true if and only if the test suite failed. +static bool TestSuiteFailed(const TestSuite* test_suite) { + return test_suite->should_run() && test_suite->Failed(); +} + +// Returns true if and only if test_suite contains at least one test that +// should run. +static bool ShouldRunTestSuite(const TestSuite* test_suite) { + return test_suite->should_run(); +} + +// AssertHelper constructor. +AssertHelper::AssertHelper(TestPartResult::Type type, + const char* file, + int line, + const char* message) + : data_(new AssertHelperData(type, file, line, message)) { +} + +AssertHelper::~AssertHelper() { + delete data_; +} + +// Message assignment, for assertion streaming support. +void AssertHelper::operator=(const Message& message) const { + UnitTest::GetInstance()-> + AddTestPartResult(data_->type, data_->file, data_->line, + AppendUserMessage(data_->message, message), + UnitTest::GetInstance()->impl() + ->CurrentOsStackTraceExceptTop(1) + // Skips the stack frame for this function itself. + ); // NOLINT +} + +namespace { + +// When TEST_P is found without a matching INSTANTIATE_TEST_SUITE_P +// to creates test cases for it, a synthetic test case is +// inserted to report ether an error or a log message. +// +// This configuration bit will likely be removed at some point. +constexpr bool kErrorOnUninstantiatedParameterizedTest = true; +constexpr bool kErrorOnUninstantiatedTypeParameterizedTest = true; + +// A test that fails at a given file/line location with a given message. +class FailureTest : public Test { + public: + explicit FailureTest(const CodeLocation& loc, std::string error_message, + bool as_error) + : loc_(loc), + error_message_(std::move(error_message)), + as_error_(as_error) {} + + void TestBody() override { + if (as_error_) { + AssertHelper(TestPartResult::kNonFatalFailure, loc_.file.c_str(), + loc_.line, "") = Message() << error_message_; + } else { + std::cout << error_message_ << std::endl; + } + } + + private: + const CodeLocation loc_; + const std::string error_message_; + const bool as_error_; +}; + + +} // namespace + +std::set* GetIgnoredParameterizedTestSuites() { + return UnitTest::GetInstance()->impl()->ignored_parameterized_test_suites(); +} + +// Add a given test_suit to the list of them allow to go un-instantiated. +MarkAsIgnored::MarkAsIgnored(const char* test_suite) { + GetIgnoredParameterizedTestSuites()->insert(test_suite); +} + +// If this parameterized test suite has no instantiations (and that +// has not been marked as okay), emit a test case reporting that. +void InsertSyntheticTestCase(const std::string& name, CodeLocation location, + bool has_test_p) { + const auto& ignored = *GetIgnoredParameterizedTestSuites(); + if (ignored.find(name) != ignored.end()) return; + + const char kMissingInstantiation[] = // + " is defined via TEST_P, but never instantiated. None of the test cases " + "will run. Either no INSTANTIATE_TEST_SUITE_P is provided or the only " + "ones provided expand to nothing." + "\n\n" + "Ideally, TEST_P definitions should only ever be included as part of " + "binaries that intend to use them. (As opposed to, for example, being " + "placed in a library that may be linked in to get other utilities.)"; + + const char kMissingTestCase[] = // + " is instantiated via INSTANTIATE_TEST_SUITE_P, but no tests are " + "defined via TEST_P . No test cases will run." + "\n\n" + "Ideally, INSTANTIATE_TEST_SUITE_P should only ever be invoked from " + "code that always depend on code that provides TEST_P. Failing to do " + "so is often an indication of dead code, e.g. the last TEST_P was " + "removed but the rest got left behind."; + + std::string message = + "Parameterized test suite " + name + + (has_test_p ? kMissingInstantiation : kMissingTestCase) + + "\n\n" + "To suppress this error for this test suite, insert the following line " + "(in a non-header) in the namespace it is defined in:" + "\n\n" + "GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(" + name + ");"; + + std::string full_name = "UninstantiatedParameterizedTestSuite<" + name + ">"; + RegisterTest( // + "GoogleTestVerification", full_name.c_str(), + nullptr, // No type parameter. + nullptr, // No value parameter. + location.file.c_str(), location.line, [message, location] { + return new FailureTest(location, message, + kErrorOnUninstantiatedParameterizedTest); + }); +} + +void RegisterTypeParameterizedTestSuite(const char* test_suite_name, + CodeLocation code_location) { + GetUnitTestImpl()->type_parameterized_test_registry().RegisterTestSuite( + test_suite_name, code_location); +} + +void RegisterTypeParameterizedTestSuiteInstantiation(const char* case_name) { + GetUnitTestImpl() + ->type_parameterized_test_registry() + .RegisterInstantiation(case_name); +} + +void TypeParameterizedTestSuiteRegistry::RegisterTestSuite( + const char* test_suite_name, CodeLocation code_location) { + suites_.emplace(std::string(test_suite_name), + TypeParameterizedTestSuiteInfo(code_location)); +} + +void TypeParameterizedTestSuiteRegistry::RegisterInstantiation( + const char* test_suite_name) { + auto it = suites_.find(std::string(test_suite_name)); + if (it != suites_.end()) { + it->second.instantiated = true; + } else { + GTEST_LOG_(ERROR) << "Unknown type parameterized test suit '" + << test_suite_name << "'"; + } +} + +void TypeParameterizedTestSuiteRegistry::CheckForInstantiations() { + const auto& ignored = *GetIgnoredParameterizedTestSuites(); + for (const auto& testcase : suites_) { + if (testcase.second.instantiated) continue; + if (ignored.find(testcase.first) != ignored.end()) continue; + + std::string message = + "Type parameterized test suite " + testcase.first + + " is defined via REGISTER_TYPED_TEST_SUITE_P, but never instantiated " + "via INSTANTIATE_TYPED_TEST_SUITE_P. None of the test cases will run." + "\n\n" + "Ideally, TYPED_TEST_P definitions should only ever be included as " + "part of binaries that intend to use them. (As opposed to, for " + "example, being placed in a library that may be linked in to get other " + "utilities.)" + "\n\n" + "To suppress this error for this test suite, insert the following line " + "(in a non-header) in the namespace it is defined in:" + "\n\n" + "GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(" + + testcase.first + ");"; + + std::string full_name = + "UninstantiatedTypeParameterizedTestSuite<" + testcase.first + ">"; + RegisterTest( // + "GoogleTestVerification", full_name.c_str(), + nullptr, // No type parameter. + nullptr, // No value parameter. + testcase.second.code_location.file.c_str(), + testcase.second.code_location.line, [message, testcase] { + return new FailureTest(testcase.second.code_location, message, + kErrorOnUninstantiatedTypeParameterizedTest); + }); + } +} + +// A copy of all command line arguments. Set by InitGoogleTest(). +static ::std::vector g_argvs; + +::std::vector GetArgvs() { +#if defined(GTEST_CUSTOM_GET_ARGVS_) + // GTEST_CUSTOM_GET_ARGVS_() may return a container of std::string or + // ::string. This code converts it to the appropriate type. + const auto& custom = GTEST_CUSTOM_GET_ARGVS_(); + return ::std::vector(custom.begin(), custom.end()); +#else // defined(GTEST_CUSTOM_GET_ARGVS_) + return g_argvs; +#endif // defined(GTEST_CUSTOM_GET_ARGVS_) +} + +// Returns the current application's name, removing directory path if that +// is present. +FilePath GetCurrentExecutableName() { + FilePath result; + +#if GTEST_OS_WINDOWS || GTEST_OS_OS2 + result.Set(FilePath(GetArgvs()[0]).RemoveExtension("exe")); +#else + result.Set(FilePath(GetArgvs()[0])); +#endif // GTEST_OS_WINDOWS + + return result.RemoveDirectoryName(); +} + +// Functions for processing the gtest_output flag. + +// Returns the output format, or "" for normal printed output. +std::string UnitTestOptions::GetOutputFormat() { + std::string s = GTEST_FLAG_GET(output); + const char* const gtest_output_flag = s.c_str(); + const char* const colon = strchr(gtest_output_flag, ':'); + return (colon == nullptr) + ? std::string(gtest_output_flag) + : std::string(gtest_output_flag, + static_cast(colon - gtest_output_flag)); +} + +// Returns the name of the requested output file, or the default if none +// was explicitly specified. +std::string UnitTestOptions::GetAbsolutePathToOutputFile() { + std::string s = GTEST_FLAG_GET(output); + const char* const gtest_output_flag = s.c_str(); + + std::string format = GetOutputFormat(); + if (format.empty()) + format = std::string(kDefaultOutputFormat); + + const char* const colon = strchr(gtest_output_flag, ':'); + if (colon == nullptr) + return internal::FilePath::MakeFileName( + internal::FilePath( + UnitTest::GetInstance()->original_working_dir()), + internal::FilePath(kDefaultOutputFile), 0, + format.c_str()).string(); + + internal::FilePath output_name(colon + 1); + if (!output_name.IsAbsolutePath()) + output_name = internal::FilePath::ConcatPaths( + internal::FilePath(UnitTest::GetInstance()->original_working_dir()), + internal::FilePath(colon + 1)); + + if (!output_name.IsDirectory()) + return output_name.string(); + + internal::FilePath result(internal::FilePath::GenerateUniqueFileName( + output_name, internal::GetCurrentExecutableName(), + GetOutputFormat().c_str())); + return result.string(); +} + +// Returns true if and only if the wildcard pattern matches the string. Each +// pattern consists of regular characters, single-character wildcards (?), and +// multi-character wildcards (*). +// +// This function implements a linear-time string globbing algorithm based on +// https://research.swtch.com/glob. +static bool PatternMatchesString(const std::string& name_str, + const char* pattern, const char* pattern_end) { + const char* name = name_str.c_str(); + const char* const name_begin = name; + const char* const name_end = name + name_str.size(); + + const char* pattern_next = pattern; + const char* name_next = name; + + while (pattern < pattern_end || name < name_end) { + if (pattern < pattern_end) { + switch (*pattern) { + default: // Match an ordinary character. + if (name < name_end && *name == *pattern) { + ++pattern; + ++name; + continue; + } + break; + case '?': // Match any single character. + if (name < name_end) { + ++pattern; + ++name; + continue; + } + break; + case '*': + // Match zero or more characters. Start by skipping over the wildcard + // and matching zero characters from name. If that fails, restart and + // match one more character than the last attempt. + pattern_next = pattern; + name_next = name + 1; + ++pattern; + continue; + } + } + // Failed to match a character. Restart if possible. + if (name_begin < name_next && name_next <= name_end) { + pattern = pattern_next; + name = name_next; + continue; + } + return false; + } + return true; +} + +namespace { + +class UnitTestFilter { + public: + UnitTestFilter() = default; + + // Constructs a filter from a string of patterns separated by `:`. + explicit UnitTestFilter(const std::string& filter) { + // By design "" filter matches "" string. + SplitString(filter, ':', &patterns_); + } + + // Returns true if and only if name matches at least one of the patterns in + // the filter. + bool MatchesName(const std::string& name) const { + return std::any_of(patterns_.begin(), patterns_.end(), + [&name](const std::string& pattern) { + return PatternMatchesString( + name, pattern.c_str(), + pattern.c_str() + pattern.size()); + }); + } + + private: + std::vector patterns_; +}; + +class PositiveAndNegativeUnitTestFilter { + public: + // Constructs a positive and a negative filter from a string. The string + // contains a positive filter optionally followed by a '-' character and a + // negative filter. In case only a negative filter is provided the positive + // filter will be assumed "*". + // A filter is a list of patterns separated by ':'. + explicit PositiveAndNegativeUnitTestFilter(const std::string& filter) { + std::vector positive_and_negative_filters; + + // NOTE: `SplitString` always returns a non-empty container. + SplitString(filter, '-', &positive_and_negative_filters); + const auto& positive_filter = positive_and_negative_filters.front(); + + if (positive_and_negative_filters.size() > 1) { + positive_filter_ = UnitTestFilter( + positive_filter.empty() ? kUniversalFilter : positive_filter); + + // TODO(b/214626361): Fail on multiple '-' characters + // For the moment to preserve old behavior we concatenate the rest of the + // string parts with `-` as separator to generate the negative filter. + auto negative_filter_string = positive_and_negative_filters[1]; + for (std::size_t i = 2; i < positive_and_negative_filters.size(); i++) + negative_filter_string = + negative_filter_string + '-' + positive_and_negative_filters[i]; + negative_filter_ = UnitTestFilter(negative_filter_string); + } else { + // In case we don't have a negative filter and positive filter is "" + // we do not use kUniversalFilter by design as opposed to when we have a + // negative filter. + positive_filter_ = UnitTestFilter(positive_filter); + } + } + + // Returns true if and only if test name (this is generated by appending test + // suit name and test name via a '.' character) matches the positive filter + // and does not match the negative filter. + bool MatchesTest(const std::string& test_suite_name, + const std::string& test_name) const { + return MatchesName(test_suite_name + "." + test_name); + } + + // Returns true if and only if name matches the positive filter and does not + // match the negative filter. + bool MatchesName(const std::string& name) const { + return positive_filter_.MatchesName(name) && + !negative_filter_.MatchesName(name); + } + + private: + UnitTestFilter positive_filter_; + UnitTestFilter negative_filter_; +}; +} // namespace + +bool UnitTestOptions::MatchesFilter(const std::string& name_str, + const char* filter) { + return UnitTestFilter(filter).MatchesName(name_str); +} + +// Returns true if and only if the user-specified filter matches the test +// suite name and the test name. +bool UnitTestOptions::FilterMatchesTest(const std::string& test_suite_name, + const std::string& test_name) { + // Split --gtest_filter at '-', if there is one, to separate into + // positive filter and negative filter portions + return PositiveAndNegativeUnitTestFilter(GTEST_FLAG_GET(filter)) + .MatchesTest(test_suite_name, test_name); +} + +#if GTEST_HAS_SEH +// Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the +// given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. +// This function is useful as an __except condition. +int UnitTestOptions::GTestShouldProcessSEH(DWORD exception_code) { + // Google Test should handle a SEH exception if: + // 1. the user wants it to, AND + // 2. this is not a breakpoint exception, AND + // 3. this is not a C++ exception (VC++ implements them via SEH, + // apparently). + // + // SEH exception code for C++ exceptions. + // (see http://support.microsoft.com/kb/185294 for more information). + const DWORD kCxxExceptionCode = 0xe06d7363; + + bool should_handle = true; + + if (!GTEST_FLAG_GET(catch_exceptions)) + should_handle = false; + else if (exception_code == EXCEPTION_BREAKPOINT) + should_handle = false; + else if (exception_code == kCxxExceptionCode) + should_handle = false; + + return should_handle ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH; +} +#endif // GTEST_HAS_SEH + +} // namespace internal + +// The c'tor sets this object as the test part result reporter used by +// Google Test. The 'result' parameter specifies where to report the +// results. Intercepts only failures from the current thread. +ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( + TestPartResultArray* result) + : intercept_mode_(INTERCEPT_ONLY_CURRENT_THREAD), + result_(result) { + Init(); +} + +// The c'tor sets this object as the test part result reporter used by +// Google Test. The 'result' parameter specifies where to report the +// results. +ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( + InterceptMode intercept_mode, TestPartResultArray* result) + : intercept_mode_(intercept_mode), + result_(result) { + Init(); +} + +void ScopedFakeTestPartResultReporter::Init() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + if (intercept_mode_ == INTERCEPT_ALL_THREADS) { + old_reporter_ = impl->GetGlobalTestPartResultReporter(); + impl->SetGlobalTestPartResultReporter(this); + } else { + old_reporter_ = impl->GetTestPartResultReporterForCurrentThread(); + impl->SetTestPartResultReporterForCurrentThread(this); + } +} + +// The d'tor restores the test part result reporter used by Google Test +// before. +ScopedFakeTestPartResultReporter::~ScopedFakeTestPartResultReporter() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + if (intercept_mode_ == INTERCEPT_ALL_THREADS) { + impl->SetGlobalTestPartResultReporter(old_reporter_); + } else { + impl->SetTestPartResultReporterForCurrentThread(old_reporter_); + } +} + +// Increments the test part result count and remembers the result. +// This method is from the TestPartResultReporterInterface interface. +void ScopedFakeTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + result_->Append(result); +} + +namespace internal { + +// Returns the type ID of ::testing::Test. We should always call this +// instead of GetTypeId< ::testing::Test>() to get the type ID of +// testing::Test. This is to work around a suspected linker bug when +// using Google Test as a framework on Mac OS X. The bug causes +// GetTypeId< ::testing::Test>() to return different values depending +// on whether the call is from the Google Test framework itself or +// from user test code. GetTestTypeId() is guaranteed to always +// return the same value, as it always calls GetTypeId<>() from the +// gtest.cc, which is within the Google Test framework. +TypeId GetTestTypeId() { + return GetTypeId(); +} + +// The value of GetTestTypeId() as seen from within the Google Test +// library. This is solely for testing GetTestTypeId(). +extern const TypeId kTestTypeIdInGoogleTest = GetTestTypeId(); + +// This predicate-formatter checks that 'results' contains a test part +// failure of the given type and that the failure message contains the +// given substring. +static AssertionResult HasOneFailure(const char* /* results_expr */, + const char* /* type_expr */, + const char* /* substr_expr */, + const TestPartResultArray& results, + TestPartResult::Type type, + const std::string& substr) { + const std::string expected(type == TestPartResult::kFatalFailure ? + "1 fatal failure" : + "1 non-fatal failure"); + Message msg; + if (results.size() != 1) { + msg << "Expected: " << expected << "\n" + << " Actual: " << results.size() << " failures"; + for (int i = 0; i < results.size(); i++) { + msg << "\n" << results.GetTestPartResult(i); + } + return AssertionFailure() << msg; + } + + const TestPartResult& r = results.GetTestPartResult(0); + if (r.type() != type) { + return AssertionFailure() << "Expected: " << expected << "\n" + << " Actual:\n" + << r; + } + + if (strstr(r.message(), substr.c_str()) == nullptr) { + return AssertionFailure() << "Expected: " << expected << " containing \"" + << substr << "\"\n" + << " Actual:\n" + << r; + } + + return AssertionSuccess(); +} + +// The constructor of SingleFailureChecker remembers where to look up +// test part results, what type of failure we expect, and what +// substring the failure message should contain. +SingleFailureChecker::SingleFailureChecker(const TestPartResultArray* results, + TestPartResult::Type type, + const std::string& substr) + : results_(results), type_(type), substr_(substr) {} + +// The destructor of SingleFailureChecker verifies that the given +// TestPartResultArray contains exactly one failure that has the given +// type and contains the given substring. If that's not the case, a +// non-fatal failure will be generated. +SingleFailureChecker::~SingleFailureChecker() { + EXPECT_PRED_FORMAT3(HasOneFailure, *results_, type_, substr_); +} + +DefaultGlobalTestPartResultReporter::DefaultGlobalTestPartResultReporter( + UnitTestImpl* unit_test) : unit_test_(unit_test) {} + +void DefaultGlobalTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + unit_test_->current_test_result()->AddTestPartResult(result); + unit_test_->listeners()->repeater()->OnTestPartResult(result); +} + +DefaultPerThreadTestPartResultReporter::DefaultPerThreadTestPartResultReporter( + UnitTestImpl* unit_test) : unit_test_(unit_test) {} + +void DefaultPerThreadTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + unit_test_->GetGlobalTestPartResultReporter()->ReportTestPartResult(result); +} + +// Returns the global test part result reporter. +TestPartResultReporterInterface* +UnitTestImpl::GetGlobalTestPartResultReporter() { + internal::MutexLock lock(&global_test_part_result_reporter_mutex_); + return global_test_part_result_repoter_; +} + +// Sets the global test part result reporter. +void UnitTestImpl::SetGlobalTestPartResultReporter( + TestPartResultReporterInterface* reporter) { + internal::MutexLock lock(&global_test_part_result_reporter_mutex_); + global_test_part_result_repoter_ = reporter; +} + +// Returns the test part result reporter for the current thread. +TestPartResultReporterInterface* +UnitTestImpl::GetTestPartResultReporterForCurrentThread() { + return per_thread_test_part_result_reporter_.get(); +} + +// Sets the test part result reporter for the current thread. +void UnitTestImpl::SetTestPartResultReporterForCurrentThread( + TestPartResultReporterInterface* reporter) { + per_thread_test_part_result_reporter_.set(reporter); +} + +// Gets the number of successful test suites. +int UnitTestImpl::successful_test_suite_count() const { + return CountIf(test_suites_, TestSuitePassed); +} + +// Gets the number of failed test suites. +int UnitTestImpl::failed_test_suite_count() const { + return CountIf(test_suites_, TestSuiteFailed); +} + +// Gets the number of all test suites. +int UnitTestImpl::total_test_suite_count() const { + return static_cast(test_suites_.size()); +} + +// Gets the number of all test suites that contain at least one test +// that should run. +int UnitTestImpl::test_suite_to_run_count() const { + return CountIf(test_suites_, ShouldRunTestSuite); +} + +// Gets the number of successful tests. +int UnitTestImpl::successful_test_count() const { + return SumOverTestSuiteList(test_suites_, &TestSuite::successful_test_count); +} + +// Gets the number of skipped tests. +int UnitTestImpl::skipped_test_count() const { + return SumOverTestSuiteList(test_suites_, &TestSuite::skipped_test_count); +} + +// Gets the number of failed tests. +int UnitTestImpl::failed_test_count() const { + return SumOverTestSuiteList(test_suites_, &TestSuite::failed_test_count); +} + +// Gets the number of disabled tests that will be reported in the XML report. +int UnitTestImpl::reportable_disabled_test_count() const { + return SumOverTestSuiteList(test_suites_, + &TestSuite::reportable_disabled_test_count); +} + +// Gets the number of disabled tests. +int UnitTestImpl::disabled_test_count() const { + return SumOverTestSuiteList(test_suites_, &TestSuite::disabled_test_count); +} + +// Gets the number of tests to be printed in the XML report. +int UnitTestImpl::reportable_test_count() const { + return SumOverTestSuiteList(test_suites_, &TestSuite::reportable_test_count); +} + +// Gets the number of all tests. +int UnitTestImpl::total_test_count() const { + return SumOverTestSuiteList(test_suites_, &TestSuite::total_test_count); +} + +// Gets the number of tests that should run. +int UnitTestImpl::test_to_run_count() const { + return SumOverTestSuiteList(test_suites_, &TestSuite::test_to_run_count); +} + +// Returns the current OS stack trace as an std::string. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// CurrentOsStackTraceExceptTop(1), Foo() will be included in the +// trace but Bar() and CurrentOsStackTraceExceptTop() won't. +std::string UnitTestImpl::CurrentOsStackTraceExceptTop(int skip_count) { + return os_stack_trace_getter()->CurrentStackTrace( + static_cast(GTEST_FLAG_GET(stack_trace_depth)), skip_count + 1 + // Skips the user-specified number of frames plus this function + // itself. + ); // NOLINT +} + +// A helper class for measuring elapsed times. +class Timer { + public: + Timer() : start_(std::chrono::steady_clock::now()) {} + + // Return time elapsed in milliseconds since the timer was created. + TimeInMillis Elapsed() { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_) + .count(); + } + + private: + std::chrono::steady_clock::time_point start_; +}; + +// Returns a timestamp as milliseconds since the epoch. Note this time may jump +// around subject to adjustments by the system, to measure elapsed time use +// Timer instead. +TimeInMillis GetTimeInMillis() { + return std::chrono::duration_cast( + std::chrono::system_clock::now() - + std::chrono::system_clock::from_time_t(0)) + .count(); +} + +// Utilities + +// class String. + +#if GTEST_OS_WINDOWS_MOBILE +// Creates a UTF-16 wide string from the given ANSI string, allocating +// memory using new. The caller is responsible for deleting the return +// value using delete[]. Returns the wide string, or NULL if the +// input is NULL. +LPCWSTR String::AnsiToUtf16(const char* ansi) { + if (!ansi) return nullptr; + const int length = strlen(ansi); + const int unicode_length = + MultiByteToWideChar(CP_ACP, 0, ansi, length, nullptr, 0); + WCHAR* unicode = new WCHAR[unicode_length + 1]; + MultiByteToWideChar(CP_ACP, 0, ansi, length, + unicode, unicode_length); + unicode[unicode_length] = 0; + return unicode; +} + +// Creates an ANSI string from the given wide string, allocating +// memory using new. The caller is responsible for deleting the return +// value using delete[]. Returns the ANSI string, or NULL if the +// input is NULL. +const char* String::Utf16ToAnsi(LPCWSTR utf16_str) { + if (!utf16_str) return nullptr; + const int ansi_length = WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, nullptr, + 0, nullptr, nullptr); + char* ansi = new char[ansi_length + 1]; + WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, ansi, ansi_length, nullptr, + nullptr); + ansi[ansi_length] = 0; + return ansi; +} + +#endif // GTEST_OS_WINDOWS_MOBILE + +// Compares two C strings. Returns true if and only if they have the same +// content. +// +// Unlike strcmp(), this function can handle NULL argument(s). A NULL +// C string is considered different to any non-NULL C string, +// including the empty string. +bool String::CStringEquals(const char * lhs, const char * rhs) { + if (lhs == nullptr) return rhs == nullptr; + + if (rhs == nullptr) return false; + + return strcmp(lhs, rhs) == 0; +} + +#if GTEST_HAS_STD_WSTRING + +// Converts an array of wide chars to a narrow string using the UTF-8 +// encoding, and streams the result to the given Message object. +static void StreamWideCharsToMessage(const wchar_t* wstr, size_t length, + Message* msg) { + for (size_t i = 0; i != length; ) { // NOLINT + if (wstr[i] != L'\0') { + *msg << WideStringToUtf8(wstr + i, static_cast(length - i)); + while (i != length && wstr[i] != L'\0') + i++; + } else { + *msg << '\0'; + i++; + } + } +} + +#endif // GTEST_HAS_STD_WSTRING + +void SplitString(const ::std::string& str, char delimiter, + ::std::vector< ::std::string>* dest) { + ::std::vector< ::std::string> parsed; + ::std::string::size_type pos = 0; + while (::testing::internal::AlwaysTrue()) { + const ::std::string::size_type colon = str.find(delimiter, pos); + if (colon == ::std::string::npos) { + parsed.push_back(str.substr(pos)); + break; + } else { + parsed.push_back(str.substr(pos, colon - pos)); + pos = colon + 1; + } + } + dest->swap(parsed); +} + +} // namespace internal + +// Constructs an empty Message. +// We allocate the stringstream separately because otherwise each use of +// ASSERT/EXPECT in a procedure adds over 200 bytes to the procedure's +// stack frame leading to huge stack frames in some cases; gcc does not reuse +// the stack space. +Message::Message() : ss_(new ::std::stringstream) { + // By default, we want there to be enough precision when printing + // a double to a Message. + *ss_ << std::setprecision(std::numeric_limits::digits10 + 2); +} + +// These two overloads allow streaming a wide C string to a Message +// using the UTF-8 encoding. +Message& Message::operator <<(const wchar_t* wide_c_str) { + return *this << internal::String::ShowWideCString(wide_c_str); +} +Message& Message::operator <<(wchar_t* wide_c_str) { + return *this << internal::String::ShowWideCString(wide_c_str); +} + +#if GTEST_HAS_STD_WSTRING +// Converts the given wide string to a narrow string using the UTF-8 +// encoding, and streams the result to this Message object. +Message& Message::operator <<(const ::std::wstring& wstr) { + internal::StreamWideCharsToMessage(wstr.c_str(), wstr.length(), this); + return *this; +} +#endif // GTEST_HAS_STD_WSTRING + +// Gets the text streamed to this object so far as an std::string. +// Each '\0' character in the buffer is replaced with "\\0". +std::string Message::GetString() const { + return internal::StringStreamToString(ss_.get()); +} + +namespace internal { + +namespace edit_distance { +std::vector CalculateOptimalEdits(const std::vector& left, + const std::vector& right) { + std::vector > costs( + left.size() + 1, std::vector(right.size() + 1)); + std::vector > best_move( + left.size() + 1, std::vector(right.size() + 1)); + + // Populate for empty right. + for (size_t l_i = 0; l_i < costs.size(); ++l_i) { + costs[l_i][0] = static_cast(l_i); + best_move[l_i][0] = kRemove; + } + // Populate for empty left. + for (size_t r_i = 1; r_i < costs[0].size(); ++r_i) { + costs[0][r_i] = static_cast(r_i); + best_move[0][r_i] = kAdd; + } + + for (size_t l_i = 0; l_i < left.size(); ++l_i) { + for (size_t r_i = 0; r_i < right.size(); ++r_i) { + if (left[l_i] == right[r_i]) { + // Found a match. Consume it. + costs[l_i + 1][r_i + 1] = costs[l_i][r_i]; + best_move[l_i + 1][r_i + 1] = kMatch; + continue; + } + + const double add = costs[l_i + 1][r_i]; + const double remove = costs[l_i][r_i + 1]; + const double replace = costs[l_i][r_i]; + if (add < remove && add < replace) { + costs[l_i + 1][r_i + 1] = add + 1; + best_move[l_i + 1][r_i + 1] = kAdd; + } else if (remove < add && remove < replace) { + costs[l_i + 1][r_i + 1] = remove + 1; + best_move[l_i + 1][r_i + 1] = kRemove; + } else { + // We make replace a little more expensive than add/remove to lower + // their priority. + costs[l_i + 1][r_i + 1] = replace + 1.00001; + best_move[l_i + 1][r_i + 1] = kReplace; + } + } + } + + // Reconstruct the best path. We do it in reverse order. + std::vector best_path; + for (size_t l_i = left.size(), r_i = right.size(); l_i > 0 || r_i > 0;) { + EditType move = best_move[l_i][r_i]; + best_path.push_back(move); + l_i -= move != kAdd; + r_i -= move != kRemove; + } + std::reverse(best_path.begin(), best_path.end()); + return best_path; +} + +namespace { + +// Helper class to convert string into ids with deduplication. +class InternalStrings { + public: + size_t GetId(const std::string& str) { + IdMap::iterator it = ids_.find(str); + if (it != ids_.end()) return it->second; + size_t id = ids_.size(); + return ids_[str] = id; + } + + private: + typedef std::map IdMap; + IdMap ids_; +}; + +} // namespace + +std::vector CalculateOptimalEdits( + const std::vector& left, + const std::vector& right) { + std::vector left_ids, right_ids; + { + InternalStrings intern_table; + for (size_t i = 0; i < left.size(); ++i) { + left_ids.push_back(intern_table.GetId(left[i])); + } + for (size_t i = 0; i < right.size(); ++i) { + right_ids.push_back(intern_table.GetId(right[i])); + } + } + return CalculateOptimalEdits(left_ids, right_ids); +} + +namespace { + +// Helper class that holds the state for one hunk and prints it out to the +// stream. +// It reorders adds/removes when possible to group all removes before all +// adds. It also adds the hunk header before printint into the stream. +class Hunk { + public: + Hunk(size_t left_start, size_t right_start) + : left_start_(left_start), + right_start_(right_start), + adds_(), + removes_(), + common_() {} + + void PushLine(char edit, const char* line) { + switch (edit) { + case ' ': + ++common_; + FlushEdits(); + hunk_.push_back(std::make_pair(' ', line)); + break; + case '-': + ++removes_; + hunk_removes_.push_back(std::make_pair('-', line)); + break; + case '+': + ++adds_; + hunk_adds_.push_back(std::make_pair('+', line)); + break; + } + } + + void PrintTo(std::ostream* os) { + PrintHeader(os); + FlushEdits(); + for (std::list >::const_iterator it = + hunk_.begin(); + it != hunk_.end(); ++it) { + *os << it->first << it->second << "\n"; + } + } + + bool has_edits() const { return adds_ || removes_; } + + private: + void FlushEdits() { + hunk_.splice(hunk_.end(), hunk_removes_); + hunk_.splice(hunk_.end(), hunk_adds_); + } + + // Print a unified diff header for one hunk. + // The format is + // "@@ -, +, @@" + // where the left/right parts are omitted if unnecessary. + void PrintHeader(std::ostream* ss) const { + *ss << "@@ "; + if (removes_) { + *ss << "-" << left_start_ << "," << (removes_ + common_); + } + if (removes_ && adds_) { + *ss << " "; + } + if (adds_) { + *ss << "+" << right_start_ << "," << (adds_ + common_); + } + *ss << " @@\n"; + } + + size_t left_start_, right_start_; + size_t adds_, removes_, common_; + std::list > hunk_, hunk_adds_, hunk_removes_; +}; + +} // namespace + +// Create a list of diff hunks in Unified diff format. +// Each hunk has a header generated by PrintHeader above plus a body with +// lines prefixed with ' ' for no change, '-' for deletion and '+' for +// addition. +// 'context' represents the desired unchanged prefix/suffix around the diff. +// If two hunks are close enough that their contexts overlap, then they are +// joined into one hunk. +std::string CreateUnifiedDiff(const std::vector& left, + const std::vector& right, + size_t context) { + const std::vector edits = CalculateOptimalEdits(left, right); + + size_t l_i = 0, r_i = 0, edit_i = 0; + std::stringstream ss; + while (edit_i < edits.size()) { + // Find first edit. + while (edit_i < edits.size() && edits[edit_i] == kMatch) { + ++l_i; + ++r_i; + ++edit_i; + } + + // Find the first line to include in the hunk. + const size_t prefix_context = std::min(l_i, context); + Hunk hunk(l_i - prefix_context + 1, r_i - prefix_context + 1); + for (size_t i = prefix_context; i > 0; --i) { + hunk.PushLine(' ', left[l_i - i].c_str()); + } + + // Iterate the edits until we found enough suffix for the hunk or the input + // is over. + size_t n_suffix = 0; + for (; edit_i < edits.size(); ++edit_i) { + if (n_suffix >= context) { + // Continue only if the next hunk is very close. + auto it = edits.begin() + static_cast(edit_i); + while (it != edits.end() && *it == kMatch) ++it; + if (it == edits.end() || + static_cast(it - edits.begin()) - edit_i >= context) { + // There is no next edit or it is too far away. + break; + } + } + + EditType edit = edits[edit_i]; + // Reset count when a non match is found. + n_suffix = edit == kMatch ? n_suffix + 1 : 0; + + if (edit == kMatch || edit == kRemove || edit == kReplace) { + hunk.PushLine(edit == kMatch ? ' ' : '-', left[l_i].c_str()); + } + if (edit == kAdd || edit == kReplace) { + hunk.PushLine('+', right[r_i].c_str()); + } + + // Advance indices, depending on edit type. + l_i += edit != kAdd; + r_i += edit != kRemove; + } + + if (!hunk.has_edits()) { + // We are done. We don't want this hunk. + break; + } + + hunk.PrintTo(&ss); + } + return ss.str(); +} + +} // namespace edit_distance + +namespace { + +// The string representation of the values received in EqFailure() are already +// escaped. Split them on escaped '\n' boundaries. Leave all other escaped +// characters the same. +std::vector SplitEscapedString(const std::string& str) { + std::vector lines; + size_t start = 0, end = str.size(); + if (end > 2 && str[0] == '"' && str[end - 1] == '"') { + ++start; + --end; + } + bool escaped = false; + for (size_t i = start; i + 1 < end; ++i) { + if (escaped) { + escaped = false; + if (str[i] == 'n') { + lines.push_back(str.substr(start, i - start - 1)); + start = i + 1; + } + } else { + escaped = str[i] == '\\'; + } + } + lines.push_back(str.substr(start, end - start)); + return lines; +} + +} // namespace + +// Constructs and returns the message for an equality assertion +// (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. +// +// The first four parameters are the expressions used in the assertion +// and their values, as strings. For example, for ASSERT_EQ(foo, bar) +// where foo is 5 and bar is 6, we have: +// +// lhs_expression: "foo" +// rhs_expression: "bar" +// lhs_value: "5" +// rhs_value: "6" +// +// The ignoring_case parameter is true if and only if the assertion is a +// *_STRCASEEQ*. When it's true, the string "Ignoring case" will +// be inserted into the message. +AssertionResult EqFailure(const char* lhs_expression, + const char* rhs_expression, + const std::string& lhs_value, + const std::string& rhs_value, + bool ignoring_case) { + Message msg; + msg << "Expected equality of these values:"; + msg << "\n " << lhs_expression; + if (lhs_value != lhs_expression) { + msg << "\n Which is: " << lhs_value; + } + msg << "\n " << rhs_expression; + if (rhs_value != rhs_expression) { + msg << "\n Which is: " << rhs_value; + } + + if (ignoring_case) { + msg << "\nIgnoring case"; + } + + if (!lhs_value.empty() && !rhs_value.empty()) { + const std::vector lhs_lines = + SplitEscapedString(lhs_value); + const std::vector rhs_lines = + SplitEscapedString(rhs_value); + if (lhs_lines.size() > 1 || rhs_lines.size() > 1) { + msg << "\nWith diff:\n" + << edit_distance::CreateUnifiedDiff(lhs_lines, rhs_lines); + } + } + + return AssertionFailure() << msg; +} + +// Constructs a failure message for Boolean assertions such as EXPECT_TRUE. +std::string GetBoolAssertionFailureMessage( + const AssertionResult& assertion_result, + const char* expression_text, + const char* actual_predicate_value, + const char* expected_predicate_value) { + const char* actual_message = assertion_result.message(); + Message msg; + msg << "Value of: " << expression_text + << "\n Actual: " << actual_predicate_value; + if (actual_message[0] != '\0') + msg << " (" << actual_message << ")"; + msg << "\nExpected: " << expected_predicate_value; + return msg.GetString(); +} + +// Helper function for implementing ASSERT_NEAR. +AssertionResult DoubleNearPredFormat(const char* expr1, + const char* expr2, + const char* abs_error_expr, + double val1, + double val2, + double abs_error) { + const double diff = fabs(val1 - val2); + if (diff <= abs_error) return AssertionSuccess(); + + // Find the value which is closest to zero. + const double min_abs = std::min(fabs(val1), fabs(val2)); + // Find the distance to the next double from that value. + const double epsilon = + nextafter(min_abs, std::numeric_limits::infinity()) - min_abs; + // Detect the case where abs_error is so small that EXPECT_NEAR is + // effectively the same as EXPECT_EQUAL, and give an informative error + // message so that the situation can be more easily understood without + // requiring exotic floating-point knowledge. + // Don't do an epsilon check if abs_error is zero because that implies + // that an equality check was actually intended. + if (!(std::isnan)(val1) && !(std::isnan)(val2) && abs_error > 0 && + abs_error < epsilon) { + return AssertionFailure() + << "The difference between " << expr1 << " and " << expr2 << " is " + << diff << ", where\n" + << expr1 << " evaluates to " << val1 << ",\n" + << expr2 << " evaluates to " << val2 << ".\nThe abs_error parameter " + << abs_error_expr << " evaluates to " << abs_error + << " which is smaller than the minimum distance between doubles for " + "numbers of this magnitude which is " + << epsilon + << ", thus making this EXPECT_NEAR check equivalent to " + "EXPECT_EQUAL. Consider using EXPECT_DOUBLE_EQ instead."; + } + return AssertionFailure() + << "The difference between " << expr1 << " and " << expr2 + << " is " << diff << ", which exceeds " << abs_error_expr << ", where\n" + << expr1 << " evaluates to " << val1 << ",\n" + << expr2 << " evaluates to " << val2 << ", and\n" + << abs_error_expr << " evaluates to " << abs_error << "."; +} + + +// Helper template for implementing FloatLE() and DoubleLE(). +template +AssertionResult FloatingPointLE(const char* expr1, + const char* expr2, + RawType val1, + RawType val2) { + // Returns success if val1 is less than val2, + if (val1 < val2) { + return AssertionSuccess(); + } + + // or if val1 is almost equal to val2. + const FloatingPoint lhs(val1), rhs(val2); + if (lhs.AlmostEquals(rhs)) { + return AssertionSuccess(); + } + + // Note that the above two checks will both fail if either val1 or + // val2 is NaN, as the IEEE floating-point standard requires that + // any predicate involving a NaN must return false. + + ::std::stringstream val1_ss; + val1_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << val1; + + ::std::stringstream val2_ss; + val2_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << val2; + + return AssertionFailure() + << "Expected: (" << expr1 << ") <= (" << expr2 << ")\n" + << " Actual: " << StringStreamToString(&val1_ss) << " vs " + << StringStreamToString(&val2_ss); +} + +} // namespace internal + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +AssertionResult FloatLE(const char* expr1, const char* expr2, + float val1, float val2) { + return internal::FloatingPointLE(expr1, expr2, val1, val2); +} + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +AssertionResult DoubleLE(const char* expr1, const char* expr2, + double val1, double val2) { + return internal::FloatingPointLE(expr1, expr2, val1, val2); +} + +namespace internal { + +// The helper function for {ASSERT|EXPECT}_STREQ. +AssertionResult CmpHelperSTREQ(const char* lhs_expression, + const char* rhs_expression, + const char* lhs, + const char* rhs) { + if (String::CStringEquals(lhs, rhs)) { + return AssertionSuccess(); + } + + return EqFailure(lhs_expression, + rhs_expression, + PrintToString(lhs), + PrintToString(rhs), + false); +} + +// The helper function for {ASSERT|EXPECT}_STRCASEEQ. +AssertionResult CmpHelperSTRCASEEQ(const char* lhs_expression, + const char* rhs_expression, + const char* lhs, + const char* rhs) { + if (String::CaseInsensitiveCStringEquals(lhs, rhs)) { + return AssertionSuccess(); + } + + return EqFailure(lhs_expression, + rhs_expression, + PrintToString(lhs), + PrintToString(rhs), + true); +} + +// The helper function for {ASSERT|EXPECT}_STRNE. +AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2) { + if (!String::CStringEquals(s1, s2)) { + return AssertionSuccess(); + } else { + return AssertionFailure() << "Expected: (" << s1_expression << ") != (" + << s2_expression << "), actual: \"" + << s1 << "\" vs \"" << s2 << "\""; + } +} + +// The helper function for {ASSERT|EXPECT}_STRCASENE. +AssertionResult CmpHelperSTRCASENE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2) { + if (!String::CaseInsensitiveCStringEquals(s1, s2)) { + return AssertionSuccess(); + } else { + return AssertionFailure() + << "Expected: (" << s1_expression << ") != (" + << s2_expression << ") (ignoring case), actual: \"" + << s1 << "\" vs \"" << s2 << "\""; + } +} + +} // namespace internal + +namespace { + +// Helper functions for implementing IsSubString() and IsNotSubstring(). + +// This group of overloaded functions return true if and only if needle +// is a substring of haystack. NULL is considered a substring of +// itself only. + +bool IsSubstringPred(const char* needle, const char* haystack) { + if (needle == nullptr || haystack == nullptr) return needle == haystack; + + return strstr(haystack, needle) != nullptr; +} + +bool IsSubstringPred(const wchar_t* needle, const wchar_t* haystack) { + if (needle == nullptr || haystack == nullptr) return needle == haystack; + + return wcsstr(haystack, needle) != nullptr; +} + +// StringType here can be either ::std::string or ::std::wstring. +template +bool IsSubstringPred(const StringType& needle, + const StringType& haystack) { + return haystack.find(needle) != StringType::npos; +} + +// This function implements either IsSubstring() or IsNotSubstring(), +// depending on the value of the expected_to_be_substring parameter. +// StringType here can be const char*, const wchar_t*, ::std::string, +// or ::std::wstring. +template +AssertionResult IsSubstringImpl( + bool expected_to_be_substring, + const char* needle_expr, const char* haystack_expr, + const StringType& needle, const StringType& haystack) { + if (IsSubstringPred(needle, haystack) == expected_to_be_substring) + return AssertionSuccess(); + + const bool is_wide_string = sizeof(needle[0]) > 1; + const char* const begin_string_quote = is_wide_string ? "L\"" : "\""; + return AssertionFailure() + << "Value of: " << needle_expr << "\n" + << " Actual: " << begin_string_quote << needle << "\"\n" + << "Expected: " << (expected_to_be_substring ? "" : "not ") + << "a substring of " << haystack_expr << "\n" + << "Which is: " << begin_string_quote << haystack << "\""; +} + +} // namespace + +// IsSubstring() and IsNotSubstring() check whether needle is a +// substring of haystack (NULL is considered a substring of itself +// only), and return an appropriate error message when they fail. + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +#if GTEST_HAS_STD_WSTRING +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} +#endif // GTEST_HAS_STD_WSTRING + +namespace internal { + +#if GTEST_OS_WINDOWS + +namespace { + +// Helper function for IsHRESULT{SuccessFailure} predicates +AssertionResult HRESULTFailureHelper(const char* expr, + const char* expected, + long hr) { // NOLINT +# if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_WINDOWS_TV_TITLE + + // Windows CE doesn't support FormatMessage. + const char error_text[] = ""; + +# else + + // Looks up the human-readable system message for the HRESULT code + // and since we're not passing any params to FormatMessage, we don't + // want inserts expanded. + const DWORD kFlags = FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS; + const DWORD kBufSize = 4096; + // Gets the system's human readable message string for this HRESULT. + char error_text[kBufSize] = { '\0' }; + DWORD message_length = ::FormatMessageA(kFlags, + 0, // no source, we're asking system + static_cast(hr), // the error + 0, // no line width restrictions + error_text, // output buffer + kBufSize, // buf size + nullptr); // no arguments for inserts + // Trims tailing white space (FormatMessage leaves a trailing CR-LF) + for (; message_length && IsSpace(error_text[message_length - 1]); + --message_length) { + error_text[message_length - 1] = '\0'; + } + +# endif // GTEST_OS_WINDOWS_MOBILE + + const std::string error_hex("0x" + String::FormatHexInt(hr)); + return ::testing::AssertionFailure() + << "Expected: " << expr << " " << expected << ".\n" + << " Actual: " << error_hex << " " << error_text << "\n"; +} + +} // namespace + +AssertionResult IsHRESULTSuccess(const char* expr, long hr) { // NOLINT + if (SUCCEEDED(hr)) { + return AssertionSuccess(); + } + return HRESULTFailureHelper(expr, "succeeds", hr); +} + +AssertionResult IsHRESULTFailure(const char* expr, long hr) { // NOLINT + if (FAILED(hr)) { + return AssertionSuccess(); + } + return HRESULTFailureHelper(expr, "fails", hr); +} + +#endif // GTEST_OS_WINDOWS + +// Utility functions for encoding Unicode text (wide strings) in +// UTF-8. + +// A Unicode code-point can have up to 21 bits, and is encoded in UTF-8 +// like this: +// +// Code-point length Encoding +// 0 - 7 bits 0xxxxxxx +// 8 - 11 bits 110xxxxx 10xxxxxx +// 12 - 16 bits 1110xxxx 10xxxxxx 10xxxxxx +// 17 - 21 bits 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + +// The maximum code-point a one-byte UTF-8 sequence can represent. +constexpr uint32_t kMaxCodePoint1 = (static_cast(1) << 7) - 1; + +// The maximum code-point a two-byte UTF-8 sequence can represent. +constexpr uint32_t kMaxCodePoint2 = (static_cast(1) << (5 + 6)) - 1; + +// The maximum code-point a three-byte UTF-8 sequence can represent. +constexpr uint32_t kMaxCodePoint3 = (static_cast(1) << (4 + 2*6)) - 1; + +// The maximum code-point a four-byte UTF-8 sequence can represent. +constexpr uint32_t kMaxCodePoint4 = (static_cast(1) << (3 + 3*6)) - 1; + +// Chops off the n lowest bits from a bit pattern. Returns the n +// lowest bits. As a side effect, the original bit pattern will be +// shifted to the right by n bits. +inline uint32_t ChopLowBits(uint32_t* bits, int n) { + const uint32_t low_bits = *bits & ((static_cast(1) << n) - 1); + *bits >>= n; + return low_bits; +} + +// Converts a Unicode code point to a narrow string in UTF-8 encoding. +// code_point parameter is of type uint32_t because wchar_t may not be +// wide enough to contain a code point. +// If the code_point is not a valid Unicode code point +// (i.e. outside of Unicode range U+0 to U+10FFFF) it will be converted +// to "(Invalid Unicode 0xXXXXXXXX)". +std::string CodePointToUtf8(uint32_t code_point) { + if (code_point > kMaxCodePoint4) { + return "(Invalid Unicode 0x" + String::FormatHexUInt32(code_point) + ")"; + } + + char str[5]; // Big enough for the largest valid code point. + if (code_point <= kMaxCodePoint1) { + str[1] = '\0'; + str[0] = static_cast(code_point); // 0xxxxxxx + } else if (code_point <= kMaxCodePoint2) { + str[2] = '\0'; + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xC0 | code_point); // 110xxxxx + } else if (code_point <= kMaxCodePoint3) { + str[3] = '\0'; + str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xE0 | code_point); // 1110xxxx + } else { // code_point <= kMaxCodePoint4 + str[4] = '\0'; + str[3] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xF0 | code_point); // 11110xxx + } + return str; +} + +// The following two functions only make sense if the system +// uses UTF-16 for wide string encoding. All supported systems +// with 16 bit wchar_t (Windows, Cygwin) do use UTF-16. + +// Determines if the arguments constitute UTF-16 surrogate pair +// and thus should be combined into a single Unicode code point +// using CreateCodePointFromUtf16SurrogatePair. +inline bool IsUtf16SurrogatePair(wchar_t first, wchar_t second) { + return sizeof(wchar_t) == 2 && + (first & 0xFC00) == 0xD800 && (second & 0xFC00) == 0xDC00; +} + +// Creates a Unicode code point from UTF16 surrogate pair. +inline uint32_t CreateCodePointFromUtf16SurrogatePair(wchar_t first, + wchar_t second) { + const auto first_u = static_cast(first); + const auto second_u = static_cast(second); + const uint32_t mask = (1 << 10) - 1; + return (sizeof(wchar_t) == 2) + ? (((first_u & mask) << 10) | (second_u & mask)) + 0x10000 + : + // This function should not be called when the condition is + // false, but we provide a sensible default in case it is. + first_u; +} + +// Converts a wide string to a narrow string in UTF-8 encoding. +// The wide string is assumed to have the following encoding: +// UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin) +// UTF-32 if sizeof(wchar_t) == 4 (on Linux) +// Parameter str points to a null-terminated wide string. +// Parameter num_chars may additionally limit the number +// of wchar_t characters processed. -1 is used when the entire string +// should be processed. +// If the string contains code points that are not valid Unicode code points +// (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding +// and contains invalid UTF-16 surrogate pairs, values in those pairs +// will be encoded as individual Unicode characters from Basic Normal Plane. +std::string WideStringToUtf8(const wchar_t* str, int num_chars) { + if (num_chars == -1) + num_chars = static_cast(wcslen(str)); + + ::std::stringstream stream; + for (int i = 0; i < num_chars; ++i) { + uint32_t unicode_code_point; + + if (str[i] == L'\0') { + break; + } else if (i + 1 < num_chars && IsUtf16SurrogatePair(str[i], str[i + 1])) { + unicode_code_point = CreateCodePointFromUtf16SurrogatePair(str[i], + str[i + 1]); + i++; + } else { + unicode_code_point = static_cast(str[i]); + } + + stream << CodePointToUtf8(unicode_code_point); + } + return StringStreamToString(&stream); +} + +// Converts a wide C string to an std::string using the UTF-8 encoding. +// NULL will be converted to "(null)". +std::string String::ShowWideCString(const wchar_t * wide_c_str) { + if (wide_c_str == nullptr) return "(null)"; + + return internal::WideStringToUtf8(wide_c_str, -1); +} + +// Compares two wide C strings. Returns true if and only if they have the +// same content. +// +// Unlike wcscmp(), this function can handle NULL argument(s). A NULL +// C string is considered different to any non-NULL C string, +// including the empty string. +bool String::WideCStringEquals(const wchar_t * lhs, const wchar_t * rhs) { + if (lhs == nullptr) return rhs == nullptr; + + if (rhs == nullptr) return false; + + return wcscmp(lhs, rhs) == 0; +} + +// Helper function for *_STREQ on wide strings. +AssertionResult CmpHelperSTREQ(const char* lhs_expression, + const char* rhs_expression, + const wchar_t* lhs, + const wchar_t* rhs) { + if (String::WideCStringEquals(lhs, rhs)) { + return AssertionSuccess(); + } + + return EqFailure(lhs_expression, + rhs_expression, + PrintToString(lhs), + PrintToString(rhs), + false); +} + +// Helper function for *_STRNE on wide strings. +AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const wchar_t* s1, + const wchar_t* s2) { + if (!String::WideCStringEquals(s1, s2)) { + return AssertionSuccess(); + } + + return AssertionFailure() << "Expected: (" << s1_expression << ") != (" + << s2_expression << "), actual: " + << PrintToString(s1) + << " vs " << PrintToString(s2); +} + +// Compares two C strings, ignoring case. Returns true if and only if they have +// the same content. +// +// Unlike strcasecmp(), this function can handle NULL argument(s). A +// NULL C string is considered different to any non-NULL C string, +// including the empty string. +bool String::CaseInsensitiveCStringEquals(const char * lhs, const char * rhs) { + if (lhs == nullptr) return rhs == nullptr; + if (rhs == nullptr) return false; + return posix::StrCaseCmp(lhs, rhs) == 0; +} + +// Compares two wide C strings, ignoring case. Returns true if and only if they +// have the same content. +// +// Unlike wcscasecmp(), this function can handle NULL argument(s). +// A NULL C string is considered different to any non-NULL wide C string, +// including the empty string. +// NB: The implementations on different platforms slightly differ. +// On windows, this method uses _wcsicmp which compares according to LC_CTYPE +// environment variable. On GNU platform this method uses wcscasecmp +// which compares according to LC_CTYPE category of the current locale. +// On MacOS X, it uses towlower, which also uses LC_CTYPE category of the +// current locale. +bool String::CaseInsensitiveWideCStringEquals(const wchar_t* lhs, + const wchar_t* rhs) { + if (lhs == nullptr) return rhs == nullptr; + + if (rhs == nullptr) return false; + +#if GTEST_OS_WINDOWS + return _wcsicmp(lhs, rhs) == 0; +#elif GTEST_OS_LINUX && !GTEST_OS_LINUX_ANDROID + return wcscasecmp(lhs, rhs) == 0; +#else + // Android, Mac OS X and Cygwin don't define wcscasecmp. + // Other unknown OSes may not define it either. + wint_t left, right; + do { + left = towlower(static_cast(*lhs++)); + right = towlower(static_cast(*rhs++)); + } while (left && left == right); + return left == right; +#endif // OS selector +} + +// Returns true if and only if str ends with the given suffix, ignoring case. +// Any string is considered to end with an empty suffix. +bool String::EndsWithCaseInsensitive( + const std::string& str, const std::string& suffix) { + const size_t str_len = str.length(); + const size_t suffix_len = suffix.length(); + return (str_len >= suffix_len) && + CaseInsensitiveCStringEquals(str.c_str() + str_len - suffix_len, + suffix.c_str()); +} + +// Formats an int value as "%02d". +std::string String::FormatIntWidth2(int value) { + return FormatIntWidthN(value, 2); +} + +// Formats an int value to given width with leading zeros. +std::string String::FormatIntWidthN(int value, int width) { + std::stringstream ss; + ss << std::setfill('0') << std::setw(width) << value; + return ss.str(); +} + +// Formats an int value as "%X". +std::string String::FormatHexUInt32(uint32_t value) { + std::stringstream ss; + ss << std::hex << std::uppercase << value; + return ss.str(); +} + +// Formats an int value as "%X". +std::string String::FormatHexInt(int value) { + return FormatHexUInt32(static_cast(value)); +} + +// Formats a byte as "%02X". +std::string String::FormatByte(unsigned char value) { + std::stringstream ss; + ss << std::setfill('0') << std::setw(2) << std::hex << std::uppercase + << static_cast(value); + return ss.str(); +} + +// Converts the buffer in a stringstream to an std::string, converting NUL +// bytes to "\\0" along the way. +std::string StringStreamToString(::std::stringstream* ss) { + const ::std::string& str = ss->str(); + const char* const start = str.c_str(); + const char* const end = start + str.length(); + + std::string result; + result.reserve(static_cast(2 * (end - start))); + for (const char* ch = start; ch != end; ++ch) { + if (*ch == '\0') { + result += "\\0"; // Replaces NUL with "\\0"; + } else { + result += *ch; + } + } + + return result; +} + +// Appends the user-supplied message to the Google-Test-generated message. +std::string AppendUserMessage(const std::string& gtest_msg, + const Message& user_msg) { + // Appends the user message if it's non-empty. + const std::string user_msg_string = user_msg.GetString(); + if (user_msg_string.empty()) { + return gtest_msg; + } + if (gtest_msg.empty()) { + return user_msg_string; + } + return gtest_msg + "\n" + user_msg_string; +} + +} // namespace internal + +// class TestResult + +// Creates an empty TestResult. +TestResult::TestResult() + : death_test_count_(0), start_timestamp_(0), elapsed_time_(0) {} + +// D'tor. +TestResult::~TestResult() { +} + +// Returns the i-th test part result among all the results. i can +// range from 0 to total_part_count() - 1. If i is not in that range, +// aborts the program. +const TestPartResult& TestResult::GetTestPartResult(int i) const { + if (i < 0 || i >= total_part_count()) + internal::posix::Abort(); + return test_part_results_.at(static_cast(i)); +} + +// Returns the i-th test property. i can range from 0 to +// test_property_count() - 1. If i is not in that range, aborts the +// program. +const TestProperty& TestResult::GetTestProperty(int i) const { + if (i < 0 || i >= test_property_count()) + internal::posix::Abort(); + return test_properties_.at(static_cast(i)); +} + +// Clears the test part results. +void TestResult::ClearTestPartResults() { + test_part_results_.clear(); +} + +// Adds a test part result to the list. +void TestResult::AddTestPartResult(const TestPartResult& test_part_result) { + test_part_results_.push_back(test_part_result); +} + +// Adds a test property to the list. If a property with the same key as the +// supplied property is already represented, the value of this test_property +// replaces the old value for that key. +void TestResult::RecordProperty(const std::string& xml_element, + const TestProperty& test_property) { + if (!ValidateTestProperty(xml_element, test_property)) { + return; + } + internal::MutexLock lock(&test_properties_mutex_); + const std::vector::iterator property_with_matching_key = + std::find_if(test_properties_.begin(), test_properties_.end(), + internal::TestPropertyKeyIs(test_property.key())); + if (property_with_matching_key == test_properties_.end()) { + test_properties_.push_back(test_property); + return; + } + property_with_matching_key->SetValue(test_property.value()); +} + +// The list of reserved attributes used in the element of XML +// output. +static const char* const kReservedTestSuitesAttributes[] = { + "disabled", + "errors", + "failures", + "name", + "random_seed", + "tests", + "time", + "timestamp" +}; + +// The list of reserved attributes used in the element of XML +// output. +static const char* const kReservedTestSuiteAttributes[] = { + "disabled", "errors", "failures", "name", + "tests", "time", "timestamp", "skipped"}; + +// The list of reserved attributes used in the element of XML output. +static const char* const kReservedTestCaseAttributes[] = { + "classname", "name", "status", "time", "type_param", + "value_param", "file", "line"}; + +// Use a slightly different set for allowed output to ensure existing tests can +// still RecordProperty("result") or "RecordProperty(timestamp") +static const char* const kReservedOutputTestCaseAttributes[] = { + "classname", "name", "status", "time", "type_param", + "value_param", "file", "line", "result", "timestamp"}; + +template +std::vector ArrayAsVector(const char* const (&array)[kSize]) { + return std::vector(array, array + kSize); +} + +static std::vector GetReservedAttributesForElement( + const std::string& xml_element) { + if (xml_element == "testsuites") { + return ArrayAsVector(kReservedTestSuitesAttributes); + } else if (xml_element == "testsuite") { + return ArrayAsVector(kReservedTestSuiteAttributes); + } else if (xml_element == "testcase") { + return ArrayAsVector(kReservedTestCaseAttributes); + } else { + GTEST_CHECK_(false) << "Unrecognized xml_element provided: " << xml_element; + } + // This code is unreachable but some compilers may not realizes that. + return std::vector(); +} + +// TODO(jdesprez): Merge the two getReserved attributes once skip is improved +static std::vector GetReservedOutputAttributesForElement( + const std::string& xml_element) { + if (xml_element == "testsuites") { + return ArrayAsVector(kReservedTestSuitesAttributes); + } else if (xml_element == "testsuite") { + return ArrayAsVector(kReservedTestSuiteAttributes); + } else if (xml_element == "testcase") { + return ArrayAsVector(kReservedOutputTestCaseAttributes); + } else { + GTEST_CHECK_(false) << "Unrecognized xml_element provided: " << xml_element; + } + // This code is unreachable but some compilers may not realizes that. + return std::vector(); +} + +static std::string FormatWordList(const std::vector& words) { + Message word_list; + for (size_t i = 0; i < words.size(); ++i) { + if (i > 0 && words.size() > 2) { + word_list << ", "; + } + if (i == words.size() - 1) { + word_list << "and "; + } + word_list << "'" << words[i] << "'"; + } + return word_list.GetString(); +} + +static bool ValidateTestPropertyName( + const std::string& property_name, + const std::vector& reserved_names) { + if (std::find(reserved_names.begin(), reserved_names.end(), property_name) != + reserved_names.end()) { + ADD_FAILURE() << "Reserved key used in RecordProperty(): " << property_name + << " (" << FormatWordList(reserved_names) + << " are reserved by " << GTEST_NAME_ << ")"; + return false; + } + return true; +} + +// Adds a failure if the key is a reserved attribute of the element named +// xml_element. Returns true if the property is valid. +bool TestResult::ValidateTestProperty(const std::string& xml_element, + const TestProperty& test_property) { + return ValidateTestPropertyName(test_property.key(), + GetReservedAttributesForElement(xml_element)); +} + +// Clears the object. +void TestResult::Clear() { + test_part_results_.clear(); + test_properties_.clear(); + death_test_count_ = 0; + elapsed_time_ = 0; +} + +// Returns true off the test part was skipped. +static bool TestPartSkipped(const TestPartResult& result) { + return result.skipped(); +} + +// Returns true if and only if the test was skipped. +bool TestResult::Skipped() const { + return !Failed() && CountIf(test_part_results_, TestPartSkipped) > 0; +} + +// Returns true if and only if the test failed. +bool TestResult::Failed() const { + for (int i = 0; i < total_part_count(); ++i) { + if (GetTestPartResult(i).failed()) + return true; + } + return false; +} + +// Returns true if and only if the test part fatally failed. +static bool TestPartFatallyFailed(const TestPartResult& result) { + return result.fatally_failed(); +} + +// Returns true if and only if the test fatally failed. +bool TestResult::HasFatalFailure() const { + return CountIf(test_part_results_, TestPartFatallyFailed) > 0; +} + +// Returns true if and only if the test part non-fatally failed. +static bool TestPartNonfatallyFailed(const TestPartResult& result) { + return result.nonfatally_failed(); +} + +// Returns true if and only if the test has a non-fatal failure. +bool TestResult::HasNonfatalFailure() const { + return CountIf(test_part_results_, TestPartNonfatallyFailed) > 0; +} + +// Gets the number of all test parts. This is the sum of the number +// of successful test parts and the number of failed test parts. +int TestResult::total_part_count() const { + return static_cast(test_part_results_.size()); +} + +// Returns the number of the test properties. +int TestResult::test_property_count() const { + return static_cast(test_properties_.size()); +} + +// class Test + +// Creates a Test object. + +// The c'tor saves the states of all flags. +Test::Test() + : gtest_flag_saver_(new GTEST_FLAG_SAVER_) { +} + +// The d'tor restores the states of all flags. The actual work is +// done by the d'tor of the gtest_flag_saver_ field, and thus not +// visible here. +Test::~Test() { +} + +// Sets up the test fixture. +// +// A sub-class may override this. +void Test::SetUp() { +} + +// Tears down the test fixture. +// +// A sub-class may override this. +void Test::TearDown() { +} + +// Allows user supplied key value pairs to be recorded for later output. +void Test::RecordProperty(const std::string& key, const std::string& value) { + UnitTest::GetInstance()->RecordProperty(key, value); +} + +// Allows user supplied key value pairs to be recorded for later output. +void Test::RecordProperty(const std::string& key, int value) { + Message value_message; + value_message << value; + RecordProperty(key, value_message.GetString().c_str()); +} + +namespace internal { + +void ReportFailureInUnknownLocation(TestPartResult::Type result_type, + const std::string& message) { + // This function is a friend of UnitTest and as such has access to + // AddTestPartResult. + UnitTest::GetInstance()->AddTestPartResult( + result_type, + nullptr, // No info about the source file where the exception occurred. + -1, // We have no info on which line caused the exception. + message, + ""); // No stack trace, either. +} + +} // namespace internal + +// Google Test requires all tests in the same test suite to use the same test +// fixture class. This function checks if the current test has the +// same fixture class as the first test in the current test suite. If +// yes, it returns true; otherwise it generates a Google Test failure and +// returns false. +bool Test::HasSameFixtureClass() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + const TestSuite* const test_suite = impl->current_test_suite(); + + // Info about the first test in the current test suite. + const TestInfo* const first_test_info = test_suite->test_info_list()[0]; + const internal::TypeId first_fixture_id = first_test_info->fixture_class_id_; + const char* const first_test_name = first_test_info->name(); + + // Info about the current test. + const TestInfo* const this_test_info = impl->current_test_info(); + const internal::TypeId this_fixture_id = this_test_info->fixture_class_id_; + const char* const this_test_name = this_test_info->name(); + + if (this_fixture_id != first_fixture_id) { + // Is the first test defined using TEST? + const bool first_is_TEST = first_fixture_id == internal::GetTestTypeId(); + // Is this test defined using TEST? + const bool this_is_TEST = this_fixture_id == internal::GetTestTypeId(); + + if (first_is_TEST || this_is_TEST) { + // Both TEST and TEST_F appear in same test suite, which is incorrect. + // Tell the user how to fix this. + + // Gets the name of the TEST and the name of the TEST_F. Note + // that first_is_TEST and this_is_TEST cannot both be true, as + // the fixture IDs are different for the two tests. + const char* const TEST_name = + first_is_TEST ? first_test_name : this_test_name; + const char* const TEST_F_name = + first_is_TEST ? this_test_name : first_test_name; + + ADD_FAILURE() + << "All tests in the same test suite must use the same test fixture\n" + << "class, so mixing TEST_F and TEST in the same test suite is\n" + << "illegal. In test suite " << this_test_info->test_suite_name() + << ",\n" + << "test " << TEST_F_name << " is defined using TEST_F but\n" + << "test " << TEST_name << " is defined using TEST. You probably\n" + << "want to change the TEST to TEST_F or move it to another test\n" + << "case."; + } else { + // Two fixture classes with the same name appear in two different + // namespaces, which is not allowed. Tell the user how to fix this. + ADD_FAILURE() + << "All tests in the same test suite must use the same test fixture\n" + << "class. However, in test suite " + << this_test_info->test_suite_name() << ",\n" + << "you defined test " << first_test_name << " and test " + << this_test_name << "\n" + << "using two different test fixture classes. This can happen if\n" + << "the two classes are from different namespaces or translation\n" + << "units and have the same name. You should probably rename one\n" + << "of the classes to put the tests into different test suites."; + } + return false; + } + + return true; +} + +#if GTEST_HAS_SEH + +// Adds an "exception thrown" fatal failure to the current test. This +// function returns its result via an output parameter pointer because VC++ +// prohibits creation of objects with destructors on stack in functions +// using __try (see error C2712). +static std::string* FormatSehExceptionMessage(DWORD exception_code, + const char* location) { + Message message; + message << "SEH exception with code 0x" << std::setbase(16) << + exception_code << std::setbase(10) << " thrown in " << location << "."; + + return new std::string(message.GetString()); +} + +#endif // GTEST_HAS_SEH + +namespace internal { + +#if GTEST_HAS_EXCEPTIONS + +// Adds an "exception thrown" fatal failure to the current test. +static std::string FormatCxxExceptionMessage(const char* description, + const char* location) { + Message message; + if (description != nullptr) { + message << "C++ exception with description \"" << description << "\""; + } else { + message << "Unknown C++ exception"; + } + message << " thrown in " << location << "."; + + return message.GetString(); +} + +static std::string PrintTestPartResultToString( + const TestPartResult& test_part_result); + +GoogleTestFailureException::GoogleTestFailureException( + const TestPartResult& failure) + : ::std::runtime_error(PrintTestPartResultToString(failure).c_str()) {} + +#endif // GTEST_HAS_EXCEPTIONS + +// We put these helper functions in the internal namespace as IBM's xlC +// compiler rejects the code if they were declared static. + +// Runs the given method and handles SEH exceptions it throws, when +// SEH is supported; returns the 0-value for type Result in case of an +// SEH exception. (Microsoft compilers cannot handle SEH and C++ +// exceptions in the same function. Therefore, we provide a separate +// wrapper function for handling SEH exceptions.) +template +Result HandleSehExceptionsInMethodIfSupported( + T* object, Result (T::*method)(), const char* location) { +#if GTEST_HAS_SEH + __try { + return (object->*method)(); + } __except (internal::UnitTestOptions::GTestShouldProcessSEH( // NOLINT + GetExceptionCode())) { + // We create the exception message on the heap because VC++ prohibits + // creation of objects with destructors on stack in functions using __try + // (see error C2712). + std::string* exception_message = FormatSehExceptionMessage( + GetExceptionCode(), location); + internal::ReportFailureInUnknownLocation(TestPartResult::kFatalFailure, + *exception_message); + delete exception_message; + return static_cast(0); + } +#else + (void)location; + return (object->*method)(); +#endif // GTEST_HAS_SEH +} + +// Runs the given method and catches and reports C++ and/or SEH-style +// exceptions, if they are supported; returns the 0-value for type +// Result in case of an SEH exception. +template +Result HandleExceptionsInMethodIfSupported( + T* object, Result (T::*method)(), const char* location) { + // NOTE: The user code can affect the way in which Google Test handles + // exceptions by setting GTEST_FLAG(catch_exceptions), but only before + // RUN_ALL_TESTS() starts. It is technically possible to check the flag + // after the exception is caught and either report or re-throw the + // exception based on the flag's value: + // + // try { + // // Perform the test method. + // } catch (...) { + // if (GTEST_FLAG_GET(catch_exceptions)) + // // Report the exception as failure. + // else + // throw; // Re-throws the original exception. + // } + // + // However, the purpose of this flag is to allow the program to drop into + // the debugger when the exception is thrown. On most platforms, once the + // control enters the catch block, the exception origin information is + // lost and the debugger will stop the program at the point of the + // re-throw in this function -- instead of at the point of the original + // throw statement in the code under test. For this reason, we perform + // the check early, sacrificing the ability to affect Google Test's + // exception handling in the method where the exception is thrown. + if (internal::GetUnitTestImpl()->catch_exceptions()) { +#if GTEST_HAS_EXCEPTIONS + try { + return HandleSehExceptionsInMethodIfSupported(object, method, location); + } catch (const AssertionException&) { // NOLINT + // This failure was reported already. + } catch (const internal::GoogleTestFailureException&) { // NOLINT + // This exception type can only be thrown by a failed Google + // Test assertion with the intention of letting another testing + // framework catch it. Therefore we just re-throw it. + throw; + } catch (const std::exception& e) { // NOLINT + internal::ReportFailureInUnknownLocation( + TestPartResult::kFatalFailure, + FormatCxxExceptionMessage(e.what(), location)); + } catch (...) { // NOLINT + internal::ReportFailureInUnknownLocation( + TestPartResult::kFatalFailure, + FormatCxxExceptionMessage(nullptr, location)); + } + return static_cast(0); +#else + return HandleSehExceptionsInMethodIfSupported(object, method, location); +#endif // GTEST_HAS_EXCEPTIONS + } else { + return (object->*method)(); + } +} + +} // namespace internal + +// Runs the test and updates the test result. +void Test::Run() { + if (!HasSameFixtureClass()) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported(this, &Test::SetUp, "SetUp()"); + // We will run the test only if SetUp() was successful and didn't call + // GTEST_SKIP(). + if (!HasFatalFailure() && !IsSkipped()) { + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &Test::TestBody, "the test body"); + } + + // However, we want to clean up as much as possible. Hence we will + // always call TearDown(), even if SetUp() or the test body has + // failed. + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &Test::TearDown, "TearDown()"); +} + +// Returns true if and only if the current test has a fatal failure. +bool Test::HasFatalFailure() { + return internal::GetUnitTestImpl()->current_test_result()->HasFatalFailure(); +} + +// Returns true if and only if the current test has a non-fatal failure. +bool Test::HasNonfatalFailure() { + return internal::GetUnitTestImpl()->current_test_result()-> + HasNonfatalFailure(); +} + +// Returns true if and only if the current test was skipped. +bool Test::IsSkipped() { + return internal::GetUnitTestImpl()->current_test_result()->Skipped(); +} + +// class TestInfo + +// Constructs a TestInfo object. It assumes ownership of the test factory +// object. +TestInfo::TestInfo(const std::string& a_test_suite_name, + const std::string& a_name, const char* a_type_param, + const char* a_value_param, + internal::CodeLocation a_code_location, + internal::TypeId fixture_class_id, + internal::TestFactoryBase* factory) + : test_suite_name_(a_test_suite_name), + name_(a_name), + type_param_(a_type_param ? new std::string(a_type_param) : nullptr), + value_param_(a_value_param ? new std::string(a_value_param) : nullptr), + location_(a_code_location), + fixture_class_id_(fixture_class_id), + should_run_(false), + is_disabled_(false), + matches_filter_(false), + is_in_another_shard_(false), + factory_(factory), + result_() {} + +// Destructs a TestInfo object. +TestInfo::~TestInfo() { delete factory_; } + +namespace internal { + +// Creates a new TestInfo object and registers it with Google Test; +// returns the created object. +// +// Arguments: +// +// test_suite_name: name of the test suite +// name: name of the test +// type_param: the name of the test's type parameter, or NULL if +// this is not a typed or a type-parameterized test. +// value_param: text representation of the test's value parameter, +// or NULL if this is not a value-parameterized test. +// code_location: code location where the test is defined +// fixture_class_id: ID of the test fixture class +// set_up_tc: pointer to the function that sets up the test suite +// tear_down_tc: pointer to the function that tears down the test suite +// factory: pointer to the factory that creates a test object. +// The newly created TestInfo instance will assume +// ownership of the factory object. +TestInfo* MakeAndRegisterTestInfo( + const char* test_suite_name, const char* name, const char* type_param, + const char* value_param, CodeLocation code_location, + TypeId fixture_class_id, SetUpTestSuiteFunc set_up_tc, + TearDownTestSuiteFunc tear_down_tc, TestFactoryBase* factory) { + TestInfo* const test_info = + new TestInfo(test_suite_name, name, type_param, value_param, + code_location, fixture_class_id, factory); + GetUnitTestImpl()->AddTestInfo(set_up_tc, tear_down_tc, test_info); + return test_info; +} + +void ReportInvalidTestSuiteType(const char* test_suite_name, + CodeLocation code_location) { + Message errors; + errors + << "Attempted redefinition of test suite " << test_suite_name << ".\n" + << "All tests in the same test suite must use the same test fixture\n" + << "class. However, in test suite " << test_suite_name << ", you tried\n" + << "to define a test using a fixture class different from the one\n" + << "used earlier. This can happen if the two fixture classes are\n" + << "from different namespaces and have the same name. You should\n" + << "probably rename one of the classes to put the tests into different\n" + << "test suites."; + + GTEST_LOG_(ERROR) << FormatFileLocation(code_location.file.c_str(), + code_location.line) + << " " << errors.GetString(); +} +} // namespace internal + +namespace { + +// A predicate that checks the test name of a TestInfo against a known +// value. +// +// This is used for implementation of the TestSuite class only. We put +// it in the anonymous namespace to prevent polluting the outer +// namespace. +// +// TestNameIs is copyable. +class TestNameIs { + public: + // Constructor. + // + // TestNameIs has NO default constructor. + explicit TestNameIs(const char* name) + : name_(name) {} + + // Returns true if and only if the test name of test_info matches name_. + bool operator()(const TestInfo * test_info) const { + return test_info && test_info->name() == name_; + } + + private: + std::string name_; +}; + +} // namespace + +namespace internal { + +// This method expands all parameterized tests registered with macros TEST_P +// and INSTANTIATE_TEST_SUITE_P into regular tests and registers those. +// This will be done just once during the program runtime. +void UnitTestImpl::RegisterParameterizedTests() { + if (!parameterized_tests_registered_) { + parameterized_test_registry_.RegisterTests(); + type_parameterized_test_registry_.CheckForInstantiations(); + parameterized_tests_registered_ = true; + } +} + +} // namespace internal + +// Creates the test object, runs it, records its result, and then +// deletes it. +void TestInfo::Run() { + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + if (!should_run_) { + if (is_disabled_) repeater->OnTestDisabled(*this); + return; + } + + // Tells UnitTest where to store test result. + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_info(this); + + // Notifies the unit test event listeners that a test is about to start. + repeater->OnTestStart(*this); + result_.set_start_timestamp(internal::GetTimeInMillis()); + internal::Timer timer; + impl->os_stack_trace_getter()->UponLeavingGTest(); + + // Creates the test object. + Test* const test = internal::HandleExceptionsInMethodIfSupported( + factory_, &internal::TestFactoryBase::CreateTest, + "the test fixture's constructor"); + + // Runs the test if the constructor didn't generate a fatal failure or invoke + // GTEST_SKIP(). + // Note that the object will not be null + if (!Test::HasFatalFailure() && !Test::IsSkipped()) { + // This doesn't throw as all user code that can throw are wrapped into + // exception handling code. + test->Run(); + } + + if (test != nullptr) { + // Deletes the test object. + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + test, &Test::DeleteSelf_, "the test fixture's destructor"); + } + + result_.set_elapsed_time(timer.Elapsed()); + + // Notifies the unit test event listener that a test has just finished. + repeater->OnTestEnd(*this); + + // Tells UnitTest to stop associating assertion results to this + // test. + impl->set_current_test_info(nullptr); +} + +// Skip and records a skipped test result for this object. +void TestInfo::Skip() { + if (!should_run_) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_info(this); + + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + + // Notifies the unit test event listeners that a test is about to start. + repeater->OnTestStart(*this); + + const TestPartResult test_part_result = + TestPartResult(TestPartResult::kSkip, this->file(), this->line(), ""); + impl->GetTestPartResultReporterForCurrentThread()->ReportTestPartResult( + test_part_result); + + // Notifies the unit test event listener that a test has just finished. + repeater->OnTestEnd(*this); + impl->set_current_test_info(nullptr); +} + +// class TestSuite + +// Gets the number of successful tests in this test suite. +int TestSuite::successful_test_count() const { + return CountIf(test_info_list_, TestPassed); +} + +// Gets the number of successful tests in this test suite. +int TestSuite::skipped_test_count() const { + return CountIf(test_info_list_, TestSkipped); +} + +// Gets the number of failed tests in this test suite. +int TestSuite::failed_test_count() const { + return CountIf(test_info_list_, TestFailed); +} + +// Gets the number of disabled tests that will be reported in the XML report. +int TestSuite::reportable_disabled_test_count() const { + return CountIf(test_info_list_, TestReportableDisabled); +} + +// Gets the number of disabled tests in this test suite. +int TestSuite::disabled_test_count() const { + return CountIf(test_info_list_, TestDisabled); +} + +// Gets the number of tests to be printed in the XML report. +int TestSuite::reportable_test_count() const { + return CountIf(test_info_list_, TestReportable); +} + +// Get the number of tests in this test suite that should run. +int TestSuite::test_to_run_count() const { + return CountIf(test_info_list_, ShouldRunTest); +} + +// Gets the number of all tests. +int TestSuite::total_test_count() const { + return static_cast(test_info_list_.size()); +} + +// Creates a TestSuite with the given name. +// +// Arguments: +// +// a_name: name of the test suite +// a_type_param: the name of the test suite's type parameter, or NULL if +// this is not a typed or a type-parameterized test suite. +// set_up_tc: pointer to the function that sets up the test suite +// tear_down_tc: pointer to the function that tears down the test suite +TestSuite::TestSuite(const char* a_name, const char* a_type_param, + internal::SetUpTestSuiteFunc set_up_tc, + internal::TearDownTestSuiteFunc tear_down_tc) + : name_(a_name), + type_param_(a_type_param ? new std::string(a_type_param) : nullptr), + set_up_tc_(set_up_tc), + tear_down_tc_(tear_down_tc), + should_run_(false), + start_timestamp_(0), + elapsed_time_(0) {} + +// Destructor of TestSuite. +TestSuite::~TestSuite() { + // Deletes every Test in the collection. + ForEach(test_info_list_, internal::Delete); +} + +// Returns the i-th test among all the tests. i can range from 0 to +// total_test_count() - 1. If i is not in that range, returns NULL. +const TestInfo* TestSuite::GetTestInfo(int i) const { + const int index = GetElementOr(test_indices_, i, -1); + return index < 0 ? nullptr : test_info_list_[static_cast(index)]; +} + +// Returns the i-th test among all the tests. i can range from 0 to +// total_test_count() - 1. If i is not in that range, returns NULL. +TestInfo* TestSuite::GetMutableTestInfo(int i) { + const int index = GetElementOr(test_indices_, i, -1); + return index < 0 ? nullptr : test_info_list_[static_cast(index)]; +} + +// Adds a test to this test suite. Will delete the test upon +// destruction of the TestSuite object. +void TestSuite::AddTestInfo(TestInfo* test_info) { + test_info_list_.push_back(test_info); + test_indices_.push_back(static_cast(test_indices_.size())); +} + +// Runs every test in this TestSuite. +void TestSuite::Run() { + if (!should_run_) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_suite(this); + + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + + // Call both legacy and the new API + repeater->OnTestSuiteStart(*this); +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + repeater->OnTestCaseStart(*this); +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &TestSuite::RunSetUpTestSuite, "SetUpTestSuite()"); + + const bool skip_all = ad_hoc_test_result().Failed(); + + start_timestamp_ = internal::GetTimeInMillis(); + internal::Timer timer; + for (int i = 0; i < total_test_count(); i++) { + if (skip_all) { + GetMutableTestInfo(i)->Skip(); + } else { + GetMutableTestInfo(i)->Run(); + } + if (GTEST_FLAG_GET(fail_fast) && + GetMutableTestInfo(i)->result()->Failed()) { + for (int j = i + 1; j < total_test_count(); j++) { + GetMutableTestInfo(j)->Skip(); + } + break; + } + } + elapsed_time_ = timer.Elapsed(); + + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &TestSuite::RunTearDownTestSuite, "TearDownTestSuite()"); + + // Call both legacy and the new API + repeater->OnTestSuiteEnd(*this); +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + repeater->OnTestCaseEnd(*this); +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + impl->set_current_test_suite(nullptr); +} + +// Skips all tests under this TestSuite. +void TestSuite::Skip() { + if (!should_run_) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_suite(this); + + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + + // Call both legacy and the new API + repeater->OnTestSuiteStart(*this); +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + repeater->OnTestCaseStart(*this); +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + for (int i = 0; i < total_test_count(); i++) { + GetMutableTestInfo(i)->Skip(); + } + + // Call both legacy and the new API + repeater->OnTestSuiteEnd(*this); + // Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + repeater->OnTestCaseEnd(*this); +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + impl->set_current_test_suite(nullptr); +} + +// Clears the results of all tests in this test suite. +void TestSuite::ClearResult() { + ad_hoc_test_result_.Clear(); + ForEach(test_info_list_, TestInfo::ClearTestResult); +} + +// Shuffles the tests in this test suite. +void TestSuite::ShuffleTests(internal::Random* random) { + Shuffle(random, &test_indices_); +} + +// Restores the test order to before the first shuffle. +void TestSuite::UnshuffleTests() { + for (size_t i = 0; i < test_indices_.size(); i++) { + test_indices_[i] = static_cast(i); + } +} + +// Formats a countable noun. Depending on its quantity, either the +// singular form or the plural form is used. e.g. +// +// FormatCountableNoun(1, "formula", "formuli") returns "1 formula". +// FormatCountableNoun(5, "book", "books") returns "5 books". +static std::string FormatCountableNoun(int count, + const char * singular_form, + const char * plural_form) { + return internal::StreamableToString(count) + " " + + (count == 1 ? singular_form : plural_form); +} + +// Formats the count of tests. +static std::string FormatTestCount(int test_count) { + return FormatCountableNoun(test_count, "test", "tests"); +} + +// Formats the count of test suites. +static std::string FormatTestSuiteCount(int test_suite_count) { + return FormatCountableNoun(test_suite_count, "test suite", "test suites"); +} + +// Converts a TestPartResult::Type enum to human-friendly string +// representation. Both kNonFatalFailure and kFatalFailure are translated +// to "Failure", as the user usually doesn't care about the difference +// between the two when viewing the test result. +static const char * TestPartResultTypeToString(TestPartResult::Type type) { + switch (type) { + case TestPartResult::kSkip: + return "Skipped\n"; + case TestPartResult::kSuccess: + return "Success"; + + case TestPartResult::kNonFatalFailure: + case TestPartResult::kFatalFailure: +#ifdef _MSC_VER + return "error: "; +#else + return "Failure\n"; +#endif + default: + return "Unknown result type"; + } +} + +namespace internal { +namespace { +enum class GTestColor { kDefault, kRed, kGreen, kYellow }; +} // namespace + +// Prints a TestPartResult to an std::string. +static std::string PrintTestPartResultToString( + const TestPartResult& test_part_result) { + return (Message() + << internal::FormatFileLocation(test_part_result.file_name(), + test_part_result.line_number()) + << " " << TestPartResultTypeToString(test_part_result.type()) + << test_part_result.message()).GetString(); +} + +// Prints a TestPartResult. +static void PrintTestPartResult(const TestPartResult& test_part_result) { + const std::string& result = + PrintTestPartResultToString(test_part_result); + printf("%s\n", result.c_str()); + fflush(stdout); + // If the test program runs in Visual Studio or a debugger, the + // following statements add the test part result message to the Output + // window such that the user can double-click on it to jump to the + // corresponding source code location; otherwise they do nothing. +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + // We don't call OutputDebugString*() on Windows Mobile, as printing + // to stdout is done by OutputDebugString() there already - we don't + // want the same message printed twice. + ::OutputDebugStringA(result.c_str()); + ::OutputDebugStringA("\n"); +#endif +} + +// class PrettyUnitTestResultPrinter +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE && \ + !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT && !GTEST_OS_WINDOWS_MINGW + +// Returns the character attribute for the given color. +static WORD GetColorAttribute(GTestColor color) { + switch (color) { + case GTestColor::kRed: + return FOREGROUND_RED; + case GTestColor::kGreen: + return FOREGROUND_GREEN; + case GTestColor::kYellow: + return FOREGROUND_RED | FOREGROUND_GREEN; + default: return 0; + } +} + +static int GetBitOffset(WORD color_mask) { + if (color_mask == 0) return 0; + + int bitOffset = 0; + while ((color_mask & 1) == 0) { + color_mask >>= 1; + ++bitOffset; + } + return bitOffset; +} + +static WORD GetNewColor(GTestColor color, WORD old_color_attrs) { + // Let's reuse the BG + static const WORD background_mask = BACKGROUND_BLUE | BACKGROUND_GREEN | + BACKGROUND_RED | BACKGROUND_INTENSITY; + static const WORD foreground_mask = FOREGROUND_BLUE | FOREGROUND_GREEN | + FOREGROUND_RED | FOREGROUND_INTENSITY; + const WORD existing_bg = old_color_attrs & background_mask; + + WORD new_color = + GetColorAttribute(color) | existing_bg | FOREGROUND_INTENSITY; + static const int bg_bitOffset = GetBitOffset(background_mask); + static const int fg_bitOffset = GetBitOffset(foreground_mask); + + if (((new_color & background_mask) >> bg_bitOffset) == + ((new_color & foreground_mask) >> fg_bitOffset)) { + new_color ^= FOREGROUND_INTENSITY; // invert intensity + } + return new_color; +} + +#else + +// Returns the ANSI color code for the given color. GTestColor::kDefault is +// an invalid input. +static const char* GetAnsiColorCode(GTestColor color) { + switch (color) { + case GTestColor::kRed: + return "1"; + case GTestColor::kGreen: + return "2"; + case GTestColor::kYellow: + return "3"; + default: + return nullptr; + } +} + +#endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + +// Returns true if and only if Google Test should use colors in the output. +bool ShouldUseColor(bool stdout_is_tty) { + std::string c = GTEST_FLAG_GET(color); + const char* const gtest_color = c.c_str(); + + if (String::CaseInsensitiveCStringEquals(gtest_color, "auto")) { +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MINGW + // On Windows the TERM variable is usually not set, but the + // console there does support colors. + return stdout_is_tty; +#else + // On non-Windows platforms, we rely on the TERM variable. + const char* const term = posix::GetEnv("TERM"); + const bool term_supports_color = + String::CStringEquals(term, "xterm") || + String::CStringEquals(term, "xterm-color") || + String::CStringEquals(term, "xterm-256color") || + String::CStringEquals(term, "screen") || + String::CStringEquals(term, "screen-256color") || + String::CStringEquals(term, "tmux") || + String::CStringEquals(term, "tmux-256color") || + String::CStringEquals(term, "rxvt-unicode") || + String::CStringEquals(term, "rxvt-unicode-256color") || + String::CStringEquals(term, "linux") || + String::CStringEquals(term, "cygwin"); + return stdout_is_tty && term_supports_color; +#endif // GTEST_OS_WINDOWS + } + + return String::CaseInsensitiveCStringEquals(gtest_color, "yes") || + String::CaseInsensitiveCStringEquals(gtest_color, "true") || + String::CaseInsensitiveCStringEquals(gtest_color, "t") || + String::CStringEquals(gtest_color, "1"); + // We take "yes", "true", "t", and "1" as meaning "yes". If the + // value is neither one of these nor "auto", we treat it as "no" to + // be conservative. +} + +// Helpers for printing colored strings to stdout. Note that on Windows, we +// cannot simply emit special characters and have the terminal change colors. +// This routine must actually emit the characters rather than return a string +// that would be colored when printed, as can be done on Linux. + +GTEST_ATTRIBUTE_PRINTF_(2, 3) +static void ColoredPrintf(GTestColor color, const char *fmt, ...) { + va_list args; + va_start(args, fmt); + +#if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_ZOS || GTEST_OS_IOS || \ + GTEST_OS_WINDOWS_PHONE || GTEST_OS_WINDOWS_RT || defined(ESP_PLATFORM) + const bool use_color = AlwaysFalse(); +#else + static const bool in_color_mode = + ShouldUseColor(posix::IsATTY(posix::FileNo(stdout)) != 0); + const bool use_color = in_color_mode && (color != GTestColor::kDefault); +#endif // GTEST_OS_WINDOWS_MOBILE || GTEST_OS_ZOS + + if (!use_color) { + vprintf(fmt, args); + va_end(args); + return; + } + +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE && \ + !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT && !GTEST_OS_WINDOWS_MINGW + const HANDLE stdout_handle = GetStdHandle(STD_OUTPUT_HANDLE); + + // Gets the current text color. + CONSOLE_SCREEN_BUFFER_INFO buffer_info; + GetConsoleScreenBufferInfo(stdout_handle, &buffer_info); + const WORD old_color_attrs = buffer_info.wAttributes; + const WORD new_color = GetNewColor(color, old_color_attrs); + + // We need to flush the stream buffers into the console before each + // SetConsoleTextAttribute call lest it affect the text that is already + // printed but has not yet reached the console. + fflush(stdout); + SetConsoleTextAttribute(stdout_handle, new_color); + + vprintf(fmt, args); + + fflush(stdout); + // Restores the text color. + SetConsoleTextAttribute(stdout_handle, old_color_attrs); +#else + printf("\033[0;3%sm", GetAnsiColorCode(color)); + vprintf(fmt, args); + printf("\033[m"); // Resets the terminal to default. +#endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + va_end(args); +} + +// Text printed in Google Test's text output and --gtest_list_tests +// output to label the type parameter and value parameter for a test. +static const char kTypeParamLabel[] = "TypeParam"; +static const char kValueParamLabel[] = "GetParam()"; + +static void PrintFullTestCommentIfPresent(const TestInfo& test_info) { + const char* const type_param = test_info.type_param(); + const char* const value_param = test_info.value_param(); + + if (type_param != nullptr || value_param != nullptr) { + printf(", where "); + if (type_param != nullptr) { + printf("%s = %s", kTypeParamLabel, type_param); + if (value_param != nullptr) printf(" and "); + } + if (value_param != nullptr) { + printf("%s = %s", kValueParamLabel, value_param); + } + } +} + +// This class implements the TestEventListener interface. +// +// Class PrettyUnitTestResultPrinter is copyable. +class PrettyUnitTestResultPrinter : public TestEventListener { + public: + PrettyUnitTestResultPrinter() {} + static void PrintTestName(const char* test_suite, const char* test) { + printf("%s.%s", test_suite, test); + } + + // The following methods override what's in the TestEventListener class. + void OnTestProgramStart(const UnitTest& /*unit_test*/) override {} + void OnTestIterationStart(const UnitTest& unit_test, int iteration) override; + void OnEnvironmentsSetUpStart(const UnitTest& unit_test) override; + void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) override {} +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseStart(const TestCase& test_case) override; +#else + void OnTestSuiteStart(const TestSuite& test_suite) override; +#endif // OnTestCaseStart + + void OnTestStart(const TestInfo& test_info) override; + void OnTestDisabled(const TestInfo& test_info) override; + + void OnTestPartResult(const TestPartResult& result) override; + void OnTestEnd(const TestInfo& test_info) override; +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseEnd(const TestCase& test_case) override; +#else + void OnTestSuiteEnd(const TestSuite& test_suite) override; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + void OnEnvironmentsTearDownStart(const UnitTest& unit_test) override; + void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) override {} + void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; + void OnTestProgramEnd(const UnitTest& /*unit_test*/) override {} + + private: + static void PrintFailedTests(const UnitTest& unit_test); + static void PrintFailedTestSuites(const UnitTest& unit_test); + static void PrintSkippedTests(const UnitTest& unit_test); +}; + + // Fired before each iteration of tests starts. +void PrettyUnitTestResultPrinter::OnTestIterationStart( + const UnitTest& unit_test, int iteration) { + if (GTEST_FLAG_GET(repeat) != 1) + printf("\nRepeating all tests (iteration %d) . . .\n\n", iteration + 1); + + std::string f = GTEST_FLAG_GET(filter); + const char* const filter = f.c_str(); + + // Prints the filter if it's not *. This reminds the user that some + // tests may be skipped. + if (!String::CStringEquals(filter, kUniversalFilter)) { + ColoredPrintf(GTestColor::kYellow, "Note: %s filter = %s\n", GTEST_NAME_, + filter); + } + + if (internal::ShouldShard(kTestTotalShards, kTestShardIndex, false)) { + const int32_t shard_index = Int32FromEnvOrDie(kTestShardIndex, -1); + ColoredPrintf(GTestColor::kYellow, "Note: This is test shard %d of %s.\n", + static_cast(shard_index) + 1, + internal::posix::GetEnv(kTestTotalShards)); + } + + if (GTEST_FLAG_GET(shuffle)) { + ColoredPrintf(GTestColor::kYellow, + "Note: Randomizing tests' orders with a seed of %d .\n", + unit_test.random_seed()); + } + + ColoredPrintf(GTestColor::kGreen, "[==========] "); + printf("Running %s from %s.\n", + FormatTestCount(unit_test.test_to_run_count()).c_str(), + FormatTestSuiteCount(unit_test.test_suite_to_run_count()).c_str()); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnEnvironmentsSetUpStart( + const UnitTest& /*unit_test*/) { + ColoredPrintf(GTestColor::kGreen, "[----------] "); + printf("Global test environment set-up.\n"); + fflush(stdout); +} + +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +void PrettyUnitTestResultPrinter::OnTestCaseStart(const TestCase& test_case) { + const std::string counts = + FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); + ColoredPrintf(GTestColor::kGreen, "[----------] "); + printf("%s from %s", counts.c_str(), test_case.name()); + if (test_case.type_param() == nullptr) { + printf("\n"); + } else { + printf(", where %s = %s\n", kTypeParamLabel, test_case.type_param()); + } + fflush(stdout); +} +#else +void PrettyUnitTestResultPrinter::OnTestSuiteStart( + const TestSuite& test_suite) { + const std::string counts = + FormatCountableNoun(test_suite.test_to_run_count(), "test", "tests"); + ColoredPrintf(GTestColor::kGreen, "[----------] "); + printf("%s from %s", counts.c_str(), test_suite.name()); + if (test_suite.type_param() == nullptr) { + printf("\n"); + } else { + printf(", where %s = %s\n", kTypeParamLabel, test_suite.type_param()); + } + fflush(stdout); +} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +void PrettyUnitTestResultPrinter::OnTestStart(const TestInfo& test_info) { + ColoredPrintf(GTestColor::kGreen, "[ RUN ] "); + PrintTestName(test_info.test_suite_name(), test_info.name()); + printf("\n"); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestDisabled(const TestInfo& test_info) { + ColoredPrintf(GTestColor::kYellow, "[ DISABLED ] "); + PrintTestName(test_info.test_suite_name(), test_info.name()); + printf("\n"); + fflush(stdout); +} + +// Called after an assertion failure. +void PrettyUnitTestResultPrinter::OnTestPartResult( + const TestPartResult& result) { + switch (result.type()) { + // If the test part succeeded, we don't need to do anything. + case TestPartResult::kSuccess: + return; + default: + // Print failure message from the assertion + // (e.g. expected this and got that). + PrintTestPartResult(result); + fflush(stdout); + } +} + +void PrettyUnitTestResultPrinter::OnTestEnd(const TestInfo& test_info) { + if (test_info.result()->Passed()) { + ColoredPrintf(GTestColor::kGreen, "[ OK ] "); + } else if (test_info.result()->Skipped()) { + ColoredPrintf(GTestColor::kGreen, "[ SKIPPED ] "); + } else { + ColoredPrintf(GTestColor::kRed, "[ FAILED ] "); + } + PrintTestName(test_info.test_suite_name(), test_info.name()); + if (test_info.result()->Failed()) + PrintFullTestCommentIfPresent(test_info); + + if (GTEST_FLAG_GET(print_time)) { + printf(" (%s ms)\n", internal::StreamableToString( + test_info.result()->elapsed_time()).c_str()); + } else { + printf("\n"); + } + fflush(stdout); +} + +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +void PrettyUnitTestResultPrinter::OnTestCaseEnd(const TestCase& test_case) { + if (!GTEST_FLAG_GET(print_time)) return; + + const std::string counts = + FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); + ColoredPrintf(GTestColor::kGreen, "[----------] "); + printf("%s from %s (%s ms total)\n\n", counts.c_str(), test_case.name(), + internal::StreamableToString(test_case.elapsed_time()).c_str()); + fflush(stdout); +} +#else +void PrettyUnitTestResultPrinter::OnTestSuiteEnd(const TestSuite& test_suite) { + if (!GTEST_FLAG_GET(print_time)) return; + + const std::string counts = + FormatCountableNoun(test_suite.test_to_run_count(), "test", "tests"); + ColoredPrintf(GTestColor::kGreen, "[----------] "); + printf("%s from %s (%s ms total)\n\n", counts.c_str(), test_suite.name(), + internal::StreamableToString(test_suite.elapsed_time()).c_str()); + fflush(stdout); +} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +void PrettyUnitTestResultPrinter::OnEnvironmentsTearDownStart( + const UnitTest& /*unit_test*/) { + ColoredPrintf(GTestColor::kGreen, "[----------] "); + printf("Global test environment tear-down\n"); + fflush(stdout); +} + +// Internal helper for printing the list of failed tests. +void PrettyUnitTestResultPrinter::PrintFailedTests(const UnitTest& unit_test) { + const int failed_test_count = unit_test.failed_test_count(); + ColoredPrintf(GTestColor::kRed, "[ FAILED ] "); + printf("%s, listed below:\n", FormatTestCount(failed_test_count).c_str()); + + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { + const TestSuite& test_suite = *unit_test.GetTestSuite(i); + if (!test_suite.should_run() || (test_suite.failed_test_count() == 0)) { + continue; + } + for (int j = 0; j < test_suite.total_test_count(); ++j) { + const TestInfo& test_info = *test_suite.GetTestInfo(j); + if (!test_info.should_run() || !test_info.result()->Failed()) { + continue; + } + ColoredPrintf(GTestColor::kRed, "[ FAILED ] "); + printf("%s.%s", test_suite.name(), test_info.name()); + PrintFullTestCommentIfPresent(test_info); + printf("\n"); + } + } + printf("\n%2d FAILED %s\n", failed_test_count, + failed_test_count == 1 ? "TEST" : "TESTS"); +} + +// Internal helper for printing the list of test suite failures not covered by +// PrintFailedTests. +void PrettyUnitTestResultPrinter::PrintFailedTestSuites( + const UnitTest& unit_test) { + int suite_failure_count = 0; + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { + const TestSuite& test_suite = *unit_test.GetTestSuite(i); + if (!test_suite.should_run()) { + continue; + } + if (test_suite.ad_hoc_test_result().Failed()) { + ColoredPrintf(GTestColor::kRed, "[ FAILED ] "); + printf("%s: SetUpTestSuite or TearDownTestSuite\n", test_suite.name()); + ++suite_failure_count; + } + } + if (suite_failure_count > 0) { + printf("\n%2d FAILED TEST %s\n", suite_failure_count, + suite_failure_count == 1 ? "SUITE" : "SUITES"); + } +} + +// Internal helper for printing the list of skipped tests. +void PrettyUnitTestResultPrinter::PrintSkippedTests(const UnitTest& unit_test) { + const int skipped_test_count = unit_test.skipped_test_count(); + if (skipped_test_count == 0) { + return; + } + + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { + const TestSuite& test_suite = *unit_test.GetTestSuite(i); + if (!test_suite.should_run() || (test_suite.skipped_test_count() == 0)) { + continue; + } + for (int j = 0; j < test_suite.total_test_count(); ++j) { + const TestInfo& test_info = *test_suite.GetTestInfo(j); + if (!test_info.should_run() || !test_info.result()->Skipped()) { + continue; + } + ColoredPrintf(GTestColor::kGreen, "[ SKIPPED ] "); + printf("%s.%s", test_suite.name(), test_info.name()); + printf("\n"); + } + } +} + +void PrettyUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + ColoredPrintf(GTestColor::kGreen, "[==========] "); + printf("%s from %s ran.", + FormatTestCount(unit_test.test_to_run_count()).c_str(), + FormatTestSuiteCount(unit_test.test_suite_to_run_count()).c_str()); + if (GTEST_FLAG_GET(print_time)) { + printf(" (%s ms total)", + internal::StreamableToString(unit_test.elapsed_time()).c_str()); + } + printf("\n"); + ColoredPrintf(GTestColor::kGreen, "[ PASSED ] "); + printf("%s.\n", FormatTestCount(unit_test.successful_test_count()).c_str()); + + const int skipped_test_count = unit_test.skipped_test_count(); + if (skipped_test_count > 0) { + ColoredPrintf(GTestColor::kGreen, "[ SKIPPED ] "); + printf("%s, listed below:\n", FormatTestCount(skipped_test_count).c_str()); + PrintSkippedTests(unit_test); + } + + if (!unit_test.Passed()) { + PrintFailedTests(unit_test); + PrintFailedTestSuites(unit_test); + } + + int num_disabled = unit_test.reportable_disabled_test_count(); + if (num_disabled && !GTEST_FLAG_GET(also_run_disabled_tests)) { + if (unit_test.Passed()) { + printf("\n"); // Add a spacer if no FAILURE banner is displayed. + } + ColoredPrintf(GTestColor::kYellow, " YOU HAVE %d DISABLED %s\n\n", + num_disabled, num_disabled == 1 ? "TEST" : "TESTS"); + } + // Ensure that Google Test output is printed before, e.g., heapchecker output. + fflush(stdout); +} + +// End PrettyUnitTestResultPrinter + +// This class implements the TestEventListener interface. +// +// Class BriefUnitTestResultPrinter is copyable. +class BriefUnitTestResultPrinter : public TestEventListener { + public: + BriefUnitTestResultPrinter() {} + static void PrintTestName(const char* test_suite, const char* test) { + printf("%s.%s", test_suite, test); + } + + // The following methods override what's in the TestEventListener class. + void OnTestProgramStart(const UnitTest& /*unit_test*/) override {} + void OnTestIterationStart(const UnitTest& /*unit_test*/, + int /*iteration*/) override {} + void OnEnvironmentsSetUpStart(const UnitTest& /*unit_test*/) override {} + void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) override {} +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseStart(const TestCase& /*test_case*/) override {} +#else + void OnTestSuiteStart(const TestSuite& /*test_suite*/) override {} +#endif // OnTestCaseStart + + void OnTestStart(const TestInfo& /*test_info*/) override {} + void OnTestDisabled(const TestInfo& /*test_info*/) override {} + + void OnTestPartResult(const TestPartResult& result) override; + void OnTestEnd(const TestInfo& test_info) override; +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseEnd(const TestCase& /*test_case*/) override {} +#else + void OnTestSuiteEnd(const TestSuite& /*test_suite*/) override {} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + void OnEnvironmentsTearDownStart(const UnitTest& /*unit_test*/) override {} + void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) override {} + void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; + void OnTestProgramEnd(const UnitTest& /*unit_test*/) override {} +}; + +// Called after an assertion failure. +void BriefUnitTestResultPrinter::OnTestPartResult( + const TestPartResult& result) { + switch (result.type()) { + // If the test part succeeded, we don't need to do anything. + case TestPartResult::kSuccess: + return; + default: + // Print failure message from the assertion + // (e.g. expected this and got that). + PrintTestPartResult(result); + fflush(stdout); + } +} + +void BriefUnitTestResultPrinter::OnTestEnd(const TestInfo& test_info) { + if (test_info.result()->Failed()) { + ColoredPrintf(GTestColor::kRed, "[ FAILED ] "); + PrintTestName(test_info.test_suite_name(), test_info.name()); + PrintFullTestCommentIfPresent(test_info); + + if (GTEST_FLAG_GET(print_time)) { + printf(" (%s ms)\n", + internal::StreamableToString(test_info.result()->elapsed_time()) + .c_str()); + } else { + printf("\n"); + } + fflush(stdout); + } +} + +void BriefUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + ColoredPrintf(GTestColor::kGreen, "[==========] "); + printf("%s from %s ran.", + FormatTestCount(unit_test.test_to_run_count()).c_str(), + FormatTestSuiteCount(unit_test.test_suite_to_run_count()).c_str()); + if (GTEST_FLAG_GET(print_time)) { + printf(" (%s ms total)", + internal::StreamableToString(unit_test.elapsed_time()).c_str()); + } + printf("\n"); + ColoredPrintf(GTestColor::kGreen, "[ PASSED ] "); + printf("%s.\n", FormatTestCount(unit_test.successful_test_count()).c_str()); + + const int skipped_test_count = unit_test.skipped_test_count(); + if (skipped_test_count > 0) { + ColoredPrintf(GTestColor::kGreen, "[ SKIPPED ] "); + printf("%s.\n", FormatTestCount(skipped_test_count).c_str()); + } + + int num_disabled = unit_test.reportable_disabled_test_count(); + if (num_disabled && !GTEST_FLAG_GET(also_run_disabled_tests)) { + if (unit_test.Passed()) { + printf("\n"); // Add a spacer if no FAILURE banner is displayed. + } + ColoredPrintf(GTestColor::kYellow, " YOU HAVE %d DISABLED %s\n\n", + num_disabled, num_disabled == 1 ? "TEST" : "TESTS"); + } + // Ensure that Google Test output is printed before, e.g., heapchecker output. + fflush(stdout); +} + +// End BriefUnitTestResultPrinter + +// class TestEventRepeater +// +// This class forwards events to other event listeners. +class TestEventRepeater : public TestEventListener { + public: + TestEventRepeater() : forwarding_enabled_(true) {} + ~TestEventRepeater() override; + void Append(TestEventListener *listener); + TestEventListener* Release(TestEventListener* listener); + + // Controls whether events will be forwarded to listeners_. Set to false + // in death test child processes. + bool forwarding_enabled() const { return forwarding_enabled_; } + void set_forwarding_enabled(bool enable) { forwarding_enabled_ = enable; } + + void OnTestProgramStart(const UnitTest& unit_test) override; + void OnTestIterationStart(const UnitTest& unit_test, int iteration) override; + void OnEnvironmentsSetUpStart(const UnitTest& unit_test) override; + void OnEnvironmentsSetUpEnd(const UnitTest& unit_test) override; +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseStart(const TestSuite& parameter) override; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestSuiteStart(const TestSuite& parameter) override; + void OnTestStart(const TestInfo& test_info) override; + void OnTestDisabled(const TestInfo& test_info) override; + void OnTestPartResult(const TestPartResult& result) override; + void OnTestEnd(const TestInfo& test_info) override; +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseEnd(const TestCase& parameter) override; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestSuiteEnd(const TestSuite& parameter) override; + void OnEnvironmentsTearDownStart(const UnitTest& unit_test) override; + void OnEnvironmentsTearDownEnd(const UnitTest& unit_test) override; + void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; + void OnTestProgramEnd(const UnitTest& unit_test) override; + + private: + // Controls whether events will be forwarded to listeners_. Set to false + // in death test child processes. + bool forwarding_enabled_; + // The list of listeners that receive events. + std::vector listeners_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventRepeater); +}; + +TestEventRepeater::~TestEventRepeater() { + ForEach(listeners_, Delete); +} + +void TestEventRepeater::Append(TestEventListener *listener) { + listeners_.push_back(listener); +} + +TestEventListener* TestEventRepeater::Release(TestEventListener *listener) { + for (size_t i = 0; i < listeners_.size(); ++i) { + if (listeners_[i] == listener) { + listeners_.erase(listeners_.begin() + static_cast(i)); + return listener; + } + } + + return nullptr; +} + +// Since most methods are very similar, use macros to reduce boilerplate. +// This defines a member that forwards the call to all listeners. +#define GTEST_REPEATER_METHOD_(Name, Type) \ +void TestEventRepeater::Name(const Type& parameter) { \ + if (forwarding_enabled_) { \ + for (size_t i = 0; i < listeners_.size(); i++) { \ + listeners_[i]->Name(parameter); \ + } \ + } \ +} +// This defines a member that forwards the call to all listeners in reverse +// order. +#define GTEST_REVERSE_REPEATER_METHOD_(Name, Type) \ + void TestEventRepeater::Name(const Type& parameter) { \ + if (forwarding_enabled_) { \ + for (size_t i = listeners_.size(); i != 0; i--) { \ + listeners_[i - 1]->Name(parameter); \ + } \ + } \ + } + +GTEST_REPEATER_METHOD_(OnTestProgramStart, UnitTest) +GTEST_REPEATER_METHOD_(OnEnvironmentsSetUpStart, UnitTest) +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +GTEST_REPEATER_METHOD_(OnTestCaseStart, TestSuite) +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +GTEST_REPEATER_METHOD_(OnTestSuiteStart, TestSuite) +GTEST_REPEATER_METHOD_(OnTestStart, TestInfo) +GTEST_REPEATER_METHOD_(OnTestDisabled, TestInfo) +GTEST_REPEATER_METHOD_(OnTestPartResult, TestPartResult) +GTEST_REPEATER_METHOD_(OnEnvironmentsTearDownStart, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsSetUpEnd, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsTearDownEnd, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnTestEnd, TestInfo) +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +GTEST_REVERSE_REPEATER_METHOD_(OnTestCaseEnd, TestSuite) +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +GTEST_REVERSE_REPEATER_METHOD_(OnTestSuiteEnd, TestSuite) +GTEST_REVERSE_REPEATER_METHOD_(OnTestProgramEnd, UnitTest) + +#undef GTEST_REPEATER_METHOD_ +#undef GTEST_REVERSE_REPEATER_METHOD_ + +void TestEventRepeater::OnTestIterationStart(const UnitTest& unit_test, + int iteration) { + if (forwarding_enabled_) { + for (size_t i = 0; i < listeners_.size(); i++) { + listeners_[i]->OnTestIterationStart(unit_test, iteration); + } + } +} + +void TestEventRepeater::OnTestIterationEnd(const UnitTest& unit_test, + int iteration) { + if (forwarding_enabled_) { + for (size_t i = listeners_.size(); i > 0; i--) { + listeners_[i - 1]->OnTestIterationEnd(unit_test, iteration); + } + } +} + +// End TestEventRepeater + +// This class generates an XML output file. +class XmlUnitTestResultPrinter : public EmptyTestEventListener { + public: + explicit XmlUnitTestResultPrinter(const char* output_file); + + void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; + void ListTestsMatchingFilter(const std::vector& test_suites); + + // Prints an XML summary of all unit tests. + static void PrintXmlTestsList(std::ostream* stream, + const std::vector& test_suites); + + private: + // Is c a whitespace character that is normalized to a space character + // when it appears in an XML attribute value? + static bool IsNormalizableWhitespace(unsigned char c) { + return c == '\t' || c == '\n' || c == '\r'; + } + + // May c appear in a well-formed XML document? + // https://www.w3.org/TR/REC-xml/#charsets + static bool IsValidXmlCharacter(unsigned char c) { + return IsNormalizableWhitespace(c) || c >= 0x20; + } + + // Returns an XML-escaped copy of the input string str. If + // is_attribute is true, the text is meant to appear as an attribute + // value, and normalizable whitespace is preserved by replacing it + // with character references. + static std::string EscapeXml(const std::string& str, bool is_attribute); + + // Returns the given string with all characters invalid in XML removed. + static std::string RemoveInvalidXmlCharacters(const std::string& str); + + // Convenience wrapper around EscapeXml when str is an attribute value. + static std::string EscapeXmlAttribute(const std::string& str) { + return EscapeXml(str, true); + } + + // Convenience wrapper around EscapeXml when str is not an attribute value. + static std::string EscapeXmlText(const char* str) { + return EscapeXml(str, false); + } + + // Verifies that the given attribute belongs to the given element and + // streams the attribute as XML. + static void OutputXmlAttribute(std::ostream* stream, + const std::string& element_name, + const std::string& name, + const std::string& value); + + // Streams an XML CDATA section, escaping invalid CDATA sequences as needed. + static void OutputXmlCDataSection(::std::ostream* stream, const char* data); + + // Streams a test suite XML stanza containing the given test result. + // + // Requires: result.Failed() + static void OutputXmlTestSuiteForTestResult(::std::ostream* stream, + const TestResult& result); + + // Streams an XML representation of a TestResult object. + static void OutputXmlTestResult(::std::ostream* stream, + const TestResult& result); + + // Streams an XML representation of a TestInfo object. + static void OutputXmlTestInfo(::std::ostream* stream, + const char* test_suite_name, + const TestInfo& test_info); + + // Prints an XML representation of a TestSuite object + static void PrintXmlTestSuite(::std::ostream* stream, + const TestSuite& test_suite); + + // Prints an XML summary of unit_test to output stream out. + static void PrintXmlUnitTest(::std::ostream* stream, + const UnitTest& unit_test); + + // Produces a string representing the test properties in a result as space + // delimited XML attributes based on the property key="value" pairs. + // When the std::string is not empty, it includes a space at the beginning, + // to delimit this attribute from prior attributes. + static std::string TestPropertiesAsXmlAttributes(const TestResult& result); + + // Streams an XML representation of the test properties of a TestResult + // object. + static void OutputXmlTestProperties(std::ostream* stream, + const TestResult& result); + + // The output file. + const std::string output_file_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(XmlUnitTestResultPrinter); +}; + +// Creates a new XmlUnitTestResultPrinter. +XmlUnitTestResultPrinter::XmlUnitTestResultPrinter(const char* output_file) + : output_file_(output_file) { + if (output_file_.empty()) { + GTEST_LOG_(FATAL) << "XML output file may not be null"; + } +} + +// Called after the unit test ends. +void XmlUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + FILE* xmlout = OpenFileForWriting(output_file_); + std::stringstream stream; + PrintXmlUnitTest(&stream, unit_test); + fprintf(xmlout, "%s", StringStreamToString(&stream).c_str()); + fclose(xmlout); +} + +void XmlUnitTestResultPrinter::ListTestsMatchingFilter( + const std::vector& test_suites) { + FILE* xmlout = OpenFileForWriting(output_file_); + std::stringstream stream; + PrintXmlTestsList(&stream, test_suites); + fprintf(xmlout, "%s", StringStreamToString(&stream).c_str()); + fclose(xmlout); +} + +// Returns an XML-escaped copy of the input string str. If is_attribute +// is true, the text is meant to appear as an attribute value, and +// normalizable whitespace is preserved by replacing it with character +// references. +// +// Invalid XML characters in str, if any, are stripped from the output. +// It is expected that most, if not all, of the text processed by this +// module will consist of ordinary English text. +// If this module is ever modified to produce version 1.1 XML output, +// most invalid characters can be retained using character references. +std::string XmlUnitTestResultPrinter::EscapeXml( + const std::string& str, bool is_attribute) { + Message m; + + for (size_t i = 0; i < str.size(); ++i) { + const char ch = str[i]; + switch (ch) { + case '<': + m << "<"; + break; + case '>': + m << ">"; + break; + case '&': + m << "&"; + break; + case '\'': + if (is_attribute) + m << "'"; + else + m << '\''; + break; + case '"': + if (is_attribute) + m << """; + else + m << '"'; + break; + default: + if (IsValidXmlCharacter(static_cast(ch))) { + if (is_attribute && + IsNormalizableWhitespace(static_cast(ch))) + m << "&#x" << String::FormatByte(static_cast(ch)) + << ";"; + else + m << ch; + } + break; + } + } + + return m.GetString(); +} + +// Returns the given string with all characters invalid in XML removed. +// Currently invalid characters are dropped from the string. An +// alternative is to replace them with certain characters such as . or ?. +std::string XmlUnitTestResultPrinter::RemoveInvalidXmlCharacters( + const std::string& str) { + std::string output; + output.reserve(str.size()); + for (std::string::const_iterator it = str.begin(); it != str.end(); ++it) + if (IsValidXmlCharacter(static_cast(*it))) + output.push_back(*it); + + return output; +} + +// The following routines generate an XML representation of a UnitTest +// object. +// +// This is how Google Test concepts map to the DTD: +// +// <-- corresponds to a UnitTest object +// <-- corresponds to a TestSuite object +// <-- corresponds to a TestInfo object +// ... +// ... +// ... +// <-- individual assertion failures +// +// +// + +// Formats the given time in milliseconds as seconds. +std::string FormatTimeInMillisAsSeconds(TimeInMillis ms) { + ::std::stringstream ss; + ss << (static_cast(ms) * 1e-3); + return ss.str(); +} + +static bool PortableLocaltime(time_t seconds, struct tm* out) { +#if defined(_MSC_VER) + return localtime_s(out, &seconds) == 0; +#elif defined(__MINGW32__) || defined(__MINGW64__) + // MINGW provides neither localtime_r nor localtime_s, but uses + // Windows' localtime(), which has a thread-local tm buffer. + struct tm* tm_ptr = localtime(&seconds); // NOLINT + if (tm_ptr == nullptr) return false; + *out = *tm_ptr; + return true; +#elif defined(__STDC_LIB_EXT1__) + // Uses localtime_s when available as localtime_r is only available from + // C23 standard. + return localtime_s(&seconds, out) != nullptr; +#else + return localtime_r(&seconds, out) != nullptr; +#endif +} + +// Converts the given epoch time in milliseconds to a date string in the ISO +// 8601 format, without the timezone information. +std::string FormatEpochTimeInMillisAsIso8601(TimeInMillis ms) { + struct tm time_struct; + if (!PortableLocaltime(static_cast(ms / 1000), &time_struct)) + return ""; + // YYYY-MM-DDThh:mm:ss.sss + return StreamableToString(time_struct.tm_year + 1900) + "-" + + String::FormatIntWidth2(time_struct.tm_mon + 1) + "-" + + String::FormatIntWidth2(time_struct.tm_mday) + "T" + + String::FormatIntWidth2(time_struct.tm_hour) + ":" + + String::FormatIntWidth2(time_struct.tm_min) + ":" + + String::FormatIntWidth2(time_struct.tm_sec) + "." + + String::FormatIntWidthN(static_cast(ms % 1000), 3); +} + +// Streams an XML CDATA section, escaping invalid CDATA sequences as needed. +void XmlUnitTestResultPrinter::OutputXmlCDataSection(::std::ostream* stream, + const char* data) { + const char* segment = data; + *stream << ""); + if (next_segment != nullptr) { + stream->write( + segment, static_cast(next_segment - segment)); + *stream << "]]>]]>"); + } else { + *stream << segment; + break; + } + } + *stream << "]]>"; +} + +void XmlUnitTestResultPrinter::OutputXmlAttribute( + std::ostream* stream, + const std::string& element_name, + const std::string& name, + const std::string& value) { + const std::vector& allowed_names = + GetReservedOutputAttributesForElement(element_name); + + GTEST_CHECK_(std::find(allowed_names.begin(), allowed_names.end(), name) != + allowed_names.end()) + << "Attribute " << name << " is not allowed for element <" << element_name + << ">."; + + *stream << " " << name << "=\"" << EscapeXmlAttribute(value) << "\""; +} + +// Streams a test suite XML stanza containing the given test result. +void XmlUnitTestResultPrinter::OutputXmlTestSuiteForTestResult( + ::std::ostream* stream, const TestResult& result) { + // Output the boilerplate for a minimal test suite with one test. + *stream << " "; + + // Output the boilerplate for a minimal test case with a single test. + *stream << " \n"; +} + +// Prints an XML representation of a TestInfo object. +void XmlUnitTestResultPrinter::OutputXmlTestInfo(::std::ostream* stream, + const char* test_suite_name, + const TestInfo& test_info) { + const TestResult& result = *test_info.result(); + const std::string kTestsuite = "testcase"; + + if (test_info.is_in_another_shard()) { + return; + } + + *stream << " \n"; + return; + } + + OutputXmlAttribute(stream, kTestsuite, "status", + test_info.should_run() ? "run" : "notrun"); + OutputXmlAttribute(stream, kTestsuite, "result", + test_info.should_run() + ? (result.Skipped() ? "skipped" : "completed") + : "suppressed"); + OutputXmlAttribute(stream, kTestsuite, "time", + FormatTimeInMillisAsSeconds(result.elapsed_time())); + OutputXmlAttribute( + stream, kTestsuite, "timestamp", + FormatEpochTimeInMillisAsIso8601(result.start_timestamp())); + OutputXmlAttribute(stream, kTestsuite, "classname", test_suite_name); + + OutputXmlTestResult(stream, result); +} + +void XmlUnitTestResultPrinter::OutputXmlTestResult(::std::ostream* stream, + const TestResult& result) { + int failures = 0; + int skips = 0; + for (int i = 0; i < result.total_part_count(); ++i) { + const TestPartResult& part = result.GetTestPartResult(i); + if (part.failed()) { + if (++failures == 1 && skips == 0) { + *stream << ">\n"; + } + const std::string location = + internal::FormatCompilerIndependentFileLocation(part.file_name(), + part.line_number()); + const std::string summary = location + "\n" + part.summary(); + *stream << " "; + const std::string detail = location + "\n" + part.message(); + OutputXmlCDataSection(stream, RemoveInvalidXmlCharacters(detail).c_str()); + *stream << "\n"; + } else if (part.skipped()) { + if (++skips == 1 && failures == 0) { + *stream << ">\n"; + } + const std::string location = + internal::FormatCompilerIndependentFileLocation(part.file_name(), + part.line_number()); + const std::string summary = location + "\n" + part.summary(); + *stream << " "; + const std::string detail = location + "\n" + part.message(); + OutputXmlCDataSection(stream, RemoveInvalidXmlCharacters(detail).c_str()); + *stream << "\n"; + } + } + + if (failures == 0 && skips == 0 && result.test_property_count() == 0) { + *stream << " />\n"; + } else { + if (failures == 0 && skips == 0) { + *stream << ">\n"; + } + OutputXmlTestProperties(stream, result); + *stream << " \n"; + } +} + +// Prints an XML representation of a TestSuite object +void XmlUnitTestResultPrinter::PrintXmlTestSuite(std::ostream* stream, + const TestSuite& test_suite) { + const std::string kTestsuite = "testsuite"; + *stream << " <" << kTestsuite; + OutputXmlAttribute(stream, kTestsuite, "name", test_suite.name()); + OutputXmlAttribute(stream, kTestsuite, "tests", + StreamableToString(test_suite.reportable_test_count())); + if (!GTEST_FLAG_GET(list_tests)) { + OutputXmlAttribute(stream, kTestsuite, "failures", + StreamableToString(test_suite.failed_test_count())); + OutputXmlAttribute( + stream, kTestsuite, "disabled", + StreamableToString(test_suite.reportable_disabled_test_count())); + OutputXmlAttribute(stream, kTestsuite, "skipped", + StreamableToString(test_suite.skipped_test_count())); + + OutputXmlAttribute(stream, kTestsuite, "errors", "0"); + + OutputXmlAttribute(stream, kTestsuite, "time", + FormatTimeInMillisAsSeconds(test_suite.elapsed_time())); + OutputXmlAttribute( + stream, kTestsuite, "timestamp", + FormatEpochTimeInMillisAsIso8601(test_suite.start_timestamp())); + *stream << TestPropertiesAsXmlAttributes(test_suite.ad_hoc_test_result()); + } + *stream << ">\n"; + for (int i = 0; i < test_suite.total_test_count(); ++i) { + if (test_suite.GetTestInfo(i)->is_reportable()) + OutputXmlTestInfo(stream, test_suite.name(), *test_suite.GetTestInfo(i)); + } + *stream << " \n"; +} + +// Prints an XML summary of unit_test to output stream out. +void XmlUnitTestResultPrinter::PrintXmlUnitTest(std::ostream* stream, + const UnitTest& unit_test) { + const std::string kTestsuites = "testsuites"; + + *stream << "\n"; + *stream << "<" << kTestsuites; + + OutputXmlAttribute(stream, kTestsuites, "tests", + StreamableToString(unit_test.reportable_test_count())); + OutputXmlAttribute(stream, kTestsuites, "failures", + StreamableToString(unit_test.failed_test_count())); + OutputXmlAttribute( + stream, kTestsuites, "disabled", + StreamableToString(unit_test.reportable_disabled_test_count())); + OutputXmlAttribute(stream, kTestsuites, "errors", "0"); + OutputXmlAttribute(stream, kTestsuites, "time", + FormatTimeInMillisAsSeconds(unit_test.elapsed_time())); + OutputXmlAttribute( + stream, kTestsuites, "timestamp", + FormatEpochTimeInMillisAsIso8601(unit_test.start_timestamp())); + + if (GTEST_FLAG_GET(shuffle)) { + OutputXmlAttribute(stream, kTestsuites, "random_seed", + StreamableToString(unit_test.random_seed())); + } + *stream << TestPropertiesAsXmlAttributes(unit_test.ad_hoc_test_result()); + + OutputXmlAttribute(stream, kTestsuites, "name", "AllTests"); + *stream << ">\n"; + + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { + if (unit_test.GetTestSuite(i)->reportable_test_count() > 0) + PrintXmlTestSuite(stream, *unit_test.GetTestSuite(i)); + } + + // If there was a test failure outside of one of the test suites (like in a + // test environment) include that in the output. + if (unit_test.ad_hoc_test_result().Failed()) { + OutputXmlTestSuiteForTestResult(stream, unit_test.ad_hoc_test_result()); + } + + *stream << "\n"; +} + +void XmlUnitTestResultPrinter::PrintXmlTestsList( + std::ostream* stream, const std::vector& test_suites) { + const std::string kTestsuites = "testsuites"; + + *stream << "\n"; + *stream << "<" << kTestsuites; + + int total_tests = 0; + for (auto test_suite : test_suites) { + total_tests += test_suite->total_test_count(); + } + OutputXmlAttribute(stream, kTestsuites, "tests", + StreamableToString(total_tests)); + OutputXmlAttribute(stream, kTestsuites, "name", "AllTests"); + *stream << ">\n"; + + for (auto test_suite : test_suites) { + PrintXmlTestSuite(stream, *test_suite); + } + *stream << "\n"; +} + +// Produces a string representing the test properties in a result as space +// delimited XML attributes based on the property key="value" pairs. +std::string XmlUnitTestResultPrinter::TestPropertiesAsXmlAttributes( + const TestResult& result) { + Message attributes; + for (int i = 0; i < result.test_property_count(); ++i) { + const TestProperty& property = result.GetTestProperty(i); + attributes << " " << property.key() << "=" + << "\"" << EscapeXmlAttribute(property.value()) << "\""; + } + return attributes.GetString(); +} + +void XmlUnitTestResultPrinter::OutputXmlTestProperties( + std::ostream* stream, const TestResult& result) { + const std::string kProperties = "properties"; + const std::string kProperty = "property"; + + if (result.test_property_count() <= 0) { + return; + } + + *stream << " <" << kProperties << ">\n"; + for (int i = 0; i < result.test_property_count(); ++i) { + const TestProperty& property = result.GetTestProperty(i); + *stream << " <" << kProperty; + *stream << " name=\"" << EscapeXmlAttribute(property.key()) << "\""; + *stream << " value=\"" << EscapeXmlAttribute(property.value()) << "\""; + *stream << "/>\n"; + } + *stream << " \n"; +} + +// End XmlUnitTestResultPrinter + +// This class generates an JSON output file. +class JsonUnitTestResultPrinter : public EmptyTestEventListener { + public: + explicit JsonUnitTestResultPrinter(const char* output_file); + + void OnTestIterationEnd(const UnitTest& unit_test, int iteration) override; + + // Prints an JSON summary of all unit tests. + static void PrintJsonTestList(::std::ostream* stream, + const std::vector& test_suites); + + private: + // Returns an JSON-escaped copy of the input string str. + static std::string EscapeJson(const std::string& str); + + //// Verifies that the given attribute belongs to the given element and + //// streams the attribute as JSON. + static void OutputJsonKey(std::ostream* stream, + const std::string& element_name, + const std::string& name, + const std::string& value, + const std::string& indent, + bool comma = true); + static void OutputJsonKey(std::ostream* stream, + const std::string& element_name, + const std::string& name, + int value, + const std::string& indent, + bool comma = true); + + // Streams a test suite JSON stanza containing the given test result. + // + // Requires: result.Failed() + static void OutputJsonTestSuiteForTestResult(::std::ostream* stream, + const TestResult& result); + + // Streams a JSON representation of a TestResult object. + static void OutputJsonTestResult(::std::ostream* stream, + const TestResult& result); + + // Streams a JSON representation of a TestInfo object. + static void OutputJsonTestInfo(::std::ostream* stream, + const char* test_suite_name, + const TestInfo& test_info); + + // Prints a JSON representation of a TestSuite object + static void PrintJsonTestSuite(::std::ostream* stream, + const TestSuite& test_suite); + + // Prints a JSON summary of unit_test to output stream out. + static void PrintJsonUnitTest(::std::ostream* stream, + const UnitTest& unit_test); + + // Produces a string representing the test properties in a result as + // a JSON dictionary. + static std::string TestPropertiesAsJson(const TestResult& result, + const std::string& indent); + + // The output file. + const std::string output_file_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(JsonUnitTestResultPrinter); +}; + +// Creates a new JsonUnitTestResultPrinter. +JsonUnitTestResultPrinter::JsonUnitTestResultPrinter(const char* output_file) + : output_file_(output_file) { + if (output_file_.empty()) { + GTEST_LOG_(FATAL) << "JSON output file may not be null"; + } +} + +void JsonUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + FILE* jsonout = OpenFileForWriting(output_file_); + std::stringstream stream; + PrintJsonUnitTest(&stream, unit_test); + fprintf(jsonout, "%s", StringStreamToString(&stream).c_str()); + fclose(jsonout); +} + +// Returns an JSON-escaped copy of the input string str. +std::string JsonUnitTestResultPrinter::EscapeJson(const std::string& str) { + Message m; + + for (size_t i = 0; i < str.size(); ++i) { + const char ch = str[i]; + switch (ch) { + case '\\': + case '"': + case '/': + m << '\\' << ch; + break; + case '\b': + m << "\\b"; + break; + case '\t': + m << "\\t"; + break; + case '\n': + m << "\\n"; + break; + case '\f': + m << "\\f"; + break; + case '\r': + m << "\\r"; + break; + default: + if (ch < ' ') { + m << "\\u00" << String::FormatByte(static_cast(ch)); + } else { + m << ch; + } + break; + } + } + + return m.GetString(); +} + +// The following routines generate an JSON representation of a UnitTest +// object. + +// Formats the given time in milliseconds as seconds. +static std::string FormatTimeInMillisAsDuration(TimeInMillis ms) { + ::std::stringstream ss; + ss << (static_cast(ms) * 1e-3) << "s"; + return ss.str(); +} + +// Converts the given epoch time in milliseconds to a date string in the +// RFC3339 format, without the timezone information. +static std::string FormatEpochTimeInMillisAsRFC3339(TimeInMillis ms) { + struct tm time_struct; + if (!PortableLocaltime(static_cast(ms / 1000), &time_struct)) + return ""; + // YYYY-MM-DDThh:mm:ss + return StreamableToString(time_struct.tm_year + 1900) + "-" + + String::FormatIntWidth2(time_struct.tm_mon + 1) + "-" + + String::FormatIntWidth2(time_struct.tm_mday) + "T" + + String::FormatIntWidth2(time_struct.tm_hour) + ":" + + String::FormatIntWidth2(time_struct.tm_min) + ":" + + String::FormatIntWidth2(time_struct.tm_sec) + "Z"; +} + +static inline std::string Indent(size_t width) { + return std::string(width, ' '); +} + +void JsonUnitTestResultPrinter::OutputJsonKey( + std::ostream* stream, + const std::string& element_name, + const std::string& name, + const std::string& value, + const std::string& indent, + bool comma) { + const std::vector& allowed_names = + GetReservedOutputAttributesForElement(element_name); + + GTEST_CHECK_(std::find(allowed_names.begin(), allowed_names.end(), name) != + allowed_names.end()) + << "Key \"" << name << "\" is not allowed for value \"" << element_name + << "\"."; + + *stream << indent << "\"" << name << "\": \"" << EscapeJson(value) << "\""; + if (comma) + *stream << ",\n"; +} + +void JsonUnitTestResultPrinter::OutputJsonKey( + std::ostream* stream, + const std::string& element_name, + const std::string& name, + int value, + const std::string& indent, + bool comma) { + const std::vector& allowed_names = + GetReservedOutputAttributesForElement(element_name); + + GTEST_CHECK_(std::find(allowed_names.begin(), allowed_names.end(), name) != + allowed_names.end()) + << "Key \"" << name << "\" is not allowed for value \"" << element_name + << "\"."; + + *stream << indent << "\"" << name << "\": " << StreamableToString(value); + if (comma) + *stream << ",\n"; +} + +// Streams a test suite JSON stanza containing the given test result. +void JsonUnitTestResultPrinter::OutputJsonTestSuiteForTestResult( + ::std::ostream* stream, const TestResult& result) { + // Output the boilerplate for a new test suite. + *stream << Indent(4) << "{\n"; + OutputJsonKey(stream, "testsuite", "name", "NonTestSuiteFailure", Indent(6)); + OutputJsonKey(stream, "testsuite", "tests", 1, Indent(6)); + if (!GTEST_FLAG_GET(list_tests)) { + OutputJsonKey(stream, "testsuite", "failures", 1, Indent(6)); + OutputJsonKey(stream, "testsuite", "disabled", 0, Indent(6)); + OutputJsonKey(stream, "testsuite", "skipped", 0, Indent(6)); + OutputJsonKey(stream, "testsuite", "errors", 0, Indent(6)); + OutputJsonKey(stream, "testsuite", "time", + FormatTimeInMillisAsDuration(result.elapsed_time()), + Indent(6)); + OutputJsonKey(stream, "testsuite", "timestamp", + FormatEpochTimeInMillisAsRFC3339(result.start_timestamp()), + Indent(6)); + } + *stream << Indent(6) << "\"testsuite\": [\n"; + + // Output the boilerplate for a new test case. + *stream << Indent(8) << "{\n"; + OutputJsonKey(stream, "testcase", "name", "", Indent(10)); + OutputJsonKey(stream, "testcase", "status", "RUN", Indent(10)); + OutputJsonKey(stream, "testcase", "result", "COMPLETED", Indent(10)); + OutputJsonKey(stream, "testcase", "timestamp", + FormatEpochTimeInMillisAsRFC3339(result.start_timestamp()), + Indent(10)); + OutputJsonKey(stream, "testcase", "time", + FormatTimeInMillisAsDuration(result.elapsed_time()), + Indent(10)); + OutputJsonKey(stream, "testcase", "classname", "", Indent(10), false); + *stream << TestPropertiesAsJson(result, Indent(10)); + + // Output the actual test result. + OutputJsonTestResult(stream, result); + + // Finish the test suite. + *stream << "\n" << Indent(6) << "]\n" << Indent(4) << "}"; +} + +// Prints a JSON representation of a TestInfo object. +void JsonUnitTestResultPrinter::OutputJsonTestInfo(::std::ostream* stream, + const char* test_suite_name, + const TestInfo& test_info) { + const TestResult& result = *test_info.result(); + const std::string kTestsuite = "testcase"; + const std::string kIndent = Indent(10); + + *stream << Indent(8) << "{\n"; + OutputJsonKey(stream, kTestsuite, "name", test_info.name(), kIndent); + + if (test_info.value_param() != nullptr) { + OutputJsonKey(stream, kTestsuite, "value_param", test_info.value_param(), + kIndent); + } + if (test_info.type_param() != nullptr) { + OutputJsonKey(stream, kTestsuite, "type_param", test_info.type_param(), + kIndent); + } + if (GTEST_FLAG_GET(list_tests)) { + OutputJsonKey(stream, kTestsuite, "file", test_info.file(), kIndent); + OutputJsonKey(stream, kTestsuite, "line", test_info.line(), kIndent, false); + *stream << "\n" << Indent(8) << "}"; + return; + } + + OutputJsonKey(stream, kTestsuite, "status", + test_info.should_run() ? "RUN" : "NOTRUN", kIndent); + OutputJsonKey(stream, kTestsuite, "result", + test_info.should_run() + ? (result.Skipped() ? "SKIPPED" : "COMPLETED") + : "SUPPRESSED", + kIndent); + OutputJsonKey(stream, kTestsuite, "timestamp", + FormatEpochTimeInMillisAsRFC3339(result.start_timestamp()), + kIndent); + OutputJsonKey(stream, kTestsuite, "time", + FormatTimeInMillisAsDuration(result.elapsed_time()), kIndent); + OutputJsonKey(stream, kTestsuite, "classname", test_suite_name, kIndent, + false); + *stream << TestPropertiesAsJson(result, kIndent); + + OutputJsonTestResult(stream, result); +} + +void JsonUnitTestResultPrinter::OutputJsonTestResult(::std::ostream* stream, + const TestResult& result) { + const std::string kIndent = Indent(10); + + int failures = 0; + for (int i = 0; i < result.total_part_count(); ++i) { + const TestPartResult& part = result.GetTestPartResult(i); + if (part.failed()) { + *stream << ",\n"; + if (++failures == 1) { + *stream << kIndent << "\"" << "failures" << "\": [\n"; + } + const std::string location = + internal::FormatCompilerIndependentFileLocation(part.file_name(), + part.line_number()); + const std::string message = EscapeJson(location + "\n" + part.message()); + *stream << kIndent << " {\n" + << kIndent << " \"failure\": \"" << message << "\",\n" + << kIndent << " \"type\": \"\"\n" + << kIndent << " }"; + } + } + + if (failures > 0) + *stream << "\n" << kIndent << "]"; + *stream << "\n" << Indent(8) << "}"; +} + +// Prints an JSON representation of a TestSuite object +void JsonUnitTestResultPrinter::PrintJsonTestSuite( + std::ostream* stream, const TestSuite& test_suite) { + const std::string kTestsuite = "testsuite"; + const std::string kIndent = Indent(6); + + *stream << Indent(4) << "{\n"; + OutputJsonKey(stream, kTestsuite, "name", test_suite.name(), kIndent); + OutputJsonKey(stream, kTestsuite, "tests", test_suite.reportable_test_count(), + kIndent); + if (!GTEST_FLAG_GET(list_tests)) { + OutputJsonKey(stream, kTestsuite, "failures", + test_suite.failed_test_count(), kIndent); + OutputJsonKey(stream, kTestsuite, "disabled", + test_suite.reportable_disabled_test_count(), kIndent); + OutputJsonKey(stream, kTestsuite, "errors", 0, kIndent); + OutputJsonKey( + stream, kTestsuite, "timestamp", + FormatEpochTimeInMillisAsRFC3339(test_suite.start_timestamp()), + kIndent); + OutputJsonKey(stream, kTestsuite, "time", + FormatTimeInMillisAsDuration(test_suite.elapsed_time()), + kIndent, false); + *stream << TestPropertiesAsJson(test_suite.ad_hoc_test_result(), kIndent) + << ",\n"; + } + + *stream << kIndent << "\"" << kTestsuite << "\": [\n"; + + bool comma = false; + for (int i = 0; i < test_suite.total_test_count(); ++i) { + if (test_suite.GetTestInfo(i)->is_reportable()) { + if (comma) { + *stream << ",\n"; + } else { + comma = true; + } + OutputJsonTestInfo(stream, test_suite.name(), *test_suite.GetTestInfo(i)); + } + } + *stream << "\n" << kIndent << "]\n" << Indent(4) << "}"; +} + +// Prints a JSON summary of unit_test to output stream out. +void JsonUnitTestResultPrinter::PrintJsonUnitTest(std::ostream* stream, + const UnitTest& unit_test) { + const std::string kTestsuites = "testsuites"; + const std::string kIndent = Indent(2); + *stream << "{\n"; + + OutputJsonKey(stream, kTestsuites, "tests", unit_test.reportable_test_count(), + kIndent); + OutputJsonKey(stream, kTestsuites, "failures", unit_test.failed_test_count(), + kIndent); + OutputJsonKey(stream, kTestsuites, "disabled", + unit_test.reportable_disabled_test_count(), kIndent); + OutputJsonKey(stream, kTestsuites, "errors", 0, kIndent); + if (GTEST_FLAG_GET(shuffle)) { + OutputJsonKey(stream, kTestsuites, "random_seed", unit_test.random_seed(), + kIndent); + } + OutputJsonKey(stream, kTestsuites, "timestamp", + FormatEpochTimeInMillisAsRFC3339(unit_test.start_timestamp()), + kIndent); + OutputJsonKey(stream, kTestsuites, "time", + FormatTimeInMillisAsDuration(unit_test.elapsed_time()), kIndent, + false); + + *stream << TestPropertiesAsJson(unit_test.ad_hoc_test_result(), kIndent) + << ",\n"; + + OutputJsonKey(stream, kTestsuites, "name", "AllTests", kIndent); + *stream << kIndent << "\"" << kTestsuites << "\": [\n"; + + bool comma = false; + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { + if (unit_test.GetTestSuite(i)->reportable_test_count() > 0) { + if (comma) { + *stream << ",\n"; + } else { + comma = true; + } + PrintJsonTestSuite(stream, *unit_test.GetTestSuite(i)); + } + } + + // If there was a test failure outside of one of the test suites (like in a + // test environment) include that in the output. + if (unit_test.ad_hoc_test_result().Failed()) { + OutputJsonTestSuiteForTestResult(stream, unit_test.ad_hoc_test_result()); + } + + *stream << "\n" << kIndent << "]\n" << "}\n"; +} + +void JsonUnitTestResultPrinter::PrintJsonTestList( + std::ostream* stream, const std::vector& test_suites) { + const std::string kTestsuites = "testsuites"; + const std::string kIndent = Indent(2); + *stream << "{\n"; + int total_tests = 0; + for (auto test_suite : test_suites) { + total_tests += test_suite->total_test_count(); + } + OutputJsonKey(stream, kTestsuites, "tests", total_tests, kIndent); + + OutputJsonKey(stream, kTestsuites, "name", "AllTests", kIndent); + *stream << kIndent << "\"" << kTestsuites << "\": [\n"; + + for (size_t i = 0; i < test_suites.size(); ++i) { + if (i != 0) { + *stream << ",\n"; + } + PrintJsonTestSuite(stream, *test_suites[i]); + } + + *stream << "\n" + << kIndent << "]\n" + << "}\n"; +} +// Produces a string representing the test properties in a result as +// a JSON dictionary. +std::string JsonUnitTestResultPrinter::TestPropertiesAsJson( + const TestResult& result, const std::string& indent) { + Message attributes; + for (int i = 0; i < result.test_property_count(); ++i) { + const TestProperty& property = result.GetTestProperty(i); + attributes << ",\n" << indent << "\"" << property.key() << "\": " + << "\"" << EscapeJson(property.value()) << "\""; + } + return attributes.GetString(); +} + +// End JsonUnitTestResultPrinter + +#if GTEST_CAN_STREAM_RESULTS_ + +// Checks if str contains '=', '&', '%' or '\n' characters. If yes, +// replaces them by "%xx" where xx is their hexadecimal value. For +// example, replaces "=" with "%3D". This algorithm is O(strlen(str)) +// in both time and space -- important as the input str may contain an +// arbitrarily long test failure message and stack trace. +std::string StreamingListener::UrlEncode(const char* str) { + std::string result; + result.reserve(strlen(str) + 1); + for (char ch = *str; ch != '\0'; ch = *++str) { + switch (ch) { + case '%': + case '=': + case '&': + case '\n': + result.append("%" + String::FormatByte(static_cast(ch))); + break; + default: + result.push_back(ch); + break; + } + } + return result; +} + +void StreamingListener::SocketWriter::MakeConnection() { + GTEST_CHECK_(sockfd_ == -1) + << "MakeConnection() can't be called when there is already a connection."; + + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; // To allow both IPv4 and IPv6 addresses. + hints.ai_socktype = SOCK_STREAM; + addrinfo* servinfo = nullptr; + + // Use the getaddrinfo() to get a linked list of IP addresses for + // the given host name. + const int error_num = getaddrinfo( + host_name_.c_str(), port_num_.c_str(), &hints, &servinfo); + if (error_num != 0) { + GTEST_LOG_(WARNING) << "stream_result_to: getaddrinfo() failed: " + << gai_strerror(error_num); + } + + // Loop through all the results and connect to the first we can. + for (addrinfo* cur_addr = servinfo; sockfd_ == -1 && cur_addr != nullptr; + cur_addr = cur_addr->ai_next) { + sockfd_ = socket( + cur_addr->ai_family, cur_addr->ai_socktype, cur_addr->ai_protocol); + if (sockfd_ != -1) { + // Connect the client socket to the server socket. + if (connect(sockfd_, cur_addr->ai_addr, cur_addr->ai_addrlen) == -1) { + close(sockfd_); + sockfd_ = -1; + } + } + } + + freeaddrinfo(servinfo); // all done with this structure + + if (sockfd_ == -1) { + GTEST_LOG_(WARNING) << "stream_result_to: failed to connect to " + << host_name_ << ":" << port_num_; + } +} + +// End of class Streaming Listener +#endif // GTEST_CAN_STREAM_RESULTS__ + +// class OsStackTraceGetter + +const char* const OsStackTraceGetterInterface::kElidedFramesMarker = + "... " GTEST_NAME_ " internal frames ..."; + +std::string OsStackTraceGetter::CurrentStackTrace(int max_depth, int skip_count) + GTEST_LOCK_EXCLUDED_(mutex_) { +#if GTEST_HAS_ABSL + std::string result; + + if (max_depth <= 0) { + return result; + } + + max_depth = std::min(max_depth, kMaxStackTraceDepth); + + std::vector raw_stack(max_depth); + // Skips the frames requested by the caller, plus this function. + const int raw_stack_size = + absl::GetStackTrace(&raw_stack[0], max_depth, skip_count + 1); + + void* caller_frame = nullptr; + { + MutexLock lock(&mutex_); + caller_frame = caller_frame_; + } + + for (int i = 0; i < raw_stack_size; ++i) { + if (raw_stack[i] == caller_frame && + !GTEST_FLAG_GET(show_internal_stack_frames)) { + // Add a marker to the trace and stop adding frames. + absl::StrAppend(&result, kElidedFramesMarker, "\n"); + break; + } + + char tmp[1024]; + const char* symbol = "(unknown)"; + if (absl::Symbolize(raw_stack[i], tmp, sizeof(tmp))) { + symbol = tmp; + } + + char line[1024]; + snprintf(line, sizeof(line), " %p: %s\n", raw_stack[i], symbol); + result += line; + } + + return result; + +#else // !GTEST_HAS_ABSL + static_cast(max_depth); + static_cast(skip_count); + return ""; +#endif // GTEST_HAS_ABSL +} + +void OsStackTraceGetter::UponLeavingGTest() GTEST_LOCK_EXCLUDED_(mutex_) { +#if GTEST_HAS_ABSL + void* caller_frame = nullptr; + if (absl::GetStackTrace(&caller_frame, 1, 3) <= 0) { + caller_frame = nullptr; + } + + MutexLock lock(&mutex_); + caller_frame_ = caller_frame; +#endif // GTEST_HAS_ABSL +} + +// A helper class that creates the premature-exit file in its +// constructor and deletes the file in its destructor. +class ScopedPrematureExitFile { + public: + explicit ScopedPrematureExitFile(const char* premature_exit_filepath) + : premature_exit_filepath_(premature_exit_filepath ? + premature_exit_filepath : "") { + // If a path to the premature-exit file is specified... + if (!premature_exit_filepath_.empty()) { + // create the file with a single "0" character in it. I/O + // errors are ignored as there's nothing better we can do and we + // don't want to fail the test because of this. + FILE* pfile = posix::FOpen(premature_exit_filepath_.c_str(), "w"); + fwrite("0", 1, 1, pfile); + fclose(pfile); + } + } + + ~ScopedPrematureExitFile() { +#if !defined GTEST_OS_ESP8266 + if (!premature_exit_filepath_.empty()) { + int retval = remove(premature_exit_filepath_.c_str()); + if (retval) { + GTEST_LOG_(ERROR) << "Failed to remove premature exit filepath \"" + << premature_exit_filepath_ << "\" with error " + << retval; + } + } +#endif + } + + private: + const std::string premature_exit_filepath_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedPrematureExitFile); +}; + +} // namespace internal + +// class TestEventListeners + +TestEventListeners::TestEventListeners() + : repeater_(new internal::TestEventRepeater()), + default_result_printer_(nullptr), + default_xml_generator_(nullptr) {} + +TestEventListeners::~TestEventListeners() { delete repeater_; } + +// Returns the standard listener responsible for the default console +// output. Can be removed from the listeners list to shut down default +// console output. Note that removing this object from the listener list +// with Release transfers its ownership to the user. +void TestEventListeners::Append(TestEventListener* listener) { + repeater_->Append(listener); +} + +// Removes the given event listener from the list and returns it. It then +// becomes the caller's responsibility to delete the listener. Returns +// NULL if the listener is not found in the list. +TestEventListener* TestEventListeners::Release(TestEventListener* listener) { + if (listener == default_result_printer_) + default_result_printer_ = nullptr; + else if (listener == default_xml_generator_) + default_xml_generator_ = nullptr; + return repeater_->Release(listener); +} + +// Returns repeater that broadcasts the TestEventListener events to all +// subscribers. +TestEventListener* TestEventListeners::repeater() { return repeater_; } + +// Sets the default_result_printer attribute to the provided listener. +// The listener is also added to the listener list and previous +// default_result_printer is removed from it and deleted. The listener can +// also be NULL in which case it will not be added to the list. Does +// nothing if the previous and the current listener objects are the same. +void TestEventListeners::SetDefaultResultPrinter(TestEventListener* listener) { + if (default_result_printer_ != listener) { + // It is an error to pass this method a listener that is already in the + // list. + delete Release(default_result_printer_); + default_result_printer_ = listener; + if (listener != nullptr) Append(listener); + } +} + +// Sets the default_xml_generator attribute to the provided listener. The +// listener is also added to the listener list and previous +// default_xml_generator is removed from it and deleted. The listener can +// also be NULL in which case it will not be added to the list. Does +// nothing if the previous and the current listener objects are the same. +void TestEventListeners::SetDefaultXmlGenerator(TestEventListener* listener) { + if (default_xml_generator_ != listener) { + // It is an error to pass this method a listener that is already in the + // list. + delete Release(default_xml_generator_); + default_xml_generator_ = listener; + if (listener != nullptr) Append(listener); + } +} + +// Controls whether events will be forwarded by the repeater to the +// listeners in the list. +bool TestEventListeners::EventForwardingEnabled() const { + return repeater_->forwarding_enabled(); +} + +void TestEventListeners::SuppressEventForwarding() { + repeater_->set_forwarding_enabled(false); +} + +// class UnitTest + +// Gets the singleton UnitTest object. The first time this method is +// called, a UnitTest object is constructed and returned. Consecutive +// calls will return the same object. +// +// We don't protect this under mutex_ as a user is not supposed to +// call this before main() starts, from which point on the return +// value will never change. +UnitTest* UnitTest::GetInstance() { + // CodeGear C++Builder insists on a public destructor for the + // default implementation. Use this implementation to keep good OO + // design with private destructor. + +#if defined(__BORLANDC__) + static UnitTest* const instance = new UnitTest; + return instance; +#else + static UnitTest instance; + return &instance; +#endif // defined(__BORLANDC__) +} + +// Gets the number of successful test suites. +int UnitTest::successful_test_suite_count() const { + return impl()->successful_test_suite_count(); +} + +// Gets the number of failed test suites. +int UnitTest::failed_test_suite_count() const { + return impl()->failed_test_suite_count(); +} + +// Gets the number of all test suites. +int UnitTest::total_test_suite_count() const { + return impl()->total_test_suite_count(); +} + +// Gets the number of all test suites that contain at least one test +// that should run. +int UnitTest::test_suite_to_run_count() const { + return impl()->test_suite_to_run_count(); +} + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +int UnitTest::successful_test_case_count() const { + return impl()->successful_test_suite_count(); +} +int UnitTest::failed_test_case_count() const { + return impl()->failed_test_suite_count(); +} +int UnitTest::total_test_case_count() const { + return impl()->total_test_suite_count(); +} +int UnitTest::test_case_to_run_count() const { + return impl()->test_suite_to_run_count(); +} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +// Gets the number of successful tests. +int UnitTest::successful_test_count() const { + return impl()->successful_test_count(); +} + +// Gets the number of skipped tests. +int UnitTest::skipped_test_count() const { + return impl()->skipped_test_count(); +} + +// Gets the number of failed tests. +int UnitTest::failed_test_count() const { return impl()->failed_test_count(); } + +// Gets the number of disabled tests that will be reported in the XML report. +int UnitTest::reportable_disabled_test_count() const { + return impl()->reportable_disabled_test_count(); +} + +// Gets the number of disabled tests. +int UnitTest::disabled_test_count() const { + return impl()->disabled_test_count(); +} + +// Gets the number of tests to be printed in the XML report. +int UnitTest::reportable_test_count() const { + return impl()->reportable_test_count(); +} + +// Gets the number of all tests. +int UnitTest::total_test_count() const { return impl()->total_test_count(); } + +// Gets the number of tests that should run. +int UnitTest::test_to_run_count() const { return impl()->test_to_run_count(); } + +// Gets the time of the test program start, in ms from the start of the +// UNIX epoch. +internal::TimeInMillis UnitTest::start_timestamp() const { + return impl()->start_timestamp(); +} + +// Gets the elapsed time, in milliseconds. +internal::TimeInMillis UnitTest::elapsed_time() const { + return impl()->elapsed_time(); +} + +// Returns true if and only if the unit test passed (i.e. all test suites +// passed). +bool UnitTest::Passed() const { return impl()->Passed(); } + +// Returns true if and only if the unit test failed (i.e. some test suite +// failed or something outside of all tests failed). +bool UnitTest::Failed() const { return impl()->Failed(); } + +// Gets the i-th test suite among all the test suites. i can range from 0 to +// total_test_suite_count() - 1. If i is not in that range, returns NULL. +const TestSuite* UnitTest::GetTestSuite(int i) const { + return impl()->GetTestSuite(i); +} + +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +const TestCase* UnitTest::GetTestCase(int i) const { + return impl()->GetTestCase(i); +} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +// Returns the TestResult containing information on test failures and +// properties logged outside of individual test suites. +const TestResult& UnitTest::ad_hoc_test_result() const { + return *impl()->ad_hoc_test_result(); +} + +// Gets the i-th test suite among all the test suites. i can range from 0 to +// total_test_suite_count() - 1. If i is not in that range, returns NULL. +TestSuite* UnitTest::GetMutableTestSuite(int i) { + return impl()->GetMutableSuiteCase(i); +} + +// Returns the list of event listeners that can be used to track events +// inside Google Test. +TestEventListeners& UnitTest::listeners() { + return *impl()->listeners(); +} + +// Registers and returns a global test environment. When a test +// program is run, all global test environments will be set-up in the +// order they were registered. After all tests in the program have +// finished, all global test environments will be torn-down in the +// *reverse* order they were registered. +// +// The UnitTest object takes ownership of the given environment. +// +// We don't protect this under mutex_, as we only support calling it +// from the main thread. +Environment* UnitTest::AddEnvironment(Environment* env) { + if (env == nullptr) { + return nullptr; + } + + impl_->environments().push_back(env); + return env; +} + +// Adds a TestPartResult to the current TestResult object. All Google Test +// assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) eventually call +// this to report their results. The user code should use the +// assertion macros instead of calling this directly. +void UnitTest::AddTestPartResult( + TestPartResult::Type result_type, + const char* file_name, + int line_number, + const std::string& message, + const std::string& os_stack_trace) GTEST_LOCK_EXCLUDED_(mutex_) { + Message msg; + msg << message; + + internal::MutexLock lock(&mutex_); + if (impl_->gtest_trace_stack().size() > 0) { + msg << "\n" << GTEST_NAME_ << " trace:"; + + for (size_t i = impl_->gtest_trace_stack().size(); i > 0; --i) { + const internal::TraceInfo& trace = impl_->gtest_trace_stack()[i - 1]; + msg << "\n" << internal::FormatFileLocation(trace.file, trace.line) + << " " << trace.message; + } + } + + if (os_stack_trace.c_str() != nullptr && !os_stack_trace.empty()) { + msg << internal::kStackTraceMarker << os_stack_trace; + } + + const TestPartResult result = TestPartResult( + result_type, file_name, line_number, msg.GetString().c_str()); + impl_->GetTestPartResultReporterForCurrentThread()-> + ReportTestPartResult(result); + + if (result_type != TestPartResult::kSuccess && + result_type != TestPartResult::kSkip) { + // gtest_break_on_failure takes precedence over + // gtest_throw_on_failure. This allows a user to set the latter + // in the code (perhaps in order to use Google Test assertions + // with another testing framework) and specify the former on the + // command line for debugging. + if (GTEST_FLAG_GET(break_on_failure)) { +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT + // Using DebugBreak on Windows allows gtest to still break into a debugger + // when a failure happens and both the --gtest_break_on_failure and + // the --gtest_catch_exceptions flags are specified. + DebugBreak(); +#elif (!defined(__native_client__)) && \ + ((defined(__clang__) || defined(__GNUC__)) && \ + (defined(__x86_64__) || defined(__i386__))) + // with clang/gcc we can achieve the same effect on x86 by invoking int3 + asm("int3"); +#else + // Dereference nullptr through a volatile pointer to prevent the compiler + // from removing. We use this rather than abort() or __builtin_trap() for + // portability: some debuggers don't correctly trap abort(). + *static_cast(nullptr) = 1; +#endif // GTEST_OS_WINDOWS + } else if (GTEST_FLAG_GET(throw_on_failure)) { +#if GTEST_HAS_EXCEPTIONS + throw internal::GoogleTestFailureException(result); +#else + // We cannot call abort() as it generates a pop-up in debug mode + // that cannot be suppressed in VC 7.1 or below. + exit(1); +#endif + } + } +} + +// Adds a TestProperty to the current TestResult object when invoked from +// inside a test, to current TestSuite's ad_hoc_test_result_ when invoked +// from SetUpTestSuite or TearDownTestSuite, or to the global property set +// when invoked elsewhere. If the result already contains a property with +// the same key, the value will be updated. +void UnitTest::RecordProperty(const std::string& key, + const std::string& value) { + impl_->RecordProperty(TestProperty(key, value)); +} + +// Runs all tests in this UnitTest object and prints the result. +// Returns 0 if successful, or 1 otherwise. +// +// We don't protect this under mutex_, as we only support calling it +// from the main thread. +int UnitTest::Run() { + const bool in_death_test_child_process = + GTEST_FLAG_GET(internal_run_death_test).length() > 0; + + // Google Test implements this protocol for catching that a test + // program exits before returning control to Google Test: + // + // 1. Upon start, Google Test creates a file whose absolute path + // is specified by the environment variable + // TEST_PREMATURE_EXIT_FILE. + // 2. When Google Test has finished its work, it deletes the file. + // + // This allows a test runner to set TEST_PREMATURE_EXIT_FILE before + // running a Google-Test-based test program and check the existence + // of the file at the end of the test execution to see if it has + // exited prematurely. + + // If we are in the child process of a death test, don't + // create/delete the premature exit file, as doing so is unnecessary + // and will confuse the parent process. Otherwise, create/delete + // the file upon entering/leaving this function. If the program + // somehow exits before this function has a chance to return, the + // premature-exit file will be left undeleted, causing a test runner + // that understands the premature-exit-file protocol to report the + // test as having failed. + const internal::ScopedPrematureExitFile premature_exit_file( + in_death_test_child_process + ? nullptr + : internal::posix::GetEnv("TEST_PREMATURE_EXIT_FILE")); + + // Captures the value of GTEST_FLAG(catch_exceptions). This value will be + // used for the duration of the program. + impl()->set_catch_exceptions(GTEST_FLAG_GET(catch_exceptions)); + +#if GTEST_OS_WINDOWS + // Either the user wants Google Test to catch exceptions thrown by the + // tests or this is executing in the context of death test child + // process. In either case the user does not want to see pop-up dialogs + // about crashes - they are expected. + if (impl()->catch_exceptions() || in_death_test_child_process) { +# if !GTEST_OS_WINDOWS_MOBILE && !GTEST_OS_WINDOWS_PHONE && !GTEST_OS_WINDOWS_RT + // SetErrorMode doesn't exist on CE. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | + SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); +# endif // !GTEST_OS_WINDOWS_MOBILE + +# if (defined(_MSC_VER) || GTEST_OS_WINDOWS_MINGW) && !GTEST_OS_WINDOWS_MOBILE + // Death test children can be terminated with _abort(). On Windows, + // _abort() can show a dialog with a warning message. This forces the + // abort message to go to stderr instead. + _set_error_mode(_OUT_TO_STDERR); +# endif + +# if defined(_MSC_VER) && !GTEST_OS_WINDOWS_MOBILE + // In the debug version, Visual Studio pops up a separate dialog + // offering a choice to debug the aborted program. We need to suppress + // this dialog or it will pop up for every EXPECT/ASSERT_DEATH statement + // executed. Google Test will notify the user of any unexpected + // failure via stderr. + if (!GTEST_FLAG_GET(break_on_failure)) + _set_abort_behavior( + 0x0, // Clear the following flags: + _WRITE_ABORT_MSG | _CALL_REPORTFAULT); // pop-up window, core dump. + + // In debug mode, the Windows CRT can crash with an assertion over invalid + // input (e.g. passing an invalid file descriptor). The default handling + // for these assertions is to pop up a dialog and wait for user input. + // Instead ask the CRT to dump such assertions to stderr non-interactively. + if (!IsDebuggerPresent()) { + (void)_CrtSetReportMode(_CRT_ASSERT, + _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); + (void)_CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); + } +# endif + } +#endif // GTEST_OS_WINDOWS + + return internal::HandleExceptionsInMethodIfSupported( + impl(), + &internal::UnitTestImpl::RunAllTests, + "auxiliary test code (environments or event listeners)") ? 0 : 1; +} + +// Returns the working directory when the first TEST() or TEST_F() was +// executed. +const char* UnitTest::original_working_dir() const { + return impl_->original_working_dir_.c_str(); +} + +// Returns the TestSuite object for the test that's currently running, +// or NULL if no test is running. +const TestSuite* UnitTest::current_test_suite() const + GTEST_LOCK_EXCLUDED_(mutex_) { + internal::MutexLock lock(&mutex_); + return impl_->current_test_suite(); +} + +// Legacy API is still available but deprecated +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +const TestCase* UnitTest::current_test_case() const + GTEST_LOCK_EXCLUDED_(mutex_) { + internal::MutexLock lock(&mutex_); + return impl_->current_test_suite(); +} +#endif + +// Returns the TestInfo object for the test that's currently running, +// or NULL if no test is running. +const TestInfo* UnitTest::current_test_info() const + GTEST_LOCK_EXCLUDED_(mutex_) { + internal::MutexLock lock(&mutex_); + return impl_->current_test_info(); +} + +// Returns the random seed used at the start of the current test run. +int UnitTest::random_seed() const { return impl_->random_seed(); } + +// Returns ParameterizedTestSuiteRegistry object used to keep track of +// value-parameterized tests and instantiate and register them. +internal::ParameterizedTestSuiteRegistry& +UnitTest::parameterized_test_registry() GTEST_LOCK_EXCLUDED_(mutex_) { + return impl_->parameterized_test_registry(); +} + +// Creates an empty UnitTest. +UnitTest::UnitTest() { + impl_ = new internal::UnitTestImpl(this); +} + +// Destructor of UnitTest. +UnitTest::~UnitTest() { + delete impl_; +} + +// Pushes a trace defined by SCOPED_TRACE() on to the per-thread +// Google Test trace stack. +void UnitTest::PushGTestTrace(const internal::TraceInfo& trace) + GTEST_LOCK_EXCLUDED_(mutex_) { + internal::MutexLock lock(&mutex_); + impl_->gtest_trace_stack().push_back(trace); +} + +// Pops a trace from the per-thread Google Test trace stack. +void UnitTest::PopGTestTrace() + GTEST_LOCK_EXCLUDED_(mutex_) { + internal::MutexLock lock(&mutex_); + impl_->gtest_trace_stack().pop_back(); +} + +namespace internal { + +UnitTestImpl::UnitTestImpl(UnitTest* parent) + : parent_(parent), + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4355 /* using this in initializer */) + default_global_test_part_result_reporter_(this), + default_per_thread_test_part_result_reporter_(this), + GTEST_DISABLE_MSC_WARNINGS_POP_() global_test_part_result_repoter_( + &default_global_test_part_result_reporter_), + per_thread_test_part_result_reporter_( + &default_per_thread_test_part_result_reporter_), + parameterized_test_registry_(), + parameterized_tests_registered_(false), + last_death_test_suite_(-1), + current_test_suite_(nullptr), + current_test_info_(nullptr), + ad_hoc_test_result_(), + os_stack_trace_getter_(nullptr), + post_flag_parse_init_performed_(false), + random_seed_(0), // Will be overridden by the flag before first use. + random_(0), // Will be reseeded before first use. + start_timestamp_(0), + elapsed_time_(0), +#if GTEST_HAS_DEATH_TEST + death_test_factory_(new DefaultDeathTestFactory), +#endif + // Will be overridden by the flag before first use. + catch_exceptions_(false) { + listeners()->SetDefaultResultPrinter(new PrettyUnitTestResultPrinter); +} + +UnitTestImpl::~UnitTestImpl() { + // Deletes every TestSuite. + ForEach(test_suites_, internal::Delete); + + // Deletes every Environment. + ForEach(environments_, internal::Delete); + + delete os_stack_trace_getter_; +} + +// Adds a TestProperty to the current TestResult object when invoked in a +// context of a test, to current test suite's ad_hoc_test_result when invoke +// from SetUpTestSuite/TearDownTestSuite, or to the global property set +// otherwise. If the result already contains a property with the same key, +// the value will be updated. +void UnitTestImpl::RecordProperty(const TestProperty& test_property) { + std::string xml_element; + TestResult* test_result; // TestResult appropriate for property recording. + + if (current_test_info_ != nullptr) { + xml_element = "testcase"; + test_result = &(current_test_info_->result_); + } else if (current_test_suite_ != nullptr) { + xml_element = "testsuite"; + test_result = &(current_test_suite_->ad_hoc_test_result_); + } else { + xml_element = "testsuites"; + test_result = &ad_hoc_test_result_; + } + test_result->RecordProperty(xml_element, test_property); +} + +#if GTEST_HAS_DEATH_TEST +// Disables event forwarding if the control is currently in a death test +// subprocess. Must not be called before InitGoogleTest. +void UnitTestImpl::SuppressTestEventsIfInSubprocess() { + if (internal_run_death_test_flag_.get() != nullptr) + listeners()->SuppressEventForwarding(); +} +#endif // GTEST_HAS_DEATH_TEST + +// Initializes event listeners performing XML output as specified by +// UnitTestOptions. Must not be called before InitGoogleTest. +void UnitTestImpl::ConfigureXmlOutput() { + const std::string& output_format = UnitTestOptions::GetOutputFormat(); + if (output_format == "xml") { + listeners()->SetDefaultXmlGenerator(new XmlUnitTestResultPrinter( + UnitTestOptions::GetAbsolutePathToOutputFile().c_str())); + } else if (output_format == "json") { + listeners()->SetDefaultXmlGenerator(new JsonUnitTestResultPrinter( + UnitTestOptions::GetAbsolutePathToOutputFile().c_str())); + } else if (output_format != "") { + GTEST_LOG_(WARNING) << "WARNING: unrecognized output format \"" + << output_format << "\" ignored."; + } +} + +#if GTEST_CAN_STREAM_RESULTS_ +// Initializes event listeners for streaming test results in string form. +// Must not be called before InitGoogleTest. +void UnitTestImpl::ConfigureStreamingOutput() { + const std::string& target = GTEST_FLAG_GET(stream_result_to); + if (!target.empty()) { + const size_t pos = target.find(':'); + if (pos != std::string::npos) { + listeners()->Append(new StreamingListener(target.substr(0, pos), + target.substr(pos+1))); + } else { + GTEST_LOG_(WARNING) << "unrecognized streaming target \"" << target + << "\" ignored."; + } + } +} +#endif // GTEST_CAN_STREAM_RESULTS_ + +// Performs initialization dependent upon flag values obtained in +// ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to +// ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest +// this function is also called from RunAllTests. Since this function can be +// called more than once, it has to be idempotent. +void UnitTestImpl::PostFlagParsingInit() { + // Ensures that this function does not execute more than once. + if (!post_flag_parse_init_performed_) { + post_flag_parse_init_performed_ = true; + +#if defined(GTEST_CUSTOM_TEST_EVENT_LISTENER_) + // Register to send notifications about key process state changes. + listeners()->Append(new GTEST_CUSTOM_TEST_EVENT_LISTENER_()); +#endif // defined(GTEST_CUSTOM_TEST_EVENT_LISTENER_) + +#if GTEST_HAS_DEATH_TEST + InitDeathTestSubprocessControlInfo(); + SuppressTestEventsIfInSubprocess(); +#endif // GTEST_HAS_DEATH_TEST + + // Registers parameterized tests. This makes parameterized tests + // available to the UnitTest reflection API without running + // RUN_ALL_TESTS. + RegisterParameterizedTests(); + + // Configures listeners for XML output. This makes it possible for users + // to shut down the default XML output before invoking RUN_ALL_TESTS. + ConfigureXmlOutput(); + + if (GTEST_FLAG_GET(brief)) { + listeners()->SetDefaultResultPrinter(new BriefUnitTestResultPrinter); + } + +#if GTEST_CAN_STREAM_RESULTS_ + // Configures listeners for streaming test results to the specified server. + ConfigureStreamingOutput(); +#endif // GTEST_CAN_STREAM_RESULTS_ + +#if GTEST_HAS_ABSL + if (GTEST_FLAG_GET(install_failure_signal_handler)) { + absl::FailureSignalHandlerOptions options; + absl::InstallFailureSignalHandler(options); + } +#endif // GTEST_HAS_ABSL + } +} + +// A predicate that checks the name of a TestSuite against a known +// value. +// +// This is used for implementation of the UnitTest class only. We put +// it in the anonymous namespace to prevent polluting the outer +// namespace. +// +// TestSuiteNameIs is copyable. +class TestSuiteNameIs { + public: + // Constructor. + explicit TestSuiteNameIs(const std::string& name) : name_(name) {} + + // Returns true if and only if the name of test_suite matches name_. + bool operator()(const TestSuite* test_suite) const { + return test_suite != nullptr && + strcmp(test_suite->name(), name_.c_str()) == 0; + } + + private: + std::string name_; +}; + +// Finds and returns a TestSuite with the given name. If one doesn't +// exist, creates one and returns it. It's the CALLER'S +// RESPONSIBILITY to ensure that this function is only called WHEN THE +// TESTS ARE NOT SHUFFLED. +// +// Arguments: +// +// test_suite_name: name of the test suite +// type_param: the name of the test suite's type parameter, or NULL if +// this is not a typed or a type-parameterized test suite. +// set_up_tc: pointer to the function that sets up the test suite +// tear_down_tc: pointer to the function that tears down the test suite +TestSuite* UnitTestImpl::GetTestSuite( + const char* test_suite_name, const char* type_param, + internal::SetUpTestSuiteFunc set_up_tc, + internal::TearDownTestSuiteFunc tear_down_tc) { + // Can we find a TestSuite with the given name? + const auto test_suite = + std::find_if(test_suites_.rbegin(), test_suites_.rend(), + TestSuiteNameIs(test_suite_name)); + + if (test_suite != test_suites_.rend()) return *test_suite; + + // No. Let's create one. + auto* const new_test_suite = + new TestSuite(test_suite_name, type_param, set_up_tc, tear_down_tc); + + const UnitTestFilter death_test_suite_filter(kDeathTestSuiteFilter); + // Is this a death test suite? + if (death_test_suite_filter.MatchesName(test_suite_name)) { + // Yes. Inserts the test suite after the last death test suite + // defined so far. This only works when the test suites haven't + // been shuffled. Otherwise we may end up running a death test + // after a non-death test. + ++last_death_test_suite_; + test_suites_.insert(test_suites_.begin() + last_death_test_suite_, + new_test_suite); + } else { + // No. Appends to the end of the list. + test_suites_.push_back(new_test_suite); + } + + test_suite_indices_.push_back(static_cast(test_suite_indices_.size())); + return new_test_suite; +} + +// Helpers for setting up / tearing down the given environment. They +// are for use in the ForEach() function. +static void SetUpEnvironment(Environment* env) { env->SetUp(); } +static void TearDownEnvironment(Environment* env) { env->TearDown(); } + +// Runs all tests in this UnitTest object, prints the result, and +// returns true if all tests are successful. If any exception is +// thrown during a test, the test is considered to be failed, but the +// rest of the tests will still be run. +// +// When parameterized tests are enabled, it expands and registers +// parameterized tests first in RegisterParameterizedTests(). +// All other functions called from RunAllTests() may safely assume that +// parameterized tests are ready to be counted and run. +bool UnitTestImpl::RunAllTests() { + // True if and only if Google Test is initialized before RUN_ALL_TESTS() is + // called. + const bool gtest_is_initialized_before_run_all_tests = GTestIsInitialized(); + + // Do not run any test if the --help flag was specified. + if (g_help_flag) + return true; + + // Repeats the call to the post-flag parsing initialization in case the + // user didn't call InitGoogleTest. + PostFlagParsingInit(); + + // Even if sharding is not on, test runners may want to use the + // GTEST_SHARD_STATUS_FILE to query whether the test supports the sharding + // protocol. + internal::WriteToShardStatusFileIfNeeded(); + + // True if and only if we are in a subprocess for running a thread-safe-style + // death test. + bool in_subprocess_for_death_test = false; + +#if GTEST_HAS_DEATH_TEST + in_subprocess_for_death_test = + (internal_run_death_test_flag_.get() != nullptr); +# if defined(GTEST_EXTRA_DEATH_TEST_CHILD_SETUP_) + if (in_subprocess_for_death_test) { + GTEST_EXTRA_DEATH_TEST_CHILD_SETUP_(); + } +# endif // defined(GTEST_EXTRA_DEATH_TEST_CHILD_SETUP_) +#endif // GTEST_HAS_DEATH_TEST + + const bool should_shard = ShouldShard(kTestTotalShards, kTestShardIndex, + in_subprocess_for_death_test); + + // Compares the full test names with the filter to decide which + // tests to run. + const bool has_tests_to_run = FilterTests(should_shard + ? HONOR_SHARDING_PROTOCOL + : IGNORE_SHARDING_PROTOCOL) > 0; + + // Lists the tests and exits if the --gtest_list_tests flag was specified. + if (GTEST_FLAG_GET(list_tests)) { + // This must be called *after* FilterTests() has been called. + ListTestsMatchingFilter(); + return true; + } + + random_seed_ = GetRandomSeedFromFlag(GTEST_FLAG_GET(random_seed)); + + // True if and only if at least one test has failed. + bool failed = false; + + TestEventListener* repeater = listeners()->repeater(); + + start_timestamp_ = GetTimeInMillis(); + repeater->OnTestProgramStart(*parent_); + + // How many times to repeat the tests? We don't want to repeat them + // when we are inside the subprocess of a death test. + const int repeat = in_subprocess_for_death_test ? 1 : GTEST_FLAG_GET(repeat); + + // Repeats forever if the repeat count is negative. + const bool gtest_repeat_forever = repeat < 0; + + // Should test environments be set up and torn down for each repeat, or only + // set up on the first and torn down on the last iteration? If there is no + // "last" iteration because the tests will repeat forever, always recreate the + // environments to avoid leaks in case one of the environments is using + // resources that are external to this process. Without this check there would + // be no way to clean up those external resources automatically. + const bool recreate_environments_when_repeating = + GTEST_FLAG_GET(recreate_environments_when_repeating) || + gtest_repeat_forever; + + for (int i = 0; gtest_repeat_forever || i != repeat; i++) { + // We want to preserve failures generated by ad-hoc test + // assertions executed before RUN_ALL_TESTS(). + ClearNonAdHocTestResult(); + + Timer timer; + + // Shuffles test suites and tests if requested. + if (has_tests_to_run && GTEST_FLAG_GET(shuffle)) { + random()->Reseed(static_cast(random_seed_)); + // This should be done before calling OnTestIterationStart(), + // such that a test event listener can see the actual test order + // in the event. + ShuffleTests(); + } + + // Tells the unit test event listeners that the tests are about to start. + repeater->OnTestIterationStart(*parent_, i); + + // Runs each test suite if there is at least one test to run. + if (has_tests_to_run) { + // Sets up all environments beforehand. If test environments aren't + // recreated for each iteration, only do so on the first iteration. + if (i == 0 || recreate_environments_when_repeating) { + repeater->OnEnvironmentsSetUpStart(*parent_); + ForEach(environments_, SetUpEnvironment); + repeater->OnEnvironmentsSetUpEnd(*parent_); + } + + // Runs the tests only if there was no fatal failure or skip triggered + // during global set-up. + if (Test::IsSkipped()) { + // Emit diagnostics when global set-up calls skip, as it will not be + // emitted by default. + TestResult& test_result = + *internal::GetUnitTestImpl()->current_test_result(); + for (int j = 0; j < test_result.total_part_count(); ++j) { + const TestPartResult& test_part_result = + test_result.GetTestPartResult(j); + if (test_part_result.type() == TestPartResult::kSkip) { + const std::string& result = test_part_result.message(); + printf("%s\n", result.c_str()); + } + } + fflush(stdout); + } else if (!Test::HasFatalFailure()) { + for (int test_index = 0; test_index < total_test_suite_count(); + test_index++) { + GetMutableSuiteCase(test_index)->Run(); + if (GTEST_FLAG_GET(fail_fast) && + GetMutableSuiteCase(test_index)->Failed()) { + for (int j = test_index + 1; j < total_test_suite_count(); j++) { + GetMutableSuiteCase(j)->Skip(); + } + break; + } + } + } else if (Test::HasFatalFailure()) { + // If there was a fatal failure during the global setup then we know we + // aren't going to run any tests. Explicitly mark all of the tests as + // skipped to make this obvious in the output. + for (int test_index = 0; test_index < total_test_suite_count(); + test_index++) { + GetMutableSuiteCase(test_index)->Skip(); + } + } + + // Tears down all environments in reverse order afterwards. If test + // environments aren't recreated for each iteration, only do so on the + // last iteration. + if (i == repeat - 1 || recreate_environments_when_repeating) { + repeater->OnEnvironmentsTearDownStart(*parent_); + std::for_each(environments_.rbegin(), environments_.rend(), + TearDownEnvironment); + repeater->OnEnvironmentsTearDownEnd(*parent_); + } + } + + elapsed_time_ = timer.Elapsed(); + + // Tells the unit test event listener that the tests have just finished. + repeater->OnTestIterationEnd(*parent_, i); + + // Gets the result and clears it. + if (!Passed()) { + failed = true; + } + + // Restores the original test order after the iteration. This + // allows the user to quickly repro a failure that happens in the + // N-th iteration without repeating the first (N - 1) iterations. + // This is not enclosed in "if (GTEST_FLAG(shuffle)) { ... }", in + // case the user somehow changes the value of the flag somewhere + // (it's always safe to unshuffle the tests). + UnshuffleTests(); + + if (GTEST_FLAG_GET(shuffle)) { + // Picks a new random seed for each iteration. + random_seed_ = GetNextRandomSeed(random_seed_); + } + } + + repeater->OnTestProgramEnd(*parent_); + + if (!gtest_is_initialized_before_run_all_tests) { + ColoredPrintf( + GTestColor::kRed, + "\nIMPORTANT NOTICE - DO NOT IGNORE:\n" + "This test program did NOT call " GTEST_INIT_GOOGLE_TEST_NAME_ + "() before calling RUN_ALL_TESTS(). This is INVALID. Soon " GTEST_NAME_ + " will start to enforce the valid usage. " + "Please fix it ASAP, or IT WILL START TO FAIL.\n"); // NOLINT +#if GTEST_FOR_GOOGLE_ + ColoredPrintf(GTestColor::kRed, + "For more details, see http://wiki/Main/ValidGUnitMain.\n"); +#endif // GTEST_FOR_GOOGLE_ + } + + return !failed; +} + +// Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file +// if the variable is present. If a file already exists at this location, this +// function will write over it. If the variable is present, but the file cannot +// be created, prints an error and exits. +void WriteToShardStatusFileIfNeeded() { + const char* const test_shard_file = posix::GetEnv(kTestShardStatusFile); + if (test_shard_file != nullptr) { + FILE* const file = posix::FOpen(test_shard_file, "w"); + if (file == nullptr) { + ColoredPrintf(GTestColor::kRed, + "Could not write to the test shard status file \"%s\" " + "specified by the %s environment variable.\n", + test_shard_file, kTestShardStatusFile); + fflush(stdout); + exit(EXIT_FAILURE); + } + fclose(file); + } +} + +// Checks whether sharding is enabled by examining the relevant +// environment variable values. If the variables are present, +// but inconsistent (i.e., shard_index >= total_shards), prints +// an error and exits. If in_subprocess_for_death_test, sharding is +// disabled because it must only be applied to the original test +// process. Otherwise, we could filter out death tests we intended to execute. +bool ShouldShard(const char* total_shards_env, + const char* shard_index_env, + bool in_subprocess_for_death_test) { + if (in_subprocess_for_death_test) { + return false; + } + + const int32_t total_shards = Int32FromEnvOrDie(total_shards_env, -1); + const int32_t shard_index = Int32FromEnvOrDie(shard_index_env, -1); + + if (total_shards == -1 && shard_index == -1) { + return false; + } else if (total_shards == -1 && shard_index != -1) { + const Message msg = Message() + << "Invalid environment variables: you have " + << kTestShardIndex << " = " << shard_index + << ", but have left " << kTestTotalShards << " unset.\n"; + ColoredPrintf(GTestColor::kRed, "%s", msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } else if (total_shards != -1 && shard_index == -1) { + const Message msg = Message() + << "Invalid environment variables: you have " + << kTestTotalShards << " = " << total_shards + << ", but have left " << kTestShardIndex << " unset.\n"; + ColoredPrintf(GTestColor::kRed, "%s", msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } else if (shard_index < 0 || shard_index >= total_shards) { + const Message msg = Message() + << "Invalid environment variables: we require 0 <= " + << kTestShardIndex << " < " << kTestTotalShards + << ", but you have " << kTestShardIndex << "=" << shard_index + << ", " << kTestTotalShards << "=" << total_shards << ".\n"; + ColoredPrintf(GTestColor::kRed, "%s", msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } + + return total_shards > 1; +} + +// Parses the environment variable var as an Int32. If it is unset, +// returns default_val. If it is not an Int32, prints an error +// and aborts. +int32_t Int32FromEnvOrDie(const char* var, int32_t default_val) { + const char* str_val = posix::GetEnv(var); + if (str_val == nullptr) { + return default_val; + } + + int32_t result; + if (!ParseInt32(Message() << "The value of environment variable " << var, + str_val, &result)) { + exit(EXIT_FAILURE); + } + return result; +} + +// Given the total number of shards, the shard index, and the test id, +// returns true if and only if the test should be run on this shard. The test id +// is some arbitrary but unique non-negative integer assigned to each test +// method. Assumes that 0 <= shard_index < total_shards. +bool ShouldRunTestOnShard(int total_shards, int shard_index, int test_id) { + return (test_id % total_shards) == shard_index; +} + +// Compares the name of each test with the user-specified filter to +// decide whether the test should be run, then records the result in +// each TestSuite and TestInfo object. +// If shard_tests == true, further filters tests based on sharding +// variables in the environment - see +// https://github.com/google/googletest/blob/master/googletest/docs/advanced.md +// . Returns the number of tests that should run. +int UnitTestImpl::FilterTests(ReactionToSharding shard_tests) { + const int32_t total_shards = shard_tests == HONOR_SHARDING_PROTOCOL ? + Int32FromEnvOrDie(kTestTotalShards, -1) : -1; + const int32_t shard_index = shard_tests == HONOR_SHARDING_PROTOCOL ? + Int32FromEnvOrDie(kTestShardIndex, -1) : -1; + + const PositiveAndNegativeUnitTestFilter gtest_flag_filter( + GTEST_FLAG_GET(filter)); + const UnitTestFilter disable_test_filter(kDisableTestFilter); + // num_runnable_tests are the number of tests that will + // run across all shards (i.e., match filter and are not disabled). + // num_selected_tests are the number of tests to be run on + // this shard. + int num_runnable_tests = 0; + int num_selected_tests = 0; + for (auto* test_suite : test_suites_) { + const std::string& test_suite_name = test_suite->name(); + test_suite->set_should_run(false); + + for (size_t j = 0; j < test_suite->test_info_list().size(); j++) { + TestInfo* const test_info = test_suite->test_info_list()[j]; + const std::string test_name(test_info->name()); + // A test is disabled if test suite name or test name matches + // kDisableTestFilter. + const bool is_disabled = + disable_test_filter.MatchesName(test_suite_name) || + disable_test_filter.MatchesName(test_name); + test_info->is_disabled_ = is_disabled; + + const bool matches_filter = + gtest_flag_filter.MatchesTest(test_suite_name, test_name); + test_info->matches_filter_ = matches_filter; + + const bool is_runnable = + (GTEST_FLAG_GET(also_run_disabled_tests) || !is_disabled) && + matches_filter; + + const bool is_in_another_shard = + shard_tests != IGNORE_SHARDING_PROTOCOL && + !ShouldRunTestOnShard(total_shards, shard_index, num_runnable_tests); + test_info->is_in_another_shard_ = is_in_another_shard; + const bool is_selected = is_runnable && !is_in_another_shard; + + num_runnable_tests += is_runnable; + num_selected_tests += is_selected; + + test_info->should_run_ = is_selected; + test_suite->set_should_run(test_suite->should_run() || is_selected); + } + } + return num_selected_tests; +} + +// Prints the given C-string on a single line by replacing all '\n' +// characters with string "\\n". If the output takes more than +// max_length characters, only prints the first max_length characters +// and "...". +static void PrintOnOneLine(const char* str, int max_length) { + if (str != nullptr) { + for (int i = 0; *str != '\0'; ++str) { + if (i >= max_length) { + printf("..."); + break; + } + if (*str == '\n') { + printf("\\n"); + i += 2; + } else { + printf("%c", *str); + ++i; + } + } + } +} + +// Prints the names of the tests matching the user-specified filter flag. +void UnitTestImpl::ListTestsMatchingFilter() { + // Print at most this many characters for each type/value parameter. + const int kMaxParamLength = 250; + + for (auto* test_suite : test_suites_) { + bool printed_test_suite_name = false; + + for (size_t j = 0; j < test_suite->test_info_list().size(); j++) { + const TestInfo* const test_info = test_suite->test_info_list()[j]; + if (test_info->matches_filter_) { + if (!printed_test_suite_name) { + printed_test_suite_name = true; + printf("%s.", test_suite->name()); + if (test_suite->type_param() != nullptr) { + printf(" # %s = ", kTypeParamLabel); + // We print the type parameter on a single line to make + // the output easy to parse by a program. + PrintOnOneLine(test_suite->type_param(), kMaxParamLength); + } + printf("\n"); + } + printf(" %s", test_info->name()); + if (test_info->value_param() != nullptr) { + printf(" # %s = ", kValueParamLabel); + // We print the value parameter on a single line to make the + // output easy to parse by a program. + PrintOnOneLine(test_info->value_param(), kMaxParamLength); + } + printf("\n"); + } + } + } + fflush(stdout); + const std::string& output_format = UnitTestOptions::GetOutputFormat(); + if (output_format == "xml" || output_format == "json") { + FILE* fileout = OpenFileForWriting( + UnitTestOptions::GetAbsolutePathToOutputFile().c_str()); + std::stringstream stream; + if (output_format == "xml") { + XmlUnitTestResultPrinter( + UnitTestOptions::GetAbsolutePathToOutputFile().c_str()) + .PrintXmlTestsList(&stream, test_suites_); + } else if (output_format == "json") { + JsonUnitTestResultPrinter( + UnitTestOptions::GetAbsolutePathToOutputFile().c_str()) + .PrintJsonTestList(&stream, test_suites_); + } + fprintf(fileout, "%s", StringStreamToString(&stream).c_str()); + fclose(fileout); + } +} + +// Sets the OS stack trace getter. +// +// Does nothing if the input and the current OS stack trace getter are +// the same; otherwise, deletes the old getter and makes the input the +// current getter. +void UnitTestImpl::set_os_stack_trace_getter( + OsStackTraceGetterInterface* getter) { + if (os_stack_trace_getter_ != getter) { + delete os_stack_trace_getter_; + os_stack_trace_getter_ = getter; + } +} + +// Returns the current OS stack trace getter if it is not NULL; +// otherwise, creates an OsStackTraceGetter, makes it the current +// getter, and returns it. +OsStackTraceGetterInterface* UnitTestImpl::os_stack_trace_getter() { + if (os_stack_trace_getter_ == nullptr) { +#ifdef GTEST_OS_STACK_TRACE_GETTER_ + os_stack_trace_getter_ = new GTEST_OS_STACK_TRACE_GETTER_; +#else + os_stack_trace_getter_ = new OsStackTraceGetter; +#endif // GTEST_OS_STACK_TRACE_GETTER_ + } + + return os_stack_trace_getter_; +} + +// Returns the most specific TestResult currently running. +TestResult* UnitTestImpl::current_test_result() { + if (current_test_info_ != nullptr) { + return ¤t_test_info_->result_; + } + if (current_test_suite_ != nullptr) { + return ¤t_test_suite_->ad_hoc_test_result_; + } + return &ad_hoc_test_result_; +} + +// Shuffles all test suites, and the tests within each test suite, +// making sure that death tests are still run first. +void UnitTestImpl::ShuffleTests() { + // Shuffles the death test suites. + ShuffleRange(random(), 0, last_death_test_suite_ + 1, &test_suite_indices_); + + // Shuffles the non-death test suites. + ShuffleRange(random(), last_death_test_suite_ + 1, + static_cast(test_suites_.size()), &test_suite_indices_); + + // Shuffles the tests inside each test suite. + for (auto& test_suite : test_suites_) { + test_suite->ShuffleTests(random()); + } +} + +// Restores the test suites and tests to their order before the first shuffle. +void UnitTestImpl::UnshuffleTests() { + for (size_t i = 0; i < test_suites_.size(); i++) { + // Unshuffles the tests in each test suite. + test_suites_[i]->UnshuffleTests(); + // Resets the index of each test suite. + test_suite_indices_[i] = static_cast(i); + } +} + +// Returns the current OS stack trace as an std::string. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in +// the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. +GTEST_NO_INLINE_ GTEST_NO_TAIL_CALL_ std::string +GetCurrentOsStackTraceExceptTop(UnitTest* /*unit_test*/, int skip_count) { + // We pass skip_count + 1 to skip this wrapper function in addition + // to what the user really wants to skip. + return GetUnitTestImpl()->CurrentOsStackTraceExceptTop(skip_count + 1); +} + +// Used by the GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_ macro to +// suppress unreachable code warnings. +namespace { +class ClassUniqueToAlwaysTrue {}; +} + +bool IsTrue(bool condition) { return condition; } + +bool AlwaysTrue() { +#if GTEST_HAS_EXCEPTIONS + // This condition is always false so AlwaysTrue() never actually throws, + // but it makes the compiler think that it may throw. + if (IsTrue(false)) + throw ClassUniqueToAlwaysTrue(); +#endif // GTEST_HAS_EXCEPTIONS + return true; +} + +// If *pstr starts with the given prefix, modifies *pstr to be right +// past the prefix and returns true; otherwise leaves *pstr unchanged +// and returns false. None of pstr, *pstr, and prefix can be NULL. +bool SkipPrefix(const char* prefix, const char** pstr) { + const size_t prefix_len = strlen(prefix); + if (strncmp(*pstr, prefix, prefix_len) == 0) { + *pstr += prefix_len; + return true; + } + return false; +} + +// Parses a string as a command line flag. The string should have +// the format "--flag=value". When def_optional is true, the "=value" +// part can be omitted. +// +// Returns the value of the flag, or NULL if the parsing failed. +static const char* ParseFlagValue(const char* str, const char* flag_name, + bool def_optional) { + // str and flag must not be NULL. + if (str == nullptr || flag_name == nullptr) return nullptr; + + // The flag must start with "--" followed by GTEST_FLAG_PREFIX_. + const std::string flag_str = + std::string("--") + GTEST_FLAG_PREFIX_ + flag_name; + const size_t flag_len = flag_str.length(); + if (strncmp(str, flag_str.c_str(), flag_len) != 0) return nullptr; + + // Skips the flag name. + const char* flag_end = str + flag_len; + + // When def_optional is true, it's OK to not have a "=value" part. + if (def_optional && (flag_end[0] == '\0')) { + return flag_end; + } + + // If def_optional is true and there are more characters after the + // flag name, or if def_optional is false, there must be a '=' after + // the flag name. + if (flag_end[0] != '=') return nullptr; + + // Returns the string after "=". + return flag_end + 1; +} + +// Parses a string for a bool flag, in the form of either +// "--flag=value" or "--flag". +// +// In the former case, the value is taken as true as long as it does +// not start with '0', 'f', or 'F'. +// +// In the latter case, the value is taken as true. +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +static bool ParseFlag(const char* str, const char* flag_name, bool* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag_name, true); + + // Aborts if the parsing failed. + if (value_str == nullptr) return false; + + // Converts the string value to a bool. + *value = !(*value_str == '0' || *value_str == 'f' || *value_str == 'F'); + return true; +} + +// Parses a string for an int32_t flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +bool ParseFlag(const char* str, const char* flag_name, int32_t* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag_name, false); + + // Aborts if the parsing failed. + if (value_str == nullptr) return false; + + // Sets *value to the value of the flag. + return ParseInt32(Message() << "The value of flag --" << flag_name, value_str, + value); +} + +// Parses a string for a string flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +template +static bool ParseFlag(const char* str, const char* flag_name, String* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag_name, false); + + // Aborts if the parsing failed. + if (value_str == nullptr) return false; + + // Sets *value to the value of the flag. + *value = value_str; + return true; +} + +// Determines whether a string has a prefix that Google Test uses for its +// flags, i.e., starts with GTEST_FLAG_PREFIX_ or GTEST_FLAG_PREFIX_DASH_. +// If Google Test detects that a command line flag has its prefix but is not +// recognized, it will print its help message. Flags starting with +// GTEST_INTERNAL_PREFIX_ followed by "internal_" are considered Google Test +// internal flags and do not trigger the help message. +static bool HasGoogleTestFlagPrefix(const char* str) { + return (SkipPrefix("--", &str) || + SkipPrefix("-", &str) || + SkipPrefix("/", &str)) && + !SkipPrefix(GTEST_FLAG_PREFIX_ "internal_", &str) && + (SkipPrefix(GTEST_FLAG_PREFIX_, &str) || + SkipPrefix(GTEST_FLAG_PREFIX_DASH_, &str)); +} + +// Prints a string containing code-encoded text. The following escape +// sequences can be used in the string to control the text color: +// +// @@ prints a single '@' character. +// @R changes the color to red. +// @G changes the color to green. +// @Y changes the color to yellow. +// @D changes to the default terminal text color. +// +static void PrintColorEncoded(const char* str) { + GTestColor color = GTestColor::kDefault; // The current color. + + // Conceptually, we split the string into segments divided by escape + // sequences. Then we print one segment at a time. At the end of + // each iteration, the str pointer advances to the beginning of the + // next segment. + for (;;) { + const char* p = strchr(str, '@'); + if (p == nullptr) { + ColoredPrintf(color, "%s", str); + return; + } + + ColoredPrintf(color, "%s", std::string(str, p).c_str()); + + const char ch = p[1]; + str = p + 2; + if (ch == '@') { + ColoredPrintf(color, "@"); + } else if (ch == 'D') { + color = GTestColor::kDefault; + } else if (ch == 'R') { + color = GTestColor::kRed; + } else if (ch == 'G') { + color = GTestColor::kGreen; + } else if (ch == 'Y') { + color = GTestColor::kYellow; + } else { + --str; + } + } +} + +static const char kColorEncodedHelpMessage[] = + "This program contains tests written using " GTEST_NAME_ + ". You can use the\n" + "following command line flags to control its behavior:\n" + "\n" + "Test Selection:\n" + " @G--" GTEST_FLAG_PREFIX_ + "list_tests@D\n" + " List the names of all tests instead of running them. The name of\n" + " TEST(Foo, Bar) is \"Foo.Bar\".\n" + " @G--" GTEST_FLAG_PREFIX_ + "filter=@YPOSITIVE_PATTERNS" + "[@G-@YNEGATIVE_PATTERNS]@D\n" + " Run only the tests whose name matches one of the positive patterns " + "but\n" + " none of the negative patterns. '?' matches any single character; " + "'*'\n" + " matches any substring; ':' separates two patterns.\n" + " @G--" GTEST_FLAG_PREFIX_ + "also_run_disabled_tests@D\n" + " Run all disabled tests too.\n" + "\n" + "Test Execution:\n" + " @G--" GTEST_FLAG_PREFIX_ + "repeat=@Y[COUNT]@D\n" + " Run the tests repeatedly; use a negative count to repeat forever.\n" + " @G--" GTEST_FLAG_PREFIX_ + "shuffle@D\n" + " Randomize tests' orders on every iteration.\n" + " @G--" GTEST_FLAG_PREFIX_ + "random_seed=@Y[NUMBER]@D\n" + " Random number seed to use for shuffling test orders (between 1 and\n" + " 99999, or 0 to use a seed based on the current time).\n" + " @G--" GTEST_FLAG_PREFIX_ + "recreate_environments_when_repeating@D\n" + " Sets up and tears down the global test environment on each repeat\n" + " of the test.\n" + "\n" + "Test Output:\n" + " @G--" GTEST_FLAG_PREFIX_ + "color=@Y(@Gyes@Y|@Gno@Y|@Gauto@Y)@D\n" + " Enable/disable colored output. The default is @Gauto@D.\n" + " @G--" GTEST_FLAG_PREFIX_ + "brief=1@D\n" + " Only print test failures.\n" + " @G--" GTEST_FLAG_PREFIX_ + "print_time=0@D\n" + " Don't print the elapsed time of each test.\n" + " @G--" GTEST_FLAG_PREFIX_ + "output=@Y(@Gjson@Y|@Gxml@Y)[@G:@YDIRECTORY_PATH@G" GTEST_PATH_SEP_ + "@Y|@G:@YFILE_PATH]@D\n" + " Generate a JSON or XML report in the given directory or with the " + "given\n" + " file name. @YFILE_PATH@D defaults to @Gtest_detail.xml@D.\n" +# if GTEST_CAN_STREAM_RESULTS_ + " @G--" GTEST_FLAG_PREFIX_ + "stream_result_to=@YHOST@G:@YPORT@D\n" + " Stream test results to the given server.\n" +# endif // GTEST_CAN_STREAM_RESULTS_ + "\n" + "Assertion Behavior:\n" +# if GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS + " @G--" GTEST_FLAG_PREFIX_ + "death_test_style=@Y(@Gfast@Y|@Gthreadsafe@Y)@D\n" + " Set the default death test style.\n" +# endif // GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS + " @G--" GTEST_FLAG_PREFIX_ + "break_on_failure@D\n" + " Turn assertion failures into debugger break-points.\n" + " @G--" GTEST_FLAG_PREFIX_ + "throw_on_failure@D\n" + " Turn assertion failures into C++ exceptions for use by an external\n" + " test framework.\n" + " @G--" GTEST_FLAG_PREFIX_ + "catch_exceptions=0@D\n" + " Do not report exceptions as test failures. Instead, allow them\n" + " to crash the program or throw a pop-up (on Windows).\n" + "\n" + "Except for @G--" GTEST_FLAG_PREFIX_ + "list_tests@D, you can alternatively set " + "the corresponding\n" + "environment variable of a flag (all letters in upper-case). For example, " + "to\n" + "disable colored text output, you can either specify " + "@G--" GTEST_FLAG_PREFIX_ + "color=no@D or set\n" + "the @G" GTEST_FLAG_PREFIX_UPPER_ + "COLOR@D environment variable to @Gno@D.\n" + "\n" + "For more information, please read the " GTEST_NAME_ + " documentation at\n" + "@G" GTEST_PROJECT_URL_ "@D. If you find a bug in " GTEST_NAME_ + "\n" + "(not one in your own code or tests), please report it to\n" + "@G<" GTEST_DEV_EMAIL_ ">@D.\n"; + +static bool ParseGoogleTestFlag(const char* const arg) { +#define GTEST_INTERNAL_PARSE_FLAG(flag_name) \ + do { \ + auto value = GTEST_FLAG_GET(flag_name); \ + if (ParseFlag(arg, #flag_name, &value)) { \ + GTEST_FLAG_SET(flag_name, value); \ + return true; \ + } \ + } while (false) + + GTEST_INTERNAL_PARSE_FLAG(also_run_disabled_tests); + GTEST_INTERNAL_PARSE_FLAG(break_on_failure); + GTEST_INTERNAL_PARSE_FLAG(catch_exceptions); + GTEST_INTERNAL_PARSE_FLAG(color); + GTEST_INTERNAL_PARSE_FLAG(death_test_style); + GTEST_INTERNAL_PARSE_FLAG(death_test_use_fork); + GTEST_INTERNAL_PARSE_FLAG(fail_fast); + GTEST_INTERNAL_PARSE_FLAG(filter); + GTEST_INTERNAL_PARSE_FLAG(internal_run_death_test); + GTEST_INTERNAL_PARSE_FLAG(list_tests); + GTEST_INTERNAL_PARSE_FLAG(output); + GTEST_INTERNAL_PARSE_FLAG(brief); + GTEST_INTERNAL_PARSE_FLAG(print_time); + GTEST_INTERNAL_PARSE_FLAG(print_utf8); + GTEST_INTERNAL_PARSE_FLAG(random_seed); + GTEST_INTERNAL_PARSE_FLAG(repeat); + GTEST_INTERNAL_PARSE_FLAG(recreate_environments_when_repeating); + GTEST_INTERNAL_PARSE_FLAG(shuffle); + GTEST_INTERNAL_PARSE_FLAG(stack_trace_depth); + GTEST_INTERNAL_PARSE_FLAG(stream_result_to); + GTEST_INTERNAL_PARSE_FLAG(throw_on_failure); + return false; +} + +#if GTEST_USE_OWN_FLAGFILE_FLAG_ +static void LoadFlagsFromFile(const std::string& path) { + FILE* flagfile = posix::FOpen(path.c_str(), "r"); + if (!flagfile) { + GTEST_LOG_(FATAL) << "Unable to open file \"" << GTEST_FLAG_GET(flagfile) + << "\""; + } + std::string contents(ReadEntireFile(flagfile)); + posix::FClose(flagfile); + std::vector lines; + SplitString(contents, '\n', &lines); + for (size_t i = 0; i < lines.size(); ++i) { + if (lines[i].empty()) + continue; + if (!ParseGoogleTestFlag(lines[i].c_str())) + g_help_flag = true; + } +} +#endif // GTEST_USE_OWN_FLAGFILE_FLAG_ + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. The type parameter CharType can be +// instantiated to either char or wchar_t. +template +void ParseGoogleTestFlagsOnlyImpl(int* argc, CharType** argv) { + std::string flagfile_value; + for (int i = 1; i < *argc; i++) { + const std::string arg_string = StreamableToString(argv[i]); + const char* const arg = arg_string.c_str(); + + using internal::ParseFlag; + + bool remove_flag = false; + if (ParseGoogleTestFlag(arg)) { + remove_flag = true; +#if GTEST_USE_OWN_FLAGFILE_FLAG_ + } else if (ParseFlag(arg, "flagfile", &flagfile_value)) { + GTEST_FLAG_SET(flagfile, flagfile_value); + LoadFlagsFromFile(flagfile_value); + remove_flag = true; +#endif // GTEST_USE_OWN_FLAGFILE_FLAG_ + } else if (arg_string == "--help" || arg_string == "-h" || + arg_string == "-?" || arg_string == "/?" || + HasGoogleTestFlagPrefix(arg)) { + // Both help flag and unrecognized Google Test flags (excluding + // internal ones) trigger help display. + g_help_flag = true; + } + + if (remove_flag) { + // Shift the remainder of the argv list left by one. Note + // that argv has (*argc + 1) elements, the last one always being + // NULL. The following loop moves the trailing NULL element as + // well. + for (int j = i; j != *argc; j++) { + argv[j] = argv[j + 1]; + } + + // Decrements the argument count. + (*argc)--; + + // We also need to decrement the iterator as we just removed + // an element. + i--; + } + } + + if (g_help_flag) { + // We print the help here instead of in RUN_ALL_TESTS(), as the + // latter may not be called at all if the user is using Google + // Test with another testing framework. + PrintColorEncoded(kColorEncodedHelpMessage); + } +} + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. +void ParseGoogleTestFlagsOnly(int* argc, char** argv) { + ParseGoogleTestFlagsOnlyImpl(argc, argv); + + // Fix the value of *_NSGetArgc() on macOS, but if and only if + // *_NSGetArgv() == argv + // Only applicable to char** version of argv +#if GTEST_OS_MAC +#ifndef GTEST_OS_IOS + if (*_NSGetArgv() == argv) { + *_NSGetArgc() = *argc; + } +#endif +#endif +} +void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv) { + ParseGoogleTestFlagsOnlyImpl(argc, argv); +} + +// The internal implementation of InitGoogleTest(). +// +// The type parameter CharType can be instantiated to either char or +// wchar_t. +template +void InitGoogleTestImpl(int* argc, CharType** argv) { + // We don't want to run the initialization code twice. + if (GTestIsInitialized()) return; + + if (*argc <= 0) return; + + g_argvs.clear(); + for (int i = 0; i != *argc; i++) { + g_argvs.push_back(StreamableToString(argv[i])); + } + +#if GTEST_HAS_ABSL + absl::InitializeSymbolizer(g_argvs[0].c_str()); +#endif // GTEST_HAS_ABSL + + ParseGoogleTestFlagsOnly(argc, argv); + GetUnitTestImpl()->PostFlagParsingInit(); +} + +} // namespace internal + +// Initializes Google Test. This must be called before calling +// RUN_ALL_TESTS(). In particular, it parses a command line for the +// flags that Google Test recognizes. Whenever a Google Test flag is +// seen, it is removed from argv, and *argc is decremented. +// +// No value is returned. Instead, the Google Test flag variables are +// updated. +// +// Calling the function for the second time has no user-visible effect. +void InitGoogleTest(int* argc, char** argv) { +#if defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) + GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_(argc, argv); +#else // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) + internal::InitGoogleTestImpl(argc, argv); +#endif // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) +} + +// This overloaded version can be used in Windows programs compiled in +// UNICODE mode. +void InitGoogleTest(int* argc, wchar_t** argv) { +#if defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) + GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_(argc, argv); +#else // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) + internal::InitGoogleTestImpl(argc, argv); +#endif // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) +} + +// This overloaded version can be used on Arduino/embedded platforms where +// there is no argc/argv. +void InitGoogleTest() { + // Since Arduino doesn't have a command line, fake out the argc/argv arguments + int argc = 1; + const auto arg0 = "dummy"; + char* argv0 = const_cast(arg0); + char** argv = &argv0; + +#if defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) + GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_(&argc, argv); +#else // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) + internal::InitGoogleTestImpl(&argc, argv); +#endif // defined(GTEST_CUSTOM_INIT_GOOGLE_TEST_FUNCTION_) +} + +std::string TempDir() { +#if defined(GTEST_CUSTOM_TEMPDIR_FUNCTION_) + return GTEST_CUSTOM_TEMPDIR_FUNCTION_(); +#elif GTEST_OS_WINDOWS_MOBILE + return "\\temp\\"; +#elif GTEST_OS_WINDOWS + const char* temp_dir = internal::posix::GetEnv("TEMP"); + if (temp_dir == nullptr || temp_dir[0] == '\0') { + return "\\temp\\"; + } else if (temp_dir[strlen(temp_dir) - 1] == '\\') { + return temp_dir; + } else { + return std::string(temp_dir) + "\\"; + } +#elif GTEST_OS_LINUX_ANDROID + const char* temp_dir = internal::posix::GetEnv("TEST_TMPDIR"); + if (temp_dir == nullptr || temp_dir[0] == '\0') { + return "/data/local/tmp/"; + } else { + return temp_dir; + } +#elif GTEST_OS_LINUX + const char* temp_dir = internal::posix::GetEnv("TEST_TMPDIR"); + if (temp_dir == nullptr || temp_dir[0] == '\0') { + return "/tmp/"; + } else { + return temp_dir; + } +#else + return "/tmp/"; +#endif // GTEST_OS_WINDOWS_MOBILE +} + +// Class ScopedTrace + +// Pushes the given source file location and message onto a per-thread +// trace stack maintained by Google Test. +void ScopedTrace::PushTrace(const char* file, int line, std::string message) { + internal::TraceInfo trace; + trace.file = file; + trace.line = line; + trace.message.swap(message); + + UnitTest::GetInstance()->PushGTestTrace(trace); +} + +// Pops the info pushed by the c'tor. +ScopedTrace::~ScopedTrace() + GTEST_LOCK_EXCLUDED_(&UnitTest::mutex_) { + UnitTest::GetInstance()->PopGTestTrace(); +} + +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest_main.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest_main.cc new file mode 100644 index 000000000000..46b27c3d7d56 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/src/gtest_main.cc @@ -0,0 +1,54 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include "gtest/gtest.h" + +#if GTEST_OS_ESP8266 || GTEST_OS_ESP32 +#if GTEST_OS_ESP8266 +extern "C" { +#endif +void setup() { + testing::InitGoogleTest(); +} + +void loop() { RUN_ALL_TESTS(); } + +#if GTEST_OS_ESP8266 +} +#endif + +#else + +GTEST_API_ int main(int argc, char **argv) { + printf("Running main() from %s\n", __FILE__); + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +#endif diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/BUILD.bazel b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/BUILD.bazel new file mode 100644 index 000000000000..8fd595c705bb --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/BUILD.bazel @@ -0,0 +1,590 @@ +# Copyright 2017 Google Inc. +# All Rights Reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Bazel BUILD for The Google C++ Testing Framework (Google Test) + +load("@rules_python//python:defs.bzl", "py_library", "py_test") + +licenses(["notice"]) + +package(default_visibility = ["//:__subpackages__"]) + +#on windows exclude gtest-tuple.h +cc_test( + name = "gtest_all_test", + size = "small", + srcs = glob( + include = [ + "gtest-*.cc", + "googletest-*.cc", + "*.h", + "googletest/include/gtest/**/*.h", + ], + exclude = [ + "gtest-unittest-api_test.cc", + "googletest/src/gtest-all.cc", + "gtest_all_test.cc", + "gtest-death-test_ex_test.cc", + "gtest-listener_test.cc", + "gtest-unittest-api_test.cc", + "googletest-param-test-test.cc", + "googletest-param-test2-test.cc", + "googletest-catch-exceptions-test_.cc", + "googletest-color-test_.cc", + "googletest-env-var-test_.cc", + "googletest-failfast-unittest_.cc", + "googletest-filter-unittest_.cc", + "googletest-global-environment-unittest_.cc", + "googletest-break-on-failure-unittest_.cc", + "googletest-listener-test.cc", + "googletest-output-test_.cc", + "googletest-list-tests-unittest_.cc", + "googletest-shuffle-test_.cc", + "googletest-setuptestsuite-test_.cc", + "googletest-uninitialized-test_.cc", + "googletest-death-test_ex_test.cc", + "googletest-param-test-test", + "googletest-throw-on-failure-test_.cc", + "googletest-param-test-invalid-name1-test_.cc", + "googletest-param-test-invalid-name2-test_.cc", + ], + ) + select({ + "//:windows": [], + "//conditions:default": [], + }), + copts = select({ + "//:windows": ["-DGTEST_USE_OWN_TR1_TUPLE=0"], + "//conditions:default": ["-DGTEST_USE_OWN_TR1_TUPLE=1"], + }) + select({ + # Ensure MSVC treats source files as UTF-8 encoded. + "//:msvc_compiler": ["-utf-8"], + "//conditions:default": [], + }), + includes = [ + "googletest", + "googletest/include", + "googletest/include/internal", + "googletest/test", + ], + linkopts = select({ + "//:qnx": [], + "//:windows": [], + "//conditions:default": ["-pthread"], + }), + deps = ["//:gtest_main"], +) + +# Tests death tests. +cc_test( + name = "googletest-death-test-test", + size = "medium", + srcs = ["googletest-death-test-test.cc"], + deps = ["//:gtest_main"], +) + +cc_test( + name = "gtest_test_macro_stack_footprint_test", + size = "small", + srcs = ["gtest_test_macro_stack_footprint_test.cc"], + deps = ["//:gtest"], +) + +#These googletest tests have their own main() +cc_test( + name = "googletest-listener-test", + size = "small", + srcs = ["googletest-listener-test.cc"], + deps = ["//:gtest_main"], +) + +cc_test( + name = "gtest-unittest-api_test", + size = "small", + srcs = [ + "gtest-unittest-api_test.cc", + ], + deps = [ + "//:gtest", + ], +) + +cc_test( + name = "googletest-param-test-test", + size = "small", + srcs = [ + "googletest-param-test-test.cc", + "googletest-param-test-test.h", + "googletest-param-test2-test.cc", + ], + deps = ["//:gtest"], +) + +cc_test( + name = "gtest_unittest", + size = "small", + srcs = ["gtest_unittest.cc"], + shard_count = 2, + deps = ["//:gtest_main"], +) + +# Py tests + +py_library( + name = "gtest_test_utils", + testonly = 1, + srcs = ["gtest_test_utils.py"], + imports = ["."], +) + +cc_binary( + name = "gtest_help_test_", + testonly = 1, + srcs = ["gtest_help_test_.cc"], + deps = ["//:gtest_main"], +) + +py_test( + name = "gtest_help_test", + size = "small", + srcs = ["gtest_help_test.py"], + data = [":gtest_help_test_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-output-test_", + testonly = 1, + srcs = ["googletest-output-test_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-output-test", + size = "small", + srcs = ["googletest-output-test.py"], + args = select({ + "//:has_absl": [], + "//conditions:default": ["--no_stacktrace_support"], + }), + data = [ + "googletest-output-test-golden-lin.txt", + ":googletest-output-test_", + ], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-color-test_", + testonly = 1, + srcs = ["googletest-color-test_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-color-test", + size = "small", + srcs = ["googletest-color-test.py"], + data = [":googletest-color-test_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-env-var-test_", + testonly = 1, + srcs = ["googletest-env-var-test_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-env-var-test", + size = "medium", + srcs = ["googletest-env-var-test.py"], + data = [":googletest-env-var-test_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-failfast-unittest_", + testonly = 1, + srcs = ["googletest-failfast-unittest_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-failfast-unittest", + size = "medium", + srcs = ["googletest-failfast-unittest.py"], + data = [":googletest-failfast-unittest_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-filter-unittest_", + testonly = 1, + srcs = ["googletest-filter-unittest_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-filter-unittest", + size = "medium", + srcs = ["googletest-filter-unittest.py"], + data = [":googletest-filter-unittest_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-global-environment-unittest_", + testonly = 1, + srcs = ["googletest-global-environment-unittest_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-global-environment-unittest", + size = "medium", + srcs = ["googletest-global-environment-unittest.py"], + data = [":googletest-global-environment-unittest_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-break-on-failure-unittest_", + testonly = 1, + srcs = ["googletest-break-on-failure-unittest_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-break-on-failure-unittest", + size = "small", + srcs = ["googletest-break-on-failure-unittest.py"], + data = [":googletest-break-on-failure-unittest_"], + deps = [":gtest_test_utils"], +) + +cc_test( + name = "gtest_assert_by_exception_test", + size = "small", + srcs = ["gtest_assert_by_exception_test.cc"], + deps = ["//:gtest"], +) + +cc_binary( + name = "googletest-throw-on-failure-test_", + testonly = 1, + srcs = ["googletest-throw-on-failure-test_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-throw-on-failure-test", + size = "small", + srcs = ["googletest-throw-on-failure-test.py"], + data = [":googletest-throw-on-failure-test_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-list-tests-unittest_", + testonly = 1, + srcs = ["googletest-list-tests-unittest_.cc"], + deps = ["//:gtest"], +) + +cc_test( + name = "gtest_skip_test", + size = "small", + srcs = ["gtest_skip_test.cc"], + deps = ["//:gtest_main"], +) + +cc_test( + name = "gtest_skip_in_environment_setup_test", + size = "small", + srcs = ["gtest_skip_in_environment_setup_test.cc"], + deps = ["//:gtest_main"], +) + +py_test( + name = "gtest_skip_check_output_test", + size = "small", + srcs = ["gtest_skip_check_output_test.py"], + data = [":gtest_skip_test"], + deps = [":gtest_test_utils"], +) + +py_test( + name = "gtest_skip_environment_check_output_test", + size = "small", + srcs = ["gtest_skip_environment_check_output_test.py"], + data = [ + ":gtest_skip_in_environment_setup_test", + ], + deps = [":gtest_test_utils"], +) + +py_test( + name = "googletest-list-tests-unittest", + size = "small", + srcs = ["googletest-list-tests-unittest.py"], + data = [":googletest-list-tests-unittest_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-shuffle-test_", + srcs = ["googletest-shuffle-test_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-shuffle-test", + size = "small", + srcs = ["googletest-shuffle-test.py"], + data = [":googletest-shuffle-test_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-catch-exceptions-no-ex-test_", + testonly = 1, + srcs = ["googletest-catch-exceptions-test_.cc"], + deps = ["//:gtest_main"], +) + +cc_binary( + name = "googletest-catch-exceptions-ex-test_", + testonly = 1, + srcs = ["googletest-catch-exceptions-test_.cc"], + copts = ["-fexceptions"], + deps = ["//:gtest_main"], +) + +py_test( + name = "googletest-catch-exceptions-test", + size = "small", + srcs = ["googletest-catch-exceptions-test.py"], + data = [ + ":googletest-catch-exceptions-ex-test_", + ":googletest-catch-exceptions-no-ex-test_", + ], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "gtest_xml_output_unittest_", + testonly = 1, + srcs = ["gtest_xml_output_unittest_.cc"], + deps = ["//:gtest"], +) + +cc_test( + name = "gtest_no_test_unittest", + size = "small", + srcs = ["gtest_no_test_unittest.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "gtest_xml_output_unittest", + size = "small", + srcs = [ + "gtest_xml_output_unittest.py", + "gtest_xml_test_utils.py", + ], + args = select({ + "//:has_absl": [], + "//conditions:default": ["--no_stacktrace_support"], + }), + data = [ + # We invoke gtest_no_test_unittest to verify the XML output + # when the test program contains no test definition. + ":gtest_no_test_unittest", + ":gtest_xml_output_unittest_", + ], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "gtest_xml_outfile1_test_", + testonly = 1, + srcs = ["gtest_xml_outfile1_test_.cc"], + deps = ["//:gtest_main"], +) + +cc_binary( + name = "gtest_xml_outfile2_test_", + testonly = 1, + srcs = ["gtest_xml_outfile2_test_.cc"], + deps = ["//:gtest_main"], +) + +py_test( + name = "gtest_xml_outfiles_test", + size = "small", + srcs = [ + "gtest_xml_outfiles_test.py", + "gtest_xml_test_utils.py", + ], + data = [ + ":gtest_xml_outfile1_test_", + ":gtest_xml_outfile2_test_", + ], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-setuptestsuite-test_", + testonly = 1, + srcs = ["googletest-setuptestsuite-test_.cc"], + deps = ["//:gtest_main"], +) + +py_test( + name = "googletest-setuptestsuite-test", + size = "medium", + srcs = ["googletest-setuptestsuite-test.py"], + data = [":googletest-setuptestsuite-test_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "googletest-uninitialized-test_", + testonly = 1, + srcs = ["googletest-uninitialized-test_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-uninitialized-test", + size = "medium", + srcs = ["googletest-uninitialized-test.py"], + data = ["googletest-uninitialized-test_"], + deps = [":gtest_test_utils"], +) + +cc_binary( + name = "gtest_testbridge_test_", + testonly = 1, + srcs = ["gtest_testbridge_test_.cc"], + deps = ["//:gtest_main"], +) + +# Tests that filtering via testbridge works +py_test( + name = "gtest_testbridge_test", + size = "small", + srcs = ["gtest_testbridge_test.py"], + data = [":gtest_testbridge_test_"], + deps = [":gtest_test_utils"], +) + +py_test( + name = "googletest-json-outfiles-test", + size = "small", + srcs = [ + "googletest-json-outfiles-test.py", + "gtest_json_test_utils.py", + ], + data = [ + ":gtest_xml_outfile1_test_", + ":gtest_xml_outfile2_test_", + ], + deps = [":gtest_test_utils"], +) + +py_test( + name = "googletest-json-output-unittest", + size = "medium", + srcs = [ + "googletest-json-output-unittest.py", + "gtest_json_test_utils.py", + ], + args = select({ + "//:has_absl": [], + "//conditions:default": ["--no_stacktrace_support"], + }), + data = [ + # We invoke gtest_no_test_unittest to verify the JSON output + # when the test program contains no test definition. + ":gtest_no_test_unittest", + ":gtest_xml_output_unittest_", + ], + deps = [":gtest_test_utils"], +) + +# Verifies interaction of death tests and exceptions. +cc_test( + name = "googletest-death-test_ex_catch_test", + size = "medium", + srcs = ["googletest-death-test_ex_test.cc"], + copts = ["-fexceptions"], + defines = ["GTEST_ENABLE_CATCH_EXCEPTIONS_=1"], + deps = ["//:gtest"], +) + +cc_binary( + name = "googletest-param-test-invalid-name1-test_", + testonly = 1, + srcs = ["googletest-param-test-invalid-name1-test_.cc"], + deps = ["//:gtest"], +) + +cc_binary( + name = "googletest-param-test-invalid-name2-test_", + testonly = 1, + srcs = ["googletest-param-test-invalid-name2-test_.cc"], + deps = ["//:gtest"], +) + +py_test( + name = "googletest-param-test-invalid-name1-test", + size = "small", + srcs = ["googletest-param-test-invalid-name1-test.py"], + data = [":googletest-param-test-invalid-name1-test_"], + tags = [ + "no_test_msvc2015", + "no_test_msvc2017", + ], + deps = [":gtest_test_utils"], +) + +py_test( + name = "googletest-param-test-invalid-name2-test", + size = "small", + srcs = ["googletest-param-test-invalid-name2-test.py"], + data = [":googletest-param-test-invalid-name2-test_"], + tags = [ + "no_test_msvc2015", + "no_test_msvc2017", + ], + deps = [":gtest_test_utils"], +) diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-break-on-failure-unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-break-on-failure-unittest.py new file mode 100755 index 000000000000..4eafba3e6bb4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-break-on-failure-unittest.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# +# Copyright 2006, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for Google Test's break-on-failure mode. + +A user can ask Google Test to seg-fault when an assertion fails, using +either the GTEST_BREAK_ON_FAILURE environment variable or the +--gtest_break_on_failure flag. This script tests such functionality +by invoking googletest-break-on-failure-unittest_ (a program written with +Google Test) with different environments and command line flags. +""" + +import os +from googletest.test import gtest_test_utils + +# Constants. + +IS_WINDOWS = os.name == 'nt' + +# The environment variable for enabling/disabling the break-on-failure mode. +BREAK_ON_FAILURE_ENV_VAR = 'GTEST_BREAK_ON_FAILURE' + +# The command line flag for enabling/disabling the break-on-failure mode. +BREAK_ON_FAILURE_FLAG = 'gtest_break_on_failure' + +# The environment variable for enabling/disabling the throw-on-failure mode. +THROW_ON_FAILURE_ENV_VAR = 'GTEST_THROW_ON_FAILURE' + +# The environment variable for enabling/disabling the catch-exceptions mode. +CATCH_EXCEPTIONS_ENV_VAR = 'GTEST_CATCH_EXCEPTIONS' + +# Path to the googletest-break-on-failure-unittest_ program. +EXE_PATH = gtest_test_utils.GetTestExecutablePath( + 'googletest-break-on-failure-unittest_') + + +environ = gtest_test_utils.environ +SetEnvVar = gtest_test_utils.SetEnvVar + +# Tests in this file run a Google-Test-based test program and expect it +# to terminate prematurely. Therefore they are incompatible with +# the premature-exit-file protocol by design. Unset the +# premature-exit filepath to prevent Google Test from creating +# the file. +SetEnvVar(gtest_test_utils.PREMATURE_EXIT_FILE_ENV_VAR, None) + + +def Run(command): + """Runs a command; returns 1 if it was killed by a signal, or 0 otherwise.""" + + p = gtest_test_utils.Subprocess(command, env=environ) + if p.terminated_by_signal: + return 1 + else: + return 0 + + +# The tests. + + +class GTestBreakOnFailureUnitTest(gtest_test_utils.TestCase): + """Tests using the GTEST_BREAK_ON_FAILURE environment variable or + the --gtest_break_on_failure flag to turn assertion failures into + segmentation faults. + """ + + def RunAndVerify(self, env_var_value, flag_value, expect_seg_fault): + """Runs googletest-break-on-failure-unittest_ and verifies that it does + (or does not) have a seg-fault. + + Args: + env_var_value: value of the GTEST_BREAK_ON_FAILURE environment + variable; None if the variable should be unset. + flag_value: value of the --gtest_break_on_failure flag; + None if the flag should not be present. + expect_seg_fault: 1 if the program is expected to generate a seg-fault; + 0 otherwise. + """ + + SetEnvVar(BREAK_ON_FAILURE_ENV_VAR, env_var_value) + + if env_var_value is None: + env_var_value_msg = ' is not set' + else: + env_var_value_msg = '=' + env_var_value + + if flag_value is None: + flag = '' + elif flag_value == '0': + flag = '--%s=0' % BREAK_ON_FAILURE_FLAG + else: + flag = '--%s' % BREAK_ON_FAILURE_FLAG + + command = [EXE_PATH] + if flag: + command.append(flag) + + if expect_seg_fault: + should_or_not = 'should' + else: + should_or_not = 'should not' + + has_seg_fault = Run(command) + + SetEnvVar(BREAK_ON_FAILURE_ENV_VAR, None) + + msg = ('when %s%s, an assertion failure in "%s" %s cause a seg-fault.' % + (BREAK_ON_FAILURE_ENV_VAR, env_var_value_msg, ' '.join(command), + should_or_not)) + self.assert_(has_seg_fault == expect_seg_fault, msg) + + def testDefaultBehavior(self): + """Tests the behavior of the default mode.""" + + self.RunAndVerify(env_var_value=None, + flag_value=None, + expect_seg_fault=0) + + def testEnvVar(self): + """Tests using the GTEST_BREAK_ON_FAILURE environment variable.""" + + self.RunAndVerify(env_var_value='0', + flag_value=None, + expect_seg_fault=0) + self.RunAndVerify(env_var_value='1', + flag_value=None, + expect_seg_fault=1) + + def testFlag(self): + """Tests using the --gtest_break_on_failure flag.""" + + self.RunAndVerify(env_var_value=None, + flag_value='0', + expect_seg_fault=0) + self.RunAndVerify(env_var_value=None, + flag_value='1', + expect_seg_fault=1) + + def testFlagOverridesEnvVar(self): + """Tests that the flag overrides the environment variable.""" + + self.RunAndVerify(env_var_value='0', + flag_value='0', + expect_seg_fault=0) + self.RunAndVerify(env_var_value='0', + flag_value='1', + expect_seg_fault=1) + self.RunAndVerify(env_var_value='1', + flag_value='0', + expect_seg_fault=0) + self.RunAndVerify(env_var_value='1', + flag_value='1', + expect_seg_fault=1) + + def testBreakOnFailureOverridesThrowOnFailure(self): + """Tests that gtest_break_on_failure overrides gtest_throw_on_failure.""" + + SetEnvVar(THROW_ON_FAILURE_ENV_VAR, '1') + try: + self.RunAndVerify(env_var_value=None, + flag_value='1', + expect_seg_fault=1) + finally: + SetEnvVar(THROW_ON_FAILURE_ENV_VAR, None) + + if IS_WINDOWS: + def testCatchExceptionsDoesNotInterfere(self): + """Tests that gtest_catch_exceptions doesn't interfere.""" + + SetEnvVar(CATCH_EXCEPTIONS_ENV_VAR, '1') + try: + self.RunAndVerify(env_var_value='1', + flag_value='1', + expect_seg_fault=1) + finally: + SetEnvVar(CATCH_EXCEPTIONS_ENV_VAR, None) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-break-on-failure-unittest_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-break-on-failure-unittest_.cc new file mode 100644 index 000000000000..f84957a2d03b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-break-on-failure-unittest_.cc @@ -0,0 +1,86 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Unit test for Google Test's break-on-failure mode. +// +// A user can ask Google Test to seg-fault when an assertion fails, using +// either the GTEST_BREAK_ON_FAILURE environment variable or the +// --gtest_break_on_failure flag. This file is used for testing such +// functionality. +// +// This program will be invoked from a Python unit test. It is +// expected to fail. Don't run it directly. + +#include "gtest/gtest.h" + +#if GTEST_OS_WINDOWS +# include +# include +#endif + +namespace { + +// A test that's expected to fail. +TEST(Foo, Bar) { + EXPECT_EQ(2, 3); +} + +#if GTEST_HAS_SEH && !GTEST_OS_WINDOWS_MOBILE +// On Windows Mobile global exception handlers are not supported. +LONG WINAPI ExitWithExceptionCode( + struct _EXCEPTION_POINTERS* exception_pointers) { + exit(exception_pointers->ExceptionRecord->ExceptionCode); +} +#endif + +} // namespace + +int main(int argc, char **argv) { +#if GTEST_OS_WINDOWS + // Suppresses display of the Windows error dialog upon encountering + // a general protection fault (segment violation). + SetErrorMode(SEM_NOGPFAULTERRORBOX | SEM_FAILCRITICALERRORS); + +# if GTEST_HAS_SEH && !GTEST_OS_WINDOWS_MOBILE + + // The default unhandled exception filter does not always exit + // with the exception code as exit code - for example it exits with + // 0 for EXCEPTION_ACCESS_VIOLATION and 1 for EXCEPTION_BREAKPOINT + // if the application is compiled in debug mode. Thus we use our own + // filter which always exits with the exception code for unhandled + // exceptions. + SetUnhandledExceptionFilter(ExitWithExceptionCode); + +# endif +#endif // GTEST_OS_WINDOWS + testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-catch-exceptions-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-catch-exceptions-test.py new file mode 100755 index 000000000000..d38d91a62a43 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-catch-exceptions-test.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python +# +# Copyright 2010 Google Inc. All Rights Reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests Google Test's exception catching behavior. + +This script invokes googletest-catch-exceptions-test_ and +googletest-catch-exceptions-ex-test_ (programs written with +Google Test) and verifies their output. +""" + +from googletest.test import gtest_test_utils + +# Constants. +FLAG_PREFIX = '--gtest_' +LIST_TESTS_FLAG = FLAG_PREFIX + 'list_tests' +NO_CATCH_EXCEPTIONS_FLAG = FLAG_PREFIX + 'catch_exceptions=0' +FILTER_FLAG = FLAG_PREFIX + 'filter' + +# Path to the googletest-catch-exceptions-ex-test_ binary, compiled with +# exceptions enabled. +EX_EXE_PATH = gtest_test_utils.GetTestExecutablePath( + 'googletest-catch-exceptions-ex-test_') + +# Path to the googletest-catch-exceptions-test_ binary, compiled with +# exceptions disabled. +EXE_PATH = gtest_test_utils.GetTestExecutablePath( + 'googletest-catch-exceptions-no-ex-test_') + +environ = gtest_test_utils.environ +SetEnvVar = gtest_test_utils.SetEnvVar + +# Tests in this file run a Google-Test-based test program and expect it +# to terminate prematurely. Therefore they are incompatible with +# the premature-exit-file protocol by design. Unset the +# premature-exit filepath to prevent Google Test from creating +# the file. +SetEnvVar(gtest_test_utils.PREMATURE_EXIT_FILE_ENV_VAR, None) + +TEST_LIST = gtest_test_utils.Subprocess( + [EXE_PATH, LIST_TESTS_FLAG], env=environ).output + +SUPPORTS_SEH_EXCEPTIONS = 'ThrowsSehException' in TEST_LIST + +if SUPPORTS_SEH_EXCEPTIONS: + BINARY_OUTPUT = gtest_test_utils.Subprocess([EXE_PATH], env=environ).output + +EX_BINARY_OUTPUT = gtest_test_utils.Subprocess( + [EX_EXE_PATH], env=environ).output + + +# The tests. +if SUPPORTS_SEH_EXCEPTIONS: + # pylint:disable-msg=C6302 + class CatchSehExceptionsTest(gtest_test_utils.TestCase): + """Tests exception-catching behavior.""" + + + def TestSehExceptions(self, test_output): + self.assert_('SEH exception with code 0x2a thrown ' + 'in the test fixture\'s constructor' + in test_output) + self.assert_('SEH exception with code 0x2a thrown ' + 'in the test fixture\'s destructor' + in test_output) + self.assert_('SEH exception with code 0x2a thrown in SetUpTestSuite()' + in test_output) + self.assert_('SEH exception with code 0x2a thrown in TearDownTestSuite()' + in test_output) + self.assert_('SEH exception with code 0x2a thrown in SetUp()' + in test_output) + self.assert_('SEH exception with code 0x2a thrown in TearDown()' + in test_output) + self.assert_('SEH exception with code 0x2a thrown in the test body' + in test_output) + + def testCatchesSehExceptionsWithCxxExceptionsEnabled(self): + self.TestSehExceptions(EX_BINARY_OUTPUT) + + def testCatchesSehExceptionsWithCxxExceptionsDisabled(self): + self.TestSehExceptions(BINARY_OUTPUT) + + +class CatchCxxExceptionsTest(gtest_test_utils.TestCase): + """Tests C++ exception-catching behavior. + + Tests in this test case verify that: + * C++ exceptions are caught and logged as C++ (not SEH) exceptions + * Exception thrown affect the remainder of the test work flow in the + expected manner. + """ + + def testCatchesCxxExceptionsInFixtureConstructor(self): + self.assertTrue( + 'C++ exception with description ' + '"Standard C++ exception" thrown ' + 'in the test fixture\'s constructor' in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT) + self.assert_('unexpected' not in EX_BINARY_OUTPUT, + 'This failure belongs in this test only if ' + '"CxxExceptionInConstructorTest" (no quotes) ' + 'appears on the same line as words "called unexpectedly"') + + if ('CxxExceptionInDestructorTest.ThrowsExceptionInDestructor' in + EX_BINARY_OUTPUT): + + def testCatchesCxxExceptionsInFixtureDestructor(self): + self.assertTrue( + 'C++ exception with description ' + '"Standard C++ exception" thrown ' + 'in the test fixture\'s destructor' in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInDestructorTest::TearDownTestSuite() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + + def testCatchesCxxExceptionsInSetUpTestCase(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + ' thrown in SetUpTestSuite()' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInConstructorTest::TearDownTestSuite() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertFalse( + 'CxxExceptionInSetUpTestSuiteTest constructor ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertFalse( + 'CxxExceptionInSetUpTestSuiteTest destructor ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertFalse( + 'CxxExceptionInSetUpTestSuiteTest::SetUp() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertFalse( + 'CxxExceptionInSetUpTestSuiteTest::TearDown() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertFalse( + 'CxxExceptionInSetUpTestSuiteTest test body ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + + def testCatchesCxxExceptionsInTearDownTestCase(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + ' thrown in TearDownTestSuite()' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + + def testCatchesCxxExceptionsInSetUp(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + ' thrown in SetUp()' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInSetUpTest::TearDownTestSuite() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInSetUpTest destructor ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInSetUpTest::TearDown() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assert_('unexpected' not in EX_BINARY_OUTPUT, + 'This failure belongs in this test only if ' + '"CxxExceptionInSetUpTest" (no quotes) ' + 'appears on the same line as words "called unexpectedly"') + + def testCatchesCxxExceptionsInTearDown(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + ' thrown in TearDown()' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInTearDownTest::TearDownTestSuite() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInTearDownTest destructor ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + + def testCatchesCxxExceptionsInTestBody(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + ' thrown in the test body' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInTestBodyTest::TearDownTestSuite() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInTestBodyTest destructor ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + self.assertTrue( + 'CxxExceptionInTestBodyTest::TearDown() ' + 'called as expected.' in EX_BINARY_OUTPUT, EX_BINARY_OUTPUT) + + def testCatchesNonStdCxxExceptions(self): + self.assertTrue( + 'Unknown C++ exception thrown in the test body' in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT) + + def testUnhandledCxxExceptionsAbortTheProgram(self): + # Filters out SEH exception tests on Windows. Unhandled SEH exceptions + # cause tests to show pop-up windows there. + FITLER_OUT_SEH_TESTS_FLAG = FILTER_FLAG + '=-*Seh*' + # By default, Google Test doesn't catch the exceptions. + uncaught_exceptions_ex_binary_output = gtest_test_utils.Subprocess( + [EX_EXE_PATH, + NO_CATCH_EXCEPTIONS_FLAG, + FITLER_OUT_SEH_TESTS_FLAG], + env=environ).output + + self.assert_('Unhandled C++ exception terminating the program' + in uncaught_exceptions_ex_binary_output) + self.assert_('unexpected' not in uncaught_exceptions_ex_binary_output) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-catch-exceptions-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-catch-exceptions-test_.cc new file mode 100644 index 000000000000..8c127d40b11d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-catch-exceptions-test_.cc @@ -0,0 +1,293 @@ +// Copyright 2010, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests for Google Test itself. Tests in this file throw C++ or SEH +// exceptions, and the output is verified by +// googletest-catch-exceptions-test.py. + +#include // NOLINT +#include // For exit(). + +#include "gtest/gtest.h" + +#if GTEST_HAS_SEH +# include +#endif + +#if GTEST_HAS_EXCEPTIONS +# include // For set_terminate(). +# include +#endif + +using testing::Test; + +#if GTEST_HAS_SEH + +class SehExceptionInConstructorTest : public Test { + public: + SehExceptionInConstructorTest() { RaiseException(42, 0, 0, NULL); } +}; + +TEST_F(SehExceptionInConstructorTest, ThrowsExceptionInConstructor) {} + +class SehExceptionInDestructorTest : public Test { + public: + ~SehExceptionInDestructorTest() { RaiseException(42, 0, 0, NULL); } +}; + +TEST_F(SehExceptionInDestructorTest, ThrowsExceptionInDestructor) {} + +class SehExceptionInSetUpTestSuiteTest : public Test { + public: + static void SetUpTestSuite() { RaiseException(42, 0, 0, NULL); } +}; + +TEST_F(SehExceptionInSetUpTestSuiteTest, ThrowsExceptionInSetUpTestSuite) {} + +class SehExceptionInTearDownTestSuiteTest : public Test { + public: + static void TearDownTestSuite() { RaiseException(42, 0, 0, NULL); } +}; + +TEST_F(SehExceptionInTearDownTestSuiteTest, + ThrowsExceptionInTearDownTestSuite) {} + +class SehExceptionInSetUpTest : public Test { + protected: + virtual void SetUp() { RaiseException(42, 0, 0, NULL); } +}; + +TEST_F(SehExceptionInSetUpTest, ThrowsExceptionInSetUp) {} + +class SehExceptionInTearDownTest : public Test { + protected: + virtual void TearDown() { RaiseException(42, 0, 0, NULL); } +}; + +TEST_F(SehExceptionInTearDownTest, ThrowsExceptionInTearDown) {} + +TEST(SehExceptionTest, ThrowsSehException) { + RaiseException(42, 0, 0, NULL); +} + +#endif // GTEST_HAS_SEH + +#if GTEST_HAS_EXCEPTIONS + +class CxxExceptionInConstructorTest : public Test { + public: + CxxExceptionInConstructorTest() { + // Without this macro VC++ complains about unreachable code at the end of + // the constructor. + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_( + throw std::runtime_error("Standard C++ exception")); + } + + static void TearDownTestSuite() { + printf("%s", + "CxxExceptionInConstructorTest::TearDownTestSuite() " + "called as expected.\n"); + } + + protected: + ~CxxExceptionInConstructorTest() override { + ADD_FAILURE() << "CxxExceptionInConstructorTest destructor " + << "called unexpectedly."; + } + + void SetUp() override { + ADD_FAILURE() << "CxxExceptionInConstructorTest::SetUp() " + << "called unexpectedly."; + } + + void TearDown() override { + ADD_FAILURE() << "CxxExceptionInConstructorTest::TearDown() " + << "called unexpectedly."; + } +}; + +TEST_F(CxxExceptionInConstructorTest, ThrowsExceptionInConstructor) { + ADD_FAILURE() << "CxxExceptionInConstructorTest test body " + << "called unexpectedly."; +} + +class CxxExceptionInSetUpTestSuiteTest : public Test { + public: + CxxExceptionInSetUpTestSuiteTest() { + printf("%s", + "CxxExceptionInSetUpTestSuiteTest constructor " + "called as expected.\n"); + } + + static void SetUpTestSuite() { + throw std::runtime_error("Standard C++ exception"); + } + + static void TearDownTestSuite() { + printf("%s", + "CxxExceptionInSetUpTestSuiteTest::TearDownTestSuite() " + "called as expected.\n"); + } + + protected: + ~CxxExceptionInSetUpTestSuiteTest() override { + printf("%s", + "CxxExceptionInSetUpTestSuiteTest destructor " + "called as expected.\n"); + } + + void SetUp() override { + printf("%s", + "CxxExceptionInSetUpTestSuiteTest::SetUp() " + "called as expected.\n"); + } + + void TearDown() override { + printf("%s", + "CxxExceptionInSetUpTestSuiteTest::TearDown() " + "called as expected.\n"); + } +}; + +TEST_F(CxxExceptionInSetUpTestSuiteTest, ThrowsExceptionInSetUpTestSuite) { + printf("%s", + "CxxExceptionInSetUpTestSuiteTest test body " + "called as expected.\n"); +} + +class CxxExceptionInTearDownTestSuiteTest : public Test { + public: + static void TearDownTestSuite() { + throw std::runtime_error("Standard C++ exception"); + } +}; + +TEST_F(CxxExceptionInTearDownTestSuiteTest, + ThrowsExceptionInTearDownTestSuite) {} + +class CxxExceptionInSetUpTest : public Test { + public: + static void TearDownTestSuite() { + printf("%s", + "CxxExceptionInSetUpTest::TearDownTestSuite() " + "called as expected.\n"); + } + + protected: + ~CxxExceptionInSetUpTest() override { + printf("%s", + "CxxExceptionInSetUpTest destructor " + "called as expected.\n"); + } + + void SetUp() override { throw std::runtime_error("Standard C++ exception"); } + + void TearDown() override { + printf("%s", + "CxxExceptionInSetUpTest::TearDown() " + "called as expected.\n"); + } +}; + +TEST_F(CxxExceptionInSetUpTest, ThrowsExceptionInSetUp) { + ADD_FAILURE() << "CxxExceptionInSetUpTest test body " + << "called unexpectedly."; +} + +class CxxExceptionInTearDownTest : public Test { + public: + static void TearDownTestSuite() { + printf("%s", + "CxxExceptionInTearDownTest::TearDownTestSuite() " + "called as expected.\n"); + } + + protected: + ~CxxExceptionInTearDownTest() override { + printf("%s", + "CxxExceptionInTearDownTest destructor " + "called as expected.\n"); + } + + void TearDown() override { + throw std::runtime_error("Standard C++ exception"); + } +}; + +TEST_F(CxxExceptionInTearDownTest, ThrowsExceptionInTearDown) {} + +class CxxExceptionInTestBodyTest : public Test { + public: + static void TearDownTestSuite() { + printf("%s", + "CxxExceptionInTestBodyTest::TearDownTestSuite() " + "called as expected.\n"); + } + + protected: + ~CxxExceptionInTestBodyTest() override { + printf("%s", + "CxxExceptionInTestBodyTest destructor " + "called as expected.\n"); + } + + void TearDown() override { + printf("%s", + "CxxExceptionInTestBodyTest::TearDown() " + "called as expected.\n"); + } +}; + +TEST_F(CxxExceptionInTestBodyTest, ThrowsStdCxxException) { + throw std::runtime_error("Standard C++ exception"); +} + +TEST(CxxExceptionTest, ThrowsNonStdCxxException) { + throw "C-string"; +} + +// This terminate handler aborts the program using exit() rather than abort(). +// This avoids showing pop-ups on Windows systems and core dumps on Unix-like +// ones. +void TerminateHandler() { + fprintf(stderr, "%s\n", "Unhandled C++ exception terminating the program."); + fflush(nullptr); + exit(3); +} + +#endif // GTEST_HAS_EXCEPTIONS + +int main(int argc, char** argv) { +#if GTEST_HAS_EXCEPTIONS + std::set_terminate(&TerminateHandler); +#endif + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-color-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-color-test.py new file mode 100755 index 000000000000..c22752db82a3 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-color-test.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# +# Copyright 2008, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Verifies that Google Test correctly determines whether to use colors.""" + +import os +from googletest.test import gtest_test_utils + +IS_WINDOWS = os.name == 'nt' + +COLOR_ENV_VAR = 'GTEST_COLOR' +COLOR_FLAG = 'gtest_color' +COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-color-test_') + + +def SetEnvVar(env_var, value): + """Sets the env variable to 'value'; unsets it when 'value' is None.""" + + if value is not None: + os.environ[env_var] = value + elif env_var in os.environ: + del os.environ[env_var] + + +def UsesColor(term, color_env_var, color_flag): + """Runs googletest-color-test_ and returns its exit code.""" + + SetEnvVar('TERM', term) + SetEnvVar(COLOR_ENV_VAR, color_env_var) + + if color_flag is None: + args = [] + else: + args = ['--%s=%s' % (COLOR_FLAG, color_flag)] + p = gtest_test_utils.Subprocess([COMMAND] + args) + return not p.exited or p.exit_code + + +class GTestColorTest(gtest_test_utils.TestCase): + def testNoEnvVarNoFlag(self): + """Tests the case when there's neither GTEST_COLOR nor --gtest_color.""" + + if not IS_WINDOWS: + self.assert_(not UsesColor('dumb', None, None)) + self.assert_(not UsesColor('emacs', None, None)) + self.assert_(not UsesColor('xterm-mono', None, None)) + self.assert_(not UsesColor('unknown', None, None)) + self.assert_(not UsesColor(None, None, None)) + self.assert_(UsesColor('linux', None, None)) + self.assert_(UsesColor('cygwin', None, None)) + self.assert_(UsesColor('xterm', None, None)) + self.assert_(UsesColor('xterm-color', None, None)) + self.assert_(UsesColor('xterm-256color', None, None)) + + def testFlagOnly(self): + """Tests the case when there's --gtest_color but not GTEST_COLOR.""" + + self.assert_(not UsesColor('dumb', None, 'no')) + self.assert_(not UsesColor('xterm-color', None, 'no')) + if not IS_WINDOWS: + self.assert_(not UsesColor('emacs', None, 'auto')) + self.assert_(UsesColor('xterm', None, 'auto')) + self.assert_(UsesColor('dumb', None, 'yes')) + self.assert_(UsesColor('xterm', None, 'yes')) + + def testEnvVarOnly(self): + """Tests the case when there's GTEST_COLOR but not --gtest_color.""" + + self.assert_(not UsesColor('dumb', 'no', None)) + self.assert_(not UsesColor('xterm-color', 'no', None)) + if not IS_WINDOWS: + self.assert_(not UsesColor('dumb', 'auto', None)) + self.assert_(UsesColor('xterm-color', 'auto', None)) + self.assert_(UsesColor('dumb', 'yes', None)) + self.assert_(UsesColor('xterm-color', 'yes', None)) + + def testEnvVarAndFlag(self): + """Tests the case when there are both GTEST_COLOR and --gtest_color.""" + + self.assert_(not UsesColor('xterm-color', 'no', 'no')) + self.assert_(UsesColor('dumb', 'no', 'yes')) + self.assert_(UsesColor('xterm-color', 'no', 'auto')) + + def testAliasesOfYesAndNo(self): + """Tests using aliases in specifying --gtest_color.""" + + self.assert_(UsesColor('dumb', None, 'true')) + self.assert_(UsesColor('dumb', None, 'YES')) + self.assert_(UsesColor('dumb', None, 'T')) + self.assert_(UsesColor('dumb', None, '1')) + + self.assert_(not UsesColor('xterm', None, 'f')) + self.assert_(not UsesColor('xterm', None, 'false')) + self.assert_(not UsesColor('xterm', None, '0')) + self.assert_(not UsesColor('xterm', None, 'unknown')) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-color-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-color-test_.cc new file mode 100644 index 000000000000..220a3a00548f --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-color-test_.cc @@ -0,0 +1,62 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// A helper program for testing how Google Test determines whether to use +// colors in the output. It prints "YES" and returns 1 if Google Test +// decides to use colors, and prints "NO" and returns 0 otherwise. + +#include + +#include "gtest/gtest.h" +#include "src/gtest-internal-inl.h" + +using testing::internal::ShouldUseColor; + +// The purpose of this is to ensure that the UnitTest singleton is +// created before main() is entered, and thus that ShouldUseColor() +// works the same way as in a real Google-Test-based test. We don't actual +// run the TEST itself. +TEST(GTestColorTest, Dummy) { +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + if (ShouldUseColor(true)) { + // Google Test decides to use colors in the output (assuming it + // goes to a TTY). + printf("YES\n"); + return 1; + } else { + // Google Test decides not to use colors in the output. + printf("NO\n"); + return 0; + } +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-death-test-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-death-test-test.cc new file mode 100644 index 000000000000..62a84b478a46 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-death-test-test.cc @@ -0,0 +1,1528 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests for death tests. + +#include "gtest/gtest-death-test.h" + +#include "gtest/gtest.h" +#include "gtest/internal/gtest-filepath.h" + +using testing::internal::AlwaysFalse; +using testing::internal::AlwaysTrue; + +#if GTEST_HAS_DEATH_TEST + +# if GTEST_OS_WINDOWS +# include // For O_BINARY +# include // For chdir(). +# include +# else +# include +# include // For waitpid. +# endif // GTEST_OS_WINDOWS + +# include +# include +# include + +# if GTEST_OS_LINUX +# include +# endif // GTEST_OS_LINUX + +# include "gtest/gtest-spi.h" +# include "src/gtest-internal-inl.h" + +namespace posix = ::testing::internal::posix; + +using testing::ContainsRegex; +using testing::Matcher; +using testing::Message; +using testing::internal::DeathTest; +using testing::internal::DeathTestFactory; +using testing::internal::FilePath; +using testing::internal::GetLastErrnoDescription; +using testing::internal::GetUnitTestImpl; +using testing::internal::InDeathTestChild; +using testing::internal::ParseNaturalNumber; + +namespace testing { +namespace internal { + +// A helper class whose objects replace the death test factory for a +// single UnitTest object during their lifetimes. +class ReplaceDeathTestFactory { + public: + explicit ReplaceDeathTestFactory(DeathTestFactory* new_factory) + : unit_test_impl_(GetUnitTestImpl()) { + old_factory_ = unit_test_impl_->death_test_factory_.release(); + unit_test_impl_->death_test_factory_.reset(new_factory); + } + + ~ReplaceDeathTestFactory() { + unit_test_impl_->death_test_factory_.release(); + unit_test_impl_->death_test_factory_.reset(old_factory_); + } + private: + // Prevents copying ReplaceDeathTestFactory objects. + ReplaceDeathTestFactory(const ReplaceDeathTestFactory&); + void operator=(const ReplaceDeathTestFactory&); + + UnitTestImpl* unit_test_impl_; + DeathTestFactory* old_factory_; +}; + +} // namespace internal +} // namespace testing + +namespace { + +void DieWithMessage(const ::std::string& message) { + fprintf(stderr, "%s", message.c_str()); + fflush(stderr); // Make sure the text is printed before the process exits. + + // We call _exit() instead of exit(), as the former is a direct + // system call and thus safer in the presence of threads. exit() + // will invoke user-defined exit-hooks, which may do dangerous + // things that conflict with death tests. + // + // Some compilers can recognize that _exit() never returns and issue the + // 'unreachable code' warning for code following this function, unless + // fooled by a fake condition. + if (AlwaysTrue()) + _exit(1); +} + +void DieInside(const ::std::string& function) { + DieWithMessage("death inside " + function + "()."); +} + +// Tests that death tests work. + +class TestForDeathTest : public testing::Test { + protected: + TestForDeathTest() : original_dir_(FilePath::GetCurrentDir()) {} + + ~TestForDeathTest() override { posix::ChDir(original_dir_.c_str()); } + + // A static member function that's expected to die. + static void StaticMemberFunction() { DieInside("StaticMemberFunction"); } + + // A method of the test fixture that may die. + void MemberFunction() { + if (should_die_) + DieInside("MemberFunction"); + } + + // True if and only if MemberFunction() should die. + bool should_die_; + const FilePath original_dir_; +}; + +// A class with a member function that may die. +class MayDie { + public: + explicit MayDie(bool should_die) : should_die_(should_die) {} + + // A member function that may die. + void MemberFunction() const { + if (should_die_) + DieInside("MayDie::MemberFunction"); + } + + private: + // True if and only if MemberFunction() should die. + bool should_die_; +}; + +// A global function that's expected to die. +void GlobalFunction() { DieInside("GlobalFunction"); } + +// A non-void function that's expected to die. +int NonVoidFunction() { + DieInside("NonVoidFunction"); + return 1; +} + +// A unary function that may die. +void DieIf(bool should_die) { + if (should_die) + DieInside("DieIf"); +} + +// A binary function that may die. +bool DieIfLessThan(int x, int y) { + if (x < y) { + DieInside("DieIfLessThan"); + } + return true; +} + +// Tests that ASSERT_DEATH can be used outside a TEST, TEST_F, or test fixture. +void DeathTestSubroutine() { + EXPECT_DEATH(GlobalFunction(), "death.*GlobalFunction"); + ASSERT_DEATH(GlobalFunction(), "death.*GlobalFunction"); +} + +// Death in dbg, not opt. +int DieInDebugElse12(int* sideeffect) { + if (sideeffect) *sideeffect = 12; + +# ifndef NDEBUG + + DieInside("DieInDebugElse12"); + +# endif // NDEBUG + + return 12; +} + +# if GTEST_OS_WINDOWS + +// Death in dbg due to Windows CRT assertion failure, not opt. +int DieInCRTDebugElse12(int* sideeffect) { + if (sideeffect) *sideeffect = 12; + + // Create an invalid fd by closing a valid one + int fdpipe[2]; + EXPECT_EQ(_pipe(fdpipe, 256, O_BINARY), 0); + EXPECT_EQ(_close(fdpipe[0]), 0); + EXPECT_EQ(_close(fdpipe[1]), 0); + + // _dup() should crash in debug mode + EXPECT_EQ(_dup(fdpipe[0]), -1); + + return 12; +} + +#endif // GTEST_OS_WINDOWS + +# if GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA + +// Tests the ExitedWithCode predicate. +TEST(ExitStatusPredicateTest, ExitedWithCode) { + // On Windows, the process's exit code is the same as its exit status, + // so the predicate just compares the its input with its parameter. + EXPECT_TRUE(testing::ExitedWithCode(0)(0)); + EXPECT_TRUE(testing::ExitedWithCode(1)(1)); + EXPECT_TRUE(testing::ExitedWithCode(42)(42)); + EXPECT_FALSE(testing::ExitedWithCode(0)(1)); + EXPECT_FALSE(testing::ExitedWithCode(1)(0)); +} + +# else + +// Returns the exit status of a process that calls _exit(2) with a +// given exit code. This is a helper function for the +// ExitStatusPredicateTest test suite. +static int NormalExitStatus(int exit_code) { + pid_t child_pid = fork(); + if (child_pid == 0) { + _exit(exit_code); + } + int status; + waitpid(child_pid, &status, 0); + return status; +} + +// Returns the exit status of a process that raises a given signal. +// If the signal does not cause the process to die, then it returns +// instead the exit status of a process that exits normally with exit +// code 1. This is a helper function for the ExitStatusPredicateTest +// test suite. +static int KilledExitStatus(int signum) { + pid_t child_pid = fork(); + if (child_pid == 0) { + raise(signum); + _exit(1); + } + int status; + waitpid(child_pid, &status, 0); + return status; +} + +// Tests the ExitedWithCode predicate. +TEST(ExitStatusPredicateTest, ExitedWithCode) { + const int status0 = NormalExitStatus(0); + const int status1 = NormalExitStatus(1); + const int status42 = NormalExitStatus(42); + const testing::ExitedWithCode pred0(0); + const testing::ExitedWithCode pred1(1); + const testing::ExitedWithCode pred42(42); + EXPECT_PRED1(pred0, status0); + EXPECT_PRED1(pred1, status1); + EXPECT_PRED1(pred42, status42); + EXPECT_FALSE(pred0(status1)); + EXPECT_FALSE(pred42(status0)); + EXPECT_FALSE(pred1(status42)); +} + +// Tests the KilledBySignal predicate. +TEST(ExitStatusPredicateTest, KilledBySignal) { + const int status_segv = KilledExitStatus(SIGSEGV); + const int status_kill = KilledExitStatus(SIGKILL); + const testing::KilledBySignal pred_segv(SIGSEGV); + const testing::KilledBySignal pred_kill(SIGKILL); + EXPECT_PRED1(pred_segv, status_segv); + EXPECT_PRED1(pred_kill, status_kill); + EXPECT_FALSE(pred_segv(status_kill)); + EXPECT_FALSE(pred_kill(status_segv)); +} + +# endif // GTEST_OS_WINDOWS || GTEST_OS_FUCHSIA + +// The following code intentionally tests a suboptimal syntax. +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdangling-else" +#pragma GCC diagnostic ignored "-Wempty-body" +#pragma GCC diagnostic ignored "-Wpragmas" +#endif +// Tests that the death test macros expand to code which may or may not +// be followed by operator<<, and that in either case the complete text +// comprises only a single C++ statement. +TEST_F(TestForDeathTest, SingleStatement) { + if (AlwaysFalse()) + // This would fail if executed; this is a compilation test only + ASSERT_DEATH(return, ""); + + if (AlwaysTrue()) + EXPECT_DEATH(_exit(1), ""); + else + // This empty "else" branch is meant to ensure that EXPECT_DEATH + // doesn't expand into an "if" statement without an "else" + ; + + if (AlwaysFalse()) + ASSERT_DEATH(return, "") << "did not die"; + + if (AlwaysFalse()) + ; + else + EXPECT_DEATH(_exit(1), "") << 1 << 2 << 3; +} +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +# if GTEST_USES_PCRE + +void DieWithEmbeddedNul() { + fprintf(stderr, "Hello%cmy null world.\n", '\0'); + fflush(stderr); + _exit(1); +} + +// Tests that EXPECT_DEATH and ASSERT_DEATH work when the error +// message has a NUL character in it. +TEST_F(TestForDeathTest, EmbeddedNulInMessage) { + EXPECT_DEATH(DieWithEmbeddedNul(), "my null world"); + ASSERT_DEATH(DieWithEmbeddedNul(), "my null world"); +} + +# endif // GTEST_USES_PCRE + +// Tests that death test macros expand to code which interacts well with switch +// statements. +TEST_F(TestForDeathTest, SwitchStatement) { + // Microsoft compiler usually complains about switch statements without + // case labels. We suppress that warning for this test. + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4065) + + switch (0) + default: + ASSERT_DEATH(_exit(1), "") << "exit in default switch handler"; + + switch (0) + case 0: + EXPECT_DEATH(_exit(1), "") << "exit in switch case"; + + GTEST_DISABLE_MSC_WARNINGS_POP_() +} + +// Tests that a static member function can be used in a "fast" style +// death test. +TEST_F(TestForDeathTest, StaticMemberFunctionFastStyle) { + GTEST_FLAG_SET(death_test_style, "fast"); + ASSERT_DEATH(StaticMemberFunction(), "death.*StaticMember"); +} + +// Tests that a method of the test fixture can be used in a "fast" +// style death test. +TEST_F(TestForDeathTest, MemberFunctionFastStyle) { + GTEST_FLAG_SET(death_test_style, "fast"); + should_die_ = true; + EXPECT_DEATH(MemberFunction(), "inside.*MemberFunction"); +} + +void ChangeToRootDir() { posix::ChDir(GTEST_PATH_SEP_); } + +// Tests that death tests work even if the current directory has been +// changed. +TEST_F(TestForDeathTest, FastDeathTestInChangedDir) { + GTEST_FLAG_SET(death_test_style, "fast"); + + ChangeToRootDir(); + EXPECT_EXIT(_exit(1), testing::ExitedWithCode(1), ""); + + ChangeToRootDir(); + ASSERT_DEATH(_exit(1), ""); +} + +# if GTEST_OS_LINUX +void SigprofAction(int, siginfo_t*, void*) { /* no op */ } + +// Sets SIGPROF action and ITIMER_PROF timer (interval: 1ms). +void SetSigprofActionAndTimer() { + struct sigaction signal_action; + memset(&signal_action, 0, sizeof(signal_action)); + sigemptyset(&signal_action.sa_mask); + signal_action.sa_sigaction = SigprofAction; + signal_action.sa_flags = SA_RESTART | SA_SIGINFO; + ASSERT_EQ(0, sigaction(SIGPROF, &signal_action, nullptr)); + // timer comes second, to avoid SIGPROF premature delivery, as suggested at + // https://www.gnu.org/software/libc/manual/html_node/Setting-an-Alarm.html + struct itimerval timer; + timer.it_interval.tv_sec = 0; + timer.it_interval.tv_usec = 1; + timer.it_value = timer.it_interval; + ASSERT_EQ(0, setitimer(ITIMER_PROF, &timer, nullptr)); +} + +// Disables ITIMER_PROF timer and ignores SIGPROF signal. +void DisableSigprofActionAndTimer(struct sigaction* old_signal_action) { + struct itimerval timer; + timer.it_interval.tv_sec = 0; + timer.it_interval.tv_usec = 0; + timer.it_value = timer.it_interval; + ASSERT_EQ(0, setitimer(ITIMER_PROF, &timer, nullptr)); + struct sigaction signal_action; + memset(&signal_action, 0, sizeof(signal_action)); + sigemptyset(&signal_action.sa_mask); + signal_action.sa_handler = SIG_IGN; + ASSERT_EQ(0, sigaction(SIGPROF, &signal_action, old_signal_action)); +} + +// Tests that death tests work when SIGPROF handler and timer are set. +TEST_F(TestForDeathTest, FastSigprofActionSet) { + GTEST_FLAG_SET(death_test_style, "fast"); + SetSigprofActionAndTimer(); + EXPECT_DEATH(_exit(1), ""); + struct sigaction old_signal_action; + DisableSigprofActionAndTimer(&old_signal_action); + EXPECT_TRUE(old_signal_action.sa_sigaction == SigprofAction); +} + +TEST_F(TestForDeathTest, ThreadSafeSigprofActionSet) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + SetSigprofActionAndTimer(); + EXPECT_DEATH(_exit(1), ""); + struct sigaction old_signal_action; + DisableSigprofActionAndTimer(&old_signal_action); + EXPECT_TRUE(old_signal_action.sa_sigaction == SigprofAction); +} +# endif // GTEST_OS_LINUX + +// Repeats a representative sample of death tests in the "threadsafe" style: + +TEST_F(TestForDeathTest, StaticMemberFunctionThreadsafeStyle) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + ASSERT_DEATH(StaticMemberFunction(), "death.*StaticMember"); +} + +TEST_F(TestForDeathTest, MemberFunctionThreadsafeStyle) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + should_die_ = true; + EXPECT_DEATH(MemberFunction(), "inside.*MemberFunction"); +} + +TEST_F(TestForDeathTest, ThreadsafeDeathTestInLoop) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + + for (int i = 0; i < 3; ++i) + EXPECT_EXIT(_exit(i), testing::ExitedWithCode(i), "") << ": i = " << i; +} + +TEST_F(TestForDeathTest, ThreadsafeDeathTestInChangedDir) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + + ChangeToRootDir(); + EXPECT_EXIT(_exit(1), testing::ExitedWithCode(1), ""); + + ChangeToRootDir(); + ASSERT_DEATH(_exit(1), ""); +} + +TEST_F(TestForDeathTest, MixedStyles) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + EXPECT_DEATH(_exit(1), ""); + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_DEATH(_exit(1), ""); +} + +# if GTEST_HAS_CLONE && GTEST_HAS_PTHREAD + +bool pthread_flag; + +void SetPthreadFlag() { + pthread_flag = true; +} + +TEST_F(TestForDeathTest, DoesNotExecuteAtforkHooks) { + if (!GTEST_FLAG_GET(death_test_use_fork)) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + pthread_flag = false; + ASSERT_EQ(0, pthread_atfork(&SetPthreadFlag, nullptr, nullptr)); + ASSERT_DEATH(_exit(1), ""); + ASSERT_FALSE(pthread_flag); + } +} + +# endif // GTEST_HAS_CLONE && GTEST_HAS_PTHREAD + +// Tests that a method of another class can be used in a death test. +TEST_F(TestForDeathTest, MethodOfAnotherClass) { + const MayDie x(true); + ASSERT_DEATH(x.MemberFunction(), "MayDie\\:\\:MemberFunction"); +} + +// Tests that a global function can be used in a death test. +TEST_F(TestForDeathTest, GlobalFunction) { + EXPECT_DEATH(GlobalFunction(), "GlobalFunction"); +} + +// Tests that any value convertible to an RE works as a second +// argument to EXPECT_DEATH. +TEST_F(TestForDeathTest, AcceptsAnythingConvertibleToRE) { + static const char regex_c_str[] = "GlobalFunction"; + EXPECT_DEATH(GlobalFunction(), regex_c_str); + + const testing::internal::RE regex(regex_c_str); + EXPECT_DEATH(GlobalFunction(), regex); + +# if !GTEST_USES_PCRE + + const ::std::string regex_std_str(regex_c_str); + EXPECT_DEATH(GlobalFunction(), regex_std_str); + + // This one is tricky; a temporary pointer into another temporary. Reference + // lifetime extension of the pointer is not sufficient. + EXPECT_DEATH(GlobalFunction(), ::std::string(regex_c_str).c_str()); + +# endif // !GTEST_USES_PCRE +} + +// Tests that a non-void function can be used in a death test. +TEST_F(TestForDeathTest, NonVoidFunction) { + ASSERT_DEATH(NonVoidFunction(), "NonVoidFunction"); +} + +// Tests that functions that take parameter(s) can be used in a death test. +TEST_F(TestForDeathTest, FunctionWithParameter) { + EXPECT_DEATH(DieIf(true), "DieIf\\(\\)"); + EXPECT_DEATH(DieIfLessThan(2, 3), "DieIfLessThan"); +} + +// Tests that ASSERT_DEATH can be used outside a TEST, TEST_F, or test fixture. +TEST_F(TestForDeathTest, OutsideFixture) { + DeathTestSubroutine(); +} + +// Tests that death tests can be done inside a loop. +TEST_F(TestForDeathTest, InsideLoop) { + for (int i = 0; i < 5; i++) { + EXPECT_DEATH(DieIfLessThan(-1, i), "DieIfLessThan") << "where i == " << i; + } +} + +// Tests that a compound statement can be used in a death test. +TEST_F(TestForDeathTest, CompoundStatement) { + EXPECT_DEATH({ // NOLINT + const int x = 2; + const int y = x + 1; + DieIfLessThan(x, y); + }, + "DieIfLessThan"); +} + +// Tests that code that doesn't die causes a death test to fail. +TEST_F(TestForDeathTest, DoesNotDie) { + EXPECT_NONFATAL_FAILURE(EXPECT_DEATH(DieIf(false), "DieIf"), + "failed to die"); +} + +// Tests that a death test fails when the error message isn't expected. +TEST_F(TestForDeathTest, ErrorMessageMismatch) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_DEATH(DieIf(true), "DieIfLessThan") << "End of death test message."; + }, "died but not with expected error"); +} + +// On exit, *aborted will be true if and only if the EXPECT_DEATH() +// statement aborted the function. +void ExpectDeathTestHelper(bool* aborted) { + *aborted = true; + EXPECT_DEATH(DieIf(false), "DieIf"); // This assertion should fail. + *aborted = false; +} + +// Tests that EXPECT_DEATH doesn't abort the test on failure. +TEST_F(TestForDeathTest, EXPECT_DEATH) { + bool aborted = true; + EXPECT_NONFATAL_FAILURE(ExpectDeathTestHelper(&aborted), + "failed to die"); + EXPECT_FALSE(aborted); +} + +// Tests that ASSERT_DEATH does abort the test on failure. +TEST_F(TestForDeathTest, ASSERT_DEATH) { + static bool aborted; + EXPECT_FATAL_FAILURE({ // NOLINT + aborted = true; + ASSERT_DEATH(DieIf(false), "DieIf"); // This assertion should fail. + aborted = false; + }, "failed to die"); + EXPECT_TRUE(aborted); +} + +// Tests that EXPECT_DEATH evaluates the arguments exactly once. +TEST_F(TestForDeathTest, SingleEvaluation) { + int x = 3; + EXPECT_DEATH(DieIf((++x) == 4), "DieIf"); + + const char* regex = "DieIf"; + const char* regex_save = regex; + EXPECT_DEATH(DieIfLessThan(3, 4), regex++); + EXPECT_EQ(regex_save + 1, regex); +} + +// Tests that run-away death tests are reported as failures. +TEST_F(TestForDeathTest, RunawayIsFailure) { + EXPECT_NONFATAL_FAILURE(EXPECT_DEATH(static_cast(0), "Foo"), + "failed to die."); +} + +// Tests that death tests report executing 'return' in the statement as +// failure. +TEST_F(TestForDeathTest, ReturnIsFailure) { + EXPECT_FATAL_FAILURE(ASSERT_DEATH(return, "Bar"), + "illegal return in test statement."); +} + +// Tests that EXPECT_DEBUG_DEATH works as expected, that is, you can stream a +// message to it, and in debug mode it: +// 1. Asserts on death. +// 2. Has no side effect. +// +// And in opt mode, it: +// 1. Has side effects but does not assert. +TEST_F(TestForDeathTest, TestExpectDebugDeath) { + int sideeffect = 0; + + // Put the regex in a local variable to make sure we don't get an "unused" + // warning in opt mode. + const char* regex = "death.*DieInDebugElse12"; + + EXPECT_DEBUG_DEATH(DieInDebugElse12(&sideeffect), regex) + << "Must accept a streamed message"; + +# ifdef NDEBUG + + // Checks that the assignment occurs in opt mode (sideeffect). + EXPECT_EQ(12, sideeffect); + +# else + + // Checks that the assignment does not occur in dbg mode (no sideeffect). + EXPECT_EQ(0, sideeffect); + +# endif +} + +# if GTEST_OS_WINDOWS + +// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/crtsetreportmode +// In debug mode, the calls to _CrtSetReportMode and _CrtSetReportFile enable +// the dumping of assertions to stderr. Tests that EXPECT_DEATH works as +// expected when in CRT debug mode (compiled with /MTd or /MDd, which defines +// _DEBUG) the Windows CRT crashes the process with an assertion failure. +// 1. Asserts on death. +// 2. Has no side effect (doesn't pop up a window or wait for user input). +#ifdef _DEBUG +TEST_F(TestForDeathTest, CRTDebugDeath) { + EXPECT_DEATH(DieInCRTDebugElse12(nullptr), "dup.* : Assertion failed") + << "Must accept a streamed message"; +} +#endif // _DEBUG + +# endif // GTEST_OS_WINDOWS + +// Tests that ASSERT_DEBUG_DEATH works as expected, that is, you can stream a +// message to it, and in debug mode it: +// 1. Asserts on death. +// 2. Has no side effect. +// +// And in opt mode, it: +// 1. Has side effects but does not assert. +TEST_F(TestForDeathTest, TestAssertDebugDeath) { + int sideeffect = 0; + + ASSERT_DEBUG_DEATH(DieInDebugElse12(&sideeffect), "death.*DieInDebugElse12") + << "Must accept a streamed message"; + +# ifdef NDEBUG + + // Checks that the assignment occurs in opt mode (sideeffect). + EXPECT_EQ(12, sideeffect); + +# else + + // Checks that the assignment does not occur in dbg mode (no sideeffect). + EXPECT_EQ(0, sideeffect); + +# endif +} + +# ifndef NDEBUG + +void ExpectDebugDeathHelper(bool* aborted) { + *aborted = true; + EXPECT_DEBUG_DEATH(return, "") << "This is expected to fail."; + *aborted = false; +} + +# if GTEST_OS_WINDOWS +TEST(PopUpDeathTest, DoesNotShowPopUpOnAbort) { + printf("This test should be considered failing if it shows " + "any pop-up dialogs.\n"); + fflush(stdout); + + EXPECT_DEATH( + { + GTEST_FLAG_SET(catch_exceptions, false); + abort(); + }, + ""); +} +# endif // GTEST_OS_WINDOWS + +// Tests that EXPECT_DEBUG_DEATH in debug mode does not abort +// the function. +TEST_F(TestForDeathTest, ExpectDebugDeathDoesNotAbort) { + bool aborted = true; + EXPECT_NONFATAL_FAILURE(ExpectDebugDeathHelper(&aborted), ""); + EXPECT_FALSE(aborted); +} + +void AssertDebugDeathHelper(bool* aborted) { + *aborted = true; + GTEST_LOG_(INFO) << "Before ASSERT_DEBUG_DEATH"; + ASSERT_DEBUG_DEATH(GTEST_LOG_(INFO) << "In ASSERT_DEBUG_DEATH"; return, "") + << "This is expected to fail."; + GTEST_LOG_(INFO) << "After ASSERT_DEBUG_DEATH"; + *aborted = false; +} + +// Tests that ASSERT_DEBUG_DEATH in debug mode aborts the function on +// failure. +TEST_F(TestForDeathTest, AssertDebugDeathAborts) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts2) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts3) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts4) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts5) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts6) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts7) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts8) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts9) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +TEST_F(TestForDeathTest, AssertDebugDeathAborts10) { + static bool aborted; + aborted = false; + EXPECT_FATAL_FAILURE(AssertDebugDeathHelper(&aborted), ""); + EXPECT_TRUE(aborted); +} + +# endif // _NDEBUG + +// Tests the *_EXIT family of macros, using a variety of predicates. +static void TestExitMacros() { + EXPECT_EXIT(_exit(1), testing::ExitedWithCode(1), ""); + ASSERT_EXIT(_exit(42), testing::ExitedWithCode(42), ""); + +# if GTEST_OS_WINDOWS + + // Of all signals effects on the process exit code, only those of SIGABRT + // are documented on Windows. + // See https://msdn.microsoft.com/en-us/query-bi/m/dwwzkt4c. + EXPECT_EXIT(raise(SIGABRT), testing::ExitedWithCode(3), "") << "b_ar"; + +# elif !GTEST_OS_FUCHSIA + + // Fuchsia has no unix signals. + EXPECT_EXIT(raise(SIGKILL), testing::KilledBySignal(SIGKILL), "") << "foo"; + ASSERT_EXIT(raise(SIGUSR2), testing::KilledBySignal(SIGUSR2), "") << "bar"; + + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_EXIT(_exit(0), testing::KilledBySignal(SIGSEGV), "") + << "This failure is expected, too."; + }, "This failure is expected, too."); + +# endif // GTEST_OS_WINDOWS + + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_EXIT(raise(SIGSEGV), testing::ExitedWithCode(0), "") + << "This failure is expected."; + }, "This failure is expected."); +} + +TEST_F(TestForDeathTest, ExitMacros) { + TestExitMacros(); +} + +TEST_F(TestForDeathTest, ExitMacrosUsingFork) { + GTEST_FLAG_SET(death_test_use_fork, true); + TestExitMacros(); +} + +TEST_F(TestForDeathTest, InvalidStyle) { + GTEST_FLAG_SET(death_test_style, "rococo"); + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_DEATH(_exit(0), "") << "This failure is expected."; + }, "This failure is expected."); +} + +TEST_F(TestForDeathTest, DeathTestFailedOutput) { + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_NONFATAL_FAILURE( + EXPECT_DEATH(DieWithMessage("death\n"), + "expected message"), + "Actual msg:\n" + "[ DEATH ] death\n"); +} + +TEST_F(TestForDeathTest, DeathTestUnexpectedReturnOutput) { + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_NONFATAL_FAILURE( + EXPECT_DEATH({ + fprintf(stderr, "returning\n"); + fflush(stderr); + return; + }, ""), + " Result: illegal return in test statement.\n" + " Error msg:\n" + "[ DEATH ] returning\n"); +} + +TEST_F(TestForDeathTest, DeathTestBadExitCodeOutput) { + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_NONFATAL_FAILURE( + EXPECT_EXIT(DieWithMessage("exiting with rc 1\n"), + testing::ExitedWithCode(3), + "expected message"), + " Result: died but not with expected exit code:\n" + " Exited with exit status 1\n" + "Actual msg:\n" + "[ DEATH ] exiting with rc 1\n"); +} + +TEST_F(TestForDeathTest, DeathTestMultiLineMatchFail) { + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_NONFATAL_FAILURE( + EXPECT_DEATH(DieWithMessage("line 1\nline 2\nline 3\n"), + "line 1\nxyz\nline 3\n"), + "Actual msg:\n" + "[ DEATH ] line 1\n" + "[ DEATH ] line 2\n" + "[ DEATH ] line 3\n"); +} + +TEST_F(TestForDeathTest, DeathTestMultiLineMatchPass) { + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_DEATH(DieWithMessage("line 1\nline 2\nline 3\n"), + "line 1\nline 2\nline 3\n"); +} + +// A DeathTestFactory that returns MockDeathTests. +class MockDeathTestFactory : public DeathTestFactory { + public: + MockDeathTestFactory(); + bool Create(const char* statement, + testing::Matcher matcher, const char* file, + int line, DeathTest** test) override; + + // Sets the parameters for subsequent calls to Create. + void SetParameters(bool create, DeathTest::TestRole role, + int status, bool passed); + + // Accessors. + int AssumeRoleCalls() const { return assume_role_calls_; } + int WaitCalls() const { return wait_calls_; } + size_t PassedCalls() const { return passed_args_.size(); } + bool PassedArgument(int n) const { + return passed_args_[static_cast(n)]; + } + size_t AbortCalls() const { return abort_args_.size(); } + DeathTest::AbortReason AbortArgument(int n) const { + return abort_args_[static_cast(n)]; + } + bool TestDeleted() const { return test_deleted_; } + + private: + friend class MockDeathTest; + // If true, Create will return a MockDeathTest; otherwise it returns + // NULL. + bool create_; + // The value a MockDeathTest will return from its AssumeRole method. + DeathTest::TestRole role_; + // The value a MockDeathTest will return from its Wait method. + int status_; + // The value a MockDeathTest will return from its Passed method. + bool passed_; + + // Number of times AssumeRole was called. + int assume_role_calls_; + // Number of times Wait was called. + int wait_calls_; + // The arguments to the calls to Passed since the last call to + // SetParameters. + std::vector passed_args_; + // The arguments to the calls to Abort since the last call to + // SetParameters. + std::vector abort_args_; + // True if the last MockDeathTest returned by Create has been + // deleted. + bool test_deleted_; +}; + + +// A DeathTest implementation useful in testing. It returns values set +// at its creation from its various inherited DeathTest methods, and +// reports calls to those methods to its parent MockDeathTestFactory +// object. +class MockDeathTest : public DeathTest { + public: + MockDeathTest(MockDeathTestFactory *parent, + TestRole role, int status, bool passed) : + parent_(parent), role_(role), status_(status), passed_(passed) { + } + ~MockDeathTest() override { parent_->test_deleted_ = true; } + TestRole AssumeRole() override { + ++parent_->assume_role_calls_; + return role_; + } + int Wait() override { + ++parent_->wait_calls_; + return status_; + } + bool Passed(bool exit_status_ok) override { + parent_->passed_args_.push_back(exit_status_ok); + return passed_; + } + void Abort(AbortReason reason) override { + parent_->abort_args_.push_back(reason); + } + + private: + MockDeathTestFactory* const parent_; + const TestRole role_; + const int status_; + const bool passed_; +}; + + +// MockDeathTestFactory constructor. +MockDeathTestFactory::MockDeathTestFactory() + : create_(true), + role_(DeathTest::OVERSEE_TEST), + status_(0), + passed_(true), + assume_role_calls_(0), + wait_calls_(0), + passed_args_(), + abort_args_() { +} + + +// Sets the parameters for subsequent calls to Create. +void MockDeathTestFactory::SetParameters(bool create, + DeathTest::TestRole role, + int status, bool passed) { + create_ = create; + role_ = role; + status_ = status; + passed_ = passed; + + assume_role_calls_ = 0; + wait_calls_ = 0; + passed_args_.clear(); + abort_args_.clear(); +} + + +// Sets test to NULL (if create_ is false) or to the address of a new +// MockDeathTest object with parameters taken from the last call +// to SetParameters (if create_ is true). Always returns true. +bool MockDeathTestFactory::Create( + const char* /*statement*/, testing::Matcher /*matcher*/, + const char* /*file*/, int /*line*/, DeathTest** test) { + test_deleted_ = false; + if (create_) { + *test = new MockDeathTest(this, role_, status_, passed_); + } else { + *test = nullptr; + } + return true; +} + +// A test fixture for testing the logic of the GTEST_DEATH_TEST_ macro. +// It installs a MockDeathTestFactory that is used for the duration +// of the test case. +class MacroLogicDeathTest : public testing::Test { + protected: + static testing::internal::ReplaceDeathTestFactory* replacer_; + static MockDeathTestFactory* factory_; + + static void SetUpTestSuite() { + factory_ = new MockDeathTestFactory; + replacer_ = new testing::internal::ReplaceDeathTestFactory(factory_); + } + + static void TearDownTestSuite() { + delete replacer_; + replacer_ = nullptr; + delete factory_; + factory_ = nullptr; + } + + // Runs a death test that breaks the rules by returning. Such a death + // test cannot be run directly from a test routine that uses a + // MockDeathTest, or the remainder of the routine will not be executed. + static void RunReturningDeathTest(bool* flag) { + ASSERT_DEATH({ // NOLINT + *flag = true; + return; + }, ""); + } +}; + +testing::internal::ReplaceDeathTestFactory* MacroLogicDeathTest::replacer_ = + nullptr; +MockDeathTestFactory* MacroLogicDeathTest::factory_ = nullptr; + +// Test that nothing happens when the factory doesn't return a DeathTest: +TEST_F(MacroLogicDeathTest, NothingHappens) { + bool flag = false; + factory_->SetParameters(false, DeathTest::OVERSEE_TEST, 0, true); + EXPECT_DEATH(flag = true, ""); + EXPECT_FALSE(flag); + EXPECT_EQ(0, factory_->AssumeRoleCalls()); + EXPECT_EQ(0, factory_->WaitCalls()); + EXPECT_EQ(0U, factory_->PassedCalls()); + EXPECT_EQ(0U, factory_->AbortCalls()); + EXPECT_FALSE(factory_->TestDeleted()); +} + +// Test that the parent process doesn't run the death test code, +// and that the Passed method returns false when the (simulated) +// child process exits with status 0: +TEST_F(MacroLogicDeathTest, ChildExitsSuccessfully) { + bool flag = false; + factory_->SetParameters(true, DeathTest::OVERSEE_TEST, 0, true); + EXPECT_DEATH(flag = true, ""); + EXPECT_FALSE(flag); + EXPECT_EQ(1, factory_->AssumeRoleCalls()); + EXPECT_EQ(1, factory_->WaitCalls()); + ASSERT_EQ(1U, factory_->PassedCalls()); + EXPECT_FALSE(factory_->PassedArgument(0)); + EXPECT_EQ(0U, factory_->AbortCalls()); + EXPECT_TRUE(factory_->TestDeleted()); +} + +// Tests that the Passed method was given the argument "true" when +// the (simulated) child process exits with status 1: +TEST_F(MacroLogicDeathTest, ChildExitsUnsuccessfully) { + bool flag = false; + factory_->SetParameters(true, DeathTest::OVERSEE_TEST, 1, true); + EXPECT_DEATH(flag = true, ""); + EXPECT_FALSE(flag); + EXPECT_EQ(1, factory_->AssumeRoleCalls()); + EXPECT_EQ(1, factory_->WaitCalls()); + ASSERT_EQ(1U, factory_->PassedCalls()); + EXPECT_TRUE(factory_->PassedArgument(0)); + EXPECT_EQ(0U, factory_->AbortCalls()); + EXPECT_TRUE(factory_->TestDeleted()); +} + +// Tests that the (simulated) child process executes the death test +// code, and is aborted with the correct AbortReason if it +// executes a return statement. +TEST_F(MacroLogicDeathTest, ChildPerformsReturn) { + bool flag = false; + factory_->SetParameters(true, DeathTest::EXECUTE_TEST, 0, true); + RunReturningDeathTest(&flag); + EXPECT_TRUE(flag); + EXPECT_EQ(1, factory_->AssumeRoleCalls()); + EXPECT_EQ(0, factory_->WaitCalls()); + EXPECT_EQ(0U, factory_->PassedCalls()); + EXPECT_EQ(1U, factory_->AbortCalls()); + EXPECT_EQ(DeathTest::TEST_ENCOUNTERED_RETURN_STATEMENT, + factory_->AbortArgument(0)); + EXPECT_TRUE(factory_->TestDeleted()); +} + +// Tests that the (simulated) child process is aborted with the +// correct AbortReason if it does not die. +TEST_F(MacroLogicDeathTest, ChildDoesNotDie) { + bool flag = false; + factory_->SetParameters(true, DeathTest::EXECUTE_TEST, 0, true); + EXPECT_DEATH(flag = true, ""); + EXPECT_TRUE(flag); + EXPECT_EQ(1, factory_->AssumeRoleCalls()); + EXPECT_EQ(0, factory_->WaitCalls()); + EXPECT_EQ(0U, factory_->PassedCalls()); + // This time there are two calls to Abort: one since the test didn't + // die, and another from the ReturnSentinel when it's destroyed. The + // sentinel normally isn't destroyed if a test doesn't die, since + // _exit(2) is called in that case by ForkingDeathTest, but not by + // our MockDeathTest. + ASSERT_EQ(2U, factory_->AbortCalls()); + EXPECT_EQ(DeathTest::TEST_DID_NOT_DIE, + factory_->AbortArgument(0)); + EXPECT_EQ(DeathTest::TEST_ENCOUNTERED_RETURN_STATEMENT, + factory_->AbortArgument(1)); + EXPECT_TRUE(factory_->TestDeleted()); +} + +// Tests that a successful death test does not register a successful +// test part. +TEST(SuccessRegistrationDeathTest, NoSuccessPart) { + EXPECT_DEATH(_exit(1), ""); + EXPECT_EQ(0, GetUnitTestImpl()->current_test_result()->total_part_count()); +} + +TEST(StreamingAssertionsDeathTest, DeathTest) { + EXPECT_DEATH(_exit(1), "") << "unexpected failure"; + ASSERT_DEATH(_exit(1), "") << "unexpected failure"; + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_DEATH(_exit(0), "") << "expected failure"; + }, "expected failure"); + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_DEATH(_exit(0), "") << "expected failure"; + }, "expected failure"); +} + +// Tests that GetLastErrnoDescription returns an empty string when the +// last error is 0 and non-empty string when it is non-zero. +TEST(GetLastErrnoDescription, GetLastErrnoDescriptionWorks) { + errno = ENOENT; + EXPECT_STRNE("", GetLastErrnoDescription().c_str()); + errno = 0; + EXPECT_STREQ("", GetLastErrnoDescription().c_str()); +} + +# if GTEST_OS_WINDOWS +TEST(AutoHandleTest, AutoHandleWorks) { + HANDLE handle = ::CreateEvent(NULL, FALSE, FALSE, NULL); + ASSERT_NE(INVALID_HANDLE_VALUE, handle); + + // Tests that the AutoHandle is correctly initialized with a handle. + testing::internal::AutoHandle auto_handle(handle); + EXPECT_EQ(handle, auto_handle.Get()); + + // Tests that Reset assigns INVALID_HANDLE_VALUE. + // Note that this cannot verify whether the original handle is closed. + auto_handle.Reset(); + EXPECT_EQ(INVALID_HANDLE_VALUE, auto_handle.Get()); + + // Tests that Reset assigns the new handle. + // Note that this cannot verify whether the original handle is closed. + handle = ::CreateEvent(NULL, FALSE, FALSE, NULL); + ASSERT_NE(INVALID_HANDLE_VALUE, handle); + auto_handle.Reset(handle); + EXPECT_EQ(handle, auto_handle.Get()); + + // Tests that AutoHandle contains INVALID_HANDLE_VALUE by default. + testing::internal::AutoHandle auto_handle2; + EXPECT_EQ(INVALID_HANDLE_VALUE, auto_handle2.Get()); +} +# endif // GTEST_OS_WINDOWS + +# if GTEST_OS_WINDOWS +typedef unsigned __int64 BiggestParsable; +typedef signed __int64 BiggestSignedParsable; +# else +typedef unsigned long long BiggestParsable; +typedef signed long long BiggestSignedParsable; +# endif // GTEST_OS_WINDOWS + +// We cannot use std::numeric_limits::max() as it clashes with the +// max() macro defined by . +const BiggestParsable kBiggestParsableMax = ULLONG_MAX; +const BiggestSignedParsable kBiggestSignedParsableMax = LLONG_MAX; + +TEST(ParseNaturalNumberTest, RejectsInvalidFormat) { + BiggestParsable result = 0; + + // Rejects non-numbers. + EXPECT_FALSE(ParseNaturalNumber("non-number string", &result)); + + // Rejects numbers with whitespace prefix. + EXPECT_FALSE(ParseNaturalNumber(" 123", &result)); + + // Rejects negative numbers. + EXPECT_FALSE(ParseNaturalNumber("-123", &result)); + + // Rejects numbers starting with a plus sign. + EXPECT_FALSE(ParseNaturalNumber("+123", &result)); + errno = 0; +} + +TEST(ParseNaturalNumberTest, RejectsOverflownNumbers) { + BiggestParsable result = 0; + + EXPECT_FALSE(ParseNaturalNumber("99999999999999999999999", &result)); + + signed char char_result = 0; + EXPECT_FALSE(ParseNaturalNumber("200", &char_result)); + errno = 0; +} + +TEST(ParseNaturalNumberTest, AcceptsValidNumbers) { + BiggestParsable result = 0; + + result = 0; + ASSERT_TRUE(ParseNaturalNumber("123", &result)); + EXPECT_EQ(123U, result); + + // Check 0 as an edge case. + result = 1; + ASSERT_TRUE(ParseNaturalNumber("0", &result)); + EXPECT_EQ(0U, result); + + result = 1; + ASSERT_TRUE(ParseNaturalNumber("00000", &result)); + EXPECT_EQ(0U, result); +} + +TEST(ParseNaturalNumberTest, AcceptsTypeLimits) { + Message msg; + msg << kBiggestParsableMax; + + BiggestParsable result = 0; + EXPECT_TRUE(ParseNaturalNumber(msg.GetString(), &result)); + EXPECT_EQ(kBiggestParsableMax, result); + + Message msg2; + msg2 << kBiggestSignedParsableMax; + + BiggestSignedParsable signed_result = 0; + EXPECT_TRUE(ParseNaturalNumber(msg2.GetString(), &signed_result)); + EXPECT_EQ(kBiggestSignedParsableMax, signed_result); + + Message msg3; + msg3 << INT_MAX; + + int int_result = 0; + EXPECT_TRUE(ParseNaturalNumber(msg3.GetString(), &int_result)); + EXPECT_EQ(INT_MAX, int_result); + + Message msg4; + msg4 << UINT_MAX; + + unsigned int uint_result = 0; + EXPECT_TRUE(ParseNaturalNumber(msg4.GetString(), &uint_result)); + EXPECT_EQ(UINT_MAX, uint_result); +} + +TEST(ParseNaturalNumberTest, WorksForShorterIntegers) { + short short_result = 0; + ASSERT_TRUE(ParseNaturalNumber("123", &short_result)); + EXPECT_EQ(123, short_result); + + signed char char_result = 0; + ASSERT_TRUE(ParseNaturalNumber("123", &char_result)); + EXPECT_EQ(123, char_result); +} + +# if GTEST_OS_WINDOWS +TEST(EnvironmentTest, HandleFitsIntoSizeT) { + ASSERT_TRUE(sizeof(HANDLE) <= sizeof(size_t)); +} +# endif // GTEST_OS_WINDOWS + +// Tests that EXPECT_DEATH_IF_SUPPORTED/ASSERT_DEATH_IF_SUPPORTED trigger +// failures when death tests are available on the system. +TEST(ConditionalDeathMacrosDeathTest, ExpectsDeathWhenDeathTestsAvailable) { + EXPECT_DEATH_IF_SUPPORTED(DieInside("CondDeathTestExpectMacro"), + "death inside CondDeathTestExpectMacro"); + ASSERT_DEATH_IF_SUPPORTED(DieInside("CondDeathTestAssertMacro"), + "death inside CondDeathTestAssertMacro"); + + // Empty statement will not crash, which must trigger a failure. + EXPECT_NONFATAL_FAILURE(EXPECT_DEATH_IF_SUPPORTED(;, ""), ""); + EXPECT_FATAL_FAILURE(ASSERT_DEATH_IF_SUPPORTED(;, ""), ""); +} + +TEST(InDeathTestChildDeathTest, ReportsDeathTestCorrectlyInFastStyle) { + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_FALSE(InDeathTestChild()); + EXPECT_DEATH({ + fprintf(stderr, InDeathTestChild() ? "Inside" : "Outside"); + fflush(stderr); + _exit(1); + }, "Inside"); +} + +TEST(InDeathTestChildDeathTest, ReportsDeathTestCorrectlyInThreadSafeStyle) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + EXPECT_FALSE(InDeathTestChild()); + EXPECT_DEATH({ + fprintf(stderr, InDeathTestChild() ? "Inside" : "Outside"); + fflush(stderr); + _exit(1); + }, "Inside"); +} + +void DieWithMessage(const char* message) { + fputs(message, stderr); + fflush(stderr); // Make sure the text is printed before the process exits. + _exit(1); +} + +TEST(MatcherDeathTest, DoesNotBreakBareRegexMatching) { + // googletest tests this, of course; here we ensure that including googlemock + // has not broken it. +#if GTEST_USES_POSIX_RE + EXPECT_DEATH(DieWithMessage("O, I die, Horatio."), "I d[aeiou]e"); +#else + EXPECT_DEATH(DieWithMessage("O, I die, Horatio."), "I di?e"); +#endif +} + +TEST(MatcherDeathTest, MonomorphicMatcherMatches) { + EXPECT_DEATH(DieWithMessage("Behind O, I am slain!"), + Matcher(ContainsRegex("I am slain"))); +} + +TEST(MatcherDeathTest, MonomorphicMatcherDoesNotMatch) { + EXPECT_NONFATAL_FAILURE( + EXPECT_DEATH( + DieWithMessage("Behind O, I am slain!"), + Matcher(ContainsRegex("Ow, I am slain"))), + "Expected: contains regular expression \"Ow, I am slain\""); +} + +TEST(MatcherDeathTest, PolymorphicMatcherMatches) { + EXPECT_DEATH(DieWithMessage("The rest is silence."), + ContainsRegex("rest is silence")); +} + +TEST(MatcherDeathTest, PolymorphicMatcherDoesNotMatch) { + EXPECT_NONFATAL_FAILURE( + EXPECT_DEATH(DieWithMessage("The rest is silence."), + ContainsRegex("rest is science")), + "Expected: contains regular expression \"rest is science\""); +} + +} // namespace + +#else // !GTEST_HAS_DEATH_TEST follows + +namespace { + +using testing::internal::CaptureStderr; +using testing::internal::GetCapturedStderr; + +// Tests that EXPECT_DEATH_IF_SUPPORTED/ASSERT_DEATH_IF_SUPPORTED are still +// defined but do not trigger failures when death tests are not available on +// the system. +TEST(ConditionalDeathMacrosTest, WarnsWhenDeathTestsNotAvailable) { + // Empty statement will not crash, but that should not trigger a failure + // when death tests are not supported. + CaptureStderr(); + EXPECT_DEATH_IF_SUPPORTED(;, ""); + std::string output = GetCapturedStderr(); + ASSERT_TRUE(NULL != strstr(output.c_str(), + "Death tests are not supported on this platform")); + ASSERT_TRUE(NULL != strstr(output.c_str(), ";")); + + // The streamed message should not be printed as there is no test failure. + CaptureStderr(); + EXPECT_DEATH_IF_SUPPORTED(;, "") << "streamed message"; + output = GetCapturedStderr(); + ASSERT_TRUE(NULL == strstr(output.c_str(), "streamed message")); + + CaptureStderr(); + ASSERT_DEATH_IF_SUPPORTED(;, ""); // NOLINT + output = GetCapturedStderr(); + ASSERT_TRUE(NULL != strstr(output.c_str(), + "Death tests are not supported on this platform")); + ASSERT_TRUE(NULL != strstr(output.c_str(), ";")); + + CaptureStderr(); + ASSERT_DEATH_IF_SUPPORTED(;, "") << "streamed message"; // NOLINT + output = GetCapturedStderr(); + ASSERT_TRUE(NULL == strstr(output.c_str(), "streamed message")); +} + +void FuncWithAssert(int* n) { + ASSERT_DEATH_IF_SUPPORTED(return;, ""); + (*n)++; +} + +// Tests that ASSERT_DEATH_IF_SUPPORTED does not return from the current +// function (as ASSERT_DEATH does) if death tests are not supported. +TEST(ConditionalDeathMacrosTest, AssertDeatDoesNotReturnhIfUnsupported) { + int n = 0; + FuncWithAssert(&n); + EXPECT_EQ(1, n); +} + +} // namespace + +#endif // !GTEST_HAS_DEATH_TEST + +namespace { + +// The following code intentionally tests a suboptimal syntax. +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdangling-else" +#pragma GCC diagnostic ignored "-Wempty-body" +#pragma GCC diagnostic ignored "-Wpragmas" +#endif +// Tests that the death test macros expand to code which may or may not +// be followed by operator<<, and that in either case the complete text +// comprises only a single C++ statement. +// +// The syntax should work whether death tests are available or not. +TEST(ConditionalDeathMacrosSyntaxDeathTest, SingleStatement) { + if (AlwaysFalse()) + // This would fail if executed; this is a compilation test only + ASSERT_DEATH_IF_SUPPORTED(return, ""); + + if (AlwaysTrue()) + EXPECT_DEATH_IF_SUPPORTED(_exit(1), ""); + else + // This empty "else" branch is meant to ensure that EXPECT_DEATH + // doesn't expand into an "if" statement without an "else" + ; // NOLINT + + if (AlwaysFalse()) + ASSERT_DEATH_IF_SUPPORTED(return, "") << "did not die"; + + if (AlwaysFalse()) + ; // NOLINT + else + EXPECT_DEATH_IF_SUPPORTED(_exit(1), "") << 1 << 2 << 3; +} +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +// Tests that conditional death test macros expand to code which interacts +// well with switch statements. +TEST(ConditionalDeathMacrosSyntaxDeathTest, SwitchStatement) { + // Microsoft compiler usually complains about switch statements without + // case labels. We suppress that warning for this test. + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4065) + + switch (0) + default: + ASSERT_DEATH_IF_SUPPORTED(_exit(1), "") + << "exit in default switch handler"; + + switch (0) + case 0: + EXPECT_DEATH_IF_SUPPORTED(_exit(1), "") << "exit in switch case"; + + GTEST_DISABLE_MSC_WARNINGS_POP_() +} + +// Tests that a test case whose name ends with "DeathTest" works fine +// on Windows. +TEST(NotADeathTest, Test) { + SUCCEED(); +} + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-death-test_ex_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-death-test_ex_test.cc new file mode 100644 index 000000000000..bbacc8ae88f5 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-death-test_ex_test.cc @@ -0,0 +1,92 @@ +// Copyright 2010, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests that verify interaction of exceptions and death tests. + +#include "gtest/gtest-death-test.h" +#include "gtest/gtest.h" + +#if GTEST_HAS_DEATH_TEST + +# if GTEST_HAS_SEH +# include // For RaiseException(). +# endif + +# include "gtest/gtest-spi.h" + +# if GTEST_HAS_EXCEPTIONS + +# include // For std::exception. + +// Tests that death tests report thrown exceptions as failures and that the +// exceptions do not escape death test macros. +TEST(CxxExceptionDeathTest, ExceptionIsFailure) { + try { + EXPECT_NONFATAL_FAILURE(EXPECT_DEATH(throw 1, ""), "threw an exception"); + } catch (...) { // NOLINT + FAIL() << "An exception escaped a death test macro invocation " + << "with catch_exceptions " + << (GTEST_FLAG_GET(catch_exceptions) ? "enabled" : "disabled"); + } +} + +class TestException : public std::exception { + public: + const char* what() const noexcept override { return "exceptional message"; } +}; + +TEST(CxxExceptionDeathTest, PrintsMessageForStdExceptions) { + // Verifies that the exception message is quoted in the failure text. + EXPECT_NONFATAL_FAILURE(EXPECT_DEATH(throw TestException(), ""), + "exceptional message"); + // Verifies that the location is mentioned in the failure text. + EXPECT_NONFATAL_FAILURE(EXPECT_DEATH(throw TestException(), ""), + __FILE__); +} +# endif // GTEST_HAS_EXCEPTIONS + +# if GTEST_HAS_SEH +// Tests that enabling interception of SEH exceptions with the +// catch_exceptions flag does not interfere with SEH exceptions being +// treated as death by death tests. +TEST(SehExceptionDeasTest, CatchExceptionsDoesNotInterfere) { + EXPECT_DEATH(RaiseException(42, 0x0, 0, NULL), "") + << "with catch_exceptions " + << (GTEST_FLAG_GET(catch_exceptions) ? "enabled" : "disabled"); +} +# endif + +#endif // GTEST_HAS_DEATH_TEST + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + GTEST_FLAG_SET(catch_exceptions, GTEST_ENABLE_CATCH_EXCEPTIONS_ != 0); + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-env-var-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-env-var-test.py new file mode 100755 index 000000000000..bc4d87d93841 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-env-var-test.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# +# Copyright 2008, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Verifies that Google Test correctly parses environment variables.""" + +import os +from googletest.test import gtest_test_utils + + +IS_WINDOWS = os.name == 'nt' +IS_LINUX = os.name == 'posix' and os.uname()[0] == 'Linux' + +COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-env-var-test_') + +environ = os.environ.copy() + + +def AssertEq(expected, actual): + if expected != actual: + print('Expected: %s' % (expected,)) + print(' Actual: %s' % (actual,)) + raise AssertionError + + +def SetEnvVar(env_var, value): + """Sets the env variable to 'value'; unsets it when 'value' is None.""" + + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] + + +def GetFlag(flag): + """Runs googletest-env-var-test_ and returns its output.""" + + args = [COMMAND] + if flag is not None: + args += [flag] + return gtest_test_utils.Subprocess(args, env=environ).output + + +def TestFlag(flag, test_val, default_val): + """Verifies that the given flag is affected by the corresponding env var.""" + + env_var = 'GTEST_' + flag.upper() + SetEnvVar(env_var, test_val) + AssertEq(test_val, GetFlag(flag)) + SetEnvVar(env_var, None) + AssertEq(default_val, GetFlag(flag)) + + +class GTestEnvVarTest(gtest_test_utils.TestCase): + + def testEnvVarAffectsFlag(self): + """Tests that environment variable should affect the corresponding flag.""" + + TestFlag('break_on_failure', '1', '0') + TestFlag('color', 'yes', 'auto') + SetEnvVar('TESTBRIDGE_TEST_RUNNER_FAIL_FAST', None) # For 'fail_fast' test + TestFlag('fail_fast', '1', '0') + TestFlag('filter', 'FooTest.Bar', '*') + SetEnvVar('XML_OUTPUT_FILE', None) # For 'output' test + TestFlag('output', 'xml:tmp/foo.xml', '') + TestFlag('brief', '1', '0') + TestFlag('print_time', '0', '1') + TestFlag('repeat', '999', '1') + TestFlag('throw_on_failure', '1', '0') + TestFlag('death_test_style', 'threadsafe', 'fast') + TestFlag('catch_exceptions', '0', '1') + + if IS_LINUX: + TestFlag('death_test_use_fork', '1', '0') + TestFlag('stack_trace_depth', '0', '100') + + + def testXmlOutputFile(self): + """Tests that $XML_OUTPUT_FILE affects the output flag.""" + + SetEnvVar('GTEST_OUTPUT', None) + SetEnvVar('XML_OUTPUT_FILE', 'tmp/bar.xml') + AssertEq('xml:tmp/bar.xml', GetFlag('output')) + + def testXmlOutputFileOverride(self): + """Tests that $XML_OUTPUT_FILE is overridden by $GTEST_OUTPUT.""" + + SetEnvVar('GTEST_OUTPUT', 'xml:tmp/foo.xml') + SetEnvVar('XML_OUTPUT_FILE', 'tmp/bar.xml') + AssertEq('xml:tmp/foo.xml', GetFlag('output')) + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-env-var-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-env-var-test_.cc new file mode 100644 index 000000000000..0ff015228f1e --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-env-var-test_.cc @@ -0,0 +1,132 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// A helper program for testing that Google Test parses the environment +// variables correctly. + +#include + +#include "gtest/gtest.h" +#include "src/gtest-internal-inl.h" + +using ::std::cout; + +namespace testing { + +// The purpose of this is to make the test more realistic by ensuring +// that the UnitTest singleton is created before main() is entered. +// We don't actual run the TEST itself. +TEST(GTestEnvVarTest, Dummy) { +} + +void PrintFlag(const char* flag) { + if (strcmp(flag, "break_on_failure") == 0) { + cout << GTEST_FLAG_GET(break_on_failure); + return; + } + + if (strcmp(flag, "catch_exceptions") == 0) { + cout << GTEST_FLAG_GET(catch_exceptions); + return; + } + + if (strcmp(flag, "color") == 0) { + cout << GTEST_FLAG_GET(color); + return; + } + + if (strcmp(flag, "death_test_style") == 0) { + cout << GTEST_FLAG_GET(death_test_style); + return; + } + + if (strcmp(flag, "death_test_use_fork") == 0) { + cout << GTEST_FLAG_GET(death_test_use_fork); + return; + } + + if (strcmp(flag, "fail_fast") == 0) { + cout << GTEST_FLAG_GET(fail_fast); + return; + } + + if (strcmp(flag, "filter") == 0) { + cout << GTEST_FLAG_GET(filter); + return; + } + + if (strcmp(flag, "output") == 0) { + cout << GTEST_FLAG_GET(output); + return; + } + + if (strcmp(flag, "brief") == 0) { + cout << GTEST_FLAG_GET(brief); + return; + } + + if (strcmp(flag, "print_time") == 0) { + cout << GTEST_FLAG_GET(print_time); + return; + } + + if (strcmp(flag, "repeat") == 0) { + cout << GTEST_FLAG_GET(repeat); + return; + } + + if (strcmp(flag, "stack_trace_depth") == 0) { + cout << GTEST_FLAG_GET(stack_trace_depth); + return; + } + + if (strcmp(flag, "throw_on_failure") == 0) { + cout << GTEST_FLAG_GET(throw_on_failure); + return; + } + + cout << "Invalid flag name " << flag + << ". Valid names are break_on_failure, color, filter, etc.\n"; + exit(1); +} + +} // namespace testing + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + if (argc != 2) { + cout << "Usage: googletest-env-var-test_ NAME_OF_FLAG\n"; + return 1; + } + + testing::PrintFlag(argv[1]); + return 0; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-failfast-unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-failfast-unittest.py new file mode 100755 index 000000000000..1356d4f8b5b7 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-failfast-unittest.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python +# +# Copyright 2020 Google Inc. All Rights Reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for Google Test fail_fast. + +A user can specify if a Google Test program should continue test execution +after a test failure via the GTEST_FAIL_FAST environment variable or the +--gtest_fail_fast flag. The default value of the flag can also be changed +by Bazel fail fast environment variable TESTBRIDGE_TEST_RUNNER_FAIL_FAST. + +This script tests such functionality by invoking googletest-failfast-unittest_ +(a program written with Google Test) with different environments and command +line flags. +""" + +import os +from googletest.test import gtest_test_utils + +# Constants. + +# Bazel testbridge environment variable for fail fast +BAZEL_FAIL_FAST_ENV_VAR = 'TESTBRIDGE_TEST_RUNNER_FAIL_FAST' + +# The environment variable for specifying fail fast. +FAIL_FAST_ENV_VAR = 'GTEST_FAIL_FAST' + +# The command line flag for specifying fail fast. +FAIL_FAST_FLAG = 'gtest_fail_fast' + +# The command line flag to run disabled tests. +RUN_DISABLED_FLAG = 'gtest_also_run_disabled_tests' + +# The command line flag for specifying a filter. +FILTER_FLAG = 'gtest_filter' + +# Command to run the googletest-failfast-unittest_ program. +COMMAND = gtest_test_utils.GetTestExecutablePath( + 'googletest-failfast-unittest_') + +# The command line flag to tell Google Test to output the list of tests it +# will run. +LIST_TESTS_FLAG = '--gtest_list_tests' + +# Indicates whether Google Test supports death tests. +SUPPORTS_DEATH_TESTS = 'HasDeathTest' in gtest_test_utils.Subprocess( + [COMMAND, LIST_TESTS_FLAG]).output + +# Utilities. + +environ = os.environ.copy() + + +def SetEnvVar(env_var, value): + """Sets the env variable to 'value'; unsets it when 'value' is None.""" + + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] + + +def RunAndReturnOutput(test_suite=None, fail_fast=None, run_disabled=False): + """Runs the test program and returns its output.""" + + args = [] + xml_path = os.path.join(gtest_test_utils.GetTempDir(), + '.GTestFailFastUnitTest.xml') + args += ['--gtest_output=xml:' + xml_path] + if fail_fast is not None: + if isinstance(fail_fast, str): + args += ['--%s=%s' % (FAIL_FAST_FLAG, fail_fast)] + elif fail_fast: + args += ['--%s' % FAIL_FAST_FLAG] + else: + args += ['--no%s' % FAIL_FAST_FLAG] + if test_suite: + args += ['--%s=%s.*' % (FILTER_FLAG, test_suite)] + if run_disabled: + args += ['--%s' % RUN_DISABLED_FLAG] + txt_out = gtest_test_utils.Subprocess([COMMAND] + args, env=environ).output + with open(xml_path) as xml_file: + return txt_out, xml_file.read() + + +# The unit test. +class GTestFailFastUnitTest(gtest_test_utils.TestCase): + """Tests the env variable or the command line flag for fail_fast.""" + + def testDefaultBehavior(self): + """Tests the behavior of not specifying the fail_fast.""" + + txt, _ = RunAndReturnOutput() + self.assertIn('22 FAILED TEST', txt) + + def testGoogletestFlag(self): + txt, _ = RunAndReturnOutput(test_suite='HasSimpleTest', fail_fast=True) + self.assertIn('1 FAILED TEST', txt) + self.assertIn('[ SKIPPED ] 3 tests', txt) + + txt, _ = RunAndReturnOutput(test_suite='HasSimpleTest', fail_fast=False) + self.assertIn('4 FAILED TEST', txt) + self.assertNotIn('[ SKIPPED ]', txt) + + def testGoogletestEnvVar(self): + """Tests the behavior of specifying fail_fast via Googletest env var.""" + + try: + SetEnvVar(FAIL_FAST_ENV_VAR, '1') + txt, _ = RunAndReturnOutput('HasSimpleTest') + self.assertIn('1 FAILED TEST', txt) + self.assertIn('[ SKIPPED ] 3 tests', txt) + + SetEnvVar(FAIL_FAST_ENV_VAR, '0') + txt, _ = RunAndReturnOutput('HasSimpleTest') + self.assertIn('4 FAILED TEST', txt) + self.assertNotIn('[ SKIPPED ]', txt) + finally: + SetEnvVar(FAIL_FAST_ENV_VAR, None) + + def testBazelEnvVar(self): + """Tests the behavior of specifying fail_fast via Bazel testbridge.""" + + try: + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, '1') + txt, _ = RunAndReturnOutput('HasSimpleTest') + self.assertIn('1 FAILED TEST', txt) + self.assertIn('[ SKIPPED ] 3 tests', txt) + + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, '0') + txt, _ = RunAndReturnOutput('HasSimpleTest') + self.assertIn('4 FAILED TEST', txt) + self.assertNotIn('[ SKIPPED ]', txt) + finally: + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, None) + + def testFlagOverridesEnvVar(self): + """Tests precedence of flag over env var.""" + + try: + SetEnvVar(FAIL_FAST_ENV_VAR, '0') + txt, _ = RunAndReturnOutput('HasSimpleTest', True) + self.assertIn('1 FAILED TEST', txt) + self.assertIn('[ SKIPPED ] 3 tests', txt) + finally: + SetEnvVar(FAIL_FAST_ENV_VAR, None) + + def testGoogletestEnvVarOverridesBazelEnvVar(self): + """Tests that the Googletest native env var over Bazel testbridge.""" + + try: + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, '0') + SetEnvVar(FAIL_FAST_ENV_VAR, '1') + txt, _ = RunAndReturnOutput('HasSimpleTest') + self.assertIn('1 FAILED TEST', txt) + self.assertIn('[ SKIPPED ] 3 tests', txt) + finally: + SetEnvVar(FAIL_FAST_ENV_VAR, None) + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, None) + + def testEventListener(self): + txt, _ = RunAndReturnOutput(test_suite='HasSkipTest', fail_fast=True) + self.assertIn('1 FAILED TEST', txt) + self.assertIn('[ SKIPPED ] 3 tests', txt) + for expected_count, callback in [(1, 'OnTestSuiteStart'), + (5, 'OnTestStart'), + (5, 'OnTestEnd'), + (5, 'OnTestPartResult'), + (1, 'OnTestSuiteEnd')]: + self.assertEqual( + expected_count, txt.count(callback), + 'Expected %d calls to callback %s match count on output: %s ' % + (expected_count, callback, txt)) + + txt, _ = RunAndReturnOutput(test_suite='HasSkipTest', fail_fast=False) + self.assertIn('3 FAILED TEST', txt) + self.assertIn('[ SKIPPED ] 1 test', txt) + for expected_count, callback in [(1, 'OnTestSuiteStart'), + (5, 'OnTestStart'), + (5, 'OnTestEnd'), + (5, 'OnTestPartResult'), + (1, 'OnTestSuiteEnd')]: + self.assertEqual( + expected_count, txt.count(callback), + 'Expected %d calls to callback %s match count on output: %s ' % + (expected_count, callback, txt)) + + def assertXmlResultCount(self, result, count, xml): + self.assertEqual( + count, xml.count('result="%s"' % result), + 'Expected \'result="%s"\' match count of %s: %s ' % + (result, count, xml)) + + def assertXmlStatusCount(self, status, count, xml): + self.assertEqual( + count, xml.count('status="%s"' % status), + 'Expected \'status="%s"\' match count of %s: %s ' % + (status, count, xml)) + + def assertFailFastXmlAndTxtOutput(self, + fail_fast, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, + run_disabled=False): + """Assert XML and text output of a test execution.""" + + txt, xml = RunAndReturnOutput(test_suite, fail_fast, run_disabled) + if failure_count > 0: + self.assertIn('%s FAILED TEST' % failure_count, txt) + if suppressed_count > 0: + self.assertIn('%s DISABLED TEST' % suppressed_count, txt) + if skipped_count > 0: + self.assertIn('[ SKIPPED ] %s tests' % skipped_count, txt) + self.assertXmlStatusCount('run', + passed_count + failure_count + skipped_count, xml) + self.assertXmlStatusCount('notrun', suppressed_count, xml) + self.assertXmlResultCount('completed', passed_count + failure_count, xml) + self.assertXmlResultCount('skipped', skipped_count, xml) + self.assertXmlResultCount('suppressed', suppressed_count, xml) + + def assertFailFastBehavior(self, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, + run_disabled=False): + """Assert --fail_fast via flag.""" + + for fail_fast in ('true', '1', 't', True): + self.assertFailFastXmlAndTxtOutput(fail_fast, test_suite, passed_count, + failure_count, skipped_count, + suppressed_count, run_disabled) + + def assertNotFailFastBehavior(self, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, + run_disabled=False): + """Assert --nofail_fast via flag.""" + + for fail_fast in ('false', '0', 'f', False): + self.assertFailFastXmlAndTxtOutput(fail_fast, test_suite, passed_count, + failure_count, skipped_count, + suppressed_count, run_disabled) + + def testFlag_HasFixtureTest(self): + """Tests the behavior of fail_fast and TEST_F.""" + self.assertFailFastBehavior( + test_suite='HasFixtureTest', + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0) + self.assertNotFailFastBehavior( + test_suite='HasFixtureTest', + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0) + + def testFlag_HasSimpleTest(self): + """Tests the behavior of fail_fast and TEST.""" + self.assertFailFastBehavior( + test_suite='HasSimpleTest', + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0) + self.assertNotFailFastBehavior( + test_suite='HasSimpleTest', + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0) + + def testFlag_HasParametersTest(self): + """Tests the behavior of fail_fast and TEST_P.""" + self.assertFailFastBehavior( + test_suite='HasParametersSuite/HasParametersTest', + passed_count=0, + failure_count=1, + skipped_count=3, + suppressed_count=0) + self.assertNotFailFastBehavior( + test_suite='HasParametersSuite/HasParametersTest', + passed_count=0, + failure_count=4, + skipped_count=0, + suppressed_count=0) + + def testFlag_HasDisabledTest(self): + """Tests the behavior of fail_fast and Disabled test cases.""" + self.assertFailFastBehavior( + test_suite='HasDisabledTest', + passed_count=1, + failure_count=1, + skipped_count=2, + suppressed_count=1, + run_disabled=False) + self.assertNotFailFastBehavior( + test_suite='HasDisabledTest', + passed_count=1, + failure_count=3, + skipped_count=0, + suppressed_count=1, + run_disabled=False) + + def testFlag_HasDisabledRunDisabledTest(self): + """Tests the behavior of fail_fast and Disabled test cases enabled.""" + self.assertFailFastBehavior( + test_suite='HasDisabledTest', + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0, + run_disabled=True) + self.assertNotFailFastBehavior( + test_suite='HasDisabledTest', + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0, + run_disabled=True) + + def testFlag_HasDisabledSuiteTest(self): + """Tests the behavior of fail_fast and Disabled test suites.""" + self.assertFailFastBehavior( + test_suite='DISABLED_HasDisabledSuite', + passed_count=0, + failure_count=0, + skipped_count=0, + suppressed_count=5, + run_disabled=False) + self.assertNotFailFastBehavior( + test_suite='DISABLED_HasDisabledSuite', + passed_count=0, + failure_count=0, + skipped_count=0, + suppressed_count=5, + run_disabled=False) + + def testFlag_HasDisabledSuiteRunDisabledTest(self): + """Tests the behavior of fail_fast and Disabled test suites enabled.""" + self.assertFailFastBehavior( + test_suite='DISABLED_HasDisabledSuite', + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0, + run_disabled=True) + self.assertNotFailFastBehavior( + test_suite='DISABLED_HasDisabledSuite', + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0, + run_disabled=True) + + if SUPPORTS_DEATH_TESTS: + + def testFlag_HasDeathTest(self): + """Tests the behavior of fail_fast and death tests.""" + self.assertFailFastBehavior( + test_suite='HasDeathTest', + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0) + self.assertNotFailFastBehavior( + test_suite='HasDeathTest', + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-failfast-unittest_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-failfast-unittest_.cc new file mode 100644 index 000000000000..0b2c951bc008 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-failfast-unittest_.cc @@ -0,0 +1,167 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Unit test for Google Test test filters. +// +// A user can specify which test(s) in a Google Test program to run via +// either the GTEST_FILTER environment variable or the --gtest_filter +// flag. This is used for testing such functionality. +// +// The program will be invoked from a Python unit test. Don't run it +// directly. + +#include "gtest/gtest.h" + +namespace { + +// Test HasFixtureTest. + +class HasFixtureTest : public testing::Test {}; + +TEST_F(HasFixtureTest, Test0) {} + +TEST_F(HasFixtureTest, Test1) { FAIL() << "Expected failure."; } + +TEST_F(HasFixtureTest, Test2) { FAIL() << "Expected failure."; } + +TEST_F(HasFixtureTest, Test3) { FAIL() << "Expected failure."; } + +TEST_F(HasFixtureTest, Test4) { FAIL() << "Expected failure."; } + +// Test HasSimpleTest. + +TEST(HasSimpleTest, Test0) {} + +TEST(HasSimpleTest, Test1) { FAIL() << "Expected failure."; } + +TEST(HasSimpleTest, Test2) { FAIL() << "Expected failure."; } + +TEST(HasSimpleTest, Test3) { FAIL() << "Expected failure."; } + +TEST(HasSimpleTest, Test4) { FAIL() << "Expected failure."; } + +// Test HasDisabledTest. + +TEST(HasDisabledTest, Test0) {} + +TEST(HasDisabledTest, DISABLED_Test1) { FAIL() << "Expected failure."; } + +TEST(HasDisabledTest, Test2) { FAIL() << "Expected failure."; } + +TEST(HasDisabledTest, Test3) { FAIL() << "Expected failure."; } + +TEST(HasDisabledTest, Test4) { FAIL() << "Expected failure."; } + +// Test HasDeathTest + +TEST(HasDeathTest, Test0) { EXPECT_DEATH_IF_SUPPORTED(exit(1), ".*"); } + +TEST(HasDeathTest, Test1) { + EXPECT_DEATH_IF_SUPPORTED(FAIL() << "Expected failure.", ".*"); +} + +TEST(HasDeathTest, Test2) { + EXPECT_DEATH_IF_SUPPORTED(FAIL() << "Expected failure.", ".*"); +} + +TEST(HasDeathTest, Test3) { + EXPECT_DEATH_IF_SUPPORTED(FAIL() << "Expected failure.", ".*"); +} + +TEST(HasDeathTest, Test4) { + EXPECT_DEATH_IF_SUPPORTED(FAIL() << "Expected failure.", ".*"); +} + +// Test DISABLED_HasDisabledSuite + +TEST(DISABLED_HasDisabledSuite, Test0) {} + +TEST(DISABLED_HasDisabledSuite, Test1) { FAIL() << "Expected failure."; } + +TEST(DISABLED_HasDisabledSuite, Test2) { FAIL() << "Expected failure."; } + +TEST(DISABLED_HasDisabledSuite, Test3) { FAIL() << "Expected failure."; } + +TEST(DISABLED_HasDisabledSuite, Test4) { FAIL() << "Expected failure."; } + +// Test HasParametersTest + +class HasParametersTest : public testing::TestWithParam {}; + +TEST_P(HasParametersTest, Test1) { FAIL() << "Expected failure."; } + +TEST_P(HasParametersTest, Test2) { FAIL() << "Expected failure."; } + +INSTANTIATE_TEST_SUITE_P(HasParametersSuite, HasParametersTest, + testing::Values(1, 2)); + +class MyTestListener : public ::testing::EmptyTestEventListener { + void OnTestSuiteStart(const ::testing::TestSuite& test_suite) override { + printf("We are in OnTestSuiteStart of %s.\n", test_suite.name()); + } + + void OnTestStart(const ::testing::TestInfo& test_info) override { + printf("We are in OnTestStart of %s.%s.\n", test_info.test_suite_name(), + test_info.name()); + } + + void OnTestPartResult( + const ::testing::TestPartResult& test_part_result) override { + printf("We are in OnTestPartResult %s:%d.\n", test_part_result.file_name(), + test_part_result.line_number()); + } + + void OnTestEnd(const ::testing::TestInfo& test_info) override { + printf("We are in OnTestEnd of %s.%s.\n", test_info.test_suite_name(), + test_info.name()); + } + + void OnTestSuiteEnd(const ::testing::TestSuite& test_suite) override { + printf("We are in OnTestSuiteEnd of %s.\n", test_suite.name()); + } +}; + +TEST(HasSkipTest, Test0) { SUCCEED() << "Expected success."; } + +TEST(HasSkipTest, Test1) { GTEST_SKIP() << "Expected skip."; } + +TEST(HasSkipTest, Test2) { FAIL() << "Expected failure."; } + +TEST(HasSkipTest, Test3) { FAIL() << "Expected failure."; } + +TEST(HasSkipTest, Test4) { FAIL() << "Expected failure."; } + +} // namespace + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + ::testing::UnitTest::GetInstance()->listeners().Append(new MyTestListener()); + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filepath-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filepath-test.cc new file mode 100644 index 000000000000..aafad36f3fef --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filepath-test.cc @@ -0,0 +1,649 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Google Test filepath utilities +// +// This file tests classes and functions used internally by +// Google Test. They are subject to change without notice. +// +// This file is #included from gtest-internal.h. +// Do not #include this file anywhere else! + +#include "gtest/internal/gtest-filepath.h" +#include "gtest/gtest.h" +#include "src/gtest-internal-inl.h" + +#if GTEST_OS_WINDOWS_MOBILE +# include // NOLINT +#elif GTEST_OS_WINDOWS +# include // NOLINT +#endif // GTEST_OS_WINDOWS_MOBILE + +namespace testing { +namespace internal { +namespace { + +#if GTEST_OS_WINDOWS_MOBILE + +// Windows CE doesn't have the remove C function. +int remove(const char* path) { + LPCWSTR wpath = String::AnsiToUtf16(path); + int ret = DeleteFile(wpath) ? 0 : -1; + delete [] wpath; + return ret; +} +// Windows CE doesn't have the _rmdir C function. +int _rmdir(const char* path) { + FilePath filepath(path); + LPCWSTR wpath = String::AnsiToUtf16( + filepath.RemoveTrailingPathSeparator().c_str()); + int ret = RemoveDirectory(wpath) ? 0 : -1; + delete [] wpath; + return ret; +} + +#else + +TEST(GetCurrentDirTest, ReturnsCurrentDir) { + const FilePath original_dir = FilePath::GetCurrentDir(); + EXPECT_FALSE(original_dir.IsEmpty()); + + posix::ChDir(GTEST_PATH_SEP_); + const FilePath cwd = FilePath::GetCurrentDir(); + posix::ChDir(original_dir.c_str()); + +# if GTEST_OS_WINDOWS || GTEST_OS_OS2 + + // Skips the ":". + const char* const cwd_without_drive = strchr(cwd.c_str(), ':'); + ASSERT_TRUE(cwd_without_drive != NULL); + EXPECT_STREQ(GTEST_PATH_SEP_, cwd_without_drive + 1); + +# else + + EXPECT_EQ(GTEST_PATH_SEP_, cwd.string()); + +# endif +} + +#endif // GTEST_OS_WINDOWS_MOBILE + +TEST(IsEmptyTest, ReturnsTrueForEmptyPath) { + EXPECT_TRUE(FilePath("").IsEmpty()); +} + +TEST(IsEmptyTest, ReturnsFalseForNonEmptyPath) { + EXPECT_FALSE(FilePath("a").IsEmpty()); + EXPECT_FALSE(FilePath(".").IsEmpty()); + EXPECT_FALSE(FilePath("a/b").IsEmpty()); + EXPECT_FALSE(FilePath("a\\b\\").IsEmpty()); +} + +// RemoveDirectoryName "" -> "" +TEST(RemoveDirectoryNameTest, WhenEmptyName) { + EXPECT_EQ("", FilePath("").RemoveDirectoryName().string()); +} + +// RemoveDirectoryName "afile" -> "afile" +TEST(RemoveDirectoryNameTest, ButNoDirectory) { + EXPECT_EQ("afile", + FilePath("afile").RemoveDirectoryName().string()); +} + +// RemoveDirectoryName "/afile" -> "afile" +TEST(RemoveDirectoryNameTest, RootFileShouldGiveFileName) { + EXPECT_EQ("afile", + FilePath(GTEST_PATH_SEP_ "afile").RemoveDirectoryName().string()); +} + +// RemoveDirectoryName "adir/" -> "" +TEST(RemoveDirectoryNameTest, WhereThereIsNoFileName) { + EXPECT_EQ("", + FilePath("adir" GTEST_PATH_SEP_).RemoveDirectoryName().string()); +} + +// RemoveDirectoryName "adir/afile" -> "afile" +TEST(RemoveDirectoryNameTest, ShouldGiveFileName) { + EXPECT_EQ("afile", + FilePath("adir" GTEST_PATH_SEP_ "afile").RemoveDirectoryName().string()); +} + +// RemoveDirectoryName "adir/subdir/afile" -> "afile" +TEST(RemoveDirectoryNameTest, ShouldAlsoGiveFileName) { + EXPECT_EQ("afile", + FilePath("adir" GTEST_PATH_SEP_ "subdir" GTEST_PATH_SEP_ "afile") + .RemoveDirectoryName().string()); +} + +#if GTEST_HAS_ALT_PATH_SEP_ + +// Tests that RemoveDirectoryName() works with the alternate separator +// on Windows. + +// RemoveDirectoryName("/afile") -> "afile" +TEST(RemoveDirectoryNameTest, RootFileShouldGiveFileNameForAlternateSeparator) { + EXPECT_EQ("afile", FilePath("/afile").RemoveDirectoryName().string()); +} + +// RemoveDirectoryName("adir/") -> "" +TEST(RemoveDirectoryNameTest, WhereThereIsNoFileNameForAlternateSeparator) { + EXPECT_EQ("", FilePath("adir/").RemoveDirectoryName().string()); +} + +// RemoveDirectoryName("adir/afile") -> "afile" +TEST(RemoveDirectoryNameTest, ShouldGiveFileNameForAlternateSeparator) { + EXPECT_EQ("afile", FilePath("adir/afile").RemoveDirectoryName().string()); +} + +// RemoveDirectoryName("adir/subdir/afile") -> "afile" +TEST(RemoveDirectoryNameTest, ShouldAlsoGiveFileNameForAlternateSeparator) { + EXPECT_EQ("afile", + FilePath("adir/subdir/afile").RemoveDirectoryName().string()); +} + +#endif + +// RemoveFileName "" -> "./" +TEST(RemoveFileNameTest, EmptyName) { +#if GTEST_OS_WINDOWS_MOBILE + // On Windows CE, we use the root as the current directory. + EXPECT_EQ(GTEST_PATH_SEP_, FilePath("").RemoveFileName().string()); +#else + EXPECT_EQ("." GTEST_PATH_SEP_, FilePath("").RemoveFileName().string()); +#endif +} + +// RemoveFileName "adir/" -> "adir/" +TEST(RemoveFileNameTest, ButNoFile) { + EXPECT_EQ("adir" GTEST_PATH_SEP_, + FilePath("adir" GTEST_PATH_SEP_).RemoveFileName().string()); +} + +// RemoveFileName "adir/afile" -> "adir/" +TEST(RemoveFileNameTest, GivesDirName) { + EXPECT_EQ("adir" GTEST_PATH_SEP_, + FilePath("adir" GTEST_PATH_SEP_ "afile").RemoveFileName().string()); +} + +// RemoveFileName "adir/subdir/afile" -> "adir/subdir/" +TEST(RemoveFileNameTest, GivesDirAndSubDirName) { + EXPECT_EQ("adir" GTEST_PATH_SEP_ "subdir" GTEST_PATH_SEP_, + FilePath("adir" GTEST_PATH_SEP_ "subdir" GTEST_PATH_SEP_ "afile") + .RemoveFileName().string()); +} + +// RemoveFileName "/afile" -> "/" +TEST(RemoveFileNameTest, GivesRootDir) { + EXPECT_EQ(GTEST_PATH_SEP_, + FilePath(GTEST_PATH_SEP_ "afile").RemoveFileName().string()); +} + +#if GTEST_HAS_ALT_PATH_SEP_ + +// Tests that RemoveFileName() works with the alternate separator on +// Windows. + +// RemoveFileName("adir/") -> "adir/" +TEST(RemoveFileNameTest, ButNoFileForAlternateSeparator) { + EXPECT_EQ("adir" GTEST_PATH_SEP_, + FilePath("adir/").RemoveFileName().string()); +} + +// RemoveFileName("adir/afile") -> "adir/" +TEST(RemoveFileNameTest, GivesDirNameForAlternateSeparator) { + EXPECT_EQ("adir" GTEST_PATH_SEP_, + FilePath("adir/afile").RemoveFileName().string()); +} + +// RemoveFileName("adir/subdir/afile") -> "adir/subdir/" +TEST(RemoveFileNameTest, GivesDirAndSubDirNameForAlternateSeparator) { + EXPECT_EQ("adir" GTEST_PATH_SEP_ "subdir" GTEST_PATH_SEP_, + FilePath("adir/subdir/afile").RemoveFileName().string()); +} + +// RemoveFileName("/afile") -> "\" +TEST(RemoveFileNameTest, GivesRootDirForAlternateSeparator) { + EXPECT_EQ(GTEST_PATH_SEP_, FilePath("/afile").RemoveFileName().string()); +} + +#endif + +TEST(MakeFileNameTest, GenerateWhenNumberIsZero) { + FilePath actual = FilePath::MakeFileName(FilePath("foo"), FilePath("bar"), + 0, "xml"); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar.xml", actual.string()); +} + +TEST(MakeFileNameTest, GenerateFileNameNumberGtZero) { + FilePath actual = FilePath::MakeFileName(FilePath("foo"), FilePath("bar"), + 12, "xml"); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar_12.xml", actual.string()); +} + +TEST(MakeFileNameTest, GenerateFileNameWithSlashNumberIsZero) { + FilePath actual = FilePath::MakeFileName(FilePath("foo" GTEST_PATH_SEP_), + FilePath("bar"), 0, "xml"); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar.xml", actual.string()); +} + +TEST(MakeFileNameTest, GenerateFileNameWithSlashNumberGtZero) { + FilePath actual = FilePath::MakeFileName(FilePath("foo" GTEST_PATH_SEP_), + FilePath("bar"), 12, "xml"); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar_12.xml", actual.string()); +} + +TEST(MakeFileNameTest, GenerateWhenNumberIsZeroAndDirIsEmpty) { + FilePath actual = FilePath::MakeFileName(FilePath(""), FilePath("bar"), + 0, "xml"); + EXPECT_EQ("bar.xml", actual.string()); +} + +TEST(MakeFileNameTest, GenerateWhenNumberIsNotZeroAndDirIsEmpty) { + FilePath actual = FilePath::MakeFileName(FilePath(""), FilePath("bar"), + 14, "xml"); + EXPECT_EQ("bar_14.xml", actual.string()); +} + +TEST(ConcatPathsTest, WorksWhenDirDoesNotEndWithPathSep) { + FilePath actual = FilePath::ConcatPaths(FilePath("foo"), + FilePath("bar.xml")); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar.xml", actual.string()); +} + +TEST(ConcatPathsTest, WorksWhenPath1EndsWithPathSep) { + FilePath actual = FilePath::ConcatPaths(FilePath("foo" GTEST_PATH_SEP_), + FilePath("bar.xml")); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar.xml", actual.string()); +} + +TEST(ConcatPathsTest, Path1BeingEmpty) { + FilePath actual = FilePath::ConcatPaths(FilePath(""), + FilePath("bar.xml")); + EXPECT_EQ("bar.xml", actual.string()); +} + +TEST(ConcatPathsTest, Path2BeingEmpty) { + FilePath actual = FilePath::ConcatPaths(FilePath("foo"), FilePath("")); + EXPECT_EQ("foo" GTEST_PATH_SEP_, actual.string()); +} + +TEST(ConcatPathsTest, BothPathBeingEmpty) { + FilePath actual = FilePath::ConcatPaths(FilePath(""), + FilePath("")); + EXPECT_EQ("", actual.string()); +} + +TEST(ConcatPathsTest, Path1ContainsPathSep) { + FilePath actual = FilePath::ConcatPaths(FilePath("foo" GTEST_PATH_SEP_ "bar"), + FilePath("foobar.xml")); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar" GTEST_PATH_SEP_ "foobar.xml", + actual.string()); +} + +TEST(ConcatPathsTest, Path2ContainsPathSep) { + FilePath actual = FilePath::ConcatPaths( + FilePath("foo" GTEST_PATH_SEP_), + FilePath("bar" GTEST_PATH_SEP_ "bar.xml")); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar" GTEST_PATH_SEP_ "bar.xml", + actual.string()); +} + +TEST(ConcatPathsTest, Path2EndsWithPathSep) { + FilePath actual = FilePath::ConcatPaths(FilePath("foo"), + FilePath("bar" GTEST_PATH_SEP_)); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar" GTEST_PATH_SEP_, actual.string()); +} + +// RemoveTrailingPathSeparator "" -> "" +TEST(RemoveTrailingPathSeparatorTest, EmptyString) { + EXPECT_EQ("", FilePath("").RemoveTrailingPathSeparator().string()); +} + +// RemoveTrailingPathSeparator "foo" -> "foo" +TEST(RemoveTrailingPathSeparatorTest, FileNoSlashString) { + EXPECT_EQ("foo", FilePath("foo").RemoveTrailingPathSeparator().string()); +} + +// RemoveTrailingPathSeparator "foo/" -> "foo" +TEST(RemoveTrailingPathSeparatorTest, ShouldRemoveTrailingSeparator) { + EXPECT_EQ("foo", + FilePath("foo" GTEST_PATH_SEP_).RemoveTrailingPathSeparator().string()); +#if GTEST_HAS_ALT_PATH_SEP_ + EXPECT_EQ("foo", FilePath("foo/").RemoveTrailingPathSeparator().string()); +#endif +} + +// RemoveTrailingPathSeparator "foo/bar/" -> "foo/bar/" +TEST(RemoveTrailingPathSeparatorTest, ShouldRemoveLastSeparator) { + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar", + FilePath("foo" GTEST_PATH_SEP_ "bar" GTEST_PATH_SEP_) + .RemoveTrailingPathSeparator().string()); +} + +// RemoveTrailingPathSeparator "foo/bar" -> "foo/bar" +TEST(RemoveTrailingPathSeparatorTest, ShouldReturnUnmodified) { + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar", + FilePath("foo" GTEST_PATH_SEP_ "bar") + .RemoveTrailingPathSeparator().string()); +} + +TEST(DirectoryTest, RootDirectoryExists) { +#if GTEST_OS_WINDOWS // We are on Windows. + char current_drive[_MAX_PATH]; // NOLINT + current_drive[0] = static_cast(_getdrive() + 'A' - 1); + current_drive[1] = ':'; + current_drive[2] = '\\'; + current_drive[3] = '\0'; + EXPECT_TRUE(FilePath(current_drive).DirectoryExists()); +#else + EXPECT_TRUE(FilePath("/").DirectoryExists()); +#endif // GTEST_OS_WINDOWS +} + +#if GTEST_OS_WINDOWS +TEST(DirectoryTest, RootOfWrongDriveDoesNotExists) { + const int saved_drive_ = _getdrive(); + // Find a drive that doesn't exist. Start with 'Z' to avoid common ones. + for (char drive = 'Z'; drive >= 'A'; drive--) + if (_chdrive(drive - 'A' + 1) == -1) { + char non_drive[_MAX_PATH]; // NOLINT + non_drive[0] = drive; + non_drive[1] = ':'; + non_drive[2] = '\\'; + non_drive[3] = '\0'; + EXPECT_FALSE(FilePath(non_drive).DirectoryExists()); + break; + } + _chdrive(saved_drive_); +} +#endif // GTEST_OS_WINDOWS + +#if !GTEST_OS_WINDOWS_MOBILE +// Windows CE _does_ consider an empty directory to exist. +TEST(DirectoryTest, EmptyPathDirectoryDoesNotExist) { + EXPECT_FALSE(FilePath("").DirectoryExists()); +} +#endif // !GTEST_OS_WINDOWS_MOBILE + +TEST(DirectoryTest, CurrentDirectoryExists) { +#if GTEST_OS_WINDOWS // We are on Windows. +# ifndef _WIN32_CE // Windows CE doesn't have a current directory. + + EXPECT_TRUE(FilePath(".").DirectoryExists()); + EXPECT_TRUE(FilePath(".\\").DirectoryExists()); + +# endif // _WIN32_CE +#else + EXPECT_TRUE(FilePath(".").DirectoryExists()); + EXPECT_TRUE(FilePath("./").DirectoryExists()); +#endif // GTEST_OS_WINDOWS +} + +// "foo/bar" == foo//bar" == "foo///bar" +TEST(NormalizeTest, MultipleConsecutiveSepaparatorsInMidstring) { + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar", + FilePath("foo" GTEST_PATH_SEP_ "bar").string()); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar", + FilePath("foo" GTEST_PATH_SEP_ GTEST_PATH_SEP_ "bar").string()); + EXPECT_EQ("foo" GTEST_PATH_SEP_ "bar", + FilePath("foo" GTEST_PATH_SEP_ GTEST_PATH_SEP_ + GTEST_PATH_SEP_ "bar").string()); +} + +// "/bar" == //bar" == "///bar" +TEST(NormalizeTest, MultipleConsecutiveSepaparatorsAtStringStart) { + EXPECT_EQ(GTEST_PATH_SEP_ "bar", + FilePath(GTEST_PATH_SEP_ "bar").string()); + EXPECT_EQ(GTEST_PATH_SEP_ "bar", + FilePath(GTEST_PATH_SEP_ GTEST_PATH_SEP_ "bar").string()); + EXPECT_EQ(GTEST_PATH_SEP_ "bar", + FilePath(GTEST_PATH_SEP_ GTEST_PATH_SEP_ GTEST_PATH_SEP_ "bar").string()); +} + +// "foo/" == foo//" == "foo///" +TEST(NormalizeTest, MultipleConsecutiveSepaparatorsAtStringEnd) { + EXPECT_EQ("foo" GTEST_PATH_SEP_, + FilePath("foo" GTEST_PATH_SEP_).string()); + EXPECT_EQ("foo" GTEST_PATH_SEP_, + FilePath("foo" GTEST_PATH_SEP_ GTEST_PATH_SEP_).string()); + EXPECT_EQ("foo" GTEST_PATH_SEP_, + FilePath("foo" GTEST_PATH_SEP_ GTEST_PATH_SEP_ GTEST_PATH_SEP_).string()); +} + +#if GTEST_HAS_ALT_PATH_SEP_ + +// Tests that separators at the end of the string are normalized +// regardless of their combination (e.g. "foo\" =="foo/\" == +// "foo\\/"). +TEST(NormalizeTest, MixAlternateSeparatorAtStringEnd) { + EXPECT_EQ("foo" GTEST_PATH_SEP_, + FilePath("foo/").string()); + EXPECT_EQ("foo" GTEST_PATH_SEP_, + FilePath("foo" GTEST_PATH_SEP_ "/").string()); + EXPECT_EQ("foo" GTEST_PATH_SEP_, + FilePath("foo//" GTEST_PATH_SEP_).string()); +} + +#endif + +TEST(AssignmentOperatorTest, DefaultAssignedToNonDefault) { + FilePath default_path; + FilePath non_default_path("path"); + non_default_path = default_path; + EXPECT_EQ("", non_default_path.string()); + EXPECT_EQ("", default_path.string()); // RHS var is unchanged. +} + +TEST(AssignmentOperatorTest, NonDefaultAssignedToDefault) { + FilePath non_default_path("path"); + FilePath default_path; + default_path = non_default_path; + EXPECT_EQ("path", default_path.string()); + EXPECT_EQ("path", non_default_path.string()); // RHS var is unchanged. +} + +TEST(AssignmentOperatorTest, ConstAssignedToNonConst) { + const FilePath const_default_path("const_path"); + FilePath non_default_path("path"); + non_default_path = const_default_path; + EXPECT_EQ("const_path", non_default_path.string()); +} + +class DirectoryCreationTest : public Test { + protected: + void SetUp() override { + testdata_path_.Set(FilePath( + TempDir() + GetCurrentExecutableName().string() + + "_directory_creation" GTEST_PATH_SEP_ "test" GTEST_PATH_SEP_)); + testdata_file_.Set(testdata_path_.RemoveTrailingPathSeparator()); + + unique_file0_.Set(FilePath::MakeFileName(testdata_path_, FilePath("unique"), + 0, "txt")); + unique_file1_.Set(FilePath::MakeFileName(testdata_path_, FilePath("unique"), + 1, "txt")); + + remove(testdata_file_.c_str()); + remove(unique_file0_.c_str()); + remove(unique_file1_.c_str()); + posix::RmDir(testdata_path_.c_str()); + } + + void TearDown() override { + remove(testdata_file_.c_str()); + remove(unique_file0_.c_str()); + remove(unique_file1_.c_str()); + posix::RmDir(testdata_path_.c_str()); + } + + void CreateTextFile(const char* filename) { + FILE* f = posix::FOpen(filename, "w"); + fprintf(f, "text\n"); + fclose(f); + } + + // Strings representing a directory and a file, with identical paths + // except for the trailing separator character that distinquishes + // a directory named 'test' from a file named 'test'. Example names: + FilePath testdata_path_; // "/tmp/directory_creation/test/" + FilePath testdata_file_; // "/tmp/directory_creation/test" + FilePath unique_file0_; // "/tmp/directory_creation/test/unique.txt" + FilePath unique_file1_; // "/tmp/directory_creation/test/unique_1.txt" +}; + +TEST_F(DirectoryCreationTest, CreateDirectoriesRecursively) { + EXPECT_FALSE(testdata_path_.DirectoryExists()) << testdata_path_.string(); + EXPECT_TRUE(testdata_path_.CreateDirectoriesRecursively()); + EXPECT_TRUE(testdata_path_.DirectoryExists()); +} + +TEST_F(DirectoryCreationTest, CreateDirectoriesForAlreadyExistingPath) { + EXPECT_FALSE(testdata_path_.DirectoryExists()) << testdata_path_.string(); + EXPECT_TRUE(testdata_path_.CreateDirectoriesRecursively()); + // Call 'create' again... should still succeed. + EXPECT_TRUE(testdata_path_.CreateDirectoriesRecursively()); +} + +TEST_F(DirectoryCreationTest, CreateDirectoriesAndUniqueFilename) { + FilePath file_path(FilePath::GenerateUniqueFileName(testdata_path_, + FilePath("unique"), "txt")); + EXPECT_EQ(unique_file0_.string(), file_path.string()); + EXPECT_FALSE(file_path.FileOrDirectoryExists()); // file not there + + testdata_path_.CreateDirectoriesRecursively(); + EXPECT_FALSE(file_path.FileOrDirectoryExists()); // file still not there + CreateTextFile(file_path.c_str()); + EXPECT_TRUE(file_path.FileOrDirectoryExists()); + + FilePath file_path2(FilePath::GenerateUniqueFileName(testdata_path_, + FilePath("unique"), "txt")); + EXPECT_EQ(unique_file1_.string(), file_path2.string()); + EXPECT_FALSE(file_path2.FileOrDirectoryExists()); // file not there + CreateTextFile(file_path2.c_str()); + EXPECT_TRUE(file_path2.FileOrDirectoryExists()); +} + +TEST_F(DirectoryCreationTest, CreateDirectoriesFail) { + // force a failure by putting a file where we will try to create a directory. + CreateTextFile(testdata_file_.c_str()); + EXPECT_TRUE(testdata_file_.FileOrDirectoryExists()); + EXPECT_FALSE(testdata_file_.DirectoryExists()); + EXPECT_FALSE(testdata_file_.CreateDirectoriesRecursively()); +} + +TEST(NoDirectoryCreationTest, CreateNoDirectoriesForDefaultXmlFile) { + const FilePath test_detail_xml("test_detail.xml"); + EXPECT_FALSE(test_detail_xml.CreateDirectoriesRecursively()); +} + +TEST(FilePathTest, DefaultConstructor) { + FilePath fp; + EXPECT_EQ("", fp.string()); +} + +TEST(FilePathTest, CharAndCopyConstructors) { + const FilePath fp("spicy"); + EXPECT_EQ("spicy", fp.string()); + + const FilePath fp_copy(fp); + EXPECT_EQ("spicy", fp_copy.string()); +} + +TEST(FilePathTest, StringConstructor) { + const FilePath fp(std::string("cider")); + EXPECT_EQ("cider", fp.string()); +} + +TEST(FilePathTest, Set) { + const FilePath apple("apple"); + FilePath mac("mac"); + mac.Set(apple); // Implement Set() since overloading operator= is forbidden. + EXPECT_EQ("apple", mac.string()); + EXPECT_EQ("apple", apple.string()); +} + +TEST(FilePathTest, ToString) { + const FilePath file("drink"); + EXPECT_EQ("drink", file.string()); +} + +TEST(FilePathTest, RemoveExtension) { + EXPECT_EQ("app", FilePath("app.cc").RemoveExtension("cc").string()); + EXPECT_EQ("app", FilePath("app.exe").RemoveExtension("exe").string()); + EXPECT_EQ("APP", FilePath("APP.EXE").RemoveExtension("exe").string()); +} + +TEST(FilePathTest, RemoveExtensionWhenThereIsNoExtension) { + EXPECT_EQ("app", FilePath("app").RemoveExtension("exe").string()); +} + +TEST(FilePathTest, IsDirectory) { + EXPECT_FALSE(FilePath("cola").IsDirectory()); + EXPECT_TRUE(FilePath("koala" GTEST_PATH_SEP_).IsDirectory()); +#if GTEST_HAS_ALT_PATH_SEP_ + EXPECT_TRUE(FilePath("koala/").IsDirectory()); +#endif +} + +TEST(FilePathTest, IsAbsolutePath) { + EXPECT_FALSE(FilePath("is" GTEST_PATH_SEP_ "relative").IsAbsolutePath()); + EXPECT_FALSE(FilePath("").IsAbsolutePath()); +#if GTEST_OS_WINDOWS + EXPECT_TRUE(FilePath("c:\\" GTEST_PATH_SEP_ "is_not" + GTEST_PATH_SEP_ "relative").IsAbsolutePath()); + EXPECT_FALSE(FilePath("c:foo" GTEST_PATH_SEP_ "bar").IsAbsolutePath()); + EXPECT_TRUE(FilePath("c:/" GTEST_PATH_SEP_ "is_not" + GTEST_PATH_SEP_ "relative").IsAbsolutePath()); +#else + EXPECT_TRUE(FilePath(GTEST_PATH_SEP_ "is_not" GTEST_PATH_SEP_ "relative") + .IsAbsolutePath()); +#endif // GTEST_OS_WINDOWS +} + +TEST(FilePathTest, IsRootDirectory) { +#if GTEST_OS_WINDOWS + EXPECT_TRUE(FilePath("a:\\").IsRootDirectory()); + EXPECT_TRUE(FilePath("Z:/").IsRootDirectory()); + EXPECT_TRUE(FilePath("e://").IsRootDirectory()); + EXPECT_FALSE(FilePath("").IsRootDirectory()); + EXPECT_FALSE(FilePath("b:").IsRootDirectory()); + EXPECT_FALSE(FilePath("b:a").IsRootDirectory()); + EXPECT_FALSE(FilePath("8:/").IsRootDirectory()); + EXPECT_FALSE(FilePath("c|/").IsRootDirectory()); +#else + EXPECT_TRUE(FilePath("/").IsRootDirectory()); + EXPECT_TRUE(FilePath("//").IsRootDirectory()); + EXPECT_FALSE(FilePath("").IsRootDirectory()); + EXPECT_FALSE(FilePath("\\").IsRootDirectory()); + EXPECT_FALSE(FilePath("/x").IsRootDirectory()); +#endif +} + +} // namespace +} // namespace internal +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filter-unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filter-unittest.py new file mode 100755 index 000000000000..bd1d5a5db815 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filter-unittest.py @@ -0,0 +1,639 @@ +#!/usr/bin/env python +# +# Copyright 2005 Google Inc. All Rights Reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for Google Test test filters. + +A user can specify which test(s) in a Google Test program to run via either +the GTEST_FILTER environment variable or the --gtest_filter flag. +This script tests such functionality by invoking +googletest-filter-unittest_ (a program written with Google Test) with different +environments and command line flags. + +Note that test sharding may also influence which tests are filtered. Therefore, +we test that here also. +""" + +import os +import re +try: + from sets import Set as set # For Python 2.3 compatibility +except ImportError: + pass +import sys +from googletest.test import gtest_test_utils + +# Constants. + +# Checks if this platform can pass empty environment variables to child +# processes. We set an env variable to an empty string and invoke a python +# script in a subprocess to print whether the variable is STILL in +# os.environ. We then use 'eval' to parse the child's output so that an +# exception is thrown if the input is anything other than 'True' nor 'False'. +CAN_PASS_EMPTY_ENV = False +if sys.executable: + os.environ['EMPTY_VAR'] = '' + child = gtest_test_utils.Subprocess( + [sys.executable, '-c', 'import os; print(\'EMPTY_VAR\' in os.environ)']) + CAN_PASS_EMPTY_ENV = eval(child.output) + + +# Check if this platform can unset environment variables in child processes. +# We set an env variable to a non-empty string, unset it, and invoke +# a python script in a subprocess to print whether the variable +# is NO LONGER in os.environ. +# We use 'eval' to parse the child's output so that an exception +# is thrown if the input is neither 'True' nor 'False'. +CAN_UNSET_ENV = False +if sys.executable: + os.environ['UNSET_VAR'] = 'X' + del os.environ['UNSET_VAR'] + child = gtest_test_utils.Subprocess( + [sys.executable, '-c', 'import os; print(\'UNSET_VAR\' not in os.environ)' + ]) + CAN_UNSET_ENV = eval(child.output) + + +# Checks if we should test with an empty filter. This doesn't +# make sense on platforms that cannot pass empty env variables (Win32) +# and on platforms that cannot unset variables (since we cannot tell +# the difference between "" and NULL -- Borland and Solaris < 5.10) +CAN_TEST_EMPTY_FILTER = (CAN_PASS_EMPTY_ENV and CAN_UNSET_ENV) + + +# The environment variable for specifying the test filters. +FILTER_ENV_VAR = 'GTEST_FILTER' + +# The environment variables for test sharding. +TOTAL_SHARDS_ENV_VAR = 'GTEST_TOTAL_SHARDS' +SHARD_INDEX_ENV_VAR = 'GTEST_SHARD_INDEX' +SHARD_STATUS_FILE_ENV_VAR = 'GTEST_SHARD_STATUS_FILE' + +# The command line flag for specifying the test filters. +FILTER_FLAG = 'gtest_filter' + +# The command line flag for including disabled tests. +ALSO_RUN_DISABLED_TESTS_FLAG = 'gtest_also_run_disabled_tests' + +# Command to run the googletest-filter-unittest_ program. +COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-filter-unittest_') + +# Regex for determining whether parameterized tests are enabled in the binary. +PARAM_TEST_REGEX = re.compile(r'/ParamTest') + +# Regex for parsing test case names from Google Test's output. +TEST_CASE_REGEX = re.compile(r'^\[\-+\] \d+ tests? from (\w+(/\w+)?)') + +# Regex for parsing test names from Google Test's output. +TEST_REGEX = re.compile(r'^\[\s*RUN\s*\].*\.(\w+(/\w+)?)') + +# The command line flag to tell Google Test to output the list of tests it +# will run. +LIST_TESTS_FLAG = '--gtest_list_tests' + +# Indicates whether Google Test supports death tests. +SUPPORTS_DEATH_TESTS = 'HasDeathTest' in gtest_test_utils.Subprocess( + [COMMAND, LIST_TESTS_FLAG]).output + +# Full names of all tests in googletest-filter-unittests_. +PARAM_TESTS = [ + 'SeqP/ParamTest.TestX/0', + 'SeqP/ParamTest.TestX/1', + 'SeqP/ParamTest.TestY/0', + 'SeqP/ParamTest.TestY/1', + 'SeqQ/ParamTest.TestX/0', + 'SeqQ/ParamTest.TestX/1', + 'SeqQ/ParamTest.TestY/0', + 'SeqQ/ParamTest.TestY/1', + ] + +DISABLED_TESTS = [ + 'BarTest.DISABLED_TestFour', + 'BarTest.DISABLED_TestFive', + 'BazTest.DISABLED_TestC', + 'DISABLED_FoobarTest.Test1', + 'DISABLED_FoobarTest.DISABLED_Test2', + 'DISABLED_FoobarbazTest.TestA', + ] + +if SUPPORTS_DEATH_TESTS: + DEATH_TESTS = [ + 'HasDeathTest.Test1', + 'HasDeathTest.Test2', + ] +else: + DEATH_TESTS = [] + +# All the non-disabled tests. +ACTIVE_TESTS = [ + 'FooTest.Abc', + 'FooTest.Xyz', + + 'BarTest.TestOne', + 'BarTest.TestTwo', + 'BarTest.TestThree', + + 'BazTest.TestOne', + 'BazTest.TestA', + 'BazTest.TestB', + ] + DEATH_TESTS + PARAM_TESTS + +param_tests_present = None + +# Utilities. + +environ = os.environ.copy() + + +def SetEnvVar(env_var, value): + """Sets the env variable to 'value'; unsets it when 'value' is None.""" + + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] + + +def RunAndReturnOutput(args = None): + """Runs the test program and returns its output.""" + + return gtest_test_utils.Subprocess([COMMAND] + (args or []), + env=environ).output + + +def RunAndExtractTestList(args = None): + """Runs the test program and returns its exit code and a list of tests run.""" + + p = gtest_test_utils.Subprocess([COMMAND] + (args or []), env=environ) + tests_run = [] + test_case = '' + test = '' + for line in p.output.split('\n'): + match = TEST_CASE_REGEX.match(line) + if match is not None: + test_case = match.group(1) + else: + match = TEST_REGEX.match(line) + if match is not None: + test = match.group(1) + tests_run.append(test_case + '.' + test) + return (tests_run, p.exit_code) + + +def InvokeWithModifiedEnv(extra_env, function, *args, **kwargs): + """Runs the given function and arguments in a modified environment.""" + try: + original_env = environ.copy() + environ.update(extra_env) + return function(*args, **kwargs) + finally: + environ.clear() + environ.update(original_env) + + +def RunWithSharding(total_shards, shard_index, command): + """Runs a test program shard and returns exit code and a list of tests run.""" + + extra_env = {SHARD_INDEX_ENV_VAR: str(shard_index), + TOTAL_SHARDS_ENV_VAR: str(total_shards)} + return InvokeWithModifiedEnv(extra_env, RunAndExtractTestList, command) + +# The unit test. + + +class GTestFilterUnitTest(gtest_test_utils.TestCase): + """Tests the env variable or the command line flag to filter tests.""" + + # Utilities. + + def AssertSetEqual(self, lhs, rhs): + """Asserts that two sets are equal.""" + + for elem in lhs: + self.assert_(elem in rhs, '%s in %s' % (elem, rhs)) + + for elem in rhs: + self.assert_(elem in lhs, '%s in %s' % (elem, lhs)) + + def AssertPartitionIsValid(self, set_var, list_of_sets): + """Asserts that list_of_sets is a valid partition of set_var.""" + + full_partition = [] + for slice_var in list_of_sets: + full_partition.extend(slice_var) + self.assertEqual(len(set_var), len(full_partition)) + self.assertEqual(set(set_var), set(full_partition)) + + def AdjustForParameterizedTests(self, tests_to_run): + """Adjust tests_to_run in case value parameterized tests are disabled.""" + + global param_tests_present + if not param_tests_present: + return list(set(tests_to_run) - set(PARAM_TESTS)) + else: + return tests_to_run + + def RunAndVerify(self, gtest_filter, tests_to_run): + """Checks that the binary runs correct set of tests for a given filter.""" + + tests_to_run = self.AdjustForParameterizedTests(tests_to_run) + + # First, tests using the environment variable. + + # Windows removes empty variables from the environment when passing it + # to a new process. This means it is impossible to pass an empty filter + # into a process using the environment variable. However, we can still + # test the case when the variable is not supplied (i.e., gtest_filter is + # None). + # pylint: disable-msg=C6403 + if CAN_TEST_EMPTY_FILTER or gtest_filter != '': + SetEnvVar(FILTER_ENV_VAR, gtest_filter) + tests_run = RunAndExtractTestList()[0] + SetEnvVar(FILTER_ENV_VAR, None) + self.AssertSetEqual(tests_run, tests_to_run) + # pylint: enable-msg=C6403 + + # Next, tests using the command line flag. + + if gtest_filter is None: + args = [] + else: + args = ['--%s=%s' % (FILTER_FLAG, gtest_filter)] + + tests_run = RunAndExtractTestList(args)[0] + self.AssertSetEqual(tests_run, tests_to_run) + + def RunAndVerifyWithSharding(self, gtest_filter, total_shards, tests_to_run, + args=None, check_exit_0=False): + """Checks that binary runs correct tests for the given filter and shard. + + Runs all shards of googletest-filter-unittest_ with the given filter, and + verifies that the right set of tests were run. The union of tests run + on each shard should be identical to tests_to_run, without duplicates. + If check_exit_0, . + + Args: + gtest_filter: A filter to apply to the tests. + total_shards: A total number of shards to split test run into. + tests_to_run: A set of tests expected to run. + args : Arguments to pass to the to the test binary. + check_exit_0: When set to a true value, make sure that all shards + return 0. + """ + + tests_to_run = self.AdjustForParameterizedTests(tests_to_run) + + # Windows removes empty variables from the environment when passing it + # to a new process. This means it is impossible to pass an empty filter + # into a process using the environment variable. However, we can still + # test the case when the variable is not supplied (i.e., gtest_filter is + # None). + # pylint: disable-msg=C6403 + if CAN_TEST_EMPTY_FILTER or gtest_filter != '': + SetEnvVar(FILTER_ENV_VAR, gtest_filter) + partition = [] + for i in range(0, total_shards): + (tests_run, exit_code) = RunWithSharding(total_shards, i, args) + if check_exit_0: + self.assertEqual(0, exit_code) + partition.append(tests_run) + + self.AssertPartitionIsValid(tests_to_run, partition) + SetEnvVar(FILTER_ENV_VAR, None) + # pylint: enable-msg=C6403 + + def RunAndVerifyAllowingDisabled(self, gtest_filter, tests_to_run): + """Checks that the binary runs correct set of tests for the given filter. + + Runs googletest-filter-unittest_ with the given filter, and enables + disabled tests. Verifies that the right set of tests were run. + + Args: + gtest_filter: A filter to apply to the tests. + tests_to_run: A set of tests expected to run. + """ + + tests_to_run = self.AdjustForParameterizedTests(tests_to_run) + + # Construct the command line. + args = ['--%s' % ALSO_RUN_DISABLED_TESTS_FLAG] + if gtest_filter is not None: + args.append('--%s=%s' % (FILTER_FLAG, gtest_filter)) + + tests_run = RunAndExtractTestList(args)[0] + self.AssertSetEqual(tests_run, tests_to_run) + + def setUp(self): + """Sets up test case. + + Determines whether value-parameterized tests are enabled in the binary and + sets the flags accordingly. + """ + + global param_tests_present + if param_tests_present is None: + param_tests_present = PARAM_TEST_REGEX.search( + RunAndReturnOutput()) is not None + + def testDefaultBehavior(self): + """Tests the behavior of not specifying the filter.""" + + self.RunAndVerify(None, ACTIVE_TESTS) + + def testDefaultBehaviorWithShards(self): + """Tests the behavior without the filter, with sharding enabled.""" + + self.RunAndVerifyWithSharding(None, 1, ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, 2, ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS) - 1, ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS), ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS) + 1, ACTIVE_TESTS) + + def testEmptyFilter(self): + """Tests an empty filter.""" + + self.RunAndVerify('', []) + self.RunAndVerifyWithSharding('', 1, []) + self.RunAndVerifyWithSharding('', 2, []) + + def testBadFilter(self): + """Tests a filter that matches nothing.""" + + self.RunAndVerify('BadFilter', []) + self.RunAndVerifyAllowingDisabled('BadFilter', []) + + def testFullName(self): + """Tests filtering by full name.""" + + self.RunAndVerify('FooTest.Xyz', ['FooTest.Xyz']) + self.RunAndVerifyAllowingDisabled('FooTest.Xyz', ['FooTest.Xyz']) + self.RunAndVerifyWithSharding('FooTest.Xyz', 5, ['FooTest.Xyz']) + + def testUniversalFilters(self): + """Tests filters that match everything.""" + + self.RunAndVerify('*', ACTIVE_TESTS) + self.RunAndVerify('*.*', ACTIVE_TESTS) + self.RunAndVerifyWithSharding('*.*', len(ACTIVE_TESTS) - 3, ACTIVE_TESTS) + self.RunAndVerifyAllowingDisabled('*', ACTIVE_TESTS + DISABLED_TESTS) + self.RunAndVerifyAllowingDisabled('*.*', ACTIVE_TESTS + DISABLED_TESTS) + + def testFilterByTestCase(self): + """Tests filtering by test case name.""" + + self.RunAndVerify('FooTest.*', ['FooTest.Abc', 'FooTest.Xyz']) + + BAZ_TESTS = ['BazTest.TestOne', 'BazTest.TestA', 'BazTest.TestB'] + self.RunAndVerify('BazTest.*', BAZ_TESTS) + self.RunAndVerifyAllowingDisabled('BazTest.*', + BAZ_TESTS + ['BazTest.DISABLED_TestC']) + + def testFilterByTest(self): + """Tests filtering by test name.""" + + self.RunAndVerify('*.TestOne', ['BarTest.TestOne', 'BazTest.TestOne']) + + def testFilterDisabledTests(self): + """Select only the disabled tests to run.""" + + self.RunAndVerify('DISABLED_FoobarTest.Test1', []) + self.RunAndVerifyAllowingDisabled('DISABLED_FoobarTest.Test1', + ['DISABLED_FoobarTest.Test1']) + + self.RunAndVerify('*DISABLED_*', []) + self.RunAndVerifyAllowingDisabled('*DISABLED_*', DISABLED_TESTS) + + self.RunAndVerify('*.DISABLED_*', []) + self.RunAndVerifyAllowingDisabled('*.DISABLED_*', [ + 'BarTest.DISABLED_TestFour', + 'BarTest.DISABLED_TestFive', + 'BazTest.DISABLED_TestC', + 'DISABLED_FoobarTest.DISABLED_Test2', + ]) + + self.RunAndVerify('DISABLED_*', []) + self.RunAndVerifyAllowingDisabled('DISABLED_*', [ + 'DISABLED_FoobarTest.Test1', + 'DISABLED_FoobarTest.DISABLED_Test2', + 'DISABLED_FoobarbazTest.TestA', + ]) + + def testWildcardInTestCaseName(self): + """Tests using wildcard in the test case name.""" + + self.RunAndVerify('*a*.*', [ + 'BarTest.TestOne', + 'BarTest.TestTwo', + 'BarTest.TestThree', + + 'BazTest.TestOne', + 'BazTest.TestA', + 'BazTest.TestB', ] + DEATH_TESTS + PARAM_TESTS) + + def testWildcardInTestName(self): + """Tests using wildcard in the test name.""" + + self.RunAndVerify('*.*A*', ['FooTest.Abc', 'BazTest.TestA']) + + def testFilterWithoutDot(self): + """Tests a filter that has no '.' in it.""" + + self.RunAndVerify('*z*', [ + 'FooTest.Xyz', + + 'BazTest.TestOne', + 'BazTest.TestA', + 'BazTest.TestB', + ]) + + def testTwoPatterns(self): + """Tests filters that consist of two patterns.""" + + self.RunAndVerify('Foo*.*:*A*', [ + 'FooTest.Abc', + 'FooTest.Xyz', + + 'BazTest.TestA', + ]) + + # An empty pattern + a non-empty one + self.RunAndVerify(':*A*', ['FooTest.Abc', 'BazTest.TestA']) + + def testThreePatterns(self): + """Tests filters that consist of three patterns.""" + + self.RunAndVerify('*oo*:*A*:*One', [ + 'FooTest.Abc', + 'FooTest.Xyz', + + 'BarTest.TestOne', + + 'BazTest.TestOne', + 'BazTest.TestA', + ]) + + # The 2nd pattern is empty. + self.RunAndVerify('*oo*::*One', [ + 'FooTest.Abc', + 'FooTest.Xyz', + + 'BarTest.TestOne', + + 'BazTest.TestOne', + ]) + + # The last 2 patterns are empty. + self.RunAndVerify('*oo*::', [ + 'FooTest.Abc', + 'FooTest.Xyz', + ]) + + def testNegativeFilters(self): + self.RunAndVerify('*-BazTest.TestOne', [ + 'FooTest.Abc', + 'FooTest.Xyz', + + 'BarTest.TestOne', + 'BarTest.TestTwo', + 'BarTest.TestThree', + + 'BazTest.TestA', + 'BazTest.TestB', + ] + DEATH_TESTS + PARAM_TESTS) + + self.RunAndVerify('*-FooTest.Abc:BazTest.*', [ + 'FooTest.Xyz', + + 'BarTest.TestOne', + 'BarTest.TestTwo', + 'BarTest.TestThree', + ] + DEATH_TESTS + PARAM_TESTS) + + self.RunAndVerify('BarTest.*-BarTest.TestOne', [ + 'BarTest.TestTwo', + 'BarTest.TestThree', + ]) + + # Tests without leading '*'. + self.RunAndVerify('-FooTest.Abc:FooTest.Xyz:BazTest.*', [ + 'BarTest.TestOne', + 'BarTest.TestTwo', + 'BarTest.TestThree', + ] + DEATH_TESTS + PARAM_TESTS) + + # Value parameterized tests. + self.RunAndVerify('*/*', PARAM_TESTS) + + # Value parameterized tests filtering by the sequence name. + self.RunAndVerify('SeqP/*', [ + 'SeqP/ParamTest.TestX/0', + 'SeqP/ParamTest.TestX/1', + 'SeqP/ParamTest.TestY/0', + 'SeqP/ParamTest.TestY/1', + ]) + + # Value parameterized tests filtering by the test name. + self.RunAndVerify('*/0', [ + 'SeqP/ParamTest.TestX/0', + 'SeqP/ParamTest.TestY/0', + 'SeqQ/ParamTest.TestX/0', + 'SeqQ/ParamTest.TestY/0', + ]) + + def testFlagOverridesEnvVar(self): + """Tests that the filter flag overrides the filtering env. variable.""" + + SetEnvVar(FILTER_ENV_VAR, 'Foo*') + args = ['--%s=%s' % (FILTER_FLAG, '*One')] + tests_run = RunAndExtractTestList(args)[0] + SetEnvVar(FILTER_ENV_VAR, None) + + self.AssertSetEqual(tests_run, ['BarTest.TestOne', 'BazTest.TestOne']) + + def testShardStatusFileIsCreated(self): + """Tests that the shard file is created if specified in the environment.""" + + shard_status_file = os.path.join(gtest_test_utils.GetTempDir(), + 'shard_status_file') + self.assert_(not os.path.exists(shard_status_file)) + + extra_env = {SHARD_STATUS_FILE_ENV_VAR: shard_status_file} + try: + InvokeWithModifiedEnv(extra_env, RunAndReturnOutput) + finally: + self.assert_(os.path.exists(shard_status_file)) + os.remove(shard_status_file) + + def testShardStatusFileIsCreatedWithListTests(self): + """Tests that the shard file is created with the "list_tests" flag.""" + + shard_status_file = os.path.join(gtest_test_utils.GetTempDir(), + 'shard_status_file2') + self.assert_(not os.path.exists(shard_status_file)) + + extra_env = {SHARD_STATUS_FILE_ENV_VAR: shard_status_file} + try: + output = InvokeWithModifiedEnv(extra_env, + RunAndReturnOutput, + [LIST_TESTS_FLAG]) + finally: + # This assertion ensures that Google Test enumerated the tests as + # opposed to running them. + self.assert_('[==========]' not in output, + 'Unexpected output during test enumeration.\n' + 'Please ensure that LIST_TESTS_FLAG is assigned the\n' + 'correct flag value for listing Google Test tests.') + + self.assert_(os.path.exists(shard_status_file)) + os.remove(shard_status_file) + + if SUPPORTS_DEATH_TESTS: + def testShardingWorksWithDeathTests(self): + """Tests integration with death tests and sharding.""" + + gtest_filter = 'HasDeathTest.*:SeqP/*' + expected_tests = [ + 'HasDeathTest.Test1', + 'HasDeathTest.Test2', + + 'SeqP/ParamTest.TestX/0', + 'SeqP/ParamTest.TestX/1', + 'SeqP/ParamTest.TestY/0', + 'SeqP/ParamTest.TestY/1', + ] + + for flag in ['--gtest_death_test_style=threadsafe', + '--gtest_death_test_style=fast']: + self.RunAndVerifyWithSharding(gtest_filter, 3, expected_tests, + check_exit_0=True, args=[flag]) + self.RunAndVerifyWithSharding(gtest_filter, 5, expected_tests, + check_exit_0=True, args=[flag]) + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filter-unittest_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filter-unittest_.cc new file mode 100644 index 000000000000..d30ec9c78b56 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-filter-unittest_.cc @@ -0,0 +1,137 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Unit test for Google Test test filters. +// +// A user can specify which test(s) in a Google Test program to run via +// either the GTEST_FILTER environment variable or the --gtest_filter +// flag. This is used for testing such functionality. +// +// The program will be invoked from a Python unit test. Don't run it +// directly. + +#include "gtest/gtest.h" + +namespace { + +// Test case FooTest. + +class FooTest : public testing::Test { +}; + +TEST_F(FooTest, Abc) { +} + +TEST_F(FooTest, Xyz) { + FAIL() << "Expected failure."; +} + +// Test case BarTest. + +TEST(BarTest, TestOne) { +} + +TEST(BarTest, TestTwo) { +} + +TEST(BarTest, TestThree) { +} + +TEST(BarTest, DISABLED_TestFour) { + FAIL() << "Expected failure."; +} + +TEST(BarTest, DISABLED_TestFive) { + FAIL() << "Expected failure."; +} + +// Test case BazTest. + +TEST(BazTest, TestOne) { + FAIL() << "Expected failure."; +} + +TEST(BazTest, TestA) { +} + +TEST(BazTest, TestB) { +} + +TEST(BazTest, DISABLED_TestC) { + FAIL() << "Expected failure."; +} + +// Test case HasDeathTest + +TEST(HasDeathTest, Test1) { + EXPECT_DEATH_IF_SUPPORTED(exit(1), ".*"); +} + +// We need at least two death tests to make sure that the all death tests +// aren't on the first shard. +TEST(HasDeathTest, Test2) { + EXPECT_DEATH_IF_SUPPORTED(exit(1), ".*"); +} + +// Test case FoobarTest + +TEST(DISABLED_FoobarTest, Test1) { + FAIL() << "Expected failure."; +} + +TEST(DISABLED_FoobarTest, DISABLED_Test2) { + FAIL() << "Expected failure."; +} + +// Test case FoobarbazTest + +TEST(DISABLED_FoobarbazTest, TestA) { + FAIL() << "Expected failure."; +} + +class ParamTest : public testing::TestWithParam { +}; + +TEST_P(ParamTest, TestX) { +} + +TEST_P(ParamTest, TestY) { +} + +INSTANTIATE_TEST_SUITE_P(SeqP, ParamTest, testing::Values(1, 2)); +INSTANTIATE_TEST_SUITE_P(SeqQ, ParamTest, testing::Values(5, 6)); + +} // namespace + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-global-environment-unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-global-environment-unittest.py new file mode 100644 index 000000000000..265793442f97 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-global-environment-unittest.py @@ -0,0 +1,130 @@ +# Copyright 2021 Google Inc. All Rights Reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Unit test for Google Test's global test environment behavior. + +A user can specify a global test environment via +testing::AddGlobalTestEnvironment. Failures in the global environment should +result in all unit tests being skipped. + +This script tests such functionality by invoking +googletest-global-environment-unittest_ (a program written with Google Test). +""" + +import re +from googletest.test import gtest_test_utils + + +def RunAndReturnOutput(args=None): + """Runs the test program and returns its output.""" + + return gtest_test_utils.Subprocess([ + gtest_test_utils.GetTestExecutablePath( + 'googletest-global-environment-unittest_') + ] + (args or [])).output + + +class GTestGlobalEnvironmentUnitTest(gtest_test_utils.TestCase): + """Tests global test environment failures.""" + + def testEnvironmentSetUpFails(self): + """Tests the behavior of not specifying the fail_fast.""" + + # Run the test. + txt = RunAndReturnOutput() + + # We should see the text of the global environment setup error. + self.assertIn('Canned environment setup error', txt) + + # Our test should have been skipped due to the error, and not treated as a + # pass. + self.assertIn('[ SKIPPED ] 1 test', txt) + self.assertIn('[ PASSED ] 0 tests', txt) + + # The test case shouldn't have been run. + self.assertNotIn('Unexpected call', txt) + + def testEnvironmentSetUpAndTornDownForEachRepeat(self): + """Tests the behavior of test environments and gtest_repeat.""" + + # When --gtest_recreate_environments_when_repeating is true, the global test + # environment should be set up and torn down for each iteration. + txt = RunAndReturnOutput([ + '--gtest_repeat=2', + '--gtest_recreate_environments_when_repeating=true', + ]) + + expected_pattern = ('(.|\n)*' + r'Repeating all tests \(iteration 1\)' + '(.|\n)*' + 'Global test environment set-up.' + '(.|\n)*' + 'SomeTest.DoesFoo' + '(.|\n)*' + 'Global test environment tear-down' + '(.|\n)*' + r'Repeating all tests \(iteration 2\)' + '(.|\n)*' + 'Global test environment set-up.' + '(.|\n)*' + 'SomeTest.DoesFoo' + '(.|\n)*' + 'Global test environment tear-down' + '(.|\n)*') + self.assertRegex(txt, expected_pattern) + + def testEnvironmentSetUpAndTornDownOnce(self): + """Tests environment and --gtest_recreate_environments_when_repeating.""" + + # By default the environment should only be set up and torn down once, at + # the start and end of the test respectively. + txt = RunAndReturnOutput([ + '--gtest_repeat=2', + ]) + + expected_pattern = ('(.|\n)*' + r'Repeating all tests \(iteration 1\)' + '(.|\n)*' + 'Global test environment set-up.' + '(.|\n)*' + 'SomeTest.DoesFoo' + '(.|\n)*' + r'Repeating all tests \(iteration 2\)' + '(.|\n)*' + 'SomeTest.DoesFoo' + '(.|\n)*' + 'Global test environment tear-down' + '(.|\n)*') + self.assertRegex(txt, expected_pattern) + + self.assertEqual(len(re.findall('Global test environment set-up', txt)), 1) + self.assertEqual( + len(re.findall('Global test environment tear-down', txt)), 1) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-global-environment-unittest_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-global-environment-unittest_.cc new file mode 100644 index 000000000000..f401b2fac25d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-global-environment-unittest_.cc @@ -0,0 +1,58 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Unit test for Google Test global test environments. +// +// The program will be invoked from a Python unit test. Don't run it +// directly. + +#include "gtest/gtest.h" + +namespace { + +// An environment that always fails in its SetUp method. +class FailingEnvironment final : public ::testing::Environment { + public: + void SetUp() override { FAIL() << "Canned environment setup error"; } +}; + +// Register the environment. +auto* const g_environment_ = + ::testing::AddGlobalTestEnvironment(new FailingEnvironment); + +// A test that doesn't actually run. +TEST(SomeTest, DoesFoo) { FAIL() << "Unexpected call"; } + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-json-outfiles-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-json-outfiles-test.py new file mode 100644 index 000000000000..db9716c2de2b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-json-outfiles-test.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# Copyright 2018, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for the gtest_json_output module.""" + +import json +import os +from googletest.test import gtest_json_test_utils +from googletest.test import gtest_test_utils + +GTEST_OUTPUT_SUBDIR = 'json_outfiles' +GTEST_OUTPUT_1_TEST = 'gtest_xml_outfile1_test_' +GTEST_OUTPUT_2_TEST = 'gtest_xml_outfile2_test_' + +EXPECTED_1 = { + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'name': + u'AllTests', + u'testsuites': [{ + u'name': + u'PropertyOne', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'TestSomeProperties', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'PropertyOne', + u'SetUpProp': u'1', + u'TestSomeProperty': u'1', + u'TearDownProp': u'1', + }], + }], +} + +EXPECTED_2 = { + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'name': + u'AllTests', + u'testsuites': [{ + u'name': + u'PropertyTwo', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'TestSomeProperties', + u'status': u'RUN', + u'result': u'COMPLETED', + u'timestamp': u'*', + u'time': u'*', + u'classname': u'PropertyTwo', + u'SetUpProp': u'2', + u'TestSomeProperty': u'2', + u'TearDownProp': u'2', + }], + }], +} + + +class GTestJsonOutFilesTest(gtest_test_utils.TestCase): + """Unit test for Google Test's JSON output functionality.""" + + def setUp(self): + # We want the trailing '/' that the last "" provides in os.path.join, for + # telling Google Test to create an output directory instead of a single file + # for xml output. + self.output_dir_ = os.path.join(gtest_test_utils.GetTempDir(), + GTEST_OUTPUT_SUBDIR, '') + self.DeleteFilesAndDir() + + def tearDown(self): + self.DeleteFilesAndDir() + + def DeleteFilesAndDir(self): + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_1_TEST + '.json')) + except os.error: + pass + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_2_TEST + '.json')) + except os.error: + pass + try: + os.rmdir(self.output_dir_) + except os.error: + pass + + def testOutfile1(self): + self._TestOutFile(GTEST_OUTPUT_1_TEST, EXPECTED_1) + + def testOutfile2(self): + self._TestOutFile(GTEST_OUTPUT_2_TEST, EXPECTED_2) + + def _TestOutFile(self, test_name, expected): + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(test_name) + command = [gtest_prog_path, '--gtest_output=json:%s' % self.output_dir_] + p = gtest_test_utils.Subprocess(command, + working_dir=gtest_test_utils.GetTempDir()) + self.assert_(p.exited) + self.assertEquals(0, p.exit_code) + + output_file_name1 = test_name + '.json' + output_file1 = os.path.join(self.output_dir_, output_file_name1) + output_file_name2 = 'lt-' + output_file_name1 + output_file2 = os.path.join(self.output_dir_, output_file_name2) + self.assert_(os.path.isfile(output_file1) or os.path.isfile(output_file2), + output_file1) + + if os.path.isfile(output_file1): + with open(output_file1) as f: + actual = json.load(f) + else: + with open(output_file2) as f: + actual = json.load(f) + self.assertEqual(expected, gtest_json_test_utils.normalize(actual)) + + +if __name__ == '__main__': + os.environ['GTEST_STACK_TRACE_DEPTH'] = '0' + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-json-output-unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-json-output-unittest.py new file mode 100644 index 000000000000..cb31965e43a4 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-json-output-unittest.py @@ -0,0 +1,848 @@ +#!/usr/bin/env python +# Copyright 2018, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for the gtest_json_output module.""" + +import datetime +import errno +import json +import os +import re +import sys + +from googletest.test import gtest_json_test_utils +from googletest.test import gtest_test_utils + +GTEST_FILTER_FLAG = '--gtest_filter' +GTEST_LIST_TESTS_FLAG = '--gtest_list_tests' +GTEST_OUTPUT_FLAG = '--gtest_output' +GTEST_DEFAULT_OUTPUT_FILE = 'test_detail.json' +GTEST_PROGRAM_NAME = 'gtest_xml_output_unittest_' + +# The flag indicating stacktraces are not supported +NO_STACKTRACE_SUPPORT_FLAG = '--no_stacktrace_support' + +SUPPORTS_STACK_TRACES = NO_STACKTRACE_SUPPORT_FLAG not in sys.argv + +if SUPPORTS_STACK_TRACES: + STACK_TRACE_TEMPLATE = '\nStack trace:\n*' +else: + STACK_TRACE_TEMPLATE = '' + +EXPECTED_NON_EMPTY = { + u'tests': + 26, + u'failures': + 5, + u'disabled': + 2, + u'errors': + 0, + u'timestamp': + u'*', + u'time': + u'*', + u'ad_hoc_property': + u'42', + u'name': + u'AllTests', + u'testsuites': [{ + u'name': + u'SuccessfulTest', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'Succeeds', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'SuccessfulTest' + }] + }, { + u'name': + u'FailedTest', + u'tests': + 1, + u'failures': + 1, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': + u'Fails', + u'status': + u'RUN', + u'result': + u'COMPLETED', + u'time': + u'*', + u'timestamp': + u'*', + u'classname': + u'FailedTest', + u'failures': [{ + u'failure': u'gtest_xml_output_unittest_.cc:*\n' + u'Expected equality of these values:\n' + u' 1\n 2' + STACK_TRACE_TEMPLATE, + u'type': u'' + }] + }] + }, { + u'name': + u'DisabledTest', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 1, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'DISABLED_test_not_run', + u'status': u'NOTRUN', + u'result': u'SUPPRESSED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'DisabledTest' + }] + }, { + u'name': + u'SkippedTest', + u'tests': + 3, + u'failures': + 1, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'Skipped', + u'status': u'RUN', + u'result': u'SKIPPED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'SkippedTest' + }, { + u'name': u'SkippedWithMessage', + u'status': u'RUN', + u'result': u'SKIPPED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'SkippedTest' + }, { + u'name': + u'SkippedAfterFailure', + u'status': + u'RUN', + u'result': + u'COMPLETED', + u'time': + u'*', + u'timestamp': + u'*', + u'classname': + u'SkippedTest', + u'failures': [{ + u'failure': u'gtest_xml_output_unittest_.cc:*\n' + u'Expected equality of these values:\n' + u' 1\n 2' + STACK_TRACE_TEMPLATE, + u'type': u'' + }] + }] + }, { + u'name': + u'MixedResultTest', + u'tests': + 3, + u'failures': + 1, + u'disabled': + 1, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'Succeeds', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'MixedResultTest' + }, { + u'name': + u'Fails', + u'status': + u'RUN', + u'result': + u'COMPLETED', + u'time': + u'*', + u'timestamp': + u'*', + u'classname': + u'MixedResultTest', + u'failures': [{ + u'failure': u'gtest_xml_output_unittest_.cc:*\n' + u'Expected equality of these values:\n' + u' 1\n 2' + STACK_TRACE_TEMPLATE, + u'type': u'' + }, { + u'failure': u'gtest_xml_output_unittest_.cc:*\n' + u'Expected equality of these values:\n' + u' 2\n 3' + STACK_TRACE_TEMPLATE, + u'type': u'' + }] + }, { + u'name': u'DISABLED_test', + u'status': u'NOTRUN', + u'result': u'SUPPRESSED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'MixedResultTest' + }] + }, { + u'name': + u'XmlQuotingTest', + u'tests': + 1, + u'failures': + 1, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': + u'OutputsCData', + u'status': + u'RUN', + u'result': + u'COMPLETED', + u'time': + u'*', + u'timestamp': + u'*', + u'classname': + u'XmlQuotingTest', + u'failures': [{ + u'failure': u'gtest_xml_output_unittest_.cc:*\n' + u'Failed\nXML output: ' + u'' + + STACK_TRACE_TEMPLATE, + u'type': u'' + }] + }] + }, { + u'name': + u'InvalidCharactersTest', + u'tests': + 1, + u'failures': + 1, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': + u'InvalidCharactersInMessage', + u'status': + u'RUN', + u'result': + u'COMPLETED', + u'time': + u'*', + u'timestamp': + u'*', + u'classname': + u'InvalidCharactersTest', + u'failures': [{ + u'failure': u'gtest_xml_output_unittest_.cc:*\n' + u'Failed\nInvalid characters in brackets' + u' [\x01\x02]' + STACK_TRACE_TEMPLATE, + u'type': u'' + }] + }] + }, { + u'name': + u'PropertyRecordingTest', + u'tests': + 4, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'SetUpTestSuite': + u'yes', + u'TearDownTestSuite': + u'aye', + u'testsuite': [{ + u'name': u'OneProperty', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'PropertyRecordingTest', + u'key_1': u'1' + }, { + u'name': u'IntValuedProperty', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'PropertyRecordingTest', + u'key_int': u'1' + }, { + u'name': u'ThreeProperties', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'PropertyRecordingTest', + u'key_1': u'1', + u'key_2': u'2', + u'key_3': u'3' + }, { + u'name': u'TwoValuesForOneKeyUsesLastValue', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'PropertyRecordingTest', + u'key_1': u'2' + }] + }, { + u'name': + u'NoFixtureTest', + u'tests': + 3, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'RecordProperty', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'NoFixtureTest', + u'key': u'1' + }, { + u'name': u'ExternalUtilityThatCallsRecordIntValuedProperty', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'NoFixtureTest', + u'key_for_utility_int': u'1' + }, { + u'name': u'ExternalUtilityThatCallsRecordStringValuedProperty', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'NoFixtureTest', + u'key_for_utility_string': u'1' + }] + }, { + u'name': + u'TypedTest/0', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'HasTypeParamAttribute', + u'type_param': u'int', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'TypedTest/0' + }] + }, { + u'name': + u'TypedTest/1', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'HasTypeParamAttribute', + u'type_param': u'long', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'TypedTest/1' + }] + }, { + u'name': + u'Single/TypeParameterizedTestSuite/0', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'HasTypeParamAttribute', + u'type_param': u'int', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'Single/TypeParameterizedTestSuite/0' + }] + }, { + u'name': + u'Single/TypeParameterizedTestSuite/1', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'HasTypeParamAttribute', + u'type_param': u'long', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'Single/TypeParameterizedTestSuite/1' + }] + }, { + u'name': + u'Single/ValueParamTest', + u'tests': + 4, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'HasValueParamAttribute/0', + u'value_param': u'33', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'Single/ValueParamTest' + }, { + u'name': u'HasValueParamAttribute/1', + u'value_param': u'42', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'Single/ValueParamTest' + }, { + u'name': u'AnotherTestThatHasValueParamAttribute/0', + u'value_param': u'33', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'Single/ValueParamTest' + }, { + u'name': u'AnotherTestThatHasValueParamAttribute/1', + u'value_param': u'42', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'Single/ValueParamTest' + }] + }] +} + +EXPECTED_FILTERED = { + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'name': + u'AllTests', + u'ad_hoc_property': + u'42', + u'testsuites': [{ + u'name': + u'SuccessfulTest', + u'tests': + 1, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': u'Succeeds', + u'status': u'RUN', + u'result': u'COMPLETED', + u'time': u'*', + u'timestamp': u'*', + u'classname': u'SuccessfulTest', + }] + }], +} + +EXPECTED_NO_TEST = { + u'tests': + 0, + u'failures': + 0, + u'disabled': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'name': + u'AllTests', + u'testsuites': [{ + u'name': + u'NonTestSuiteFailure', + u'tests': + 1, + u'failures': + 1, + u'disabled': + 0, + u'skipped': + 0, + u'errors': + 0, + u'time': + u'*', + u'timestamp': + u'*', + u'testsuite': [{ + u'name': + u'', + u'status': + u'RUN', + u'result': + u'COMPLETED', + u'time': + u'*', + u'timestamp': + u'*', + u'classname': + u'', + u'failures': [{ + u'failure': u'gtest_no_test_unittest.cc:*\n' + u'Expected equality of these values:\n' + u' 1\n 2' + STACK_TRACE_TEMPLATE, + u'type': u'', + }] + }] + }], +} + +GTEST_PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath(GTEST_PROGRAM_NAME) + +SUPPORTS_TYPED_TESTS = 'TypedTest' in gtest_test_utils.Subprocess( + [GTEST_PROGRAM_PATH, GTEST_LIST_TESTS_FLAG], capture_stderr=False).output + + +class GTestJsonOutputUnitTest(gtest_test_utils.TestCase): + """Unit test for Google Test's JSON output functionality. + """ + + # This test currently breaks on platforms that do not support typed and + # type-parameterized tests, so we don't run it under them. + if SUPPORTS_TYPED_TESTS: + + def testNonEmptyJsonOutput(self): + """Verifies JSON output for a Google Test binary with non-empty output. + + Runs a test program that generates a non-empty JSON output, and + tests that the JSON output is expected. + """ + self._TestJsonOutput(GTEST_PROGRAM_NAME, EXPECTED_NON_EMPTY, 1) + + def testNoTestJsonOutput(self): + """Verifies JSON output for a Google Test binary without actual tests. + + Runs a test program that generates an JSON output for a binary with no + tests, and tests that the JSON output is expected. + """ + + self._TestJsonOutput('gtest_no_test_unittest', EXPECTED_NO_TEST, 0) + + def testTimestampValue(self): + """Checks whether the timestamp attribute in the JSON output is valid. + + Runs a test program that generates an empty JSON output, and checks if + the timestamp attribute in the testsuites tag is valid. + """ + actual = self._GetJsonOutput('gtest_no_test_unittest', [], 0) + date_time_str = actual['timestamp'] + # datetime.strptime() is only available in Python 2.5+ so we have to + # parse the expected datetime manually. + match = re.match(r'(\d+)-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)', date_time_str) + self.assertTrue( + re.match, + 'JSON datettime string %s has incorrect format' % date_time_str) + date_time_from_json = datetime.datetime( + year=int(match.group(1)), month=int(match.group(2)), + day=int(match.group(3)), hour=int(match.group(4)), + minute=int(match.group(5)), second=int(match.group(6))) + + time_delta = abs(datetime.datetime.now() - date_time_from_json) + # timestamp value should be near the current local time + self.assertTrue(time_delta < datetime.timedelta(seconds=600), + 'time_delta is %s' % time_delta) + + def testDefaultOutputFile(self): + """Verifies the default output file name. + + Confirms that Google Test produces an JSON output file with the expected + default name if no name is explicitly specified. + """ + output_file = os.path.join(gtest_test_utils.GetTempDir(), + GTEST_DEFAULT_OUTPUT_FILE) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath( + 'gtest_no_test_unittest') + try: + os.remove(output_file) + except OSError: + e = sys.exc_info()[1] + if e.errno != errno.ENOENT: + raise + + p = gtest_test_utils.Subprocess( + [gtest_prog_path, '%s=json' % GTEST_OUTPUT_FLAG], + working_dir=gtest_test_utils.GetTempDir()) + self.assert_(p.exited) + self.assertEquals(0, p.exit_code) + self.assert_(os.path.isfile(output_file)) + + def testSuppressedJsonOutput(self): + """Verifies that no JSON output is generated. + + Tests that no JSON file is generated if the default JSON listener is + shut down before RUN_ALL_TESTS is invoked. + """ + + json_path = os.path.join(gtest_test_utils.GetTempDir(), + GTEST_PROGRAM_NAME + 'out.json') + if os.path.isfile(json_path): + os.remove(json_path) + + command = [GTEST_PROGRAM_PATH, + '%s=json:%s' % (GTEST_OUTPUT_FLAG, json_path), + '--shut_down_xml'] + p = gtest_test_utils.Subprocess(command) + if p.terminated_by_signal: + # p.signal is available only if p.terminated_by_signal is True. + self.assertFalse( + p.terminated_by_signal, + '%s was killed by signal %d' % (GTEST_PROGRAM_NAME, p.signal)) + else: + self.assert_(p.exited) + self.assertEquals(1, p.exit_code, + "'%s' exited with code %s, which doesn't match " + 'the expected exit code %s.' + % (command, p.exit_code, 1)) + + self.assert_(not os.path.isfile(json_path)) + + def testFilteredTestJsonOutput(self): + """Verifies JSON output when a filter is applied. + + Runs a test program that executes only some tests and verifies that + non-selected tests do not show up in the JSON output. + """ + + self._TestJsonOutput(GTEST_PROGRAM_NAME, EXPECTED_FILTERED, 0, + extra_args=['%s=SuccessfulTest.*' % GTEST_FILTER_FLAG]) + + def _GetJsonOutput(self, gtest_prog_name, extra_args, expected_exit_code): + """Returns the JSON output generated by running the program gtest_prog_name. + + Furthermore, the program's exit code must be expected_exit_code. + + Args: + gtest_prog_name: Google Test binary name. + extra_args: extra arguments to binary invocation. + expected_exit_code: program's exit code. + """ + json_path = os.path.join(gtest_test_utils.GetTempDir(), + gtest_prog_name + 'out.json') + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(gtest_prog_name) + + command = ( + [gtest_prog_path, '%s=json:%s' % (GTEST_OUTPUT_FLAG, json_path)] + + extra_args + ) + p = gtest_test_utils.Subprocess(command) + if p.terminated_by_signal: + self.assert_(False, + '%s was killed by signal %d' % (gtest_prog_name, p.signal)) + else: + self.assert_(p.exited) + self.assertEquals(expected_exit_code, p.exit_code, + "'%s' exited with code %s, which doesn't match " + 'the expected exit code %s.' + % (command, p.exit_code, expected_exit_code)) + with open(json_path) as f: + actual = json.load(f) + return actual + + def _TestJsonOutput(self, gtest_prog_name, expected, + expected_exit_code, extra_args=None): + """Checks the JSON output generated by the Google Test binary. + + Asserts that the JSON document generated by running the program + gtest_prog_name matches expected_json, a string containing another + JSON document. Furthermore, the program's exit code must be + expected_exit_code. + + Args: + gtest_prog_name: Google Test binary name. + expected: expected output. + expected_exit_code: program's exit code. + extra_args: extra arguments to binary invocation. + """ + + actual = self._GetJsonOutput(gtest_prog_name, extra_args or [], + expected_exit_code) + self.assertEqual(expected, gtest_json_test_utils.normalize(actual)) + + +if __name__ == '__main__': + if NO_STACKTRACE_SUPPORT_FLAG in sys.argv: + # unittest.main() can't handle unknown flags + sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) + + os.environ['GTEST_STACK_TRACE_DEPTH'] = '1' + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-list-tests-unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-list-tests-unittest.py new file mode 100755 index 000000000000..9d56883d741d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-list-tests-unittest.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +# +# Copyright 2006, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for Google Test's --gtest_list_tests flag. + +A user can ask Google Test to list all tests by specifying the +--gtest_list_tests flag. This script tests such functionality +by invoking googletest-list-tests-unittest_ (a program written with +Google Test) the command line flags. +""" + +import re +from googletest.test import gtest_test_utils + +# Constants. + +# The command line flag for enabling/disabling listing all tests. +LIST_TESTS_FLAG = 'gtest_list_tests' + +# Path to the googletest-list-tests-unittest_ program. +EXE_PATH = gtest_test_utils.GetTestExecutablePath('googletest-list-tests-unittest_') + +# The expected output when running googletest-list-tests-unittest_ with +# --gtest_list_tests +EXPECTED_OUTPUT_NO_FILTER_RE = re.compile(r"""FooDeathTest\. + Test1 +Foo\. + Bar1 + Bar2 + DISABLED_Bar3 +Abc\. + Xyz + Def +FooBar\. + Baz +FooTest\. + Test1 + DISABLED_Test2 + Test3 +TypedTest/0\. # TypeParam = (VeryLo{245}|class VeryLo{239})\.\.\. + TestA + TestB +TypedTest/1\. # TypeParam = int\s*\*( __ptr64)? + TestA + TestB +TypedTest/2\. # TypeParam = .*MyArray + TestA + TestB +My/TypeParamTest/0\. # TypeParam = (VeryLo{245}|class VeryLo{239})\.\.\. + TestA + TestB +My/TypeParamTest/1\. # TypeParam = int\s*\*( __ptr64)? + TestA + TestB +My/TypeParamTest/2\. # TypeParam = .*MyArray + TestA + TestB +MyInstantiation/ValueParamTest\. + TestA/0 # GetParam\(\) = one line + TestA/1 # GetParam\(\) = two\\nlines + TestA/2 # GetParam\(\) = a very\\nlo{241}\.\.\. + TestB/0 # GetParam\(\) = one line + TestB/1 # GetParam\(\) = two\\nlines + TestB/2 # GetParam\(\) = a very\\nlo{241}\.\.\. +""") + +# The expected output when running googletest-list-tests-unittest_ with +# --gtest_list_tests and --gtest_filter=Foo*. +EXPECTED_OUTPUT_FILTER_FOO_RE = re.compile(r"""FooDeathTest\. + Test1 +Foo\. + Bar1 + Bar2 + DISABLED_Bar3 +FooBar\. + Baz +FooTest\. + Test1 + DISABLED_Test2 + Test3 +""") + +# Utilities. + + +def Run(args): + """Runs googletest-list-tests-unittest_ and returns the list of tests printed.""" + + return gtest_test_utils.Subprocess([EXE_PATH] + args, + capture_stderr=False).output + + +# The unit test. + + +class GTestListTestsUnitTest(gtest_test_utils.TestCase): + """Tests using the --gtest_list_tests flag to list all tests.""" + + def RunAndVerify(self, flag_value, expected_output_re, other_flag): + """Runs googletest-list-tests-unittest_ and verifies that it prints + the correct tests. + + Args: + flag_value: value of the --gtest_list_tests flag; + None if the flag should not be present. + expected_output_re: regular expression that matches the expected + output after running command; + other_flag: a different flag to be passed to command + along with gtest_list_tests; + None if the flag should not be present. + """ + + if flag_value is None: + flag = '' + flag_expression = 'not set' + elif flag_value == '0': + flag = '--%s=0' % LIST_TESTS_FLAG + flag_expression = '0' + else: + flag = '--%s' % LIST_TESTS_FLAG + flag_expression = '1' + + args = [flag] + + if other_flag is not None: + args += [other_flag] + + output = Run(args) + + if expected_output_re: + self.assert_( + expected_output_re.match(output), + ('when %s is %s, the output of "%s" is "%s",\n' + 'which does not match regex "%s"' % + (LIST_TESTS_FLAG, flag_expression, ' '.join(args), output, + expected_output_re.pattern))) + else: + self.assert_( + not EXPECTED_OUTPUT_NO_FILTER_RE.match(output), + ('when %s is %s, the output of "%s" is "%s"'% + (LIST_TESTS_FLAG, flag_expression, ' '.join(args), output))) + + def testDefaultBehavior(self): + """Tests the behavior of the default mode.""" + + self.RunAndVerify(flag_value=None, + expected_output_re=None, + other_flag=None) + + def testFlag(self): + """Tests using the --gtest_list_tests flag.""" + + self.RunAndVerify(flag_value='0', + expected_output_re=None, + other_flag=None) + self.RunAndVerify(flag_value='1', + expected_output_re=EXPECTED_OUTPUT_NO_FILTER_RE, + other_flag=None) + + def testOverrideNonFilterFlags(self): + """Tests that --gtest_list_tests overrides the non-filter flags.""" + + self.RunAndVerify(flag_value='1', + expected_output_re=EXPECTED_OUTPUT_NO_FILTER_RE, + other_flag='--gtest_break_on_failure') + + def testWithFilterFlags(self): + """Tests that --gtest_list_tests takes into account the + --gtest_filter flag.""" + + self.RunAndVerify(flag_value='1', + expected_output_re=EXPECTED_OUTPUT_FILTER_FOO_RE, + other_flag='--gtest_filter=Foo*') + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-list-tests-unittest_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-list-tests-unittest_.cc new file mode 100644 index 000000000000..493c6f00464a --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-list-tests-unittest_.cc @@ -0,0 +1,156 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Unit test for Google Test's --gtest_list_tests flag. +// +// A user can ask Google Test to list all tests that will run +// so that when using a filter, a user will know what +// tests to look for. The tests will not be run after listing. +// +// This program will be invoked from a Python unit test. +// Don't run it directly. + +#include "gtest/gtest.h" + +// Several different test cases and tests that will be listed. +TEST(Foo, Bar1) { +} + +TEST(Foo, Bar2) { +} + +TEST(Foo, DISABLED_Bar3) { +} + +TEST(Abc, Xyz) { +} + +TEST(Abc, Def) { +} + +TEST(FooBar, Baz) { +} + +class FooTest : public testing::Test { +}; + +TEST_F(FooTest, Test1) { +} + +TEST_F(FooTest, DISABLED_Test2) { +} + +TEST_F(FooTest, Test3) { +} + +TEST(FooDeathTest, Test1) { +} + +// A group of value-parameterized tests. + +class MyType { + public: + explicit MyType(const std::string& a_value) : value_(a_value) {} + + const std::string& value() const { return value_; } + + private: + std::string value_; +}; + +// Teaches Google Test how to print a MyType. +void PrintTo(const MyType& x, std::ostream* os) { + *os << x.value(); +} + +class ValueParamTest : public testing::TestWithParam { +}; + +TEST_P(ValueParamTest, TestA) { +} + +TEST_P(ValueParamTest, TestB) { +} + +INSTANTIATE_TEST_SUITE_P( + MyInstantiation, ValueParamTest, + testing::Values(MyType("one line"), + MyType("two\nlines"), + MyType("a very\nloooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong line"))); // NOLINT + +// A group of typed tests. + +// A deliberately long type name for testing the line-truncating +// behavior when printing a type parameter. +class VeryLoooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooooogName { // NOLINT +}; + +template +class TypedTest : public testing::Test { +}; + +template +class MyArray { +}; + +typedef testing::Types > MyTypes; + +TYPED_TEST_SUITE(TypedTest, MyTypes); + +TYPED_TEST(TypedTest, TestA) { +} + +TYPED_TEST(TypedTest, TestB) { +} + +// A group of type-parameterized tests. + +template +class TypeParamTest : public testing::Test { +}; + +TYPED_TEST_SUITE_P(TypeParamTest); + +TYPED_TEST_P(TypeParamTest, TestA) { +} + +TYPED_TEST_P(TypeParamTest, TestB) { +} + +REGISTER_TYPED_TEST_SUITE_P(TypeParamTest, TestA, TestB); + +INSTANTIATE_TYPED_TEST_SUITE_P(My, TypeParamTest, MyTypes); + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-listener-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-listener-test.cc new file mode 100644 index 000000000000..e7f9b13f51c9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-listener-test.cc @@ -0,0 +1,519 @@ +// Copyright 2009 Google Inc. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This file verifies Google Test event listeners receive events at the +// right times. + +#include + +#include "gtest/gtest.h" +#include "gtest/internal/custom/gtest.h" + +using ::testing::AddGlobalTestEnvironment; +using ::testing::Environment; +using ::testing::InitGoogleTest; +using ::testing::Test; +using ::testing::TestSuite; +using ::testing::TestEventListener; +using ::testing::TestInfo; +using ::testing::TestPartResult; +using ::testing::UnitTest; + +// Used by tests to register their events. +std::vector* g_events = nullptr; + +namespace testing { +namespace internal { + +class EventRecordingListener : public TestEventListener { + public: + explicit EventRecordingListener(const char* name) : name_(name) {} + + protected: + void OnTestProgramStart(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnTestProgramStart")); + } + + void OnTestIterationStart(const UnitTest& /*unit_test*/, + int iteration) override { + Message message; + message << GetFullMethodName("OnTestIterationStart") + << "(" << iteration << ")"; + g_events->push_back(message.GetString()); + } + + void OnEnvironmentsSetUpStart(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsSetUpStart")); + } + + void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsSetUpEnd")); + } +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseStart(const TestCase& /*test_case*/) override { + g_events->push_back(GetFullMethodName("OnTestCaseStart")); + } +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + void OnTestStart(const TestInfo& /*test_info*/) override { + g_events->push_back(GetFullMethodName("OnTestStart")); + } + + void OnTestPartResult(const TestPartResult& /*test_part_result*/) override { + g_events->push_back(GetFullMethodName("OnTestPartResult")); + } + + void OnTestEnd(const TestInfo& /*test_info*/) override { + g_events->push_back(GetFullMethodName("OnTestEnd")); + } + +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + void OnTestCaseEnd(const TestCase& /*test_case*/) override { + g_events->push_back(GetFullMethodName("OnTestCaseEnd")); + } +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + void OnEnvironmentsTearDownStart(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsTearDownStart")); + } + + void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsTearDownEnd")); + } + + void OnTestIterationEnd(const UnitTest& /*unit_test*/, + int iteration) override { + Message message; + message << GetFullMethodName("OnTestIterationEnd") + << "(" << iteration << ")"; + g_events->push_back(message.GetString()); + } + + void OnTestProgramEnd(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnTestProgramEnd")); + } + + private: + std::string GetFullMethodName(const char* name) { + return name_ + "." + name; + } + + std::string name_; +}; + +// This listener is using OnTestSuiteStart, OnTestSuiteEnd API +class EventRecordingListener2 : public TestEventListener { + public: + explicit EventRecordingListener2(const char* name) : name_(name) {} + + protected: + void OnTestProgramStart(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnTestProgramStart")); + } + + void OnTestIterationStart(const UnitTest& /*unit_test*/, + int iteration) override { + Message message; + message << GetFullMethodName("OnTestIterationStart") << "(" << iteration + << ")"; + g_events->push_back(message.GetString()); + } + + void OnEnvironmentsSetUpStart(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsSetUpStart")); + } + + void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsSetUpEnd")); + } + + void OnTestSuiteStart(const TestSuite& /*test_suite*/) override { + g_events->push_back(GetFullMethodName("OnTestSuiteStart")); + } + + void OnTestStart(const TestInfo& /*test_info*/) override { + g_events->push_back(GetFullMethodName("OnTestStart")); + } + + void OnTestPartResult(const TestPartResult& /*test_part_result*/) override { + g_events->push_back(GetFullMethodName("OnTestPartResult")); + } + + void OnTestEnd(const TestInfo& /*test_info*/) override { + g_events->push_back(GetFullMethodName("OnTestEnd")); + } + + void OnTestSuiteEnd(const TestSuite& /*test_suite*/) override { + g_events->push_back(GetFullMethodName("OnTestSuiteEnd")); + } + + void OnEnvironmentsTearDownStart(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsTearDownStart")); + } + + void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnEnvironmentsTearDownEnd")); + } + + void OnTestIterationEnd(const UnitTest& /*unit_test*/, + int iteration) override { + Message message; + message << GetFullMethodName("OnTestIterationEnd") << "(" << iteration + << ")"; + g_events->push_back(message.GetString()); + } + + void OnTestProgramEnd(const UnitTest& /*unit_test*/) override { + g_events->push_back(GetFullMethodName("OnTestProgramEnd")); + } + + private: + std::string GetFullMethodName(const char* name) { return name_ + "." + name; } + + std::string name_; +}; + +class EnvironmentInvocationCatcher : public Environment { + protected: + void SetUp() override { g_events->push_back("Environment::SetUp"); } + + void TearDown() override { g_events->push_back("Environment::TearDown"); } +}; + +class ListenerTest : public Test { + protected: + static void SetUpTestSuite() { + g_events->push_back("ListenerTest::SetUpTestSuite"); + } + + static void TearDownTestSuite() { + g_events->push_back("ListenerTest::TearDownTestSuite"); + } + + void SetUp() override { g_events->push_back("ListenerTest::SetUp"); } + + void TearDown() override { g_events->push_back("ListenerTest::TearDown"); } +}; + +TEST_F(ListenerTest, DoesFoo) { + // Test execution order within a test case is not guaranteed so we are not + // recording the test name. + g_events->push_back("ListenerTest::* Test Body"); + SUCCEED(); // Triggers OnTestPartResult. +} + +TEST_F(ListenerTest, DoesBar) { + g_events->push_back("ListenerTest::* Test Body"); + SUCCEED(); // Triggers OnTestPartResult. +} + +} // namespace internal + +} // namespace testing + +using ::testing::internal::EnvironmentInvocationCatcher; +using ::testing::internal::EventRecordingListener; +using ::testing::internal::EventRecordingListener2; + +void VerifyResults(const std::vector& data, + const char* const* expected_data, + size_t expected_data_size) { + const size_t actual_size = data.size(); + // If the following assertion fails, a new entry will be appended to + // data. Hence we save data.size() first. + EXPECT_EQ(expected_data_size, actual_size); + + // Compares the common prefix. + const size_t shorter_size = expected_data_size <= actual_size ? + expected_data_size : actual_size; + size_t i = 0; + for (; i < shorter_size; ++i) { + ASSERT_STREQ(expected_data[i], data[i].c_str()) + << "at position " << i; + } + + // Prints extra elements in the actual data. + for (; i < actual_size; ++i) { + printf(" Actual event #%lu: %s\n", + static_cast(i), data[i].c_str()); + } +} + +int main(int argc, char **argv) { + std::vector events; + g_events = &events; + InitGoogleTest(&argc, argv); + + UnitTest::GetInstance()->listeners().Append( + new EventRecordingListener("1st")); + UnitTest::GetInstance()->listeners().Append( + new EventRecordingListener("2nd")); + UnitTest::GetInstance()->listeners().Append( + new EventRecordingListener2("3rd")); + + AddGlobalTestEnvironment(new EnvironmentInvocationCatcher); + + GTEST_CHECK_(events.size() == 0) + << "AddGlobalTestEnvironment should not generate any events itself."; + + GTEST_FLAG_SET(repeat, 2); + GTEST_FLAG_SET(recreate_environments_when_repeating, true); + int ret_val = RUN_ALL_TESTS(); + +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + // The deprecated OnTestSuiteStart/OnTestCaseStart events are included + const char* const expected_events[] = {"1st.OnTestProgramStart", + "2nd.OnTestProgramStart", + "3rd.OnTestProgramStart", + "1st.OnTestIterationStart(0)", + "2nd.OnTestIterationStart(0)", + "3rd.OnTestIterationStart(0)", + "1st.OnEnvironmentsSetUpStart", + "2nd.OnEnvironmentsSetUpStart", + "3rd.OnEnvironmentsSetUpStart", + "Environment::SetUp", + "3rd.OnEnvironmentsSetUpEnd", + "2nd.OnEnvironmentsSetUpEnd", + "1st.OnEnvironmentsSetUpEnd", + "3rd.OnTestSuiteStart", + "1st.OnTestCaseStart", + "2nd.OnTestCaseStart", + "ListenerTest::SetUpTestSuite", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "ListenerTest::TearDownTestSuite", + "3rd.OnTestSuiteEnd", + "2nd.OnTestCaseEnd", + "1st.OnTestCaseEnd", + "1st.OnEnvironmentsTearDownStart", + "2nd.OnEnvironmentsTearDownStart", + "3rd.OnEnvironmentsTearDownStart", + "Environment::TearDown", + "3rd.OnEnvironmentsTearDownEnd", + "2nd.OnEnvironmentsTearDownEnd", + "1st.OnEnvironmentsTearDownEnd", + "3rd.OnTestIterationEnd(0)", + "2nd.OnTestIterationEnd(0)", + "1st.OnTestIterationEnd(0)", + "1st.OnTestIterationStart(1)", + "2nd.OnTestIterationStart(1)", + "3rd.OnTestIterationStart(1)", + "1st.OnEnvironmentsSetUpStart", + "2nd.OnEnvironmentsSetUpStart", + "3rd.OnEnvironmentsSetUpStart", + "Environment::SetUp", + "3rd.OnEnvironmentsSetUpEnd", + "2nd.OnEnvironmentsSetUpEnd", + "1st.OnEnvironmentsSetUpEnd", + "3rd.OnTestSuiteStart", + "1st.OnTestCaseStart", + "2nd.OnTestCaseStart", + "ListenerTest::SetUpTestSuite", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "ListenerTest::TearDownTestSuite", + "3rd.OnTestSuiteEnd", + "2nd.OnTestCaseEnd", + "1st.OnTestCaseEnd", + "1st.OnEnvironmentsTearDownStart", + "2nd.OnEnvironmentsTearDownStart", + "3rd.OnEnvironmentsTearDownStart", + "Environment::TearDown", + "3rd.OnEnvironmentsTearDownEnd", + "2nd.OnEnvironmentsTearDownEnd", + "1st.OnEnvironmentsTearDownEnd", + "3rd.OnTestIterationEnd(1)", + "2nd.OnTestIterationEnd(1)", + "1st.OnTestIterationEnd(1)", + "3rd.OnTestProgramEnd", + "2nd.OnTestProgramEnd", + "1st.OnTestProgramEnd"}; +#else + const char* const expected_events[] = {"1st.OnTestProgramStart", + "2nd.OnTestProgramStart", + "3rd.OnTestProgramStart", + "1st.OnTestIterationStart(0)", + "2nd.OnTestIterationStart(0)", + "3rd.OnTestIterationStart(0)", + "1st.OnEnvironmentsSetUpStart", + "2nd.OnEnvironmentsSetUpStart", + "3rd.OnEnvironmentsSetUpStart", + "Environment::SetUp", + "3rd.OnEnvironmentsSetUpEnd", + "2nd.OnEnvironmentsSetUpEnd", + "1st.OnEnvironmentsSetUpEnd", + "3rd.OnTestSuiteStart", + "ListenerTest::SetUpTestSuite", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "ListenerTest::TearDownTestSuite", + "3rd.OnTestSuiteEnd", + "1st.OnEnvironmentsTearDownStart", + "2nd.OnEnvironmentsTearDownStart", + "3rd.OnEnvironmentsTearDownStart", + "Environment::TearDown", + "3rd.OnEnvironmentsTearDownEnd", + "2nd.OnEnvironmentsTearDownEnd", + "1st.OnEnvironmentsTearDownEnd", + "3rd.OnTestIterationEnd(0)", + "2nd.OnTestIterationEnd(0)", + "1st.OnTestIterationEnd(0)", + "1st.OnTestIterationStart(1)", + "2nd.OnTestIterationStart(1)", + "3rd.OnTestIterationStart(1)", + "1st.OnEnvironmentsSetUpStart", + "2nd.OnEnvironmentsSetUpStart", + "3rd.OnEnvironmentsSetUpStart", + "Environment::SetUp", + "3rd.OnEnvironmentsSetUpEnd", + "2nd.OnEnvironmentsSetUpEnd", + "1st.OnEnvironmentsSetUpEnd", + "3rd.OnTestSuiteStart", + "ListenerTest::SetUpTestSuite", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "1st.OnTestStart", + "2nd.OnTestStart", + "3rd.OnTestStart", + "ListenerTest::SetUp", + "ListenerTest::* Test Body", + "1st.OnTestPartResult", + "2nd.OnTestPartResult", + "3rd.OnTestPartResult", + "ListenerTest::TearDown", + "3rd.OnTestEnd", + "2nd.OnTestEnd", + "1st.OnTestEnd", + "ListenerTest::TearDownTestSuite", + "3rd.OnTestSuiteEnd", + "1st.OnEnvironmentsTearDownStart", + "2nd.OnEnvironmentsTearDownStart", + "3rd.OnEnvironmentsTearDownStart", + "Environment::TearDown", + "3rd.OnEnvironmentsTearDownEnd", + "2nd.OnEnvironmentsTearDownEnd", + "1st.OnEnvironmentsTearDownEnd", + "3rd.OnTestIterationEnd(1)", + "2nd.OnTestIterationEnd(1)", + "1st.OnTestIterationEnd(1)", + "3rd.OnTestProgramEnd", + "2nd.OnTestProgramEnd", + "1st.OnTestProgramEnd"}; +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + + VerifyResults(events, + expected_events, + sizeof(expected_events)/sizeof(expected_events[0])); + + // We need to check manually for ad hoc test failures that happen after + // RUN_ALL_TESTS finishes. + if (UnitTest::GetInstance()->Failed()) + ret_val = 1; + + return ret_val; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-message-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-message-test.cc new file mode 100644 index 000000000000..962d519114e6 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-message-test.cc @@ -0,0 +1,158 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests for the Message class. + +#include "gtest/gtest-message.h" + +#include "gtest/gtest.h" + +namespace { + +using ::testing::Message; + +// Tests the testing::Message class + +// Tests the default constructor. +TEST(MessageTest, DefaultConstructor) { + const Message msg; + EXPECT_EQ("", msg.GetString()); +} + +// Tests the copy constructor. +TEST(MessageTest, CopyConstructor) { + const Message msg1("Hello"); + const Message msg2(msg1); + EXPECT_EQ("Hello", msg2.GetString()); +} + +// Tests constructing a Message from a C-string. +TEST(MessageTest, ConstructsFromCString) { + Message msg("Hello"); + EXPECT_EQ("Hello", msg.GetString()); +} + +// Tests streaming a float. +TEST(MessageTest, StreamsFloat) { + const std::string s = (Message() << 1.23456F << " " << 2.34567F).GetString(); + // Both numbers should be printed with enough precision. + EXPECT_PRED_FORMAT2(testing::IsSubstring, "1.234560", s.c_str()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, " 2.345669", s.c_str()); +} + +// Tests streaming a double. +TEST(MessageTest, StreamsDouble) { + const std::string s = (Message() << 1260570880.4555497 << " " + << 1260572265.1954534).GetString(); + // Both numbers should be printed with enough precision. + EXPECT_PRED_FORMAT2(testing::IsSubstring, "1260570880.45", s.c_str()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, " 1260572265.19", s.c_str()); +} + +// Tests streaming a non-char pointer. +TEST(MessageTest, StreamsPointer) { + int n = 0; + int* p = &n; + EXPECT_NE("(null)", (Message() << p).GetString()); +} + +// Tests streaming a NULL non-char pointer. +TEST(MessageTest, StreamsNullPointer) { + int* p = nullptr; + EXPECT_EQ("(null)", (Message() << p).GetString()); +} + +// Tests streaming a C string. +TEST(MessageTest, StreamsCString) { + EXPECT_EQ("Foo", (Message() << "Foo").GetString()); +} + +// Tests streaming a NULL C string. +TEST(MessageTest, StreamsNullCString) { + char* p = nullptr; + EXPECT_EQ("(null)", (Message() << p).GetString()); +} + +// Tests streaming std::string. +TEST(MessageTest, StreamsString) { + const ::std::string str("Hello"); + EXPECT_EQ("Hello", (Message() << str).GetString()); +} + +// Tests that we can output strings containing embedded NULs. +TEST(MessageTest, StreamsStringWithEmbeddedNUL) { + const char char_array_with_nul[] = + "Here's a NUL\0 and some more string"; + const ::std::string string_with_nul(char_array_with_nul, + sizeof(char_array_with_nul) - 1); + EXPECT_EQ("Here's a NUL\\0 and some more string", + (Message() << string_with_nul).GetString()); +} + +// Tests streaming a NUL char. +TEST(MessageTest, StreamsNULChar) { + EXPECT_EQ("\\0", (Message() << '\0').GetString()); +} + +// Tests streaming int. +TEST(MessageTest, StreamsInt) { + EXPECT_EQ("123", (Message() << 123).GetString()); +} + +// Tests that basic IO manipulators (endl, ends, and flush) can be +// streamed to Message. +TEST(MessageTest, StreamsBasicIoManip) { + EXPECT_EQ("Line 1.\nA NUL char \\0 in line 2.", + (Message() << "Line 1." << std::endl + << "A NUL char " << std::ends << std::flush + << " in line 2.").GetString()); +} + +// Tests Message::GetString() +TEST(MessageTest, GetString) { + Message msg; + msg << 1 << " lamb"; + EXPECT_EQ("1 lamb", msg.GetString()); +} + +// Tests streaming a Message object to an ostream. +TEST(MessageTest, StreamsToOStream) { + Message msg("Hello"); + ::std::stringstream ss; + ss << msg; + EXPECT_EQ("Hello", testing::internal::StringStreamToString(&ss)); +} + +// Tests that a Message object doesn't take up too much stack space. +TEST(MessageTest, DoesNotTakeUpMuchStackSpace) { + EXPECT_LE(sizeof(Message), 16U); +} + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-options-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-options-test.cc new file mode 100644 index 000000000000..cd386ff23dd0 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-options-test.cc @@ -0,0 +1,219 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Google Test UnitTestOptions tests +// +// This file tests classes and functions used internally by +// Google Test. They are subject to change without notice. +// +// This file is #included from gtest.cc, to avoid changing build or +// make-files on Windows and other platforms. Do not #include this file +// anywhere else! + +#include "gtest/gtest.h" + +#if GTEST_OS_WINDOWS_MOBILE +# include +#elif GTEST_OS_WINDOWS +# include +#elif GTEST_OS_OS2 +// For strcasecmp on OS/2 +#include +#endif // GTEST_OS_WINDOWS_MOBILE + +#include "src/gtest-internal-inl.h" + +namespace testing { +namespace internal { +namespace { + +// Turns the given relative path into an absolute path. +FilePath GetAbsolutePathOf(const FilePath& relative_path) { + return FilePath::ConcatPaths(FilePath::GetCurrentDir(), relative_path); +} + +// Testing UnitTestOptions::GetOutputFormat/GetOutputFile. + +TEST(XmlOutputTest, GetOutputFormatDefault) { + GTEST_FLAG_SET(output, ""); + EXPECT_STREQ("", UnitTestOptions::GetOutputFormat().c_str()); +} + +TEST(XmlOutputTest, GetOutputFormat) { + GTEST_FLAG_SET(output, "xml:filename"); + EXPECT_STREQ("xml", UnitTestOptions::GetOutputFormat().c_str()); +} + +TEST(XmlOutputTest, GetOutputFileDefault) { + GTEST_FLAG_SET(output, ""); + EXPECT_EQ(GetAbsolutePathOf(FilePath("test_detail.xml")).string(), + UnitTestOptions::GetAbsolutePathToOutputFile()); +} + +TEST(XmlOutputTest, GetOutputFileSingleFile) { + GTEST_FLAG_SET(output, "xml:filename.abc"); + EXPECT_EQ(GetAbsolutePathOf(FilePath("filename.abc")).string(), + UnitTestOptions::GetAbsolutePathToOutputFile()); +} + +TEST(XmlOutputTest, GetOutputFileFromDirectoryPath) { + GTEST_FLAG_SET(output, "xml:path" GTEST_PATH_SEP_); + const std::string expected_output_file = + GetAbsolutePathOf( + FilePath(std::string("path") + GTEST_PATH_SEP_ + + GetCurrentExecutableName().string() + ".xml")).string(); + const std::string& output_file = + UnitTestOptions::GetAbsolutePathToOutputFile(); +#if GTEST_OS_WINDOWS + EXPECT_STRCASEEQ(expected_output_file.c_str(), output_file.c_str()); +#else + EXPECT_EQ(expected_output_file, output_file.c_str()); +#endif +} + +TEST(OutputFileHelpersTest, GetCurrentExecutableName) { + const std::string exe_str = GetCurrentExecutableName().string(); +#if GTEST_OS_WINDOWS + const bool success = + _strcmpi("googletest-options-test", exe_str.c_str()) == 0 || + _strcmpi("gtest-options-ex_test", exe_str.c_str()) == 0 || + _strcmpi("gtest_all_test", exe_str.c_str()) == 0 || + _strcmpi("gtest_dll_test", exe_str.c_str()) == 0; +#elif GTEST_OS_OS2 + const bool success = + strcasecmp("googletest-options-test", exe_str.c_str()) == 0 || + strcasecmp("gtest-options-ex_test", exe_str.c_str()) == 0 || + strcasecmp("gtest_all_test", exe_str.c_str()) == 0 || + strcasecmp("gtest_dll_test", exe_str.c_str()) == 0; +#elif GTEST_OS_FUCHSIA + const bool success = exe_str == "app"; +#else + const bool success = + exe_str == "googletest-options-test" || + exe_str == "gtest_all_test" || + exe_str == "lt-gtest_all_test" || + exe_str == "gtest_dll_test"; +#endif // GTEST_OS_WINDOWS + if (!success) + FAIL() << "GetCurrentExecutableName() returns " << exe_str; +} + +#if !GTEST_OS_FUCHSIA + +class XmlOutputChangeDirTest : public Test { + protected: + void SetUp() override { + original_working_dir_ = FilePath::GetCurrentDir(); + posix::ChDir(".."); + // This will make the test fail if run from the root directory. + EXPECT_NE(original_working_dir_.string(), + FilePath::GetCurrentDir().string()); + } + + void TearDown() override { + posix::ChDir(original_working_dir_.string().c_str()); + } + + FilePath original_working_dir_; +}; + +TEST_F(XmlOutputChangeDirTest, PreserveOriginalWorkingDirWithDefault) { + GTEST_FLAG_SET(output, ""); + EXPECT_EQ(FilePath::ConcatPaths(original_working_dir_, + FilePath("test_detail.xml")).string(), + UnitTestOptions::GetAbsolutePathToOutputFile()); +} + +TEST_F(XmlOutputChangeDirTest, PreserveOriginalWorkingDirWithDefaultXML) { + GTEST_FLAG_SET(output, "xml"); + EXPECT_EQ(FilePath::ConcatPaths(original_working_dir_, + FilePath("test_detail.xml")).string(), + UnitTestOptions::GetAbsolutePathToOutputFile()); +} + +TEST_F(XmlOutputChangeDirTest, PreserveOriginalWorkingDirWithRelativeFile) { + GTEST_FLAG_SET(output, "xml:filename.abc"); + EXPECT_EQ(FilePath::ConcatPaths(original_working_dir_, + FilePath("filename.abc")).string(), + UnitTestOptions::GetAbsolutePathToOutputFile()); +} + +TEST_F(XmlOutputChangeDirTest, PreserveOriginalWorkingDirWithRelativePath) { + GTEST_FLAG_SET(output, "xml:path" GTEST_PATH_SEP_); + const std::string expected_output_file = + FilePath::ConcatPaths( + original_working_dir_, + FilePath(std::string("path") + GTEST_PATH_SEP_ + + GetCurrentExecutableName().string() + ".xml")).string(); + const std::string& output_file = + UnitTestOptions::GetAbsolutePathToOutputFile(); +#if GTEST_OS_WINDOWS + EXPECT_STRCASEEQ(expected_output_file.c_str(), output_file.c_str()); +#else + EXPECT_EQ(expected_output_file, output_file.c_str()); +#endif +} + +TEST_F(XmlOutputChangeDirTest, PreserveOriginalWorkingDirWithAbsoluteFile) { +#if GTEST_OS_WINDOWS + GTEST_FLAG_SET(output, "xml:c:\\tmp\\filename.abc"); + EXPECT_EQ(FilePath("c:\\tmp\\filename.abc").string(), + UnitTestOptions::GetAbsolutePathToOutputFile()); +#else + GTEST_FLAG_SET(output, "xml:/tmp/filename.abc"); + EXPECT_EQ(FilePath("/tmp/filename.abc").string(), + UnitTestOptions::GetAbsolutePathToOutputFile()); +#endif +} + +TEST_F(XmlOutputChangeDirTest, PreserveOriginalWorkingDirWithAbsolutePath) { +#if GTEST_OS_WINDOWS + const std::string path = "c:\\tmp\\"; +#else + const std::string path = "/tmp/"; +#endif + + GTEST_FLAG_SET(output, "xml:" + path); + const std::string expected_output_file = + path + GetCurrentExecutableName().string() + ".xml"; + const std::string& output_file = + UnitTestOptions::GetAbsolutePathToOutputFile(); + +#if GTEST_OS_WINDOWS + EXPECT_STRCASEEQ(expected_output_file.c_str(), output_file.c_str()); +#else + EXPECT_EQ(expected_output_file, output_file.c_str()); +#endif +} + +#endif // !GTEST_OS_FUCHSIA + +} // namespace +} // namespace internal +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test-golden-lin.txt b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test-golden-lin.txt new file mode 100644 index 000000000000..1f24fb791539 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test-golden-lin.txt @@ -0,0 +1,1196 @@ +The non-test part of the code is expected to have 2 failures. + +googletest-output-test_.cc:#: Failure +Value of: false + Actual: false +Expected: true +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 2 + 3 +Stack trace: (omitted) + +[==========] Running 89 tests from 42 test suites. +[----------] Global test environment set-up. +FooEnvironment::SetUp() called. +BarEnvironment::SetUp() called. +[----------] 1 test from ADeathTest +[ RUN ] ADeathTest.ShouldRunFirst +[ OK ] ADeathTest.ShouldRunFirst +[----------] 1 test from ATypedDeathTest/0, where TypeParam = int +[ RUN ] ATypedDeathTest/0.ShouldRunFirst +[ OK ] ATypedDeathTest/0.ShouldRunFirst +[----------] 1 test from ATypedDeathTest/1, where TypeParam = double +[ RUN ] ATypedDeathTest/1.ShouldRunFirst +[ OK ] ATypedDeathTest/1.ShouldRunFirst +[----------] 1 test from My/ATypeParamDeathTest/0, where TypeParam = int +[ RUN ] My/ATypeParamDeathTest/0.ShouldRunFirst +[ OK ] My/ATypeParamDeathTest/0.ShouldRunFirst +[----------] 1 test from My/ATypeParamDeathTest/1, where TypeParam = double +[ RUN ] My/ATypeParamDeathTest/1.ShouldRunFirst +[ OK ] My/ATypeParamDeathTest/1.ShouldRunFirst +[----------] 2 tests from PassingTest +[ RUN ] PassingTest.PassingTest1 +[ OK ] PassingTest.PassingTest1 +[ RUN ] PassingTest.PassingTest2 +[ OK ] PassingTest.PassingTest2 +[----------] 2 tests from NonfatalFailureTest +[ RUN ] NonfatalFailureTest.EscapesStringOperands +googletest-output-test_.cc:#: Failure +Expected equality of these values: + kGoldenString + Which is: "\"Line" + actual + Which is: "actual \"string\"" +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Expected equality of these values: + golden + Which is: "\"Line" + actual + Which is: "actual \"string\"" +Stack trace: (omitted) + +[ FAILED ] NonfatalFailureTest.EscapesStringOperands +[ RUN ] NonfatalFailureTest.DiffForLongStrings +googletest-output-test_.cc:#: Failure +Expected equality of these values: + golden_str + Which is: "\"Line\0 1\"\nLine 2" + "Line 2" +With diff: +@@ -1,2 @@ +-\"Line\0 1\" + Line 2 + +Stack trace: (omitted) + +[ FAILED ] NonfatalFailureTest.DiffForLongStrings +[----------] 3 tests from FatalFailureTest +[ RUN ] FatalFailureTest.FatalFailureInSubroutine +(expecting a failure that x should be 1) +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + x + Which is: 2 +Stack trace: (omitted) + +[ FAILED ] FatalFailureTest.FatalFailureInSubroutine +[ RUN ] FatalFailureTest.FatalFailureInNestedSubroutine +(expecting a failure that x should be 1) +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + x + Which is: 2 +Stack trace: (omitted) + +[ FAILED ] FatalFailureTest.FatalFailureInNestedSubroutine +[ RUN ] FatalFailureTest.NonfatalFailureInSubroutine +(expecting a failure on false) +googletest-output-test_.cc:#: Failure +Value of: false + Actual: false +Expected: true +Stack trace: (omitted) + +[ FAILED ] FatalFailureTest.NonfatalFailureInSubroutine +[----------] 1 test from LoggingTest +[ RUN ] LoggingTest.InterleavingLoggingAndAssertions +(expecting 2 failures on (3) >= (a[i])) +i == 0 +i == 1 +googletest-output-test_.cc:#: Failure +Expected: (3) >= (a[i]), actual: 3 vs 9 +Stack trace: (omitted) + +i == 2 +i == 3 +googletest-output-test_.cc:#: Failure +Expected: (3) >= (a[i]), actual: 3 vs 6 +Stack trace: (omitted) + +[ FAILED ] LoggingTest.InterleavingLoggingAndAssertions +[----------] 7 tests from SCOPED_TRACETest +[ RUN ] SCOPED_TRACETest.AcceptedValues +googletest-output-test_.cc:#: Failure +Failed +Just checking that all these values work fine. +Google Test trace: +googletest-output-test_.cc:#: (null) +googletest-output-test_.cc:#: 1337 +googletest-output-test_.cc:#: std::string +googletest-output-test_.cc:#: literal string +Stack trace: (omitted) + +[ FAILED ] SCOPED_TRACETest.AcceptedValues +[ RUN ] SCOPED_TRACETest.ObeysScopes +(expected to fail) +googletest-output-test_.cc:#: Failure +Failed +This failure is expected, and shouldn't have a trace. +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +This failure is expected, and should have a trace. +Google Test trace: +googletest-output-test_.cc:#: Expected trace +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +This failure is expected, and shouldn't have a trace. +Stack trace: (omitted) + +[ FAILED ] SCOPED_TRACETest.ObeysScopes +[ RUN ] SCOPED_TRACETest.WorksInLoop +(expected to fail) +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 2 + n + Which is: 1 +Google Test trace: +googletest-output-test_.cc:#: i = 1 +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + n + Which is: 2 +Google Test trace: +googletest-output-test_.cc:#: i = 2 +Stack trace: (omitted) + +[ FAILED ] SCOPED_TRACETest.WorksInLoop +[ RUN ] SCOPED_TRACETest.WorksInSubroutine +(expected to fail) +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 2 + n + Which is: 1 +Google Test trace: +googletest-output-test_.cc:#: n = 1 +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + n + Which is: 2 +Google Test trace: +googletest-output-test_.cc:#: n = 2 +Stack trace: (omitted) + +[ FAILED ] SCOPED_TRACETest.WorksInSubroutine +[ RUN ] SCOPED_TRACETest.CanBeNested +(expected to fail) +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + n + Which is: 2 +Google Test trace: +googletest-output-test_.cc:#: n = 2 +googletest-output-test_.cc:#: +Stack trace: (omitted) + +[ FAILED ] SCOPED_TRACETest.CanBeNested +[ RUN ] SCOPED_TRACETest.CanBeRepeated +(expected to fail) +googletest-output-test_.cc:#: Failure +Failed +This failure is expected, and should contain trace point A. +Google Test trace: +googletest-output-test_.cc:#: A +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +This failure is expected, and should contain trace point A and B. +Google Test trace: +googletest-output-test_.cc:#: B +googletest-output-test_.cc:#: A +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +This failure is expected, and should contain trace point A, B, and C. +Google Test trace: +googletest-output-test_.cc:#: C +googletest-output-test_.cc:#: B +googletest-output-test_.cc:#: A +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +This failure is expected, and should contain trace point A, B, and D. +Google Test trace: +googletest-output-test_.cc:#: D +googletest-output-test_.cc:#: B +googletest-output-test_.cc:#: A +Stack trace: (omitted) + +[ FAILED ] SCOPED_TRACETest.CanBeRepeated +[ RUN ] SCOPED_TRACETest.WorksConcurrently +(expecting 6 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected failure #1 (in thread B, only trace B alive). +Google Test trace: +googletest-output-test_.cc:#: Trace B +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #2 (in thread A, trace A & B both alive). +Google Test trace: +googletest-output-test_.cc:#: Trace A +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #3 (in thread B, trace A & B both alive). +Google Test trace: +googletest-output-test_.cc:#: Trace B +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #4 (in thread B, only trace A alive). +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #5 (in thread A, only trace A alive). +Google Test trace: +googletest-output-test_.cc:#: Trace A +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #6 (in thread A, no trace alive). +Stack trace: (omitted) + +[ FAILED ] SCOPED_TRACETest.WorksConcurrently +[----------] 1 test from ScopedTraceTest +[ RUN ] ScopedTraceTest.WithExplicitFileAndLine +googletest-output-test_.cc:#: Failure +Failed +Check that the trace is attached to a particular location. +Google Test trace: +explicit_file.cc:123: expected trace message +Stack trace: (omitted) + +[ FAILED ] ScopedTraceTest.WithExplicitFileAndLine +[----------] 1 test from NonFatalFailureInFixtureConstructorTest +[ RUN ] NonFatalFailureInFixtureConstructorTest.FailureInConstructor +(expecting 5 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected failure #1, in the test fixture c'tor. +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #2, in SetUp(). +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #3, in the test body. +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #4, in TearDown. +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #5, in the test fixture d'tor. +Stack trace: (omitted) + +[ FAILED ] NonFatalFailureInFixtureConstructorTest.FailureInConstructor +[----------] 1 test from FatalFailureInFixtureConstructorTest +[ RUN ] FatalFailureInFixtureConstructorTest.FailureInConstructor +(expecting 2 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected failure #1, in the test fixture c'tor. +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #2, in the test fixture d'tor. +Stack trace: (omitted) + +[ FAILED ] FatalFailureInFixtureConstructorTest.FailureInConstructor +[----------] 1 test from NonFatalFailureInSetUpTest +[ RUN ] NonFatalFailureInSetUpTest.FailureInSetUp +(expecting 4 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected failure #1, in SetUp(). +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #2, in the test function. +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #3, in TearDown(). +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #4, in the test fixture d'tor. +Stack trace: (omitted) + +[ FAILED ] NonFatalFailureInSetUpTest.FailureInSetUp +[----------] 1 test from FatalFailureInSetUpTest +[ RUN ] FatalFailureInSetUpTest.FailureInSetUp +(expecting 3 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected failure #1, in SetUp(). +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #2, in TearDown(). +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected failure #3, in the test fixture d'tor. +Stack trace: (omitted) + +[ FAILED ] FatalFailureInSetUpTest.FailureInSetUp +[----------] 1 test from AddFailureAtTest +[ RUN ] AddFailureAtTest.MessageContainsSpecifiedFileAndLineNumber +foo.cc:42: Failure +Failed +Expected nonfatal failure in foo.cc +Stack trace: (omitted) + +[ FAILED ] AddFailureAtTest.MessageContainsSpecifiedFileAndLineNumber +[----------] 1 test from GtestFailAtTest +[ RUN ] GtestFailAtTest.MessageContainsSpecifiedFileAndLineNumber +foo.cc:42: Failure +Failed +Expected fatal failure in foo.cc +Stack trace: (omitted) + +[ FAILED ] GtestFailAtTest.MessageContainsSpecifiedFileAndLineNumber +[----------] 4 tests from MixedUpTestSuiteTest +[ RUN ] MixedUpTestSuiteTest.FirstTestFromNamespaceFoo +[ OK ] MixedUpTestSuiteTest.FirstTestFromNamespaceFoo +[ RUN ] MixedUpTestSuiteTest.SecondTestFromNamespaceFoo +[ OK ] MixedUpTestSuiteTest.SecondTestFromNamespaceFoo +[ RUN ] MixedUpTestSuiteTest.ThisShouldFail +gtest.cc:#: Failure +Failed +All tests in the same test suite must use the same test fixture +class. However, in test suite MixedUpTestSuiteTest, +you defined test FirstTestFromNamespaceFoo and test ThisShouldFail +using two different test fixture classes. This can happen if +the two classes are from different namespaces or translation +units and have the same name. You should probably rename one +of the classes to put the tests into different test suites. +Stack trace: (omitted) + +[ FAILED ] MixedUpTestSuiteTest.ThisShouldFail +[ RUN ] MixedUpTestSuiteTest.ThisShouldFailToo +gtest.cc:#: Failure +Failed +All tests in the same test suite must use the same test fixture +class. However, in test suite MixedUpTestSuiteTest, +you defined test FirstTestFromNamespaceFoo and test ThisShouldFailToo +using two different test fixture classes. This can happen if +the two classes are from different namespaces or translation +units and have the same name. You should probably rename one +of the classes to put the tests into different test suites. +Stack trace: (omitted) + +[ FAILED ] MixedUpTestSuiteTest.ThisShouldFailToo +[----------] 2 tests from MixedUpTestSuiteWithSameTestNameTest +[ RUN ] MixedUpTestSuiteWithSameTestNameTest.TheSecondTestWithThisNameShouldFail +[ OK ] MixedUpTestSuiteWithSameTestNameTest.TheSecondTestWithThisNameShouldFail +[ RUN ] MixedUpTestSuiteWithSameTestNameTest.TheSecondTestWithThisNameShouldFail +gtest.cc:#: Failure +Failed +All tests in the same test suite must use the same test fixture +class. However, in test suite MixedUpTestSuiteWithSameTestNameTest, +you defined test TheSecondTestWithThisNameShouldFail and test TheSecondTestWithThisNameShouldFail +using two different test fixture classes. This can happen if +the two classes are from different namespaces or translation +units and have the same name. You should probably rename one +of the classes to put the tests into different test suites. +Stack trace: (omitted) + +[ FAILED ] MixedUpTestSuiteWithSameTestNameTest.TheSecondTestWithThisNameShouldFail +[----------] 2 tests from TEST_F_before_TEST_in_same_test_case +[ RUN ] TEST_F_before_TEST_in_same_test_case.DefinedUsingTEST_F +[ OK ] TEST_F_before_TEST_in_same_test_case.DefinedUsingTEST_F +[ RUN ] TEST_F_before_TEST_in_same_test_case.DefinedUsingTESTAndShouldFail +gtest.cc:#: Failure +Failed +All tests in the same test suite must use the same test fixture +class, so mixing TEST_F and TEST in the same test suite is +illegal. In test suite TEST_F_before_TEST_in_same_test_case, +test DefinedUsingTEST_F is defined using TEST_F but +test DefinedUsingTESTAndShouldFail is defined using TEST. You probably +want to change the TEST to TEST_F or move it to another test +case. +Stack trace: (omitted) + +[ FAILED ] TEST_F_before_TEST_in_same_test_case.DefinedUsingTESTAndShouldFail +[----------] 2 tests from TEST_before_TEST_F_in_same_test_case +[ RUN ] TEST_before_TEST_F_in_same_test_case.DefinedUsingTEST +[ OK ] TEST_before_TEST_F_in_same_test_case.DefinedUsingTEST +[ RUN ] TEST_before_TEST_F_in_same_test_case.DefinedUsingTEST_FAndShouldFail +gtest.cc:#: Failure +Failed +All tests in the same test suite must use the same test fixture +class, so mixing TEST_F and TEST in the same test suite is +illegal. In test suite TEST_before_TEST_F_in_same_test_case, +test DefinedUsingTEST_FAndShouldFail is defined using TEST_F but +test DefinedUsingTEST is defined using TEST. You probably +want to change the TEST to TEST_F or move it to another test +case. +Stack trace: (omitted) + +[ FAILED ] TEST_before_TEST_F_in_same_test_case.DefinedUsingTEST_FAndShouldFail +[----------] 8 tests from ExpectNonfatalFailureTest +[ RUN ] ExpectNonfatalFailureTest.CanReferenceGlobalVariables +[ OK ] ExpectNonfatalFailureTest.CanReferenceGlobalVariables +[ RUN ] ExpectNonfatalFailureTest.CanReferenceLocalVariables +[ OK ] ExpectNonfatalFailureTest.CanReferenceLocalVariables +[ RUN ] ExpectNonfatalFailureTest.SucceedsWhenThereIsOneNonfatalFailure +[ OK ] ExpectNonfatalFailureTest.SucceedsWhenThereIsOneNonfatalFailure +[ RUN ] ExpectNonfatalFailureTest.FailsWhenThereIsNoNonfatalFailure +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenThereIsNoNonfatalFailure +[ RUN ] ExpectNonfatalFailureTest.FailsWhenThereAreTwoNonfatalFailures +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: 2 failures +googletest-output-test_.cc:#: Non-fatal failure: +Failed +Expected non-fatal failure 1. +Stack trace: (omitted) + + +googletest-output-test_.cc:#: Non-fatal failure: +Failed +Expected non-fatal failure 2. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenThereAreTwoNonfatalFailures +[ RUN ] ExpectNonfatalFailureTest.FailsWhenThereIsOneFatalFailure +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: +googletest-output-test_.cc:#: Fatal failure: +Failed +Expected fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenThereIsOneFatalFailure +[ RUN ] ExpectNonfatalFailureTest.FailsWhenStatementReturns +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenStatementReturns +[ RUN ] ExpectNonfatalFailureTest.FailsWhenStatementThrows +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenStatementThrows +[----------] 8 tests from ExpectFatalFailureTest +[ RUN ] ExpectFatalFailureTest.CanReferenceGlobalVariables +[ OK ] ExpectFatalFailureTest.CanReferenceGlobalVariables +[ RUN ] ExpectFatalFailureTest.CanReferenceLocalStaticVariables +[ OK ] ExpectFatalFailureTest.CanReferenceLocalStaticVariables +[ RUN ] ExpectFatalFailureTest.SucceedsWhenThereIsOneFatalFailure +[ OK ] ExpectFatalFailureTest.SucceedsWhenThereIsOneFatalFailure +[ RUN ] ExpectFatalFailureTest.FailsWhenThereIsNoFatalFailure +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectFatalFailureTest.FailsWhenThereIsNoFatalFailure +[ RUN ] ExpectFatalFailureTest.FailsWhenThereAreTwoFatalFailures +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: 2 failures +googletest-output-test_.cc:#: Fatal failure: +Failed +Expected fatal failure. +Stack trace: (omitted) + + +googletest-output-test_.cc:#: Fatal failure: +Failed +Expected fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectFatalFailureTest.FailsWhenThereAreTwoFatalFailures +[ RUN ] ExpectFatalFailureTest.FailsWhenThereIsOneNonfatalFailure +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: +googletest-output-test_.cc:#: Non-fatal failure: +Failed +Expected non-fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectFatalFailureTest.FailsWhenThereIsOneNonfatalFailure +[ RUN ] ExpectFatalFailureTest.FailsWhenStatementReturns +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectFatalFailureTest.FailsWhenStatementReturns +[ RUN ] ExpectFatalFailureTest.FailsWhenStatementThrows +(expecting a failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectFatalFailureTest.FailsWhenStatementThrows +[----------] 2 tests from TypedTest/0, where TypeParam = int +[ RUN ] TypedTest/0.Success +[ OK ] TypedTest/0.Success +[ RUN ] TypedTest/0.Failure +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + TypeParam() + Which is: 0 +Expected failure +Stack trace: (omitted) + +[ FAILED ] TypedTest/0.Failure, where TypeParam = int +[----------] 2 tests from TypedTestWithNames/char0, where TypeParam = char +[ RUN ] TypedTestWithNames/char0.Success +[ OK ] TypedTestWithNames/char0.Success +[ RUN ] TypedTestWithNames/char0.Failure +googletest-output-test_.cc:#: Failure +Failed +Stack trace: (omitted) + +[ FAILED ] TypedTestWithNames/char0.Failure, where TypeParam = char +[----------] 2 tests from TypedTestWithNames/int1, where TypeParam = int +[ RUN ] TypedTestWithNames/int1.Success +[ OK ] TypedTestWithNames/int1.Success +[ RUN ] TypedTestWithNames/int1.Failure +googletest-output-test_.cc:#: Failure +Failed +Stack trace: (omitted) + +[ FAILED ] TypedTestWithNames/int1.Failure, where TypeParam = int +[----------] 2 tests from Unsigned/TypedTestP/0, where TypeParam = unsigned char +[ RUN ] Unsigned/TypedTestP/0.Success +[ OK ] Unsigned/TypedTestP/0.Success +[ RUN ] Unsigned/TypedTestP/0.Failure +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1U + Which is: 1 + TypeParam() + Which is: '\0' +Expected failure +Stack trace: (omitted) + +[ FAILED ] Unsigned/TypedTestP/0.Failure, where TypeParam = unsigned char +[----------] 2 tests from Unsigned/TypedTestP/1, where TypeParam = unsigned int +[ RUN ] Unsigned/TypedTestP/1.Success +[ OK ] Unsigned/TypedTestP/1.Success +[ RUN ] Unsigned/TypedTestP/1.Failure +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1U + Which is: 1 + TypeParam() + Which is: 0 +Expected failure +Stack trace: (omitted) + +[ FAILED ] Unsigned/TypedTestP/1.Failure, where TypeParam = unsigned int +[----------] 2 tests from UnsignedCustomName/TypedTestP/unsignedChar0, where TypeParam = unsigned char +[ RUN ] UnsignedCustomName/TypedTestP/unsignedChar0.Success +[ OK ] UnsignedCustomName/TypedTestP/unsignedChar0.Success +[ RUN ] UnsignedCustomName/TypedTestP/unsignedChar0.Failure +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1U + Which is: 1 + TypeParam() + Which is: '\0' +Expected failure +Stack trace: (omitted) + +[ FAILED ] UnsignedCustomName/TypedTestP/unsignedChar0.Failure, where TypeParam = unsigned char +[----------] 2 tests from UnsignedCustomName/TypedTestP/unsignedInt1, where TypeParam = unsigned int +[ RUN ] UnsignedCustomName/TypedTestP/unsignedInt1.Success +[ OK ] UnsignedCustomName/TypedTestP/unsignedInt1.Success +[ RUN ] UnsignedCustomName/TypedTestP/unsignedInt1.Failure +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1U + Which is: 1 + TypeParam() + Which is: 0 +Expected failure +Stack trace: (omitted) + +[ FAILED ] UnsignedCustomName/TypedTestP/unsignedInt1.Failure, where TypeParam = unsigned int +[----------] 4 tests from ExpectFailureTest +[ RUN ] ExpectFailureTest.ExpectFatalFailure +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: +googletest-output-test_.cc:#: Success: +Succeeded +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: +googletest-output-test_.cc:#: Non-fatal failure: +Failed +Expected non-fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 fatal failure containing "Some other fatal failure expected." + Actual: +googletest-output-test_.cc:#: Fatal failure: +Failed +Expected fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectFailureTest.ExpectFatalFailure +[ RUN ] ExpectFailureTest.ExpectNonFatalFailure +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: +googletest-output-test_.cc:#: Success: +Succeeded +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: +googletest-output-test_.cc:#: Fatal failure: +Failed +Expected fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure containing "Some other non-fatal failure." + Actual: +googletest-output-test_.cc:#: Non-fatal failure: +Failed +Expected non-fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectFailureTest.ExpectNonFatalFailure +[ RUN ] ExpectFailureTest.ExpectFatalFailureOnAllThreads +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: +googletest-output-test_.cc:#: Success: +Succeeded +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: +googletest-output-test_.cc:#: Non-fatal failure: +Failed +Expected non-fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 fatal failure containing "Some other fatal failure expected." + Actual: +googletest-output-test_.cc:#: Fatal failure: +Failed +Expected fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectFailureTest.ExpectFatalFailureOnAllThreads +[ RUN ] ExpectFailureTest.ExpectNonFatalFailureOnAllThreads +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: +googletest-output-test_.cc:#: Success: +Succeeded +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: +googletest-output-test_.cc:#: Fatal failure: +Failed +Expected fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +(expecting 1 failure) +gtest.cc:#: Failure +Expected: 1 non-fatal failure containing "Some other non-fatal failure." + Actual: +googletest-output-test_.cc:#: Non-fatal failure: +Failed +Expected non-fatal failure. +Stack trace: (omitted) + + +Stack trace: (omitted) + +[ FAILED ] ExpectFailureTest.ExpectNonFatalFailureOnAllThreads +[----------] 2 tests from ExpectFailureWithThreadsTest +[ RUN ] ExpectFailureWithThreadsTest.ExpectFatalFailure +(expecting 2 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected fatal failure. +Stack trace: (omitted) + +gtest.cc:#: Failure +Expected: 1 fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectFailureWithThreadsTest.ExpectFatalFailure +[ RUN ] ExpectFailureWithThreadsTest.ExpectNonFatalFailure +(expecting 2 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected non-fatal failure. +Stack trace: (omitted) + +gtest.cc:#: Failure +Expected: 1 non-fatal failure + Actual: 0 failures +Stack trace: (omitted) + +[ FAILED ] ExpectFailureWithThreadsTest.ExpectNonFatalFailure +[----------] 1 test from ScopedFakeTestPartResultReporterTest +[ RUN ] ScopedFakeTestPartResultReporterTest.InterceptOnlyCurrentThread +(expecting 2 failures) +googletest-output-test_.cc:#: Failure +Failed +Expected fatal failure. +Stack trace: (omitted) + +googletest-output-test_.cc:#: Failure +Failed +Expected non-fatal failure. +Stack trace: (omitted) + +[ FAILED ] ScopedFakeTestPartResultReporterTest.InterceptOnlyCurrentThread +[----------] 2 tests from DynamicFixture +DynamicFixture::SetUpTestSuite +[ RUN ] DynamicFixture.DynamicTestPass +DynamicFixture() +DynamicFixture::SetUp +DynamicFixture::TearDown +~DynamicFixture() +[ OK ] DynamicFixture.DynamicTestPass +[ RUN ] DynamicFixture.DynamicTestFail +DynamicFixture() +DynamicFixture::SetUp +googletest-output-test_.cc:#: Failure +Value of: Pass + Actual: false +Expected: true +Stack trace: (omitted) + +DynamicFixture::TearDown +~DynamicFixture() +[ FAILED ] DynamicFixture.DynamicTestFail +DynamicFixture::TearDownTestSuite +[----------] 1 test from DynamicFixtureAnotherName +DynamicFixture::SetUpTestSuite +[ RUN ] DynamicFixtureAnotherName.DynamicTestPass +DynamicFixture() +DynamicFixture::SetUp +DynamicFixture::TearDown +~DynamicFixture() +[ OK ] DynamicFixtureAnotherName.DynamicTestPass +DynamicFixture::TearDownTestSuite +[----------] 2 tests from BadDynamicFixture1 +DynamicFixture::SetUpTestSuite +[ RUN ] BadDynamicFixture1.FixtureBase +DynamicFixture() +DynamicFixture::SetUp +DynamicFixture::TearDown +~DynamicFixture() +[ OK ] BadDynamicFixture1.FixtureBase +[ RUN ] BadDynamicFixture1.TestBase +DynamicFixture() +gtest.cc:#: Failure +Failed +All tests in the same test suite must use the same test fixture +class, so mixing TEST_F and TEST in the same test suite is +illegal. In test suite BadDynamicFixture1, +test FixtureBase is defined using TEST_F but +test TestBase is defined using TEST. You probably +want to change the TEST to TEST_F or move it to another test +case. +Stack trace: (omitted) + +~DynamicFixture() +[ FAILED ] BadDynamicFixture1.TestBase +DynamicFixture::TearDownTestSuite +[----------] 2 tests from BadDynamicFixture2 +DynamicFixture::SetUpTestSuite +[ RUN ] BadDynamicFixture2.FixtureBase +DynamicFixture() +DynamicFixture::SetUp +DynamicFixture::TearDown +~DynamicFixture() +[ OK ] BadDynamicFixture2.FixtureBase +[ RUN ] BadDynamicFixture2.Derived +DynamicFixture() +gtest.cc:#: Failure +Failed +All tests in the same test suite must use the same test fixture +class. However, in test suite BadDynamicFixture2, +you defined test FixtureBase and test Derived +using two different test fixture classes. This can happen if +the two classes are from different namespaces or translation +units and have the same name. You should probably rename one +of the classes to put the tests into different test suites. +Stack trace: (omitted) + +~DynamicFixture() +[ FAILED ] BadDynamicFixture2.Derived +DynamicFixture::TearDownTestSuite +[----------] 1 test from TestSuiteThatFailsToSetUp +googletest-output-test_.cc:#: Failure +Value of: false + Actual: false +Expected: true +Stack trace: (omitted) + +[ RUN ] TestSuiteThatFailsToSetUp.ShouldNotRun +googletest-output-test_.cc:#: Skipped + +[ SKIPPED ] TestSuiteThatFailsToSetUp.ShouldNotRun +[----------] 1 test from PrintingFailingParams/FailingParamTest +[ RUN ] PrintingFailingParams/FailingParamTest.Fails/0 +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + GetParam() + Which is: 2 +Stack trace: (omitted) + +[ FAILED ] PrintingFailingParams/FailingParamTest.Fails/0, where GetParam() = 2 +[----------] 1 test from EmptyBasenameParamInst +[ RUN ] EmptyBasenameParamInst.Passes/0 +[ OK ] EmptyBasenameParamInst.Passes/0 +[----------] 2 tests from PrintingStrings/ParamTest +[ RUN ] PrintingStrings/ParamTest.Success/a +[ OK ] PrintingStrings/ParamTest.Success/a +[ RUN ] PrintingStrings/ParamTest.Failure/a +googletest-output-test_.cc:#: Failure +Expected equality of these values: + "b" + GetParam() + Which is: "a" +Expected failure +Stack trace: (omitted) + +[ FAILED ] PrintingStrings/ParamTest.Failure/a, where GetParam() = "a" +[----------] 3 tests from GoogleTestVerification +[ RUN ] GoogleTestVerification.UninstantiatedParameterizedTestSuite +googletest-output-test_.cc:#: Failure +Parameterized test suite NoTests is instantiated via INSTANTIATE_TEST_SUITE_P, but no tests are defined via TEST_P . No test cases will run. + +Ideally, INSTANTIATE_TEST_SUITE_P should only ever be invoked from code that always depend on code that provides TEST_P. Failing to do so is often an indication of dead code, e.g. the last TEST_P was removed but the rest got left behind. + +To suppress this error for this test suite, insert the following line (in a non-header) in the namespace it is defined in: + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NoTests); +Stack trace: (omitted) + +[ FAILED ] GoogleTestVerification.UninstantiatedParameterizedTestSuite +[ RUN ] GoogleTestVerification.UninstantiatedParameterizedTestSuite +googletest-output-test_.cc:#: Failure +Parameterized test suite DetectNotInstantiatedTest is defined via TEST_P, but never instantiated. None of the test cases will run. Either no INSTANTIATE_TEST_SUITE_P is provided or the only ones provided expand to nothing. + +Ideally, TEST_P definitions should only ever be included as part of binaries that intend to use them. (As opposed to, for example, being placed in a library that may be linked in to get other utilities.) + +To suppress this error for this test suite, insert the following line (in a non-header) in the namespace it is defined in: + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(DetectNotInstantiatedTest); +Stack trace: (omitted) + +[ FAILED ] GoogleTestVerification.UninstantiatedParameterizedTestSuite +[ RUN ] GoogleTestVerification.UninstantiatedTypeParameterizedTestSuite +googletest-output-test_.cc:#: Failure +Type parameterized test suite DetectNotInstantiatedTypesTest is defined via REGISTER_TYPED_TEST_SUITE_P, but never instantiated via INSTANTIATE_TYPED_TEST_SUITE_P. None of the test cases will run. + +Ideally, TYPED_TEST_P definitions should only ever be included as part of binaries that intend to use them. (As opposed to, for example, being placed in a library that may be linked in to get other utilities.) + +To suppress this error for this test suite, insert the following line (in a non-header) in the namespace it is defined in: + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(DetectNotInstantiatedTypesTest); +Stack trace: (omitted) + +[ FAILED ] GoogleTestVerification.UninstantiatedTypeParameterizedTestSuite +[----------] Global test environment tear-down +BarEnvironment::TearDown() called. +googletest-output-test_.cc:#: Failure +Failed +Expected non-fatal failure. +Stack trace: (omitted) + +FooEnvironment::TearDown() called. +googletest-output-test_.cc:#: Failure +Failed +Expected fatal failure. +Stack trace: (omitted) + +[==========] 89 tests from 42 test suites ran. +[ PASSED ] 31 tests. +[ SKIPPED ] 1 test, listed below: +[ SKIPPED ] TestSuiteThatFailsToSetUp.ShouldNotRun +[ FAILED ] 57 tests, listed below: +[ FAILED ] NonfatalFailureTest.EscapesStringOperands +[ FAILED ] NonfatalFailureTest.DiffForLongStrings +[ FAILED ] FatalFailureTest.FatalFailureInSubroutine +[ FAILED ] FatalFailureTest.FatalFailureInNestedSubroutine +[ FAILED ] FatalFailureTest.NonfatalFailureInSubroutine +[ FAILED ] LoggingTest.InterleavingLoggingAndAssertions +[ FAILED ] SCOPED_TRACETest.AcceptedValues +[ FAILED ] SCOPED_TRACETest.ObeysScopes +[ FAILED ] SCOPED_TRACETest.WorksInLoop +[ FAILED ] SCOPED_TRACETest.WorksInSubroutine +[ FAILED ] SCOPED_TRACETest.CanBeNested +[ FAILED ] SCOPED_TRACETest.CanBeRepeated +[ FAILED ] SCOPED_TRACETest.WorksConcurrently +[ FAILED ] ScopedTraceTest.WithExplicitFileAndLine +[ FAILED ] NonFatalFailureInFixtureConstructorTest.FailureInConstructor +[ FAILED ] FatalFailureInFixtureConstructorTest.FailureInConstructor +[ FAILED ] NonFatalFailureInSetUpTest.FailureInSetUp +[ FAILED ] FatalFailureInSetUpTest.FailureInSetUp +[ FAILED ] AddFailureAtTest.MessageContainsSpecifiedFileAndLineNumber +[ FAILED ] GtestFailAtTest.MessageContainsSpecifiedFileAndLineNumber +[ FAILED ] MixedUpTestSuiteTest.ThisShouldFail +[ FAILED ] MixedUpTestSuiteTest.ThisShouldFailToo +[ FAILED ] MixedUpTestSuiteWithSameTestNameTest.TheSecondTestWithThisNameShouldFail +[ FAILED ] TEST_F_before_TEST_in_same_test_case.DefinedUsingTESTAndShouldFail +[ FAILED ] TEST_before_TEST_F_in_same_test_case.DefinedUsingTEST_FAndShouldFail +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenThereIsNoNonfatalFailure +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenThereAreTwoNonfatalFailures +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenThereIsOneFatalFailure +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenStatementReturns +[ FAILED ] ExpectNonfatalFailureTest.FailsWhenStatementThrows +[ FAILED ] ExpectFatalFailureTest.FailsWhenThereIsNoFatalFailure +[ FAILED ] ExpectFatalFailureTest.FailsWhenThereAreTwoFatalFailures +[ FAILED ] ExpectFatalFailureTest.FailsWhenThereIsOneNonfatalFailure +[ FAILED ] ExpectFatalFailureTest.FailsWhenStatementReturns +[ FAILED ] ExpectFatalFailureTest.FailsWhenStatementThrows +[ FAILED ] TypedTest/0.Failure, where TypeParam = int +[ FAILED ] TypedTestWithNames/char0.Failure, where TypeParam = char +[ FAILED ] TypedTestWithNames/int1.Failure, where TypeParam = int +[ FAILED ] Unsigned/TypedTestP/0.Failure, where TypeParam = unsigned char +[ FAILED ] Unsigned/TypedTestP/1.Failure, where TypeParam = unsigned int +[ FAILED ] UnsignedCustomName/TypedTestP/unsignedChar0.Failure, where TypeParam = unsigned char +[ FAILED ] UnsignedCustomName/TypedTestP/unsignedInt1.Failure, where TypeParam = unsigned int +[ FAILED ] ExpectFailureTest.ExpectFatalFailure +[ FAILED ] ExpectFailureTest.ExpectNonFatalFailure +[ FAILED ] ExpectFailureTest.ExpectFatalFailureOnAllThreads +[ FAILED ] ExpectFailureTest.ExpectNonFatalFailureOnAllThreads +[ FAILED ] ExpectFailureWithThreadsTest.ExpectFatalFailure +[ FAILED ] ExpectFailureWithThreadsTest.ExpectNonFatalFailure +[ FAILED ] ScopedFakeTestPartResultReporterTest.InterceptOnlyCurrentThread +[ FAILED ] DynamicFixture.DynamicTestFail +[ FAILED ] BadDynamicFixture1.TestBase +[ FAILED ] BadDynamicFixture2.Derived +[ FAILED ] PrintingFailingParams/FailingParamTest.Fails/0, where GetParam() = 2 +[ FAILED ] PrintingStrings/ParamTest.Failure/a, where GetParam() = "a" +[ FAILED ] GoogleTestVerification.UninstantiatedParameterizedTestSuite +[ FAILED ] GoogleTestVerification.UninstantiatedParameterizedTestSuite +[ FAILED ] GoogleTestVerification.UninstantiatedTypeParameterizedTestSuite + +57 FAILED TESTS +[ FAILED ] TestSuiteThatFailsToSetUp: SetUpTestSuite or TearDownTestSuite + + 1 FAILED TEST SUITE + YOU HAVE 1 DISABLED TEST + +Note: Google Test filter = FatalFailureTest.*:LoggingTest.* +[==========] Running 4 tests from 2 test suites. +[----------] Global test environment set-up. +[----------] 3 tests from FatalFailureTest +[ RUN ] FatalFailureTest.FatalFailureInSubroutine +(expecting a failure that x should be 1) +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + x + Which is: 2 +Stack trace: (omitted) + +[ FAILED ] FatalFailureTest.FatalFailureInSubroutine (? ms) +[ RUN ] FatalFailureTest.FatalFailureInNestedSubroutine +(expecting a failure that x should be 1) +googletest-output-test_.cc:#: Failure +Expected equality of these values: + 1 + x + Which is: 2 +Stack trace: (omitted) + +[ FAILED ] FatalFailureTest.FatalFailureInNestedSubroutine (? ms) +[ RUN ] FatalFailureTest.NonfatalFailureInSubroutine +(expecting a failure on false) +googletest-output-test_.cc:#: Failure +Value of: false + Actual: false +Expected: true +Stack trace: (omitted) + +[ FAILED ] FatalFailureTest.NonfatalFailureInSubroutine (? ms) +[----------] 3 tests from FatalFailureTest (? ms total) + +[----------] 1 test from LoggingTest +[ RUN ] LoggingTest.InterleavingLoggingAndAssertions +(expecting 2 failures on (3) >= (a[i])) +i == 0 +i == 1 +googletest-output-test_.cc:#: Failure +Expected: (3) >= (a[i]), actual: 3 vs 9 +Stack trace: (omitted) + +i == 2 +i == 3 +googletest-output-test_.cc:#: Failure +Expected: (3) >= (a[i]), actual: 3 vs 6 +Stack trace: (omitted) + +[ FAILED ] LoggingTest.InterleavingLoggingAndAssertions (? ms) +[----------] 1 test from LoggingTest (? ms total) + +[----------] Global test environment tear-down +[==========] 4 tests from 2 test suites ran. (? ms total) +[ PASSED ] 0 tests. +[ FAILED ] 4 tests, listed below: +[ FAILED ] FatalFailureTest.FatalFailureInSubroutine +[ FAILED ] FatalFailureTest.FatalFailureInNestedSubroutine +[ FAILED ] FatalFailureTest.NonfatalFailureInSubroutine +[ FAILED ] LoggingTest.InterleavingLoggingAndAssertions + + 4 FAILED TESTS +Note: Google Test filter = *DISABLED_* +[==========] Running 1 test from 1 test suite. +[----------] Global test environment set-up. +[----------] 1 test from DisabledTestsWarningTest +[ RUN ] DisabledTestsWarningTest.DISABLED_AlsoRunDisabledTestsFlagSuppressesWarning +[ OK ] DisabledTestsWarningTest.DISABLED_AlsoRunDisabledTestsFlagSuppressesWarning +[----------] Global test environment tear-down +[==========] 1 test from 1 test suite ran. +[ PASSED ] 1 test. +Note: Google Test filter = PassingTest.* +Note: This is test shard 2 of 2. +[==========] Running 1 test from 1 test suite. +[----------] Global test environment set-up. +[----------] 1 test from PassingTest +[ RUN ] PassingTest.PassingTest2 +[ OK ] PassingTest.PassingTest2 +[----------] Global test environment tear-down +[==========] 1 test from 1 test suite ran. +[ PASSED ] 1 test. diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test.py new file mode 100755 index 000000000000..ff44483331fb --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test.py @@ -0,0 +1,346 @@ +#!/usr/bin/env python +# +# Copyright 2008, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +r"""Tests the text output of Google C++ Testing and Mocking Framework. + +To update the golden file: +googletest_output_test.py --build_dir=BUILD/DIR --gengolden +where BUILD/DIR contains the built googletest-output-test_ file. +googletest_output_test.py --gengolden +googletest_output_test.py +""" + +import difflib +import os +import re +import sys +from googletest.test import gtest_test_utils + + +# The flag for generating the golden file +GENGOLDEN_FLAG = '--gengolden' +CATCH_EXCEPTIONS_ENV_VAR_NAME = 'GTEST_CATCH_EXCEPTIONS' + +# The flag indicating stacktraces are not supported +NO_STACKTRACE_SUPPORT_FLAG = '--no_stacktrace_support' + +IS_LINUX = os.name == 'posix' and os.uname()[0] == 'Linux' +IS_WINDOWS = os.name == 'nt' + +GOLDEN_NAME = 'googletest-output-test-golden-lin.txt' + +PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath('googletest-output-test_') + +# At least one command we exercise must not have the +# 'internal_skip_environment_and_ad_hoc_tests' argument. +COMMAND_LIST_TESTS = ({}, [PROGRAM_PATH, '--gtest_list_tests']) +COMMAND_WITH_COLOR = ({}, [PROGRAM_PATH, '--gtest_color=yes']) +COMMAND_WITH_TIME = ({}, [PROGRAM_PATH, + '--gtest_print_time', + 'internal_skip_environment_and_ad_hoc_tests', + '--gtest_filter=FatalFailureTest.*:LoggingTest.*']) +COMMAND_WITH_DISABLED = ( + {}, [PROGRAM_PATH, + '--gtest_also_run_disabled_tests', + 'internal_skip_environment_and_ad_hoc_tests', + '--gtest_filter=*DISABLED_*']) +COMMAND_WITH_SHARDING = ( + {'GTEST_SHARD_INDEX': '1', 'GTEST_TOTAL_SHARDS': '2'}, + [PROGRAM_PATH, + 'internal_skip_environment_and_ad_hoc_tests', + '--gtest_filter=PassingTest.*']) + +GOLDEN_PATH = os.path.join(gtest_test_utils.GetSourceDir(), GOLDEN_NAME) + + +def ToUnixLineEnding(s): + """Changes all Windows/Mac line endings in s to UNIX line endings.""" + + return s.replace('\r\n', '\n').replace('\r', '\n') + + +def RemoveLocations(test_output): + """Removes all file location info from a Google Test program's output. + + Args: + test_output: the output of a Google Test program. + + Returns: + output with all file location info (in the form of + 'DIRECTORY/FILE_NAME:LINE_NUMBER: 'or + 'DIRECTORY\\FILE_NAME(LINE_NUMBER): ') replaced by + 'FILE_NAME:#: '. + """ + + return re.sub(r'.*[/\\]((googletest-output-test_|gtest).cc)(\:\d+|\(\d+\))\: ', + r'\1:#: ', test_output) + + +def RemoveStackTraceDetails(output): + """Removes all stack traces from a Google Test program's output.""" + + # *? means "find the shortest string that matches". + return re.sub(r'Stack trace:(.|\n)*?\n\n', + 'Stack trace: (omitted)\n\n', output) + + +def RemoveStackTraces(output): + """Removes all traces of stack traces from a Google Test program's output.""" + + # *? means "find the shortest string that matches". + return re.sub(r'Stack trace:(.|\n)*?\n\n', '', output) + + +def RemoveTime(output): + """Removes all time information from a Google Test program's output.""" + + return re.sub(r'\(\d+ ms', '(? ms', output) + + +def RemoveTypeInfoDetails(test_output): + """Removes compiler-specific type info from Google Test program's output. + + Args: + test_output: the output of a Google Test program. + + Returns: + output with type information normalized to canonical form. + """ + + # some compilers output the name of type 'unsigned int' as 'unsigned' + return re.sub(r'unsigned int', 'unsigned', test_output) + + +def NormalizeToCurrentPlatform(test_output): + """Normalizes platform specific output details for easier comparison.""" + + if IS_WINDOWS: + # Removes the color information that is not present on Windows. + test_output = re.sub('\x1b\\[(0;3\d)?m', '', test_output) + # Changes failure message headers into the Windows format. + test_output = re.sub(r': Failure\n', r': error: ', test_output) + # Changes file(line_number) to file:line_number. + test_output = re.sub(r'((\w|\.)+)\((\d+)\):', r'\1:\3:', test_output) + + return test_output + + +def RemoveTestCounts(output): + """Removes test counts from a Google Test program's output.""" + + output = re.sub(r'\d+ tests?, listed below', + '? tests, listed below', output) + output = re.sub(r'\d+ FAILED TESTS', + '? FAILED TESTS', output) + output = re.sub(r'\d+ tests? from \d+ test cases?', + '? tests from ? test cases', output) + output = re.sub(r'\d+ tests? from ([a-zA-Z_])', + r'? tests from \1', output) + return re.sub(r'\d+ tests?\.', '? tests.', output) + + +def RemoveMatchingTests(test_output, pattern): + """Removes output of specified tests from a Google Test program's output. + + This function strips not only the beginning and the end of a test but also + all output in between. + + Args: + test_output: A string containing the test output. + pattern: A regex string that matches names of test cases or + tests to remove. + + Returns: + Contents of test_output with tests whose names match pattern removed. + """ + + test_output = re.sub( + r'.*\[ RUN \] .*%s(.|\n)*?\[( FAILED | OK )\] .*%s.*\n' % ( + pattern, pattern), + '', + test_output) + return re.sub(r'.*%s.*\n' % pattern, '', test_output) + + +def NormalizeOutput(output): + """Normalizes output (the output of googletest-output-test_.exe).""" + + output = ToUnixLineEnding(output) + output = RemoveLocations(output) + output = RemoveStackTraceDetails(output) + output = RemoveTime(output) + return output + + +def GetShellCommandOutput(env_cmd): + """Runs a command in a sub-process, and returns its output in a string. + + Args: + env_cmd: The shell command. A 2-tuple where element 0 is a dict of extra + environment variables to set, and element 1 is a string with + the command and any flags. + + Returns: + A string with the command's combined standard and diagnostic output. + """ + + # Spawns cmd in a sub-process, and gets its standard I/O file objects. + # Set and save the environment properly. + environ = os.environ.copy() + environ.update(env_cmd[0]) + p = gtest_test_utils.Subprocess(env_cmd[1], env=environ) + + return p.output + + +def GetCommandOutput(env_cmd): + """Runs a command and returns its output with all file location + info stripped off. + + Args: + env_cmd: The shell command. A 2-tuple where element 0 is a dict of extra + environment variables to set, and element 1 is a string with + the command and any flags. + """ + + # Disables exception pop-ups on Windows. + environ, cmdline = env_cmd + environ = dict(environ) # Ensures we are modifying a copy. + environ[CATCH_EXCEPTIONS_ENV_VAR_NAME] = '1' + return NormalizeOutput(GetShellCommandOutput((environ, cmdline))) + + +def GetOutputOfAllCommands(): + """Returns concatenated output from several representative commands.""" + + return (GetCommandOutput(COMMAND_WITH_COLOR) + + GetCommandOutput(COMMAND_WITH_TIME) + + GetCommandOutput(COMMAND_WITH_DISABLED) + + GetCommandOutput(COMMAND_WITH_SHARDING)) + + +test_list = GetShellCommandOutput(COMMAND_LIST_TESTS) +SUPPORTS_DEATH_TESTS = 'DeathTest' in test_list +SUPPORTS_TYPED_TESTS = 'TypedTest' in test_list +SUPPORTS_THREADS = 'ExpectFailureWithThreadsTest' in test_list +SUPPORTS_STACK_TRACES = NO_STACKTRACE_SUPPORT_FLAG not in sys.argv + +CAN_GENERATE_GOLDEN_FILE = (SUPPORTS_DEATH_TESTS and + SUPPORTS_TYPED_TESTS and + SUPPORTS_THREADS and + SUPPORTS_STACK_TRACES) + +class GTestOutputTest(gtest_test_utils.TestCase): + def RemoveUnsupportedTests(self, test_output): + if not SUPPORTS_DEATH_TESTS: + test_output = RemoveMatchingTests(test_output, 'DeathTest') + if not SUPPORTS_TYPED_TESTS: + test_output = RemoveMatchingTests(test_output, 'TypedTest') + test_output = RemoveMatchingTests(test_output, 'TypedDeathTest') + test_output = RemoveMatchingTests(test_output, 'TypeParamDeathTest') + if not SUPPORTS_THREADS: + test_output = RemoveMatchingTests(test_output, + 'ExpectFailureWithThreadsTest') + test_output = RemoveMatchingTests(test_output, + 'ScopedFakeTestPartResultReporterTest') + test_output = RemoveMatchingTests(test_output, + 'WorksConcurrently') + if not SUPPORTS_STACK_TRACES: + test_output = RemoveStackTraces(test_output) + + return test_output + + def testOutput(self): + output = GetOutputOfAllCommands() + + golden_file = open(GOLDEN_PATH, 'rb') + # A mis-configured source control system can cause \r appear in EOL + # sequences when we read the golden file irrespective of an operating + # system used. Therefore, we need to strip those \r's from newlines + # unconditionally. + golden = ToUnixLineEnding(golden_file.read().decode()) + golden_file.close() + + # We want the test to pass regardless of certain features being + # supported or not. + + # We still have to remove type name specifics in all cases. + normalized_actual = RemoveTypeInfoDetails(output) + normalized_golden = RemoveTypeInfoDetails(golden) + + if CAN_GENERATE_GOLDEN_FILE: + self.assertEqual(normalized_golden, normalized_actual, + '\n'.join(difflib.unified_diff( + normalized_golden.split('\n'), + normalized_actual.split('\n'), + 'golden', 'actual'))) + else: + normalized_actual = NormalizeToCurrentPlatform( + RemoveTestCounts(normalized_actual)) + normalized_golden = NormalizeToCurrentPlatform( + RemoveTestCounts(self.RemoveUnsupportedTests(normalized_golden))) + + # This code is very handy when debugging golden file differences: + if os.getenv('DEBUG_GTEST_OUTPUT_TEST'): + open(os.path.join( + gtest_test_utils.GetSourceDir(), + '_googletest-output-test_normalized_actual.txt'), 'wb').write( + normalized_actual) + open(os.path.join( + gtest_test_utils.GetSourceDir(), + '_googletest-output-test_normalized_golden.txt'), 'wb').write( + normalized_golden) + + self.assertEqual(normalized_golden, normalized_actual) + + +if __name__ == '__main__': + if NO_STACKTRACE_SUPPORT_FLAG in sys.argv: + # unittest.main() can't handle unknown flags + sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) + + if GENGOLDEN_FLAG in sys.argv: + if CAN_GENERATE_GOLDEN_FILE: + output = GetOutputOfAllCommands() + golden_file = open(GOLDEN_PATH, 'wb') + golden_file.write(output.encode()) + golden_file.close() + else: + message = ( + """Unable to write a golden file when compiled in an environment +that does not support all the required features (death tests, +typed tests, stack traces, and multiple threads). +Please build this test and generate the golden file using Blaze on Linux.""") + + sys.stderr.write(message) + sys.exit(1) + else: + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test_.cc new file mode 100644 index 000000000000..b0ad52ca3e6c --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-output-test_.cc @@ -0,0 +1,1116 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// The purpose of this file is to generate Google Test output under +// various conditions. The output will then be verified by +// googletest-output-test.py to ensure that Google Test generates the +// desired messages. Therefore, most tests in this file are MEANT TO +// FAIL. + +#include "gtest/gtest-spi.h" +#include "gtest/gtest.h" +#include "src/gtest-internal-inl.h" + +#include + +#if _MSC_VER +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4127 /* conditional expression is constant */) +#endif // _MSC_VER + +#if GTEST_IS_THREADSAFE +using testing::ScopedFakeTestPartResultReporter; +using testing::TestPartResultArray; + +using testing::internal::Notification; +using testing::internal::ThreadWithParam; +#endif + +namespace posix = ::testing::internal::posix; + +// Tests catching fatal failures. + +// A subroutine used by the following test. +void TestEq1(int x) { + ASSERT_EQ(1, x); +} + +// This function calls a test subroutine, catches the fatal failure it +// generates, and then returns early. +void TryTestSubroutine() { + // Calls a subrountine that yields a fatal failure. + TestEq1(2); + + // Catches the fatal failure and aborts the test. + // + // The testing::Test:: prefix is necessary when calling + // HasFatalFailure() outside of a TEST, TEST_F, or test fixture. + if (testing::Test::HasFatalFailure()) return; + + // If we get here, something is wrong. + FAIL() << "This should never be reached."; +} + +TEST(PassingTest, PassingTest1) { +} + +TEST(PassingTest, PassingTest2) { +} + +// Tests that parameters of failing parameterized tests are printed in the +// failing test summary. +class FailingParamTest : public testing::TestWithParam {}; + +TEST_P(FailingParamTest, Fails) { + EXPECT_EQ(1, GetParam()); +} + +// This generates a test which will fail. Google Test is expected to print +// its parameter when it outputs the list of all failed tests. +INSTANTIATE_TEST_SUITE_P(PrintingFailingParams, + FailingParamTest, + testing::Values(2)); + +// Tests that an empty value for the test suite basename yields just +// the test name without any prior / +class EmptyBasenameParamInst : public testing::TestWithParam {}; + +TEST_P(EmptyBasenameParamInst, Passes) { EXPECT_EQ(1, GetParam()); } + +INSTANTIATE_TEST_SUITE_P(, EmptyBasenameParamInst, testing::Values(1)); + +static const char kGoldenString[] = "\"Line\0 1\"\nLine 2"; + +TEST(NonfatalFailureTest, EscapesStringOperands) { + std::string actual = "actual \"string\""; + EXPECT_EQ(kGoldenString, actual); + + const char* golden = kGoldenString; + EXPECT_EQ(golden, actual); +} + +TEST(NonfatalFailureTest, DiffForLongStrings) { + std::string golden_str(kGoldenString, sizeof(kGoldenString) - 1); + EXPECT_EQ(golden_str, "Line 2"); +} + +// Tests catching a fatal failure in a subroutine. +TEST(FatalFailureTest, FatalFailureInSubroutine) { + printf("(expecting a failure that x should be 1)\n"); + + TryTestSubroutine(); +} + +// Tests catching a fatal failure in a nested subroutine. +TEST(FatalFailureTest, FatalFailureInNestedSubroutine) { + printf("(expecting a failure that x should be 1)\n"); + + // Calls a subrountine that yields a fatal failure. + TryTestSubroutine(); + + // Catches the fatal failure and aborts the test. + // + // When calling HasFatalFailure() inside a TEST, TEST_F, or test + // fixture, the testing::Test:: prefix is not needed. + if (HasFatalFailure()) return; + + // If we get here, something is wrong. + FAIL() << "This should never be reached."; +} + +// Tests HasFatalFailure() after a failed EXPECT check. +TEST(FatalFailureTest, NonfatalFailureInSubroutine) { + printf("(expecting a failure on false)\n"); + EXPECT_TRUE(false); // Generates a nonfatal failure + ASSERT_FALSE(HasFatalFailure()); // This should succeed. +} + +// Tests interleaving user logging and Google Test assertions. +TEST(LoggingTest, InterleavingLoggingAndAssertions) { + static const int a[4] = { + 3, 9, 2, 6 + }; + + printf("(expecting 2 failures on (3) >= (a[i]))\n"); + for (int i = 0; i < static_cast(sizeof(a)/sizeof(*a)); i++) { + printf("i == %d\n", i); + EXPECT_GE(3, a[i]); + } +} + +// Tests the SCOPED_TRACE macro. + +// A helper function for testing SCOPED_TRACE. +void SubWithoutTrace(int n) { + EXPECT_EQ(1, n); + ASSERT_EQ(2, n); +} + +// Another helper function for testing SCOPED_TRACE. +void SubWithTrace(int n) { + SCOPED_TRACE(testing::Message() << "n = " << n); + + SubWithoutTrace(n); +} + +TEST(SCOPED_TRACETest, AcceptedValues) { + SCOPED_TRACE("literal string"); + SCOPED_TRACE(std::string("std::string")); + SCOPED_TRACE(1337); // streamable type + const char* null_value = nullptr; + SCOPED_TRACE(null_value); + + ADD_FAILURE() << "Just checking that all these values work fine."; +} + +// Tests that SCOPED_TRACE() obeys lexical scopes. +TEST(SCOPED_TRACETest, ObeysScopes) { + printf("(expected to fail)\n"); + + // There should be no trace before SCOPED_TRACE() is invoked. + ADD_FAILURE() << "This failure is expected, and shouldn't have a trace."; + + { + SCOPED_TRACE("Expected trace"); + // After SCOPED_TRACE(), a failure in the current scope should contain + // the trace. + ADD_FAILURE() << "This failure is expected, and should have a trace."; + } + + // Once the control leaves the scope of the SCOPED_TRACE(), there + // should be no trace again. + ADD_FAILURE() << "This failure is expected, and shouldn't have a trace."; +} + +// Tests that SCOPED_TRACE works inside a loop. +TEST(SCOPED_TRACETest, WorksInLoop) { + printf("(expected to fail)\n"); + + for (int i = 1; i <= 2; i++) { + SCOPED_TRACE(testing::Message() << "i = " << i); + + SubWithoutTrace(i); + } +} + +// Tests that SCOPED_TRACE works in a subroutine. +TEST(SCOPED_TRACETest, WorksInSubroutine) { + printf("(expected to fail)\n"); + + SubWithTrace(1); + SubWithTrace(2); +} + +// Tests that SCOPED_TRACE can be nested. +TEST(SCOPED_TRACETest, CanBeNested) { + printf("(expected to fail)\n"); + + SCOPED_TRACE(""); // A trace without a message. + + SubWithTrace(2); +} + +// Tests that multiple SCOPED_TRACEs can be used in the same scope. +TEST(SCOPED_TRACETest, CanBeRepeated) { + printf("(expected to fail)\n"); + + SCOPED_TRACE("A"); + ADD_FAILURE() + << "This failure is expected, and should contain trace point A."; + + SCOPED_TRACE("B"); + ADD_FAILURE() + << "This failure is expected, and should contain trace point A and B."; + + { + SCOPED_TRACE("C"); + ADD_FAILURE() << "This failure is expected, and should " + << "contain trace point A, B, and C."; + } + + SCOPED_TRACE("D"); + ADD_FAILURE() << "This failure is expected, and should " + << "contain trace point A, B, and D."; +} + +#if GTEST_IS_THREADSAFE +// Tests that SCOPED_TRACE()s can be used concurrently from multiple +// threads. Namely, an assertion should be affected by +// SCOPED_TRACE()s in its own thread only. + +// Here's the sequence of actions that happen in the test: +// +// Thread A (main) | Thread B (spawned) +// ===============================|================================ +// spawns thread B | +// -------------------------------+-------------------------------- +// waits for n1 | SCOPED_TRACE("Trace B"); +// | generates failure #1 +// | notifies n1 +// -------------------------------+-------------------------------- +// SCOPED_TRACE("Trace A"); | waits for n2 +// generates failure #2 | +// notifies n2 | +// -------------------------------|-------------------------------- +// waits for n3 | generates failure #3 +// | trace B dies +// | generates failure #4 +// | notifies n3 +// -------------------------------|-------------------------------- +// generates failure #5 | finishes +// trace A dies | +// generates failure #6 | +// -------------------------------|-------------------------------- +// waits for thread B to finish | + +struct CheckPoints { + Notification n1; + Notification n2; + Notification n3; +}; + +static void ThreadWithScopedTrace(CheckPoints* check_points) { + { + SCOPED_TRACE("Trace B"); + ADD_FAILURE() + << "Expected failure #1 (in thread B, only trace B alive)."; + check_points->n1.Notify(); + check_points->n2.WaitForNotification(); + + ADD_FAILURE() + << "Expected failure #3 (in thread B, trace A & B both alive)."; + } // Trace B dies here. + ADD_FAILURE() + << "Expected failure #4 (in thread B, only trace A alive)."; + check_points->n3.Notify(); +} + +TEST(SCOPED_TRACETest, WorksConcurrently) { + printf("(expecting 6 failures)\n"); + + CheckPoints check_points; + ThreadWithParam thread(&ThreadWithScopedTrace, &check_points, + nullptr); + check_points.n1.WaitForNotification(); + + { + SCOPED_TRACE("Trace A"); + ADD_FAILURE() + << "Expected failure #2 (in thread A, trace A & B both alive)."; + check_points.n2.Notify(); + check_points.n3.WaitForNotification(); + + ADD_FAILURE() + << "Expected failure #5 (in thread A, only trace A alive)."; + } // Trace A dies here. + ADD_FAILURE() + << "Expected failure #6 (in thread A, no trace alive)."; + thread.Join(); +} +#endif // GTEST_IS_THREADSAFE + +// Tests basic functionality of the ScopedTrace utility (most of its features +// are already tested in SCOPED_TRACETest). +TEST(ScopedTraceTest, WithExplicitFileAndLine) { + testing::ScopedTrace trace("explicit_file.cc", 123, "expected trace message"); + ADD_FAILURE() << "Check that the trace is attached to a particular location."; +} + +TEST(DisabledTestsWarningTest, + DISABLED_AlsoRunDisabledTestsFlagSuppressesWarning) { + // This test body is intentionally empty. Its sole purpose is for + // verifying that the --gtest_also_run_disabled_tests flag + // suppresses the "YOU HAVE 12 DISABLED TESTS" warning at the end of + // the test output. +} + +// Tests using assertions outside of TEST and TEST_F. +// +// This function creates two failures intentionally. +void AdHocTest() { + printf("The non-test part of the code is expected to have 2 failures.\n\n"); + EXPECT_TRUE(false); + EXPECT_EQ(2, 3); +} + +// Runs all TESTs, all TEST_Fs, and the ad hoc test. +int RunAllTests() { + AdHocTest(); + return RUN_ALL_TESTS(); +} + +// Tests non-fatal failures in the fixture constructor. +class NonFatalFailureInFixtureConstructorTest : public testing::Test { + protected: + NonFatalFailureInFixtureConstructorTest() { + printf("(expecting 5 failures)\n"); + ADD_FAILURE() << "Expected failure #1, in the test fixture c'tor."; + } + + ~NonFatalFailureInFixtureConstructorTest() override { + ADD_FAILURE() << "Expected failure #5, in the test fixture d'tor."; + } + + void SetUp() override { ADD_FAILURE() << "Expected failure #2, in SetUp()."; } + + void TearDown() override { + ADD_FAILURE() << "Expected failure #4, in TearDown."; + } +}; + +TEST_F(NonFatalFailureInFixtureConstructorTest, FailureInConstructor) { + ADD_FAILURE() << "Expected failure #3, in the test body."; +} + +// Tests fatal failures in the fixture constructor. +class FatalFailureInFixtureConstructorTest : public testing::Test { + protected: + FatalFailureInFixtureConstructorTest() { + printf("(expecting 2 failures)\n"); + Init(); + } + + ~FatalFailureInFixtureConstructorTest() override { + ADD_FAILURE() << "Expected failure #2, in the test fixture d'tor."; + } + + void SetUp() override { + ADD_FAILURE() << "UNEXPECTED failure in SetUp(). " + << "We should never get here, as the test fixture c'tor " + << "had a fatal failure."; + } + + void TearDown() override { + ADD_FAILURE() << "UNEXPECTED failure in TearDown(). " + << "We should never get here, as the test fixture c'tor " + << "had a fatal failure."; + } + + private: + void Init() { + FAIL() << "Expected failure #1, in the test fixture c'tor."; + } +}; + +TEST_F(FatalFailureInFixtureConstructorTest, FailureInConstructor) { + ADD_FAILURE() << "UNEXPECTED failure in the test body. " + << "We should never get here, as the test fixture c'tor " + << "had a fatal failure."; +} + +// Tests non-fatal failures in SetUp(). +class NonFatalFailureInSetUpTest : public testing::Test { + protected: + ~NonFatalFailureInSetUpTest() override { Deinit(); } + + void SetUp() override { + printf("(expecting 4 failures)\n"); + ADD_FAILURE() << "Expected failure #1, in SetUp()."; + } + + void TearDown() override { FAIL() << "Expected failure #3, in TearDown()."; } + + private: + void Deinit() { + FAIL() << "Expected failure #4, in the test fixture d'tor."; + } +}; + +TEST_F(NonFatalFailureInSetUpTest, FailureInSetUp) { + FAIL() << "Expected failure #2, in the test function."; +} + +// Tests fatal failures in SetUp(). +class FatalFailureInSetUpTest : public testing::Test { + protected: + ~FatalFailureInSetUpTest() override { Deinit(); } + + void SetUp() override { + printf("(expecting 3 failures)\n"); + FAIL() << "Expected failure #1, in SetUp()."; + } + + void TearDown() override { FAIL() << "Expected failure #2, in TearDown()."; } + + private: + void Deinit() { + FAIL() << "Expected failure #3, in the test fixture d'tor."; + } +}; + +TEST_F(FatalFailureInSetUpTest, FailureInSetUp) { + FAIL() << "UNEXPECTED failure in the test function. " + << "We should never get here, as SetUp() failed."; +} + +TEST(AddFailureAtTest, MessageContainsSpecifiedFileAndLineNumber) { + ADD_FAILURE_AT("foo.cc", 42) << "Expected nonfatal failure in foo.cc"; +} + +TEST(GtestFailAtTest, MessageContainsSpecifiedFileAndLineNumber) { + GTEST_FAIL_AT("foo.cc", 42) << "Expected fatal failure in foo.cc"; +} + +// The MixedUpTestSuiteTest test case verifies that Google Test will fail a +// test if it uses a different fixture class than what other tests in +// the same test case use. It deliberately contains two fixture +// classes with the same name but defined in different namespaces. + +// The MixedUpTestSuiteWithSameTestNameTest test case verifies that +// when the user defines two tests with the same test case name AND +// same test name (but in different namespaces), the second test will +// fail. + +namespace foo { + +class MixedUpTestSuiteTest : public testing::Test { +}; + +TEST_F(MixedUpTestSuiteTest, FirstTestFromNamespaceFoo) {} +TEST_F(MixedUpTestSuiteTest, SecondTestFromNamespaceFoo) {} + +class MixedUpTestSuiteWithSameTestNameTest : public testing::Test { +}; + +TEST_F(MixedUpTestSuiteWithSameTestNameTest, + TheSecondTestWithThisNameShouldFail) {} + +} // namespace foo + +namespace bar { + +class MixedUpTestSuiteTest : public testing::Test { +}; + +// The following two tests are expected to fail. We rely on the +// golden file to check that Google Test generates the right error message. +TEST_F(MixedUpTestSuiteTest, ThisShouldFail) {} +TEST_F(MixedUpTestSuiteTest, ThisShouldFailToo) {} + +class MixedUpTestSuiteWithSameTestNameTest : public testing::Test { +}; + +// Expected to fail. We rely on the golden file to check that Google Test +// generates the right error message. +TEST_F(MixedUpTestSuiteWithSameTestNameTest, + TheSecondTestWithThisNameShouldFail) {} + +} // namespace bar + +// The following two test cases verify that Google Test catches the user +// error of mixing TEST and TEST_F in the same test case. The first +// test case checks the scenario where TEST_F appears before TEST, and +// the second one checks where TEST appears before TEST_F. + +class TEST_F_before_TEST_in_same_test_case : public testing::Test { +}; + +TEST_F(TEST_F_before_TEST_in_same_test_case, DefinedUsingTEST_F) {} + +// Expected to fail. We rely on the golden file to check that Google Test +// generates the right error message. +TEST(TEST_F_before_TEST_in_same_test_case, DefinedUsingTESTAndShouldFail) {} + +class TEST_before_TEST_F_in_same_test_case : public testing::Test { +}; + +TEST(TEST_before_TEST_F_in_same_test_case, DefinedUsingTEST) {} + +// Expected to fail. We rely on the golden file to check that Google Test +// generates the right error message. +TEST_F(TEST_before_TEST_F_in_same_test_case, DefinedUsingTEST_FAndShouldFail) { +} + +// Used for testing EXPECT_NONFATAL_FAILURE() and EXPECT_FATAL_FAILURE(). +int global_integer = 0; + +// Tests that EXPECT_NONFATAL_FAILURE() can reference global variables. +TEST(ExpectNonfatalFailureTest, CanReferenceGlobalVariables) { + global_integer = 0; + EXPECT_NONFATAL_FAILURE({ + EXPECT_EQ(1, global_integer) << "Expected non-fatal failure."; + }, "Expected non-fatal failure."); +} + +// Tests that EXPECT_NONFATAL_FAILURE() can reference local variables +// (static or not). +TEST(ExpectNonfatalFailureTest, CanReferenceLocalVariables) { + int m = 0; + static int n; + n = 1; + EXPECT_NONFATAL_FAILURE({ + EXPECT_EQ(m, n) << "Expected non-fatal failure."; + }, "Expected non-fatal failure."); +} + +// Tests that EXPECT_NONFATAL_FAILURE() succeeds when there is exactly +// one non-fatal failure and no fatal failure. +TEST(ExpectNonfatalFailureTest, SucceedsWhenThereIsOneNonfatalFailure) { + EXPECT_NONFATAL_FAILURE({ + ADD_FAILURE() << "Expected non-fatal failure."; + }, "Expected non-fatal failure."); +} + +// Tests that EXPECT_NONFATAL_FAILURE() fails when there is no +// non-fatal failure. +TEST(ExpectNonfatalFailureTest, FailsWhenThereIsNoNonfatalFailure) { + printf("(expecting a failure)\n"); + EXPECT_NONFATAL_FAILURE({ + }, ""); +} + +// Tests that EXPECT_NONFATAL_FAILURE() fails when there are two +// non-fatal failures. +TEST(ExpectNonfatalFailureTest, FailsWhenThereAreTwoNonfatalFailures) { + printf("(expecting a failure)\n"); + EXPECT_NONFATAL_FAILURE({ + ADD_FAILURE() << "Expected non-fatal failure 1."; + ADD_FAILURE() << "Expected non-fatal failure 2."; + }, ""); +} + +// Tests that EXPECT_NONFATAL_FAILURE() fails when there is one fatal +// failure. +TEST(ExpectNonfatalFailureTest, FailsWhenThereIsOneFatalFailure) { + printf("(expecting a failure)\n"); + EXPECT_NONFATAL_FAILURE({ + FAIL() << "Expected fatal failure."; + }, ""); +} + +// Tests that EXPECT_NONFATAL_FAILURE() fails when the statement being +// tested returns. +TEST(ExpectNonfatalFailureTest, FailsWhenStatementReturns) { + printf("(expecting a failure)\n"); + EXPECT_NONFATAL_FAILURE({ + return; + }, ""); +} + +#if GTEST_HAS_EXCEPTIONS + +// Tests that EXPECT_NONFATAL_FAILURE() fails when the statement being +// tested throws. +TEST(ExpectNonfatalFailureTest, FailsWhenStatementThrows) { + printf("(expecting a failure)\n"); + try { + EXPECT_NONFATAL_FAILURE({ + throw 0; + }, ""); + } catch(int) { // NOLINT + } +} + +#endif // GTEST_HAS_EXCEPTIONS + +// Tests that EXPECT_FATAL_FAILURE() can reference global variables. +TEST(ExpectFatalFailureTest, CanReferenceGlobalVariables) { + global_integer = 0; + EXPECT_FATAL_FAILURE({ + ASSERT_EQ(1, global_integer) << "Expected fatal failure."; + }, "Expected fatal failure."); +} + +// Tests that EXPECT_FATAL_FAILURE() can reference local static +// variables. +TEST(ExpectFatalFailureTest, CanReferenceLocalStaticVariables) { + static int n; + n = 1; + EXPECT_FATAL_FAILURE({ + ASSERT_EQ(0, n) << "Expected fatal failure."; + }, "Expected fatal failure."); +} + +// Tests that EXPECT_FATAL_FAILURE() succeeds when there is exactly +// one fatal failure and no non-fatal failure. +TEST(ExpectFatalFailureTest, SucceedsWhenThereIsOneFatalFailure) { + EXPECT_FATAL_FAILURE({ + FAIL() << "Expected fatal failure."; + }, "Expected fatal failure."); +} + +// Tests that EXPECT_FATAL_FAILURE() fails when there is no fatal +// failure. +TEST(ExpectFatalFailureTest, FailsWhenThereIsNoFatalFailure) { + printf("(expecting a failure)\n"); + EXPECT_FATAL_FAILURE({ + }, ""); +} + +// A helper for generating a fatal failure. +void FatalFailure() { + FAIL() << "Expected fatal failure."; +} + +// Tests that EXPECT_FATAL_FAILURE() fails when there are two +// fatal failures. +TEST(ExpectFatalFailureTest, FailsWhenThereAreTwoFatalFailures) { + printf("(expecting a failure)\n"); + EXPECT_FATAL_FAILURE({ + FatalFailure(); + FatalFailure(); + }, ""); +} + +// Tests that EXPECT_FATAL_FAILURE() fails when there is one non-fatal +// failure. +TEST(ExpectFatalFailureTest, FailsWhenThereIsOneNonfatalFailure) { + printf("(expecting a failure)\n"); + EXPECT_FATAL_FAILURE({ + ADD_FAILURE() << "Expected non-fatal failure."; + }, ""); +} + +// Tests that EXPECT_FATAL_FAILURE() fails when the statement being +// tested returns. +TEST(ExpectFatalFailureTest, FailsWhenStatementReturns) { + printf("(expecting a failure)\n"); + EXPECT_FATAL_FAILURE({ + return; + }, ""); +} + +#if GTEST_HAS_EXCEPTIONS + +// Tests that EXPECT_FATAL_FAILURE() fails when the statement being +// tested throws. +TEST(ExpectFatalFailureTest, FailsWhenStatementThrows) { + printf("(expecting a failure)\n"); + try { + EXPECT_FATAL_FAILURE({ + throw 0; + }, ""); + } catch(int) { // NOLINT + } +} + +#endif // GTEST_HAS_EXCEPTIONS + +// This #ifdef block tests the output of value-parameterized tests. + +std::string ParamNameFunc(const testing::TestParamInfo& info) { + return info.param; +} + +class ParamTest : public testing::TestWithParam { +}; + +TEST_P(ParamTest, Success) { + EXPECT_EQ("a", GetParam()); +} + +TEST_P(ParamTest, Failure) { + EXPECT_EQ("b", GetParam()) << "Expected failure"; +} + +INSTANTIATE_TEST_SUITE_P(PrintingStrings, + ParamTest, + testing::Values(std::string("a")), + ParamNameFunc); + +// The case where a suite has INSTANTIATE_TEST_SUITE_P but not TEST_P. +using NoTests = ParamTest; +INSTANTIATE_TEST_SUITE_P(ThisIsOdd, NoTests, ::testing::Values("Hello")); + +// fails under kErrorOnUninstantiatedParameterizedTest=true +class DetectNotInstantiatedTest : public testing::TestWithParam {}; +TEST_P(DetectNotInstantiatedTest, Used) { } + +// This would make the test failure from the above go away. +// INSTANTIATE_TEST_SUITE_P(Fix, DetectNotInstantiatedTest, testing::Values(1)); + +template +class TypedTest : public testing::Test { +}; + +TYPED_TEST_SUITE(TypedTest, testing::Types); + +TYPED_TEST(TypedTest, Success) { + EXPECT_EQ(0, TypeParam()); +} + +TYPED_TEST(TypedTest, Failure) { + EXPECT_EQ(1, TypeParam()) << "Expected failure"; +} + +typedef testing::Types TypesForTestWithNames; + +template +class TypedTestWithNames : public testing::Test {}; + +class TypedTestNames { + public: + template + static std::string GetName(int i) { + if (std::is_same::value) + return std::string("char") + ::testing::PrintToString(i); + if (std::is_same::value) + return std::string("int") + ::testing::PrintToString(i); + } +}; + +TYPED_TEST_SUITE(TypedTestWithNames, TypesForTestWithNames, TypedTestNames); + +TYPED_TEST(TypedTestWithNames, Success) {} + +TYPED_TEST(TypedTestWithNames, Failure) { FAIL(); } + +template +class TypedTestP : public testing::Test { +}; + +TYPED_TEST_SUITE_P(TypedTestP); + +TYPED_TEST_P(TypedTestP, Success) { + EXPECT_EQ(0U, TypeParam()); +} + +TYPED_TEST_P(TypedTestP, Failure) { + EXPECT_EQ(1U, TypeParam()) << "Expected failure"; +} + +REGISTER_TYPED_TEST_SUITE_P(TypedTestP, Success, Failure); + +typedef testing::Types UnsignedTypes; +INSTANTIATE_TYPED_TEST_SUITE_P(Unsigned, TypedTestP, UnsignedTypes); + +class TypedTestPNames { + public: + template + static std::string GetName(int i) { + if (std::is_same::value) { + return std::string("unsignedChar") + ::testing::PrintToString(i); + } + if (std::is_same::value) { + return std::string("unsignedInt") + ::testing::PrintToString(i); + } + } +}; + +INSTANTIATE_TYPED_TEST_SUITE_P(UnsignedCustomName, TypedTestP, UnsignedTypes, + TypedTestPNames); + +template +class DetectNotInstantiatedTypesTest : public testing::Test {}; +TYPED_TEST_SUITE_P(DetectNotInstantiatedTypesTest); +TYPED_TEST_P(DetectNotInstantiatedTypesTest, Used) { + TypeParam instantiate; + (void)instantiate; +} +REGISTER_TYPED_TEST_SUITE_P(DetectNotInstantiatedTypesTest, Used); + +// kErrorOnUninstantiatedTypeParameterizedTest=true would make the above fail. +// Adding the following would make that test failure go away. +// +// typedef ::testing::Types MyTypes; +// INSTANTIATE_TYPED_TEST_SUITE_P(All, DetectNotInstantiatedTypesTest, MyTypes); + +#if GTEST_HAS_DEATH_TEST + +// We rely on the golden file to verify that tests whose test case +// name ends with DeathTest are run first. + +TEST(ADeathTest, ShouldRunFirst) { +} + +// We rely on the golden file to verify that typed tests whose test +// case name ends with DeathTest are run first. + +template +class ATypedDeathTest : public testing::Test { +}; + +typedef testing::Types NumericTypes; +TYPED_TEST_SUITE(ATypedDeathTest, NumericTypes); + +TYPED_TEST(ATypedDeathTest, ShouldRunFirst) { +} + + +// We rely on the golden file to verify that type-parameterized tests +// whose test case name ends with DeathTest are run first. + +template +class ATypeParamDeathTest : public testing::Test { +}; + +TYPED_TEST_SUITE_P(ATypeParamDeathTest); + +TYPED_TEST_P(ATypeParamDeathTest, ShouldRunFirst) { +} + +REGISTER_TYPED_TEST_SUITE_P(ATypeParamDeathTest, ShouldRunFirst); + +INSTANTIATE_TYPED_TEST_SUITE_P(My, ATypeParamDeathTest, NumericTypes); + +#endif // GTEST_HAS_DEATH_TEST + +// Tests various failure conditions of +// EXPECT_{,NON}FATAL_FAILURE{,_ON_ALL_THREADS}. +class ExpectFailureTest : public testing::Test { + public: // Must be public and not protected due to a bug in g++ 3.4.2. + enum FailureMode { + FATAL_FAILURE, + NONFATAL_FAILURE + }; + static void AddFailure(FailureMode failure) { + if (failure == FATAL_FAILURE) { + FAIL() << "Expected fatal failure."; + } else { + ADD_FAILURE() << "Expected non-fatal failure."; + } + } +}; + +TEST_F(ExpectFailureTest, ExpectFatalFailure) { + // Expected fatal failure, but succeeds. + printf("(expecting 1 failure)\n"); + EXPECT_FATAL_FAILURE(SUCCEED(), "Expected fatal failure."); + // Expected fatal failure, but got a non-fatal failure. + printf("(expecting 1 failure)\n"); + EXPECT_FATAL_FAILURE(AddFailure(NONFATAL_FAILURE), "Expected non-fatal " + "failure."); + // Wrong message. + printf("(expecting 1 failure)\n"); + EXPECT_FATAL_FAILURE(AddFailure(FATAL_FAILURE), "Some other fatal failure " + "expected."); +} + +TEST_F(ExpectFailureTest, ExpectNonFatalFailure) { + // Expected non-fatal failure, but succeeds. + printf("(expecting 1 failure)\n"); + EXPECT_NONFATAL_FAILURE(SUCCEED(), "Expected non-fatal failure."); + // Expected non-fatal failure, but got a fatal failure. + printf("(expecting 1 failure)\n"); + EXPECT_NONFATAL_FAILURE(AddFailure(FATAL_FAILURE), "Expected fatal failure."); + // Wrong message. + printf("(expecting 1 failure)\n"); + EXPECT_NONFATAL_FAILURE(AddFailure(NONFATAL_FAILURE), "Some other non-fatal " + "failure."); +} + +#if GTEST_IS_THREADSAFE + +class ExpectFailureWithThreadsTest : public ExpectFailureTest { + protected: + static void AddFailureInOtherThread(FailureMode failure) { + ThreadWithParam thread(&AddFailure, failure, nullptr); + thread.Join(); + } +}; + +TEST_F(ExpectFailureWithThreadsTest, ExpectFatalFailure) { + // We only intercept the current thread. + printf("(expecting 2 failures)\n"); + EXPECT_FATAL_FAILURE(AddFailureInOtherThread(FATAL_FAILURE), + "Expected fatal failure."); +} + +TEST_F(ExpectFailureWithThreadsTest, ExpectNonFatalFailure) { + // We only intercept the current thread. + printf("(expecting 2 failures)\n"); + EXPECT_NONFATAL_FAILURE(AddFailureInOtherThread(NONFATAL_FAILURE), + "Expected non-fatal failure."); +} + +typedef ExpectFailureWithThreadsTest ScopedFakeTestPartResultReporterTest; + +// Tests that the ScopedFakeTestPartResultReporter only catches failures from +// the current thread if it is instantiated with INTERCEPT_ONLY_CURRENT_THREAD. +TEST_F(ScopedFakeTestPartResultReporterTest, InterceptOnlyCurrentThread) { + printf("(expecting 2 failures)\n"); + TestPartResultArray results; + { + ScopedFakeTestPartResultReporter reporter( + ScopedFakeTestPartResultReporter::INTERCEPT_ONLY_CURRENT_THREAD, + &results); + AddFailureInOtherThread(FATAL_FAILURE); + AddFailureInOtherThread(NONFATAL_FAILURE); + } + // The two failures should not have been intercepted. + EXPECT_EQ(0, results.size()) << "This shouldn't fail."; +} + +#endif // GTEST_IS_THREADSAFE + +TEST_F(ExpectFailureTest, ExpectFatalFailureOnAllThreads) { + // Expected fatal failure, but succeeds. + printf("(expecting 1 failure)\n"); + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(SUCCEED(), "Expected fatal failure."); + // Expected fatal failure, but got a non-fatal failure. + printf("(expecting 1 failure)\n"); + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(AddFailure(NONFATAL_FAILURE), + "Expected non-fatal failure."); + // Wrong message. + printf("(expecting 1 failure)\n"); + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(AddFailure(FATAL_FAILURE), + "Some other fatal failure expected."); +} + +TEST_F(ExpectFailureTest, ExpectNonFatalFailureOnAllThreads) { + // Expected non-fatal failure, but succeeds. + printf("(expecting 1 failure)\n"); + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(SUCCEED(), "Expected non-fatal " + "failure."); + // Expected non-fatal failure, but got a fatal failure. + printf("(expecting 1 failure)\n"); + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(AddFailure(FATAL_FAILURE), + "Expected fatal failure."); + // Wrong message. + printf("(expecting 1 failure)\n"); + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(AddFailure(NONFATAL_FAILURE), + "Some other non-fatal failure."); +} + +class DynamicFixture : public testing::Test { + protected: + DynamicFixture() { printf("DynamicFixture()\n"); } + ~DynamicFixture() override { printf("~DynamicFixture()\n"); } + void SetUp() override { printf("DynamicFixture::SetUp\n"); } + void TearDown() override { printf("DynamicFixture::TearDown\n"); } + + static void SetUpTestSuite() { printf("DynamicFixture::SetUpTestSuite\n"); } + static void TearDownTestSuite() { + printf("DynamicFixture::TearDownTestSuite\n"); + } +}; + +template +class DynamicTest : public DynamicFixture { + public: + void TestBody() override { EXPECT_TRUE(Pass); } +}; + +auto dynamic_test = ( + // Register two tests with the same fixture correctly. + testing::RegisterTest( + "DynamicFixture", "DynamicTestPass", nullptr, nullptr, __FILE__, + __LINE__, []() -> DynamicFixture* { return new DynamicTest; }), + testing::RegisterTest( + "DynamicFixture", "DynamicTestFail", nullptr, nullptr, __FILE__, + __LINE__, []() -> DynamicFixture* { return new DynamicTest; }), + + // Register the same fixture with another name. That's fine. + testing::RegisterTest( + "DynamicFixtureAnotherName", "DynamicTestPass", nullptr, nullptr, + __FILE__, __LINE__, + []() -> DynamicFixture* { return new DynamicTest; }), + + // Register two tests with the same fixture incorrectly. + testing::RegisterTest( + "BadDynamicFixture1", "FixtureBase", nullptr, nullptr, __FILE__, + __LINE__, []() -> DynamicFixture* { return new DynamicTest; }), + testing::RegisterTest( + "BadDynamicFixture1", "TestBase", nullptr, nullptr, __FILE__, __LINE__, + []() -> testing::Test* { return new DynamicTest; }), + + // Register two tests with the same fixture incorrectly by omitting the + // return type. + testing::RegisterTest( + "BadDynamicFixture2", "FixtureBase", nullptr, nullptr, __FILE__, + __LINE__, []() -> DynamicFixture* { return new DynamicTest; }), + testing::RegisterTest("BadDynamicFixture2", "Derived", nullptr, nullptr, + __FILE__, __LINE__, + []() { return new DynamicTest; })); + +// Two test environments for testing testing::AddGlobalTestEnvironment(). + +class FooEnvironment : public testing::Environment { + public: + void SetUp() override { printf("%s", "FooEnvironment::SetUp() called.\n"); } + + void TearDown() override { + printf("%s", "FooEnvironment::TearDown() called.\n"); + FAIL() << "Expected fatal failure."; + } +}; + +class BarEnvironment : public testing::Environment { + public: + void SetUp() override { printf("%s", "BarEnvironment::SetUp() called.\n"); } + + void TearDown() override { + printf("%s", "BarEnvironment::TearDown() called.\n"); + ADD_FAILURE() << "Expected non-fatal failure."; + } +}; + +class TestSuiteThatFailsToSetUp : public testing::Test { + public: + static void SetUpTestSuite() { EXPECT_TRUE(false); } +}; +TEST_F(TestSuiteThatFailsToSetUp, ShouldNotRun) { + std::abort(); +} + +// The main function. +// +// The idea is to use Google Test to run all the tests we have defined (some +// of them are intended to fail), and then compare the test results +// with the "golden" file. +int main(int argc, char **argv) { + GTEST_FLAG_SET(print_time, false); + + // We just run the tests, knowing some of them are intended to fail. + // We will use a separate Python script to compare the output of + // this program with the golden file. + + // It's hard to test InitGoogleTest() directly, as it has many + // global side effects. The following line serves as a sanity test + // for it. + testing::InitGoogleTest(&argc, argv); + bool internal_skip_environment_and_ad_hoc_tests = + std::count(argv, argv + argc, + std::string("internal_skip_environment_and_ad_hoc_tests")) > 0; + +#if GTEST_HAS_DEATH_TEST + if (GTEST_FLAG_GET(internal_run_death_test) != "") { + // Skip the usual output capturing if we're running as the child + // process of an threadsafe-style death test. +# if GTEST_OS_WINDOWS + posix::FReopen("nul:", "w", stdout); +# else + posix::FReopen("/dev/null", "w", stdout); +# endif // GTEST_OS_WINDOWS + return RUN_ALL_TESTS(); + } +#endif // GTEST_HAS_DEATH_TEST + + if (internal_skip_environment_and_ad_hoc_tests) + return RUN_ALL_TESTS(); + + // Registers two global test environments. + // The golden file verifies that they are set up in the order they + // are registered, and torn down in the reverse order. + testing::AddGlobalTestEnvironment(new FooEnvironment); + testing::AddGlobalTestEnvironment(new BarEnvironment); +#if _MSC_VER +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4127 +#endif // _MSC_VER + return RunAllTests(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name1-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name1-test.py new file mode 100644 index 000000000000..b8d609a700c8 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name1-test.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# +# Copyright 2015 Google Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Verifies that Google Test warns the user when not initialized properly.""" + +from googletest.test import gtest_test_utils + +binary_name = 'googletest-param-test-invalid-name1-test_' +COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) + + +def Assert(condition): + if not condition: + raise AssertionError + + +def TestExitCodeAndOutput(command): + """Runs the given command and verifies its exit code and output.""" + + err = ('Parameterized test name \'"InvalidWithQuotes"\' is invalid') + + p = gtest_test_utils.Subprocess(command) + Assert(p.terminated_by_signal) + + # Verify the output message contains appropriate output + Assert(err in p.output) + + +class GTestParamTestInvalidName1Test(gtest_test_utils.TestCase): + + def testExitCodeAndOutput(self): + TestExitCodeAndOutput(COMMAND) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name1-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name1-test_.cc new file mode 100644 index 000000000000..955d699900d8 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name1-test_.cc @@ -0,0 +1,50 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "gtest/gtest.h" + +namespace { +class DummyTest : public ::testing::TestWithParam {}; + +TEST_P(DummyTest, Dummy) { +} + +INSTANTIATE_TEST_SUITE_P(InvalidTestName, + DummyTest, + ::testing::Values("InvalidWithQuotes"), + ::testing::PrintToStringParamName()); + +} // namespace + +int main(int argc, char *argv[]) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name2-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name2-test.py new file mode 100644 index 000000000000..d92fa065ae23 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name2-test.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# +# Copyright 2015 Google Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Verifies that Google Test warns the user when not initialized properly.""" + +from googletest.test import gtest_test_utils + +binary_name = 'googletest-param-test-invalid-name2-test_' +COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) + + +def Assert(condition): + if not condition: + raise AssertionError + + +def TestExitCodeAndOutput(command): + """Runs the given command and verifies its exit code and output.""" + + err = ('Duplicate parameterized test name \'a\'') + + p = gtest_test_utils.Subprocess(command) + Assert(p.terminated_by_signal) + + # Check for appropriate output + Assert(err in p.output) + + +class GTestParamTestInvalidName2Test(gtest_test_utils.TestCase): + + def testExitCodeAndOutput(self): + TestExitCodeAndOutput(COMMAND) + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name2-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name2-test_.cc new file mode 100644 index 000000000000..76371df54f0b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-invalid-name2-test_.cc @@ -0,0 +1,55 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "gtest/gtest.h" + +namespace { +class DummyTest : public ::testing::TestWithParam {}; + +std::string StringParamTestSuffix( + const testing::TestParamInfo& info) { + return std::string(info.param); +} + +TEST_P(DummyTest, Dummy) { +} + +INSTANTIATE_TEST_SUITE_P(DuplicateTestNames, + DummyTest, + ::testing::Values("a", "b", "a", "c"), + StringParamTestSuffix); +} // namespace + +int main(int argc, char *argv[]) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + + diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-test.cc new file mode 100644 index 000000000000..023aa46d69f9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-test.cc @@ -0,0 +1,1119 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests for Google Test itself. This file verifies that the parameter +// generators objects produce correct parameter sequences and that +// Google Test runtime instantiates correct tests from those sequences. + +#include "gtest/gtest.h" + +# include +# include +# include +# include +# include +# include +# include + +# include "src/gtest-internal-inl.h" // for UnitTestOptions +# include "test/googletest-param-test-test.h" + +using ::std::vector; +using ::std::sort; + +using ::testing::AddGlobalTestEnvironment; +using ::testing::Bool; +using ::testing::Combine; +using ::testing::Message; +using ::testing::Range; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::ValuesIn; + +using ::testing::internal::ParamGenerator; +using ::testing::internal::UnitTestOptions; + +// Prints a value to a string. +// +// FIXME: remove PrintValue() when we move matchers and +// EXPECT_THAT() from Google Mock to Google Test. At that time, we +// can write EXPECT_THAT(x, Eq(y)) to compare two tuples x and y, as +// EXPECT_THAT() and the matchers know how to print tuples. +template +::std::string PrintValue(const T& value) { + return testing::PrintToString(value); +} + +// Verifies that a sequence generated by the generator and accessed +// via the iterator object matches the expected one using Google Test +// assertions. +template +void VerifyGenerator(const ParamGenerator& generator, + const T (&expected_values)[N]) { + typename ParamGenerator::iterator it = generator.begin(); + for (size_t i = 0; i < N; ++i) { + ASSERT_FALSE(it == generator.end()) + << "At element " << i << " when accessing via an iterator " + << "created with the copy constructor.\n"; + // We cannot use EXPECT_EQ() here as the values may be tuples, + // which don't support <<. + EXPECT_TRUE(expected_values[i] == *it) + << "where i is " << i + << ", expected_values[i] is " << PrintValue(expected_values[i]) + << ", *it is " << PrintValue(*it) + << ", and 'it' is an iterator created with the copy constructor.\n"; + ++it; + } + EXPECT_TRUE(it == generator.end()) + << "At the presumed end of sequence when accessing via an iterator " + << "created with the copy constructor.\n"; + + // Test the iterator assignment. The following lines verify that + // the sequence accessed via an iterator initialized via the + // assignment operator (as opposed to a copy constructor) matches + // just the same. + it = generator.begin(); + for (size_t i = 0; i < N; ++i) { + ASSERT_FALSE(it == generator.end()) + << "At element " << i << " when accessing via an iterator " + << "created with the assignment operator.\n"; + EXPECT_TRUE(expected_values[i] == *it) + << "where i is " << i + << ", expected_values[i] is " << PrintValue(expected_values[i]) + << ", *it is " << PrintValue(*it) + << ", and 'it' is an iterator created with the copy constructor.\n"; + ++it; + } + EXPECT_TRUE(it == generator.end()) + << "At the presumed end of sequence when accessing via an iterator " + << "created with the assignment operator.\n"; +} + +template +void VerifyGeneratorIsEmpty(const ParamGenerator& generator) { + typename ParamGenerator::iterator it = generator.begin(); + EXPECT_TRUE(it == generator.end()); + + it = generator.begin(); + EXPECT_TRUE(it == generator.end()); +} + +// Generator tests. They test that each of the provided generator functions +// generates an expected sequence of values. The general test pattern +// instantiates a generator using one of the generator functions, +// checks the sequence produced by the generator using its iterator API, +// and then resets the iterator back to the beginning of the sequence +// and checks the sequence again. + +// Tests that iterators produced by generator functions conform to the +// ForwardIterator concept. +TEST(IteratorTest, ParamIteratorConformsToForwardIteratorConcept) { + const ParamGenerator gen = Range(0, 10); + ParamGenerator::iterator it = gen.begin(); + + // Verifies that iterator initialization works as expected. + ParamGenerator::iterator it2 = it; + EXPECT_TRUE(*it == *it2) << "Initialized iterators must point to the " + << "element same as its source points to"; + + // Verifies that iterator assignment works as expected. + ++it; + EXPECT_FALSE(*it == *it2); + it2 = it; + EXPECT_TRUE(*it == *it2) << "Assigned iterators must point to the " + << "element same as its source points to"; + + // Verifies that prefix operator++() returns *this. + EXPECT_EQ(&it, &(++it)) << "Result of the prefix operator++ must be " + << "refer to the original object"; + + // Verifies that the result of the postfix operator++ points to the value + // pointed to by the original iterator. + int original_value = *it; // Have to compute it outside of macro call to be + // unaffected by the parameter evaluation order. + EXPECT_EQ(original_value, *(it++)); + + // Verifies that prefix and postfix operator++() advance an iterator + // all the same. + it2 = it; + ++it; + ++it2; + EXPECT_TRUE(*it == *it2); +} + +// Tests that Range() generates the expected sequence. +TEST(RangeTest, IntRangeWithDefaultStep) { + const ParamGenerator gen = Range(0, 3); + const int expected_values[] = {0, 1, 2}; + VerifyGenerator(gen, expected_values); +} + +// Edge case. Tests that Range() generates the single element sequence +// as expected when provided with range limits that are equal. +TEST(RangeTest, IntRangeSingleValue) { + const ParamGenerator gen = Range(0, 1); + const int expected_values[] = {0}; + VerifyGenerator(gen, expected_values); +} + +// Edge case. Tests that Range() with generates empty sequence when +// supplied with an empty range. +TEST(RangeTest, IntRangeEmpty) { + const ParamGenerator gen = Range(0, 0); + VerifyGeneratorIsEmpty(gen); +} + +// Tests that Range() with custom step (greater then one) generates +// the expected sequence. +TEST(RangeTest, IntRangeWithCustomStep) { + const ParamGenerator gen = Range(0, 9, 3); + const int expected_values[] = {0, 3, 6}; + VerifyGenerator(gen, expected_values); +} + +// Tests that Range() with custom step (greater then one) generates +// the expected sequence when the last element does not fall on the +// upper range limit. Sequences generated by Range() must not have +// elements beyond the range limits. +TEST(RangeTest, IntRangeWithCustomStepOverUpperBound) { + const ParamGenerator gen = Range(0, 4, 3); + const int expected_values[] = {0, 3}; + VerifyGenerator(gen, expected_values); +} + +// Verifies that Range works with user-defined types that define +// copy constructor, operator=(), operator+(), and operator<(). +class DogAdder { + public: + explicit DogAdder(const char* a_value) : value_(a_value) {} + DogAdder(const DogAdder& other) : value_(other.value_.c_str()) {} + + DogAdder operator=(const DogAdder& other) { + if (this != &other) + value_ = other.value_; + return *this; + } + DogAdder operator+(const DogAdder& other) const { + Message msg; + msg << value_.c_str() << other.value_.c_str(); + return DogAdder(msg.GetString().c_str()); + } + bool operator<(const DogAdder& other) const { + return value_ < other.value_; + } + const std::string& value() const { return value_; } + + private: + std::string value_; +}; + +TEST(RangeTest, WorksWithACustomType) { + const ParamGenerator gen = + Range(DogAdder("cat"), DogAdder("catdogdog"), DogAdder("dog")); + ParamGenerator::iterator it = gen.begin(); + + ASSERT_FALSE(it == gen.end()); + EXPECT_STREQ("cat", it->value().c_str()); + + ASSERT_FALSE(++it == gen.end()); + EXPECT_STREQ("catdog", it->value().c_str()); + + EXPECT_TRUE(++it == gen.end()); +} + +class IntWrapper { + public: + explicit IntWrapper(int a_value) : value_(a_value) {} + IntWrapper(const IntWrapper& other) : value_(other.value_) {} + + IntWrapper operator=(const IntWrapper& other) { + value_ = other.value_; + return *this; + } + // operator+() adds a different type. + IntWrapper operator+(int other) const { return IntWrapper(value_ + other); } + bool operator<(const IntWrapper& other) const { + return value_ < other.value_; + } + int value() const { return value_; } + + private: + int value_; +}; + +TEST(RangeTest, WorksWithACustomTypeWithDifferentIncrementType) { + const ParamGenerator gen = Range(IntWrapper(0), IntWrapper(2)); + ParamGenerator::iterator it = gen.begin(); + + ASSERT_FALSE(it == gen.end()); + EXPECT_EQ(0, it->value()); + + ASSERT_FALSE(++it == gen.end()); + EXPECT_EQ(1, it->value()); + + EXPECT_TRUE(++it == gen.end()); +} + +// Tests that ValuesIn() with an array parameter generates +// the expected sequence. +TEST(ValuesInTest, ValuesInArray) { + int array[] = {3, 5, 8}; + const ParamGenerator gen = ValuesIn(array); + VerifyGenerator(gen, array); +} + +// Tests that ValuesIn() with a const array parameter generates +// the expected sequence. +TEST(ValuesInTest, ValuesInConstArray) { + const int array[] = {3, 5, 8}; + const ParamGenerator gen = ValuesIn(array); + VerifyGenerator(gen, array); +} + +// Edge case. Tests that ValuesIn() with an array parameter containing a +// single element generates the single element sequence. +TEST(ValuesInTest, ValuesInSingleElementArray) { + int array[] = {42}; + const ParamGenerator gen = ValuesIn(array); + VerifyGenerator(gen, array); +} + +// Tests that ValuesIn() generates the expected sequence for an STL +// container (vector). +TEST(ValuesInTest, ValuesInVector) { + typedef ::std::vector ContainerType; + ContainerType values; + values.push_back(3); + values.push_back(5); + values.push_back(8); + const ParamGenerator gen = ValuesIn(values); + + const int expected_values[] = {3, 5, 8}; + VerifyGenerator(gen, expected_values); +} + +// Tests that ValuesIn() generates the expected sequence. +TEST(ValuesInTest, ValuesInIteratorRange) { + typedef ::std::vector ContainerType; + ContainerType values; + values.push_back(3); + values.push_back(5); + values.push_back(8); + const ParamGenerator gen = ValuesIn(values.begin(), values.end()); + + const int expected_values[] = {3, 5, 8}; + VerifyGenerator(gen, expected_values); +} + +// Edge case. Tests that ValuesIn() provided with an iterator range specifying a +// single value generates a single-element sequence. +TEST(ValuesInTest, ValuesInSingleElementIteratorRange) { + typedef ::std::vector ContainerType; + ContainerType values; + values.push_back(42); + const ParamGenerator gen = ValuesIn(values.begin(), values.end()); + + const int expected_values[] = {42}; + VerifyGenerator(gen, expected_values); +} + +// Edge case. Tests that ValuesIn() provided with an empty iterator range +// generates an empty sequence. +TEST(ValuesInTest, ValuesInEmptyIteratorRange) { + typedef ::std::vector ContainerType; + ContainerType values; + const ParamGenerator gen = ValuesIn(values.begin(), values.end()); + + VerifyGeneratorIsEmpty(gen); +} + +// Tests that the Values() generates the expected sequence. +TEST(ValuesTest, ValuesWorks) { + const ParamGenerator gen = Values(3, 5, 8); + + const int expected_values[] = {3, 5, 8}; + VerifyGenerator(gen, expected_values); +} + +// Tests that Values() generates the expected sequences from elements of +// different types convertible to ParamGenerator's parameter type. +TEST(ValuesTest, ValuesWorksForValuesOfCompatibleTypes) { + const ParamGenerator gen = Values(3, 5.0f, 8.0); + + const double expected_values[] = {3.0, 5.0, 8.0}; + VerifyGenerator(gen, expected_values); +} + +TEST(ValuesTest, ValuesWorksForMaxLengthList) { + const ParamGenerator gen = Values( + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, + 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, + 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, + 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, + 410, 420, 430, 440, 450, 460, 470, 480, 490, 500); + + const int expected_values[] = { + 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, + 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, + 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, + 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, + 410, 420, 430, 440, 450, 460, 470, 480, 490, 500}; + VerifyGenerator(gen, expected_values); +} + +// Edge case test. Tests that single-parameter Values() generates the sequence +// with the single value. +TEST(ValuesTest, ValuesWithSingleParameter) { + const ParamGenerator gen = Values(42); + + const int expected_values[] = {42}; + VerifyGenerator(gen, expected_values); +} + +// Tests that Bool() generates sequence (false, true). +TEST(BoolTest, BoolWorks) { + const ParamGenerator gen = Bool(); + + const bool expected_values[] = {false, true}; + VerifyGenerator(gen, expected_values); +} + +// Tests that Combine() with two parameters generates the expected sequence. +TEST(CombineTest, CombineWithTwoParameters) { + const char* foo = "foo"; + const char* bar = "bar"; + const ParamGenerator > gen = + Combine(Values(foo, bar), Values(3, 4)); + + std::tuple expected_values[] = { + std::make_tuple(foo, 3), std::make_tuple(foo, 4), std::make_tuple(bar, 3), + std::make_tuple(bar, 4)}; + VerifyGenerator(gen, expected_values); +} + +// Tests that Combine() with three parameters generates the expected sequence. +TEST(CombineTest, CombineWithThreeParameters) { + const ParamGenerator > gen = + Combine(Values(0, 1), Values(3, 4), Values(5, 6)); + std::tuple expected_values[] = { + std::make_tuple(0, 3, 5), std::make_tuple(0, 3, 6), + std::make_tuple(0, 4, 5), std::make_tuple(0, 4, 6), + std::make_tuple(1, 3, 5), std::make_tuple(1, 3, 6), + std::make_tuple(1, 4, 5), std::make_tuple(1, 4, 6)}; + VerifyGenerator(gen, expected_values); +} + +// Tests that the Combine() with the first parameter generating a single value +// sequence generates a sequence with the number of elements equal to the +// number of elements in the sequence generated by the second parameter. +TEST(CombineTest, CombineWithFirstParameterSingleValue) { + const ParamGenerator > gen = + Combine(Values(42), Values(0, 1)); + + std::tuple expected_values[] = {std::make_tuple(42, 0), + std::make_tuple(42, 1)}; + VerifyGenerator(gen, expected_values); +} + +// Tests that the Combine() with the second parameter generating a single value +// sequence generates a sequence with the number of elements equal to the +// number of elements in the sequence generated by the first parameter. +TEST(CombineTest, CombineWithSecondParameterSingleValue) { + const ParamGenerator > gen = + Combine(Values(0, 1), Values(42)); + + std::tuple expected_values[] = {std::make_tuple(0, 42), + std::make_tuple(1, 42)}; + VerifyGenerator(gen, expected_values); +} + +// Tests that when the first parameter produces an empty sequence, +// Combine() produces an empty sequence, too. +TEST(CombineTest, CombineWithFirstParameterEmptyRange) { + const ParamGenerator > gen = + Combine(Range(0, 0), Values(0, 1)); + VerifyGeneratorIsEmpty(gen); +} + +// Tests that when the second parameter produces an empty sequence, +// Combine() produces an empty sequence, too. +TEST(CombineTest, CombineWithSecondParameterEmptyRange) { + const ParamGenerator > gen = + Combine(Values(0, 1), Range(1, 1)); + VerifyGeneratorIsEmpty(gen); +} + +// Edge case. Tests that combine works with the maximum number +// of parameters supported by Google Test (currently 10). +TEST(CombineTest, CombineWithMaxNumberOfParameters) { + const char* foo = "foo"; + const char* bar = "bar"; + const ParamGenerator< + std::tuple > + gen = + Combine(Values(foo, bar), Values(1), Values(2), Values(3), Values(4), + Values(5), Values(6), Values(7), Values(8), Values(9)); + + std::tuple + expected_values[] = {std::make_tuple(foo, 1, 2, 3, 4, 5, 6, 7, 8, 9), + std::make_tuple(bar, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; + VerifyGenerator(gen, expected_values); +} + +class NonDefaultConstructAssignString { + public: + NonDefaultConstructAssignString(const std::string& s) : str_(s) {} + NonDefaultConstructAssignString() = delete; + NonDefaultConstructAssignString(const NonDefaultConstructAssignString&) = + default; + NonDefaultConstructAssignString& operator=( + const NonDefaultConstructAssignString&) = delete; + ~NonDefaultConstructAssignString() = default; + + const std::string& str() const { return str_; } + + private: + std::string str_; +}; + +TEST(CombineTest, NonDefaultConstructAssign) { + const ParamGenerator > gen = + Combine(Values(0, 1), Values(NonDefaultConstructAssignString("A"), + NonDefaultConstructAssignString("B"))); + + ParamGenerator >::iterator + it = gen.begin(); + + EXPECT_EQ(0, std::get<0>(*it)); + EXPECT_EQ("A", std::get<1>(*it).str()); + ++it; + + EXPECT_EQ(0, std::get<0>(*it)); + EXPECT_EQ("B", std::get<1>(*it).str()); + ++it; + + EXPECT_EQ(1, std::get<0>(*it)); + EXPECT_EQ("A", std::get<1>(*it).str()); + ++it; + + EXPECT_EQ(1, std::get<0>(*it)); + EXPECT_EQ("B", std::get<1>(*it).str()); + ++it; + + EXPECT_TRUE(it == gen.end()); +} + + +// Tests that an generator produces correct sequence after being +// assigned from another generator. +TEST(ParamGeneratorTest, AssignmentWorks) { + ParamGenerator gen = Values(1, 2); + const ParamGenerator gen2 = Values(3, 4); + gen = gen2; + + const int expected_values[] = {3, 4}; + VerifyGenerator(gen, expected_values); +} + +// This test verifies that the tests are expanded and run as specified: +// one test per element from the sequence produced by the generator +// specified in INSTANTIATE_TEST_SUITE_P. It also verifies that the test's +// fixture constructor, SetUp(), and TearDown() have run and have been +// supplied with the correct parameters. + +// The use of environment object allows detection of the case where no test +// case functionality is run at all. In this case TearDownTestSuite will not +// be able to detect missing tests, naturally. +template +class TestGenerationEnvironment : public ::testing::Environment { + public: + static TestGenerationEnvironment* Instance() { + static TestGenerationEnvironment* instance = new TestGenerationEnvironment; + return instance; + } + + void FixtureConstructorExecuted() { fixture_constructor_count_++; } + void SetUpExecuted() { set_up_count_++; } + void TearDownExecuted() { tear_down_count_++; } + void TestBodyExecuted() { test_body_count_++; } + + void TearDown() override { + // If all MultipleTestGenerationTest tests have been de-selected + // by the filter flag, the following checks make no sense. + bool perform_check = false; + + for (int i = 0; i < kExpectedCalls; ++i) { + Message msg; + msg << "TestsExpandedAndRun/" << i; + if (UnitTestOptions::FilterMatchesTest( + "TestExpansionModule/MultipleTestGenerationTest", + msg.GetString().c_str())) { + perform_check = true; + } + } + if (perform_check) { + EXPECT_EQ(kExpectedCalls, fixture_constructor_count_) + << "Fixture constructor of ParamTestGenerationTest test case " + << "has not been run as expected."; + EXPECT_EQ(kExpectedCalls, set_up_count_) + << "Fixture SetUp method of ParamTestGenerationTest test case " + << "has not been run as expected."; + EXPECT_EQ(kExpectedCalls, tear_down_count_) + << "Fixture TearDown method of ParamTestGenerationTest test case " + << "has not been run as expected."; + EXPECT_EQ(kExpectedCalls, test_body_count_) + << "Test in ParamTestGenerationTest test case " + << "has not been run as expected."; + } + } + + private: + TestGenerationEnvironment() : fixture_constructor_count_(0), set_up_count_(0), + tear_down_count_(0), test_body_count_(0) {} + + int fixture_constructor_count_; + int set_up_count_; + int tear_down_count_; + int test_body_count_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestGenerationEnvironment); +}; + +const int test_generation_params[] = {36, 42, 72}; + +class TestGenerationTest : public TestWithParam { + public: + enum { + PARAMETER_COUNT = + sizeof(test_generation_params)/sizeof(test_generation_params[0]) + }; + + typedef TestGenerationEnvironment Environment; + + TestGenerationTest() { + Environment::Instance()->FixtureConstructorExecuted(); + current_parameter_ = GetParam(); + } + void SetUp() override { + Environment::Instance()->SetUpExecuted(); + EXPECT_EQ(current_parameter_, GetParam()); + } + void TearDown() override { + Environment::Instance()->TearDownExecuted(); + EXPECT_EQ(current_parameter_, GetParam()); + } + + static void SetUpTestSuite() { + bool all_tests_in_test_case_selected = true; + + for (int i = 0; i < PARAMETER_COUNT; ++i) { + Message test_name; + test_name << "TestsExpandedAndRun/" << i; + if ( !UnitTestOptions::FilterMatchesTest( + "TestExpansionModule/MultipleTestGenerationTest", + test_name.GetString())) { + all_tests_in_test_case_selected = false; + } + } + EXPECT_TRUE(all_tests_in_test_case_selected) + << "When running the TestGenerationTest test case all of its tests\n" + << "must be selected by the filter flag for the test case to pass.\n" + << "If not all of them are enabled, we can't reliably conclude\n" + << "that the correct number of tests have been generated."; + + collected_parameters_.clear(); + } + + static void TearDownTestSuite() { + vector expected_values(test_generation_params, + test_generation_params + PARAMETER_COUNT); + // Test execution order is not guaranteed by Google Test, + // so the order of values in collected_parameters_ can be + // different and we have to sort to compare. + sort(expected_values.begin(), expected_values.end()); + sort(collected_parameters_.begin(), collected_parameters_.end()); + + EXPECT_TRUE(collected_parameters_ == expected_values); + } + + protected: + int current_parameter_; + static vector collected_parameters_; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestGenerationTest); +}; +vector TestGenerationTest::collected_parameters_; + +TEST_P(TestGenerationTest, TestsExpandedAndRun) { + Environment::Instance()->TestBodyExecuted(); + EXPECT_EQ(current_parameter_, GetParam()); + collected_parameters_.push_back(GetParam()); +} +INSTANTIATE_TEST_SUITE_P(TestExpansionModule, TestGenerationTest, + ValuesIn(test_generation_params)); + +// This test verifies that the element sequence (third parameter of +// INSTANTIATE_TEST_SUITE_P) is evaluated in InitGoogleTest() and neither at +// the call site of INSTANTIATE_TEST_SUITE_P nor in RUN_ALL_TESTS(). For +// that, we declare param_value_ to be a static member of +// GeneratorEvaluationTest and initialize it to 0. We set it to 1 in +// main(), just before invocation of InitGoogleTest(). After calling +// InitGoogleTest(), we set the value to 2. If the sequence is evaluated +// before or after InitGoogleTest, INSTANTIATE_TEST_SUITE_P will create a +// test with parameter other than 1, and the test body will fail the +// assertion. +class GeneratorEvaluationTest : public TestWithParam { + public: + static int param_value() { return param_value_; } + static void set_param_value(int param_value) { param_value_ = param_value; } + + private: + static int param_value_; +}; +int GeneratorEvaluationTest::param_value_ = 0; + +TEST_P(GeneratorEvaluationTest, GeneratorsEvaluatedInMain) { + EXPECT_EQ(1, GetParam()); +} +INSTANTIATE_TEST_SUITE_P(GenEvalModule, GeneratorEvaluationTest, + Values(GeneratorEvaluationTest::param_value())); + +// Tests that generators defined in a different translation unit are +// functional. Generator extern_gen is defined in gtest-param-test_test2.cc. +extern ParamGenerator extern_gen; +class ExternalGeneratorTest : public TestWithParam {}; +TEST_P(ExternalGeneratorTest, ExternalGenerator) { + // Sequence produced by extern_gen contains only a single value + // which we verify here. + EXPECT_EQ(GetParam(), 33); +} +INSTANTIATE_TEST_SUITE_P(ExternalGeneratorModule, ExternalGeneratorTest, + extern_gen); + +// Tests that a parameterized test case can be defined in one translation +// unit and instantiated in another. This test will be instantiated in +// gtest-param-test_test2.cc. ExternalInstantiationTest fixture class is +// defined in gtest-param-test_test.h. +TEST_P(ExternalInstantiationTest, IsMultipleOf33) { + EXPECT_EQ(0, GetParam() % 33); +} + +// Tests that a parameterized test case can be instantiated with multiple +// generators. +class MultipleInstantiationTest : public TestWithParam {}; +TEST_P(MultipleInstantiationTest, AllowsMultipleInstances) { +} +INSTANTIATE_TEST_SUITE_P(Sequence1, MultipleInstantiationTest, Values(1, 2)); +INSTANTIATE_TEST_SUITE_P(Sequence2, MultipleInstantiationTest, Range(3, 5)); + +// Tests that a parameterized test case can be instantiated +// in multiple translation units. This test will be instantiated +// here and in gtest-param-test_test2.cc. +// InstantiationInMultipleTranslationUnitsTest fixture class +// is defined in gtest-param-test_test.h. +TEST_P(InstantiationInMultipleTranslationUnitsTest, IsMultipleOf42) { + EXPECT_EQ(0, GetParam() % 42); +} +INSTANTIATE_TEST_SUITE_P(Sequence1, InstantiationInMultipleTranslationUnitsTest, + Values(42, 42 * 2)); + +// Tests that each iteration of parameterized test runs in a separate test +// object. +class SeparateInstanceTest : public TestWithParam { + public: + SeparateInstanceTest() : count_(0) {} + + static void TearDownTestSuite() { + EXPECT_GE(global_count_, 2) + << "If some (but not all) SeparateInstanceTest tests have been " + << "filtered out this test will fail. Make sure that all " + << "GeneratorEvaluationTest are selected or de-selected together " + << "by the test filter."; + } + + protected: + int count_; + static int global_count_; +}; +int SeparateInstanceTest::global_count_ = 0; + +TEST_P(SeparateInstanceTest, TestsRunInSeparateInstances) { + EXPECT_EQ(0, count_++); + global_count_++; +} +INSTANTIATE_TEST_SUITE_P(FourElemSequence, SeparateInstanceTest, Range(1, 4)); + +// Tests that all instantiations of a test have named appropriately. Test +// defined with TEST_P(TestSuiteName, TestName) and instantiated with +// INSTANTIATE_TEST_SUITE_P(SequenceName, TestSuiteName, generator) must be +// named SequenceName/TestSuiteName.TestName/i, where i is the 0-based index of +// the sequence element used to instantiate the test. +class NamingTest : public TestWithParam {}; + +TEST_P(NamingTest, TestsReportCorrectNamesAndParameters) { + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + + EXPECT_STREQ("ZeroToFiveSequence/NamingTest", test_info->test_suite_name()); + + Message index_stream; + index_stream << "TestsReportCorrectNamesAndParameters/" << GetParam(); + EXPECT_STREQ(index_stream.GetString().c_str(), test_info->name()); + + EXPECT_EQ(::testing::PrintToString(GetParam()), test_info->value_param()); +} + +INSTANTIATE_TEST_SUITE_P(ZeroToFiveSequence, NamingTest, Range(0, 5)); + +// Tests that macros in test names are expanded correctly. +class MacroNamingTest : public TestWithParam {}; + +#define PREFIX_WITH_FOO(test_name) Foo##test_name +#define PREFIX_WITH_MACRO(test_name) Macro##test_name + +TEST_P(PREFIX_WITH_MACRO(NamingTest), PREFIX_WITH_FOO(SomeTestName)) { + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + + EXPECT_STREQ("FortyTwo/MacroNamingTest", test_info->test_suite_name()); + EXPECT_STREQ("FooSomeTestName/0", test_info->name()); +} + +INSTANTIATE_TEST_SUITE_P(FortyTwo, MacroNamingTest, Values(42)); + +// Tests the same thing for non-parametrized tests. +class MacroNamingTestNonParametrized : public ::testing::Test {}; + +TEST_F(PREFIX_WITH_MACRO(NamingTestNonParametrized), + PREFIX_WITH_FOO(SomeTestName)) { + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + + EXPECT_STREQ("MacroNamingTestNonParametrized", test_info->test_suite_name()); + EXPECT_STREQ("FooSomeTestName", test_info->name()); +} + +TEST(MacroNameing, LookupNames) { + std::set know_suite_names, know_test_names; + + auto ins = testing::UnitTest::GetInstance(); + int ts = 0; + while (const testing::TestSuite* suite = ins->GetTestSuite(ts++)) { + know_suite_names.insert(suite->name()); + + int ti = 0; + while (const testing::TestInfo* info = suite->GetTestInfo(ti++)) { + know_test_names.insert(std::string(suite->name()) + "." + info->name()); + } + } + + // Check that the expected form of the test suit name actually exists. + EXPECT_NE( // + know_suite_names.find("FortyTwo/MacroNamingTest"), + know_suite_names.end()); + EXPECT_NE( + know_suite_names.find("MacroNamingTestNonParametrized"), + know_suite_names.end()); + // Check that the expected form of the test name actually exists. + EXPECT_NE( // + know_test_names.find("FortyTwo/MacroNamingTest.FooSomeTestName/0"), + know_test_names.end()); + EXPECT_NE( + know_test_names.find("MacroNamingTestNonParametrized.FooSomeTestName"), + know_test_names.end()); +} + +// Tests that user supplied custom parameter names are working correctly. +// Runs the test with a builtin helper method which uses PrintToString, +// as well as a custom function and custom functor to ensure all possible +// uses work correctly. +class CustomFunctorNamingTest : public TestWithParam {}; +TEST_P(CustomFunctorNamingTest, CustomTestNames) {} + +struct CustomParamNameFunctor { + std::string operator()(const ::testing::TestParamInfo& inf) { + return inf.param; + } +}; + +INSTANTIATE_TEST_SUITE_P(CustomParamNameFunctor, CustomFunctorNamingTest, + Values(std::string("FunctorName")), + CustomParamNameFunctor()); + +INSTANTIATE_TEST_SUITE_P(AllAllowedCharacters, CustomFunctorNamingTest, + Values("abcdefghijklmnopqrstuvwxyz", + "ABCDEFGHIJKLMNOPQRSTUVWXYZ", "01234567890_"), + CustomParamNameFunctor()); + +inline std::string CustomParamNameFunction( + const ::testing::TestParamInfo& inf) { + return inf.param; +} + +class CustomFunctionNamingTest : public TestWithParam {}; +TEST_P(CustomFunctionNamingTest, CustomTestNames) {} + +INSTANTIATE_TEST_SUITE_P(CustomParamNameFunction, CustomFunctionNamingTest, + Values(std::string("FunctionName")), + CustomParamNameFunction); + +INSTANTIATE_TEST_SUITE_P(CustomParamNameFunctionP, CustomFunctionNamingTest, + Values(std::string("FunctionNameP")), + &CustomParamNameFunction); + +// Test custom naming with a lambda + +class CustomLambdaNamingTest : public TestWithParam {}; +TEST_P(CustomLambdaNamingTest, CustomTestNames) {} + +INSTANTIATE_TEST_SUITE_P(CustomParamNameLambda, CustomLambdaNamingTest, + Values(std::string("LambdaName")), + [](const ::testing::TestParamInfo& inf) { + return inf.param; + }); + +TEST(CustomNamingTest, CheckNameRegistry) { + ::testing::UnitTest* unit_test = ::testing::UnitTest::GetInstance(); + std::set test_names; + for (int suite_num = 0; suite_num < unit_test->total_test_suite_count(); + ++suite_num) { + const ::testing::TestSuite* test_suite = unit_test->GetTestSuite(suite_num); + for (int test_num = 0; test_num < test_suite->total_test_count(); + ++test_num) { + const ::testing::TestInfo* test_info = test_suite->GetTestInfo(test_num); + test_names.insert(std::string(test_info->name())); + } + } + EXPECT_EQ(1u, test_names.count("CustomTestNames/FunctorName")); + EXPECT_EQ(1u, test_names.count("CustomTestNames/FunctionName")); + EXPECT_EQ(1u, test_names.count("CustomTestNames/FunctionNameP")); + EXPECT_EQ(1u, test_names.count("CustomTestNames/LambdaName")); +} + +// Test a numeric name to ensure PrintToStringParamName works correctly. + +class CustomIntegerNamingTest : public TestWithParam {}; + +TEST_P(CustomIntegerNamingTest, TestsReportCorrectNames) { + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + Message test_name_stream; + test_name_stream << "TestsReportCorrectNames/" << GetParam(); + EXPECT_STREQ(test_name_stream.GetString().c_str(), test_info->name()); +} + +INSTANTIATE_TEST_SUITE_P(PrintToString, CustomIntegerNamingTest, Range(0, 5), + ::testing::PrintToStringParamName()); + +// Test a custom struct with PrintToString. + +struct CustomStruct { + explicit CustomStruct(int value) : x(value) {} + int x; +}; + +std::ostream& operator<<(std::ostream& stream, const CustomStruct& val) { + stream << val.x; + return stream; +} + +class CustomStructNamingTest : public TestWithParam {}; + +TEST_P(CustomStructNamingTest, TestsReportCorrectNames) { + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + Message test_name_stream; + test_name_stream << "TestsReportCorrectNames/" << GetParam(); + EXPECT_STREQ(test_name_stream.GetString().c_str(), test_info->name()); +} + +INSTANTIATE_TEST_SUITE_P(PrintToString, CustomStructNamingTest, + Values(CustomStruct(0), CustomStruct(1)), + ::testing::PrintToStringParamName()); + +// Test that using a stateful parameter naming function works as expected. + +struct StatefulNamingFunctor { + StatefulNamingFunctor() : sum(0) {} + std::string operator()(const ::testing::TestParamInfo& info) { + int value = info.param + sum; + sum += info.param; + return ::testing::PrintToString(value); + } + int sum; +}; + +class StatefulNamingTest : public ::testing::TestWithParam { + protected: + StatefulNamingTest() : sum_(0) {} + int sum_; +}; + +TEST_P(StatefulNamingTest, TestsReportCorrectNames) { + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + sum_ += GetParam(); + Message test_name_stream; + test_name_stream << "TestsReportCorrectNames/" << sum_; + EXPECT_STREQ(test_name_stream.GetString().c_str(), test_info->name()); +} + +INSTANTIATE_TEST_SUITE_P(StatefulNamingFunctor, StatefulNamingTest, Range(0, 5), + StatefulNamingFunctor()); + +// Class that cannot be streamed into an ostream. It needs to be copyable +// (and, in case of MSVC, also assignable) in order to be a test parameter +// type. Its default copy constructor and assignment operator do exactly +// what we need. +class Unstreamable { + public: + explicit Unstreamable(int value) : value_(value) {} + // -Wunused-private-field: dummy accessor for `value_`. + const int& dummy_value() const { return value_; } + + private: + int value_; +}; + +class CommentTest : public TestWithParam {}; + +TEST_P(CommentTest, TestsCorrectlyReportUnstreamableParams) { + const ::testing::TestInfo* const test_info = + ::testing::UnitTest::GetInstance()->current_test_info(); + + EXPECT_EQ(::testing::PrintToString(GetParam()), test_info->value_param()); +} + +INSTANTIATE_TEST_SUITE_P(InstantiationWithComments, CommentTest, + Values(Unstreamable(1))); + +// Verify that we can create a hierarchy of test fixtures, where the base +// class fixture is not parameterized and the derived class is. In this case +// ParameterizedDerivedTest inherits from NonParameterizedBaseTest. We +// perform simple tests on both. +class NonParameterizedBaseTest : public ::testing::Test { + public: + NonParameterizedBaseTest() : n_(17) { } + protected: + int n_; +}; + +class ParameterizedDerivedTest : public NonParameterizedBaseTest, + public ::testing::WithParamInterface { + protected: + ParameterizedDerivedTest() : count_(0) { } + int count_; + static int global_count_; +}; + +int ParameterizedDerivedTest::global_count_ = 0; + +TEST_F(NonParameterizedBaseTest, FixtureIsInitialized) { + EXPECT_EQ(17, n_); +} + +TEST_P(ParameterizedDerivedTest, SeesSequence) { + EXPECT_EQ(17, n_); + EXPECT_EQ(0, count_++); + EXPECT_EQ(GetParam(), global_count_++); +} + +class ParameterizedDeathTest : public ::testing::TestWithParam { }; + +TEST_F(ParameterizedDeathTest, GetParamDiesFromTestF) { + EXPECT_DEATH_IF_SUPPORTED(GetParam(), + ".* value-parameterized test .*"); +} + +INSTANTIATE_TEST_SUITE_P(RangeZeroToFive, ParameterizedDerivedTest, + Range(0, 5)); + +// Tests param generator working with Enums +enum MyEnums { + ENUM1 = 1, + ENUM2 = 3, + ENUM3 = 8, +}; + +class MyEnumTest : public testing::TestWithParam {}; + +TEST_P(MyEnumTest, ChecksParamMoreThanZero) { EXPECT_GE(10, GetParam()); } +INSTANTIATE_TEST_SUITE_P(MyEnumTests, MyEnumTest, + ::testing::Values(ENUM1, ENUM2, 0)); + +namespace works_here { +// Never used not instantiated, this should work. +class NotUsedTest : public testing::TestWithParam {}; + +/////// +// Never used not instantiated, this should work. +template +class NotUsedTypeTest : public testing::Test {}; +TYPED_TEST_SUITE_P(NotUsedTypeTest); + +// Used but not instantiated, this would fail. but... +class NotInstantiatedTest : public testing::TestWithParam {}; +// ... we mark is as allowed. +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NotInstantiatedTest); + +TEST_P(NotInstantiatedTest, Used) { } + +using OtherName = NotInstantiatedTest; +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(OtherName); +TEST_P(OtherName, Used) { } + +// Used but not instantiated, this would fail. but... +template +class NotInstantiatedTypeTest : public testing::Test {}; +TYPED_TEST_SUITE_P(NotInstantiatedTypeTest); +// ... we mark is as allowed. +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(NotInstantiatedTypeTest); + +TYPED_TEST_P(NotInstantiatedTypeTest, Used) { } +REGISTER_TYPED_TEST_SUITE_P(NotInstantiatedTypeTest, Used); +} // namespace works_here + +int main(int argc, char **argv) { + // Used in TestGenerationTest test suite. + AddGlobalTestEnvironment(TestGenerationTest::Environment::Instance()); + // Used in GeneratorEvaluationTest test suite. Tests that the updated value + // will be picked up for instantiating tests in GeneratorEvaluationTest. + GeneratorEvaluationTest::set_param_value(1); + + ::testing::InitGoogleTest(&argc, argv); + + // Used in GeneratorEvaluationTest test suite. Tests that value updated + // here will NOT be used for instantiating tests in + // GeneratorEvaluationTest. + GeneratorEvaluationTest::set_param_value(2); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-test.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-test.h new file mode 100644 index 000000000000..891937538d01 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test-test.h @@ -0,0 +1,51 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This header file provides classes and functions used internally +// for testing Google Test itself. + +#ifndef GOOGLETEST_TEST_GOOGLETEST_PARAM_TEST_TEST_H_ +#define GOOGLETEST_TEST_GOOGLETEST_PARAM_TEST_TEST_H_ + +#include "gtest/gtest.h" + +// Test fixture for testing definition and instantiation of a test +// in separate translation units. +class ExternalInstantiationTest : public ::testing::TestWithParam { +}; + +// Test fixture for testing instantiation of a test in multiple +// translation units. +class InstantiationInMultipleTranslationUnitsTest + : public ::testing::TestWithParam { +}; + +#endif // GOOGLETEST_TEST_GOOGLETEST_PARAM_TEST_TEST_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test2-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test2-test.cc new file mode 100644 index 000000000000..2a29fb1d0686 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-param-test2-test.cc @@ -0,0 +1,61 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests for Google Test itself. This verifies that the basic constructs of +// Google Test work. + +#include "gtest/gtest.h" +#include "test/googletest-param-test-test.h" + +using ::testing::Values; +using ::testing::internal::ParamGenerator; + +// Tests that generators defined in a different translation unit +// are functional. The test using extern_gen is defined +// in googletest-param-test-test.cc. +ParamGenerator extern_gen = Values(33); + +// Tests that a parameterized test case can be defined in one translation unit +// and instantiated in another. The test is defined in +// googletest-param-test-test.cc and ExternalInstantiationTest fixture class is +// defined in gtest-param-test_test.h. +INSTANTIATE_TEST_SUITE_P(MultiplesOf33, + ExternalInstantiationTest, + Values(33, 66)); + +// Tests that a parameterized test case can be instantiated +// in multiple translation units. Another instantiation is defined +// in googletest-param-test-test.cc and +// InstantiationInMultipleTranslationUnitsTest fixture is defined in +// gtest-param-test_test.h +INSTANTIATE_TEST_SUITE_P(Sequence2, + InstantiationInMultipleTranslationUnitsTest, + Values(42*3, 42*4, 42*5)); + diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-port-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-port-test.cc new file mode 100644 index 000000000000..b14e1f76f4b0 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-port-test.cc @@ -0,0 +1,1305 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// This file tests the internal cross-platform support utilities. +#include + +#include "gtest/internal/gtest-port.h" + +#if GTEST_OS_MAC +# include +#endif // GTEST_OS_MAC + +#include // NOLINT +#include +#include +#include // NOLINT +#include // For std::pair and std::make_pair. +#include + +#include "gtest/gtest.h" +#include "gtest/gtest-spi.h" +#include "src/gtest-internal-inl.h" + +using std::make_pair; +using std::pair; + +namespace testing { +namespace internal { + +TEST(IsXDigitTest, WorksForNarrowAscii) { + EXPECT_TRUE(IsXDigit('0')); + EXPECT_TRUE(IsXDigit('9')); + EXPECT_TRUE(IsXDigit('A')); + EXPECT_TRUE(IsXDigit('F')); + EXPECT_TRUE(IsXDigit('a')); + EXPECT_TRUE(IsXDigit('f')); + + EXPECT_FALSE(IsXDigit('-')); + EXPECT_FALSE(IsXDigit('g')); + EXPECT_FALSE(IsXDigit('G')); +} + +TEST(IsXDigitTest, ReturnsFalseForNarrowNonAscii) { + EXPECT_FALSE(IsXDigit(static_cast('\x80'))); + EXPECT_FALSE(IsXDigit(static_cast('0' | '\x80'))); +} + +TEST(IsXDigitTest, WorksForWideAscii) { + EXPECT_TRUE(IsXDigit(L'0')); + EXPECT_TRUE(IsXDigit(L'9')); + EXPECT_TRUE(IsXDigit(L'A')); + EXPECT_TRUE(IsXDigit(L'F')); + EXPECT_TRUE(IsXDigit(L'a')); + EXPECT_TRUE(IsXDigit(L'f')); + + EXPECT_FALSE(IsXDigit(L'-')); + EXPECT_FALSE(IsXDigit(L'g')); + EXPECT_FALSE(IsXDigit(L'G')); +} + +TEST(IsXDigitTest, ReturnsFalseForWideNonAscii) { + EXPECT_FALSE(IsXDigit(static_cast(0x80))); + EXPECT_FALSE(IsXDigit(static_cast(L'0' | 0x80))); + EXPECT_FALSE(IsXDigit(static_cast(L'0' | 0x100))); +} + +class Base { + public: + Base() : member_(0) {} + explicit Base(int n) : member_(n) {} + Base(const Base&) = default; + Base& operator=(const Base&) = default; + virtual ~Base() {} + int member() { return member_; } + + private: + int member_; +}; + +class Derived : public Base { + public: + explicit Derived(int n) : Base(n) {} +}; + +TEST(ImplicitCastTest, ConvertsPointers) { + Derived derived(0); + EXPECT_TRUE(&derived == ::testing::internal::ImplicitCast_(&derived)); +} + +TEST(ImplicitCastTest, CanUseInheritance) { + Derived derived(1); + Base base = ::testing::internal::ImplicitCast_(derived); + EXPECT_EQ(derived.member(), base.member()); +} + +class Castable { + public: + explicit Castable(bool* converted) : converted_(converted) {} + operator Base() { + *converted_ = true; + return Base(); + } + + private: + bool* converted_; +}; + +TEST(ImplicitCastTest, CanUseNonConstCastOperator) { + bool converted = false; + Castable castable(&converted); + Base base = ::testing::internal::ImplicitCast_(castable); + EXPECT_TRUE(converted); +} + +class ConstCastable { + public: + explicit ConstCastable(bool* converted) : converted_(converted) {} + operator Base() const { + *converted_ = true; + return Base(); + } + + private: + bool* converted_; +}; + +TEST(ImplicitCastTest, CanUseConstCastOperatorOnConstValues) { + bool converted = false; + const ConstCastable const_castable(&converted); + Base base = ::testing::internal::ImplicitCast_(const_castable); + EXPECT_TRUE(converted); +} + +class ConstAndNonConstCastable { + public: + ConstAndNonConstCastable(bool* converted, bool* const_converted) + : converted_(converted), const_converted_(const_converted) {} + operator Base() { + *converted_ = true; + return Base(); + } + operator Base() const { + *const_converted_ = true; + return Base(); + } + + private: + bool* converted_; + bool* const_converted_; +}; + +TEST(ImplicitCastTest, CanSelectBetweenConstAndNonConstCasrAppropriately) { + bool converted = false; + bool const_converted = false; + ConstAndNonConstCastable castable(&converted, &const_converted); + Base base = ::testing::internal::ImplicitCast_(castable); + EXPECT_TRUE(converted); + EXPECT_FALSE(const_converted); + + converted = false; + const_converted = false; + const ConstAndNonConstCastable const_castable(&converted, &const_converted); + base = ::testing::internal::ImplicitCast_(const_castable); + EXPECT_FALSE(converted); + EXPECT_TRUE(const_converted); +} + +class To { + public: + To(bool* converted) { *converted = true; } // NOLINT +}; + +TEST(ImplicitCastTest, CanUseImplicitConstructor) { + bool converted = false; + To to = ::testing::internal::ImplicitCast_(&converted); + (void)to; + EXPECT_TRUE(converted); +} + +// The following code intentionally tests a suboptimal syntax. +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdangling-else" +#pragma GCC diagnostic ignored "-Wempty-body" +#pragma GCC diagnostic ignored "-Wpragmas" +#endif +TEST(GtestCheckSyntaxTest, BehavesLikeASingleStatement) { + if (AlwaysFalse()) + GTEST_CHECK_(false) << "This should never be executed; " + "It's a compilation test only."; + + if (AlwaysTrue()) + GTEST_CHECK_(true); + else + ; // NOLINT + + if (AlwaysFalse()) + ; // NOLINT + else + GTEST_CHECK_(true) << ""; +} +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +TEST(GtestCheckSyntaxTest, WorksWithSwitch) { + switch (0) { + case 1: + break; + default: + GTEST_CHECK_(true); + } + + switch (0) + case 0: + GTEST_CHECK_(true) << "Check failed in switch case"; +} + +// Verifies behavior of FormatFileLocation. +TEST(FormatFileLocationTest, FormatsFileLocation) { + EXPECT_PRED_FORMAT2(IsSubstring, "foo.cc", FormatFileLocation("foo.cc", 42)); + EXPECT_PRED_FORMAT2(IsSubstring, "42", FormatFileLocation("foo.cc", 42)); +} + +TEST(FormatFileLocationTest, FormatsUnknownFile) { + EXPECT_PRED_FORMAT2(IsSubstring, "unknown file", + FormatFileLocation(nullptr, 42)); + EXPECT_PRED_FORMAT2(IsSubstring, "42", FormatFileLocation(nullptr, 42)); +} + +TEST(FormatFileLocationTest, FormatsUknownLine) { + EXPECT_EQ("foo.cc:", FormatFileLocation("foo.cc", -1)); +} + +TEST(FormatFileLocationTest, FormatsUknownFileAndLine) { + EXPECT_EQ("unknown file:", FormatFileLocation(nullptr, -1)); +} + +// Verifies behavior of FormatCompilerIndependentFileLocation. +TEST(FormatCompilerIndependentFileLocationTest, FormatsFileLocation) { + EXPECT_EQ("foo.cc:42", FormatCompilerIndependentFileLocation("foo.cc", 42)); +} + +TEST(FormatCompilerIndependentFileLocationTest, FormatsUknownFile) { + EXPECT_EQ("unknown file:42", + FormatCompilerIndependentFileLocation(nullptr, 42)); +} + +TEST(FormatCompilerIndependentFileLocationTest, FormatsUknownLine) { + EXPECT_EQ("foo.cc", FormatCompilerIndependentFileLocation("foo.cc", -1)); +} + +TEST(FormatCompilerIndependentFileLocationTest, FormatsUknownFileAndLine) { + EXPECT_EQ("unknown file", FormatCompilerIndependentFileLocation(nullptr, -1)); +} + +#if GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_QNX || GTEST_OS_FUCHSIA || \ + GTEST_OS_DRAGONFLY || GTEST_OS_FREEBSD || GTEST_OS_GNU_KFREEBSD || \ + GTEST_OS_NETBSD || GTEST_OS_OPENBSD || GTEST_OS_GNU_HURD +void* ThreadFunc(void* data) { + internal::Mutex* mutex = static_cast(data); + mutex->Lock(); + mutex->Unlock(); + return nullptr; +} + +TEST(GetThreadCountTest, ReturnsCorrectValue) { + size_t starting_count; + size_t thread_count_after_create; + size_t thread_count_after_join; + + // We can't guarantee that no other thread was created or destroyed between + // any two calls to GetThreadCount(). We make multiple attempts, hoping that + // background noise is not constant and we would see the "right" values at + // some point. + for (int attempt = 0; attempt < 20; ++attempt) { + starting_count = GetThreadCount(); + pthread_t thread_id; + + internal::Mutex mutex; + { + internal::MutexLock lock(&mutex); + pthread_attr_t attr; + ASSERT_EQ(0, pthread_attr_init(&attr)); + ASSERT_EQ(0, pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE)); + + const int status = pthread_create(&thread_id, &attr, &ThreadFunc, &mutex); + ASSERT_EQ(0, pthread_attr_destroy(&attr)); + ASSERT_EQ(0, status); + } + + thread_count_after_create = GetThreadCount(); + + void* dummy; + ASSERT_EQ(0, pthread_join(thread_id, &dummy)); + + // Join before we decide whether we need to retry the test. Retry if an + // arbitrary other thread was created or destroyed in the meantime. + if (thread_count_after_create != starting_count + 1) continue; + + // The OS may not immediately report the updated thread count after + // joining a thread, causing flakiness in this test. To counter that, we + // wait for up to .5 seconds for the OS to report the correct value. + bool thread_count_matches = false; + for (int i = 0; i < 5; ++i) { + thread_count_after_join = GetThreadCount(); + if (thread_count_after_join == starting_count) { + thread_count_matches = true; + break; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // Retry if an arbitrary other thread was created or destroyed. + if (!thread_count_matches) continue; + + break; + } + + EXPECT_EQ(thread_count_after_create, starting_count + 1); + EXPECT_EQ(thread_count_after_join, starting_count); +} +#else +TEST(GetThreadCountTest, ReturnsZeroWhenUnableToCountThreads) { + EXPECT_EQ(0U, GetThreadCount()); +} +#endif // GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_QNX || GTEST_OS_FUCHSIA + +TEST(GtestCheckDeathTest, DiesWithCorrectOutputOnFailure) { + const bool a_false_condition = false; + const char regex[] = +#ifdef _MSC_VER + "googletest-port-test\\.cc\\(\\d+\\):" +#elif GTEST_USES_POSIX_RE + "googletest-port-test\\.cc:[0-9]+" +#else + "googletest-port-test\\.cc:\\d+" +#endif // _MSC_VER + ".*a_false_condition.*Extra info.*"; + + EXPECT_DEATH_IF_SUPPORTED(GTEST_CHECK_(a_false_condition) << "Extra info", + regex); +} + +#if GTEST_HAS_DEATH_TEST + +TEST(GtestCheckDeathTest, LivesSilentlyOnSuccess) { + EXPECT_EXIT({ + GTEST_CHECK_(true) << "Extra info"; + ::std::cerr << "Success\n"; + exit(0); }, + ::testing::ExitedWithCode(0), "Success"); +} + +#endif // GTEST_HAS_DEATH_TEST + +// Verifies that Google Test choose regular expression engine appropriate to +// the platform. The test will produce compiler errors in case of failure. +// For simplicity, we only cover the most important platforms here. +TEST(RegexEngineSelectionTest, SelectsCorrectRegexEngine) { +#if !GTEST_USES_PCRE +# if GTEST_HAS_POSIX_RE + + EXPECT_TRUE(GTEST_USES_POSIX_RE); + +# else + + EXPECT_TRUE(GTEST_USES_SIMPLE_RE); + +# endif +#endif // !GTEST_USES_PCRE +} + +#if GTEST_USES_POSIX_RE + +template +class RETest : public ::testing::Test {}; + +// Defines StringTypes as the list of all string types that class RE +// supports. +typedef testing::Types< ::std::string, const char*> StringTypes; + +TYPED_TEST_SUITE(RETest, StringTypes); + +// Tests RE's implicit constructors. +TYPED_TEST(RETest, ImplicitConstructorWorks) { + const RE empty(TypeParam("")); + EXPECT_STREQ("", empty.pattern()); + + const RE simple(TypeParam("hello")); + EXPECT_STREQ("hello", simple.pattern()); + + const RE normal(TypeParam(".*(\\w+)")); + EXPECT_STREQ(".*(\\w+)", normal.pattern()); +} + +// Tests that RE's constructors reject invalid regular expressions. +TYPED_TEST(RETest, RejectsInvalidRegex) { + EXPECT_NONFATAL_FAILURE({ + const RE invalid(TypeParam("?")); + }, "\"?\" is not a valid POSIX Extended regular expression."); +} + +// Tests RE::FullMatch(). +TYPED_TEST(RETest, FullMatchWorks) { + const RE empty(TypeParam("")); + EXPECT_TRUE(RE::FullMatch(TypeParam(""), empty)); + EXPECT_FALSE(RE::FullMatch(TypeParam("a"), empty)); + + const RE re(TypeParam("a.*z")); + EXPECT_TRUE(RE::FullMatch(TypeParam("az"), re)); + EXPECT_TRUE(RE::FullMatch(TypeParam("axyz"), re)); + EXPECT_FALSE(RE::FullMatch(TypeParam("baz"), re)); + EXPECT_FALSE(RE::FullMatch(TypeParam("azy"), re)); +} + +// Tests RE::PartialMatch(). +TYPED_TEST(RETest, PartialMatchWorks) { + const RE empty(TypeParam("")); + EXPECT_TRUE(RE::PartialMatch(TypeParam(""), empty)); + EXPECT_TRUE(RE::PartialMatch(TypeParam("a"), empty)); + + const RE re(TypeParam("a.*z")); + EXPECT_TRUE(RE::PartialMatch(TypeParam("az"), re)); + EXPECT_TRUE(RE::PartialMatch(TypeParam("axyz"), re)); + EXPECT_TRUE(RE::PartialMatch(TypeParam("baz"), re)); + EXPECT_TRUE(RE::PartialMatch(TypeParam("azy"), re)); + EXPECT_FALSE(RE::PartialMatch(TypeParam("zza"), re)); +} + +#elif GTEST_USES_SIMPLE_RE + +TEST(IsInSetTest, NulCharIsNotInAnySet) { + EXPECT_FALSE(IsInSet('\0', "")); + EXPECT_FALSE(IsInSet('\0', "\0")); + EXPECT_FALSE(IsInSet('\0', "a")); +} + +TEST(IsInSetTest, WorksForNonNulChars) { + EXPECT_FALSE(IsInSet('a', "Ab")); + EXPECT_FALSE(IsInSet('c', "")); + + EXPECT_TRUE(IsInSet('b', "bcd")); + EXPECT_TRUE(IsInSet('b', "ab")); +} + +TEST(IsAsciiDigitTest, IsFalseForNonDigit) { + EXPECT_FALSE(IsAsciiDigit('\0')); + EXPECT_FALSE(IsAsciiDigit(' ')); + EXPECT_FALSE(IsAsciiDigit('+')); + EXPECT_FALSE(IsAsciiDigit('-')); + EXPECT_FALSE(IsAsciiDigit('.')); + EXPECT_FALSE(IsAsciiDigit('a')); +} + +TEST(IsAsciiDigitTest, IsTrueForDigit) { + EXPECT_TRUE(IsAsciiDigit('0')); + EXPECT_TRUE(IsAsciiDigit('1')); + EXPECT_TRUE(IsAsciiDigit('5')); + EXPECT_TRUE(IsAsciiDigit('9')); +} + +TEST(IsAsciiPunctTest, IsFalseForNonPunct) { + EXPECT_FALSE(IsAsciiPunct('\0')); + EXPECT_FALSE(IsAsciiPunct(' ')); + EXPECT_FALSE(IsAsciiPunct('\n')); + EXPECT_FALSE(IsAsciiPunct('a')); + EXPECT_FALSE(IsAsciiPunct('0')); +} + +TEST(IsAsciiPunctTest, IsTrueForPunct) { + for (const char* p = "^-!\"#$%&'()*+,./:;<=>?@[\\]_`{|}~"; *p; p++) { + EXPECT_PRED1(IsAsciiPunct, *p); + } +} + +TEST(IsRepeatTest, IsFalseForNonRepeatChar) { + EXPECT_FALSE(IsRepeat('\0')); + EXPECT_FALSE(IsRepeat(' ')); + EXPECT_FALSE(IsRepeat('a')); + EXPECT_FALSE(IsRepeat('1')); + EXPECT_FALSE(IsRepeat('-')); +} + +TEST(IsRepeatTest, IsTrueForRepeatChar) { + EXPECT_TRUE(IsRepeat('?')); + EXPECT_TRUE(IsRepeat('*')); + EXPECT_TRUE(IsRepeat('+')); +} + +TEST(IsAsciiWhiteSpaceTest, IsFalseForNonWhiteSpace) { + EXPECT_FALSE(IsAsciiWhiteSpace('\0')); + EXPECT_FALSE(IsAsciiWhiteSpace('a')); + EXPECT_FALSE(IsAsciiWhiteSpace('1')); + EXPECT_FALSE(IsAsciiWhiteSpace('+')); + EXPECT_FALSE(IsAsciiWhiteSpace('_')); +} + +TEST(IsAsciiWhiteSpaceTest, IsTrueForWhiteSpace) { + EXPECT_TRUE(IsAsciiWhiteSpace(' ')); + EXPECT_TRUE(IsAsciiWhiteSpace('\n')); + EXPECT_TRUE(IsAsciiWhiteSpace('\r')); + EXPECT_TRUE(IsAsciiWhiteSpace('\t')); + EXPECT_TRUE(IsAsciiWhiteSpace('\v')); + EXPECT_TRUE(IsAsciiWhiteSpace('\f')); +} + +TEST(IsAsciiWordCharTest, IsFalseForNonWordChar) { + EXPECT_FALSE(IsAsciiWordChar('\0')); + EXPECT_FALSE(IsAsciiWordChar('+')); + EXPECT_FALSE(IsAsciiWordChar('.')); + EXPECT_FALSE(IsAsciiWordChar(' ')); + EXPECT_FALSE(IsAsciiWordChar('\n')); +} + +TEST(IsAsciiWordCharTest, IsTrueForLetter) { + EXPECT_TRUE(IsAsciiWordChar('a')); + EXPECT_TRUE(IsAsciiWordChar('b')); + EXPECT_TRUE(IsAsciiWordChar('A')); + EXPECT_TRUE(IsAsciiWordChar('Z')); +} + +TEST(IsAsciiWordCharTest, IsTrueForDigit) { + EXPECT_TRUE(IsAsciiWordChar('0')); + EXPECT_TRUE(IsAsciiWordChar('1')); + EXPECT_TRUE(IsAsciiWordChar('7')); + EXPECT_TRUE(IsAsciiWordChar('9')); +} + +TEST(IsAsciiWordCharTest, IsTrueForUnderscore) { + EXPECT_TRUE(IsAsciiWordChar('_')); +} + +TEST(IsValidEscapeTest, IsFalseForNonPrintable) { + EXPECT_FALSE(IsValidEscape('\0')); + EXPECT_FALSE(IsValidEscape('\007')); +} + +TEST(IsValidEscapeTest, IsFalseForDigit) { + EXPECT_FALSE(IsValidEscape('0')); + EXPECT_FALSE(IsValidEscape('9')); +} + +TEST(IsValidEscapeTest, IsFalseForWhiteSpace) { + EXPECT_FALSE(IsValidEscape(' ')); + EXPECT_FALSE(IsValidEscape('\n')); +} + +TEST(IsValidEscapeTest, IsFalseForSomeLetter) { + EXPECT_FALSE(IsValidEscape('a')); + EXPECT_FALSE(IsValidEscape('Z')); +} + +TEST(IsValidEscapeTest, IsTrueForPunct) { + EXPECT_TRUE(IsValidEscape('.')); + EXPECT_TRUE(IsValidEscape('-')); + EXPECT_TRUE(IsValidEscape('^')); + EXPECT_TRUE(IsValidEscape('$')); + EXPECT_TRUE(IsValidEscape('(')); + EXPECT_TRUE(IsValidEscape(']')); + EXPECT_TRUE(IsValidEscape('{')); + EXPECT_TRUE(IsValidEscape('|')); +} + +TEST(IsValidEscapeTest, IsTrueForSomeLetter) { + EXPECT_TRUE(IsValidEscape('d')); + EXPECT_TRUE(IsValidEscape('D')); + EXPECT_TRUE(IsValidEscape('s')); + EXPECT_TRUE(IsValidEscape('S')); + EXPECT_TRUE(IsValidEscape('w')); + EXPECT_TRUE(IsValidEscape('W')); +} + +TEST(AtomMatchesCharTest, EscapedPunct) { + EXPECT_FALSE(AtomMatchesChar(true, '\\', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, '\\', ' ')); + EXPECT_FALSE(AtomMatchesChar(true, '_', '.')); + EXPECT_FALSE(AtomMatchesChar(true, '.', 'a')); + + EXPECT_TRUE(AtomMatchesChar(true, '\\', '\\')); + EXPECT_TRUE(AtomMatchesChar(true, '_', '_')); + EXPECT_TRUE(AtomMatchesChar(true, '+', '+')); + EXPECT_TRUE(AtomMatchesChar(true, '.', '.')); +} + +TEST(AtomMatchesCharTest, Escaped_d) { + EXPECT_FALSE(AtomMatchesChar(true, 'd', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 'd', 'a')); + EXPECT_FALSE(AtomMatchesChar(true, 'd', '.')); + + EXPECT_TRUE(AtomMatchesChar(true, 'd', '0')); + EXPECT_TRUE(AtomMatchesChar(true, 'd', '9')); +} + +TEST(AtomMatchesCharTest, Escaped_D) { + EXPECT_FALSE(AtomMatchesChar(true, 'D', '0')); + EXPECT_FALSE(AtomMatchesChar(true, 'D', '9')); + + EXPECT_TRUE(AtomMatchesChar(true, 'D', '\0')); + EXPECT_TRUE(AtomMatchesChar(true, 'D', 'a')); + EXPECT_TRUE(AtomMatchesChar(true, 'D', '-')); +} + +TEST(AtomMatchesCharTest, Escaped_s) { + EXPECT_FALSE(AtomMatchesChar(true, 's', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 's', 'a')); + EXPECT_FALSE(AtomMatchesChar(true, 's', '.')); + EXPECT_FALSE(AtomMatchesChar(true, 's', '9')); + + EXPECT_TRUE(AtomMatchesChar(true, 's', ' ')); + EXPECT_TRUE(AtomMatchesChar(true, 's', '\n')); + EXPECT_TRUE(AtomMatchesChar(true, 's', '\t')); +} + +TEST(AtomMatchesCharTest, Escaped_S) { + EXPECT_FALSE(AtomMatchesChar(true, 'S', ' ')); + EXPECT_FALSE(AtomMatchesChar(true, 'S', '\r')); + + EXPECT_TRUE(AtomMatchesChar(true, 'S', '\0')); + EXPECT_TRUE(AtomMatchesChar(true, 'S', 'a')); + EXPECT_TRUE(AtomMatchesChar(true, 'S', '9')); +} + +TEST(AtomMatchesCharTest, Escaped_w) { + EXPECT_FALSE(AtomMatchesChar(true, 'w', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 'w', '+')); + EXPECT_FALSE(AtomMatchesChar(true, 'w', ' ')); + EXPECT_FALSE(AtomMatchesChar(true, 'w', '\n')); + + EXPECT_TRUE(AtomMatchesChar(true, 'w', '0')); + EXPECT_TRUE(AtomMatchesChar(true, 'w', 'b')); + EXPECT_TRUE(AtomMatchesChar(true, 'w', 'C')); + EXPECT_TRUE(AtomMatchesChar(true, 'w', '_')); +} + +TEST(AtomMatchesCharTest, Escaped_W) { + EXPECT_FALSE(AtomMatchesChar(true, 'W', 'A')); + EXPECT_FALSE(AtomMatchesChar(true, 'W', 'b')); + EXPECT_FALSE(AtomMatchesChar(true, 'W', '9')); + EXPECT_FALSE(AtomMatchesChar(true, 'W', '_')); + + EXPECT_TRUE(AtomMatchesChar(true, 'W', '\0')); + EXPECT_TRUE(AtomMatchesChar(true, 'W', '*')); + EXPECT_TRUE(AtomMatchesChar(true, 'W', '\n')); +} + +TEST(AtomMatchesCharTest, EscapedWhiteSpace) { + EXPECT_FALSE(AtomMatchesChar(true, 'f', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 'f', '\n')); + EXPECT_FALSE(AtomMatchesChar(true, 'n', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 'n', '\r')); + EXPECT_FALSE(AtomMatchesChar(true, 'r', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 'r', 'a')); + EXPECT_FALSE(AtomMatchesChar(true, 't', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 't', 't')); + EXPECT_FALSE(AtomMatchesChar(true, 'v', '\0')); + EXPECT_FALSE(AtomMatchesChar(true, 'v', '\f')); + + EXPECT_TRUE(AtomMatchesChar(true, 'f', '\f')); + EXPECT_TRUE(AtomMatchesChar(true, 'n', '\n')); + EXPECT_TRUE(AtomMatchesChar(true, 'r', '\r')); + EXPECT_TRUE(AtomMatchesChar(true, 't', '\t')); + EXPECT_TRUE(AtomMatchesChar(true, 'v', '\v')); +} + +TEST(AtomMatchesCharTest, UnescapedDot) { + EXPECT_FALSE(AtomMatchesChar(false, '.', '\n')); + + EXPECT_TRUE(AtomMatchesChar(false, '.', '\0')); + EXPECT_TRUE(AtomMatchesChar(false, '.', '.')); + EXPECT_TRUE(AtomMatchesChar(false, '.', 'a')); + EXPECT_TRUE(AtomMatchesChar(false, '.', ' ')); +} + +TEST(AtomMatchesCharTest, UnescapedChar) { + EXPECT_FALSE(AtomMatchesChar(false, 'a', '\0')); + EXPECT_FALSE(AtomMatchesChar(false, 'a', 'b')); + EXPECT_FALSE(AtomMatchesChar(false, '$', 'a')); + + EXPECT_TRUE(AtomMatchesChar(false, '$', '$')); + EXPECT_TRUE(AtomMatchesChar(false, '5', '5')); + EXPECT_TRUE(AtomMatchesChar(false, 'Z', 'Z')); +} + +TEST(ValidateRegexTest, GeneratesFailureAndReturnsFalseForInvalid) { + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex(NULL)), + "NULL is not a valid simple regular expression"); + EXPECT_NONFATAL_FAILURE( + ASSERT_FALSE(ValidateRegex("a\\")), + "Syntax error at index 1 in simple regular expression \"a\\\": "); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("a\\")), + "'\\' cannot appear at the end"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("\\n\\")), + "'\\' cannot appear at the end"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("\\s\\hb")), + "invalid escape sequence \"\\h\""); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("^^")), + "'^' can only appear at the beginning"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex(".*^b")), + "'^' can only appear at the beginning"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("$$")), + "'$' can only appear at the end"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("^$a")), + "'$' can only appear at the end"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("a(b")), + "'(' is unsupported"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("ab)")), + "')' is unsupported"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("[ab")), + "'[' is unsupported"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("a{2")), + "'{' is unsupported"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("?")), + "'?' can only follow a repeatable token"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("^*")), + "'*' can only follow a repeatable token"); + EXPECT_NONFATAL_FAILURE(ASSERT_FALSE(ValidateRegex("5*+")), + "'+' can only follow a repeatable token"); +} + +TEST(ValidateRegexTest, ReturnsTrueForValid) { + EXPECT_TRUE(ValidateRegex("")); + EXPECT_TRUE(ValidateRegex("a")); + EXPECT_TRUE(ValidateRegex(".*")); + EXPECT_TRUE(ValidateRegex("^a_+")); + EXPECT_TRUE(ValidateRegex("^a\\t\\&?")); + EXPECT_TRUE(ValidateRegex("09*$")); + EXPECT_TRUE(ValidateRegex("^Z$")); + EXPECT_TRUE(ValidateRegex("a\\^Z\\$\\(\\)\\|\\[\\]\\{\\}")); +} + +TEST(MatchRepetitionAndRegexAtHeadTest, WorksForZeroOrOne) { + EXPECT_FALSE(MatchRepetitionAndRegexAtHead(false, 'a', '?', "a", "ba")); + // Repeating more than once. + EXPECT_FALSE(MatchRepetitionAndRegexAtHead(false, 'a', '?', "b", "aab")); + + // Repeating zero times. + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(false, 'a', '?', "b", "ba")); + // Repeating once. + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(false, 'a', '?', "b", "ab")); + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(false, '#', '?', ".", "##")); +} + +TEST(MatchRepetitionAndRegexAtHeadTest, WorksForZeroOrMany) { + EXPECT_FALSE(MatchRepetitionAndRegexAtHead(false, '.', '*', "a$", "baab")); + + // Repeating zero times. + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(false, '.', '*', "b", "bc")); + // Repeating once. + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(false, '.', '*', "b", "abc")); + // Repeating more than once. + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(true, 'w', '*', "-", "ab_1-g")); +} + +TEST(MatchRepetitionAndRegexAtHeadTest, WorksForOneOrMany) { + EXPECT_FALSE(MatchRepetitionAndRegexAtHead(false, '.', '+', "a$", "baab")); + // Repeating zero times. + EXPECT_FALSE(MatchRepetitionAndRegexAtHead(false, '.', '+', "b", "bc")); + + // Repeating once. + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(false, '.', '+', "b", "abc")); + // Repeating more than once. + EXPECT_TRUE(MatchRepetitionAndRegexAtHead(true, 'w', '+', "-", "ab_1-g")); +} + +TEST(MatchRegexAtHeadTest, ReturnsTrueForEmptyRegex) { + EXPECT_TRUE(MatchRegexAtHead("", "")); + EXPECT_TRUE(MatchRegexAtHead("", "ab")); +} + +TEST(MatchRegexAtHeadTest, WorksWhenDollarIsInRegex) { + EXPECT_FALSE(MatchRegexAtHead("$", "a")); + + EXPECT_TRUE(MatchRegexAtHead("$", "")); + EXPECT_TRUE(MatchRegexAtHead("a$", "a")); +} + +TEST(MatchRegexAtHeadTest, WorksWhenRegexStartsWithEscapeSequence) { + EXPECT_FALSE(MatchRegexAtHead("\\w", "+")); + EXPECT_FALSE(MatchRegexAtHead("\\W", "ab")); + + EXPECT_TRUE(MatchRegexAtHead("\\sa", "\nab")); + EXPECT_TRUE(MatchRegexAtHead("\\d", "1a")); +} + +TEST(MatchRegexAtHeadTest, WorksWhenRegexStartsWithRepetition) { + EXPECT_FALSE(MatchRegexAtHead(".+a", "abc")); + EXPECT_FALSE(MatchRegexAtHead("a?b", "aab")); + + EXPECT_TRUE(MatchRegexAtHead(".*a", "bc12-ab")); + EXPECT_TRUE(MatchRegexAtHead("a?b", "b")); + EXPECT_TRUE(MatchRegexAtHead("a?b", "ab")); +} + +TEST(MatchRegexAtHeadTest, + WorksWhenRegexStartsWithRepetionOfEscapeSequence) { + EXPECT_FALSE(MatchRegexAtHead("\\.+a", "abc")); + EXPECT_FALSE(MatchRegexAtHead("\\s?b", " b")); + + EXPECT_TRUE(MatchRegexAtHead("\\(*a", "((((ab")); + EXPECT_TRUE(MatchRegexAtHead("\\^?b", "^b")); + EXPECT_TRUE(MatchRegexAtHead("\\\\?b", "b")); + EXPECT_TRUE(MatchRegexAtHead("\\\\?b", "\\b")); +} + +TEST(MatchRegexAtHeadTest, MatchesSequentially) { + EXPECT_FALSE(MatchRegexAtHead("ab.*c", "acabc")); + + EXPECT_TRUE(MatchRegexAtHead("ab.*c", "ab-fsc")); +} + +TEST(MatchRegexAnywhereTest, ReturnsFalseWhenStringIsNull) { + EXPECT_FALSE(MatchRegexAnywhere("", NULL)); +} + +TEST(MatchRegexAnywhereTest, WorksWhenRegexStartsWithCaret) { + EXPECT_FALSE(MatchRegexAnywhere("^a", "ba")); + EXPECT_FALSE(MatchRegexAnywhere("^$", "a")); + + EXPECT_TRUE(MatchRegexAnywhere("^a", "ab")); + EXPECT_TRUE(MatchRegexAnywhere("^", "ab")); + EXPECT_TRUE(MatchRegexAnywhere("^$", "")); +} + +TEST(MatchRegexAnywhereTest, ReturnsFalseWhenNoMatch) { + EXPECT_FALSE(MatchRegexAnywhere("a", "bcde123")); + EXPECT_FALSE(MatchRegexAnywhere("a.+a", "--aa88888888")); +} + +TEST(MatchRegexAnywhereTest, ReturnsTrueWhenMatchingPrefix) { + EXPECT_TRUE(MatchRegexAnywhere("\\w+", "ab1_ - 5")); + EXPECT_TRUE(MatchRegexAnywhere(".*=", "=")); + EXPECT_TRUE(MatchRegexAnywhere("x.*ab?.*bc", "xaaabc")); +} + +TEST(MatchRegexAnywhereTest, ReturnsTrueWhenMatchingNonPrefix) { + EXPECT_TRUE(MatchRegexAnywhere("\\w+", "$$$ ab1_ - 5")); + EXPECT_TRUE(MatchRegexAnywhere("\\.+=", "= ...=")); +} + +// Tests RE's implicit constructors. +TEST(RETest, ImplicitConstructorWorks) { + const RE empty(""); + EXPECT_STREQ("", empty.pattern()); + + const RE simple("hello"); + EXPECT_STREQ("hello", simple.pattern()); +} + +// Tests that RE's constructors reject invalid regular expressions. +TEST(RETest, RejectsInvalidRegex) { + EXPECT_NONFATAL_FAILURE({ + const RE normal(NULL); + }, "NULL is not a valid simple regular expression"); + + EXPECT_NONFATAL_FAILURE({ + const RE normal(".*(\\w+"); + }, "'(' is unsupported"); + + EXPECT_NONFATAL_FAILURE({ + const RE invalid("^?"); + }, "'?' can only follow a repeatable token"); +} + +// Tests RE::FullMatch(). +TEST(RETest, FullMatchWorks) { + const RE empty(""); + EXPECT_TRUE(RE::FullMatch("", empty)); + EXPECT_FALSE(RE::FullMatch("a", empty)); + + const RE re1("a"); + EXPECT_TRUE(RE::FullMatch("a", re1)); + + const RE re("a.*z"); + EXPECT_TRUE(RE::FullMatch("az", re)); + EXPECT_TRUE(RE::FullMatch("axyz", re)); + EXPECT_FALSE(RE::FullMatch("baz", re)); + EXPECT_FALSE(RE::FullMatch("azy", re)); +} + +// Tests RE::PartialMatch(). +TEST(RETest, PartialMatchWorks) { + const RE empty(""); + EXPECT_TRUE(RE::PartialMatch("", empty)); + EXPECT_TRUE(RE::PartialMatch("a", empty)); + + const RE re("a.*z"); + EXPECT_TRUE(RE::PartialMatch("az", re)); + EXPECT_TRUE(RE::PartialMatch("axyz", re)); + EXPECT_TRUE(RE::PartialMatch("baz", re)); + EXPECT_TRUE(RE::PartialMatch("azy", re)); + EXPECT_FALSE(RE::PartialMatch("zza", re)); +} + +#endif // GTEST_USES_POSIX_RE + +#if !GTEST_OS_WINDOWS_MOBILE + +TEST(CaptureTest, CapturesStdout) { + CaptureStdout(); + fprintf(stdout, "abc"); + EXPECT_STREQ("abc", GetCapturedStdout().c_str()); + + CaptureStdout(); + fprintf(stdout, "def%cghi", '\0'); + EXPECT_EQ(::std::string("def\0ghi", 7), ::std::string(GetCapturedStdout())); +} + +TEST(CaptureTest, CapturesStderr) { + CaptureStderr(); + fprintf(stderr, "jkl"); + EXPECT_STREQ("jkl", GetCapturedStderr().c_str()); + + CaptureStderr(); + fprintf(stderr, "jkl%cmno", '\0'); + EXPECT_EQ(::std::string("jkl\0mno", 7), ::std::string(GetCapturedStderr())); +} + +// Tests that stdout and stderr capture don't interfere with each other. +TEST(CaptureTest, CapturesStdoutAndStderr) { + CaptureStdout(); + CaptureStderr(); + fprintf(stdout, "pqr"); + fprintf(stderr, "stu"); + EXPECT_STREQ("pqr", GetCapturedStdout().c_str()); + EXPECT_STREQ("stu", GetCapturedStderr().c_str()); +} + +TEST(CaptureDeathTest, CannotReenterStdoutCapture) { + CaptureStdout(); + EXPECT_DEATH_IF_SUPPORTED(CaptureStdout(), + "Only one stdout capturer can exist at a time"); + GetCapturedStdout(); + + // We cannot test stderr capturing using death tests as they use it + // themselves. +} + +#endif // !GTEST_OS_WINDOWS_MOBILE + +TEST(ThreadLocalTest, DefaultConstructorInitializesToDefaultValues) { + ThreadLocal t1; + EXPECT_EQ(0, t1.get()); + + ThreadLocal t2; + EXPECT_TRUE(t2.get() == nullptr); +} + +TEST(ThreadLocalTest, SingleParamConstructorInitializesToParam) { + ThreadLocal t1(123); + EXPECT_EQ(123, t1.get()); + + int i = 0; + ThreadLocal t2(&i); + EXPECT_EQ(&i, t2.get()); +} + +class NoDefaultContructor { + public: + explicit NoDefaultContructor(const char*) {} + NoDefaultContructor(const NoDefaultContructor&) {} +}; + +TEST(ThreadLocalTest, ValueDefaultContructorIsNotRequiredForParamVersion) { + ThreadLocal bar(NoDefaultContructor("foo")); + bar.pointer(); +} + +TEST(ThreadLocalTest, GetAndPointerReturnSameValue) { + ThreadLocal thread_local_string; + + EXPECT_EQ(thread_local_string.pointer(), &(thread_local_string.get())); + + // Verifies the condition still holds after calling set. + thread_local_string.set("foo"); + EXPECT_EQ(thread_local_string.pointer(), &(thread_local_string.get())); +} + +TEST(ThreadLocalTest, PointerAndConstPointerReturnSameValue) { + ThreadLocal thread_local_string; + const ThreadLocal& const_thread_local_string = + thread_local_string; + + EXPECT_EQ(thread_local_string.pointer(), const_thread_local_string.pointer()); + + thread_local_string.set("foo"); + EXPECT_EQ(thread_local_string.pointer(), const_thread_local_string.pointer()); +} + +#if GTEST_IS_THREADSAFE + +void AddTwo(int* param) { *param += 2; } + +TEST(ThreadWithParamTest, ConstructorExecutesThreadFunc) { + int i = 40; + ThreadWithParam thread(&AddTwo, &i, nullptr); + thread.Join(); + EXPECT_EQ(42, i); +} + +TEST(MutexDeathTest, AssertHeldShouldAssertWhenNotLocked) { + // AssertHeld() is flaky only in the presence of multiple threads accessing + // the lock. In this case, the test is robust. + EXPECT_DEATH_IF_SUPPORTED({ + Mutex m; + { MutexLock lock(&m); } + m.AssertHeld(); + }, + "thread .*hold"); +} + +TEST(MutexTest, AssertHeldShouldNotAssertWhenLocked) { + Mutex m; + MutexLock lock(&m); + m.AssertHeld(); +} + +class AtomicCounterWithMutex { + public: + explicit AtomicCounterWithMutex(Mutex* mutex) : + value_(0), mutex_(mutex), random_(42) {} + + void Increment() { + MutexLock lock(mutex_); + int temp = value_; + { + // We need to put up a memory barrier to prevent reads and writes to + // value_ rearranged with the call to sleep_for when observed + // from other threads. +#if GTEST_HAS_PTHREAD + // On POSIX, locking a mutex puts up a memory barrier. We cannot use + // Mutex and MutexLock here or rely on their memory barrier + // functionality as we are testing them here. + pthread_mutex_t memory_barrier_mutex; + GTEST_CHECK_POSIX_SUCCESS_( + pthread_mutex_init(&memory_barrier_mutex, nullptr)); + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_lock(&memory_barrier_mutex)); + + std::this_thread::sleep_for( + std::chrono::milliseconds(random_.Generate(30))); + + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_unlock(&memory_barrier_mutex)); + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_destroy(&memory_barrier_mutex)); +#elif GTEST_OS_WINDOWS + // On Windows, performing an interlocked access puts up a memory barrier. + volatile LONG dummy = 0; + ::InterlockedIncrement(&dummy); + std::this_thread::sleep_for( + std::chrono::milliseconds(random_.Generate(30))); + ::InterlockedIncrement(&dummy); +#else +# error "Memory barrier not implemented on this platform." +#endif // GTEST_HAS_PTHREAD + } + value_ = temp + 1; + } + int value() const { return value_; } + + private: + volatile int value_; + Mutex* const mutex_; // Protects value_. + Random random_; +}; + +void CountingThreadFunc(pair param) { + for (int i = 0; i < param.second; ++i) + param.first->Increment(); +} + +// Tests that the mutex only lets one thread at a time to lock it. +TEST(MutexTest, OnlyOneThreadCanLockAtATime) { + Mutex mutex; + AtomicCounterWithMutex locked_counter(&mutex); + + typedef ThreadWithParam > ThreadType; + const int kCycleCount = 20; + const int kThreadCount = 7; + std::unique_ptr counting_threads[kThreadCount]; + Notification threads_can_start; + // Creates and runs kThreadCount threads that increment locked_counter + // kCycleCount times each. + for (int i = 0; i < kThreadCount; ++i) { + counting_threads[i].reset(new ThreadType(&CountingThreadFunc, + make_pair(&locked_counter, + kCycleCount), + &threads_can_start)); + } + threads_can_start.Notify(); + for (int i = 0; i < kThreadCount; ++i) + counting_threads[i]->Join(); + + // If the mutex lets more than one thread to increment the counter at a + // time, they are likely to encounter a race condition and have some + // increments overwritten, resulting in the lower then expected counter + // value. + EXPECT_EQ(kCycleCount * kThreadCount, locked_counter.value()); +} + +template +void RunFromThread(void (func)(T), T param) { + ThreadWithParam thread(func, param, nullptr); + thread.Join(); +} + +void RetrieveThreadLocalValue( + pair*, std::string*> param) { + *param.second = param.first->get(); +} + +TEST(ThreadLocalTest, ParameterizedConstructorSetsDefault) { + ThreadLocal thread_local_string("foo"); + EXPECT_STREQ("foo", thread_local_string.get().c_str()); + + thread_local_string.set("bar"); + EXPECT_STREQ("bar", thread_local_string.get().c_str()); + + std::string result; + RunFromThread(&RetrieveThreadLocalValue, + make_pair(&thread_local_string, &result)); + EXPECT_STREQ("foo", result.c_str()); +} + +// Keeps track of whether of destructors being called on instances of +// DestructorTracker. On Windows, waits for the destructor call reports. +class DestructorCall { + public: + DestructorCall() { + invoked_ = false; +#if GTEST_OS_WINDOWS + wait_event_.Reset(::CreateEvent(NULL, TRUE, FALSE, NULL)); + GTEST_CHECK_(wait_event_.Get() != NULL); +#endif + } + + bool CheckDestroyed() const { +#if GTEST_OS_WINDOWS + if (::WaitForSingleObject(wait_event_.Get(), 1000) != WAIT_OBJECT_0) + return false; +#endif + return invoked_; + } + + void ReportDestroyed() { + invoked_ = true; +#if GTEST_OS_WINDOWS + ::SetEvent(wait_event_.Get()); +#endif + } + + static std::vector& List() { return *list_; } + + static void ResetList() { + for (size_t i = 0; i < list_->size(); ++i) { + delete list_->at(i); + } + list_->clear(); + } + + private: + bool invoked_; +#if GTEST_OS_WINDOWS + AutoHandle wait_event_; +#endif + static std::vector* const list_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DestructorCall); +}; + +std::vector* const DestructorCall::list_ = + new std::vector; + +// DestructorTracker keeps track of whether its instances have been +// destroyed. +class DestructorTracker { + public: + DestructorTracker() : index_(GetNewIndex()) {} + DestructorTracker(const DestructorTracker& /* rhs */) + : index_(GetNewIndex()) {} + ~DestructorTracker() { + // We never access DestructorCall::List() concurrently, so we don't need + // to protect this access with a mutex. + DestructorCall::List()[index_]->ReportDestroyed(); + } + + private: + static size_t GetNewIndex() { + DestructorCall::List().push_back(new DestructorCall); + return DestructorCall::List().size() - 1; + } + const size_t index_; +}; + +typedef ThreadLocal* ThreadParam; + +void CallThreadLocalGet(ThreadParam thread_local_param) { + thread_local_param->get(); +} + +// Tests that when a ThreadLocal object dies in a thread, it destroys +// the managed object for that thread. +TEST(ThreadLocalTest, DestroysManagedObjectForOwnThreadWhenDying) { + DestructorCall::ResetList(); + + { + ThreadLocal thread_local_tracker; + ASSERT_EQ(0U, DestructorCall::List().size()); + + // This creates another DestructorTracker object for the main thread. + thread_local_tracker.get(); + ASSERT_EQ(1U, DestructorCall::List().size()); + ASSERT_FALSE(DestructorCall::List()[0]->CheckDestroyed()); + } + + // Now thread_local_tracker has died. + ASSERT_EQ(1U, DestructorCall::List().size()); + EXPECT_TRUE(DestructorCall::List()[0]->CheckDestroyed()); + + DestructorCall::ResetList(); +} + +// Tests that when a thread exits, the thread-local object for that +// thread is destroyed. +TEST(ThreadLocalTest, DestroysManagedObjectAtThreadExit) { + DestructorCall::ResetList(); + + { + ThreadLocal thread_local_tracker; + ASSERT_EQ(0U, DestructorCall::List().size()); + + // This creates another DestructorTracker object in the new thread. + ThreadWithParam thread(&CallThreadLocalGet, + &thread_local_tracker, nullptr); + thread.Join(); + + // The thread has exited, and we should have a DestroyedTracker + // instance created for it. But it may not have been destroyed yet. + ASSERT_EQ(1U, DestructorCall::List().size()); + } + + // The thread has exited and thread_local_tracker has died. + ASSERT_EQ(1U, DestructorCall::List().size()); + EXPECT_TRUE(DestructorCall::List()[0]->CheckDestroyed()); + + DestructorCall::ResetList(); +} + +TEST(ThreadLocalTest, ThreadLocalMutationsAffectOnlyCurrentThread) { + ThreadLocal thread_local_string; + thread_local_string.set("Foo"); + EXPECT_STREQ("Foo", thread_local_string.get().c_str()); + + std::string result; + RunFromThread(&RetrieveThreadLocalValue, + make_pair(&thread_local_string, &result)); + EXPECT_TRUE(result.empty()); +} + +#endif // GTEST_IS_THREADSAFE + +#if GTEST_OS_WINDOWS +TEST(WindowsTypesTest, HANDLEIsVoidStar) { + StaticAssertTypeEq(); +} + +#if GTEST_OS_WINDOWS_MINGW && !defined(__MINGW64_VERSION_MAJOR) +TEST(WindowsTypesTest, _CRITICAL_SECTIONIs_CRITICAL_SECTION) { + StaticAssertTypeEq(); +} +#else +TEST(WindowsTypesTest, CRITICAL_SECTIONIs_RTL_CRITICAL_SECTION) { + StaticAssertTypeEq(); +} +#endif + +#endif // GTEST_OS_WINDOWS + +} // namespace internal +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-printers-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-printers-test.cc new file mode 100644 index 000000000000..0058917a2759 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-printers-test.cc @@ -0,0 +1,1991 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Google Test - The Google C++ Testing and Mocking Framework +// +// This file tests the universal value printer. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest-printers.h" +#include "gtest/gtest.h" + +// Some user-defined types for testing the universal value printer. + +// An anonymous enum type. +enum AnonymousEnum { + kAE1 = -1, + kAE2 = 1 +}; + +// An enum without a user-defined printer. +enum EnumWithoutPrinter { + kEWP1 = -2, + kEWP2 = 42 +}; + +// An enum with a << operator. +enum EnumWithStreaming { + kEWS1 = 10 +}; + +std::ostream& operator<<(std::ostream& os, EnumWithStreaming e) { + return os << (e == kEWS1 ? "kEWS1" : "invalid"); +} + +// An enum with a PrintTo() function. +enum EnumWithPrintTo { + kEWPT1 = 1 +}; + +void PrintTo(EnumWithPrintTo e, std::ostream* os) { + *os << (e == kEWPT1 ? "kEWPT1" : "invalid"); +} + +// A class implicitly convertible to BiggestInt. +class BiggestIntConvertible { + public: + operator ::testing::internal::BiggestInt() const { return 42; } +}; + +// A parent class with two child classes. The parent and one of the kids have +// stream operators. +class ParentClass {}; +class ChildClassWithStreamOperator : public ParentClass {}; +class ChildClassWithoutStreamOperator : public ParentClass {}; +static void operator<<(std::ostream& os, const ParentClass&) { + os << "ParentClass"; +} +static void operator<<(std::ostream& os, const ChildClassWithStreamOperator&) { + os << "ChildClassWithStreamOperator"; +} + +// A user-defined unprintable class template in the global namespace. +template +class UnprintableTemplateInGlobal { + public: + UnprintableTemplateInGlobal() : value_() {} + private: + T value_; +}; + +// A user-defined streamable type in the global namespace. +class StreamableInGlobal { + public: + virtual ~StreamableInGlobal() {} +}; + +inline void operator<<(::std::ostream& os, const StreamableInGlobal& /* x */) { + os << "StreamableInGlobal"; +} + +void operator<<(::std::ostream& os, const StreamableInGlobal* /* x */) { + os << "StreamableInGlobal*"; +} + +namespace foo { + +// A user-defined unprintable type in a user namespace. +class UnprintableInFoo { + public: + UnprintableInFoo() : z_(0) { memcpy(xy_, "\xEF\x12\x0\x0\x34\xAB\x0\x0", 8); } + double z() const { return z_; } + private: + char xy_[8]; + double z_; +}; + +// A user-defined printable type in a user-chosen namespace. +struct PrintableViaPrintTo { + PrintableViaPrintTo() : value() {} + int value; +}; + +void PrintTo(const PrintableViaPrintTo& x, ::std::ostream* os) { + *os << "PrintableViaPrintTo: " << x.value; +} + +// A type with a user-defined << for printing its pointer. +struct PointerPrintable { +}; + +::std::ostream& operator<<(::std::ostream& os, + const PointerPrintable* /* x */) { + return os << "PointerPrintable*"; +} + +// A user-defined printable class template in a user-chosen namespace. +template +class PrintableViaPrintToTemplate { + public: + explicit PrintableViaPrintToTemplate(const T& a_value) : value_(a_value) {} + + const T& value() const { return value_; } + private: + T value_; +}; + +template +void PrintTo(const PrintableViaPrintToTemplate& x, ::std::ostream* os) { + *os << "PrintableViaPrintToTemplate: " << x.value(); +} + +// A user-defined streamable class template in a user namespace. +template +class StreamableTemplateInFoo { + public: + StreamableTemplateInFoo() : value_() {} + + const T& value() const { return value_; } + private: + T value_; +}; + +template +inline ::std::ostream& operator<<(::std::ostream& os, + const StreamableTemplateInFoo& x) { + return os << "StreamableTemplateInFoo: " << x.value(); +} + +// A user-defined streamable type in a user namespace whose operator<< is +// templated on the type of the output stream. +struct TemplatedStreamableInFoo {}; + +template +OutputStream& operator<<(OutputStream& os, + const TemplatedStreamableInFoo& /*ts*/) { + os << "TemplatedStreamableInFoo"; + return os; +} + +// A user-defined streamable but recursively-defined container type in +// a user namespace, it mimics therefore std::filesystem::path or +// boost::filesystem::path. +class PathLike { + public: + struct iterator { + typedef PathLike value_type; + + iterator& operator++(); + PathLike& operator*(); + }; + + using value_type = char; + using const_iterator = iterator; + + PathLike() {} + + iterator begin() const { return iterator(); } + iterator end() const { return iterator(); } + + friend ::std::ostream& operator<<(::std::ostream& os, const PathLike&) { + return os << "Streamable-PathLike"; + } +}; + +} // namespace foo + +namespace testing { +namespace { +template +class Wrapper { + public: + explicit Wrapper(T&& value) : value_(std::forward(value)) {} + + const T& value() const { return value_; } + + private: + T value_; +}; + +} // namespace + +namespace internal { +template +class UniversalPrinter> { + public: + static void Print(const Wrapper& w, ::std::ostream* os) { + *os << "Wrapper("; + UniversalPrint(w.value(), os); + *os << ')'; + } +}; +} // namespace internal + + +namespace gtest_printers_test { + +using ::std::deque; +using ::std::list; +using ::std::make_pair; +using ::std::map; +using ::std::multimap; +using ::std::multiset; +using ::std::pair; +using ::std::set; +using ::std::vector; +using ::testing::PrintToString; +using ::testing::internal::FormatForComparisonFailureMessage; +using ::testing::internal::ImplicitCast_; +using ::testing::internal::NativeArray; +using ::testing::internal::RelationToSourceReference; +using ::testing::internal::Strings; +using ::testing::internal::UniversalPrint; +using ::testing::internal::UniversalPrinter; +using ::testing::internal::UniversalTersePrint; +using ::testing::internal::UniversalTersePrintTupleFieldsToStrings; + +// Prints a value to a string using the universal value printer. This +// is a helper for testing UniversalPrinter::Print() for various types. +template +std::string Print(const T& value) { + ::std::stringstream ss; + UniversalPrinter::Print(value, &ss); + return ss.str(); +} + +// Prints a value passed by reference to a string, using the universal +// value printer. This is a helper for testing +// UniversalPrinter::Print() for various types. +template +std::string PrintByRef(const T& value) { + ::std::stringstream ss; + UniversalPrinter::Print(value, &ss); + return ss.str(); +} + +// Tests printing various enum types. + +TEST(PrintEnumTest, AnonymousEnum) { + EXPECT_EQ("-1", Print(kAE1)); + EXPECT_EQ("1", Print(kAE2)); +} + +TEST(PrintEnumTest, EnumWithoutPrinter) { + EXPECT_EQ("-2", Print(kEWP1)); + EXPECT_EQ("42", Print(kEWP2)); +} + +TEST(PrintEnumTest, EnumWithStreaming) { + EXPECT_EQ("kEWS1", Print(kEWS1)); + EXPECT_EQ("invalid", Print(static_cast(0))); +} + +TEST(PrintEnumTest, EnumWithPrintTo) { + EXPECT_EQ("kEWPT1", Print(kEWPT1)); + EXPECT_EQ("invalid", Print(static_cast(0))); +} + +// Tests printing a class implicitly convertible to BiggestInt. + +TEST(PrintClassTest, BiggestIntConvertible) { + EXPECT_EQ("42", Print(BiggestIntConvertible())); +} + +// Tests printing various char types. + +// char. +TEST(PrintCharTest, PlainChar) { + EXPECT_EQ("'\\0'", Print('\0')); + EXPECT_EQ("'\\'' (39, 0x27)", Print('\'')); + EXPECT_EQ("'\"' (34, 0x22)", Print('"')); + EXPECT_EQ("'?' (63, 0x3F)", Print('?')); + EXPECT_EQ("'\\\\' (92, 0x5C)", Print('\\')); + EXPECT_EQ("'\\a' (7)", Print('\a')); + EXPECT_EQ("'\\b' (8)", Print('\b')); + EXPECT_EQ("'\\f' (12, 0xC)", Print('\f')); + EXPECT_EQ("'\\n' (10, 0xA)", Print('\n')); + EXPECT_EQ("'\\r' (13, 0xD)", Print('\r')); + EXPECT_EQ("'\\t' (9)", Print('\t')); + EXPECT_EQ("'\\v' (11, 0xB)", Print('\v')); + EXPECT_EQ("'\\x7F' (127)", Print('\x7F')); + EXPECT_EQ("'\\xFF' (255)", Print('\xFF')); + EXPECT_EQ("' ' (32, 0x20)", Print(' ')); + EXPECT_EQ("'a' (97, 0x61)", Print('a')); +} + +// signed char. +TEST(PrintCharTest, SignedChar) { + EXPECT_EQ("'\\0'", Print(static_cast('\0'))); + EXPECT_EQ("'\\xCE' (-50)", + Print(static_cast(-50))); +} + +// unsigned char. +TEST(PrintCharTest, UnsignedChar) { + EXPECT_EQ("'\\0'", Print(static_cast('\0'))); + EXPECT_EQ("'b' (98, 0x62)", + Print(static_cast('b'))); +} + +TEST(PrintCharTest, Char16) { + EXPECT_EQ("U+0041", Print(u'A')); +} + +TEST(PrintCharTest, Char32) { + EXPECT_EQ("U+0041", Print(U'A')); +} + +#ifdef __cpp_char8_t +TEST(PrintCharTest, Char8) { + EXPECT_EQ("U+0041", Print(u8'A')); +} +#endif + +// Tests printing other simple, built-in types. + +// bool. +TEST(PrintBuiltInTypeTest, Bool) { + EXPECT_EQ("false", Print(false)); + EXPECT_EQ("true", Print(true)); +} + +// wchar_t. +TEST(PrintBuiltInTypeTest, Wchar_t) { + EXPECT_EQ("L'\\0'", Print(L'\0')); + EXPECT_EQ("L'\\'' (39, 0x27)", Print(L'\'')); + EXPECT_EQ("L'\"' (34, 0x22)", Print(L'"')); + EXPECT_EQ("L'?' (63, 0x3F)", Print(L'?')); + EXPECT_EQ("L'\\\\' (92, 0x5C)", Print(L'\\')); + EXPECT_EQ("L'\\a' (7)", Print(L'\a')); + EXPECT_EQ("L'\\b' (8)", Print(L'\b')); + EXPECT_EQ("L'\\f' (12, 0xC)", Print(L'\f')); + EXPECT_EQ("L'\\n' (10, 0xA)", Print(L'\n')); + EXPECT_EQ("L'\\r' (13, 0xD)", Print(L'\r')); + EXPECT_EQ("L'\\t' (9)", Print(L'\t')); + EXPECT_EQ("L'\\v' (11, 0xB)", Print(L'\v')); + EXPECT_EQ("L'\\x7F' (127)", Print(L'\x7F')); + EXPECT_EQ("L'\\xFF' (255)", Print(L'\xFF')); + EXPECT_EQ("L' ' (32, 0x20)", Print(L' ')); + EXPECT_EQ("L'a' (97, 0x61)", Print(L'a')); + EXPECT_EQ("L'\\x576' (1398)", Print(static_cast(0x576))); + EXPECT_EQ("L'\\xC74D' (51021)", Print(static_cast(0xC74D))); +} + +// Test that int64_t provides more storage than wchar_t. +TEST(PrintTypeSizeTest, Wchar_t) { + EXPECT_LT(sizeof(wchar_t), sizeof(int64_t)); +} + +// Various integer types. +TEST(PrintBuiltInTypeTest, Integer) { + EXPECT_EQ("'\\xFF' (255)", Print(static_cast(255))); // uint8 + EXPECT_EQ("'\\x80' (-128)", Print(static_cast(-128))); // int8 + EXPECT_EQ("65535", Print(std::numeric_limits::max())); // uint16 + EXPECT_EQ("-32768", Print(std::numeric_limits::min())); // int16 + EXPECT_EQ("4294967295", + Print(std::numeric_limits::max())); // uint32 + EXPECT_EQ("-2147483648", + Print(std::numeric_limits::min())); // int32 + EXPECT_EQ("18446744073709551615", + Print(std::numeric_limits::max())); // uint64 + EXPECT_EQ("-9223372036854775808", + Print(std::numeric_limits::min())); // int64 +#ifdef __cpp_char8_t + EXPECT_EQ("U+0000", + Print(std::numeric_limits::min())); // char8_t + EXPECT_EQ("U+00FF", + Print(std::numeric_limits::max())); // char8_t +#endif + EXPECT_EQ("U+0000", + Print(std::numeric_limits::min())); // char16_t + EXPECT_EQ("U+FFFF", + Print(std::numeric_limits::max())); // char16_t + EXPECT_EQ("U+0000", + Print(std::numeric_limits::min())); // char32_t + EXPECT_EQ("U+FFFFFFFF", + Print(std::numeric_limits::max())); // char32_t +} + +// Size types. +TEST(PrintBuiltInTypeTest, Size_t) { + EXPECT_EQ("1", Print(sizeof('a'))); // size_t. +#if !GTEST_OS_WINDOWS + // Windows has no ssize_t type. + EXPECT_EQ("-2", Print(static_cast(-2))); // ssize_t. +#endif // !GTEST_OS_WINDOWS +} + +// gcc/clang __{u,}int128_t values. +#if defined(__SIZEOF_INT128__) +TEST(PrintBuiltInTypeTest, Int128) { + // Small ones + EXPECT_EQ("0", Print(__int128_t{0})); + EXPECT_EQ("0", Print(__uint128_t{0})); + EXPECT_EQ("12345", Print(__int128_t{12345})); + EXPECT_EQ("12345", Print(__uint128_t{12345})); + EXPECT_EQ("-12345", Print(__int128_t{-12345})); + + // Large ones + EXPECT_EQ("340282366920938463463374607431768211455", Print(~__uint128_t{})); + __int128_t max_128 = static_cast<__int128_t>(~__uint128_t{} / 2); + EXPECT_EQ("-170141183460469231731687303715884105728", Print(~max_128)); + EXPECT_EQ("170141183460469231731687303715884105727", Print(max_128)); +} +#endif // __SIZEOF_INT128__ + +// Floating-points. +TEST(PrintBuiltInTypeTest, FloatingPoints) { + EXPECT_EQ("1.5", Print(1.5f)); // float + EXPECT_EQ("-2.5", Print(-2.5)); // double +} + +#if GTEST_HAS_RTTI +TEST(PrintBuiltInTypeTest, TypeInfo) { + struct MyStruct {}; + auto res = Print(typeid(MyStruct{})); + // We can't guarantee that we can demangle the name, but either name should + // contain the substring "MyStruct". + EXPECT_NE(res.find("MyStruct"), res.npos) << res; +} +#endif // GTEST_HAS_RTTI + +// Since ::std::stringstream::operator<<(const void *) formats the pointer +// output differently with different compilers, we have to create the expected +// output first and use it as our expectation. +static std::string PrintPointer(const void* p) { + ::std::stringstream expected_result_stream; + expected_result_stream << p; + return expected_result_stream.str(); +} + +// Tests printing C strings. + +// const char*. +TEST(PrintCStringTest, Const) { + const char* p = "World"; + EXPECT_EQ(PrintPointer(p) + " pointing to \"World\"", Print(p)); +} + +// char*. +TEST(PrintCStringTest, NonConst) { + char p[] = "Hi"; + EXPECT_EQ(PrintPointer(p) + " pointing to \"Hi\"", + Print(static_cast(p))); +} + +// NULL C string. +TEST(PrintCStringTest, Null) { + const char* p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests that C strings are escaped properly. +TEST(PrintCStringTest, EscapesProperly) { + const char* p = "'\"?\\\a\b\f\n\r\t\v\x7F\xFF a"; + EXPECT_EQ(PrintPointer(p) + " pointing to \"'\\\"?\\\\\\a\\b\\f" + "\\n\\r\\t\\v\\x7F\\xFF a\"", + Print(p)); +} + +#ifdef __cpp_char8_t +// const char8_t*. +TEST(PrintU8StringTest, Const) { + const char8_t* p = u8"界"; + EXPECT_EQ(PrintPointer(p) + " pointing to u8\"\\xE7\\x95\\x8C\"", Print(p)); +} + +// char8_t*. +TEST(PrintU8StringTest, NonConst) { + char8_t p[] = u8"世"; + EXPECT_EQ(PrintPointer(p) + " pointing to u8\"\\xE4\\xB8\\x96\"", + Print(static_cast(p))); +} + +// NULL u8 string. +TEST(PrintU8StringTest, Null) { + const char8_t* p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests that u8 strings are escaped properly. +TEST(PrintU8StringTest, EscapesProperly) { + const char8_t* p = u8"'\"?\\\a\b\f\n\r\t\v\x7F\xFF hello 世界"; + EXPECT_EQ(PrintPointer(p) + + " pointing to u8\"'\\\"?\\\\\\a\\b\\f\\n\\r\\t\\v\\x7F\\xFF " + "hello \\xE4\\xB8\\x96\\xE7\\x95\\x8C\"", + Print(p)); +} +#endif + +// const char16_t*. +TEST(PrintU16StringTest, Const) { + const char16_t* p = u"界"; + EXPECT_EQ(PrintPointer(p) + " pointing to u\"\\x754C\"", Print(p)); +} + +// char16_t*. +TEST(PrintU16StringTest, NonConst) { + char16_t p[] = u"世"; + EXPECT_EQ(PrintPointer(p) + " pointing to u\"\\x4E16\"", + Print(static_cast(p))); +} + +// NULL u16 string. +TEST(PrintU16StringTest, Null) { + const char16_t* p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests that u16 strings are escaped properly. +TEST(PrintU16StringTest, EscapesProperly) { + const char16_t* p = u"'\"?\\\a\b\f\n\r\t\v\x7F\xFF hello 世界"; + EXPECT_EQ(PrintPointer(p) + + " pointing to u\"'\\\"?\\\\\\a\\b\\f\\n\\r\\t\\v\\x7F\\xFF " + "hello \\x4E16\\x754C\"", + Print(p)); +} + +// const char32_t*. +TEST(PrintU32StringTest, Const) { + const char32_t* p = U"🗺️"; + EXPECT_EQ(PrintPointer(p) + " pointing to U\"\\x1F5FA\\xFE0F\"", Print(p)); +} + +// char32_t*. +TEST(PrintU32StringTest, NonConst) { + char32_t p[] = U"🌌"; + EXPECT_EQ(PrintPointer(p) + " pointing to U\"\\x1F30C\"", + Print(static_cast(p))); +} + +// NULL u32 string. +TEST(PrintU32StringTest, Null) { + const char32_t* p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests that u32 strings are escaped properly. +TEST(PrintU32StringTest, EscapesProperly) { + const char32_t* p = U"'\"?\\\a\b\f\n\r\t\v\x7F\xFF hello 🗺️"; + EXPECT_EQ(PrintPointer(p) + + " pointing to U\"'\\\"?\\\\\\a\\b\\f\\n\\r\\t\\v\\x7F\\xFF " + "hello \\x1F5FA\\xFE0F\"", + Print(p)); +} + +// MSVC compiler can be configured to define whar_t as a typedef +// of unsigned short. Defining an overload for const wchar_t* in that case +// would cause pointers to unsigned shorts be printed as wide strings, +// possibly accessing more memory than intended and causing invalid +// memory accesses. MSVC defines _NATIVE_WCHAR_T_DEFINED symbol when +// wchar_t is implemented as a native type. +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) + +// const wchar_t*. +TEST(PrintWideCStringTest, Const) { + const wchar_t* p = L"World"; + EXPECT_EQ(PrintPointer(p) + " pointing to L\"World\"", Print(p)); +} + +// wchar_t*. +TEST(PrintWideCStringTest, NonConst) { + wchar_t p[] = L"Hi"; + EXPECT_EQ(PrintPointer(p) + " pointing to L\"Hi\"", + Print(static_cast(p))); +} + +// NULL wide C string. +TEST(PrintWideCStringTest, Null) { + const wchar_t* p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests that wide C strings are escaped properly. +TEST(PrintWideCStringTest, EscapesProperly) { + const wchar_t s[] = {'\'', '"', '?', '\\', '\a', '\b', '\f', '\n', '\r', + '\t', '\v', 0xD3, 0x576, 0x8D3, 0xC74D, ' ', 'a', '\0'}; + EXPECT_EQ(PrintPointer(s) + " pointing to L\"'\\\"?\\\\\\a\\b\\f" + "\\n\\r\\t\\v\\xD3\\x576\\x8D3\\xC74D a\"", + Print(static_cast(s))); +} +#endif // native wchar_t + +// Tests printing pointers to other char types. + +// signed char*. +TEST(PrintCharPointerTest, SignedChar) { + signed char* p = reinterpret_cast(0x1234); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// const signed char*. +TEST(PrintCharPointerTest, ConstSignedChar) { + signed char* p = reinterpret_cast(0x1234); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// unsigned char*. +TEST(PrintCharPointerTest, UnsignedChar) { + unsigned char* p = reinterpret_cast(0x1234); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// const unsigned char*. +TEST(PrintCharPointerTest, ConstUnsignedChar) { + const unsigned char* p = reinterpret_cast(0x1234); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests printing pointers to simple, built-in types. + +// bool*. +TEST(PrintPointerToBuiltInTypeTest, Bool) { + bool* p = reinterpret_cast(0xABCD); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// void*. +TEST(PrintPointerToBuiltInTypeTest, Void) { + void* p = reinterpret_cast(0xABCD); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// const void*. +TEST(PrintPointerToBuiltInTypeTest, ConstVoid) { + const void* p = reinterpret_cast(0xABCD); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests printing pointers to pointers. +TEST(PrintPointerToPointerTest, IntPointerPointer) { + int** p = reinterpret_cast(0xABCD); + EXPECT_EQ(PrintPointer(p), Print(p)); + p = nullptr; + EXPECT_EQ("NULL", Print(p)); +} + +// Tests printing (non-member) function pointers. + +void MyFunction(int /* n */) {} + +TEST(PrintPointerTest, NonMemberFunctionPointer) { + // We cannot directly cast &MyFunction to const void* because the + // standard disallows casting between pointers to functions and + // pointers to objects, and some compilers (e.g. GCC 3.4) enforce + // this limitation. + EXPECT_EQ( + PrintPointer(reinterpret_cast( + reinterpret_cast(&MyFunction))), + Print(&MyFunction)); + int (*p)(bool) = NULL; // NOLINT + EXPECT_EQ("NULL", Print(p)); +} + +// An assertion predicate determining whether a one string is a prefix for +// another. +template +AssertionResult HasPrefix(const StringType& str, const StringType& prefix) { + if (str.find(prefix, 0) == 0) + return AssertionSuccess(); + + const bool is_wide_string = sizeof(prefix[0]) > 1; + const char* const begin_string_quote = is_wide_string ? "L\"" : "\""; + return AssertionFailure() + << begin_string_quote << prefix << "\" is not a prefix of " + << begin_string_quote << str << "\"\n"; +} + +// Tests printing member variable pointers. Although they are called +// pointers, they don't point to a location in the address space. +// Their representation is implementation-defined. Thus they will be +// printed as raw bytes. + +struct Foo { + public: + virtual ~Foo() {} + int MyMethod(char x) { return x + 1; } + virtual char MyVirtualMethod(int /* n */) { return 'a'; } + + int value; +}; + +TEST(PrintPointerTest, MemberVariablePointer) { + EXPECT_TRUE(HasPrefix(Print(&Foo::value), + Print(sizeof(&Foo::value)) + "-byte object ")); + int Foo::*p = NULL; // NOLINT + EXPECT_TRUE(HasPrefix(Print(p), + Print(sizeof(p)) + "-byte object ")); +} + +// Tests printing member function pointers. Although they are called +// pointers, they don't point to a location in the address space. +// Their representation is implementation-defined. Thus they will be +// printed as raw bytes. +TEST(PrintPointerTest, MemberFunctionPointer) { + EXPECT_TRUE(HasPrefix(Print(&Foo::MyMethod), + Print(sizeof(&Foo::MyMethod)) + "-byte object ")); + EXPECT_TRUE( + HasPrefix(Print(&Foo::MyVirtualMethod), + Print(sizeof((&Foo::MyVirtualMethod))) + "-byte object ")); + int (Foo::*p)(char) = NULL; // NOLINT + EXPECT_TRUE(HasPrefix(Print(p), + Print(sizeof(p)) + "-byte object ")); +} + +// Tests printing C arrays. + +// The difference between this and Print() is that it ensures that the +// argument is a reference to an array. +template +std::string PrintArrayHelper(T (&a)[N]) { + return Print(a); +} + +// One-dimensional array. +TEST(PrintArrayTest, OneDimensionalArray) { + int a[5] = { 1, 2, 3, 4, 5 }; + EXPECT_EQ("{ 1, 2, 3, 4, 5 }", PrintArrayHelper(a)); +} + +// Two-dimensional array. +TEST(PrintArrayTest, TwoDimensionalArray) { + int a[2][5] = { + { 1, 2, 3, 4, 5 }, + { 6, 7, 8, 9, 0 } + }; + EXPECT_EQ("{ { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 0 } }", PrintArrayHelper(a)); +} + +// Array of const elements. +TEST(PrintArrayTest, ConstArray) { + const bool a[1] = { false }; + EXPECT_EQ("{ false }", PrintArrayHelper(a)); +} + +// char array without terminating NUL. +TEST(PrintArrayTest, CharArrayWithNoTerminatingNul) { + // Array a contains '\0' in the middle and doesn't end with '\0'. + char a[] = { 'H', '\0', 'i' }; + EXPECT_EQ("\"H\\0i\" (no terminating NUL)", PrintArrayHelper(a)); +} + +// char array with terminating NUL. +TEST(PrintArrayTest, CharArrayWithTerminatingNul) { + const char a[] = "\0Hi"; + EXPECT_EQ("\"\\0Hi\"", PrintArrayHelper(a)); +} + +#ifdef __cpp_char8_t +// char_t array without terminating NUL. +TEST(PrintArrayTest, Char8ArrayWithNoTerminatingNul) { + // Array a contains '\0' in the middle and doesn't end with '\0'. + const char8_t a[] = {u8'H', u8'\0', u8'i'}; + EXPECT_EQ("u8\"H\\0i\" (no terminating NUL)", PrintArrayHelper(a)); +} + +// char8_t array with terminating NUL. +TEST(PrintArrayTest, Char8ArrayWithTerminatingNul) { + const char8_t a[] = u8"\0世界"; + EXPECT_EQ( + "u8\"\\0\\xE4\\xB8\\x96\\xE7\\x95\\x8C\"", + PrintArrayHelper(a)); +} +#endif + +// const char16_t array without terminating NUL. +TEST(PrintArrayTest, Char16ArrayWithNoTerminatingNul) { + // Array a contains '\0' in the middle and doesn't end with '\0'. + const char16_t a[] = {u'こ', u'\0', u'ん', u'に', u'ち', u'は'}; + EXPECT_EQ("u\"\\x3053\\0\\x3093\\x306B\\x3061\\x306F\" (no terminating NUL)", + PrintArrayHelper(a)); +} + +// char16_t array with terminating NUL. +TEST(PrintArrayTest, Char16ArrayWithTerminatingNul) { + const char16_t a[] = u"\0こんにちは"; + EXPECT_EQ("u\"\\0\\x3053\\x3093\\x306B\\x3061\\x306F\"", PrintArrayHelper(a)); +} + +// char32_t array without terminating NUL. +TEST(PrintArrayTest, Char32ArrayWithNoTerminatingNul) { + // Array a contains '\0' in the middle and doesn't end with '\0'. + const char32_t a[] = {U'👋', U'\0', U'🌌'}; + EXPECT_EQ("U\"\\x1F44B\\0\\x1F30C\" (no terminating NUL)", + PrintArrayHelper(a)); +} + +// char32_t array with terminating NUL. +TEST(PrintArrayTest, Char32ArrayWithTerminatingNul) { + const char32_t a[] = U"\0👋🌌"; + EXPECT_EQ("U\"\\0\\x1F44B\\x1F30C\"", PrintArrayHelper(a)); +} + +// wchar_t array without terminating NUL. +TEST(PrintArrayTest, WCharArrayWithNoTerminatingNul) { + // Array a contains '\0' in the middle and doesn't end with '\0'. + const wchar_t a[] = {L'H', L'\0', L'i'}; + EXPECT_EQ("L\"H\\0i\" (no terminating NUL)", PrintArrayHelper(a)); +} + +// wchar_t array with terminating NUL. +TEST(PrintArrayTest, WCharArrayWithTerminatingNul) { + const wchar_t a[] = L"\0Hi"; + EXPECT_EQ("L\"\\0Hi\"", PrintArrayHelper(a)); +} + +// Array of objects. +TEST(PrintArrayTest, ObjectArray) { + std::string a[3] = {"Hi", "Hello", "Ni hao"}; + EXPECT_EQ("{ \"Hi\", \"Hello\", \"Ni hao\" }", PrintArrayHelper(a)); +} + +// Array with many elements. +TEST(PrintArrayTest, BigArray) { + int a[100] = { 1, 2, 3 }; + EXPECT_EQ("{ 1, 2, 3, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0 }", + PrintArrayHelper(a)); +} + +// Tests printing ::string and ::std::string. + +// ::std::string. +TEST(PrintStringTest, StringInStdNamespace) { + const char s[] = "'\"?\\\a\b\f\n\0\r\t\v\x7F\xFF a"; + const ::std::string str(s, sizeof(s)); + EXPECT_EQ("\"'\\\"?\\\\\\a\\b\\f\\n\\0\\r\\t\\v\\x7F\\xFF a\\0\"", + Print(str)); +} + +TEST(PrintStringTest, StringAmbiguousHex) { + // "\x6BANANA" is ambiguous, it can be interpreted as starting with either of: + // '\x6', '\x6B', or '\x6BA'. + + // a hex escaping sequence following by a decimal digit + EXPECT_EQ("\"0\\x12\" \"3\"", Print(::std::string("0\x12" "3"))); + // a hex escaping sequence following by a hex digit (lower-case) + EXPECT_EQ("\"mm\\x6\" \"bananas\"", Print(::std::string("mm\x6" "bananas"))); + // a hex escaping sequence following by a hex digit (upper-case) + EXPECT_EQ("\"NOM\\x6\" \"BANANA\"", Print(::std::string("NOM\x6" "BANANA"))); + // a hex escaping sequence following by a non-xdigit + EXPECT_EQ("\"!\\x5-!\"", Print(::std::string("!\x5-!"))); +} + +// Tests printing ::std::wstring. +#if GTEST_HAS_STD_WSTRING +// ::std::wstring. +TEST(PrintWideStringTest, StringInStdNamespace) { + const wchar_t s[] = L"'\"?\\\a\b\f\n\0\r\t\v\xD3\x576\x8D3\xC74D a"; + const ::std::wstring str(s, sizeof(s)/sizeof(wchar_t)); + EXPECT_EQ("L\"'\\\"?\\\\\\a\\b\\f\\n\\0\\r\\t\\v" + "\\xD3\\x576\\x8D3\\xC74D a\\0\"", + Print(str)); +} + +TEST(PrintWideStringTest, StringAmbiguousHex) { + // same for wide strings. + EXPECT_EQ("L\"0\\x12\" L\"3\"", Print(::std::wstring(L"0\x12" L"3"))); + EXPECT_EQ("L\"mm\\x6\" L\"bananas\"", + Print(::std::wstring(L"mm\x6" L"bananas"))); + EXPECT_EQ("L\"NOM\\x6\" L\"BANANA\"", + Print(::std::wstring(L"NOM\x6" L"BANANA"))); + EXPECT_EQ("L\"!\\x5-!\"", Print(::std::wstring(L"!\x5-!"))); +} +#endif // GTEST_HAS_STD_WSTRING + +#ifdef __cpp_char8_t +TEST(PrintStringTest, U8String) { + std::u8string str = u8"Hello, 世界"; + EXPECT_EQ(str, str); // Verify EXPECT_EQ compiles with this type. + EXPECT_EQ("u8\"Hello, \\xE4\\xB8\\x96\\xE7\\x95\\x8C\"", Print(str)); +} +#endif + +TEST(PrintStringTest, U16String) { + std::u16string str = u"Hello, 世界"; + EXPECT_EQ(str, str); // Verify EXPECT_EQ compiles with this type. + EXPECT_EQ("u\"Hello, \\x4E16\\x754C\"", Print(str)); +} + +TEST(PrintStringTest, U32String) { + std::u32string str = U"Hello, 🗺️"; + EXPECT_EQ(str, str); // Verify EXPECT_EQ compiles with this type + EXPECT_EQ("U\"Hello, \\x1F5FA\\xFE0F\"", Print(str)); +} + +// Tests printing types that support generic streaming (i.e. streaming +// to std::basic_ostream for any valid Char and +// CharTraits types). + +// Tests printing a non-template type that supports generic streaming. + +class AllowsGenericStreaming {}; + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const AllowsGenericStreaming& /* a */) { + return os << "AllowsGenericStreaming"; +} + +TEST(PrintTypeWithGenericStreamingTest, NonTemplateType) { + AllowsGenericStreaming a; + EXPECT_EQ("AllowsGenericStreaming", Print(a)); +} + +// Tests printing a template type that supports generic streaming. + +template +class AllowsGenericStreamingTemplate {}; + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const AllowsGenericStreamingTemplate& /* a */) { + return os << "AllowsGenericStreamingTemplate"; +} + +TEST(PrintTypeWithGenericStreamingTest, TemplateType) { + AllowsGenericStreamingTemplate a; + EXPECT_EQ("AllowsGenericStreamingTemplate", Print(a)); +} + +// Tests printing a type that supports generic streaming and can be +// implicitly converted to another printable type. + +template +class AllowsGenericStreamingAndImplicitConversionTemplate { + public: + operator bool() const { return false; } +}; + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const AllowsGenericStreamingAndImplicitConversionTemplate& /* a */) { + return os << "AllowsGenericStreamingAndImplicitConversionTemplate"; +} + +TEST(PrintTypeWithGenericStreamingTest, TypeImplicitlyConvertible) { + AllowsGenericStreamingAndImplicitConversionTemplate a; + EXPECT_EQ("AllowsGenericStreamingAndImplicitConversionTemplate", Print(a)); +} + +#if GTEST_INTERNAL_HAS_STRING_VIEW + +// Tests printing internal::StringView. + +TEST(PrintStringViewTest, SimpleStringView) { + const internal::StringView sp = "Hello"; + EXPECT_EQ("\"Hello\"", Print(sp)); +} + +TEST(PrintStringViewTest, UnprintableCharacters) { + const char str[] = "NUL (\0) and \r\t"; + const internal::StringView sp(str, sizeof(str) - 1); + EXPECT_EQ("\"NUL (\\0) and \\r\\t\"", Print(sp)); +} + +#endif // GTEST_INTERNAL_HAS_STRING_VIEW + +// Tests printing STL containers. + +TEST(PrintStlContainerTest, EmptyDeque) { + deque empty; + EXPECT_EQ("{}", Print(empty)); +} + +TEST(PrintStlContainerTest, NonEmptyDeque) { + deque non_empty; + non_empty.push_back(1); + non_empty.push_back(3); + EXPECT_EQ("{ 1, 3 }", Print(non_empty)); +} + + +TEST(PrintStlContainerTest, OneElementHashMap) { + ::std::unordered_map map1; + map1[1] = 'a'; + EXPECT_EQ("{ (1, 'a' (97, 0x61)) }", Print(map1)); +} + +TEST(PrintStlContainerTest, HashMultiMap) { + ::std::unordered_multimap map1; + map1.insert(make_pair(5, true)); + map1.insert(make_pair(5, false)); + + // Elements of hash_multimap can be printed in any order. + const std::string result = Print(map1); + EXPECT_TRUE(result == "{ (5, true), (5, false) }" || + result == "{ (5, false), (5, true) }") + << " where Print(map1) returns \"" << result << "\"."; +} + + + +TEST(PrintStlContainerTest, HashSet) { + ::std::unordered_set set1; + set1.insert(1); + EXPECT_EQ("{ 1 }", Print(set1)); +} + +TEST(PrintStlContainerTest, HashMultiSet) { + const int kSize = 5; + int a[kSize] = { 1, 1, 2, 5, 1 }; + ::std::unordered_multiset set1(a, a + kSize); + + // Elements of hash_multiset can be printed in any order. + const std::string result = Print(set1); + const std::string expected_pattern = "{ d, d, d, d, d }"; // d means a digit. + + // Verifies the result matches the expected pattern; also extracts + // the numbers in the result. + ASSERT_EQ(expected_pattern.length(), result.length()); + std::vector numbers; + for (size_t i = 0; i != result.length(); i++) { + if (expected_pattern[i] == 'd') { + ASSERT_NE(isdigit(static_cast(result[i])), 0); + numbers.push_back(result[i] - '0'); + } else { + EXPECT_EQ(expected_pattern[i], result[i]) << " where result is " + << result; + } + } + + // Makes sure the result contains the right numbers. + std::sort(numbers.begin(), numbers.end()); + std::sort(a, a + kSize); + EXPECT_TRUE(std::equal(a, a + kSize, numbers.begin())); +} + + +TEST(PrintStlContainerTest, List) { + const std::string a[] = {"hello", "world"}; + const list strings(a, a + 2); + EXPECT_EQ("{ \"hello\", \"world\" }", Print(strings)); +} + +TEST(PrintStlContainerTest, Map) { + map map1; + map1[1] = true; + map1[5] = false; + map1[3] = true; + EXPECT_EQ("{ (1, true), (3, true), (5, false) }", Print(map1)); +} + +TEST(PrintStlContainerTest, MultiMap) { + multimap map1; + // The make_pair template function would deduce the type as + // pair here, and since the key part in a multimap has to + // be constant, without a templated ctor in the pair class (as in + // libCstd on Solaris), make_pair call would fail to compile as no + // implicit conversion is found. Thus explicit typename is used + // here instead. + map1.insert(pair(true, 0)); + map1.insert(pair(true, 1)); + map1.insert(pair(false, 2)); + EXPECT_EQ("{ (false, 2), (true, 0), (true, 1) }", Print(map1)); +} + +TEST(PrintStlContainerTest, Set) { + const unsigned int a[] = { 3, 0, 5 }; + set set1(a, a + 3); + EXPECT_EQ("{ 0, 3, 5 }", Print(set1)); +} + +TEST(PrintStlContainerTest, MultiSet) { + const int a[] = { 1, 1, 2, 5, 1 }; + multiset set1(a, a + 5); + EXPECT_EQ("{ 1, 1, 1, 2, 5 }", Print(set1)); +} + + +TEST(PrintStlContainerTest, SinglyLinkedList) { + int a[] = { 9, 2, 8 }; + const std::forward_list ints(a, a + 3); + EXPECT_EQ("{ 9, 2, 8 }", Print(ints)); +} + +TEST(PrintStlContainerTest, Pair) { + pair p(true, 5); + EXPECT_EQ("(true, 5)", Print(p)); +} + +TEST(PrintStlContainerTest, Vector) { + vector v; + v.push_back(1); + v.push_back(2); + EXPECT_EQ("{ 1, 2 }", Print(v)); +} + +TEST(PrintStlContainerTest, LongSequence) { + const int a[100] = { 1, 2, 3 }; + const vector v(a, a + 100); + EXPECT_EQ("{ 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, " + "0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... }", Print(v)); +} + +TEST(PrintStlContainerTest, NestedContainer) { + const int a1[] = { 1, 2 }; + const int a2[] = { 3, 4, 5 }; + const list l1(a1, a1 + 2); + const list l2(a2, a2 + 3); + + vector > v; + v.push_back(l1); + v.push_back(l2); + EXPECT_EQ("{ { 1, 2 }, { 3, 4, 5 } }", Print(v)); +} + +TEST(PrintStlContainerTest, OneDimensionalNativeArray) { + const int a[3] = { 1, 2, 3 }; + NativeArray b(a, 3, RelationToSourceReference()); + EXPECT_EQ("{ 1, 2, 3 }", Print(b)); +} + +TEST(PrintStlContainerTest, TwoDimensionalNativeArray) { + const int a[2][3] = { { 1, 2, 3 }, { 4, 5, 6 } }; + NativeArray b(a, 2, RelationToSourceReference()); + EXPECT_EQ("{ { 1, 2, 3 }, { 4, 5, 6 } }", Print(b)); +} + +// Tests that a class named iterator isn't treated as a container. + +struct iterator { + char x; +}; + +TEST(PrintStlContainerTest, Iterator) { + iterator it = {}; + EXPECT_EQ("1-byte object <00>", Print(it)); +} + +// Tests that a class named const_iterator isn't treated as a container. + +struct const_iterator { + char x; +}; + +TEST(PrintStlContainerTest, ConstIterator) { + const_iterator it = {}; + EXPECT_EQ("1-byte object <00>", Print(it)); +} + +// Tests printing ::std::tuples. + +// Tuples of various arities. +TEST(PrintStdTupleTest, VariousSizes) { + ::std::tuple<> t0; + EXPECT_EQ("()", Print(t0)); + + ::std::tuple t1(5); + EXPECT_EQ("(5)", Print(t1)); + + ::std::tuple t2('a', true); + EXPECT_EQ("('a' (97, 0x61), true)", Print(t2)); + + ::std::tuple t3(false, 2, 3); + EXPECT_EQ("(false, 2, 3)", Print(t3)); + + ::std::tuple t4(false, 2, 3, 4); + EXPECT_EQ("(false, 2, 3, 4)", Print(t4)); + + const char* const str = "8"; + ::std::tuple + t10(false, 'a', static_cast(3), 4, 5, 1.5F, -2.5, str, // NOLINT + nullptr, "10"); + EXPECT_EQ("(false, 'a' (97, 0x61), 3, 4, 5, 1.5, -2.5, " + PrintPointer(str) + + " pointing to \"8\", NULL, \"10\")", + Print(t10)); +} + +// Nested tuples. +TEST(PrintStdTupleTest, NestedTuple) { + ::std::tuple< ::std::tuple, char> nested( + ::std::make_tuple(5, true), 'a'); + EXPECT_EQ("((5, true), 'a' (97, 0x61))", Print(nested)); +} + +TEST(PrintNullptrT, Basic) { + EXPECT_EQ("(nullptr)", Print(nullptr)); +} + +TEST(PrintReferenceWrapper, Printable) { + int x = 5; + EXPECT_EQ("@" + PrintPointer(&x) + " 5", Print(std::ref(x))); + EXPECT_EQ("@" + PrintPointer(&x) + " 5", Print(std::cref(x))); +} + +TEST(PrintReferenceWrapper, Unprintable) { + ::foo::UnprintableInFoo up; + EXPECT_EQ( + "@" + PrintPointer(&up) + + " 16-byte object ", + Print(std::ref(up))); + EXPECT_EQ( + "@" + PrintPointer(&up) + + " 16-byte object ", + Print(std::cref(up))); +} + +// Tests printing user-defined unprintable types. + +// Unprintable types in the global namespace. +TEST(PrintUnprintableTypeTest, InGlobalNamespace) { + EXPECT_EQ("1-byte object <00>", + Print(UnprintableTemplateInGlobal())); +} + +// Unprintable types in a user namespace. +TEST(PrintUnprintableTypeTest, InUserNamespace) { + EXPECT_EQ("16-byte object ", + Print(::foo::UnprintableInFoo())); +} + +// Unprintable types are that too big to be printed completely. + +struct Big { + Big() { memset(array, 0, sizeof(array)); } + char array[257]; +}; + +TEST(PrintUnpritableTypeTest, BigObject) { + EXPECT_EQ("257-byte object <00-00 00-00 00-00 00-00 00-00 00-00 " + "00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 " + "00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 " + "00-00 00-00 00-00 00-00 00-00 00-00 ... 00-00 00-00 00-00 " + "00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 " + "00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 " + "00-00 00-00 00-00 00-00 00-00 00-00 00-00 00-00 00>", + Print(Big())); +} + +// Tests printing user-defined streamable types. + +// Streamable types in the global namespace. +TEST(PrintStreamableTypeTest, InGlobalNamespace) { + StreamableInGlobal x; + EXPECT_EQ("StreamableInGlobal", Print(x)); + EXPECT_EQ("StreamableInGlobal*", Print(&x)); +} + +// Printable template types in a user namespace. +TEST(PrintStreamableTypeTest, TemplateTypeInUserNamespace) { + EXPECT_EQ("StreamableTemplateInFoo: 0", + Print(::foo::StreamableTemplateInFoo())); +} + +TEST(PrintStreamableTypeTest, TypeInUserNamespaceWithTemplatedStreamOperator) { + EXPECT_EQ("TemplatedStreamableInFoo", + Print(::foo::TemplatedStreamableInFoo())); +} + +TEST(PrintStreamableTypeTest, SubclassUsesSuperclassStreamOperator) { + ParentClass parent; + ChildClassWithStreamOperator child_stream; + ChildClassWithoutStreamOperator child_no_stream; + EXPECT_EQ("ParentClass", Print(parent)); + EXPECT_EQ("ChildClassWithStreamOperator", Print(child_stream)); + EXPECT_EQ("ParentClass", Print(child_no_stream)); +} + +// Tests printing a user-defined recursive container type that has a << +// operator. +TEST(PrintStreamableTypeTest, PathLikeInUserNamespace) { + ::foo::PathLike x; + EXPECT_EQ("Streamable-PathLike", Print(x)); + const ::foo::PathLike cx; + EXPECT_EQ("Streamable-PathLike", Print(cx)); +} + +// Tests printing user-defined types that have a PrintTo() function. +TEST(PrintPrintableTypeTest, InUserNamespace) { + EXPECT_EQ("PrintableViaPrintTo: 0", + Print(::foo::PrintableViaPrintTo())); +} + +// Tests printing a pointer to a user-defined type that has a << +// operator for its pointer. +TEST(PrintPrintableTypeTest, PointerInUserNamespace) { + ::foo::PointerPrintable x; + EXPECT_EQ("PointerPrintable*", Print(&x)); +} + +// Tests printing user-defined class template that have a PrintTo() function. +TEST(PrintPrintableTypeTest, TemplateInUserNamespace) { + EXPECT_EQ("PrintableViaPrintToTemplate: 5", + Print(::foo::PrintableViaPrintToTemplate(5))); +} + +// Tests that the universal printer prints both the address and the +// value of a reference. +TEST(PrintReferenceTest, PrintsAddressAndValue) { + int n = 5; + EXPECT_EQ("@" + PrintPointer(&n) + " 5", PrintByRef(n)); + + int a[2][3] = { + { 0, 1, 2 }, + { 3, 4, 5 } + }; + EXPECT_EQ("@" + PrintPointer(a) + " { { 0, 1, 2 }, { 3, 4, 5 } }", + PrintByRef(a)); + + const ::foo::UnprintableInFoo x; + EXPECT_EQ("@" + PrintPointer(&x) + " 16-byte object " + "", + PrintByRef(x)); +} + +// Tests that the universal printer prints a function pointer passed by +// reference. +TEST(PrintReferenceTest, HandlesFunctionPointer) { + void (*fp)(int n) = &MyFunction; + const std::string fp_pointer_string = + PrintPointer(reinterpret_cast(&fp)); + // We cannot directly cast &MyFunction to const void* because the + // standard disallows casting between pointers to functions and + // pointers to objects, and some compilers (e.g. GCC 3.4) enforce + // this limitation. + const std::string fp_string = PrintPointer(reinterpret_cast( + reinterpret_cast(fp))); + EXPECT_EQ("@" + fp_pointer_string + " " + fp_string, + PrintByRef(fp)); +} + +// Tests that the universal printer prints a member function pointer +// passed by reference. +TEST(PrintReferenceTest, HandlesMemberFunctionPointer) { + int (Foo::*p)(char ch) = &Foo::MyMethod; + EXPECT_TRUE(HasPrefix( + PrintByRef(p), + "@" + PrintPointer(reinterpret_cast(&p)) + " " + + Print(sizeof(p)) + "-byte object ")); + + char (Foo::*p2)(int n) = &Foo::MyVirtualMethod; + EXPECT_TRUE(HasPrefix( + PrintByRef(p2), + "@" + PrintPointer(reinterpret_cast(&p2)) + " " + + Print(sizeof(p2)) + "-byte object ")); +} + +// Tests that the universal printer prints a member variable pointer +// passed by reference. +TEST(PrintReferenceTest, HandlesMemberVariablePointer) { + int Foo::*p = &Foo::value; // NOLINT + EXPECT_TRUE(HasPrefix( + PrintByRef(p), + "@" + PrintPointer(&p) + " " + Print(sizeof(p)) + "-byte object ")); +} + +// Tests that FormatForComparisonFailureMessage(), which is used to print +// an operand in a comparison assertion (e.g. ASSERT_EQ) when the assertion +// fails, formats the operand in the desired way. + +// scalar +TEST(FormatForComparisonFailureMessageTest, WorksForScalar) { + EXPECT_STREQ("123", + FormatForComparisonFailureMessage(123, 124).c_str()); +} + +// non-char pointer +TEST(FormatForComparisonFailureMessageTest, WorksForNonCharPointer) { + int n = 0; + EXPECT_EQ(PrintPointer(&n), + FormatForComparisonFailureMessage(&n, &n).c_str()); +} + +// non-char array +TEST(FormatForComparisonFailureMessageTest, FormatsNonCharArrayAsPointer) { + // In expression 'array == x', 'array' is compared by pointer. + // Therefore we want to print an array operand as a pointer. + int n[] = { 1, 2, 3 }; + EXPECT_EQ(PrintPointer(n), + FormatForComparisonFailureMessage(n, n).c_str()); +} + +// Tests formatting a char pointer when it's compared with another pointer. +// In this case we want to print it as a raw pointer, as the comparison is by +// pointer. + +// char pointer vs pointer +TEST(FormatForComparisonFailureMessageTest, WorksForCharPointerVsPointer) { + // In expression 'p == x', where 'p' and 'x' are (const or not) char + // pointers, the operands are compared by pointer. Therefore we + // want to print 'p' as a pointer instead of a C string (we don't + // even know if it's supposed to point to a valid C string). + + // const char* + const char* s = "hello"; + EXPECT_EQ(PrintPointer(s), + FormatForComparisonFailureMessage(s, s).c_str()); + + // char* + char ch = 'a'; + EXPECT_EQ(PrintPointer(&ch), + FormatForComparisonFailureMessage(&ch, &ch).c_str()); +} + +// wchar_t pointer vs pointer +TEST(FormatForComparisonFailureMessageTest, WorksForWCharPointerVsPointer) { + // In expression 'p == x', where 'p' and 'x' are (const or not) char + // pointers, the operands are compared by pointer. Therefore we + // want to print 'p' as a pointer instead of a wide C string (we don't + // even know if it's supposed to point to a valid wide C string). + + // const wchar_t* + const wchar_t* s = L"hello"; + EXPECT_EQ(PrintPointer(s), + FormatForComparisonFailureMessage(s, s).c_str()); + + // wchar_t* + wchar_t ch = L'a'; + EXPECT_EQ(PrintPointer(&ch), + FormatForComparisonFailureMessage(&ch, &ch).c_str()); +} + +// Tests formatting a char pointer when it's compared to a string object. +// In this case we want to print the char pointer as a C string. + +// char pointer vs std::string +TEST(FormatForComparisonFailureMessageTest, WorksForCharPointerVsStdString) { + const char* s = "hello \"world"; + EXPECT_STREQ("\"hello \\\"world\"", // The string content should be escaped. + FormatForComparisonFailureMessage(s, ::std::string()).c_str()); + + // char* + char str[] = "hi\1"; + char* p = str; + EXPECT_STREQ("\"hi\\x1\"", // The string content should be escaped. + FormatForComparisonFailureMessage(p, ::std::string()).c_str()); +} + +#if GTEST_HAS_STD_WSTRING +// wchar_t pointer vs std::wstring +TEST(FormatForComparisonFailureMessageTest, WorksForWCharPointerVsStdWString) { + const wchar_t* s = L"hi \"world"; + EXPECT_STREQ("L\"hi \\\"world\"", // The string content should be escaped. + FormatForComparisonFailureMessage(s, ::std::wstring()).c_str()); + + // wchar_t* + wchar_t str[] = L"hi\1"; + wchar_t* p = str; + EXPECT_STREQ("L\"hi\\x1\"", // The string content should be escaped. + FormatForComparisonFailureMessage(p, ::std::wstring()).c_str()); +} +#endif + +// Tests formatting a char array when it's compared with a pointer or array. +// In this case we want to print the array as a row pointer, as the comparison +// is by pointer. + +// char array vs pointer +TEST(FormatForComparisonFailureMessageTest, WorksForCharArrayVsPointer) { + char str[] = "hi \"world\""; + char* p = nullptr; + EXPECT_EQ(PrintPointer(str), + FormatForComparisonFailureMessage(str, p).c_str()); +} + +// char array vs char array +TEST(FormatForComparisonFailureMessageTest, WorksForCharArrayVsCharArray) { + const char str[] = "hi \"world\""; + EXPECT_EQ(PrintPointer(str), + FormatForComparisonFailureMessage(str, str).c_str()); +} + +// wchar_t array vs pointer +TEST(FormatForComparisonFailureMessageTest, WorksForWCharArrayVsPointer) { + wchar_t str[] = L"hi \"world\""; + wchar_t* p = nullptr; + EXPECT_EQ(PrintPointer(str), + FormatForComparisonFailureMessage(str, p).c_str()); +} + +// wchar_t array vs wchar_t array +TEST(FormatForComparisonFailureMessageTest, WorksForWCharArrayVsWCharArray) { + const wchar_t str[] = L"hi \"world\""; + EXPECT_EQ(PrintPointer(str), + FormatForComparisonFailureMessage(str, str).c_str()); +} + +// Tests formatting a char array when it's compared with a string object. +// In this case we want to print the array as a C string. + +// char array vs std::string +TEST(FormatForComparisonFailureMessageTest, WorksForCharArrayVsStdString) { + const char str[] = "hi \"world\""; + EXPECT_STREQ("\"hi \\\"world\\\"\"", // The content should be escaped. + FormatForComparisonFailureMessage(str, ::std::string()).c_str()); +} + +#if GTEST_HAS_STD_WSTRING +// wchar_t array vs std::wstring +TEST(FormatForComparisonFailureMessageTest, WorksForWCharArrayVsStdWString) { + const wchar_t str[] = L"hi \"w\0rld\""; + EXPECT_STREQ( + "L\"hi \\\"w\"", // The content should be escaped. + // Embedded NUL terminates the string. + FormatForComparisonFailureMessage(str, ::std::wstring()).c_str()); +} +#endif + +// Useful for testing PrintToString(). We cannot use EXPECT_EQ() +// there as its implementation uses PrintToString(). The caller must +// ensure that 'value' has no side effect. +#define EXPECT_PRINT_TO_STRING_(value, expected_string) \ + EXPECT_TRUE(PrintToString(value) == (expected_string)) \ + << " where " #value " prints as " << (PrintToString(value)) + +TEST(PrintToStringTest, WorksForScalar) { + EXPECT_PRINT_TO_STRING_(123, "123"); +} + +TEST(PrintToStringTest, WorksForPointerToConstChar) { + const char* p = "hello"; + EXPECT_PRINT_TO_STRING_(p, "\"hello\""); +} + +TEST(PrintToStringTest, WorksForPointerToNonConstChar) { + char s[] = "hello"; + char* p = s; + EXPECT_PRINT_TO_STRING_(p, "\"hello\""); +} + +TEST(PrintToStringTest, EscapesForPointerToConstChar) { + const char* p = "hello\n"; + EXPECT_PRINT_TO_STRING_(p, "\"hello\\n\""); +} + +TEST(PrintToStringTest, EscapesForPointerToNonConstChar) { + char s[] = "hello\1"; + char* p = s; + EXPECT_PRINT_TO_STRING_(p, "\"hello\\x1\""); +} + +TEST(PrintToStringTest, WorksForArray) { + int n[3] = { 1, 2, 3 }; + EXPECT_PRINT_TO_STRING_(n, "{ 1, 2, 3 }"); +} + +TEST(PrintToStringTest, WorksForCharArray) { + char s[] = "hello"; + EXPECT_PRINT_TO_STRING_(s, "\"hello\""); +} + +TEST(PrintToStringTest, WorksForCharArrayWithEmbeddedNul) { + const char str_with_nul[] = "hello\0 world"; + EXPECT_PRINT_TO_STRING_(str_with_nul, "\"hello\\0 world\""); + + char mutable_str_with_nul[] = "hello\0 world"; + EXPECT_PRINT_TO_STRING_(mutable_str_with_nul, "\"hello\\0 world\""); +} + + TEST(PrintToStringTest, ContainsNonLatin) { + // Sanity test with valid UTF-8. Prints both in hex and as text. + std::string non_ascii_str = ::std::string("오전 4:30"); + EXPECT_PRINT_TO_STRING_(non_ascii_str, + "\"\\xEC\\x98\\xA4\\xEC\\xA0\\x84 4:30\"\n" + " As Text: \"오전 4:30\""); + non_ascii_str = ::std::string("From ä — ẑ"); + EXPECT_PRINT_TO_STRING_(non_ascii_str, + "\"From \\xC3\\xA4 \\xE2\\x80\\x94 \\xE1\\xBA\\x91\"" + "\n As Text: \"From ä — ẑ\""); +} + +TEST(IsValidUTF8Test, IllFormedUTF8) { + // The following test strings are ill-formed UTF-8 and are printed + // as hex only (or ASCII, in case of ASCII bytes) because IsValidUTF8() is + // expected to fail, thus output does not contain "As Text:". + + static const char *const kTestdata[][2] = { + // 2-byte lead byte followed by a single-byte character. + {"\xC3\x74", "\"\\xC3t\""}, + // Valid 2-byte character followed by an orphan trail byte. + {"\xC3\x84\xA4", "\"\\xC3\\x84\\xA4\""}, + // Lead byte without trail byte. + {"abc\xC3", "\"abc\\xC3\""}, + // 3-byte lead byte, single-byte character, orphan trail byte. + {"x\xE2\x70\x94", "\"x\\xE2p\\x94\""}, + // Truncated 3-byte character. + {"\xE2\x80", "\"\\xE2\\x80\""}, + // Truncated 3-byte character followed by valid 2-byte char. + {"\xE2\x80\xC3\x84", "\"\\xE2\\x80\\xC3\\x84\""}, + // Truncated 3-byte character followed by a single-byte character. + {"\xE2\x80\x7A", "\"\\xE2\\x80z\""}, + // 3-byte lead byte followed by valid 3-byte character. + {"\xE2\xE2\x80\x94", "\"\\xE2\\xE2\\x80\\x94\""}, + // 4-byte lead byte followed by valid 3-byte character. + {"\xF0\xE2\x80\x94", "\"\\xF0\\xE2\\x80\\x94\""}, + // Truncated 4-byte character. + {"\xF0\xE2\x80", "\"\\xF0\\xE2\\x80\""}, + // Invalid UTF-8 byte sequences embedded in other chars. + {"abc\xE2\x80\x94\xC3\x74xyc", "\"abc\\xE2\\x80\\x94\\xC3txyc\""}, + {"abc\xC3\x84\xE2\x80\xC3\x84xyz", + "\"abc\\xC3\\x84\\xE2\\x80\\xC3\\x84xyz\""}, + // Non-shortest UTF-8 byte sequences are also ill-formed. + // The classics: xC0, xC1 lead byte. + {"\xC0\x80", "\"\\xC0\\x80\""}, + {"\xC1\x81", "\"\\xC1\\x81\""}, + // Non-shortest sequences. + {"\xE0\x80\x80", "\"\\xE0\\x80\\x80\""}, + {"\xf0\x80\x80\x80", "\"\\xF0\\x80\\x80\\x80\""}, + // Last valid code point before surrogate range, should be printed as text, + // too. + {"\xED\x9F\xBF", "\"\\xED\\x9F\\xBF\"\n As Text: \"퟿\""}, + // Start of surrogate lead. Surrogates are not printed as text. + {"\xED\xA0\x80", "\"\\xED\\xA0\\x80\""}, + // Last non-private surrogate lead. + {"\xED\xAD\xBF", "\"\\xED\\xAD\\xBF\""}, + // First private-use surrogate lead. + {"\xED\xAE\x80", "\"\\xED\\xAE\\x80\""}, + // Last private-use surrogate lead. + {"\xED\xAF\xBF", "\"\\xED\\xAF\\xBF\""}, + // Mid-point of surrogate trail. + {"\xED\xB3\xBF", "\"\\xED\\xB3\\xBF\""}, + // First valid code point after surrogate range, should be printed as text, + // too. + {"\xEE\x80\x80", "\"\\xEE\\x80\\x80\"\n As Text: \"\""} + }; + + for (int i = 0; i < int(sizeof(kTestdata)/sizeof(kTestdata[0])); ++i) { + EXPECT_PRINT_TO_STRING_(kTestdata[i][0], kTestdata[i][1]); + } +} + +#undef EXPECT_PRINT_TO_STRING_ + +TEST(UniversalTersePrintTest, WorksForNonReference) { + ::std::stringstream ss; + UniversalTersePrint(123, &ss); + EXPECT_EQ("123", ss.str()); +} + +TEST(UniversalTersePrintTest, WorksForReference) { + const int& n = 123; + ::std::stringstream ss; + UniversalTersePrint(n, &ss); + EXPECT_EQ("123", ss.str()); +} + +TEST(UniversalTersePrintTest, WorksForCString) { + const char* s1 = "abc"; + ::std::stringstream ss1; + UniversalTersePrint(s1, &ss1); + EXPECT_EQ("\"abc\"", ss1.str()); + + char* s2 = const_cast(s1); + ::std::stringstream ss2; + UniversalTersePrint(s2, &ss2); + EXPECT_EQ("\"abc\"", ss2.str()); + + const char* s3 = nullptr; + ::std::stringstream ss3; + UniversalTersePrint(s3, &ss3); + EXPECT_EQ("NULL", ss3.str()); +} + +TEST(UniversalPrintTest, WorksForNonReference) { + ::std::stringstream ss; + UniversalPrint(123, &ss); + EXPECT_EQ("123", ss.str()); +} + +TEST(UniversalPrintTest, WorksForReference) { + const int& n = 123; + ::std::stringstream ss; + UniversalPrint(n, &ss); + EXPECT_EQ("123", ss.str()); +} + +TEST(UniversalPrintTest, WorksForPairWithConst) { + std::pair, int> p(Wrapper("abc"), 1); + ::std::stringstream ss; + UniversalPrint(p, &ss); + EXPECT_EQ("(Wrapper(\"abc\"), 1)", ss.str()); +} + +TEST(UniversalPrintTest, WorksForCString) { + const char* s1 = "abc"; + ::std::stringstream ss1; + UniversalPrint(s1, &ss1); + EXPECT_EQ(PrintPointer(s1) + " pointing to \"abc\"", std::string(ss1.str())); + + char* s2 = const_cast(s1); + ::std::stringstream ss2; + UniversalPrint(s2, &ss2); + EXPECT_EQ(PrintPointer(s2) + " pointing to \"abc\"", std::string(ss2.str())); + + const char* s3 = nullptr; + ::std::stringstream ss3; + UniversalPrint(s3, &ss3); + EXPECT_EQ("NULL", ss3.str()); +} + +TEST(UniversalPrintTest, WorksForCharArray) { + const char str[] = "\"Line\0 1\"\nLine 2"; + ::std::stringstream ss1; + UniversalPrint(str, &ss1); + EXPECT_EQ("\"\\\"Line\\0 1\\\"\\nLine 2\"", ss1.str()); + + const char mutable_str[] = "\"Line\0 1\"\nLine 2"; + ::std::stringstream ss2; + UniversalPrint(mutable_str, &ss2); + EXPECT_EQ("\"\\\"Line\\0 1\\\"\\nLine 2\"", ss2.str()); +} + +TEST(UniversalPrintTest, IncompleteType) { + struct Incomplete; + char some_object = 0; + EXPECT_EQ("(incomplete type)", + PrintToString(reinterpret_cast(some_object))); +} + +TEST(UniversalPrintTest, SmartPointers) { + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + std::unique_ptr p(new int(17)); + EXPECT_EQ("(ptr = " + PrintPointer(p.get()) + ", value = 17)", + PrintToString(p)); + std::unique_ptr p2(new int[2]); + EXPECT_EQ("(" + PrintPointer(p2.get()) + ")", PrintToString(p2)); + + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + std::shared_ptr p3(new int(1979)); + EXPECT_EQ("(ptr = " + PrintPointer(p3.get()) + ", value = 1979)", + PrintToString(p3)); +#if __cpp_lib_shared_ptr_arrays >= 201611L + std::shared_ptr p4(new int[2]); + EXPECT_EQ("(" + PrintPointer(p4.get()) + ")", PrintToString(p4)); +#endif + + // modifiers + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", + PrintToString(std::unique_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); +#if __cpp_lib_shared_ptr_arrays >= 201611L + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + EXPECT_EQ("(nullptr)", + PrintToString(std::shared_ptr())); +#endif + + // void + EXPECT_EQ("(nullptr)", PrintToString(std::unique_ptr( + nullptr, nullptr))); + EXPECT_EQ("(" + PrintPointer(p.get()) + ")", + PrintToString( + std::unique_ptr(p.get(), [](void*) {}))); + EXPECT_EQ("(nullptr)", PrintToString(std::shared_ptr())); + EXPECT_EQ("(" + PrintPointer(p.get()) + ")", + PrintToString(std::shared_ptr(p.get(), [](void*) {}))); +} + +TEST(UniversalTersePrintTupleFieldsToStringsTestWithStd, PrintsEmptyTuple) { + Strings result = UniversalTersePrintTupleFieldsToStrings(::std::make_tuple()); + EXPECT_EQ(0u, result.size()); +} + +TEST(UniversalTersePrintTupleFieldsToStringsTestWithStd, PrintsOneTuple) { + Strings result = UniversalTersePrintTupleFieldsToStrings( + ::std::make_tuple(1)); + ASSERT_EQ(1u, result.size()); + EXPECT_EQ("1", result[0]); +} + +TEST(UniversalTersePrintTupleFieldsToStringsTestWithStd, PrintsTwoTuple) { + Strings result = UniversalTersePrintTupleFieldsToStrings( + ::std::make_tuple(1, 'a')); + ASSERT_EQ(2u, result.size()); + EXPECT_EQ("1", result[0]); + EXPECT_EQ("'a' (97, 0x61)", result[1]); +} + +TEST(UniversalTersePrintTupleFieldsToStringsTestWithStd, PrintsTersely) { + const int n = 1; + Strings result = UniversalTersePrintTupleFieldsToStrings( + ::std::tuple(n, "a")); + ASSERT_EQ(2u, result.size()); + EXPECT_EQ("1", result[0]); + EXPECT_EQ("\"a\"", result[1]); +} + +#if GTEST_INTERNAL_HAS_ANY +class PrintAnyTest : public ::testing::Test { + protected: + template + static std::string ExpectedTypeName() { +#if GTEST_HAS_RTTI + return internal::GetTypeName(); +#else + return ""; +#endif // GTEST_HAS_RTTI + } +}; + +TEST_F(PrintAnyTest, Empty) { + internal::Any any; + EXPECT_EQ("no value", PrintToString(any)); +} + +TEST_F(PrintAnyTest, NonEmpty) { + internal::Any any; + constexpr int val1 = 10; + const std::string val2 = "content"; + + any = val1; + EXPECT_EQ("value of type " + ExpectedTypeName(), PrintToString(any)); + + any = val2; + EXPECT_EQ("value of type " + ExpectedTypeName(), + PrintToString(any)); +} +#endif // GTEST_INTERNAL_HAS_ANY + +#if GTEST_INTERNAL_HAS_OPTIONAL +TEST(PrintOptionalTest, Basic) { + EXPECT_EQ("(nullopt)", PrintToString(internal::Nullopt())); + internal::Optional value; + EXPECT_EQ("(nullopt)", PrintToString(value)); + value = {7}; + EXPECT_EQ("(7)", PrintToString(value)); + EXPECT_EQ("(1.1)", PrintToString(internal::Optional{1.1})); + EXPECT_EQ("(\"A\")", PrintToString(internal::Optional{"A"})); +} +#endif // GTEST_INTERNAL_HAS_OPTIONAL + +#if GTEST_INTERNAL_HAS_VARIANT +struct NonPrintable { + unsigned char contents = 17; +}; + +TEST(PrintOneofTest, Basic) { + using Type = internal::Variant; + EXPECT_EQ("('int(index = 0)' with value 7)", PrintToString(Type(7))); + EXPECT_EQ("('StreamableInGlobal(index = 1)' with value StreamableInGlobal)", + PrintToString(Type(StreamableInGlobal{}))); + EXPECT_EQ( + "('testing::gtest_printers_test::NonPrintable(index = 2)' with value " + "1-byte object <11>)", + PrintToString(Type(NonPrintable{}))); +} +#endif // GTEST_INTERNAL_HAS_VARIANT +namespace { +class string_ref; + +/** + * This is a synthetic pointer to a fixed size string. + */ +class string_ptr { + public: + string_ptr(const char* data, size_t size) : data_(data), size_(size) {} + + string_ptr& operator++() noexcept { + data_ += size_; + return *this; + } + + string_ref operator*() const noexcept; + + private: + const char* data_; + size_t size_; +}; + +/** + * This is a synthetic reference of a fixed size string. + */ +class string_ref { + public: + string_ref(const char* data, size_t size) : data_(data), size_(size) {} + + string_ptr operator&() const noexcept { return {data_, size_}; } // NOLINT + + bool operator==(const char* s) const noexcept { + if (size_ > 0 && data_[size_ - 1] != 0) { + return std::string(data_, size_) == std::string(s); + } else { + return std::string(data_) == std::string(s); + } + } + + private: + const char* data_; + size_t size_; +}; + +string_ref string_ptr::operator*() const noexcept { return {data_, size_}; } + +TEST(string_ref, compare) { + const char* s = "alex\0davidjohn\0"; + string_ptr ptr(s, 5); + EXPECT_EQ(*ptr, "alex"); + EXPECT_TRUE(*ptr == "alex"); + ++ptr; + EXPECT_EQ(*ptr, "david"); + EXPECT_TRUE(*ptr == "david"); + ++ptr; + EXPECT_EQ(*ptr, "john"); +} + +} // namespace + +} // namespace gtest_printers_test +} // namespace testing diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-setuptestsuite-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-setuptestsuite-test.py new file mode 100755 index 000000000000..9d1fd0295cd1 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-setuptestsuite-test.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# +# Copyright 2019, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Verifies that SetUpTestSuite and TearDownTestSuite errors are noticed.""" + +from googletest.test import gtest_test_utils + +COMMAND = gtest_test_utils.GetTestExecutablePath( + 'googletest-setuptestsuite-test_') + + +class GTestSetUpTestSuiteTest(gtest_test_utils.TestCase): + + def testSetupErrorAndTearDownError(self): + p = gtest_test_utils.Subprocess(COMMAND) + self.assertNotEqual(p.exit_code, 0, msg=p.output) + + self.assertIn( + '[ FAILED ] SetupFailTest: SetUpTestSuite or TearDownTestSuite\n' + '[ FAILED ] TearDownFailTest: SetUpTestSuite or TearDownTestSuite\n' + '\n' + ' 2 FAILED TEST SUITES\n', + p.output) + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-setuptestsuite-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-setuptestsuite-test_.cc new file mode 100644 index 000000000000..a4bc4ef441d9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-setuptestsuite-test_.cc @@ -0,0 +1,49 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "gtest/gtest.h" + +class SetupFailTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + ASSERT_EQ("", "SET_UP_FAIL"); + } +}; + +TEST_F(SetupFailTest, NoopPassingTest) {} + +class TearDownFailTest : public ::testing::Test { + protected: + static void TearDownTestSuite() { + ASSERT_EQ("", "TEAR_DOWN_FAIL"); + } +}; + +TEST_F(TearDownFailTest, NoopPassingTest) {} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-shuffle-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-shuffle-test.py new file mode 100755 index 000000000000..9d2adc1286b2 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-shuffle-test.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python +# +# Copyright 2009 Google Inc. All Rights Reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Verifies that test shuffling works.""" + +import os +from googletest.test import gtest_test_utils + +# Command to run the googletest-shuffle-test_ program. +COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-shuffle-test_') + +# The environment variables for test sharding. +TOTAL_SHARDS_ENV_VAR = 'GTEST_TOTAL_SHARDS' +SHARD_INDEX_ENV_VAR = 'GTEST_SHARD_INDEX' + +TEST_FILTER = 'A*.A:A*.B:C*' + +ALL_TESTS = [] +ACTIVE_TESTS = [] +FILTERED_TESTS = [] +SHARDED_TESTS = [] + +SHUFFLED_ALL_TESTS = [] +SHUFFLED_ACTIVE_TESTS = [] +SHUFFLED_FILTERED_TESTS = [] +SHUFFLED_SHARDED_TESTS = [] + + +def AlsoRunDisabledTestsFlag(): + return '--gtest_also_run_disabled_tests' + + +def FilterFlag(test_filter): + return '--gtest_filter=%s' % (test_filter,) + + +def RepeatFlag(n): + return '--gtest_repeat=%s' % (n,) + + +def ShuffleFlag(): + return '--gtest_shuffle' + + +def RandomSeedFlag(n): + return '--gtest_random_seed=%s' % (n,) + + +def RunAndReturnOutput(extra_env, args): + """Runs the test program and returns its output.""" + + environ_copy = os.environ.copy() + environ_copy.update(extra_env) + + return gtest_test_utils.Subprocess([COMMAND] + args, env=environ_copy).output + + +def GetTestsForAllIterations(extra_env, args): + """Runs the test program and returns a list of test lists. + + Args: + extra_env: a map from environment variables to their values + args: command line flags to pass to googletest-shuffle-test_ + + Returns: + A list where the i-th element is the list of tests run in the i-th + test iteration. + """ + + test_iterations = [] + for line in RunAndReturnOutput(extra_env, args).split('\n'): + if line.startswith('----'): + tests = [] + test_iterations.append(tests) + elif line.strip(): + tests.append(line.strip()) # 'TestCaseName.TestName' + + return test_iterations + + +def GetTestCases(tests): + """Returns a list of test cases in the given full test names. + + Args: + tests: a list of full test names + + Returns: + A list of test cases from 'tests', in their original order. + Consecutive duplicates are removed. + """ + + test_cases = [] + for test in tests: + test_case = test.split('.')[0] + if not test_case in test_cases: + test_cases.append(test_case) + + return test_cases + + +def CalculateTestLists(): + """Calculates the list of tests run under different flags.""" + + if not ALL_TESTS: + ALL_TESTS.extend( + GetTestsForAllIterations({}, [AlsoRunDisabledTestsFlag()])[0]) + + if not ACTIVE_TESTS: + ACTIVE_TESTS.extend(GetTestsForAllIterations({}, [])[0]) + + if not FILTERED_TESTS: + FILTERED_TESTS.extend( + GetTestsForAllIterations({}, [FilterFlag(TEST_FILTER)])[0]) + + if not SHARDED_TESTS: + SHARDED_TESTS.extend( + GetTestsForAllIterations({TOTAL_SHARDS_ENV_VAR: '3', + SHARD_INDEX_ENV_VAR: '1'}, + [])[0]) + + if not SHUFFLED_ALL_TESTS: + SHUFFLED_ALL_TESTS.extend(GetTestsForAllIterations( + {}, [AlsoRunDisabledTestsFlag(), ShuffleFlag(), RandomSeedFlag(1)])[0]) + + if not SHUFFLED_ACTIVE_TESTS: + SHUFFLED_ACTIVE_TESTS.extend(GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1)])[0]) + + if not SHUFFLED_FILTERED_TESTS: + SHUFFLED_FILTERED_TESTS.extend(GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1), FilterFlag(TEST_FILTER)])[0]) + + if not SHUFFLED_SHARDED_TESTS: + SHUFFLED_SHARDED_TESTS.extend( + GetTestsForAllIterations({TOTAL_SHARDS_ENV_VAR: '3', + SHARD_INDEX_ENV_VAR: '1'}, + [ShuffleFlag(), RandomSeedFlag(1)])[0]) + + +class GTestShuffleUnitTest(gtest_test_utils.TestCase): + """Tests test shuffling.""" + + def setUp(self): + CalculateTestLists() + + def testShufflePreservesNumberOfTests(self): + self.assertEqual(len(ALL_TESTS), len(SHUFFLED_ALL_TESTS)) + self.assertEqual(len(ACTIVE_TESTS), len(SHUFFLED_ACTIVE_TESTS)) + self.assertEqual(len(FILTERED_TESTS), len(SHUFFLED_FILTERED_TESTS)) + self.assertEqual(len(SHARDED_TESTS), len(SHUFFLED_SHARDED_TESTS)) + + def testShuffleChangesTestOrder(self): + self.assert_(SHUFFLED_ALL_TESTS != ALL_TESTS, SHUFFLED_ALL_TESTS) + self.assert_(SHUFFLED_ACTIVE_TESTS != ACTIVE_TESTS, SHUFFLED_ACTIVE_TESTS) + self.assert_(SHUFFLED_FILTERED_TESTS != FILTERED_TESTS, + SHUFFLED_FILTERED_TESTS) + self.assert_(SHUFFLED_SHARDED_TESTS != SHARDED_TESTS, + SHUFFLED_SHARDED_TESTS) + + def testShuffleChangesTestCaseOrder(self): + self.assert_(GetTestCases(SHUFFLED_ALL_TESTS) != GetTestCases(ALL_TESTS), + GetTestCases(SHUFFLED_ALL_TESTS)) + self.assert_( + GetTestCases(SHUFFLED_ACTIVE_TESTS) != GetTestCases(ACTIVE_TESTS), + GetTestCases(SHUFFLED_ACTIVE_TESTS)) + self.assert_( + GetTestCases(SHUFFLED_FILTERED_TESTS) != GetTestCases(FILTERED_TESTS), + GetTestCases(SHUFFLED_FILTERED_TESTS)) + self.assert_( + GetTestCases(SHUFFLED_SHARDED_TESTS) != GetTestCases(SHARDED_TESTS), + GetTestCases(SHUFFLED_SHARDED_TESTS)) + + def testShuffleDoesNotRepeatTest(self): + for test in SHUFFLED_ALL_TESTS: + self.assertEqual(1, SHUFFLED_ALL_TESTS.count(test), + '%s appears more than once' % (test,)) + for test in SHUFFLED_ACTIVE_TESTS: + self.assertEqual(1, SHUFFLED_ACTIVE_TESTS.count(test), + '%s appears more than once' % (test,)) + for test in SHUFFLED_FILTERED_TESTS: + self.assertEqual(1, SHUFFLED_FILTERED_TESTS.count(test), + '%s appears more than once' % (test,)) + for test in SHUFFLED_SHARDED_TESTS: + self.assertEqual(1, SHUFFLED_SHARDED_TESTS.count(test), + '%s appears more than once' % (test,)) + + def testShuffleDoesNotCreateNewTest(self): + for test in SHUFFLED_ALL_TESTS: + self.assert_(test in ALL_TESTS, '%s is an invalid test' % (test,)) + for test in SHUFFLED_ACTIVE_TESTS: + self.assert_(test in ACTIVE_TESTS, '%s is an invalid test' % (test,)) + for test in SHUFFLED_FILTERED_TESTS: + self.assert_(test in FILTERED_TESTS, '%s is an invalid test' % (test,)) + for test in SHUFFLED_SHARDED_TESTS: + self.assert_(test in SHARDED_TESTS, '%s is an invalid test' % (test,)) + + def testShuffleIncludesAllTests(self): + for test in ALL_TESTS: + self.assert_(test in SHUFFLED_ALL_TESTS, '%s is missing' % (test,)) + for test in ACTIVE_TESTS: + self.assert_(test in SHUFFLED_ACTIVE_TESTS, '%s is missing' % (test,)) + for test in FILTERED_TESTS: + self.assert_(test in SHUFFLED_FILTERED_TESTS, '%s is missing' % (test,)) + for test in SHARDED_TESTS: + self.assert_(test in SHUFFLED_SHARDED_TESTS, '%s is missing' % (test,)) + + def testShuffleLeavesDeathTestsAtFront(self): + non_death_test_found = False + for test in SHUFFLED_ACTIVE_TESTS: + if 'DeathTest.' in test: + self.assert_(not non_death_test_found, + '%s appears after a non-death test' % (test,)) + else: + non_death_test_found = True + + def _VerifyTestCasesDoNotInterleave(self, tests): + test_cases = [] + for test in tests: + [test_case, _] = test.split('.') + if test_cases and test_cases[-1] != test_case: + test_cases.append(test_case) + self.assertEqual(1, test_cases.count(test_case), + 'Test case %s is not grouped together in %s' % + (test_case, tests)) + + def testShuffleDoesNotInterleaveTestCases(self): + self._VerifyTestCasesDoNotInterleave(SHUFFLED_ALL_TESTS) + self._VerifyTestCasesDoNotInterleave(SHUFFLED_ACTIVE_TESTS) + self._VerifyTestCasesDoNotInterleave(SHUFFLED_FILTERED_TESTS) + self._VerifyTestCasesDoNotInterleave(SHUFFLED_SHARDED_TESTS) + + def testShuffleRestoresOrderAfterEachIteration(self): + # Get the test lists in all 3 iterations, using random seed 1, 2, + # and 3 respectively. Google Test picks a different seed in each + # iteration, and this test depends on the current implementation + # picking successive numbers. This dependency is not ideal, but + # makes the test much easier to write. + [tests_in_iteration1, tests_in_iteration2, tests_in_iteration3] = ( + GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1), RepeatFlag(3)])) + + # Make sure running the tests with random seed 1 gets the same + # order as in iteration 1 above. + [tests_with_seed1] = GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1)]) + self.assertEqual(tests_in_iteration1, tests_with_seed1) + + # Make sure running the tests with random seed 2 gets the same + # order as in iteration 2 above. Success means that Google Test + # correctly restores the test order before re-shuffling at the + # beginning of iteration 2. + [tests_with_seed2] = GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(2)]) + self.assertEqual(tests_in_iteration2, tests_with_seed2) + + # Make sure running the tests with random seed 3 gets the same + # order as in iteration 3 above. Success means that Google Test + # correctly restores the test order before re-shuffling at the + # beginning of iteration 3. + [tests_with_seed3] = GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(3)]) + self.assertEqual(tests_in_iteration3, tests_with_seed3) + + def testShuffleGeneratesNewOrderInEachIteration(self): + [tests_in_iteration1, tests_in_iteration2, tests_in_iteration3] = ( + GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1), RepeatFlag(3)])) + + self.assert_(tests_in_iteration1 != tests_in_iteration2, + tests_in_iteration1) + self.assert_(tests_in_iteration1 != tests_in_iteration3, + tests_in_iteration1) + self.assert_(tests_in_iteration2 != tests_in_iteration3, + tests_in_iteration2) + + def testShuffleShardedTestsPreservesPartition(self): + # If we run M tests on N shards, the same M tests should be run in + # total, regardless of the random seeds used by the shards. + [tests1] = GetTestsForAllIterations({TOTAL_SHARDS_ENV_VAR: '3', + SHARD_INDEX_ENV_VAR: '0'}, + [ShuffleFlag(), RandomSeedFlag(1)]) + [tests2] = GetTestsForAllIterations({TOTAL_SHARDS_ENV_VAR: '3', + SHARD_INDEX_ENV_VAR: '1'}, + [ShuffleFlag(), RandomSeedFlag(20)]) + [tests3] = GetTestsForAllIterations({TOTAL_SHARDS_ENV_VAR: '3', + SHARD_INDEX_ENV_VAR: '2'}, + [ShuffleFlag(), RandomSeedFlag(25)]) + sorted_sharded_tests = tests1 + tests2 + tests3 + sorted_sharded_tests.sort() + sorted_active_tests = [] + sorted_active_tests.extend(ACTIVE_TESTS) + sorted_active_tests.sort() + self.assertEqual(sorted_active_tests, sorted_sharded_tests) + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-shuffle-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-shuffle-test_.cc new file mode 100644 index 000000000000..4505663ae433 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-shuffle-test_.cc @@ -0,0 +1,101 @@ +// Copyright 2009, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Verifies that test shuffling works. + +#include "gtest/gtest.h" + +namespace { + +using ::testing::EmptyTestEventListener; +using ::testing::InitGoogleTest; +using ::testing::Message; +using ::testing::Test; +using ::testing::TestEventListeners; +using ::testing::TestInfo; +using ::testing::UnitTest; + +// The test methods are empty, as the sole purpose of this program is +// to print the test names before/after shuffling. + +class A : public Test {}; +TEST_F(A, A) {} +TEST_F(A, B) {} + +TEST(ADeathTest, A) {} +TEST(ADeathTest, B) {} +TEST(ADeathTest, C) {} + +TEST(B, A) {} +TEST(B, B) {} +TEST(B, C) {} +TEST(B, DISABLED_D) {} +TEST(B, DISABLED_E) {} + +TEST(BDeathTest, A) {} +TEST(BDeathTest, B) {} + +TEST(C, A) {} +TEST(C, B) {} +TEST(C, C) {} +TEST(C, DISABLED_D) {} + +TEST(CDeathTest, A) {} + +TEST(DISABLED_D, A) {} +TEST(DISABLED_D, DISABLED_B) {} + +// This printer prints the full test names only, starting each test +// iteration with a "----" marker. +class TestNamePrinter : public EmptyTestEventListener { + public: + void OnTestIterationStart(const UnitTest& /* unit_test */, + int /* iteration */) override { + printf("----\n"); + } + + void OnTestStart(const TestInfo& test_info) override { + printf("%s.%s\n", test_info.test_suite_name(), test_info.name()); + } +}; + +} // namespace + +int main(int argc, char **argv) { + InitGoogleTest(&argc, argv); + + // Replaces the default printer with TestNamePrinter, which prints + // the test name only. + TestEventListeners& listeners = UnitTest::GetInstance()->listeners(); + delete listeners.Release(listeners.default_result_printer()); + listeners.Append(new TestNamePrinter); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-test-part-test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-test-part-test.cc new file mode 100644 index 000000000000..44cf7ca044b8 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-test-part-test.cc @@ -0,0 +1,230 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "gtest/gtest-test-part.h" + +#include "gtest/gtest.h" + +using testing::Message; +using testing::Test; +using testing::TestPartResult; +using testing::TestPartResultArray; + +namespace { + +// Tests the TestPartResult class. + +// The test fixture for testing TestPartResult. +class TestPartResultTest : public Test { + protected: + TestPartResultTest() + : r1_(TestPartResult::kSuccess, "foo/bar.cc", 10, "Success!"), + r2_(TestPartResult::kNonFatalFailure, "foo/bar.cc", -1, "Failure!"), + r3_(TestPartResult::kFatalFailure, nullptr, -1, "Failure!"), + r4_(TestPartResult::kSkip, "foo/bar.cc", 2, "Skipped!") {} + + TestPartResult r1_, r2_, r3_, r4_; +}; + + +TEST_F(TestPartResultTest, ConstructorWorks) { + Message message; + message << "something is terribly wrong"; + message << static_cast(testing::internal::kStackTraceMarker); + message << "some unimportant stack trace"; + + const TestPartResult result(TestPartResult::kNonFatalFailure, + "some_file.cc", + 42, + message.GetString().c_str()); + + EXPECT_EQ(TestPartResult::kNonFatalFailure, result.type()); + EXPECT_STREQ("some_file.cc", result.file_name()); + EXPECT_EQ(42, result.line_number()); + EXPECT_STREQ(message.GetString().c_str(), result.message()); + EXPECT_STREQ("something is terribly wrong", result.summary()); +} + +TEST_F(TestPartResultTest, ResultAccessorsWork) { + const TestPartResult success(TestPartResult::kSuccess, + "file.cc", + 42, + "message"); + EXPECT_TRUE(success.passed()); + EXPECT_FALSE(success.failed()); + EXPECT_FALSE(success.nonfatally_failed()); + EXPECT_FALSE(success.fatally_failed()); + EXPECT_FALSE(success.skipped()); + + const TestPartResult nonfatal_failure(TestPartResult::kNonFatalFailure, + "file.cc", + 42, + "message"); + EXPECT_FALSE(nonfatal_failure.passed()); + EXPECT_TRUE(nonfatal_failure.failed()); + EXPECT_TRUE(nonfatal_failure.nonfatally_failed()); + EXPECT_FALSE(nonfatal_failure.fatally_failed()); + EXPECT_FALSE(nonfatal_failure.skipped()); + + const TestPartResult fatal_failure(TestPartResult::kFatalFailure, + "file.cc", + 42, + "message"); + EXPECT_FALSE(fatal_failure.passed()); + EXPECT_TRUE(fatal_failure.failed()); + EXPECT_FALSE(fatal_failure.nonfatally_failed()); + EXPECT_TRUE(fatal_failure.fatally_failed()); + EXPECT_FALSE(fatal_failure.skipped()); + + const TestPartResult skip(TestPartResult::kSkip, "file.cc", 42, "message"); + EXPECT_FALSE(skip.passed()); + EXPECT_FALSE(skip.failed()); + EXPECT_FALSE(skip.nonfatally_failed()); + EXPECT_FALSE(skip.fatally_failed()); + EXPECT_TRUE(skip.skipped()); +} + +// Tests TestPartResult::type(). +TEST_F(TestPartResultTest, type) { + EXPECT_EQ(TestPartResult::kSuccess, r1_.type()); + EXPECT_EQ(TestPartResult::kNonFatalFailure, r2_.type()); + EXPECT_EQ(TestPartResult::kFatalFailure, r3_.type()); + EXPECT_EQ(TestPartResult::kSkip, r4_.type()); +} + +// Tests TestPartResult::file_name(). +TEST_F(TestPartResultTest, file_name) { + EXPECT_STREQ("foo/bar.cc", r1_.file_name()); + EXPECT_STREQ(nullptr, r3_.file_name()); + EXPECT_STREQ("foo/bar.cc", r4_.file_name()); +} + +// Tests TestPartResult::line_number(). +TEST_F(TestPartResultTest, line_number) { + EXPECT_EQ(10, r1_.line_number()); + EXPECT_EQ(-1, r2_.line_number()); + EXPECT_EQ(2, r4_.line_number()); +} + +// Tests TestPartResult::message(). +TEST_F(TestPartResultTest, message) { + EXPECT_STREQ("Success!", r1_.message()); + EXPECT_STREQ("Skipped!", r4_.message()); +} + +// Tests TestPartResult::passed(). +TEST_F(TestPartResultTest, Passed) { + EXPECT_TRUE(r1_.passed()); + EXPECT_FALSE(r2_.passed()); + EXPECT_FALSE(r3_.passed()); + EXPECT_FALSE(r4_.passed()); +} + +// Tests TestPartResult::failed(). +TEST_F(TestPartResultTest, Failed) { + EXPECT_FALSE(r1_.failed()); + EXPECT_TRUE(r2_.failed()); + EXPECT_TRUE(r3_.failed()); + EXPECT_FALSE(r4_.failed()); +} + +// Tests TestPartResult::failed(). +TEST_F(TestPartResultTest, Skipped) { + EXPECT_FALSE(r1_.skipped()); + EXPECT_FALSE(r2_.skipped()); + EXPECT_FALSE(r3_.skipped()); + EXPECT_TRUE(r4_.skipped()); +} + +// Tests TestPartResult::fatally_failed(). +TEST_F(TestPartResultTest, FatallyFailed) { + EXPECT_FALSE(r1_.fatally_failed()); + EXPECT_FALSE(r2_.fatally_failed()); + EXPECT_TRUE(r3_.fatally_failed()); + EXPECT_FALSE(r4_.fatally_failed()); +} + +// Tests TestPartResult::nonfatally_failed(). +TEST_F(TestPartResultTest, NonfatallyFailed) { + EXPECT_FALSE(r1_.nonfatally_failed()); + EXPECT_TRUE(r2_.nonfatally_failed()); + EXPECT_FALSE(r3_.nonfatally_failed()); + EXPECT_FALSE(r4_.nonfatally_failed()); +} + +// Tests the TestPartResultArray class. + +class TestPartResultArrayTest : public Test { + protected: + TestPartResultArrayTest() + : r1_(TestPartResult::kNonFatalFailure, "foo/bar.cc", -1, "Failure 1"), + r2_(TestPartResult::kFatalFailure, "foo/bar.cc", -1, "Failure 2") {} + + const TestPartResult r1_, r2_; +}; + +// Tests that TestPartResultArray initially has size 0. +TEST_F(TestPartResultArrayTest, InitialSizeIsZero) { + TestPartResultArray results; + EXPECT_EQ(0, results.size()); +} + +// Tests that TestPartResultArray contains the given TestPartResult +// after one Append() operation. +TEST_F(TestPartResultArrayTest, ContainsGivenResultAfterAppend) { + TestPartResultArray results; + results.Append(r1_); + EXPECT_EQ(1, results.size()); + EXPECT_STREQ("Failure 1", results.GetTestPartResult(0).message()); +} + +// Tests that TestPartResultArray contains the given TestPartResults +// after two Append() operations. +TEST_F(TestPartResultArrayTest, ContainsGivenResultsAfterTwoAppends) { + TestPartResultArray results; + results.Append(r1_); + results.Append(r2_); + EXPECT_EQ(2, results.size()); + EXPECT_STREQ("Failure 1", results.GetTestPartResult(0).message()); + EXPECT_STREQ("Failure 2", results.GetTestPartResult(1).message()); +} + +typedef TestPartResultArrayTest TestPartResultArrayDeathTest; + +// Tests that the program dies when GetTestPartResult() is called with +// an invalid index. +TEST_F(TestPartResultArrayDeathTest, DiesWhenIndexIsOutOfBound) { + TestPartResultArray results; + results.Append(r1_); + + EXPECT_DEATH_IF_SUPPORTED(results.GetTestPartResult(-1), ""); + EXPECT_DEATH_IF_SUPPORTED(results.GetTestPartResult(1), ""); +} + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-throw-on-failure-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-throw-on-failure-test.py new file mode 100755 index 000000000000..772bbc5f39b9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-throw-on-failure-test.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# +# Copyright 2009, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests Google Test's throw-on-failure mode with exceptions disabled. + +This script invokes googletest-throw-on-failure-test_ (a program written with +Google Test) with different environments and command line flags. +""" + +import os +from googletest.test import gtest_test_utils + + +# Constants. + +# The command line flag for enabling/disabling the throw-on-failure mode. +THROW_ON_FAILURE = 'gtest_throw_on_failure' + +# Path to the googletest-throw-on-failure-test_ program, compiled with +# exceptions disabled. +EXE_PATH = gtest_test_utils.GetTestExecutablePath( + 'googletest-throw-on-failure-test_') + + +# Utilities. + + +def SetEnvVar(env_var, value): + """Sets an environment variable to a given value; unsets it when the + given value is None. + """ + + env_var = env_var.upper() + if value is not None: + os.environ[env_var] = value + elif env_var in os.environ: + del os.environ[env_var] + + +def Run(command): + """Runs a command; returns True/False if its exit code is/isn't 0.""" + + print('Running "%s". . .' % ' '.join(command)) + p = gtest_test_utils.Subprocess(command) + return p.exited and p.exit_code == 0 + + +# The tests. +class ThrowOnFailureTest(gtest_test_utils.TestCase): + """Tests the throw-on-failure mode.""" + + def RunAndVerify(self, env_var_value, flag_value, should_fail): + """Runs googletest-throw-on-failure-test_ and verifies that it does + (or does not) exit with a non-zero code. + + Args: + env_var_value: value of the GTEST_BREAK_ON_FAILURE environment + variable; None if the variable should be unset. + flag_value: value of the --gtest_break_on_failure flag; + None if the flag should not be present. + should_fail: True if and only if the program is expected to fail. + """ + + SetEnvVar(THROW_ON_FAILURE, env_var_value) + + if env_var_value is None: + env_var_value_msg = ' is not set' + else: + env_var_value_msg = '=' + env_var_value + + if flag_value is None: + flag = '' + elif flag_value == '0': + flag = '--%s=0' % THROW_ON_FAILURE + else: + flag = '--%s' % THROW_ON_FAILURE + + command = [EXE_PATH] + if flag: + command.append(flag) + + if should_fail: + should_or_not = 'should' + else: + should_or_not = 'should not' + + failed = not Run(command) + + SetEnvVar(THROW_ON_FAILURE, None) + + msg = ('when %s%s, an assertion failure in "%s" %s cause a non-zero ' + 'exit code.' % + (THROW_ON_FAILURE, env_var_value_msg, ' '.join(command), + should_or_not)) + self.assert_(failed == should_fail, msg) + + def testDefaultBehavior(self): + """Tests the behavior of the default mode.""" + + self.RunAndVerify(env_var_value=None, flag_value=None, should_fail=False) + + def testThrowOnFailureEnvVar(self): + """Tests using the GTEST_THROW_ON_FAILURE environment variable.""" + + self.RunAndVerify(env_var_value='0', + flag_value=None, + should_fail=False) + self.RunAndVerify(env_var_value='1', + flag_value=None, + should_fail=True) + + def testThrowOnFailureFlag(self): + """Tests using the --gtest_throw_on_failure flag.""" + + self.RunAndVerify(env_var_value=None, + flag_value='0', + should_fail=False) + self.RunAndVerify(env_var_value=None, + flag_value='1', + should_fail=True) + + def testThrowOnFailureFlagOverridesEnvVar(self): + """Tests that --gtest_throw_on_failure overrides GTEST_THROW_ON_FAILURE.""" + + self.RunAndVerify(env_var_value='0', + flag_value='0', + should_fail=False) + self.RunAndVerify(env_var_value='0', + flag_value='1', + should_fail=True) + self.RunAndVerify(env_var_value='1', + flag_value='0', + should_fail=False) + self.RunAndVerify(env_var_value='1', + flag_value='1', + should_fail=True) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-throw-on-failure-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-throw-on-failure-test_.cc new file mode 100644 index 000000000000..83bb914c7e47 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-throw-on-failure-test_.cc @@ -0,0 +1,71 @@ +// Copyright 2009, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Tests Google Test's throw-on-failure mode with exceptions disabled. +// +// This program must be compiled with exceptions disabled. It will be +// invoked by googletest-throw-on-failure-test.py, and is expected to exit +// with non-zero in the throw-on-failure mode or 0 otherwise. + +#include "gtest/gtest.h" + +#include // for fflush, fprintf, NULL, etc. +#include // for exit +#include // for set_terminate + +// This terminate handler aborts the program using exit() rather than abort(). +// This avoids showing pop-ups on Windows systems and core dumps on Unix-like +// ones. +void TerminateHandler() { + fprintf(stderr, "%s\n", "Unhandled C++ exception terminating the program."); + fflush(nullptr); + exit(1); +} + +int main(int argc, char** argv) { +#if GTEST_HAS_EXCEPTIONS + std::set_terminate(&TerminateHandler); +#endif + testing::InitGoogleTest(&argc, argv); + + // We want to ensure that people can use Google Test assertions in + // other testing frameworks, as long as they initialize Google Test + // properly and set the throw-on-failure mode. Therefore, we don't + // use Google Test's constructs for defining and running tests + // (e.g. TEST and RUN_ALL_TESTS) here. + + // In the throw-on-failure mode with exceptions disabled, this + // assertion will cause the program to exit with a non-zero code. + EXPECT_EQ(2, 3); + + // When not in the throw-on-failure mode, the control will reach + // here. + return 0; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-uninitialized-test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-uninitialized-test.py new file mode 100755 index 000000000000..73c91764a5b8 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-uninitialized-test.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# +# Copyright 2008, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Verifies that Google Test warns the user when not initialized properly.""" + +from googletest.test import gtest_test_utils + +COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-uninitialized-test_') + + +def Assert(condition): + if not condition: + raise AssertionError + + +def AssertEq(expected, actual): + if expected != actual: + print('Expected: %s' % (expected,)) + print(' Actual: %s' % (actual,)) + raise AssertionError + + +def TestExitCodeAndOutput(command): + """Runs the given command and verifies its exit code and output.""" + + # Verifies that 'command' exits with code 1. + p = gtest_test_utils.Subprocess(command) + if p.exited and p.exit_code == 0: + Assert('IMPORTANT NOTICE' in p.output); + Assert('InitGoogleTest' in p.output) + + +class GTestUninitializedTest(gtest_test_utils.TestCase): + def testExitCodeAndOutput(self): + TestExitCodeAndOutput(COMMAND) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-uninitialized-test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-uninitialized-test_.cc new file mode 100644 index 000000000000..b4434d51eebf --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/googletest-uninitialized-test_.cc @@ -0,0 +1,42 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "gtest/gtest.h" + +TEST(DummyTest, Dummy) { + // This test doesn't verify anything. We just need it to create a + // realistic stage for testing the behavior of Google Test when + // RUN_ALL_TESTS() is called without + // testing::InitGoogleTest() being called first. +} + +int main() { + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test2_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test2_test.cc new file mode 100644 index 000000000000..e83ca2e11b1b --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test2_test.cc @@ -0,0 +1,40 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include + +#include "test/gtest-typed-test_test.h" +#include "gtest/gtest.h" + +// Tests that the same type-parameterized test case can be +// instantiated in different translation units linked together. +// (ContainerTest is also instantiated in gtest-typed-test_test.cc.) +INSTANTIATE_TYPED_TEST_SUITE_P(Vector, ContainerTest, + testing::Types >); diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test_test.cc new file mode 100644 index 000000000000..5fc678cb0d5d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test_test.cc @@ -0,0 +1,437 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "test/gtest-typed-test_test.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +#if _MSC_VER +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4127 /* conditional expression is constant */) +#endif // _MSC_VER + +using testing::Test; + +// Used for testing that SetUpTestSuite()/TearDownTestSuite(), fixture +// ctor/dtor, and SetUp()/TearDown() work correctly in typed tests and +// type-parameterized test. +template +class CommonTest : public Test { + // For some technical reason, SetUpTestSuite() and TearDownTestSuite() + // must be public. + public: + static void SetUpTestSuite() { + shared_ = new T(5); + } + + static void TearDownTestSuite() { + delete shared_; + shared_ = nullptr; + } + + // This 'protected:' is optional. There's no harm in making all + // members of this fixture class template public. + protected: + // We used to use std::list here, but switched to std::vector since + // MSVC's doesn't compile cleanly with /W4. + typedef std::vector Vector; + typedef std::set IntSet; + + CommonTest() : value_(1) {} + + ~CommonTest() override { EXPECT_EQ(3, value_); } + + void SetUp() override { + EXPECT_EQ(1, value_); + value_++; + } + + void TearDown() override { + EXPECT_EQ(2, value_); + value_++; + } + + T value_; + static T* shared_; +}; + +template +T* CommonTest::shared_ = nullptr; + +using testing::Types; + +// Tests that SetUpTestSuite()/TearDownTestSuite(), fixture ctor/dtor, +// and SetUp()/TearDown() work correctly in typed tests + +typedef Types TwoTypes; +TYPED_TEST_SUITE(CommonTest, TwoTypes); + +TYPED_TEST(CommonTest, ValuesAreCorrect) { + // Static members of the fixture class template can be visited via + // the TestFixture:: prefix. + EXPECT_EQ(5, *TestFixture::shared_); + + // Typedefs in the fixture class template can be visited via the + // "typename TestFixture::" prefix. + typename TestFixture::Vector empty; + EXPECT_EQ(0U, empty.size()); + + typename TestFixture::IntSet empty2; + EXPECT_EQ(0U, empty2.size()); + + // Non-static members of the fixture class must be visited via + // 'this', as required by C++ for class templates. + EXPECT_EQ(2, this->value_); +} + +// The second test makes sure shared_ is not deleted after the first +// test. +TYPED_TEST(CommonTest, ValuesAreStillCorrect) { + // Static members of the fixture class template can also be visited + // via 'this'. + ASSERT_TRUE(this->shared_ != nullptr); + EXPECT_EQ(5, *this->shared_); + + // TypeParam can be used to refer to the type parameter. + EXPECT_EQ(static_cast(2), this->value_); +} + +// Tests that multiple TYPED_TEST_SUITE's can be defined in the same +// translation unit. + +template +class TypedTest1 : public Test { +}; + +// Verifies that the second argument of TYPED_TEST_SUITE can be a +// single type. +TYPED_TEST_SUITE(TypedTest1, int); +TYPED_TEST(TypedTest1, A) {} + +template +class TypedTest2 : public Test { +}; + +// Verifies that the second argument of TYPED_TEST_SUITE can be a +// Types<...> type list. +TYPED_TEST_SUITE(TypedTest2, Types); + +// This also verifies that tests from different typed test cases can +// share the same name. +TYPED_TEST(TypedTest2, A) {} + +// Tests that a typed test case can be defined in a namespace. + +namespace library1 { + +template +class NumericTest : public Test { +}; + +typedef Types NumericTypes; +TYPED_TEST_SUITE(NumericTest, NumericTypes); + +TYPED_TEST(NumericTest, DefaultIsZero) { + EXPECT_EQ(0, TypeParam()); +} + +} // namespace library1 + +// Tests that custom names work. +template +class TypedTestWithNames : public Test {}; + +class TypedTestNames { + public: + template + static std::string GetName(int i) { + if (std::is_same::value) { + return std::string("char") + ::testing::PrintToString(i); + } + if (std::is_same::value) { + return std::string("int") + ::testing::PrintToString(i); + } + } +}; + +TYPED_TEST_SUITE(TypedTestWithNames, TwoTypes, TypedTestNames); + +TYPED_TEST(TypedTestWithNames, TestSuiteName) { + if (std::is_same::value) { + EXPECT_STREQ(::testing::UnitTest::GetInstance() + ->current_test_info() + ->test_suite_name(), + "TypedTestWithNames/char0"); + } + if (std::is_same::value) { + EXPECT_STREQ(::testing::UnitTest::GetInstance() + ->current_test_info() + ->test_suite_name(), + "TypedTestWithNames/int1"); + } +} + +using testing::Types; +using testing::internal::TypedTestSuitePState; + +// Tests TypedTestSuitePState. + +class TypedTestSuitePStateTest : public Test { + protected: + void SetUp() override { + state_.AddTestName("foo.cc", 0, "FooTest", "A"); + state_.AddTestName("foo.cc", 0, "FooTest", "B"); + state_.AddTestName("foo.cc", 0, "FooTest", "C"); + } + + TypedTestSuitePState state_; +}; + +TEST_F(TypedTestSuitePStateTest, SucceedsForMatchingList) { + const char* tests = "A, B, C"; + EXPECT_EQ(tests, + state_.VerifyRegisteredTestNames("Suite", "foo.cc", 1, tests)); +} + +// Makes sure that the order of the tests and spaces around the names +// don't matter. +TEST_F(TypedTestSuitePStateTest, IgnoresOrderAndSpaces) { + const char* tests = "A,C, B"; + EXPECT_EQ(tests, + state_.VerifyRegisteredTestNames("Suite", "foo.cc", 1, tests)); +} + +using TypedTestSuitePStateDeathTest = TypedTestSuitePStateTest; + +TEST_F(TypedTestSuitePStateDeathTest, DetectsDuplicates) { + EXPECT_DEATH_IF_SUPPORTED( + state_.VerifyRegisteredTestNames("Suite", "foo.cc", 1, "A, B, A, C"), + "foo\\.cc.1.?: Test A is listed more than once\\."); +} + +TEST_F(TypedTestSuitePStateDeathTest, DetectsExtraTest) { + EXPECT_DEATH_IF_SUPPORTED( + state_.VerifyRegisteredTestNames("Suite", "foo.cc", 1, "A, B, C, D"), + "foo\\.cc.1.?: No test named D can be found in this test suite\\."); +} + +TEST_F(TypedTestSuitePStateDeathTest, DetectsMissedTest) { + EXPECT_DEATH_IF_SUPPORTED( + state_.VerifyRegisteredTestNames("Suite", "foo.cc", 1, "A, C"), + "foo\\.cc.1.?: You forgot to list test B\\."); +} + +// Tests that defining a test for a parameterized test case generates +// a run-time error if the test case has been registered. +TEST_F(TypedTestSuitePStateDeathTest, DetectsTestAfterRegistration) { + state_.VerifyRegisteredTestNames("Suite", "foo.cc", 1, "A, B, C"); + EXPECT_DEATH_IF_SUPPORTED( + state_.AddTestName("foo.cc", 2, "FooTest", "D"), + "foo\\.cc.2.?: Test D must be defined before REGISTER_TYPED_TEST_SUITE_P" + "\\(FooTest, \\.\\.\\.\\)\\."); +} + +// Tests that SetUpTestSuite()/TearDownTestSuite(), fixture ctor/dtor, +// and SetUp()/TearDown() work correctly in type-parameterized tests. + +template +class DerivedTest : public CommonTest { +}; + +TYPED_TEST_SUITE_P(DerivedTest); + +TYPED_TEST_P(DerivedTest, ValuesAreCorrect) { + // Static members of the fixture class template can be visited via + // the TestFixture:: prefix. + EXPECT_EQ(5, *TestFixture::shared_); + + // Non-static members of the fixture class must be visited via + // 'this', as required by C++ for class templates. + EXPECT_EQ(2, this->value_); +} + +// The second test makes sure shared_ is not deleted after the first +// test. +TYPED_TEST_P(DerivedTest, ValuesAreStillCorrect) { + // Static members of the fixture class template can also be visited + // via 'this'. + ASSERT_TRUE(this->shared_ != nullptr); + EXPECT_EQ(5, *this->shared_); + EXPECT_EQ(2, this->value_); +} + +REGISTER_TYPED_TEST_SUITE_P(DerivedTest, + ValuesAreCorrect, ValuesAreStillCorrect); + +typedef Types MyTwoTypes; +INSTANTIATE_TYPED_TEST_SUITE_P(My, DerivedTest, MyTwoTypes); + +// Tests that custom names work with type parametrized tests. We reuse the +// TwoTypes from above here. +template +class TypeParametrizedTestWithNames : public Test {}; + +TYPED_TEST_SUITE_P(TypeParametrizedTestWithNames); + +TYPED_TEST_P(TypeParametrizedTestWithNames, TestSuiteName) { + if (std::is_same::value) { + EXPECT_STREQ(::testing::UnitTest::GetInstance() + ->current_test_info() + ->test_suite_name(), + "CustomName/TypeParametrizedTestWithNames/parChar0"); + } + if (std::is_same::value) { + EXPECT_STREQ(::testing::UnitTest::GetInstance() + ->current_test_info() + ->test_suite_name(), + "CustomName/TypeParametrizedTestWithNames/parInt1"); + } +} + +REGISTER_TYPED_TEST_SUITE_P(TypeParametrizedTestWithNames, TestSuiteName); + +class TypeParametrizedTestNames { + public: + template + static std::string GetName(int i) { + if (std::is_same::value) { + return std::string("parChar") + ::testing::PrintToString(i); + } + if (std::is_same::value) { + return std::string("parInt") + ::testing::PrintToString(i); + } + } +}; + +INSTANTIATE_TYPED_TEST_SUITE_P(CustomName, TypeParametrizedTestWithNames, + TwoTypes, TypeParametrizedTestNames); + +// Tests that multiple TYPED_TEST_SUITE_P's can be defined in the same +// translation unit. + +template +class TypedTestP1 : public Test { +}; + +TYPED_TEST_SUITE_P(TypedTestP1); + +// For testing that the code between TYPED_TEST_SUITE_P() and +// TYPED_TEST_P() is not enclosed in a namespace. +using IntAfterTypedTestSuiteP = int; + +TYPED_TEST_P(TypedTestP1, A) {} +TYPED_TEST_P(TypedTestP1, B) {} + +// For testing that the code between TYPED_TEST_P() and +// REGISTER_TYPED_TEST_SUITE_P() is not enclosed in a namespace. +using IntBeforeRegisterTypedTestSuiteP = int; + +REGISTER_TYPED_TEST_SUITE_P(TypedTestP1, A, B); + +template +class TypedTestP2 : public Test { +}; + +TYPED_TEST_SUITE_P(TypedTestP2); + +// This also verifies that tests from different type-parameterized +// test cases can share the same name. +TYPED_TEST_P(TypedTestP2, A) {} + +REGISTER_TYPED_TEST_SUITE_P(TypedTestP2, A); + +// Verifies that the code between TYPED_TEST_SUITE_P() and +// REGISTER_TYPED_TEST_SUITE_P() is not enclosed in a namespace. +IntAfterTypedTestSuiteP after = 0; +IntBeforeRegisterTypedTestSuiteP before = 0; + +// Verifies that the last argument of INSTANTIATE_TYPED_TEST_SUITE_P() +// can be either a single type or a Types<...> type list. +INSTANTIATE_TYPED_TEST_SUITE_P(Int, TypedTestP1, int); +INSTANTIATE_TYPED_TEST_SUITE_P(Int, TypedTestP2, Types); + +// Tests that the same type-parameterized test case can be +// instantiated more than once in the same translation unit. +INSTANTIATE_TYPED_TEST_SUITE_P(Double, TypedTestP2, Types); + +// Tests that the same type-parameterized test case can be +// instantiated in different translation units linked together. +// (ContainerTest is also instantiated in gtest-typed-test_test.cc.) +typedef Types, std::set > MyContainers; +INSTANTIATE_TYPED_TEST_SUITE_P(My, ContainerTest, MyContainers); + +// Tests that a type-parameterized test case can be defined and +// instantiated in a namespace. + +namespace library2 { + +template +class NumericTest : public Test { +}; + +TYPED_TEST_SUITE_P(NumericTest); + +TYPED_TEST_P(NumericTest, DefaultIsZero) { + EXPECT_EQ(0, TypeParam()); +} + +TYPED_TEST_P(NumericTest, ZeroIsLessThanOne) { + EXPECT_LT(TypeParam(0), TypeParam(1)); +} + +REGISTER_TYPED_TEST_SUITE_P(NumericTest, + DefaultIsZero, ZeroIsLessThanOne); +typedef Types NumericTypes; +INSTANTIATE_TYPED_TEST_SUITE_P(My, NumericTest, NumericTypes); + +static const char* GetTestName() { + return testing::UnitTest::GetInstance()->current_test_info()->name(); +} +// Test the stripping of space from test names +template class TrimmedTest : public Test { }; +TYPED_TEST_SUITE_P(TrimmedTest); +TYPED_TEST_P(TrimmedTest, Test1) { EXPECT_STREQ("Test1", GetTestName()); } +TYPED_TEST_P(TrimmedTest, Test2) { EXPECT_STREQ("Test2", GetTestName()); } +TYPED_TEST_P(TrimmedTest, Test3) { EXPECT_STREQ("Test3", GetTestName()); } +TYPED_TEST_P(TrimmedTest, Test4) { EXPECT_STREQ("Test4", GetTestName()); } +TYPED_TEST_P(TrimmedTest, Test5) { EXPECT_STREQ("Test5", GetTestName()); } +REGISTER_TYPED_TEST_SUITE_P( + TrimmedTest, + Test1, Test2,Test3 , Test4 ,Test5 ); // NOLINT +template struct MyPair {}; +// Be sure to try a type with a comma in its name just in case it matters. +typedef Types > TrimTypes; +INSTANTIATE_TYPED_TEST_SUITE_P(My, TrimmedTest, TrimTypes); + +} // namespace library2 + diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test_test.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test_test.h new file mode 100644 index 000000000000..8ce559c99f73 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-typed-test_test.h @@ -0,0 +1,60 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLETEST_TEST_GTEST_TYPED_TEST_TEST_H_ +#define GOOGLETEST_TEST_GTEST_TYPED_TEST_TEST_H_ + +#include "gtest/gtest.h" + +using testing::Test; + +// For testing that the same type-parameterized test case can be +// instantiated in different translation units linked together. +// ContainerTest will be instantiated in both gtest-typed-test_test.cc +// and gtest-typed-test2_test.cc. + +template +class ContainerTest : public Test { +}; + +TYPED_TEST_SUITE_P(ContainerTest); + +TYPED_TEST_P(ContainerTest, CanBeDefaultConstructed) { + TypeParam container; +} + +TYPED_TEST_P(ContainerTest, InitialSizeIsZero) { + TypeParam container; + EXPECT_EQ(0U, container.size()); +} + +REGISTER_TYPED_TEST_SUITE_P(ContainerTest, + CanBeDefaultConstructed, InitialSizeIsZero); + +#endif // GOOGLETEST_TEST_GTEST_TYPED_TEST_TEST_H_ diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-unittest-api_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-unittest-api_test.cc new file mode 100644 index 000000000000..8ef505838c70 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest-unittest-api_test.cc @@ -0,0 +1,328 @@ +// Copyright 2009 Google Inc. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// The Google C++ Testing and Mocking Framework (Google Test) +// +// This file contains tests verifying correctness of data provided via +// UnitTest's public methods. + +#include "gtest/gtest.h" + +#include // For strcmp. +#include + +using ::testing::InitGoogleTest; + +namespace testing { +namespace internal { + +template +struct LessByName { + bool operator()(const T* a, const T* b) { + return strcmp(a->name(), b->name()) < 0; + } +}; + +class UnitTestHelper { + public: + // Returns the array of pointers to all test suites sorted by the test suite + // name. The caller is responsible for deleting the array. + static TestSuite const** GetSortedTestSuites() { + UnitTest& unit_test = *UnitTest::GetInstance(); + auto const** const test_suites = new const TestSuite*[static_cast( + unit_test.total_test_suite_count())]; + + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) + test_suites[i] = unit_test.GetTestSuite(i); + + std::sort(test_suites, + test_suites + unit_test.total_test_suite_count(), + LessByName()); + return test_suites; + } + + // Returns the test suite by its name. The caller doesn't own the returned + // pointer. + static const TestSuite* FindTestSuite(const char* name) { + UnitTest& unit_test = *UnitTest::GetInstance(); + for (int i = 0; i < unit_test.total_test_suite_count(); ++i) { + const TestSuite* test_suite = unit_test.GetTestSuite(i); + if (0 == strcmp(test_suite->name(), name)) + return test_suite; + } + return nullptr; + } + + // Returns the array of pointers to all tests in a particular test suite + // sorted by the test name. The caller is responsible for deleting the + // array. + static TestInfo const** GetSortedTests(const TestSuite* test_suite) { + TestInfo const** const tests = new const TestInfo*[static_cast( + test_suite->total_test_count())]; + + for (int i = 0; i < test_suite->total_test_count(); ++i) + tests[i] = test_suite->GetTestInfo(i); + + std::sort(tests, tests + test_suite->total_test_count(), + LessByName()); + return tests; + } +}; + +template class TestSuiteWithCommentTest : public Test {}; +TYPED_TEST_SUITE(TestSuiteWithCommentTest, Types); +TYPED_TEST(TestSuiteWithCommentTest, Dummy) {} + +const int kTypedTestSuites = 1; +const int kTypedTests = 1; + +// We can only test the accessors that do not change value while tests run. +// Since tests can be run in any order, the values the accessors that track +// test execution (such as failed_test_count) can not be predicted. +TEST(ApiTest, UnitTestImmutableAccessorsWork) { + UnitTest* unit_test = UnitTest::GetInstance(); + + ASSERT_EQ(2 + kTypedTestSuites, unit_test->total_test_suite_count()); + EXPECT_EQ(1 + kTypedTestSuites, unit_test->test_suite_to_run_count()); + EXPECT_EQ(2, unit_test->disabled_test_count()); + EXPECT_EQ(5 + kTypedTests, unit_test->total_test_count()); + EXPECT_EQ(3 + kTypedTests, unit_test->test_to_run_count()); + + const TestSuite** const test_suites = UnitTestHelper::GetSortedTestSuites(); + + EXPECT_STREQ("ApiTest", test_suites[0]->name()); + EXPECT_STREQ("DISABLED_Test", test_suites[1]->name()); + EXPECT_STREQ("TestSuiteWithCommentTest/0", test_suites[2]->name()); + + delete[] test_suites; + + // The following lines initiate actions to verify certain methods in + // FinalSuccessChecker::TearDown. + + // Records a test property to verify TestResult::GetTestProperty(). + RecordProperty("key", "value"); +} + +AssertionResult IsNull(const char* str) { + if (str != nullptr) { + return testing::AssertionFailure() << "argument is " << str; + } + return AssertionSuccess(); +} + +TEST(ApiTest, TestSuiteImmutableAccessorsWork) { + const TestSuite* test_suite = UnitTestHelper::FindTestSuite("ApiTest"); + ASSERT_TRUE(test_suite != nullptr); + + EXPECT_STREQ("ApiTest", test_suite->name()); + EXPECT_TRUE(IsNull(test_suite->type_param())); + EXPECT_TRUE(test_suite->should_run()); + EXPECT_EQ(1, test_suite->disabled_test_count()); + EXPECT_EQ(3, test_suite->test_to_run_count()); + ASSERT_EQ(4, test_suite->total_test_count()); + + const TestInfo** tests = UnitTestHelper::GetSortedTests(test_suite); + + EXPECT_STREQ("DISABLED_Dummy1", tests[0]->name()); + EXPECT_STREQ("ApiTest", tests[0]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[0]->value_param())); + EXPECT_TRUE(IsNull(tests[0]->type_param())); + EXPECT_FALSE(tests[0]->should_run()); + + EXPECT_STREQ("TestSuiteDisabledAccessorsWork", tests[1]->name()); + EXPECT_STREQ("ApiTest", tests[1]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[1]->value_param())); + EXPECT_TRUE(IsNull(tests[1]->type_param())); + EXPECT_TRUE(tests[1]->should_run()); + + EXPECT_STREQ("TestSuiteImmutableAccessorsWork", tests[2]->name()); + EXPECT_STREQ("ApiTest", tests[2]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[2]->value_param())); + EXPECT_TRUE(IsNull(tests[2]->type_param())); + EXPECT_TRUE(tests[2]->should_run()); + + EXPECT_STREQ("UnitTestImmutableAccessorsWork", tests[3]->name()); + EXPECT_STREQ("ApiTest", tests[3]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[3]->value_param())); + EXPECT_TRUE(IsNull(tests[3]->type_param())); + EXPECT_TRUE(tests[3]->should_run()); + + delete[] tests; + tests = nullptr; + + test_suite = UnitTestHelper::FindTestSuite("TestSuiteWithCommentTest/0"); + ASSERT_TRUE(test_suite != nullptr); + + EXPECT_STREQ("TestSuiteWithCommentTest/0", test_suite->name()); + EXPECT_STREQ(GetTypeName>().c_str(), test_suite->type_param()); + EXPECT_TRUE(test_suite->should_run()); + EXPECT_EQ(0, test_suite->disabled_test_count()); + EXPECT_EQ(1, test_suite->test_to_run_count()); + ASSERT_EQ(1, test_suite->total_test_count()); + + tests = UnitTestHelper::GetSortedTests(test_suite); + + EXPECT_STREQ("Dummy", tests[0]->name()); + EXPECT_STREQ("TestSuiteWithCommentTest/0", tests[0]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[0]->value_param())); + EXPECT_STREQ(GetTypeName>().c_str(), tests[0]->type_param()); + EXPECT_TRUE(tests[0]->should_run()); + + delete[] tests; +} + +TEST(ApiTest, TestSuiteDisabledAccessorsWork) { + const TestSuite* test_suite = UnitTestHelper::FindTestSuite("DISABLED_Test"); + ASSERT_TRUE(test_suite != nullptr); + + EXPECT_STREQ("DISABLED_Test", test_suite->name()); + EXPECT_TRUE(IsNull(test_suite->type_param())); + EXPECT_FALSE(test_suite->should_run()); + EXPECT_EQ(1, test_suite->disabled_test_count()); + EXPECT_EQ(0, test_suite->test_to_run_count()); + ASSERT_EQ(1, test_suite->total_test_count()); + + const TestInfo* const test_info = test_suite->GetTestInfo(0); + EXPECT_STREQ("Dummy2", test_info->name()); + EXPECT_STREQ("DISABLED_Test", test_info->test_suite_name()); + EXPECT_TRUE(IsNull(test_info->value_param())); + EXPECT_TRUE(IsNull(test_info->type_param())); + EXPECT_FALSE(test_info->should_run()); +} + +// These two tests are here to provide support for testing +// test_suite_to_run_count, disabled_test_count, and test_to_run_count. +TEST(ApiTest, DISABLED_Dummy1) {} +TEST(DISABLED_Test, Dummy2) {} + +class FinalSuccessChecker : public Environment { + protected: + void TearDown() override { + UnitTest* unit_test = UnitTest::GetInstance(); + + EXPECT_EQ(1 + kTypedTestSuites, unit_test->successful_test_suite_count()); + EXPECT_EQ(3 + kTypedTests, unit_test->successful_test_count()); + EXPECT_EQ(0, unit_test->failed_test_suite_count()); + EXPECT_EQ(0, unit_test->failed_test_count()); + EXPECT_TRUE(unit_test->Passed()); + EXPECT_FALSE(unit_test->Failed()); + ASSERT_EQ(2 + kTypedTestSuites, unit_test->total_test_suite_count()); + + const TestSuite** const test_suites = UnitTestHelper::GetSortedTestSuites(); + + EXPECT_STREQ("ApiTest", test_suites[0]->name()); + EXPECT_TRUE(IsNull(test_suites[0]->type_param())); + EXPECT_TRUE(test_suites[0]->should_run()); + EXPECT_EQ(1, test_suites[0]->disabled_test_count()); + ASSERT_EQ(4, test_suites[0]->total_test_count()); + EXPECT_EQ(3, test_suites[0]->successful_test_count()); + EXPECT_EQ(0, test_suites[0]->failed_test_count()); + EXPECT_TRUE(test_suites[0]->Passed()); + EXPECT_FALSE(test_suites[0]->Failed()); + + EXPECT_STREQ("DISABLED_Test", test_suites[1]->name()); + EXPECT_TRUE(IsNull(test_suites[1]->type_param())); + EXPECT_FALSE(test_suites[1]->should_run()); + EXPECT_EQ(1, test_suites[1]->disabled_test_count()); + ASSERT_EQ(1, test_suites[1]->total_test_count()); + EXPECT_EQ(0, test_suites[1]->successful_test_count()); + EXPECT_EQ(0, test_suites[1]->failed_test_count()); + + EXPECT_STREQ("TestSuiteWithCommentTest/0", test_suites[2]->name()); + EXPECT_STREQ(GetTypeName>().c_str(), + test_suites[2]->type_param()); + EXPECT_TRUE(test_suites[2]->should_run()); + EXPECT_EQ(0, test_suites[2]->disabled_test_count()); + ASSERT_EQ(1, test_suites[2]->total_test_count()); + EXPECT_EQ(1, test_suites[2]->successful_test_count()); + EXPECT_EQ(0, test_suites[2]->failed_test_count()); + EXPECT_TRUE(test_suites[2]->Passed()); + EXPECT_FALSE(test_suites[2]->Failed()); + + const TestSuite* test_suite = UnitTestHelper::FindTestSuite("ApiTest"); + const TestInfo** tests = UnitTestHelper::GetSortedTests(test_suite); + EXPECT_STREQ("DISABLED_Dummy1", tests[0]->name()); + EXPECT_STREQ("ApiTest", tests[0]->test_suite_name()); + EXPECT_FALSE(tests[0]->should_run()); + + EXPECT_STREQ("TestSuiteDisabledAccessorsWork", tests[1]->name()); + EXPECT_STREQ("ApiTest", tests[1]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[1]->value_param())); + EXPECT_TRUE(IsNull(tests[1]->type_param())); + EXPECT_TRUE(tests[1]->should_run()); + EXPECT_TRUE(tests[1]->result()->Passed()); + EXPECT_EQ(0, tests[1]->result()->test_property_count()); + + EXPECT_STREQ("TestSuiteImmutableAccessorsWork", tests[2]->name()); + EXPECT_STREQ("ApiTest", tests[2]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[2]->value_param())); + EXPECT_TRUE(IsNull(tests[2]->type_param())); + EXPECT_TRUE(tests[2]->should_run()); + EXPECT_TRUE(tests[2]->result()->Passed()); + EXPECT_EQ(0, tests[2]->result()->test_property_count()); + + EXPECT_STREQ("UnitTestImmutableAccessorsWork", tests[3]->name()); + EXPECT_STREQ("ApiTest", tests[3]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[3]->value_param())); + EXPECT_TRUE(IsNull(tests[3]->type_param())); + EXPECT_TRUE(tests[3]->should_run()); + EXPECT_TRUE(tests[3]->result()->Passed()); + EXPECT_EQ(1, tests[3]->result()->test_property_count()); + const TestProperty& property = tests[3]->result()->GetTestProperty(0); + EXPECT_STREQ("key", property.key()); + EXPECT_STREQ("value", property.value()); + + delete[] tests; + + test_suite = UnitTestHelper::FindTestSuite("TestSuiteWithCommentTest/0"); + tests = UnitTestHelper::GetSortedTests(test_suite); + + EXPECT_STREQ("Dummy", tests[0]->name()); + EXPECT_STREQ("TestSuiteWithCommentTest/0", tests[0]->test_suite_name()); + EXPECT_TRUE(IsNull(tests[0]->value_param())); + EXPECT_STREQ(GetTypeName>().c_str(), tests[0]->type_param()); + EXPECT_TRUE(tests[0]->should_run()); + EXPECT_TRUE(tests[0]->result()->Passed()); + EXPECT_EQ(0, tests[0]->result()->test_property_count()); + + delete[] tests; + delete[] test_suites; + } +}; + +} // namespace internal +} // namespace testing + +int main(int argc, char **argv) { + InitGoogleTest(&argc, argv); + + AddGlobalTestEnvironment(new testing::internal::FinalSuccessChecker()); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_all_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_all_test.cc new file mode 100644 index 000000000000..615b29b70651 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_all_test.cc @@ -0,0 +1,46 @@ +// Copyright 2009, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests for Google C++ Testing and Mocking Framework (Google Test) +// +// Sometimes it's desirable to build most of Google Test's own tests +// by compiling a single file. This file serves this purpose. +#include "test/googletest-filepath-test.cc" +#include "test/googletest-message-test.cc" +#include "test/googletest-options-test.cc" +#include "test/googletest-port-test.cc" +#include "test/googletest-test-part-test.cc" +#include "test/gtest-typed-test2_test.cc" +#include "test/gtest-typed-test_test.cc" +#include "test/gtest_pred_impl_unittest.cc" +#include "test/gtest_prod_test.cc" +#include "test/gtest_skip_test.cc" +#include "test/gtest_unittest.cc" +#include "test/production.cc" diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_assert_by_exception_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_assert_by_exception_test.cc new file mode 100644 index 000000000000..ada4cb30ef68 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_assert_by_exception_test.cc @@ -0,0 +1,116 @@ +// Copyright 2009, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Tests Google Test's assert-by-exception mode with exceptions enabled. + +#include "gtest/gtest.h" + +#include +#include +#include +#include + +class ThrowListener : public testing::EmptyTestEventListener { + void OnTestPartResult(const testing::TestPartResult& result) override { + if (result.type() == testing::TestPartResult::kFatalFailure) { + throw testing::AssertionException(result); + } + } +}; + +// Prints the given failure message and exits the program with +// non-zero. We use this instead of a Google Test assertion to +// indicate a failure, as the latter is been tested and cannot be +// relied on. +void Fail(const char* msg) { + printf("FAILURE: %s\n", msg); + fflush(stdout); + exit(1); +} + +static void AssertFalse() { + ASSERT_EQ(2, 3) << "Expected failure"; +} + +// Tests that an assertion failure throws a subclass of +// std::runtime_error. +TEST(Test, Test) { + // A successful assertion shouldn't throw. + try { + EXPECT_EQ(3, 3); + } catch(...) { + Fail("A successful assertion wrongfully threw."); + } + + // A successful assertion shouldn't throw. + try { + EXPECT_EQ(3, 4); + } catch(...) { + Fail("A failed non-fatal assertion wrongfully threw."); + } + + // A failed assertion should throw. + try { + AssertFalse(); + } catch(const testing::AssertionException& e) { + if (strstr(e.what(), "Expected failure") != nullptr) throw; + + printf("%s", + "A failed assertion did throw an exception of the right type, " + "but the message is incorrect. Instead of containing \"Expected " + "failure\", it is:\n"); + Fail(e.what()); + } catch(...) { + Fail("A failed assertion threw the wrong type of exception."); + } + Fail("A failed assertion should've thrown but didn't."); +} + +int kTestForContinuingTest = 0; + +TEST(Test, Test2) { + kTestForContinuingTest = 1; +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::UnitTest::GetInstance()->listeners().Append(new ThrowListener); + + int result = RUN_ALL_TESTS(); + if (result == 0) { + printf("RUN_ALL_TESTS returned %d\n", result); + Fail("Expected failure instead."); + } + + if (kTestForContinuingTest == 0) { + Fail("Should have continued with other tests, but did not."); + } + return 0; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_environment_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_environment_test.cc new file mode 100644 index 000000000000..c7facf5a39ec --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_environment_test.cc @@ -0,0 +1,184 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests using global test environments. + +#include +#include +#include "gtest/gtest.h" +#include "src/gtest-internal-inl.h" + +namespace { + +enum FailureType { + NO_FAILURE, NON_FATAL_FAILURE, FATAL_FAILURE +}; + +// For testing using global test environments. +class MyEnvironment : public testing::Environment { + public: + MyEnvironment() { Reset(); } + + // Depending on the value of failure_in_set_up_, SetUp() will + // generate a non-fatal failure, generate a fatal failure, or + // succeed. + void SetUp() override { + set_up_was_run_ = true; + + switch (failure_in_set_up_) { + case NON_FATAL_FAILURE: + ADD_FAILURE() << "Expected non-fatal failure in global set-up."; + break; + case FATAL_FAILURE: + FAIL() << "Expected fatal failure in global set-up."; + break; + default: + break; + } + } + + // Generates a non-fatal failure. + void TearDown() override { + tear_down_was_run_ = true; + ADD_FAILURE() << "Expected non-fatal failure in global tear-down."; + } + + // Resets the state of the environment s.t. it can be reused. + void Reset() { + failure_in_set_up_ = NO_FAILURE; + set_up_was_run_ = false; + tear_down_was_run_ = false; + } + + // We call this function to set the type of failure SetUp() should + // generate. + void set_failure_in_set_up(FailureType type) { + failure_in_set_up_ = type; + } + + // Was SetUp() run? + bool set_up_was_run() const { return set_up_was_run_; } + + // Was TearDown() run? + bool tear_down_was_run() const { return tear_down_was_run_; } + + private: + FailureType failure_in_set_up_; + bool set_up_was_run_; + bool tear_down_was_run_; +}; + +// Was the TEST run? +bool test_was_run; + +// The sole purpose of this TEST is to enable us to check whether it +// was run. +TEST(FooTest, Bar) { + test_was_run = true; +} + +// Prints the message and aborts the program if condition is false. +void Check(bool condition, const char* msg) { + if (!condition) { + printf("FAILED: %s\n", msg); + testing::internal::posix::Abort(); + } +} + +// Runs the tests. Return true if and only if successful. +// +// The 'failure' parameter specifies the type of failure that should +// be generated by the global set-up. +int RunAllTests(MyEnvironment* env, FailureType failure) { + env->Reset(); + env->set_failure_in_set_up(failure); + test_was_run = false; + testing::internal::GetUnitTestImpl()->ClearAdHocTestResult(); + return RUN_ALL_TESTS(); +} + +} // namespace + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + + // Registers a global test environment, and verifies that the + // registration function returns its argument. + MyEnvironment* const env = new MyEnvironment; + Check(testing::AddGlobalTestEnvironment(env) == env, + "AddGlobalTestEnvironment() should return its argument."); + + // Verifies that RUN_ALL_TESTS() runs the tests when the global + // set-up is successful. + Check(RunAllTests(env, NO_FAILURE) != 0, + "RUN_ALL_TESTS() should return non-zero, as the global tear-down " + "should generate a failure."); + Check(test_was_run, + "The tests should run, as the global set-up should generate no " + "failure"); + Check(env->tear_down_was_run(), + "The global tear-down should run, as the global set-up was run."); + + // Verifies that RUN_ALL_TESTS() runs the tests when the global + // set-up generates no fatal failure. + Check(RunAllTests(env, NON_FATAL_FAILURE) != 0, + "RUN_ALL_TESTS() should return non-zero, as both the global set-up " + "and the global tear-down should generate a non-fatal failure."); + Check(test_was_run, + "The tests should run, as the global set-up should generate no " + "fatal failure."); + Check(env->tear_down_was_run(), + "The global tear-down should run, as the global set-up was run."); + + // Verifies that RUN_ALL_TESTS() runs no test when the global set-up + // generates a fatal failure. + Check(RunAllTests(env, FATAL_FAILURE) != 0, + "RUN_ALL_TESTS() should return non-zero, as the global set-up " + "should generate a fatal failure."); + Check(!test_was_run, + "The tests should not run, as the global set-up should generate " + "a fatal failure."); + Check(env->tear_down_was_run(), + "The global tear-down should run, as the global set-up was run."); + + // Verifies that RUN_ALL_TESTS() doesn't do global set-up or + // tear-down when there is no test to run. + GTEST_FLAG_SET(filter, "-*"); + Check(RunAllTests(env, NO_FAILURE) == 0, + "RUN_ALL_TESTS() should return zero, as there is no test to run."); + Check(!env->set_up_was_run(), + "The global set-up should not run, as there is no test to run."); + Check(!env->tear_down_was_run(), + "The global tear-down should not run, " + "as the global set-up was not run."); + + printf("PASS\n"); + return 0; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_help_test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_help_test.py new file mode 100755 index 000000000000..3e628ae5080e --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_help_test.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# +# Copyright 2009, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests the --help flag of Google C++ Testing and Mocking Framework. + +SYNOPSIS + gtest_help_test.py --build_dir=BUILD/DIR + # where BUILD/DIR contains the built gtest_help_test_ file. + gtest_help_test.py +""" + +import os +import re +from googletest.test import gtest_test_utils + + +IS_LINUX = os.name == 'posix' and os.uname()[0] == 'Linux' +IS_GNUHURD = os.name == 'posix' and os.uname()[0] == 'GNU' +IS_GNUKFREEBSD = os.name == 'posix' and os.uname()[0] == 'GNU/kFreeBSD' +IS_WINDOWS = os.name == 'nt' + +PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath('gtest_help_test_') +FLAG_PREFIX = '--gtest_' +DEATH_TEST_STYLE_FLAG = FLAG_PREFIX + 'death_test_style' +STREAM_RESULT_TO_FLAG = FLAG_PREFIX + 'stream_result_to' +UNKNOWN_FLAG = FLAG_PREFIX + 'unknown_flag_for_testing' +LIST_TESTS_FLAG = FLAG_PREFIX + 'list_tests' +INCORRECT_FLAG_VARIANTS = [re.sub('^--', '-', LIST_TESTS_FLAG), + re.sub('^--', '/', LIST_TESTS_FLAG), + re.sub('_', '-', LIST_TESTS_FLAG)] +INTERNAL_FLAG_FOR_TESTING = FLAG_PREFIX + 'internal_flag_for_testing' + +SUPPORTS_DEATH_TESTS = "DeathTest" in gtest_test_utils.Subprocess( + [PROGRAM_PATH, LIST_TESTS_FLAG]).output + +# The help message must match this regex. +HELP_REGEX = re.compile( + FLAG_PREFIX + r'list_tests.*' + + FLAG_PREFIX + r'filter=.*' + + FLAG_PREFIX + r'also_run_disabled_tests.*' + + FLAG_PREFIX + r'repeat=.*' + + FLAG_PREFIX + r'shuffle.*' + + FLAG_PREFIX + r'random_seed=.*' + + FLAG_PREFIX + r'color=.*' + + FLAG_PREFIX + r'brief.*' + + FLAG_PREFIX + r'print_time.*' + + FLAG_PREFIX + r'output=.*' + + FLAG_PREFIX + r'break_on_failure.*' + + FLAG_PREFIX + r'throw_on_failure.*' + + FLAG_PREFIX + r'catch_exceptions=0.*', + re.DOTALL) + + +def RunWithFlag(flag): + """Runs gtest_help_test_ with the given flag. + + Returns: + the exit code and the text output as a tuple. + Args: + flag: the command-line flag to pass to gtest_help_test_, or None. + """ + + if flag is None: + command = [PROGRAM_PATH] + else: + command = [PROGRAM_PATH, flag] + child = gtest_test_utils.Subprocess(command) + return child.exit_code, child.output + + +class GTestHelpTest(gtest_test_utils.TestCase): + """Tests the --help flag and its equivalent forms.""" + + def TestHelpFlag(self, flag): + """Verifies correct behavior when help flag is specified. + + The right message must be printed and the tests must + skipped when the given flag is specified. + + Args: + flag: A flag to pass to the binary or None. + """ + + exit_code, output = RunWithFlag(flag) + self.assertEquals(0, exit_code) + self.assert_(HELP_REGEX.search(output), output) + + if IS_LINUX or IS_GNUHURD or IS_GNUKFREEBSD: + self.assert_(STREAM_RESULT_TO_FLAG in output, output) + else: + self.assert_(STREAM_RESULT_TO_FLAG not in output, output) + + if SUPPORTS_DEATH_TESTS and not IS_WINDOWS: + self.assert_(DEATH_TEST_STYLE_FLAG in output, output) + else: + self.assert_(DEATH_TEST_STYLE_FLAG not in output, output) + + def TestNonHelpFlag(self, flag): + """Verifies correct behavior when no help flag is specified. + + Verifies that when no help flag is specified, the tests are run + and the help message is not printed. + + Args: + flag: A flag to pass to the binary or None. + """ + + exit_code, output = RunWithFlag(flag) + self.assert_(exit_code != 0) + self.assert_(not HELP_REGEX.search(output), output) + + def testPrintsHelpWithFullFlag(self): + self.TestHelpFlag('--help') + + def testPrintsHelpWithShortFlag(self): + self.TestHelpFlag('-h') + + def testPrintsHelpWithQuestionFlag(self): + self.TestHelpFlag('-?') + + def testPrintsHelpWithWindowsStyleQuestionFlag(self): + self.TestHelpFlag('/?') + + def testPrintsHelpWithUnrecognizedGoogleTestFlag(self): + self.TestHelpFlag(UNKNOWN_FLAG) + + def testPrintsHelpWithIncorrectFlagStyle(self): + for incorrect_flag in INCORRECT_FLAG_VARIANTS: + self.TestHelpFlag(incorrect_flag) + + def testRunsTestsWithoutHelpFlag(self): + """Verifies that when no help flag is specified, the tests are run + and the help message is not printed.""" + + self.TestNonHelpFlag(None) + + def testRunsTestsWithGtestInternalFlag(self): + """Verifies that the tests are run and no help message is printed when + a flag starting with Google Test prefix and 'internal_' is supplied.""" + + self.TestNonHelpFlag(INTERNAL_FLAG_FOR_TESTING) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_help_test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_help_test_.cc new file mode 100644 index 000000000000..750ae6ce95fc --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_help_test_.cc @@ -0,0 +1,45 @@ +// Copyright 2009, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This program is meant to be run by gtest_help_test.py. Do not run +// it directly. + +#include "gtest/gtest.h" + +// When a help flag is specified, this program should skip the tests +// and exit with 0; otherwise the following test will be executed, +// causing this program to exit with a non-zero code. +TEST(HelpFlagTest, ShouldNotBeRun) { + ASSERT_TRUE(false) << "Tests shouldn't be run when --help is specified."; +} + +#if GTEST_HAS_DEATH_TEST +TEST(DeathTest, UsedByPythonScriptToDetectSupportForDeathTestsInThisBinary) {} +#endif diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_json_test_utils.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_json_test_utils.py new file mode 100644 index 000000000000..62bbfc288f82 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_json_test_utils.py @@ -0,0 +1,60 @@ +# Copyright 2018, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test utilities for gtest_json_output.""" + +import re + + +def normalize(obj): + """Normalize output object. + + Args: + obj: Google Test's JSON output object to normalize. + + Returns: + Normalized output without any references to transient information that may + change from run to run. + """ + def _normalize(key, value): + if key == 'time': + return re.sub(r'^\d+(\.\d+)?s$', '*', value) + elif key == 'timestamp': + return re.sub(r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ$', '*', value) + elif key == 'failure': + value = re.sub(r'^.*[/\\](.*:)\d+\n', '\\1*\n', value) + return re.sub(r'Stack trace:\n(.|\n)*', 'Stack trace:\n*', value) + else: + return normalize(value) + if isinstance(obj, dict): + return {k: _normalize(k, v) for k, v in obj.items()} + if isinstance(obj, list): + return [normalize(x) for x in obj] + else: + return obj diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_list_output_unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_list_output_unittest.py new file mode 100644 index 000000000000..faacf103c342 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_list_output_unittest.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python +# +# Copyright 2006, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Unit test for Google Test's --gtest_list_tests flag. + +A user can ask Google Test to list all tests by specifying the +--gtest_list_tests flag. If output is requested, via --gtest_output=xml +or --gtest_output=json, the tests are listed, with extra information in the +output file. +This script tests such functionality by invoking gtest_list_output_unittest_ + (a program written with Google Test) the command line flags. +""" + +import os +import re +from googletest.test import gtest_test_utils + +GTEST_LIST_TESTS_FLAG = '--gtest_list_tests' +GTEST_OUTPUT_FLAG = '--gtest_output' + +EXPECTED_XML = """<\?xml version="1.0" encoding="UTF-8"\?> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +""" + +EXPECTED_JSON = """{ + "tests": 16, + "name": "AllTests", + "testsuites": \[ + { + "name": "FooTest", + "tests": 2, + "testsuite": \[ + { + "name": "Test1", + "file": ".*gtest_list_output_unittest_.cc", + "line": 43 + }, + { + "name": "Test2", + "file": ".*gtest_list_output_unittest_.cc", + "line": 45 + } + \] + }, + { + "name": "FooTestFixture", + "tests": 2, + "testsuite": \[ + { + "name": "Test3", + "file": ".*gtest_list_output_unittest_.cc", + "line": 48 + }, + { + "name": "Test4", + "file": ".*gtest_list_output_unittest_.cc", + "line": 49 + } + \] + }, + { + "name": "TypedTest\\\\/0", + "tests": 2, + "testsuite": \[ + { + "name": "Test7", + "type_param": "int", + "file": ".*gtest_list_output_unittest_.cc", + "line": 60 + }, + { + "name": "Test8", + "type_param": "int", + "file": ".*gtest_list_output_unittest_.cc", + "line": 61 + } + \] + }, + { + "name": "TypedTest\\\\/1", + "tests": 2, + "testsuite": \[ + { + "name": "Test7", + "type_param": "bool", + "file": ".*gtest_list_output_unittest_.cc", + "line": 60 + }, + { + "name": "Test8", + "type_param": "bool", + "file": ".*gtest_list_output_unittest_.cc", + "line": 61 + } + \] + }, + { + "name": "Single\\\\/TypeParameterizedTestSuite\\\\/0", + "tests": 2, + "testsuite": \[ + { + "name": "Test9", + "type_param": "int", + "file": ".*gtest_list_output_unittest_.cc", + "line": 66 + }, + { + "name": "Test10", + "type_param": "int", + "file": ".*gtest_list_output_unittest_.cc", + "line": 67 + } + \] + }, + { + "name": "Single\\\\/TypeParameterizedTestSuite\\\\/1", + "tests": 2, + "testsuite": \[ + { + "name": "Test9", + "type_param": "bool", + "file": ".*gtest_list_output_unittest_.cc", + "line": 66 + }, + { + "name": "Test10", + "type_param": "bool", + "file": ".*gtest_list_output_unittest_.cc", + "line": 67 + } + \] + }, + { + "name": "ValueParam\\\\/ValueParamTest", + "tests": 4, + "testsuite": \[ + { + "name": "Test5\\\\/0", + "value_param": "33", + "file": ".*gtest_list_output_unittest_.cc", + "line": 52 + }, + { + "name": "Test5\\\\/1", + "value_param": "42", + "file": ".*gtest_list_output_unittest_.cc", + "line": 52 + }, + { + "name": "Test6\\\\/0", + "value_param": "33", + "file": ".*gtest_list_output_unittest_.cc", + "line": 53 + }, + { + "name": "Test6\\\\/1", + "value_param": "42", + "file": ".*gtest_list_output_unittest_.cc", + "line": 53 + } + \] + } + \] +} +""" + + +class GTestListTestsOutputUnitTest(gtest_test_utils.TestCase): + """Unit test for Google Test's list tests with output to file functionality. + """ + + def testXml(self): + """Verifies XML output for listing tests in a Google Test binary. + + Runs a test program that generates an empty XML output, and + tests that the XML output is expected. + """ + self._TestOutput('xml', EXPECTED_XML) + + def testJSON(self): + """Verifies XML output for listing tests in a Google Test binary. + + Runs a test program that generates an empty XML output, and + tests that the XML output is expected. + """ + self._TestOutput('json', EXPECTED_JSON) + + def _GetOutput(self, out_format): + file_path = os.path.join(gtest_test_utils.GetTempDir(), + 'test_out.' + out_format) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath( + 'gtest_list_output_unittest_') + + command = ([ + gtest_prog_path, + '%s=%s:%s' % (GTEST_OUTPUT_FLAG, out_format, file_path), + '--gtest_list_tests' + ]) + environ_copy = os.environ.copy() + p = gtest_test_utils.Subprocess( + command, env=environ_copy, working_dir=gtest_test_utils.GetTempDir()) + + self.assertTrue(p.exited) + self.assertEqual(0, p.exit_code) + self.assertTrue(os.path.isfile(file_path)) + with open(file_path) as f: + result = f.read() + return result + + def _TestOutput(self, test_format, expected_output): + actual = self._GetOutput(test_format) + actual_lines = actual.splitlines() + expected_lines = expected_output.splitlines() + line_count = 0 + for actual_line in actual_lines: + expected_line = expected_lines[line_count] + expected_line_re = re.compile(expected_line.strip()) + self.assertTrue( + expected_line_re.match(actual_line.strip()), + ('actual output of "%s",\n' + 'which does not match expected regex of "%s"\n' + 'on line %d' % (actual, expected_output, line_count))) + line_count = line_count + 1 + + +if __name__ == '__main__': + os.environ['GTEST_STACK_TRACE_DEPTH'] = '1' + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_list_output_unittest_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_list_output_unittest_.cc new file mode 100644 index 000000000000..92b9d4f28eec --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_list_output_unittest_.cc @@ -0,0 +1,77 @@ +// Copyright 2018, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: david.schuldenfrei@gmail.com (David Schuldenfrei) + +// Unit test for Google Test's --gtest_list_tests and --gtest_output flag. +// +// A user can ask Google Test to list all tests that will run, +// and have the output saved in a Json/Xml file. +// The tests will not be run after listing. +// +// This program will be invoked from a Python unit test. +// Don't run it directly. + +#include "gtest/gtest.h" + +TEST(FooTest, Test1) {} + +TEST(FooTest, Test2) {} + +class FooTestFixture : public ::testing::Test {}; +TEST_F(FooTestFixture, Test3) {} +TEST_F(FooTestFixture, Test4) {} + +class ValueParamTest : public ::testing::TestWithParam {}; +TEST_P(ValueParamTest, Test5) {} +TEST_P(ValueParamTest, Test6) {} +INSTANTIATE_TEST_SUITE_P(ValueParam, ValueParamTest, ::testing::Values(33, 42)); + +template +class TypedTest : public ::testing::Test {}; +typedef testing::Types TypedTestTypes; +TYPED_TEST_SUITE(TypedTest, TypedTestTypes); +TYPED_TEST(TypedTest, Test7) {} +TYPED_TEST(TypedTest, Test8) {} + +template +class TypeParameterizedTestSuite : public ::testing::Test {}; +TYPED_TEST_SUITE_P(TypeParameterizedTestSuite); +TYPED_TEST_P(TypeParameterizedTestSuite, Test9) {} +TYPED_TEST_P(TypeParameterizedTestSuite, Test10) {} +REGISTER_TYPED_TEST_SUITE_P(TypeParameterizedTestSuite, Test9, Test10); +typedef testing::Types TypeParameterizedTestSuiteTypes; // NOLINT +INSTANTIATE_TYPED_TEST_SUITE_P(Single, TypeParameterizedTestSuite, + TypeParameterizedTestSuiteTypes); + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_main_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_main_unittest.cc new file mode 100644 index 000000000000..eddedeabe8f3 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_main_unittest.cc @@ -0,0 +1,44 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "gtest/gtest.h" + +// Tests that we don't have to define main() when we link to +// gtest_main instead of gtest. + +namespace { + +TEST(GTestMainTest, ShouldSucceed) { +} + +} // namespace + +// We are using the main() function defined in gtest_main.cc, so we +// don't define it here. diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_no_test_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_no_test_unittest.cc new file mode 100644 index 000000000000..d4f88dbfdfa6 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_no_test_unittest.cc @@ -0,0 +1,54 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Tests that a Google Test program that has no test defined can run +// successfully. + +#include "gtest/gtest.h" + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + + // An ad-hoc assertion outside of all tests. + // + // This serves three purposes: + // + // 1. It verifies that an ad-hoc assertion can be executed even if + // no test is defined. + // 2. It verifies that a failed ad-hoc assertion causes the test + // program to fail. + // 3. We had a bug where the XML output won't be generated if an + // assertion is executed before RUN_ALL_TESTS() is called, even + // though --gtest_output=xml is specified. This makes sure the + // bug is fixed and doesn't regress. + EXPECT_EQ(1, 2); + + // The above EXPECT_EQ() should cause RUN_ALL_TESTS() to return non-zero. + return RUN_ALL_TESTS() ? 0 : 1; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_pred_impl_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_pred_impl_unittest.cc new file mode 100644 index 000000000000..5eeb1473798e --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_pred_impl_unittest.cc @@ -0,0 +1,2422 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file is AUTOMATICALLY GENERATED on 07/21/2021 by command +// 'gen_gtest_pred_impl.py 5'. DO NOT EDIT BY HAND! + +// Regression test for gtest_pred_impl.h +// +// This file is generated by a script and quite long. If you intend to +// learn how Google Test works by reading its unit tests, read +// gtest_unittest.cc instead. +// +// This is intended as a regression test for the Google Test predicate +// assertions. We compile it as part of the gtest_unittest target +// only to keep the implementation tidy and compact, as it is quite +// involved to set up the stage for testing Google Test using Google +// Test itself. +// +// Currently, gtest_unittest takes ~11 seconds to run in the testing +// daemon. In the future, if it grows too large and needs much more +// time to finish, we should consider separating this file into a +// stand-alone regression test. + +#include + +#include "gtest/gtest.h" +#include "gtest/gtest-spi.h" + +// A user-defined data type. +struct Bool { + explicit Bool(int val) : value(val != 0) {} + + bool operator>(int n) const { return value > Bool(n).value; } + + Bool operator+(const Bool& rhs) const { return Bool(value + rhs.value); } + + bool operator==(const Bool& rhs) const { return value == rhs.value; } + + bool value; +}; + +// Enables Bool to be used in assertions. +std::ostream& operator<<(std::ostream& os, const Bool& x) { + return os << (x.value ? "true" : "false"); +} + +// Sample functions/functors for testing unary predicate assertions. + +// A unary predicate function. +template +bool PredFunction1(T1 v1) { + return v1 > 0; +} + +// The following two functions are needed because a compiler doesn't have +// a context yet to know which template function must be instantiated. +bool PredFunction1Int(int v1) { + return v1 > 0; +} +bool PredFunction1Bool(Bool v1) { + return v1 > 0; +} + +// A unary predicate functor. +struct PredFunctor1 { + template + bool operator()(const T1& v1) { + return v1 > 0; + } +}; + +// A unary predicate-formatter function. +template +testing::AssertionResult PredFormatFunction1(const char* e1, + const T1& v1) { + if (PredFunction1(v1)) + return testing::AssertionSuccess(); + + return testing::AssertionFailure() + << e1 + << " is expected to be positive, but evaluates to " + << v1 << "."; +} + +// A unary predicate-formatter functor. +struct PredFormatFunctor1 { + template + testing::AssertionResult operator()(const char* e1, + const T1& v1) const { + return PredFormatFunction1(e1, v1); + } +}; + +// Tests for {EXPECT|ASSERT}_PRED_FORMAT1. + +class Predicate1Test : public testing::Test { + protected: + void SetUp() override { + expected_to_finish_ = true; + finished_ = false; + n1_ = 0; + } + + void TearDown() override { + // Verifies that each of the predicate's arguments was evaluated + // exactly once. + EXPECT_EQ(1, n1_) << + "The predicate assertion didn't evaluate argument 2 " + "exactly once."; + + // Verifies that the control flow in the test function is expected. + if (expected_to_finish_ && !finished_) { + FAIL() << "The predicate assertion unexpectedly aborted the test."; + } else if (!expected_to_finish_ && finished_) { + FAIL() << "The failed predicate assertion didn't abort the test " + "as expected."; + } + } + + // true if and only if the test function is expected to run to finish. + static bool expected_to_finish_; + + // true if and only if the test function did run to finish. + static bool finished_; + + static int n1_; +}; + +bool Predicate1Test::expected_to_finish_; +bool Predicate1Test::finished_; +int Predicate1Test::n1_; + +typedef Predicate1Test EXPECT_PRED_FORMAT1Test; +typedef Predicate1Test ASSERT_PRED_FORMAT1Test; +typedef Predicate1Test EXPECT_PRED1Test; +typedef Predicate1Test ASSERT_PRED1Test; + +// Tests a successful EXPECT_PRED1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED1Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED1(PredFunction1Int, + ++n1_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED1Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED1(PredFunction1Bool, + Bool(++n1_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED1Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED1(PredFunctor1(), + ++n1_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED1Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED1(PredFunctor1(), + Bool(++n1_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED1Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED1(PredFunction1Int, + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED1Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED1(PredFunction1Bool, + Bool(n1_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED1Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED1(PredFunctor1(), + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED1Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED1(PredFunctor1(), + Bool(n1_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED1Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED1(PredFunction1Int, + ++n1_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED1Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED1(PredFunction1Bool, + Bool(++n1_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED1Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED1(PredFunctor1(), + ++n1_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED1Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED1(PredFunctor1(), + Bool(++n1_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED1Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED1(PredFunction1Int, + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED1Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED1(PredFunction1Bool, + Bool(n1_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED1Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED1(PredFunctor1(), + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED1Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED1(PredFunctor1(), + Bool(n1_++)); + finished_ = true; + }, ""); +} + +// Tests a successful EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT1(PredFormatFunction1, + ++n1_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED_FORMAT1(PredFormatFunction1, + Bool(++n1_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT1(PredFormatFunctor1(), + ++n1_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED_FORMAT1(PredFormatFunctor1(), + Bool(++n1_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT1(PredFormatFunction1, + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT1(PredFormatFunction1, + Bool(n1_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT1(PredFormatFunctor1(), + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT1Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT1(PredFormatFunctor1(), + Bool(n1_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT1(PredFormatFunction1, + ++n1_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED_FORMAT1(PredFormatFunction1, + Bool(++n1_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT1(PredFormatFunctor1(), + ++n1_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED_FORMAT1(PredFormatFunctor1(), + Bool(++n1_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT1(PredFormatFunction1, + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT1(PredFormatFunction1, + Bool(n1_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT1(PredFormatFunctor1(), + n1_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT1 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT1Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT1(PredFormatFunctor1(), + Bool(n1_++)); + finished_ = true; + }, ""); +} +// Sample functions/functors for testing binary predicate assertions. + +// A binary predicate function. +template +bool PredFunction2(T1 v1, T2 v2) { + return v1 + v2 > 0; +} + +// The following two functions are needed because a compiler doesn't have +// a context yet to know which template function must be instantiated. +bool PredFunction2Int(int v1, int v2) { + return v1 + v2 > 0; +} +bool PredFunction2Bool(Bool v1, Bool v2) { + return v1 + v2 > 0; +} + +// A binary predicate functor. +struct PredFunctor2 { + template + bool operator()(const T1& v1, + const T2& v2) { + return v1 + v2 > 0; + } +}; + +// A binary predicate-formatter function. +template +testing::AssertionResult PredFormatFunction2(const char* e1, + const char* e2, + const T1& v1, + const T2& v2) { + if (PredFunction2(v1, v2)) + return testing::AssertionSuccess(); + + return testing::AssertionFailure() + << e1 << " + " << e2 + << " is expected to be positive, but evaluates to " + << v1 + v2 << "."; +} + +// A binary predicate-formatter functor. +struct PredFormatFunctor2 { + template + testing::AssertionResult operator()(const char* e1, + const char* e2, + const T1& v1, + const T2& v2) const { + return PredFormatFunction2(e1, e2, v1, v2); + } +}; + +// Tests for {EXPECT|ASSERT}_PRED_FORMAT2. + +class Predicate2Test : public testing::Test { + protected: + void SetUp() override { + expected_to_finish_ = true; + finished_ = false; + n1_ = n2_ = 0; + } + + void TearDown() override { + // Verifies that each of the predicate's arguments was evaluated + // exactly once. + EXPECT_EQ(1, n1_) << + "The predicate assertion didn't evaluate argument 2 " + "exactly once."; + EXPECT_EQ(1, n2_) << + "The predicate assertion didn't evaluate argument 3 " + "exactly once."; + + // Verifies that the control flow in the test function is expected. + if (expected_to_finish_ && !finished_) { + FAIL() << "The predicate assertion unexpectedly aborted the test."; + } else if (!expected_to_finish_ && finished_) { + FAIL() << "The failed predicate assertion didn't abort the test " + "as expected."; + } + } + + // true if and only if the test function is expected to run to finish. + static bool expected_to_finish_; + + // true if and only if the test function did run to finish. + static bool finished_; + + static int n1_; + static int n2_; +}; + +bool Predicate2Test::expected_to_finish_; +bool Predicate2Test::finished_; +int Predicate2Test::n1_; +int Predicate2Test::n2_; + +typedef Predicate2Test EXPECT_PRED_FORMAT2Test; +typedef Predicate2Test ASSERT_PRED_FORMAT2Test; +typedef Predicate2Test EXPECT_PRED2Test; +typedef Predicate2Test ASSERT_PRED2Test; + +// Tests a successful EXPECT_PRED2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED2Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED2(PredFunction2Int, + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED2Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED2(PredFunction2Bool, + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED2Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED2(PredFunctor2(), + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED2Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED2(PredFunctor2(), + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED2Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED2(PredFunction2Int, + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED2Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED2(PredFunction2Bool, + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED2Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED2(PredFunctor2(), + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED2Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED2(PredFunctor2(), + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED2Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED2(PredFunction2Int, + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED2Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED2(PredFunction2Bool, + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED2Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED2(PredFunctor2(), + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED2Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED2(PredFunctor2(), + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED2Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED2(PredFunction2Int, + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED2Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED2(PredFunction2Bool, + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED2Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED2(PredFunctor2(), + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED2Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED2(PredFunctor2(), + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} + +// Tests a successful EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT2(PredFormatFunction2, + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED_FORMAT2(PredFormatFunction2, + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT2(PredFormatFunctor2(), + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED_FORMAT2(PredFormatFunctor2(), + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(PredFormatFunction2, + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(PredFormatFunction2, + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(PredFormatFunctor2(), + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT2Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(PredFormatFunctor2(), + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT2(PredFormatFunction2, + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED_FORMAT2(PredFormatFunction2, + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT2(PredFormatFunctor2(), + ++n1_, + ++n2_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED_FORMAT2(PredFormatFunctor2(), + Bool(++n1_), + Bool(++n2_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT2(PredFormatFunction2, + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT2(PredFormatFunction2, + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT2(PredFormatFunctor2(), + n1_++, + n2_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT2 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT2Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT2(PredFormatFunctor2(), + Bool(n1_++), + Bool(n2_++)); + finished_ = true; + }, ""); +} +// Sample functions/functors for testing ternary predicate assertions. + +// A ternary predicate function. +template +bool PredFunction3(T1 v1, T2 v2, T3 v3) { + return v1 + v2 + v3 > 0; +} + +// The following two functions are needed because a compiler doesn't have +// a context yet to know which template function must be instantiated. +bool PredFunction3Int(int v1, int v2, int v3) { + return v1 + v2 + v3 > 0; +} +bool PredFunction3Bool(Bool v1, Bool v2, Bool v3) { + return v1 + v2 + v3 > 0; +} + +// A ternary predicate functor. +struct PredFunctor3 { + template + bool operator()(const T1& v1, + const T2& v2, + const T3& v3) { + return v1 + v2 + v3 > 0; + } +}; + +// A ternary predicate-formatter function. +template +testing::AssertionResult PredFormatFunction3(const char* e1, + const char* e2, + const char* e3, + const T1& v1, + const T2& v2, + const T3& v3) { + if (PredFunction3(v1, v2, v3)) + return testing::AssertionSuccess(); + + return testing::AssertionFailure() + << e1 << " + " << e2 << " + " << e3 + << " is expected to be positive, but evaluates to " + << v1 + v2 + v3 << "."; +} + +// A ternary predicate-formatter functor. +struct PredFormatFunctor3 { + template + testing::AssertionResult operator()(const char* e1, + const char* e2, + const char* e3, + const T1& v1, + const T2& v2, + const T3& v3) const { + return PredFormatFunction3(e1, e2, e3, v1, v2, v3); + } +}; + +// Tests for {EXPECT|ASSERT}_PRED_FORMAT3. + +class Predicate3Test : public testing::Test { + protected: + void SetUp() override { + expected_to_finish_ = true; + finished_ = false; + n1_ = n2_ = n3_ = 0; + } + + void TearDown() override { + // Verifies that each of the predicate's arguments was evaluated + // exactly once. + EXPECT_EQ(1, n1_) << + "The predicate assertion didn't evaluate argument 2 " + "exactly once."; + EXPECT_EQ(1, n2_) << + "The predicate assertion didn't evaluate argument 3 " + "exactly once."; + EXPECT_EQ(1, n3_) << + "The predicate assertion didn't evaluate argument 4 " + "exactly once."; + + // Verifies that the control flow in the test function is expected. + if (expected_to_finish_ && !finished_) { + FAIL() << "The predicate assertion unexpectedly aborted the test."; + } else if (!expected_to_finish_ && finished_) { + FAIL() << "The failed predicate assertion didn't abort the test " + "as expected."; + } + } + + // true if and only if the test function is expected to run to finish. + static bool expected_to_finish_; + + // true if and only if the test function did run to finish. + static bool finished_; + + static int n1_; + static int n2_; + static int n3_; +}; + +bool Predicate3Test::expected_to_finish_; +bool Predicate3Test::finished_; +int Predicate3Test::n1_; +int Predicate3Test::n2_; +int Predicate3Test::n3_; + +typedef Predicate3Test EXPECT_PRED_FORMAT3Test; +typedef Predicate3Test ASSERT_PRED_FORMAT3Test; +typedef Predicate3Test EXPECT_PRED3Test; +typedef Predicate3Test ASSERT_PRED3Test; + +// Tests a successful EXPECT_PRED3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED3Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED3(PredFunction3Int, + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED3Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED3(PredFunction3Bool, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED3Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED3(PredFunctor3(), + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED3Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED3(PredFunctor3(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED3Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED3(PredFunction3Int, + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED3Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED3(PredFunction3Bool, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED3Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED3(PredFunctor3(), + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED3Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED3(PredFunctor3(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED3Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED3(PredFunction3Int, + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED3Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED3(PredFunction3Bool, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED3Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED3(PredFunctor3(), + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED3Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED3(PredFunctor3(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED3Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED3(PredFunction3Int, + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED3Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED3(PredFunction3Bool, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED3Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED3(PredFunctor3(), + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED3Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED3(PredFunctor3(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} + +// Tests a successful EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT3(PredFormatFunction3, + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED_FORMAT3(PredFormatFunction3, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT3(PredFormatFunctor3(), + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED_FORMAT3(PredFormatFunctor3(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT3(PredFormatFunction3, + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT3(PredFormatFunction3, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT3(PredFormatFunctor3(), + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT3Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT3(PredFormatFunctor3(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT3(PredFormatFunction3, + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED_FORMAT3(PredFormatFunction3, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT3(PredFormatFunctor3(), + ++n1_, + ++n2_, + ++n3_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED_FORMAT3(PredFormatFunctor3(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT3(PredFormatFunction3, + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT3(PredFormatFunction3, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT3(PredFormatFunctor3(), + n1_++, + n2_++, + n3_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT3 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT3Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT3(PredFormatFunctor3(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++)); + finished_ = true; + }, ""); +} +// Sample functions/functors for testing 4-ary predicate assertions. + +// A 4-ary predicate function. +template +bool PredFunction4(T1 v1, T2 v2, T3 v3, T4 v4) { + return v1 + v2 + v3 + v4 > 0; +} + +// The following two functions are needed because a compiler doesn't have +// a context yet to know which template function must be instantiated. +bool PredFunction4Int(int v1, int v2, int v3, int v4) { + return v1 + v2 + v3 + v4 > 0; +} +bool PredFunction4Bool(Bool v1, Bool v2, Bool v3, Bool v4) { + return v1 + v2 + v3 + v4 > 0; +} + +// A 4-ary predicate functor. +struct PredFunctor4 { + template + bool operator()(const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4) { + return v1 + v2 + v3 + v4 > 0; + } +}; + +// A 4-ary predicate-formatter function. +template +testing::AssertionResult PredFormatFunction4(const char* e1, + const char* e2, + const char* e3, + const char* e4, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4) { + if (PredFunction4(v1, v2, v3, v4)) + return testing::AssertionSuccess(); + + return testing::AssertionFailure() + << e1 << " + " << e2 << " + " << e3 << " + " << e4 + << " is expected to be positive, but evaluates to " + << v1 + v2 + v3 + v4 << "."; +} + +// A 4-ary predicate-formatter functor. +struct PredFormatFunctor4 { + template + testing::AssertionResult operator()(const char* e1, + const char* e2, + const char* e3, + const char* e4, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4) const { + return PredFormatFunction4(e1, e2, e3, e4, v1, v2, v3, v4); + } +}; + +// Tests for {EXPECT|ASSERT}_PRED_FORMAT4. + +class Predicate4Test : public testing::Test { + protected: + void SetUp() override { + expected_to_finish_ = true; + finished_ = false; + n1_ = n2_ = n3_ = n4_ = 0; + } + + void TearDown() override { + // Verifies that each of the predicate's arguments was evaluated + // exactly once. + EXPECT_EQ(1, n1_) << + "The predicate assertion didn't evaluate argument 2 " + "exactly once."; + EXPECT_EQ(1, n2_) << + "The predicate assertion didn't evaluate argument 3 " + "exactly once."; + EXPECT_EQ(1, n3_) << + "The predicate assertion didn't evaluate argument 4 " + "exactly once."; + EXPECT_EQ(1, n4_) << + "The predicate assertion didn't evaluate argument 5 " + "exactly once."; + + // Verifies that the control flow in the test function is expected. + if (expected_to_finish_ && !finished_) { + FAIL() << "The predicate assertion unexpectedly aborted the test."; + } else if (!expected_to_finish_ && finished_) { + FAIL() << "The failed predicate assertion didn't abort the test " + "as expected."; + } + } + + // true if and only if the test function is expected to run to finish. + static bool expected_to_finish_; + + // true if and only if the test function did run to finish. + static bool finished_; + + static int n1_; + static int n2_; + static int n3_; + static int n4_; +}; + +bool Predicate4Test::expected_to_finish_; +bool Predicate4Test::finished_; +int Predicate4Test::n1_; +int Predicate4Test::n2_; +int Predicate4Test::n3_; +int Predicate4Test::n4_; + +typedef Predicate4Test EXPECT_PRED_FORMAT4Test; +typedef Predicate4Test ASSERT_PRED_FORMAT4Test; +typedef Predicate4Test EXPECT_PRED4Test; +typedef Predicate4Test ASSERT_PRED4Test; + +// Tests a successful EXPECT_PRED4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED4Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED4(PredFunction4Int, + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED4Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED4(PredFunction4Bool, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED4Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED4(PredFunctor4(), + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED4Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED4(PredFunctor4(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED4Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED4(PredFunction4Int, + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED4Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED4(PredFunction4Bool, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED4Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED4(PredFunctor4(), + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED4Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED4(PredFunctor4(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED4Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED4(PredFunction4Int, + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED4Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED4(PredFunction4Bool, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED4Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED4(PredFunctor4(), + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED4Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED4(PredFunctor4(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED4Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED4(PredFunction4Int, + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED4Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED4(PredFunction4Bool, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED4Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED4(PredFunctor4(), + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED4Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED4(PredFunctor4(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} + +// Tests a successful EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT4(PredFormatFunction4, + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED_FORMAT4(PredFormatFunction4, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT4(PredFormatFunctor4(), + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED_FORMAT4(PredFormatFunctor4(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT4(PredFormatFunction4, + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT4(PredFormatFunction4, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT4(PredFormatFunctor4(), + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT4Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT4(PredFormatFunctor4(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT4(PredFormatFunction4, + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED_FORMAT4(PredFormatFunction4, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT4(PredFormatFunctor4(), + ++n1_, + ++n2_, + ++n3_, + ++n4_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED_FORMAT4(PredFormatFunctor4(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT4(PredFormatFunction4, + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT4(PredFormatFunction4, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT4(PredFormatFunctor4(), + n1_++, + n2_++, + n3_++, + n4_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT4 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT4Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT4(PredFormatFunctor4(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++)); + finished_ = true; + }, ""); +} +// Sample functions/functors for testing 5-ary predicate assertions. + +// A 5-ary predicate function. +template +bool PredFunction5(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5) { + return v1 + v2 + v3 + v4 + v5 > 0; +} + +// The following two functions are needed because a compiler doesn't have +// a context yet to know which template function must be instantiated. +bool PredFunction5Int(int v1, int v2, int v3, int v4, int v5) { + return v1 + v2 + v3 + v4 + v5 > 0; +} +bool PredFunction5Bool(Bool v1, Bool v2, Bool v3, Bool v4, Bool v5) { + return v1 + v2 + v3 + v4 + v5 > 0; +} + +// A 5-ary predicate functor. +struct PredFunctor5 { + template + bool operator()(const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4, + const T5& v5) { + return v1 + v2 + v3 + v4 + v5 > 0; + } +}; + +// A 5-ary predicate-formatter function. +template +testing::AssertionResult PredFormatFunction5(const char* e1, + const char* e2, + const char* e3, + const char* e4, + const char* e5, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4, + const T5& v5) { + if (PredFunction5(v1, v2, v3, v4, v5)) + return testing::AssertionSuccess(); + + return testing::AssertionFailure() + << e1 << " + " << e2 << " + " << e3 << " + " << e4 << " + " << e5 + << " is expected to be positive, but evaluates to " + << v1 + v2 + v3 + v4 + v5 << "."; +} + +// A 5-ary predicate-formatter functor. +struct PredFormatFunctor5 { + template + testing::AssertionResult operator()(const char* e1, + const char* e2, + const char* e3, + const char* e4, + const char* e5, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4, + const T5& v5) const { + return PredFormatFunction5(e1, e2, e3, e4, e5, v1, v2, v3, v4, v5); + } +}; + +// Tests for {EXPECT|ASSERT}_PRED_FORMAT5. + +class Predicate5Test : public testing::Test { + protected: + void SetUp() override { + expected_to_finish_ = true; + finished_ = false; + n1_ = n2_ = n3_ = n4_ = n5_ = 0; + } + + void TearDown() override { + // Verifies that each of the predicate's arguments was evaluated + // exactly once. + EXPECT_EQ(1, n1_) << + "The predicate assertion didn't evaluate argument 2 " + "exactly once."; + EXPECT_EQ(1, n2_) << + "The predicate assertion didn't evaluate argument 3 " + "exactly once."; + EXPECT_EQ(1, n3_) << + "The predicate assertion didn't evaluate argument 4 " + "exactly once."; + EXPECT_EQ(1, n4_) << + "The predicate assertion didn't evaluate argument 5 " + "exactly once."; + EXPECT_EQ(1, n5_) << + "The predicate assertion didn't evaluate argument 6 " + "exactly once."; + + // Verifies that the control flow in the test function is expected. + if (expected_to_finish_ && !finished_) { + FAIL() << "The predicate assertion unexpectedly aborted the test."; + } else if (!expected_to_finish_ && finished_) { + FAIL() << "The failed predicate assertion didn't abort the test " + "as expected."; + } + } + + // true if and only if the test function is expected to run to finish. + static bool expected_to_finish_; + + // true if and only if the test function did run to finish. + static bool finished_; + + static int n1_; + static int n2_; + static int n3_; + static int n4_; + static int n5_; +}; + +bool Predicate5Test::expected_to_finish_; +bool Predicate5Test::finished_; +int Predicate5Test::n1_; +int Predicate5Test::n2_; +int Predicate5Test::n3_; +int Predicate5Test::n4_; +int Predicate5Test::n5_; + +typedef Predicate5Test EXPECT_PRED_FORMAT5Test; +typedef Predicate5Test ASSERT_PRED_FORMAT5Test; +typedef Predicate5Test EXPECT_PRED5Test; +typedef Predicate5Test ASSERT_PRED5Test; + +// Tests a successful EXPECT_PRED5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED5Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED5(PredFunction5Int, + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED5Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED5(PredFunction5Bool, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED5Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED5(PredFunctor5(), + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED5Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED5(PredFunctor5(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED5Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED5(PredFunction5Int, + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED5Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED5(PredFunction5Bool, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED5Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED5(PredFunctor5(), + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED5Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED5(PredFunctor5(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED5Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED5(PredFunction5Int, + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED5Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED5(PredFunction5Bool, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED5Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED5(PredFunctor5(), + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED5Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED5(PredFunctor5(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED5Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED5(PredFunction5Int, + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED5Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED5(PredFunction5Bool, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED5Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED5(PredFunctor5(), + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED5Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED5(PredFunctor5(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} + +// Tests a successful EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctionOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT5(PredFormatFunction5, + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctionOnUserTypeSuccess) { + EXPECT_PRED_FORMAT5(PredFormatFunction5, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctorOnBuiltInTypeSuccess) { + EXPECT_PRED_FORMAT5(PredFormatFunctor5(), + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctorOnUserTypeSuccess) { + EXPECT_PRED_FORMAT5(PredFormatFunctor5(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a failed EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctionOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT5(PredFormatFunction5, + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctionOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT5(PredFormatFunction5, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctorOnBuiltInTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT5(PredFormatFunctor5(), + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed EXPECT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(EXPECT_PRED_FORMAT5Test, FunctorOnUserTypeFailure) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT5(PredFormatFunctor5(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} + +// Tests a successful ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctionOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT5(PredFormatFunction5, + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctionOnUserTypeSuccess) { + ASSERT_PRED_FORMAT5(PredFormatFunction5, + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctorOnBuiltInTypeSuccess) { + ASSERT_PRED_FORMAT5(PredFormatFunctor5(), + ++n1_, + ++n2_, + ++n3_, + ++n4_, + ++n5_); + finished_ = true; +} + +// Tests a successful ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctorOnUserTypeSuccess) { + ASSERT_PRED_FORMAT5(PredFormatFunctor5(), + Bool(++n1_), + Bool(++n2_), + Bool(++n3_), + Bool(++n4_), + Bool(++n5_)); + finished_ = true; +} + +// Tests a failed ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a function on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctionOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT5(PredFormatFunction5, + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a function on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctionOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT5(PredFormatFunction5, + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a built-in type (int). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctorOnBuiltInTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT5(PredFormatFunctor5(), + n1_++, + n2_++, + n3_++, + n4_++, + n5_++); + finished_ = true; + }, ""); +} + +// Tests a failed ASSERT_PRED_FORMAT5 where the +// predicate-formatter is a functor on a user-defined type (Bool). +TEST_F(ASSERT_PRED_FORMAT5Test, FunctorOnUserTypeFailure) { + expected_to_finish_ = false; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT5(PredFormatFunctor5(), + Bool(n1_++), + Bool(n2_++), + Bool(n3_++), + Bool(n4_++), + Bool(n5_++)); + finished_ = true; + }, ""); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_premature_exit_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_premature_exit_test.cc new file mode 100644 index 000000000000..1d1187eff006 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_premature_exit_test.cc @@ -0,0 +1,126 @@ +// Copyright 2013, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests that Google Test manipulates the premature-exit-detection +// file correctly. + +#include + +#include "gtest/gtest.h" + +using ::testing::InitGoogleTest; +using ::testing::Test; +using ::testing::internal::posix::GetEnv; +using ::testing::internal::posix::Stat; +using ::testing::internal::posix::StatStruct; + +namespace { + +class PrematureExitTest : public Test { + public: + // Returns true if and only if the given file exists. + static bool FileExists(const char* filepath) { + StatStruct stat; + return Stat(filepath, &stat) == 0; + } + + protected: + PrematureExitTest() { + premature_exit_file_path_ = GetEnv("TEST_PREMATURE_EXIT_FILE"); + + // Normalize NULL to "" for ease of handling. + if (premature_exit_file_path_ == nullptr) { + premature_exit_file_path_ = ""; + } + } + + // Returns true if and only if the premature-exit file exists. + bool PrematureExitFileExists() const { + return FileExists(premature_exit_file_path_); + } + + const char* premature_exit_file_path_; +}; + +typedef PrematureExitTest PrematureExitDeathTest; + +// Tests that: +// - the premature-exit file exists during the execution of a +// death test (EXPECT_DEATH*), and +// - a death test doesn't interfere with the main test process's +// handling of the premature-exit file. +TEST_F(PrematureExitDeathTest, FileExistsDuringExecutionOfDeathTest) { + if (*premature_exit_file_path_ == '\0') { + return; + } + + EXPECT_DEATH_IF_SUPPORTED({ + // If the file exists, crash the process such that the main test + // process will catch the (expected) crash and report a success; + // otherwise don't crash, which will cause the main test process + // to report that the death test has failed. + if (PrematureExitFileExists()) { + exit(1); + } + }, ""); +} + +// Tests that the premature-exit file exists during the execution of a +// normal (non-death) test. +TEST_F(PrematureExitTest, PrematureExitFileExistsDuringTestExecution) { + if (*premature_exit_file_path_ == '\0') { + return; + } + + EXPECT_TRUE(PrematureExitFileExists()) + << " file " << premature_exit_file_path_ + << " should exist during test execution, but doesn't."; +} + +} // namespace + +int main(int argc, char **argv) { + InitGoogleTest(&argc, argv); + const int exit_code = RUN_ALL_TESTS(); + + // Test that the premature-exit file is deleted upon return from + // RUN_ALL_TESTS(). + const char* const filepath = GetEnv("TEST_PREMATURE_EXIT_FILE"); + if (filepath != nullptr && *filepath != '\0') { + if (PrematureExitTest::FileExists(filepath)) { + printf( + "File %s shouldn't exist after the test program finishes, but does.", + filepath); + return 1; + } + } + + return exit_code; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_prod_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_prod_test.cc new file mode 100644 index 000000000000..ede81a0d17a0 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_prod_test.cc @@ -0,0 +1,56 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Unit test for gtest_prod.h. + +#include "production.h" +#include "gtest/gtest.h" + +// Tests that private members can be accessed from a TEST declared as +// a friend of the class. +TEST(PrivateCodeTest, CanAccessPrivateMembers) { + PrivateCode a; + EXPECT_EQ(0, a.x_); + + a.set_x(1); + EXPECT_EQ(1, a.x_); +} + +typedef testing::Test PrivateCodeFixtureTest; + +// Tests that private members can be accessed from a TEST_F declared +// as a friend of the class. +TEST_F(PrivateCodeFixtureTest, CanAccessPrivateMembers) { + PrivateCode a; + EXPECT_EQ(0, a.x_); + + a.set_x(2); + EXPECT_EQ(2, a.x_); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_repeat_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_repeat_test.cc new file mode 100644 index 000000000000..6b10048f8303 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_repeat_test.cc @@ -0,0 +1,225 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Tests the --gtest_repeat=number flag. + +#include +#include +#include "gtest/gtest.h" +#include "src/gtest-internal-inl.h" + +namespace { + +// We need this when we are testing Google Test itself and therefore +// cannot use Google Test assertions. +#define GTEST_CHECK_INT_EQ_(expected, actual) \ + do {\ + const int expected_val = (expected);\ + const int actual_val = (actual);\ + if (::testing::internal::IsTrue(expected_val != actual_val)) {\ + ::std::cout << "Value of: " #actual "\n"\ + << " Actual: " << actual_val << "\n"\ + << "Expected: " #expected "\n"\ + << "Which is: " << expected_val << "\n";\ + ::testing::internal::posix::Abort();\ + }\ + } while (::testing::internal::AlwaysFalse()) + + +// Used for verifying that global environment set-up and tear-down are +// inside the --gtest_repeat loop. + +int g_environment_set_up_count = 0; +int g_environment_tear_down_count = 0; + +class MyEnvironment : public testing::Environment { + public: + MyEnvironment() {} + void SetUp() override { g_environment_set_up_count++; } + void TearDown() override { g_environment_tear_down_count++; } +}; + +// A test that should fail. + +int g_should_fail_count = 0; + +TEST(FooTest, ShouldFail) { + g_should_fail_count++; + EXPECT_EQ(0, 1) << "Expected failure."; +} + +// A test that should pass. + +int g_should_pass_count = 0; + +TEST(FooTest, ShouldPass) { + g_should_pass_count++; +} + +// A test that contains a thread-safe death test and a fast death +// test. It should pass. + +int g_death_test_count = 0; + +TEST(BarDeathTest, ThreadSafeAndFast) { + g_death_test_count++; + + GTEST_FLAG_SET(death_test_style, "threadsafe"); + EXPECT_DEATH_IF_SUPPORTED(::testing::internal::posix::Abort(), ""); + + GTEST_FLAG_SET(death_test_style, "fast"); + EXPECT_DEATH_IF_SUPPORTED(::testing::internal::posix::Abort(), ""); +} + +int g_param_test_count = 0; + +const int kNumberOfParamTests = 10; + +class MyParamTest : public testing::TestWithParam {}; + +TEST_P(MyParamTest, ShouldPass) { + GTEST_CHECK_INT_EQ_(g_param_test_count % kNumberOfParamTests, GetParam()); + g_param_test_count++; +} +INSTANTIATE_TEST_SUITE_P(MyParamSequence, + MyParamTest, + testing::Range(0, kNumberOfParamTests)); + +// Resets the count for each test. +void ResetCounts() { + g_environment_set_up_count = 0; + g_environment_tear_down_count = 0; + g_should_fail_count = 0; + g_should_pass_count = 0; + g_death_test_count = 0; + g_param_test_count = 0; +} + +// Checks that the count for each test is expected. +void CheckCounts(int expected) { + GTEST_CHECK_INT_EQ_(expected, g_environment_set_up_count); + GTEST_CHECK_INT_EQ_(expected, g_environment_tear_down_count); + GTEST_CHECK_INT_EQ_(expected, g_should_fail_count); + GTEST_CHECK_INT_EQ_(expected, g_should_pass_count); + GTEST_CHECK_INT_EQ_(expected, g_death_test_count); + GTEST_CHECK_INT_EQ_(expected * kNumberOfParamTests, g_param_test_count); +} + +// Tests the behavior of Google Test when --gtest_repeat is not specified. +void TestRepeatUnspecified() { + ResetCounts(); + GTEST_CHECK_INT_EQ_(1, RUN_ALL_TESTS()); + CheckCounts(1); +} + +// Tests the behavior of Google Test when --gtest_repeat has the given value. +void TestRepeat(int repeat) { + GTEST_FLAG_SET(repeat, repeat); + GTEST_FLAG_SET(recreate_environments_when_repeating, true); + + ResetCounts(); + GTEST_CHECK_INT_EQ_(repeat > 0 ? 1 : 0, RUN_ALL_TESTS()); + CheckCounts(repeat); +} + +// Tests using --gtest_repeat when --gtest_filter specifies an empty +// set of tests. +void TestRepeatWithEmptyFilter(int repeat) { + GTEST_FLAG_SET(repeat, repeat); + GTEST_FLAG_SET(recreate_environments_when_repeating, true); + GTEST_FLAG_SET(filter, "None"); + + ResetCounts(); + GTEST_CHECK_INT_EQ_(0, RUN_ALL_TESTS()); + CheckCounts(0); +} + +// Tests using --gtest_repeat when --gtest_filter specifies a set of +// successful tests. +void TestRepeatWithFilterForSuccessfulTests(int repeat) { + GTEST_FLAG_SET(repeat, repeat); + GTEST_FLAG_SET(recreate_environments_when_repeating, true); + GTEST_FLAG_SET(filter, "*-*ShouldFail"); + + ResetCounts(); + GTEST_CHECK_INT_EQ_(0, RUN_ALL_TESTS()); + GTEST_CHECK_INT_EQ_(repeat, g_environment_set_up_count); + GTEST_CHECK_INT_EQ_(repeat, g_environment_tear_down_count); + GTEST_CHECK_INT_EQ_(0, g_should_fail_count); + GTEST_CHECK_INT_EQ_(repeat, g_should_pass_count); + GTEST_CHECK_INT_EQ_(repeat, g_death_test_count); + GTEST_CHECK_INT_EQ_(repeat * kNumberOfParamTests, g_param_test_count); +} + +// Tests using --gtest_repeat when --gtest_filter specifies a set of +// failed tests. +void TestRepeatWithFilterForFailedTests(int repeat) { + GTEST_FLAG_SET(repeat, repeat); + GTEST_FLAG_SET(recreate_environments_when_repeating, true); + GTEST_FLAG_SET(filter, "*ShouldFail"); + + ResetCounts(); + GTEST_CHECK_INT_EQ_(1, RUN_ALL_TESTS()); + GTEST_CHECK_INT_EQ_(repeat, g_environment_set_up_count); + GTEST_CHECK_INT_EQ_(repeat, g_environment_tear_down_count); + GTEST_CHECK_INT_EQ_(repeat, g_should_fail_count); + GTEST_CHECK_INT_EQ_(0, g_should_pass_count); + GTEST_CHECK_INT_EQ_(0, g_death_test_count); + GTEST_CHECK_INT_EQ_(0, g_param_test_count); +} + +} // namespace + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + + testing::AddGlobalTestEnvironment(new MyEnvironment); + + TestRepeatUnspecified(); + TestRepeat(0); + TestRepeat(1); + TestRepeat(5); + + TestRepeatWithEmptyFilter(2); + TestRepeatWithEmptyFilter(3); + + TestRepeatWithFilterForSuccessfulTests(3); + + TestRepeatWithFilterForFailedTests(4); + + // It would be nice to verify that the tests indeed loop forever + // when GTEST_FLAG(repeat) is negative, but this test will be quite + // complicated to write. Since this flag is for interactive + // debugging only and doesn't affect the normal test result, such a + // test would be an overkill. + + printf("PASS\n"); + return 0; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_check_output_test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_check_output_test.py new file mode 100755 index 000000000000..1c87b44f0157 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_check_output_test.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# +# Copyright 2019 Google LLC. All Rights Reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Tests Google Test's gtest skip in environment setup behavior. + +This script invokes gtest_skip_in_environment_setup_test_ and verifies its +output. +""" + +import re + +from googletest.test import gtest_test_utils + +# Path to the gtest_skip_in_environment_setup_test binary +EXE_PATH = gtest_test_utils.GetTestExecutablePath('gtest_skip_test') + +OUTPUT = gtest_test_utils.Subprocess([EXE_PATH]).output + + +# Test. +class SkipEntireEnvironmentTest(gtest_test_utils.TestCase): + + def testSkipEntireEnvironmentTest(self): + self.assertIn('Skipped\nskipping single test\n', OUTPUT) + skip_fixture = 'Skipped\nskipping all tests for this fixture\n' + self.assertIsNotNone( + re.search(skip_fixture + '.*' + skip_fixture, OUTPUT, flags=re.DOTALL), + repr(OUTPUT)) + self.assertNotIn('FAILED', OUTPUT) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_environment_check_output_test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_environment_check_output_test.py new file mode 100755 index 000000000000..6960b11a5866 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_environment_check_output_test.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# +# Copyright 2019 Google LLC. All Rights Reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Tests Google Test's gtest skip in environment setup behavior. + +This script invokes gtest_skip_in_environment_setup_test_ and verifies its +output. +""" + +from googletest.test import gtest_test_utils + +# Path to the gtest_skip_in_environment_setup_test binary +EXE_PATH = gtest_test_utils.GetTestExecutablePath( + 'gtest_skip_in_environment_setup_test') + +OUTPUT = gtest_test_utils.Subprocess([EXE_PATH]).output + + +# Test. +class SkipEntireEnvironmentTest(gtest_test_utils.TestCase): + + def testSkipEntireEnvironmentTest(self): + self.assertIn('Skipping the entire environment', OUTPUT) + self.assertNotIn('FAILED', OUTPUT) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_in_environment_setup_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_in_environment_setup_test.cc new file mode 100644 index 000000000000..937231063816 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_in_environment_setup_test.cc @@ -0,0 +1,49 @@ +// Copyright 2019, Google LLC. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google LLC. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// This test verifies that skipping in the environment results in the +// testcases being skipped. + +#include +#include "gtest/gtest.h" + +class SetupEnvironment : public testing::Environment { + public: + void SetUp() override { GTEST_SKIP() << "Skipping the entire environment"; } +}; + +TEST(Test, AlwaysFails) { EXPECT_EQ(true, false); } + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + + testing::AddGlobalTestEnvironment(new SetupEnvironment()); + + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_test.cc new file mode 100644 index 000000000000..4a23004cca36 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_skip_test.cc @@ -0,0 +1,55 @@ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: arseny.aprelev@gmail.com (Arseny Aprelev) +// + +#include "gtest/gtest.h" + +using ::testing::Test; + +TEST(SkipTest, DoesSkip) { + GTEST_SKIP() << "skipping single test"; + EXPECT_EQ(0, 1); +} + +class Fixture : public Test { + protected: + void SetUp() override { + GTEST_SKIP() << "skipping all tests for this fixture"; + } +}; + +TEST_F(Fixture, SkipsOneTest) { + EXPECT_EQ(5, 7); +} + +TEST_F(Fixture, SkipsAnotherTest) { + EXPECT_EQ(99, 100); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_sole_header_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_sole_header_test.cc new file mode 100644 index 000000000000..1d94ac6b3ad9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_sole_header_test.cc @@ -0,0 +1,56 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// This test verifies that it's possible to use Google Test by including +// the gtest.h header file alone. + +#include "gtest/gtest.h" + +namespace { + +void Subroutine() { + EXPECT_EQ(42, 42); +} + +TEST(NoFatalFailureTest, ExpectNoFatalFailure) { + EXPECT_NO_FATAL_FAILURE(;); + EXPECT_NO_FATAL_FAILURE(SUCCEED()); + EXPECT_NO_FATAL_FAILURE(Subroutine()); + EXPECT_NO_FATAL_FAILURE({ SUCCEED(); }); +} + +TEST(NoFatalFailureTest, AssertNoFatalFailure) { + ASSERT_NO_FATAL_FAILURE(;); + ASSERT_NO_FATAL_FAILURE(SUCCEED()); + ASSERT_NO_FATAL_FAILURE(Subroutine()); + ASSERT_NO_FATAL_FAILURE({ SUCCEED(); }); +} + +} // namespace diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_stress_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_stress_test.cc new file mode 100644 index 000000000000..843481910f03 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_stress_test.cc @@ -0,0 +1,248 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Tests that SCOPED_TRACE() and various Google Test assertions can be +// used in a large number of threads concurrently. + +#include "gtest/gtest.h" + +#include + +#include "src/gtest-internal-inl.h" + +#if GTEST_IS_THREADSAFE + +namespace testing { +namespace { + +using internal::Notification; +using internal::TestPropertyKeyIs; +using internal::ThreadWithParam; + +// In order to run tests in this file, for platforms where Google Test is +// thread safe, implement ThreadWithParam. See the description of its API +// in gtest-port.h, where it is defined for already supported platforms. + +// How many threads to create? +const int kThreadCount = 50; + +std::string IdToKey(int id, const char* suffix) { + Message key; + key << "key_" << id << "_" << suffix; + return key.GetString(); +} + +std::string IdToString(int id) { + Message id_message; + id_message << id; + return id_message.GetString(); +} + +void ExpectKeyAndValueWereRecordedForId( + const std::vector& properties, + int id, const char* suffix) { + TestPropertyKeyIs matches_key(IdToKey(id, suffix).c_str()); + const std::vector::const_iterator property = + std::find_if(properties.begin(), properties.end(), matches_key); + ASSERT_TRUE(property != properties.end()) + << "expecting " << suffix << " value for id " << id; + EXPECT_STREQ(IdToString(id).c_str(), property->value()); +} + +// Calls a large number of Google Test assertions, where exactly one of them +// will fail. +void ManyAsserts(int id) { + GTEST_LOG_(INFO) << "Thread #" << id << " running..."; + + SCOPED_TRACE(Message() << "Thread #" << id); + + for (int i = 0; i < kThreadCount; i++) { + SCOPED_TRACE(Message() << "Iteration #" << i); + + // A bunch of assertions that should succeed. + EXPECT_TRUE(true); + ASSERT_FALSE(false) << "This shouldn't fail."; + EXPECT_STREQ("a", "a"); + ASSERT_LE(5, 6); + EXPECT_EQ(i, i) << "This shouldn't fail."; + + // RecordProperty() should interact safely with other threads as well. + // The shared_key forces property updates. + Test::RecordProperty(IdToKey(id, "string").c_str(), IdToString(id).c_str()); + Test::RecordProperty(IdToKey(id, "int").c_str(), id); + Test::RecordProperty("shared_key", IdToString(id).c_str()); + + // This assertion should fail kThreadCount times per thread. It + // is for testing whether Google Test can handle failed assertions in a + // multi-threaded context. + EXPECT_LT(i, 0) << "This should always fail."; + } +} + +void CheckTestFailureCount(int expected_failures) { + const TestInfo* const info = UnitTest::GetInstance()->current_test_info(); + const TestResult* const result = info->result(); + GTEST_CHECK_(expected_failures == result->total_part_count()) + << "Logged " << result->total_part_count() << " failures " + << " vs. " << expected_failures << " expected"; +} + +// Tests using SCOPED_TRACE() and Google Test assertions in many threads +// concurrently. +TEST(StressTest, CanUseScopedTraceAndAssertionsInManyThreads) { + { + std::unique_ptr > threads[kThreadCount]; + Notification threads_can_start; + for (int i = 0; i != kThreadCount; i++) + threads[i].reset(new ThreadWithParam(&ManyAsserts, + i, + &threads_can_start)); + + threads_can_start.Notify(); + + // Blocks until all the threads are done. + for (int i = 0; i != kThreadCount; i++) + threads[i]->Join(); + } + + // Ensures that kThreadCount*kThreadCount failures have been reported. + const TestInfo* const info = UnitTest::GetInstance()->current_test_info(); + const TestResult* const result = info->result(); + + std::vector properties; + // We have no access to the TestResult's list of properties but we can + // copy them one by one. + for (int i = 0; i < result->test_property_count(); ++i) + properties.push_back(result->GetTestProperty(i)); + + EXPECT_EQ(kThreadCount * 2 + 1, result->test_property_count()) + << "String and int values recorded on each thread, " + << "as well as one shared_key"; + for (int i = 0; i < kThreadCount; ++i) { + ExpectKeyAndValueWereRecordedForId(properties, i, "string"); + ExpectKeyAndValueWereRecordedForId(properties, i, "int"); + } + CheckTestFailureCount(kThreadCount*kThreadCount); +} + +void FailingThread(bool is_fatal) { + if (is_fatal) + FAIL() << "Fatal failure in some other thread. " + << "(This failure is expected.)"; + else + ADD_FAILURE() << "Non-fatal failure in some other thread. " + << "(This failure is expected.)"; +} + +void GenerateFatalFailureInAnotherThread(bool is_fatal) { + ThreadWithParam thread(&FailingThread, is_fatal, nullptr); + thread.Join(); +} + +TEST(NoFatalFailureTest, ExpectNoFatalFailureIgnoresFailuresInOtherThreads) { + EXPECT_NO_FATAL_FAILURE(GenerateFatalFailureInAnotherThread(true)); + // We should only have one failure (the one from + // GenerateFatalFailureInAnotherThread()), since the EXPECT_NO_FATAL_FAILURE + // should succeed. + CheckTestFailureCount(1); +} + +void AssertNoFatalFailureIgnoresFailuresInOtherThreads() { + ASSERT_NO_FATAL_FAILURE(GenerateFatalFailureInAnotherThread(true)); +} +TEST(NoFatalFailureTest, AssertNoFatalFailureIgnoresFailuresInOtherThreads) { + // Using a subroutine, to make sure, that the test continues. + AssertNoFatalFailureIgnoresFailuresInOtherThreads(); + // We should only have one failure (the one from + // GenerateFatalFailureInAnotherThread()), since the EXPECT_NO_FATAL_FAILURE + // should succeed. + CheckTestFailureCount(1); +} + +TEST(FatalFailureTest, ExpectFatalFailureIgnoresFailuresInOtherThreads) { + // This statement should fail, since the current thread doesn't generate a + // fatal failure, only another one does. + EXPECT_FATAL_FAILURE(GenerateFatalFailureInAnotherThread(true), "expected"); + CheckTestFailureCount(2); +} + +TEST(FatalFailureOnAllThreadsTest, ExpectFatalFailureOnAllThreads) { + // This statement should succeed, because failures in all threads are + // considered. + EXPECT_FATAL_FAILURE_ON_ALL_THREADS( + GenerateFatalFailureInAnotherThread(true), "expected"); + CheckTestFailureCount(0); + // We need to add a failure, because main() checks that there are failures. + // But when only this test is run, we shouldn't have any failures. + ADD_FAILURE() << "This is an expected non-fatal failure."; +} + +TEST(NonFatalFailureTest, ExpectNonFatalFailureIgnoresFailuresInOtherThreads) { + // This statement should fail, since the current thread doesn't generate a + // fatal failure, only another one does. + EXPECT_NONFATAL_FAILURE(GenerateFatalFailureInAnotherThread(false), + "expected"); + CheckTestFailureCount(2); +} + +TEST(NonFatalFailureOnAllThreadsTest, ExpectNonFatalFailureOnAllThreads) { + // This statement should succeed, because failures in all threads are + // considered. + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS( + GenerateFatalFailureInAnotherThread(false), "expected"); + CheckTestFailureCount(0); + // We need to add a failure, because main() checks that there are failures, + // But when only this test is run, we shouldn't have any failures. + ADD_FAILURE() << "This is an expected non-fatal failure."; +} + +} // namespace +} // namespace testing + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + + const int result = RUN_ALL_TESTS(); // Expected to fail. + GTEST_CHECK_(result == 1) << "RUN_ALL_TESTS() did not fail as expected"; + + printf("\nPASS\n"); + return 0; +} + +#else +TEST(StressTest, + DISABLED_ThreadSafetyTestsAreSkippedWhenGoogleTestIsNotThreadSafe) { +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +#endif // GTEST_IS_THREADSAFE diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_test_macro_stack_footprint_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_test_macro_stack_footprint_test.cc new file mode 100644 index 000000000000..a48db05012cc --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_test_macro_stack_footprint_test.cc @@ -0,0 +1,89 @@ +// Copyright 2013, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Each TEST() expands to some static registration logic. GCC puts all +// such static initialization logic for a translation unit in a common, +// internal function. Since Google's build system restricts how much +// stack space a function can use, there's a limit on how many TEST()s +// one can put in a single C++ test file. This test ensures that a large +// number of TEST()s can be defined in the same translation unit. + +#include "gtest/gtest.h" + +// This macro defines 10 dummy tests. +#define TEN_TESTS_(test_case_name) \ + TEST(test_case_name, T0) {} \ + TEST(test_case_name, T1) {} \ + TEST(test_case_name, T2) {} \ + TEST(test_case_name, T3) {} \ + TEST(test_case_name, T4) {} \ + TEST(test_case_name, T5) {} \ + TEST(test_case_name, T6) {} \ + TEST(test_case_name, T7) {} \ + TEST(test_case_name, T8) {} \ + TEST(test_case_name, T9) {} + +// This macro defines 100 dummy tests. +#define HUNDRED_TESTS_(test_case_name_prefix) \ + TEN_TESTS_(test_case_name_prefix ## 0) \ + TEN_TESTS_(test_case_name_prefix ## 1) \ + TEN_TESTS_(test_case_name_prefix ## 2) \ + TEN_TESTS_(test_case_name_prefix ## 3) \ + TEN_TESTS_(test_case_name_prefix ## 4) \ + TEN_TESTS_(test_case_name_prefix ## 5) \ + TEN_TESTS_(test_case_name_prefix ## 6) \ + TEN_TESTS_(test_case_name_prefix ## 7) \ + TEN_TESTS_(test_case_name_prefix ## 8) \ + TEN_TESTS_(test_case_name_prefix ## 9) + +// This macro defines 1000 dummy tests. +#define THOUSAND_TESTS_(test_case_name_prefix) \ + HUNDRED_TESTS_(test_case_name_prefix ## 0) \ + HUNDRED_TESTS_(test_case_name_prefix ## 1) \ + HUNDRED_TESTS_(test_case_name_prefix ## 2) \ + HUNDRED_TESTS_(test_case_name_prefix ## 3) \ + HUNDRED_TESTS_(test_case_name_prefix ## 4) \ + HUNDRED_TESTS_(test_case_name_prefix ## 5) \ + HUNDRED_TESTS_(test_case_name_prefix ## 6) \ + HUNDRED_TESTS_(test_case_name_prefix ## 7) \ + HUNDRED_TESTS_(test_case_name_prefix ## 8) \ + HUNDRED_TESTS_(test_case_name_prefix ## 9) + +// Ensures that we can define 1000 TEST()s in the same translation +// unit. +THOUSAND_TESTS_(T) + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + + // We don't actually need to run the dummy tests - the purpose is to + // ensure that they compile. + return 0; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_test_utils.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_test_utils.py new file mode 100755 index 000000000000..eecc53346c27 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_test_utils.py @@ -0,0 +1,255 @@ +# Copyright 2006, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test utilities for Google C++ Testing and Mocking Framework.""" +# Suppresses the 'Import not at the top of the file' lint complaint. +# pylint: disable-msg=C6204 + +import os +import subprocess +import sys + +IS_WINDOWS = os.name == 'nt' +IS_CYGWIN = os.name == 'posix' and 'CYGWIN' in os.uname()[0] +IS_OS2 = os.name == 'os2' + +import atexit +import shutil +import tempfile +import unittest as _test_module +# pylint: enable-msg=C6204 + +GTEST_OUTPUT_VAR_NAME = 'GTEST_OUTPUT' + +# The environment variable for specifying the path to the premature-exit file. +PREMATURE_EXIT_FILE_ENV_VAR = 'TEST_PREMATURE_EXIT_FILE' + +environ = os.environ.copy() + + +def SetEnvVar(env_var, value): + """Sets/unsets an environment variable to a given value.""" + + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] + + +# Here we expose a class from a particular module, depending on the +# environment. The comment suppresses the 'Invalid variable name' lint +# complaint. +TestCase = _test_module.TestCase # pylint: disable=C6409 + +# Initially maps a flag to its default value. After +# _ParseAndStripGTestFlags() is called, maps a flag to its actual value. +_flag_map = {'source_dir': os.path.dirname(sys.argv[0]), + 'build_dir': os.path.dirname(sys.argv[0])} +_gtest_flags_are_parsed = False + + +def _ParseAndStripGTestFlags(argv): + """Parses and strips Google Test flags from argv. This is idempotent.""" + + # Suppresses the lint complaint about a global variable since we need it + # here to maintain module-wide state. + global _gtest_flags_are_parsed # pylint: disable=W0603 + if _gtest_flags_are_parsed: + return + + _gtest_flags_are_parsed = True + for flag in _flag_map: + # The environment variable overrides the default value. + if flag.upper() in os.environ: + _flag_map[flag] = os.environ[flag.upper()] + + # The command line flag overrides the environment variable. + i = 1 # Skips the program name. + while i < len(argv): + prefix = '--' + flag + '=' + if argv[i].startswith(prefix): + _flag_map[flag] = argv[i][len(prefix):] + del argv[i] + break + else: + # We don't increment i in case we just found a --gtest_* flag + # and removed it from argv. + i += 1 + + +def GetFlag(flag): + """Returns the value of the given flag.""" + + # In case GetFlag() is called before Main(), we always call + # _ParseAndStripGTestFlags() here to make sure the --gtest_* flags + # are parsed. + _ParseAndStripGTestFlags(sys.argv) + + return _flag_map[flag] + + +def GetSourceDir(): + """Returns the absolute path of the directory where the .py files are.""" + + return os.path.abspath(GetFlag('source_dir')) + + +def GetBuildDir(): + """Returns the absolute path of the directory where the test binaries are.""" + + return os.path.abspath(GetFlag('build_dir')) + + +_temp_dir = None + +def _RemoveTempDir(): + if _temp_dir: + shutil.rmtree(_temp_dir, ignore_errors=True) + +atexit.register(_RemoveTempDir) + + +def GetTempDir(): + global _temp_dir + if not _temp_dir: + _temp_dir = tempfile.mkdtemp() + return _temp_dir + + +def GetTestExecutablePath(executable_name, build_dir=None): + """Returns the absolute path of the test binary given its name. + + The function will print a message and abort the program if the resulting file + doesn't exist. + + Args: + executable_name: name of the test binary that the test script runs. + build_dir: directory where to look for executables, by default + the result of GetBuildDir(). + + Returns: + The absolute path of the test binary. + """ + + path = os.path.abspath(os.path.join(build_dir or GetBuildDir(), + executable_name)) + if (IS_WINDOWS or IS_CYGWIN or IS_OS2) and not path.endswith('.exe'): + path += '.exe' + + if not os.path.exists(path): + message = ( + 'Unable to find the test binary "%s". Please make sure to provide\n' + 'a path to the binary via the --build_dir flag or the BUILD_DIR\n' + 'environment variable.' % path) + print(message, file=sys.stderr) + sys.exit(1) + + return path + + +def GetExitStatus(exit_code): + """Returns the argument to exit(), or -1 if exit() wasn't called. + + Args: + exit_code: the result value of os.system(command). + """ + + if os.name == 'nt': + # On Windows, os.WEXITSTATUS() doesn't work and os.system() returns + # the argument to exit() directly. + return exit_code + else: + # On Unix, os.WEXITSTATUS() must be used to extract the exit status + # from the result of os.system(). + if os.WIFEXITED(exit_code): + return os.WEXITSTATUS(exit_code) + else: + return -1 + + +class Subprocess: + def __init__(self, command, working_dir=None, capture_stderr=True, env=None): + """Changes into a specified directory, if provided, and executes a command. + + Restores the old directory afterwards. + + Args: + command: The command to run, in the form of sys.argv. + working_dir: The directory to change into. + capture_stderr: Determines whether to capture stderr in the output member + or to discard it. + env: Dictionary with environment to pass to the subprocess. + + Returns: + An object that represents outcome of the executed process. It has the + following attributes: + terminated_by_signal True if and only if the child process has been + terminated by a signal. + exited True if and only if the child process exited + normally. + exit_code The code with which the child process exited. + output Child process's stdout and stderr output + combined in a string. + """ + + if capture_stderr: + stderr = subprocess.STDOUT + else: + stderr = subprocess.PIPE + + p = subprocess.Popen(command, + stdout=subprocess.PIPE, stderr=stderr, + cwd=working_dir, universal_newlines=True, env=env) + # communicate returns a tuple with the file object for the child's + # output. + self.output = p.communicate()[0] + self._return_code = p.returncode + + if bool(self._return_code & 0x80000000): + self.terminated_by_signal = True + self.exited = False + else: + self.terminated_by_signal = False + self.exited = True + self.exit_code = self._return_code + + +def Main(): + """Runs the unit test.""" + + # We must call _ParseAndStripGTestFlags() before calling + # unittest.main(). Otherwise the latter will be confused by the + # --gtest_* flags. + _ParseAndStripGTestFlags(sys.argv) + # The tested binaries should not be writing XML output files unless the + # script explicitly instructs them to. + if GTEST_OUTPUT_VAR_NAME in os.environ: + del os.environ[GTEST_OUTPUT_VAR_NAME] + + _test_module.main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_testbridge_test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_testbridge_test.py new file mode 100755 index 000000000000..1c2a303a8887 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_testbridge_test.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# +# Copyright 2018 Google LLC. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Verifies that Google Test uses filter provided via testbridge.""" + +import os + +from googletest.test import gtest_test_utils + +binary_name = 'gtest_testbridge_test_' +COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) +TESTBRIDGE_NAME = 'TESTBRIDGE_TEST_ONLY' + + +def Assert(condition): + if not condition: + raise AssertionError + + +class GTestTestFilterTest(gtest_test_utils.TestCase): + + def testTestExecutionIsFiltered(self): + """Tests that the test filter is picked up from the testbridge env var.""" + subprocess_env = os.environ.copy() + + subprocess_env[TESTBRIDGE_NAME] = '*.TestThatSucceeds' + p = gtest_test_utils.Subprocess(COMMAND, env=subprocess_env) + + self.assertEquals(0, p.exit_code) + + Assert('filter = *.TestThatSucceeds' in p.output) + Assert('[ OK ] TestFilterTest.TestThatSucceeds' in p.output) + Assert('[ PASSED ] 1 test.' in p.output) + + +if __name__ == '__main__': + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_testbridge_test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_testbridge_test_.cc new file mode 100644 index 000000000000..24617b209e10 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_testbridge_test_.cc @@ -0,0 +1,43 @@ +// Copyright 2018, Google LLC. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// This program is meant to be run by gtest_test_filter_test.py. Do not run +// it directly. + +#include "gtest/gtest.h" + +// These tests are used to detect if filtering is working. Only +// 'TestThatSucceeds' should ever run. + +TEST(TestFilterTest, TestThatSucceeds) {} + +TEST(TestFilterTest, TestThatFails) { + ASSERT_TRUE(false) << "This test should never be run."; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_throw_on_failure_ex_test.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_throw_on_failure_ex_test.cc new file mode 100644 index 000000000000..aeead13feb56 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_throw_on_failure_ex_test.cc @@ -0,0 +1,90 @@ +// Copyright 2009, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Tests Google Test's throw-on-failure mode with exceptions enabled. + +#include "gtest/gtest.h" + +#include +#include +#include +#include + +// Prints the given failure message and exits the program with +// non-zero. We use this instead of a Google Test assertion to +// indicate a failure, as the latter is been tested and cannot be +// relied on. +void Fail(const char* msg) { + printf("FAILURE: %s\n", msg); + fflush(stdout); + exit(1); +} + +// Tests that an assertion failure throws a subclass of +// std::runtime_error. +void TestFailureThrowsRuntimeError() { + GTEST_FLAG_SET(throw_on_failure, true); + + // A successful assertion shouldn't throw. + try { + EXPECT_EQ(3, 3); + } catch(...) { + Fail("A successful assertion wrongfully threw."); + } + + // A failed assertion should throw a subclass of std::runtime_error. + try { + EXPECT_EQ(2, 3) << "Expected failure"; + } catch(const std::runtime_error& e) { + if (strstr(e.what(), "Expected failure") != nullptr) return; + + printf("%s", + "A failed assertion did throw an exception of the right type, " + "but the message is incorrect. Instead of containing \"Expected " + "failure\", it is:\n"); + Fail(e.what()); + } catch(...) { + Fail("A failed assertion threw the wrong type of exception."); + } + Fail("A failed assertion should've thrown but didn't."); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + + // We want to ensure that people can use Google Test assertions in + // other testing frameworks, as long as they initialize Google Test + // properly and set the thrown-on-failure mode. Therefore, we don't + // use Google Test's constructs for defining and running tests + // (e.g. TEST and RUN_ALL_TESTS) here. + + TestFailureThrowsRuntimeError(); + return 0; +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_unittest.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_unittest.cc new file mode 100644 index 000000000000..0f10df551bb2 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_unittest.cc @@ -0,0 +1,7833 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// Tests for Google Test itself. This verifies that the basic constructs of +// Google Test work. + +#include "gtest/gtest.h" + +// Verifies that the command line flag variables can be accessed in +// code once "gtest.h" has been #included. +// Do not move it after other gtest #includes. +TEST(CommandLineFlagsTest, CanBeAccessedInCodeOnceGTestHIsIncluded) { + bool dummy = + GTEST_FLAG_GET(also_run_disabled_tests) || + GTEST_FLAG_GET(break_on_failure) || GTEST_FLAG_GET(catch_exceptions) || + GTEST_FLAG_GET(color) != "unknown" || GTEST_FLAG_GET(fail_fast) || + GTEST_FLAG_GET(filter) != "unknown" || GTEST_FLAG_GET(list_tests) || + GTEST_FLAG_GET(output) != "unknown" || GTEST_FLAG_GET(brief) || + GTEST_FLAG_GET(print_time) || GTEST_FLAG_GET(random_seed) || + GTEST_FLAG_GET(repeat) > 0 || + GTEST_FLAG_GET(recreate_environments_when_repeating) || + GTEST_FLAG_GET(show_internal_stack_frames) || GTEST_FLAG_GET(shuffle) || + GTEST_FLAG_GET(stack_trace_depth) > 0 || + GTEST_FLAG_GET(stream_result_to) != "unknown" || + GTEST_FLAG_GET(throw_on_failure); + EXPECT_TRUE(dummy || !dummy); // Suppresses warning that dummy is unused. +} + +#include // For INT_MAX. +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "gtest/gtest-spi.h" +#include "src/gtest-internal-inl.h" + +namespace testing { +namespace internal { + +#if GTEST_CAN_STREAM_RESULTS_ + +class StreamingListenerTest : public Test { + public: + class FakeSocketWriter : public StreamingListener::AbstractSocketWriter { + public: + // Sends a string to the socket. + void Send(const std::string& message) override { output_ += message; } + + std::string output_; + }; + + StreamingListenerTest() + : fake_sock_writer_(new FakeSocketWriter), + streamer_(fake_sock_writer_), + test_info_obj_("FooTest", "Bar", nullptr, nullptr, + CodeLocation(__FILE__, __LINE__), nullptr, nullptr) {} + + protected: + std::string* output() { return &(fake_sock_writer_->output_); } + + FakeSocketWriter* const fake_sock_writer_; + StreamingListener streamer_; + UnitTest unit_test_; + TestInfo test_info_obj_; // The name test_info_ was taken by testing::Test. +}; + +TEST_F(StreamingListenerTest, OnTestProgramEnd) { + *output() = ""; + streamer_.OnTestProgramEnd(unit_test_); + EXPECT_EQ("event=TestProgramEnd&passed=1\n", *output()); +} + +TEST_F(StreamingListenerTest, OnTestIterationEnd) { + *output() = ""; + streamer_.OnTestIterationEnd(unit_test_, 42); + EXPECT_EQ("event=TestIterationEnd&passed=1&elapsed_time=0ms\n", *output()); +} + +TEST_F(StreamingListenerTest, OnTestSuiteStart) { + *output() = ""; + streamer_.OnTestSuiteStart(TestSuite("FooTest", "Bar", nullptr, nullptr)); + EXPECT_EQ("event=TestCaseStart&name=FooTest\n", *output()); +} + +TEST_F(StreamingListenerTest, OnTestSuiteEnd) { + *output() = ""; + streamer_.OnTestSuiteEnd(TestSuite("FooTest", "Bar", nullptr, nullptr)); + EXPECT_EQ("event=TestCaseEnd&passed=1&elapsed_time=0ms\n", *output()); +} + +TEST_F(StreamingListenerTest, OnTestStart) { + *output() = ""; + streamer_.OnTestStart(test_info_obj_); + EXPECT_EQ("event=TestStart&name=Bar\n", *output()); +} + +TEST_F(StreamingListenerTest, OnTestEnd) { + *output() = ""; + streamer_.OnTestEnd(test_info_obj_); + EXPECT_EQ("event=TestEnd&passed=1&elapsed_time=0ms\n", *output()); +} + +TEST_F(StreamingListenerTest, OnTestPartResult) { + *output() = ""; + streamer_.OnTestPartResult(TestPartResult( + TestPartResult::kFatalFailure, "foo.cc", 42, "failed=\n&%")); + + // Meta characters in the failure message should be properly escaped. + EXPECT_EQ( + "event=TestPartResult&file=foo.cc&line=42&message=failed%3D%0A%26%25\n", + *output()); +} + +#endif // GTEST_CAN_STREAM_RESULTS_ + +// Provides access to otherwise private parts of the TestEventListeners class +// that are needed to test it. +class TestEventListenersAccessor { + public: + static TestEventListener* GetRepeater(TestEventListeners* listeners) { + return listeners->repeater(); + } + + static void SetDefaultResultPrinter(TestEventListeners* listeners, + TestEventListener* listener) { + listeners->SetDefaultResultPrinter(listener); + } + static void SetDefaultXmlGenerator(TestEventListeners* listeners, + TestEventListener* listener) { + listeners->SetDefaultXmlGenerator(listener); + } + + static bool EventForwardingEnabled(const TestEventListeners& listeners) { + return listeners.EventForwardingEnabled(); + } + + static void SuppressEventForwarding(TestEventListeners* listeners) { + listeners->SuppressEventForwarding(); + } +}; + +class UnitTestRecordPropertyTestHelper : public Test { + protected: + UnitTestRecordPropertyTestHelper() {} + + // Forwards to UnitTest::RecordProperty() to bypass access controls. + void UnitTestRecordProperty(const char* key, const std::string& value) { + unit_test_.RecordProperty(key, value); + } + + UnitTest unit_test_; +}; + +} // namespace internal +} // namespace testing + +using testing::AssertionFailure; +using testing::AssertionResult; +using testing::AssertionSuccess; +using testing::DoubleLE; +using testing::EmptyTestEventListener; +using testing::Environment; +using testing::FloatLE; +using testing::IsNotSubstring; +using testing::IsSubstring; +using testing::kMaxStackTraceDepth; +using testing::Message; +using testing::ScopedFakeTestPartResultReporter; +using testing::StaticAssertTypeEq; +using testing::Test; +using testing::TestEventListeners; +using testing::TestInfo; +using testing::TestPartResult; +using testing::TestPartResultArray; +using testing::TestProperty; +using testing::TestResult; +using testing::TestSuite; +using testing::TimeInMillis; +using testing::UnitTest; +using testing::internal::AlwaysFalse; +using testing::internal::AlwaysTrue; +using testing::internal::AppendUserMessage; +using testing::internal::ArrayAwareFind; +using testing::internal::ArrayEq; +using testing::internal::CodePointToUtf8; +using testing::internal::CopyArray; +using testing::internal::CountIf; +using testing::internal::EqFailure; +using testing::internal::FloatingPoint; +using testing::internal::ForEach; +using testing::internal::FormatEpochTimeInMillisAsIso8601; +using testing::internal::FormatTimeInMillisAsSeconds; +using testing::internal::GetCurrentOsStackTraceExceptTop; +using testing::internal::GetElementOr; +using testing::internal::GetNextRandomSeed; +using testing::internal::GetRandomSeedFromFlag; +using testing::internal::GetTestTypeId; +using testing::internal::GetTimeInMillis; +using testing::internal::GetTypeId; +using testing::internal::GetUnitTestImpl; +using testing::internal::GTestFlagSaver; +using testing::internal::HasDebugStringAndShortDebugString; +using testing::internal::Int32FromEnvOrDie; +using testing::internal::IsContainer; +using testing::internal::IsContainerTest; +using testing::internal::IsNotContainer; +using testing::internal::kMaxRandomSeed; +using testing::internal::kTestTypeIdInGoogleTest; +using testing::internal::NativeArray; +using testing::internal::OsStackTraceGetter; +using testing::internal::OsStackTraceGetterInterface; +using testing::internal::ParseFlag; +using testing::internal::RelationToSourceCopy; +using testing::internal::RelationToSourceReference; +using testing::internal::ShouldRunTestOnShard; +using testing::internal::ShouldShard; +using testing::internal::ShouldUseColor; +using testing::internal::Shuffle; +using testing::internal::ShuffleRange; +using testing::internal::SkipPrefix; +using testing::internal::StreamableToString; +using testing::internal::String; +using testing::internal::TestEventListenersAccessor; +using testing::internal::TestResultAccessor; +using testing::internal::UnitTestImpl; +using testing::internal::WideStringToUtf8; +using testing::internal::edit_distance::CalculateOptimalEdits; +using testing::internal::edit_distance::CreateUnifiedDiff; +using testing::internal::edit_distance::EditType; + +#if GTEST_HAS_STREAM_REDIRECTION +using testing::internal::CaptureStdout; +using testing::internal::GetCapturedStdout; +#endif + +#if GTEST_IS_THREADSAFE +using testing::internal::ThreadWithParam; +#endif + +class TestingVector : public std::vector { +}; + +::std::ostream& operator<<(::std::ostream& os, + const TestingVector& vector) { + os << "{ "; + for (size_t i = 0; i < vector.size(); i++) { + os << vector[i] << " "; + } + os << "}"; + return os; +} + +// This line tests that we can define tests in an unnamed namespace. +namespace { + +TEST(GetRandomSeedFromFlagTest, HandlesZero) { + const int seed = GetRandomSeedFromFlag(0); + EXPECT_LE(1, seed); + EXPECT_LE(seed, static_cast(kMaxRandomSeed)); +} + +TEST(GetRandomSeedFromFlagTest, PreservesValidSeed) { + EXPECT_EQ(1, GetRandomSeedFromFlag(1)); + EXPECT_EQ(2, GetRandomSeedFromFlag(2)); + EXPECT_EQ(kMaxRandomSeed - 1, GetRandomSeedFromFlag(kMaxRandomSeed - 1)); + EXPECT_EQ(static_cast(kMaxRandomSeed), + GetRandomSeedFromFlag(kMaxRandomSeed)); +} + +TEST(GetRandomSeedFromFlagTest, NormalizesInvalidSeed) { + const int seed1 = GetRandomSeedFromFlag(-1); + EXPECT_LE(1, seed1); + EXPECT_LE(seed1, static_cast(kMaxRandomSeed)); + + const int seed2 = GetRandomSeedFromFlag(kMaxRandomSeed + 1); + EXPECT_LE(1, seed2); + EXPECT_LE(seed2, static_cast(kMaxRandomSeed)); +} + +TEST(GetNextRandomSeedTest, WorksForValidInput) { + EXPECT_EQ(2, GetNextRandomSeed(1)); + EXPECT_EQ(3, GetNextRandomSeed(2)); + EXPECT_EQ(static_cast(kMaxRandomSeed), + GetNextRandomSeed(kMaxRandomSeed - 1)); + EXPECT_EQ(1, GetNextRandomSeed(kMaxRandomSeed)); + + // We deliberately don't test GetNextRandomSeed() with invalid + // inputs, as that requires death tests, which are expensive. This + // is fine as GetNextRandomSeed() is internal and has a + // straightforward definition. +} + +static void ClearCurrentTestPartResults() { + TestResultAccessor::ClearTestPartResults( + GetUnitTestImpl()->current_test_result()); +} + +// Tests GetTypeId. + +TEST(GetTypeIdTest, ReturnsSameValueForSameType) { + EXPECT_EQ(GetTypeId(), GetTypeId()); + EXPECT_EQ(GetTypeId(), GetTypeId()); +} + +class SubClassOfTest : public Test {}; +class AnotherSubClassOfTest : public Test {}; + +TEST(GetTypeIdTest, ReturnsDifferentValuesForDifferentTypes) { + EXPECT_NE(GetTypeId(), GetTypeId()); + EXPECT_NE(GetTypeId(), GetTypeId()); + EXPECT_NE(GetTypeId(), GetTestTypeId()); + EXPECT_NE(GetTypeId(), GetTestTypeId()); + EXPECT_NE(GetTypeId(), GetTestTypeId()); + EXPECT_NE(GetTypeId(), GetTypeId()); +} + +// Verifies that GetTestTypeId() returns the same value, no matter it +// is called from inside Google Test or outside of it. +TEST(GetTestTypeIdTest, ReturnsTheSameValueInsideOrOutsideOfGoogleTest) { + EXPECT_EQ(kTestTypeIdInGoogleTest, GetTestTypeId()); +} + +// Tests CanonicalizeForStdLibVersioning. + +using ::testing::internal::CanonicalizeForStdLibVersioning; + +TEST(CanonicalizeForStdLibVersioning, LeavesUnversionedNamesUnchanged) { + EXPECT_EQ("std::bind", CanonicalizeForStdLibVersioning("std::bind")); + EXPECT_EQ("std::_", CanonicalizeForStdLibVersioning("std::_")); + EXPECT_EQ("std::__foo", CanonicalizeForStdLibVersioning("std::__foo")); + EXPECT_EQ("gtl::__1::x", CanonicalizeForStdLibVersioning("gtl::__1::x")); + EXPECT_EQ("__1::x", CanonicalizeForStdLibVersioning("__1::x")); + EXPECT_EQ("::__1::x", CanonicalizeForStdLibVersioning("::__1::x")); +} + +TEST(CanonicalizeForStdLibVersioning, ElidesDoubleUnderNames) { + EXPECT_EQ("std::bind", CanonicalizeForStdLibVersioning("std::__1::bind")); + EXPECT_EQ("std::_", CanonicalizeForStdLibVersioning("std::__1::_")); + + EXPECT_EQ("std::bind", CanonicalizeForStdLibVersioning("std::__g::bind")); + EXPECT_EQ("std::_", CanonicalizeForStdLibVersioning("std::__g::_")); + + EXPECT_EQ("std::bind", + CanonicalizeForStdLibVersioning("std::__google::bind")); + EXPECT_EQ("std::_", CanonicalizeForStdLibVersioning("std::__google::_")); +} + +// Tests FormatTimeInMillisAsSeconds(). + +TEST(FormatTimeInMillisAsSecondsTest, FormatsZero) { + EXPECT_EQ("0", FormatTimeInMillisAsSeconds(0)); +} + +TEST(FormatTimeInMillisAsSecondsTest, FormatsPositiveNumber) { + EXPECT_EQ("0.003", FormatTimeInMillisAsSeconds(3)); + EXPECT_EQ("0.01", FormatTimeInMillisAsSeconds(10)); + EXPECT_EQ("0.2", FormatTimeInMillisAsSeconds(200)); + EXPECT_EQ("1.2", FormatTimeInMillisAsSeconds(1200)); + EXPECT_EQ("3", FormatTimeInMillisAsSeconds(3000)); +} + +TEST(FormatTimeInMillisAsSecondsTest, FormatsNegativeNumber) { + EXPECT_EQ("-0.003", FormatTimeInMillisAsSeconds(-3)); + EXPECT_EQ("-0.01", FormatTimeInMillisAsSeconds(-10)); + EXPECT_EQ("-0.2", FormatTimeInMillisAsSeconds(-200)); + EXPECT_EQ("-1.2", FormatTimeInMillisAsSeconds(-1200)); + EXPECT_EQ("-3", FormatTimeInMillisAsSeconds(-3000)); +} + +// Tests FormatEpochTimeInMillisAsIso8601(). The correctness of conversion +// for particular dates below was verified in Python using +// datetime.datetime.fromutctimestamp(/1000). + +// FormatEpochTimeInMillisAsIso8601 depends on the current timezone, so we +// have to set up a particular timezone to obtain predictable results. +class FormatEpochTimeInMillisAsIso8601Test : public Test { + public: + // On Cygwin, GCC doesn't allow unqualified integer literals to exceed + // 32 bits, even when 64-bit integer types are available. We have to + // force the constants to have a 64-bit type here. + static const TimeInMillis kMillisPerSec = 1000; + + private: + void SetUp() override { + saved_tz_ = nullptr; + + GTEST_DISABLE_MSC_DEPRECATED_PUSH_(/* getenv, strdup: deprecated */) + if (getenv("TZ")) + saved_tz_ = strdup(getenv("TZ")); + GTEST_DISABLE_MSC_DEPRECATED_POP_() + + // Set up the time zone for FormatEpochTimeInMillisAsIso8601 to use. We + // cannot use the local time zone because the function's output depends + // on the time zone. + SetTimeZone("UTC+00"); + } + + void TearDown() override { + SetTimeZone(saved_tz_); + free(const_cast(saved_tz_)); + saved_tz_ = nullptr; + } + + static void SetTimeZone(const char* time_zone) { + // tzset() distinguishes between the TZ variable being present and empty + // and not being present, so we have to consider the case of time_zone + // being NULL. +#if _MSC_VER || GTEST_OS_WINDOWS_MINGW + // ...Unless it's MSVC, whose standard library's _putenv doesn't + // distinguish between an empty and a missing variable. + const std::string env_var = + std::string("TZ=") + (time_zone ? time_zone : ""); + _putenv(env_var.c_str()); + GTEST_DISABLE_MSC_WARNINGS_PUSH_(4996 /* deprecated function */) + tzset(); + GTEST_DISABLE_MSC_WARNINGS_POP_() +#else +#if GTEST_OS_LINUX_ANDROID && __ANDROID_API__ < 21 + // Work around KitKat bug in tzset by setting "UTC" before setting "UTC+00". + // See https://github.com/android/ndk/issues/1604. + setenv("TZ", "UTC", 1); + tzset(); +#endif + if (time_zone) { + setenv(("TZ"), time_zone, 1); + } else { + unsetenv("TZ"); + } + tzset(); +#endif + } + + const char* saved_tz_; +}; + +const TimeInMillis FormatEpochTimeInMillisAsIso8601Test::kMillisPerSec; + +TEST_F(FormatEpochTimeInMillisAsIso8601Test, PrintsTwoDigitSegments) { + EXPECT_EQ("2011-10-31T18:52:42.000", + FormatEpochTimeInMillisAsIso8601(1320087162 * kMillisPerSec)); +} + +TEST_F(FormatEpochTimeInMillisAsIso8601Test, IncludesMillisecondsAfterDot) { + EXPECT_EQ( + "2011-10-31T18:52:42.234", + FormatEpochTimeInMillisAsIso8601(1320087162 * kMillisPerSec + 234)); +} + +TEST_F(FormatEpochTimeInMillisAsIso8601Test, PrintsLeadingZeroes) { + EXPECT_EQ("2011-09-03T05:07:02.000", + FormatEpochTimeInMillisAsIso8601(1315026422 * kMillisPerSec)); +} + +TEST_F(FormatEpochTimeInMillisAsIso8601Test, Prints24HourTime) { + EXPECT_EQ("2011-09-28T17:08:22.000", + FormatEpochTimeInMillisAsIso8601(1317229702 * kMillisPerSec)); +} + +TEST_F(FormatEpochTimeInMillisAsIso8601Test, PrintsEpochStart) { + EXPECT_EQ("1970-01-01T00:00:00.000", FormatEpochTimeInMillisAsIso8601(0)); +} + +# ifdef __BORLANDC__ +// Silences warnings: "Condition is always true", "Unreachable code" +# pragma option push -w-ccc -w-rch +# endif + +// Tests that the LHS of EXPECT_EQ or ASSERT_EQ can be used as a null literal +// when the RHS is a pointer type. +TEST(NullLiteralTest, LHSAllowsNullLiterals) { + EXPECT_EQ(0, static_cast(nullptr)); // NOLINT + ASSERT_EQ(0, static_cast(nullptr)); // NOLINT + EXPECT_EQ(NULL, static_cast(nullptr)); // NOLINT + ASSERT_EQ(NULL, static_cast(nullptr)); // NOLINT + EXPECT_EQ(nullptr, static_cast(nullptr)); + ASSERT_EQ(nullptr, static_cast(nullptr)); + + const int* const p = nullptr; + EXPECT_EQ(0, p); // NOLINT + ASSERT_EQ(0, p); // NOLINT + EXPECT_EQ(NULL, p); // NOLINT + ASSERT_EQ(NULL, p); // NOLINT + EXPECT_EQ(nullptr, p); + ASSERT_EQ(nullptr, p); +} + +struct ConvertToAll { + template + operator T() const { // NOLINT + return T(); + } +}; + +struct ConvertToPointer { + template + operator T*() const { // NOLINT + return nullptr; + } +}; + +struct ConvertToAllButNoPointers { + template ::value, int>::type = 0> + operator T() const { // NOLINT + return T(); + } +}; + +struct MyType {}; +inline bool operator==(MyType const&, MyType const&) { return true; } + +TEST(NullLiteralTest, ImplicitConversion) { + EXPECT_EQ(ConvertToPointer{}, static_cast(nullptr)); +#if !defined(__GNUC__) || defined(__clang__) + // Disabled due to GCC bug gcc.gnu.org/PR89580 + EXPECT_EQ(ConvertToAll{}, static_cast(nullptr)); +#endif + EXPECT_EQ(ConvertToAll{}, MyType{}); + EXPECT_EQ(ConvertToAllButNoPointers{}, MyType{}); +} + +#ifdef __clang__ +#pragma clang diagnostic push +#if __has_warning("-Wzero-as-null-pointer-constant") +#pragma clang diagnostic error "-Wzero-as-null-pointer-constant" +#endif +#endif + +TEST(NullLiteralTest, NoConversionNoWarning) { + // Test that gtests detection and handling of null pointer constants + // doesn't trigger a warning when '0' isn't actually used as null. + EXPECT_EQ(0, 0); + ASSERT_EQ(0, 0); +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +# ifdef __BORLANDC__ +// Restores warnings after previous "#pragma option push" suppressed them. +# pragma option pop +# endif + +// +// Tests CodePointToUtf8(). + +// Tests that the NUL character L'\0' is encoded correctly. +TEST(CodePointToUtf8Test, CanEncodeNul) { + EXPECT_EQ("", CodePointToUtf8(L'\0')); +} + +// Tests that ASCII characters are encoded correctly. +TEST(CodePointToUtf8Test, CanEncodeAscii) { + EXPECT_EQ("a", CodePointToUtf8(L'a')); + EXPECT_EQ("Z", CodePointToUtf8(L'Z')); + EXPECT_EQ("&", CodePointToUtf8(L'&')); + EXPECT_EQ("\x7F", CodePointToUtf8(L'\x7F')); +} + +// Tests that Unicode code-points that have 8 to 11 bits are encoded +// as 110xxxxx 10xxxxxx. +TEST(CodePointToUtf8Test, CanEncode8To11Bits) { + // 000 1101 0011 => 110-00011 10-010011 + EXPECT_EQ("\xC3\x93", CodePointToUtf8(L'\xD3')); + + // 101 0111 0110 => 110-10101 10-110110 + // Some compilers (e.g., GCC on MinGW) cannot handle non-ASCII codepoints + // in wide strings and wide chars. In order to accommodate them, we have to + // introduce such character constants as integers. + EXPECT_EQ("\xD5\xB6", + CodePointToUtf8(static_cast(0x576))); +} + +// Tests that Unicode code-points that have 12 to 16 bits are encoded +// as 1110xxxx 10xxxxxx 10xxxxxx. +TEST(CodePointToUtf8Test, CanEncode12To16Bits) { + // 0000 1000 1101 0011 => 1110-0000 10-100011 10-010011 + EXPECT_EQ("\xE0\xA3\x93", + CodePointToUtf8(static_cast(0x8D3))); + + // 1100 0111 0100 1101 => 1110-1100 10-011101 10-001101 + EXPECT_EQ("\xEC\x9D\x8D", + CodePointToUtf8(static_cast(0xC74D))); +} + +#if !GTEST_WIDE_STRING_USES_UTF16_ +// Tests in this group require a wchar_t to hold > 16 bits, and thus +// are skipped on Windows, and Cygwin, where a wchar_t is +// 16-bit wide. This code may not compile on those systems. + +// Tests that Unicode code-points that have 17 to 21 bits are encoded +// as 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx. +TEST(CodePointToUtf8Test, CanEncode17To21Bits) { + // 0 0001 0000 1000 1101 0011 => 11110-000 10-010000 10-100011 10-010011 + EXPECT_EQ("\xF0\x90\xA3\x93", CodePointToUtf8(L'\x108D3')); + + // 0 0001 0000 0100 0000 0000 => 11110-000 10-010000 10-010000 10-000000 + EXPECT_EQ("\xF0\x90\x90\x80", CodePointToUtf8(L'\x10400')); + + // 1 0000 1000 0110 0011 0100 => 11110-100 10-001000 10-011000 10-110100 + EXPECT_EQ("\xF4\x88\x98\xB4", CodePointToUtf8(L'\x108634')); +} + +// Tests that encoding an invalid code-point generates the expected result. +TEST(CodePointToUtf8Test, CanEncodeInvalidCodePoint) { + EXPECT_EQ("(Invalid Unicode 0x1234ABCD)", CodePointToUtf8(L'\x1234ABCD')); +} + +#endif // !GTEST_WIDE_STRING_USES_UTF16_ + +// Tests WideStringToUtf8(). + +// Tests that the NUL character L'\0' is encoded correctly. +TEST(WideStringToUtf8Test, CanEncodeNul) { + EXPECT_STREQ("", WideStringToUtf8(L"", 0).c_str()); + EXPECT_STREQ("", WideStringToUtf8(L"", -1).c_str()); +} + +// Tests that ASCII strings are encoded correctly. +TEST(WideStringToUtf8Test, CanEncodeAscii) { + EXPECT_STREQ("a", WideStringToUtf8(L"a", 1).c_str()); + EXPECT_STREQ("ab", WideStringToUtf8(L"ab", 2).c_str()); + EXPECT_STREQ("a", WideStringToUtf8(L"a", -1).c_str()); + EXPECT_STREQ("ab", WideStringToUtf8(L"ab", -1).c_str()); +} + +// Tests that Unicode code-points that have 8 to 11 bits are encoded +// as 110xxxxx 10xxxxxx. +TEST(WideStringToUtf8Test, CanEncode8To11Bits) { + // 000 1101 0011 => 110-00011 10-010011 + EXPECT_STREQ("\xC3\x93", WideStringToUtf8(L"\xD3", 1).c_str()); + EXPECT_STREQ("\xC3\x93", WideStringToUtf8(L"\xD3", -1).c_str()); + + // 101 0111 0110 => 110-10101 10-110110 + const wchar_t s[] = { 0x576, '\0' }; + EXPECT_STREQ("\xD5\xB6", WideStringToUtf8(s, 1).c_str()); + EXPECT_STREQ("\xD5\xB6", WideStringToUtf8(s, -1).c_str()); +} + +// Tests that Unicode code-points that have 12 to 16 bits are encoded +// as 1110xxxx 10xxxxxx 10xxxxxx. +TEST(WideStringToUtf8Test, CanEncode12To16Bits) { + // 0000 1000 1101 0011 => 1110-0000 10-100011 10-010011 + const wchar_t s1[] = { 0x8D3, '\0' }; + EXPECT_STREQ("\xE0\xA3\x93", WideStringToUtf8(s1, 1).c_str()); + EXPECT_STREQ("\xE0\xA3\x93", WideStringToUtf8(s1, -1).c_str()); + + // 1100 0111 0100 1101 => 1110-1100 10-011101 10-001101 + const wchar_t s2[] = { 0xC74D, '\0' }; + EXPECT_STREQ("\xEC\x9D\x8D", WideStringToUtf8(s2, 1).c_str()); + EXPECT_STREQ("\xEC\x9D\x8D", WideStringToUtf8(s2, -1).c_str()); +} + +// Tests that the conversion stops when the function encounters \0 character. +TEST(WideStringToUtf8Test, StopsOnNulCharacter) { + EXPECT_STREQ("ABC", WideStringToUtf8(L"ABC\0XYZ", 100).c_str()); +} + +// Tests that the conversion stops when the function reaches the limit +// specified by the 'length' parameter. +TEST(WideStringToUtf8Test, StopsWhenLengthLimitReached) { + EXPECT_STREQ("ABC", WideStringToUtf8(L"ABCDEF", 3).c_str()); +} + +#if !GTEST_WIDE_STRING_USES_UTF16_ +// Tests that Unicode code-points that have 17 to 21 bits are encoded +// as 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx. This code may not compile +// on the systems using UTF-16 encoding. +TEST(WideStringToUtf8Test, CanEncode17To21Bits) { + // 0 0001 0000 1000 1101 0011 => 11110-000 10-010000 10-100011 10-010011 + EXPECT_STREQ("\xF0\x90\xA3\x93", WideStringToUtf8(L"\x108D3", 1).c_str()); + EXPECT_STREQ("\xF0\x90\xA3\x93", WideStringToUtf8(L"\x108D3", -1).c_str()); + + // 1 0000 1000 0110 0011 0100 => 11110-100 10-001000 10-011000 10-110100 + EXPECT_STREQ("\xF4\x88\x98\xB4", WideStringToUtf8(L"\x108634", 1).c_str()); + EXPECT_STREQ("\xF4\x88\x98\xB4", WideStringToUtf8(L"\x108634", -1).c_str()); +} + +// Tests that encoding an invalid code-point generates the expected result. +TEST(WideStringToUtf8Test, CanEncodeInvalidCodePoint) { + EXPECT_STREQ("(Invalid Unicode 0xABCDFF)", + WideStringToUtf8(L"\xABCDFF", -1).c_str()); +} +#else // !GTEST_WIDE_STRING_USES_UTF16_ +// Tests that surrogate pairs are encoded correctly on the systems using +// UTF-16 encoding in the wide strings. +TEST(WideStringToUtf8Test, CanEncodeValidUtf16SUrrogatePairs) { + const wchar_t s[] = { 0xD801, 0xDC00, '\0' }; + EXPECT_STREQ("\xF0\x90\x90\x80", WideStringToUtf8(s, -1).c_str()); +} + +// Tests that encoding an invalid UTF-16 surrogate pair +// generates the expected result. +TEST(WideStringToUtf8Test, CanEncodeInvalidUtf16SurrogatePair) { + // Leading surrogate is at the end of the string. + const wchar_t s1[] = { 0xD800, '\0' }; + EXPECT_STREQ("\xED\xA0\x80", WideStringToUtf8(s1, -1).c_str()); + // Leading surrogate is not followed by the trailing surrogate. + const wchar_t s2[] = { 0xD800, 'M', '\0' }; + EXPECT_STREQ("\xED\xA0\x80M", WideStringToUtf8(s2, -1).c_str()); + // Trailing surrogate appearas without a leading surrogate. + const wchar_t s3[] = { 0xDC00, 'P', 'Q', 'R', '\0' }; + EXPECT_STREQ("\xED\xB0\x80PQR", WideStringToUtf8(s3, -1).c_str()); +} +#endif // !GTEST_WIDE_STRING_USES_UTF16_ + +// Tests that codepoint concatenation works correctly. +#if !GTEST_WIDE_STRING_USES_UTF16_ +TEST(WideStringToUtf8Test, ConcatenatesCodepointsCorrectly) { + const wchar_t s[] = { 0x108634, 0xC74D, '\n', 0x576, 0x8D3, 0x108634, '\0'}; + EXPECT_STREQ( + "\xF4\x88\x98\xB4" + "\xEC\x9D\x8D" + "\n" + "\xD5\xB6" + "\xE0\xA3\x93" + "\xF4\x88\x98\xB4", + WideStringToUtf8(s, -1).c_str()); +} +#else +TEST(WideStringToUtf8Test, ConcatenatesCodepointsCorrectly) { + const wchar_t s[] = { 0xC74D, '\n', 0x576, 0x8D3, '\0'}; + EXPECT_STREQ( + "\xEC\x9D\x8D" "\n" "\xD5\xB6" "\xE0\xA3\x93", + WideStringToUtf8(s, -1).c_str()); +} +#endif // !GTEST_WIDE_STRING_USES_UTF16_ + +// Tests the Random class. + +TEST(RandomDeathTest, GeneratesCrashesOnInvalidRange) { + testing::internal::Random random(42); + EXPECT_DEATH_IF_SUPPORTED( + random.Generate(0), + "Cannot generate a number in the range \\[0, 0\\)"); + EXPECT_DEATH_IF_SUPPORTED( + random.Generate(testing::internal::Random::kMaxRange + 1), + "Generation of a number in \\[0, 2147483649\\) was requested, " + "but this can only generate numbers in \\[0, 2147483648\\)"); +} + +TEST(RandomTest, GeneratesNumbersWithinRange) { + constexpr uint32_t kRange = 10000; + testing::internal::Random random(12345); + for (int i = 0; i < 10; i++) { + EXPECT_LT(random.Generate(kRange), kRange) << " for iteration " << i; + } + + testing::internal::Random random2(testing::internal::Random::kMaxRange); + for (int i = 0; i < 10; i++) { + EXPECT_LT(random2.Generate(kRange), kRange) << " for iteration " << i; + } +} + +TEST(RandomTest, RepeatsWhenReseeded) { + constexpr int kSeed = 123; + constexpr int kArraySize = 10; + constexpr uint32_t kRange = 10000; + uint32_t values[kArraySize]; + + testing::internal::Random random(kSeed); + for (int i = 0; i < kArraySize; i++) { + values[i] = random.Generate(kRange); + } + + random.Reseed(kSeed); + for (int i = 0; i < kArraySize; i++) { + EXPECT_EQ(values[i], random.Generate(kRange)) << " for iteration " << i; + } +} + +// Tests STL container utilities. + +// Tests CountIf(). + +static bool IsPositive(int n) { return n > 0; } + +TEST(ContainerUtilityTest, CountIf) { + std::vector v; + EXPECT_EQ(0, CountIf(v, IsPositive)); // Works for an empty container. + + v.push_back(-1); + v.push_back(0); + EXPECT_EQ(0, CountIf(v, IsPositive)); // Works when no value satisfies. + + v.push_back(2); + v.push_back(-10); + v.push_back(10); + EXPECT_EQ(2, CountIf(v, IsPositive)); +} + +// Tests ForEach(). + +static int g_sum = 0; +static void Accumulate(int n) { g_sum += n; } + +TEST(ContainerUtilityTest, ForEach) { + std::vector v; + g_sum = 0; + ForEach(v, Accumulate); + EXPECT_EQ(0, g_sum); // Works for an empty container; + + g_sum = 0; + v.push_back(1); + ForEach(v, Accumulate); + EXPECT_EQ(1, g_sum); // Works for a container with one element. + + g_sum = 0; + v.push_back(20); + v.push_back(300); + ForEach(v, Accumulate); + EXPECT_EQ(321, g_sum); +} + +// Tests GetElementOr(). +TEST(ContainerUtilityTest, GetElementOr) { + std::vector a; + EXPECT_EQ('x', GetElementOr(a, 0, 'x')); + + a.push_back('a'); + a.push_back('b'); + EXPECT_EQ('a', GetElementOr(a, 0, 'x')); + EXPECT_EQ('b', GetElementOr(a, 1, 'x')); + EXPECT_EQ('x', GetElementOr(a, -2, 'x')); + EXPECT_EQ('x', GetElementOr(a, 2, 'x')); +} + +TEST(ContainerUtilityDeathTest, ShuffleRange) { + std::vector a; + a.push_back(0); + a.push_back(1); + a.push_back(2); + testing::internal::Random random(1); + + EXPECT_DEATH_IF_SUPPORTED( + ShuffleRange(&random, -1, 1, &a), + "Invalid shuffle range start -1: must be in range \\[0, 3\\]"); + EXPECT_DEATH_IF_SUPPORTED( + ShuffleRange(&random, 4, 4, &a), + "Invalid shuffle range start 4: must be in range \\[0, 3\\]"); + EXPECT_DEATH_IF_SUPPORTED( + ShuffleRange(&random, 3, 2, &a), + "Invalid shuffle range finish 2: must be in range \\[3, 3\\]"); + EXPECT_DEATH_IF_SUPPORTED( + ShuffleRange(&random, 3, 4, &a), + "Invalid shuffle range finish 4: must be in range \\[3, 3\\]"); +} + +class VectorShuffleTest : public Test { + protected: + static const size_t kVectorSize = 20; + + VectorShuffleTest() : random_(1) { + for (int i = 0; i < static_cast(kVectorSize); i++) { + vector_.push_back(i); + } + } + + static bool VectorIsCorrupt(const TestingVector& vector) { + if (kVectorSize != vector.size()) { + return true; + } + + bool found_in_vector[kVectorSize] = { false }; + for (size_t i = 0; i < vector.size(); i++) { + const int e = vector[i]; + if (e < 0 || e >= static_cast(kVectorSize) || found_in_vector[e]) { + return true; + } + found_in_vector[e] = true; + } + + // Vector size is correct, elements' range is correct, no + // duplicate elements. Therefore no corruption has occurred. + return false; + } + + static bool VectorIsNotCorrupt(const TestingVector& vector) { + return !VectorIsCorrupt(vector); + } + + static bool RangeIsShuffled(const TestingVector& vector, int begin, int end) { + for (int i = begin; i < end; i++) { + if (i != vector[static_cast(i)]) { + return true; + } + } + return false; + } + + static bool RangeIsUnshuffled( + const TestingVector& vector, int begin, int end) { + return !RangeIsShuffled(vector, begin, end); + } + + static bool VectorIsShuffled(const TestingVector& vector) { + return RangeIsShuffled(vector, 0, static_cast(vector.size())); + } + + static bool VectorIsUnshuffled(const TestingVector& vector) { + return !VectorIsShuffled(vector); + } + + testing::internal::Random random_; + TestingVector vector_; +}; // class VectorShuffleTest + +const size_t VectorShuffleTest::kVectorSize; + +TEST_F(VectorShuffleTest, HandlesEmptyRange) { + // Tests an empty range at the beginning... + ShuffleRange(&random_, 0, 0, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsUnshuffled, vector_); + + // ...in the middle... + ShuffleRange(&random_, kVectorSize/2, kVectorSize/2, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsUnshuffled, vector_); + + // ...at the end... + ShuffleRange(&random_, kVectorSize - 1, kVectorSize - 1, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsUnshuffled, vector_); + + // ...and past the end. + ShuffleRange(&random_, kVectorSize, kVectorSize, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsUnshuffled, vector_); +} + +TEST_F(VectorShuffleTest, HandlesRangeOfSizeOne) { + // Tests a size one range at the beginning... + ShuffleRange(&random_, 0, 1, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsUnshuffled, vector_); + + // ...in the middle... + ShuffleRange(&random_, kVectorSize/2, kVectorSize/2 + 1, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsUnshuffled, vector_); + + // ...and at the end. + ShuffleRange(&random_, kVectorSize - 1, kVectorSize, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsUnshuffled, vector_); +} + +// Because we use our own random number generator and a fixed seed, +// we can guarantee that the following "random" tests will succeed. + +TEST_F(VectorShuffleTest, ShufflesEntireVector) { + Shuffle(&random_, &vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + EXPECT_FALSE(VectorIsUnshuffled(vector_)) << vector_; + + // Tests the first and last elements in particular to ensure that + // there are no off-by-one problems in our shuffle algorithm. + EXPECT_NE(0, vector_[0]); + EXPECT_NE(static_cast(kVectorSize - 1), vector_[kVectorSize - 1]); +} + +TEST_F(VectorShuffleTest, ShufflesStartOfVector) { + const int kRangeSize = kVectorSize/2; + + ShuffleRange(&random_, 0, kRangeSize, &vector_); + + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + EXPECT_PRED3(RangeIsShuffled, vector_, 0, kRangeSize); + EXPECT_PRED3(RangeIsUnshuffled, vector_, kRangeSize, + static_cast(kVectorSize)); +} + +TEST_F(VectorShuffleTest, ShufflesEndOfVector) { + const int kRangeSize = kVectorSize / 2; + ShuffleRange(&random_, kRangeSize, kVectorSize, &vector_); + + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + EXPECT_PRED3(RangeIsUnshuffled, vector_, 0, kRangeSize); + EXPECT_PRED3(RangeIsShuffled, vector_, kRangeSize, + static_cast(kVectorSize)); +} + +TEST_F(VectorShuffleTest, ShufflesMiddleOfVector) { + const int kRangeSize = static_cast(kVectorSize) / 3; + ShuffleRange(&random_, kRangeSize, 2*kRangeSize, &vector_); + + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + EXPECT_PRED3(RangeIsUnshuffled, vector_, 0, kRangeSize); + EXPECT_PRED3(RangeIsShuffled, vector_, kRangeSize, 2*kRangeSize); + EXPECT_PRED3(RangeIsUnshuffled, vector_, 2 * kRangeSize, + static_cast(kVectorSize)); +} + +TEST_F(VectorShuffleTest, ShufflesRepeatably) { + TestingVector vector2; + for (size_t i = 0; i < kVectorSize; i++) { + vector2.push_back(static_cast(i)); + } + + random_.Reseed(1234); + Shuffle(&random_, &vector_); + random_.Reseed(1234); + Shuffle(&random_, &vector2); + + ASSERT_PRED1(VectorIsNotCorrupt, vector_); + ASSERT_PRED1(VectorIsNotCorrupt, vector2); + + for (size_t i = 0; i < kVectorSize; i++) { + EXPECT_EQ(vector_[i], vector2[i]) << " where i is " << i; + } +} + +// Tests the size of the AssertHelper class. + +TEST(AssertHelperTest, AssertHelperIsSmall) { + // To avoid breaking clients that use lots of assertions in one + // function, we cannot grow the size of AssertHelper. + EXPECT_LE(sizeof(testing::internal::AssertHelper), sizeof(void*)); +} + +// Tests String::EndsWithCaseInsensitive(). +TEST(StringTest, EndsWithCaseInsensitive) { + EXPECT_TRUE(String::EndsWithCaseInsensitive("foobar", "BAR")); + EXPECT_TRUE(String::EndsWithCaseInsensitive("foobaR", "bar")); + EXPECT_TRUE(String::EndsWithCaseInsensitive("foobar", "")); + EXPECT_TRUE(String::EndsWithCaseInsensitive("", "")); + + EXPECT_FALSE(String::EndsWithCaseInsensitive("Foobar", "foo")); + EXPECT_FALSE(String::EndsWithCaseInsensitive("foobar", "Foo")); + EXPECT_FALSE(String::EndsWithCaseInsensitive("", "foo")); +} + +// C++Builder's preprocessor is buggy; it fails to expand macros that +// appear in macro parameters after wide char literals. Provide an alias +// for NULL as a workaround. +static const wchar_t* const kNull = nullptr; + +// Tests String::CaseInsensitiveWideCStringEquals +TEST(StringTest, CaseInsensitiveWideCStringEquals) { + EXPECT_TRUE(String::CaseInsensitiveWideCStringEquals(nullptr, nullptr)); + EXPECT_FALSE(String::CaseInsensitiveWideCStringEquals(kNull, L"")); + EXPECT_FALSE(String::CaseInsensitiveWideCStringEquals(L"", kNull)); + EXPECT_FALSE(String::CaseInsensitiveWideCStringEquals(kNull, L"foobar")); + EXPECT_FALSE(String::CaseInsensitiveWideCStringEquals(L"foobar", kNull)); + EXPECT_TRUE(String::CaseInsensitiveWideCStringEquals(L"foobar", L"foobar")); + EXPECT_TRUE(String::CaseInsensitiveWideCStringEquals(L"foobar", L"FOOBAR")); + EXPECT_TRUE(String::CaseInsensitiveWideCStringEquals(L"FOOBAR", L"foobar")); +} + +#if GTEST_OS_WINDOWS + +// Tests String::ShowWideCString(). +TEST(StringTest, ShowWideCString) { + EXPECT_STREQ("(null)", + String::ShowWideCString(NULL).c_str()); + EXPECT_STREQ("", String::ShowWideCString(L"").c_str()); + EXPECT_STREQ("foo", String::ShowWideCString(L"foo").c_str()); +} + +# if GTEST_OS_WINDOWS_MOBILE +TEST(StringTest, AnsiAndUtf16Null) { + EXPECT_EQ(NULL, String::AnsiToUtf16(NULL)); + EXPECT_EQ(NULL, String::Utf16ToAnsi(NULL)); +} + +TEST(StringTest, AnsiAndUtf16ConvertBasic) { + const char* ansi = String::Utf16ToAnsi(L"str"); + EXPECT_STREQ("str", ansi); + delete [] ansi; + const WCHAR* utf16 = String::AnsiToUtf16("str"); + EXPECT_EQ(0, wcsncmp(L"str", utf16, 3)); + delete [] utf16; +} + +TEST(StringTest, AnsiAndUtf16ConvertPathChars) { + const char* ansi = String::Utf16ToAnsi(L".:\\ \"*?"); + EXPECT_STREQ(".:\\ \"*?", ansi); + delete [] ansi; + const WCHAR* utf16 = String::AnsiToUtf16(".:\\ \"*?"); + EXPECT_EQ(0, wcsncmp(L".:\\ \"*?", utf16, 3)); + delete [] utf16; +} +# endif // GTEST_OS_WINDOWS_MOBILE + +#endif // GTEST_OS_WINDOWS + +// Tests TestProperty construction. +TEST(TestPropertyTest, StringValue) { + TestProperty property("key", "1"); + EXPECT_STREQ("key", property.key()); + EXPECT_STREQ("1", property.value()); +} + +// Tests TestProperty replacing a value. +TEST(TestPropertyTest, ReplaceStringValue) { + TestProperty property("key", "1"); + EXPECT_STREQ("1", property.value()); + property.SetValue("2"); + EXPECT_STREQ("2", property.value()); +} + +// AddFatalFailure() and AddNonfatalFailure() must be stand-alone +// functions (i.e. their definitions cannot be inlined at the call +// sites), or C++Builder won't compile the code. +static void AddFatalFailure() { + FAIL() << "Expected fatal failure."; +} + +static void AddNonfatalFailure() { + ADD_FAILURE() << "Expected non-fatal failure."; +} + +class ScopedFakeTestPartResultReporterTest : public Test { + public: // Must be public and not protected due to a bug in g++ 3.4.2. + enum FailureMode { + FATAL_FAILURE, + NONFATAL_FAILURE + }; + static void AddFailure(FailureMode failure) { + if (failure == FATAL_FAILURE) { + AddFatalFailure(); + } else { + AddNonfatalFailure(); + } + } +}; + +// Tests that ScopedFakeTestPartResultReporter intercepts test +// failures. +TEST_F(ScopedFakeTestPartResultReporterTest, InterceptsTestFailures) { + TestPartResultArray results; + { + ScopedFakeTestPartResultReporter reporter( + ScopedFakeTestPartResultReporter::INTERCEPT_ONLY_CURRENT_THREAD, + &results); + AddFailure(NONFATAL_FAILURE); + AddFailure(FATAL_FAILURE); + } + + EXPECT_EQ(2, results.size()); + EXPECT_TRUE(results.GetTestPartResult(0).nonfatally_failed()); + EXPECT_TRUE(results.GetTestPartResult(1).fatally_failed()); +} + +TEST_F(ScopedFakeTestPartResultReporterTest, DeprecatedConstructor) { + TestPartResultArray results; + { + // Tests, that the deprecated constructor still works. + ScopedFakeTestPartResultReporter reporter(&results); + AddFailure(NONFATAL_FAILURE); + } + EXPECT_EQ(1, results.size()); +} + +#if GTEST_IS_THREADSAFE + +class ScopedFakeTestPartResultReporterWithThreadsTest + : public ScopedFakeTestPartResultReporterTest { + protected: + static void AddFailureInOtherThread(FailureMode failure) { + ThreadWithParam thread(&AddFailure, failure, nullptr); + thread.Join(); + } +}; + +TEST_F(ScopedFakeTestPartResultReporterWithThreadsTest, + InterceptsTestFailuresInAllThreads) { + TestPartResultArray results; + { + ScopedFakeTestPartResultReporter reporter( + ScopedFakeTestPartResultReporter::INTERCEPT_ALL_THREADS, &results); + AddFailure(NONFATAL_FAILURE); + AddFailure(FATAL_FAILURE); + AddFailureInOtherThread(NONFATAL_FAILURE); + AddFailureInOtherThread(FATAL_FAILURE); + } + + EXPECT_EQ(4, results.size()); + EXPECT_TRUE(results.GetTestPartResult(0).nonfatally_failed()); + EXPECT_TRUE(results.GetTestPartResult(1).fatally_failed()); + EXPECT_TRUE(results.GetTestPartResult(2).nonfatally_failed()); + EXPECT_TRUE(results.GetTestPartResult(3).fatally_failed()); +} + +#endif // GTEST_IS_THREADSAFE + +// Tests EXPECT_FATAL_FAILURE{,ON_ALL_THREADS}. Makes sure that they +// work even if the failure is generated in a called function rather than +// the current context. + +typedef ScopedFakeTestPartResultReporterTest ExpectFatalFailureTest; + +TEST_F(ExpectFatalFailureTest, CatchesFatalFaliure) { + EXPECT_FATAL_FAILURE(AddFatalFailure(), "Expected fatal failure."); +} + +TEST_F(ExpectFatalFailureTest, AcceptsStdStringObject) { + EXPECT_FATAL_FAILURE(AddFatalFailure(), + ::std::string("Expected fatal failure.")); +} + +TEST_F(ExpectFatalFailureTest, CatchesFatalFailureOnAllThreads) { + // We have another test below to verify that the macro catches fatal + // failures generated on another thread. + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(AddFatalFailure(), + "Expected fatal failure."); +} + +#ifdef __BORLANDC__ +// Silences warnings: "Condition is always true" +# pragma option push -w-ccc +#endif + +// Tests that EXPECT_FATAL_FAILURE() can be used in a non-void +// function even when the statement in it contains ASSERT_*. + +int NonVoidFunction() { + EXPECT_FATAL_FAILURE(ASSERT_TRUE(false), ""); + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(FAIL(), ""); + return 0; +} + +TEST_F(ExpectFatalFailureTest, CanBeUsedInNonVoidFunction) { + NonVoidFunction(); +} + +// Tests that EXPECT_FATAL_FAILURE(statement, ...) doesn't abort the +// current function even though 'statement' generates a fatal failure. + +void DoesNotAbortHelper(bool* aborted) { + EXPECT_FATAL_FAILURE(ASSERT_TRUE(false), ""); + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(FAIL(), ""); + + *aborted = false; +} + +#ifdef __BORLANDC__ +// Restores warnings after previous "#pragma option push" suppressed them. +# pragma option pop +#endif + +TEST_F(ExpectFatalFailureTest, DoesNotAbort) { + bool aborted = true; + DoesNotAbortHelper(&aborted); + EXPECT_FALSE(aborted); +} + +// Tests that the EXPECT_FATAL_FAILURE{,_ON_ALL_THREADS} accepts a +// statement that contains a macro which expands to code containing an +// unprotected comma. + +static int global_var = 0; +#define GTEST_USE_UNPROTECTED_COMMA_ global_var++, global_var++ + +TEST_F(ExpectFatalFailureTest, AcceptsMacroThatExpandsToUnprotectedComma) { +#ifndef __BORLANDC__ + // ICE's in C++Builder. + EXPECT_FATAL_FAILURE({ + GTEST_USE_UNPROTECTED_COMMA_; + AddFatalFailure(); + }, ""); +#endif + + EXPECT_FATAL_FAILURE_ON_ALL_THREADS({ + GTEST_USE_UNPROTECTED_COMMA_; + AddFatalFailure(); + }, ""); +} + +// Tests EXPECT_NONFATAL_FAILURE{,ON_ALL_THREADS}. + +typedef ScopedFakeTestPartResultReporterTest ExpectNonfatalFailureTest; + +TEST_F(ExpectNonfatalFailureTest, CatchesNonfatalFailure) { + EXPECT_NONFATAL_FAILURE(AddNonfatalFailure(), + "Expected non-fatal failure."); +} + +TEST_F(ExpectNonfatalFailureTest, AcceptsStdStringObject) { + EXPECT_NONFATAL_FAILURE(AddNonfatalFailure(), + ::std::string("Expected non-fatal failure.")); +} + +TEST_F(ExpectNonfatalFailureTest, CatchesNonfatalFailureOnAllThreads) { + // We have another test below to verify that the macro catches + // non-fatal failures generated on another thread. + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(AddNonfatalFailure(), + "Expected non-fatal failure."); +} + +// Tests that the EXPECT_NONFATAL_FAILURE{,_ON_ALL_THREADS} accepts a +// statement that contains a macro which expands to code containing an +// unprotected comma. +TEST_F(ExpectNonfatalFailureTest, AcceptsMacroThatExpandsToUnprotectedComma) { + EXPECT_NONFATAL_FAILURE({ + GTEST_USE_UNPROTECTED_COMMA_; + AddNonfatalFailure(); + }, ""); + + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS({ + GTEST_USE_UNPROTECTED_COMMA_; + AddNonfatalFailure(); + }, ""); +} + +#if GTEST_IS_THREADSAFE + +typedef ScopedFakeTestPartResultReporterWithThreadsTest + ExpectFailureWithThreadsTest; + +TEST_F(ExpectFailureWithThreadsTest, ExpectFatalFailureOnAllThreads) { + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(AddFailureInOtherThread(FATAL_FAILURE), + "Expected fatal failure."); +} + +TEST_F(ExpectFailureWithThreadsTest, ExpectNonFatalFailureOnAllThreads) { + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS( + AddFailureInOtherThread(NONFATAL_FAILURE), "Expected non-fatal failure."); +} + +#endif // GTEST_IS_THREADSAFE + +// Tests the TestProperty class. + +TEST(TestPropertyTest, ConstructorWorks) { + const TestProperty property("key", "value"); + EXPECT_STREQ("key", property.key()); + EXPECT_STREQ("value", property.value()); +} + +TEST(TestPropertyTest, SetValue) { + TestProperty property("key", "value_1"); + EXPECT_STREQ("key", property.key()); + property.SetValue("value_2"); + EXPECT_STREQ("key", property.key()); + EXPECT_STREQ("value_2", property.value()); +} + +// Tests the TestResult class + +// The test fixture for testing TestResult. +class TestResultTest : public Test { + protected: + typedef std::vector TPRVector; + + // We make use of 2 TestPartResult objects, + TestPartResult * pr1, * pr2; + + // ... and 3 TestResult objects. + TestResult * r0, * r1, * r2; + + void SetUp() override { + // pr1 is for success. + pr1 = new TestPartResult(TestPartResult::kSuccess, + "foo/bar.cc", + 10, + "Success!"); + + // pr2 is for fatal failure. + pr2 = new TestPartResult(TestPartResult::kFatalFailure, + "foo/bar.cc", + -1, // This line number means "unknown" + "Failure!"); + + // Creates the TestResult objects. + r0 = new TestResult(); + r1 = new TestResult(); + r2 = new TestResult(); + + // In order to test TestResult, we need to modify its internal + // state, in particular the TestPartResult vector it holds. + // test_part_results() returns a const reference to this vector. + // We cast it to a non-const object s.t. it can be modified + TPRVector* results1 = const_cast( + &TestResultAccessor::test_part_results(*r1)); + TPRVector* results2 = const_cast( + &TestResultAccessor::test_part_results(*r2)); + + // r0 is an empty TestResult. + + // r1 contains a single SUCCESS TestPartResult. + results1->push_back(*pr1); + + // r2 contains a SUCCESS, and a FAILURE. + results2->push_back(*pr1); + results2->push_back(*pr2); + } + + void TearDown() override { + delete pr1; + delete pr2; + + delete r0; + delete r1; + delete r2; + } + + // Helper that compares two TestPartResults. + static void CompareTestPartResult(const TestPartResult& expected, + const TestPartResult& actual) { + EXPECT_EQ(expected.type(), actual.type()); + EXPECT_STREQ(expected.file_name(), actual.file_name()); + EXPECT_EQ(expected.line_number(), actual.line_number()); + EXPECT_STREQ(expected.summary(), actual.summary()); + EXPECT_STREQ(expected.message(), actual.message()); + EXPECT_EQ(expected.passed(), actual.passed()); + EXPECT_EQ(expected.failed(), actual.failed()); + EXPECT_EQ(expected.nonfatally_failed(), actual.nonfatally_failed()); + EXPECT_EQ(expected.fatally_failed(), actual.fatally_failed()); + } +}; + +// Tests TestResult::total_part_count(). +TEST_F(TestResultTest, total_part_count) { + ASSERT_EQ(0, r0->total_part_count()); + ASSERT_EQ(1, r1->total_part_count()); + ASSERT_EQ(2, r2->total_part_count()); +} + +// Tests TestResult::Passed(). +TEST_F(TestResultTest, Passed) { + ASSERT_TRUE(r0->Passed()); + ASSERT_TRUE(r1->Passed()); + ASSERT_FALSE(r2->Passed()); +} + +// Tests TestResult::Failed(). +TEST_F(TestResultTest, Failed) { + ASSERT_FALSE(r0->Failed()); + ASSERT_FALSE(r1->Failed()); + ASSERT_TRUE(r2->Failed()); +} + +// Tests TestResult::GetTestPartResult(). + +typedef TestResultTest TestResultDeathTest; + +TEST_F(TestResultDeathTest, GetTestPartResult) { + CompareTestPartResult(*pr1, r2->GetTestPartResult(0)); + CompareTestPartResult(*pr2, r2->GetTestPartResult(1)); + EXPECT_DEATH_IF_SUPPORTED(r2->GetTestPartResult(2), ""); + EXPECT_DEATH_IF_SUPPORTED(r2->GetTestPartResult(-1), ""); +} + +// Tests TestResult has no properties when none are added. +TEST(TestResultPropertyTest, NoPropertiesFoundWhenNoneAreAdded) { + TestResult test_result; + ASSERT_EQ(0, test_result.test_property_count()); +} + +// Tests TestResult has the expected property when added. +TEST(TestResultPropertyTest, OnePropertyFoundWhenAdded) { + TestResult test_result; + TestProperty property("key_1", "1"); + TestResultAccessor::RecordProperty(&test_result, "testcase", property); + ASSERT_EQ(1, test_result.test_property_count()); + const TestProperty& actual_property = test_result.GetTestProperty(0); + EXPECT_STREQ("key_1", actual_property.key()); + EXPECT_STREQ("1", actual_property.value()); +} + +// Tests TestResult has multiple properties when added. +TEST(TestResultPropertyTest, MultiplePropertiesFoundWhenAdded) { + TestResult test_result; + TestProperty property_1("key_1", "1"); + TestProperty property_2("key_2", "2"); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_1); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_2); + ASSERT_EQ(2, test_result.test_property_count()); + const TestProperty& actual_property_1 = test_result.GetTestProperty(0); + EXPECT_STREQ("key_1", actual_property_1.key()); + EXPECT_STREQ("1", actual_property_1.value()); + + const TestProperty& actual_property_2 = test_result.GetTestProperty(1); + EXPECT_STREQ("key_2", actual_property_2.key()); + EXPECT_STREQ("2", actual_property_2.value()); +} + +// Tests TestResult::RecordProperty() overrides values for duplicate keys. +TEST(TestResultPropertyTest, OverridesValuesForDuplicateKeys) { + TestResult test_result; + TestProperty property_1_1("key_1", "1"); + TestProperty property_2_1("key_2", "2"); + TestProperty property_1_2("key_1", "12"); + TestProperty property_2_2("key_2", "22"); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_1_1); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_2_1); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_1_2); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_2_2); + + ASSERT_EQ(2, test_result.test_property_count()); + const TestProperty& actual_property_1 = test_result.GetTestProperty(0); + EXPECT_STREQ("key_1", actual_property_1.key()); + EXPECT_STREQ("12", actual_property_1.value()); + + const TestProperty& actual_property_2 = test_result.GetTestProperty(1); + EXPECT_STREQ("key_2", actual_property_2.key()); + EXPECT_STREQ("22", actual_property_2.value()); +} + +// Tests TestResult::GetTestProperty(). +TEST(TestResultPropertyTest, GetTestProperty) { + TestResult test_result; + TestProperty property_1("key_1", "1"); + TestProperty property_2("key_2", "2"); + TestProperty property_3("key_3", "3"); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_1); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_2); + TestResultAccessor::RecordProperty(&test_result, "testcase", property_3); + + const TestProperty& fetched_property_1 = test_result.GetTestProperty(0); + const TestProperty& fetched_property_2 = test_result.GetTestProperty(1); + const TestProperty& fetched_property_3 = test_result.GetTestProperty(2); + + EXPECT_STREQ("key_1", fetched_property_1.key()); + EXPECT_STREQ("1", fetched_property_1.value()); + + EXPECT_STREQ("key_2", fetched_property_2.key()); + EXPECT_STREQ("2", fetched_property_2.value()); + + EXPECT_STREQ("key_3", fetched_property_3.key()); + EXPECT_STREQ("3", fetched_property_3.value()); + + EXPECT_DEATH_IF_SUPPORTED(test_result.GetTestProperty(3), ""); + EXPECT_DEATH_IF_SUPPORTED(test_result.GetTestProperty(-1), ""); +} + +// Tests the Test class. +// +// It's difficult to test every public method of this class (we are +// already stretching the limit of Google Test by using it to test itself!). +// Fortunately, we don't have to do that, as we are already testing +// the functionalities of the Test class extensively by using Google Test +// alone. +// +// Therefore, this section only contains one test. + +// Tests that GTestFlagSaver works on Windows and Mac. + +class GTestFlagSaverTest : public Test { + protected: + // Saves the Google Test flags such that we can restore them later, and + // then sets them to their default values. This will be called + // before the first test in this test case is run. + static void SetUpTestSuite() { + saver_ = new GTestFlagSaver; + + GTEST_FLAG_SET(also_run_disabled_tests, false); + GTEST_FLAG_SET(break_on_failure, false); + GTEST_FLAG_SET(catch_exceptions, false); + GTEST_FLAG_SET(death_test_use_fork, false); + GTEST_FLAG_SET(color, "auto"); + GTEST_FLAG_SET(fail_fast, false); + GTEST_FLAG_SET(filter, ""); + GTEST_FLAG_SET(list_tests, false); + GTEST_FLAG_SET(output, ""); + GTEST_FLAG_SET(brief, false); + GTEST_FLAG_SET(print_time, true); + GTEST_FLAG_SET(random_seed, 0); + GTEST_FLAG_SET(repeat, 1); + GTEST_FLAG_SET(recreate_environments_when_repeating, true); + GTEST_FLAG_SET(shuffle, false); + GTEST_FLAG_SET(stack_trace_depth, kMaxStackTraceDepth); + GTEST_FLAG_SET(stream_result_to, ""); + GTEST_FLAG_SET(throw_on_failure, false); + } + + // Restores the Google Test flags that the tests have modified. This will + // be called after the last test in this test case is run. + static void TearDownTestSuite() { + delete saver_; + saver_ = nullptr; + } + + // Verifies that the Google Test flags have their default values, and then + // modifies each of them. + void VerifyAndModifyFlags() { + EXPECT_FALSE(GTEST_FLAG_GET(also_run_disabled_tests)); + EXPECT_FALSE(GTEST_FLAG_GET(break_on_failure)); + EXPECT_FALSE(GTEST_FLAG_GET(catch_exceptions)); + EXPECT_STREQ("auto", GTEST_FLAG_GET(color).c_str()); + EXPECT_FALSE(GTEST_FLAG_GET(death_test_use_fork)); + EXPECT_FALSE(GTEST_FLAG_GET(fail_fast)); + EXPECT_STREQ("", GTEST_FLAG_GET(filter).c_str()); + EXPECT_FALSE(GTEST_FLAG_GET(list_tests)); + EXPECT_STREQ("", GTEST_FLAG_GET(output).c_str()); + EXPECT_FALSE(GTEST_FLAG_GET(brief)); + EXPECT_TRUE(GTEST_FLAG_GET(print_time)); + EXPECT_EQ(0, GTEST_FLAG_GET(random_seed)); + EXPECT_EQ(1, GTEST_FLAG_GET(repeat)); + EXPECT_TRUE(GTEST_FLAG_GET(recreate_environments_when_repeating)); + EXPECT_FALSE(GTEST_FLAG_GET(shuffle)); + EXPECT_EQ(kMaxStackTraceDepth, GTEST_FLAG_GET(stack_trace_depth)); + EXPECT_STREQ("", GTEST_FLAG_GET(stream_result_to).c_str()); + EXPECT_FALSE(GTEST_FLAG_GET(throw_on_failure)); + + GTEST_FLAG_SET(also_run_disabled_tests, true); + GTEST_FLAG_SET(break_on_failure, true); + GTEST_FLAG_SET(catch_exceptions, true); + GTEST_FLAG_SET(color, "no"); + GTEST_FLAG_SET(death_test_use_fork, true); + GTEST_FLAG_SET(fail_fast, true); + GTEST_FLAG_SET(filter, "abc"); + GTEST_FLAG_SET(list_tests, true); + GTEST_FLAG_SET(output, "xml:foo.xml"); + GTEST_FLAG_SET(brief, true); + GTEST_FLAG_SET(print_time, false); + GTEST_FLAG_SET(random_seed, 1); + GTEST_FLAG_SET(repeat, 100); + GTEST_FLAG_SET(recreate_environments_when_repeating, false); + GTEST_FLAG_SET(shuffle, true); + GTEST_FLAG_SET(stack_trace_depth, 1); + GTEST_FLAG_SET(stream_result_to, "localhost:1234"); + GTEST_FLAG_SET(throw_on_failure, true); + } + + private: + // For saving Google Test flags during this test case. + static GTestFlagSaver* saver_; +}; + +GTestFlagSaver* GTestFlagSaverTest::saver_ = nullptr; + +// Google Test doesn't guarantee the order of tests. The following two +// tests are designed to work regardless of their order. + +// Modifies the Google Test flags in the test body. +TEST_F(GTestFlagSaverTest, ModifyGTestFlags) { + VerifyAndModifyFlags(); +} + +// Verifies that the Google Test flags in the body of the previous test were +// restored to their original values. +TEST_F(GTestFlagSaverTest, VerifyGTestFlags) { + VerifyAndModifyFlags(); +} + +// Sets an environment variable with the given name to the given +// value. If the value argument is "", unsets the environment +// variable. The caller must ensure that both arguments are not NULL. +static void SetEnv(const char* name, const char* value) { +#if GTEST_OS_WINDOWS_MOBILE + // Environment variables are not supported on Windows CE. + return; +#elif defined(__BORLANDC__) || defined(__SunOS_5_8) || defined(__SunOS_5_9) + // C++Builder's putenv only stores a pointer to its parameter; we have to + // ensure that the string remains valid as long as it might be needed. + // We use an std::map to do so. + static std::map added_env; + + // Because putenv stores a pointer to the string buffer, we can't delete the + // previous string (if present) until after it's replaced. + std::string *prev_env = NULL; + if (added_env.find(name) != added_env.end()) { + prev_env = added_env[name]; + } + added_env[name] = new std::string( + (Message() << name << "=" << value).GetString()); + + // The standard signature of putenv accepts a 'char*' argument. Other + // implementations, like C++Builder's, accept a 'const char*'. + // We cast away the 'const' since that would work for both variants. + putenv(const_cast(added_env[name]->c_str())); + delete prev_env; +#elif GTEST_OS_WINDOWS // If we are on Windows proper. + _putenv((Message() << name << "=" << value).GetString().c_str()); +#else + if (*value == '\0') { + unsetenv(name); + } else { + setenv(name, value, 1); + } +#endif // GTEST_OS_WINDOWS_MOBILE +} + +#if !GTEST_OS_WINDOWS_MOBILE +// Environment variables are not supported on Windows CE. + +using testing::internal::Int32FromGTestEnv; + +// Tests Int32FromGTestEnv(). + +// Tests that Int32FromGTestEnv() returns the default value when the +// environment variable is not set. +TEST(Int32FromGTestEnvTest, ReturnsDefaultWhenVariableIsNotSet) { + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "TEMP", ""); + EXPECT_EQ(10, Int32FromGTestEnv("temp", 10)); +} + +# if !defined(GTEST_GET_INT32_FROM_ENV_) + +// Tests that Int32FromGTestEnv() returns the default value when the +// environment variable overflows as an Int32. +TEST(Int32FromGTestEnvTest, ReturnsDefaultWhenValueOverflows) { + printf("(expecting 2 warnings)\n"); + + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "TEMP", "12345678987654321"); + EXPECT_EQ(20, Int32FromGTestEnv("temp", 20)); + + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "TEMP", "-12345678987654321"); + EXPECT_EQ(30, Int32FromGTestEnv("temp", 30)); +} + +// Tests that Int32FromGTestEnv() returns the default value when the +// environment variable does not represent a valid decimal integer. +TEST(Int32FromGTestEnvTest, ReturnsDefaultWhenValueIsInvalid) { + printf("(expecting 2 warnings)\n"); + + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "TEMP", "A1"); + EXPECT_EQ(40, Int32FromGTestEnv("temp", 40)); + + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "TEMP", "12X"); + EXPECT_EQ(50, Int32FromGTestEnv("temp", 50)); +} + +# endif // !defined(GTEST_GET_INT32_FROM_ENV_) + +// Tests that Int32FromGTestEnv() parses and returns the value of the +// environment variable when it represents a valid decimal integer in +// the range of an Int32. +TEST(Int32FromGTestEnvTest, ParsesAndReturnsValidValue) { + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "TEMP", "123"); + EXPECT_EQ(123, Int32FromGTestEnv("temp", 0)); + + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "TEMP", "-321"); + EXPECT_EQ(-321, Int32FromGTestEnv("temp", 0)); +} +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Tests ParseFlag(). + +// Tests that ParseInt32Flag() returns false and doesn't change the +// output value when the flag has wrong format +TEST(ParseInt32FlagTest, ReturnsFalseForInvalidFlag) { + int32_t value = 123; + EXPECT_FALSE(ParseFlag("--a=100", "b", &value)); + EXPECT_EQ(123, value); + + EXPECT_FALSE(ParseFlag("a=100", "a", &value)); + EXPECT_EQ(123, value); +} + +// Tests that ParseFlag() returns false and doesn't change the +// output value when the flag overflows as an Int32. +TEST(ParseInt32FlagTest, ReturnsDefaultWhenValueOverflows) { + printf("(expecting 2 warnings)\n"); + + int32_t value = 123; + EXPECT_FALSE(ParseFlag("--abc=12345678987654321", "abc", &value)); + EXPECT_EQ(123, value); + + EXPECT_FALSE(ParseFlag("--abc=-12345678987654321", "abc", &value)); + EXPECT_EQ(123, value); +} + +// Tests that ParseInt32Flag() returns false and doesn't change the +// output value when the flag does not represent a valid decimal +// integer. +TEST(ParseInt32FlagTest, ReturnsDefaultWhenValueIsInvalid) { + printf("(expecting 2 warnings)\n"); + + int32_t value = 123; + EXPECT_FALSE(ParseFlag("--abc=A1", "abc", &value)); + EXPECT_EQ(123, value); + + EXPECT_FALSE(ParseFlag("--abc=12X", "abc", &value)); + EXPECT_EQ(123, value); +} + +// Tests that ParseInt32Flag() parses the value of the flag and +// returns true when the flag represents a valid decimal integer in +// the range of an Int32. +TEST(ParseInt32FlagTest, ParsesAndReturnsValidValue) { + int32_t value = 123; + EXPECT_TRUE(ParseFlag("--" GTEST_FLAG_PREFIX_ "abc=456", "abc", &value)); + EXPECT_EQ(456, value); + + EXPECT_TRUE(ParseFlag("--" GTEST_FLAG_PREFIX_ "abc=-789", "abc", &value)); + EXPECT_EQ(-789, value); +} + +// Tests that Int32FromEnvOrDie() parses the value of the var or +// returns the correct default. +// Environment variables are not supported on Windows CE. +#if !GTEST_OS_WINDOWS_MOBILE +TEST(Int32FromEnvOrDieTest, ParsesAndReturnsValidValue) { + EXPECT_EQ(333, Int32FromEnvOrDie(GTEST_FLAG_PREFIX_UPPER_ "UnsetVar", 333)); + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "UnsetVar", "123"); + EXPECT_EQ(123, Int32FromEnvOrDie(GTEST_FLAG_PREFIX_UPPER_ "UnsetVar", 333)); + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "UnsetVar", "-123"); + EXPECT_EQ(-123, Int32FromEnvOrDie(GTEST_FLAG_PREFIX_UPPER_ "UnsetVar", 333)); +} +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Tests that Int32FromEnvOrDie() aborts with an error message +// if the variable is not an int32_t. +TEST(Int32FromEnvOrDieDeathTest, AbortsOnFailure) { + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "VAR", "xxx"); + EXPECT_DEATH_IF_SUPPORTED( + Int32FromEnvOrDie(GTEST_FLAG_PREFIX_UPPER_ "VAR", 123), + ".*"); +} + +// Tests that Int32FromEnvOrDie() aborts with an error message +// if the variable cannot be represented by an int32_t. +TEST(Int32FromEnvOrDieDeathTest, AbortsOnInt32Overflow) { + SetEnv(GTEST_FLAG_PREFIX_UPPER_ "VAR", "1234567891234567891234"); + EXPECT_DEATH_IF_SUPPORTED( + Int32FromEnvOrDie(GTEST_FLAG_PREFIX_UPPER_ "VAR", 123), + ".*"); +} + +// Tests that ShouldRunTestOnShard() selects all tests +// where there is 1 shard. +TEST(ShouldRunTestOnShardTest, IsPartitionWhenThereIsOneShard) { + EXPECT_TRUE(ShouldRunTestOnShard(1, 0, 0)); + EXPECT_TRUE(ShouldRunTestOnShard(1, 0, 1)); + EXPECT_TRUE(ShouldRunTestOnShard(1, 0, 2)); + EXPECT_TRUE(ShouldRunTestOnShard(1, 0, 3)); + EXPECT_TRUE(ShouldRunTestOnShard(1, 0, 4)); +} + +class ShouldShardTest : public testing::Test { + protected: + void SetUp() override { + index_var_ = GTEST_FLAG_PREFIX_UPPER_ "INDEX"; + total_var_ = GTEST_FLAG_PREFIX_UPPER_ "TOTAL"; + } + + void TearDown() override { + SetEnv(index_var_, ""); + SetEnv(total_var_, ""); + } + + const char* index_var_; + const char* total_var_; +}; + +// Tests that sharding is disabled if neither of the environment variables +// are set. +TEST_F(ShouldShardTest, ReturnsFalseWhenNeitherEnvVarIsSet) { + SetEnv(index_var_, ""); + SetEnv(total_var_, ""); + + EXPECT_FALSE(ShouldShard(total_var_, index_var_, false)); + EXPECT_FALSE(ShouldShard(total_var_, index_var_, true)); +} + +// Tests that sharding is not enabled if total_shards == 1. +TEST_F(ShouldShardTest, ReturnsFalseWhenTotalShardIsOne) { + SetEnv(index_var_, "0"); + SetEnv(total_var_, "1"); + EXPECT_FALSE(ShouldShard(total_var_, index_var_, false)); + EXPECT_FALSE(ShouldShard(total_var_, index_var_, true)); +} + +// Tests that sharding is enabled if total_shards > 1 and +// we are not in a death test subprocess. +// Environment variables are not supported on Windows CE. +#if !GTEST_OS_WINDOWS_MOBILE +TEST_F(ShouldShardTest, WorksWhenShardEnvVarsAreValid) { + SetEnv(index_var_, "4"); + SetEnv(total_var_, "22"); + EXPECT_TRUE(ShouldShard(total_var_, index_var_, false)); + EXPECT_FALSE(ShouldShard(total_var_, index_var_, true)); + + SetEnv(index_var_, "8"); + SetEnv(total_var_, "9"); + EXPECT_TRUE(ShouldShard(total_var_, index_var_, false)); + EXPECT_FALSE(ShouldShard(total_var_, index_var_, true)); + + SetEnv(index_var_, "0"); + SetEnv(total_var_, "9"); + EXPECT_TRUE(ShouldShard(total_var_, index_var_, false)); + EXPECT_FALSE(ShouldShard(total_var_, index_var_, true)); +} +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Tests that we exit in error if the sharding values are not valid. + +typedef ShouldShardTest ShouldShardDeathTest; + +TEST_F(ShouldShardDeathTest, AbortsWhenShardingEnvVarsAreInvalid) { + SetEnv(index_var_, "4"); + SetEnv(total_var_, "4"); + EXPECT_DEATH_IF_SUPPORTED(ShouldShard(total_var_, index_var_, false), ".*"); + + SetEnv(index_var_, "4"); + SetEnv(total_var_, "-2"); + EXPECT_DEATH_IF_SUPPORTED(ShouldShard(total_var_, index_var_, false), ".*"); + + SetEnv(index_var_, "5"); + SetEnv(total_var_, ""); + EXPECT_DEATH_IF_SUPPORTED(ShouldShard(total_var_, index_var_, false), ".*"); + + SetEnv(index_var_, ""); + SetEnv(total_var_, "5"); + EXPECT_DEATH_IF_SUPPORTED(ShouldShard(total_var_, index_var_, false), ".*"); +} + +// Tests that ShouldRunTestOnShard is a partition when 5 +// shards are used. +TEST(ShouldRunTestOnShardTest, IsPartitionWhenThereAreFiveShards) { + // Choose an arbitrary number of tests and shards. + const int num_tests = 17; + const int num_shards = 5; + + // Check partitioning: each test should be on exactly 1 shard. + for (int test_id = 0; test_id < num_tests; test_id++) { + int prev_selected_shard_index = -1; + for (int shard_index = 0; shard_index < num_shards; shard_index++) { + if (ShouldRunTestOnShard(num_shards, shard_index, test_id)) { + if (prev_selected_shard_index < 0) { + prev_selected_shard_index = shard_index; + } else { + ADD_FAILURE() << "Shard " << prev_selected_shard_index << " and " + << shard_index << " are both selected to run test " << test_id; + } + } + } + } + + // Check balance: This is not required by the sharding protocol, but is a + // desirable property for performance. + for (int shard_index = 0; shard_index < num_shards; shard_index++) { + int num_tests_on_shard = 0; + for (int test_id = 0; test_id < num_tests; test_id++) { + num_tests_on_shard += + ShouldRunTestOnShard(num_shards, shard_index, test_id); + } + EXPECT_GE(num_tests_on_shard, num_tests / num_shards); + } +} + +// For the same reason we are not explicitly testing everything in the +// Test class, there are no separate tests for the following classes +// (except for some trivial cases): +// +// TestSuite, UnitTest, UnitTestResultPrinter. +// +// Similarly, there are no separate tests for the following macros: +// +// TEST, TEST_F, RUN_ALL_TESTS + +TEST(UnitTestTest, CanGetOriginalWorkingDir) { + ASSERT_TRUE(UnitTest::GetInstance()->original_working_dir() != nullptr); + EXPECT_STRNE(UnitTest::GetInstance()->original_working_dir(), ""); +} + +TEST(UnitTestTest, ReturnsPlausibleTimestamp) { + EXPECT_LT(0, UnitTest::GetInstance()->start_timestamp()); + EXPECT_LE(UnitTest::GetInstance()->start_timestamp(), GetTimeInMillis()); +} + +// When a property using a reserved key is supplied to this function, it +// tests that a non-fatal failure is added, a fatal failure is not added, +// and that the property is not recorded. +void ExpectNonFatalFailureRecordingPropertyWithReservedKey( + const TestResult& test_result, const char* key) { + EXPECT_NONFATAL_FAILURE(Test::RecordProperty(key, "1"), "Reserved key"); + ASSERT_EQ(0, test_result.test_property_count()) << "Property for key '" << key + << "' recorded unexpectedly."; +} + +void ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTest( + const char* key) { + const TestInfo* test_info = UnitTest::GetInstance()->current_test_info(); + ASSERT_TRUE(test_info != nullptr); + ExpectNonFatalFailureRecordingPropertyWithReservedKey(*test_info->result(), + key); +} + +void ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTestSuite( + const char* key) { + const testing::TestSuite* test_suite = + UnitTest::GetInstance()->current_test_suite(); + ASSERT_TRUE(test_suite != nullptr); + ExpectNonFatalFailureRecordingPropertyWithReservedKey( + test_suite->ad_hoc_test_result(), key); +} + +void ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + const char* key) { + ExpectNonFatalFailureRecordingPropertyWithReservedKey( + UnitTest::GetInstance()->ad_hoc_test_result(), key); +} + +// Tests that property recording functions in UnitTest outside of tests +// functions correctly. Creating a separate instance of UnitTest ensures it +// is in a state similar to the UnitTest's singleton's between tests. +class UnitTestRecordPropertyTest : + public testing::internal::UnitTestRecordPropertyTestHelper { + public: + static void SetUpTestSuite() { + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTestSuite( + "disabled"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTestSuite( + "errors"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTestSuite( + "failures"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTestSuite( + "name"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTestSuite( + "tests"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTestSuite( + "time"); + + Test::RecordProperty("test_case_key_1", "1"); + + const testing::TestSuite* test_suite = + UnitTest::GetInstance()->current_test_suite(); + + ASSERT_TRUE(test_suite != nullptr); + + ASSERT_EQ(1, test_suite->ad_hoc_test_result().test_property_count()); + EXPECT_STREQ("test_case_key_1", + test_suite->ad_hoc_test_result().GetTestProperty(0).key()); + EXPECT_STREQ("1", + test_suite->ad_hoc_test_result().GetTestProperty(0).value()); + } +}; + +// Tests TestResult has the expected property when added. +TEST_F(UnitTestRecordPropertyTest, OnePropertyFoundWhenAdded) { + UnitTestRecordProperty("key_1", "1"); + + ASSERT_EQ(1, unit_test_.ad_hoc_test_result().test_property_count()); + + EXPECT_STREQ("key_1", + unit_test_.ad_hoc_test_result().GetTestProperty(0).key()); + EXPECT_STREQ("1", + unit_test_.ad_hoc_test_result().GetTestProperty(0).value()); +} + +// Tests TestResult has multiple properties when added. +TEST_F(UnitTestRecordPropertyTest, MultiplePropertiesFoundWhenAdded) { + UnitTestRecordProperty("key_1", "1"); + UnitTestRecordProperty("key_2", "2"); + + ASSERT_EQ(2, unit_test_.ad_hoc_test_result().test_property_count()); + + EXPECT_STREQ("key_1", + unit_test_.ad_hoc_test_result().GetTestProperty(0).key()); + EXPECT_STREQ("1", unit_test_.ad_hoc_test_result().GetTestProperty(0).value()); + + EXPECT_STREQ("key_2", + unit_test_.ad_hoc_test_result().GetTestProperty(1).key()); + EXPECT_STREQ("2", unit_test_.ad_hoc_test_result().GetTestProperty(1).value()); +} + +// Tests TestResult::RecordProperty() overrides values for duplicate keys. +TEST_F(UnitTestRecordPropertyTest, OverridesValuesForDuplicateKeys) { + UnitTestRecordProperty("key_1", "1"); + UnitTestRecordProperty("key_2", "2"); + UnitTestRecordProperty("key_1", "12"); + UnitTestRecordProperty("key_2", "22"); + + ASSERT_EQ(2, unit_test_.ad_hoc_test_result().test_property_count()); + + EXPECT_STREQ("key_1", + unit_test_.ad_hoc_test_result().GetTestProperty(0).key()); + EXPECT_STREQ("12", + unit_test_.ad_hoc_test_result().GetTestProperty(0).value()); + + EXPECT_STREQ("key_2", + unit_test_.ad_hoc_test_result().GetTestProperty(1).key()); + EXPECT_STREQ("22", + unit_test_.ad_hoc_test_result().GetTestProperty(1).value()); +} + +TEST_F(UnitTestRecordPropertyTest, + AddFailureInsideTestsWhenUsingTestSuiteReservedKeys) { + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTest( + "name"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTest( + "value_param"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTest( + "type_param"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTest( + "status"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTest( + "time"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyForCurrentTest( + "classname"); +} + +TEST_F(UnitTestRecordPropertyTest, + AddRecordWithReservedKeysGeneratesCorrectPropertyList) { + EXPECT_NONFATAL_FAILURE( + Test::RecordProperty("name", "1"), + "'classname', 'name', 'status', 'time', 'type_param', 'value_param'," + " 'file', and 'line' are reserved"); +} + +class UnitTestRecordPropertyTestEnvironment : public Environment { + public: + void TearDown() override { + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "tests"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "failures"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "disabled"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "errors"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "name"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "timestamp"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "time"); + ExpectNonFatalFailureRecordingPropertyWithReservedKeyOutsideOfTestSuite( + "random_seed"); + } +}; + +// This will test property recording outside of any test or test case. +static Environment* record_property_env GTEST_ATTRIBUTE_UNUSED_ = + AddGlobalTestEnvironment(new UnitTestRecordPropertyTestEnvironment); + +// This group of tests is for predicate assertions (ASSERT_PRED*, etc) +// of various arities. They do not attempt to be exhaustive. Rather, +// view them as smoke tests that can be easily reviewed and verified. +// A more complete set of tests for predicate assertions can be found +// in gtest_pred_impl_unittest.cc. + +// First, some predicates and predicate-formatters needed by the tests. + +// Returns true if and only if the argument is an even number. +bool IsEven(int n) { + return (n % 2) == 0; +} + +// A functor that returns true if and only if the argument is an even number. +struct IsEvenFunctor { + bool operator()(int n) { return IsEven(n); } +}; + +// A predicate-formatter function that asserts the argument is an even +// number. +AssertionResult AssertIsEven(const char* expr, int n) { + if (IsEven(n)) { + return AssertionSuccess(); + } + + Message msg; + msg << expr << " evaluates to " << n << ", which is not even."; + return AssertionFailure(msg); +} + +// A predicate function that returns AssertionResult for use in +// EXPECT/ASSERT_TRUE/FALSE. +AssertionResult ResultIsEven(int n) { + if (IsEven(n)) + return AssertionSuccess() << n << " is even"; + else + return AssertionFailure() << n << " is odd"; +} + +// A predicate function that returns AssertionResult but gives no +// explanation why it succeeds. Needed for testing that +// EXPECT/ASSERT_FALSE handles such functions correctly. +AssertionResult ResultIsEvenNoExplanation(int n) { + if (IsEven(n)) + return AssertionSuccess(); + else + return AssertionFailure() << n << " is odd"; +} + +// A predicate-formatter functor that asserts the argument is an even +// number. +struct AssertIsEvenFunctor { + AssertionResult operator()(const char* expr, int n) { + return AssertIsEven(expr, n); + } +}; + +// Returns true if and only if the sum of the arguments is an even number. +bool SumIsEven2(int n1, int n2) { + return IsEven(n1 + n2); +} + +// A functor that returns true if and only if the sum of the arguments is an +// even number. +struct SumIsEven3Functor { + bool operator()(int n1, int n2, int n3) { + return IsEven(n1 + n2 + n3); + } +}; + +// A predicate-formatter function that asserts the sum of the +// arguments is an even number. +AssertionResult AssertSumIsEven4( + const char* e1, const char* e2, const char* e3, const char* e4, + int n1, int n2, int n3, int n4) { + const int sum = n1 + n2 + n3 + n4; + if (IsEven(sum)) { + return AssertionSuccess(); + } + + Message msg; + msg << e1 << " + " << e2 << " + " << e3 << " + " << e4 + << " (" << n1 << " + " << n2 << " + " << n3 << " + " << n4 + << ") evaluates to " << sum << ", which is not even."; + return AssertionFailure(msg); +} + +// A predicate-formatter functor that asserts the sum of the arguments +// is an even number. +struct AssertSumIsEven5Functor { + AssertionResult operator()( + const char* e1, const char* e2, const char* e3, const char* e4, + const char* e5, int n1, int n2, int n3, int n4, int n5) { + const int sum = n1 + n2 + n3 + n4 + n5; + if (IsEven(sum)) { + return AssertionSuccess(); + } + + Message msg; + msg << e1 << " + " << e2 << " + " << e3 << " + " << e4 << " + " << e5 + << " (" + << n1 << " + " << n2 << " + " << n3 << " + " << n4 << " + " << n5 + << ") evaluates to " << sum << ", which is not even."; + return AssertionFailure(msg); + } +}; + + +// Tests unary predicate assertions. + +// Tests unary predicate assertions that don't use a custom formatter. +TEST(Pred1Test, WithoutFormat) { + // Success cases. + EXPECT_PRED1(IsEvenFunctor(), 2) << "This failure is UNEXPECTED!"; + ASSERT_PRED1(IsEven, 4); + + // Failure cases. + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED1(IsEven, 5) << "This failure is expected."; + }, "This failure is expected."); + EXPECT_FATAL_FAILURE(ASSERT_PRED1(IsEvenFunctor(), 5), + "evaluates to false"); +} + +// Tests unary predicate assertions that use a custom formatter. +TEST(Pred1Test, WithFormat) { + // Success cases. + EXPECT_PRED_FORMAT1(AssertIsEven, 2); + ASSERT_PRED_FORMAT1(AssertIsEvenFunctor(), 4) + << "This failure is UNEXPECTED!"; + + // Failure cases. + const int n = 5; + EXPECT_NONFATAL_FAILURE(EXPECT_PRED_FORMAT1(AssertIsEvenFunctor(), n), + "n evaluates to 5, which is not even."); + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT1(AssertIsEven, 5) << "This failure is expected."; + }, "This failure is expected."); +} + +// Tests that unary predicate assertions evaluates their arguments +// exactly once. +TEST(Pred1Test, SingleEvaluationOnFailure) { + // A success case. + static int n = 0; + EXPECT_PRED1(IsEven, n++); + EXPECT_EQ(1, n) << "The argument is not evaluated exactly once."; + + // A failure case. + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT1(AssertIsEvenFunctor(), n++) + << "This failure is expected."; + }, "This failure is expected."); + EXPECT_EQ(2, n) << "The argument is not evaluated exactly once."; +} + + +// Tests predicate assertions whose arity is >= 2. + +// Tests predicate assertions that don't use a custom formatter. +TEST(PredTest, WithoutFormat) { + // Success cases. + ASSERT_PRED2(SumIsEven2, 2, 4) << "This failure is UNEXPECTED!"; + EXPECT_PRED3(SumIsEven3Functor(), 4, 6, 8); + + // Failure cases. + const int n1 = 1; + const int n2 = 2; + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED2(SumIsEven2, n1, n2) << "This failure is expected."; + }, "This failure is expected."); + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED3(SumIsEven3Functor(), 1, 2, 4); + }, "evaluates to false"); +} + +// Tests predicate assertions that use a custom formatter. +TEST(PredTest, WithFormat) { + // Success cases. + ASSERT_PRED_FORMAT4(AssertSumIsEven4, 4, 6, 8, 10) << + "This failure is UNEXPECTED!"; + EXPECT_PRED_FORMAT5(AssertSumIsEven5Functor(), 2, 4, 6, 8, 10); + + // Failure cases. + const int n1 = 1; + const int n2 = 2; + const int n3 = 4; + const int n4 = 6; + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT4(AssertSumIsEven4, n1, n2, n3, n4); + }, "evaluates to 13, which is not even."); + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT5(AssertSumIsEven5Functor(), 1, 2, 4, 6, 8) + << "This failure is expected."; + }, "This failure is expected."); +} + +// Tests that predicate assertions evaluates their arguments +// exactly once. +TEST(PredTest, SingleEvaluationOnFailure) { + // A success case. + int n1 = 0; + int n2 = 0; + EXPECT_PRED2(SumIsEven2, n1++, n2++); + EXPECT_EQ(1, n1) << "Argument 1 is not evaluated exactly once."; + EXPECT_EQ(1, n2) << "Argument 2 is not evaluated exactly once."; + + // Another success case. + n1 = n2 = 0; + int n3 = 0; + int n4 = 0; + int n5 = 0; + ASSERT_PRED_FORMAT5(AssertSumIsEven5Functor(), + n1++, n2++, n3++, n4++, n5++) + << "This failure is UNEXPECTED!"; + EXPECT_EQ(1, n1) << "Argument 1 is not evaluated exactly once."; + EXPECT_EQ(1, n2) << "Argument 2 is not evaluated exactly once."; + EXPECT_EQ(1, n3) << "Argument 3 is not evaluated exactly once."; + EXPECT_EQ(1, n4) << "Argument 4 is not evaluated exactly once."; + EXPECT_EQ(1, n5) << "Argument 5 is not evaluated exactly once."; + + // A failure case. + n1 = n2 = n3 = 0; + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED3(SumIsEven3Functor(), ++n1, n2++, n3++) + << "This failure is expected."; + }, "This failure is expected."); + EXPECT_EQ(1, n1) << "Argument 1 is not evaluated exactly once."; + EXPECT_EQ(1, n2) << "Argument 2 is not evaluated exactly once."; + EXPECT_EQ(1, n3) << "Argument 3 is not evaluated exactly once."; + + // Another failure case. + n1 = n2 = n3 = n4 = 0; + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT4(AssertSumIsEven4, ++n1, n2++, n3++, n4++); + }, "evaluates to 1, which is not even."); + EXPECT_EQ(1, n1) << "Argument 1 is not evaluated exactly once."; + EXPECT_EQ(1, n2) << "Argument 2 is not evaluated exactly once."; + EXPECT_EQ(1, n3) << "Argument 3 is not evaluated exactly once."; + EXPECT_EQ(1, n4) << "Argument 4 is not evaluated exactly once."; +} + +// Test predicate assertions for sets +TEST(PredTest, ExpectPredEvalFailure) { + std::set set_a = {2, 1, 3, 4, 5}; + std::set set_b = {0, 4, 8}; + const auto compare_sets = [] (std::set, std::set) { return false; }; + EXPECT_NONFATAL_FAILURE( + EXPECT_PRED2(compare_sets, set_a, set_b), + "compare_sets(set_a, set_b) evaluates to false, where\nset_a evaluates " + "to { 1, 2, 3, 4, 5 }\nset_b evaluates to { 0, 4, 8 }"); +} + +// Some helper functions for testing using overloaded/template +// functions with ASSERT_PREDn and EXPECT_PREDn. + +bool IsPositive(double x) { + return x > 0; +} + +template +bool IsNegative(T x) { + return x < 0; +} + +template +bool GreaterThan(T1 x1, T2 x2) { + return x1 > x2; +} + +// Tests that overloaded functions can be used in *_PRED* as long as +// their types are explicitly specified. +TEST(PredicateAssertionTest, AcceptsOverloadedFunction) { + // C++Builder requires C-style casts rather than static_cast. + EXPECT_PRED1((bool (*)(int))(IsPositive), 5); // NOLINT + ASSERT_PRED1((bool (*)(double))(IsPositive), 6.0); // NOLINT +} + +// Tests that template functions can be used in *_PRED* as long as +// their types are explicitly specified. +TEST(PredicateAssertionTest, AcceptsTemplateFunction) { + EXPECT_PRED1(IsNegative, -5); + // Makes sure that we can handle templates with more than one + // parameter. + ASSERT_PRED2((GreaterThan), 5, 0); +} + + +// Some helper functions for testing using overloaded/template +// functions with ASSERT_PRED_FORMATn and EXPECT_PRED_FORMATn. + +AssertionResult IsPositiveFormat(const char* /* expr */, int n) { + return n > 0 ? AssertionSuccess() : + AssertionFailure(Message() << "Failure"); +} + +AssertionResult IsPositiveFormat(const char* /* expr */, double x) { + return x > 0 ? AssertionSuccess() : + AssertionFailure(Message() << "Failure"); +} + +template +AssertionResult IsNegativeFormat(const char* /* expr */, T x) { + return x < 0 ? AssertionSuccess() : + AssertionFailure(Message() << "Failure"); +} + +template +AssertionResult EqualsFormat(const char* /* expr1 */, const char* /* expr2 */, + const T1& x1, const T2& x2) { + return x1 == x2 ? AssertionSuccess() : + AssertionFailure(Message() << "Failure"); +} + +// Tests that overloaded functions can be used in *_PRED_FORMAT* +// without explicitly specifying their types. +TEST(PredicateFormatAssertionTest, AcceptsOverloadedFunction) { + EXPECT_PRED_FORMAT1(IsPositiveFormat, 5); + ASSERT_PRED_FORMAT1(IsPositiveFormat, 6.0); +} + +// Tests that template functions can be used in *_PRED_FORMAT* without +// explicitly specifying their types. +TEST(PredicateFormatAssertionTest, AcceptsTemplateFunction) { + EXPECT_PRED_FORMAT1(IsNegativeFormat, -5); + ASSERT_PRED_FORMAT2(EqualsFormat, 3, 3); +} + + +// Tests string assertions. + +// Tests ASSERT_STREQ with non-NULL arguments. +TEST(StringAssertionTest, ASSERT_STREQ) { + const char * const p1 = "good"; + ASSERT_STREQ(p1, p1); + + // Let p2 have the same content as p1, but be at a different address. + const char p2[] = "good"; + ASSERT_STREQ(p1, p2); + + EXPECT_FATAL_FAILURE(ASSERT_STREQ("bad", "good"), + " \"bad\"\n \"good\""); +} + +// Tests ASSERT_STREQ with NULL arguments. +TEST(StringAssertionTest, ASSERT_STREQ_Null) { + ASSERT_STREQ(static_cast(nullptr), nullptr); + EXPECT_FATAL_FAILURE(ASSERT_STREQ(nullptr, "non-null"), "non-null"); +} + +// Tests ASSERT_STREQ with NULL arguments. +TEST(StringAssertionTest, ASSERT_STREQ_Null2) { + EXPECT_FATAL_FAILURE(ASSERT_STREQ("non-null", nullptr), "non-null"); +} + +// Tests ASSERT_STRNE. +TEST(StringAssertionTest, ASSERT_STRNE) { + ASSERT_STRNE("hi", "Hi"); + ASSERT_STRNE("Hi", nullptr); + ASSERT_STRNE(nullptr, "Hi"); + ASSERT_STRNE("", nullptr); + ASSERT_STRNE(nullptr, ""); + ASSERT_STRNE("", "Hi"); + ASSERT_STRNE("Hi", ""); + EXPECT_FATAL_FAILURE(ASSERT_STRNE("Hi", "Hi"), + "\"Hi\" vs \"Hi\""); +} + +// Tests ASSERT_STRCASEEQ. +TEST(StringAssertionTest, ASSERT_STRCASEEQ) { + ASSERT_STRCASEEQ("hi", "Hi"); + ASSERT_STRCASEEQ(static_cast(nullptr), nullptr); + + ASSERT_STRCASEEQ("", ""); + EXPECT_FATAL_FAILURE(ASSERT_STRCASEEQ("Hi", "hi2"), + "Ignoring case"); +} + +// Tests ASSERT_STRCASENE. +TEST(StringAssertionTest, ASSERT_STRCASENE) { + ASSERT_STRCASENE("hi1", "Hi2"); + ASSERT_STRCASENE("Hi", nullptr); + ASSERT_STRCASENE(nullptr, "Hi"); + ASSERT_STRCASENE("", nullptr); + ASSERT_STRCASENE(nullptr, ""); + ASSERT_STRCASENE("", "Hi"); + ASSERT_STRCASENE("Hi", ""); + EXPECT_FATAL_FAILURE(ASSERT_STRCASENE("Hi", "hi"), + "(ignoring case)"); +} + +// Tests *_STREQ on wide strings. +TEST(StringAssertionTest, STREQ_Wide) { + // NULL strings. + ASSERT_STREQ(static_cast(nullptr), nullptr); + + // Empty strings. + ASSERT_STREQ(L"", L""); + + // Non-null vs NULL. + EXPECT_NONFATAL_FAILURE(EXPECT_STREQ(L"non-null", nullptr), "non-null"); + + // Equal strings. + EXPECT_STREQ(L"Hi", L"Hi"); + + // Unequal strings. + EXPECT_NONFATAL_FAILURE(EXPECT_STREQ(L"abc", L"Abc"), + "Abc"); + + // Strings containing wide characters. + EXPECT_NONFATAL_FAILURE(EXPECT_STREQ(L"abc\x8119", L"abc\x8120"), + "abc"); + + // The streaming variation. + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_STREQ(L"abc\x8119", L"abc\x8121") << "Expected failure"; + }, "Expected failure"); +} + +// Tests *_STRNE on wide strings. +TEST(StringAssertionTest, STRNE_Wide) { + // NULL strings. + EXPECT_NONFATAL_FAILURE( + { // NOLINT + EXPECT_STRNE(static_cast(nullptr), nullptr); + }, + ""); + + // Empty strings. + EXPECT_NONFATAL_FAILURE(EXPECT_STRNE(L"", L""), + "L\"\""); + + // Non-null vs NULL. + ASSERT_STRNE(L"non-null", nullptr); + + // Equal strings. + EXPECT_NONFATAL_FAILURE(EXPECT_STRNE(L"Hi", L"Hi"), + "L\"Hi\""); + + // Unequal strings. + EXPECT_STRNE(L"abc", L"Abc"); + + // Strings containing wide characters. + EXPECT_NONFATAL_FAILURE(EXPECT_STRNE(L"abc\x8119", L"abc\x8119"), + "abc"); + + // The streaming variation. + ASSERT_STRNE(L"abc\x8119", L"abc\x8120") << "This shouldn't happen"; +} + +// Tests for ::testing::IsSubstring(). + +// Tests that IsSubstring() returns the correct result when the input +// argument type is const char*. +TEST(IsSubstringTest, ReturnsCorrectResultForCString) { + EXPECT_FALSE(IsSubstring("", "", nullptr, "a")); + EXPECT_FALSE(IsSubstring("", "", "b", nullptr)); + EXPECT_FALSE(IsSubstring("", "", "needle", "haystack")); + + EXPECT_TRUE(IsSubstring("", "", static_cast(nullptr), nullptr)); + EXPECT_TRUE(IsSubstring("", "", "needle", "two needles")); +} + +// Tests that IsSubstring() returns the correct result when the input +// argument type is const wchar_t*. +TEST(IsSubstringTest, ReturnsCorrectResultForWideCString) { + EXPECT_FALSE(IsSubstring("", "", kNull, L"a")); + EXPECT_FALSE(IsSubstring("", "", L"b", kNull)); + EXPECT_FALSE(IsSubstring("", "", L"needle", L"haystack")); + + EXPECT_TRUE( + IsSubstring("", "", static_cast(nullptr), nullptr)); + EXPECT_TRUE(IsSubstring("", "", L"needle", L"two needles")); +} + +// Tests that IsSubstring() generates the correct message when the input +// argument type is const char*. +TEST(IsSubstringTest, GeneratesCorrectMessageForCString) { + EXPECT_STREQ("Value of: needle_expr\n" + " Actual: \"needle\"\n" + "Expected: a substring of haystack_expr\n" + "Which is: \"haystack\"", + IsSubstring("needle_expr", "haystack_expr", + "needle", "haystack").failure_message()); +} + +// Tests that IsSubstring returns the correct result when the input +// argument type is ::std::string. +TEST(IsSubstringTest, ReturnsCorrectResultsForStdString) { + EXPECT_TRUE(IsSubstring("", "", std::string("hello"), "ahellob")); + EXPECT_FALSE(IsSubstring("", "", "hello", std::string("world"))); +} + +#if GTEST_HAS_STD_WSTRING +// Tests that IsSubstring returns the correct result when the input +// argument type is ::std::wstring. +TEST(IsSubstringTest, ReturnsCorrectResultForStdWstring) { + EXPECT_TRUE(IsSubstring("", "", ::std::wstring(L"needle"), L"two needles")); + EXPECT_FALSE(IsSubstring("", "", L"needle", ::std::wstring(L"haystack"))); +} + +// Tests that IsSubstring() generates the correct message when the input +// argument type is ::std::wstring. +TEST(IsSubstringTest, GeneratesCorrectMessageForWstring) { + EXPECT_STREQ("Value of: needle_expr\n" + " Actual: L\"needle\"\n" + "Expected: a substring of haystack_expr\n" + "Which is: L\"haystack\"", + IsSubstring( + "needle_expr", "haystack_expr", + ::std::wstring(L"needle"), L"haystack").failure_message()); +} + +#endif // GTEST_HAS_STD_WSTRING + +// Tests for ::testing::IsNotSubstring(). + +// Tests that IsNotSubstring() returns the correct result when the input +// argument type is const char*. +TEST(IsNotSubstringTest, ReturnsCorrectResultForCString) { + EXPECT_TRUE(IsNotSubstring("", "", "needle", "haystack")); + EXPECT_FALSE(IsNotSubstring("", "", "needle", "two needles")); +} + +// Tests that IsNotSubstring() returns the correct result when the input +// argument type is const wchar_t*. +TEST(IsNotSubstringTest, ReturnsCorrectResultForWideCString) { + EXPECT_TRUE(IsNotSubstring("", "", L"needle", L"haystack")); + EXPECT_FALSE(IsNotSubstring("", "", L"needle", L"two needles")); +} + +// Tests that IsNotSubstring() generates the correct message when the input +// argument type is const wchar_t*. +TEST(IsNotSubstringTest, GeneratesCorrectMessageForWideCString) { + EXPECT_STREQ("Value of: needle_expr\n" + " Actual: L\"needle\"\n" + "Expected: not a substring of haystack_expr\n" + "Which is: L\"two needles\"", + IsNotSubstring( + "needle_expr", "haystack_expr", + L"needle", L"two needles").failure_message()); +} + +// Tests that IsNotSubstring returns the correct result when the input +// argument type is ::std::string. +TEST(IsNotSubstringTest, ReturnsCorrectResultsForStdString) { + EXPECT_FALSE(IsNotSubstring("", "", std::string("hello"), "ahellob")); + EXPECT_TRUE(IsNotSubstring("", "", "hello", std::string("world"))); +} + +// Tests that IsNotSubstring() generates the correct message when the input +// argument type is ::std::string. +TEST(IsNotSubstringTest, GeneratesCorrectMessageForStdString) { + EXPECT_STREQ("Value of: needle_expr\n" + " Actual: \"needle\"\n" + "Expected: not a substring of haystack_expr\n" + "Which is: \"two needles\"", + IsNotSubstring( + "needle_expr", "haystack_expr", + ::std::string("needle"), "two needles").failure_message()); +} + +#if GTEST_HAS_STD_WSTRING + +// Tests that IsNotSubstring returns the correct result when the input +// argument type is ::std::wstring. +TEST(IsNotSubstringTest, ReturnsCorrectResultForStdWstring) { + EXPECT_FALSE( + IsNotSubstring("", "", ::std::wstring(L"needle"), L"two needles")); + EXPECT_TRUE(IsNotSubstring("", "", L"needle", ::std::wstring(L"haystack"))); +} + +#endif // GTEST_HAS_STD_WSTRING + +// Tests floating-point assertions. + +template +class FloatingPointTest : public Test { + protected: + // Pre-calculated numbers to be used by the tests. + struct TestValues { + RawType close_to_positive_zero; + RawType close_to_negative_zero; + RawType further_from_negative_zero; + + RawType close_to_one; + RawType further_from_one; + + RawType infinity; + RawType close_to_infinity; + RawType further_from_infinity; + + RawType nan1; + RawType nan2; + }; + + typedef typename testing::internal::FloatingPoint Floating; + typedef typename Floating::Bits Bits; + + void SetUp() override { + const uint32_t max_ulps = Floating::kMaxUlps; + + // The bits that represent 0.0. + const Bits zero_bits = Floating(0).bits(); + + // Makes some numbers close to 0.0. + values_.close_to_positive_zero = Floating::ReinterpretBits( + zero_bits + max_ulps/2); + values_.close_to_negative_zero = -Floating::ReinterpretBits( + zero_bits + max_ulps - max_ulps/2); + values_.further_from_negative_zero = -Floating::ReinterpretBits( + zero_bits + max_ulps + 1 - max_ulps/2); + + // The bits that represent 1.0. + const Bits one_bits = Floating(1).bits(); + + // Makes some numbers close to 1.0. + values_.close_to_one = Floating::ReinterpretBits(one_bits + max_ulps); + values_.further_from_one = Floating::ReinterpretBits( + one_bits + max_ulps + 1); + + // +infinity. + values_.infinity = Floating::Infinity(); + + // The bits that represent +infinity. + const Bits infinity_bits = Floating(values_.infinity).bits(); + + // Makes some numbers close to infinity. + values_.close_to_infinity = Floating::ReinterpretBits( + infinity_bits - max_ulps); + values_.further_from_infinity = Floating::ReinterpretBits( + infinity_bits - max_ulps - 1); + + // Makes some NAN's. Sets the most significant bit of the fraction so that + // our NaN's are quiet; trying to process a signaling NaN would raise an + // exception if our environment enables floating point exceptions. + values_.nan1 = Floating::ReinterpretBits(Floating::kExponentBitMask + | (static_cast(1) << (Floating::kFractionBitCount - 1)) | 1); + values_.nan2 = Floating::ReinterpretBits(Floating::kExponentBitMask + | (static_cast(1) << (Floating::kFractionBitCount - 1)) | 200); + } + + void TestSize() { + EXPECT_EQ(sizeof(RawType), sizeof(Bits)); + } + + static TestValues values_; +}; + +template +typename FloatingPointTest::TestValues + FloatingPointTest::values_; + +// Instantiates FloatingPointTest for testing *_FLOAT_EQ. +typedef FloatingPointTest FloatTest; + +// Tests that the size of Float::Bits matches the size of float. +TEST_F(FloatTest, Size) { + TestSize(); +} + +// Tests comparing with +0 and -0. +TEST_F(FloatTest, Zeros) { + EXPECT_FLOAT_EQ(0.0, -0.0); + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(-0.0, 1.0), + "1.0"); + EXPECT_FATAL_FAILURE(ASSERT_FLOAT_EQ(0.0, 1.5), + "1.5"); +} + +// Tests comparing numbers close to 0. +// +// This ensures that *_FLOAT_EQ handles the sign correctly and no +// overflow occurs when comparing numbers whose absolute value is very +// small. +TEST_F(FloatTest, AlmostZeros) { + // In C++Builder, names within local classes (such as used by + // EXPECT_FATAL_FAILURE) cannot be resolved against static members of the + // scoping class. Use a static local alias as a workaround. + // We use the assignment syntax since some compilers, like Sun Studio, + // don't allow initializing references using construction syntax + // (parentheses). + static const FloatTest::TestValues& v = this->values_; + + EXPECT_FLOAT_EQ(0.0, v.close_to_positive_zero); + EXPECT_FLOAT_EQ(-0.0, v.close_to_negative_zero); + EXPECT_FLOAT_EQ(v.close_to_positive_zero, v.close_to_negative_zero); + + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_FLOAT_EQ(v.close_to_positive_zero, + v.further_from_negative_zero); + }, "v.further_from_negative_zero"); +} + +// Tests comparing numbers close to each other. +TEST_F(FloatTest, SmallDiff) { + EXPECT_FLOAT_EQ(1.0, values_.close_to_one); + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(1.0, values_.further_from_one), + "values_.further_from_one"); +} + +// Tests comparing numbers far apart. +TEST_F(FloatTest, LargeDiff) { + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(2.5, 3.0), + "3.0"); +} + +// Tests comparing with infinity. +// +// This ensures that no overflow occurs when comparing numbers whose +// absolute value is very large. +TEST_F(FloatTest, Infinity) { + EXPECT_FLOAT_EQ(values_.infinity, values_.close_to_infinity); + EXPECT_FLOAT_EQ(-values_.infinity, -values_.close_to_infinity); + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(values_.infinity, -values_.infinity), + "-values_.infinity"); + + // This is interesting as the representations of infinity and nan1 + // are only 1 DLP apart. + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(values_.infinity, values_.nan1), + "values_.nan1"); +} + +// Tests that comparing with NAN always returns false. +TEST_F(FloatTest, NaN) { + // In C++Builder, names within local classes (such as used by + // EXPECT_FATAL_FAILURE) cannot be resolved against static members of the + // scoping class. Use a static local alias as a workaround. + // We use the assignment syntax since some compilers, like Sun Studio, + // don't allow initializing references using construction syntax + // (parentheses). + static const FloatTest::TestValues& v = this->values_; + + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(v.nan1, v.nan1), + "v.nan1"); + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(v.nan1, v.nan2), + "v.nan2"); + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(1.0, v.nan1), + "v.nan1"); + + EXPECT_FATAL_FAILURE(ASSERT_FLOAT_EQ(v.nan1, v.infinity), + "v.infinity"); +} + +// Tests that *_FLOAT_EQ are reflexive. +TEST_F(FloatTest, Reflexive) { + EXPECT_FLOAT_EQ(0.0, 0.0); + EXPECT_FLOAT_EQ(1.0, 1.0); + ASSERT_FLOAT_EQ(values_.infinity, values_.infinity); +} + +// Tests that *_FLOAT_EQ are commutative. +TEST_F(FloatTest, Commutative) { + // We already tested EXPECT_FLOAT_EQ(1.0, values_.close_to_one). + EXPECT_FLOAT_EQ(values_.close_to_one, 1.0); + + // We already tested EXPECT_FLOAT_EQ(1.0, values_.further_from_one). + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(values_.further_from_one, 1.0), + "1.0"); +} + +// Tests EXPECT_NEAR. +TEST_F(FloatTest, EXPECT_NEAR) { + EXPECT_NEAR(-1.0f, -1.1f, 0.2f); + EXPECT_NEAR(2.0f, 3.0f, 1.0f); + EXPECT_NONFATAL_FAILURE(EXPECT_NEAR(1.0f, 1.5f, 0.25f), // NOLINT + "The difference between 1.0f and 1.5f is 0.5, " + "which exceeds 0.25f"); +} + +// Tests ASSERT_NEAR. +TEST_F(FloatTest, ASSERT_NEAR) { + ASSERT_NEAR(-1.0f, -1.1f, 0.2f); + ASSERT_NEAR(2.0f, 3.0f, 1.0f); + EXPECT_FATAL_FAILURE(ASSERT_NEAR(1.0f, 1.5f, 0.25f), // NOLINT + "The difference between 1.0f and 1.5f is 0.5, " + "which exceeds 0.25f"); +} + +// Tests the cases where FloatLE() should succeed. +TEST_F(FloatTest, FloatLESucceeds) { + EXPECT_PRED_FORMAT2(FloatLE, 1.0f, 2.0f); // When val1 < val2, + ASSERT_PRED_FORMAT2(FloatLE, 1.0f, 1.0f); // val1 == val2, + + // or when val1 is greater than, but almost equals to, val2. + EXPECT_PRED_FORMAT2(FloatLE, values_.close_to_positive_zero, 0.0f); +} + +// Tests the cases where FloatLE() should fail. +TEST_F(FloatTest, FloatLEFails) { + // When val1 is greater than val2 by a large margin, + EXPECT_NONFATAL_FAILURE(EXPECT_PRED_FORMAT2(FloatLE, 2.0f, 1.0f), + "(2.0f) <= (1.0f)"); + + // or by a small yet non-negligible margin, + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(FloatLE, values_.further_from_one, 1.0f); + }, "(values_.further_from_one) <= (1.0f)"); + + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(FloatLE, values_.nan1, values_.infinity); + }, "(values_.nan1) <= (values_.infinity)"); + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(FloatLE, -values_.infinity, values_.nan1); + }, "(-values_.infinity) <= (values_.nan1)"); + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT2(FloatLE, values_.nan1, values_.nan1); + }, "(values_.nan1) <= (values_.nan1)"); +} + +// Instantiates FloatingPointTest for testing *_DOUBLE_EQ. +typedef FloatingPointTest DoubleTest; + +// Tests that the size of Double::Bits matches the size of double. +TEST_F(DoubleTest, Size) { + TestSize(); +} + +// Tests comparing with +0 and -0. +TEST_F(DoubleTest, Zeros) { + EXPECT_DOUBLE_EQ(0.0, -0.0); + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(-0.0, 1.0), + "1.0"); + EXPECT_FATAL_FAILURE(ASSERT_DOUBLE_EQ(0.0, 1.0), + "1.0"); +} + +// Tests comparing numbers close to 0. +// +// This ensures that *_DOUBLE_EQ handles the sign correctly and no +// overflow occurs when comparing numbers whose absolute value is very +// small. +TEST_F(DoubleTest, AlmostZeros) { + // In C++Builder, names within local classes (such as used by + // EXPECT_FATAL_FAILURE) cannot be resolved against static members of the + // scoping class. Use a static local alias as a workaround. + // We use the assignment syntax since some compilers, like Sun Studio, + // don't allow initializing references using construction syntax + // (parentheses). + static const DoubleTest::TestValues& v = this->values_; + + EXPECT_DOUBLE_EQ(0.0, v.close_to_positive_zero); + EXPECT_DOUBLE_EQ(-0.0, v.close_to_negative_zero); + EXPECT_DOUBLE_EQ(v.close_to_positive_zero, v.close_to_negative_zero); + + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_DOUBLE_EQ(v.close_to_positive_zero, + v.further_from_negative_zero); + }, "v.further_from_negative_zero"); +} + +// Tests comparing numbers close to each other. +TEST_F(DoubleTest, SmallDiff) { + EXPECT_DOUBLE_EQ(1.0, values_.close_to_one); + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(1.0, values_.further_from_one), + "values_.further_from_one"); +} + +// Tests comparing numbers far apart. +TEST_F(DoubleTest, LargeDiff) { + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(2.0, 3.0), + "3.0"); +} + +// Tests comparing with infinity. +// +// This ensures that no overflow occurs when comparing numbers whose +// absolute value is very large. +TEST_F(DoubleTest, Infinity) { + EXPECT_DOUBLE_EQ(values_.infinity, values_.close_to_infinity); + EXPECT_DOUBLE_EQ(-values_.infinity, -values_.close_to_infinity); + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(values_.infinity, -values_.infinity), + "-values_.infinity"); + + // This is interesting as the representations of infinity_ and nan1_ + // are only 1 DLP apart. + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(values_.infinity, values_.nan1), + "values_.nan1"); +} + +// Tests that comparing with NAN always returns false. +TEST_F(DoubleTest, NaN) { + static const DoubleTest::TestValues& v = this->values_; + + // Nokia's STLport crashes if we try to output infinity or NaN. + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(v.nan1, v.nan1), + "v.nan1"); + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(v.nan1, v.nan2), "v.nan2"); + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(1.0, v.nan1), "v.nan1"); + EXPECT_FATAL_FAILURE(ASSERT_DOUBLE_EQ(v.nan1, v.infinity), + "v.infinity"); +} + +// Tests that *_DOUBLE_EQ are reflexive. +TEST_F(DoubleTest, Reflexive) { + EXPECT_DOUBLE_EQ(0.0, 0.0); + EXPECT_DOUBLE_EQ(1.0, 1.0); + ASSERT_DOUBLE_EQ(values_.infinity, values_.infinity); +} + +// Tests that *_DOUBLE_EQ are commutative. +TEST_F(DoubleTest, Commutative) { + // We already tested EXPECT_DOUBLE_EQ(1.0, values_.close_to_one). + EXPECT_DOUBLE_EQ(values_.close_to_one, 1.0); + + // We already tested EXPECT_DOUBLE_EQ(1.0, values_.further_from_one). + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(values_.further_from_one, 1.0), + "1.0"); +} + +// Tests EXPECT_NEAR. +TEST_F(DoubleTest, EXPECT_NEAR) { + EXPECT_NEAR(-1.0, -1.1, 0.2); + EXPECT_NEAR(2.0, 3.0, 1.0); + EXPECT_NONFATAL_FAILURE(EXPECT_NEAR(1.0, 1.5, 0.25), // NOLINT + "The difference between 1.0 and 1.5 is 0.5, " + "which exceeds 0.25"); + // At this magnitude adjacent doubles are 512.0 apart, so this triggers a + // slightly different failure reporting path. + EXPECT_NONFATAL_FAILURE( + EXPECT_NEAR(4.2934311416234112e+18, 4.2934311416234107e+18, 1.0), + "The abs_error parameter 1.0 evaluates to 1 which is smaller than the " + "minimum distance between doubles for numbers of this magnitude which is " + "512"); +} + +// Tests ASSERT_NEAR. +TEST_F(DoubleTest, ASSERT_NEAR) { + ASSERT_NEAR(-1.0, -1.1, 0.2); + ASSERT_NEAR(2.0, 3.0, 1.0); + EXPECT_FATAL_FAILURE(ASSERT_NEAR(1.0, 1.5, 0.25), // NOLINT + "The difference between 1.0 and 1.5 is 0.5, " + "which exceeds 0.25"); +} + +// Tests the cases where DoubleLE() should succeed. +TEST_F(DoubleTest, DoubleLESucceeds) { + EXPECT_PRED_FORMAT2(DoubleLE, 1.0, 2.0); // When val1 < val2, + ASSERT_PRED_FORMAT2(DoubleLE, 1.0, 1.0); // val1 == val2, + + // or when val1 is greater than, but almost equals to, val2. + EXPECT_PRED_FORMAT2(DoubleLE, values_.close_to_positive_zero, 0.0); +} + +// Tests the cases where DoubleLE() should fail. +TEST_F(DoubleTest, DoubleLEFails) { + // When val1 is greater than val2 by a large margin, + EXPECT_NONFATAL_FAILURE(EXPECT_PRED_FORMAT2(DoubleLE, 2.0, 1.0), + "(2.0) <= (1.0)"); + + // or by a small yet non-negligible margin, + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(DoubleLE, values_.further_from_one, 1.0); + }, "(values_.further_from_one) <= (1.0)"); + + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(DoubleLE, values_.nan1, values_.infinity); + }, "(values_.nan1) <= (values_.infinity)"); + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_PRED_FORMAT2(DoubleLE, -values_.infinity, values_.nan1); + }, " (-values_.infinity) <= (values_.nan1)"); + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_PRED_FORMAT2(DoubleLE, values_.nan1, values_.nan1); + }, "(values_.nan1) <= (values_.nan1)"); +} + + +// Verifies that a test or test case whose name starts with DISABLED_ is +// not run. + +// A test whose name starts with DISABLED_. +// Should not run. +TEST(DisabledTest, DISABLED_TestShouldNotRun) { + FAIL() << "Unexpected failure: Disabled test should not be run."; +} + +// A test whose name does not start with DISABLED_. +// Should run. +TEST(DisabledTest, NotDISABLED_TestShouldRun) { + EXPECT_EQ(1, 1); +} + +// A test case whose name starts with DISABLED_. +// Should not run. +TEST(DISABLED_TestSuite, TestShouldNotRun) { + FAIL() << "Unexpected failure: Test in disabled test case should not be run."; +} + +// A test case and test whose names start with DISABLED_. +// Should not run. +TEST(DISABLED_TestSuite, DISABLED_TestShouldNotRun) { + FAIL() << "Unexpected failure: Test in disabled test case should not be run."; +} + +// Check that when all tests in a test case are disabled, SetUpTestSuite() and +// TearDownTestSuite() are not called. +class DisabledTestsTest : public Test { + protected: + static void SetUpTestSuite() { + FAIL() << "Unexpected failure: All tests disabled in test case. " + "SetUpTestSuite() should not be called."; + } + + static void TearDownTestSuite() { + FAIL() << "Unexpected failure: All tests disabled in test case. " + "TearDownTestSuite() should not be called."; + } +}; + +TEST_F(DisabledTestsTest, DISABLED_TestShouldNotRun_1) { + FAIL() << "Unexpected failure: Disabled test should not be run."; +} + +TEST_F(DisabledTestsTest, DISABLED_TestShouldNotRun_2) { + FAIL() << "Unexpected failure: Disabled test should not be run."; +} + +// Tests that disabled typed tests aren't run. + +template +class TypedTest : public Test { +}; + +typedef testing::Types NumericTypes; +TYPED_TEST_SUITE(TypedTest, NumericTypes); + +TYPED_TEST(TypedTest, DISABLED_ShouldNotRun) { + FAIL() << "Unexpected failure: Disabled typed test should not run."; +} + +template +class DISABLED_TypedTest : public Test { +}; + +TYPED_TEST_SUITE(DISABLED_TypedTest, NumericTypes); + +TYPED_TEST(DISABLED_TypedTest, ShouldNotRun) { + FAIL() << "Unexpected failure: Disabled typed test should not run."; +} + +// Tests that disabled type-parameterized tests aren't run. + +template +class TypedTestP : public Test { +}; + +TYPED_TEST_SUITE_P(TypedTestP); + +TYPED_TEST_P(TypedTestP, DISABLED_ShouldNotRun) { + FAIL() << "Unexpected failure: " + << "Disabled type-parameterized test should not run."; +} + +REGISTER_TYPED_TEST_SUITE_P(TypedTestP, DISABLED_ShouldNotRun); + +INSTANTIATE_TYPED_TEST_SUITE_P(My, TypedTestP, NumericTypes); + +template +class DISABLED_TypedTestP : public Test { +}; + +TYPED_TEST_SUITE_P(DISABLED_TypedTestP); + +TYPED_TEST_P(DISABLED_TypedTestP, ShouldNotRun) { + FAIL() << "Unexpected failure: " + << "Disabled type-parameterized test should not run."; +} + +REGISTER_TYPED_TEST_SUITE_P(DISABLED_TypedTestP, ShouldNotRun); + +INSTANTIATE_TYPED_TEST_SUITE_P(My, DISABLED_TypedTestP, NumericTypes); + +// Tests that assertion macros evaluate their arguments exactly once. + +class SingleEvaluationTest : public Test { + public: // Must be public and not protected due to a bug in g++ 3.4.2. + // This helper function is needed by the FailedASSERT_STREQ test + // below. It's public to work around C++Builder's bug with scoping local + // classes. + static void CompareAndIncrementCharPtrs() { + ASSERT_STREQ(p1_++, p2_++); + } + + // This helper function is needed by the FailedASSERT_NE test below. It's + // public to work around C++Builder's bug with scoping local classes. + static void CompareAndIncrementInts() { + ASSERT_NE(a_++, b_++); + } + + protected: + SingleEvaluationTest() { + p1_ = s1_; + p2_ = s2_; + a_ = 0; + b_ = 0; + } + + static const char* const s1_; + static const char* const s2_; + static const char* p1_; + static const char* p2_; + + static int a_; + static int b_; +}; + +const char* const SingleEvaluationTest::s1_ = "01234"; +const char* const SingleEvaluationTest::s2_ = "abcde"; +const char* SingleEvaluationTest::p1_; +const char* SingleEvaluationTest::p2_; +int SingleEvaluationTest::a_; +int SingleEvaluationTest::b_; + +// Tests that when ASSERT_STREQ fails, it evaluates its arguments +// exactly once. +TEST_F(SingleEvaluationTest, FailedASSERT_STREQ) { + EXPECT_FATAL_FAILURE(SingleEvaluationTest::CompareAndIncrementCharPtrs(), + "p2_++"); + EXPECT_EQ(s1_ + 1, p1_); + EXPECT_EQ(s2_ + 1, p2_); +} + +// Tests that string assertion arguments are evaluated exactly once. +TEST_F(SingleEvaluationTest, ASSERT_STR) { + // successful EXPECT_STRNE + EXPECT_STRNE(p1_++, p2_++); + EXPECT_EQ(s1_ + 1, p1_); + EXPECT_EQ(s2_ + 1, p2_); + + // failed EXPECT_STRCASEEQ + EXPECT_NONFATAL_FAILURE(EXPECT_STRCASEEQ(p1_++, p2_++), + "Ignoring case"); + EXPECT_EQ(s1_ + 2, p1_); + EXPECT_EQ(s2_ + 2, p2_); +} + +// Tests that when ASSERT_NE fails, it evaluates its arguments exactly +// once. +TEST_F(SingleEvaluationTest, FailedASSERT_NE) { + EXPECT_FATAL_FAILURE(SingleEvaluationTest::CompareAndIncrementInts(), + "(a_++) != (b_++)"); + EXPECT_EQ(1, a_); + EXPECT_EQ(1, b_); +} + +// Tests that assertion arguments are evaluated exactly once. +TEST_F(SingleEvaluationTest, OtherCases) { + // successful EXPECT_TRUE + EXPECT_TRUE(0 == a_++); // NOLINT + EXPECT_EQ(1, a_); + + // failed EXPECT_TRUE + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(-1 == a_++), "-1 == a_++"); + EXPECT_EQ(2, a_); + + // successful EXPECT_GT + EXPECT_GT(a_++, b_++); + EXPECT_EQ(3, a_); + EXPECT_EQ(1, b_); + + // failed EXPECT_LT + EXPECT_NONFATAL_FAILURE(EXPECT_LT(a_++, b_++), "(a_++) < (b_++)"); + EXPECT_EQ(4, a_); + EXPECT_EQ(2, b_); + + // successful ASSERT_TRUE + ASSERT_TRUE(0 < a_++); // NOLINT + EXPECT_EQ(5, a_); + + // successful ASSERT_GT + ASSERT_GT(a_++, b_++); + EXPECT_EQ(6, a_); + EXPECT_EQ(3, b_); +} + +#if GTEST_HAS_EXCEPTIONS + +#if GTEST_HAS_RTTI + +#ifdef _MSC_VER +#define ERROR_DESC "class std::runtime_error" +#else +#define ERROR_DESC "std::runtime_error" +#endif + +#else // GTEST_HAS_RTTI + +#define ERROR_DESC "an std::exception-derived error" + +#endif // GTEST_HAS_RTTI + +void ThrowAnInteger() { + throw 1; +} +void ThrowRuntimeError(const char* what) { + throw std::runtime_error(what); +} + +// Tests that assertion arguments are evaluated exactly once. +TEST_F(SingleEvaluationTest, ExceptionTests) { + // successful EXPECT_THROW + EXPECT_THROW({ // NOLINT + a_++; + ThrowAnInteger(); + }, int); + EXPECT_EQ(1, a_); + + // failed EXPECT_THROW, throws different + EXPECT_NONFATAL_FAILURE(EXPECT_THROW({ // NOLINT + a_++; + ThrowAnInteger(); + }, bool), "throws a different type"); + EXPECT_EQ(2, a_); + + // failed EXPECT_THROW, throws runtime error + EXPECT_NONFATAL_FAILURE(EXPECT_THROW({ // NOLINT + a_++; + ThrowRuntimeError("A description"); + }, bool), "throws " ERROR_DESC " with description \"A description\""); + EXPECT_EQ(3, a_); + + // failed EXPECT_THROW, throws nothing + EXPECT_NONFATAL_FAILURE(EXPECT_THROW(a_++, bool), "throws nothing"); + EXPECT_EQ(4, a_); + + // successful EXPECT_NO_THROW + EXPECT_NO_THROW(a_++); + EXPECT_EQ(5, a_); + + // failed EXPECT_NO_THROW + EXPECT_NONFATAL_FAILURE(EXPECT_NO_THROW({ // NOLINT + a_++; + ThrowAnInteger(); + }), "it throws"); + EXPECT_EQ(6, a_); + + // successful EXPECT_ANY_THROW + EXPECT_ANY_THROW({ // NOLINT + a_++; + ThrowAnInteger(); + }); + EXPECT_EQ(7, a_); + + // failed EXPECT_ANY_THROW + EXPECT_NONFATAL_FAILURE(EXPECT_ANY_THROW(a_++), "it doesn't"); + EXPECT_EQ(8, a_); +} + +#endif // GTEST_HAS_EXCEPTIONS + +// Tests {ASSERT|EXPECT}_NO_FATAL_FAILURE. +class NoFatalFailureTest : public Test { + protected: + void Succeeds() {} + void FailsNonFatal() { + ADD_FAILURE() << "some non-fatal failure"; + } + void Fails() { + FAIL() << "some fatal failure"; + } + + void DoAssertNoFatalFailureOnFails() { + ASSERT_NO_FATAL_FAILURE(Fails()); + ADD_FAILURE() << "should not reach here."; + } + + void DoExpectNoFatalFailureOnFails() { + EXPECT_NO_FATAL_FAILURE(Fails()); + ADD_FAILURE() << "other failure"; + } +}; + +TEST_F(NoFatalFailureTest, NoFailure) { + EXPECT_NO_FATAL_FAILURE(Succeeds()); + ASSERT_NO_FATAL_FAILURE(Succeeds()); +} + +TEST_F(NoFatalFailureTest, NonFatalIsNoFailure) { + EXPECT_NONFATAL_FAILURE( + EXPECT_NO_FATAL_FAILURE(FailsNonFatal()), + "some non-fatal failure"); + EXPECT_NONFATAL_FAILURE( + ASSERT_NO_FATAL_FAILURE(FailsNonFatal()), + "some non-fatal failure"); +} + +TEST_F(NoFatalFailureTest, AssertNoFatalFailureOnFatalFailure) { + TestPartResultArray gtest_failures; + { + ScopedFakeTestPartResultReporter gtest_reporter(>est_failures); + DoAssertNoFatalFailureOnFails(); + } + ASSERT_EQ(2, gtest_failures.size()); + EXPECT_EQ(TestPartResult::kFatalFailure, + gtest_failures.GetTestPartResult(0).type()); + EXPECT_EQ(TestPartResult::kFatalFailure, + gtest_failures.GetTestPartResult(1).type()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, "some fatal failure", + gtest_failures.GetTestPartResult(0).message()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, "it does", + gtest_failures.GetTestPartResult(1).message()); +} + +TEST_F(NoFatalFailureTest, ExpectNoFatalFailureOnFatalFailure) { + TestPartResultArray gtest_failures; + { + ScopedFakeTestPartResultReporter gtest_reporter(>est_failures); + DoExpectNoFatalFailureOnFails(); + } + ASSERT_EQ(3, gtest_failures.size()); + EXPECT_EQ(TestPartResult::kFatalFailure, + gtest_failures.GetTestPartResult(0).type()); + EXPECT_EQ(TestPartResult::kNonFatalFailure, + gtest_failures.GetTestPartResult(1).type()); + EXPECT_EQ(TestPartResult::kNonFatalFailure, + gtest_failures.GetTestPartResult(2).type()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, "some fatal failure", + gtest_failures.GetTestPartResult(0).message()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, "it does", + gtest_failures.GetTestPartResult(1).message()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, "other failure", + gtest_failures.GetTestPartResult(2).message()); +} + +TEST_F(NoFatalFailureTest, MessageIsStreamable) { + TestPartResultArray gtest_failures; + { + ScopedFakeTestPartResultReporter gtest_reporter(>est_failures); + EXPECT_NO_FATAL_FAILURE(FAIL() << "foo") << "my message"; + } + ASSERT_EQ(2, gtest_failures.size()); + EXPECT_EQ(TestPartResult::kNonFatalFailure, + gtest_failures.GetTestPartResult(0).type()); + EXPECT_EQ(TestPartResult::kNonFatalFailure, + gtest_failures.GetTestPartResult(1).type()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, "foo", + gtest_failures.GetTestPartResult(0).message()); + EXPECT_PRED_FORMAT2(testing::IsSubstring, "my message", + gtest_failures.GetTestPartResult(1).message()); +} + +// Tests non-string assertions. + +std::string EditsToString(const std::vector& edits) { + std::string out; + for (size_t i = 0; i < edits.size(); ++i) { + static const char kEdits[] = " +-/"; + out.append(1, kEdits[edits[i]]); + } + return out; +} + +std::vector CharsToIndices(const std::string& str) { + std::vector out; + for (size_t i = 0; i < str.size(); ++i) { + out.push_back(static_cast(str[i])); + } + return out; +} + +std::vector CharsToLines(const std::string& str) { + std::vector out; + for (size_t i = 0; i < str.size(); ++i) { + out.push_back(str.substr(i, 1)); + } + return out; +} + +TEST(EditDistance, TestSuites) { + struct Case { + int line; + const char* left; + const char* right; + const char* expected_edits; + const char* expected_diff; + }; + static const Case kCases[] = { + // No change. + {__LINE__, "A", "A", " ", ""}, + {__LINE__, "ABCDE", "ABCDE", " ", ""}, + // Simple adds. + {__LINE__, "X", "XA", " +", "@@ +1,2 @@\n X\n+A\n"}, + {__LINE__, "X", "XABCD", " ++++", "@@ +1,5 @@\n X\n+A\n+B\n+C\n+D\n"}, + // Simple removes. + {__LINE__, "XA", "X", " -", "@@ -1,2 @@\n X\n-A\n"}, + {__LINE__, "XABCD", "X", " ----", "@@ -1,5 @@\n X\n-A\n-B\n-C\n-D\n"}, + // Simple replaces. + {__LINE__, "A", "a", "/", "@@ -1,1 +1,1 @@\n-A\n+a\n"}, + {__LINE__, "ABCD", "abcd", "////", + "@@ -1,4 +1,4 @@\n-A\n-B\n-C\n-D\n+a\n+b\n+c\n+d\n"}, + // Path finding. + {__LINE__, "ABCDEFGH", "ABXEGH1", " -/ - +", + "@@ -1,8 +1,7 @@\n A\n B\n-C\n-D\n+X\n E\n-F\n G\n H\n+1\n"}, + {__LINE__, "AAAABCCCC", "ABABCDCDC", "- / + / ", + "@@ -1,9 +1,9 @@\n-A\n A\n-A\n+B\n A\n B\n C\n+D\n C\n-C\n+D\n C\n"}, + {__LINE__, "ABCDE", "BCDCD", "- +/", + "@@ -1,5 +1,5 @@\n-A\n B\n C\n D\n-E\n+C\n+D\n"}, + {__LINE__, "ABCDEFGHIJKL", "BCDCDEFGJKLJK", "- ++ -- ++", + "@@ -1,4 +1,5 @@\n-A\n B\n+C\n+D\n C\n D\n" + "@@ -6,7 +7,7 @@\n F\n G\n-H\n-I\n J\n K\n L\n+J\n+K\n"}, + {}}; + for (const Case* c = kCases; c->left; ++c) { + EXPECT_TRUE(c->expected_edits == + EditsToString(CalculateOptimalEdits(CharsToIndices(c->left), + CharsToIndices(c->right)))) + << "Left <" << c->left << "> Right <" << c->right << "> Edits <" + << EditsToString(CalculateOptimalEdits( + CharsToIndices(c->left), CharsToIndices(c->right))) << ">"; + EXPECT_TRUE(c->expected_diff == CreateUnifiedDiff(CharsToLines(c->left), + CharsToLines(c->right))) + << "Left <" << c->left << "> Right <" << c->right << "> Diff <" + << CreateUnifiedDiff(CharsToLines(c->left), CharsToLines(c->right)) + << ">"; + } +} + +// Tests EqFailure(), used for implementing *EQ* assertions. +TEST(AssertionTest, EqFailure) { + const std::string foo_val("5"), bar_val("6"); + const std::string msg1( + EqFailure("foo", "bar", foo_val, bar_val, false) + .failure_message()); + EXPECT_STREQ( + "Expected equality of these values:\n" + " foo\n" + " Which is: 5\n" + " bar\n" + " Which is: 6", + msg1.c_str()); + + const std::string msg2( + EqFailure("foo", "6", foo_val, bar_val, false) + .failure_message()); + EXPECT_STREQ( + "Expected equality of these values:\n" + " foo\n" + " Which is: 5\n" + " 6", + msg2.c_str()); + + const std::string msg3( + EqFailure("5", "bar", foo_val, bar_val, false) + .failure_message()); + EXPECT_STREQ( + "Expected equality of these values:\n" + " 5\n" + " bar\n" + " Which is: 6", + msg3.c_str()); + + const std::string msg4( + EqFailure("5", "6", foo_val, bar_val, false).failure_message()); + EXPECT_STREQ( + "Expected equality of these values:\n" + " 5\n" + " 6", + msg4.c_str()); + + const std::string msg5( + EqFailure("foo", "bar", + std::string("\"x\""), std::string("\"y\""), + true).failure_message()); + EXPECT_STREQ( + "Expected equality of these values:\n" + " foo\n" + " Which is: \"x\"\n" + " bar\n" + " Which is: \"y\"\n" + "Ignoring case", + msg5.c_str()); +} + +TEST(AssertionTest, EqFailureWithDiff) { + const std::string left( + "1\\n2XXX\\n3\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12XXX\\n13\\n14\\n15"); + const std::string right( + "1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n11\\n12\\n13\\n14"); + const std::string msg1( + EqFailure("left", "right", left, right, false).failure_message()); + EXPECT_STREQ( + "Expected equality of these values:\n" + " left\n" + " Which is: " + "1\\n2XXX\\n3\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12XXX\\n13\\n14\\n15\n" + " right\n" + " Which is: 1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n11\\n12\\n13\\n14\n" + "With diff:\n@@ -1,5 +1,6 @@\n 1\n-2XXX\n+2\n 3\n+4\n 5\n 6\n" + "@@ -7,8 +8,6 @@\n 8\n 9\n-10\n 11\n-12XXX\n+12\n 13\n 14\n-15\n", + msg1.c_str()); +} + +// Tests AppendUserMessage(), used for implementing the *EQ* macros. +TEST(AssertionTest, AppendUserMessage) { + const std::string foo("foo"); + + Message msg; + EXPECT_STREQ("foo", + AppendUserMessage(foo, msg).c_str()); + + msg << "bar"; + EXPECT_STREQ("foo\nbar", + AppendUserMessage(foo, msg).c_str()); +} + +#ifdef __BORLANDC__ +// Silences warnings: "Condition is always true", "Unreachable code" +# pragma option push -w-ccc -w-rch +#endif + +// Tests ASSERT_TRUE. +TEST(AssertionTest, ASSERT_TRUE) { + ASSERT_TRUE(2 > 1); // NOLINT + EXPECT_FATAL_FAILURE(ASSERT_TRUE(2 < 1), + "2 < 1"); +} + +// Tests ASSERT_TRUE(predicate) for predicates returning AssertionResult. +TEST(AssertionTest, AssertTrueWithAssertionResult) { + ASSERT_TRUE(ResultIsEven(2)); +#ifndef __BORLANDC__ + // ICE's in C++Builder. + EXPECT_FATAL_FAILURE(ASSERT_TRUE(ResultIsEven(3)), + "Value of: ResultIsEven(3)\n" + " Actual: false (3 is odd)\n" + "Expected: true"); +#endif + ASSERT_TRUE(ResultIsEvenNoExplanation(2)); + EXPECT_FATAL_FAILURE(ASSERT_TRUE(ResultIsEvenNoExplanation(3)), + "Value of: ResultIsEvenNoExplanation(3)\n" + " Actual: false (3 is odd)\n" + "Expected: true"); +} + +// Tests ASSERT_FALSE. +TEST(AssertionTest, ASSERT_FALSE) { + ASSERT_FALSE(2 < 1); // NOLINT + EXPECT_FATAL_FAILURE(ASSERT_FALSE(2 > 1), + "Value of: 2 > 1\n" + " Actual: true\n" + "Expected: false"); +} + +// Tests ASSERT_FALSE(predicate) for predicates returning AssertionResult. +TEST(AssertionTest, AssertFalseWithAssertionResult) { + ASSERT_FALSE(ResultIsEven(3)); +#ifndef __BORLANDC__ + // ICE's in C++Builder. + EXPECT_FATAL_FAILURE(ASSERT_FALSE(ResultIsEven(2)), + "Value of: ResultIsEven(2)\n" + " Actual: true (2 is even)\n" + "Expected: false"); +#endif + ASSERT_FALSE(ResultIsEvenNoExplanation(3)); + EXPECT_FATAL_FAILURE(ASSERT_FALSE(ResultIsEvenNoExplanation(2)), + "Value of: ResultIsEvenNoExplanation(2)\n" + " Actual: true\n" + "Expected: false"); +} + +#ifdef __BORLANDC__ +// Restores warnings after previous "#pragma option push" suppressed them +# pragma option pop +#endif + +// Tests using ASSERT_EQ on double values. The purpose is to make +// sure that the specialization we did for integer and anonymous enums +// isn't used for double arguments. +TEST(ExpectTest, ASSERT_EQ_Double) { + // A success. + ASSERT_EQ(5.6, 5.6); + + // A failure. + EXPECT_FATAL_FAILURE(ASSERT_EQ(5.1, 5.2), + "5.1"); +} + +// Tests ASSERT_EQ. +TEST(AssertionTest, ASSERT_EQ) { + ASSERT_EQ(5, 2 + 3); + EXPECT_FATAL_FAILURE(ASSERT_EQ(5, 2*3), + "Expected equality of these values:\n" + " 5\n" + " 2*3\n" + " Which is: 6"); +} + +// Tests ASSERT_EQ(NULL, pointer). +TEST(AssertionTest, ASSERT_EQ_NULL) { + // A success. + const char* p = nullptr; + ASSERT_EQ(nullptr, p); + + // A failure. + static int n = 0; + EXPECT_FATAL_FAILURE(ASSERT_EQ(nullptr, &n), " &n\n Which is:"); +} + +// Tests ASSERT_EQ(0, non_pointer). Since the literal 0 can be +// treated as a null pointer by the compiler, we need to make sure +// that ASSERT_EQ(0, non_pointer) isn't interpreted by Google Test as +// ASSERT_EQ(static_cast(NULL), non_pointer). +TEST(ExpectTest, ASSERT_EQ_0) { + int n = 0; + + // A success. + ASSERT_EQ(0, n); + + // A failure. + EXPECT_FATAL_FAILURE(ASSERT_EQ(0, 5.6), + " 0\n 5.6"); +} + +// Tests ASSERT_NE. +TEST(AssertionTest, ASSERT_NE) { + ASSERT_NE(6, 7); + EXPECT_FATAL_FAILURE(ASSERT_NE('a', 'a'), + "Expected: ('a') != ('a'), " + "actual: 'a' (97, 0x61) vs 'a' (97, 0x61)"); +} + +// Tests ASSERT_LE. +TEST(AssertionTest, ASSERT_LE) { + ASSERT_LE(2, 3); + ASSERT_LE(2, 2); + EXPECT_FATAL_FAILURE(ASSERT_LE(2, 0), + "Expected: (2) <= (0), actual: 2 vs 0"); +} + +// Tests ASSERT_LT. +TEST(AssertionTest, ASSERT_LT) { + ASSERT_LT(2, 3); + EXPECT_FATAL_FAILURE(ASSERT_LT(2, 2), + "Expected: (2) < (2), actual: 2 vs 2"); +} + +// Tests ASSERT_GE. +TEST(AssertionTest, ASSERT_GE) { + ASSERT_GE(2, 1); + ASSERT_GE(2, 2); + EXPECT_FATAL_FAILURE(ASSERT_GE(2, 3), + "Expected: (2) >= (3), actual: 2 vs 3"); +} + +// Tests ASSERT_GT. +TEST(AssertionTest, ASSERT_GT) { + ASSERT_GT(2, 1); + EXPECT_FATAL_FAILURE(ASSERT_GT(2, 2), + "Expected: (2) > (2), actual: 2 vs 2"); +} + +#if GTEST_HAS_EXCEPTIONS + +void ThrowNothing() {} + +// Tests ASSERT_THROW. +TEST(AssertionTest, ASSERT_THROW) { + ASSERT_THROW(ThrowAnInteger(), int); + +# ifndef __BORLANDC__ + + // ICE's in C++Builder 2007 and 2009. + EXPECT_FATAL_FAILURE( + ASSERT_THROW(ThrowAnInteger(), bool), + "Expected: ThrowAnInteger() throws an exception of type bool.\n" + " Actual: it throws a different type."); + EXPECT_FATAL_FAILURE( + ASSERT_THROW(ThrowRuntimeError("A description"), std::logic_error), + "Expected: ThrowRuntimeError(\"A description\") " + "throws an exception of type std::logic_error.\n " + "Actual: it throws " ERROR_DESC " " + "with description \"A description\"."); +# endif + + EXPECT_FATAL_FAILURE( + ASSERT_THROW(ThrowNothing(), bool), + "Expected: ThrowNothing() throws an exception of type bool.\n" + " Actual: it throws nothing."); +} + +// Tests ASSERT_NO_THROW. +TEST(AssertionTest, ASSERT_NO_THROW) { + ASSERT_NO_THROW(ThrowNothing()); + EXPECT_FATAL_FAILURE(ASSERT_NO_THROW(ThrowAnInteger()), + "Expected: ThrowAnInteger() doesn't throw an exception." + "\n Actual: it throws."); + EXPECT_FATAL_FAILURE(ASSERT_NO_THROW(ThrowRuntimeError("A description")), + "Expected: ThrowRuntimeError(\"A description\") " + "doesn't throw an exception.\n " + "Actual: it throws " ERROR_DESC " " + "with description \"A description\"."); +} + +// Tests ASSERT_ANY_THROW. +TEST(AssertionTest, ASSERT_ANY_THROW) { + ASSERT_ANY_THROW(ThrowAnInteger()); + EXPECT_FATAL_FAILURE( + ASSERT_ANY_THROW(ThrowNothing()), + "Expected: ThrowNothing() throws an exception.\n" + " Actual: it doesn't."); +} + +#endif // GTEST_HAS_EXCEPTIONS + +// Makes sure we deal with the precedence of <<. This test should +// compile. +TEST(AssertionTest, AssertPrecedence) { + ASSERT_EQ(1 < 2, true); + bool false_value = false; + ASSERT_EQ(true && false_value, false); +} + +// A subroutine used by the following test. +void TestEq1(int x) { + ASSERT_EQ(1, x); +} + +// Tests calling a test subroutine that's not part of a fixture. +TEST(AssertionTest, NonFixtureSubroutine) { + EXPECT_FATAL_FAILURE(TestEq1(2), + " x\n Which is: 2"); +} + +// An uncopyable class. +class Uncopyable { + public: + explicit Uncopyable(int a_value) : value_(a_value) {} + + int value() const { return value_; } + bool operator==(const Uncopyable& rhs) const { + return value() == rhs.value(); + } + private: + // This constructor deliberately has no implementation, as we don't + // want this class to be copyable. + Uncopyable(const Uncopyable&); // NOLINT + + int value_; +}; + +::std::ostream& operator<<(::std::ostream& os, const Uncopyable& value) { + return os << value.value(); +} + + +bool IsPositiveUncopyable(const Uncopyable& x) { + return x.value() > 0; +} + +// A subroutine used by the following test. +void TestAssertNonPositive() { + Uncopyable y(-1); + ASSERT_PRED1(IsPositiveUncopyable, y); +} +// A subroutine used by the following test. +void TestAssertEqualsUncopyable() { + Uncopyable x(5); + Uncopyable y(-1); + ASSERT_EQ(x, y); +} + +// Tests that uncopyable objects can be used in assertions. +TEST(AssertionTest, AssertWorksWithUncopyableObject) { + Uncopyable x(5); + ASSERT_PRED1(IsPositiveUncopyable, x); + ASSERT_EQ(x, x); + EXPECT_FATAL_FAILURE(TestAssertNonPositive(), + "IsPositiveUncopyable(y) evaluates to false, where\ny evaluates to -1"); + EXPECT_FATAL_FAILURE(TestAssertEqualsUncopyable(), + "Expected equality of these values:\n" + " x\n Which is: 5\n y\n Which is: -1"); +} + +// Tests that uncopyable objects can be used in expects. +TEST(AssertionTest, ExpectWorksWithUncopyableObject) { + Uncopyable x(5); + EXPECT_PRED1(IsPositiveUncopyable, x); + Uncopyable y(-1); + EXPECT_NONFATAL_FAILURE(EXPECT_PRED1(IsPositiveUncopyable, y), + "IsPositiveUncopyable(y) evaluates to false, where\ny evaluates to -1"); + EXPECT_EQ(x, x); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(x, y), + "Expected equality of these values:\n" + " x\n Which is: 5\n y\n Which is: -1"); +} + +enum NamedEnum { + kE1 = 0, + kE2 = 1 +}; + +TEST(AssertionTest, NamedEnum) { + EXPECT_EQ(kE1, kE1); + EXPECT_LT(kE1, kE2); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(kE1, kE2), "Which is: 0"); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(kE1, kE2), "Which is: 1"); +} + +// Sun Studio and HP aCC2reject this code. +#if !defined(__SUNPRO_CC) && !defined(__HP_aCC) + +// Tests using assertions with anonymous enums. +enum { + kCaseA = -1, + +# if GTEST_OS_LINUX + + // We want to test the case where the size of the anonymous enum is + // larger than sizeof(int), to make sure our implementation of the + // assertions doesn't truncate the enums. However, MSVC + // (incorrectly) doesn't allow an enum value to exceed the range of + // an int, so this has to be conditionally compiled. + // + // On Linux, kCaseB and kCaseA have the same value when truncated to + // int size. We want to test whether this will confuse the + // assertions. + kCaseB = testing::internal::kMaxBiggestInt, + +# else + + kCaseB = INT_MAX, + +# endif // GTEST_OS_LINUX + + kCaseC = 42 +}; + +TEST(AssertionTest, AnonymousEnum) { +# if GTEST_OS_LINUX + + EXPECT_EQ(static_cast(kCaseA), static_cast(kCaseB)); + +# endif // GTEST_OS_LINUX + + EXPECT_EQ(kCaseA, kCaseA); + EXPECT_NE(kCaseA, kCaseB); + EXPECT_LT(kCaseA, kCaseB); + EXPECT_LE(kCaseA, kCaseB); + EXPECT_GT(kCaseB, kCaseA); + EXPECT_GE(kCaseA, kCaseA); + EXPECT_NONFATAL_FAILURE(EXPECT_GE(kCaseA, kCaseB), + "(kCaseA) >= (kCaseB)"); + EXPECT_NONFATAL_FAILURE(EXPECT_GE(kCaseA, kCaseC), + "-1 vs 42"); + + ASSERT_EQ(kCaseA, kCaseA); + ASSERT_NE(kCaseA, kCaseB); + ASSERT_LT(kCaseA, kCaseB); + ASSERT_LE(kCaseA, kCaseB); + ASSERT_GT(kCaseB, kCaseA); + ASSERT_GE(kCaseA, kCaseA); + +# ifndef __BORLANDC__ + + // ICE's in C++Builder. + EXPECT_FATAL_FAILURE(ASSERT_EQ(kCaseA, kCaseB), + " kCaseB\n Which is: "); + EXPECT_FATAL_FAILURE(ASSERT_EQ(kCaseA, kCaseC), + "\n Which is: 42"); +# endif + + EXPECT_FATAL_FAILURE(ASSERT_EQ(kCaseA, kCaseC), + "\n Which is: -1"); +} + +#endif // !GTEST_OS_MAC && !defined(__SUNPRO_CC) + +#if GTEST_OS_WINDOWS + +static HRESULT UnexpectedHRESULTFailure() { + return E_UNEXPECTED; +} + +static HRESULT OkHRESULTSuccess() { + return S_OK; +} + +static HRESULT FalseHRESULTSuccess() { + return S_FALSE; +} + +// HRESULT assertion tests test both zero and non-zero +// success codes as well as failure message for each. +// +// Windows CE doesn't support message texts. +TEST(HRESULTAssertionTest, EXPECT_HRESULT_SUCCEEDED) { + EXPECT_HRESULT_SUCCEEDED(S_OK); + EXPECT_HRESULT_SUCCEEDED(S_FALSE); + + EXPECT_NONFATAL_FAILURE(EXPECT_HRESULT_SUCCEEDED(UnexpectedHRESULTFailure()), + "Expected: (UnexpectedHRESULTFailure()) succeeds.\n" + " Actual: 0x8000FFFF"); +} + +TEST(HRESULTAssertionTest, ASSERT_HRESULT_SUCCEEDED) { + ASSERT_HRESULT_SUCCEEDED(S_OK); + ASSERT_HRESULT_SUCCEEDED(S_FALSE); + + EXPECT_FATAL_FAILURE(ASSERT_HRESULT_SUCCEEDED(UnexpectedHRESULTFailure()), + "Expected: (UnexpectedHRESULTFailure()) succeeds.\n" + " Actual: 0x8000FFFF"); +} + +TEST(HRESULTAssertionTest, EXPECT_HRESULT_FAILED) { + EXPECT_HRESULT_FAILED(E_UNEXPECTED); + + EXPECT_NONFATAL_FAILURE(EXPECT_HRESULT_FAILED(OkHRESULTSuccess()), + "Expected: (OkHRESULTSuccess()) fails.\n" + " Actual: 0x0"); + EXPECT_NONFATAL_FAILURE(EXPECT_HRESULT_FAILED(FalseHRESULTSuccess()), + "Expected: (FalseHRESULTSuccess()) fails.\n" + " Actual: 0x1"); +} + +TEST(HRESULTAssertionTest, ASSERT_HRESULT_FAILED) { + ASSERT_HRESULT_FAILED(E_UNEXPECTED); + +# ifndef __BORLANDC__ + + // ICE's in C++Builder 2007 and 2009. + EXPECT_FATAL_FAILURE(ASSERT_HRESULT_FAILED(OkHRESULTSuccess()), + "Expected: (OkHRESULTSuccess()) fails.\n" + " Actual: 0x0"); +# endif + + EXPECT_FATAL_FAILURE(ASSERT_HRESULT_FAILED(FalseHRESULTSuccess()), + "Expected: (FalseHRESULTSuccess()) fails.\n" + " Actual: 0x1"); +} + +// Tests that streaming to the HRESULT macros works. +TEST(HRESULTAssertionTest, Streaming) { + EXPECT_HRESULT_SUCCEEDED(S_OK) << "unexpected failure"; + ASSERT_HRESULT_SUCCEEDED(S_OK) << "unexpected failure"; + EXPECT_HRESULT_FAILED(E_UNEXPECTED) << "unexpected failure"; + ASSERT_HRESULT_FAILED(E_UNEXPECTED) << "unexpected failure"; + + EXPECT_NONFATAL_FAILURE( + EXPECT_HRESULT_SUCCEEDED(E_UNEXPECTED) << "expected failure", + "expected failure"); + +# ifndef __BORLANDC__ + + // ICE's in C++Builder 2007 and 2009. + EXPECT_FATAL_FAILURE( + ASSERT_HRESULT_SUCCEEDED(E_UNEXPECTED) << "expected failure", + "expected failure"); +# endif + + EXPECT_NONFATAL_FAILURE( + EXPECT_HRESULT_FAILED(S_OK) << "expected failure", + "expected failure"); + + EXPECT_FATAL_FAILURE( + ASSERT_HRESULT_FAILED(S_OK) << "expected failure", + "expected failure"); +} + +#endif // GTEST_OS_WINDOWS + +// The following code intentionally tests a suboptimal syntax. +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdangling-else" +#pragma GCC diagnostic ignored "-Wempty-body" +#pragma GCC diagnostic ignored "-Wpragmas" +#endif +// Tests that the assertion macros behave like single statements. +TEST(AssertionSyntaxTest, BasicAssertionsBehavesLikeSingleStatement) { + if (AlwaysFalse()) + ASSERT_TRUE(false) << "This should never be executed; " + "It's a compilation test only."; + + if (AlwaysTrue()) + EXPECT_FALSE(false); + else + ; // NOLINT + + if (AlwaysFalse()) + ASSERT_LT(1, 3); + + if (AlwaysFalse()) + ; // NOLINT + else + EXPECT_GT(3, 2) << ""; +} +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#if GTEST_HAS_EXCEPTIONS +// Tests that the compiler will not complain about unreachable code in the +// EXPECT_THROW/EXPECT_ANY_THROW/EXPECT_NO_THROW macros. +TEST(ExpectThrowTest, DoesNotGenerateUnreachableCodeWarning) { + int n = 0; + + EXPECT_THROW(throw 1, int); + EXPECT_NONFATAL_FAILURE(EXPECT_THROW(n++, int), ""); + EXPECT_NONFATAL_FAILURE(EXPECT_THROW(throw 1, const char*), ""); + EXPECT_NO_THROW(n++); + EXPECT_NONFATAL_FAILURE(EXPECT_NO_THROW(throw 1), ""); + EXPECT_ANY_THROW(throw 1); + EXPECT_NONFATAL_FAILURE(EXPECT_ANY_THROW(n++), ""); +} + +TEST(ExpectThrowTest, DoesNotGenerateDuplicateCatchClauseWarning) { + EXPECT_THROW(throw std::exception(), std::exception); +} + +// The following code intentionally tests a suboptimal syntax. +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdangling-else" +#pragma GCC diagnostic ignored "-Wempty-body" +#pragma GCC diagnostic ignored "-Wpragmas" +#endif +TEST(AssertionSyntaxTest, ExceptionAssertionsBehavesLikeSingleStatement) { + if (AlwaysFalse()) + EXPECT_THROW(ThrowNothing(), bool); + + if (AlwaysTrue()) + EXPECT_THROW(ThrowAnInteger(), int); + else + ; // NOLINT + + if (AlwaysFalse()) + EXPECT_NO_THROW(ThrowAnInteger()); + + if (AlwaysTrue()) + EXPECT_NO_THROW(ThrowNothing()); + else + ; // NOLINT + + if (AlwaysFalse()) + EXPECT_ANY_THROW(ThrowNothing()); + + if (AlwaysTrue()) + EXPECT_ANY_THROW(ThrowAnInteger()); + else + ; // NOLINT +} +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#endif // GTEST_HAS_EXCEPTIONS + +// The following code intentionally tests a suboptimal syntax. +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdangling-else" +#pragma GCC diagnostic ignored "-Wempty-body" +#pragma GCC diagnostic ignored "-Wpragmas" +#endif +TEST(AssertionSyntaxTest, NoFatalFailureAssertionsBehavesLikeSingleStatement) { + if (AlwaysFalse()) + EXPECT_NO_FATAL_FAILURE(FAIL()) << "This should never be executed. " + << "It's a compilation test only."; + else + ; // NOLINT + + if (AlwaysFalse()) + ASSERT_NO_FATAL_FAILURE(FAIL()) << ""; + else + ; // NOLINT + + if (AlwaysTrue()) + EXPECT_NO_FATAL_FAILURE(SUCCEED()); + else + ; // NOLINT + + if (AlwaysFalse()) + ; // NOLINT + else + ASSERT_NO_FATAL_FAILURE(SUCCEED()); +} +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +// Tests that the assertion macros work well with switch statements. +TEST(AssertionSyntaxTest, WorksWithSwitch) { + switch (0) { + case 1: + break; + default: + ASSERT_TRUE(true); + } + + switch (0) + case 0: + EXPECT_FALSE(false) << "EXPECT_FALSE failed in switch case"; + + // Binary assertions are implemented using a different code path + // than the Boolean assertions. Hence we test them separately. + switch (0) { + case 1: + default: + ASSERT_EQ(1, 1) << "ASSERT_EQ failed in default switch handler"; + } + + switch (0) + case 0: + EXPECT_NE(1, 2); +} + +#if GTEST_HAS_EXCEPTIONS + +void ThrowAString() { + throw "std::string"; +} + +// Test that the exception assertion macros compile and work with const +// type qualifier. +TEST(AssertionSyntaxTest, WorksWithConst) { + ASSERT_THROW(ThrowAString(), const char*); + + EXPECT_THROW(ThrowAString(), const char*); +} + +#endif // GTEST_HAS_EXCEPTIONS + +} // namespace + +namespace testing { + +// Tests that Google Test tracks SUCCEED*. +TEST(SuccessfulAssertionTest, SUCCEED) { + SUCCEED(); + SUCCEED() << "OK"; + EXPECT_EQ(2, GetUnitTestImpl()->current_test_result()->total_part_count()); +} + +// Tests that Google Test doesn't track successful EXPECT_*. +TEST(SuccessfulAssertionTest, EXPECT) { + EXPECT_TRUE(true); + EXPECT_EQ(0, GetUnitTestImpl()->current_test_result()->total_part_count()); +} + +// Tests that Google Test doesn't track successful EXPECT_STR*. +TEST(SuccessfulAssertionTest, EXPECT_STR) { + EXPECT_STREQ("", ""); + EXPECT_EQ(0, GetUnitTestImpl()->current_test_result()->total_part_count()); +} + +// Tests that Google Test doesn't track successful ASSERT_*. +TEST(SuccessfulAssertionTest, ASSERT) { + ASSERT_TRUE(true); + EXPECT_EQ(0, GetUnitTestImpl()->current_test_result()->total_part_count()); +} + +// Tests that Google Test doesn't track successful ASSERT_STR*. +TEST(SuccessfulAssertionTest, ASSERT_STR) { + ASSERT_STREQ("", ""); + EXPECT_EQ(0, GetUnitTestImpl()->current_test_result()->total_part_count()); +} + +} // namespace testing + +namespace { + +// Tests the message streaming variation of assertions. + +TEST(AssertionWithMessageTest, EXPECT) { + EXPECT_EQ(1, 1) << "This should succeed."; + EXPECT_NONFATAL_FAILURE(EXPECT_NE(1, 1) << "Expected failure #1.", + "Expected failure #1"); + EXPECT_LE(1, 2) << "This should succeed."; + EXPECT_NONFATAL_FAILURE(EXPECT_LT(1, 0) << "Expected failure #2.", + "Expected failure #2."); + EXPECT_GE(1, 0) << "This should succeed."; + EXPECT_NONFATAL_FAILURE(EXPECT_GT(1, 2) << "Expected failure #3.", + "Expected failure #3."); + + EXPECT_STREQ("1", "1") << "This should succeed."; + EXPECT_NONFATAL_FAILURE(EXPECT_STRNE("1", "1") << "Expected failure #4.", + "Expected failure #4."); + EXPECT_STRCASEEQ("a", "A") << "This should succeed."; + EXPECT_NONFATAL_FAILURE(EXPECT_STRCASENE("a", "A") << "Expected failure #5.", + "Expected failure #5."); + + EXPECT_FLOAT_EQ(1, 1) << "This should succeed."; + EXPECT_NONFATAL_FAILURE(EXPECT_DOUBLE_EQ(1, 1.2) << "Expected failure #6.", + "Expected failure #6."); + EXPECT_NEAR(1, 1.1, 0.2) << "This should succeed."; +} + +TEST(AssertionWithMessageTest, ASSERT) { + ASSERT_EQ(1, 1) << "This should succeed."; + ASSERT_NE(1, 2) << "This should succeed."; + ASSERT_LE(1, 2) << "This should succeed."; + ASSERT_LT(1, 2) << "This should succeed."; + ASSERT_GE(1, 0) << "This should succeed."; + EXPECT_FATAL_FAILURE(ASSERT_GT(1, 2) << "Expected failure.", + "Expected failure."); +} + +TEST(AssertionWithMessageTest, ASSERT_STR) { + ASSERT_STREQ("1", "1") << "This should succeed."; + ASSERT_STRNE("1", "2") << "This should succeed."; + ASSERT_STRCASEEQ("a", "A") << "This should succeed."; + EXPECT_FATAL_FAILURE(ASSERT_STRCASENE("a", "A") << "Expected failure.", + "Expected failure."); +} + +TEST(AssertionWithMessageTest, ASSERT_FLOATING) { + ASSERT_FLOAT_EQ(1, 1) << "This should succeed."; + ASSERT_DOUBLE_EQ(1, 1) << "This should succeed."; + EXPECT_FATAL_FAILURE(ASSERT_NEAR(1, 1.2, 0.1) << "Expect failure.", // NOLINT + "Expect failure."); +} + +// Tests using ASSERT_FALSE with a streamed message. +TEST(AssertionWithMessageTest, ASSERT_FALSE) { + ASSERT_FALSE(false) << "This shouldn't fail."; + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_FALSE(true) << "Expected failure: " << 2 << " > " << 1 + << " evaluates to " << true; + }, "Expected failure"); +} + +// Tests using FAIL with a streamed message. +TEST(AssertionWithMessageTest, FAIL) { + EXPECT_FATAL_FAILURE(FAIL() << 0, + "0"); +} + +// Tests using SUCCEED with a streamed message. +TEST(AssertionWithMessageTest, SUCCEED) { + SUCCEED() << "Success == " << 1; +} + +// Tests using ASSERT_TRUE with a streamed message. +TEST(AssertionWithMessageTest, ASSERT_TRUE) { + ASSERT_TRUE(true) << "This should succeed."; + ASSERT_TRUE(true) << true; + EXPECT_FATAL_FAILURE( + { // NOLINT + ASSERT_TRUE(false) << static_cast(nullptr) + << static_cast(nullptr); + }, + "(null)(null)"); +} + +#if GTEST_OS_WINDOWS +// Tests using wide strings in assertion messages. +TEST(AssertionWithMessageTest, WideStringMessage) { + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_TRUE(false) << L"This failure is expected.\x8119"; + }, "This failure is expected."); + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_EQ(1, 2) << "This failure is " + << L"expected too.\x8120"; + }, "This failure is expected too."); +} +#endif // GTEST_OS_WINDOWS + +// Tests EXPECT_TRUE. +TEST(ExpectTest, EXPECT_TRUE) { + EXPECT_TRUE(true) << "Intentional success"; + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(false) << "Intentional failure #1.", + "Intentional failure #1."); + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(false) << "Intentional failure #2.", + "Intentional failure #2."); + EXPECT_TRUE(2 > 1); // NOLINT + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(2 < 1), + "Value of: 2 < 1\n" + " Actual: false\n" + "Expected: true"); + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(2 > 3), + "2 > 3"); +} + +// Tests EXPECT_TRUE(predicate) for predicates returning AssertionResult. +TEST(ExpectTest, ExpectTrueWithAssertionResult) { + EXPECT_TRUE(ResultIsEven(2)); + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(ResultIsEven(3)), + "Value of: ResultIsEven(3)\n" + " Actual: false (3 is odd)\n" + "Expected: true"); + EXPECT_TRUE(ResultIsEvenNoExplanation(2)); + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(ResultIsEvenNoExplanation(3)), + "Value of: ResultIsEvenNoExplanation(3)\n" + " Actual: false (3 is odd)\n" + "Expected: true"); +} + +// Tests EXPECT_FALSE with a streamed message. +TEST(ExpectTest, EXPECT_FALSE) { + EXPECT_FALSE(2 < 1); // NOLINT + EXPECT_FALSE(false) << "Intentional success"; + EXPECT_NONFATAL_FAILURE(EXPECT_FALSE(true) << "Intentional failure #1.", + "Intentional failure #1."); + EXPECT_NONFATAL_FAILURE(EXPECT_FALSE(true) << "Intentional failure #2.", + "Intentional failure #2."); + EXPECT_NONFATAL_FAILURE(EXPECT_FALSE(2 > 1), + "Value of: 2 > 1\n" + " Actual: true\n" + "Expected: false"); + EXPECT_NONFATAL_FAILURE(EXPECT_FALSE(2 < 3), + "2 < 3"); +} + +// Tests EXPECT_FALSE(predicate) for predicates returning AssertionResult. +TEST(ExpectTest, ExpectFalseWithAssertionResult) { + EXPECT_FALSE(ResultIsEven(3)); + EXPECT_NONFATAL_FAILURE(EXPECT_FALSE(ResultIsEven(2)), + "Value of: ResultIsEven(2)\n" + " Actual: true (2 is even)\n" + "Expected: false"); + EXPECT_FALSE(ResultIsEvenNoExplanation(3)); + EXPECT_NONFATAL_FAILURE(EXPECT_FALSE(ResultIsEvenNoExplanation(2)), + "Value of: ResultIsEvenNoExplanation(2)\n" + " Actual: true\n" + "Expected: false"); +} + +#ifdef __BORLANDC__ +// Restores warnings after previous "#pragma option push" suppressed them +# pragma option pop +#endif + +// Tests EXPECT_EQ. +TEST(ExpectTest, EXPECT_EQ) { + EXPECT_EQ(5, 2 + 3); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(5, 2*3), + "Expected equality of these values:\n" + " 5\n" + " 2*3\n" + " Which is: 6"); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(5, 2 - 3), + "2 - 3"); +} + +// Tests using EXPECT_EQ on double values. The purpose is to make +// sure that the specialization we did for integer and anonymous enums +// isn't used for double arguments. +TEST(ExpectTest, EXPECT_EQ_Double) { + // A success. + EXPECT_EQ(5.6, 5.6); + + // A failure. + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(5.1, 5.2), + "5.1"); +} + +// Tests EXPECT_EQ(NULL, pointer). +TEST(ExpectTest, EXPECT_EQ_NULL) { + // A success. + const char* p = nullptr; + EXPECT_EQ(nullptr, p); + + // A failure. + int n = 0; + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(nullptr, &n), " &n\n Which is:"); +} + +// Tests EXPECT_EQ(0, non_pointer). Since the literal 0 can be +// treated as a null pointer by the compiler, we need to make sure +// that EXPECT_EQ(0, non_pointer) isn't interpreted by Google Test as +// EXPECT_EQ(static_cast(NULL), non_pointer). +TEST(ExpectTest, EXPECT_EQ_0) { + int n = 0; + + // A success. + EXPECT_EQ(0, n); + + // A failure. + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(0, 5.6), + " 0\n 5.6"); +} + +// Tests EXPECT_NE. +TEST(ExpectTest, EXPECT_NE) { + EXPECT_NE(6, 7); + + EXPECT_NONFATAL_FAILURE(EXPECT_NE('a', 'a'), + "Expected: ('a') != ('a'), " + "actual: 'a' (97, 0x61) vs 'a' (97, 0x61)"); + EXPECT_NONFATAL_FAILURE(EXPECT_NE(2, 2), + "2"); + char* const p0 = nullptr; + EXPECT_NONFATAL_FAILURE(EXPECT_NE(p0, p0), + "p0"); + // Only way to get the Nokia compiler to compile the cast + // is to have a separate void* variable first. Putting + // the two casts on the same line doesn't work, neither does + // a direct C-style to char*. + void* pv1 = (void*)0x1234; // NOLINT + char* const p1 = reinterpret_cast(pv1); + EXPECT_NONFATAL_FAILURE(EXPECT_NE(p1, p1), + "p1"); +} + +// Tests EXPECT_LE. +TEST(ExpectTest, EXPECT_LE) { + EXPECT_LE(2, 3); + EXPECT_LE(2, 2); + EXPECT_NONFATAL_FAILURE(EXPECT_LE(2, 0), + "Expected: (2) <= (0), actual: 2 vs 0"); + EXPECT_NONFATAL_FAILURE(EXPECT_LE(1.1, 0.9), + "(1.1) <= (0.9)"); +} + +// Tests EXPECT_LT. +TEST(ExpectTest, EXPECT_LT) { + EXPECT_LT(2, 3); + EXPECT_NONFATAL_FAILURE(EXPECT_LT(2, 2), + "Expected: (2) < (2), actual: 2 vs 2"); + EXPECT_NONFATAL_FAILURE(EXPECT_LT(2, 1), + "(2) < (1)"); +} + +// Tests EXPECT_GE. +TEST(ExpectTest, EXPECT_GE) { + EXPECT_GE(2, 1); + EXPECT_GE(2, 2); + EXPECT_NONFATAL_FAILURE(EXPECT_GE(2, 3), + "Expected: (2) >= (3), actual: 2 vs 3"); + EXPECT_NONFATAL_FAILURE(EXPECT_GE(0.9, 1.1), + "(0.9) >= (1.1)"); +} + +// Tests EXPECT_GT. +TEST(ExpectTest, EXPECT_GT) { + EXPECT_GT(2, 1); + EXPECT_NONFATAL_FAILURE(EXPECT_GT(2, 2), + "Expected: (2) > (2), actual: 2 vs 2"); + EXPECT_NONFATAL_FAILURE(EXPECT_GT(2, 3), + "(2) > (3)"); +} + +#if GTEST_HAS_EXCEPTIONS + +// Tests EXPECT_THROW. +TEST(ExpectTest, EXPECT_THROW) { + EXPECT_THROW(ThrowAnInteger(), int); + EXPECT_NONFATAL_FAILURE(EXPECT_THROW(ThrowAnInteger(), bool), + "Expected: ThrowAnInteger() throws an exception of " + "type bool.\n Actual: it throws a different type."); + EXPECT_NONFATAL_FAILURE(EXPECT_THROW(ThrowRuntimeError("A description"), + std::logic_error), + "Expected: ThrowRuntimeError(\"A description\") " + "throws an exception of type std::logic_error.\n " + "Actual: it throws " ERROR_DESC " " + "with description \"A description\"."); + EXPECT_NONFATAL_FAILURE( + EXPECT_THROW(ThrowNothing(), bool), + "Expected: ThrowNothing() throws an exception of type bool.\n" + " Actual: it throws nothing."); +} + +// Tests EXPECT_NO_THROW. +TEST(ExpectTest, EXPECT_NO_THROW) { + EXPECT_NO_THROW(ThrowNothing()); + EXPECT_NONFATAL_FAILURE(EXPECT_NO_THROW(ThrowAnInteger()), + "Expected: ThrowAnInteger() doesn't throw an " + "exception.\n Actual: it throws."); + EXPECT_NONFATAL_FAILURE(EXPECT_NO_THROW(ThrowRuntimeError("A description")), + "Expected: ThrowRuntimeError(\"A description\") " + "doesn't throw an exception.\n " + "Actual: it throws " ERROR_DESC " " + "with description \"A description\"."); +} + +// Tests EXPECT_ANY_THROW. +TEST(ExpectTest, EXPECT_ANY_THROW) { + EXPECT_ANY_THROW(ThrowAnInteger()); + EXPECT_NONFATAL_FAILURE( + EXPECT_ANY_THROW(ThrowNothing()), + "Expected: ThrowNothing() throws an exception.\n" + " Actual: it doesn't."); +} + +#endif // GTEST_HAS_EXCEPTIONS + +// Make sure we deal with the precedence of <<. +TEST(ExpectTest, ExpectPrecedence) { + EXPECT_EQ(1 < 2, true); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(true, true && false), + " true && false\n Which is: false"); +} + + +// Tests the StreamableToString() function. + +// Tests using StreamableToString() on a scalar. +TEST(StreamableToStringTest, Scalar) { + EXPECT_STREQ("5", StreamableToString(5).c_str()); +} + +// Tests using StreamableToString() on a non-char pointer. +TEST(StreamableToStringTest, Pointer) { + int n = 0; + int* p = &n; + EXPECT_STRNE("(null)", StreamableToString(p).c_str()); +} + +// Tests using StreamableToString() on a NULL non-char pointer. +TEST(StreamableToStringTest, NullPointer) { + int* p = nullptr; + EXPECT_STREQ("(null)", StreamableToString(p).c_str()); +} + +// Tests using StreamableToString() on a C string. +TEST(StreamableToStringTest, CString) { + EXPECT_STREQ("Foo", StreamableToString("Foo").c_str()); +} + +// Tests using StreamableToString() on a NULL C string. +TEST(StreamableToStringTest, NullCString) { + char* p = nullptr; + EXPECT_STREQ("(null)", StreamableToString(p).c_str()); +} + +// Tests using streamable values as assertion messages. + +// Tests using std::string as an assertion message. +TEST(StreamableTest, string) { + static const std::string str( + "This failure message is a std::string, and is expected."); + EXPECT_FATAL_FAILURE(FAIL() << str, + str.c_str()); +} + +// Tests that we can output strings containing embedded NULs. +// Limited to Linux because we can only do this with std::string's. +TEST(StreamableTest, stringWithEmbeddedNUL) { + static const char char_array_with_nul[] = + "Here's a NUL\0 and some more string"; + static const std::string string_with_nul(char_array_with_nul, + sizeof(char_array_with_nul) + - 1); // drops the trailing NUL + EXPECT_FATAL_FAILURE(FAIL() << string_with_nul, + "Here's a NUL\\0 and some more string"); +} + +// Tests that we can output a NUL char. +TEST(StreamableTest, NULChar) { + EXPECT_FATAL_FAILURE({ // NOLINT + FAIL() << "A NUL" << '\0' << " and some more string"; + }, "A NUL\\0 and some more string"); +} + +// Tests using int as an assertion message. +TEST(StreamableTest, int) { + EXPECT_FATAL_FAILURE(FAIL() << 900913, + "900913"); +} + +// Tests using NULL char pointer as an assertion message. +// +// In MSVC, streaming a NULL char * causes access violation. Google Test +// implemented a workaround (substituting "(null)" for NULL). This +// tests whether the workaround works. +TEST(StreamableTest, NullCharPtr) { + EXPECT_FATAL_FAILURE(FAIL() << static_cast(nullptr), "(null)"); +} + +// Tests that basic IO manipulators (endl, ends, and flush) can be +// streamed to testing::Message. +TEST(StreamableTest, BasicIoManip) { + EXPECT_FATAL_FAILURE({ // NOLINT + FAIL() << "Line 1." << std::endl + << "A NUL char " << std::ends << std::flush << " in line 2."; + }, "Line 1.\nA NUL char \\0 in line 2."); +} + +// Tests the macros that haven't been covered so far. + +void AddFailureHelper(bool* aborted) { + *aborted = true; + ADD_FAILURE() << "Intentional failure."; + *aborted = false; +} + +// Tests ADD_FAILURE. +TEST(MacroTest, ADD_FAILURE) { + bool aborted = true; + EXPECT_NONFATAL_FAILURE(AddFailureHelper(&aborted), + "Intentional failure."); + EXPECT_FALSE(aborted); +} + +// Tests ADD_FAILURE_AT. +TEST(MacroTest, ADD_FAILURE_AT) { + // Verifies that ADD_FAILURE_AT does generate a nonfatal failure and + // the failure message contains the user-streamed part. + EXPECT_NONFATAL_FAILURE(ADD_FAILURE_AT("foo.cc", 42) << "Wrong!", "Wrong!"); + + // Verifies that the user-streamed part is optional. + EXPECT_NONFATAL_FAILURE(ADD_FAILURE_AT("foo.cc", 42), "Failed"); + + // Unfortunately, we cannot verify that the failure message contains + // the right file path and line number the same way, as + // EXPECT_NONFATAL_FAILURE() doesn't get to see the file path and + // line number. Instead, we do that in googletest-output-test_.cc. +} + +// Tests FAIL. +TEST(MacroTest, FAIL) { + EXPECT_FATAL_FAILURE(FAIL(), + "Failed"); + EXPECT_FATAL_FAILURE(FAIL() << "Intentional failure.", + "Intentional failure."); +} + +// Tests GTEST_FAIL_AT. +TEST(MacroTest, GTEST_FAIL_AT) { + // Verifies that GTEST_FAIL_AT does generate a fatal failure and + // the failure message contains the user-streamed part. + EXPECT_FATAL_FAILURE(GTEST_FAIL_AT("foo.cc", 42) << "Wrong!", "Wrong!"); + + // Verifies that the user-streamed part is optional. + EXPECT_FATAL_FAILURE(GTEST_FAIL_AT("foo.cc", 42), "Failed"); + + // See the ADD_FAIL_AT test above to see how we test that the failure message + // contains the right filename and line number -- the same applies here. +} + +// Tests SUCCEED +TEST(MacroTest, SUCCEED) { + SUCCEED(); + SUCCEED() << "Explicit success."; +} + +// Tests for EXPECT_EQ() and ASSERT_EQ(). +// +// These tests fail *intentionally*, s.t. the failure messages can be +// generated and tested. +// +// We have different tests for different argument types. + +// Tests using bool values in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, Bool) { + EXPECT_EQ(true, true); + EXPECT_FATAL_FAILURE({ + bool false_value = false; + ASSERT_EQ(false_value, true); + }, " false_value\n Which is: false\n true"); +} + +// Tests using int values in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, Int) { + ASSERT_EQ(32, 32); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(32, 33), + " 32\n 33"); +} + +// Tests using time_t values in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, Time_T) { + EXPECT_EQ(static_cast(0), + static_cast(0)); + EXPECT_FATAL_FAILURE(ASSERT_EQ(static_cast(0), + static_cast(1234)), + "1234"); +} + +// Tests using char values in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, Char) { + ASSERT_EQ('z', 'z'); + const char ch = 'b'; + EXPECT_NONFATAL_FAILURE(EXPECT_EQ('\0', ch), + " ch\n Which is: 'b'"); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ('a', ch), + " ch\n Which is: 'b'"); +} + +// Tests using wchar_t values in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, WideChar) { + EXPECT_EQ(L'b', L'b'); + + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(L'\0', L'x'), + "Expected equality of these values:\n" + " L'\0'\n" + " Which is: L'\0' (0, 0x0)\n" + " L'x'\n" + " Which is: L'x' (120, 0x78)"); + + static wchar_t wchar; + wchar = L'b'; + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(L'a', wchar), + "wchar"); + wchar = 0x8119; + EXPECT_FATAL_FAILURE(ASSERT_EQ(static_cast(0x8120), wchar), + " wchar\n Which is: L'"); +} + +// Tests using ::std::string values in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, StdString) { + // Compares a const char* to an std::string that has identical + // content. + ASSERT_EQ("Test", ::std::string("Test")); + + // Compares two identical std::strings. + static const ::std::string str1("A * in the middle"); + static const ::std::string str2(str1); + EXPECT_EQ(str1, str2); + + // Compares a const char* to an std::string that has different + // content + EXPECT_NONFATAL_FAILURE(EXPECT_EQ("Test", ::std::string("test")), + "\"test\""); + + // Compares an std::string to a char* that has different content. + char* const p1 = const_cast("foo"); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(::std::string("bar"), p1), + "p1"); + + // Compares two std::strings that have different contents, one of + // which having a NUL character in the middle. This should fail. + static ::std::string str3(str1); + str3.at(2) = '\0'; + EXPECT_FATAL_FAILURE(ASSERT_EQ(str1, str3), + " str3\n Which is: \"A \\0 in the middle\""); +} + +#if GTEST_HAS_STD_WSTRING + +// Tests using ::std::wstring values in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, StdWideString) { + // Compares two identical std::wstrings. + const ::std::wstring wstr1(L"A * in the middle"); + const ::std::wstring wstr2(wstr1); + ASSERT_EQ(wstr1, wstr2); + + // Compares an std::wstring to a const wchar_t* that has identical + // content. + const wchar_t kTestX8119[] = { 'T', 'e', 's', 't', 0x8119, '\0' }; + EXPECT_EQ(::std::wstring(kTestX8119), kTestX8119); + + // Compares an std::wstring to a const wchar_t* that has different + // content. + const wchar_t kTestX8120[] = { 'T', 'e', 's', 't', 0x8120, '\0' }; + EXPECT_NONFATAL_FAILURE({ // NOLINT + EXPECT_EQ(::std::wstring(kTestX8119), kTestX8120); + }, "kTestX8120"); + + // Compares two std::wstrings that have different contents, one of + // which having a NUL character in the middle. + ::std::wstring wstr3(wstr1); + wstr3.at(2) = L'\0'; + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(wstr1, wstr3), + "wstr3"); + + // Compares a wchar_t* to an std::wstring that has different + // content. + EXPECT_FATAL_FAILURE({ // NOLINT + ASSERT_EQ(const_cast(L"foo"), ::std::wstring(L"bar")); + }, ""); +} + +#endif // GTEST_HAS_STD_WSTRING + +// Tests using char pointers in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, CharPointer) { + char* const p0 = nullptr; + // Only way to get the Nokia compiler to compile the cast + // is to have a separate void* variable first. Putting + // the two casts on the same line doesn't work, neither does + // a direct C-style to char*. + void* pv1 = (void*)0x1234; // NOLINT + void* pv2 = (void*)0xABC0; // NOLINT + char* const p1 = reinterpret_cast(pv1); + char* const p2 = reinterpret_cast(pv2); + ASSERT_EQ(p1, p1); + + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(p0, p2), + " p2\n Which is:"); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(p1, p2), + " p2\n Which is:"); + EXPECT_FATAL_FAILURE(ASSERT_EQ(reinterpret_cast(0x1234), + reinterpret_cast(0xABC0)), + "ABC0"); +} + +// Tests using wchar_t pointers in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, WideCharPointer) { + wchar_t* const p0 = nullptr; + // Only way to get the Nokia compiler to compile the cast + // is to have a separate void* variable first. Putting + // the two casts on the same line doesn't work, neither does + // a direct C-style to char*. + void* pv1 = (void*)0x1234; // NOLINT + void* pv2 = (void*)0xABC0; // NOLINT + wchar_t* const p1 = reinterpret_cast(pv1); + wchar_t* const p2 = reinterpret_cast(pv2); + EXPECT_EQ(p0, p0); + + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(p0, p2), + " p2\n Which is:"); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(p1, p2), + " p2\n Which is:"); + void* pv3 = (void*)0x1234; // NOLINT + void* pv4 = (void*)0xABC0; // NOLINT + const wchar_t* p3 = reinterpret_cast(pv3); + const wchar_t* p4 = reinterpret_cast(pv4); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(p3, p4), + "p4"); +} + +// Tests using other types of pointers in {EXPECT|ASSERT}_EQ. +TEST(EqAssertionTest, OtherPointer) { + ASSERT_EQ(static_cast(nullptr), static_cast(nullptr)); + EXPECT_FATAL_FAILURE(ASSERT_EQ(static_cast(nullptr), + reinterpret_cast(0x1234)), + "0x1234"); +} + +// A class that supports binary comparison operators but not streaming. +class UnprintableChar { + public: + explicit UnprintableChar(char ch) : char_(ch) {} + + bool operator==(const UnprintableChar& rhs) const { + return char_ == rhs.char_; + } + bool operator!=(const UnprintableChar& rhs) const { + return char_ != rhs.char_; + } + bool operator<(const UnprintableChar& rhs) const { + return char_ < rhs.char_; + } + bool operator<=(const UnprintableChar& rhs) const { + return char_ <= rhs.char_; + } + bool operator>(const UnprintableChar& rhs) const { + return char_ > rhs.char_; + } + bool operator>=(const UnprintableChar& rhs) const { + return char_ >= rhs.char_; + } + + private: + char char_; +}; + +// Tests that ASSERT_EQ() and friends don't require the arguments to +// be printable. +TEST(ComparisonAssertionTest, AcceptsUnprintableArgs) { + const UnprintableChar x('x'), y('y'); + ASSERT_EQ(x, x); + EXPECT_NE(x, y); + ASSERT_LT(x, y); + EXPECT_LE(x, y); + ASSERT_GT(y, x); + EXPECT_GE(x, x); + + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(x, y), "1-byte object <78>"); + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(x, y), "1-byte object <79>"); + EXPECT_NONFATAL_FAILURE(EXPECT_LT(y, y), "1-byte object <79>"); + EXPECT_NONFATAL_FAILURE(EXPECT_GT(x, y), "1-byte object <78>"); + EXPECT_NONFATAL_FAILURE(EXPECT_GT(x, y), "1-byte object <79>"); + + // Code tested by EXPECT_FATAL_FAILURE cannot reference local + // variables, so we have to write UnprintableChar('x') instead of x. +#ifndef __BORLANDC__ + // ICE's in C++Builder. + EXPECT_FATAL_FAILURE(ASSERT_NE(UnprintableChar('x'), UnprintableChar('x')), + "1-byte object <78>"); + EXPECT_FATAL_FAILURE(ASSERT_LE(UnprintableChar('y'), UnprintableChar('x')), + "1-byte object <78>"); +#endif + EXPECT_FATAL_FAILURE(ASSERT_LE(UnprintableChar('y'), UnprintableChar('x')), + "1-byte object <79>"); + EXPECT_FATAL_FAILURE(ASSERT_GE(UnprintableChar('x'), UnprintableChar('y')), + "1-byte object <78>"); + EXPECT_FATAL_FAILURE(ASSERT_GE(UnprintableChar('x'), UnprintableChar('y')), + "1-byte object <79>"); +} + +// Tests the FRIEND_TEST macro. + +// This class has a private member we want to test. We will test it +// both in a TEST and in a TEST_F. +class Foo { + public: + Foo() {} + + private: + int Bar() const { return 1; } + + // Declares the friend tests that can access the private member + // Bar(). + FRIEND_TEST(FRIEND_TEST_Test, TEST); + FRIEND_TEST(FRIEND_TEST_Test2, TEST_F); +}; + +// Tests that the FRIEND_TEST declaration allows a TEST to access a +// class's private members. This should compile. +TEST(FRIEND_TEST_Test, TEST) { + ASSERT_EQ(1, Foo().Bar()); +} + +// The fixture needed to test using FRIEND_TEST with TEST_F. +class FRIEND_TEST_Test2 : public Test { + protected: + Foo foo; +}; + +// Tests that the FRIEND_TEST declaration allows a TEST_F to access a +// class's private members. This should compile. +TEST_F(FRIEND_TEST_Test2, TEST_F) { + ASSERT_EQ(1, foo.Bar()); +} + +// Tests the life cycle of Test objects. + +// The test fixture for testing the life cycle of Test objects. +// +// This class counts the number of live test objects that uses this +// fixture. +class TestLifeCycleTest : public Test { + protected: + // Constructor. Increments the number of test objects that uses + // this fixture. + TestLifeCycleTest() { count_++; } + + // Destructor. Decrements the number of test objects that uses this + // fixture. + ~TestLifeCycleTest() override { count_--; } + + // Returns the number of live test objects that uses this fixture. + int count() const { return count_; } + + private: + static int count_; +}; + +int TestLifeCycleTest::count_ = 0; + +// Tests the life cycle of test objects. +TEST_F(TestLifeCycleTest, Test1) { + // There should be only one test object in this test case that's + // currently alive. + ASSERT_EQ(1, count()); +} + +// Tests the life cycle of test objects. +TEST_F(TestLifeCycleTest, Test2) { + // After Test1 is done and Test2 is started, there should still be + // only one live test object, as the object for Test1 should've been + // deleted. + ASSERT_EQ(1, count()); +} + +} // namespace + +// Tests that the copy constructor works when it is NOT optimized away by +// the compiler. +TEST(AssertionResultTest, CopyConstructorWorksWhenNotOptimied) { + // Checks that the copy constructor doesn't try to dereference NULL pointers + // in the source object. + AssertionResult r1 = AssertionSuccess(); + AssertionResult r2 = r1; + // The following line is added to prevent the compiler from optimizing + // away the constructor call. + r1 << "abc"; + + AssertionResult r3 = r1; + EXPECT_EQ(static_cast(r3), static_cast(r1)); + EXPECT_STREQ("abc", r1.message()); +} + +// Tests that AssertionSuccess and AssertionFailure construct +// AssertionResult objects as expected. +TEST(AssertionResultTest, ConstructionWorks) { + AssertionResult r1 = AssertionSuccess(); + EXPECT_TRUE(r1); + EXPECT_STREQ("", r1.message()); + + AssertionResult r2 = AssertionSuccess() << "abc"; + EXPECT_TRUE(r2); + EXPECT_STREQ("abc", r2.message()); + + AssertionResult r3 = AssertionFailure(); + EXPECT_FALSE(r3); + EXPECT_STREQ("", r3.message()); + + AssertionResult r4 = AssertionFailure() << "def"; + EXPECT_FALSE(r4); + EXPECT_STREQ("def", r4.message()); + + AssertionResult r5 = AssertionFailure(Message() << "ghi"); + EXPECT_FALSE(r5); + EXPECT_STREQ("ghi", r5.message()); +} + +// Tests that the negation flips the predicate result but keeps the message. +TEST(AssertionResultTest, NegationWorks) { + AssertionResult r1 = AssertionSuccess() << "abc"; + EXPECT_FALSE(!r1); + EXPECT_STREQ("abc", (!r1).message()); + + AssertionResult r2 = AssertionFailure() << "def"; + EXPECT_TRUE(!r2); + EXPECT_STREQ("def", (!r2).message()); +} + +TEST(AssertionResultTest, StreamingWorks) { + AssertionResult r = AssertionSuccess(); + r << "abc" << 'd' << 0 << true; + EXPECT_STREQ("abcd0true", r.message()); +} + +TEST(AssertionResultTest, CanStreamOstreamManipulators) { + AssertionResult r = AssertionSuccess(); + r << "Data" << std::endl << std::flush << std::ends << "Will be visible"; + EXPECT_STREQ("Data\n\\0Will be visible", r.message()); +} + +// The next test uses explicit conversion operators + +TEST(AssertionResultTest, ConstructibleFromContextuallyConvertibleToBool) { + struct ExplicitlyConvertibleToBool { + explicit operator bool() const { return value; } + bool value; + }; + ExplicitlyConvertibleToBool v1 = {false}; + ExplicitlyConvertibleToBool v2 = {true}; + EXPECT_FALSE(v1); + EXPECT_TRUE(v2); +} + +struct ConvertibleToAssertionResult { + operator AssertionResult() const { return AssertionResult(true); } +}; + +TEST(AssertionResultTest, ConstructibleFromImplicitlyConvertible) { + ConvertibleToAssertionResult obj; + EXPECT_TRUE(obj); +} + +// Tests streaming a user type whose definition and operator << are +// both in the global namespace. +class Base { + public: + explicit Base(int an_x) : x_(an_x) {} + int x() const { return x_; } + private: + int x_; +}; +std::ostream& operator<<(std::ostream& os, + const Base& val) { + return os << val.x(); +} +std::ostream& operator<<(std::ostream& os, + const Base* pointer) { + return os << "(" << pointer->x() << ")"; +} + +TEST(MessageTest, CanStreamUserTypeInGlobalNameSpace) { + Message msg; + Base a(1); + + msg << a << &a; // Uses ::operator<<. + EXPECT_STREQ("1(1)", msg.GetString().c_str()); +} + +// Tests streaming a user type whose definition and operator<< are +// both in an unnamed namespace. +namespace { +class MyTypeInUnnamedNameSpace : public Base { + public: + explicit MyTypeInUnnamedNameSpace(int an_x): Base(an_x) {} +}; +std::ostream& operator<<(std::ostream& os, + const MyTypeInUnnamedNameSpace& val) { + return os << val.x(); +} +std::ostream& operator<<(std::ostream& os, + const MyTypeInUnnamedNameSpace* pointer) { + return os << "(" << pointer->x() << ")"; +} +} // namespace + +TEST(MessageTest, CanStreamUserTypeInUnnamedNameSpace) { + Message msg; + MyTypeInUnnamedNameSpace a(1); + + msg << a << &a; // Uses ::operator<<. + EXPECT_STREQ("1(1)", msg.GetString().c_str()); +} + +// Tests streaming a user type whose definition and operator<< are +// both in a user namespace. +namespace namespace1 { +class MyTypeInNameSpace1 : public Base { + public: + explicit MyTypeInNameSpace1(int an_x): Base(an_x) {} +}; +std::ostream& operator<<(std::ostream& os, + const MyTypeInNameSpace1& val) { + return os << val.x(); +} +std::ostream& operator<<(std::ostream& os, + const MyTypeInNameSpace1* pointer) { + return os << "(" << pointer->x() << ")"; +} +} // namespace namespace1 + +TEST(MessageTest, CanStreamUserTypeInUserNameSpace) { + Message msg; + namespace1::MyTypeInNameSpace1 a(1); + + msg << a << &a; // Uses namespace1::operator<<. + EXPECT_STREQ("1(1)", msg.GetString().c_str()); +} + +// Tests streaming a user type whose definition is in a user namespace +// but whose operator<< is in the global namespace. +namespace namespace2 { +class MyTypeInNameSpace2 : public ::Base { + public: + explicit MyTypeInNameSpace2(int an_x): Base(an_x) {} +}; +} // namespace namespace2 +std::ostream& operator<<(std::ostream& os, + const namespace2::MyTypeInNameSpace2& val) { + return os << val.x(); +} +std::ostream& operator<<(std::ostream& os, + const namespace2::MyTypeInNameSpace2* pointer) { + return os << "(" << pointer->x() << ")"; +} + +TEST(MessageTest, CanStreamUserTypeInUserNameSpaceWithStreamOperatorInGlobal) { + Message msg; + namespace2::MyTypeInNameSpace2 a(1); + + msg << a << &a; // Uses ::operator<<. + EXPECT_STREQ("1(1)", msg.GetString().c_str()); +} + +// Tests streaming NULL pointers to testing::Message. +TEST(MessageTest, NullPointers) { + Message msg; + char* const p1 = nullptr; + unsigned char* const p2 = nullptr; + int* p3 = nullptr; + double* p4 = nullptr; + bool* p5 = nullptr; + Message* p6 = nullptr; + + msg << p1 << p2 << p3 << p4 << p5 << p6; + ASSERT_STREQ("(null)(null)(null)(null)(null)(null)", + msg.GetString().c_str()); +} + +// Tests streaming wide strings to testing::Message. +TEST(MessageTest, WideStrings) { + // Streams a NULL of type const wchar_t*. + const wchar_t* const_wstr = nullptr; + EXPECT_STREQ("(null)", + (Message() << const_wstr).GetString().c_str()); + + // Streams a NULL of type wchar_t*. + wchar_t* wstr = nullptr; + EXPECT_STREQ("(null)", + (Message() << wstr).GetString().c_str()); + + // Streams a non-NULL of type const wchar_t*. + const_wstr = L"abc\x8119"; + EXPECT_STREQ("abc\xe8\x84\x99", + (Message() << const_wstr).GetString().c_str()); + + // Streams a non-NULL of type wchar_t*. + wstr = const_cast(const_wstr); + EXPECT_STREQ("abc\xe8\x84\x99", + (Message() << wstr).GetString().c_str()); +} + + +// This line tests that we can define tests in the testing namespace. +namespace testing { + +// Tests the TestInfo class. + +class TestInfoTest : public Test { + protected: + static const TestInfo* GetTestInfo(const char* test_name) { + const TestSuite* const test_suite = + GetUnitTestImpl()->GetTestSuite("TestInfoTest", "", nullptr, nullptr); + + for (int i = 0; i < test_suite->total_test_count(); ++i) { + const TestInfo* const test_info = test_suite->GetTestInfo(i); + if (strcmp(test_name, test_info->name()) == 0) + return test_info; + } + return nullptr; + } + + static const TestResult* GetTestResult( + const TestInfo* test_info) { + return test_info->result(); + } +}; + +// Tests TestInfo::test_case_name() and TestInfo::name(). +TEST_F(TestInfoTest, Names) { + const TestInfo* const test_info = GetTestInfo("Names"); + + ASSERT_STREQ("TestInfoTest", test_info->test_suite_name()); + ASSERT_STREQ("Names", test_info->name()); +} + +// Tests TestInfo::result(). +TEST_F(TestInfoTest, result) { + const TestInfo* const test_info = GetTestInfo("result"); + + // Initially, there is no TestPartResult for this test. + ASSERT_EQ(0, GetTestResult(test_info)->total_part_count()); + + // After the previous assertion, there is still none. + ASSERT_EQ(0, GetTestResult(test_info)->total_part_count()); +} + +#define VERIFY_CODE_LOCATION \ + const int expected_line = __LINE__ - 1; \ + const TestInfo* const test_info = GetUnitTestImpl()->current_test_info(); \ + ASSERT_TRUE(test_info); \ + EXPECT_STREQ(__FILE__, test_info->file()); \ + EXPECT_EQ(expected_line, test_info->line()) + +TEST(CodeLocationForTEST, Verify) { + VERIFY_CODE_LOCATION; +} + +class CodeLocationForTESTF : public Test { +}; + +TEST_F(CodeLocationForTESTF, Verify) { + VERIFY_CODE_LOCATION; +} + +class CodeLocationForTESTP : public TestWithParam { +}; + +TEST_P(CodeLocationForTESTP, Verify) { + VERIFY_CODE_LOCATION; +} + +INSTANTIATE_TEST_SUITE_P(, CodeLocationForTESTP, Values(0)); + +template +class CodeLocationForTYPEDTEST : public Test { +}; + +TYPED_TEST_SUITE(CodeLocationForTYPEDTEST, int); + +TYPED_TEST(CodeLocationForTYPEDTEST, Verify) { + VERIFY_CODE_LOCATION; +} + +template +class CodeLocationForTYPEDTESTP : public Test { +}; + +TYPED_TEST_SUITE_P(CodeLocationForTYPEDTESTP); + +TYPED_TEST_P(CodeLocationForTYPEDTESTP, Verify) { + VERIFY_CODE_LOCATION; +} + +REGISTER_TYPED_TEST_SUITE_P(CodeLocationForTYPEDTESTP, Verify); + +INSTANTIATE_TYPED_TEST_SUITE_P(My, CodeLocationForTYPEDTESTP, int); + +#undef VERIFY_CODE_LOCATION + +// Tests setting up and tearing down a test case. +// Legacy API is deprecated but still available +#ifndef GTEST_REMOVE_LEGACY_TEST_CASEAPI_ +class SetUpTestCaseTest : public Test { + protected: + // This will be called once before the first test in this test case + // is run. + static void SetUpTestCase() { + printf("Setting up the test case . . .\n"); + + // Initializes some shared resource. In this simple example, we + // just create a C string. More complex stuff can be done if + // desired. + shared_resource_ = "123"; + + // Increments the number of test cases that have been set up. + counter_++; + + // SetUpTestCase() should be called only once. + EXPECT_EQ(1, counter_); + } + + // This will be called once after the last test in this test case is + // run. + static void TearDownTestCase() { + printf("Tearing down the test case . . .\n"); + + // Decrements the number of test cases that have been set up. + counter_--; + + // TearDownTestCase() should be called only once. + EXPECT_EQ(0, counter_); + + // Cleans up the shared resource. + shared_resource_ = nullptr; + } + + // This will be called before each test in this test case. + void SetUp() override { + // SetUpTestCase() should be called only once, so counter_ should + // always be 1. + EXPECT_EQ(1, counter_); + } + + // Number of test cases that have been set up. + static int counter_; + + // Some resource to be shared by all tests in this test case. + static const char* shared_resource_; +}; + +int SetUpTestCaseTest::counter_ = 0; +const char* SetUpTestCaseTest::shared_resource_ = nullptr; + +// A test that uses the shared resource. +TEST_F(SetUpTestCaseTest, Test1) { EXPECT_STRNE(nullptr, shared_resource_); } + +// Another test that uses the shared resource. +TEST_F(SetUpTestCaseTest, Test2) { + EXPECT_STREQ("123", shared_resource_); +} +#endif // GTEST_REMOVE_LEGACY_TEST_CASEAPI_ + +// Tests SetupTestSuite/TearDown TestSuite +class SetUpTestSuiteTest : public Test { + protected: + // This will be called once before the first test in this test case + // is run. + static void SetUpTestSuite() { + printf("Setting up the test suite . . .\n"); + + // Initializes some shared resource. In this simple example, we + // just create a C string. More complex stuff can be done if + // desired. + shared_resource_ = "123"; + + // Increments the number of test cases that have been set up. + counter_++; + + // SetUpTestSuite() should be called only once. + EXPECT_EQ(1, counter_); + } + + // This will be called once after the last test in this test case is + // run. + static void TearDownTestSuite() { + printf("Tearing down the test suite . . .\n"); + + // Decrements the number of test suites that have been set up. + counter_--; + + // TearDownTestSuite() should be called only once. + EXPECT_EQ(0, counter_); + + // Cleans up the shared resource. + shared_resource_ = nullptr; + } + + // This will be called before each test in this test case. + void SetUp() override { + // SetUpTestSuite() should be called only once, so counter_ should + // always be 1. + EXPECT_EQ(1, counter_); + } + + // Number of test suites that have been set up. + static int counter_; + + // Some resource to be shared by all tests in this test case. + static const char* shared_resource_; +}; + +int SetUpTestSuiteTest::counter_ = 0; +const char* SetUpTestSuiteTest::shared_resource_ = nullptr; + +// A test that uses the shared resource. +TEST_F(SetUpTestSuiteTest, TestSetupTestSuite1) { + EXPECT_STRNE(nullptr, shared_resource_); +} + +// Another test that uses the shared resource. +TEST_F(SetUpTestSuiteTest, TestSetupTestSuite2) { + EXPECT_STREQ("123", shared_resource_); +} + +// The ParseFlagsTest test case tests ParseGoogleTestFlagsOnly. + +// The Flags struct stores a copy of all Google Test flags. +struct Flags { + // Constructs a Flags struct where each flag has its default value. + Flags() + : also_run_disabled_tests(false), + break_on_failure(false), + catch_exceptions(false), + death_test_use_fork(false), + fail_fast(false), + filter(""), + list_tests(false), + output(""), + brief(false), + print_time(true), + random_seed(0), + repeat(1), + recreate_environments_when_repeating(true), + shuffle(false), + stack_trace_depth(kMaxStackTraceDepth), + stream_result_to(""), + throw_on_failure(false) {} + + // Factory methods. + + // Creates a Flags struct where the gtest_also_run_disabled_tests flag has + // the given value. + static Flags AlsoRunDisabledTests(bool also_run_disabled_tests) { + Flags flags; + flags.also_run_disabled_tests = also_run_disabled_tests; + return flags; + } + + // Creates a Flags struct where the gtest_break_on_failure flag has + // the given value. + static Flags BreakOnFailure(bool break_on_failure) { + Flags flags; + flags.break_on_failure = break_on_failure; + return flags; + } + + // Creates a Flags struct where the gtest_catch_exceptions flag has + // the given value. + static Flags CatchExceptions(bool catch_exceptions) { + Flags flags; + flags.catch_exceptions = catch_exceptions; + return flags; + } + + // Creates a Flags struct where the gtest_death_test_use_fork flag has + // the given value. + static Flags DeathTestUseFork(bool death_test_use_fork) { + Flags flags; + flags.death_test_use_fork = death_test_use_fork; + return flags; + } + + // Creates a Flags struct where the gtest_fail_fast flag has + // the given value. + static Flags FailFast(bool fail_fast) { + Flags flags; + flags.fail_fast = fail_fast; + return flags; + } + + // Creates a Flags struct where the gtest_filter flag has the given + // value. + static Flags Filter(const char* filter) { + Flags flags; + flags.filter = filter; + return flags; + } + + // Creates a Flags struct where the gtest_list_tests flag has the + // given value. + static Flags ListTests(bool list_tests) { + Flags flags; + flags.list_tests = list_tests; + return flags; + } + + // Creates a Flags struct where the gtest_output flag has the given + // value. + static Flags Output(const char* output) { + Flags flags; + flags.output = output; + return flags; + } + + // Creates a Flags struct where the gtest_brief flag has the given + // value. + static Flags Brief(bool brief) { + Flags flags; + flags.brief = brief; + return flags; + } + + // Creates a Flags struct where the gtest_print_time flag has the given + // value. + static Flags PrintTime(bool print_time) { + Flags flags; + flags.print_time = print_time; + return flags; + } + + // Creates a Flags struct where the gtest_random_seed flag has the given + // value. + static Flags RandomSeed(int32_t random_seed) { + Flags flags; + flags.random_seed = random_seed; + return flags; + } + + // Creates a Flags struct where the gtest_repeat flag has the given + // value. + static Flags Repeat(int32_t repeat) { + Flags flags; + flags.repeat = repeat; + return flags; + } + + // Creates a Flags struct where the gtest_recreate_environments_when_repeating + // flag has the given value. + static Flags RecreateEnvironmentsWhenRepeating( + bool recreate_environments_when_repeating) { + Flags flags; + flags.recreate_environments_when_repeating = + recreate_environments_when_repeating; + return flags; + } + + // Creates a Flags struct where the gtest_shuffle flag has the given + // value. + static Flags Shuffle(bool shuffle) { + Flags flags; + flags.shuffle = shuffle; + return flags; + } + + // Creates a Flags struct where the GTEST_FLAG(stack_trace_depth) flag has + // the given value. + static Flags StackTraceDepth(int32_t stack_trace_depth) { + Flags flags; + flags.stack_trace_depth = stack_trace_depth; + return flags; + } + + // Creates a Flags struct where the GTEST_FLAG(stream_result_to) flag has + // the given value. + static Flags StreamResultTo(const char* stream_result_to) { + Flags flags; + flags.stream_result_to = stream_result_to; + return flags; + } + + // Creates a Flags struct where the gtest_throw_on_failure flag has + // the given value. + static Flags ThrowOnFailure(bool throw_on_failure) { + Flags flags; + flags.throw_on_failure = throw_on_failure; + return flags; + } + + // These fields store the flag values. + bool also_run_disabled_tests; + bool break_on_failure; + bool catch_exceptions; + bool death_test_use_fork; + bool fail_fast; + const char* filter; + bool list_tests; + const char* output; + bool brief; + bool print_time; + int32_t random_seed; + int32_t repeat; + bool recreate_environments_when_repeating; + bool shuffle; + int32_t stack_trace_depth; + const char* stream_result_to; + bool throw_on_failure; +}; + +// Fixture for testing ParseGoogleTestFlagsOnly(). +class ParseFlagsTest : public Test { + protected: + // Clears the flags before each test. + void SetUp() override { + GTEST_FLAG_SET(also_run_disabled_tests, false); + GTEST_FLAG_SET(break_on_failure, false); + GTEST_FLAG_SET(catch_exceptions, false); + GTEST_FLAG_SET(death_test_use_fork, false); + GTEST_FLAG_SET(fail_fast, false); + GTEST_FLAG_SET(filter, ""); + GTEST_FLAG_SET(list_tests, false); + GTEST_FLAG_SET(output, ""); + GTEST_FLAG_SET(brief, false); + GTEST_FLAG_SET(print_time, true); + GTEST_FLAG_SET(random_seed, 0); + GTEST_FLAG_SET(repeat, 1); + GTEST_FLAG_SET(recreate_environments_when_repeating, true); + GTEST_FLAG_SET(shuffle, false); + GTEST_FLAG_SET(stack_trace_depth, kMaxStackTraceDepth); + GTEST_FLAG_SET(stream_result_to, ""); + GTEST_FLAG_SET(throw_on_failure, false); + } + + // Asserts that two narrow or wide string arrays are equal. + template + static void AssertStringArrayEq(int size1, CharType** array1, int size2, + CharType** array2) { + ASSERT_EQ(size1, size2) << " Array sizes different."; + + for (int i = 0; i != size1; i++) { + ASSERT_STREQ(array1[i], array2[i]) << " where i == " << i; + } + } + + // Verifies that the flag values match the expected values. + static void CheckFlags(const Flags& expected) { + EXPECT_EQ(expected.also_run_disabled_tests, + GTEST_FLAG_GET(also_run_disabled_tests)); + EXPECT_EQ(expected.break_on_failure, GTEST_FLAG_GET(break_on_failure)); + EXPECT_EQ(expected.catch_exceptions, GTEST_FLAG_GET(catch_exceptions)); + EXPECT_EQ(expected.death_test_use_fork, + GTEST_FLAG_GET(death_test_use_fork)); + EXPECT_EQ(expected.fail_fast, GTEST_FLAG_GET(fail_fast)); + EXPECT_STREQ(expected.filter, GTEST_FLAG_GET(filter).c_str()); + EXPECT_EQ(expected.list_tests, GTEST_FLAG_GET(list_tests)); + EXPECT_STREQ(expected.output, GTEST_FLAG_GET(output).c_str()); + EXPECT_EQ(expected.brief, GTEST_FLAG_GET(brief)); + EXPECT_EQ(expected.print_time, GTEST_FLAG_GET(print_time)); + EXPECT_EQ(expected.random_seed, GTEST_FLAG_GET(random_seed)); + EXPECT_EQ(expected.repeat, GTEST_FLAG_GET(repeat)); + EXPECT_EQ(expected.recreate_environments_when_repeating, + GTEST_FLAG_GET(recreate_environments_when_repeating)); + EXPECT_EQ(expected.shuffle, GTEST_FLAG_GET(shuffle)); + EXPECT_EQ(expected.stack_trace_depth, GTEST_FLAG_GET(stack_trace_depth)); + EXPECT_STREQ(expected.stream_result_to, + GTEST_FLAG_GET(stream_result_to).c_str()); + EXPECT_EQ(expected.throw_on_failure, GTEST_FLAG_GET(throw_on_failure)); + } + + // Parses a command line (specified by argc1 and argv1), then + // verifies that the flag values are expected and that the + // recognized flags are removed from the command line. + template + static void TestParsingFlags(int argc1, const CharType** argv1, + int argc2, const CharType** argv2, + const Flags& expected, bool should_print_help) { + const bool saved_help_flag = ::testing::internal::g_help_flag; + ::testing::internal::g_help_flag = false; + +# if GTEST_HAS_STREAM_REDIRECTION + CaptureStdout(); +# endif + + // Parses the command line. + internal::ParseGoogleTestFlagsOnly(&argc1, const_cast(argv1)); + +# if GTEST_HAS_STREAM_REDIRECTION + const std::string captured_stdout = GetCapturedStdout(); +# endif + + // Verifies the flag values. + CheckFlags(expected); + + // Verifies that the recognized flags are removed from the command + // line. + AssertStringArrayEq(argc1 + 1, argv1, argc2 + 1, argv2); + + // ParseGoogleTestFlagsOnly should neither set g_help_flag nor print the + // help message for the flags it recognizes. + EXPECT_EQ(should_print_help, ::testing::internal::g_help_flag); + +# if GTEST_HAS_STREAM_REDIRECTION + const char* const expected_help_fragment = + "This program contains tests written using"; + if (should_print_help) { + EXPECT_PRED_FORMAT2(IsSubstring, expected_help_fragment, captured_stdout); + } else { + EXPECT_PRED_FORMAT2(IsNotSubstring, + expected_help_fragment, captured_stdout); + } +# endif // GTEST_HAS_STREAM_REDIRECTION + + ::testing::internal::g_help_flag = saved_help_flag; + } + + // This macro wraps TestParsingFlags s.t. the user doesn't need + // to specify the array sizes. + +# define GTEST_TEST_PARSING_FLAGS_(argv1, argv2, expected, should_print_help) \ + TestParsingFlags(sizeof(argv1)/sizeof(*argv1) - 1, argv1, \ + sizeof(argv2)/sizeof(*argv2) - 1, argv2, \ + expected, should_print_help) +}; + +// Tests parsing an empty command line. +TEST_F(ParseFlagsTest, Empty) { + const char* argv[] = {nullptr}; + + const char* argv2[] = {nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags(), false); +} + +// Tests parsing a command line that has no flag. +TEST_F(ParseFlagsTest, NoFlag) { + const char* argv[] = {"foo.exe", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags(), false); +} + +// Tests parsing --gtest_fail_fast. +TEST_F(ParseFlagsTest, FailFast) { + const char* argv[] = {"foo.exe", "--gtest_fail_fast", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::FailFast(true), false); +} + +// Tests parsing a bad --gtest_filter flag. +TEST_F(ParseFlagsTest, FilterBad) { + const char* argv[] = {"foo.exe", "--gtest_filter", nullptr}; + + const char* argv2[] = {"foo.exe", "--gtest_filter", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Filter(""), true); +} + +// Tests parsing an empty --gtest_filter flag. +TEST_F(ParseFlagsTest, FilterEmpty) { + const char* argv[] = {"foo.exe", "--gtest_filter=", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Filter(""), false); +} + +// Tests parsing a non-empty --gtest_filter flag. +TEST_F(ParseFlagsTest, FilterNonEmpty) { + const char* argv[] = {"foo.exe", "--gtest_filter=abc", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Filter("abc"), false); +} + +// Tests parsing --gtest_break_on_failure. +TEST_F(ParseFlagsTest, BreakOnFailureWithoutValue) { + const char* argv[] = {"foo.exe", "--gtest_break_on_failure", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::BreakOnFailure(true), false); +} + +// Tests parsing --gtest_break_on_failure=0. +TEST_F(ParseFlagsTest, BreakOnFailureFalse_0) { + const char* argv[] = {"foo.exe", "--gtest_break_on_failure=0", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::BreakOnFailure(false), false); +} + +// Tests parsing --gtest_break_on_failure=f. +TEST_F(ParseFlagsTest, BreakOnFailureFalse_f) { + const char* argv[] = {"foo.exe", "--gtest_break_on_failure=f", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::BreakOnFailure(false), false); +} + +// Tests parsing --gtest_break_on_failure=F. +TEST_F(ParseFlagsTest, BreakOnFailureFalse_F) { + const char* argv[] = {"foo.exe", "--gtest_break_on_failure=F", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::BreakOnFailure(false), false); +} + +// Tests parsing a --gtest_break_on_failure flag that has a "true" +// definition. +TEST_F(ParseFlagsTest, BreakOnFailureTrue) { + const char* argv[] = {"foo.exe", "--gtest_break_on_failure=1", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::BreakOnFailure(true), false); +} + +// Tests parsing --gtest_catch_exceptions. +TEST_F(ParseFlagsTest, CatchExceptions) { + const char* argv[] = {"foo.exe", "--gtest_catch_exceptions", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::CatchExceptions(true), false); +} + +// Tests parsing --gtest_death_test_use_fork. +TEST_F(ParseFlagsTest, DeathTestUseFork) { + const char* argv[] = {"foo.exe", "--gtest_death_test_use_fork", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::DeathTestUseFork(true), false); +} + +// Tests having the same flag twice with different values. The +// expected behavior is that the one coming last takes precedence. +TEST_F(ParseFlagsTest, DuplicatedFlags) { + const char* argv[] = {"foo.exe", "--gtest_filter=a", "--gtest_filter=b", + nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Filter("b"), false); +} + +// Tests having an unrecognized flag on the command line. +TEST_F(ParseFlagsTest, UnrecognizedFlag) { + const char* argv[] = {"foo.exe", "--gtest_break_on_failure", + "bar", // Unrecognized by Google Test. + "--gtest_filter=b", nullptr}; + + const char* argv2[] = {"foo.exe", "bar", nullptr}; + + Flags flags; + flags.break_on_failure = true; + flags.filter = "b"; + GTEST_TEST_PARSING_FLAGS_(argv, argv2, flags, false); +} + +// Tests having a --gtest_list_tests flag +TEST_F(ParseFlagsTest, ListTestsFlag) { + const char* argv[] = {"foo.exe", "--gtest_list_tests", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ListTests(true), false); +} + +// Tests having a --gtest_list_tests flag with a "true" value +TEST_F(ParseFlagsTest, ListTestsTrue) { + const char* argv[] = {"foo.exe", "--gtest_list_tests=1", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ListTests(true), false); +} + +// Tests having a --gtest_list_tests flag with a "false" value +TEST_F(ParseFlagsTest, ListTestsFalse) { + const char* argv[] = {"foo.exe", "--gtest_list_tests=0", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ListTests(false), false); +} + +// Tests parsing --gtest_list_tests=f. +TEST_F(ParseFlagsTest, ListTestsFalse_f) { + const char* argv[] = {"foo.exe", "--gtest_list_tests=f", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ListTests(false), false); +} + +// Tests parsing --gtest_list_tests=F. +TEST_F(ParseFlagsTest, ListTestsFalse_F) { + const char* argv[] = {"foo.exe", "--gtest_list_tests=F", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ListTests(false), false); +} + +// Tests parsing --gtest_output (invalid). +TEST_F(ParseFlagsTest, OutputEmpty) { + const char* argv[] = {"foo.exe", "--gtest_output", nullptr}; + + const char* argv2[] = {"foo.exe", "--gtest_output", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags(), true); +} + +// Tests parsing --gtest_output=xml +TEST_F(ParseFlagsTest, OutputXml) { + const char* argv[] = {"foo.exe", "--gtest_output=xml", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Output("xml"), false); +} + +// Tests parsing --gtest_output=xml:file +TEST_F(ParseFlagsTest, OutputXmlFile) { + const char* argv[] = {"foo.exe", "--gtest_output=xml:file", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Output("xml:file"), false); +} + +// Tests parsing --gtest_output=xml:directory/path/ +TEST_F(ParseFlagsTest, OutputXmlDirectory) { + const char* argv[] = {"foo.exe", "--gtest_output=xml:directory/path/", + nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, + Flags::Output("xml:directory/path/"), false); +} + +// Tests having a --gtest_brief flag +TEST_F(ParseFlagsTest, BriefFlag) { + const char* argv[] = {"foo.exe", "--gtest_brief", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Brief(true), false); +} + +// Tests having a --gtest_brief flag with a "true" value +TEST_F(ParseFlagsTest, BriefFlagTrue) { + const char* argv[] = {"foo.exe", "--gtest_brief=1", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Brief(true), false); +} + +// Tests having a --gtest_brief flag with a "false" value +TEST_F(ParseFlagsTest, BriefFlagFalse) { + const char* argv[] = {"foo.exe", "--gtest_brief=0", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Brief(false), false); +} + +// Tests having a --gtest_print_time flag +TEST_F(ParseFlagsTest, PrintTimeFlag) { + const char* argv[] = {"foo.exe", "--gtest_print_time", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::PrintTime(true), false); +} + +// Tests having a --gtest_print_time flag with a "true" value +TEST_F(ParseFlagsTest, PrintTimeTrue) { + const char* argv[] = {"foo.exe", "--gtest_print_time=1", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::PrintTime(true), false); +} + +// Tests having a --gtest_print_time flag with a "false" value +TEST_F(ParseFlagsTest, PrintTimeFalse) { + const char* argv[] = {"foo.exe", "--gtest_print_time=0", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::PrintTime(false), false); +} + +// Tests parsing --gtest_print_time=f. +TEST_F(ParseFlagsTest, PrintTimeFalse_f) { + const char* argv[] = {"foo.exe", "--gtest_print_time=f", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::PrintTime(false), false); +} + +// Tests parsing --gtest_print_time=F. +TEST_F(ParseFlagsTest, PrintTimeFalse_F) { + const char* argv[] = {"foo.exe", "--gtest_print_time=F", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::PrintTime(false), false); +} + +// Tests parsing --gtest_random_seed=number +TEST_F(ParseFlagsTest, RandomSeed) { + const char* argv[] = {"foo.exe", "--gtest_random_seed=1000", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::RandomSeed(1000), false); +} + +// Tests parsing --gtest_repeat=number +TEST_F(ParseFlagsTest, Repeat) { + const char* argv[] = {"foo.exe", "--gtest_repeat=1000", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Repeat(1000), false); +} + +// Tests parsing --gtest_recreate_environments_when_repeating +TEST_F(ParseFlagsTest, RecreateEnvironmentsWhenRepeating) { + const char* argv[] = { + "foo.exe", + "--gtest_recreate_environments_when_repeating=0", + nullptr, + }; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_( + argv, argv2, Flags::RecreateEnvironmentsWhenRepeating(false), false); +} + +// Tests having a --gtest_also_run_disabled_tests flag +TEST_F(ParseFlagsTest, AlsoRunDisabledTestsFlag) { + const char* argv[] = {"foo.exe", "--gtest_also_run_disabled_tests", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::AlsoRunDisabledTests(true), + false); +} + +// Tests having a --gtest_also_run_disabled_tests flag with a "true" value +TEST_F(ParseFlagsTest, AlsoRunDisabledTestsTrue) { + const char* argv[] = {"foo.exe", "--gtest_also_run_disabled_tests=1", + nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::AlsoRunDisabledTests(true), + false); +} + +// Tests having a --gtest_also_run_disabled_tests flag with a "false" value +TEST_F(ParseFlagsTest, AlsoRunDisabledTestsFalse) { + const char* argv[] = {"foo.exe", "--gtest_also_run_disabled_tests=0", + nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::AlsoRunDisabledTests(false), + false); +} + +// Tests parsing --gtest_shuffle. +TEST_F(ParseFlagsTest, ShuffleWithoutValue) { + const char* argv[] = {"foo.exe", "--gtest_shuffle", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Shuffle(true), false); +} + +// Tests parsing --gtest_shuffle=0. +TEST_F(ParseFlagsTest, ShuffleFalse_0) { + const char* argv[] = {"foo.exe", "--gtest_shuffle=0", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Shuffle(false), false); +} + +// Tests parsing a --gtest_shuffle flag that has a "true" definition. +TEST_F(ParseFlagsTest, ShuffleTrue) { + const char* argv[] = {"foo.exe", "--gtest_shuffle=1", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Shuffle(true), false); +} + +// Tests parsing --gtest_stack_trace_depth=number. +TEST_F(ParseFlagsTest, StackTraceDepth) { + const char* argv[] = {"foo.exe", "--gtest_stack_trace_depth=5", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::StackTraceDepth(5), false); +} + +TEST_F(ParseFlagsTest, StreamResultTo) { + const char* argv[] = {"foo.exe", "--gtest_stream_result_to=localhost:1234", + nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_( + argv, argv2, Flags::StreamResultTo("localhost:1234"), false); +} + +// Tests parsing --gtest_throw_on_failure. +TEST_F(ParseFlagsTest, ThrowOnFailureWithoutValue) { + const char* argv[] = {"foo.exe", "--gtest_throw_on_failure", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ThrowOnFailure(true), false); +} + +// Tests parsing --gtest_throw_on_failure=0. +TEST_F(ParseFlagsTest, ThrowOnFailureFalse_0) { + const char* argv[] = {"foo.exe", "--gtest_throw_on_failure=0", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ThrowOnFailure(false), false); +} + +// Tests parsing a --gtest_throw_on_failure flag that has a "true" +// definition. +TEST_F(ParseFlagsTest, ThrowOnFailureTrue) { + const char* argv[] = {"foo.exe", "--gtest_throw_on_failure=1", nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::ThrowOnFailure(true), false); +} + +# if GTEST_OS_WINDOWS +// Tests parsing wide strings. +TEST_F(ParseFlagsTest, WideStrings) { + const wchar_t* argv[] = { + L"foo.exe", + L"--gtest_filter=Foo*", + L"--gtest_list_tests=1", + L"--gtest_break_on_failure", + L"--non_gtest_flag", + NULL + }; + + const wchar_t* argv2[] = { + L"foo.exe", + L"--non_gtest_flag", + NULL + }; + + Flags expected_flags; + expected_flags.break_on_failure = true; + expected_flags.filter = "Foo*"; + expected_flags.list_tests = true; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, expected_flags, false); +} +# endif // GTEST_OS_WINDOWS + +#if GTEST_USE_OWN_FLAGFILE_FLAG_ +class FlagfileTest : public ParseFlagsTest { + public: + void SetUp() override { + ParseFlagsTest::SetUp(); + + testdata_path_.Set(internal::FilePath( + testing::TempDir() + internal::GetCurrentExecutableName().string() + + "_flagfile_test")); + testing::internal::posix::RmDir(testdata_path_.c_str()); + EXPECT_TRUE(testdata_path_.CreateFolder()); + } + + void TearDown() override { + testing::internal::posix::RmDir(testdata_path_.c_str()); + ParseFlagsTest::TearDown(); + } + + internal::FilePath CreateFlagfile(const char* contents) { + internal::FilePath file_path(internal::FilePath::GenerateUniqueFileName( + testdata_path_, internal::FilePath("unique"), "txt")); + FILE* f = testing::internal::posix::FOpen(file_path.c_str(), "w"); + fprintf(f, "%s", contents); + fclose(f); + return file_path; + } + + private: + internal::FilePath testdata_path_; +}; + +// Tests an empty flagfile. +TEST_F(FlagfileTest, Empty) { + internal::FilePath flagfile_path(CreateFlagfile("")); + std::string flagfile_flag = + std::string("--" GTEST_FLAG_PREFIX_ "flagfile=") + flagfile_path.c_str(); + + const char* argv[] = {"foo.exe", flagfile_flag.c_str(), nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags(), false); +} + +// Tests passing a non-empty --gtest_filter flag via --gtest_flagfile. +TEST_F(FlagfileTest, FilterNonEmpty) { + internal::FilePath flagfile_path(CreateFlagfile( + "--" GTEST_FLAG_PREFIX_ "filter=abc")); + std::string flagfile_flag = + std::string("--" GTEST_FLAG_PREFIX_ "flagfile=") + flagfile_path.c_str(); + + const char* argv[] = {"foo.exe", flagfile_flag.c_str(), nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, Flags::Filter("abc"), false); +} + +// Tests passing several flags via --gtest_flagfile. +TEST_F(FlagfileTest, SeveralFlags) { + internal::FilePath flagfile_path(CreateFlagfile( + "--" GTEST_FLAG_PREFIX_ "filter=abc\n" + "--" GTEST_FLAG_PREFIX_ "break_on_failure\n" + "--" GTEST_FLAG_PREFIX_ "list_tests")); + std::string flagfile_flag = + std::string("--" GTEST_FLAG_PREFIX_ "flagfile=") + flagfile_path.c_str(); + + const char* argv[] = {"foo.exe", flagfile_flag.c_str(), nullptr}; + + const char* argv2[] = {"foo.exe", nullptr}; + + Flags expected_flags; + expected_flags.break_on_failure = true; + expected_flags.filter = "abc"; + expected_flags.list_tests = true; + + GTEST_TEST_PARSING_FLAGS_(argv, argv2, expected_flags, false); +} +#endif // GTEST_USE_OWN_FLAGFILE_FLAG_ + +// Tests current_test_info() in UnitTest. +class CurrentTestInfoTest : public Test { + protected: + // Tests that current_test_info() returns NULL before the first test in + // the test case is run. + static void SetUpTestSuite() { + // There should be no tests running at this point. + const TestInfo* test_info = + UnitTest::GetInstance()->current_test_info(); + EXPECT_TRUE(test_info == nullptr) + << "There should be no tests running at this point."; + } + + // Tests that current_test_info() returns NULL after the last test in + // the test case has run. + static void TearDownTestSuite() { + const TestInfo* test_info = + UnitTest::GetInstance()->current_test_info(); + EXPECT_TRUE(test_info == nullptr) + << "There should be no tests running at this point."; + } +}; + +// Tests that current_test_info() returns TestInfo for currently running +// test by checking the expected test name against the actual one. +TEST_F(CurrentTestInfoTest, WorksForFirstTestInATestSuite) { + const TestInfo* test_info = + UnitTest::GetInstance()->current_test_info(); + ASSERT_TRUE(nullptr != test_info) + << "There is a test running so we should have a valid TestInfo."; + EXPECT_STREQ("CurrentTestInfoTest", test_info->test_suite_name()) + << "Expected the name of the currently running test suite."; + EXPECT_STREQ("WorksForFirstTestInATestSuite", test_info->name()) + << "Expected the name of the currently running test."; +} + +// Tests that current_test_info() returns TestInfo for currently running +// test by checking the expected test name against the actual one. We +// use this test to see that the TestInfo object actually changed from +// the previous invocation. +TEST_F(CurrentTestInfoTest, WorksForSecondTestInATestSuite) { + const TestInfo* test_info = + UnitTest::GetInstance()->current_test_info(); + ASSERT_TRUE(nullptr != test_info) + << "There is a test running so we should have a valid TestInfo."; + EXPECT_STREQ("CurrentTestInfoTest", test_info->test_suite_name()) + << "Expected the name of the currently running test suite."; + EXPECT_STREQ("WorksForSecondTestInATestSuite", test_info->name()) + << "Expected the name of the currently running test."; +} + +} // namespace testing + + +// These two lines test that we can define tests in a namespace that +// has the name "testing" and is nested in another namespace. +namespace my_namespace { +namespace testing { + +// Makes sure that TEST knows to use ::testing::Test instead of +// ::my_namespace::testing::Test. +class Test {}; + +// Makes sure that an assertion knows to use ::testing::Message instead of +// ::my_namespace::testing::Message. +class Message {}; + +// Makes sure that an assertion knows to use +// ::testing::AssertionResult instead of +// ::my_namespace::testing::AssertionResult. +class AssertionResult {}; + +// Tests that an assertion that should succeed works as expected. +TEST(NestedTestingNamespaceTest, Success) { + EXPECT_EQ(1, 1) << "This shouldn't fail."; +} + +// Tests that an assertion that should fail works as expected. +TEST(NestedTestingNamespaceTest, Failure) { + EXPECT_FATAL_FAILURE(FAIL() << "This failure is expected.", + "This failure is expected."); +} + +} // namespace testing +} // namespace my_namespace + +// Tests that one can call superclass SetUp and TearDown methods-- +// that is, that they are not private. +// No tests are based on this fixture; the test "passes" if it compiles +// successfully. +class ProtectedFixtureMethodsTest : public Test { + protected: + void SetUp() override { Test::SetUp(); } + void TearDown() override { Test::TearDown(); } +}; + +// StreamingAssertionsTest tests the streaming versions of a representative +// sample of assertions. +TEST(StreamingAssertionsTest, Unconditional) { + SUCCEED() << "expected success"; + EXPECT_NONFATAL_FAILURE(ADD_FAILURE() << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(FAIL() << "expected failure", + "expected failure"); +} + +#ifdef __BORLANDC__ +// Silences warnings: "Condition is always true", "Unreachable code" +# pragma option push -w-ccc -w-rch +#endif + +TEST(StreamingAssertionsTest, Truth) { + EXPECT_TRUE(true) << "unexpected failure"; + ASSERT_TRUE(true) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_TRUE(false) << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_TRUE(false) << "expected failure", + "expected failure"); +} + +TEST(StreamingAssertionsTest, Truth2) { + EXPECT_FALSE(false) << "unexpected failure"; + ASSERT_FALSE(false) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_FALSE(true) << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_FALSE(true) << "expected failure", + "expected failure"); +} + +#ifdef __BORLANDC__ +// Restores warnings after previous "#pragma option push" suppressed them +# pragma option pop +#endif + +TEST(StreamingAssertionsTest, IntegerEquals) { + EXPECT_EQ(1, 1) << "unexpected failure"; + ASSERT_EQ(1, 1) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_EQ(1, 2) << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_EQ(1, 2) << "expected failure", + "expected failure"); +} + +TEST(StreamingAssertionsTest, IntegerLessThan) { + EXPECT_LT(1, 2) << "unexpected failure"; + ASSERT_LT(1, 2) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_LT(2, 1) << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_LT(2, 1) << "expected failure", + "expected failure"); +} + +TEST(StreamingAssertionsTest, StringsEqual) { + EXPECT_STREQ("foo", "foo") << "unexpected failure"; + ASSERT_STREQ("foo", "foo") << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_STREQ("foo", "bar") << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_STREQ("foo", "bar") << "expected failure", + "expected failure"); +} + +TEST(StreamingAssertionsTest, StringsNotEqual) { + EXPECT_STRNE("foo", "bar") << "unexpected failure"; + ASSERT_STRNE("foo", "bar") << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_STRNE("foo", "foo") << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_STRNE("foo", "foo") << "expected failure", + "expected failure"); +} + +TEST(StreamingAssertionsTest, StringsEqualIgnoringCase) { + EXPECT_STRCASEEQ("foo", "FOO") << "unexpected failure"; + ASSERT_STRCASEEQ("foo", "FOO") << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_STRCASEEQ("foo", "bar") << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_STRCASEEQ("foo", "bar") << "expected failure", + "expected failure"); +} + +TEST(StreamingAssertionsTest, StringNotEqualIgnoringCase) { + EXPECT_STRCASENE("foo", "bar") << "unexpected failure"; + ASSERT_STRCASENE("foo", "bar") << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_STRCASENE("foo", "FOO") << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_STRCASENE("bar", "BAR") << "expected failure", + "expected failure"); +} + +TEST(StreamingAssertionsTest, FloatingPointEquals) { + EXPECT_FLOAT_EQ(1.0, 1.0) << "unexpected failure"; + ASSERT_FLOAT_EQ(1.0, 1.0) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_FLOAT_EQ(0.0, 1.0) << "expected failure", + "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_FLOAT_EQ(0.0, 1.0) << "expected failure", + "expected failure"); +} + +#if GTEST_HAS_EXCEPTIONS + +TEST(StreamingAssertionsTest, Throw) { + EXPECT_THROW(ThrowAnInteger(), int) << "unexpected failure"; + ASSERT_THROW(ThrowAnInteger(), int) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_THROW(ThrowAnInteger(), bool) << + "expected failure", "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_THROW(ThrowAnInteger(), bool) << + "expected failure", "expected failure"); +} + +TEST(StreamingAssertionsTest, NoThrow) { + EXPECT_NO_THROW(ThrowNothing()) << "unexpected failure"; + ASSERT_NO_THROW(ThrowNothing()) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_NO_THROW(ThrowAnInteger()) << + "expected failure", "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_NO_THROW(ThrowAnInteger()) << + "expected failure", "expected failure"); +} + +TEST(StreamingAssertionsTest, AnyThrow) { + EXPECT_ANY_THROW(ThrowAnInteger()) << "unexpected failure"; + ASSERT_ANY_THROW(ThrowAnInteger()) << "unexpected failure"; + EXPECT_NONFATAL_FAILURE(EXPECT_ANY_THROW(ThrowNothing()) << + "expected failure", "expected failure"); + EXPECT_FATAL_FAILURE(ASSERT_ANY_THROW(ThrowNothing()) << + "expected failure", "expected failure"); +} + +#endif // GTEST_HAS_EXCEPTIONS + +// Tests that Google Test correctly decides whether to use colors in the output. + +TEST(ColoredOutputTest, UsesColorsWhenGTestColorFlagIsYes) { + GTEST_FLAG_SET(color, "yes"); + + SetEnv("TERM", "xterm"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + EXPECT_TRUE(ShouldUseColor(false)); // Stdout is not a TTY. + + SetEnv("TERM", "dumb"); // TERM doesn't support colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + EXPECT_TRUE(ShouldUseColor(false)); // Stdout is not a TTY. +} + +TEST(ColoredOutputTest, UsesColorsWhenGTestColorFlagIsAliasOfYes) { + SetEnv("TERM", "dumb"); // TERM doesn't support colors. + + GTEST_FLAG_SET(color, "True"); + EXPECT_TRUE(ShouldUseColor(false)); // Stdout is not a TTY. + + GTEST_FLAG_SET(color, "t"); + EXPECT_TRUE(ShouldUseColor(false)); // Stdout is not a TTY. + + GTEST_FLAG_SET(color, "1"); + EXPECT_TRUE(ShouldUseColor(false)); // Stdout is not a TTY. +} + +TEST(ColoredOutputTest, UsesNoColorWhenGTestColorFlagIsNo) { + GTEST_FLAG_SET(color, "no"); + + SetEnv("TERM", "xterm"); // TERM supports colors. + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + EXPECT_FALSE(ShouldUseColor(false)); // Stdout is not a TTY. + + SetEnv("TERM", "dumb"); // TERM doesn't support colors. + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + EXPECT_FALSE(ShouldUseColor(false)); // Stdout is not a TTY. +} + +TEST(ColoredOutputTest, UsesNoColorWhenGTestColorFlagIsInvalid) { + SetEnv("TERM", "xterm"); // TERM supports colors. + + GTEST_FLAG_SET(color, "F"); + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + + GTEST_FLAG_SET(color, "0"); + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + + GTEST_FLAG_SET(color, "unknown"); + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. +} + +TEST(ColoredOutputTest, UsesColorsWhenStdoutIsTty) { + GTEST_FLAG_SET(color, "auto"); + + SetEnv("TERM", "xterm"); // TERM supports colors. + EXPECT_FALSE(ShouldUseColor(false)); // Stdout is not a TTY. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. +} + +TEST(ColoredOutputTest, UsesColorsWhenTermSupportsColors) { + GTEST_FLAG_SET(color, "auto"); + +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MINGW + // On Windows, we ignore the TERM variable as it's usually not set. + + SetEnv("TERM", "dumb"); + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", ""); + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "xterm"); + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. +#else + // On non-Windows platforms, we rely on TERM to determine if the + // terminal supports colors. + + SetEnv("TERM", "dumb"); // TERM doesn't support colors. + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "emacs"); // TERM doesn't support colors. + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "vt100"); // TERM doesn't support colors. + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "xterm-mono"); // TERM doesn't support colors. + EXPECT_FALSE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "xterm"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "xterm-color"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "xterm-256color"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "screen"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "screen-256color"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "tmux"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "tmux-256color"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "rxvt-unicode"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "rxvt-unicode-256color"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "linux"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. + + SetEnv("TERM", "cygwin"); // TERM supports colors. + EXPECT_TRUE(ShouldUseColor(true)); // Stdout is a TTY. +#endif // GTEST_OS_WINDOWS +} + +// Verifies that StaticAssertTypeEq works in a namespace scope. + +static bool dummy1 GTEST_ATTRIBUTE_UNUSED_ = StaticAssertTypeEq(); +static bool dummy2 GTEST_ATTRIBUTE_UNUSED_ = + StaticAssertTypeEq(); + +// Verifies that StaticAssertTypeEq works in a class. + +template +class StaticAssertTypeEqTestHelper { + public: + StaticAssertTypeEqTestHelper() { StaticAssertTypeEq(); } +}; + +TEST(StaticAssertTypeEqTest, WorksInClass) { + StaticAssertTypeEqTestHelper(); +} + +// Verifies that StaticAssertTypeEq works inside a function. + +typedef int IntAlias; + +TEST(StaticAssertTypeEqTest, CompilesForEqualTypes) { + StaticAssertTypeEq(); + StaticAssertTypeEq(); +} + +TEST(HasNonfatalFailureTest, ReturnsFalseWhenThereIsNoFailure) { + EXPECT_FALSE(HasNonfatalFailure()); +} + +static void FailFatally() { FAIL(); } + +TEST(HasNonfatalFailureTest, ReturnsFalseWhenThereIsOnlyFatalFailure) { + FailFatally(); + const bool has_nonfatal_failure = HasNonfatalFailure(); + ClearCurrentTestPartResults(); + EXPECT_FALSE(has_nonfatal_failure); +} + +TEST(HasNonfatalFailureTest, ReturnsTrueWhenThereIsNonfatalFailure) { + ADD_FAILURE(); + const bool has_nonfatal_failure = HasNonfatalFailure(); + ClearCurrentTestPartResults(); + EXPECT_TRUE(has_nonfatal_failure); +} + +TEST(HasNonfatalFailureTest, ReturnsTrueWhenThereAreFatalAndNonfatalFailures) { + FailFatally(); + ADD_FAILURE(); + const bool has_nonfatal_failure = HasNonfatalFailure(); + ClearCurrentTestPartResults(); + EXPECT_TRUE(has_nonfatal_failure); +} + +// A wrapper for calling HasNonfatalFailure outside of a test body. +static bool HasNonfatalFailureHelper() { + return testing::Test::HasNonfatalFailure(); +} + +TEST(HasNonfatalFailureTest, WorksOutsideOfTestBody) { + EXPECT_FALSE(HasNonfatalFailureHelper()); +} + +TEST(HasNonfatalFailureTest, WorksOutsideOfTestBody2) { + ADD_FAILURE(); + const bool has_nonfatal_failure = HasNonfatalFailureHelper(); + ClearCurrentTestPartResults(); + EXPECT_TRUE(has_nonfatal_failure); +} + +TEST(HasFailureTest, ReturnsFalseWhenThereIsNoFailure) { + EXPECT_FALSE(HasFailure()); +} + +TEST(HasFailureTest, ReturnsTrueWhenThereIsFatalFailure) { + FailFatally(); + const bool has_failure = HasFailure(); + ClearCurrentTestPartResults(); + EXPECT_TRUE(has_failure); +} + +TEST(HasFailureTest, ReturnsTrueWhenThereIsNonfatalFailure) { + ADD_FAILURE(); + const bool has_failure = HasFailure(); + ClearCurrentTestPartResults(); + EXPECT_TRUE(has_failure); +} + +TEST(HasFailureTest, ReturnsTrueWhenThereAreFatalAndNonfatalFailures) { + FailFatally(); + ADD_FAILURE(); + const bool has_failure = HasFailure(); + ClearCurrentTestPartResults(); + EXPECT_TRUE(has_failure); +} + +// A wrapper for calling HasFailure outside of a test body. +static bool HasFailureHelper() { return testing::Test::HasFailure(); } + +TEST(HasFailureTest, WorksOutsideOfTestBody) { + EXPECT_FALSE(HasFailureHelper()); +} + +TEST(HasFailureTest, WorksOutsideOfTestBody2) { + ADD_FAILURE(); + const bool has_failure = HasFailureHelper(); + ClearCurrentTestPartResults(); + EXPECT_TRUE(has_failure); +} + +class TestListener : public EmptyTestEventListener { + public: + TestListener() : on_start_counter_(nullptr), is_destroyed_(nullptr) {} + TestListener(int* on_start_counter, bool* is_destroyed) + : on_start_counter_(on_start_counter), + is_destroyed_(is_destroyed) {} + + ~TestListener() override { + if (is_destroyed_) + *is_destroyed_ = true; + } + + protected: + void OnTestProgramStart(const UnitTest& /*unit_test*/) override { + if (on_start_counter_ != nullptr) (*on_start_counter_)++; + } + + private: + int* on_start_counter_; + bool* is_destroyed_; +}; + +// Tests the constructor. +TEST(TestEventListenersTest, ConstructionWorks) { + TestEventListeners listeners; + + EXPECT_TRUE(TestEventListenersAccessor::GetRepeater(&listeners) != nullptr); + EXPECT_TRUE(listeners.default_result_printer() == nullptr); + EXPECT_TRUE(listeners.default_xml_generator() == nullptr); +} + +// Tests that the TestEventListeners destructor deletes all the listeners it +// owns. +TEST(TestEventListenersTest, DestructionWorks) { + bool default_result_printer_is_destroyed = false; + bool default_xml_printer_is_destroyed = false; + bool extra_listener_is_destroyed = false; + TestListener* default_result_printer = + new TestListener(nullptr, &default_result_printer_is_destroyed); + TestListener* default_xml_printer = + new TestListener(nullptr, &default_xml_printer_is_destroyed); + TestListener* extra_listener = + new TestListener(nullptr, &extra_listener_is_destroyed); + + { + TestEventListeners listeners; + TestEventListenersAccessor::SetDefaultResultPrinter(&listeners, + default_result_printer); + TestEventListenersAccessor::SetDefaultXmlGenerator(&listeners, + default_xml_printer); + listeners.Append(extra_listener); + } + EXPECT_TRUE(default_result_printer_is_destroyed); + EXPECT_TRUE(default_xml_printer_is_destroyed); + EXPECT_TRUE(extra_listener_is_destroyed); +} + +// Tests that a listener Append'ed to a TestEventListeners list starts +// receiving events. +TEST(TestEventListenersTest, Append) { + int on_start_counter = 0; + bool is_destroyed = false; + TestListener* listener = new TestListener(&on_start_counter, &is_destroyed); + { + TestEventListeners listeners; + listeners.Append(listener); + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + EXPECT_EQ(1, on_start_counter); + } + EXPECT_TRUE(is_destroyed); +} + +// Tests that listeners receive events in the order they were appended to +// the list, except for *End requests, which must be received in the reverse +// order. +class SequenceTestingListener : public EmptyTestEventListener { + public: + SequenceTestingListener(std::vector* vector, const char* id) + : vector_(vector), id_(id) {} + + protected: + void OnTestProgramStart(const UnitTest& /*unit_test*/) override { + vector_->push_back(GetEventDescription("OnTestProgramStart")); + } + + void OnTestProgramEnd(const UnitTest& /*unit_test*/) override { + vector_->push_back(GetEventDescription("OnTestProgramEnd")); + } + + void OnTestIterationStart(const UnitTest& /*unit_test*/, + int /*iteration*/) override { + vector_->push_back(GetEventDescription("OnTestIterationStart")); + } + + void OnTestIterationEnd(const UnitTest& /*unit_test*/, + int /*iteration*/) override { + vector_->push_back(GetEventDescription("OnTestIterationEnd")); + } + + private: + std::string GetEventDescription(const char* method) { + Message message; + message << id_ << "." << method; + return message.GetString(); + } + + std::vector* vector_; + const char* const id_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(SequenceTestingListener); +}; + +TEST(EventListenerTest, AppendKeepsOrder) { + std::vector vec; + TestEventListeners listeners; + listeners.Append(new SequenceTestingListener(&vec, "1st")); + listeners.Append(new SequenceTestingListener(&vec, "2nd")); + listeners.Append(new SequenceTestingListener(&vec, "3rd")); + + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + ASSERT_EQ(3U, vec.size()); + EXPECT_STREQ("1st.OnTestProgramStart", vec[0].c_str()); + EXPECT_STREQ("2nd.OnTestProgramStart", vec[1].c_str()); + EXPECT_STREQ("3rd.OnTestProgramStart", vec[2].c_str()); + + vec.clear(); + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramEnd( + *UnitTest::GetInstance()); + ASSERT_EQ(3U, vec.size()); + EXPECT_STREQ("3rd.OnTestProgramEnd", vec[0].c_str()); + EXPECT_STREQ("2nd.OnTestProgramEnd", vec[1].c_str()); + EXPECT_STREQ("1st.OnTestProgramEnd", vec[2].c_str()); + + vec.clear(); + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestIterationStart( + *UnitTest::GetInstance(), 0); + ASSERT_EQ(3U, vec.size()); + EXPECT_STREQ("1st.OnTestIterationStart", vec[0].c_str()); + EXPECT_STREQ("2nd.OnTestIterationStart", vec[1].c_str()); + EXPECT_STREQ("3rd.OnTestIterationStart", vec[2].c_str()); + + vec.clear(); + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestIterationEnd( + *UnitTest::GetInstance(), 0); + ASSERT_EQ(3U, vec.size()); + EXPECT_STREQ("3rd.OnTestIterationEnd", vec[0].c_str()); + EXPECT_STREQ("2nd.OnTestIterationEnd", vec[1].c_str()); + EXPECT_STREQ("1st.OnTestIterationEnd", vec[2].c_str()); +} + +// Tests that a listener removed from a TestEventListeners list stops receiving +// events and is not deleted when the list is destroyed. +TEST(TestEventListenersTest, Release) { + int on_start_counter = 0; + bool is_destroyed = false; + // Although Append passes the ownership of this object to the list, + // the following calls release it, and we need to delete it before the + // test ends. + TestListener* listener = new TestListener(&on_start_counter, &is_destroyed); + { + TestEventListeners listeners; + listeners.Append(listener); + EXPECT_EQ(listener, listeners.Release(listener)); + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + EXPECT_TRUE(listeners.Release(listener) == nullptr); + } + EXPECT_EQ(0, on_start_counter); + EXPECT_FALSE(is_destroyed); + delete listener; +} + +// Tests that no events are forwarded when event forwarding is disabled. +TEST(EventListenerTest, SuppressEventForwarding) { + int on_start_counter = 0; + TestListener* listener = new TestListener(&on_start_counter, nullptr); + + TestEventListeners listeners; + listeners.Append(listener); + ASSERT_TRUE(TestEventListenersAccessor::EventForwardingEnabled(listeners)); + TestEventListenersAccessor::SuppressEventForwarding(&listeners); + ASSERT_FALSE(TestEventListenersAccessor::EventForwardingEnabled(listeners)); + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + EXPECT_EQ(0, on_start_counter); +} + +// Tests that events generated by Google Test are not forwarded in +// death test subprocesses. +TEST(EventListenerDeathTest, EventsNotForwardedInDeathTestSubprecesses) { + EXPECT_DEATH_IF_SUPPORTED({ + GTEST_CHECK_(TestEventListenersAccessor::EventForwardingEnabled( + *GetUnitTestImpl()->listeners())) << "expected failure";}, + "expected failure"); +} + +// Tests that a listener installed via SetDefaultResultPrinter() starts +// receiving events and is returned via default_result_printer() and that +// the previous default_result_printer is removed from the list and deleted. +TEST(EventListenerTest, default_result_printer) { + int on_start_counter = 0; + bool is_destroyed = false; + TestListener* listener = new TestListener(&on_start_counter, &is_destroyed); + + TestEventListeners listeners; + TestEventListenersAccessor::SetDefaultResultPrinter(&listeners, listener); + + EXPECT_EQ(listener, listeners.default_result_printer()); + + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + + EXPECT_EQ(1, on_start_counter); + + // Replacing default_result_printer with something else should remove it + // from the list and destroy it. + TestEventListenersAccessor::SetDefaultResultPrinter(&listeners, nullptr); + + EXPECT_TRUE(listeners.default_result_printer() == nullptr); + EXPECT_TRUE(is_destroyed); + + // After broadcasting an event the counter is still the same, indicating + // the listener is not in the list anymore. + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + EXPECT_EQ(1, on_start_counter); +} + +// Tests that the default_result_printer listener stops receiving events +// when removed via Release and that is not owned by the list anymore. +TEST(EventListenerTest, RemovingDefaultResultPrinterWorks) { + int on_start_counter = 0; + bool is_destroyed = false; + // Although Append passes the ownership of this object to the list, + // the following calls release it, and we need to delete it before the + // test ends. + TestListener* listener = new TestListener(&on_start_counter, &is_destroyed); + { + TestEventListeners listeners; + TestEventListenersAccessor::SetDefaultResultPrinter(&listeners, listener); + + EXPECT_EQ(listener, listeners.Release(listener)); + EXPECT_TRUE(listeners.default_result_printer() == nullptr); + EXPECT_FALSE(is_destroyed); + + // Broadcasting events now should not affect default_result_printer. + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + EXPECT_EQ(0, on_start_counter); + } + // Destroying the list should not affect the listener now, too. + EXPECT_FALSE(is_destroyed); + delete listener; +} + +// Tests that a listener installed via SetDefaultXmlGenerator() starts +// receiving events and is returned via default_xml_generator() and that +// the previous default_xml_generator is removed from the list and deleted. +TEST(EventListenerTest, default_xml_generator) { + int on_start_counter = 0; + bool is_destroyed = false; + TestListener* listener = new TestListener(&on_start_counter, &is_destroyed); + + TestEventListeners listeners; + TestEventListenersAccessor::SetDefaultXmlGenerator(&listeners, listener); + + EXPECT_EQ(listener, listeners.default_xml_generator()); + + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + + EXPECT_EQ(1, on_start_counter); + + // Replacing default_xml_generator with something else should remove it + // from the list and destroy it. + TestEventListenersAccessor::SetDefaultXmlGenerator(&listeners, nullptr); + + EXPECT_TRUE(listeners.default_xml_generator() == nullptr); + EXPECT_TRUE(is_destroyed); + + // After broadcasting an event the counter is still the same, indicating + // the listener is not in the list anymore. + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + EXPECT_EQ(1, on_start_counter); +} + +// Tests that the default_xml_generator listener stops receiving events +// when removed via Release and that is not owned by the list anymore. +TEST(EventListenerTest, RemovingDefaultXmlGeneratorWorks) { + int on_start_counter = 0; + bool is_destroyed = false; + // Although Append passes the ownership of this object to the list, + // the following calls release it, and we need to delete it before the + // test ends. + TestListener* listener = new TestListener(&on_start_counter, &is_destroyed); + { + TestEventListeners listeners; + TestEventListenersAccessor::SetDefaultXmlGenerator(&listeners, listener); + + EXPECT_EQ(listener, listeners.Release(listener)); + EXPECT_TRUE(listeners.default_xml_generator() == nullptr); + EXPECT_FALSE(is_destroyed); + + // Broadcasting events now should not affect default_xml_generator. + TestEventListenersAccessor::GetRepeater(&listeners)->OnTestProgramStart( + *UnitTest::GetInstance()); + EXPECT_EQ(0, on_start_counter); + } + // Destroying the list should not affect the listener now, too. + EXPECT_FALSE(is_destroyed); + delete listener; +} + +// Sanity tests to ensure that the alternative, verbose spellings of +// some of the macros work. We don't test them thoroughly as that +// would be quite involved. Since their implementations are +// straightforward, and they are rarely used, we'll just rely on the +// users to tell us when they are broken. +GTEST_TEST(AlternativeNameTest, Works) { // GTEST_TEST is the same as TEST. + GTEST_SUCCEED() << "OK"; // GTEST_SUCCEED is the same as SUCCEED. + + // GTEST_FAIL is the same as FAIL. + EXPECT_FATAL_FAILURE(GTEST_FAIL() << "An expected failure", + "An expected failure"); + + // GTEST_ASSERT_XY is the same as ASSERT_XY. + + GTEST_ASSERT_EQ(0, 0); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_EQ(0, 1) << "An expected failure", + "An expected failure"); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_EQ(1, 0) << "An expected failure", + "An expected failure"); + + GTEST_ASSERT_NE(0, 1); + GTEST_ASSERT_NE(1, 0); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_NE(0, 0) << "An expected failure", + "An expected failure"); + + GTEST_ASSERT_LE(0, 0); + GTEST_ASSERT_LE(0, 1); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_LE(1, 0) << "An expected failure", + "An expected failure"); + + GTEST_ASSERT_LT(0, 1); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_LT(0, 0) << "An expected failure", + "An expected failure"); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_LT(1, 0) << "An expected failure", + "An expected failure"); + + GTEST_ASSERT_GE(0, 0); + GTEST_ASSERT_GE(1, 0); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_GE(0, 1) << "An expected failure", + "An expected failure"); + + GTEST_ASSERT_GT(1, 0); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_GT(0, 1) << "An expected failure", + "An expected failure"); + EXPECT_FATAL_FAILURE(GTEST_ASSERT_GT(1, 1) << "An expected failure", + "An expected failure"); +} + +// Tests for internal utilities necessary for implementation of the universal +// printing. + +class ConversionHelperBase {}; +class ConversionHelperDerived : public ConversionHelperBase {}; + +struct HasDebugStringMethods { + std::string DebugString() const { return ""; } + std::string ShortDebugString() const { return ""; } +}; + +struct InheritsDebugStringMethods : public HasDebugStringMethods {}; + +struct WrongTypeDebugStringMethod { + std::string DebugString() const { return ""; } + int ShortDebugString() const { return 1; } +}; + +struct NotConstDebugStringMethod { + std::string DebugString() { return ""; } + std::string ShortDebugString() const { return ""; } +}; + +struct MissingDebugStringMethod { + std::string DebugString() { return ""; } +}; + +struct IncompleteType; + +// Tests that HasDebugStringAndShortDebugString::value is a compile-time +// constant. +TEST(HasDebugStringAndShortDebugStringTest, ValueIsCompileTimeConstant) { + GTEST_COMPILE_ASSERT_( + HasDebugStringAndShortDebugString::value, + const_true); + GTEST_COMPILE_ASSERT_( + HasDebugStringAndShortDebugString::value, + const_true); + GTEST_COMPILE_ASSERT_(HasDebugStringAndShortDebugString< + const InheritsDebugStringMethods>::value, + const_true); + GTEST_COMPILE_ASSERT_( + !HasDebugStringAndShortDebugString::value, + const_false); + GTEST_COMPILE_ASSERT_( + !HasDebugStringAndShortDebugString::value, + const_false); + GTEST_COMPILE_ASSERT_( + !HasDebugStringAndShortDebugString::value, + const_false); + GTEST_COMPILE_ASSERT_( + !HasDebugStringAndShortDebugString::value, const_false); + GTEST_COMPILE_ASSERT_(!HasDebugStringAndShortDebugString::value, + const_false); +} + +// Tests that HasDebugStringAndShortDebugString::value is true when T has +// needed methods. +TEST(HasDebugStringAndShortDebugStringTest, + ValueIsTrueWhenTypeHasDebugStringAndShortDebugString) { + EXPECT_TRUE( + HasDebugStringAndShortDebugString::value); +} + +// Tests that HasDebugStringAndShortDebugString::value is false when T +// doesn't have needed methods. +TEST(HasDebugStringAndShortDebugStringTest, + ValueIsFalseWhenTypeIsNotAProtocolMessage) { + EXPECT_FALSE(HasDebugStringAndShortDebugString::value); + EXPECT_FALSE( + HasDebugStringAndShortDebugString::value); +} + +// Tests GTEST_REMOVE_REFERENCE_AND_CONST_. + +template +void TestGTestRemoveReferenceAndConst() { + static_assert(std::is_same::value, + "GTEST_REMOVE_REFERENCE_AND_CONST_ failed."); +} + +TEST(RemoveReferenceToConstTest, Works) { + TestGTestRemoveReferenceAndConst(); + TestGTestRemoveReferenceAndConst(); + TestGTestRemoveReferenceAndConst(); + TestGTestRemoveReferenceAndConst(); + TestGTestRemoveReferenceAndConst(); +} + +// Tests GTEST_REFERENCE_TO_CONST_. + +template +void TestGTestReferenceToConst() { + static_assert(std::is_same::value, + "GTEST_REFERENCE_TO_CONST_ failed."); +} + +TEST(GTestReferenceToConstTest, Works) { + TestGTestReferenceToConst(); + TestGTestReferenceToConst(); + TestGTestReferenceToConst(); + TestGTestReferenceToConst(); +} + + +// Tests IsContainerTest. + +class NonContainer {}; + +TEST(IsContainerTestTest, WorksForNonContainer) { + EXPECT_EQ(sizeof(IsNotContainer), sizeof(IsContainerTest(0))); + EXPECT_EQ(sizeof(IsNotContainer), sizeof(IsContainerTest(0))); + EXPECT_EQ(sizeof(IsNotContainer), sizeof(IsContainerTest(0))); +} + +TEST(IsContainerTestTest, WorksForContainer) { + EXPECT_EQ(sizeof(IsContainer), + sizeof(IsContainerTest >(0))); + EXPECT_EQ(sizeof(IsContainer), + sizeof(IsContainerTest >(0))); +} + +struct ConstOnlyContainerWithPointerIterator { + using const_iterator = int*; + const_iterator begin() const; + const_iterator end() const; +}; + +struct ConstOnlyContainerWithClassIterator { + struct const_iterator { + const int& operator*() const; + const_iterator& operator++(/* pre-increment */); + }; + const_iterator begin() const; + const_iterator end() const; +}; + +TEST(IsContainerTestTest, ConstOnlyContainer) { + EXPECT_EQ(sizeof(IsContainer), + sizeof(IsContainerTest(0))); + EXPECT_EQ(sizeof(IsContainer), + sizeof(IsContainerTest(0))); +} + +// Tests IsHashTable. +struct AHashTable { + typedef void hasher; +}; +struct NotReallyAHashTable { + typedef void hasher; + typedef void reverse_iterator; +}; +TEST(IsHashTable, Basic) { + EXPECT_TRUE(testing::internal::IsHashTable::value); + EXPECT_FALSE(testing::internal::IsHashTable::value); + EXPECT_FALSE(testing::internal::IsHashTable>::value); + EXPECT_TRUE(testing::internal::IsHashTable>::value); +} + +// Tests ArrayEq(). + +TEST(ArrayEqTest, WorksForDegeneratedArrays) { + EXPECT_TRUE(ArrayEq(5, 5L)); + EXPECT_FALSE(ArrayEq('a', 0)); +} + +TEST(ArrayEqTest, WorksForOneDimensionalArrays) { + // Note that a and b are distinct but compatible types. + const int a[] = { 0, 1 }; + long b[] = { 0, 1 }; + EXPECT_TRUE(ArrayEq(a, b)); + EXPECT_TRUE(ArrayEq(a, 2, b)); + + b[0] = 2; + EXPECT_FALSE(ArrayEq(a, b)); + EXPECT_FALSE(ArrayEq(a, 1, b)); +} + +TEST(ArrayEqTest, WorksForTwoDimensionalArrays) { + const char a[][3] = { "hi", "lo" }; + const char b[][3] = { "hi", "lo" }; + const char c[][3] = { "hi", "li" }; + + EXPECT_TRUE(ArrayEq(a, b)); + EXPECT_TRUE(ArrayEq(a, 2, b)); + + EXPECT_FALSE(ArrayEq(a, c)); + EXPECT_FALSE(ArrayEq(a, 2, c)); +} + +// Tests ArrayAwareFind(). + +TEST(ArrayAwareFindTest, WorksForOneDimensionalArray) { + const char a[] = "hello"; + EXPECT_EQ(a + 4, ArrayAwareFind(a, a + 5, 'o')); + EXPECT_EQ(a + 5, ArrayAwareFind(a, a + 5, 'x')); +} + +TEST(ArrayAwareFindTest, WorksForTwoDimensionalArray) { + int a[][2] = { { 0, 1 }, { 2, 3 }, { 4, 5 } }; + const int b[2] = { 2, 3 }; + EXPECT_EQ(a + 1, ArrayAwareFind(a, a + 3, b)); + + const int c[2] = { 6, 7 }; + EXPECT_EQ(a + 3, ArrayAwareFind(a, a + 3, c)); +} + +// Tests CopyArray(). + +TEST(CopyArrayTest, WorksForDegeneratedArrays) { + int n = 0; + CopyArray('a', &n); + EXPECT_EQ('a', n); +} + +TEST(CopyArrayTest, WorksForOneDimensionalArrays) { + const char a[3] = "hi"; + int b[3]; +#ifndef __BORLANDC__ // C++Builder cannot compile some array size deductions. + CopyArray(a, &b); + EXPECT_TRUE(ArrayEq(a, b)); +#endif + + int c[3]; + CopyArray(a, 3, c); + EXPECT_TRUE(ArrayEq(a, c)); +} + +TEST(CopyArrayTest, WorksForTwoDimensionalArrays) { + const int a[2][3] = { { 0, 1, 2 }, { 3, 4, 5 } }; + int b[2][3]; +#ifndef __BORLANDC__ // C++Builder cannot compile some array size deductions. + CopyArray(a, &b); + EXPECT_TRUE(ArrayEq(a, b)); +#endif + + int c[2][3]; + CopyArray(a, 2, c); + EXPECT_TRUE(ArrayEq(a, c)); +} + +// Tests NativeArray. + +TEST(NativeArrayTest, ConstructorFromArrayWorks) { + const int a[3] = { 0, 1, 2 }; + NativeArray na(a, 3, RelationToSourceReference()); + EXPECT_EQ(3U, na.size()); + EXPECT_EQ(a, na.begin()); +} + +TEST(NativeArrayTest, CreatesAndDeletesCopyOfArrayWhenAskedTo) { + typedef int Array[2]; + Array* a = new Array[1]; + (*a)[0] = 0; + (*a)[1] = 1; + NativeArray na(*a, 2, RelationToSourceCopy()); + EXPECT_NE(*a, na.begin()); + delete[] a; + EXPECT_EQ(0, na.begin()[0]); + EXPECT_EQ(1, na.begin()[1]); + + // We rely on the heap checker to verify that na deletes the copy of + // array. +} + +TEST(NativeArrayTest, TypeMembersAreCorrect) { + StaticAssertTypeEq::value_type>(); + StaticAssertTypeEq::value_type>(); + + StaticAssertTypeEq::const_iterator>(); + StaticAssertTypeEq::const_iterator>(); +} + +TEST(NativeArrayTest, MethodsWork) { + const int a[3] = { 0, 1, 2 }; + NativeArray na(a, 3, RelationToSourceCopy()); + ASSERT_EQ(3U, na.size()); + EXPECT_EQ(3, na.end() - na.begin()); + + NativeArray::const_iterator it = na.begin(); + EXPECT_EQ(0, *it); + ++it; + EXPECT_EQ(1, *it); + it++; + EXPECT_EQ(2, *it); + ++it; + EXPECT_EQ(na.end(), it); + + EXPECT_TRUE(na == na); + + NativeArray na2(a, 3, RelationToSourceReference()); + EXPECT_TRUE(na == na2); + + const int b1[3] = { 0, 1, 1 }; + const int b2[4] = { 0, 1, 2, 3 }; + EXPECT_FALSE(na == NativeArray(b1, 3, RelationToSourceReference())); + EXPECT_FALSE(na == NativeArray(b2, 4, RelationToSourceCopy())); +} + +TEST(NativeArrayTest, WorksForTwoDimensionalArray) { + const char a[2][3] = { "hi", "lo" }; + NativeArray na(a, 2, RelationToSourceReference()); + ASSERT_EQ(2U, na.size()); + EXPECT_EQ(a, na.begin()); +} + +// IndexSequence +TEST(IndexSequence, MakeIndexSequence) { + using testing::internal::IndexSequence; + using testing::internal::MakeIndexSequence; + EXPECT_TRUE( + (std::is_same, MakeIndexSequence<0>::type>::value)); + EXPECT_TRUE( + (std::is_same, MakeIndexSequence<1>::type>::value)); + EXPECT_TRUE( + (std::is_same, MakeIndexSequence<2>::type>::value)); + EXPECT_TRUE(( + std::is_same, MakeIndexSequence<3>::type>::value)); + EXPECT_TRUE( + (std::is_base_of, MakeIndexSequence<3>>::value)); +} + +// ElemFromList +TEST(ElemFromList, Basic) { + using testing::internal::ElemFromList; + EXPECT_TRUE( + (std::is_same::type>::value)); + EXPECT_TRUE( + (std::is_same::type>::value)); + EXPECT_TRUE( + (std::is_same::type>::value)); + EXPECT_TRUE(( + std::is_same::type>::value)); +} + +// FlatTuple +TEST(FlatTuple, Basic) { + using testing::internal::FlatTuple; + + FlatTuple tuple = {}; + EXPECT_EQ(0, tuple.Get<0>()); + EXPECT_EQ(0.0, tuple.Get<1>()); + EXPECT_EQ(nullptr, tuple.Get<2>()); + + tuple = FlatTuple( + testing::internal::FlatTupleConstructTag{}, 7, 3.2, "Foo"); + EXPECT_EQ(7, tuple.Get<0>()); + EXPECT_EQ(3.2, tuple.Get<1>()); + EXPECT_EQ(std::string("Foo"), tuple.Get<2>()); + + tuple.Get<1>() = 5.1; + EXPECT_EQ(5.1, tuple.Get<1>()); +} + +namespace { +std::string AddIntToString(int i, const std::string& s) { + return s + std::to_string(i); +} +} // namespace + +TEST(FlatTuple, Apply) { + using testing::internal::FlatTuple; + + FlatTuple tuple{testing::internal::FlatTupleConstructTag{}, + 5, "Hello"}; + + // Lambda. + EXPECT_TRUE(tuple.Apply([](int i, const std::string& s) -> bool { + return i == static_cast(s.size()); + })); + + // Function. + EXPECT_EQ(tuple.Apply(AddIntToString), "Hello5"); + + // Mutating operations. + tuple.Apply([](int& i, std::string& s) { + ++i; + s += s; + }); + EXPECT_EQ(tuple.Get<0>(), 6); + EXPECT_EQ(tuple.Get<1>(), "HelloHello"); +} + +struct ConstructionCounting { + ConstructionCounting() { ++default_ctor_calls; } + ~ConstructionCounting() { ++dtor_calls; } + ConstructionCounting(const ConstructionCounting&) { ++copy_ctor_calls; } + ConstructionCounting(ConstructionCounting&&) noexcept { ++move_ctor_calls; } + ConstructionCounting& operator=(const ConstructionCounting&) { + ++copy_assignment_calls; + return *this; + } + ConstructionCounting& operator=(ConstructionCounting&&) noexcept { + ++move_assignment_calls; + return *this; + } + + static void Reset() { + default_ctor_calls = 0; + dtor_calls = 0; + copy_ctor_calls = 0; + move_ctor_calls = 0; + copy_assignment_calls = 0; + move_assignment_calls = 0; + } + + static int default_ctor_calls; + static int dtor_calls; + static int copy_ctor_calls; + static int move_ctor_calls; + static int copy_assignment_calls; + static int move_assignment_calls; +}; + +int ConstructionCounting::default_ctor_calls = 0; +int ConstructionCounting::dtor_calls = 0; +int ConstructionCounting::copy_ctor_calls = 0; +int ConstructionCounting::move_ctor_calls = 0; +int ConstructionCounting::copy_assignment_calls = 0; +int ConstructionCounting::move_assignment_calls = 0; + +TEST(FlatTuple, ConstructorCalls) { + using testing::internal::FlatTuple; + + // Default construction. + ConstructionCounting::Reset(); + { FlatTuple tuple; } + EXPECT_EQ(ConstructionCounting::default_ctor_calls, 1); + EXPECT_EQ(ConstructionCounting::dtor_calls, 1); + EXPECT_EQ(ConstructionCounting::copy_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::move_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::copy_assignment_calls, 0); + EXPECT_EQ(ConstructionCounting::move_assignment_calls, 0); + + // Copy construction. + ConstructionCounting::Reset(); + { + ConstructionCounting elem; + FlatTuple tuple{ + testing::internal::FlatTupleConstructTag{}, elem}; + } + EXPECT_EQ(ConstructionCounting::default_ctor_calls, 1); + EXPECT_EQ(ConstructionCounting::dtor_calls, 2); + EXPECT_EQ(ConstructionCounting::copy_ctor_calls, 1); + EXPECT_EQ(ConstructionCounting::move_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::copy_assignment_calls, 0); + EXPECT_EQ(ConstructionCounting::move_assignment_calls, 0); + + // Move construction. + ConstructionCounting::Reset(); + { + FlatTuple tuple{ + testing::internal::FlatTupleConstructTag{}, ConstructionCounting{}}; + } + EXPECT_EQ(ConstructionCounting::default_ctor_calls, 1); + EXPECT_EQ(ConstructionCounting::dtor_calls, 2); + EXPECT_EQ(ConstructionCounting::copy_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::move_ctor_calls, 1); + EXPECT_EQ(ConstructionCounting::copy_assignment_calls, 0); + EXPECT_EQ(ConstructionCounting::move_assignment_calls, 0); + + // Copy assignment. + // TODO(ofats): it should be testing assignment operator of FlatTuple, not its + // elements + ConstructionCounting::Reset(); + { + FlatTuple tuple; + ConstructionCounting elem; + tuple.Get<0>() = elem; + } + EXPECT_EQ(ConstructionCounting::default_ctor_calls, 2); + EXPECT_EQ(ConstructionCounting::dtor_calls, 2); + EXPECT_EQ(ConstructionCounting::copy_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::move_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::copy_assignment_calls, 1); + EXPECT_EQ(ConstructionCounting::move_assignment_calls, 0); + + // Move assignment. + // TODO(ofats): it should be testing assignment operator of FlatTuple, not its + // elements + ConstructionCounting::Reset(); + { + FlatTuple tuple; + tuple.Get<0>() = ConstructionCounting{}; + } + EXPECT_EQ(ConstructionCounting::default_ctor_calls, 2); + EXPECT_EQ(ConstructionCounting::dtor_calls, 2); + EXPECT_EQ(ConstructionCounting::copy_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::move_ctor_calls, 0); + EXPECT_EQ(ConstructionCounting::copy_assignment_calls, 0); + EXPECT_EQ(ConstructionCounting::move_assignment_calls, 1); + + ConstructionCounting::Reset(); +} + +TEST(FlatTuple, ManyTypes) { + using testing::internal::FlatTuple; + + // Instantiate FlatTuple with 257 ints. + // Tests show that we can do it with thousands of elements, but very long + // compile times makes it unusuitable for this test. +#define GTEST_FLAT_TUPLE_INT8 int, int, int, int, int, int, int, int, +#define GTEST_FLAT_TUPLE_INT16 GTEST_FLAT_TUPLE_INT8 GTEST_FLAT_TUPLE_INT8 +#define GTEST_FLAT_TUPLE_INT32 GTEST_FLAT_TUPLE_INT16 GTEST_FLAT_TUPLE_INT16 +#define GTEST_FLAT_TUPLE_INT64 GTEST_FLAT_TUPLE_INT32 GTEST_FLAT_TUPLE_INT32 +#define GTEST_FLAT_TUPLE_INT128 GTEST_FLAT_TUPLE_INT64 GTEST_FLAT_TUPLE_INT64 +#define GTEST_FLAT_TUPLE_INT256 GTEST_FLAT_TUPLE_INT128 GTEST_FLAT_TUPLE_INT128 + + // Let's make sure that we can have a very long list of types without blowing + // up the template instantiation depth. + FlatTuple tuple; + + tuple.Get<0>() = 7; + tuple.Get<99>() = 17; + tuple.Get<256>() = 1000; + EXPECT_EQ(7, tuple.Get<0>()); + EXPECT_EQ(17, tuple.Get<99>()); + EXPECT_EQ(1000, tuple.Get<256>()); +} + +// Tests SkipPrefix(). + +TEST(SkipPrefixTest, SkipsWhenPrefixMatches) { + const char* const str = "hello"; + + const char* p = str; + EXPECT_TRUE(SkipPrefix("", &p)); + EXPECT_EQ(str, p); + + p = str; + EXPECT_TRUE(SkipPrefix("hell", &p)); + EXPECT_EQ(str + 4, p); +} + +TEST(SkipPrefixTest, DoesNotSkipWhenPrefixDoesNotMatch) { + const char* const str = "world"; + + const char* p = str; + EXPECT_FALSE(SkipPrefix("W", &p)); + EXPECT_EQ(str, p); + + p = str; + EXPECT_FALSE(SkipPrefix("world!", &p)); + EXPECT_EQ(str, p); +} + +// Tests ad_hoc_test_result(). +TEST(AdHocTestResultTest, AdHocTestResultForUnitTestDoesNotShowFailure) { + const testing::TestResult& test_result = + testing::UnitTest::GetInstance()->ad_hoc_test_result(); + EXPECT_FALSE(test_result.Failed()); +} + +class DynamicUnitTestFixture : public testing::Test {}; + +class DynamicTest : public DynamicUnitTestFixture { + void TestBody() override { EXPECT_TRUE(true); } +}; + +auto* dynamic_test = testing::RegisterTest( + "DynamicUnitTestFixture", "DynamicTest", "TYPE", "VALUE", __FILE__, + __LINE__, []() -> DynamicUnitTestFixture* { return new DynamicTest; }); + +TEST(RegisterTest, WasRegistered) { + auto* unittest = testing::UnitTest::GetInstance(); + for (int i = 0; i < unittest->total_test_suite_count(); ++i) { + auto* tests = unittest->GetTestSuite(i); + if (tests->name() != std::string("DynamicUnitTestFixture")) continue; + for (int j = 0; j < tests->total_test_count(); ++j) { + if (tests->GetTestInfo(j)->name() != std::string("DynamicTest")) continue; + // Found it. + EXPECT_STREQ(tests->GetTestInfo(j)->value_param(), "VALUE"); + EXPECT_STREQ(tests->GetTestInfo(j)->type_param(), "TYPE"); + return; + } + } + + FAIL() << "Didn't find the test!"; +} + +// Test that the pattern globbing algorithm is linear. If not, this test should +// time out. +TEST(PatternGlobbingTest, MatchesFilterLinearRuntime) { + std::string name(100, 'a'); // Construct the string (a^100)b + name.push_back('b'); + + std::string pattern; // Construct the string ((a*)^100)b + for (int i = 0; i < 100; ++i) { + pattern.append("a*"); + } + pattern.push_back('b'); + + EXPECT_TRUE( + testing::internal::UnitTestOptions::MatchesFilter(name, pattern.c_str())); +} + +TEST(PatternGlobbingTest, MatchesFilterWithMultiplePatterns) { + const std::string name = "aaaa"; + EXPECT_TRUE(testing::internal::UnitTestOptions::MatchesFilter(name, "a*")); + EXPECT_TRUE(testing::internal::UnitTestOptions::MatchesFilter(name, "a*:")); + EXPECT_FALSE(testing::internal::UnitTestOptions::MatchesFilter(name, "ab")); + EXPECT_FALSE(testing::internal::UnitTestOptions::MatchesFilter(name, "ab:")); + EXPECT_TRUE(testing::internal::UnitTestOptions::MatchesFilter(name, "ab:a*")); +} + +TEST(PatternGlobbingTest, MatchesFilterEdgeCases) { + EXPECT_FALSE(testing::internal::UnitTestOptions::MatchesFilter("", "*a")); + EXPECT_TRUE(testing::internal::UnitTestOptions::MatchesFilter("", "*")); + EXPECT_FALSE(testing::internal::UnitTestOptions::MatchesFilter("a", "")); + EXPECT_TRUE(testing::internal::UnitTestOptions::MatchesFilter("", "")); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfile1_test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfile1_test_.cc new file mode 100644 index 000000000000..19aa252a3010 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfile1_test_.cc @@ -0,0 +1,43 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// gtest_xml_outfile1_test_ writes some xml via TestProperty used by +// gtest_xml_outfiles_test.py + +#include "gtest/gtest.h" + +class PropertyOne : public testing::Test { + protected: + void SetUp() override { RecordProperty("SetUpProp", 1); } + void TearDown() override { RecordProperty("TearDownProp", 1); } +}; + +TEST_F(PropertyOne, TestSomeProperties) { + RecordProperty("TestSomeProperty", 1); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfile2_test_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfile2_test_.cc new file mode 100644 index 000000000000..f9a2a6e9846d --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfile2_test_.cc @@ -0,0 +1,43 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// gtest_xml_outfile2_test_ writes some xml via TestProperty used by +// gtest_xml_outfiles_test.py + +#include "gtest/gtest.h" + +class PropertyTwo : public testing::Test { + protected: + void SetUp() override { RecordProperty("SetUpProp", 2); } + void TearDown() override { RecordProperty("TearDownProp", 2); } +}; + +TEST_F(PropertyTwo, TestSomeProperties) { + RecordProperty("TestSomeProperty", 2); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfiles_test.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfiles_test.py new file mode 100755 index 000000000000..916bdf4de484 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_outfiles_test.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# +# Copyright 2008, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for the gtest_xml_output module.""" + +import os +from xml.dom import minidom, Node +from googletest.test import gtest_test_utils +from googletest.test import gtest_xml_test_utils + +GTEST_OUTPUT_SUBDIR = "xml_outfiles" +GTEST_OUTPUT_1_TEST = "gtest_xml_outfile1_test_" +GTEST_OUTPUT_2_TEST = "gtest_xml_outfile2_test_" + +EXPECTED_XML_1 = """ + + + + + + + + + + + +""" + +EXPECTED_XML_2 = """ + + + + + + + + + + + +""" + + +class GTestXMLOutFilesTest(gtest_xml_test_utils.GTestXMLTestCase): + """Unit test for Google Test's XML output functionality.""" + + def setUp(self): + # We want the trailing '/' that the last "" provides in os.path.join, for + # telling Google Test to create an output directory instead of a single file + # for xml output. + self.output_dir_ = os.path.join(gtest_test_utils.GetTempDir(), + GTEST_OUTPUT_SUBDIR, "") + self.DeleteFilesAndDir() + + def tearDown(self): + self.DeleteFilesAndDir() + + def DeleteFilesAndDir(self): + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_1_TEST + ".xml")) + except os.error: + pass + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_2_TEST + ".xml")) + except os.error: + pass + try: + os.rmdir(self.output_dir_) + except os.error: + pass + + def testOutfile1(self): + self._TestOutFile(GTEST_OUTPUT_1_TEST, EXPECTED_XML_1) + + def testOutfile2(self): + self._TestOutFile(GTEST_OUTPUT_2_TEST, EXPECTED_XML_2) + + def _TestOutFile(self, test_name, expected_xml): + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(test_name) + command = [gtest_prog_path, "--gtest_output=xml:%s" % self.output_dir_] + p = gtest_test_utils.Subprocess(command, + working_dir=gtest_test_utils.GetTempDir()) + self.assert_(p.exited) + self.assertEquals(0, p.exit_code) + + output_file_name1 = test_name + ".xml" + output_file1 = os.path.join(self.output_dir_, output_file_name1) + output_file_name2 = 'lt-' + output_file_name1 + output_file2 = os.path.join(self.output_dir_, output_file_name2) + self.assert_(os.path.isfile(output_file1) or os.path.isfile(output_file2), + output_file1) + + expected = minidom.parseString(expected_xml) + if os.path.isfile(output_file1): + actual = minidom.parse(output_file1) + else: + actual = minidom.parse(output_file2) + self.NormalizeXml(actual.documentElement) + self.AssertEquivalentNodes(expected.documentElement, + actual.documentElement) + expected.unlink() + actual.unlink() + + +if __name__ == "__main__": + os.environ["GTEST_STACK_TRACE_DEPTH"] = "0" + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_output_unittest.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_output_unittest.py new file mode 100755 index 000000000000..f0b0c3b90645 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_output_unittest.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python +# +# Copyright 2006, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test for the gtest_xml_output module""" + +import datetime +import errno +import os +import re +import sys +from xml.dom import minidom, Node + +from googletest.test import gtest_test_utils +from googletest.test import gtest_xml_test_utils + +GTEST_FILTER_FLAG = '--gtest_filter' +GTEST_LIST_TESTS_FLAG = '--gtest_list_tests' +GTEST_OUTPUT_FLAG = '--gtest_output' +GTEST_DEFAULT_OUTPUT_FILE = 'test_detail.xml' +GTEST_PROGRAM_NAME = 'gtest_xml_output_unittest_' + +# The flag indicating stacktraces are not supported +NO_STACKTRACE_SUPPORT_FLAG = '--no_stacktrace_support' + +# The environment variables for test sharding. +TOTAL_SHARDS_ENV_VAR = 'GTEST_TOTAL_SHARDS' +SHARD_INDEX_ENV_VAR = 'GTEST_SHARD_INDEX' +SHARD_STATUS_FILE_ENV_VAR = 'GTEST_SHARD_STATUS_FILE' + +SUPPORTS_STACK_TRACES = NO_STACKTRACE_SUPPORT_FLAG not in sys.argv + +if SUPPORTS_STACK_TRACES: + STACK_TRACE_TEMPLATE = '\nStack trace:\n*' +else: + STACK_TRACE_TEMPLATE = '' + # unittest.main() can't handle unknown flags + sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) + +EXPECTED_NON_EMPTY_XML = """ + + + + + + + + + + + + + + + + + + + + ]]>%(stack)s]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +""" % { + 'stack': STACK_TRACE_TEMPLATE +} + +EXPECTED_FILTERED_TEST_XML = """ + + + + +""" + +EXPECTED_SHARDED_TEST_XML = """ + + + + + + + + + + + + + + +""" + +EXPECTED_NO_TEST_XML = """ + + + + + + +""" % { + 'stack': STACK_TRACE_TEMPLATE +} + +GTEST_PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath(GTEST_PROGRAM_NAME) + +SUPPORTS_TYPED_TESTS = 'TypedTest' in gtest_test_utils.Subprocess( + [GTEST_PROGRAM_PATH, GTEST_LIST_TESTS_FLAG], capture_stderr=False).output + + +class GTestXMLOutputUnitTest(gtest_xml_test_utils.GTestXMLTestCase): + """ + Unit test for Google Test's XML output functionality. + """ + + # This test currently breaks on platforms that do not support typed and + # type-parameterized tests, so we don't run it under them. + if SUPPORTS_TYPED_TESTS: + def testNonEmptyXmlOutput(self): + """ + Runs a test program that generates a non-empty XML output, and + tests that the XML output is expected. + """ + self._TestXmlOutput(GTEST_PROGRAM_NAME, EXPECTED_NON_EMPTY_XML, 1) + + def testNoTestXmlOutput(self): + """Verifies XML output for a Google Test binary without actual tests. + + Runs a test program that generates an XML output for a binary without tests, + and tests that the XML output is expected. + """ + + self._TestXmlOutput('gtest_no_test_unittest', EXPECTED_NO_TEST_XML, 0) + + def testTimestampValue(self): + """Checks whether the timestamp attribute in the XML output is valid. + + Runs a test program that generates an empty XML output, and checks if + the timestamp attribute in the testsuites tag is valid. + """ + actual = self._GetXmlOutput('gtest_no_test_unittest', [], {}, 0) + date_time_str = actual.documentElement.getAttributeNode('timestamp').value + # datetime.strptime() is only available in Python 2.5+ so we have to + # parse the expected datetime manually. + match = re.match(r'(\d+)-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)', date_time_str) + self.assertTrue( + re.match, + 'XML datettime string %s has incorrect format' % date_time_str) + date_time_from_xml = datetime.datetime( + year=int(match.group(1)), month=int(match.group(2)), + day=int(match.group(3)), hour=int(match.group(4)), + minute=int(match.group(5)), second=int(match.group(6))) + + time_delta = abs(datetime.datetime.now() - date_time_from_xml) + # timestamp value should be near the current local time + self.assertTrue(time_delta < datetime.timedelta(seconds=600), + 'time_delta is %s' % time_delta) + actual.unlink() + + def testDefaultOutputFile(self): + """ + Confirms that Google Test produces an XML output file with the expected + default name if no name is explicitly specified. + """ + output_file = os.path.join(gtest_test_utils.GetTempDir(), + GTEST_DEFAULT_OUTPUT_FILE) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath( + 'gtest_no_test_unittest') + try: + os.remove(output_file) + except OSError: + e = sys.exc_info()[1] + if e.errno != errno.ENOENT: + raise + + p = gtest_test_utils.Subprocess( + [gtest_prog_path, '%s=xml' % GTEST_OUTPUT_FLAG], + working_dir=gtest_test_utils.GetTempDir()) + self.assert_(p.exited) + self.assertEquals(0, p.exit_code) + self.assert_(os.path.isfile(output_file)) + + def testSuppressedXmlOutput(self): + """ + Tests that no XML file is generated if the default XML listener is + shut down before RUN_ALL_TESTS is invoked. + """ + + xml_path = os.path.join(gtest_test_utils.GetTempDir(), + GTEST_PROGRAM_NAME + 'out.xml') + if os.path.isfile(xml_path): + os.remove(xml_path) + + command = [GTEST_PROGRAM_PATH, + '%s=xml:%s' % (GTEST_OUTPUT_FLAG, xml_path), + '--shut_down_xml'] + p = gtest_test_utils.Subprocess(command) + if p.terminated_by_signal: + # p.signal is available only if p.terminated_by_signal is True. + self.assertFalse( + p.terminated_by_signal, + '%s was killed by signal %d' % (GTEST_PROGRAM_NAME, p.signal)) + else: + self.assert_(p.exited) + self.assertEquals(1, p.exit_code, + "'%s' exited with code %s, which doesn't match " + 'the expected exit code %s.' + % (command, p.exit_code, 1)) + + self.assert_(not os.path.isfile(xml_path)) + + def testFilteredTestXmlOutput(self): + """Verifies XML output when a filter is applied. + + Runs a test program that executes only some tests and verifies that + non-selected tests do not show up in the XML output. + """ + + self._TestXmlOutput(GTEST_PROGRAM_NAME, EXPECTED_FILTERED_TEST_XML, 0, + extra_args=['%s=SuccessfulTest.*' % GTEST_FILTER_FLAG]) + + def testShardedTestXmlOutput(self): + """Verifies XML output when run using multiple shards. + + Runs a test program that executes only one shard and verifies that tests + from other shards do not show up in the XML output. + """ + + self._TestXmlOutput( + GTEST_PROGRAM_NAME, + EXPECTED_SHARDED_TEST_XML, + 0, + extra_env={SHARD_INDEX_ENV_VAR: '0', + TOTAL_SHARDS_ENV_VAR: '10'}) + + def _GetXmlOutput(self, gtest_prog_name, extra_args, extra_env, + expected_exit_code): + """ + Returns the xml output generated by running the program gtest_prog_name. + Furthermore, the program's exit code must be expected_exit_code. + """ + xml_path = os.path.join(gtest_test_utils.GetTempDir(), + gtest_prog_name + 'out.xml') + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(gtest_prog_name) + + command = ([gtest_prog_path, '%s=xml:%s' % (GTEST_OUTPUT_FLAG, xml_path)] + + extra_args) + environ_copy = os.environ.copy() + if extra_env: + environ_copy.update(extra_env) + p = gtest_test_utils.Subprocess(command, env=environ_copy) + + if p.terminated_by_signal: + self.assert_(False, + '%s was killed by signal %d' % (gtest_prog_name, p.signal)) + else: + self.assert_(p.exited) + self.assertEquals(expected_exit_code, p.exit_code, + "'%s' exited with code %s, which doesn't match " + 'the expected exit code %s.' + % (command, p.exit_code, expected_exit_code)) + actual = minidom.parse(xml_path) + return actual + + def _TestXmlOutput(self, gtest_prog_name, expected_xml, + expected_exit_code, extra_args=None, extra_env=None): + """ + Asserts that the XML document generated by running the program + gtest_prog_name matches expected_xml, a string containing another + XML document. Furthermore, the program's exit code must be + expected_exit_code. + """ + + actual = self._GetXmlOutput(gtest_prog_name, extra_args or [], + extra_env or {}, expected_exit_code) + expected = minidom.parseString(expected_xml) + self.NormalizeXml(actual.documentElement) + self.AssertEquivalentNodes(expected.documentElement, + actual.documentElement) + expected.unlink() + actual.unlink() + + +if __name__ == '__main__': + os.environ['GTEST_STACK_TRACE_DEPTH'] = '1' + gtest_test_utils.Main() diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_output_unittest_.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_output_unittest_.cc new file mode 100644 index 000000000000..c0036aaef9df --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_output_unittest_.cc @@ -0,0 +1,193 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Unit test for Google Test XML output. +// +// A user can specify XML output in a Google Test program to run via +// either the GTEST_OUTPUT environment variable or the --gtest_output +// flag. This is used for testing such functionality. +// +// This program will be invoked from a Python unit test. Don't run it +// directly. + +#include "gtest/gtest.h" + +using ::testing::InitGoogleTest; +using ::testing::TestEventListeners; +using ::testing::TestWithParam; +using ::testing::UnitTest; +using ::testing::Test; +using ::testing::Values; + +class SuccessfulTest : public Test { +}; + +TEST_F(SuccessfulTest, Succeeds) { + SUCCEED() << "This is a success."; + ASSERT_EQ(1, 1); +} + +class FailedTest : public Test { +}; + +TEST_F(FailedTest, Fails) { + ASSERT_EQ(1, 2); +} + +class DisabledTest : public Test { +}; + +TEST_F(DisabledTest, DISABLED_test_not_run) { + FAIL() << "Unexpected failure: Disabled test should not be run"; +} + +class SkippedTest : public Test { +}; + +TEST_F(SkippedTest, Skipped) { + GTEST_SKIP(); +} + +TEST_F(SkippedTest, SkippedWithMessage) { + GTEST_SKIP() << "It is good practice to tell why you skip a test."; +} + +TEST_F(SkippedTest, SkippedAfterFailure) { + EXPECT_EQ(1, 2); + GTEST_SKIP() << "It is good practice to tell why you skip a test."; +} + +TEST(MixedResultTest, Succeeds) { + EXPECT_EQ(1, 1); + ASSERT_EQ(1, 1); +} + +TEST(MixedResultTest, Fails) { + EXPECT_EQ(1, 2); + ASSERT_EQ(2, 3); +} + +TEST(MixedResultTest, DISABLED_test) { + FAIL() << "Unexpected failure: Disabled test should not be run"; +} + +TEST(XmlQuotingTest, OutputsCData) { + FAIL() << "XML output: " + ""; +} + +// Helps to test that invalid characters produced by test code do not make +// it into the XML file. +TEST(InvalidCharactersTest, InvalidCharactersInMessage) { + FAIL() << "Invalid characters in brackets [\x1\x2]"; +} + +class PropertyRecordingTest : public Test { + public: + static void SetUpTestSuite() { RecordProperty("SetUpTestSuite", "yes"); } + static void TearDownTestSuite() { + RecordProperty("TearDownTestSuite", "aye"); + } +}; + +TEST_F(PropertyRecordingTest, OneProperty) { + RecordProperty("key_1", "1"); +} + +TEST_F(PropertyRecordingTest, IntValuedProperty) { + RecordProperty("key_int", 1); +} + +TEST_F(PropertyRecordingTest, ThreeProperties) { + RecordProperty("key_1", "1"); + RecordProperty("key_2", "2"); + RecordProperty("key_3", "3"); +} + +TEST_F(PropertyRecordingTest, TwoValuesForOneKeyUsesLastValue) { + RecordProperty("key_1", "1"); + RecordProperty("key_1", "2"); +} + +TEST(NoFixtureTest, RecordProperty) { + RecordProperty("key", "1"); +} + +void ExternalUtilityThatCallsRecordProperty(const std::string& key, int value) { + testing::Test::RecordProperty(key, value); +} + +void ExternalUtilityThatCallsRecordProperty(const std::string& key, + const std::string& value) { + testing::Test::RecordProperty(key, value); +} + +TEST(NoFixtureTest, ExternalUtilityThatCallsRecordIntValuedProperty) { + ExternalUtilityThatCallsRecordProperty("key_for_utility_int", 1); +} + +TEST(NoFixtureTest, ExternalUtilityThatCallsRecordStringValuedProperty) { + ExternalUtilityThatCallsRecordProperty("key_for_utility_string", "1"); +} + +// Verifies that the test parameter value is output in the 'value_param' +// XML attribute for value-parameterized tests. +class ValueParamTest : public TestWithParam {}; +TEST_P(ValueParamTest, HasValueParamAttribute) {} +TEST_P(ValueParamTest, AnotherTestThatHasValueParamAttribute) {} +INSTANTIATE_TEST_SUITE_P(Single, ValueParamTest, Values(33, 42)); + +// Verifies that the type parameter name is output in the 'type_param' +// XML attribute for typed tests. +template class TypedTest : public Test {}; +typedef testing::Types TypedTestTypes; +TYPED_TEST_SUITE(TypedTest, TypedTestTypes); +TYPED_TEST(TypedTest, HasTypeParamAttribute) {} + +// Verifies that the type parameter name is output in the 'type_param' +// XML attribute for type-parameterized tests. +template +class TypeParameterizedTestSuite : public Test {}; +TYPED_TEST_SUITE_P(TypeParameterizedTestSuite); +TYPED_TEST_P(TypeParameterizedTestSuite, HasTypeParamAttribute) {} +REGISTER_TYPED_TEST_SUITE_P(TypeParameterizedTestSuite, HasTypeParamAttribute); +typedef testing::Types TypeParameterizedTestSuiteTypes; // NOLINT +INSTANTIATE_TYPED_TEST_SUITE_P(Single, TypeParameterizedTestSuite, + TypeParameterizedTestSuiteTypes); + +int main(int argc, char** argv) { + InitGoogleTest(&argc, argv); + + if (argc > 1 && strcmp(argv[1], "--shut_down_xml") == 0) { + TestEventListeners& listeners = UnitTest::GetInstance()->listeners(); + delete listeners.Release(listeners.default_xml_generator()); + } + testing::Test::RecordProperty("ad_hoc_property", "42"); + return RUN_ALL_TESTS(); +} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_test_utils.py b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_test_utils.py new file mode 100755 index 000000000000..50c6e7dca8a9 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/gtest_xml_test_utils.py @@ -0,0 +1,197 @@ +# Copyright 2006, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unit test utilities for gtest_xml_output""" + +import re +from xml.dom import minidom, Node +from googletest.test import gtest_test_utils + +GTEST_DEFAULT_OUTPUT_FILE = 'test_detail.xml' + +class GTestXMLTestCase(gtest_test_utils.TestCase): + """ + Base class for tests of Google Test's XML output functionality. + """ + + + def AssertEquivalentNodes(self, expected_node, actual_node): + """ + Asserts that actual_node (a DOM node object) is equivalent to + expected_node (another DOM node object), in that either both of + them are CDATA nodes and have the same value, or both are DOM + elements and actual_node meets all of the following conditions: + + * It has the same tag name as expected_node. + * It has the same set of attributes as expected_node, each with + the same value as the corresponding attribute of expected_node. + Exceptions are any attribute named "time", which needs only be + convertible to a floating-point number and any attribute named + "type_param" which only has to be non-empty. + * It has an equivalent set of child nodes (including elements and + CDATA sections) as expected_node. Note that we ignore the + order of the children as they are not guaranteed to be in any + particular order. + """ + + if expected_node.nodeType == Node.CDATA_SECTION_NODE: + self.assertEquals(Node.CDATA_SECTION_NODE, actual_node.nodeType) + self.assertEquals(expected_node.nodeValue, actual_node.nodeValue) + return + + self.assertEquals(Node.ELEMENT_NODE, actual_node.nodeType) + self.assertEquals(Node.ELEMENT_NODE, expected_node.nodeType) + self.assertEquals(expected_node.tagName, actual_node.tagName) + + expected_attributes = expected_node.attributes + actual_attributes = actual_node.attributes + self.assertEquals( + expected_attributes.length, actual_attributes.length, + 'attribute numbers differ in element %s:\nExpected: %r\nActual: %r' % ( + actual_node.tagName, expected_attributes.keys(), + actual_attributes.keys())) + for i in range(expected_attributes.length): + expected_attr = expected_attributes.item(i) + actual_attr = actual_attributes.get(expected_attr.name) + self.assert_( + actual_attr is not None, + 'expected attribute %s not found in element %s' % + (expected_attr.name, actual_node.tagName)) + self.assertEquals( + expected_attr.value, actual_attr.value, + ' values of attribute %s in element %s differ: %s vs %s' % + (expected_attr.name, actual_node.tagName, + expected_attr.value, actual_attr.value)) + + expected_children = self._GetChildren(expected_node) + actual_children = self._GetChildren(actual_node) + self.assertEquals( + len(expected_children), len(actual_children), + 'number of child elements differ in element ' + actual_node.tagName) + for child_id, child in expected_children.items(): + self.assert_(child_id in actual_children, + '<%s> is not in <%s> (in element %s)' % + (child_id, actual_children, actual_node.tagName)) + self.AssertEquivalentNodes(child, actual_children[child_id]) + + identifying_attribute = { + 'testsuites': 'name', + 'testsuite': 'name', + 'testcase': 'name', + 'failure': 'message', + 'skipped': 'message', + 'property': 'name', + } + + def _GetChildren(self, element): + """ + Fetches all of the child nodes of element, a DOM Element object. + Returns them as the values of a dictionary keyed by the IDs of the + children. For , , , and + elements, the ID is the value of their "name" attribute; for + elements, it is the value of the "message" attribute; for + elements, it is the value of their parent's "name" attribute plus the + literal string "properties"; CDATA sections and non-whitespace + text nodes are concatenated into a single CDATA section with ID + "detail". An exception is raised if any element other than the above + four is encountered, if two child elements with the same identifying + attributes are encountered, or if any other type of node is encountered. + """ + + children = {} + for child in element.childNodes: + if child.nodeType == Node.ELEMENT_NODE: + if child.tagName == 'properties': + self.assert_(child.parentNode is not None, + 'Encountered element without a parent') + child_id = child.parentNode.getAttribute('name') + '-properties' + else: + self.assert_(child.tagName in self.identifying_attribute, + 'Encountered unknown element <%s>' % child.tagName) + child_id = child.getAttribute( + self.identifying_attribute[child.tagName]) + self.assert_(child_id not in children) + children[child_id] = child + elif child.nodeType in [Node.TEXT_NODE, Node.CDATA_SECTION_NODE]: + if 'detail' not in children: + if (child.nodeType == Node.CDATA_SECTION_NODE or + not child.nodeValue.isspace()): + children['detail'] = child.ownerDocument.createCDATASection( + child.nodeValue) + else: + children['detail'].nodeValue += child.nodeValue + else: + self.fail('Encountered unexpected node type %d' % child.nodeType) + return children + + def NormalizeXml(self, element): + """ + Normalizes Google Test's XML output to eliminate references to transient + information that may change from run to run. + + * The "time" attribute of , and + elements is replaced with a single asterisk, if it contains + only digit characters. + * The "timestamp" attribute of elements is replaced with a + single asterisk, if it contains a valid ISO8601 datetime value. + * The "type_param" attribute of elements is replaced with a + single asterisk (if it sn non-empty) as it is the type name returned + by the compiler and is platform dependent. + * The line info reported in the first line of the "message" + attribute and CDATA section of elements is replaced with the + file's basename and a single asterisk for the line number. + * The directory names in file paths are removed. + * The stack traces are removed. + """ + + if element.tagName in ('testsuites', 'testsuite', 'testcase'): + timestamp = element.getAttributeNode('timestamp') + timestamp.value = re.sub(r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d\d\d$', + '*', timestamp.value) + if element.tagName in ('testsuites', 'testsuite', 'testcase'): + time = element.getAttributeNode('time') + time.value = re.sub(r'^\d+(\.\d+)?$', '*', time.value) + type_param = element.getAttributeNode('type_param') + if type_param and type_param.value: + type_param.value = '*' + elif element.tagName == 'failure' or element.tagName == 'skipped': + source_line_pat = r'^.*[/\\](.*:)\d+\n' + # Replaces the source line information with a normalized form. + message = element.getAttributeNode('message') + message.value = re.sub(source_line_pat, '\\1*\n', message.value) + for child in element.childNodes: + if child.nodeType == Node.CDATA_SECTION_NODE: + # Replaces the source line information with a normalized form. + cdata = re.sub(source_line_pat, '\\1*\n', child.nodeValue) + # Removes the actual stack trace. + child.nodeValue = re.sub(r'Stack trace:\n(.|\n)*', + 'Stack trace:\n*', cdata) + for child in element.childNodes: + if child.nodeType == Node.ELEMENT_NODE: + self.NormalizeXml(child) diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/production.cc b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/production.cc new file mode 100644 index 000000000000..0f69f6dbd2e2 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/production.cc @@ -0,0 +1,35 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// This is part of the unit test for gtest_prod.h. + +#include "production.h" + +PrivateCode::PrivateCode() : x_(0) {} diff --git a/packages/shylu/shylu_node/tacho/unit-test/googletest/test/production.h b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/production.h new file mode 100644 index 000000000000..41a547225423 --- /dev/null +++ b/packages/shylu/shylu_node/tacho/unit-test/googletest/test/production.h @@ -0,0 +1,54 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// This is part of the unit test for gtest_prod.h. + +#ifndef GOOGLETEST_TEST_PRODUCTION_H_ +#define GOOGLETEST_TEST_PRODUCTION_H_ + +#include "gtest/gtest_prod.h" + +class PrivateCode { + public: + // Declares a friend test that does not use a fixture. + FRIEND_TEST(PrivateCodeTest, CanAccessPrivateMembers); + + // Declares a friend test that uses a fixture. + FRIEND_TEST(PrivateCodeFixtureTest, CanAccessPrivateMembers); + + PrivateCode(); + + int x() const { return x_; } + private: + void set_x(int an_x) { x_ = an_x; } + int x_; +}; + +#endif // GOOGLETEST_TEST_PRODUCTION_H_