Skip to content

Commit

Permalink
TYP: Narrow down types of arguments (DataFrame) (#52752)
Browse files Browse the repository at this point in the history
* Specify method in reindex for class dataframe

Specify parser in to_xml of class dataframe

Update doc string to_orc in class dataframeUpdate doc string to_orc in class dataframe

Specify engine in to_parquet in class dataframe

* undo changes and adding None as an optional argument type for validate argument of join and merge method

Change byteorder argument typing for to_stata method to literal, added definition in pandas/_typing.py

Change if_exists argument typing for to_gbq method to literal, added definition in pandas/_typing.py

Change orient argument typing for from_dict method to literal, added definition in pandas/_typing.py

Change how argument typing for to_timestamp method to literal, added definition in pandas/_typing.py

Change validate argument typing for merge and join methods to literal, added definition in pandas/_typing.py

Change na_action arguments typing for applymap method to literal, added definition in pandas/_typing.py

Change join and errors arguments typing for update method to litaral, added definition in pandas/_typing.py

Change keep argument typing for nlargest and nsallest to litaera, added definition in pandas/_typing.py

Specify the kind and na_position more precisely in sort_values, reusing type definitions in pandas/_typing.py

* removing none from literal and adding it to the argument of applymap

* adding reindex literal to super class NDFrame as it violates the Liskov substitution principle otherwise

* adding reindex literal to super class NDFrame as it violates the Liskov substitution principle otherwise

* adding literal to missing.py

* ignore type for orient in from_dict method of frame due to mypy error

* pulling main and resolving merge conflict

---------

Co-authored-by: Patrick Schleiter <[email protected]>
  • Loading branch information
benedikt-mangold and pschleiter authored Apr 20, 2023
1 parent dd2f0d2 commit e09a193
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 23 deletions.
43 changes: 43 additions & 0 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@
]
Timezone = Union[str, tzinfo]

ToTimestampHow = Literal["s", "e", "start", "end"]

# NDFrameT is stricter and ensures that the same subclass of NDFrame always is
# used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a
# Series is passed into a function, a Series is always returned and if a DataFrame is
Expand Down Expand Up @@ -361,6 +363,9 @@ def closed(self) -> bool:
SortKind = Literal["quicksort", "mergesort", "heapsort", "stable"]
NaPosition = Literal["first", "last"]

# Arguments for nsmalles and n_largest
NsmallestNlargestKeep = Literal["first", "last", "all"]

# quantile interpolation
QuantileInterpolation = Literal["linear", "lower", "higher", "midpoint", "nearest"]

Expand All @@ -372,9 +377,32 @@ def closed(self) -> bool:

# merge
MergeHow = Literal["left", "right", "inner", "outer", "cross"]
MergeValidate = Literal[
"one_to_one",
"1:1",
"one_to_many",
"1:m",
"many_to_one",
"m:1",
"many_to_many",
"m:m",
]

# join
JoinHow = Literal["left", "right", "inner", "outer"]
JoinValidate = Literal[
"one_to_one",
"1:1",
"one_to_many",
"1:m",
"many_to_one",
"m:1",
"many_to_many",
"m:m",
]

# reindex
ReindexMethod = Union[FillnaOptions, Literal["nearest"]]

MatplotlibColor = Union[str, Sequence[float]]
TimeGrouperOrigin = Union[
Expand All @@ -400,3 +428,18 @@ def closed(self) -> bool:
"backslashreplace",
"namereplace",
]

# update
UpdateJoin = Literal["left"]

# applymap
NaAction = Literal["ignore"]

# from_dict
FromDictOrient = Literal["columns", "index", "tight"]

# to_gbc
ToGbqIfexist = Literal["fail", "replace", "append"]

# to_stata
ToStataByteorder = Literal[">", "<", "little", "big"]
55 changes: 35 additions & 20 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,23 +219,34 @@
FloatFormatType,
FormattersType,
Frequency,
FromDictOrient,
IgnoreRaise,
IndexKeyFunc,
IndexLabel,
JoinValidate,
Level,
MergeHow,
MergeValidate,
NaAction,
NaPosition,
NsmallestNlargestKeep,
PythonFuncType,
QuantileInterpolation,
ReadBuffer,
ReindexMethod,
Renamer,
Scalar,
Self,
SortKind,
StorageOptions,
Suffixes,
ToGbqIfexist,
ToStataByteorder,
ToTimestampHow,
UpdateJoin,
ValueKeyFunc,
WriteBuffer,
XMLParsers,
npt,
)

Expand Down Expand Up @@ -1637,7 +1648,7 @@ def __rmatmul__(self, other) -> DataFrame:
def from_dict(
cls,
data: dict,
orient: str = "columns",
orient: FromDictOrient = "columns",
dtype: Dtype | None = None,
columns: Axes | None = None,
) -> DataFrame:
Expand Down Expand Up @@ -1724,7 +1735,7 @@ def from_dict(
c 2 4
"""
index = None
orient = orient.lower()
orient = orient.lower() # type: ignore[assignment]
if orient == "index":
if len(data) > 0:
# TODO speed up Series case
Expand Down Expand Up @@ -1981,7 +1992,7 @@ def to_gbq(
project_id: str | None = None,
chunksize: int | None = None,
reauth: bool = False,
if_exists: str = "fail",
if_exists: ToGbqIfexist = "fail",
auth_local_webserver: bool = True,
table_schema: list[dict[str, str]] | None = None,
location: str | None = None,
Expand Down Expand Up @@ -2535,7 +2546,7 @@ def to_stata(
*,
convert_dates: dict[Hashable, str] | None = None,
write_index: bool = True,
byteorder: str | None = None,
byteorder: ToStataByteorder | None = None,
time_stamp: datetime.datetime | None = None,
data_label: str | None = None,
variable_labels: dict[Hashable, str] | None = None,
Expand Down Expand Up @@ -2763,7 +2774,7 @@ def to_markdown(
def to_parquet(
self,
path: None = ...,
engine: str = ...,
engine: Literal["auto", "pyarrow", "fastparquet"] = ...,
compression: str | None = ...,
index: bool | None = ...,
partition_cols: list[str] | None = ...,
Expand All @@ -2776,7 +2787,7 @@ def to_parquet(
def to_parquet(
self,
path: FilePath | WriteBuffer[bytes],
engine: str = ...,
engine: Literal["auto", "pyarrow", "fastparquet"] = ...,
compression: str | None = ...,
index: bool | None = ...,
partition_cols: list[str] | None = ...,
Expand All @@ -2789,7 +2800,7 @@ def to_parquet(
def to_parquet(
self,
path: FilePath | WriteBuffer[bytes] | None = None,
engine: str = "auto",
engine: Literal["auto", "pyarrow", "fastparquet"] = "auto",
compression: str | None = "snappy",
index: bool | None = None,
partition_cols: list[str] | None = None,
Expand Down Expand Up @@ -2919,7 +2930,7 @@ def to_orc(
we refer to objects with a write() method, such as a file handle
(e.g. via builtin open function). If path is None,
a bytes object is returned.
engine : str, default 'pyarrow'
engine : {'pyarrow'}, default 'pyarrow'
ORC library to use. Pyarrow must be >= 7.0.0.
index : bool, optional
If ``True``, include the dataframe's index(es) in the file output.
Expand Down Expand Up @@ -3155,7 +3166,7 @@ def to_xml(
encoding: str = "utf-8",
xml_declaration: bool | None = True,
pretty_print: bool | None = True,
parser: str | None = "lxml",
parser: XMLParsers | None = "lxml",
stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None,
compression: CompressionOptions = "infer",
storage_options: StorageOptions = None,
Expand Down Expand Up @@ -4988,7 +4999,7 @@ def reindex(
index=None,
columns=None,
axis: Axis | None = None,
method: str | None = None,
method: ReindexMethod | None = None,
copy: bool | None = None,
level: Level | None = None,
fill_value: Scalar | None = np.nan,
Expand Down Expand Up @@ -6521,8 +6532,8 @@ def sort_values(
axis: Axis = ...,
ascending=...,
inplace: Literal[False] = ...,
kind: str = ...,
na_position: str = ...,
kind: SortKind = ...,
na_position: NaPosition = ...,
ignore_index: bool = ...,
key: ValueKeyFunc = ...,
) -> DataFrame:
Expand Down Expand Up @@ -7077,7 +7088,9 @@ def value_counts(

return counts

def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
def nlargest(
self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first"
) -> DataFrame:
"""
Return the first `n` rows ordered by `columns` in descending order.
Expand Down Expand Up @@ -7184,7 +7197,9 @@ def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFram
"""
return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest()

def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame:
def nsmallest(
self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first"
) -> DataFrame:
"""
Return the first `n` rows ordered by `columns` in ascending order.
Expand Down Expand Up @@ -8348,10 +8363,10 @@ def combiner(x, y):
def update(
self,
other,
join: str = "left",
join: UpdateJoin = "left",
overwrite: bool = True,
filter_func=None,
errors: str = "ignore",
errors: IgnoreRaise = "ignore",
) -> None:
"""
Modify in place using non-NA values from another DataFrame.
Expand Down Expand Up @@ -9857,7 +9872,7 @@ def infer(x):
return self.apply(infer).__finalize__(self, "map")

def applymap(
self, func: PythonFuncType, na_action: str | None = None, **kwargs
self, func: PythonFuncType, na_action: NaAction | None = None, **kwargs
) -> DataFrame:
"""
Apply a function to a Dataframe elementwise.
Expand Down Expand Up @@ -9969,7 +9984,7 @@ def join(
lsuffix: str = "",
rsuffix: str = "",
sort: bool = False,
validate: str | None = None,
validate: JoinValidate | None = None,
) -> DataFrame:
"""
Join columns of another DataFrame.
Expand Down Expand Up @@ -10211,7 +10226,7 @@ def merge(
suffixes: Suffixes = ("_x", "_y"),
copy: bool | None = None,
indicator: str | bool = False,
validate: str | None = None,
validate: MergeValidate | None = None,
) -> DataFrame:
from pandas.core.reshape.merge import merge

Expand Down Expand Up @@ -11506,7 +11521,7 @@ def quantile(
def to_timestamp(
self,
freq: Frequency | None = None,
how: str = "start",
how: ToTimestampHow = "start",
axis: Axis = 0,
copy: bool | None = None,
) -> DataFrame:
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
NDFrameT,
OpenFileErrors,
RandomState,
ReindexMethod,
Renamer,
Scalar,
Self,
Expand Down Expand Up @@ -5154,7 +5155,7 @@ def reindex(
index=None,
columns=None,
axis: Axis | None = None,
method: str | None = None,
method: ReindexMethod | None = None,
copy: bool_t | None = None,
level: Level | None = None,
fill_value: Scalar | None = np.nan,
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Axis,
AxisInt,
F,
ReindexMethod,
npt,
)
from pandas.compat._optional import import_optional_dependency
Expand Down Expand Up @@ -949,7 +950,7 @@ def get_fill_func(method, ndim: int = 1):
return {"pad": _pad_2d, "backfill": _backfill_2d}[method]


def clean_reindex_fill_method(method) -> str | None:
def clean_reindex_fill_method(method) -> ReindexMethod | None:
return clean_fill_method(method, allow_nearest=True)


Expand Down
3 changes: 2 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
NumpySorter,
NumpyValueArrayLike,
QuantileInterpolation,
ReindexMethod,
Renamer,
Scalar,
Self,
Expand Down Expand Up @@ -4718,7 +4719,7 @@ def reindex( # type: ignore[override]
index=None,
*,
axis: Axis | None = None,
method: str | None = None,
method: ReindexMethod | None = None,
copy: bool | None = None,
level: Level | None = None,
fill_value: Scalar | None = None,
Expand Down

0 comments on commit e09a193

Please sign in to comment.