Skip to content

Commit

Permalink
Update for python 3.9 (#13)
Browse files Browse the repository at this point in the history
Also added some more testing around domain models
  • Loading branch information
hmvp authored May 10, 2021
1 parent 706ef48 commit 33e0271
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.0.14
current_version = 0.0.15
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

[Unreleased]

- Fixed database intialisation
- Project scaffolding
- TLC for python 3.9
2 changes: 1 addition & 1 deletion orchestrator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

"""This is the orchestrator workflow engine."""

__version__ = "0.0.14"
__version__ = "0.0.15"

from orchestrator.app import OrchestratorCore
from orchestrator.settings import app_settings, oauth2_settings
Expand Down
29 changes: 13 additions & 16 deletions orchestrator/domain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pydantic import BaseModel, Field, ValidationError
from pydantic.main import ModelMetaclass
from pydantic.types import ConstrainedList
from pydantic.typing import get_args, get_origin
from sqlalchemy import and_
from sqlalchemy.orm import selectinload

Expand All @@ -49,8 +50,8 @@ def _is_constrained_list_type(type: Type) -> bool:
except Exception:

# Strip generic arguments, it still might be a subclass
if hasattr(type, "__origin__"):
return _is_constrained_list_type(type.__origin__)
if get_origin(type):
return _is_constrained_list_type(get_origin(type))
else:
return False

Expand Down Expand Up @@ -91,7 +92,7 @@ def __init_subclass__(
# Check if child subscription instance models conform to the same lifecycle
for product_block_field_name, product_block_field_type in cls._product_block_fields_.items():
if is_list_type(product_block_field_type) or is_optional_type(product_block_field_type):
product_block_field_type = product_block_field_type.__args__[0]
product_block_field_type = get_args(product_block_field_type)[0]

if lifecycle:
for lifecycle_status in lifecycle:
Expand Down Expand Up @@ -149,7 +150,7 @@ def _init_instances(

if is_list_type(product_block_field_type):
if _is_constrained_list_type(product_block_field_type):
product_block_model = product_block_field_type.__args__[0]
product_block_model = get_args(product_block_field_type)[0]
default_value = product_block_field_type()
# if constrainedlist has minimum, return that minimum else empty list
if product_block_field_type.min_items:
Expand Down Expand Up @@ -232,7 +233,7 @@ def domain_filter(instance: SubscriptionInstanceTable) -> bool:
else:
product_block_model_list = instances[product_block_field_name]

product_block_model = product_block_field_type.__args__[0]
product_block_model = get_args(product_block_field_type)[0]
instance_list: List[SubscriptionInstanceTable] = list(
filter(
filter_func, flatten(grouped_instances.get(name, []) for name in product_block_model.__names__)
Expand All @@ -247,7 +248,7 @@ def domain_filter(instance: SubscriptionInstanceTable) -> bool:
else:
product_block_model = product_block_field_type
if is_optional_type(product_block_field_type):
product_block_model = product_block_model.__args__[0]
product_block_model = get_args(product_block_model)[0]

instance = only(
list(
Expand Down Expand Up @@ -426,19 +427,19 @@ def _load_instances_values(cls, instance_values: List[SubscriptionInstanceValueT
"""
instance_values_dict: State = {}
list_field_names = []
list_field_names = set()

# Set default values
for field_name, field_type in cls._non_product_block_fields_.items():
# Ensure that empty lists are handled OK
if is_list_type(field_type):
instance_values_dict[field_name] = []
list_field_names.append(field_name)
list_field_names.add(field_name)

for siv in instance_values:
# check the type of the siv in the instance and act accordingly: only lists and scalar values supported
resource_type_name = siv.resource_type.resource_type
if is_list_type(cls._non_product_block_fields_[resource_type_name]):
if resource_type_name in list_field_names:
instance_values_dict[resource_type_name].append(siv.value)
else:
instance_values_dict[resource_type_name] = siv.value
Expand Down Expand Up @@ -760,7 +761,7 @@ def find_product_block_in(cls: Type[DomainModel]) -> List[ProductBlockModel]:
product_blocks_in_model = []
for product_block_field_type in cls._product_block_fields_.values():
if is_list_type(product_block_field_type) or is_optional_type(product_block_field_type):
product_block_model = product_block_field_type.__args__[0]
product_block_model = get_args(product_block_field_type)[0]
else:
product_block_model = product_block_field_type

Expand Down Expand Up @@ -1003,12 +1004,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
# This makes a lot of assuptions about the internals of `typing`
if "__orig_bases__" in cls.__dict__ and cls.__dict__["__orig_bases__"]:
generic_base_cls = cls.__dict__["__orig_bases__"][0]
if (
not hasattr(generic_base_cls, "item_type")
and hasattr(generic_base_cls, "__args__")
and generic_base_cls.__args__
):
cls.item_type = generic_base_cls.__args__[0]
if not hasattr(generic_base_cls, "item_type") and get_args(generic_base_cls):
cls.item_type = get_args(generic_base_cls)[0]

# Make sure __args__ is set
cls.__args__ = (cls.item_type,)
27 changes: 14 additions & 13 deletions orchestrator/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, Type, TypedDict, TypeVar, Union

from pydantic import BaseModel
from pydantic.typing import get_args, get_origin

UUIDstr = str
State = Dict[str, Any]
Expand Down Expand Up @@ -102,10 +103,10 @@ def is_of_type(t: Any, test_type: Any) -> bool:
"""

if (
hasattr(t, "__origin__")
and hasattr(test_type, "__origin__")
and t.__origin__ is test_type.__origin__
and t.__args__ == test_type.__args__
get_origin(t)
and get_origin(test_type)
and get_origin(t) is get_origin(test_type)
and get_args(t) == get_args(test_type)
):
return True

Expand Down Expand Up @@ -148,16 +149,16 @@ def is_list_type(t: Any, test_type: Optional[type] = None) -> bool:
>>> is_list_type(Literal[1,2,3])
False
"""
if hasattr(t, "__origin__"):
if get_origin(t):
if is_optional_type(t):
for arg in t.__args__:
for arg in get_args(t):
if is_list_type(arg, test_type):
return True
elif t.__origin__ == Literal:
elif get_origin(t) == Literal: # type:ignore
return False # Literal cannot contain lists see pep 586
elif issubclass(t.__origin__, list):
if test_type and t.__args__:
return is_of_type(t.__args__[0], test_type)
elif issubclass(get_origin(t), list):
if test_type and get_args(t):
return is_of_type(get_args(t)[0], test_type)
else:
return True

Expand All @@ -184,9 +185,9 @@ def is_optional_type(t: Any, test_type: Optional[type] = None) -> bool:
>>> is_optional_type(int)
False
"""
if hasattr(t, "__origin__"):
if t.__origin__ == Union and len(t.__args__) == 2:
for arg in t.__args__:
if get_origin(t):
if get_origin(t) == Union and len(get_args(t)) == 2 and None.__class__ in get_args(t): # type:ignore
for arg in get_args(t):
if arg is None.__class__:
continue

Expand Down
16 changes: 9 additions & 7 deletions orchestrator/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Any, Callable, List, Optional, Tuple, Union, cast
from uuid import UUID

from pydantic.typing import get_args

from orchestrator.domain.base import SubscriptionModel
from orchestrator.types import (
FormGenerator,
Expand Down Expand Up @@ -142,9 +144,7 @@ def _build_arguments(func: Union[StepFunc, InputStepFunc], state: State) -> List
Domain models are retrieved from the DB (after `subscription_id` lookup in the state). Everything else is
retrieved from the state.
One exception: if a domain model is requested, but no key (variable name) is found for it in the state, it is
interpreted as a request to instantiate it on behalf of the step function. To do so it does lookup `product` and
`customer` values (both UUIDs) in the state.
For domain models only ``Optional`` and ``List`` are supported as container types. Union, Dict and others are not supported
Args:
func: step function to inspect for requested arguments
Expand Down Expand Up @@ -191,15 +191,15 @@ def _build_arguments(func: Union[StepFunc, InputStepFunc], state: State) -> List
subscription_ids = map(_get_sub_id, state.get(name, []))
subscriptions = [
# Actual type is first argument from list type
param.annotation.__args__[0].from_subscription(subscription_id)
get_args(param.annotation)[0].from_subscription(subscription_id)
for subscription_id in subscription_ids
]
arguments.append(subscriptions)
elif is_optional_type(param.annotation, SubscriptionModel):
subscription_id = _get_sub_id(state.get(name))
if subscription_id:
# Actual type is first argument from union type
sub_mod = param.annotation.__args__[0].from_subscription(subscription_id)
# Actual type is first argument from optional type
sub_mod = get_args(param.annotation)[0].from_subscription(subscription_id)
arguments.append(sub_mod)
else:
arguments.append(None)
Expand Down Expand Up @@ -245,7 +245,7 @@ def load_initial_state_for_modify(organisation: UUID, subscription_id: UUID) ->
and passed as values to the step function. The dict `new_state` returned by the step function will be merged with
that of the original `state` dict and returned as the final result.
It knows how to deal with Optional parameters. Eg, given::
It knows how to deal with parameters that have a default. Eg, given::
@inject_args
def do_stuff_with_saps(subscription_id: UUID, sap1: Dict, sap2: Optional[Dict] = None) -> State:
Expand Down Expand Up @@ -284,6 +284,8 @@ def do_stuff(light_path: Sn8LightPath) -> State:
present in the state. This will not work for more than one domain model. Eg. you can't request two domain
models to be created as we will not know to which of the two domain models `product` is applicable to.
Also supported is wrapping a domain model in ``Optional`` or ``List``. Other types are not supported.
Args:
func: a step function with parameters (that should be keys into the state dict, except for optional ones)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ requires = [
"oauth2-lib~=1.0.4"
]
description-file = "README.md"
requires-python = ">=3.6,<3.9"
requires-python = ">=3.6,<3.10"

[tool.flit.metadata.urls]
Documentation = "https://workfloworchestrator.org/"
Expand Down
16 changes: 11 additions & 5 deletions test/unit_tests/api/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def long_running_step():

@workflow("Long Running Workflow")
def long_running_workflow_py():
return init >> long_running_step >> done
return init >> long_running_step >> long_running_step >> done

with WorkflowInstanceForTests(long_running_workflow_py, "long_running_workflow_py"):

Expand Down Expand Up @@ -143,23 +143,29 @@ def test_long_running_pause(test_client, long_running_workflow):
assert response.json()["global_status"] == "PAUSED"

response = test_client.get(f"api/processes/{pid}")
assert len(response.json()["steps"]) == 3
assert len(response.json()["steps"]) == 4
assert response.json()["current_state"]["done"] is True
# assume ordered steplist
assert response.json()["steps"][2]["status"] == "pending"
assert response.json()["steps"][3]["status"] == "pending"

response = test_client.put("/api/settings/status", json={"global_lock": False})

# Make sure it started again
time.sleep(1)

assert response.json()["global_lock"] is False
assert response.json()["running_processes"] == 1
assert response.json()["global_status"] == "RUNNING"

# Let it finish
# Let it finish after second lock step
with test_condition:
test_condition.notify_all()
time.sleep(1)

response = test_client.get(f"api/processes/{pid}")
assert HTTPStatus.OK == response.status_code
# assume ordered steplist
assert response.json()["steps"][2]["status"] == "complete"
assert response.json()["steps"][3]["status"] == "complete"

app_settings.TESTING = True

Expand Down
Loading

0 comments on commit 33e0271

Please sign in to comment.