From f33cb8d26230530dc0d04dae78697b19893b60ee Mon Sep 17 00:00:00 2001 From: Lukas Plank Date: Thu, 25 Jul 2024 13:20:29 +0200 Subject: [PATCH] fix: Type hint cleanup closes: #10 --- rdfproxy/adapter.py | 20 ++++++++++++++------ rdfproxy/utils/_types.py | 8 ++++---- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/rdfproxy/adapter.py b/rdfproxy/adapter.py index dd557b8..518b9ff 100644 --- a/rdfproxy/adapter.py +++ b/rdfproxy/adapter.py @@ -1,17 +1,19 @@ """SPARQLModelAdapter class for QueryResult to Pydantic model conversions.""" -from collections.abc import Sequence +from collections.abc import Iterable +from typing import cast from SPARQLWrapper import JSON, QueryResult, SPARQLWrapper from pydantic import BaseModel -from rdfproxy.utils._types import _TModelConstructorCallable + +from rdfproxy.utils._types import _TModelConstructorCallable, _TModelInstance from rdfproxy.utils.utils import ( get_bindings_from_query_result, instantiate_model_from_kwargs, ) -class SPARQLModelAdapter[ModelType: BaseModel]: +class SPARQLModelAdapter: """Adapter/Mapper for QueryResult to Pydantic model conversions.""" def __init__(self, sparql_wrapper: SPARQLWrapper) -> None: @@ -20,19 +22,25 @@ def __init__(self, sparql_wrapper: SPARQLWrapper) -> None: if self.sparql_wrapper.returnFormat != "json": self.sparql_wrapper.setReturnFormat(JSON) - def __call__(self, query: str, model_constructor) -> Sequence[ModelType]: + def __call__( + self, + query: str, + model_constructor: type[_TModelInstance] | _TModelConstructorCallable, + ) -> Iterable[_TModelInstance]: self.sparql_wrapper.setQuery(query) query_result: QueryResult = self.sparql_wrapper.query() if isinstance(model_constructor, type(BaseModel)): + model_constructor = cast(type[_TModelInstance], model_constructor) + bindings = get_bindings_from_query_result(query_result) - models: list[ModelType] = [ + models: list[_TModelInstance] = [ instantiate_model_from_kwargs(model_constructor, **binding) for binding in bindings ] elif isinstance(model_constructor, _TModelConstructorCallable): - models: list[ModelType] = model_constructor(query_result) + models: Iterable[_TModelInstance] = model_constructor(query_result) else: raise TypeError( diff --git a/rdfproxy/utils/_types.py b/rdfproxy/utils/_types.py index eb9ce90..99299f9 100644 --- a/rdfproxy/utils/_types.py +++ b/rdfproxy/utils/_types.py @@ -1,17 +1,17 @@ """Type definitions for rdfproxy.""" +from collections.abc import Iterable from typing import Annotated, Protocol, TypeVar, runtime_checkable +from SPARQLWrapper import QueryResult from pydantic import BaseModel -_TModelInstance: Annotated[TypeVar, "Type defintion for Pydantic model instances."] = ( - TypeVar("_TModelInstance", bound=BaseModel) -) +_TModelInstance = TypeVar("_TModelInstance", bound=BaseModel) @runtime_checkable class _TModelConstructorCallable[ModelType: BaseModel](Protocol): """Callback protocol for model constructor callables.""" - def __call__(self, **kwargs) -> ModelType: ... + def __call__(self, query_result: QueryResult) -> Iterable[ModelType]: ...