Skip to content

Commit

Permalink
Rewrites column.__setitem__, Use boolean_mask_scatter (#10202)
Browse files Browse the repository at this point in the history
closes #8667 

This PR rewrites `column.__setitem__` and calls `boolean_mask_scatter` if keys and values meet some criteria.
Benchmark shows in low-order problem size (10K ish), there are 30% speed up for aligned values and 10% for unaligned values. Note standard deviation of the unaligned case is quite high after refactor. For larger problem sizes performance is rather unaffected.

Benchmarks:
<details>

<summary>10K</summary>

```
---------------------------------------------------------------- benchmark 'boolean_mask_col_aligned': 2 tests -----------------------------------------------------------------
Name (time in ms)                                      Min               Max              Mean            StdDev            Median               IQR            Outliers  Rounds
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[boolean_mask_col_aligned] (afte)     1.2809 (1.0)      1.6781 (1.0)      1.3064 (1.0)      0.0364 (1.0)      1.2996 (1.0)      0.0137 (1.0)         22;34     761
column_setitem[boolean_mask_col_aligned] (befo)     1.7024 (1.33)     2.3863 (1.42)     1.7270 (1.32)     0.0523 (1.43)     1.7187 (1.32)     0.0138 (1.01)        20;31     563
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------- benchmark 'boolean_mask_col_unaligned': 2 tests --------------------------------------------------------------------------
Name (time in us)                                            Min                   Max                  Mean             StdDev                Median                IQR            Outliers  Rounds
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[boolean_mask_col_unaligned] (afte)       972.3390 (1.0)      1,520.7559 (1.29)     1,008.6033 (1.0)      77.0920 (9.45)       984.0354 (1.0)      13.2429 (1.45)       83;132     984
column_setitem[boolean_mask_col_unaligned] (befo)     1,106.3821 (1.14)     1,179.5689 (1.0)      1,118.7759 (1.11)      8.1539 (1.0)      1,116.2200 (1.13)      9.1530 (1.0)        149;30     874
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------- benchmark 'boolean_mask_scalar': 2 tests ---------------------------------------------------------------------
Name (time in us)                                   Min                 Max                Mean             StdDev              Median               IQR            Outliers  Rounds
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[boolean_mask_scalar] (afte)     532.2532 (1.0)      605.4689 (1.0)      542.2607 (1.0)      11.1111 (1.10)     537.9058 (1.0)      5.2921 (1.0)       178;213    1461
column_setitem[boolean_mask_scalar] (befo)     770.1530 (1.45)     863.1549 (1.43)     781.4038 (1.44)     10.0834 (1.0)      778.3370 (1.45)     7.0044 (1.32)       114;90    1219
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------- benchmark 'integer_scatter_map_col': 2 tests ----------------------------------------------------------------
Name (time in ms)                                     Min               Max              Mean            StdDev            Median               IQR            Outliers  Rounds
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[integer_scatter_map_col] (afte)     1.4785 (1.0)      1.9170 (1.21)     1.5176 (1.01)     0.0438 (3.91)     1.5098 (1.00)     0.0171 (1.37)        18;26     644
column_setitem[integer_scatter_map_col] (befo)     1.4882 (1.01)     1.5802 (1.0)      1.5084 (1.0)      0.0112 (1.0)      1.5068 (1.0)      0.0124 (1.0)        140;22     650
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------- benchmark 'integer_scatter_map_scalar': 2 tests ----------------------------------------------------------------------
Name (time in us)                                          Min                   Max                Mean             StdDev              Median               IQR            Outliers  Rounds
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[integer_scatter_map_scalar] (afte)     878.2479 (1.0)      1,343.0519 (1.39)     897.1496 (1.01)     27.5109 (3.23)     892.9770 (1.00)     7.8208 (1.0)         29;60    1074
column_setitem[integer_scatter_map_scalar] (befo)     879.3280 (1.00)       966.9410 (1.0)      890.8573 (1.0)       8.5287 (1.0)      888.7504 (1.0)      8.4981 (1.09)       171;50    1086
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------- benchmark 'stride-1_slice_col': 2 tests -----------------------------------------------------------------------
Name (time in us)                                  Min                   Max                Mean             StdDev              Median                IQR            Outliers  Rounds
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[stride-1_slice_col] (afte)     752.6411 (1.0)        852.3620 (1.0)      775.1790 (1.0)      10.9726 (1.0)      772.6796 (1.0)      14.1604 (1.32)       245;23    1152
column_setitem[stride-1_slice_col] (befo)     974.8179 (1.30)     1,307.2360 (1.53)     991.1287 (1.28)     25.8696 (2.36)     985.5330 (1.28)     10.6919 (1.0)         30;51     763
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------- benchmark 'stride-1_slice_scalar': 2 tests -------------------------------------------------------------------
Name (time in us)                                    Min                 Max               Mean            StdDev             Median               IQR            Outliers  Rounds
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[stride-1_slice_scalar] (afte)     87.3711 (1.16)     134.7861 (1.0)      89.7566 (1.15)     2.4861 (1.0)      89.5601 (1.16)     1.6061 (1.13)        95;87    2106
column_setitem[stride-1_slice_scalar] (befo)     75.3789 (1.0)      136.7659 (1.01)     78.0297 (1.0)      4.8186 (1.94)     77.0842 (1.0)      1.4230 (1.0)       122;184    2403
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------- benchmark 'stride-2_slice_col': 2 tests ----------------------------------------------------------------------
Name (time in us)                                  Min                 Max                Mean             StdDev              Median                IQR            Outliers  Rounds
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[stride-2_slice_col] (afte)     684.7882 (1.02)     972.5131 (1.01)     712.5983 (1.04)     54.4972 (3.04)     693.8996 (1.02)     10.3808 (1.18)      109;163    1338
column_setitem[stride-2_slice_col] (befo)     672.3758 (1.0)      964.4001 (1.0)      684.2917 (1.0)      17.9163 (1.0)      679.5955 (1.0)       8.7875 (1.0)        85;106    1368
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------- benchmark 'stride-2_slice_scalar': 2 tests --------------------------------------------------------------------
Name (time in us)                                     Min                 Max                Mean            StdDev              Median               IQR            Outliers  Rounds
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
column_setitem[stride-2_slice_scalar] (afte)     302.0421 (1.04)     374.3470 (1.0)      307.5677 (1.04)     4.3768 (1.0)      306.4690 (1.04)     2.4854 (1.17)      258;253    2532
column_setitem[stride-2_slice_scalar] (befo)     290.4800 (1.0)      378.1999 (1.01)     295.3950 (1.0)      4.6778 (1.07)     294.1729 (1.0)      2.1292 (1.0)       273;324    2977
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```
</details>

<details>

<summary> 1M </summary>

```
------------------------------ benchmark 'boolean_mask_col_aligned': 2 tests ------------------------------
Name (time in ms)                                       Min                Max               Mean          
-----------------------------------------------------------------------------------------------------------
column_setitem[boolean_mask_col_aligned] (afte)     75.3847 (1.0)      79.0559 (1.0)      76.2878 (1.0)    
column_setitem[boolean_mask_col_aligned] (befo)     76.3708 (1.01)     79.7394 (1.01)     77.0892 (1.01)   
-----------------------------------------------------------------------------------------------------------

------------------------------ benchmark 'boolean_mask_col_unaligned': 2 tests ------------------------------
Name (time in ms)                                         Min                Max               Mean          
-------------------------------------------------------------------------------------------------------------
column_setitem[boolean_mask_col_unaligned] (afte)     46.5199 (1.0)      48.3434 (1.0)      46.9222 (1.0)    
column_setitem[boolean_mask_col_unaligned] (befo)     46.6314 (1.00)     48.5938 (1.01)     47.1492 (1.00)   
-------------------------------------------------------------------------------------------------------------

------------------------------ benchmark 'boolean_mask_scalar': 2 tests ------------------------------
Name (time in ms)                                  Min                Max               Mean          
------------------------------------------------------------------------------------------------------
column_setitem[boolean_mask_scalar] (afte)     17.0548 (1.0)      17.8006 (1.0)      17.5430 (1.0)    
column_setitem[boolean_mask_scalar] (befo)     18.4329 (1.08)     18.6918 (1.05)     18.5073 (1.05)   
------------------------------------------------------------------------------------------------------

-------------------------------- benchmark 'integer_scatter_map_col': 2 tests -------------------------------
Name (time in ms)                                       Min                 Max                Mean          
-------------------------------------------------------------------------------------------------------------
column_setitem[integer_scatter_map_col] (afte)     115.7189 (1.01)     120.0585 (1.0)      116.6452 (1.0)    
column_setitem[integer_scatter_map_col] (befo)     114.7481 (1.0)      122.7263 (1.02)     117.5000 (1.01)   
-------------------------------------------------------------------------------------------------------------

------------------------------ benchmark 'integer_scatter_map_scalar': 2 tests ------------------------------
Name (time in ms)                                         Min                Max               Mean          
-------------------------------------------------------------------------------------------------------------
column_setitem[integer_scatter_map_scalar] (afte)     57.9951 (1.0)      62.2284 (1.0)      59.8864 (1.0)    
column_setitem[integer_scatter_map_scalar] (befo)     60.9071 (1.05)     62.2952 (1.00)     61.6422 (1.03)   
-------------------------------------------------------------------------------------------------------------

------------------------------ benchmark 'stride-1_slice_col': 2 tests ------------------------------
Name (time in ms)                                 Min                Max               Mean          
-----------------------------------------------------------------------------------------------------
column_setitem[stride-1_slice_col] (afte)     56.9203 (1.0)      58.0924 (1.0)      57.4940 (1.0)    
column_setitem[stride-1_slice_col] (befo)     58.1888 (1.02)     59.2996 (1.02)     58.5722 (1.02)   
-----------------------------------------------------------------------------------------------------

-------------------------------- benchmark 'stride-1_slice_scalar': 2 tests -------------------------------
Name (time in us)                                     Min                 Max                Mean          
-----------------------------------------------------------------------------------------------------------
column_setitem[stride-1_slice_scalar] (afte)     287.1200 (1.08)     515.4130 (1.24)     298.1191 (1.09)   
column_setitem[stride-1_slice_scalar] (befo)     265.6982 (1.0)      415.9641 (1.0)      273.5206 (1.0)    
-----------------------------------------------------------------------------------------------------------

------------------------------ benchmark 'stride-2_slice_col': 2 tests ------------------------------
Name (time in ms)                                 Min                Max               Mean          
-----------------------------------------------------------------------------------------------------
column_setitem[stride-2_slice_col] (afte)     29.1543 (1.01)     31.0217 (1.00)     29.6341 (1.00)   
column_setitem[stride-2_slice_col] (befo)     28.9718 (1.0)      30.8824 (1.0)      29.5045 (1.0)    
-----------------------------------------------------------------------------------------------------

--------------------------------- benchmark 'stride-2_slice_scalar': 2 tests --------------------------------
Name (time in us)                                     Min                   Max                Mean          
-------------------------------------------------------------------------------------------------------------
column_setitem[stride-2_slice_scalar] (afte)     780.7089 (1.00)     1,407.5749 (1.0)      817.0583 (1.0)    
column_setitem[stride-2_slice_scalar] (befo)     777.0571 (1.0)      2,036.4628 (1.45)     832.0608 (1.02)   
-------------------------------------------------------------------------------------------------------------
```
</details>

Authors:
  - Michael Wang (https://github.com/isVoid)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #10202
  • Loading branch information
isVoid authored Feb 23, 2022
1 parent 496f452 commit a72479f
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 99 deletions.
56 changes: 30 additions & 26 deletions python/cudf/cudf/_lib/copying.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

import pickle

Expand Down Expand Up @@ -537,11 +537,11 @@ def copy_if_else(object lhs, object rhs, Column boolean_mask):
as_device_scalar(lhs), as_device_scalar(rhs), boolean_mask)


def _boolean_mask_scatter_table(input_table, target_table,
Column boolean_mask):
def _boolean_mask_scatter_columns(list input_columns, list target_columns,
Column boolean_mask):

cdef table_view input_table_view = table_view_from_columns(input_table)
cdef table_view target_table_view = table_view_from_columns(target_table)
cdef table_view input_table_view = table_view_from_columns(input_columns)
cdef table_view target_table_view = table_view_from_columns(target_columns)
cdef column_view boolean_mask_view = boolean_mask.view()

cdef unique_ptr[table] c_result
Expand All @@ -555,14 +555,10 @@ def _boolean_mask_scatter_table(input_table, target_table,
)
)

return data_from_unique_ptr(
move(c_result),
column_names=target_table._column_names,
index_names=target_table._index._column_names
)
return columns_from_unique_ptr(move(c_result))


def _boolean_mask_scatter_scalar(list input_scalars, target_table,
def _boolean_mask_scatter_scalar(list input_scalars, list target_columns,
Column boolean_mask):

cdef vector[reference_wrapper[constscalar]] input_scalar_vector
Expand All @@ -571,7 +567,7 @@ def _boolean_mask_scatter_scalar(list input_scalars, target_table,
for scl in input_scalars:
input_scalar_vector.push_back(reference_wrapper[constscalar](
scl.get_raw_ptr()[0]))
cdef table_view target_table_view = table_view_from_columns(target_table)
cdef table_view target_table_view = table_view_from_columns(target_columns)
cdef column_view boolean_mask_view = boolean_mask.view()

cdef unique_ptr[table] c_result
Expand All @@ -585,29 +581,37 @@ def _boolean_mask_scatter_scalar(list input_scalars, target_table,
)
)

return data_from_unique_ptr(
move(c_result),
column_names=target_table._column_names,
index_names=target_table._index._column_names
)
return columns_from_unique_ptr(move(c_result))


# TODO: This function is currently unused but should be used in
# ColumnBase.__setitem__, see https://github.com/rapidsai/cudf/issues/8667.
def boolean_mask_scatter(object input, target_table,
def boolean_mask_scatter(list input_, list target_columns,
Column boolean_mask):
"""Copy the target columns, replacing masked rows with input data.
The ``input_`` data can be a list of columns or as a list of scalars.
A list of input columns will be used to replace corresponding rows in the
target columns for which the boolean mask is ``True``. For the nth ``True``
in the boolean mask, the nth row in ``input_`` is used to replace. A list
of input scalars will replace all rows in the target columns for which the
boolean mask is ``True``.
"""
if len(input_) != len(target_columns):
raise ValueError("Mismatched number of input and target columns.")

if len(input_) == 0:
return []

if isinstance(input, cudf.core.frame.Frame):
return _boolean_mask_scatter_table(
input,
target_table,
if isinstance(input_[0], Column):
return _boolean_mask_scatter_columns(
input_,
target_columns,
boolean_mask
)
else:
scalar_list = [as_device_scalar(i) for i in input]
scalar_list = [as_device_scalar(i) for i in input_]
return _boolean_mask_scatter_scalar(
scalar_list,
target_table,
target_columns,
boolean_mask
)

Expand Down
178 changes: 110 additions & 68 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@
from cudf.utils.utils import NotIterable, mask_dtype

T = TypeVar("T", bound="ColumnBase")
# TODO: This workaround allows type hints for `slice`, since `slice` is a
# method in ColumnBase.
Slice = TypeVar("Slice", bound=slice)


class ColumnBase(Column, Serializable, NotIterable):
Expand Down Expand Up @@ -317,22 +320,24 @@ def _fill(
if end <= begin or begin >= self.size:
return self if inplace else self.copy()

fill_scalar = as_device_scalar(fill_value, self.dtype)
# Constructing a cuDF scalar can cut unnecessary DtoH copy if
# the scalar is None when calling `is_valid`.
slr = cudf.Scalar(fill_value, dtype=self.dtype)

if not inplace:
return libcudf.filling.fill(self, begin, end, fill_scalar)
return libcudf.filling.fill(self, begin, end, slr.device_value)

if is_string_dtype(self.dtype):
return self._mimic_inplace(
libcudf.filling.fill(self, begin, end, fill_scalar),
libcudf.filling.fill(self, begin, end, slr.device_value),
inplace=True,
)

if fill_value is None and not self.nullable:
if not slr.is_valid() and not self.nullable:
mask = create_null_mask(self.size, state=MaskState.ALL_VALID)
self.set_base_mask(mask)

libcudf.filling.fill_in_place(self, begin, end, fill_scalar)
libcudf.filling.fill_in_place(self, begin, end, slr.device_value)

return self

Expand Down Expand Up @@ -491,82 +496,119 @@ def __getitem__(self, arg) -> Union[ScalarLike, ColumnBase]:

def __setitem__(self, key: Any, value: Any):
"""
Set the value of self[key] to value.
Set the value of ``self[key]`` to ``value``.
If value and self are of different types,
value is coerced to self.dtype
If ``value`` and ``self`` are of different types, ``value`` is coerced
to ``self.dtype``. Assumes ``self`` and ``value`` are index-aligned.
"""

# Normalize value to scalar/column
value_normalized = (
cudf.Scalar(value, dtype=self.dtype)
if is_scalar(value)
else as_column(value, dtype=self.dtype)
)

out: Optional[ColumnBase] # If None, no need to perform mimic inplace.
if isinstance(key, slice):
key_start, key_stop, key_stride = key.indices(len(self))
if key_start < 0:
key_start = key_start + len(self)
if key_stop < 0:
key_stop = key_stop + len(self)
if key_start >= key_stop:
return self.copy()
if (key_stride is None or key_stride == 1) and is_scalar(value):
return self._fill(value, key_start, key_stop, inplace=True)
if key_stride != 1 or key_stride is not None or is_scalar(value):
key = arange(
start=key_start,
stop=key_stop,
step=key_stride,
dtype=cudf.dtype(np.int32),
)
nelem = len(key)
else:
nelem = abs(key_stop - key_start)
out = self._scatter_by_slice(key, value_normalized)
else:
key = as_column(key)
if is_bool_dtype(key.dtype):
if not len(key) == len(self):
raise ValueError(
"Boolean mask must be of same length as column"
)
key = arange(len(self))[key]
if hasattr(value, "__len__") and len(value) == len(self):
value = as_column(value)[key]
nelem = len(key)
if not isinstance(key, cudf.core.column.NumericalColumn):
raise ValueError(f"Invalid scatter map type {key.dtype}.")
out = self._scatter_by_column(key, value_normalized)

if is_scalar(value):
value = cudf.Scalar(value, dtype=self.dtype)
else:
if len(value) != nelem:
msg = (
f"Size mismatch: cannot set value "
f"of size {len(value)} to indexing result of size "
f"{nelem}"
if out:
self._mimic_inplace(out, inplace=True)

def _scatter_by_slice(
self, key: Slice, value: Union[cudf.core.scalar.Scalar, ColumnBase]
) -> Optional[ColumnBase]:
"""If this function returns None, it's either a no-op (slice is empty),
or the inplace replacement is already performed (fill-in-place).
"""
start, stop, step = key.indices(len(self))
if start >= stop:
return None
num_keys = (stop - start) // step

self._check_scatter_key_length(num_keys, value)

if step == 1:
if isinstance(value, cudf.core.scalar.Scalar):
return self._fill(value, start, stop, inplace=True)
else:
return libcudf.copying.copy_range(
value, self, 0, num_keys, start, stop, False
)
raise ValueError(msg)
value = as_column(value).astype(self.dtype)

if (
isinstance(key, slice)
and (key_stride == 1 or key_stride is None)
and not is_scalar(value)
):
# step != 1, create a scatter map with arange
scatter_map = arange(
start=start, stop=stop, step=step, dtype=cudf.dtype(np.int32),
)

out = libcudf.copying.copy_range(
value, self, 0, nelem, key_start, key_stop, False
)
return self._scatter_by_column(scatter_map, value)

def _scatter_by_column(
self,
key: cudf.core.column.NumericalColumn,
value: Union[cudf.core.scalar.Scalar, ColumnBase],
) -> ColumnBase:
if is_bool_dtype(key.dtype):
# `key` is boolean mask
if len(key) != len(self):
raise ValueError(
"Boolean mask must be of same length as column"
)
if isinstance(value, ColumnBase) and len(self) == len(value):
# Both value and key are aligned to self. Thus, the values
# corresponding to the false values in key should be
# ignored.
value = value.apply_boolean_mask(key)
# After applying boolean mask, the length of value equals
# the number of elements to scatter, we can skip computing
# the sum of ``key`` below.
num_keys = len(value)
else:
# Compute the number of element to scatter by summing all
# `True`s in the boolean mask.
num_keys = key.sum()
else:
try:
if not isinstance(key, Column):
key = as_column(key)
if not is_scalar(value) and not isinstance(value, Column):
value = as_column(value)
out = libcudf.copying.scatter(
# `key` is integer scatter map
num_keys = len(key)

self._check_scatter_key_length(num_keys, value)

try:
if is_bool_dtype(key.dtype):
return libcudf.copying.boolean_mask_scatter(
[value], [self], key
)[0]._with_type_metadata(self.dtype)
else:
return libcudf.copying.scatter(
value, key, self
)._with_type_metadata(self.dtype)
except RuntimeError as e:
if "out of bounds" in str(e):
raise IndexError(
f"index out of bounds for column of size {len(self)}"
) from e
raise
except RuntimeError as e:
if "out of bounds" in str(e):
raise IndexError(
f"index out of bounds for column of size {len(self)}"
) from e
raise

self._mimic_inplace(out, inplace=True)
def _check_scatter_key_length(
self, num_keys: int, value: Union[cudf.core.scalar.Scalar, ColumnBase]
):
"""`num_keys` is the number of keys to scatter. Should equal to the
number of rows in ``value`` if ``value`` is a column.
"""
if isinstance(value, ColumnBase):
if len(value) != num_keys:
msg = (
f"Size mismatch: cannot set value "
f"of size {len(value)} to indexing result of size "
f"{num_keys}"
)
raise ValueError(msg)

def fillna(
self: T,
Expand Down Expand Up @@ -2247,7 +2289,7 @@ def arange(
stop: Union[int, float] = None,
step: Union[int, float] = 1,
dtype=None,
) -> ColumnBase:
) -> cudf.core.column.NumericalColumn:
"""
Returns a column with evenly spaced values within a given interval.
Expand Down
2 changes: 2 additions & 0 deletions python/cudf/cudf/core/df_protocol.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

import collections
import enum
from typing import (
Expand Down
16 changes: 12 additions & 4 deletions python/cudf/cudf/tests/test_df_protocol.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

from typing import Any, Tuple

import cupy as cp
Expand Down Expand Up @@ -27,13 +29,19 @@ def assert_buffer_equal(buffer_and_dtype: Tuple[_CuDFBuffer, Any], cudfcol):
)
# check that non null values are the equals as nulls are represented
# by sentinel values in the buffer.
non_null_idxs = cudf.Series(cudfcol) != cudf.NA
# FIXME: In gh-10202 some minimal fixes were added to unblock CI. But
# currently only non-null values are compared, null positions are
# unchecked.
non_null_idxs = ~cudf.Series(cudfcol).isna()
assert_eq(col_from_buf[non_null_idxs], cudfcol[non_null_idxs])

if dtype[0] != _DtypeKind.BOOL:
array_from_dlpack = cp.fromDlpack(buf.__dlpack__())
col_array = cp.asarray(cudfcol.data_array_view)
assert_eq(array_from_dlpack.flatten(), col_array.flatten())
array_from_dlpack = cp.fromDlpack(buf.__dlpack__()).get()
col_array = cp.asarray(cudfcol.data_array_view).get()
assert_eq(
array_from_dlpack[non_null_idxs.to_numpy()].flatten(),
col_array[non_null_idxs.to_numpy()].flatten(),
)
else:
pytest.raises(TypeError, buf.__dlpack__)

Expand Down
Loading

0 comments on commit a72479f

Please sign in to comment.