Skip to content

Commit

Permalink
Adding util to get the device id for a pointer address (#1297)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #1297
  • Loading branch information
cjnolet authored Feb 23, 2023
1 parent bbfb869 commit ad768f4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
26 changes: 26 additions & 0 deletions cpp/include/raft/util/cudart_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,32 @@ void print_vector(const char* variable_name, const T* ptr, size_t componentsCoun
}
/** @} */

/**
* Returns the id of the device for which the pointer is located
* @param p pointer to check
* @return id of device for which pointer is located, otherwise -1.
*/
template <typename T>
int get_device_for_address(const T* p)
{
if (!p) { return -1; }

cudaPointerAttributes att;
cudaError_t err = cudaPointerGetAttributes(&att, p);
if (err == cudaErrorInvalidValue) {
// Make sure the current thread error status has been reset
err = cudaGetLastError();
return -1;
}

// memoryType is deprecated for CUDA 10.0+
if (att.type == cudaMemoryTypeDevice) {
return att.device;
} else {
return -1;
}
}

/** helper method to get max usable shared mem per block parameter */
inline int getSharedMemPerBlock()
{
Expand Down
14 changes: 14 additions & 0 deletions cpp/test/util/cudart_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
* limitations under the License.
*/

#include <raft/core/device_resources.hpp>
#include <raft/util/cudart_utils.hpp>
#include <vector>

#include <rmm/device_uvector.hpp>

#include <gtest/gtest.h>

Expand Down Expand Up @@ -84,4 +88,14 @@ TEST(Raft, Utils)
<< "expected regex:'" << re_exp << "'";
}

TEST(Raft, GetDeviceForAddress)
{
device_resources handle;
std::vector<int> h(1);
ASSERT_EQ(-1, raft::get_device_for_address(h.data()));

rmm::device_uvector<int> d(1, handle.get_stream());
ASSERT_EQ(0, raft::get_device_for_address(d.data()));
}

} // namespace raft

0 comments on commit ad768f4

Please sign in to comment.