diff --git a/docs/cudf/source/developer_guide/library_design.md b/docs/cudf/source/developer_guide/library_design.md index 54a28db1b58..bac5eae6b34 100644 --- a/docs/cudf/source/developer_guide/library_design.md +++ b/docs/cudf/source/developer_guide/library_design.md @@ -236,13 +236,11 @@ Spilling consists of two components: - A spill manager that tracks all instances of `SpillableBuffer` and spills them on demand. A global spill manager is used throughout cudf when spilling is enabled, which makes `as_buffer()` return `SpillableBuffer` instead of the default `Buffer` instances. -Accessing `Buffer.ptr`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.ptr`, which might have spilled its device memory. In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. +Accessing `Buffer.get_ptr(...)`, we get the device memory pointer of the buffer. This is unproblematic in the case of `Buffer` but what happens when accessing `SpillableBuffer.get_ptr(...)`, which might have spilled its device memory. In this case, `SpillableBuffer` needs to unspill the memory before returning its device memory pointer. Furthermore, while this device memory pointer is being used (or could be used), `SpillableBuffer` cannot spill its memory back to host memory because doing so would invalidate the device pointer. To address this, we mark the `SpillableBuffer` as unspillable, we say that the buffer has been _exposed_. This can either be permanent if the device pointer is exposed to external projects or temporary while `libcudf` accesses the device memory. -The `SpillableBuffer.get_ptr()` returns the device pointer of the buffer memory just like `.ptr` but if given an instance of `SpillLock`, the buffer is only unspillable as long as the instance of `SpillLock` is alive. - -For convenience, one can use the decorator/context `acquire_spill_lock` to associate a `SpillLock` with a lifetime bound to the context automatically. +The `SpillableBuffer.get_ptr(...)` returns the device pointer of the buffer memory but if called within an `acquire_spill_lock` decorator/context, the buffer is only marked unspillable while running within the decorator/context. #### Statistics cuDF supports spilling statistics, which can be very useful for performance profiling and to identify code that renders buffers unspillable. diff --git a/python/cudf/cudf/_lib/column.pyi b/python/cudf/cudf/_lib/column.pyi index 612f3cdf95a..013cba3ae03 100644 --- a/python/cudf/cudf/_lib/column.pyi +++ b/python/cudf/cudf/_lib/column.pyi @@ -52,8 +52,6 @@ class Column: @property def base_mask(self) -> Optional[Buffer]: ... @property - def base_mask_ptr(self) -> int: ... - @property def mask(self) -> Optional[Buffer]: ... @property def mask_ptr(self) -> int: ... diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index a5d72193049..11b4a900896 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -101,7 +101,7 @@ cdef class Column: if self.data is None: return 0 else: - return self.data.ptr + return self.data.get_ptr(mode="write") def set_base_data(self, value): if value is not None and not isinstance(value, Buffer): @@ -124,13 +124,6 @@ cdef class Column: def base_mask(self): return self._base_mask - @property - def base_mask_ptr(self): - if self.base_mask is None: - return 0 - else: - return self.base_mask.ptr - @property def mask(self): if self._mask is None: @@ -145,7 +138,7 @@ cdef class Column: if self.mask is None: return 0 else: - return self.mask.ptr + return self.mask.get_ptr(mode="write") def set_base_mask(self, value): """ @@ -206,7 +199,7 @@ cdef class Column: elif hasattr(value, "__cuda_array_interface__"): if value.__cuda_array_interface__["typestr"] not in ("|i1", "|u1"): if isinstance(value, Column): - value = value.data_array_view + value = value.data_array_view(mode="write") value = cp.asarray(value).view('|u1') mask = as_buffer(value) if mask.size < required_num_bytes: @@ -329,10 +322,10 @@ cdef class Column: if col.base_data is None: data = NULL - elif isinstance(col.base_data, SpillableBuffer): - data = (col.base_data).get_ptr() else: - data = (col.base_data.ptr) + data = (col.base_data.get_ptr( + mode="write") + ) cdef Column child_column if col.base_children: @@ -341,7 +334,9 @@ cdef class Column: cdef libcudf_types.bitmask_type* mask if self.nullable: - mask = (self.base_mask_ptr) + mask = ( + self.base_mask.get_ptr(mode="write") + ) else: mask = NULL @@ -387,10 +382,8 @@ cdef class Column: if col.base_data is None: data = NULL - elif isinstance(col.base_data, SpillableBuffer): - data = (col.base_data).get_ptr() else: - data = (col.base_data.ptr) + data = (col.base_data.get_ptr(mode="read")) cdef Column child_column if col.base_children: @@ -399,7 +392,9 @@ cdef class Column: cdef libcudf_types.bitmask_type* mask if self.nullable: - mask = (self.base_mask_ptr) + mask = ( + self.base_mask.get_ptr(mode="read") + ) else: mask = NULL @@ -549,7 +544,8 @@ cdef class Column: f"{data_owner} is spilled, which invalidates " f"the exposed data_ptr ({hex(data_ptr)})" ) - data_owner.ptr # accessing the pointer marks it exposed. + # accessing the pointer marks it exposed permanently. + data_owner.mark_exposed() else: data = as_buffer( rmm.DeviceBuffer(ptr=data_ptr, size=0) diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index c01709322ed..6a53586396f 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. import pickle @@ -765,7 +765,7 @@ cdef class _CPackedColumns: gpu_data = Buffer.deserialize(header["data"], frames) dbuf = DeviceBuffer( - ptr=gpu_data.ptr, + ptr=gpu_data.get_ptr(mode="write"), size=gpu_data.nbytes ) diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index 6f17dbab86c..a0a8279b213 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -62,7 +62,9 @@ def mask_to_bools(object mask_buffer, size_type begin_bit, size_type end_bit): if not isinstance(mask_buffer, cudf.core.buffer.Buffer): raise TypeError("mask_buffer is not an instance of " "cudf.core.buffer.Buffer") - cdef bitmask_type* bit_mask = (mask_buffer.ptr) + cdef bitmask_type* bit_mask = ( + mask_buffer.get_ptr(mode="read") + ) cdef unique_ptr[column] result with nogil: diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index ebc4d76b6a0..71f48e0ab0c 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -176,7 +176,9 @@ def _getitem(self, offset: int, size: int) -> Buffer: """ return self._from_device_memory( cuda_array_interface_wrapper( - ptr=self.ptr + offset, size=size, owner=self.owner + ptr=self.get_ptr(mode="read") + offset, + size=size, + owner=self.owner, ) ) @@ -202,11 +204,6 @@ def nbytes(self) -> int: """Size of the buffer in bytes.""" return self._size - @property - def ptr(self) -> int: - """Device pointer to the start of the buffer.""" - return self._ptr - @property def owner(self) -> Any: """Object owning the memory of the buffer.""" @@ -215,18 +212,74 @@ def owner(self) -> Any: @property def __cuda_array_interface__(self) -> Mapping: """Implementation of the CUDA Array Interface.""" + return self._get_cuda_array_interface(readonly=False) + + def _get_cuda_array_interface(self, readonly=False): + """Helper function to create a CUDA Array Interface. + + Parameters + ---------- + readonly : bool, default False + If True, returns a CUDA Array Interface with + readonly flag set to True. + If False, returns a CUDA Array Interface with + readonly flag set to False. + + Returns + ------- + dict + """ return { - "data": (self.ptr, False), + "data": ( + self.get_ptr(mode="read" if readonly else "write"), + readonly, + ), "shape": (self.size,), "strides": None, "typestr": "|u1", "version": 0, } + @property + def _readonly_proxy_cai_obj(self): + """ + Returns a proxy object with a read-only CUDA Array Interface. + """ + return cuda_array_interface_wrapper( + ptr=self.get_ptr(mode="read"), + size=self.size, + owner=self, + readonly=True, + typestr="|u1", + version=0, + ) + + def get_ptr(self, *, mode) -> int: + """Device pointer to the start of the buffer. + + Parameters + ---------- + mode : str + Supported values are {"read", "write"} + If "write", the data pointed to may be modified + by the caller. If "read", the data pointed to + must not be modified by the caller. + Failure to fulfill this contract will cause + incorrect behavior. + + + See Also + -------- + SpillableBuffer.get_ptr + """ + return self._ptr + def memoryview(self) -> memoryview: """Read-only access to the buffer through host memory.""" host_buf = host_memory_allocation(self.size) - rmm._lib.device_buffer.copy_ptr_to_host(self.ptr, host_buf) + rmm._lib.device_buffer.copy_ptr_to_host( + self.get_ptr(mode="read"), host_buf + ) return memoryview(host_buf).toreadonly() def serialize(self) -> Tuple[dict, list]: diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py index 7ca85a307bf..2064c1fd133 100644 --- a/python/cudf/cudf/core/buffer/spillable_buffer.py +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -145,7 +145,7 @@ def __len__(self): def __getitem__(self, i): if i == 0: - return self._buf.ptr + return self._buf.get_ptr(mode="write") elif i == 1: return False raise IndexError("tuple index out of range") @@ -359,7 +359,7 @@ def spill_lock(self, spill_lock: SpillLock) -> None: self.spill(target="gpu") self._spill_locks.add(spill_lock) - def get_ptr(self) -> int: + def get_ptr(self, *, mode) -> int: """Get a device pointer to the memory of the buffer. If this is called within an `acquire_spill_lock` context, @@ -369,8 +369,8 @@ def get_ptr(self) -> int: If this is *not* called within a `acquire_spill_lock` context, this buffer is marked as unspillable permanently. - Return - ------ + Returns + ------- int The device pointer as an integer """ @@ -409,18 +409,6 @@ def memory_info(self) -> Tuple[int, int, str]: ).__array_interface__["data"][0] return (ptr, self.nbytes, self._ptr_desc["type"]) - @property - def ptr(self) -> int: - """Access the memory directly - - Notice, this will mark the buffer as "exposed" and make - it unspillable permanently. - - Consider using `.get_ptr()` instead. - """ - self.mark_exposed() - return self._ptr - @property def owner(self) -> Any: return self._owner @@ -559,12 +547,12 @@ def __init__(self, base: SpillableBuffer, offset: int, size: int) -> None: self._owner = base self.lock = base.lock - @property - def ptr(self) -> int: - return self._base.ptr + self._offset - - def get_ptr(self) -> int: - return self._base.get_ptr() + self._offset + def get_ptr(self, *, mode) -> int: + """ + A passthrough method to `SpillableBuffer.get_ptr` + with factoring in the `offset`. + """ + return self._base.get_ptr(mode=mode) + self._offset def _getitem(self, offset: int, size: int) -> Buffer: return SpillableBufferSlice( diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index ef9f515fff7..af21d7545ee 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -956,9 +956,10 @@ def clip(self, lo: ScalarLike, hi: ScalarLike) -> "column.ColumnBase": self.astype(self.categories.dtype).clip(lo, hi).astype(self.dtype) ) - @property - def data_array_view(self) -> cuda.devicearray.DeviceNDArray: - return self.codes.data_array_view + def data_array_view( + self, *, mode="write" + ) -> cuda.devicearray.DeviceNDArray: + return self.codes.data_array_view(mode=mode) def unique(self, preserve_order=False) -> CategoricalColumn: codes = self.as_numerical.unique(preserve_order=preserve_order) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 69319e2f775..2f4d9e28314 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -64,7 +64,7 @@ ) from cudf.core._compat import PANDAS_GE_150 from cudf.core.abc import Serializable -from cudf.core.buffer import Buffer, as_buffer +from cudf.core.buffer import Buffer, acquire_spill_lock, as_buffer from cudf.core.dtypes import ( CategoricalDtype, IntervalDtype, @@ -113,19 +113,71 @@ def as_frame(self) -> "cudf.core.frame.Frame": {None: self.copy(deep=False)} ) - @property - def data_array_view(self) -> "cuda.devicearray.DeviceNDArray": + def data_array_view( + self, *, mode="write" + ) -> "cuda.devicearray.DeviceNDArray": """ View the data as a device array object + + Parameters + ---------- + mode : str, default 'write' + Supported values are {'read', 'write'} + If 'write' is passed, a device array object + with readonly flag set to False in CAI is returned. + If 'read' is passed, a device array object + with readonly flag set to True in CAI is returned. + This also means, If the caller wishes to modify + the data returned through this view, they must + pass mode="write", else pass mode="read". + + Returns + ------- + numba.cuda.cudadrv.devicearray.DeviceNDArray """ - return cuda.as_cuda_array(self.data).view(self.dtype) + if self.data is not None: + if mode == "read": + obj = self.data._readonly_proxy_cai_obj + elif mode == "write": + obj = self.data + else: + raise ValueError(f"Unsupported mode: {mode}") + else: + obj = None + return cuda.as_cuda_array(obj).view(self.dtype) - @property - def mask_array_view(self) -> "cuda.devicearray.DeviceNDArray": + def mask_array_view( + self, *, mode="write" + ) -> "cuda.devicearray.DeviceNDArray": """ View the mask as a device array + + Parameters + ---------- + mode : str, default 'write' + Supported values are {'read', 'write'} + If 'write' is passed, a device array object + with readonly flag set to False in CAI is returned. + If 'read' is passed, a device array object + with readonly flag set to True in CAI is returned. + This also means, If the caller wishes to modify + the data returned through this view, they must + pass mode="write", else pass mode="read". + + Returns + ------- + numba.cuda.cudadrv.devicearray.DeviceNDArray """ - return cuda.as_cuda_array(self.mask).view(mask_dtype) + if self.mask is not None: + if mode == "read": + obj = self.mask._readonly_proxy_cai_obj + elif mode == "write": + obj = self.mask + else: + raise ValueError(f"Unsupported mode: {mode}") + else: + obj = None + return cuda.as_cuda_array(obj).view(mask_dtype) def __len__(self) -> int: return self.size @@ -163,7 +215,8 @@ def values_host(self) -> "np.ndarray": if self.has_nulls(): raise ValueError("Column must have no nulls.") - return self.data_array_view.copy_to_host() + with acquire_spill_lock(): + return self.data_array_view(mode="read").copy_to_host() @property def values(self) -> "cupy.ndarray": @@ -176,7 +229,7 @@ def values(self) -> "cupy.ndarray": if self.has_nulls(): raise ValueError("Column must have no nulls.") - return cupy.asarray(self.data_array_view) + return cupy.asarray(self.data_array_view(mode="write")) def find_and_replace( self: T, @@ -363,7 +416,7 @@ def nullmask(self) -> Buffer: """The gpu buffer for the null-mask""" if not self.nullable: raise ValueError("Column has no null mask") - return self.mask_array_view + return self.mask_array_view(mode="read") def copy(self: T, deep: bool = True) -> T: """Columns are immutable, so a deep copy produces a copy of the diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 7943135afe1..8ee3b6e15b6 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -35,7 +35,11 @@ is_number, is_scalar, ) -from cudf.core.buffer import Buffer, cuda_array_interface_wrapper +from cudf.core.buffer import ( + Buffer, + acquire_spill_lock, + cuda_array_interface_wrapper, +) from cudf.core.column import ( ColumnBase, as_column, @@ -110,8 +114,8 @@ def __contains__(self, item: ScalarLike) -> bool: # Handles improper item types # Fails if item is of type None, so the handler. try: - if np.can_cast(item, self.data_array_view.dtype): - item = self.data_array_view.dtype.type(item) + if np.can_cast(item, self.dtype): + item = self.dtype.type(item) else: return False except (TypeError, ValueError): @@ -564,6 +568,7 @@ def fillna( return super(NumericalColumn, col).fillna(fill_value, method) + @acquire_spill_lock() def _find_value( self, value: ScalarLike, closest: bool, find: Callable, compare: str ) -> int: @@ -573,14 +578,14 @@ def _find_value( found = 0 if len(self): found = find( - self.data_array_view, + self.data_array_view(mode="read"), value, mask=self.mask, ) if found == -1: if self.is_monotonic_increasing and closest: found = find( - self.data_array_view, + self.data_array_view(mode="read"), value, mask=self.mask, compare=compare, diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 4ca3a9ff04d..9c30585a541 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -5395,8 +5395,9 @@ def base_size(self) -> int: else: return self.base_children[0].size - 1 - @property - def data_array_view(self) -> cuda.devicearray.DeviceNDArray: + def data_array_view( + self, *, mode="write" + ) -> cuda.devicearray.DeviceNDArray: raise ValueError("Cannot get an array view of a StringColumn") def to_arrow(self) -> pa.Array: diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index b7d1724a342..e7979fa4d27 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -13,7 +13,7 @@ from cudf import _lib as libcudf from cudf._typing import ColumnBinaryOperand, DatetimeLikeScalar, Dtype from cudf.api.types import is_scalar, is_timedelta64_dtype -from cudf.core.buffer import Buffer +from cudf.core.buffer import Buffer, acquire_spill_lock from cudf.core.column import ColumnBase, column, string from cudf.utils.dtypes import np_to_pa_dtype from cudf.utils.utils import _fillna_natwise @@ -125,11 +125,16 @@ def values(self): "TimeDelta Arrays is not yet implemented in cudf" ) + @acquire_spill_lock() def to_arrow(self) -> pa.Array: mask = None if self.nullable: - mask = pa.py_buffer(self.mask_array_view.copy_to_host()) - data = pa.py_buffer(self.as_numerical.data_array_view.copy_to_host()) + mask = pa.py_buffer( + self.mask_array_view(mode="read").copy_to_host() + ) + data = pa.py_buffer( + self.as_numerical.data_array_view(mode="read").copy_to_host() + ) pa_dtype = np_to_pa_dtype(self.dtype) return pa.Array.from_buffers( type=pa_dtype, diff --git a/python/cudf/cudf/core/df_protocol.py b/python/cudf/cudf/core/df_protocol.py index b38d3048ed7..2090906380e 100644 --- a/python/cudf/cudf/core/df_protocol.py +++ b/python/cudf/cudf/core/df_protocol.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. import enum from collections import abc @@ -89,7 +89,7 @@ def ptr(self) -> int: """ Pointer to start of the buffer as an integer. """ - return self._buf.ptr + return self._buf.get_ptr(mode="write") def __dlpack__(self): # DLPack not implemented in NumPy yet, so leave it out here. diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 32764c6c2f0..8b508eac324 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. from __future__ import annotations @@ -1448,7 +1448,7 @@ def searchsorted( # Return result as cupy array if the values is non-scalar # If values is scalar, result is expected to be scalar. - result = cupy.asarray(outcol.data_array_view) + result = cupy.asarray(outcol.data_array_view(mode="read")) if scalar_flag: return result[0].item() else: diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 6526ba1e7c3..c8016786be9 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -49,6 +49,7 @@ is_scalar, ) from cudf.core._base_index import BaseIndex +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ColumnBase, as_column, full from cudf.core.column_accessor import ColumnAccessor from cudf.core.dtypes import ListDtype @@ -2105,6 +2106,7 @@ def add_suffix(self, suffix): Use `Series.add_suffix` or `DataFrame.add_suffix`" ) + @acquire_spill_lock() @_cudf_nvtx_annotate def _apply(self, func, kernel_getter, *args, **kwargs): """Apply `func` across the rows of the frame.""" diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index faad5275abd..1c697a2d824 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -39,6 +39,7 @@ is_struct_dtype, ) from cudf.core.abc import Serializable +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import ( ColumnBase, DatetimeColumn, @@ -4855,6 +4856,7 @@ def _align_indices(series_list, how="outer", allow_non_unique=False): return result +@acquire_spill_lock() @_cudf_nvtx_annotate def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): r"""Returns a boolean array where two arrays are equal within a tolerance. @@ -4959,10 +4961,10 @@ def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): index = as_index(a.index) a_col = column.as_column(a) - a_array = cupy.asarray(a_col.data_array_view) + a_array = cupy.asarray(a_col.data_array_view(mode="read")) b_col = column.as_column(b) - b_array = cupy.asarray(b_col.data_array_view) + b_array = cupy.asarray(b_col.data_array_view(mode="read")) result = cupy.isclose( a=a_array, b=b_array, rtol=rtol, atol=atol, equal_nan=equal_nan diff --git a/python/cudf/cudf/core/window/rolling.py b/python/cudf/cudf/core/window/rolling.py index fb1cafa5625..cac4774400a 100644 --- a/python/cudf/cudf/core/window/rolling.py +++ b/python/cudf/cudf/core/window/rolling.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION +# Copyright (c) 2020-2023, NVIDIA CORPORATION import itertools @@ -11,6 +11,7 @@ from cudf.api.types import is_integer, is_number from cudf.core import column from cudf.core._compat import PANDAS_GE_150 +from cudf.core.buffer import acquire_spill_lock from cudf.core.column.column import as_column from cudf.core.mixins import Reducible from cudf.utils import cudautils @@ -487,9 +488,11 @@ def _window_to_window_sizes(self, window): if is_integer(window): return window else: - return cudautils.window_sizes_from_offset( - self.obj.index._values.data_array_view, window - ) + with acquire_spill_lock(): + return cudautils.window_sizes_from_offset( + self.obj.index._values.data_array_view(mode="write"), + window, + ) def __repr__(self): return "{} [window={},min_periods={},center={}]".format( @@ -524,16 +527,17 @@ def __init__(self, groupby, window, min_periods=None, center=False): super().__init__(obj, window, min_periods=min_periods, center=center) + @acquire_spill_lock() def _window_to_window_sizes(self, window): if is_integer(window): return cudautils.grouped_window_sizes_from_offset( - column.arange(len(self.obj)).data_array_view, + column.arange(len(self.obj)).data_array_view(mode="read"), self._group_starts, window, ) else: return cudautils.grouped_window_sizes_from_offset( - self.obj.index._values.data_array_view, + self.obj.index._values.data_array_view(mode="read"), self._group_starts, window, ) diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py index cbaf47a4c68..fb4daba1209 100644 --- a/python/cudf/cudf/testing/_utils.py +++ b/python/cudf/cudf/testing/_utils.py @@ -336,7 +336,7 @@ def assert_column_memory_eq( """ def get_ptr(x) -> int: - return x.ptr if x else 0 + return x.get_ptr(mode="read") if x else 0 assert get_ptr(lhs.base_data) == get_ptr(rhs.base_data) assert get_ptr(lhs.base_mask) == get_ptr(rhs.base_mask) diff --git a/python/cudf/cudf/tests/test_buffer.py b/python/cudf/cudf/tests/test_buffer.py index df7152d53a6..1c9e7475080 100644 --- a/python/cudf/cudf/tests/test_buffer.py +++ b/python/cudf/cudf/tests/test_buffer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. import cupy as cp import pytest @@ -52,7 +52,7 @@ def test_buffer_creation_from_any(): ary = cp.arange(arr_len) b = as_buffer(ary, exposed=True) assert isinstance(b, Buffer) - assert ary.data.ptr == b.ptr + assert ary.data.ptr == b.get_ptr(mode="read") assert ary.nbytes == b.size with pytest.raises( @@ -62,7 +62,7 @@ def test_buffer_creation_from_any(): b = as_buffer(ary.data.ptr, size=ary.nbytes, owner=ary, exposed=True) assert isinstance(b, Buffer) - assert ary.data.ptr == b.ptr + assert ary.data.ptr == b.get_ptr(mode="read") assert ary.nbytes == b.size assert b.owner.owner is ary diff --git a/python/cudf/cudf/tests/test_column.py b/python/cudf/cudf/tests/test_column.py index 75b82baf2e8..7d113bbb9e2 100644 --- a/python/cudf/cudf/tests/test_column.py +++ b/python/cudf/cudf/tests/test_column.py @@ -285,8 +285,8 @@ def test_column_view_valid_numeric_to_numeric(data, from_dtype, to_dtype): expect = pd.Series(cpu_data_view, dtype=cpu_data_view.dtype) got = cudf.Series(gpu_data_view, dtype=gpu_data_view.dtype) - gpu_ptr = gpu_data.data.ptr - assert gpu_ptr == got._column.data.ptr + gpu_ptr = gpu_data.data.get_ptr(mode="read") + assert gpu_ptr == got._column.data.get_ptr(mode="read") assert_eq(expect, got) diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 0e0b0a37255..65e24c7c704 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1868,7 +1868,7 @@ def test_to_from_arrow_nulls(data_type): # number of bytes, so only check the first byte in this case np.testing.assert_array_equal( np.asarray(s1.buffers()[0]).view("u1")[0], - gs1._column.mask_array_view.copy_to_host().view("u1")[0], + gs1._column.mask_array_view(mode="read").copy_to_host().view("u1")[0], ) assert pa.Array.equals(s1, gs1.to_arrow()) @@ -1879,7 +1879,7 @@ def test_to_from_arrow_nulls(data_type): # number of bytes, so only check the first byte in this case np.testing.assert_array_equal( np.asarray(s2.buffers()[0]).view("u1")[0], - gs2._column.mask_array_view.copy_to_host().view("u1")[0], + gs2._column.mask_array_view(mode="read").copy_to_host().view("u1")[0], ) assert pa.Array.equals(s2, gs2.to_arrow()) @@ -2659,11 +2659,11 @@ def query_GPU_memory(note=""): cudaDF = cudaDF[boolmask] assert ( - cudaDF.index._values.data_array_view.device_ctypes_pointer + cudaDF.index._values.data_array_view(mode="read").device_ctypes_pointer == cudaDF["col0"].index._values.data_array_view.device_ctypes_pointer ) assert ( - cudaDF.index._values.data_array_view.device_ctypes_pointer + cudaDF.index._values.data_array_view(mode="read").device_ctypes_pointer == cudaDF["col1"].index._values.data_array_view.device_ctypes_pointer ) diff --git a/python/cudf/cudf/tests/test_dataframe_copy.py b/python/cudf/cudf/tests/test_dataframe_copy.py index 1a9098c70db..85e994bd733 100644 --- a/python/cudf/cudf/tests/test_dataframe_copy.py +++ b/python/cudf/cudf/tests/test_dataframe_copy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. from copy import copy, deepcopy import numpy as np @@ -160,7 +160,7 @@ def test_kernel_deep_copy(): cdf = gdf.copy(deep=True) sr = gdf["b"] - add_one[1, len(sr)](sr._column.data_array_view) + add_one[1, len(sr)](sr._column.data_array_view(mode="write")) assert not gdf.to_string().split() == cdf.to_string().split() diff --git a/python/cudf/cudf/tests/test_df_protocol.py b/python/cudf/cudf/tests/test_df_protocol.py index 0981e850c10..7dbca90ab03 100644 --- a/python/cudf/cudf/tests/test_df_protocol.py +++ b/python/cudf/cudf/tests/test_df_protocol.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. from typing import Any, Tuple @@ -41,7 +41,7 @@ def assert_buffer_equal(buffer_and_dtype: Tuple[_CuDFBuffer, Any], cudfcol): if dtype[0] != _DtypeKind.BOOL: array_from_dlpack = cp.from_dlpack(buf.__dlpack__()).get() - col_array = cp.asarray(cudfcol.data_array_view).get() + col_array = cp.asarray(cudfcol.data_array_view(mode="read")).get() assert_eq( array_from_dlpack[non_null_idxs.to_numpy()].flatten(), col_array[non_null_idxs.to_numpy()].flatten(), diff --git a/python/cudf/cudf/tests/test_multiindex.py b/python/cudf/cudf/tests/test_multiindex.py index d27d6732226..3e1f001e7d1 100644 --- a/python/cudf/cudf/tests/test_multiindex.py +++ b/python/cudf/cudf/tests/test_multiindex.py @@ -804,8 +804,8 @@ def test_multiindex_copy_deep(data, deep): lchildren = reduce(operator.add, lchildren) rchildren = reduce(operator.add, rchildren) - lptrs = [child.base_data.ptr for child in lchildren] - rptrs = [child.base_data.ptr for child in rchildren] + lptrs = [child.base_data.get_ptr(mode="read") for child in lchildren] + rptrs = [child.base_data.get_ptr(mode="read") for child in rchildren] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) @@ -814,20 +814,36 @@ def test_multiindex_copy_deep(data, deep): mi2 = mi1.copy(deep=deep) # Assert ._levels identity - lptrs = [lv._data._data[None].base_data.ptr for lv in mi1._levels] - rptrs = [lv._data._data[None].base_data.ptr for lv in mi2._levels] + lptrs = [ + lv._data._data[None].base_data.get_ptr(mode="read") + for lv in mi1._levels + ] + rptrs = [ + lv._data._data[None].base_data.get_ptr(mode="read") + for lv in mi2._levels + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) # Assert ._codes identity - lptrs = [c.base_data.ptr for _, c in mi1._codes._data.items()] - rptrs = [c.base_data.ptr for _, c in mi2._codes._data.items()] + lptrs = [ + c.base_data.get_ptr(mode="read") + for _, c in mi1._codes._data.items() + ] + rptrs = [ + c.base_data.get_ptr(mode="read") + for _, c in mi2._codes._data.items() + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) # Assert ._data identity - lptrs = [d.base_data.ptr for _, d in mi1._data.items()] - rptrs = [d.base_data.ptr for _, d in mi2._data.items()] + lptrs = [ + d.base_data.get_ptr(mode="read") for _, d in mi1._data.items() + ] + rptrs = [ + d.base_data.get_ptr(mode="read") for _, d in mi2._data.items() + ] assert all((x == y) == same_ref for x, y in zip(lptrs, rptrs)) diff --git a/python/cudf/cudf/tests/test_pack.py b/python/cudf/cudf/tests/test_pack.py index b6bda7ef5fa..9972071122e 100644 --- a/python/cudf/cudf/tests/test_pack.py +++ b/python/cudf/cudf/tests/test_pack.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -129,7 +129,9 @@ def assert_packed_frame_unique_pointers(df): for col in df: if df._data[col].data: - assert df._data[col].data.ptr != unpacked._data[col].data.ptr + assert df._data[col].data.get_ptr(mode="read") != unpacked._data[ + col + ].data.get_ptr(mode="read") def test_packed_dataframe_unique_pointers_numeric(): diff --git a/python/cudf/cudf/tests/test_repr.py b/python/cudf/cudf/tests/test_repr.py index 5ba0bec3dc4..bae0fde6463 100644 --- a/python/cudf/cudf/tests/test_repr.py +++ b/python/cudf/cudf/tests/test_repr.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2023, NVIDIA CORPORATION. import textwrap @@ -31,7 +31,7 @@ def test_null_series(nrows, dtype): sr[np.random.choice([False, True], size=size)] = None if dtype != "category" and cudf.dtype(dtype).kind in {"u", "i"}: ps = pd.Series( - sr._column.data_array_view.copy_to_host(), + sr._column.data_array_view(mode="read").copy_to_host(), dtype=np_dtypes_to_pandas_dtypes.get( cudf.dtype(dtype), cudf.dtype(dtype) ), diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index bafe51b62ec..88ce908aa5f 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -119,7 +119,7 @@ def test_spillable_buffer(manager: SpillManager): buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) assert isinstance(buf, SpillableBuffer) assert buf.spillable - buf.ptr # Expose pointer + buf.mark_exposed() assert buf.exposed assert not buf.spillable buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) @@ -137,7 +137,6 @@ def test_spillable_buffer(manager: SpillManager): @pytest.mark.parametrize( "attribute", [ - "ptr", "get_ptr", "memoryview", "is_spilled", @@ -210,7 +209,7 @@ def test_spilling_buffer(manager: SpillManager): buf = as_buffer(rmm.DeviceBuffer(size=10), exposed=False) buf.spill(target="cpu") assert buf.is_spilled - buf.ptr # Expose pointer and trigger unspill + buf.mark_exposed() # Expose pointer and trigger unspill assert not buf.is_spilled with pytest.raises(ValueError, match="unspillable buffer"): buf.spill(target="cpu") @@ -378,10 +377,10 @@ def test_get_ptr(manager: SpillManager, target): assert buf.spillable assert len(buf._spill_locks) == 0 with acquire_spill_lock(): - buf.get_ptr() + buf.get_ptr(mode="read") assert not buf.spillable with acquire_spill_lock(): - buf.get_ptr() + buf.get_ptr(mode="read") assert not buf.spillable assert not buf.spillable assert buf.spillable @@ -501,7 +500,7 @@ def test_serialize_cuda_dataframe(manager: SpillManager): assert len(buf._base._spill_locks) == 1 assert len(frames) == 1 assert isinstance(frames[0], Buffer) - assert frames[0].ptr == buf.ptr + assert frames[0].get_ptr(mode="read") == buf.get_ptr(mode="read") frames[0] = cupy.array(frames[0], copy=True) df2 = protocol.deserialize(header, frames) @@ -557,18 +556,20 @@ def test_as_buffer_of_spillable_buffer(manager: SpillManager): b3 = as_buffer(b1.memory_info()[0], size=b1.size, owner=b1) with acquire_spill_lock(): - b3 = as_buffer(b1.get_ptr(), size=b1.size, owner=b1) + b3 = as_buffer(b1.get_ptr(mode="read"), size=b1.size, owner=b1) assert isinstance(b3, SpillableBufferSlice) assert b3.owner is b1 b4 = as_buffer( - b1.ptr + data.itemsize, size=b1.size - data.itemsize, owner=b3 + b1.get_ptr(mode="write") + data.itemsize, + size=b1.size - data.itemsize, + owner=b3, ) assert isinstance(b4, SpillableBufferSlice) assert b4.owner is b1 assert all(cupy.array(b4.memoryview()) == data[1:]) - b5 = as_buffer(b4.ptr, size=b4.size - 1, owner=b4) + b5 = as_buffer(b4.get_ptr(mode="write"), size=b4.size - 1, owner=b4) assert isinstance(b5, SpillableBufferSlice) assert b5.owner is b1 assert all(cupy.array(b5.memoryview()) == data[1:-1]) @@ -623,7 +624,7 @@ def test_statistics_expose(manager: SpillManager): ] # Expose the first buffer - buffers[0].ptr + buffers[0].mark_exposed() assert len(manager.statistics.exposes) == 1 stat = list(manager.statistics.exposes.values())[0] assert stat.count == 1 @@ -632,7 +633,7 @@ def test_statistics_expose(manager: SpillManager): # Expose all 10 buffers for i in range(10): - buffers[i].ptr + buffers[i].mark_exposed() # The rest of the ptr accesses should accumulate to a single stat # because they resolve to the same traceback. @@ -652,7 +653,7 @@ def test_statistics_expose(manager: SpillManager): # Expose the new buffers and check that they are counted as spilled for i in range(10): - buffers[i].ptr + buffers[i].mark_exposed() assert len(manager.statistics.exposes) == 3 stat = list(manager.statistics.exposes.values())[2] assert stat.count == 10 diff --git a/python/cudf/cudf/utils/applyutils.py b/python/cudf/cudf/utils/applyutils.py index 89331b933a8..7e998413642 100644 --- a/python/cudf/cudf/utils/applyutils.py +++ b/python/cudf/cudf/utils/applyutils.py @@ -1,13 +1,15 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. import functools from typing import Any, Dict +import cupy as cp from numba import cuda from numba.core.utils import pysignature import cudf from cudf import _lib as libcudf +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import column from cudf.utils import utils from cudf.utils.docutils import docfmt_partial @@ -139,21 +141,25 @@ def __init__( self.cache_key = cache_key self.kernel = self.compile(func, sig.parameters.keys(), kwargs.keys()) + @acquire_spill_lock() def run(self, df, **launch_params): # Get input columns if isinstance(self.incols, dict): inputs = { - v: df[k]._column.data_array_view + v: df[k]._column.data_array_view(mode="read") for (k, v) in self.incols.items() } else: - inputs = {k: df[k]._column.data_array_view for k in self.incols} + inputs = { + k: df[k]._column.data_array_view(mode="read") + for k in self.incols + } # Allocate output columns outputs = {} for k, dt in self.outcols.items(): outputs[k] = column.column_empty( len(df), dt, False - ).data_array_view + ).data_array_view(mode="write") # Bind argument args = {} for dct in [inputs, outputs, self.kwargs]: @@ -174,7 +180,7 @@ def run(self, df, **launch_params): ) if out_mask is not None: outdf._data[k] = outdf[k]._column.set_mask( - out_mask.data_array_view + out_mask.data_array_view(mode="write") ) return outdf @@ -213,11 +219,12 @@ def launch_kernel(self, df, args, chunks, blkct=None, tpb=None): def normalize_chunks(self, size, chunks): if isinstance(chunks, int): # *chunks* is the chunksize - return column.arange(0, size, chunks).data_array_view + return cuda.as_cuda_array( + cp.arange(start=0, stop=size, step=chunks) + ).view("int64") else: # *chunks* is an array of chunk leading offset - chunks = column.as_column(chunks) - return chunks.data_array_view + return cuda.as_cuda_array(cp.asarray(chunks)).view("int64") def _make_row_wise_kernel(func, argnames, extras): diff --git a/python/cudf/cudf/utils/queryutils.py b/python/cudf/cudf/utils/queryutils.py index 25b3d517e1c..4ce89b526d6 100644 --- a/python/cudf/cudf/utils/queryutils.py +++ b/python/cudf/cudf/utils/queryutils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2022, NVIDIA CORPORATION. +# Copyright (c) 2018-2023, NVIDIA CORPORATION. import ast import datetime @@ -8,6 +8,7 @@ from numba import cuda import cudf +from cudf.core.buffer import acquire_spill_lock from cudf.core.column import column_empty from cudf.utils import applyutils from cudf.utils.dtypes import ( @@ -191,6 +192,7 @@ def _add_prefix(arg): return kernel +@acquire_spill_lock() def query_execute(df, expr, callenv): """Compile & execute the query expression @@ -220,7 +222,7 @@ def query_execute(df, expr, callenv): "or bool dtypes." ) - colarrays = [col.data_array_view for col in colarrays] + colarrays = [col.data_array_view(mode="read") for col in colarrays] kernel = compiled["kernel"] # process env args diff --git a/python/strings_udf/strings_udf/_typing.py b/python/strings_udf/strings_udf/_typing.py index 99e4046b0b3..80deb881ec8 100644 --- a/python/strings_udf/strings_udf/_typing.py +++ b/python/strings_udf/strings_udf/_typing.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. import operator @@ -12,6 +12,8 @@ from numba.cuda.cudadecl import registry as cuda_decl_registry from numba.cuda.cudadrv import nvvm +import rmm + data_layout = nvvm.data_layout # libcudf size_type @@ -112,7 +114,9 @@ def prepare_args(self, ty, val, **kwargs): if isinstance(ty, types.CPointer) and isinstance( ty.dtype, (StringView, UDFString) ): - return types.uint64, val.ptr + return types.uint64, val.ptr if isinstance( + val, rmm._lib.device_buffer.DeviceBuffer + ) else val.get_ptr(mode="read") else: return ty, val