Skip to content

Commit

Permalink
test readonly with numpy; rename use_stream parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Dec 4, 2024
1 parent 6af4da3 commit 16fc9f6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
1 change: 1 addition & 0 deletions cuda_core/examples/strided_memory_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
# Below, as a user we want to perform the said in-place operation on either CPU
# or GPU, by calling the corresponding function implemented "elsewhere" (done above).


# We assume the 0-th argument supports either DLPack or CUDA Array Interface (both
# of which are supported by StridedMemoryView).
@args_viewable_as_strided_memory((0,))
Expand Down
24 changes: 17 additions & 7 deletions cuda_core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ def convert_strides_to_counts(strides, itemsize):
np.empty(3, dtype=np.int32),
np.empty((6, 6), dtype=np.float64)[::2, ::2],
np.empty((3, 4), order="F"),
np.empty((), dtype=np.float16),
# readonly is fixed recently (numpy/numpy#26501)
pytest.param(
np.frombuffer(b""),
marks=pytest.mark.skipif(
tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+"
),
),
),
)
class TestViewCPU:
Expand Down Expand Up @@ -57,22 +65,23 @@ def _check_view(self, view, in_arr):
assert view.device_id == -1
assert view.is_device_accessible is False
assert view.exporting_obj is in_arr
assert view.readonly is not in_arr.flags.writeable


def gpu_array_samples():
# TODO: this function would initialize the device at test collection time
samples = []
if cp is not None:
samples += [
(cp.empty(3, dtype=cp.complex64), None),
(cp.empty(3, dtype=cp.complex64), False),
(cp.empty((6, 6), dtype=cp.float64)[::2, ::2], True),
(cp.empty((3, 4), order="F"), True),
]
# Numba's device_array is the only known array container that does not
# support DLPack (so that we get to test the CAI coverage).
if numba_cuda is not None:
samples += [
(numba_cuda.device_array((2,), dtype=np.int8), None),
(numba_cuda.device_array((2,), dtype=np.int8), False),
(numba_cuda.device_array((4, 2), dtype=np.float32), True),
]
return samples
Expand All @@ -86,14 +95,14 @@ def gpu_array_ptr(arr):
raise NotImplementedError(f"{arr=}")


@pytest.mark.parametrize("in_arr,stream", (*gpu_array_samples(),))
@pytest.mark.parametrize("in_arr,use_stream", (*gpu_array_samples(),))
class TestViewGPU:
def test_args_viewable_as_strided_memory_gpu(self, in_arr, stream):
def test_args_viewable_as_strided_memory_gpu(self, in_arr, use_stream):
# TODO: use the device fixture?
dev = Device()
dev.set_current()
# This is the consumer stream
s = dev.create_stream() if stream else None
s = dev.create_stream() if use_stream else None

@args_viewable_as_strided_memory((0,))
def my_func(arr):
Expand All @@ -102,12 +111,12 @@ def my_func(arr):

my_func(in_arr)

def test_strided_memory_view_cpu(self, in_arr, stream):
def test_strided_memory_view_cpu(self, in_arr, use_stream):
# TODO: use the device fixture?
dev = Device()
dev.set_current()
# This is the consumer stream
s = dev.create_stream() if stream else None
s = dev.create_stream() if use_stream else None

view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1)
self._check_view(view, in_arr, dev)
Expand All @@ -125,3 +134,4 @@ def _check_view(self, view, in_arr, dev):
assert view.device_id == dev.device_id
assert view.is_device_accessible is True
assert view.exporting_obj is in_arr
# can't test view.readonly with CuPy or Numba...

0 comments on commit 16fc9f6

Please sign in to comment.