Skip to content

Commit

Permalink
Merge branch 'main' into WSMR/batched_send
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 1, 2022
2 parents c5ea5cd + 715d7be commit cf50bd0
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 119 deletions.
60 changes: 31 additions & 29 deletions distributed/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,53 +185,56 @@ def test_identifier():


def test_watch():
stop_called = threading.Event()
watch_thread = None
start = time()

def stop():
if not stop_called.is_set(): # Run setup code
nonlocal watch_thread
nonlocal start
watch_thread = threading.current_thread()
start = time()
stop_called.set()
return time() > start + 0.500

start_threads = threading.active_count()

log = watch(interval="10ms", cycle="50ms", stop=stop)

start = time() # wait until thread starts up
while threading.active_count() <= start_threads:
assert time() < start + 2
sleep(0.01)

stop_called.wait(2)
sleep(0.5)
assert 1 < len(log) < 10

start = time()
while threading.active_count() > start_threads:
assert time() < start + 2
sleep(0.01)
watch_thread.join(2)


def test_watch_requires_lock_to_run():
start = time()

def stop_lock():
return time() > start + 0.600
stop_profiling_called = threading.Event()
profiling_thread = None

def stop_profile():
def stop_profiling():
if not stop_profiling_called.is_set(): # Run setup code
nonlocal profiling_thread
nonlocal start
profiling_thread = threading.current_thread()
start = time()
stop_profiling_called.set()
return time() > start + 0.500

def hold_lock(stop):
release_lock = threading.Event()

def block_lock():
with lock:
while not stop():
sleep(0.1)
release_lock.wait()

start_threads = threading.active_count()

# Hog the lock over the entire duration of watch
thread = threading.Thread(
target=hold_lock, name="Hold Lock", kwargs={"stop": stop_lock}
)
thread.daemon = True
thread.start()
# Block the lock over the entire duration of watch
blocking_thread = threading.Thread(target=block_lock, name="Block Lock")
blocking_thread.daemon = True
blocking_thread.start()

log = watch(interval="10ms", cycle="50ms", stop=stop_profile)
log = watch(interval="10ms", cycle="50ms", stop=stop_profiling)

start = time() # wait until thread starts up
while threading.active_count() < start_threads + 2:
Expand All @@ -240,11 +243,10 @@ def hold_lock(stop):

sleep(0.5)
assert len(log) == 0
release_lock.set()

start = time()
while threading.active_count() > start_threads:
assert time() < start + 2
sleep(0.01)
profiling_thread.join(2)
blocking_thread.join(2)


@dataclasses.dataclass(frozen=True)
Expand Down
28 changes: 3 additions & 25 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils_test import (
BlockedGatherDep,
BlockedGetData,
TaskStateMetadataPlugin,
_LockedCommPool,
assert_story,
Expand Down Expand Up @@ -2618,7 +2620,7 @@ def sink(a, b, *args):
if peer_addr == a.address and msg["op"] == "get_data":
break

# Provoke an "impossible transision exception"
# Provoke an "impossible transition exception"
# By choosing a state which doesn't exist we're not running into validation
# errors and the state machine should raise if we want to transition from
# fetch to memory
Expand Down Expand Up @@ -3137,30 +3139,6 @@ async def test_task_flight_compute_oserror(c, s, a, b):
assert_story(sum_story, expected_sum_story, strict=True)


class BlockedGatherDep(Worker):
def __init__(self, *args, **kwargs):
self.in_gather_dep = asyncio.Event()
self.block_gather_dep = asyncio.Event()
super().__init__(*args, **kwargs)

async def gather_dep(self, *args, **kwargs):
self.in_gather_dep.set()
await self.block_gather_dep.wait()
return await super().gather_dep(*args, **kwargs)


class BlockedGetData(Worker):
def __init__(self, *args, **kwargs):
self.in_get_data = asyncio.Event()
self.block_get_data = asyncio.Event()
super().__init__(*args, **kwargs)

async def get_data(self, comm, *args, **kwargs):
self.in_get_data.set()
await self.block_get_data.wait()
return await super().get_data(comm, *args, **kwargs)


@gen_cluster(client=True, nthreads=[])
async def test_gather_dep_cancelled_rescheduled(c, s):
"""At time of writing, the gather_dep implementation filtered tasks again
Expand Down
Loading

0 comments on commit cf50bd0

Please sign in to comment.