Skip to content

Commit

Permalink
feat: implement model_bool hook for controlling model truthiness
Browse files Browse the repository at this point in the history
Closes #110.
  • Loading branch information
lu-pl committed Oct 24, 2024
1 parent 319de5a commit 01ea313
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
7 changes: 5 additions & 2 deletions rdfproxy/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import Any, Generic, get_args

from pydantic import BaseModel
from rdfproxy.utils._types import _TModelInstance
from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance
from rdfproxy.utils.utils import (
_collect_values_from_bindings,
_get_group_by,
_get_key_from_metadata,
_is_list_basemodel_type,
_is_list_type,
get_model_bool_predicate,
)


Expand All @@ -29,10 +30,12 @@ def get_models(self) -> list[_TModelInstance]:
def _get_unique_models(self, model, bindings):
"""Call the mapping logic and collect unique and non-empty models."""
models = []
model_bool_predicate: ModelBoolPredicate = get_model_bool_predicate(model)

for _bindings in bindings:
_model = model(**dict(self._generate_binding_pairs(model, **_bindings)))

if any(_model.model_dump().values()) and (_model not in models):
if model_bool_predicate(_model) and (_model not in models):
models.append(_model)

return models
Expand Down
13 changes: 12 additions & 1 deletion rdfproxy/utils/_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Type definitions for rdfproxy."""

from typing import TypeVar
from collections.abc import Iterable
from typing import Protocol, TypeAlias, TypeVar, runtime_checkable

from pydantic import BaseModel

Expand All @@ -27,3 +28,13 @@ class Person(BaseModel):
"""

...


@runtime_checkable
class ModelBoolPredicate(Protocol):
"""Type for model_bool predicate functions."""

def __call__(self, model: BaseModel) -> bool: ...


_TModelBoolValue: TypeAlias = ModelBoolPredicate | str | Iterable[str]
39 changes: 38 additions & 1 deletion rdfproxy/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
MissingModelConfigException,
UnboundGroupingKeyException,
)
from rdfproxy.utils._types import SPARQLBinding
from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue


def _is_type(obj: type | None, _type: type) -> bool:
Expand Down Expand Up @@ -70,3 +70,40 @@ def _get_group_by(model: type[BaseModel], kwargs: dict) -> str:
f"Applicable grouping keys: {', '.join(kwargs.keys())}."
)
return group_by


def default_model_bool_predicate(model: BaseModel) -> bool:
"""Default predicate for determining model truthiness.
Adheres to rdfproxy.utils._types.ModelBoolPredicate.
"""
return any(dict(model).values())


def _get_model_bool_predicate_from_config_value(
model_bool_value: _TModelBoolValue,
) -> ModelBoolPredicate:
"""Get a model_bool predicate function given the value of the model_bool config setting."""
match model_bool_value:
case ModelBoolPredicate():
return model_bool_value
case str():
return lambda model: bool(dict(model)[model_bool_value])
case Iterable():
return lambda model: all(map(lambda k: dict(model)[k], model_bool_value))
case _:
raise TypeError(
"Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n"
f"Received {type(model_bool_value)}"
)


def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate:
"""Get the applicable model_bool predicate function given a model."""
model_bool_predicate = (
default_model_bool_predicate
if (model_bool_value := model.model_config.get("model_bool", None)) is None
else _get_model_bool_predicate_from_config_value(model_bool_value)
)

return model_bool_predicate

0 comments on commit 01ea313

Please sign in to comment.