Skip to content

Commit

Permalink
Search attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
cretz committed Jun 8, 2022
1 parent 8dad0fe commit edbd7dd
Show file tree
Hide file tree
Showing 16 changed files with 579 additions and 419 deletions.
13 changes: 12 additions & 1 deletion temporalio/bridge/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Worker using SDK Core."""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable, List
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable, List, Mapping

import google.protobuf.internal.containers

Expand Down Expand Up @@ -366,3 +366,14 @@ async def encode_completion(
)
for val in command.start_child_workflow_execution.memo.values():
await _encode_bridge_payload(val, codec)


def encode_search_attributes(
attrs: temporalio.common.SearchAttributes,
payloads: Mapping[str, temporalio.bridge.proto.common.Payload],
) -> None:
"""Encode search attributes as bridge payloads."""
for k, vals in attrs.items():
payloads[k].CopyFrom(
to_bridge_payload(temporalio.converter.encode_search_attribute_values(vals))
)
29 changes: 14 additions & 15 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ async def start_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -234,7 +234,7 @@ async def start_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -259,7 +259,7 @@ async def start_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -283,7 +283,7 @@ async def start_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -305,7 +305,7 @@ async def start_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand Down Expand Up @@ -388,7 +388,7 @@ async def execute_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -413,7 +413,7 @@ async def execute_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -438,7 +438,7 @@ async def execute_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -462,7 +462,7 @@ async def execute_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand All @@ -484,7 +484,7 @@ async def execute_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy] = None,
cron_schedule: str = "",
memo: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[Mapping[str, Any]] = None,
search_attributes: Optional[temporalio.common.SearchAttributes] = None,
header: Optional[Mapping[str, Any]] = None,
start_signal: Optional[str] = None,
start_signal_args: Iterable[Any] = [],
Expand Down Expand Up @@ -1297,7 +1297,7 @@ class StartWorkflowInput:
retry_policy: Optional[temporalio.common.RetryPolicy]
cron_schedule: str
memo: Optional[Mapping[str, Any]]
search_attributes: Optional[Mapping[str, Any]]
search_attributes: Optional[temporalio.common.SearchAttributes]
header: Optional[Mapping[str, Any]]
start_signal: Optional[str]
start_signal_args: Iterable[Any]
Expand Down Expand Up @@ -1527,10 +1527,9 @@ async def start_workflow(
for k, v in input.memo.items():
req.memo.fields[k] = (await self._client.data_converter.encode([v]))[0]
if input.search_attributes is not None:
for k, v in input.search_attributes.items():
req.search_attributes.indexed_fields[k] = (
await self._client.data_converter.encode([v])
)[0]
temporalio.converter.encode_search_attributes(
input.search_attributes, req.search_attributes
)
if input.header is not None:
for k, v in input.header.items():
req.header.fields[k] = (await self._client.data_converter.encode([v]))[
Expand Down
14 changes: 12 additions & 2 deletions temporalio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import timedelta
from datetime import datetime, timedelta
from enum import IntEnum
from typing import Any, Iterable, Optional
from typing import Any, Iterable, List, Mapping, Optional, Union

from typing_extensions import TypeAlias

import temporalio.api.common.v1
import temporalio.api.enums.v1
Expand Down Expand Up @@ -106,6 +108,14 @@ class QueryRejectCondition(IntEnum):
"""See :py:attr:`temporalio.api.enums.v1.QueryRejectCondition.QUERY_REJECT_CONDITION_NOT_COMPLETED_CLEANLY`."""


SearchAttributeValue: TypeAlias = Union[str, int, float, bool, datetime]

# We choose to make this a list instead of an iterable so we can catch if people
# are not sending lists each time but maybe accidentally sending a string (which
# is iterable)
SearchAttributes: TypeAlias = Mapping[str, List[SearchAttributeValue]]


# Should be set as the "arg" argument for _arg_or_args checks where the argument
# is unset. This is different than None which is a legitimate argument.
_arg_unset = object()
Expand Down
64 changes: 64 additions & 0 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type

import dacite
Expand All @@ -16,6 +17,7 @@
import google.protobuf.symbol_database

import temporalio.api.common.v1
import temporalio.common


class PayloadConverter(ABC):
Expand Down Expand Up @@ -594,6 +596,68 @@ def default() -> DataConverter:
return _default


def encode_search_attributes(
attrs: temporalio.common.SearchAttributes,
api: temporalio.api.common.v1.SearchAttributes,
) -> None:
"""Convert search attributes into an API message.
Args:
attrs: Search attributes to convert.
api: API message to set converted attributes on.
"""
for k, v in attrs.items():
api.indexed_fields[k].CopyFrom(encode_search_attribute_values(v))


def encode_search_attribute_values(
vals: List[temporalio.common.SearchAttributeValue],
) -> temporalio.api.common.v1.Payload:
"""Convert search attribute values into a payload.
Args:
vals: List of values to convert.
"""
if not isinstance(vals, list):
raise TypeError("Search attribute values must be lists")
# Convert dates to strings
safe_vals = []
for v in vals:
if isinstance(v, datetime):
if v.tzinfo is None:
raise ValueError(
"Timezone must be present on all search attribute dates"
)
v = v.isoformat()
safe_vals.append(v)
return default().payload_converter.to_payloads([safe_vals])[0]


def decode_search_attributes(
api: temporalio.api.common.v1.SearchAttributes,
) -> temporalio.common.SearchAttributes:
"""Decode API search attributes to values.
Args:
api: API message with search attribute values to convert.
Returns:
Converted search attribute values.
"""
conv = default().payload_converter
ret = {}
for k, v in api.indexed_fields.items():
val = conv.from_payloads([v])[0]
# If a value did not come back as a list, make it a single-item list
if not isinstance(val, list):
val = [val]
# Convert each item to datetime if necessary
if v.metadata.get("type") == b"Datetime":
val = [datetime.fromisoformat(v) for v in val]
ret[k] = val
return ret


class _FunctionTypeLookup:
def __init__(self, type_hint_eval_str: bool) -> None:
# Keyed by callable __qualname__, value is optional arg types and
Expand Down
4 changes: 2 additions & 2 deletions temporalio/worker/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class ContinueAsNewInput:
run_timeout: Optional[timedelta]
task_timeout: Optional[timedelta]
memo: Optional[Mapping[str, Any]]
search_attributes: Optional[Mapping[str, Any]]
search_attributes: Optional[temporalio.common.SearchAttributes]
# The types may be absent
arg_types: Optional[List[Type]]

Expand Down Expand Up @@ -203,7 +203,7 @@ class StartChildWorkflowInput:
retry_policy: Optional[temporalio.common.RetryPolicy]
cron_schedule: str
memo: Optional[Mapping[str, Any]]
search_attributes: Optional[Mapping[str, Any]]
search_attributes: Optional[temporalio.common.SearchAttributes]
# The types may be absent
arg_types: Optional[List[Type]]
ret_type: Optional[Type]
Expand Down
3 changes: 3 additions & 0 deletions temporalio/worker/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ async def _create_workflow_instance(
run_timeout=start.workflow_run_timeout.ToTimedelta()
if start.HasField("workflow_run_timeout")
else None,
search_attributes=temporalio.converter.decode_search_attributes(
start.search_attributes
),
start_time=act.timestamp.ToDatetime().replace(tzinfo=timezone.utc),
task_queue=self._task_queue,
task_timeout=start.workflow_task_timeout.ToTimedelta(),
Expand Down
41 changes: 23 additions & 18 deletions temporalio/worker/workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Iterable,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Sequence,
Expand Down Expand Up @@ -625,7 +626,7 @@ def workflow_continue_as_new(
run_timeout: Optional[timedelta],
task_timeout: Optional[timedelta],
memo: Optional[Mapping[str, Any]],
search_attributes: Optional[Mapping[str, Any]],
search_attributes: Optional[temporalio.common.SearchAttributes],
) -> NoReturn:
# Use definition if callable
name: Optional[str] = None
Expand Down Expand Up @@ -811,7 +812,7 @@ async def workflow_start_child_workflow(
retry_policy: Optional[temporalio.common.RetryPolicy],
cron_schedule: str,
memo: Optional[Mapping[str, Any]],
search_attributes: Optional[Mapping[str, Any]],
search_attributes: Optional[temporalio.common.SearchAttributes],
) -> temporalio.workflow.ChildWorkflowHandle[Any, Any]:
# Use definition if callable
name: str
Expand Down Expand Up @@ -889,6 +890,20 @@ def workflow_start_local_activity(
)
)

def workflow_upsert_search_attributes(
self, attributes: temporalio.common.SearchAttributes
) -> None:
v = self._add_command().upsert_workflow_search_attributes_command_attributes
v.seq = self._next_seq("upsert_search_attributes")
temporalio.bridge.worker.encode_search_attributes(
attributes, v.search_attributes
)
# Update the keys in the existing dictionary. We keep exact values sent
# in instead of any kind of normalization. This means empty lists remain
# as empty lists which matches what the server does. We know this is
# mutable, so we can cast it as such.
cast(MutableMapping, self._info.search_attributes).update(attributes)

async def workflow_wait_condition(
self, fn: Callable[[], bool], *, timeout: Optional[float] = None
) -> None:
Expand Down Expand Up @@ -1665,15 +1680,9 @@ def _apply_start_command(
self._instance._payload_converter.to_payloads([val])[0]
)
if self._input.search_attributes:
for k, val in self._input.search_attributes.items():
v.search_attributes[k] = temporalio.bridge.worker.to_bridge_payload(
# We have to use the default data converter for this
(
temporalio.converter.default().payload_converter.to_payloads(
[val]
)
)[0]
)
temporalio.bridge.worker.encode_search_attributes(
self._input.search_attributes, v.search_attributes
)
v.cancellation_type = cast(
"temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType",
int(self._input.cancellation_type),
Expand Down Expand Up @@ -1786,10 +1795,6 @@ def _apply_command(
self._instance._payload_converter.to_payloads([val])[0]
)
if self._input.search_attributes:
for k, val in self._input.search_attributes.items():
v.search_attributes[k] = temporalio.bridge.worker.to_bridge_payload(
# We have to use the default data converter for this
temporalio.converter.default().payload_converter.to_payloads([val])[
0
]
)
temporalio.bridge.worker.encode_search_attributes(
self._input.search_attributes, v.search_attributes
)
Loading

0 comments on commit edbd7dd

Please sign in to comment.