Skip to content

Commit

Permalink
IWF-397: Add channel sizes data (#66)
Browse files Browse the repository at this point in the history
* IWF-397: Add channel sizes data

* IWF-397: Lint

* IWF-397: Add signal_channel_types to Communication
  • Loading branch information
lwolczynski authored Dec 10, 2024
1 parent ec9fa94 commit c3bd12a
Show file tree
Hide file tree
Showing 12 changed files with 692 additions and 302 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ git submodule update --remote --merge
This project uses [openapi-python-client](https://github.com/openapi-generators/openapi-python-client) to generate an API client from the IDL. To update the generated client:
```bash
mkdir iwf/iwf_api/iwf_api
cd iwf && poetry run openapi-python-client update --path ../iwf-idl/iwf-sdk.yaml --config .openapi-python-client-config.yaml
cd .. && cp -R iwf/iwf_api/iwf_api/* iwf/iwf_api && rm -R iwf/iwf_api/iwf_api && poetry update
```
Then run `cd .. && cp -R iwf/iwf_api/iwf_api/* iwf/iwf_api && rm -R iwf/iwf_api/iwf_api && poetry update` to
The last command will:
* Fix the api package path
* Update the local path dependency.
#### Linting
Expand Down
2 changes: 1 addition & 1 deletion iwf-idl
Submodule iwf-idl updated 2 files
+13 −0 iwf-sdk.yaml
+13 −0 iwf.yaml
77 changes: 71 additions & 6 deletions iwf/communication.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
from typing import Any, Optional, Union

from iwf.errors import WorkflowDefinitionError
from iwf.iwf_api.models import EncodedObject, InterStateChannelPublishing
from iwf.iwf_api.models import (
EncodedObject,
InterStateChannelPublishing,
WorkflowWorkerRpcRequestInternalChannelInfos,
WorkflowWorkerRpcRequestSignalChannelInfos,
)
from iwf.object_encoder import ObjectEncoder
from iwf.state_movement import StateMovement


class Communication:
_type_store: dict[str, Optional[type]]
_internal_channel_type_store: dict[str, Optional[type]]
_signal_channel_type_store: dict[str, Optional[type]]
_object_encoder: ObjectEncoder
_to_publish_internal_channel: dict[str, list[EncodedObject]]
_state_movements: list[StateMovement]
_internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos]
_signal_channel_infos: Optional[WorkflowWorkerRpcRequestSignalChannelInfos]

def __init__(
self, type_store: dict[str, Optional[type]], object_encoder: ObjectEncoder
self,
internal_channel_type_store: dict[str, Optional[type]],
signal_channel_type_store: dict[str, Optional[type]],
object_encoder: ObjectEncoder,
internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos],
signal_channel_infos: Optional[WorkflowWorkerRpcRequestSignalChannelInfos],
):
self._object_encoder = object_encoder
self._type_store = type_store
self._internal_channel_type_store = internal_channel_type_store
self._signal_channel_type_store = signal_channel_type_store
self._to_publish_internal_channel = {}
self._state_movements = []
self._internal_channel_infos = internal_channel_infos
self._signal_channel_infos = signal_channel_infos

def trigger_state_execution(self, state: Union[str, type], state_input: Any = None):
"""
Expand All @@ -31,10 +47,10 @@ def trigger_state_execution(self, state: Union[str, type], state_input: Any = No
self._state_movements.append(movement)

def publish_to_internal_channel(self, channel_name: str, value: Any = None):
registered_type = self._type_store.get(channel_name)
registered_type = self._internal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._type_store.items():
for name, t in self._internal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t

Expand Down Expand Up @@ -66,3 +82,52 @@ def get_to_publishing_internal_channel(self) -> list[InterStateChannelPublishing

def get_to_trigger_state_movements(self) -> list[StateMovement]:
return self._state_movements

def get_internal_channel_size(self, channel_name):
registered_type = self._internal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._internal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t

if registered_type is None:
raise WorkflowDefinitionError(
f"InternalChannel channel_name is not defined {channel_name}"
)

if (
self._internal_channel_infos is not None
and channel_name in self._internal_channel_infos
):
server_channel_size = self._internal_channel_infos[channel_name].size
else:
server_channel_size = 0

if channel_name in self._to_publish_internal_channel:
buffer_channel_size = len(self._to_publish_internal_channel[channel_name])
else:
buffer_channel_size = 0

return server_channel_size + buffer_channel_size

def get_signal_channel_size(self, channel_name):
registered_type = self._signal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._signal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t

if registered_type is None:
raise WorkflowDefinitionError(
f"SignalChannel channel_name is not defined {channel_name}"
)

if (
self._signal_channel_infos is not None
and channel_name in self._signal_channel_infos
):
return self._signal_channel_infos[channel_name].size
else:
return 0
10 changes: 10 additions & 0 deletions iwf/iwf_api/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Contains all the data models used in inputs/outputs """

from .channel_info import ChannelInfo
from .channel_request_status import ChannelRequestStatus
from .command_combination import CommandCombination
from .command_request import CommandRequest
Expand Down Expand Up @@ -79,9 +80,16 @@
WorkflowWaitForStateCompletionResponse,
)
from .workflow_worker_rpc_request import WorkflowWorkerRpcRequest
from .workflow_worker_rpc_request_internal_channel_infos import (
WorkflowWorkerRpcRequestInternalChannelInfos,
)
from .workflow_worker_rpc_request_signal_channel_infos import (
WorkflowWorkerRpcRequestSignalChannelInfos,
)
from .workflow_worker_rpc_response import WorkflowWorkerRpcResponse

__all__ = (
"ChannelInfo",
"ChannelRequestStatus",
"CommandCombination",
"CommandRequest",
Expand Down Expand Up @@ -157,5 +165,7 @@
"WorkflowWaitForStateCompletionRequest",
"WorkflowWaitForStateCompletionResponse",
"WorkflowWorkerRpcRequest",
"WorkflowWorkerRpcRequestInternalChannelInfos",
"WorkflowWorkerRpcRequestSignalChannelInfos",
"WorkflowWorkerRpcResponse",
)
57 changes: 57 additions & 0 deletions iwf/iwf_api/models/channel_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Any, Dict, List, Type, TypeVar, Union

import attr

from ..types import UNSET, Unset

T = TypeVar("T", bound="ChannelInfo")


@attr.s(auto_attribs=True)
class ChannelInfo:
"""
Attributes:
size (Union[Unset, int]):
"""

size: Union[Unset, int] = UNSET
additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
size = self.size

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update({})
if size is not UNSET:
field_dict["size"] = size

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
d = src_dict.copy()
size = d.pop("size", UNSET)

channel_info = cls(
size=size,
)

channel_info.additional_properties = d
return channel_info

@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> Any:
return self.additional_properties[key]

def __setitem__(self, key: str, value: Any) -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
56 changes: 56 additions & 0 deletions iwf/iwf_api/models/workflow_worker_rpc_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from ..models.encoded_object import EncodedObject
from ..models.key_value import KeyValue
from ..models.search_attribute import SearchAttribute
from ..models.workflow_worker_rpc_request_internal_channel_infos import (
WorkflowWorkerRpcRequestInternalChannelInfos,
)
from ..models.workflow_worker_rpc_request_signal_channel_infos import (
WorkflowWorkerRpcRequestSignalChannelInfos,
)


T = TypeVar("T", bound="WorkflowWorkerRpcRequest")
Expand All @@ -24,6 +30,8 @@ class WorkflowWorkerRpcRequest:
input_ (Union[Unset, EncodedObject]):
search_attributes (Union[Unset, List['SearchAttribute']]):
data_attributes (Union[Unset, List['KeyValue']]):
signal_channel_infos (Union[Unset, WorkflowWorkerRpcRequestSignalChannelInfos]):
internal_channel_infos (Union[Unset, WorkflowWorkerRpcRequestInternalChannelInfos]):
"""

context: "Context"
Expand All @@ -32,6 +40,12 @@ class WorkflowWorkerRpcRequest:
input_: Union[Unset, "EncodedObject"] = UNSET
search_attributes: Union[Unset, List["SearchAttribute"]] = UNSET
data_attributes: Union[Unset, List["KeyValue"]] = UNSET
signal_channel_infos: Union[Unset, "WorkflowWorkerRpcRequestSignalChannelInfos"] = (
UNSET
)
internal_channel_infos: Union[
Unset, "WorkflowWorkerRpcRequestInternalChannelInfos"
] = UNSET
additional_properties: Dict[str, Any] = attr.ib(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -59,6 +73,14 @@ def to_dict(self) -> Dict[str, Any]:

data_attributes.append(data_attributes_item)

signal_channel_infos: Union[Unset, Dict[str, Any]] = UNSET
if not isinstance(self.signal_channel_infos, Unset):
signal_channel_infos = self.signal_channel_infos.to_dict()

internal_channel_infos: Union[Unset, Dict[str, Any]] = UNSET
if not isinstance(self.internal_channel_infos, Unset):
internal_channel_infos = self.internal_channel_infos.to_dict()

field_dict: Dict[str, Any] = {}
field_dict.update(self.additional_properties)
field_dict.update(
Expand All @@ -74,6 +96,10 @@ def to_dict(self) -> Dict[str, Any]:
field_dict["searchAttributes"] = search_attributes
if data_attributes is not UNSET:
field_dict["dataAttributes"] = data_attributes
if signal_channel_infos is not UNSET:
field_dict["signalChannelInfos"] = signal_channel_infos
if internal_channel_infos is not UNSET:
field_dict["internalChannelInfos"] = internal_channel_infos

return field_dict

Expand All @@ -83,6 +109,12 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
from ..models.encoded_object import EncodedObject
from ..models.key_value import KeyValue
from ..models.search_attribute import SearchAttribute
from ..models.workflow_worker_rpc_request_internal_channel_infos import (
WorkflowWorkerRpcRequestInternalChannelInfos,
)
from ..models.workflow_worker_rpc_request_signal_channel_infos import (
WorkflowWorkerRpcRequestSignalChannelInfos,
)

d = src_dict.copy()
context = Context.from_dict(d.pop("context"))
Expand Down Expand Up @@ -114,13 +146,37 @@ def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:

data_attributes.append(data_attributes_item)

_signal_channel_infos = d.pop("signalChannelInfos", UNSET)
signal_channel_infos: Union[Unset, WorkflowWorkerRpcRequestSignalChannelInfos]
if isinstance(_signal_channel_infos, Unset):
signal_channel_infos = UNSET
else:
signal_channel_infos = WorkflowWorkerRpcRequestSignalChannelInfos.from_dict(
_signal_channel_infos
)

_internal_channel_infos = d.pop("internalChannelInfos", UNSET)
internal_channel_infos: Union[
Unset, WorkflowWorkerRpcRequestInternalChannelInfos
]
if isinstance(_internal_channel_infos, Unset):
internal_channel_infos = UNSET
else:
internal_channel_infos = (
WorkflowWorkerRpcRequestInternalChannelInfos.from_dict(
_internal_channel_infos
)
)

workflow_worker_rpc_request = cls(
context=context,
workflow_type=workflow_type,
rpc_name=rpc_name,
input_=input_,
search_attributes=search_attributes,
data_attributes=data_attributes,
signal_channel_infos=signal_channel_infos,
internal_channel_infos=internal_channel_infos,
)

workflow_worker_rpc_request.additional_properties = d
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar

import attr

if TYPE_CHECKING:
from ..models.channel_info import ChannelInfo


T = TypeVar("T", bound="WorkflowWorkerRpcRequestInternalChannelInfos")


@attr.s(auto_attribs=True)
class WorkflowWorkerRpcRequestInternalChannelInfos:
""" """

additional_properties: Dict[str, "ChannelInfo"] = attr.ib(init=False, factory=dict)

def to_dict(self) -> Dict[str, Any]:
pass

field_dict: Dict[str, Any] = {}
for prop_name, prop in self.additional_properties.items():
field_dict[prop_name] = prop.to_dict()

field_dict.update({})

return field_dict

@classmethod
def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T:
from ..models.channel_info import ChannelInfo

d = src_dict.copy()
workflow_worker_rpc_request_internal_channel_infos = cls()

additional_properties = {}
for prop_name, prop_dict in d.items():
additional_property = ChannelInfo.from_dict(prop_dict)

additional_properties[prop_name] = additional_property

workflow_worker_rpc_request_internal_channel_infos.additional_properties = (
additional_properties
)
return workflow_worker_rpc_request_internal_channel_infos

@property
def additional_keys(self) -> List[str]:
return list(self.additional_properties.keys())

def __getitem__(self, key: str) -> "ChannelInfo":
return self.additional_properties[key]

def __setitem__(self, key: str, value: "ChannelInfo") -> None:
self.additional_properties[key] = value

def __delitem__(self, key: str) -> None:
del self.additional_properties[key]

def __contains__(self, key: str) -> bool:
return key in self.additional_properties
Loading

0 comments on commit c3bd12a

Please sign in to comment.