Skip to content

Commit

Permalink
feat: Basic support for SCXML test suit
Browse files Browse the repository at this point in the history
  • Loading branch information
fgmacedo committed Nov 22, 2024
1 parent 9b55852 commit ecc7957
Show file tree
Hide file tree
Showing 15 changed files with 532 additions and 25 deletions.
28 changes: 28 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,31 @@ def pytest_ignore_collect(collection_path, path, config):

if "django_project" in str(path):
return True


# @pytest.fixture(autouse=True, scope="module")
# def mock_dot_write(request):
# """
# This fixture avoids updating files while executing tests
# """

# def open_effect(
# filename,
# mode="r",
# *args,
# **kwargs,
# ):
# if mode in ("r", "rt", "rb"):
# return open(filename, mode, *args, **kwargs)
# elif filename.startswith("/tmp/"):
# return open(filename, mode, *args, **kwargs)
# elif "b" in mode:
# return io.BytesIO()
# else:
# return io.StringIO()

# # using global mock instead of the fixture mocker due to the ScopeMismatch
# # this fixture is module scoped and mocker is function scoped
# with mock.patch("pydot.core.io.open", spec=True) as m:
# m.side_effect = open_effect
# yield m
18 changes: 17 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dev = [
"pytest-mock >=3.10.0",
"pytest-benchmark >=4.0.0",
"pytest-asyncio",
"pydot",
"django >=5.0.8; python_version >='3.10'",
"pytest-django >=4.8.0; python_version >'3.8'",
"Sphinx; python_version >'3.8'",
Expand All @@ -51,6 +52,7 @@ dev = [
"sphinx-autobuild; python_version >'3.8'",
"furo >=2024.5.6; python_version >'3.8'",
"sphinx-copybutton >=0.5.2; python_version >'3.8'",
"pdbr>=0.8.9; python_version >='3.8'",
]

[build-system]
Expand All @@ -61,7 +63,21 @@ build-backend = "hatchling.build"
packages = ["statemachine/"]

[tool.pytest.ini_options]
addopts = "--ignore=docs/conf.py --ignore=docs/auto_examples/ --ignore=docs/_build/ --ignore=tests/examples/ --cov --cov-config .coveragerc --doctest-glob='*.md' --doctest-modules --doctest-continue-on-failure --benchmark-autosave --benchmark-group-by=name"
addopts = [
"--ignore=docs/conf.py",
"--ignore=docs/auto_examples/",
"--ignore=docs/_build/",
"--ignore=tests/examples/",
"--cov",
"--cov-config",
".coveragerc",
"--doctest-glob=*.md",
"--doctest-modules",
"--doctest-continue-on-failure",
"--benchmark-autosave",
"--benchmark-group-by=name",
"--pdbcls=pdbr:RichPdb",
]
doctest_optionflags = "ELLIPSIS IGNORE_EXCEPTION_DETAIL NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL"
asyncio_mode = "auto"
markers = ["""slow: marks tests as slow (deselect with '-m "not slow"')"""]
Expand Down
11 changes: 6 additions & 5 deletions statemachine/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import IntEnum
from enum import IntFlag
from enum import auto
from functools import partial
from inspect import isawaitable
from typing import TYPE_CHECKING
from typing import Callable
Expand Down Expand Up @@ -89,10 +90,10 @@ def __init__(
self.attr_name: str = func and func.fget and func.fget.__name__ or ""
elif callable(func):
self.reference = SpecReference.CALLABLE
self.is_bounded = hasattr(func, "__self__")
self.attr_name = (
func.__name__ if not self.is_event or self.is_bounded else f"_{func.__name__}_"
)
is_partial = isinstance(func, partial)
self.is_bounded = is_partial or hasattr(func, "__self__")
name = func.func.__name__ if is_partial else func.__name__
self.attr_name = name if not self.is_event or self.is_bounded else f"_{name}_"
if not self.is_bounded:
func.attr_name = self.attr_name
func.is_event = is_event
Expand All @@ -110,7 +111,7 @@ def __repr__(self):
return f"{type(self).__name__}({self.func!r}, is_convention={self.is_convention!r})"

def __str__(self):
name = getattr(self.func, "__name__", self.func)
name = self.attr_name
if self.expected_value is False:
name = f"!{name}"
return name
Expand Down
2 changes: 1 addition & 1 deletion statemachine/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _search_callable(self, spec):
yield listener.build_key(spec.attr_name), partial(callable_method, func)
return

yield f"{spec.attr_name}@None", partial(callable_method, spec.func)
yield f"{spec.attr_name}-{id(spec.func)}@None", partial(callable_method, spec.func)

def search_name(self, name):
for listener in self.items:
Expand Down
64 changes: 64 additions & 0 deletions statemachine/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Dict

from ..factory import StateMachineMetaclass
from ..state import State
from ..statemachine import StateMachine
from ..transition_list import TransitionList


def create_machine_class_from_definition(
name: str, definition: dict, **extra_kwargs
) -> StateMachine:
"""
Creates a StateMachine class from a dictionary definition, using the StateMachineMetaclass.
Example usage with a traffic light machine:
>>> machine = create_machine_class_from_definition(
... "TrafficLightMachine",
... {
... "states": {
... "green": {"initial": True},
... "yellow": {},
... "red": {},
... },
... "events": {
... "change": [
... {"from": "green", "to": "yellow"},
... {"from": "yellow", "to": "red"},
... {"from": "red", "to": "green"},
... ]
... },
... }
... )
"""

states_instances = {
state_id: State(**state_kwargs) for state_id, state_kwargs in definition["states"].items()
}

events: Dict[str, TransitionList] = {}
for event_name, transitions in definition["events"].items():
for transition_data in transitions:
source = states_instances[transition_data["from"]]
target = states_instances[transition_data["to"]]

transition = source.to(
target,
event=event_name,
cond=transition_data.get("cond"),
unless=transition_data.get("unless"),
on=transition_data.get("on"),
before=transition_data.get("before"),
after=transition_data.get("after"),
)

if event_name in events:
events[event_name] |= transition
elif event_name is not None:
events[event_name] = transition

attrs_mapper = {**extra_kwargs, **states_instances, **events}

return StateMachineMetaclass(name, (StateMachine,), attrs_mapper) # type: ignore[return-value]
126 changes: 126 additions & 0 deletions statemachine/io/scxml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Simple SCXML parser that converts SCXML documents to state machine definitions.
"""

import xml.etree.ElementTree as ET
from functools import partial
from typing import Any
from typing import Dict
from typing import List

from statemachine.statemachine import StateMachine


def send_event(machine: StateMachine, event_to_send: str) -> None:
machine.send(event_to_send)


def assign(model, location, expr):
pass

Check warning on line 19 in statemachine/io/scxml.py

View check run for this annotation

Codecov / codecov/patch

statemachine/io/scxml.py#L19

Added line #L19 was not covered by tests


def strip_namespaces(tree):
"""Remove all namespaces from tags and attributes in place.
Leaves only the local names in the subtree.
"""
for el in tree.iter():
tag = el.tag
if tag and isinstance(tag, str) and tag[0] == "{":
el.tag = tag.partition("}")[2]
attrib = el.attrib
if attrib:
for name, value in list(attrib.items()):
if name and isinstance(name, str) and name[0] == "{":
del attrib[name]
attrib[name.partition("}")[2]] = value

Check warning on line 36 in statemachine/io/scxml.py

View check run for this annotation

Codecov / codecov/patch

statemachine/io/scxml.py#L35-L36

Added lines #L35 - L36 were not covered by tests


def parse_scxml(scxml_content: str) -> Dict[str, Any]: # noqa: C901
"""
Parse SCXML content and return a dictionary definition compatible with
create_machine_class_from_definition.
The returned dictionary has the format:
{
"states": {
"state_id": {"initial": True},
...
},
"events": {
"event_name": [
{"from": "source_state", "to": "target_state"},
...
]
}
}
"""
# Parse XML content
root = ET.fromstring(scxml_content)
strip_namespaces(root)

# Find the scxml element (it might be the root or a child)
scxml = root if "scxml" in root.tag else root.find(".//scxml")
if scxml is None:
raise ValueError("No scxml element found in document")

Check warning on line 65 in statemachine/io/scxml.py

View check run for this annotation

Codecov / codecov/patch

statemachine/io/scxml.py#L65

Added line #L65 was not covered by tests

# Get initial state from scxml element
initial_state = scxml.get("initial")

# Build states dictionary
states = {}
events: Dict[str, List[Dict[str, str]]] = {}

def _parse_state(state_elem, final=False): # noqa: C901
state_id = state_elem.get("id")
if not state_id:
raise ValueError("All states must have an id")

Check warning on line 77 in statemachine/io/scxml.py

View check run for this annotation

Codecov / codecov/patch

statemachine/io/scxml.py#L77

Added line #L77 was not covered by tests

# Mark as initial if specified
states[state_id] = {"initial": state_id == initial_state, "final": final}

# Process transitions
for trans_elem in state_elem.findall("transition"):
event = trans_elem.get("event") or None
target = trans_elem.get("target")

if target:
if event not in events:
events[event] = []

if target not in states:
states[target] = {}

events[event].append(
{
"from": state_id,
"to": target,
}
)

for onentry_elem in state_elem.findall("onentry"):
for raise_elem in onentry_elem.findall("raise"):
event = raise_elem.get("event")
if event:
state = states[state_id]
if "enter" not in state:
state["enter"] = []
state["enter"].append(partial(send_event, event_to_send=event))

# First pass: collect all states and mark initial
for state_elem in scxml.findall(".//state"):
_parse_state(state_elem)

# Second pass: collect final states
for state_elem in scxml.findall(".//final"):
_parse_state(state_elem, final=True)

# If no initial state was specified, mark the first state as initial
if not initial_state and states:
first_state = next(iter(states))
states[first_state]["initial"] = True

Check warning on line 121 in statemachine/io/scxml.py

View check run for this annotation

Codecov / codecov/patch

statemachine/io/scxml.py#L120-L121

Added lines #L120 - L121 were not covered by tests

return {
"states": states,
"events": events,
}
42 changes: 40 additions & 2 deletions statemachine/spec_parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import operator
import re
from typing import Callable

Expand Down Expand Up @@ -40,6 +41,23 @@ def decorated(*args, **kwargs) -> bool:
return decorated


def build_custom_operator(operator) -> Callable:
def custom_comparator(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return bool(operator(left(*args, **kwargs), right(*args, **kwargs)))

return decorated

return custom_comparator


def build_constant(constant) -> Callable:
def decorated(*args, **kwargs):
return constant

return decorated


def custom_or(left: Callable, right: Callable) -> Callable:
def decorated(*args, **kwargs) -> bool:
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]
Expand All @@ -49,7 +67,7 @@ def decorated(*args, **kwargs) -> bool:
return decorated


def build_expression(node, variable_hook, operator_mapping):
def build_expression(node, variable_hook, operator_mapping): # noqa: C901
if isinstance(node, ast.BoolOp):
# Handle `and` / `or` operations
operator_fn = operator_mapping[type(node.op)]
Expand All @@ -58,13 +76,23 @@ def build_expression(node, variable_hook, operator_mapping):
right_expr = build_expression(right, variable_hook, operator_mapping)
left_expr = operator_fn(left_expr, right_expr)
return left_expr
elif isinstance(node, ast.Compare):
operator_fn = operator_mapping[type(node.ops[0])]
left_expr = build_expression(node.left, variable_hook, operator_mapping)
for right in node.comparators:
right_expr = build_expression(right, variable_hook, operator_mapping)
left_expr = operator_fn(left_expr, right_expr)
return left_expr
elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
# Handle `not` operation
operand_expr = build_expression(node.operand, variable_hook, operator_mapping)
return operator_mapping[type(node.op)](operand_expr)
elif isinstance(node, ast.Name):
# Handle variables by calling the variable_hook
return variable_hook(node.id)
elif isinstance(node, ast.Constant):
# Handle constants by returning the value
return build_constant(node.value)
else:
raise ValueError(f"Unsupported expression structure: {node.__class__.__name__}")

Expand All @@ -80,4 +108,14 @@ def parse_boolean_expr(expr, variable_hook, operator_mapping):
return build_expression(tree.body, variable_hook, operator_mapping)


operator_mapping = {ast.Or: custom_or, ast.And: custom_and, ast.Not: custom_not}
operator_mapping = {
ast.Or: custom_or,
ast.And: custom_and,
ast.Not: custom_not,
ast.GtE: build_custom_operator(operator.ge),
ast.Gt: build_custom_operator(operator.gt),
ast.LtE: build_custom_operator(operator.le),
ast.Lt: build_custom_operator(operator.lt),
ast.Eq: build_custom_operator(operator.eq),
ast.NotEq: build_custom_operator(operator.ne),
}
Loading

0 comments on commit ecc7957

Please sign in to comment.