From f567cf5b100ba71934d08ac3c1c9648ce09594e1 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 21 Mar 2023 21:40:34 +0000 Subject: [PATCH] Correctly handle scalar indices in `Index.__getitem__` (#12955) It is not sufficient to check for isinstance(i, int) since the index may be a numpy type for which this check is False. Instead, invert the condition and check if the return value from _get_elements_from_column is a Column, in which case we should get an Index back. Closes #12954. Authors: - Lawrence Mitchell (https://github.com/wence-) - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/12955 --- python/cudf/cudf/core/cut.py | 10 +++------- python/cudf/cudf/core/index.py | 2 +- python/cudf/cudf/tests/test_index.py | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/python/cudf/cudf/core/cut.py b/python/cudf/cudf/core/cut.py index 6590cf2940d..ccf730c91fb 100644 --- a/python/cudf/cudf/core/cut.py +++ b/python/cudf/cudf/core/cut.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2023, NVIDIA CORPORATION. from collections import abc @@ -279,12 +279,8 @@ def cut( if labels is not None: if labels is not ordered and len(set(labels)) != len(labels): # when we have duplicate labels and ordered is False, we - # should allow duplicate categories. The categories are - # returned in order - new_data = [interval_labels[i][0] for i in index_labels.values] - return cudf.CategoricalIndex( - new_data, categories=sorted(set(labels)), ordered=False - ) + # should allow duplicate categories. + return interval_labels[index_labels] col = build_categorical_column( categories=interval_labels, diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 413e005b798..d1408fec160 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -1403,7 +1403,7 @@ def __repr__(self): @_cudf_nvtx_annotate def __getitem__(self, index): res = self._get_elements_from_column(index) - if not isinstance(index, int): + if isinstance(res, ColumnBase): res = as_index(res) res.name = self.name return res diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index d043b917251..0b0c5fba7fa 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -2886,3 +2886,22 @@ def test_index_to_pandas_nullable(data, expected_dtype): expected = pd.Index(data, dtype=expected_dtype) assert_eq(pi, expected) + + +class TestIndexScalarGetItem: + @pytest.fixture( + params=[range(1, 10, 2), [1, 2, 3], ["a", "b", "c"], [1.5, 2.5, 3.5]] + ) + def index_values(self, request): + return request.param + + @pytest.fixture(params=[int, np.int8, np.int32, np.int64]) + def i(self, request): + return request.param(1) + + def test_scalar_getitem(self, index_values, i): + index = cudf.Index(index_values) + + assert not isinstance(index[i], cudf.Index) + assert index[i] == index_values[i] + assert_eq(index, index.to_pandas())