diff --git a/lenskit/lenskit/data/query.py b/lenskit/lenskit/data/query.py index a3685880c..c37fc80f1 100644 --- a/lenskit/lenskit/data/query.py +++ b/lenskit/lenskit/data/query.py @@ -71,7 +71,7 @@ def create(cls, data: QueryInput) -> RecQuery: assert_never(f"invalid type {type(data)}") -QueryInput: TypeAlias = RecQuery | EntityId | ItemList | np.integer | None +QueryInput: TypeAlias = RecQuery | EntityId | ItemList | None """ Types that can be converted to a query by :meth:`RecQuery.create`. """ diff --git a/lenskit/lenskit/data/types.py b/lenskit/lenskit/data/types.py index a9f0824f2..6bcb2192e 100644 --- a/lenskit/lenskit/data/types.py +++ b/lenskit/lenskit/data/types.py @@ -20,7 +20,7 @@ FeedbackType: TypeAlias = Literal["explicit", "implicit"] "Types of feedback supported." -EntityId: TypeAlias = int | str | bytes +EntityId: TypeAlias = int | str | bytes | np.integer[Any] | np.string_ "Allowable entity identifier types." NPEntityId: TypeAlias = np.integer[Any] | np.str_ | np.bytes_ | np.object_ "Allowable entity identifier types (NumPy version)" diff --git a/lenskit/lenskit/pipeline/types.py b/lenskit/lenskit/pipeline/types.py index f532e1296..0e8d312a7 100644 --- a/lenskit/lenskit/pipeline/types.py +++ b/lenskit/lenskit/pipeline/types.py @@ -10,7 +10,7 @@ import re import warnings from importlib import import_module -from types import GenericAlias, NoneType +from types import GenericAlias, NoneType, UnionType from typing import ( # type: ignore Generic, Protocol, @@ -128,7 +128,8 @@ def is_compatible_data(obj: object, *targets: type | TypeVar) -> bool: except TypeError: pass - if get_origin(target) == Union: + origin = get_origin(target) + if origin == UnionType or origin == Union: types = get_args(target) if is_compatible_data(obj, *types): return True diff --git a/lenskit/tests/pipeline/test_types.py b/lenskit/tests/pipeline/test_types.py index 42798015b..0e21e1684 100644 --- a/lenskit/tests/pipeline/test_types.py +++ b/lenskit/tests/pipeline/test_types.py @@ -12,7 +12,7 @@ from collections.abc import Iterable, Sequence from pathlib import Path from types import NoneType -from typing import TypeVar +from typing import Any, TypeVar import numpy as np import pandas as pd @@ -102,6 +102,14 @@ def test_numpy_typecheck(): assert not is_compatible_data(np.arange(10), NDArray[np.float64]) +def test_numpy_scalar_typecheck(): + assert is_compatible_data(np.int32(4270), np.integer[Any]) + + +def test_numpy_scalar_typecheck2(): + assert is_compatible_data(np.int32(4270), np.integer[Any] | int) + + def test_pandas_typecheck(): assert is_compatible_data(pd.Series(["a", "b"]), ArrayLike)