Skip to content

Commit

Permalink
feat: Improve the syntax for conditions with multiple predicates (#3427)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Jul 18, 2024
1 parent f8c9776 commit bdc747d
Show file tree
Hide file tree
Showing 13 changed files with 1,225 additions and 151 deletions.
14 changes: 6 additions & 8 deletions altair/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
# ruff: noqa
__version__ = "5.4.0dev"

from typing import Any

# Necessary as mypy would see expr as the module alt.expr although due to how
# the imports are set up it is expr in the alt.expr module
expr: Any


# The content of __all__ is automatically written by
# tools/update_init_file.py. Do not modify directly.
__all__ = [
Expand Down Expand Up @@ -54,6 +47,7 @@
"BrushConfig",
"CalculateTransform",
"Categorical",
"ChainedWhen",
"Chart",
"ChartDataType",
"ChartType",
Expand Down Expand Up @@ -488,6 +482,7 @@
"TextDef",
"TextDirection",
"TextValue",
"Then",
"Theta",
"Theta2",
"Theta2Datum",
Expand Down Expand Up @@ -565,6 +560,7 @@
"VegaLiteSchema",
"ViewBackground",
"ViewConfig",
"When",
"WindowEventType",
"WindowFieldDef",
"WindowOnlyOp",
Expand Down Expand Up @@ -622,7 +618,6 @@
"load_ipython_extension",
"load_schema",
"mixins",
"overload",
"param",
"parse_shorthand",
"renderers",
Expand All @@ -645,6 +640,7 @@
"vconcat",
"vegalite",
"vegalite_compilers",
"when",
"with_property_setters",
]

Expand All @@ -654,7 +650,9 @@ def __dir__():


from altair.vegalite import *
from altair.vegalite.v5.schema.core import Dict
from altair.jupyter import JupyterChart
from altair.expr import expr
from altair.utils import AltairDeprecationWarning


Expand Down
2 changes: 2 additions & 0 deletions altair/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
update_nested,
display_traceback,
SchemaBase,
SHORTHAND_KEYS,
)
from .html import spec_to_html
from .plugin_registry import PluginRegistry
Expand All @@ -16,6 +17,7 @@


__all__ = (
"SHORTHAND_KEYS",
"AltairDeprecationWarning",
"Optional",
"PluginRegistry",
Expand Down
47 changes: 23 additions & 24 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,7 @@
import sys
import traceback
import warnings
from typing import (
Callable,
TypeVar,
Any,
Iterator,
cast,
Literal,
Protocol,
TYPE_CHECKING,
runtime_checkable,
)
from typing import Callable, TypeVar, Any, Iterator, cast, Literal, TYPE_CHECKING
from itertools import groupby
from operator import itemgetter

Expand All @@ -33,6 +23,10 @@

from altair.utils.schemapi import SchemaBase, Undefined

if sys.version_info >= (3, 12):
from typing import runtime_checkable, Protocol
else:
from typing_extensions import runtime_checkable, Protocol
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
Expand Down Expand Up @@ -199,6 +193,22 @@ def __dataframe__(
"utcsecondsmilliseconds",
]

VALID_TYPECODES = list(itertools.chain(iter(TYPECODE_MAP), iter(INV_TYPECODE_MAP)))

SHORTHAND_UNITS = {
"field": "(?P<field>.*)",
"type": "(?P<type>{})".format("|".join(VALID_TYPECODES)),
"agg_count": "(?P<aggregate>count)",
"op_count": "(?P<op>count)",
"aggregate": "(?P<aggregate>{})".format("|".join(AGGREGATES)),
"window_op": "(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)),
"timeUnit": "(?P<timeUnit>{})".format("|".join(TIMEUNITS)),
}

SHORTHAND_KEYS: frozenset[Literal["field", "aggregate", "type", "timeUnit"]] = (
frozenset(("field", "aggregate", "type", "timeUnit"))
)


def infer_vegalite_type_for_pandas(
data: object,
Expand Down Expand Up @@ -577,18 +587,6 @@ def parse_shorthand(
if not shorthand:
return {}

valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP)

units = {
"field": "(?P<field>.*)",
"type": "(?P<type>{})".format("|".join(valid_typecodes)),
"agg_count": "(?P<aggregate>count)",
"op_count": "(?P<op>count)",
"aggregate": "(?P<aggregate>{})".format("|".join(AGGREGATES)),
"window_op": "(?P<op>{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)),
"timeUnit": "(?P<timeUnit>{})".format("|".join(TIMEUNITS)),
}

patterns = []

if parse_aggregates:
Expand All @@ -606,7 +604,8 @@ def parse_shorthand(
patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns)))

regexps = (
re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns
re.compile(r"\A" + p.format(**SHORTHAND_UNITS) + r"\Z", re.DOTALL)
for p in patterns
)

# find matches depending on valid fields passed
Expand Down
Loading

0 comments on commit bdc747d

Please sign in to comment.