From 2e3bbe81dbc269504e84b035621938dc193406bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E7=AB=8B=E4=B8=9A=EF=BC=88Chris=20Fu=EF=BC=89?= <17433201@qq.com> Date: Thu, 30 Mar 2023 22:54:24 +0800 Subject: [PATCH] Fix result type of `Series.sum(...)` (#604) * Fix result type of `Series.sum(...)` * Fix note comment in `test_types_sum()` * Make tests in `test_types_sum()` a bit stricter --- pandas-stubs/core/series.pyi | 29 ++++++++++++++++++++++++++++- tests/test_series.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 32ad5cb5..71f532f9 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -70,7 +70,10 @@ from pandas.core.window.rolling import ( Rolling, Window, ) -from typing_extensions import TypeAlias +from typing_extensions import ( + Never, + TypeAlias, +) import xarray as xr from pandas._libs.interval import Interval @@ -1811,6 +1814,30 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): fill_value: float | None = ..., axis: AxisIndex | None = ..., ) -> Series[S1]: ... + # ignore needed because of mypy, for using `Never` as type-var. + @overload + def sum( + self: Series[Never], # type: ignore[type-var] + axis: AxisIndex | None = ..., + skipna: _bool | None = ..., + level: None = ..., + numeric_only: _bool = ..., + min_count: int = ..., + **kwargs, + ) -> Any: ... + # ignore needed because of mypy, for overlapping overloads + # between `Series[bool]` and `Series[int]`. + @overload + def sum( # type: ignore[misc] + self: Series[bool], + axis: AxisIndex | None = ..., + skipna: _bool | None = ..., + level: None = ..., + numeric_only: _bool = ..., + min_count: int = ..., + **kwargs, + ) -> int: ... + @overload def sum( self: Series[S1], axis: AxisIndex | None = ..., diff --git a/tests/test_series.py b/tests/test_series.py index 934c40c4..1bfb5e88 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -332,6 +332,36 @@ def test_types_sum() -> None: s.sum(numeric_only=False) s.sum(min_count=4) + # Note: + # 1. Return types of `series.groupby(...).sum(...)` are NOT tested + # (waiting for stubs). + # 2. Runtime return types of `series.sum(min_count=...)` are NOT + # tested (because of potential `nan`s). + + s0 = assert_type(pd.Series([1, 2, 3, np.nan]), "pd.Series") + check(assert_type(s0.sum(), "Any"), np.float64) + check(assert_type(s0.sum(skipna=False), "Any"), np.float64) + check(assert_type(s0.sum(numeric_only=False), "Any"), np.float64) + assert_type(s0.sum(min_count=4), "Any") + + s1 = assert_type(pd.Series([False, True], dtype=bool), "pd.Series[bool]") + check(assert_type(s1.sum(), "int"), np.int64) + check(assert_type(s1.sum(skipna=False), "int"), np.int64) + check(assert_type(s1.sum(numeric_only=False), "int"), np.int64) + assert_type(s1.sum(min_count=4), "int") + + s2 = assert_type(pd.Series([0, 1], dtype=int), "pd.Series[int]") + check(assert_type(s2.sum(), "int"), np.int64) + check(assert_type(s2.sum(skipna=False), "int"), np.int64) + check(assert_type(s2.sum(numeric_only=False), "int"), np.int64) + assert_type(s2.sum(min_count=4), "int") + + s3 = assert_type(pd.Series([1, 2, 3, np.nan], dtype=float), "pd.Series[float]") + check(assert_type(s3.sum(), "float"), np.float64) + check(assert_type(s3.sum(skipna=False), "float"), np.float64) + check(assert_type(s3.sum(numeric_only=False), "float"), np.float64) + assert_type(s3.sum(min_count=4), "float") + def test_types_cumsum() -> None: s = pd.Series([1, 2, 3, np.nan])