diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 1c9793eb0a..0feb188ad8 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -332,6 +332,32 @@ void print_vector(const char* variable_name, const T* ptr, size_t componentsCoun } /** @} */ +/** + * Returns the id of the device for which the pointer is located + * @param p pointer to check + * @return id of device for which pointer is located, otherwise -1. + */ +template +int get_device_for_address(const T* p) +{ + if (!p) { return -1; } + + cudaPointerAttributes att; + cudaError_t err = cudaPointerGetAttributes(&att, p); + if (err == cudaErrorInvalidValue) { + // Make sure the current thread error status has been reset + err = cudaGetLastError(); + return -1; + } + + // memoryType is deprecated for CUDA 10.0+ + if (att.type == cudaMemoryTypeDevice) { + return att.device; + } else { + return -1; + } +} + /** helper method to get max usable shared mem per block parameter */ inline int getSharedMemPerBlock() { diff --git a/cpp/test/util/cudart_utils.cpp b/cpp/test/util/cudart_utils.cpp index e6b1aa9676..57cd2ff9b0 100644 --- a/cpp/test/util/cudart_utils.cpp +++ b/cpp/test/util/cudart_utils.cpp @@ -14,7 +14,11 @@ * limitations under the License. */ +#include #include +#include + +#include #include @@ -84,4 +88,14 @@ TEST(Raft, Utils) << "expected regex:'" << re_exp << "'"; } +TEST(Raft, GetDeviceForAddress) +{ + device_resources handle; + std::vector h(1); + ASSERT_EQ(-1, raft::get_device_for_address(h.data())); + + rmm::device_uvector d(1, handle.get_stream()); + ASSERT_EQ(0, raft::get_device_for_address(d.data())); +} + } // namespace raft