From 44f6a95f3ff8c5375b3494fd3e563fe3d058a9cc Mon Sep 17 00:00:00 2001 From: David Justo Date: Fri, 24 Sep 2021 10:41:35 -0700 Subject: [PATCH] Allow try-catching Entity exceptions in orchestrators (#324) --- .../models/TaskOrchestrationExecutor.py | 8 ++- .../models/entities/ResponseMessage.py | 6 +- tests/orchestrator/orchestrator_test_utils.py | 5 +- tests/orchestrator/test_entity.py | 62 ++++++++++++++++++- tests/test_utils/ContextBuilder.py | 7 ++- 5 files changed, 77 insertions(+), 11 deletions(-) diff --git a/azure/durable_functions/models/TaskOrchestrationExecutor.py b/azure/durable_functions/models/TaskOrchestrationExecutor.py index 4119d76d..11ea74bb 100644 --- a/azure/durable_functions/models/TaskOrchestrationExecutor.py +++ b/azure/durable_functions/models/TaskOrchestrationExecutor.py @@ -180,8 +180,12 @@ def parse_history_event(directive_result): # retrieve result new_value = parse_history_event(event) if task._api_name == "CallEntityAction": - new_value = ResponseMessage.from_dict(new_value) - new_value = json.loads(new_value.result) + event_payload = ResponseMessage.from_dict(new_value) + new_value = json.loads(event_payload.result) + + if event_payload.is_exception: + new_value = Exception(new_value) + is_success = False else: # generate exception new_value = Exception(f"{event.Reason} \n {event.Details}") diff --git a/azure/durable_functions/models/entities/ResponseMessage.py b/azure/durable_functions/models/entities/ResponseMessage.py index ffd58985..0b8b35dc 100644 --- a/azure/durable_functions/models/entities/ResponseMessage.py +++ b/azure/durable_functions/models/entities/ResponseMessage.py @@ -7,7 +7,7 @@ class ResponseMessage: Specifies the response of an entity, as processed by the durable-extension. """ - def __init__(self, result: str): + def __init__(self, result: str, is_exception: bool = False): """Instantiate a ResponseMessage. Specifies the response of an entity, as processed by the durable-extension. @@ -18,6 +18,7 @@ def __init__(self, result: str): The result provided by the entity """ self.result = result + self.is_exception = is_exception # TODO: JS has an additional exceptionType field, but does not use it @classmethod @@ -34,5 +35,6 @@ def from_dict(cls, d: Dict[str, Any]) -> 'ResponseMessage': ResponseMessage: The ResponseMessage built from the provided dictionary """ - result = cls(d["result"]) + is_error = "exceptionType" in d.keys() + result = cls(d["result"], is_error) return result diff --git a/tests/orchestrator/orchestrator_test_utils.py b/tests/orchestrator/orchestrator_test_utils.py index f7fcd07c..5c77151c 100644 --- a/tests/orchestrator/orchestrator_test_utils.py +++ b/tests/orchestrator/orchestrator_test_utils.py @@ -30,8 +30,9 @@ def assert_entity_state_equals(expected, result): assert_attribute_equal(expected, result, "signals") def assert_results_are_equal(expected: Dict[str, Any], result: Dict[str, Any]) -> bool: - assert_attribute_equal(expected, result, "result") - assert_attribute_equal(expected, result, "isError") + for (payload_expected, payload_result) in zip(expected, result): + assert_attribute_equal(payload_expected, payload_result, "result") + assert_attribute_equal(payload_expected, payload_result, "isError") def assert_attribute_equal(expected, result, attribute): if attribute in expected: diff --git a/tests/orchestrator/test_entity.py b/tests/orchestrator/test_entity.py index e4b07ae7..83a7d25a 100644 --- a/tests/orchestrator/test_entity.py +++ b/tests/orchestrator/test_entity.py @@ -1,6 +1,6 @@ from azure.durable_functions.models.ReplaySchema import ReplaySchema from .orchestrator_test_utils \ - import assert_orchestration_state_equals, get_orchestration_state_result, assert_valid_schema, \ + import assert_orchestration_state_equals, assert_results_are_equal, get_orchestration_state_result, assert_valid_schema, \ get_entity_state_result, assert_entity_state_equals from tests.test_utils.ContextBuilder import ContextBuilder from tests.test_utils.EntityContextBuilder import EntityContextBuilder @@ -23,6 +23,14 @@ def generator_function_call_entity(context): outputs.append(x) return outputs +def generator_function_catch_entity_exception(context): + entityId = df.EntityId("Counter", "myCounter") + try: + yield context.call_entity(entityId, "add", 3) + return "No exception thrown" + except: + return "Exception thrown" + def generator_function_signal_entity(context): outputs = [] entityId = df.EntityId("Counter", "myCounter") @@ -53,6 +61,29 @@ def counter_entity_function(context): context.set_state(current_value) context.set_result(result) +def counter_entity_function_raises_exception(context): + raise Exception("boom!") + +def test_entity_raises_exception(): + # Create input batch + batch = [] + add_to_batch(batch, name="get") + context_builder = EntityContextBuilder(batch=batch) + + # Run the entity, get observed result + result = get_entity_state_result( + context_builder, + counter_entity_function_raises_exception, + ) + + # Construct expected result + expected_state = entity_base_expected_state() + apply_operation(expected_state, result="boom!", state=None, is_error=True) + expected = expected_state.to_json() + + # Ensure expectation matches observed behavior + #assert_valid_schema(result) + assert_entity_state_equals(expected, result) def test_entity_signal_then_call(): """Tests that a simple counter entity outputs the correct value @@ -161,11 +192,11 @@ def add_signal_entity_action(state: OrchestratorState, id_: df.EntityId, op: str state.actions.append([action]) def add_call_entity_completed_events( - context_builder: ContextBuilder, op: str, instance_id=str, input_=None, event_id=0): + context_builder: ContextBuilder, op: str, instance_id=str, input_=None, event_id=0, is_error=False): context_builder.add_event_sent_event(instance_id, event_id) context_builder.add_orchestrator_completed_event() context_builder.add_orchestrator_started_event() - context_builder.add_event_raised_event(name="0000", id_=0, input_=input_, is_entity=True) + context_builder.add_event_raised_event(name="0000", id_=0, input_=input_, is_entity=True, is_error=is_error) def test_call_entity_sent(): context_builder = ContextBuilder('test_simple_function') @@ -233,4 +264,29 @@ def test_call_entity_raised(): #assert_valid_schema(result) + assert_orchestration_state_equals(expected, result) + +def test_call_entity_catch_exception(): + entityId = df.EntityId("Counter", "myCounter") + context_builder = ContextBuilder('catch exceptions') + add_call_entity_completed_events( + context_builder, + "add", + df.EntityId.get_scheduler_id(entityId), + input_="I am an error!", + event_id=0, + is_error=True + ) + + result = get_orchestration_state_result( + context_builder, generator_function_catch_entity_exception) + + expected_state = base_expected_state( + "Exception thrown" + ) + + add_call_entity_action(expected_state, entityId, "add", 3) + expected_state._is_done = True + expected = expected_state.to_json() + assert_orchestration_state_equals(expected, result) \ No newline at end of file diff --git a/tests/test_utils/ContextBuilder.py b/tests/test_utils/ContextBuilder.py index 7cc7273f..a70f2808 100644 --- a/tests/test_utils/ContextBuilder.py +++ b/tests/test_utils/ContextBuilder.py @@ -125,11 +125,14 @@ def add_execution_started_event( event.Input = input_ self.history_events.append(event) - def add_event_raised_event(self, name:str, id_: int, input_=None, timestamp=None, is_entity=False): + def add_event_raised_event(self, name:str, id_: int, input_=None, timestamp=None, is_entity=False, is_error = False): event = self.get_base_event(HistoryEventType.EVENT_RAISED, id_=id_, timestamp=timestamp) event.Name = name if is_entity: - event.Input = json.dumps({ "result": json.dumps(input_) }) + if is_error: + event.Input = json.dumps({ "result": json.dumps(input_), "exceptionType": "True" }) + else: + event.Input = json.dumps({ "result": json.dumps(input_) }) else: event.Input = input_ # event.timestamp = timestamp