Skip to content

Commit

Permalink
Remove support for batching messages on client side
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 7, 2024
1 parent ac14c13 commit 786b4c4
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 97 deletions.
20 changes: 6 additions & 14 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ async def connect(self) -> None:
self.term()
raise

def send(
self, messages: str | list[str], max_retries: int = DEFAULT_MAX_RETRIES
) -> None:
self.loop.run_until_complete(self._send(messages, max_retries))
def send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> None:
self.loop.run_until_complete(self._send(message, max_retries))

async def process_message(self, msg: str) -> None:
pass
Expand All @@ -124,28 +122,22 @@ async def _receiver(self) -> None:
await asyncio.sleep(1)
self.socket.connect(self.url)

async def _send(
self, messages: str | list[str], max_retries: int = DEFAULT_MAX_RETRIES
) -> None:
async def _send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> None:
self._ack_event.clear()
if isinstance(messages, str):
messages = [messages]

backoff = 1

while max_retries > 0:
try:
await self.socket.send_multipart(
[b""] + [message.encode("utf-8") for message in messages]
)
await self.socket.send_multipart([b"", message.encode("utf-8")])
try:
await asyncio.wait_for(
self._ack_event.wait(), timeout=self._connection_timeout
)
return
except asyncio.TimeoutError:
logger.warning(
f"{self.dealer_id} failed to get acknowledgment on the {messages}. Resending."
f"{self.dealer_id} failed to get acknowledgment on the {message}. Resending."
)
except zmq.ZMQError as exc:
logger.debug(
Expand All @@ -164,5 +156,5 @@ async def _send(
backoff = min(backoff * 2, 10) # Exponential backoff

raise ClientConnectionError(
f"{self.dealer_id} Failed to send {messages=} after retries."
f"{self.dealer_id} Failed to send {message=} after retries."
)
32 changes: 0 additions & 32 deletions src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,38 +100,6 @@ async def publisher():

asyncio.run(publisher())

# def _event_publisher(self):
# logger.debug("Publishing event.")
# with Client(
# url=self._evaluator_url,
# token=self._token,
# cert=self._cert,
# ) as client:
# events = []
# last_sent_time = time.time()
# while not self._done:
# try:
# event = self._event_queue.get()
# if event is self._sentinel:
# self._done = True
# if events:
# client.send(events)
# events.clear()
# break
# events.append(event_to_json(event))

# current_time = time.time()
# if current_time - last_sent_time >= 1:
# if events:
# client.send(events)
# events.clear()
# last_sent_time = current_time
# except ClientConnectionError as e:
# logger.error(f"Failed to send event: {e}")
# except Exception as e:
# logger.error(f"Error while sending event: {e}")
# raise

def report(self, msg):
self._statemachine.transition(msg)

Expand Down
100 changes: 49 additions & 51 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,70 +202,68 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
def ensemble(self) -> Ensemble:
return self._ensemble

async def handle_client(self, dealer: bytes, frames: list[bytes]) -> None:
for frame in frames:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
self._clients_connected.add(dealer)
self._clients_empty.clear()
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
snapshot=current_snapshot_dict,
ensemble=self.ensemble.id_,
)
await self._router_socket.send_multipart(
[dealer, b"", event_to_json(event).encode()]
async def handle_client(self, dealer: bytes, frame: bytes) -> None:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
self._clients_connected.add(dealer)
self._clients_empty.clear()
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
snapshot=current_snapshot_dict,
ensemble=self.ensemble.id_,
)
await self._router_socket.send_multipart(
[dealer, b"", event_to_json(event).encode()]
)
elif raw_msg == "DISCONNECT":
self._clients_connected.discard(dealer)
if not self._clients_connected:
self._clients_empty.set()
else:
event = event_from_json(raw_msg)
if type(event) is EEUserCancel:
logger.debug("Client asked to cancel.")
self._signal_cancel()
elif type(event) is EEUserDone:
logger.debug("Client signalled done.")
self.stop()

async def handle_dispatch(self, dealer: bytes, frame: bytes) -> None:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
self._dispatchers_connected.add(dealer)
self._dispatchers_connected.clear()
elif raw_msg == "DISCONNECT":
self._dispatchers_connected.discard(dealer)
if not self._dispatchers_connected:
self._dispatchers_empty.set()
else:
event = dispatch_event_from_json(raw_msg)
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
f"{event.ensemble}. "
f"Ignoring since I am {self.ensemble.id_}"
)
elif raw_msg == "DISCONNECT":
self._clients_connected.discard(dealer)
if not self._clients_connected:
self._clients_empty.set()
else:
event = event_from_json(raw_msg)
if type(event) is EEUserCancel:
logger.debug("Client asked to cancel.")
self._signal_cancel()
elif type(event) is EEUserDone:
logger.debug("Client signalled done.")
self.stop()

async def handle_dispatch(self, dealer: bytes, frames: list[bytes]) -> None:
for frame in frames:
raw_msg = frame.decode("utf-8")
if raw_msg == "CONNECT":
self._dispatchers_connected.add(dealer)
self._dispatchers_connected.clear()
elif raw_msg == "DISCONNECT":
self._dispatchers_connected.discard(dealer)
if not self._dispatchers_connected:
self._dispatchers_empty.set()
return
if type(event) is ForwardModelStepChecksum:
await self.forward_checksum(event)
else:
event = dispatch_event_from_json(raw_msg)
if event.ensemble != self.ensemble.id_:
logger.info(
"Got event from evaluator "
f"{event.ensemble}. "
f"Ignoring since I am {self.ensemble.id_}"
)
continue
if type(event) is ForwardModelStepChecksum:
await self.forward_checksum(event)
else:
await self._events.put(event)
await self._events.put(event)

async def listen_for_messages(self) -> None:
await self._server_started.wait()
while True:
try:
dealer, _, *frames = await self._router_socket.recv_multipart()
dealer, _, frame = await self._router_socket.recv_multipart()
# print(f"GOT MESSAGE {frames=} from {dealer=}")
await self._router_socket.send_multipart([dealer, b"", b"ACK"])
sender = dealer.decode("utf-8")
if sender.startswith("client"):
await self.handle_client(dealer, frames)
await self.handle_client(dealer, frame)
elif sender.startswith("dispatch"):
# await self._router_socket.send_multipart([dealer, b"", b"ACK"])
await self.handle_dispatch(dealer, frames)
await self.handle_dispatch(dealer, frame)
else:
logger.info(f"Connection attempt to unknown sender: {sender}.")
except zmq.error.ZMQError as e:
Expand Down

0 comments on commit 786b4c4

Please sign in to comment.