Skip to content

Commit

Permalink
feat: redesign when-then-otherwise to account for condition restric…
Browse files Browse the repository at this point in the history
…tions

- Utilise `TypedDict.__extra_items__` https://peps.python.org/pep-0728/
- Use generics where knowledge of the internal structure was lacking
- Make `Then` a subclass of `SchemaBase`
- Introduce some informative errors for complex structural issues
- Update/add `@overload`s
  • Loading branch information
dangotbanned committed Jul 6, 2024
1 parent 6f06b80 commit 1292962
Showing 1 changed file with 167 additions and 24 deletions.
191 changes: 167 additions & 24 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,11 +581,72 @@ def _condition_to_selection(
return selection


class _Conditions(TypedDict, total=False):
condition: Required[list[_ConditionType]]
class _ConditionClosed(TypedDict, closed=True, total=False): # type: ignore[call-arg]
# https://peps.python.org/pep-0728/
# Parameter {"param", "value", "empty"}
# Predicate {"test", "value"}
empty: Optional[bool]
param: Parameter | str
test: _TestPredicateType
value: Any


class _ConditionExtra(TypedDict, closed=True, total=False): # type: ignore[call-arg]
# https://peps.python.org/pep-0728/
# Likely a Field predicate
empty: Optional[bool]
param: Parameter | str
test: _TestPredicateType
value: Any
__extra_items__: _StatementOrLiteralType


_Condition: TypeAlias = _ConditionExtra
"""A singular, non-chainable condition produced by ``.when()``."""

_Conditions: TypeAlias = t.List[_ConditionClosed]
"""Chainable conditions produced by ``.when()`` and ``Then.when()``."""

_C = TypeVar("_C", _Conditions, _Condition)


class _Conditional(TypedDict, t.Generic[_C], total=False):
condition: Required[_C]
value: Any


class _Value(TypedDict, total=False):
value: Required[Any]


def _reveal_parsed_shorthand(obj: Map, /) -> dict[str, Any]:
# Helper for producing a useful error message
short = {"field", "aggregate", "type", "timeUnit"}
return {k: v for k, v in obj.items() if k in short}


def _is_extra(*objs: Any, kwargs: Map) -> Iterator[bool]:
extra_keys = {"field", "aggregate", "type", "timeUnit"}
for el in objs:
if isinstance(el, (core.SchemaBase, t.Mapping)):
item = el.to_dict(validate=False) if isinstance(el, core.SchemaBase) else el
yield not (item.keys() - kwargs.keys()).isdisjoint(extra_keys)
else:
continue


def _is_condition_extra(
obj: Any, *objs: Any, kwargs: Map, lit: bool
) -> TypeIs[_Condition]:
# NOTE: Short circuits on the first conflict.
# 1 - originated from parse_shorthand
# 2 - Used a wrapper or `dict` directly, including `extra_keys`
if isinstance(obj, str):
return not lit
else:
return any(_is_extra(obj, *objs, kwargs=kwargs))


def _parse_when_constraints(
constraints: dict[str, Any], /
) -> Iterator[BinaryExpression]:
Expand Down Expand Up @@ -674,7 +735,7 @@ def _parse_literal(val: Any, *, str_as_lit: bool) -> dict[str, Any]:
if isinstance(val, str) and not str_as_lit:
return utils.parse_shorthand(val)
elif _is_one_or_seq_literal_value(val):
return value(val)
return {"value": val}
else:
msg = f"Expected one or more literal values, but got: {type(val).__name__!r}"
raise TypeError(msg)
Expand All @@ -696,12 +757,12 @@ def _parse_then(

def _parse_otherwise(
statement: _StatementOrLiteralType,
conditions: _Conditions,
conditions: _Conditional[Any],
kwargs: dict[str, Any],
*,
lit: bool,
) -> SchemaBase | _Conditions:
selection: SchemaBase | _Conditions
) -> SchemaBase | _Conditional[Any]:
selection: SchemaBase | _Conditional[Any]
if isinstance(statement, core.SchemaBase):
selection = statement.copy()
conditions.update(**kwargs) # type: ignore[call-arg]
Expand All @@ -724,8 +785,8 @@ def _when_then(
kwargs: dict[str, Any],
*,
lit: bool,
) -> _ConditionType:
condition = _deepcopy(self._condition)
) -> _ConditionClosed | _Condition:
condition: Any = _deepcopy(self._condition)
then = _parse_then(statement, kwargs, lit=lit)
condition.update(then)
return condition
Expand All @@ -746,13 +807,31 @@ class When(_BaseWhen):
def __init__(self, condition: _ConditionType, /) -> None:
self._condition = condition

@overload
def then(
self, statement: str, *, str_as_lit: Literal[False] = ..., **kwargs: Any
) -> Then[_Condition]: ...
@overload
def then(
self, statement: str, *, str_as_lit: Literal[True], **kwargs: Any
) -> Then[_Conditions]: ...
@overload
def then(
self,
statement: _Value | _LiteralNumeric | Sequence[_LiteralNumeric],
**kwargs: Any,
) -> Then[_Conditions]: ...
@overload
def then(
self, statement: dict[str, Any] | SchemaBase, **kwargs: Any
) -> Then[Any]: ...
def then(
self,
statement: _StatementOrLiteralType,
*,
str_as_lit: bool = False,
**kwargs: Any,
) -> Then:
) -> Then[Any]:
"""Attach a statement to this predicate.
Parameters
Expand All @@ -769,10 +848,13 @@ def then(
:class:`Then`
"""
condition = self._when_then(statement, kwargs, lit=str_as_lit)
return Then(_Conditions({"condition": [condition]}))
if _is_condition_extra(condition, statement, kwargs=kwargs, lit=str_as_lit):
return Then(_Conditional(condition=condition))
else:
return Then(_Conditional(condition=[condition]))


class Then:
class Then(core.SchemaBase, t.Generic[_C]):
"""Utility class for ``when-then-otherwise`` conditions.
Represents the state after calling :func:`.when().then()`.
Expand All @@ -787,8 +869,11 @@ class Then:
`polars.expr.whenthen <https://github.com/pola-rs/polars/blob/b85c5e0502ca99c77742ee25ba177e6cd11cf100/py-polars/polars/expr/whenthen.py>`__
"""

def __init__(self, conditions: _Conditions, /) -> None:
self._conditions = conditions
_schema = {"type": "object"}

def __init__(self, conditions: _Conditional[_C], /) -> None:
super().__init__(**conditions)
self.condition: _C

@overload
def otherwise(
Expand All @@ -797,18 +882,24 @@ def otherwise(
@overload
def otherwise(
self,
statement: dict[str, Any] | _OneOrSeqLiteralValue,
statement: _Value | _LiteralNumeric | Sequence[_LiteralNumeric],
**kwargs: Any,
) -> _Conditional[_Conditions]: ...
@overload
def otherwise(
self,
statement: Map | _OneOrSeqLiteralValue,
*,
str_as_lit: bool = ...,
**kwargs: Any,
) -> _Conditions: ...
) -> _Conditional[Any]: ...
def otherwise(
self,
statement: _StatementOrLiteralType,
*,
str_as_lit: bool = False,
**kwargs: Any,
) -> SchemaBase | _Conditions:
) -> SchemaBase | _Conditional[Any]:
"""Finalize the condition with a default value.
Parameters
Expand All @@ -823,7 +914,35 @@ def otherwise(
**kwargs
Additional keyword args are added to the resulting ``dict``.
"""
return _parse_otherwise(statement, self.to_dict(), kwargs, lit=str_as_lit)
conditions: _Conditional[Any]
is_extra = functools.partial(_is_condition_extra, kwargs=kwargs, lit=str_as_lit)
if is_extra(self.condition, statement):
current = self.condition
if isinstance(current, list) and len(current) == 1:
# This case is guaranteed to have come from `When` and not `ChainedWhen`
# The `list` isn't needed if we complete the condition here
conditions = _Conditional(condition=current[0])
elif isinstance(current, dict):
if not is_extra(statement):
conditions = self.to_dict()
else:
cond = _reveal_parsed_shorthand(current)
msg = (
f"Only one field may be used within a condition.\n"
f"Shorthand {statement!r} would conflict with {cond!r}\n\n"
f"Pass `str_as_lit=True` if {statement!r} is not shorthand."
)
raise TypeError(msg)
else:
# Generic message to cover less trivial cases
msg = (
f"Chained conditions cannot be mixed with field conditions.\n"
f"{self!r}\n\n{statement!r}"
)
raise TypeError(msg)
else:
conditions = self.to_dict()
return _parse_otherwise(statement, conditions, kwargs, lit=str_as_lit)

def when(
self,
Expand Down Expand Up @@ -858,10 +977,29 @@ def when(
A partial state which requires calling :meth:`ChainedWhen.then()` to finish the condition.
"""
condition = _parse_when(predicate, *more_predicates, empty=empty, **constraints)
return ChainedWhen(condition, self.to_dict())
conditions = self.to_dict()
current = conditions["condition"]
if isinstance(current, list):
conditions = t.cast(_Conditional[_Conditions], conditions)
return ChainedWhen(condition, conditions)
elif isinstance(current, dict):
cond = _reveal_parsed_shorthand(current)
msg = (
f"Chained conditions cannot be mixed with field conditions.\n"
f"Additional conditions would conflict with {cond!r}\n\n"
f"Must finalize by calling `.otherwise()`."
)
raise TypeError(msg)
else:
msg = (
f"The internal structure has been modified.\n"
f"{type(current).__name__!r} found, but only `dict | list` are valid."
)
raise NotImplementedError(msg)

def to_dict(self, *args: Any, **kwargs: Any) -> _Conditions:
return self._conditions.copy()
def to_dict(self, *args, **kwds) -> _Conditional[_C]: # type: ignore[override]
m = super().to_dict(*args, **kwds)
return _Conditional(condition=m["condition"])


class ChainedWhen(_BaseWhen):
Expand All @@ -877,7 +1015,12 @@ class ChainedWhen(_BaseWhen):
`polars.expr.whenthen <https://github.com/pola-rs/polars/blob/b85c5e0502ca99c77742ee25ba177e6cd11cf100/py-polars/polars/expr/whenthen.py>`__
"""

def __init__(self, condition: _ConditionType, conditions: _Conditions, /) -> None:
def __init__(
self,
condition: _ConditionType,
conditions: _Conditional[_Conditions],
/,
) -> None:
self._condition = condition
self._conditions = conditions

Expand All @@ -887,7 +1030,7 @@ def then(
*,
str_as_lit: bool = False,
**kwargs: Any,
) -> Then:
) -> Then[_Conditions]:
"""Attach a statement to this predicate.
Parameters
Expand Down Expand Up @@ -989,9 +1132,9 @@ def when(
# Top-Level Functions


def value(value, **kwargs) -> dict[str, Any]:
def value(value, **kwargs) -> _Value:
"""Specify a value for use in an encoding"""
return dict(value=value, **kwargs)
return _Value(value=value, **kwargs) # type: ignore[typeddict-item]


def param(
Expand Down

0 comments on commit 1292962

Please sign in to comment.