Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Fix Slicing issue with categorical column in DataFrame #4683

Merged
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@
- PR #4630 Remove dangling reference to RMM exec policy in drop duplicates tests.
- PR #4625 Fix hash-based repartition bug in dask_cudf
- PR #4662 Fix to handle `keep_index` in `partition_by_hash`
- PR #4683 Fix Slicing issue with categorical column in DataFrame
- PR #4676 Fix bug in `_shuffle_group` for repartition


Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/_libxx/column.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ cdef class Column:

cdef libcudf_types.type_id tid = np_to_cudf_types[np.dtype(data_dtype)]
cdef libcudf_types.data_type dtype = libcudf_types.data_type(tid)
cdef libcudf_types.size_type offset = self.offset
cdef libcudf_types.size_type offset = col.offset
cdef vector[mutable_column_view] children
cdef void* data

Expand Down Expand Up @@ -353,7 +353,7 @@ cdef class Column:
data_dtype = col.dtype
cdef libcudf_types.type_id tid = np_to_cudf_types[np.dtype(data_dtype)]
cdef libcudf_types.data_type dtype = libcudf_types.data_type(tid)
cdef libcudf_types.size_type offset = self.offset
cdef libcudf_types.size_type offset = col.offset
cdef vector[column_view] children
cdef void* data

Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ def children(self):
codes_column = column.build_column(
data=codes_column.base_data,
dtype=codes_column.dtype,
mask=codes_column.base_mask,
mask=self.base_mask,
size=self.size,
offset=self.offset + codes_column.offset,
offset=self.offset,
)
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
self._children = (codes_column,)
return self._children
Expand All @@ -347,7 +347,7 @@ def categories(self, value):
@property
def codes(self):
if self._codes is None:
self._codes = self.children[0].set_mask(self.mask)
self._codes = self.children[0]
return self._codes

@property
Expand Down
18 changes: 10 additions & 8 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,13 @@ def __getitem__(self, arg):

if is_categorical_dtype(self):
codes = self.codes[arg]
return build_column(
data=None,
dtype=self.dtype,
mask=codes.mask,
children=(codes,),
return build_categorical_column(
categories=self.categories,
codes=as_column(codes.base_data, dtype=codes.dtype),
mask=codes.base_mask,
ordered=self.ordered,
size=codes.size,
offset=codes.offset,
)

start, stop, stride = arg.indices(len(self))
Expand Down Expand Up @@ -547,7 +549,7 @@ def __setitem__(self, key, value):
if is_categorical_dtype(value.dtype):
out = build_categorical_column(
categories=value.categories,
codes=out,
codes=as_column(out.base_data, dtype=out.dtype),
mask=out.base_mask,
size=out.size,
offset=out.offset,
Expand All @@ -558,14 +560,14 @@ def __setitem__(self, key, value):
if is_scalar(value):
input = self
if is_categorical_dtype(self.dtype):
input = self.codes
input = self.children[0]
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved

out = input.as_frame()._scatter(key, [value])._as_column()

if is_categorical_dtype(self.dtype):
out = build_categorical_column(
categories=self.categories,
codes=out,
codes=as_column(out.base_data, dtype=out.dtype),
mask=out.base_mask,
size=out.size,
offset=out.offset,
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def _copy_categories(self, other, include_index=True):
):
self._data[name] = build_categorical_column(
categories=other_col.categories,
codes=col,
codes=as_column(col.base_data, dtype=col.dtype),
mask=col.base_mask,
ordered=other_col.ordered,
size=col.size,
Expand Down
10 changes: 10 additions & 0 deletions python/cudf/cudf/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,13 @@ def test_categorical_remove_categories(pd_str_cat, inplace):
cd_sr_1 = cd_sr.cat.remove_categories(["a", "d"], inplace=inplace)

raises.match("removals must all be in old categories")


def test_categorical_dataframe_slice_copy():
pdf = pd.DataFrame({"g": pd.Series(["a", "b", "z"], dtype="category")})
gdf = DataFrame.from_pandas(pdf)

exp = pdf[1:].copy()
gdf = gdf[1:].copy()

assert_eq(exp, gdf)