Skip to content

Commit

Permalink
address feedback: raise informative error message
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Apr 9, 2024
1 parent 820bb47 commit 3e9f913
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ cdef class Array(_PandasConvertible):
void* c_type_ptr
shared_ptr[CArray] c_array

if c_device_array.device_type == 2:
if c_device_array.device_type == ARROW_DEVICE_CUDA:
_ensure_cuda_loaded()

c_type = pyarrow_unwrap_data_type(type)
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,9 @@ cdef extern from "arrow/c/abi.h":
cdef struct ArrowArrayStream:
void (*release)(ArrowArrayStream*) noexcept nogil

ctypedef int32_t ArrowDeviceType
cdef ArrowDeviceType ARROW_DEVICE_CUDA

cdef struct ArrowDeviceArray:
ArrowArray array
int64_t device_id
Expand Down
16 changes: 11 additions & 5 deletions python/pyarrow/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ UnionMode_DENSE = _UnionMode_DENSE

__pc = None
__pac = None
__cuda_loaded = False
__cuda_loaded = None


def _pc():
Expand All @@ -148,12 +148,18 @@ def _ensure_cuda_loaded():
# Try importing the cuda module to ensure libarrow_cuda gets loaded
# to register the CUDA device for the C Data Interface import
global __cuda_loaded
if not __cuda_loaded:
if __cuda_loaded is None:
try:
import pyarrow.cuda # no-cython-lint
except ImportError:
pass
__cuda_loaded = True
__cuda_loaded = True
except ImportError as exc:
__cuda_loaded = str(exc)

if __cuda_loaded is not True:
raise ImportError(
"Trying to import data on a CUDA device, but PyArrow is not built with "
f"CUDA support.\n(importing 'pyarrow.cuda' resulted in \"{__cuda_loaded}\")."
)


def _gdb_test_session():
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3656,7 +3656,7 @@ cdef class RecordBatch(_Tabular):
void* c_schema_ptr
shared_ptr[CRecordBatch] c_batch

if c_device_array.device_type == 2:
if c_device_array.device_type == ARROW_DEVICE_CUDA:
_ensure_cuda_loaded()

c_schema = pyarrow_unwrap_schema(schema)
Expand Down
21 changes: 21 additions & 0 deletions python/pyarrow/tests/test_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,24 @@ def test_roundtrip_chunked_array_capsule_requested_schema():
requested_capsule = requested_type.__arrow_c_schema__()
with pytest.raises(NotImplementedError):
chunked.__arrow_c_stream__(requested_capsule)


def test_import_device_no_cuda():
try:
import pyarrow.cuda # noqa
except ImportError:
pass
else:
pytest.skip("pyarrow.cuda is available")

c_array = ffi.new("struct ArrowDeviceArray*")
ptr_array = int(ffi.cast("uintptr_t", c_array))
arr = pa.array([1, 2, 3], type=pa.int64())
arr._export_to_c_device(ptr_array)

# patch the device type of the struct, this results in an invalid ArrowDeviceArray
# but this is just to test we raise am error before actually importing buffers
c_array.device_type = 2 # ARROW_DEVICE_CUDA

with pytest.raises(ImportError, match="Trying to import data on a CUDA device"):
pa.Array._import_from_c_device(ptr_array, arr.type)

0 comments on commit 3e9f913

Please sign in to comment.