diff --git a/python/pyspark/sql/_typing.pyi b/python/pyspark/sql/_typing.pyi index b696eea7293fb..4969268939adf 100644 --- a/python/pyspark/sql/_typing.pyi +++ b/python/pyspark/sql/_typing.pyi @@ -38,7 +38,6 @@ import pyspark.sql.types from pyspark.sql.column import Column ColumnOrName = Union[Column, str] -ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName) ColumnOrNameOrOrdinal = Union[Column, str, int] DecimalLiteral = decimal.Decimal DateTimeLiteral = Union[datetime.datetime, datetime.date] diff --git a/python/pyspark/sql/classic/window.py b/python/pyspark/sql/classic/window.py index b5c528eec10a1..63e9a337c0c2e 100644 --- a/python/pyspark/sql/classic/window.py +++ b/python/pyspark/sql/classic/window.py @@ -15,7 +15,7 @@ # limitations under the License. # import sys -from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union +from typing import cast, Iterable, Sequence, Tuple, TYPE_CHECKING, Union from pyspark.sql.window import ( Window as ParentWindow, @@ -25,13 +25,15 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject - from pyspark.sql._typing import ColumnOrName, ColumnOrName_ + from pyspark.sql._typing import ColumnOrName __all__ = ["Window", "WindowSpec"] -def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> "JavaObject": +def _to_java_cols( + cols: Tuple[Union["ColumnOrName", Sequence["ColumnOrName"]], ...] +) -> "JavaObject": from pyspark.sql.classic.column import _to_seq, _to_java_column if len(cols) == 1 and isinstance(cols[0], list): @@ -42,7 +44,7 @@ def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...] class Window(ParentWindow): @staticmethod - def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: from py4j.java_gateway import JVMView sc = get_active_spark_context() @@ -52,7 +54,7 @@ def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWi return WindowSpec(jspec) @staticmethod - def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: from py4j.java_gateway import JVMView sc = get_active_spark_context() @@ -99,10 +101,12 @@ def __new__(cls, jspec: "JavaObject") -> "WindowSpec": def __init__(self, jspec: "JavaObject") -> None: self._jspec = jspec - def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def partitionBy( + self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]] + ) -> ParentWindowSpec: return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols))) - def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: return WindowSpec(self._jspec.orderBy(_to_java_cols(cols))) def rowsBetween(self, start: int, end: int) -> ParentWindowSpec: diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 806476af1eb60..efb3e0e8eb507 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -15,7 +15,7 @@ # limitations under the License. # from types import FunctionType -from typing import Any, Callable, Iterable, Union, Optional, NewType, Protocol, Tuple, TypeVar +from typing import Any, Callable, Iterable, Union, Optional, NewType, Protocol, Tuple import datetime import decimal @@ -28,7 +28,6 @@ ColumnOrName = Union[Column, str] -ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName) ColumnOrNameOrOrdinal = Union[Column, str, int] diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index b27e8cd58513b..e727dd3b28a27 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -28,6 +28,7 @@ TYPE_CHECKING, Union, List, + Sequence, overload, Optional, Tuple, @@ -1607,7 +1608,9 @@ def reduce( reduce.__doc__ = pysparkfuncs.reduce.__doc__ -def array(*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column: +def array( + *cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]] +) -> Column: if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): cols = cols[0] # type: ignore[assignment] return _invoke_function_over_columns("array", *cols) # type: ignore[arg-type] @@ -1778,7 +1781,7 @@ def concat(*cols: "ColumnOrName") -> Column: def create_map( - *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] + *cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]] ) -> Column: if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): cols = cols[0] # type: ignore[assignment] @@ -1977,7 +1980,7 @@ def json_tuple(col: "ColumnOrName", *fields: str) -> Column: def map_concat( - *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] + *cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]] ) -> Column: if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): cols = cols[0] # type: ignore[assignment] @@ -2251,7 +2254,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column: def struct( - *cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] + *cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]] ) -> Column: if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)): cols = cols[0] # type: ignore[assignment] diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py index cbca6886060cf..b1bf080ded315 100644 --- a/python/pyspark/sql/connect/window.py +++ b/python/pyspark/sql/connect/window.py @@ -29,12 +29,12 @@ from pyspark.sql.connect.functions import builtin as F if TYPE_CHECKING: - from pyspark.sql.connect._typing import ColumnOrName, ColumnOrName_ + from pyspark.sql.connect._typing import ColumnOrName __all__ = ["Window", "WindowSpec"] -def _to_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> List[Column]: +def _to_cols(cols: Tuple[Union["ColumnOrName", Sequence["ColumnOrName"]], ...]) -> List[Column]: if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] return [F._to_col(c) for c in cast(Iterable["ColumnOrName"], cols)] @@ -84,14 +84,16 @@ def __init__( self._orderSpec = orderSpec self._frame = frame - def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def partitionBy( + self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]] + ) -> ParentWindowSpec: return WindowSpec( partitionSpec=[c._expr for c in _to_cols(cols)], # type: ignore[misc] orderSpec=self._orderSpec, frame=self._frame, ) - def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: return WindowSpec( partitionSpec=self._partitionSpec, orderSpec=[cast(SortOrder, F._sort_col(c)._expr) for c in _to_cols(cols)], @@ -139,11 +141,11 @@ class Window(ParentWindow): _spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None) @staticmethod - def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: return Window._spec.partitionBy(*cols) @staticmethod - def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec: + def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec: return Window._spec.orderBy(*cols) @staticmethod diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 1c7427750baa1..2828c0b46f161 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -29,6 +29,7 @@ Callable, Dict, List, + Sequence, Iterable, overload, Optional, @@ -66,7 +67,6 @@ from pyspark import SparkContext from pyspark.sql._typing import ( ColumnOrName, - ColumnOrName_, DataTypeOrString, UserDefinedFunctionLike, ) @@ -6875,13 +6875,13 @@ def struct(*cols: "ColumnOrName") -> Column: @overload -def struct(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column: +def struct(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column: ... @_try_remote_functions def struct( - *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] + *cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]] ) -> Column: """Creates a new struct column. @@ -13694,13 +13694,13 @@ def create_map(*cols: "ColumnOrName") -> Column: @overload -def create_map(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column: +def create_map(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column: ... @_try_remote_functions def create_map( - *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] + *cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]] ) -> Column: """ Map function: Creates a new map column from an even number of input columns or @@ -13861,13 +13861,13 @@ def array(*cols: "ColumnOrName") -> Column: @overload -def array(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column: +def array(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column: ... @_try_remote_functions def array( - *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] + *cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]] ) -> Column: """ Collection function: Creates a new array column from the input columns or column names. @@ -18283,13 +18283,13 @@ def map_concat(*cols: "ColumnOrName") -> Column: @overload -def map_concat(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column: +def map_concat(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column: ... @_try_remote_functions def map_concat( - *cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]] + *cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]] ) -> Column: """ Map function: Returns the union of all given maps. diff --git a/python/pyspark/sql/tests/typing/test_functions.yml b/python/pyspark/sql/tests/typing/test_functions.yml index d699bf01876ff..3f29c0dc17443 100644 --- a/python/pyspark/sql/tests/typing/test_functions.yml +++ b/python/pyspark/sql/tests/typing/test_functions.yml @@ -70,32 +70,32 @@ main:29: error: No overload variant of "array" matches argument types "list[Column]", "list[Column]" [call-overload] main:29: note: Possible overload variants: main:29: note: def array(*cols: Union[Column, str]) -> Column - main:29: note: def [ColumnOrName_] array(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:29: note: def array(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column main:30: error: No overload variant of "create_map" matches argument types "list[Column]", "list[Column]" [call-overload] main:30: note: Possible overload variants: main:30: note: def create_map(*cols: Union[Column, str]) -> Column - main:30: note: def [ColumnOrName_] create_map(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:30: note: def create_map(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column main:31: error: No overload variant of "map_concat" matches argument types "list[Column]", "list[Column]" [call-overload] main:31: note: Possible overload variants: main:31: note: def map_concat(*cols: Union[Column, str]) -> Column - main:31: note: def [ColumnOrName_] map_concat(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:31: note: def map_concat(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column main:32: error: No overload variant of "struct" matches argument types "list[str]", "list[str]" [call-overload] main:32: note: Possible overload variants: main:32: note: def struct(*cols: Union[Column, str]) -> Column - main:32: note: def [ColumnOrName_] struct(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:32: note: def struct(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column main:33: error: No overload variant of "array" matches argument types "list[str]", "list[str]" [call-overload] main:33: note: Possible overload variants: main:33: note: def array(*cols: Union[Column, str]) -> Column - main:33: note: def [ColumnOrName_] array(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:33: note: def array(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column main:34: error: No overload variant of "create_map" matches argument types "list[str]", "list[str]" [call-overload] main:34: note: Possible overload variants: main:34: note: def create_map(*cols: Union[Column, str]) -> Column - main:34: note: def [ColumnOrName_] create_map(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:34: note: def create_map(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column main:35: error: No overload variant of "map_concat" matches argument types "list[str]", "list[str]" [call-overload] main:35: note: Possible overload variants: main:35: note: def map_concat(*cols: Union[Column, str]) -> Column - main:35: note: def [ColumnOrName_] map_concat(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:35: note: def map_concat(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column main:36: error: No overload variant of "struct" matches argument types "list[str]", "list[str]" [call-overload] main:36: note: Possible overload variants: main:36: note: def struct(*cols: Union[Column, str]) -> Column - main:36: note: def [ColumnOrName_] struct(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column + main:36: note: def struct(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 22c9f697acde3..0c2cf4f666164 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -18,7 +18,7 @@ # mypy: disable-error-code="empty-body" import sys -from typing import List, TYPE_CHECKING, Union +from typing import Sequence, TYPE_CHECKING, Union from pyspark.sql.utils import dispatch_window_method from pyspark.util import ( @@ -28,7 +28,7 @@ if TYPE_CHECKING: from py4j.java_gateway import JavaObject - from pyspark.sql._typing import ColumnOrName, ColumnOrName_ + from pyspark.sql._typing import ColumnOrName __all__ = ["Window", "WindowSpec"] @@ -68,7 +68,7 @@ class Window: @staticmethod @dispatch_window_method - def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": + def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": """ Creates a :class:`WindowSpec` with the partitioning defined. @@ -121,7 +121,7 @@ def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowS @staticmethod @dispatch_window_method - def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": + def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": """ Creates a :class:`WindowSpec` with the ordering defined. @@ -348,7 +348,7 @@ def __new__(cls, jspec: "JavaObject") -> "WindowSpec": return WindowSpec.__new__(WindowSpec, jspec) - def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": + def partitionBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": """ Defines the partitioning columns in a :class:`WindowSpec`. @@ -361,7 +361,7 @@ def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "W """ ... - def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec": + def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec": """ Defines the ordering columns in a :class:`WindowSpec`.