Skip to content

Commit

Permalink
Test retry Client send
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 8, 2024
1 parent 9cce793 commit e618038
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def _send(self, message: str, retries: Optional[int] = None) -> None:

backoff = 1
retries = retries or self.DEFAULT_MAX_RETRIES
while retries > 0:
while retries >= 0:
try:
await self.socket.send_multipart([b"", message.encode("utf-8")])
try:
Expand Down
57 changes: 28 additions & 29 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import queue

import pytest

from _ert.forward_model_runner.client import Client, ClientConnectionError
from tests.ert.utils import async_mock_zmq_server
from tests.ert.utils import mock_zmq_thread


@pytest.mark.integration_test
Expand All @@ -19,38 +19,37 @@ def test_invalid_server():
pass


async def test_successful_sending(unused_tcp_port):
def test_successful_sending(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
messages = []
server_started = asyncio.Event()

server_task = asyncio.create_task(
async_mock_zmq_server(messages, unused_tcp_port, server_started)
)
await server_started.wait()
messages_c1 = ["test_1", "test_2", "test_3"]
async with Client(url) as c1:
for message in messages_c1:
await c1._send(message)

await server_task
with mock_zmq_thread(unused_tcp_port, messages):
messages_c1 = ["test_1", "test_2", "test_3"]
with Client(url) as c1:
for message in messages_c1:
c1.send(message)

for msg in messages_c1:
assert msg in messages


async def test_retry(unused_tcp_port):
pass
# host = "localhost"
# url = f"tcp://{host}:{unused_tcp_port}"
# messages = []
# server_started = asyncio.Event()

# server_task = asyncio.create_task(
# async_mock_zmq_server(messages, unused_tcp_port, server_started)
# )

# messages_c1 = ["test_1", "test_2", "test_3"]

# TODO write test for retry!
def test_retry(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
messages = []
signal_queue = queue.Queue()
signal_queue.put(2)
client_connection_error_set = False
with mock_zmq_thread(unused_tcp_port, messages, signal_queue):
messages_c1 = ["test_1", "test_2", "test_3"]
with Client(url, ack_timeout=1) as c1:
for message in messages_c1:
try:
c1.send(message, retries=2)
except ClientConnectionError:
client_connection_error_set = True
signal_queue.put(0)
assert client_connection_error_set
assert messages.count("test_1") == 3
assert messages.count("test_2") == 1
assert messages.count("test_3") == 1
8 changes: 5 additions & 3 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ async def _handler(router_socket):
if signal_queue:
with contextlib.suppress(queue.Empty):
signal_value = signal_queue.get(timeout=0.1)

print(f"{dealer=} {frame=} {signal_value=}")
if frame in [b"CONNECT", b"DISCONNECT"] or signal_value != 1:
if frame in [b"CONNECT", b"DISCONNECT"] or signal_value == 0:
await router_socket.send_multipart([dealer, b"", b"ACK"])
if frame not in [b"CONNECT", b"DISCONNECT"]:
messages.append(frame.decode("utf-8"))
if frame not in [b"CONNECT", b"DISCONNECT"] and signal_value != 1:
messages.append(frame.decode("utf-8"))

except asyncio.CancelledError:
break

Expand Down

0 comments on commit e618038

Please sign in to comment.