From e09a193900a9db08e19403e6389f43017469078d Mon Sep 17 00:00:00 2001 From: Ben Mangold <48798074+benedikt-mangold@users.noreply.github.com> Date: Thu, 20 Apr 2023 17:06:16 +0200 Subject: [PATCH] TYP: Narrow down types of arguments (DataFrame) (#52752) * Specify method in reindex for class dataframe Specify parser in to_xml of class dataframe Update doc string to_orc in class dataframeUpdate doc string to_orc in class dataframe Specify engine in to_parquet in class dataframe * undo changes and adding None as an optional argument type for validate argument of join and merge method Change byteorder argument typing for to_stata method to literal, added definition in pandas/_typing.py Change if_exists argument typing for to_gbq method to literal, added definition in pandas/_typing.py Change orient argument typing for from_dict method to literal, added definition in pandas/_typing.py Change how argument typing for to_timestamp method to literal, added definition in pandas/_typing.py Change validate argument typing for merge and join methods to literal, added definition in pandas/_typing.py Change na_action arguments typing for applymap method to literal, added definition in pandas/_typing.py Change join and errors arguments typing for update method to litaral, added definition in pandas/_typing.py Change keep argument typing for nlargest and nsallest to litaera, added definition in pandas/_typing.py Specify the kind and na_position more precisely in sort_values, reusing type definitions in pandas/_typing.py * removing none from literal and adding it to the argument of applymap * adding reindex literal to super class NDFrame as it violates the Liskov substitution principle otherwise * adding reindex literal to super class NDFrame as it violates the Liskov substitution principle otherwise * adding literal to missing.py * ignore type for orient in from_dict method of frame due to mypy error * pulling main and resolving merge conflict --------- Co-authored-by: Patrick Schleiter <4884221+pschleiter@users.noreply.github.com> --- pandas/_typing.py | 43 +++++++++++++++++++++++++++++++++ pandas/core/frame.py | 55 +++++++++++++++++++++++++++--------------- pandas/core/generic.py | 3 ++- pandas/core/missing.py | 3 ++- pandas/core/series.py | 3 ++- 5 files changed, 84 insertions(+), 23 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index a99dc584f64ca..e162f7f1662ee 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -132,6 +132,8 @@ ] Timezone = Union[str, tzinfo] +ToTimestampHow = Literal["s", "e", "start", "end"] + # NDFrameT is stricter and ensures that the same subclass of NDFrame always is # used. E.g. `def func(a: NDFrameT) -> NDFrameT: ...` means that if a # Series is passed into a function, a Series is always returned and if a DataFrame is @@ -361,6 +363,9 @@ def closed(self) -> bool: SortKind = Literal["quicksort", "mergesort", "heapsort", "stable"] NaPosition = Literal["first", "last"] +# Arguments for nsmalles and n_largest +NsmallestNlargestKeep = Literal["first", "last", "all"] + # quantile interpolation QuantileInterpolation = Literal["linear", "lower", "higher", "midpoint", "nearest"] @@ -372,9 +377,32 @@ def closed(self) -> bool: # merge MergeHow = Literal["left", "right", "inner", "outer", "cross"] +MergeValidate = Literal[ + "one_to_one", + "1:1", + "one_to_many", + "1:m", + "many_to_one", + "m:1", + "many_to_many", + "m:m", +] # join JoinHow = Literal["left", "right", "inner", "outer"] +JoinValidate = Literal[ + "one_to_one", + "1:1", + "one_to_many", + "1:m", + "many_to_one", + "m:1", + "many_to_many", + "m:m", +] + +# reindex +ReindexMethod = Union[FillnaOptions, Literal["nearest"]] MatplotlibColor = Union[str, Sequence[float]] TimeGrouperOrigin = Union[ @@ -400,3 +428,18 @@ def closed(self) -> bool: "backslashreplace", "namereplace", ] + +# update +UpdateJoin = Literal["left"] + +# applymap +NaAction = Literal["ignore"] + +# from_dict +FromDictOrient = Literal["columns", "index", "tight"] + +# to_gbc +ToGbqIfexist = Literal["fail", "replace", "append"] + +# to_stata +ToStataByteorder = Literal[">", "<", "little", "big"] diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 5341b87c39676..bd298b8d723b8 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -219,23 +219,34 @@ FloatFormatType, FormattersType, Frequency, + FromDictOrient, IgnoreRaise, IndexKeyFunc, IndexLabel, + JoinValidate, Level, MergeHow, + MergeValidate, + NaAction, NaPosition, + NsmallestNlargestKeep, PythonFuncType, QuantileInterpolation, ReadBuffer, + ReindexMethod, Renamer, Scalar, Self, SortKind, StorageOptions, Suffixes, + ToGbqIfexist, + ToStataByteorder, + ToTimestampHow, + UpdateJoin, ValueKeyFunc, WriteBuffer, + XMLParsers, npt, ) @@ -1637,7 +1648,7 @@ def __rmatmul__(self, other) -> DataFrame: def from_dict( cls, data: dict, - orient: str = "columns", + orient: FromDictOrient = "columns", dtype: Dtype | None = None, columns: Axes | None = None, ) -> DataFrame: @@ -1724,7 +1735,7 @@ def from_dict( c 2 4 """ index = None - orient = orient.lower() + orient = orient.lower() # type: ignore[assignment] if orient == "index": if len(data) > 0: # TODO speed up Series case @@ -1981,7 +1992,7 @@ def to_gbq( project_id: str | None = None, chunksize: int | None = None, reauth: bool = False, - if_exists: str = "fail", + if_exists: ToGbqIfexist = "fail", auth_local_webserver: bool = True, table_schema: list[dict[str, str]] | None = None, location: str | None = None, @@ -2535,7 +2546,7 @@ def to_stata( *, convert_dates: dict[Hashable, str] | None = None, write_index: bool = True, - byteorder: str | None = None, + byteorder: ToStataByteorder | None = None, time_stamp: datetime.datetime | None = None, data_label: str | None = None, variable_labels: dict[Hashable, str] | None = None, @@ -2763,7 +2774,7 @@ def to_markdown( def to_parquet( self, path: None = ..., - engine: str = ..., + engine: Literal["auto", "pyarrow", "fastparquet"] = ..., compression: str | None = ..., index: bool | None = ..., partition_cols: list[str] | None = ..., @@ -2776,7 +2787,7 @@ def to_parquet( def to_parquet( self, path: FilePath | WriteBuffer[bytes], - engine: str = ..., + engine: Literal["auto", "pyarrow", "fastparquet"] = ..., compression: str | None = ..., index: bool | None = ..., partition_cols: list[str] | None = ..., @@ -2789,7 +2800,7 @@ def to_parquet( def to_parquet( self, path: FilePath | WriteBuffer[bytes] | None = None, - engine: str = "auto", + engine: Literal["auto", "pyarrow", "fastparquet"] = "auto", compression: str | None = "snappy", index: bool | None = None, partition_cols: list[str] | None = None, @@ -2919,7 +2930,7 @@ def to_orc( we refer to objects with a write() method, such as a file handle (e.g. via builtin open function). If path is None, a bytes object is returned. - engine : str, default 'pyarrow' + engine : {'pyarrow'}, default 'pyarrow' ORC library to use. Pyarrow must be >= 7.0.0. index : bool, optional If ``True``, include the dataframe's index(es) in the file output. @@ -3155,7 +3166,7 @@ def to_xml( encoding: str = "utf-8", xml_declaration: bool | None = True, pretty_print: bool | None = True, - parser: str | None = "lxml", + parser: XMLParsers | None = "lxml", stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None, compression: CompressionOptions = "infer", storage_options: StorageOptions = None, @@ -4988,7 +4999,7 @@ def reindex( index=None, columns=None, axis: Axis | None = None, - method: str | None = None, + method: ReindexMethod | None = None, copy: bool | None = None, level: Level | None = None, fill_value: Scalar | None = np.nan, @@ -6521,8 +6532,8 @@ def sort_values( axis: Axis = ..., ascending=..., inplace: Literal[False] = ..., - kind: str = ..., - na_position: str = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., ignore_index: bool = ..., key: ValueKeyFunc = ..., ) -> DataFrame: @@ -7077,7 +7088,9 @@ def value_counts( return counts - def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: + def nlargest( + self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first" + ) -> DataFrame: """ Return the first `n` rows ordered by `columns` in descending order. @@ -7184,7 +7197,9 @@ def nlargest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFram """ return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest() - def nsmallest(self, n: int, columns: IndexLabel, keep: str = "first") -> DataFrame: + def nsmallest( + self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first" + ) -> DataFrame: """ Return the first `n` rows ordered by `columns` in ascending order. @@ -8348,10 +8363,10 @@ def combiner(x, y): def update( self, other, - join: str = "left", + join: UpdateJoin = "left", overwrite: bool = True, filter_func=None, - errors: str = "ignore", + errors: IgnoreRaise = "ignore", ) -> None: """ Modify in place using non-NA values from another DataFrame. @@ -9857,7 +9872,7 @@ def infer(x): return self.apply(infer).__finalize__(self, "map") def applymap( - self, func: PythonFuncType, na_action: str | None = None, **kwargs + self, func: PythonFuncType, na_action: NaAction | None = None, **kwargs ) -> DataFrame: """ Apply a function to a Dataframe elementwise. @@ -9969,7 +9984,7 @@ def join( lsuffix: str = "", rsuffix: str = "", sort: bool = False, - validate: str | None = None, + validate: JoinValidate | None = None, ) -> DataFrame: """ Join columns of another DataFrame. @@ -10211,7 +10226,7 @@ def merge( suffixes: Suffixes = ("_x", "_y"), copy: bool | None = None, indicator: str | bool = False, - validate: str | None = None, + validate: MergeValidate | None = None, ) -> DataFrame: from pandas.core.reshape.merge import merge @@ -11506,7 +11521,7 @@ def quantile( def to_timestamp( self, freq: Frequency | None = None, - how: str = "start", + how: ToTimestampHow = "start", axis: Axis = 0, copy: bool | None = None, ) -> DataFrame: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 9a9db0486e4a7..f1c39281ca8aa 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -70,6 +70,7 @@ NDFrameT, OpenFileErrors, RandomState, + ReindexMethod, Renamer, Scalar, Self, @@ -5154,7 +5155,7 @@ def reindex( index=None, columns=None, axis: Axis | None = None, - method: str | None = None, + method: ReindexMethod | None = None, copy: bool_t | None = None, level: Level | None = None, fill_value: Scalar | None = np.nan, diff --git a/pandas/core/missing.py b/pandas/core/missing.py index aaed431f890d3..585ad50ad9069 100644 --- a/pandas/core/missing.py +++ b/pandas/core/missing.py @@ -25,6 +25,7 @@ Axis, AxisInt, F, + ReindexMethod, npt, ) from pandas.compat._optional import import_optional_dependency @@ -949,7 +950,7 @@ def get_fill_func(method, ndim: int = 1): return {"pad": _pad_2d, "backfill": _backfill_2d}[method] -def clean_reindex_fill_method(method) -> str | None: +def clean_reindex_fill_method(method) -> ReindexMethod | None: return clean_fill_method(method, allow_nearest=True) diff --git a/pandas/core/series.py b/pandas/core/series.py index fbfbcbdacafc5..aa7d108746630 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -167,6 +167,7 @@ NumpySorter, NumpyValueArrayLike, QuantileInterpolation, + ReindexMethod, Renamer, Scalar, Self, @@ -4718,7 +4719,7 @@ def reindex( # type: ignore[override] index=None, *, axis: Axis | None = None, - method: str | None = None, + method: ReindexMethod | None = None, copy: bool | None = None, level: Level | None = None, fill_value: Scalar | None = None,