Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement model_bool hook for controlling model truthiness #113

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions rdfproxy/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import Any, Generic, get_args

from pydantic import BaseModel
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance
from rdfproxy.utils.utils import (
_collect_values_from_bindings,
_get_group_by,
_get_key_from_metadata,
_is_list_basemodel_type,
_is_list_type,
get_model_bool_predicate,
)


Expand All @@ -29,10 +30,12 @@ def get_models(self) -> list[_TModelInstance]:
def _get_unique_models(self, model, bindings):
"""Call the mapping logic and collect unique and non-empty models."""
models = []
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(model)

for _bindings in bindings:
_model = model(**dict(self._generate_binding_pairs(model, **_bindings)))

if any(_model.model_dump().values()) and (_model not in models):
if model_bool_predicate(_model) and (_model not in models):
models.append(_model)

return models
Expand Down
13 changes: 12 additions & 1 deletion rdfproxy/utils/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Type definitions for rdfproxy."""

from typing import Protocol, TypeVar
from collections.abc import Iterable
from typing import Protocol, TypeAlias, TypeVar, runtime_checkable

from pydantic import BaseModel

Expand Down Expand Up @@ -31,3 +32,13 @@ class Person(BaseModel):
"""

...


@runtime_checkable
class ModelBoolPredicate(Protocol):
"""Type for model_bool predicate functions."""

def __call__(self, model: BaseModel) -> bool: ...


_TModelBoolValue: TypeAlias = ModelBoolPredicate | str | Iterable[str]
48 changes: 46 additions & 2 deletions rdfproxy/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""SPARQL/FastAPI utils."""

from collections.abc import Callable, Iterable
from typing import Any, get_args, get_origin
from typing import Any, TypeGuard, get_args, get_origin

from pydantic import BaseModel
from pydantic.fields import FieldInfo
from rdfproxy.utils._exceptions import (
MissingModelConfigException,
UnboundGroupingKeyException,
)
from rdfproxy.utils._types import SPARQLBinding
from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue


def _is_type(obj: type | None, _type: type) -> bool:
Expand Down Expand Up @@ -70,3 +70,47 @@ def _get_group_by(model: type[BaseModel], kwargs: dict) -> str:
f"Applicable grouping keys: {', '.join(kwargs.keys())}."
)
return group_by


def default_model_bool_predicate(model: BaseModel) -> bool:
"""Default predicate for determining model truthiness.

Adheres to rdfproxy.utils._types.ModelBoolPredicate.
"""
return any(dict(model).values())


def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]:
return (not isinstance(iterable, str)) and all(
map(lambda i: isinstance(i, str), iterable)
)


def _get_model_bool_predicate_from_config_value(
model_bool_value: _TModelBoolValue,
) -> ModelBoolPredicate:
"""Get a model_bool predicate function given the value of the model_bool config setting."""
match model_bool_value:
case ModelBoolPredicate():
return model_bool_value
case str():
return lambda model: bool(dict(model)[model_bool_value])
case model_bool_value if _is_iterable_of_str(model_bool_value):
return lambda model: all(map(lambda k: dict(model)[k], model_bool_value))
case _:
raise TypeError(
"Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n"
f"Received {type(model_bool_value)}"
)


def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate:
"""Get the applicable model_bool predicate function given a model."""
if (model_bool_value := model.model_config.get("model_bool", None)) is None:
model_bool_predicate = default_model_bool_predicate
else:
model_bool_predicate = _get_model_bool_predicate_from_config_value(
model_bool_value
)

return model_bool_predicate
114 changes: 114 additions & 0 deletions tests/data/parameters/model_bindings_mapper_model_bool_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Parameters for testing ModelBindingsMapper with the model_bool config option.

The test cover all cases discussed in https://github.com/acdh-oeaw/rdfproxy/issues/110.
"""

from pydantic import BaseModel, ConfigDict, Field, create_model
from tests.utils._types import ModelBindingsMapperParameter


bindings = [
{"parent": "x", "child": "c", "name": "foo"},
{"parent": "y", "child": "d", "name": None},
{"parent": "y", "child": "e", "name": None},
{"parent": "z", "child": None, "name": None},
]


class Child1(BaseModel):
name: str | None = None


class Child2(BaseModel):
model_config = ConfigDict(model_bool=lambda model: True)
name: str | None = None
child: str | None = None


class Child3(BaseModel):
model_config = ConfigDict(model_bool=lambda model: True)
name: str | None = None


class Child4(BaseModel):
model_config = ConfigDict(model_bool="child")
name: str | None = None
child: str | None = None


class Child5(BaseModel):
model_config = ConfigDict(model_bool="child")
name: str | None = None
child: str | None = Field(default=None, exclude=True)


def _create_parent_with_child(child: type[BaseModel]) -> type[BaseModel]:
model = create_model(
"Parent",
parent=(str, ...),
children=(list[child], ...),
__config__=ConfigDict(group_by="parent"),
)

return model


parent_child_parameters = [
ModelBindingsMapperParameter(
model=_create_parent_with_child(Child1),
bindings=bindings,
expected=[
{"parent": "x", "children": [{"name": "foo"}]},
{"parent": "y", "children": []},
{"parent": "z", "children": []},
],
),
ModelBindingsMapperParameter(
model=_create_parent_with_child(Child2),
bindings=bindings,
expected=[
{"parent": "x", "children": [{"name": "foo", "child": "c"}]},
{
"parent": "y",
"children": [
{"name": None, "child": "d"},
{"name": None, "child": "e"},
],
},
{"parent": "z", "children": [{"name": None, "child": None}]},
],
),
ModelBindingsMapperParameter(
model=_create_parent_with_child(Child3),
bindings=bindings,
expected=[
{"parent": "x", "children": [{"name": "foo"}]},
{"parent": "y", "children": [{"name": None}]},
{"parent": "z", "children": [{"name": None}]},
],
),
ModelBindingsMapperParameter(
model=_create_parent_with_child(Child4),
bindings=bindings,
expected=[
{"parent": "x", "children": [{"name": "foo", "child": "c"}]},
{
"parent": "y",
"children": [
{"name": None, "child": "d"},
{"name": None, "child": "e"},
],
},
{"parent": "z", "children": []},
],
),
ModelBindingsMapperParameter(
model=_create_parent_with_child(Child5),
bindings=bindings,
expected=[
{"parent": "x", "children": [{"name": "foo"}]},
{"parent": "y", "children": [{"name": None}, {"name": None}]},
{"parent": "z", "children": []},
],
),
]
24 changes: 24 additions & 0 deletions tests/tests_mapper/test_model_bindings_mapper_model_bool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Pytest entry point for testing rdfproxy.mapper.ModelBindingsMapper with model_flag config."""

import pytest

from pydantic import BaseModel
from rdfproxy.mapper import ModelBindingsMapper
from tests.data.parameters.model_bindings_mapper_model_bool_parameters import (
parent_child_parameters,
)


@pytest.mark.parametrize(
["model", "bindings", "expected"],
parent_child_parameters,
)
def test_basic_model_bindings_mapper(model, bindings, expected):
"""Test for rdfproxy.ModelBindingsMapper with model_bool config..

Given a model and a set of bindings, run the BindingsModelMapper logic
and compare the result against the expected shape.
"""
mapper: ModelBindingsMapper = ModelBindingsMapper(model, *bindings)
models: list[BaseModel] = mapper.get_models()
assert [model.model_dump() for model in models] == expected
Loading