Skip to content

Commit

Permalink
Fix result type of Series.sum(...) (#604)
Browse files Browse the repository at this point in the history
* Fix result type of `Series.sum(...)`

* Fix note comment in `test_types_sum()`

* Make tests in `test_types_sum()` a bit stricter
  • Loading branch information
Azureblade3808 authored Mar 30, 2023
1 parent e4731ae commit 2e3bbe8
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
29 changes: 28 additions & 1 deletion pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ from pandas.core.window.rolling import (
Rolling,
Window,
)
from typing_extensions import TypeAlias
from typing_extensions import (
Never,
TypeAlias,
)
import xarray as xr

from pandas._libs.interval import Interval
Expand Down Expand Up @@ -1811,6 +1814,30 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
fill_value: float | None = ...,
axis: AxisIndex | None = ...,
) -> Series[S1]: ...
# ignore needed because of mypy, for using `Never` as type-var.
@overload
def sum(
self: Series[Never], # type: ignore[type-var]
axis: AxisIndex | None = ...,
skipna: _bool | None = ...,
level: None = ...,
numeric_only: _bool = ...,
min_count: int = ...,
**kwargs,
) -> Any: ...
# ignore needed because of mypy, for overlapping overloads
# between `Series[bool]` and `Series[int]`.
@overload
def sum( # type: ignore[misc]
self: Series[bool],
axis: AxisIndex | None = ...,
skipna: _bool | None = ...,
level: None = ...,
numeric_only: _bool = ...,
min_count: int = ...,
**kwargs,
) -> int: ...
@overload
def sum(
self: Series[S1],
axis: AxisIndex | None = ...,
Expand Down
30 changes: 30 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,36 @@ def test_types_sum() -> None:
s.sum(numeric_only=False)
s.sum(min_count=4)

# Note:
# 1. Return types of `series.groupby(...).sum(...)` are NOT tested
# (waiting for stubs).
# 2. Runtime return types of `series.sum(min_count=...)` are NOT
# tested (because of potential `nan`s).

s0 = assert_type(pd.Series([1, 2, 3, np.nan]), "pd.Series")
check(assert_type(s0.sum(), "Any"), np.float64)
check(assert_type(s0.sum(skipna=False), "Any"), np.float64)
check(assert_type(s0.sum(numeric_only=False), "Any"), np.float64)
assert_type(s0.sum(min_count=4), "Any")

s1 = assert_type(pd.Series([False, True], dtype=bool), "pd.Series[bool]")
check(assert_type(s1.sum(), "int"), np.int64)
check(assert_type(s1.sum(skipna=False), "int"), np.int64)
check(assert_type(s1.sum(numeric_only=False), "int"), np.int64)
assert_type(s1.sum(min_count=4), "int")

s2 = assert_type(pd.Series([0, 1], dtype=int), "pd.Series[int]")
check(assert_type(s2.sum(), "int"), np.int64)
check(assert_type(s2.sum(skipna=False), "int"), np.int64)
check(assert_type(s2.sum(numeric_only=False), "int"), np.int64)
assert_type(s2.sum(min_count=4), "int")

s3 = assert_type(pd.Series([1, 2, 3, np.nan], dtype=float), "pd.Series[float]")
check(assert_type(s3.sum(), "float"), np.float64)
check(assert_type(s3.sum(skipna=False), "float"), np.float64)
check(assert_type(s3.sum(numeric_only=False), "float"), np.float64)
assert_type(s3.sum(min_count=4), "float")


def test_types_cumsum() -> None:
s = pd.Series([1, 2, 3, np.nan])
Expand Down

0 comments on commit 2e3bbe8

Please sign in to comment.