From f55107bf5bb4ee458b2887a9e94e14323668a37d Mon Sep 17 00:00:00 2001 From: galipremsagar Date: Wed, 15 May 2024 18:59:17 +0000 Subject: [PATCH] Fix DatetimeIndex.loc --- python/cudf/cudf/core/indexed_frame.py | 20 +++++-- python/cudf/cudf/tests/test_indexing.py | 74 +++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 6 deletions(-) diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index 904cd0c69c2..ec922b6e24c 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -195,7 +195,6 @@ def _get_label_range_or_mask(index, start, stop, step): if ( not (start is None and stop is None) and type(index) is cudf.core.index.DatetimeIndex - and index.is_monotonic_increasing is False ): start = pd.to_datetime(start) stop = pd.to_datetime(stop) @@ -206,8 +205,8 @@ def _get_label_range_or_mask(index, start, stop, step): # when we have a non-monotonic datetime index, return # values in the slice defined by index_of(start) and # index_of(end) - start_loc = index.get_loc(start.to_datetime64()) - stop_loc = index.get_loc(stop.to_datetime64()) + 1 + start_loc = index.get_loc(start) + stop_loc = index.get_loc(stop) + 1 return slice(start_loc, stop_loc) else: raise KeyError( @@ -215,10 +214,19 @@ def _get_label_range_or_mask(index, start, stop, step): "DatetimeIndexes with non-existing keys is not allowed.", ) elif start is not None: - boolean_mask = index >= start + if index.is_monotonic_increasing: + return index >= start + elif index.is_monotonic_decreasing: + return index <= start + else: + return index.find_label_range(slice(start, stop, step)) else: - boolean_mask = index <= stop - return boolean_mask + if index.is_monotonic_increasing: + return index <= stop + elif index.is_monotonic_decreasing: + return index >= stop + else: + return index.find_label_range(slice(start, stop, step)) else: return index.find_label_range(slice(start, stop, step)) diff --git a/python/cudf/cudf/tests/test_indexing.py b/python/cudf/cudf/tests/test_indexing.py index f49b9b02076..fcea945f595 100644 --- a/python/cudf/cudf/tests/test_indexing.py +++ b/python/cudf/cudf/tests/test_indexing.py @@ -2264,3 +2264,77 @@ def test_loc_setitem_empty_dataframe(): gdf.loc[["index_1"], "new_col"] = "A" assert_eq(pdf, gdf) + + +@pytest.mark.parametrize( + "data", + [ + [15, 14, 12, 10, 1], + [1, 10, 12, 14, 15], + ], +) +@pytest.mark.parametrize( + "scalar", + [ + 1, + 10, + 15, + 14, + 0, + 2, + ], +) +def test_loc_datetime_monotonic_with_ts(data, scalar): + gdf = cudf.DataFrame( + {"a": [1, 1, 1, 2, 2], "b": [1, 2, 3, 4, 5]}, + index=cudf.Index(data, dtype="datetime64[ns]"), + ) + pdf = gdf.to_pandas() + + i = pd.Timestamp(scalar) + + actual = gdf.loc[i:] + expected = pdf.loc[i:] + + assert_eq(actual, expected) + + actual = gdf.loc[:i] + expected = pdf.loc[:i] + + assert_eq(actual, expected) + + +@pytest.mark.parametrize("data", [[15, 14, 3, 10, 1]]) +@pytest.mark.parametrize("scalar", [1, 10, 15, 14, 0, 2]) +def test_loc_datetime_random_with_ts(data, scalar): + gdf = cudf.DataFrame( + {"a": [1, 1, 1, 2, 2], "b": [1, 2, 3, 4, 5]}, + index=cudf.Index(data, dtype="datetime64[ns]"), + ) + pdf = gdf.to_pandas() + + i = pd.Timestamp(scalar) + + if i not in pdf.index: + assert_exceptions_equal( + lambda: pdf.loc[i:], + lambda: gdf.loc[i:], + lfunc_args_and_kwargs=([],), + rfunc_args_and_kwargs=([],), + ) + assert_exceptions_equal( + lambda: pdf.loc[:i], + lambda: gdf.loc[:i], + lfunc_args_and_kwargs=([],), + rfunc_args_and_kwargs=([],), + ) + else: + actual = gdf.loc[i:] + expected = pdf.loc[i:] + + assert_eq(actual, expected) + + actual = gdf.loc[:i] + expected = pdf.loc[:i] + + assert_eq(actual, expected)