diff --git a/rdfproxy/mapper.py b/rdfproxy/mapper.py index a7ba2c3..001a349 100644 --- a/rdfproxy/mapper.py +++ b/rdfproxy/mapper.py @@ -4,13 +4,14 @@ from typing import Any, Generic, get_args from pydantic import BaseModel -from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance from rdfproxy.utils.utils import ( _collect_values_from_bindings, _get_group_by, _get_key_from_metadata, _is_list_basemodel_type, _is_list_type, + get_model_bool_predicate, ) @@ -29,10 +30,12 @@ def get_models(self) -> list[_TModelInstance]: def _get_unique_models(self, model, bindings): """Call the mapping logic and collect unique and non-empty models.""" models = [] + model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(model) + for _bindings in bindings: _model = model(**dict(self._generate_binding_pairs(model, **_bindings))) - if any(_model.model_dump().values()) and (_model not in models): + if model_bool_predicate(_model) and (_model not in models): models.append(_model) return models diff --git a/rdfproxy/utils/_types.py b/rdfproxy/utils/_types.py index f079f50..93b59d1 100644 --- a/rdfproxy/utils/_types.py +++ b/rdfproxy/utils/_types.py @@ -1,6 +1,7 @@ """Type definitions for rdfproxy.""" -from typing import Protocol, TypeVar +from collections.abc import Iterable +from typing import Protocol, TypeAlias, TypeVar, runtime_checkable from pydantic import BaseModel @@ -31,3 +32,13 @@ class Person(BaseModel): """ ... + + +@runtime_checkable +class ModelBoolPredicate(Protocol): + """Type for model_bool predicate functions.""" + + def __call__(self, model: BaseModel) -> bool: ... + + +_TModelBoolValue: TypeAlias = ModelBoolPredicate | str | Iterable[str] diff --git a/rdfproxy/utils/utils.py b/rdfproxy/utils/utils.py index 1eb5c8b..0615d7d 100644 --- a/rdfproxy/utils/utils.py +++ b/rdfproxy/utils/utils.py @@ -1,7 +1,7 @@ """SPARQL/FastAPI utils.""" from collections.abc import Callable, Iterable -from typing import Any, get_args, get_origin +from typing import Any, TypeGuard, get_args, get_origin from pydantic import BaseModel from pydantic.fields import FieldInfo @@ -9,7 +9,7 @@ MissingModelConfigException, UnboundGroupingKeyException, ) -from rdfproxy.utils._types import SPARQLBinding +from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue def _is_type(obj: type | None, _type: type) -> bool: @@ -70,3 +70,47 @@ def _get_group_by(model: type[BaseModel], kwargs: dict) -> str: f"Applicable grouping keys: {', '.join(kwargs.keys())}." ) return group_by + + +def default_model_bool_predicate(model: BaseModel) -> bool: + """Default predicate for determining model truthiness. + + Adheres to rdfproxy.utils._types.ModelBoolPredicate. + """ + return any(dict(model).values()) + + +def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: + return (not isinstance(iterable, str)) and all( + map(lambda i: isinstance(i, str), iterable) + ) + + +def _get_model_bool_predicate_from_config_value( + model_bool_value: _TModelBoolValue, +) -> ModelBoolPredicate: + """Get a model_bool predicate function given the value of the model_bool config setting.""" + match model_bool_value: + case ModelBoolPredicate(): + return model_bool_value + case str(): + return lambda model: bool(dict(model)[model_bool_value]) + case model_bool_value if _is_iterable_of_str(model_bool_value): + return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) + case _: + raise TypeError( + "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" + f"Received {type(model_bool_value)}" + ) + + +def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: + """Get the applicable model_bool predicate function given a model.""" + if (model_bool_value := model.model_config.get("model_bool", None)) is None: + model_bool_predicate = default_model_bool_predicate + else: + model_bool_predicate = _get_model_bool_predicate_from_config_value( + model_bool_value + ) + + return model_bool_predicate diff --git a/tests/data/parameters/model_bindings_mapper_model_bool_parameters.py b/tests/data/parameters/model_bindings_mapper_model_bool_parameters.py new file mode 100644 index 0000000..2b31d5a --- /dev/null +++ b/tests/data/parameters/model_bindings_mapper_model_bool_parameters.py @@ -0,0 +1,114 @@ +"""Parameters for testing ModelBindingsMapper with the model_bool config option. + +The test cover all cases discussed in https://github.com/acdh-oeaw/rdfproxy/issues/110. +""" + +from pydantic import BaseModel, ConfigDict, Field, create_model +from tests.utils._types import ModelBindingsMapperParameter + + +bindings = [ + {"parent": "x", "child": "c", "name": "foo"}, + {"parent": "y", "child": "d", "name": None}, + {"parent": "y", "child": "e", "name": None}, + {"parent": "z", "child": None, "name": None}, +] + + +class Child1(BaseModel): + name: str | None = None + + +class Child2(BaseModel): + model_config = ConfigDict(model_bool=lambda model: True) + name: str | None = None + child: str | None = None + + +class Child3(BaseModel): + model_config = ConfigDict(model_bool=lambda model: True) + name: str | None = None + + +class Child4(BaseModel): + model_config = ConfigDict(model_bool="child") + name: str | None = None + child: str | None = None + + +class Child5(BaseModel): + model_config = ConfigDict(model_bool="child") + name: str | None = None + child: str | None = Field(default=None, exclude=True) + + +def _create_parent_with_child(child: type[BaseModel]) -> type[BaseModel]: + model = create_model( + "Parent", + parent=(str, ...), + children=(list[child], ...), + __config__=ConfigDict(group_by="parent"), + ) + + return model + + +parent_child_parameters = [ + ModelBindingsMapperParameter( + model=_create_parent_with_child(Child1), + bindings=bindings, + expected=[ + {"parent": "x", "children": [{"name": "foo"}]}, + {"parent": "y", "children": []}, + {"parent": "z", "children": []}, + ], + ), + ModelBindingsMapperParameter( + model=_create_parent_with_child(Child2), + bindings=bindings, + expected=[ + {"parent": "x", "children": [{"name": "foo", "child": "c"}]}, + { + "parent": "y", + "children": [ + {"name": None, "child": "d"}, + {"name": None, "child": "e"}, + ], + }, + {"parent": "z", "children": [{"name": None, "child": None}]}, + ], + ), + ModelBindingsMapperParameter( + model=_create_parent_with_child(Child3), + bindings=bindings, + expected=[ + {"parent": "x", "children": [{"name": "foo"}]}, + {"parent": "y", "children": [{"name": None}]}, + {"parent": "z", "children": [{"name": None}]}, + ], + ), + ModelBindingsMapperParameter( + model=_create_parent_with_child(Child4), + bindings=bindings, + expected=[ + {"parent": "x", "children": [{"name": "foo", "child": "c"}]}, + { + "parent": "y", + "children": [ + {"name": None, "child": "d"}, + {"name": None, "child": "e"}, + ], + }, + {"parent": "z", "children": []}, + ], + ), + ModelBindingsMapperParameter( + model=_create_parent_with_child(Child5), + bindings=bindings, + expected=[ + {"parent": "x", "children": [{"name": "foo"}]}, + {"parent": "y", "children": [{"name": None}, {"name": None}]}, + {"parent": "z", "children": []}, + ], + ), +] diff --git a/tests/tests_mapper/test_model_bindings_mapper_model_bool.py b/tests/tests_mapper/test_model_bindings_mapper_model_bool.py new file mode 100644 index 0000000..f9bde75 --- /dev/null +++ b/tests/tests_mapper/test_model_bindings_mapper_model_bool.py @@ -0,0 +1,24 @@ +"""Pytest entry point for testing rdfproxy.mapper.ModelBindingsMapper with model_flag config.""" + +import pytest + +from pydantic import BaseModel +from rdfproxy.mapper import ModelBindingsMapper +from tests.data.parameters.model_bindings_mapper_model_bool_parameters import ( + parent_child_parameters, +) + + +@pytest.mark.parametrize( + ["model", "bindings", "expected"], + parent_child_parameters, +) +def test_basic_model_bindings_mapper(model, bindings, expected): + """Test for rdfproxy.ModelBindingsMapper with model_bool config.. + + Given a model and a set of bindings, run the BindingsModelMapper logic + and compare the result against the expected shape. + """ + mapper: ModelBindingsMapper = ModelBindingsMapper(model, *bindings) + models: list[BaseModel] = mapper.get_models() + assert [model.model_dump() for model in models] == expected