Skip to content

Commit

Permalink
[python] Add bindings for memory manager and device to Context class
Browse files Browse the repository at this point in the history
  • Loading branch information
anjakefala committed Jul 18, 2024
1 parent a137687 commit d5b4b9d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/pyarrow/_cuda.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ cdef class Context(_Weakrefable):
cudabuf = GetResultValue(self.context.get().Allocate(nbytes))
return pyarrow_wrap_cudabuffer(cudabuf)

def memory_manager(self):
return MemoryManager.wrap(self.context.get().memory_manager())

def device(self):
return Device.wrap(self.context.get().device())

def foreign_buffer(self, address, size, base=None):
"""
Create device buffer from address and size as a view.
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/includes/libarrow_cuda.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ cdef extern from "arrow/gpu/cuda_api.h" namespace "arrow::cuda" nogil:
const void* handle() const
int device_number() const
CResult[uintptr_t] GetDeviceAddress(uintptr_t addr)
shared_ptr[CDevice] device() const
shared_ptr[CMemoryManager] memory_manager() const

cdef cppclass CCudaIpcMemHandle" arrow::cuda::CudaIpcMemHandle":
@staticmethod
Expand Down
12 changes: 12 additions & 0 deletions python/pyarrow/tests/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def test_Context():
assert global_context.device_number == 0
assert global_context1.device_number == cuda.Context.get_num_devices() - 1

mm = global_context.memory_manager()
assert not mm.is_cpu
assert "<pyarrow.MemoryManager device: CudaDevice" in repr(mm)

dev = global_context.device()
assert dev == mm.device

assert not dev.is_cpu
assert dev.device_id == 0
assert dev.device_type == pa.DeviceAllocationType.CUDA
assert "<pyarrow.Device: CudaDevice" in repr(dev)

with pytest.raises(ValueError,
match=("device_number argument must "
"be non-negative less than")):
Expand Down

0 comments on commit d5b4b9d

Please sign in to comment.