Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TYP: Overload concat #41184

Merged
merged 15 commits into from
Aug 22, 2021
3 changes: 2 additions & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,10 +2310,11 @@ def describe(self):
counts = self.value_counts(dropna=False)
freqs = counts / counts.sum()

from pandas import Index
from pandas.core.reshape.concat import concat

result = concat([counts, freqs], axis=1)
result.columns = ["counts", "freqs"]
result.columns = Index(["counts", "freqs"])
Copy link
Member

@MarcoGorelli MarcoGorelli Aug 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'll need to import Index in/after line 2311 for this to work

result.index.name = "categories"

return result
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5755,7 +5755,8 @@ def astype(
# GH 19920: retain column metadata after concat
result = concat(results, axis=1, copy=False)
result.columns = self.columns
return result
# https://github.com/python/mypy/issues/8354
return cast(FrameOrSeries, result)

@final
def copy(self: FrameOrSeries, deep: bool_t = True) -> FrameOrSeries:
Expand Down Expand Up @@ -6118,7 +6119,8 @@ def convert_dtypes(
for col_name, col in self.items()
]
if len(results) > 0:
return concat(results, axis=1, copy=False)
# https://github.com/python/mypy/issues/8354
return cast(FrameOrSeries, concat(results, axis=1, copy=False))
else:
return self.copy()

Expand Down
8 changes: 2 additions & 6 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ def _aggregate_multiple_funcs(self, arg) -> DataFrame:
res_df = concat(
results.values(), axis=1, keys=[key.label for key in results.keys()]
)
# error: Incompatible return value type (got "Union[DataFrame, Series]",
# expected "DataFrame")
return res_df # type: ignore[return-value]
return res_df

indexed_output = {key.position: val for key, val in results.items()}
output = self.obj._constructor_expanddim(indexed_output, index=None)
Expand Down Expand Up @@ -547,9 +545,7 @@ def _transform_general(self, func: Callable, *args, **kwargs) -> Series:
result = self.obj._constructor(dtype=np.float64)

result.name = self.obj.name
# error: Incompatible return value type (got "Union[DataFrame, Series]",
# expected "Series")
return result # type: ignore[return-value]
return result

def _can_use_transform_fast(self, result) -> bool:
return True
Expand Down
78 changes: 64 additions & 14 deletions pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
TYPE_CHECKING,
Hashable,
Iterable,
Literal,
Mapping,
cast,
overload,
)

import numpy as np

from pandas._typing import Axis
from pandas.util._decorators import (
cache_readonly,
deprecate_nonkeyword_arguments,
Expand Down Expand Up @@ -57,25 +59,73 @@
@overload
def concat(
objs: Iterable[DataFrame] | Mapping[Hashable, DataFrame],
axis=0,
join: str = "outer",
ignore_index: bool = False,
keys=None,
levels=None,
names=None,
verify_integrity: bool = False,
sort: bool = False,
copy: bool = True,
axis: Literal[0, "index"] = ...,
join: str = ...,
ignore_index: bool = ...,
keys=...,
levels=...,
names=...,
verify_integrity: bool = ...,
sort: bool = ...,
copy: bool = ...,
) -> DataFrame:
...


@overload
def concat(
objs: Iterable[Series] | Mapping[Hashable, Series],
axis: Literal[0, "index"] = ...,
join: str = ...,
ignore_index: bool = ...,
keys=...,
levels=...,
names=...,
verify_integrity: bool = ...,
sort: bool = ...,
copy: bool = ...,
) -> Series:
...


@overload
def concat(
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
axis=0,
join: str = "outer",
ignore_index: bool = False,
axis: Literal[0, "index"] = ...,
join: str = ...,
ignore_index: bool = ...,
keys=...,
levels=...,
names=...,
verify_integrity: bool = ...,
sort: bool = ...,
copy: bool = ...,
) -> DataFrame | Series:
...


@overload
def concat(
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
axis: Literal[1, "columns"],
join: str = ...,
ignore_index: bool = ...,
keys=...,
levels=...,
names=...,
verify_integrity: bool = ...,
sort: bool = ...,
copy: bool = ...,
) -> DataFrame:
...


@overload
def concat(
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
axis: Axis = ...,
join: str = ...,
ignore_index: bool = ...,
keys=None,
levels=None,
names=None,
Expand All @@ -89,8 +139,8 @@ def concat(
@deprecate_nonkeyword_arguments(version=None, allowed_args=["objs"])
def concat(
objs: Iterable[NDFrame] | Mapping[Hashable, NDFrame],
axis=0,
join="outer",
axis: Axis = 0,
join: str = "outer",
ignore_index: bool = False,
keys=None,
levels=None,
Expand Down
12 changes: 3 additions & 9 deletions pandas/core/reshape/melt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from __future__ import annotations

import re
from typing import (
TYPE_CHECKING,
cast,
)
from typing import TYPE_CHECKING
import warnings

import numpy as np
Expand Down Expand Up @@ -34,10 +31,7 @@
from pandas.core.tools.numeric import to_numeric

if TYPE_CHECKING:
from pandas import (
DataFrame,
Series,
)
from pandas import DataFrame


@Appender(_shared_docs["melt"] % {"caller": "pd.melt(df, ", "other": "DataFrame.melt"})
Expand Down Expand Up @@ -136,7 +130,7 @@ def melt(
for col in id_vars:
id_data = frame.pop(col)
if is_extension_array_dtype(id_data):
id_data = cast("Series", concat([id_data] * K, ignore_index=True))
id_data = concat([id_data] * K, ignore_index=True)
else:
id_data = np.tile(id_data._values, K)
mdata[col] = id_data
Expand Down
10 changes: 2 additions & 8 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from __future__ import annotations

import itertools
from typing import (
TYPE_CHECKING,
cast,
)
from typing import TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -1059,10 +1056,7 @@ def get_empty_frame(data) -> DataFrame:
)
sparse_series.append(Series(data=sarr, index=index, name=col))

out = concat(sparse_series, axis=1, copy=False)
# TODO: overload concat with Literal for axis
out = cast(DataFrame, out)
return out
return concat(sparse_series, axis=1, copy=False)

else:
# take on axis=1 + transpose to ensure ndarray layout is column-major
Expand Down