From a2ecf3190dad78e43f77d6132d835ef22dfb3382 Mon Sep 17 00:00:00 2001 From: Dmitry Chigarev <62142979+dchigarev@users.noreply.github.com> Date: Mon, 15 Feb 2021 15:33:28 +0300 Subject: [PATCH] FIX-#2362: fix key handling in 'Series.__setitem__' (#2731) Signed-off-by: Dmitry Chigarev --- modin/pandas/series.py | 15 ++++++++++----- modin/pandas/test/test_series.py | 26 +++++++++++++++++++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/modin/pandas/series.py b/modin/pandas/series.py index ecbd8d0b78a..febd84e3055 100644 --- a/modin/pandas/series.py +++ b/modin/pandas/series.py @@ -285,11 +285,16 @@ def __round__(self, decimals=0): ) def __setitem__(self, key, value): - if key not in self.keys(): - raise KeyError(key) - self._create_or_update_from_compiler( - self._query_compiler.setitem(1, key, value), inplace=True - ) + if isinstance(key, slice) and ( + isinstance(key.start, int) or isinstance(key.stop, int) + ): + # There could be two type of slices: + # - Location based slice (1:5) + # - Labels based slice ("a":"e") + # For location based slice we're going to `iloc`, since `loc` can't manage it. + self.iloc[key] = value + else: + self.loc[key] = value def __sub__(self, right): return self.sub(right) diff --git a/modin/pandas/test/test_series.py b/modin/pandas/test/test_series.py index 49246ae80ee..4dccd3bc5fd 100644 --- a/modin/pandas/test/test_series.py +++ b/modin/pandas/test/test_series.py @@ -180,13 +180,13 @@ def inter_df_math_helper_one_side(modin_series, pandas_series, op): pass -def create_test_series(vals, sort=False): +def create_test_series(vals, sort=False, **kwargs): if isinstance(vals, dict): - modin_series = pd.Series(vals[next(iter(vals.keys()))]) - pandas_series = pandas.Series(vals[next(iter(vals.keys()))]) + modin_series = pd.Series(vals[next(iter(vals.keys()))], **kwargs) + pandas_series = pandas.Series(vals[next(iter(vals.keys()))], **kwargs) else: - modin_series = pd.Series(vals) - pandas_series = pandas.Series(vals) + modin_series = pd.Series(vals, **kwargs) + pandas_series = pandas.Series(vals, **kwargs) if sort: modin_series = modin_series.sort_values().reset_index(drop=True) pandas_series = pandas_series.sort_values().reset_index(drop=True) @@ -526,6 +526,22 @@ def test___setitem__(data): df_equals(modin_series, pandas_series) +@pytest.mark.parametrize( + "key", + [ + pytest.param(slice(1, 3), id="numeric_slice"), + pytest.param(slice("a", "c"), id="index_based_slice"), + pytest.param(["a", "c", "e"], id="list_of_labels"), + pytest.param([True, False, True, False, True], id="boolean_mask"), + ], +) +def test___setitem___non_hashable(key): + md_sr, pd_sr = create_test_series([1, 2, 3, 4, 5], index=["a", "b", "c", "d", "e"]) + md_sr[key] = 10 + pd_sr[key] = 10 + df_equals(md_sr, pd_sr) + + @pytest.mark.parametrize("data", test_data_values, ids=test_data_keys) def test___sizeof__(data): modin_series, pandas_series = create_test_series(data)