Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Classes for event info #3531

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/ansys/fluent/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,7 @@
from ansys.fluent.core.search import search # noqa: F401
from ansys.fluent.core.services.batch_ops import BatchOps # noqa: F401
from ansys.fluent.core.session import BaseSession as Fluent # noqa: F401
from ansys.fluent.core.streaming_services.events_streaming import ( # noqa: F401
Event,
MeshingEvent,
SolverEvent,
)
from ansys.fluent.core.streaming_services.events_streaming import * # noqa: F401, F403
from ansys.fluent.core.utils import fldoc, get_examples_download_dir
from ansys.fluent.core.utils.fluent_version import FluentVersion # noqa: F401
from ansys.fluent.core.utils.setup_for_fluent import setup_for_fluent # noqa: F401
Expand Down
314 changes: 304 additions & 10 deletions src/ansys/fluent/core/streaming_services/events_streaming.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,47 @@
"""Module for events management."""

from dataclasses import dataclass, field, fields
from enum import Enum
from functools import partial
import inspect
import logging
from typing import Callable, Generic, Literal, Type, TypeVar
import warnings

from google.protobuf.json_format import MessageToDict

from ansys.api.fluent.v0 import events_pb2 as EventsProtoModule
from ansys.fluent.core.exceptions import InvalidArgument
from ansys.fluent.core.streaming_services.streaming import StreamingService
from ansys.fluent.core.warnings import PyFluentDeprecationWarning

__all__ = [
"Event",
"SolverEvent",
"MeshingEvent",
"TimestepStartedEventInfo",
"TimestepEndedEventInfo",
"IterationEndedEventInfo",
"CalculationsStartedEventInfo",
"CalculationsEndedEventInfo",
"CalculationsPausedEventInfo",
"CalculationsResumedEventInfo",
"AboutToLoadCaseEventInfo",
"CaseLoadedEventInfo",
"AboutToLoadDataEventInfo",
"DataLoadedEventInfo",
"AboutToInitializeSolutionEventInfo",
"SolutionInitializedEventInfo",
"ReportDefinitionUpdatedEventInfo",
"ReportPlotSetUpdatedEventInfo",
"ResidualPlotUpdatedEventInfo",
"SettingsClearedEventInfo",
"SolutionPausedEventInfo",
"ProgressUpdatedEventInfo",
"SolverTimeEstimateUpdatedEventInfo",
"FatalErrorEventInfo",
]

network_logger = logging.getLogger("pyfluent.networking")


Expand Down Expand Up @@ -70,6 +100,256 @@ def _missing_(cls, value: str):
return _missing_for_events(cls, value)


class EventInfoBase:
"""Base class for event information classes."""

derived_classes = {}

def __init_subclass__(cls, event, **kwargs):
super().__init_subclass__(**kwargs)
cls.derived_classes[event] = cls

def __post_init__(self):
for f in fields(self):
# Cast to the correct type
setattr(self, f.name, f.type(getattr(self, f.name)))

def __getattr__(self, name):
for f in fields(self):
if f.metadata.get("deprecated_name") == name:
warnings.warn(
f"'{name}' is deprecated. Use '{f.name}' instead.",
PyFluentDeprecationWarning,
)
return getattr(self, f.name)
return self.__getattribute__(name)


@dataclass
class TimestepStartedEventInfo(EventInfoBase, event=SolverEvent.TIMESTEP_STARTED):
"""Information about the event triggered when a timestep is started.
Attributes
----------
index : int
Timestep index.
size : float
Timestep size.
"""

index: int
size: float


@dataclass
class TimestepEndedEventInfo(EventInfoBase, event=SolverEvent.TIMESTEP_ENDED):
"""Information about the event triggered when a timestep is ended.
Attributes
----------
index : int
Timestep index.
size : float
Timestep size.
"""

index: int
size: float


@dataclass
class IterationEndedEventInfo(EventInfoBase, event=SolverEvent.ITERATION_ENDED):
"""Information about the event triggered when an iteration is ended.
Attributes
----------
index : int
Iteration index.
"""

index: int


class CalculationsStartedEventInfo(
EventInfoBase, event=SolverEvent.CALCULATIONS_STARTED
):
"""Information about the event triggered when calculations are started."""


class CalculationsEndedEventInfo(EventInfoBase, event=SolverEvent.CALCULATIONS_ENDED):
"""Information about the event triggered when calculations are ended."""


class CalculationsPausedEventInfo(EventInfoBase, event=SolverEvent.CALCULATIONS_PAUSED):
"""Information about the event triggered when calculations are paused."""


class CalculationsResumedEventInfo(
EventInfoBase, event=SolverEvent.CALCULATIONS_RESUMED
):
"""Information about the event triggered when calculations are resumed."""


@dataclass
class AboutToLoadCaseEventInfo(EventInfoBase, event=SolverEvent.ABOUT_TO_LOAD_CASE):
"""Information about the event triggered just before a case file is loaded.
Attributes
----------
case_file_name : str
Case filename.
"""

case_file_name: str = field(metadata=dict(deprecated_name="casefilepath"))


@dataclass
class CaseLoadedEventInfo(EventInfoBase, event=SolverEvent.CASE_LOADED):
"""Information about the event triggered after a case file is loaded.
Attributes
----------
case_file_name : str
Case filename.
"""

case_file_name: str = field(metadata=dict(deprecated_name="casefilepath"))


@dataclass
class AboutToLoadDataEventInfo(EventInfoBase, event=SolverEvent.ABOUT_TO_LOAD_DATA):
"""Information about the event triggered just before a data file is loaded.
Attributes
----------
data_file_name : str
Data filename.
"""

data_file_name: str = field(metadata=dict(deprecated_name="datafilepath"))


@dataclass
class DataLoadedEventInfo(EventInfoBase, event=SolverEvent.DATA_LOADED):
"""Information about the event triggered after a data file is loaded.
Attributes
----------
data_file_name : str
Data filename.
"""

data_file_name: str = field(metadata=dict(deprecated_name="datafilepath"))


class AboutToInitializeSolutionEventInfo(
EventInfoBase, event=SolverEvent.ABOUT_TO_INITIALIZE_SOLUTION
):
"""Information about the event triggered just before solution is initialized."""


class SolutionInitializedEventInfo(
EventInfoBase, event=SolverEvent.SOLUTION_INITIALIZED
):
"""Information about the event triggered after solution is initialized."""


@dataclass
class ReportDefinitionUpdatedEventInfo(
EventInfoBase, event=SolverEvent.REPORT_DEFINITION_UPDATED
):
"""Information about the event triggered when a report definition is updated.
Attributes
----------
report_name : str
Report name.
"""

report_name: str = field(metadata=dict(deprecated_name="reportdefinitionname"))


@dataclass
class ReportPlotSetUpdatedEventInfo(
EventInfoBase, event=SolverEvent.REPORT_PLOT_SET_UPDATED
):
"""Information about the event triggered when a report plot set is updated.
Attributes
----------
plot_set_name : str
Plot set name.
"""

plot_set_name: str = field(metadata=dict(deprecated_name="plotsetname"))


class ResidualPlotUpdatedEventInfo(
EventInfoBase, event=SolverEvent.RESIDUAL_PLOT_UPDATED
):
"""Information about the event triggered when residual plots are updated."""


class SettingsClearedEventInfo(EventInfoBase, event=SolverEvent.SETTINGS_CLEARED):
"""Information about the event triggered when settings are cleared."""


@dataclass
class SolutionPausedEventInfo(EventInfoBase, event=SolverEvent.SOLUTION_PAUSED):
"""Information about the event triggered when solution is paused.
Attributes
----------
level : str
Level of the pause event.
index : int
Index of the pause event.
"""

level: str
index: int


@dataclass
class ProgressUpdatedEventInfo(EventInfoBase, event=SolverEvent.PROGRESS_UPDATED):
"""Information about the event triggered when progress is updated.
Attributes
----------
message : str
Progress message.
percentage : int
Progress percentage.
"""

message: str
percentage: int = field(metadata=dict(deprecated_name="percentComplete"))


@dataclass
class SolverTimeEstimateUpdatedEventInfo(
EventInfoBase, event=SolverEvent.SOLVER_TIME_ESTIMATE_UPDATED
):
"""Information about the event triggered when solver time estimate is updated.
Attributes
----------
hours : float
Hours of solver time estimate.
minutes : float
Minutes of solver time estimate.
seconds : float
Seconds of solver time estimate.
"""

hours: float
minutes: float
seconds: float


@dataclass
class FatalErrorEventInfo(EventInfoBase, event=SolverEvent.FATAL_ERROR):
"""Information about the event triggered when a fatal error occurs.
Attributes
----------
message : str
Error message.
error_code : int
Error code.
"""

message: str
error_code: int = field(metadata=dict(deprecated_name="errorCode"))


TEvent = TypeVar("TEvent")


Expand Down Expand Up @@ -100,6 +380,18 @@ def __init__(
self._session = session
self._sync_event_ids = {}

def _construct_event_info(
self, response: EventsProtoModule.BeginStreamingResponse, event: TEvent
):
event_info_msg = getattr(response, event.value.lower())
event_info_dict = MessageToDict(
event_info_msg, including_default_value_fields=True
)
solver_event = SolverEvent(event.value)
event_info_cls = EventInfoBase.derived_classes.get(solver_event)
# Key names can be different, but their order is the same
return event_info_cls(*event_info_dict.values())

def _process_streaming(
self, service, id, stream_begin_method, started_evt, *args, **kwargs
):
Expand Down Expand Up @@ -129,7 +421,7 @@ def _process_streaming(
for callback in callbacks_map.values():
callback(
session=self._session,
event_info=getattr(response, event_name.value.lower()),
event_info=self._construct_event_info(response, event_name),
)
except StopIteration:
break
Expand Down Expand Up @@ -247,7 +539,7 @@ def _register_solution_event_sync_callback(
callback_id: str,
callback: Callable,
) -> tuple[Literal[SolverEvent.SOLUTION_PAUSED], Callable]:
unique_id = self._session.scheme_eval.scheme_eval(
unique_id: int = self._session.scheme_eval.scheme_eval(
f"""
(let
((ids
Expand Down Expand Up @@ -277,14 +569,16 @@ def _register_solution_event_sync_callback(
"""
)

def on_pause(session, event_info: EventsProtoModule.AutoPauseEvent):
if unique_id == event_info.level:
event_info_cls = (
EventsProtoModule.TimestepEndedEvent
if event_type == SolverEvent.TIMESTEP_ENDED
else EventsProtoModule.IterationEndedEvent
)
event_info = event_info_cls(index=event_info.index)
def on_pause(session, event_info: SolutionPausedEventInfo):
if unique_id == int(event_info.level):
if event_type == SolverEvent.ITERATION_ENDED:
event_info = IterationEndedEventInfo(index=event_info.index)
else:
event_info = TimestepEndedEventInfo(
# TODO: Timestep size is currently not available
mkundu1 marked this conversation as resolved.
Show resolved Hide resolved
index=event_info.index,
size=0,
)
try:
callback(session, event_info)
except Exception as e:
Expand Down
Loading