Skip to content

Commit

Permalink
Support fixtures and pytest.mark.parametrize with gen_cluster (#4958
Browse files Browse the repository at this point in the history
)

Support fixtures and `pytest.mark.parametrize` with `gen_cluster` (#4958)
  • Loading branch information
gjoseph92 authored Jun 23, 2021
1 parent ac35e0f commit 9f4165a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
42 changes: 42 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import pathlib
import socket
import threading
from contextlib import contextmanager
Expand Down Expand Up @@ -45,6 +46,47 @@ async def test_gen_cluster(c, s, a, b):
assert await c.submit(lambda: 123) == 123


@gen_cluster(client=True)
async def test_gen_cluster_pytest_fixture(c, s, a, b, tmp_path):
assert isinstance(tmp_path, pathlib.Path)
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in [a, b]:
assert isinstance(w, Worker)


@pytest.mark.parametrize("foo", [True])
@gen_cluster(client=True)
async def test_gen_cluster_parametrized(c, s, a, b, foo):
assert foo is True
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in [a, b]:
assert isinstance(w, Worker)


@pytest.mark.parametrize("foo", [True])
@pytest.mark.parametrize("bar", ["a", "b"])
@gen_cluster(client=True)
async def test_gen_cluster_multi_parametrized(c, s, a, b, foo, bar):
assert foo is True
assert bar in ("a", "b")
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in [a, b]:
assert isinstance(w, Worker)


@pytest.mark.parametrize("foo", [True])
@gen_cluster(client=True)
async def test_gen_cluster_parametrized_variadic_workers(c, s, *workers, foo):
assert foo is True
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in workers:
assert isinstance(w, Worker)


@gen_cluster(
client=True,
Worker=Nanny,
Expand Down
29 changes: 27 additions & 2 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import functools
import gc
import inspect
import io
import itertools
import logging
Expand Down Expand Up @@ -861,6 +862,15 @@ def gen_cluster(
async def test_foo(scheduler, worker1, worker2):
await ... # use tornado coroutines
@pytest.mark.parametrize("param", [1, 2, 3])
@gen_cluster()
async def test_foo(scheduler, worker1, worker2, param):
await ... # use tornado coroutines
@gen_cluster()
async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture_b):
await ... # use tornado coroutines
See also:
start
end
Expand All @@ -877,7 +887,7 @@ def _(func):
if not iscoroutinefunction(func):
func = gen.coroutine(func)

def test_func():
def test_func(*outer_args, **kwargs):
result = None
workers = []
with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop:
Expand Down Expand Up @@ -919,7 +929,7 @@ async def coro():
)
args = [c] + args
try:
future = func(*args)
future = func(*args, *outer_args, **kwargs)
if timeout:
future = asyncio.wait_for(future, timeout)
result = await future
Expand Down Expand Up @@ -979,6 +989,21 @@ def get_unclosed():

return result

# Patch the signature so pytest can inject fixtures
orig_sig = inspect.signature(func)
args = [None] * (1 + len(nthreads)) # scheduler, *workers
if client:
args.insert(0, None)

bound = orig_sig.bind_partial(*args)
test_func.__signature__ = orig_sig.replace(
parameters=[
p
for name, p in orig_sig.parameters.items()
if name not in bound.arguments
]
)

return test_func

return _
Expand Down

0 comments on commit 9f4165a

Please sign in to comment.