Skip to content

Commit

Permalink
Removed stream-synced helpers, refactored names, added domains
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Dec 16, 2021
1 parent 029c259 commit d8357b2
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 93 deletions.
59 changes: 32 additions & 27 deletions cpp/include/raft/common/detail/nvtx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@

#include <rmm/cuda_stream_view.hpp>

namespace raft::common::detail {
namespace raft::common::nvtx::detail {

#ifdef NVTX_ENABLED

#include <nvToolsExt.h>
#include <cstdint>
#include <cstdlib>
#include <mutex>
#include <nvToolsExt.h>
#include <string>
#include <type_traits>
#include <unordered_map>

/**
Expand Down Expand Up @@ -134,8 +135,25 @@ inline auto generate_next_color(const std::string& tag) -> uint32_t
return rgb;
}

static inline nvtxDomainHandle_t domain = nvtxDomainCreateA("application");
template <typename Domain, typename = Domain>
struct domain_store {
/* If `Domain::name` does not exist, this default instance is used and throws the error. */
static_assert(sizeof(Domain) != sizeof(Domain),
"Type used to identify a domain must contain a static member 'char const* name'");
static inline nvtxDomainHandle_t const kValue = nullptr;
};

template <typename Domain>
struct domain_store<
Domain,
/* Check if there exists `Domain::name` */
std::enable_if_t<
std::is_same<char const*, typename std::decay<decltype(Domain::name)>::type>::value,
Domain>> {
static inline nvtxDomainHandle_t const kValue = nvtxDomainCreateA(Domain::name);
};

template <typename Domain>
inline void push_range_name(const char* name)
{
nvtxEventAttributes_t event_attrib = {0};
Expand All @@ -145,54 +163,41 @@ inline void push_range_name(const char* name)
event_attrib.color = generate_next_color(name);
event_attrib.messageType = NVTX_MESSAGE_TYPE_ASCII;
event_attrib.message.ascii = name;
nvtxDomainRangePushEx(domain, &event_attrib);
nvtxDomainRangePushEx(domain_store<Domain>::kValue, &event_attrib);
}

template <typename... Args>
template <typename Domain, typename... Args>
inline void push_range(const char* format, Args... args)
{
if constexpr (sizeof...(args) > 0) {
int length = std::snprintf(nullptr, 0, format, args...);
assert(length >= 0);
std::vector<char> buf(length + 1);
std::snprintf(buf.data(), length + 1, format, args...);
push_range_name(buf.data());
push_range_name<Domain>(buf.data());
} else {
push_range_name(format);
push_range_name<Domain>(format);
}
}

template <typename... Args>
inline void push_range(rmm::cuda_stream_view stream, const char* format, Args... args)
template <typename Domain>
inline void pop_range()
{
stream.synchronize();
push_range(format, args...);
}

inline void pop_range() { nvtxDomainRangePop(domain); }

inline void pop_range(rmm::cuda_stream_view stream)
{
stream.synchronize();
pop_range();
nvtxDomainRangePop(domain_store<Domain>::kValue);
}

#else // NVTX_ENABLED

template <typename... Args>
template <typename Domain, typename... Args>
inline void push_range(const char* format, Args... args)
{
}

template <typename... Args>
inline void push_range(rmm::cuda_stream_view stream, const char* format, Args... args)
template <typename Domain>
inline void pop_range()
{
}

inline void pop_range() {}

inline void pop_range(rmm::cuda_stream_view stream) {}

#endif // NVTX_ENABLED

} // namespace raft::common::detail
} // namespace raft::common::nvtx::detail
106 changes: 50 additions & 56 deletions cpp/include/raft/common/nvtx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,93 +16,87 @@

#pragma once

#include <optional>
#include "detail/nvtx.hpp"
#include <optional>

namespace raft::common::nvtx {

namespace domain {

/** The default NVTX domain. */
struct app {
static constexpr char const* name{"application"};
};

/** This NVTX domain is supposed to be used within raft. */
struct raft {
static constexpr char const* name{"raft"};
};

namespace raft::common {
} // namespace domain

/**
* @brief Push a named nvtx range
* @brief Push a named NVTX range.
*
* @tparam Domain optional struct that defines the NVTX domain message;
* You can create a new domain with a custom message as follows:
* `struct custom_domain { static constexpr char const* name{"custom message"}; }`
* NB: make sure to use the same domain for `push_range` and `pop_range`.
* @param format range name format (accepts printf-style arguments)
* @param args the arguments for the printf-style formatting
*/
template <typename... Args>
inline void push_nvtx_range(const char* format, Args... args)
template <typename Domain = domain::app, typename... Args>
inline void push_range(const char* format, Args... args)
{
detail::push_range(format, args...);
detail::push_range<Domain, Args...>(format, args...);
}

/**
* @brief Synchronize CUDA stream and push a named nvtx range
* @param format range name format (accepts printf-style arguments)
* @param args the arguments for the printf-style formatting
* @param stream stream to synchronize
* @brief Pop the latest range.
*
* @tparam Domain optional struct that defines the NVTX domain message;
* You can create a new domain with a custom message as follows:
* `struct custom_domain { static constexpr char const* name{"custom message"}; }`
* NB: make sure to use the same domain for `push_range` and `pop_range`.
*/
template <typename... Args>
inline void push_nvtx_range(rmm::cuda_stream_view stream, const char* format, Args... args)
template <typename Domain = domain::app>
inline void pop_range()
{
detail::push_range(stream, format, args...);
detail::pop_range<Domain>();
}

/** Pop the latest range */
inline void pop_nvtx_range() { detail::pop_range(); }

/**
* @brief Synchronize CUDA stream and pop the latest nvtx range
* @param stream stream to synchronize
* @brief Push a named NVTX range that would be popped at the end of the object lifetime.
*
* @tparam Domain optional struct that defines the NVTX domain message;
* You can create a new domain with a custom message as follows:
* `struct custom_domain { static constexpr char const* name{"custom message"}; }`
*/
inline void pop_nvtx_range(rmm::cuda_stream_view stream) { detail::pop_range(stream); }

/** Push a named nvtx range that would be popped at the end of the object lifetime. */
class nvtx_range {
private:
std::optional<rmm::cuda_stream_view> stream_maybe_;

template <typename Domain = domain::app>
class range {
public:
/**
* Synchronize CUDA stream and push a named nvtx range
* At the end of the object lifetime, synchronize again and pop the range.
*
* @param stream stream to synchronize
* @param format range name format (accepts printf-style arguments)
* @param args the arguments for the printf-style formatting
*/
template <typename... Args>
explicit nvtx_range(rmm::cuda_stream_view stream, const char* format, Args... args)
: stream_maybe_(std::make_optional(stream))
{
push_nvtx_range(stream, format, args...);
}

/**
* Push a named nvtx range.
* Push a named NVTX range.
* At the end of the object lifetime, pop the range back.
*
* @param format range name format (accepts printf-style arguments)
* @param args the arguments for the printf-style formatting
*/
template <typename... Args>
explicit nvtx_range(const char* format, Args... args) : stream_maybe_(std::nullopt)
explicit range(const char* format, Args... args)
{
push_nvtx_range(format, args...);
push_range<Domain, Args...>(format, args...);
}

~nvtx_range()
{
if (stream_maybe_.has_value()) {
pop_nvtx_range(*stream_maybe_);
} else {
pop_nvtx_range();
}
}
~range() { pop_range<Domain>(); }

/* This object is not meant to be touched. */
nvtx_range(const nvtx_range&) = delete;
nvtx_range(nvtx_range&&) = delete;
auto operator=(const nvtx_range&) -> nvtx_range& = delete;
auto operator=(nvtx_range&&) -> nvtx_range& = delete;
range(const range&) = delete;
range(range&&) = delete;
auto operator=(const range&) -> range& = delete;
auto operator=(range&&) -> range& = delete;
static auto operator new(std::size_t) -> void* = delete;
static auto operator new[](std::size_t) -> void* = delete;
};

} // namespace raft::common
} // namespace raft::common::nvtx
9 changes: 6 additions & 3 deletions cpp/include/raft/linalg/svd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ void svdQR(const raft::handle_t& handle,
bool gen_right_vec,
cudaStream_t stream)
{
common::nvtx_range fun_scope("raft::linalg::svdQR(%d, %d)", n_rows, n_cols);
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::linalg::svdQR(%d, %d)", n_rows, n_cols);
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cublasHandle_t cublasH = handle.get_cublas_handle();

Expand Down Expand Up @@ -142,7 +143,8 @@ void svdEig(const raft::handle_t& handle,
bool gen_left_vec,
cudaStream_t stream)
{
common::nvtx_range fun_scope("raft::linalg::svdEig(%d, %d)", n_rows, n_cols);
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::linalg::svdEig(%d, %d)", n_rows, n_cols);
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();
cublasHandle_t cublasH = handle.get_cublas_handle();

Expand Down Expand Up @@ -221,7 +223,8 @@ void svdJacobi(const raft::handle_t& handle,
int max_sweeps,
cudaStream_t stream)
{
common::nvtx_range fun_scope("raft::linalg::svdJacobi(%d, %d)", n_rows, n_cols);
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"raft::linalg::svdJacobi(%d, %d)", n_rows, n_cols);
cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle();

gesvdjInfo_t gesvdj_params = NULL;
Expand Down
2 changes: 1 addition & 1 deletion cpp/test/distance/distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
void SetUp() override
{
auto testInfo = testing::UnitTest::GetInstance()->current_test_info();
common::nvtx_range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name());
common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name());

raft::random::Rng r(params.seed);
int m = params.m;
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/eigen_solvers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace raft {

TEST(Raft, EigenSolvers)
{
common::nvtx_range fun_scope("test::EigenSolvers");
common::nvtx::range fun_scope("test::EigenSolvers");
using namespace matrix;
using index_type = int;
using value_type = double;
Expand Down Expand Up @@ -69,7 +69,7 @@ TEST(Raft, EigenSolvers)

TEST(Raft, SpectralSolvers)
{
common::nvtx_range fun_scope("test::SpectralSolvers");
common::nvtx::range fun_scope("test::SpectralSolvers");
using namespace matrix;
using index_type = int;
using value_type = double;
Expand Down
8 changes: 4 additions & 4 deletions cpp/test/nvtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class NvtxNextColorTest : public ::testing::Test {
const std::string temp1 = "foo";
const std::string temp2 = "bar";

diff_string_diff_color =
common::detail::generate_next_color(temp1) != common::detail::generate_next_color(temp2);
same_string_same_color =
common::detail::generate_next_color(temp1) == common::detail::generate_next_color(temp1);
diff_string_diff_color = common::nvtx::detail::generate_next_color(temp1) !=
common::nvtx::detail::generate_next_color(temp2);
same_string_same_color = common::nvtx::detail::generate_next_color(temp1) ==
common::nvtx::detail::generate_next_color(temp1);
}
void TearDown() {}
bool diff_string_diff_color = false;
Expand Down

0 comments on commit d8357b2

Please sign in to comment.