Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly handle scalar indices in Index.__getitem__ #12955

Merged
merged 4 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions python/cudf/cudf/core/cut.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2023, NVIDIA CORPORATION.

from collections import abc

Expand Down Expand Up @@ -279,12 +279,8 @@ def cut(
if labels is not None:
if labels is not ordered and len(set(labels)) != len(labels):
# when we have duplicate labels and ordered is False, we
# should allow duplicate categories. The categories are
# returned in order
new_data = [interval_labels[i][0] for i in index_labels.values]
return cudf.CategoricalIndex(
new_data, categories=sorted(set(labels)), ordered=False
)
# should allow duplicate categories.
return interval_labels[index_labels]
wence- marked this conversation as resolved.
Show resolved Hide resolved

col = build_categorical_column(
categories=interval_labels,
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ def __repr__(self):
@_cudf_nvtx_annotate
def __getitem__(self, index):
res = self._get_elements_from_column(index)
if not isinstance(index, int):
if isinstance(res, ColumnBase):
wence- marked this conversation as resolved.
Show resolved Hide resolved
res = as_index(res)
res.name = self.name
return res
Expand Down
17 changes: 17 additions & 0 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,3 +2886,20 @@ def test_index_to_pandas_nullable(data, expected_dtype):
expected = pd.Index(data, dtype=expected_dtype)

assert_eq(pi, expected)


class TestIndexScalarGetItem:
@pytest.fixture(
params=[range(1, 10, 2), [1, 2, 3], ["a", "b", "c"], [1.5, 2.5, 3.5]]
)
def index_values(self, request):
return request.param

@pytest.fixture(params=[int, np.int8, np.int32, np.int64])
def i(self, request):
return request.param(1)

def test_scalar_getitem(self, index_values, i):
index = cudf.Index(index_values)

assert not isinstance(index[i], cudf.Index)
wence- marked this conversation as resolved.
Show resolved Hide resolved