diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index dc413f18a74..00a32921f18 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -154,7 +154,7 @@ jobs: # Increase this value to reset cache if # continuous_integration/environment-${{ matrix.environment }}.yaml has not # changed. See also same variable in .pre-commit-config.yaml - CACHE_NUMBER: 1 + CACHE_NUMBER: 2 id: cache - name: Update environment diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index c7cf8ac69e8..a7ec037e27c 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -29,6 +29,8 @@ distributed: default-task-durations: # How long we expect function names to run ("1h", "1s") (helps for long tasks) rechunk-split: 1us split-shuffle: 1us + split-taskshuffle: 1us + split-stage: 1us validate: False # Check scheduler state at every step for debugging dashboard: status: diff --git a/distributed/stealing.py b/distributed/stealing.py index 36765149aff..952aa55b1e1 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -542,4 +542,8 @@ def _get_thief( return min(potential_thieves, key=partial(scheduler.worker_objective, ts)) -fast_tasks = {"split-shuffle"} +fast_tasks = { + k + for k, v in dask.config.get("distributed.scheduler.default-task-durations").items() + if parse_timedelta(v) <= 0.001 +} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 636efcd9e48..a15be821319 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2783,12 +2783,11 @@ async def test_default_task_duration_splits(c, s, a, b): await wait(fut) split_prefix = [pre for pre in s.task_prefixes.keys() if "split" in pre] - assert len(split_prefix) == 1 - split_prefix = split_prefix[0] - default_time = parse_timedelta( - dask.config.get("distributed.scheduler.default-task-durations")[split_prefix] - ) - assert default_time <= 1e-6 + assert split_prefix + default_times = dask.config.get("distributed.scheduler.default-task-durations") + for p in split_prefix: + default_time = parse_timedelta(default_times[p]) + assert default_time <= 1e-6 @gen_test() diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index dc25d8fb1d3..539923dd1cc 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1027,10 +1027,10 @@ async def test_blocklist_shuffle_split(c, s, a, b): while not s.tasks: await asyncio.sleep(0.005) - prefixes = set(s.task_prefixes.keys()) + from distributed.stealing import fast_tasks - blocked = fast_tasks & prefixes + blocked = fast_tasks & s.task_prefixes.keys() assert blocked assert any(["split" in prefix for prefix in blocked])