From 3e9f913114e5a1e9d85e0b059612eb4ff224edfe Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 9 Apr 2024 09:30:41 +0200 Subject: [PATCH] address feedback: raise informative error message --- python/pyarrow/array.pxi | 2 +- python/pyarrow/includes/libarrow.pxd | 3 +++ python/pyarrow/lib.pyx | 16 +++++++++++----- python/pyarrow/table.pxi | 2 +- python/pyarrow/tests/test_cffi.py | 21 +++++++++++++++++++++ 5 files changed, 37 insertions(+), 7 deletions(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index b521eed5ec75a..5923ef45bba53 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1839,7 +1839,7 @@ cdef class Array(_PandasConvertible): void* c_type_ptr shared_ptr[CArray] c_array - if c_device_array.device_type == 2: + if c_device_array.device_type == ARROW_DEVICE_CUDA: _ensure_cuda_loaded() c_type = pyarrow_unwrap_data_type(type) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e09d3543a0daa..3e2d0ed209555 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2916,6 +2916,9 @@ cdef extern from "arrow/c/abi.h": cdef struct ArrowArrayStream: void (*release)(ArrowArrayStream*) noexcept nogil + ctypedef int32_t ArrowDeviceType + cdef ArrowDeviceType ARROW_DEVICE_CUDA + cdef struct ArrowDeviceArray: ArrowArray array int64_t device_id diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index 39b13acb466d7..4937ebe3c29b9 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -125,7 +125,7 @@ UnionMode_DENSE = _UnionMode_DENSE __pc = None __pac = None -__cuda_loaded = False +__cuda_loaded = None def _pc(): @@ -148,12 +148,18 @@ def _ensure_cuda_loaded(): # Try importing the cuda module to ensure libarrow_cuda gets loaded # to register the CUDA device for the C Data Interface import global __cuda_loaded - if not __cuda_loaded: + if __cuda_loaded is None: try: import pyarrow.cuda # no-cython-lint - except ImportError: - pass - __cuda_loaded = True + __cuda_loaded = True + except ImportError as exc: + __cuda_loaded = str(exc) + + if __cuda_loaded is not True: + raise ImportError( + "Trying to import data on a CUDA device, but PyArrow is not built with " + f"CUDA support.\n(importing 'pyarrow.cuda' resulted in \"{__cuda_loaded}\")." + ) def _gdb_test_session(): diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 2112ddefa4bc5..42be5fad1a61b 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3656,7 +3656,7 @@ cdef class RecordBatch(_Tabular): void* c_schema_ptr shared_ptr[CRecordBatch] c_batch - if c_device_array.device_type == 2: + if c_device_array.device_type == ARROW_DEVICE_CUDA: _ensure_cuda_loaded() c_schema = pyarrow_unwrap_schema(schema) diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index f8b2ea15d31ad..745e430a1217b 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -697,3 +697,24 @@ def test_roundtrip_chunked_array_capsule_requested_schema(): requested_capsule = requested_type.__arrow_c_schema__() with pytest.raises(NotImplementedError): chunked.__arrow_c_stream__(requested_capsule) + + +def test_import_device_no_cuda(): + try: + import pyarrow.cuda # noqa + except ImportError: + pass + else: + pytest.skip("pyarrow.cuda is available") + + c_array = ffi.new("struct ArrowDeviceArray*") + ptr_array = int(ffi.cast("uintptr_t", c_array)) + arr = pa.array([1, 2, 3], type=pa.int64()) + arr._export_to_c_device(ptr_array) + + # patch the device type of the struct, this results in an invalid ArrowDeviceArray + # but this is just to test we raise am error before actually importing buffers + c_array.device_type = 2 # ARROW_DEVICE_CUDA + + with pytest.raises(ImportError, match="Trying to import data on a CUDA device"): + pa.Array._import_from_c_device(ptr_array, arr.type)