Skip to content

Commit

Permalink
Implementing router-dealer pattern with custom acknowledgments with zmq
Browse files Browse the repository at this point in the history
 - dispatcher now send messages in chunks
 - dispatcher always for acknolwedgment from the evaluator
 - removing websockets, no more wait_for_evaluator
 - Settup encryption with curve
 - each dealer (client, dispatcher) will get a unique name
  • Loading branch information
xjules committed Nov 19, 2024
1 parent 03cfa25 commit feac786
Show file tree
Hide file tree
Showing 10 changed files with 252 additions and 377 deletions.
152 changes: 72 additions & 80 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations

import asyncio
import logging
import ssl
from typing import Any, AnyStr, Optional, Union
import uuid
from typing import Any, Optional, Union

import zmq
import zmq.asyncio
from typing_extensions import Self
from websockets.client import WebSocketClientProtocol, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
ConnectionClosedError,
ConnectionClosedOK,
InvalidHandshake,
InvalidURI,
)

from _ert.async_utils import new_event_loop

Expand All @@ -35,18 +31,18 @@ def __enter__(self) -> Self:
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
if self.websocket is not None:
self.loop.run_until_complete(self.websocket.close())
self.loop.close()
self.socket.close()
self.context.term()

async def __aenter__(self) -> "Client":
async def __aenter__(self) -> Self:
return self

async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
if self.websocket is not None:
await self.websocket.close()
self.socket.close()
self.context.term()
self.loop.close()

def __init__(
self,
Expand All @@ -55,84 +51,80 @@ def __init__(
cert: Optional[Union[str, bytes]] = None,
max_retries: Optional[int] = None,
timeout_multiplier: Optional[int] = None,
dealer_name: str | None = None,
) -> None:
if max_retries is None:
max_retries = self.DEFAULT_MAX_RETRIES
if timeout_multiplier is None:
timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER
if url is None:
raise ValueError("url was None")
self.url = url
self.token = token
self._extra_headers = Headers()

# Set up ZeroMQ context and socket
self.context = zmq.asyncio.Context() # type: ignore
self.socket = self.context.socket(zmq.DEALER)
if dealer_name is None:
dispatch_id = f"dispatch-{uuid.uuid4().hex[:8]}"
else:
dispatch_id = dealer_name
self.socket.setsockopt_string(zmq.IDENTITY, dispatch_id)
if token is not None:
self._extra_headers["token"] = token

# Mimics the behavior of the ssl argument when connection to
# websockets. If none is specified it will deduce based on the url,
# if True it will enforce TLS, and if you want to use self signed
# certificates you need to pass an ssl_context with the certificate
# loaded.
self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = None
if cert is not None:
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ssl_context.load_verify_locations(cadata=cert)
elif url.startswith("wss"):
self._ssl_context = True
client_public, client_secret = zmq.curve_keypair()
self.socket.curve_secretkey = client_secret
self.socket.curve_publickey = client_public
self.socket.curve_serverkey = token.encode("utf-8")
self.socket.connect(url)

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket: Optional[WebSocketClientProtocol] = None
self.loop = new_event_loop()

async def get_websocket(self) -> WebSocketClientProtocol:
return await connect(
self.url,
ssl=self._ssl_context,
extra_headers=self._extra_headers,
open_timeout=self.CONNECTION_TIMEOUT,
ping_timeout=self.CONNECTION_TIMEOUT,
ping_interval=self.CONNECTION_TIMEOUT,
close_timeout=self.CONNECTION_TIMEOUT,
)

async def _send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):
async def reconnect(self):
"""Connect to the server with exponential backoff."""
retries = self._max_retries
while retries > 0:
try:
if self.websocket is None:
self.websocket = await self.get_websocket()
await self.websocket.send(msg)
return
except ConnectionClosedOK as exception:
_error_msg = (
f"Connection closed received from the server {self.url}! "
f" Exception from {type(exception)}: {exception!s}"
self.socket.connect(self.url)
break
except zmq.ZMQError as e:
logger.warning(f"Failed to connect to {self.url}: {e}")
retries -= 1
if retries == 0:
raise e
# Exponential backoff
sleep_time = self._timeout_multiplier * (self._max_retries - retries)
await asyncio.sleep(sleep_time)

def send(self, messages: str | list[str]) -> None:
self.loop.run_until_complete(self.send_async(messages))

async def send_async(self, messages: str | list[str]) -> None:
if isinstance(messages, str):
messages = [messages]
retries = 0
max_retries = 5
while retries < max_retries:
try:
logger.debug(f"sending messages: {messages}")
await self.socket.send_multipart(
[b""] + [message.encode("utf-8") for message in messages]
)
raise ClientConnectionClosedOK(_error_msg) from exception
except (
InvalidHandshake,
InvalidURI,
OSError,
asyncio.TimeoutError,
) as exception:
if retry == self._max_retries:
_error_msg = (
f"Not able to establish the "
f"websocket connection {self.url}! Max retries reached!"
" Check for firewall issues."
f" Exception from {type(exception)}: {exception!s}"
try:
_, ack = await asyncio.wait_for(
self.socket.recv_multipart(), timeout=3
)
raise ClientConnectionError(_error_msg) from exception
except ConnectionClosedError as exception:
if retry == self._max_retries:
_error_msg = (
f"Not been able to send the event"
f" to {self.url}! Max retries reached!"
f" Exception from {type(exception)}: {exception!s}"
logger.debug(f"Got acknowledgment: {ack}")
if ack.decode() == "ACK":
break
logger.warning(
"Got acknowledgment but not the expected message. Resending"
)
raise ClientConnectionError(_error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
retries += 1
except asyncio.TimeoutError:
logger.warning(
"Failed to get acknowledgment on the message. Resending"
)
retries += 1
except zmq.ZMQError as e:
logger.warning(f"Failed to send message from {e} reconnecting ...")
await self.reconnect()
39 changes: 22 additions & 17 deletions src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import queue
import threading
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Final, Union
Expand All @@ -18,8 +19,6 @@
)
from _ert.forward_model_runner.client import (
Client,
ClientConnectionClosedOK,
ClientConnectionError,
)
from _ert.forward_model_runner.reporting.base import Reporter
from _ert.forward_model_runner.reporting.message import (
Expand Down Expand Up @@ -90,7 +89,8 @@ def _event_publisher(self):
token=self._token,
cert=self._cert,
) as client:
event = None
events = []
last_sent_time = time.time()
while True:
with self._timestamp_lock:
if (
Expand All @@ -99,23 +99,28 @@ def _event_publisher(self):
):
self._timeout_timestamp = None
break
if event is None:
# if we successfully sent the event we can proceed
# to next one

try:
event = self._event_queue.get()
logger.debug(f"Got event for zmq: {event}")
if event is self._sentinel:
if events:
logger.debug(f"Got event class for zmq: {events}")
client.send(events)
events.clear()
break
try:
client.send(event_to_json(event))
event = None
except ClientConnectionError as exception:
# Possible intermittent failure, we retry sending the event
logger.error(str(exception))
except ClientConnectionClosedOK as exception:
# The receiving end has closed the connection, we stop
# sending events
logger.debug(str(exception))
break
events.append(event_to_json(event))

current_time = time.time()
if current_time - last_sent_time >= 2:
if events:
logger.debug(f"Got event class for zmq: {events}")
client.send(events)
events.clear()
last_sent_time = current_time
except Exception as e:
logger.error(f"Failed to send event: {e}")
raise

def report(self, msg):
self._statemachine.transition(msg)
Expand Down
20 changes: 6 additions & 14 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,8 @@
from ert.run_arg import RunArg
from ert.scheduler import Scheduler, create_driver

from ._wait_for_evaluator import wait_for_evaluator
from .config import EvaluatorServerConfig
from .snapshot import (
EnsembleSnapshot,
FMStepSnapshot,
RealizationSnapshot,
)
from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot
from .state import (
ENSEMBLE_STATE_CANCELLED,
ENSEMBLE_STATE_FAILED,
Expand Down Expand Up @@ -122,6 +117,7 @@ def __post_init__(self) -> None:
self._config: Optional[EvaluatorServerConfig] = None
self.snapshot: EnsembleSnapshot = self._create_snapshot()
self.status = self.snapshot.status
self._client: Client | None = None
if self.snapshot.status:
self._status_tracker = _EnsembleStateTracker(self.snapshot.status)
else:
Expand Down Expand Up @@ -205,7 +201,7 @@ async def send_event(
retries: int = 10,
) -> None:
async with Client(url, token, cert, max_retries=retries) as client:
await client._send(event_to_json(event))
await client.send_async(event_to_json(event))

def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]:
def event_builder(status: str) -> Event:
Expand All @@ -230,16 +226,12 @@ async def evaluate(
ce_unary_send_method_name,
partialmethod(
self.__class__.send_event,
self._config.dispatch_uri,
self._config.get_connection_info().router_uri,
token=self._config.token,
cert=self._config.cert,
),
)
await wait_for_evaluator(
base_url=self._config.url,
token=self._config.token,
cert=self._config.cert,
)

await self._evaluate_inner(
event_unary_send=getattr(self, ce_unary_send_method_name),
scheduler_queue=scheduler_queue,
Expand Down Expand Up @@ -282,7 +274,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
max_running=self._queue_config.max_running,
submit_sleep=self._queue_config.submit_sleep,
ens_id=self.id_,
ee_uri=self._config.dispatch_uri,
ee_uri=self._config.get_connection_info().router_uri,
ee_cert=self._config.cert,
ee_token=self._config.token,
)
Expand Down
60 changes: 0 additions & 60 deletions src/ert/ensemble_evaluator/_wait_for_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import asyncio
import logging
import ssl
import time
from typing import Optional, Union

import aiohttp

logger = logging.getLogger(__name__)

WAIT_FOR_EVALUATOR_TIMEOUT = 60
Expand All @@ -17,59 +13,3 @@ def get_ssl_context(cert: Optional[Union[str, bytes]]) -> Union[ssl.SSLContext,
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_verify_locations(cadata=cert)
return ssl_context


async def attempt_connection(
url: str,
token: Optional[str] = None,
cert: Optional[Union[str, bytes]] = None,
connection_timeout: float = 2,
) -> None:
timeout = aiohttp.ClientTimeout(connect=connection_timeout)
headers = {} if token is None else {"token": token}
async with aiohttp.ClientSession() as session, session.request(
method="get",
url=url,
ssl=get_ssl_context(cert),
headers=headers,
timeout=timeout,
) as resp:
resp.raise_for_status()


async def wait_for_evaluator(
base_url: str,
token: Optional[str] = None,
cert: Optional[Union[str, bytes]] = None,
healthcheck_endpoint: str = "/healthcheck",
timeout: Optional[float] = None,
connection_timeout: float = 2,
) -> None:
if timeout is None:
timeout = WAIT_FOR_EVALUATOR_TIMEOUT
healthcheck_url = base_url + healthcheck_endpoint
start = time.time()
sleep_time = 0.2
sleep_time_max = 5.0
while time.time() - start < timeout:
try:
await attempt_connection(
url=healthcheck_url,
token=token,
cert=cert,
connection_timeout=connection_timeout,
)
return
except aiohttp.ClientError:
sleep_time = min(sleep_time_max, sleep_time * 2)
remaining_time = max(0, timeout - (time.time() - start) + 0.1)
await asyncio.sleep(min(sleep_time, remaining_time))

# We have timed out, but we make one last attempt to ensure that
# we have tried to connect at both ends of the time window
await attempt_connection(
url=healthcheck_url,
token=token,
cert=cert,
connection_timeout=connection_timeout,
)
Loading

0 comments on commit feac786

Please sign in to comment.