Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: evaluator module #192

Merged
merged 7 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 115 additions & 6 deletions flag_engine/segments/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import operator
import re
import typing
from contextlib import suppress
from functools import wraps

import semver

from flag_engine.environments.models import EnvironmentModel
from flag_engine.identities.models import IdentityModel
from flag_engine.identities.traits.models import TraitModel
from flag_engine.identities.traits.types import TraitValue
from flag_engine.segments import constants
from flag_engine.segments.models import (
SegmentConditionModel,
SegmentModel,
SegmentRuleModel,
)
from flag_engine.segments.types import ConditionOperator
from flag_engine.utils.hashing import get_hashed_percentage_for_object_ids

from ..environments.models import EnvironmentModel
from ..identities.traits.models import TraitModel
from . import constants
from .models import SegmentConditionModel, SegmentModel, SegmentRuleModel
from flag_engine.utils.semver import is_semver
from flag_engine.utils.types import get_casting_function


def get_identity_segments(
Expand Down Expand Up @@ -79,6 +92,7 @@ def _traits_match_segment_condition(
identity_id: typing.Union[int, str],
) -> bool:
if condition.operator == constants.PERCENTAGE_SPLIT:
assert condition.value
float_value = float(condition.value)
return (
get_hashed_percentage_for_object_ids([segment_id, identity_id])
Expand All @@ -95,4 +109,99 @@ def _traits_match_segment_condition(
if condition.operator == constants.IS_SET:
return trait is not None

return condition.matches_trait_value(trait.trait_value) if trait else False
return _matches_trait_value(condition, trait.trait_value) if trait else False


def _matches_trait_value(
condition: SegmentConditionModel,
trait_value: TraitValue,
) -> bool:
if match_func := MATCH_FUNCS_BY_OPERATOR.get(condition.operator):
return match_func(condition.value, trait_value)

return False


def _evaluate_not_contains(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
return isinstance(trait_value, str) and str(segment_value) not in trait_value


def _evaluate_regex(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
return (
trait_value is not None
and re.compile(str(segment_value)).match(str(trait_value)) is not None
)


def _evaluate_modulo(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
if not isinstance(trait_value, (int, float)):
return False

if segment_value is None:
return False

try:
divisor_part, remainder_part = segment_value.split("|")
divisor = float(divisor_part)
remainder = float(remainder_part)
except ValueError:
return False

return trait_value % divisor == remainder


def _evaluate_in(segment_value: typing.Optional[str], trait_value: TraitValue) -> bool:
if segment_value:
if isinstance(trait_value, str):
return trait_value in segment_value.split(",")
if isinstance(trait_value, int) and not any(
trait_value is x for x in (False, True)
):
return str(trait_value) in segment_value.split(",")
return False


def _trait_value_typed(
func: typing.Callable[..., bool],
) -> typing.Callable[[typing.Optional[str], TraitValue], bool]:
@wraps(func)
def inner(
segment_value: typing.Optional[str],
trait_value: TraitValue,
) -> bool:
with suppress(TypeError, ValueError):
if isinstance(trait_value, str) and is_semver(segment_value):
trait_value = semver.VersionInfo.parse(
trait_value,
)
match_value = get_casting_function(trait_value)(segment_value)
return func(trait_value, match_value)
return False

return inner


MATCH_FUNCS_BY_OPERATOR: typing.Dict[
ConditionOperator, typing.Callable[[typing.Optional[str], TraitValue], bool]
] = {
constants.NOT_CONTAINS: _evaluate_not_contains,
constants.REGEX: _evaluate_regex,
constants.MODULO: _evaluate_modulo,
constants.IN: _evaluate_in,
constants.EQUAL: _trait_value_typed(operator.eq),
constants.GREATER_THAN: _trait_value_typed(operator.gt),
constants.GREATER_THAN_INCLUSIVE: _trait_value_typed(operator.ge),
constants.LESS_THAN: _trait_value_typed(operator.lt),
constants.LESS_THAN_INCLUSIVE: _trait_value_typed(operator.le),
constants.NOT_EQUAL: _trait_value_typed(operator.ne),
constants.CONTAINS: _trait_value_typed(operator.contains),
}
59 changes: 0 additions & 59 deletions flag_engine/segments/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import re
import typing
from contextlib import suppress

import semver
from pydantic import BaseModel, Field

from flag_engine.features.models import FeatureStateModel
from flag_engine.segments import constants
from flag_engine.segments.types import ConditionOperator, RuleType
from flag_engine.utils.semver import is_semver
from flag_engine.utils.types import get_casting_function


class SegmentConditionModel(BaseModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_EXCEPTION_OPERATOR_METHODS can be removed now, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Done.

Expand All @@ -24,60 +19,6 @@ class SegmentConditionModel(BaseModel):
value: typing.Optional[str] = None
property_: typing.Optional[str] = None

def matches_trait_value(self, trait_value: typing.Any) -> bool:
# TODO: move this logic to the evaluator module
with suppress(ValueError):
if type(self.value) is str and is_semver(self.value):
trait_value = semver.VersionInfo.parse(trait_value)
if self.operator in self._EXCEPTION_OPERATOR_METHODS:
evaluator_function = getattr(
self, self._EXCEPTION_OPERATOR_METHODS.get(self.operator)
)
return evaluator_function(trait_value)

matching_function_name = {
constants.EQUAL: "__eq__",
constants.GREATER_THAN: "__gt__",
constants.GREATER_THAN_INCLUSIVE: "__ge__",
constants.LESS_THAN: "__lt__",
constants.LESS_THAN_INCLUSIVE: "__le__",
constants.NOT_EQUAL: "__ne__",
constants.CONTAINS: "__contains__",
}.get(self.operator)
matching_function = getattr(
trait_value, matching_function_name, lambda v: False
)
to_same_type_as_trait_value = get_casting_function(trait_value)
return matching_function(to_same_type_as_trait_value(self.value))

return False

def evaluate_not_contains(self, trait_value: typing.Iterable) -> bool:
return self.value not in trait_value

def evaluate_regex(self, trait_value: str) -> bool:
return (
trait_value is not None
and re.compile(str(self.value)).match(str(trait_value)) is not None
)

def evaluate_modulo(self, trait_value: typing.Union[str, int, float, bool]) -> bool:
if type(trait_value) not in (int, float):
return False
try:
divisor, remainder = self.value.split("|")
divisor = float(divisor)
remainder = float(remainder)
except ValueError:
return False
return trait_value % divisor == remainder

def evaluate_in(self, trait_value) -> bool:
try:
return str(trait_value) in self.value.split(",")
except AttributeError:
return False


class SegmentRuleModel(BaseModel):
type: RuleType
Expand Down
Loading