From b085085121fe23b5bb1fea8f2939ab46598a88e6 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 22 Feb 2023 19:40:43 -0500 Subject: [PATCH 1/2] Adding util to get the device id for a pointer address and associated test --- cpp/include/raft/util/cudart_utils.hpp | 27 ++++++++++++++++++++++++++ cpp/test/util/cudart_utils.cpp | 14 +++++++++++++ 2 files changed, 41 insertions(+) diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 1c9793eb0a..b6764b7c36 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -25,6 +25,7 @@ #pragma once #include +#include #include #include #include @@ -332,6 +333,32 @@ void print_vector(const char* variable_name, const T* ptr, size_t componentsCoun } /** @} */ +/** + * Returns the id of the device for which the pointer is located + * @param p pointer to check + * @return id of device for which pointer is located, otherwise -1. + */ +template +int get_device_for_address(const T* p) +{ + if (!p) { return -1; } + + cudaPointerAttributes att; + cudaError_t err = cudaPointerGetAttributes(&att, p); + if (err == cudaErrorInvalidValue) { + // Make sure the current thread error status has been reset + err = cudaGetLastError(); + return -1; + } + + // memoryType is deprecated for CUDA 10.0+ + if (att.type == cudaMemoryTypeDevice) { + return att.device; + } else { + return -1; + } +} + /** helper method to get max usable shared mem per block parameter */ inline int getSharedMemPerBlock() { diff --git a/cpp/test/util/cudart_utils.cpp b/cpp/test/util/cudart_utils.cpp index e6b1aa9676..57cd2ff9b0 100644 --- a/cpp/test/util/cudart_utils.cpp +++ b/cpp/test/util/cudart_utils.cpp @@ -14,7 +14,11 @@ * limitations under the License. */ +#include #include +#include + +#include #include @@ -84,4 +88,14 @@ TEST(Raft, Utils) << "expected regex:'" << re_exp << "'"; } +TEST(Raft, GetDeviceForAddress) +{ + device_resources handle; + std::vector h(1); + ASSERT_EQ(-1, raft::get_device_for_address(h.data())); + + rmm::device_uvector d(1, handle.get_stream()); + ASSERT_EQ(0, raft::get_device_for_address(d.data())); +} + } // namespace raft From 4b8629c5ceccd8222d1f527503ae2638fb68a048 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 22 Feb 2023 20:34:32 -0500 Subject: [PATCH 2/2] Removing logger include --- cpp/include/raft/util/cudart_utils.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index b6764b7c36..0feb188ad8 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -25,7 +25,6 @@ #pragma once #include -#include #include #include #include