Skip to content

Commit

Permalink
[SPARK-49035][PYTHON] Eliminate TypeVar ColumnOrName_
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Eliminate TypeVar `ColumnOrName_`

### Why are the changes needed?
unify the usage of `ColumnOrName`

### Does this PR introduce _any_ user-facing change?
No, internal change

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #47512 from zhengruifeng/hint_CoN_.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Jul 28, 2024
1 parent 80223bb commit 112a52d
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 43 deletions.
1 change: 0 additions & 1 deletion python/pyspark/sql/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import pyspark.sql.types
from pyspark.sql.column import Column

ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
ColumnOrNameOrOrdinal = Union[Column, str, int]
DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
Expand Down
18 changes: 11 additions & 7 deletions python/pyspark/sql/classic/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
import sys
from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union
from typing import cast, Iterable, Sequence, Tuple, TYPE_CHECKING, Union

from pyspark.sql.window import (
Window as ParentWindow,
Expand All @@ -25,13 +25,15 @@

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
from pyspark.sql._typing import ColumnOrName, ColumnOrName_
from pyspark.sql._typing import ColumnOrName


__all__ = ["Window", "WindowSpec"]


def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> "JavaObject":
def _to_java_cols(
cols: Tuple[Union["ColumnOrName", Sequence["ColumnOrName"]], ...]
) -> "JavaObject":
from pyspark.sql.classic.column import _to_seq, _to_java_column

if len(cols) == 1 and isinstance(cols[0], list):
Expand All @@ -42,7 +44,7 @@ def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]

class Window(ParentWindow):
@staticmethod
def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
from py4j.java_gateway import JVMView

sc = get_active_spark_context()
Expand All @@ -52,7 +54,7 @@ def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWi
return WindowSpec(jspec)

@staticmethod
def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
from py4j.java_gateway import JVMView

sc = get_active_spark_context()
Expand Down Expand Up @@ -99,10 +101,12 @@ def __new__(cls, jspec: "JavaObject") -> "WindowSpec":
def __init__(self, jspec: "JavaObject") -> None:
self._jspec = jspec

def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def partitionBy(
self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]
) -> ParentWindowSpec:
return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols)))

def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
return WindowSpec(self._jspec.orderBy(_to_java_cols(cols)))

def rowsBetween(self, start: int, end: int) -> ParentWindowSpec:
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
from types import FunctionType
from typing import Any, Callable, Iterable, Union, Optional, NewType, Protocol, Tuple, TypeVar
from typing import Any, Callable, Iterable, Union, Optional, NewType, Protocol, Tuple
import datetime
import decimal

Expand All @@ -28,7 +28,6 @@


ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)

ColumnOrNameOrOrdinal = Union[Column, str, int]

Expand Down
11 changes: 7 additions & 4 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TYPE_CHECKING,
Union,
List,
Sequence,
overload,
Optional,
Tuple,
Expand Down Expand Up @@ -1607,7 +1608,9 @@ def reduce(
reduce.__doc__ = pysparkfuncs.reduce.__doc__


def array(*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column:
def array(
*cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]
) -> Column:
if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
cols = cols[0] # type: ignore[assignment]
return _invoke_function_over_columns("array", *cols) # type: ignore[arg-type]
Expand Down Expand Up @@ -1778,7 +1781,7 @@ def concat(*cols: "ColumnOrName") -> Column:


def create_map(
*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
*cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]
) -> Column:
if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
cols = cols[0] # type: ignore[assignment]
Expand Down Expand Up @@ -1977,7 +1980,7 @@ def json_tuple(col: "ColumnOrName", *fields: str) -> Column:


def map_concat(
*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
*cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]
) -> Column:
if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
cols = cols[0] # type: ignore[assignment]
Expand Down Expand Up @@ -2251,7 +2254,7 @@ def sort_array(col: "ColumnOrName", asc: bool = True) -> Column:


def struct(
*cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
*cols: Union["ColumnOrName", Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]
) -> Column:
if len(cols) == 1 and isinstance(cols[0], (list, set, tuple)):
cols = cols[0] # type: ignore[assignment]
Expand Down
14 changes: 8 additions & 6 deletions python/pyspark/sql/connect/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
from pyspark.sql.connect.functions import builtin as F

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName, ColumnOrName_
from pyspark.sql.connect._typing import ColumnOrName

__all__ = ["Window", "WindowSpec"]


def _to_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> List[Column]:
def _to_cols(cols: Tuple[Union["ColumnOrName", Sequence["ColumnOrName"]], ...]) -> List[Column]:
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0] # type: ignore[assignment]
return [F._to_col(c) for c in cast(Iterable["ColumnOrName"], cols)]
Expand Down Expand Up @@ -84,14 +84,16 @@ def __init__(
self._orderSpec = orderSpec
self._frame = frame

def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def partitionBy(
self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]
) -> ParentWindowSpec:
return WindowSpec(
partitionSpec=[c._expr for c in _to_cols(cols)], # type: ignore[misc]
orderSpec=self._orderSpec,
frame=self._frame,
)

def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
return WindowSpec(
partitionSpec=self._partitionSpec,
orderSpec=[cast(SortOrder, F._sort_col(c)._expr) for c in _to_cols(cols)],
Expand Down Expand Up @@ -139,11 +141,11 @@ class Window(ParentWindow):
_spec = WindowSpec(partitionSpec=[], orderSpec=[], frame=None)

@staticmethod
def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
return Window._spec.partitionBy(*cols)

@staticmethod
def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> ParentWindowSpec:
def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> ParentWindowSpec:
return Window._spec.orderBy(*cols)

@staticmethod
Expand Down
18 changes: 9 additions & 9 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Callable,
Dict,
List,
Sequence,
Iterable,
overload,
Optional,
Expand Down Expand Up @@ -66,7 +67,6 @@
from pyspark import SparkContext
from pyspark.sql._typing import (
ColumnOrName,
ColumnOrName_,
DataTypeOrString,
UserDefinedFunctionLike,
)
Expand Down Expand Up @@ -6875,13 +6875,13 @@ def struct(*cols: "ColumnOrName") -> Column:


@overload
def struct(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column:
def struct(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column:
...


@_try_remote_functions
def struct(
*cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]]
*cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]]
) -> Column:
"""Creates a new struct column.

Expand Down Expand Up @@ -13694,13 +13694,13 @@ def create_map(*cols: "ColumnOrName") -> Column:


@overload
def create_map(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column:
def create_map(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column:
...


@_try_remote_functions
def create_map(
*cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]]
*cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]]
) -> Column:
"""
Map function: Creates a new map column from an even number of input columns or
Expand Down Expand Up @@ -13861,13 +13861,13 @@ def array(*cols: "ColumnOrName") -> Column:


@overload
def array(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column:
def array(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column:
...


@_try_remote_functions
def array(
*cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]]
*cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]]
) -> Column:
"""
Collection function: Creates a new array column from the input columns or column names.
Expand Down Expand Up @@ -18283,13 +18283,13 @@ def map_concat(*cols: "ColumnOrName") -> Column:


@overload
def map_concat(__cols: Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]) -> Column:
def map_concat(__cols: Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]) -> Column:
...


@_try_remote_functions
def map_concat(
*cols: Union["ColumnOrName", Union[List["ColumnOrName_"], Tuple["ColumnOrName_", ...]]]
*cols: Union["ColumnOrName", Union[Sequence["ColumnOrName"], Tuple["ColumnOrName", ...]]]
) -> Column:
"""
Map function: Returns the union of all given maps.
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/tests/typing/test_functions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,32 @@
main:29: error: No overload variant of "array" matches argument types "list[Column]", "list[Column]" [call-overload]
main:29: note: Possible overload variants:
main:29: note: def array(*cols: Union[Column, str]) -> Column
main:29: note: def [ColumnOrName_] array(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:29: note: def array(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
main:30: error: No overload variant of "create_map" matches argument types "list[Column]", "list[Column]" [call-overload]
main:30: note: Possible overload variants:
main:30: note: def create_map(*cols: Union[Column, str]) -> Column
main:30: note: def [ColumnOrName_] create_map(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:30: note: def create_map(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
main:31: error: No overload variant of "map_concat" matches argument types "list[Column]", "list[Column]" [call-overload]
main:31: note: Possible overload variants:
main:31: note: def map_concat(*cols: Union[Column, str]) -> Column
main:31: note: def [ColumnOrName_] map_concat(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:31: note: def map_concat(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
main:32: error: No overload variant of "struct" matches argument types "list[str]", "list[str]" [call-overload]
main:32: note: Possible overload variants:
main:32: note: def struct(*cols: Union[Column, str]) -> Column
main:32: note: def [ColumnOrName_] struct(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:32: note: def struct(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
main:33: error: No overload variant of "array" matches argument types "list[str]", "list[str]" [call-overload]
main:33: note: Possible overload variants:
main:33: note: def array(*cols: Union[Column, str]) -> Column
main:33: note: def [ColumnOrName_] array(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:33: note: def array(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
main:34: error: No overload variant of "create_map" matches argument types "list[str]", "list[str]" [call-overload]
main:34: note: Possible overload variants:
main:34: note: def create_map(*cols: Union[Column, str]) -> Column
main:34: note: def [ColumnOrName_] create_map(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:34: note: def create_map(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
main:35: error: No overload variant of "map_concat" matches argument types "list[str]", "list[str]" [call-overload]
main:35: note: Possible overload variants:
main:35: note: def map_concat(*cols: Union[Column, str]) -> Column
main:35: note: def [ColumnOrName_] map_concat(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:35: note: def map_concat(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
main:36: error: No overload variant of "struct" matches argument types "list[str]", "list[str]" [call-overload]
main:36: note: Possible overload variants:
main:36: note: def struct(*cols: Union[Column, str]) -> Column
main:36: note: def [ColumnOrName_] struct(Union[list[ColumnOrName_], tuple[ColumnOrName_, ...]], /) -> Column
main:36: note: def struct(Union[Sequence[Union[Column, str]], tuple[Union[Column, str], ...]], /) -> Column
12 changes: 6 additions & 6 deletions python/pyspark/sql/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# mypy: disable-error-code="empty-body"

import sys
from typing import List, TYPE_CHECKING, Union
from typing import Sequence, TYPE_CHECKING, Union

from pyspark.sql.utils import dispatch_window_method
from pyspark.util import (
Expand All @@ -28,7 +28,7 @@

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
from pyspark.sql._typing import ColumnOrName, ColumnOrName_
from pyspark.sql._typing import ColumnOrName

__all__ = ["Window", "WindowSpec"]

Expand Down Expand Up @@ -68,7 +68,7 @@ class Window:

@staticmethod
@dispatch_window_method
def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec":
def partitionBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
"""
Creates a :class:`WindowSpec` with the partitioning defined.
Expand Down Expand Up @@ -121,7 +121,7 @@ def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowS

@staticmethod
@dispatch_window_method
def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec":
def orderBy(*cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
"""
Creates a :class:`WindowSpec` with the ordering defined.
Expand Down Expand Up @@ -348,7 +348,7 @@ def __new__(cls, jspec: "JavaObject") -> "WindowSpec":

return WindowSpec.__new__(WindowSpec, jspec)

def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec":
def partitionBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
"""
Defines the partitioning columns in a :class:`WindowSpec`.
Expand All @@ -361,7 +361,7 @@ def partitionBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "W
"""
...

def orderBy(self, *cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec":
def orderBy(self, *cols: Union["ColumnOrName", Sequence["ColumnOrName"]]) -> "WindowSpec":
"""
Defines the ordering columns in a :class:`WindowSpec`.
Expand Down

0 comments on commit 112a52d

Please sign in to comment.