From 18e6e6dcdab0e3e41a233cb4d9dc626f99bb0644 Mon Sep 17 00:00:00 2001 From: Mainak Kundu Date: Mon, 2 Dec 2024 22:54:40 -0500 Subject: [PATCH] feat: optional last --- doc/source/user_guide/events.rst | 6 ++--- .../streaming_services/events_streaming.py | 25 +++++++++++-------- tests/test_events_manager.py | 13 +++++++++- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/doc/source/user_guide/events.rst b/doc/source/user_guide/events.rst index a8afb9b835d..1bd18e53a0b 100644 --- a/doc/source/user_guide/events.rst +++ b/doc/source/user_guide/events.rst @@ -25,10 +25,10 @@ The following code triggers a callback at the end of every iteration. >>> >>> callback_id = solver.events.register_callback(SolverEvent.ITERATION_ENDED, on_iteration_ended) -The general signature of the callback function is ``cb(, session, event_info)``, where ``session`` is the session instance +The general signature of the callback function is ``cb(session, event_info, )``, where ``session`` is the session instance and ``event_info`` instance holds information about the event. The event information classes for each event are documented in the API reference of the :obj:`~ansys.fluent.core.streaming_services.events_streaming` module. See the callback function -``on_case_loaded_with_args()`` in the below examples for an example of how to pass optional arguments to the callback +``on_case_loaded_with_args()`` in the below examples for an example of how to pass additional arguments to the callback function. @@ -74,7 +74,7 @@ Examples >>> def on_case_loaded(session, event_info: CaseLoadedEventInfo): >>> print("Case loaded. Index = ", event_info.index) >>> - >>> def on_case_loaded_with_args(x, y, session, event_info: CaseLoadedEventInfo): + >>> def on_case_loaded_with_args(session, event_info: CaseLoadedEventInfo, x, y): >>> print(f"Case loaded with {x}, {y}. Index = ", event_info.index) >>> >>> callback = meshing.events.register_callback(MeshingEvent.CASE_LOADED, on_case_loaded) diff --git a/src/ansys/fluent/core/streaming_services/events_streaming.py b/src/ansys/fluent/core/streaming_services/events_streaming.py index b5ce2637e77..0c48c15511f 100644 --- a/src/ansys/fluent/core/streaming_services/events_streaming.py +++ b/src/ansys/fluent/core/streaming_services/events_streaming.py @@ -442,23 +442,26 @@ def _process_streaming( @staticmethod def _make_callback_to_call(callback: Callable, args, kwargs): - old_style = "session_id" in inspect.signature(callback).parameters - if old_style: + params = inspect.signature(callback).parameters + if "session_id" in params: warnings.warn( "Update event callback function signatures" " substituting 'session' for 'session_id'.", PyFluentDeprecationWarning, ) - fn = partial(callback, *args, **kwargs) - return ( - ( - lambda session, event_info: fn( - session_id=session.id, event_info=event_info - ) + return lambda session, event_info: callback( + *args, session_id=session.id, event_info=event_info, **kwargs + ) + else: + positional_args = [ + p + for p in params + if p not in kwargs and p not in ("session", "event_info") + ] + kwargs.update(dict(zip(positional_args, args))) + return lambda session, event_info: callback( + session=session, event_info=event_info, **kwargs ) - if old_style - else fn - ) def register_callback( self, diff --git a/tests/test_events_manager.py b/tests/test_events_manager.py index 925ed39de3e..fea1b35b067 100644 --- a/tests/test_events_manager.py +++ b/tests/test_events_manager.py @@ -28,7 +28,12 @@ def on_case_loaded(session, event_info): on_case_loaded.loaded = False - def on_case_loaded_with_args(x, y, session, event_info): + def on_case_loaded_with_args_optional_first(x, y, session, event_info): + on_case_loaded_with_args_optional_first.state = dict(x=x, y=y) + + on_case_loaded_with_args_optional_first.state = None + + def on_case_loaded_with_args(session, event_info, x, y): on_case_loaded_with_args.state = dict(x=x, y=y) on_case_loaded_with_args.state = None @@ -43,6 +48,10 @@ def on_case_loaded_with_args(x, y, session, event_info): solver.events.register_callback(SolverEvent.CASE_LOADED, on_case_loaded) + solver.events.register_callback( + SolverEvent.CASE_LOADED, on_case_loaded_with_args_optional_first, 12, y=42 + ) + solver.events.register_callback( SolverEvent.CASE_LOADED, on_case_loaded_with_args, 12, y=42 ) @@ -54,6 +63,7 @@ def on_case_loaded_with_args(x, y, session, event_info): assert not on_case_loaded_old.loaded assert not on_case_loaded.loaded assert not on_case_loaded_old_with_args.state + assert not on_case_loaded_with_args_optional_first.state assert not on_case_loaded_with_args.state try: @@ -64,6 +74,7 @@ def on_case_loaded_with_args(x, y, session, event_info): assert on_case_loaded_old.loaded assert on_case_loaded.loaded assert on_case_loaded_old_with_args.state == dict(x=12, y=42) + assert on_case_loaded_with_args_optional_first.state == dict(x=12, y=42) assert on_case_loaded_with_args.state == dict(x=12, y=42)