Skip to content

Commit

Permalink
Correctly handle scalar indices in Index.__getitem__ (#12955)
Browse files Browse the repository at this point in the history
It is not sufficient to check for isinstance(i, int) since the index may be a numpy type for which this check is False. Instead, invert the condition and check if the return value from _get_elements_from_column is a Column, in which case we should get an Index back.

Closes #12954.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #12955
  • Loading branch information
wence- authored Mar 21, 2023
1 parent 832dd27 commit f567cf5
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
10 changes: 3 additions & 7 deletions python/cudf/cudf/core/cut.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION.
# Copyright (c) 2021-2023, NVIDIA CORPORATION.

from collections import abc

Expand Down Expand Up @@ -279,12 +279,8 @@ def cut(
if labels is not None:
if labels is not ordered and len(set(labels)) != len(labels):
# when we have duplicate labels and ordered is False, we
# should allow duplicate categories. The categories are
# returned in order
new_data = [interval_labels[i][0] for i in index_labels.values]
return cudf.CategoricalIndex(
new_data, categories=sorted(set(labels)), ordered=False
)
# should allow duplicate categories.
return interval_labels[index_labels]

col = build_categorical_column(
categories=interval_labels,
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ def __repr__(self):
@_cudf_nvtx_annotate
def __getitem__(self, index):
res = self._get_elements_from_column(index)
if not isinstance(index, int):
if isinstance(res, ColumnBase):
res = as_index(res)
res.name = self.name
return res
Expand Down
19 changes: 19 additions & 0 deletions python/cudf/cudf/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,3 +2886,22 @@ def test_index_to_pandas_nullable(data, expected_dtype):
expected = pd.Index(data, dtype=expected_dtype)

assert_eq(pi, expected)


class TestIndexScalarGetItem:
@pytest.fixture(
params=[range(1, 10, 2), [1, 2, 3], ["a", "b", "c"], [1.5, 2.5, 3.5]]
)
def index_values(self, request):
return request.param

@pytest.fixture(params=[int, np.int8, np.int32, np.int64])
def i(self, request):
return request.param(1)

def test_scalar_getitem(self, index_values, i):
index = cudf.Index(index_values)

assert not isinstance(index[i], cudf.Index)
assert index[i] == index_values[i]
assert_eq(index, index.to_pandas())

0 comments on commit f567cf5

Please sign in to comment.