Skip to content

Commit

Permalink
chore: refactor isinstance_or_issubclass to allow tuples (#1820)
Browse files Browse the repository at this point in the history
chore: refactor isinstance_or_issubclass
  • Loading branch information
FBruzzesi authored Jan 17, 2025
1 parent 70bc5c6 commit c94476c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
7 changes: 3 additions & 4 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def narwhals_to_native_dtype(
return pyspark_types.StringType()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return pyspark_types.BooleanType()
if any(isinstance_or_issubclass(dtype, t) for t in [dtypes.Date, dtypes.Datetime]):
if isinstance_or_issubclass(dtype, (dtypes.Date, dtypes.Datetime)):
msg = "Converting to Date or Datetime dtype is not supported yet"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.List): # pragma: no cover
Expand All @@ -98,9 +98,8 @@ def narwhals_to_native_dtype(
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
raise NotImplementedError(msg)
if any(
isinstance_or_issubclass(dtype, t)
for t in [dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8]
if isinstance_or_issubclass(
dtype, (dtypes.UInt64, dtypes.UInt32, dtypes.UInt16, dtypes.UInt8)
): # pragma: no cover
msg = "Unsigned integer types are not supported by PySpark"
raise UnsupportedDTypeError(msg)
Expand Down
10 changes: 6 additions & 4 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,14 @@ def parse_version(version: str) -> tuple[int, ...]:
return tuple(int(re.sub(r"\D", "", str(v))) for v in version.split("."))


def isinstance_or_issubclass(obj: Any, cls: Any) -> bool:
def isinstance_or_issubclass(obj_or_cls: object | type, cls_or_tuple: Any) -> bool:
from narwhals.dtypes import DType

if isinstance(obj, DType):
return isinstance(obj, cls)
return isinstance(obj, cls) or (isinstance(obj, type) and issubclass(obj, cls))
if isinstance(obj_or_cls, DType):
return isinstance(obj_or_cls, cls_or_tuple)
return isinstance(obj_or_cls, cls_or_tuple) or (
isinstance(obj_or_cls, type) and issubclass(obj_or_cls, cls_or_tuple)
)


def validate_laziness(items: Iterable[Any]) -> None:
Expand Down

0 comments on commit c94476c

Please sign in to comment.