Skip to content

Commit

Permalink
feat: optional last
Browse files Browse the repository at this point in the history
  • Loading branch information
mkundu1 committed Dec 3, 2024
1 parent 7b27d14 commit 40dcbf6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions doc/source/user_guide/events.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The following code triggers a callback at the end of every iteration.
>>>
>>> callback_id = solver.events.register_callback(SolverEvent.ITERATION_ENDED, on_iteration_ended)
The general signature of the callback function is ``cb(<optional arguments>, session, event_info)``, where ``session`` is the session instance
The general signature of the callback function is ``cb(session, event_info, <optional arguments>)``, where ``session`` is the session instance
and ``event_info`` instance holds information about the event. The event information classes for each event are documented in the
API reference of the :obj:`~ansys.fluent.core.streaming_services.events_streaming` module. See the callback function
``on_case_loaded_with_args()`` in the below examples for an example of how to pass optional arguments to the callback
Expand Down Expand Up @@ -74,7 +74,7 @@ Examples
>>> def on_case_loaded(session, event_info: CaseLoadedEventInfo):
>>> print("Case loaded. Index = ", event_info.index)
>>>
>>> def on_case_loaded_with_args(x, y, session, event_info: CaseLoadedEventInfo):
>>> def on_case_loaded_with_args(session, event_info: CaseLoadedEventInfo, x, y):
>>> print(f"Case loaded with {x}, {y}. Index = ", event_info.index)
>>>
>>> callback = meshing.events.register_callback(MeshingEvent.CASE_LOADED, on_case_loaded)
Expand Down
25 changes: 14 additions & 11 deletions src/ansys/fluent/core/streaming_services/events_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,23 +442,26 @@ def _process_streaming(

@staticmethod
def _make_callback_to_call(callback: Callable, args, kwargs):
old_style = "session_id" in inspect.signature(callback).parameters
if old_style:
params = inspect.signature(callback).parameters
if "session_id" in params:
warnings.warn(
"Update event callback function signatures"
" substituting 'session' for 'session_id'.",
PyFluentDeprecationWarning,
)
fn = partial(callback, *args, **kwargs)
return (
(
lambda session, event_info: fn(
session_id=session.id, event_info=event_info
)
return lambda session, event_info: callback(
*args, session_id=session.id, event_info=event_info, **kwargs
)
else:
positional_args = [
p
for p in params
if p not in kwargs and p not in ("session", "event_info")
]
kwargs.update(dict(zip(positional_args, args)))
return lambda session, event_info: callback(
session=session, event_info=event_info, **kwargs
)
if old_style
else fn
)

def register_callback(
self,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_events_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def on_case_loaded(session, event_info):

on_case_loaded.loaded = False

def on_case_loaded_with_args(x, y, session, event_info):
def on_case_loaded_with_args_optional_first(x, y, session, event_info):
on_case_loaded_with_args_optional_first.state = dict(x=x, y=y)

on_case_loaded_with_args_optional_first.state = None

def on_case_loaded_with_args(session, event_info, x, y):
on_case_loaded_with_args.state = dict(x=x, y=y)

on_case_loaded_with_args.state = None
Expand All @@ -43,6 +48,10 @@ def on_case_loaded_with_args(x, y, session, event_info):

solver.events.register_callback(SolverEvent.CASE_LOADED, on_case_loaded)

solver.events.register_callback(
SolverEvent.CASE_LOADED, on_case_loaded_with_args_optional_first, 12, y=42
)

solver.events.register_callback(
SolverEvent.CASE_LOADED, on_case_loaded_with_args, 12, y=42
)
Expand All @@ -54,6 +63,7 @@ def on_case_loaded_with_args(x, y, session, event_info):
assert not on_case_loaded_old.loaded
assert not on_case_loaded.loaded
assert not on_case_loaded_old_with_args.state
assert not on_case_loaded_with_args_optional_first.state
assert not on_case_loaded_with_args.state

try:
Expand All @@ -64,6 +74,7 @@ def on_case_loaded_with_args(x, y, session, event_info):
assert on_case_loaded_old.loaded
assert on_case_loaded.loaded
assert on_case_loaded_old_with_args.state == dict(x=12, y=42)
assert on_case_loaded_with_args_optional_first.state == dict(x=12, y=42)
assert on_case_loaded_with_args.state == dict(x=12, y=42)


Expand Down

0 comments on commit 40dcbf6

Please sign in to comment.