From f61649949fac17a734c0e4dda40c19da0c9834f0 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 7 Nov 2022 12:30:56 +0100 Subject: [PATCH] Porting spillabe buffer and manager from #11553 --- ci/gpu/build.sh | 4 + .../source/developer_guide/library_design.md | 9 +- python/cudf/cudf/_lib/binaryop.pyx | 3 + python/cudf/cudf/_lib/column.pxd | 6 +- python/cudf/cudf/_lib/column.pyx | 104 +++- python/cudf/cudf/_lib/copying.pyx | 21 +- python/cudf/cudf/_lib/groupby.pyx | 6 +- python/cudf/cudf/_lib/transform.pyx | 6 +- python/cudf/cudf/_lib/transpose.pyx | 6 +- python/cudf/cudf/_lib/unary.pyx | 7 + python/cudf/cudf/core/buffer/__init__.py | 3 +- python/cudf/cudf/core/buffer/buffer.py | 22 + python/cudf/cudf/core/buffer/spill_manager.py | 306 +++++++++++ .../cudf/cudf/core/buffer/spillable_buffer.py | 473 ++++++++++++++++++ python/cudf/cudf/core/buffer/utils.py | 65 ++- python/cudf/cudf/core/column/column.py | 4 +- python/cudf/cudf/core/column/decimal.py | 4 +- python/cudf/cudf/core/df_protocol.py | 22 +- python/cudf/cudf/core/groupby/groupby.py | 4 + python/cudf/cudf/options.py | 71 +++ python/cudf/cudf/tests/conftest.py | 14 + python/cudf/cudf/tests/test_buffer.py | 12 +- python/cudf/cudf/tests/test_groupby.py | 6 +- python/cudf/cudf/tests/test_spilling.py | 464 +++++++++++++++++ python/cudf/cudf/utils/utils.py | 2 +- .../strings_udf/_lib/cudf_jit_udf.pyx | 2 +- 26 files changed, 1602 insertions(+), 44 deletions(-) create mode 100644 python/cudf/cudf/core/buffer/spill_manager.py create mode 100644 python/cudf/cudf/core/buffer/spillable_buffer.py create mode 100644 python/cudf/cudf/tests/test_spilling.py diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 500c3bdbcc5..516d369f5d9 100755 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -282,6 +282,10 @@ conda list gpuci_logger "Python py.test for cuDF" py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" --junitxml="$WORKSPACE/junit-cudf.xml" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests +gpuci_logger "Python py.tests for cuDF with spilling (CUDF_SPILL_DEVICE_LIMIT=1)" +# Due to time concerns, we only run a limited set of tests +CUDF_SPILL=on CUDF_SPILL_DEVICE_LIMIT=1 py.test -n 8 --cache-clear --basetemp="$WORKSPACE/cudf-cuda-tmp" --ignore="$WORKSPACE/python/cudf/cudf/benchmarks" -v --cov-config="$WORKSPACE/python/cudf/.coveragerc" --cov-append --cov=cudf --cov-report=xml:"$WORKSPACE/python/cudf/cudf-coverage.xml" --cov-report term --dist=loadscope tests/test_binops.py tests/test_dataframe.py tests/test_buffer.py tests/test_onehot.py tests/test_reshape.py + cd "$WORKSPACE/python/dask_cudf" gpuci_logger "Python py.test for dask-cudf" py.test -n 8 --cache-clear --basetemp="$WORKSPACE/dask-cudf-cuda-tmp" --junitxml="$WORKSPACE/junit-dask-cudf.xml" -v --cov-config=.coveragerc --cov=dask_cudf --cov-report=xml:"$WORKSPACE/python/dask_cudf/dask-cudf-coverage.xml" --cov-report term dask_cudf diff --git a/docs/cudf/source/developer_guide/library_design.md b/docs/cudf/source/developer_guide/library_design.md index 2f0fb5d86fc..be233edf200 100644 --- a/docs/cudf/source/developer_guide/library_design.md +++ b/docs/cudf/source/developer_guide/library_design.md @@ -203,7 +203,6 @@ For instance, all numerical types (floats and ints of different widths) are all ### Buffer - `Column`s are in turn composed of one or more `Buffer`s. A `Buffer` represents a single, contiguous, device memory allocation owned by another object. A `Buffer` constructed from a preexisting device memory allocation (such as a CuPy array) will view that memory. @@ -212,6 +211,14 @@ Conversely, when constructed from a host object, The data is then copied from the host object into the newly allocated device memory. You can read more about [device memory allocation with RMM here](https://github.com/rapidsai/rmm). + +### Spilling to host memory + +Setting the environment variable `CUDF_SPILL=on` enables automatic spilling (and "unspilling") of buffers from +device to host to enable out-of-memory computation, i.e., computing on objects that occupy more memory than is +available on the GPU. + + ## The Cython layer The lowest level of cuDF is its interaction with `libcudf` via Cython. diff --git a/python/cudf/cudf/_lib/binaryop.pyx b/python/cudf/cudf/_lib/binaryop.pyx index 995fdc7e315..9455565a74f 100644 --- a/python/cudf/cudf/_lib/binaryop.pyx +++ b/python/cudf/cudf/_lib/binaryop.pyx @@ -22,6 +22,7 @@ from cudf._lib.cpp.types cimport data_type, type_id from cudf._lib.types cimport dtype_to_data_type, underlying_type_t_type_id from cudf.api.types import is_scalar, is_string_dtype +from cudf.core.buffer import with_spill_lock cimport cudf._lib.cpp.binaryop as cpp_binaryop from cudf._lib.cpp.binaryop cimport binary_operator @@ -156,6 +157,7 @@ cdef binaryop_s_v(DeviceScalar lhs, Column rhs, return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def binaryop(lhs, rhs, op, dtype): """ Dispatches a binary op call to the appropriate libcudf function: @@ -203,6 +205,7 @@ def binaryop(lhs, rhs, op, dtype): return result +@with_spill_lock() def binaryop_udf(Column lhs, Column rhs, udf_ptx, dtype): """ Apply a user-defined binary operator (a UDF) defined in `udf_ptx` on diff --git a/python/cudf/cudf/_lib/column.pxd b/python/cudf/cudf/_lib/column.pxd index 2df958466c6..f8f851bfe0f 100644 --- a/python/cudf/cudf/_lib/column.pxd +++ b/python/cudf/cudf/_lib/column.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -28,7 +28,9 @@ cdef class Column: cdef mutable_column_view mutable_view(self) except * @staticmethod - cdef Column from_unique_ptr(unique_ptr[column] c_col) + cdef Column from_unique_ptr( + unique_ptr[column] c_col, bint data_ptr_exposed=* + ) @staticmethod cdef Column from_column_view(column_view, object) diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 918d786fb83..9e5b62ab404 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -8,7 +8,14 @@ import rmm import cudf import cudf._lib as libcudf from cudf.api.types import is_categorical_dtype -from cudf.core.buffer import Buffer, as_buffer +from cudf.core.buffer import ( + Buffer, + SpillableBuffer, + SpillLock, + as_buffer, + get_spill_lock, + with_spill_lock, +) from cpython.buffer cimport PyObject_CheckBuffer from libc.stdint cimport uintptr_t @@ -95,7 +102,11 @@ cdef class Column: if self._data is None: start = self.offset * self.dtype.itemsize end = start + self.size * self.dtype.itemsize - self._data = self.base_data[start:end] + if start == 0 and end == self.base_data.size: + # `data` spans all of `base_data` + self._data = self.base_data + else: + self._data = self.base_data[start:end] return self._data @property @@ -249,7 +260,8 @@ cdef class Column: @property def null_count(self): if self._null_count is None: - self._null_count = self.compute_null_count() + with with_spill_lock(): + self._null_count = self.compute_null_count() return self._null_count @property @@ -381,7 +393,14 @@ cdef class Column: cdef vector[column_view] children cdef void* data - data = (col.base_data_ptr) + if col.base_data is None: + data = NULL + elif isinstance(col.base_data, SpillableBuffer): + data = (col.base_data).get_ptr( + spill_lock=get_spill_lock() + ) + else: + data = (col.base_data.ptr) cdef Column child_column if col.base_children: @@ -406,7 +425,16 @@ cdef class Column: children) @staticmethod - cdef Column from_unique_ptr(unique_ptr[column] c_col): + cdef Column from_unique_ptr( + unique_ptr[column] c_col, bint data_ptr_exposed=False + ): + """Create a Column from a column + + Typically, this is called on the result of a libcudf operation. + If the data of the libcudf result has been exposed, set + `data_ptr_exposed=True` to expose the memory of the returned Column + as well. + """ cdef column_view view = c_col.get()[0].view() cdef libcudf_types.type_id tid = view.type().id() cdef libcudf_types.data_type c_dtype @@ -431,20 +459,30 @@ cdef class Column: # After call to release(), c_col is unusable cdef column_contents contents = move(c_col.get()[0].release()) - data = DeviceBuffer.c_from_unique_ptr(move(contents.data)) - data = as_buffer(data) + data = as_buffer( + DeviceBuffer.c_from_unique_ptr(move(contents.data)), + exposed=data_ptr_exposed + ) if null_count > 0: - mask = DeviceBuffer.c_from_unique_ptr(move(contents.null_mask)) - mask = as_buffer(mask) + mask = as_buffer( + DeviceBuffer.c_from_unique_ptr(move(contents.null_mask)), + exposed=data_ptr_exposed + ) else: mask = None cdef vector[unique_ptr[column]] c_children = move(contents.children) - children = () + children = [] if c_children.size() != 0: - children = tuple(Column.from_unique_ptr(move(c_children[i])) - for i in range(c_children.size())) + # Because of a bug in Cython, we cannot set the optional + # `data_ptr_exposed` argument within a comprehension. + for i in range(c_children.size()): + child = Column.from_unique_ptr( + move(c_children[i]), + data_ptr_exposed=data_ptr_exposed + ) + children.append(child) return cudf.core.column.build_column( data, @@ -452,7 +490,7 @@ cdef class Column: mask=mask, size=size, null_count=null_count, - children=children + children=tuple(children) ) @staticmethod @@ -474,6 +512,7 @@ cdef class Column: size = cv.size() offset = cv.offset() dtype = dtype_from_column_view(cv) + dtype_itemsize = dtype.itemsize if hasattr(dtype, "itemsize") else 1 data_ptr = (cv.head[void]()) data = None @@ -484,19 +523,45 @@ cdef class Column: data_owner = owner.base_data mask_owner = mask_owner.base_mask base_size = owner.base_size - + base_nbytes = base_size * dtype_itemsize if data_ptr: if data_owner is None: data = as_buffer( rmm.DeviceBuffer(ptr=data_ptr, - size=(size+offset) * dtype.itemsize) + size=(size+offset) * dtype_itemsize) ) + elif ( + # This is an optimization to avoid creating a new + # SpillableBuffer that represent the same memory + # as the owner. + column_owner and + isinstance(data_owner, SpillableBuffer) and + # We have to make sure that `data_owner` is already spill + # locked and that its pointer is the same as `data_ptr` + # _without_ exposing the buffer permanently. + not data_owner.spillable and + data_owner.get_ptr(spill_lock=SpillLock()) == data_ptr and + data_owner.size == base_nbytes + ): + data = data_owner else: + # At this point we don't know the relationship between data_ptr + # and data_owner thus we mark both of them exposed. + # TODO: try to discover their relationship and create a + # SpillableBufferSlice instead. data = as_buffer( - data=data_ptr, - size=(base_size) * dtype.itemsize, - owner=data_owner + data_ptr, + size=base_nbytes, + owner=data_owner, + exposed=True, ) + if isinstance(data_owner, SpillableBuffer): + if data_owner.is_spilled: + raise ValueError( + f"{data_owner} is spilled, which invalidates " + f"the exposed data_ptr ({hex(data_ptr)})" + ) + data_owner.ptr # accessing the pointer marks it exposed. else: data = as_buffer( rmm.DeviceBuffer(ptr=data_ptr, size=0) @@ -538,7 +603,8 @@ cdef class Column: mask = as_buffer( data=mask_ptr, size=bitmask_allocation_size_bytes(base_size), - owner=mask_owner + owner=mask_owner, + exposed=True ) if cv.has_nulls(): diff --git a/python/cudf/cudf/_lib/copying.pyx b/python/cudf/cudf/_lib/copying.pyx index d9a7a5b8754..7cd811caa26 100644 --- a/python/cudf/cudf/_lib/copying.pyx +++ b/python/cudf/cudf/_lib/copying.pyx @@ -12,7 +12,7 @@ from libcpp.vector cimport vector from rmm._lib.device_buffer cimport DeviceBuffer import cudf -from cudf.core.buffer import Buffer, as_buffer +from cudf.core.buffer import Buffer, as_buffer, with_spill_lock from cudf._lib.column cimport Column @@ -64,6 +64,7 @@ def _gather_map_is_valid( return gm_min >= -nrows and gm_max < nrows +@with_spill_lock() def copy_column(Column input_column): """ Deep copies a column @@ -132,6 +133,7 @@ def _copy_range(Column input_column, return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def copy_range(Column input_column, Column target_column, size_type input_begin, @@ -164,6 +166,7 @@ def copy_range(Column input_column, input_begin, input_end, target_begin) +@with_spill_lock() def gather( list columns, Column gather_map, @@ -231,6 +234,7 @@ cdef scatter_column(list source_columns, return columns_from_unique_ptr(move(c_result)) +@with_spill_lock() def scatter(list sources, Column scatter_map, list target_columns, bool bounds_check=True): """ @@ -271,6 +275,7 @@ def scatter(list sources, Column scatter_map, list target_columns, ) +@with_spill_lock() def column_empty_like(Column input_column): cdef column_view input_column_view = input_column.view() @@ -282,6 +287,7 @@ def column_empty_like(Column input_column): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def column_allocate_like(Column input_column, size=None): cdef size_type c_size = 0 @@ -306,6 +312,7 @@ def column_allocate_like(Column input_column, size=None): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def columns_empty_like(list input_columns): cdef table_view input_table_view = table_view_from_columns(input_columns) cdef unique_ptr[table] c_result @@ -316,6 +323,7 @@ def columns_empty_like(list input_columns): return columns_from_unique_ptr(move(c_result)) +@with_spill_lock() def column_slice(Column input_column, object indices): cdef column_view input_column_view = input_column.view() @@ -345,6 +353,7 @@ def column_slice(Column input_column, object indices): return result +@with_spill_lock() def columns_slice(list input_columns, list indices): """ Given a list of input columns, return columns sliced by ``indices``. @@ -371,6 +380,7 @@ def columns_slice(list input_columns, list indices): ] +@with_spill_lock() def column_split(Column input_column, object splits): cdef column_view input_column_view = input_column.view() @@ -402,6 +412,7 @@ def column_split(Column input_column, object splits): return result +@with_spill_lock() def columns_split(list input_columns, object splits): cdef table_view input_table_view = table_view_from_columns(input_columns) @@ -508,6 +519,7 @@ def _copy_if_else_scalar_scalar(DeviceScalar lhs, return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def copy_if_else(object lhs, object rhs, Column boolean_mask): if isinstance(lhs, Column): @@ -575,6 +587,7 @@ def _boolean_mask_scatter_scalar(list input_scalars, list target_columns, return columns_from_unique_ptr(move(c_result)) +@with_spill_lock() def boolean_mask_scatter(list input_, list target_columns, Column boolean_mask): """Copy the target columns, replacing masked rows with input data. @@ -607,6 +620,7 @@ def boolean_mask_scatter(list input_, list target_columns, ) +@with_spill_lock() def shift(Column input, int offset, object fill_value=None): cdef DeviceScalar fill @@ -643,6 +657,7 @@ def shift(Column input, int offset, object fill_value=None): return Column.from_unique_ptr(move(c_output)) +@with_spill_lock() def get_element(Column input_column, size_type index): cdef column_view col_view = input_column.view() @@ -657,6 +672,7 @@ def get_element(Column input_column, size_type index): ) +@with_spill_lock() def segmented_gather(Column source_column, Column gather_map): cdef shared_ptr[lists_column_view] source_LCV = ( make_shared[lists_column_view](source_column.view()) @@ -724,7 +740,8 @@ cdef class _CPackedColumns: gpu_data = as_buffer( data=self.gpu_data_ptr, size=self.gpu_data_size, - owner=self + owner=self, + exposed=True ) data_header, data_frames = gpu_data.serialize() header["data"] = data_header diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx index e6fbefaeee9..bea39c06387 100644 --- a/python/cudf/cudf/_lib/groupby.pyx +++ b/python/cudf/cudf/_lib/groupby.pyx @@ -10,6 +10,7 @@ from cudf.api.types import ( is_string_dtype, is_struct_dtype, ) +from cudf.core.buffer import with_spill_lock from libcpp cimport bool from libcpp.memory cimport unique_ptr @@ -86,13 +87,16 @@ cdef class GroupBy: def __cinit__(self, list keys, bool dropna=True, *args, **kwargs): cdef libcudf_types.null_policy c_null_handling + cdef table_view keys_view if dropna: c_null_handling = libcudf_types.null_policy.EXCLUDE else: c_null_handling = libcudf_types.null_policy.INCLUDE - cdef table_view keys_view = table_view_from_columns(keys) + with with_spill_lock() as spill_lock: + keys_view = table_view_from_columns(keys) + self._spill_lock = spill_lock with nogil: self.c_obj.reset( diff --git a/python/cudf/cudf/_lib/transform.pyx b/python/cudf/cudf/_lib/transform.pyx index b95bce0db58..1fa68282c3d 100644 --- a/python/cudf/cudf/_lib/transform.pyx +++ b/python/cudf/cudf/_lib/transform.pyx @@ -146,7 +146,11 @@ def one_hot_encode(Column input_column, Column categories): libcudf_transform.one_hot_encode(c_view_input, c_view_categories) ) - owner = Column.from_unique_ptr(move(c_result.first)) + # Notice, the data pointer of `owner` has been exposed + # through `c_result.second` at this point. + owner = Column.from_unique_ptr( + move(c_result.first), data_ptr_exposed=True + ) pylist_categories = categories.to_arrow().to_pylist() encodings, _ = data_from_table_view( diff --git a/python/cudf/cudf/_lib/transpose.pyx b/python/cudf/cudf/_lib/transpose.pyx index b9eea6169bd..51e49b1f27a 100644 --- a/python/cudf/cudf/_lib/transpose.pyx +++ b/python/cudf/cudf/_lib/transpose.pyx @@ -20,7 +20,11 @@ def transpose(list source_columns): with nogil: c_result = move(cpp_transpose(c_input)) - result_owner = Column.from_unique_ptr(move(c_result.first)) + # Notice, the data pointer of `result_owner` has been exposed + # through `c_result.second` at this point. + result_owner = Column.from_unique_ptr( + move(c_result.first), data_ptr_exposed=True + ) return columns_from_table_view( c_result.second, owners=[result_owner] * c_result.second.num_columns() diff --git a/python/cudf/cudf/_lib/unary.pyx b/python/cudf/cudf/_lib/unary.pyx index 52f0a804b2a..b1f5e3bd101 100644 --- a/python/cudf/cudf/_lib/unary.pyx +++ b/python/cudf/cudf/_lib/unary.pyx @@ -3,6 +3,7 @@ from enum import IntEnum from cudf.api.types import is_decimal_dtype +from cudf.core.buffer import with_spill_lock from libcpp.memory cimport unique_ptr from libcpp.utility cimport move @@ -43,6 +44,7 @@ class UnaryOp(IntEnum): NOT = unary_operator.NOT +@with_spill_lock() def unary_operation(Column input, object op): cdef column_view c_input = input.view() cdef unary_operator c_op = ( @@ -60,6 +62,7 @@ def unary_operation(Column input, object op): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def is_null(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -70,6 +73,7 @@ def is_null(Column input): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def is_valid(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -80,6 +84,7 @@ def is_valid(Column input): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def cast(Column input, object dtype=np.float64): cdef column_view c_input = input.view() cdef data_type c_dtype = dtype_to_data_type(dtype) @@ -95,6 +100,7 @@ def cast(Column input, object dtype=np.float64): return result +@with_spill_lock() def is_nan(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result @@ -105,6 +111,7 @@ def is_nan(Column input): return Column.from_unique_ptr(move(c_result)) +@with_spill_lock() def is_non_nan(Column input): cdef column_view c_input = input.view() cdef unique_ptr[column] c_result diff --git a/python/cudf/cudf/core/buffer/__init__.py b/python/cudf/cudf/core/buffer/__init__.py index a73bc69ffb5..044f2fa0478 100644 --- a/python/cudf/cudf/core/buffer/__init__.py +++ b/python/cudf/cudf/core/buffer/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2022, NVIDIA CORPORATION. from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper -from cudf.core.buffer.utils import as_buffer +from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock +from cudf.core.buffer.utils import as_buffer, get_spill_lock, with_spill_lock diff --git a/python/cudf/cudf/core/buffer/buffer.py b/python/cudf/cudf/core/buffer/buffer.py index 73e589ebb8e..29534ab5529 100644 --- a/python/cudf/cudf/core/buffer/buffer.py +++ b/python/cudf/cudf/core/buffer/buffer.py @@ -148,6 +148,28 @@ def _from_host_memory(cls: Type[T], data: Any) -> T: # Create from device memory return cls._from_device_memory(buf) + @classmethod + def _from_any_memory(cls: Type[T], data: Any) -> T: + """Create a Buffer from device or host memory + + If data exposes `__cuda_array_interface__`, we deligate to the + `_from_device_memory` constructor otherwise `_from_host_memory`. + + Parameters + ---------- + data : Any + An object that represens device or host memory. + + Returns + ------- + Buffer + Buffer representing `data`. + """ + + if hasattr(data, "__cuda_array_interface__"): + return cls._from_device_memory(data) + return cls._from_host_memory(data) + def _getitem(self, offset: int, size: int) -> Buffer: """ Sub-classes can overwrite this to implement __getitem__ diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py new file mode 100644 index 00000000000..821cae54128 --- /dev/null +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -0,0 +1,306 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +from __future__ import annotations + +import gc +import io +import threading +import traceback +import warnings +import weakref +from typing import List, Optional, Tuple + +import rmm.mr + +from cudf.core.buffer.spillable_buffer import SpillableBuffer +from cudf.options import get_option +from cudf.utils.string import format_bytes + + +def get_traceback() -> str: + """Pretty print current traceback to a string""" + with io.StringIO() as f: + traceback.print_stack(file=f) + f.seek(0) + return f.read() + + +def get_rmm_memory_resource_stack( + mr: rmm.mr.DeviceMemoryResource, +) -> List[rmm.mr.DeviceMemoryResource]: + """Get the RMM resource stack + + Parameters + ---------- + mr : rmm.mr.DeviceMemoryResource + Top of the resource stack + + Return + ------ + list + List of RMM resources + """ + + if hasattr(mr, "upstream_mr"): + return [mr] + get_rmm_memory_resource_stack(mr.upstream_mr) + return [mr] + + +class SpillManager: + """Manager of spillable buffers. + + This class implements tracking of all known spillable buffers, on-demand + spilling of said buffers, and (optionally) maintains a memory usage limit. + + When `spill_on_demand=True`, the manager registers an RMM out-of-memory + error handler, which will spill spillable buffers in order to free up + memory. + + When `device_memory_limit=True`, the manager will try keep the device + memory usage below the specified limit by spilling of spillable buffers + continuously, which will introduce a modest overhead. + + Parameters + ---------- + spill_on_demand : bool + Enable spill on demand. The global manager sets this to the value of + `CUDF_SPILL_ON_DEMAND` or False. + device_memory_limit: int, optional + If not None, this is the device memory limit in bytes that triggers + device to host spilling. The global manager sets this to the value + of `CUDF_SPILL_DEVICE_LIMIT` or None. + """ + + _base_buffers: weakref.WeakValueDictionary[int, SpillableBuffer] + + def __init__( + self, + *, + spill_on_demand: bool = False, + device_memory_limit: int = None, + ) -> None: + self._lock = threading.Lock() + self._base_buffers = weakref.WeakValueDictionary() + self._id_counter = 0 + self._spill_on_demand = spill_on_demand + self._device_memory_limit = device_memory_limit + + if self._spill_on_demand: + # Set the RMM out-of-memory handle if not already set + mr = rmm.mr.get_current_device_resource() + if all( + not isinstance(m, rmm.mr.FailureCallbackResourceAdaptor) + for m in get_rmm_memory_resource_stack(mr) + ): + rmm.mr.set_current_device_resource( + rmm.mr.FailureCallbackResourceAdaptor( + mr, self._out_of_memory_handle + ) + ) + + def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool: + """Try to handle an out-of-memory error by spilling + + This can by used as the callback function to RMM's + `FailureCallbackResourceAdaptor` + + Parameters + ---------- + nbytes : int + Number of bytes to try to spill. + retry_once : bool, optional + If True, call `gc.collect()` and retry once. + + Return + ------ + bool + True if any buffers were freed otherwise False. + + Warning + ------- + In order to avoid deadlock, this function should not lock + already locked buffers. + """ + + # Keep spilling until `nbytes` been spilled + total_spilled = 0 + while total_spilled < nbytes: + spilled = self.spill_device_memory() + if spilled == 0: + break # No more to spill! + total_spilled += spilled + + if total_spilled > 0: + return True # Ask RMM to retry the allocation + + if retry_once: + # Let's collect garbage and try one more time + gc.collect() + return self._out_of_memory_handle(nbytes, retry_once=False) + + # TODO: write to log instead of stdout + print( + f"[WARNING] RMM allocation of {format_bytes(nbytes)} bytes " + "failed, spill-on-demand couldn't find any device memory to " + f"spill:\n{repr(self)}\ntraceback:\n{get_traceback()}" + ) + return False # Since we didn't find anything to spill, we give up + + def add(self, buffer: SpillableBuffer) -> None: + """Add buffer to the set of managed buffers + + The manager keeps a weak reference to the buffer + + Parameters + ---------- + buffer : SpillableBuffer + The buffer to manage + """ + if buffer.size > 0 and not buffer.exposed: + with self._lock: + self._base_buffers[self._id_counter] = buffer + self._id_counter += 1 + self.spill_to_device_limit() + + def base_buffers( + self, order_by_access_time: bool = False + ) -> Tuple[SpillableBuffer, ...]: + """Get all managed buffers + + Parameters + ---------- + order_by_access_time : bool, optional + Order the buffer by access time (ascending order) + + Return + ------ + tuple + Tuple of buffers + """ + with self._lock: + ret = tuple(self._base_buffers.values()) + if order_by_access_time: + ret = tuple(sorted(ret, key=lambda b: b.last_accessed)) + return ret + + def spill_device_memory(self) -> int: + """Try to spill device memory + + This function is safe to call doing spill-on-demand + since it does not lock buffers already locked. + + Return + ------ + int + Number of bytes spilled. + """ + for buf in self.base_buffers(order_by_access_time=True): + if buf.lock.acquire(blocking=False): + try: + if not buf.is_spilled and buf.spillable: + buf.__spill__(target="cpu") + return buf.size + finally: + buf.lock.release() + return 0 + + def spill_to_device_limit(self, device_limit: int = None) -> int: + """Spill until device limit + + Notice, by default this is a no-op. + + Parameters + ---------- + device_limit : int, optional + Limit in bytes. If None, the value of the environment variable + `CUDF_SPILL_DEVICE_LIMIT` is used. If this is not set, the method + does nothing and returns 0. + + Return + ------ + int + The number of bytes spilled. + """ + limit = ( + self._device_memory_limit if device_limit is None else device_limit + ) + if limit is None: + return 0 + ret = 0 + while True: + unspilled = sum( + buf.size for buf in self.base_buffers() if not buf.is_spilled + ) + if unspilled < limit: + break + nbytes = self.spill_device_memory() + if nbytes == 0: + break # No more to spill + ret += nbytes + return ret + + def lookup_address_range( # TODO: remove, only for debugging + self, ptr: int, size: int + ) -> List[SpillableBuffer]: + ret = [] + for buf in self.base_buffers(): + if buf.is_overlapping(ptr, size): + ret.append(buf) + return ret + + def __repr__(self) -> str: + spilled = sum( + buf.size for buf in self.base_buffers() if buf.is_spilled + ) + unspilled = sum( + buf.size for buf in self.base_buffers() if not buf.is_spilled + ) + unspillable = 0 + for buf in self.base_buffers(): + if not (buf.is_spilled or buf.spillable): + unspillable += buf.size + unspillable_ratio = unspillable / unspilled if unspilled else 0 + + return ( + f"" + ) + + +# The global manager has three states: +# - Uninitialized +# - Initialized to None (spilling disabled) +# - Initialized to a SpillManager instance (spilling enabled) +_global_manager_uninitialized: bool = True +_global_manager: Optional[SpillManager] = None + + +def set_global_manager(manager: Optional[SpillManager]) -> None: + """Set the global manager, which if None disables spilling""" + + global _global_manager, _global_manager_uninitialized + if _global_manager is not None: + gc.collect() + base_buffers = _global_manager.base_buffers() + if len(base_buffers) > 0: + warnings.warn(f"overwriting non-empty manager: {base_buffers}") + + _global_manager = manager + _global_manager_uninitialized = False + + +def get_global_manager() -> Optional[SpillManager]: + """Get the global manager or None if spilling is disabled""" + global _global_manager_uninitialized + if _global_manager_uninitialized: + manager = None + if get_option("spill"): + manager = SpillManager( + spill_on_demand=get_option("spill_on_demand"), + device_memory_limit=get_option("spill_device_limit"), + ) + set_global_manager(manager) + return _global_manager diff --git a/python/cudf/cudf/core/buffer/spillable_buffer.py b/python/cudf/cudf/core/buffer/spillable_buffer.py new file mode 100644 index 00000000000..50ac6b4e653 --- /dev/null +++ b/python/cudf/cudf/core/buffer/spillable_buffer.py @@ -0,0 +1,473 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +from __future__ import annotations + +import collections.abc +import pickle +import time +import weakref +from threading import RLock +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar + +import numpy + +import rmm + +from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper +from cudf.utils.string import format_bytes + +if TYPE_CHECKING: + from cudf.core.buffer.spill_manager import SpillManager + + +T = TypeVar("T", bound="SpillableBuffer") + + +class SpillLock: + pass + + +class DelayedPointerTuple(collections.abc.Sequence): + """ + A delayed version of the "data" field in __cuda_array_interface__. + + The idea is to delay the access to `Buffer.ptr` until the user + actually accesses the data pointer. + + For instance, in many cases __cuda_array_interface__ is accessed + only to determine whether an object is a CUDA object or not. + + TODO: this doesn't support libraries such as PyTorch that declare + the tuple of __cuda_array_interface__["data"] in Cython. In such + cases, Cython will raise an error because DelayedPointerTuple + isn't a "real" tuple. + """ + + def __init__(self, buffer) -> None: + self._buf = buffer + + def __len__(self): + return 2 + + def __getitem__(self, i): + if i == 0: + return self._buf.ptr + elif i == 1: + return False + raise IndexError("tuple index out of range") + + +class SpillableBuffer(Buffer): + """A spillable buffer that implements DeviceBufferLike. + + This buffer supports spilling the represented data to host memory. + Spilling can be done manually by calling `.__spill__(target="cpu")` but + usually the associated spilling manager triggers spilling based on current + device memory usage see `cudf.core.buffer.spill_manager.SpillManager`. + Unspill is triggered automatically when accessing the data of the buffer. + + The buffer might not be spillable, which is based on the "expose" status + of the buffer. We say that the buffer has been exposed if the device + pointer (integer or void*) has been accessed outside of SpillableBuffer. + In this case, we cannot invalidate the device pointer by moving the data + to host. + + A buffer can be exposed permanently at creation or by accessing the `.ptr` + property. To avoid this, one can use `.get_ptr()` instead, which support + exposing the buffer temporarily. + + Use the factory function `as_buffer` to create a SpillableBuffer instance. + """ + + _lock: RLock + _spill_locks: weakref.WeakSet + _last_accessed: float + _ptr_desc: Dict[str, Any] + _exposed: bool + _manager: SpillManager + + def _finalize_init(self, ptr_desc: Dict[str, Any], exposed: bool) -> None: + from cudf.core.buffer.spill_manager import get_global_manager + + self._lock = RLock() + self._spill_locks = weakref.WeakSet() + self._last_accessed = time.monotonic() + self._ptr_desc = ptr_desc + self._exposed = exposed + manager = get_global_manager() + if manager is None: + raise ValueError( + f"cannot create {self.__class__} with a global spill manager" + ) + + if self._ptr: + # TODO: run the following asserts in "debug mode" or not at all. + # Assert that any buffers `data` may refer to has been exposed + # already. If this is not the case, it means that somewhere we + # are accessing a buffer's device pointer without marking it as + # exposed, which would be a bug. + bases = manager.lookup_address_range(self._ptr, self._size) + assert all(b.exposed for b in bases) + # Assert that if `data` refers to any existing base buffers, it + # must itself be exposed. + assert len(bases) == 0 or exposed + + self._manager = manager + self._manager.add(self) + + @classmethod + def _from_device_memory( + cls: Type[T], data: Any, *, exposed: bool = False + ) -> T: + """Create a spillabe buffer from device memory. + + No data is being copied. + + Parameters + ---------- + data : device-buffer-like + An object implementing the CUDA Array Interface. + exposed : bool, optional + Mark the buffer as permanently exposed (unspillable). + + Returns + ------- + SpillableBuffer + Buffer representing the same device memory as `data` + """ + ret = super(SpillableBuffer, cls)._from_device_memory(data) + ret._finalize_init(ptr_desc={"type": "gpu"}, exposed=exposed) + return ret + + @classmethod + def _from_host_memory(cls: Type[T], data: Any) -> T: + """Create a spillabe buffer from host memory. + + Data must implement `__array_interface__`, the buffer protocol, and/or + be convertible to a buffer object using `numpy.array()` + + The new buffer is marked as spilled to host memory already. + + Raises ValueError if array isn't C-contiguous. + + Parameters + ---------- + data : Any + An object that represens host memory. + + Returns + ------- + SpillableBuffer + Buffer representing a copy of `data`. + """ + + # Convert to a memoryview using numpy array, this will not copy data + # in most cases. + data = memoryview(numpy.array(data, copy=False, subok=True)) + if not data.c_contiguous: + raise ValueError("Buffer data must be C-contiguous") + + # Create an already spilled buffer + ret = cls.__new__(cls) + ret._owner = None + ret._ptr = 0 + ret._size = data.nbytes + ret._finalize_init( + ptr_desc={"type": "cpu", "memoryview": data}, exposed=False + ) + return ret + + @property + def lock(self) -> RLock: + return self._lock + + @property + def is_spilled(self) -> bool: + return self._ptr_desc["type"] != "gpu" + + def __spill__(self, target: str = "cpu") -> None: + """Spill or un-spill this buffer in-place + + Parameters + ---------- + target : str + The target of the spilling. + """ + + with self._lock: + ptr_type = self._ptr_desc["type"] + if ptr_type == target: + return + + if not self.spillable: + raise ValueError( + f"Cannot in-place move an unspillable buffer: {self}" + ) + + if (ptr_type, target) == ("gpu", "cpu"): + host_mem = memoryview(bytearray(self.size)) + rmm._lib.device_buffer.copy_ptr_to_host(self._ptr, host_mem) + self._ptr_desc["memoryview"] = host_mem + self._ptr = 0 + self._owner = None + elif (ptr_type, target) == ("cpu", "gpu"): + # Notice, this operation is prone to deadlock because the RMM + # allocation might trigger spilling-on-demand which in turn + # trigger a new call to this buffer's `__spill__()`. + # Therefore, it is important that spilling-on-demand doesn't + # tries to unspill an already locked buffer! + dev_mem = rmm.DeviceBuffer.to_device( + self._ptr_desc.pop("memoryview") + ) + self._ptr = dev_mem.ptr + self._owner = dev_mem + assert self._size == dev_mem.size + else: + # TODO: support moving to disk + raise ValueError(f"Unknown target: {target}") + self._ptr_desc["type"] = target + + @property + def ptr(self) -> int: + """Access the memory directly + + Notice, this will mark the buffer as "exposed" and make + it unspillable permanently. + + Consider using `.get_ptr()` instead. + """ + + self._manager.spill_to_device_limit() + with self._lock: + self.__spill__(target="gpu") + self._exposed = True + self._last_accessed = time.monotonic() + return self._ptr + + def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: + if spill_lock is None: + spill_lock = SpillLock() + with self._lock: + self.__spill__(target="gpu") + self._spill_locks.add(spill_lock) + return spill_lock + + def get_ptr(self, spill_lock: SpillLock = None) -> int: + """Get a device pointer to the memory of the buffer. + + If spill_lock is not None, a reference to this buffer is added + to spill_lock, which disable spilling of this buffer while + spill_lock is alive. + + Parameters + ---------- + spill_lock : SpillLock, optional + Adding a reference of this buffer to the spill lock. + + Return + ------ + int + The device pointer as an integer + """ + + if spill_lock is None: + return self.ptr # expose the buffer permanently + + self.spill_lock(spill_lock) + self._last_accessed = time.monotonic() + return self._ptr + + @property + def owner(self) -> Any: + return self._owner + + @property + def exposed(self) -> bool: + return self._exposed + + @property + def spillable(self) -> bool: + return not self._exposed and len(self._spill_locks) == 0 + + @property + def size(self) -> int: + return self._size + + @property + def nbytes(self) -> int: + return self._size + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @property + def __cuda_array_interface__(self) -> dict: + return { + "data": DelayedPointerTuple(self), + "shape": (self.size,), + "strides": None, + "typestr": "|u1", + "version": 0, + } + + def memoryview(self, *, offset: int = 0, size: int = None) -> memoryview: + size = self._size if size is None else size + with self._lock: + if self.spillable: + self.__spill__(target="cpu") + return self._ptr_desc["memoryview"][offset : offset + size] + else: + assert self._ptr_desc["type"] == "gpu" + ret = memoryview(bytearray(size)) + rmm._lib.device_buffer.copy_ptr_to_host( + self._ptr + offset, ret + ) + return ret + + def _getitem(self, offset: int, size: int) -> Buffer: + return SpillableBufferSlice(base=self, offset=offset, size=size) + + def serialize(self) -> Tuple[dict, list]: + """Serialize the Buffer + + Normally, we would use `[self]` as the frames. This would work but + also mean that `self` becomes exposed permanently if the frames are + later accessed through `__cuda_array_interface__`, which is exactly + what libraries like Dask+UCX would do when communicating! + + The sound solution is to modify Dask et al. so that they access the + frames through `.get_ptr()` and holds on to the `spill_lock` until + the frame has been transferred. However, until this adaptation we + use a hack where the frame is a `Buffer` with a `spill_lock` as the + owner, which makes `self` unspillable while the frame is alive but + doesn't expose `self` when `__cuda_array_interface__` is accessed. + + Warning, this hack means that the returned frame must be copied before + given to `.deserialize()`, otherwise we would have a `Buffer` pointing + to memory already owned by an existing `SpillableBuffer`. + """ + header: Dict[Any, Any] + frames: List[Buffer | memoryview] + with self._lock: + header = {} + header["type-serialized"] = pickle.dumps(self.__class__) + header["frame_count"] = 1 + if self.is_spilled: + frames = [self.memoryview()] + else: + # TODO: Use `frames=[self]` instead of this hack, see doc above + spill_lock = SpillLock() + ptr = self.get_ptr(spill_lock=spill_lock) + frames = [ + Buffer._from_device_memory( + cuda_array_interface_wrapper( + ptr=ptr, + size=self.size, + owner=(self._owner, spill_lock), + ) + ) + ] + return header, frames + + def is_overlapping(self, ptr: int, size: int): + with self._lock: + return ( + not self.is_spilled + and (ptr + size) > self._ptr + and (self._ptr + self._size) > ptr + ) + + def __repr__(self) -> str: + if self._ptr_desc["type"] != "gpu": + ptr_info = str(self._ptr_desc) + else: + ptr_info = str(hex(self._ptr)) + return ( + f"" + ) + + +class SpillableBufferSlice(SpillableBuffer): + """A slice of a spillable buffer + + This buffer applies the slicing and then delegates all + operations to its base buffer. + + Parameters + ---------- + base : SpillableBuffer + The base of the view + offset : int + Memory offset into the base buffer + size : int + Size of the view (in bytes) + """ + + def __init__(self, base: SpillableBuffer, offset: int, size: int) -> None: + if size < 0: + raise ValueError("size cannot be negative") + if offset < 0: + raise ValueError("offset cannot be negative") + if offset + size > base.size: + raise ValueError( + "offset+size cannot be greater than the size of base" + ) + self._base = base + self._offset = offset + self._size = size + self._owner = base + self._lock = base.lock + + @property + def ptr(self) -> int: + return self._base.ptr + self._offset + + def get_ptr(self, spill_lock: SpillLock = None) -> int: + return self._base.get_ptr(spill_lock=spill_lock) + self._offset + + def _getitem(self, offset: int, size: int) -> Buffer: + return SpillableBufferSlice( + base=self._base, offset=offset + self._offset, size=size + ) + + @classmethod + def deserialize(cls, header: dict, frames: list): + # TODO: because of the hack in `SpillableBuffer.serialize()` where + # frames are of type `Buffer`, we always deserialize as if they are + # `SpillableBufferbuffer`. In the future, we should be able to + # deserialize into `SpillableBufferSlice` when the frames hasn't been + # copied. + return SpillableBuffer.deserialize(header, frames) + + def memoryview(self, *, offset: int = 0, size: int = None) -> memoryview: + size = self._size if size is None else size + return self._base.memoryview(offset=self._offset + offset, size=size) + + def __repr__(self) -> str: + return ( + f" None: + return self._base.__spill__(target=target) + + @property + def is_spilled(self) -> bool: + return self._base.is_spilled + + @property + def exposed(self) -> bool: + return self._base.exposed + + @property + def spillable(self) -> bool: + return self._base.spillable + + def spill_lock(self, spill_lock: SpillLock = None) -> SpillLock: + return self._base.spill_lock(spill_lock=spill_lock) diff --git a/python/cudf/cudf/core/buffer/utils.py b/python/cudf/cudf/core/buffer/utils.py index 5e017c4bc92..3da1d610ca1 100644 --- a/python/cudf/cudf/core/buffer/utils.py +++ b/python/cudf/cudf/core/buffer/utils.py @@ -2,9 +2,13 @@ from __future__ import annotations -from typing import Any, Union +import threading +from contextlib import ContextDecorator +from typing import Any, Dict, Optional, Tuple, Union from cudf.core.buffer.buffer import Buffer, cuda_array_interface_wrapper +from cudf.core.buffer.spill_manager import get_global_manager +from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock def as_buffer( @@ -12,6 +16,7 @@ def as_buffer( *, size: int = None, owner: object = None, + exposed: bool = False, ) -> Buffer: """Factory function to wrap `data` in a Buffer object. @@ -37,6 +42,10 @@ def as_buffer( owner : object, optional Python object to which the lifetime of the memory allocation is tied. A reference to this object is kept in the returned Buffer. + exposed : bool, optional + Mark the buffer as permanently exposed (unspillable). This is ignored + unless spilling is enabled and the data represents device memory, see + SpillableBuffer. Return ------ @@ -62,6 +71,60 @@ def as_buffer( "`data` is a buffer-like or array-like object" ) + if get_global_manager() is not None: + if hasattr(data, "__cuda_array_interface__"): + return SpillableBuffer._from_device_memory(data, exposed=exposed) + if exposed: + raise ValueError("cannot created exposed host memory") + return SpillableBuffer._from_host_memory(data) + if hasattr(data, "__cuda_array_interface__"): return Buffer._from_device_memory(data) return Buffer._from_host_memory(data) + + +_thread_spill_locks: Dict[int, Tuple[Optional[SpillLock], int]] = {} + + +def _push_thread_spill_lock() -> None: + _id = threading.get_ident() + spill_lock, count = _thread_spill_locks.get(_id, (None, 0)) + if spill_lock is None: + spill_lock = SpillLock() + _thread_spill_locks[_id] = (spill_lock, count + 1) + + +def _pop_thread_spill_lock() -> None: + _id = threading.get_ident() + spill_lock, count = _thread_spill_locks[_id] + if count == 1: + spill_lock = None + _thread_spill_locks[_id] = (spill_lock, count - 1) + + +class with_spill_lock(ContextDecorator): + """Decorator and context to set spill lock automatically. + + All calls to `get_spill_lock()` within the decorated function or context + will return a spill lock with a lifetime bound to the function or context. + """ + + def __enter__(self) -> Optional[SpillLock]: + _push_thread_spill_lock() + return get_spill_lock() + + def __exit__(self, *exc): + _pop_thread_spill_lock() + + +def get_spill_lock() -> Union[SpillLock, None]: + """Return a spill lock within the context of `with_spill_lock` or None + + Returns None, if spilling is disabled. + """ + + if get_global_manager() is None: + return None + _id = threading.get_ident() + spill_lock, _ = _thread_spill_locks.get(_id, (None, 0)) + return spill_lock diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 6c17b492f8a..d16df7ea1c0 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1764,7 +1764,7 @@ def as_column( ): arbitrary = cupy.ascontiguousarray(arbitrary) - data = as_buffer(arbitrary) + data = as_buffer(arbitrary, exposed=True) col = build_column(data, dtype=current_dtype, mask=mask) if dtype is not None: @@ -2221,7 +2221,7 @@ def _mask_from_cuda_array_interface_desc(obj) -> Union[Buffer, None]: typecode = typestr[1] if typecode == "t": mask_size = bitmask_allocation_size_bytes(nelem) - mask = as_buffer(data=ptr, size=mask_size, owner=obj) + mask = as_buffer(data=ptr, size=mask_size, owner=obj, exposed=True) elif typecode == "b": col = as_column(mask) mask = bools_to_mask(col) diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 5ee9024a0d8..77ca3f9688b 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -203,7 +203,7 @@ def from_arrow(cls, data: pa.Array): data_128 = cp.array(np.frombuffer(data.buffers()[1]).view("int32")) data_32 = data_128[::4].copy() return cls( - data=as_buffer(data_32.view("uint8")), + data=as_buffer(data_32.view("uint8"), exposed=True), size=len(data), dtype=dtype, offset=data.offset, @@ -290,7 +290,7 @@ def from_arrow(cls, data: pa.Array): data_128 = cp.array(np.frombuffer(data.buffers()[1]).view("int64")) data_64 = data_128[::2].copy() return cls( - data=as_buffer(data_64.view("uint8")), + data=as_buffer(data_64.view("uint8"), exposed=True), size=len(data), dtype=dtype, offset=data.offset, diff --git a/python/cudf/cudf/core/df_protocol.py b/python/cudf/cudf/core/df_protocol.py index b29fc41e5b4..b38d3048ed7 100644 --- a/python/cudf/cudf/core/df_protocol.py +++ b/python/cudf/cudf/core/df_protocol.py @@ -721,7 +721,9 @@ def _protocol_to_cudf_column_numeric( _dbuffer, _ddtype = buffers["data"] _check_buffer_is_on_gpu(_dbuffer) cudfcol_num = build_column( - as_buffer(data=_dbuffer.ptr, size=_dbuffer.bufsize, owner=None), + as_buffer( + data=_dbuffer.ptr, size=_dbuffer.bufsize, owner=None, exposed=True + ), protocol_dtype_to_cupy_dtype(_ddtype), ) return _set_missing_values(col, cudfcol_num), buffers @@ -751,7 +753,11 @@ def _set_missing_values( valid_mask = protocol_col.get_buffers()["validity"] if valid_mask is not None: bitmask = cp.asarray( - as_buffer(data=valid_mask[0].ptr, size=valid_mask[0].bufsize), + as_buffer( + data=valid_mask[0].ptr, + size=valid_mask[0].bufsize, + exposed=True, + ), cp.bool8, ) cudf_col[~bitmask] = None @@ -790,7 +796,9 @@ def _protocol_to_cudf_column_categorical( _check_buffer_is_on_gpu(codes_buffer) cdtype = protocol_dtype_to_cupy_dtype(codes_dtype) codes = build_column( - as_buffer(data=codes_buffer.ptr, size=codes_buffer.bufsize), + as_buffer( + data=codes_buffer.ptr, size=codes_buffer.bufsize, exposed=True + ), cdtype, ) @@ -822,7 +830,9 @@ def _protocol_to_cudf_column_string( data_buffer, data_dtype = buffers["data"] _check_buffer_is_on_gpu(data_buffer) encoded_string = build_column( - as_buffer(data=data_buffer.ptr, size=data_buffer.bufsize), + as_buffer( + data=data_buffer.ptr, size=data_buffer.bufsize, exposed=True + ), protocol_dtype_to_cupy_dtype(data_dtype), ) @@ -832,7 +842,9 @@ def _protocol_to_cudf_column_string( offset_buffer, offset_dtype = buffers["offsets"] _check_buffer_is_on_gpu(offset_buffer) offsets = build_column( - as_buffer(data=offset_buffer.ptr, size=offset_buffer.bufsize), + as_buffer( + data=offset_buffer.ptr, size=offset_buffer.bufsize, exposed=True + ), protocol_dtype_to_cupy_dtype(offset_dtype), ) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index e4ea59c1f15..371c0566166 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -259,6 +259,10 @@ def __init__( else: self.grouping = _Grouping(obj, by, level) + self._groupby = libgroupby.GroupBy( + [*self.grouping.keys._columns], dropna=self._dropna + ) + def __iter__(self): group_names, offsets, _, grouped_values = self._grouped() if isinstance(group_names, cudf.BaseIndex): diff --git a/python/cudf/cudf/options.py b/python/cudf/cudf/options.py index 7f6a6f10e25..4a0a0437e00 100644 --- a/python/cudf/cudf/options.py +++ b/python/cudf/cudf/options.py @@ -1,5 +1,6 @@ # Copyright (c) 2022, NVIDIA CORPORATION. +import os import textwrap from collections.abc import Container from dataclasses import dataclass @@ -17,6 +18,26 @@ class Option: _OPTIONS: Dict[str, Option] = {} +def _env_get_int(name, default): + try: + return int(os.getenv(name, default)) + except (ValueError, TypeError): + return default + + +def _env_get_bool(name, default): + env = os.getenv(name) + if env is None: + return default + as_a_int = _env_get_int(name, None) + env = env.lower().strip() + if env == "true" or env == "on" or as_a_int: + return True + if env == "false" or env == "off" or as_a_int == 0: + return False + return default + + def _register_option( name: str, default_value: Any, description: str, validator: Callable ): @@ -129,6 +150,16 @@ def _validator(val): return _validator +def _integer_and_none_validator(val): + try: + if val is None or int(val): + return + except ValueError: + raise ValueError( + f"{val} is not a valid option. " f"Must be an integer or None." + ) + + _register_option( "default_integer_bitwidth", None, @@ -163,3 +194,43 @@ def _validator(val): ), _make_contains_validator([None, 32, 64]), ) + + +_register_option( + "spill", + _env_get_bool("CUDF_SPILL", False), + textwrap.dedent( + """ + Enables spilling. + \tValid values are True or False. Default is False. + """ + ), + _make_contains_validator([False, True]), +) + +_register_option( + "spill_on_demand", + _env_get_bool("CUDF_SPILL_ON_DEMAND", True), + textwrap.dedent( + """ + Enables spilling on demand using an RMM out-of-memory error handler. + This has no effect if spilling is disabled, see the "spill" option. + \tValid values are True or False. Default is True. + """ + ), + _make_contains_validator([False, True]), +) + +_register_option( + "spill_device_limit", + _env_get_int("CUDF_SPILL_DEVICE_LIMIT", None), + textwrap.dedent( + """ + Enforce a device memory limit in bytes. + This has no effect if spilling is disabled, see the "spill" option. + \tValid values are any positive integer or None (disabled). + \tDefault is None. + """ + ), + _integer_and_none_validator, +) diff --git a/python/cudf/cudf/tests/conftest.py b/python/cudf/cudf/tests/conftest.py index 258b628305d..bf565ed9a47 100644 --- a/python/cudf/cudf/tests/conftest.py +++ b/python/cudf/cudf/tests/conftest.py @@ -158,3 +158,17 @@ def default_float_bitwidth(request): cudf.set_option("default_float_bitwidth", request.param) yield request.param cudf.set_option("default_float_bitwidth", old_default) + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + """Hook to make result information available in fixtures + + See + """ + outcome = yield + rep = outcome.get_result() + + # Set a report attribute for each phase of a call, which can + # be "setup", "call", "teardown" + setattr(item, "report", {rep.when: rep}) diff --git a/python/cudf/cudf/tests/test_buffer.py b/python/cudf/cudf/tests/test_buffer.py index 5ed5750f29b..6ff715db761 100644 --- a/python/cudf/cudf/tests/test_buffer.py +++ b/python/cudf/cudf/tests/test_buffer.py @@ -48,15 +48,21 @@ def test_buffer_from_cuda_iface_dtype(data, dtype): def test_buffer_creation_from_any(): ary = cp.arange(arr_len) - b = as_buffer(ary) + b = as_buffer(ary, exposed=True) assert isinstance(b, Buffer) - assert ary.__cuda_array_interface__["data"][0] == b.ptr + assert ary.data.ptr == b.ptr assert ary.nbytes == b.size with pytest.raises( ValueError, match="size must be specified when `data` is an integer" ): - as_buffer(42) + as_buffer(ary.data.ptr) + + b = as_buffer(ary.data.ptr, size=ary.nbytes, owner=ary, exposed=True) + assert isinstance(b, Buffer) + assert ary.data.ptr == b.ptr + assert ary.nbytes == b.size + assert b.owner.owner is ary @pytest.mark.parametrize( diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index b00e31115c9..3898db1c9fa 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -1456,7 +1456,11 @@ def test_groupby_attribute_error(): class TestGroupBy(cudf.core.groupby.GroupBy): @property def _groupby(self): - raise AttributeError("Test error message") + raise AttributeError(err_msg) + + @_groupby.setter + def _groupby(self, _): + pass a = cudf.DataFrame({"a": [1, 2], "b": [2, 3]}) gb = TestGroupBy(a, a["a"]) diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py new file mode 100644 index 00000000000..b38d90f3178 --- /dev/null +++ b/python/cudf/cudf/tests/test_spilling.py @@ -0,0 +1,464 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. + +import importlib +import random +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from typing import Tuple + +import cupy +import numpy as np +import pandas +import pandas.testing +import pytest + +import rmm + +import cudf +import cudf.core.buffer.spill_manager +import cudf.options +from cudf.core.abc import Serializable +from cudf.core.buffer import Buffer, as_buffer, get_spill_lock, with_spill_lock +from cudf.core.buffer.spill_manager import ( + SpillManager, + get_global_manager, + get_rmm_memory_resource_stack, + set_global_manager, +) +from cudf.core.buffer.spillable_buffer import SpillableBuffer, SpillLock +from cudf.testing._utils import assert_eq + + +def gen_df(target="gpu") -> cudf.DataFrame: + ret = cudf.DataFrame({"a": [1, 2, 3]}) + if target != "gpu": + gen_df.buffer(ret).__spill__(target=target) + return ret + + +gen_df.buffer = lambda df: df._data._data["a"].data +gen_df.is_spilled = lambda df: gen_df.buffer(df).is_spilled +gen_df.is_spillable = lambda df: gen_df.buffer(df).spillable +gen_df.buffer_size = gen_df.buffer(gen_df()).size + + +def spilled_and_unspilled(manager: SpillManager) -> Tuple[int, int]: + """Get bytes spilled and unspilled known by the manager""" + spilled = sum(buf.size for buf in manager.base_buffers() if buf.is_spilled) + unspilled = sum( + buf.size for buf in manager.base_buffers() if not buf.is_spilled + ) + return spilled, unspilled + + +@pytest.fixture +def manager(request): + """Fixture to enable and make a spilling manager availabe""" + kwargs = dict(getattr(request, "param", {})) + with warnings.catch_warnings(): + warnings.simplefilter("error") + set_global_manager(manager=SpillManager(**kwargs)) + yield get_global_manager() + # Retrieving the test result using the `pytest_runtest_makereport` + # hook from conftest.py + if request.node.report["call"].failed: + # Ignore `overwriting non-empty manager` errors when + # test is failing. + warnings.simplefilter("ignore") + set_global_manager(manager=None) + + +def test_spillable_buffer(manager: SpillManager): + buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + assert isinstance(buf, SpillableBuffer) + assert buf.spillable + buf.ptr # Expose pointer + assert buf.exposed + assert not buf.spillable + buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + # Notice, accessing `__cuda_array_interface__` itself doesn't + # expose the pointer, only accessing the "data" field exposes + # the pointer. + iface = buf.__cuda_array_interface__ + assert not buf.exposed + assert buf.spillable + iface["data"][0] # Expose pointer + assert buf.exposed + assert not buf.spillable + + +@pytest.mark.parametrize( + "attribute", + [ + "ptr", + "get_ptr", + "memoryview", + "is_spilled", + "exposed", + "spillable", + "spill_lock", + "__spill__", + ], +) +def test_spillable_buffer_view_attributes(manager: SpillManager, attribute): + base = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + view = base[:] + attr_base = getattr(base, attribute) + attr_view = getattr(view, attribute) + if callable(attr_view): + pass + else: + assert attr_base == attr_view + + +def test_from_pandas(manager: SpillManager): + pdf1 = pandas.DataFrame({"x": [1, 2, 3]}) + df = cudf.from_pandas(pdf1) + assert df._data._data["x"].data.spillable + pdf2 = df.to_pandas() + pandas.testing.assert_frame_equal(pdf1, pdf2) + + +def test_creations(manager: SpillManager): + df = cudf.datasets.timeseries() + assert isinstance(df._data._data["x"].data, SpillableBuffer) + assert df._data._data["x"].data.spillable + df = cudf.DataFrame({"x": [1, 2, 3]}) + assert df._data._data["x"].data.spillable + df = cudf.datasets.randomdata(10) + assert df._data._data["x"].data.spillable + + +def test_spillable_df_groupby(manager: SpillManager): + df = cudf.DataFrame({"x": [1, 1, 1]}) + gb = df.groupby("x") + # `gb` holds a reference to the device memory, which makes + # the buffer unspillable + assert len(df._data._data["x"].data._spill_locks) == 1 + assert not df._data._data["x"].data.spillable + del gb + assert df._data._data["x"].data.spillable + + +def test_spilling_buffer(manager: SpillManager): + buf = as_buffer(rmm.DeviceBuffer(size=10), exposed=False) + buf.__spill__(target="cpu") + assert buf.is_spilled + buf.ptr # Expose pointer and trigger unspill + assert not buf.is_spilled + with pytest.raises(ValueError, match="unspillable buffer"): + buf.__spill__(target="cpu") + + +def test_environment_variables(monkeypatch): + def reload_options(): + # In order to enabling monkey patching of the environment variables + # mark the global manager as uninitialized. + set_global_manager(None) + cudf.core.buffer.spill_manager._global_manager_uninitialized = True + importlib.reload(cudf.options) + + monkeypatch.setenv("CUDF_SPILL_ON_DEMAND", "off") + monkeypatch.setenv("CUDF_SPILL", "off") + reload_options() + assert get_global_manager() is None + + monkeypatch.setenv("CUDF_SPILL", "on") + reload_options() + manager = get_global_manager() + assert isinstance(manager, SpillManager) + assert manager._spill_on_demand is False + assert manager._device_memory_limit is None + + monkeypatch.setenv("CUDF_SPILL_DEVICE_LIMIT", "1000") + reload_options() + manager = get_global_manager() + assert isinstance(manager, SpillManager) + assert manager._device_memory_limit == 1000 + + +def test_spill_device_memory(manager: SpillManager): + df = gen_df() + assert spilled_and_unspilled(manager) == (0, gen_df.buffer_size) + manager.spill_device_memory() + assert spilled_and_unspilled(manager) == (gen_df.buffer_size, 0) + del df + assert spilled_and_unspilled(manager) == (0, 0) + df1 = gen_df() + df2 = gen_df() + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert gen_df.is_spilled(df2) + df3 = df1 + df2 + assert not gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + assert not gen_df.is_spilled(df3) + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + assert not gen_df.is_spilled(df3) + df2.abs() # Should change the access time + manager.spill_device_memory() + assert gen_df.is_spilled(df1) + assert not gen_df.is_spilled(df2) + assert gen_df.is_spilled(df3) + + +def test_spill_to_device_limit(manager: SpillManager): + df1 = gen_df() + df2 = gen_df() + assert spilled_and_unspilled(manager) == (0, gen_df.buffer_size * 2) + manager.spill_to_device_limit(device_limit=0) + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 2, 0) + df3 = df1 + df2 + manager.spill_to_device_limit(device_limit=0) + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 3, 0) + assert gen_df.is_spilled(df1) + assert gen_df.is_spilled(df2) + assert gen_df.is_spilled(df3) + + +@pytest.mark.parametrize( + "manager", [{"device_memory_limit": 0}], indirect=True +) +def test_zero_device_limit(manager: SpillManager): + assert manager._device_memory_limit == 0 + df1 = gen_df() + df2 = gen_df() + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 2, 0) + df1 + df2 + # Notice, while performing the addintion both df1 and df2 are unspillable + assert spilled_and_unspilled(manager) == (0, gen_df.buffer_size * 2) + manager.spill_to_device_limit() + assert spilled_and_unspilled(manager) == (gen_df.buffer_size * 2, 0) + + +def test_lookup_address_range(manager: SpillManager): + df = gen_df() + buf = gen_df.buffer(df) + buffers = manager.base_buffers() + assert len(buffers) == 1 + (buf,) = buffers + assert gen_df.buffer(df) is buf + assert manager.lookup_address_range(buf.ptr, buf.size)[0] is buf + assert manager.lookup_address_range(buf.ptr + 1, buf.size - 1)[0] is buf + assert manager.lookup_address_range(buf.ptr + 1, buf.size + 1)[0] is buf + assert manager.lookup_address_range(buf.ptr - 1, buf.size - 1)[0] is buf + assert manager.lookup_address_range(buf.ptr - 1, buf.size + 1)[0] is buf + assert not manager.lookup_address_range(buf.ptr + buf.size, buf.size) + assert not manager.lookup_address_range(buf.ptr - buf.size, buf.size) + + +def test_external_memory_never_spills(manager): + """ + Test that external data, i.e., data not managed by RMM, + is never spilled + """ + + cupy.cuda.set_allocator() # uses default allocator + + a = cupy.asarray([1, 2, 3]) + s = cudf.Series(a) + assert len(manager.base_buffers()) == 0 + assert not s._data[None].data.spillable + + +def test_spilling_df_views(manager): + df = gen_df(target="cpu") + assert gen_df.is_spilled(df) + df_view = df.loc[1:] + assert gen_df.is_spillable(df_view) + assert gen_df.is_spillable(df) + + +def test_modify_spilled_views(manager): + df = gen_df() + df_view = df.iloc[1:] + buf = gen_df.buffer(df) + buf.__spill__(target="cpu") + + # modify the spilled df and check that the changes are reflected + # in the view + df.iloc[1:] = 0 + assert_eq(df_view, df.iloc[1:]) + + # now, modify the view and check that the changes are reflected in + # the df + df_view.iloc[:] = -1 + assert_eq(df_view, df.iloc[1:]) + + +def test_ptr_restricted(manager: SpillManager): + buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) + assert buf.spillable + assert len(buf._spill_locks) == 0 + slock1 = SpillLock() + buf.get_ptr(spill_lock=slock1) + assert not buf.spillable + assert len(buf._spill_locks) == 1 + slock2 = buf.spill_lock() + buf.get_ptr(spill_lock=slock2) + assert not buf.spillable + assert len(buf._spill_locks) == 2 + del slock1 + assert len(buf._spill_locks) == 1 + del slock2 + assert len(buf._spill_locks) == 0 + assert buf.spillable + + +def test_get_spill_lock(manager: SpillManager): + @with_spill_lock() + def f(sleep=False, nest=0): + if sleep: + time.sleep(random.random() / 100) + if nest: + return f(nest=nest - 1) + return get_spill_lock() + + assert get_spill_lock() is None + slock = f() + assert isinstance(slock, SpillLock) + assert get_spill_lock() is None + slock = f(nest=2) + assert isinstance(slock, SpillLock) + assert get_spill_lock() is None + + with ThreadPoolExecutor(max_workers=2) as executor: + futures_with_spill_lock = [] + futures_without_spill_lock = [] + for _ in range(100): + futures_with_spill_lock.append( + executor.submit(f, sleep=True, nest=1) + ) + futures_without_spill_lock.append( + executor.submit(f, sleep=True, nest=1) + ) + all(isinstance(f.result(), SpillLock) for f in futures_with_spill_lock) + all(f is None for f in futures_without_spill_lock) + + +def test_get_spill_lock_no_manager(): + """When spilling is disabled, get_spill_lock() should return None always""" + + @with_spill_lock() + def f(): + return get_spill_lock() + + assert get_spill_lock() is None + assert f() is None + + +@pytest.mark.parametrize("target", ["gpu", "cpu"]) +@pytest.mark.parametrize("view", [None, slice(0, 2), slice(1, 3)]) +def test_serialize_device(manager, target, view): + df1 = gen_df() + if view is not None: + df1 = df1.iloc[view] + gen_df.buffer(df1).__spill__(target=target) + + header, frames = df1.device_serialize() + assert len(frames) == 1 + if target == "gpu": + assert isinstance(frames[0], Buffer) + assert not gen_df.is_spilled(df1) + assert not gen_df.is_spillable(df1) + frames[0] = cupy.array(frames[0], copy=True) + else: + assert isinstance(frames[0], memoryview) + assert gen_df.is_spilled(df1) + assert gen_df.is_spillable(df1) + + df2 = Serializable.device_deserialize(header, frames) + assert_eq(df1, df2) + + +@pytest.mark.parametrize("target", ["gpu", "cpu"]) +@pytest.mark.parametrize("view", [None, slice(0, 2), slice(1, 3)]) +def test_serialize_host(manager, target, view): + df1 = gen_df() + if view is not None: + df1 = df1.iloc[view] + gen_df.buffer(df1).__spill__(target=target) + + # Unspilled df becomes spilled after host serialization + header, frames = df1.host_serialize() + assert all(isinstance(f, memoryview) for f in frames) + df2 = Serializable.host_deserialize(header, frames) + assert gen_df.is_spilled(df2) + assert_eq(df1, df2) + + +def test_serialize_dask_dataframe(manager: SpillManager): + protocol = pytest.importorskip("distributed.protocol") + + df1 = gen_df(target="gpu") + header, frames = protocol.serialize( + df1, serializers=("dask",), on_error="raise" + ) + buf: SpillableBuffer = gen_df.buffer(df1) + assert len(frames) == 1 + assert isinstance(frames[0], memoryview) + # Check that the memoryview and frames is the same memory + assert ( + np.array(buf.memoryview()).__array_interface__["data"] + == np.array(frames[0]).__array_interface__["data"] + ) + + df2 = protocol.deserialize(header, frames) + assert gen_df.is_spilled(df2) + assert_eq(df1, df2) + + +def test_serialize_cuda_dataframe(manager: SpillManager): + protocol = pytest.importorskip("distributed.protocol") + + df1 = gen_df(target="gpu") + header, frames = protocol.serialize( + df1, serializers=("cuda",), on_error="raise" + ) + buf: SpillableBuffer = gen_df.buffer(df1) + assert len(buf._spill_locks) == 1 + assert len(frames) == 1 + assert isinstance(frames[0], Buffer) + assert frames[0].ptr == buf.ptr + + frames[0] = cupy.array(frames[0], copy=True) + df2 = protocol.deserialize(header, frames) + assert_eq(df1, df2) + + +def test_get_rmm_memory_resource_stack(): + mr1 = rmm.mr.get_current_device_resource() + assert all( + not isinstance(m, rmm.mr.FailureCallbackResourceAdaptor) + for m in get_rmm_memory_resource_stack(mr1) + ) + + mr2 = rmm.mr.FailureCallbackResourceAdaptor(mr1, lambda x: False) + assert get_rmm_memory_resource_stack(mr2)[0] is mr2 + assert get_rmm_memory_resource_stack(mr2)[1] is mr1 + + mr3 = rmm.mr.FixedSizeMemoryResource(mr2) + assert get_rmm_memory_resource_stack(mr3)[0] is mr3 + assert get_rmm_memory_resource_stack(mr3)[1] is mr2 + assert get_rmm_memory_resource_stack(mr3)[2] is mr1 + + mr4 = rmm.mr.FailureCallbackResourceAdaptor(mr3, lambda x: False) + assert get_rmm_memory_resource_stack(mr4)[0] is mr4 + assert get_rmm_memory_resource_stack(mr4)[1] is mr3 + assert get_rmm_memory_resource_stack(mr4)[2] is mr2 + assert get_rmm_memory_resource_stack(mr4)[3] is mr1 + + +def test_df_transpose(manager: SpillManager): + df1 = cudf.DataFrame({"x": [1, 2]}) + df2 = df1.transpose() + # For now, all buffers are marked as exposed + assert df1._data._data["x"].data.exposed + assert df2._data._data[0].data.exposed + assert df2._data._data[1].data.exposed diff --git a/python/cudf/cudf/utils/utils.py b/python/cudf/cudf/utils/utils.py index c5f4629483a..65a86484207 100644 --- a/python/cudf/cudf/utils/utils.py +++ b/python/cudf/cudf/utils/utils.py @@ -294,7 +294,7 @@ def pa_mask_buffer_to_mask(mask_buf, size): dbuf = rmm.DeviceBuffer(size=mask_size) dbuf.copy_from_host(np.asarray(mask_buf).view("u1")) return as_buffer(dbuf) - return as_buffer(mask_buf) + return as_buffer(mask_buf, exposed=True) def _isnat(val): diff --git a/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx b/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx index 4fc9e473fa3..bf459f22c16 100644 --- a/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx +++ b/python/strings_udf/strings_udf/_lib/cudf_jit_udf.pyx @@ -24,7 +24,7 @@ def column_to_string_view_array(Column strings_col): c_buffer = move(cpp_to_string_view_array(input_view)) device_buffer = DeviceBuffer.c_from_unique_ptr(move(c_buffer)) - return as_buffer(device_buffer) + return as_buffer(device_buffer, exposed=True) def column_from_udf_string_array(DeviceBuffer d_buffer):