Skip to content

Commit

Permalink
apacheGH-40698: [C++] Create registry for Devices to map DeviceType t…
Browse files Browse the repository at this point in the history
…o MemoryManager in C Device Data import (apache#40699)

### Rationale for this change

Follow-up on apache#39980 (comment)

Right now, the user of `ImportDeviceArray` or `ImportDeviceRecordBatch` needs to provide a `DeviceMemoryMapper` mapping the device type and id to a MemoryManager. We provide a default implementation of that mapper that just knows about the default CPU memory manager (and there is another implementation in `arrow::cuda`, but you need to explicitly pass that to the import function)

To make this easier, this PR adds a registry such that default device mappers can be added separately.

### What changes are included in this PR?

This PR adds two new public functions to register device types (`RegisterDeviceMemoryManager`) and retrieve the mapper from the registry (`GetDeviceMemoryManager`).

Further, it provides a `RegisterCUDADevice` to optionally register the CUDA devices (by default only CPU device is registered).

### Are these changes tested?

### Are there any user-facing changes?

* GitHub Issue: apache#40698

Lead-authored-by: Joris Van den Bossche <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
jorisvandenbossche and pitrou authored Mar 27, 2024
1 parent aae2557 commit a407a6b
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 26 deletions.
13 changes: 13 additions & 0 deletions cpp/src/arrow/buffer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1023,4 +1023,17 @@ TEST(TestBufferConcatenation, EmptyBuffer) {
AssertMyBufferEqual(*result, contents);
}

TEST(TestDeviceRegistry, Basics) {
// Test the error cases for the device registry

// CPU is already registered
ASSERT_RAISES(KeyError,
RegisterDeviceMapper(DeviceAllocationType::kCPU, [](int64_t device_id) {
return default_cpu_memory_manager();
}));

// VPI is not registered
ASSERT_RAISES(KeyError, GetDeviceMapper(DeviceAllocationType::kVPI));
}

} // namespace arrow
11 changes: 5 additions & 6 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1967,12 +1967,11 @@ Result<std::shared_ptr<RecordBatch>> ImportRecordBatch(struct ArrowArray* array,
return ImportRecordBatch(array, *maybe_schema);
}

Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType device_type,
int64_t device_id) {
if (device_type != ARROW_DEVICE_CPU) {
return Status::NotImplemented("Only importing data on CPU is supported");
}
return default_cpu_memory_manager();
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMemoryMapper(
ArrowDeviceType device_type, int64_t device_id) {
ARROW_ASSIGN_OR_RAISE(auto mapper,
GetDeviceMapper(static_cast<DeviceAllocationType>(device_type)));
return mapper(device_id);
}

Result<std::shared_ptr<Array>> ImportDeviceArray(struct ArrowDeviceArray* array,
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/arrow/c/bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ using DeviceMemoryMapper =
std::function<Result<std::shared_ptr<MemoryManager>>(ArrowDeviceType, int64_t)>;

ARROW_EXPORT
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType device_type,
int64_t device_id);
Result<std::shared_ptr<MemoryManager>> DefaultDeviceMemoryMapper(
ArrowDeviceType device_type, int64_t device_id);

/// \brief EXPERIMENTAL: Import C++ device array from the C data interface.
///
Expand All @@ -236,7 +236,7 @@ Result<std::shared_ptr<MemoryManager>> DefaultDeviceMapper(ArrowDeviceType devic
ARROW_EXPORT
Result<std::shared_ptr<Array>> ImportDeviceArray(
struct ArrowDeviceArray* array, std::shared_ptr<DataType> type,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// \brief EXPERIMENTAL: Import C++ device array and its type from the C data interface.
///
Expand All @@ -253,7 +253,7 @@ Result<std::shared_ptr<Array>> ImportDeviceArray(
ARROW_EXPORT
Result<std::shared_ptr<Array>> ImportDeviceArray(
struct ArrowDeviceArray* array, struct ArrowSchema* type,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device from the C data
/// interface.
Expand All @@ -271,7 +271,7 @@ Result<std::shared_ptr<Array>> ImportDeviceArray(
ARROW_EXPORT
Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
struct ArrowDeviceArray* array, std::shared_ptr<Schema> schema,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device and its schema
/// from the C data interface.
Expand All @@ -291,7 +291,7 @@ Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
ARROW_EXPORT
Result<std::shared_ptr<RecordBatch>> ImportDeviceRecordBatch(
struct ArrowDeviceArray* array, struct ArrowSchema* schema,
const DeviceMemoryMapper& mapper = DefaultDeviceMapper);
const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);

/// @}

Expand Down
63 changes: 63 additions & 0 deletions cpp/src/arrow/device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "arrow/device.h"

#include <cstring>
#include <mutex>
#include <unordered_map>
#include <utility>

#include "arrow/array.h"
Expand Down Expand Up @@ -268,4 +270,65 @@ std::shared_ptr<MemoryManager> CPUDevice::default_memory_manager() {
return default_cpu_memory_manager();
}

namespace {

class DeviceMapperRegistryImpl {
public:
DeviceMapperRegistryImpl() {}

Status RegisterDevice(DeviceAllocationType device_type, DeviceMapper memory_mapper) {
std::lock_guard<std::mutex> lock(lock_);
auto [_, inserted] = registry_.try_emplace(device_type, std::move(memory_mapper));
if (!inserted) {
return Status::KeyError("Device type ", static_cast<int>(device_type),
" is already registered");
}
return Status::OK();
}

Result<DeviceMapper> GetMapper(DeviceAllocationType device_type) {
std::lock_guard<std::mutex> lock(lock_);
auto it = registry_.find(device_type);
if (it == registry_.end()) {
return Status::KeyError("Device type ", static_cast<int>(device_type),
"is not registered");
}
return it->second;
}

private:
std::mutex lock_;
std::unordered_map<DeviceAllocationType, DeviceMapper> registry_;
};

Result<std::shared_ptr<MemoryManager>> DefaultCPUDeviceMapper(int64_t device_id) {
return default_cpu_memory_manager();
}

static std::unique_ptr<DeviceMapperRegistryImpl> CreateDeviceRegistry() {
auto registry = std::make_unique<DeviceMapperRegistryImpl>();

// Always register the CPU device
DCHECK_OK(registry->RegisterDevice(DeviceAllocationType::kCPU, DefaultCPUDeviceMapper));

return registry;
}

DeviceMapperRegistryImpl* GetDeviceRegistry() {
static auto g_registry = CreateDeviceRegistry();
return g_registry.get();
}

} // namespace

Status RegisterDeviceMapper(DeviceAllocationType device_type, DeviceMapper mapper) {
auto registry = GetDeviceRegistry();
return registry->RegisterDevice(device_type, std::move(mapper));
}

Result<DeviceMapper> GetDeviceMapper(DeviceAllocationType device_type) {
auto registry = GetDeviceRegistry();
return registry->GetMapper(device_type);
}

} // namespace arrow
28 changes: 28 additions & 0 deletions cpp/src/arrow/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,32 @@ class ARROW_EXPORT CPUMemoryManager : public MemoryManager {
ARROW_EXPORT
std::shared_ptr<MemoryManager> default_cpu_memory_manager();

using DeviceMapper =
std::function<Result<std::shared_ptr<MemoryManager>>(int64_t device_id)>;

/// \brief Register a function to retrieve a MemoryManager for a Device type
///
/// This registers the device type globally. A specific device type can only
/// be registered once. This method is thread-safe.
///
/// Currently, this registry is only used for importing data through the C Device
/// Data Interface (for the default Device to MemoryManager mapper in
/// arrow::ImportDeviceArray/ImportDeviceRecordBatch).
///
/// \param[in] device_type the device type for which to register a MemoryManager
/// \param[in] mapper function that takes a device id and returns the appropriate
/// MemoryManager for the registered device type and given device id
/// \return Status
ARROW_EXPORT
Status RegisterDeviceMapper(DeviceAllocationType device_type, DeviceMapper mapper);

/// \brief Get the registered function to retrieve a MemoryManager for the
/// given Device type
///
/// \param[in] device_type the device type
/// \return function that takes a device id and returns the appropriate
/// MemoryManager for the registered device type and given device id
ARROW_EXPORT
Result<DeviceMapper> GetDeviceMapper(DeviceAllocationType device_type);

} // namespace arrow
19 changes: 19 additions & 0 deletions cpp/src/arrow/gpu/cuda_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cuda.h>

#include "arrow/buffer.h"
#include "arrow/device.h"
#include "arrow/io/memory.h"
#include "arrow/memory_pool.h"
#include "arrow/status.h"
Expand Down Expand Up @@ -501,5 +502,23 @@ Result<std::shared_ptr<MemoryManager>> DefaultMemoryMapper(ArrowDeviceType devic
}
}

namespace {

Result<std::shared_ptr<MemoryManager>> DefaultCUDADeviceMapper(int64_t device_id) {
ARROW_ASSIGN_OR_RAISE(auto device, arrow::cuda::CudaDevice::Make(device_id));
return device->default_memory_manager();
}

bool RegisterCUDADeviceInternal() {
DCHECK_OK(RegisterDeviceMapper(DeviceAllocationType::kCUDA, DefaultCUDADeviceMapper));
// TODO add the CUDA_HOST and CUDA_MANAGED allocation types when they are supported in
// the CudaDevice
return true;
}

static auto cuda_registered = RegisterCUDADeviceInternal();

} // namespace

} // namespace cuda
} // namespace arrow
4 changes: 3 additions & 1 deletion cpp/src/arrow/gpu/cuda_memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,9 @@ Result<uintptr_t> GetDeviceAddress(const uint8_t* cpu_data,
ARROW_EXPORT
Result<uint8_t*> GetHostAddress(uintptr_t device_ptr);

ARROW_EXPORT
ARROW_DEPRECATED(
"Deprecated in 16.0.0. The CUDA device is registered by default, and you can use "
"arrow::DefaultDeviceMapper instead.")
Result<std::shared_ptr<MemoryManager>> DefaultMemoryMapper(ArrowDeviceType device_type,
int64_t device_id);

Expand Down
15 changes: 2 additions & 13 deletions cpp/src/arrow/gpu/cuda_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -716,17 +716,6 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {
public:
using ArrayFactory = std::function<Result<std::shared_ptr<Array>>()>;

static Result<std::shared_ptr<MemoryManager>> DeviceMapper(ArrowDeviceType type,
int64_t id) {
if (type != ARROW_DEVICE_CUDA) {
return Status::NotImplemented("should only be CUDA device");
}

ARROW_ASSIGN_OR_RAISE(auto manager, cuda::CudaDeviceManager::Instance());
ARROW_ASSIGN_OR_RAISE(auto device, manager->GetDevice(id));
return device->default_memory_manager();
}

static ArrayFactory JSONArrayFactory(std::shared_ptr<DataType> type, const char* json) {
return [=]() { return ArrayFromJSON(type, json); };
}
Expand Down Expand Up @@ -759,7 +748,7 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {

std::shared_ptr<Array> device_array_roundtripped;
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
ImportDeviceArray(&c_array, &c_schema));
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));

Expand All @@ -779,7 +768,7 @@ class TestCudaDeviceArrayRoundtrip : public ::testing::Test {
ASSERT_OK(ExportDeviceArray(*device_array, sync, &c_array, &c_schema));
device_array_roundtripped.reset();
ASSERT_OK_AND_ASSIGN(device_array_roundtripped,
ImportDeviceArray(&c_array, &c_schema, DeviceMapper));
ImportDeviceArray(&c_array, &c_schema));
ASSERT_TRUE(ArrowSchemaIsReleased(&c_schema));
ASSERT_TRUE(ArrowArrayIsReleased(&c_array.array));

Expand Down

0 comments on commit a407a6b

Please sign in to comment.