diff --git a/cpp/include/cudf_test/testing_main.hpp b/cpp/include/cudf_test/testing_main.hpp index 88e3088d794..ecac761f7cb 100644 --- a/cpp/include/cudf_test/testing_main.hpp +++ b/cpp/include/cudf_test/testing_main.hpp @@ -145,6 +145,32 @@ inline auto parse_cudf_test_opts(int argc, char** argv) } } +/** + * @brief Sets up stream mode memory resource adaptor + * + * The resource adaptor is only set as the current device resource if the + * stream mode is enabled. + * + * The caller must keep the return object alive for the life of the test runs. + * + * @param cmd_opts Command line options returned by parse_cudf_test_opts + * @return Memory resource adaptor + */ +inline auto make_stream_mode_adaptor(cxxopts::ParseResult const& cmd_opts) +{ + auto resource = rmm::mr::get_current_device_resource(); + auto const stream_mode = cmd_opts["stream_mode"].as(); + auto const stream_error_mode = cmd_opts["stream_error_mode"].as(); + auto const error_on_invalid_stream = (stream_error_mode == "error"); + auto const check_default_stream = (stream_mode == "new_cudf_default"); + auto adaptor = + make_stream_checking_resource_adaptor(resource, error_on_invalid_stream, check_default_stream); + if ((stream_mode == "new_cudf_default") || (stream_mode == "new_testing_default")) { + rmm::mr::set_current_device_resource(&adaptor); + } + return adaptor; +} + /** * @brief Macro that defines main function for gtest programs that use rmm * @@ -155,25 +181,14 @@ inline auto parse_cudf_test_opts(int argc, char** argv) * function parses the command line to customize test behavior, like the * allocation mode used for creating the default memory resource. */ -#define CUDF_TEST_PROGRAM_MAIN() \ - int main(int argc, char** argv) \ - { \ - ::testing::InitGoogleTest(&argc, argv); \ - auto const cmd_opts = parse_cudf_test_opts(argc, argv); \ - auto const rmm_mode = cmd_opts["rmm_mode"].as(); \ - auto resource = cudf::test::create_memory_resource(rmm_mode); \ - rmm::mr::set_current_device_resource(resource.get()); \ - \ - auto const stream_mode = cmd_opts["stream_mode"].as(); \ - if ((stream_mode == "new_cudf_default") || (stream_mode == "new_testing_default")) { \ - auto const stream_error_mode = cmd_opts["stream_error_mode"].as(); \ - auto const error_on_invalid_stream = (stream_error_mode == "error"); \ - auto const check_default_stream = (stream_mode == "new_cudf_default"); \ - auto adaptor = make_stream_checking_resource_adaptor( \ - resource.get(), error_on_invalid_stream, check_default_stream); \ - rmm::mr::set_current_device_resource(&adaptor); \ - return RUN_ALL_TESTS(); \ - } \ - \ - return RUN_ALL_TESTS(); \ +#define CUDF_TEST_PROGRAM_MAIN() \ + int main(int argc, char** argv) \ + { \ + ::testing::InitGoogleTest(&argc, argv); \ + auto const cmd_opts = parse_cudf_test_opts(argc, argv); \ + auto const rmm_mode = cmd_opts["rmm_mode"].as(); \ + auto resource = cudf::test::create_memory_resource(rmm_mode); \ + rmm::mr::set_current_device_resource(resource.get()); \ + auto adaptor = make_stream_mode_adaptor(cmd_opts); \ + return RUN_ALL_TESTS(); \ } diff --git a/cpp/tests/error/error_handling_test.cu b/cpp/tests/error/error_handling_test.cu index 674d2e0a6ea..46d01ec14ff 100644 --- a/cpp/tests/error/error_handling_test.cu +++ b/cpp/tests/error/error_handling_test.cu @@ -128,17 +128,7 @@ TEST(DebugAssert, cudf_assert_true) int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); - auto const cmd_opts = parse_cudf_test_opts(argc, argv); - auto const stream_mode = cmd_opts["stream_mode"].as(); - if ((stream_mode == "new_cudf_default") || (stream_mode == "new_testing_default")) { - auto resource = rmm::mr::get_current_device_resource(); - auto const stream_error_mode = cmd_opts["stream_error_mode"].as(); - auto const error_on_invalid_stream = (stream_error_mode == "error"); - auto const check_default_stream = (stream_mode == "new_cudf_default"); - auto adaptor = make_stream_checking_resource_adaptor( - resource, error_on_invalid_stream, check_default_stream); - rmm::mr::set_current_device_resource(&adaptor); - return RUN_ALL_TESTS(); - } + auto const cmd_opts = parse_cudf_test_opts(argc, argv); + auto adaptor = make_stream_mode_adaptor(cmd_opts); return RUN_ALL_TESTS(); }