Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Nov 30, 2024
1 parent 8295d56 commit f1239a2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

from cuda.core.experimental import utils
from cuda.core.experimental._device import Device
from cuda.core.experimental._event import EventOptions
from cuda.core.experimental._launcher import LaunchConfig, launch
from cuda.core.experimental._program import Program
from cuda.core.experimental._stream import Stream, StreamOptions
from cuda.core.experimental import utils
28 changes: 9 additions & 19 deletions cuda_core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ def convert_strides_to_counts(strides, itemsize):


@pytest.mark.parametrize(
"in_arr,", (
"in_arr,",
(
np.empty(3, dtype=np.int32),
np.empty((6, 6), dtype=np.float64)[::2, ::2],
np.empty((3, 4), order='F'),
)
np.empty((3, 4), order="F"),
),
)
class TestViewCPU:

def test_viewable_cpu(self, in_arr):

@viewable((0,))
def my_func(arr):
# stream_ptr=-1 means "the consumer does not care"
Expand All @@ -49,8 +48,7 @@ def _check_view(self, view, in_arr):
assert isinstance(view, StridedMemoryView)
assert view.ptr == in_arr.ctypes.data
assert view.shape == in_arr.shape
strides_in_counts = convert_strides_to_counts(
in_arr.strides, in_arr.dtype.itemsize)
strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize)
if in_arr.flags.c_contiguous:
assert view.strides is None
else:
Expand All @@ -68,7 +66,7 @@ def gpu_array_samples():
samples += [
(cp.empty(3, dtype=cp.complex64), None),
(cp.empty((6, 6), dtype=cp.float64)[::2, ::2], True),
(cp.empty((3, 4), order='F'), True),
(cp.empty((3, 4), order="F"), True),
]
# Numba's device_array is the only known array container that does not
# support DLPack (so that we get to test the CAI coverage).
Expand All @@ -88,13 +86,8 @@ def gpu_array_ptr(arr):
assert False, f"{arr=}"


@pytest.mark.parametrize(
"in_arr,stream", (
*gpu_array_samples(),
)
)
@pytest.mark.parametrize("in_arr,stream", (*gpu_array_samples(),))
class TestViewGPU:

def test_viewable_gpu(self, in_arr, stream):
# TODO: use the device fixture?
dev = Device()
Expand All @@ -116,17 +109,14 @@ def test_strided_memory_view_cpu(self, in_arr, stream):
# This is the consumer stream
s = dev.create_stream() if stream else None

view = StridedMemoryView(
in_arr,
stream_ptr=s.handle if s else -1)
view = StridedMemoryView(in_arr, stream_ptr=s.handle if s else -1)
self._check_view(view, in_arr, dev)

def _check_view(self, view, in_arr, dev):
assert isinstance(view, StridedMemoryView)
assert view.ptr == gpu_array_ptr(in_arr)
assert view.shape == in_arr.shape
strides_in_counts = convert_strides_to_counts(
in_arr.strides, in_arr.dtype.itemsize)
strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize)
if in_arr.flags["C_CONTIGUOUS"]:
assert view.strides in (None, strides_in_counts)
else:
Expand Down

0 comments on commit f1239a2

Please sign in to comment.