diff --git a/cpp/include/raft/common/detail/nvtx.hpp b/cpp/include/raft/common/detail/nvtx.hpp index f2993f0ec6..4cef7c07bc 100644 --- a/cpp/include/raft/common/detail/nvtx.hpp +++ b/cpp/include/raft/common/detail/nvtx.hpp @@ -18,15 +18,16 @@ #include -namespace raft::common::detail { +namespace raft::common::nvtx::detail { #ifdef NVTX_ENABLED -#include #include #include #include +#include #include +#include #include /** @@ -134,8 +135,25 @@ inline auto generate_next_color(const std::string& tag) -> uint32_t return rgb; } -static inline nvtxDomainHandle_t domain = nvtxDomainCreateA("application"); +template +struct domain_store { + /* If `Domain::name` does not exist, this default instance is used and throws the error. */ + static_assert(sizeof(Domain) != sizeof(Domain), + "Type used to identify a domain must contain a static member 'char const* name'"); + static inline nvtxDomainHandle_t const kValue = nullptr; +}; + +template +struct domain_store< + Domain, + /* Check if there exists `Domain::name` */ + std::enable_if_t< + std::is_same::type>::value, + Domain>> { + static inline nvtxDomainHandle_t const kValue = nvtxDomainCreateA(Domain::name); +}; +template inline void push_range_name(const char* name) { nvtxEventAttributes_t event_attrib = {0}; @@ -145,10 +163,10 @@ inline void push_range_name(const char* name) event_attrib.color = generate_next_color(name); event_attrib.messageType = NVTX_MESSAGE_TYPE_ASCII; event_attrib.message.ascii = name; - nvtxDomainRangePushEx(domain, &event_attrib); + nvtxDomainRangePushEx(domain_store::kValue, &event_attrib); } -template +template inline void push_range(const char* format, Args... args) { if constexpr (sizeof...(args) > 0) { @@ -156,43 +174,30 @@ inline void push_range(const char* format, Args... args) assert(length >= 0); std::vector buf(length + 1); std::snprintf(buf.data(), length + 1, format, args...); - push_range_name(buf.data()); + push_range_name(buf.data()); } else { - push_range_name(format); + push_range_name(format); } } -template -inline void push_range(rmm::cuda_stream_view stream, const char* format, Args... args) +template +inline void pop_range() { - stream.synchronize(); - push_range(format, args...); -} - -inline void pop_range() { nvtxDomainRangePop(domain); } - -inline void pop_range(rmm::cuda_stream_view stream) -{ - stream.synchronize(); - pop_range(); + nvtxDomainRangePop(domain_store::kValue); } #else // NVTX_ENABLED -template +template inline void push_range(const char* format, Args... args) { } -template -inline void push_range(rmm::cuda_stream_view stream, const char* format, Args... args) +template +inline void pop_range() { } -inline void pop_range() {} - -inline void pop_range(rmm::cuda_stream_view stream) {} - #endif // NVTX_ENABLED -} // namespace raft::common::detail +} // namespace raft::common::nvtx::detail diff --git a/cpp/include/raft/common/nvtx.hpp b/cpp/include/raft/common/nvtx.hpp index 5ce1c2960c..35f5802a5b 100644 --- a/cpp/include/raft/common/nvtx.hpp +++ b/cpp/include/raft/common/nvtx.hpp @@ -16,93 +16,87 @@ #pragma once -#include #include "detail/nvtx.hpp" +#include + +namespace raft::common::nvtx { + +namespace domain { + +/** The default NVTX domain. */ +struct app { + static constexpr char const* name{"application"}; +}; + +/** This NVTX domain is supposed to be used within raft. */ +struct raft { + static constexpr char const* name{"raft"}; +}; -namespace raft::common { +} // namespace domain /** - * @brief Push a named nvtx range + * @brief Push a named NVTX range. + * + * @tparam Domain optional struct that defines the NVTX domain message; + * You can create a new domain with a custom message as follows: + * `struct custom_domain { static constexpr char const* name{"custom message"}; }` + * NB: make sure to use the same domain for `push_range` and `pop_range`. * @param format range name format (accepts printf-style arguments) * @param args the arguments for the printf-style formatting */ -template -inline void push_nvtx_range(const char* format, Args... args) +template +inline void push_range(const char* format, Args... args) { - detail::push_range(format, args...); + detail::push_range(format, args...); } /** - * @brief Synchronize CUDA stream and push a named nvtx range - * @param format range name format (accepts printf-style arguments) - * @param args the arguments for the printf-style formatting - * @param stream stream to synchronize + * @brief Pop the latest range. + * + * @tparam Domain optional struct that defines the NVTX domain message; + * You can create a new domain with a custom message as follows: + * `struct custom_domain { static constexpr char const* name{"custom message"}; }` + * NB: make sure to use the same domain for `push_range` and `pop_range`. */ -template -inline void push_nvtx_range(rmm::cuda_stream_view stream, const char* format, Args... args) +template +inline void pop_range() { - detail::push_range(stream, format, args...); + detail::pop_range(); } -/** Pop the latest range */ -inline void pop_nvtx_range() { detail::pop_range(); } - /** - * @brief Synchronize CUDA stream and pop the latest nvtx range - * @param stream stream to synchronize + * @brief Push a named NVTX range that would be popped at the end of the object lifetime. + * + * @tparam Domain optional struct that defines the NVTX domain message; + * You can create a new domain with a custom message as follows: + * `struct custom_domain { static constexpr char const* name{"custom message"}; }` */ -inline void pop_nvtx_range(rmm::cuda_stream_view stream) { detail::pop_range(stream); } - -/** Push a named nvtx range that would be popped at the end of the object lifetime. */ -class nvtx_range { - private: - std::optional stream_maybe_; - +template +class range { public: /** - * Synchronize CUDA stream and push a named nvtx range - * At the end of the object lifetime, synchronize again and pop the range. - * - * @param stream stream to synchronize - * @param format range name format (accepts printf-style arguments) - * @param args the arguments for the printf-style formatting - */ - template - explicit nvtx_range(rmm::cuda_stream_view stream, const char* format, Args... args) - : stream_maybe_(std::make_optional(stream)) - { - push_nvtx_range(stream, format, args...); - } - - /** - * Push a named nvtx range. + * Push a named NVTX range. * At the end of the object lifetime, pop the range back. * * @param format range name format (accepts printf-style arguments) * @param args the arguments for the printf-style formatting */ template - explicit nvtx_range(const char* format, Args... args) : stream_maybe_(std::nullopt) + explicit range(const char* format, Args... args) { - push_nvtx_range(format, args...); + push_range(format, args...); } - ~nvtx_range() - { - if (stream_maybe_.has_value()) { - pop_nvtx_range(*stream_maybe_); - } else { - pop_nvtx_range(); - } - } + ~range() { pop_range(); } /* This object is not meant to be touched. */ - nvtx_range(const nvtx_range&) = delete; - nvtx_range(nvtx_range&&) = delete; - auto operator=(const nvtx_range&) -> nvtx_range& = delete; - auto operator=(nvtx_range&&) -> nvtx_range& = delete; + range(const range&) = delete; + range(range&&) = delete; + auto operator=(const range&) -> range& = delete; + auto operator=(range&&) -> range& = delete; static auto operator new(std::size_t) -> void* = delete; static auto operator new[](std::size_t) -> void* = delete; }; -} // namespace raft::common +} // namespace raft::common::nvtx diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index c08c776095..2afae788a1 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -64,7 +64,8 @@ void svdQR(const raft::handle_t& handle, bool gen_right_vec, cudaStream_t stream) { - common::nvtx_range fun_scope("raft::linalg::svdQR(%d, %d)", n_rows, n_cols); + common::nvtx::range fun_scope( + "raft::linalg::svdQR(%d, %d)", n_rows, n_cols); cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); cublasHandle_t cublasH = handle.get_cublas_handle(); @@ -142,7 +143,8 @@ void svdEig(const raft::handle_t& handle, bool gen_left_vec, cudaStream_t stream) { - common::nvtx_range fun_scope("raft::linalg::svdEig(%d, %d)", n_rows, n_cols); + common::nvtx::range fun_scope( + "raft::linalg::svdEig(%d, %d)", n_rows, n_cols); cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); cublasHandle_t cublasH = handle.get_cublas_handle(); @@ -221,7 +223,8 @@ void svdJacobi(const raft::handle_t& handle, int max_sweeps, cudaStream_t stream) { - common::nvtx_range fun_scope("raft::linalg::svdJacobi(%d, %d)", n_rows, n_cols); + common::nvtx::range fun_scope( + "raft::linalg::svdJacobi(%d, %d)", n_rows, n_cols); cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); gesvdjInfo_t gesvdj_params = NULL; diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 102c18963b..475202137b 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -418,7 +418,7 @@ class DistanceTest : public ::testing::TestWithParam> { void SetUp() override { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); - common::nvtx_range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); raft::random::Rng r(params.seed); int m = params.m; diff --git a/cpp/test/eigen_solvers.cu b/cpp/test/eigen_solvers.cu index e6bba8d3d8..f898d11d2e 100644 --- a/cpp/test/eigen_solvers.cu +++ b/cpp/test/eigen_solvers.cu @@ -28,7 +28,7 @@ namespace raft { TEST(Raft, EigenSolvers) { - common::nvtx_range fun_scope("test::EigenSolvers"); + common::nvtx::range fun_scope("test::EigenSolvers"); using namespace matrix; using index_type = int; using value_type = double; @@ -69,7 +69,7 @@ TEST(Raft, EigenSolvers) TEST(Raft, SpectralSolvers) { - common::nvtx_range fun_scope("test::SpectralSolvers"); + common::nvtx::range fun_scope("test::SpectralSolvers"); using namespace matrix; using index_type = int; using value_type = double; diff --git a/cpp/test/nvtx.cpp b/cpp/test/nvtx.cpp index 9b43828c0c..81f692a215 100644 --- a/cpp/test/nvtx.cpp +++ b/cpp/test/nvtx.cpp @@ -30,10 +30,10 @@ class NvtxNextColorTest : public ::testing::Test { const std::string temp1 = "foo"; const std::string temp2 = "bar"; - diff_string_diff_color = - common::detail::generate_next_color(temp1) != common::detail::generate_next_color(temp2); - same_string_same_color = - common::detail::generate_next_color(temp1) == common::detail::generate_next_color(temp1); + diff_string_diff_color = common::nvtx::detail::generate_next_color(temp1) != + common::nvtx::detail::generate_next_color(temp2); + same_string_same_color = common::nvtx::detail::generate_next_color(temp1) == + common::nvtx::detail::generate_next_color(temp1); } void TearDown() {} bool diff_string_diff_color = false;