diff --git a/continuous_integration/recipes/distributed/meta.yaml b/continuous_integration/recipes/distributed/meta.yaml index 8d8011b5a40..8222c1a07fd 100644 --- a/continuous_integration/recipes/distributed/meta.yaml +++ b/continuous_integration/recipes/distributed/meta.yaml @@ -47,9 +47,9 @@ outputs: track_features: # [cython_enabled] - cythonized-scheduler # [cython_enabled] entry_points: - - dask-scheduler = distributed.cli.dask_scheduler:go - - dask-ssh = distributed.cli.dask_ssh:go - - dask-worker = distributed.cli.dask_worker:go + - dask-scheduler = distributed.cli.dask_scheduler:main + - dask-ssh = distributed.cli.dask_ssh:main + - dask-worker = distributed.cli.dask_worker:main requirements: build: - {{ compiler('c') }} # [cython_enabled] diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index d4b9657740f..c7a18517876 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -10,7 +10,7 @@ from tornado.ioloop import IOLoop from distributed import Scheduler -from distributed.cli.utils import check_python_3, install_signal_handlers +from distributed.cli.utils import install_signal_handlers from distributed.preloading import validate_preload_argv from distributed.proctitle import ( enable_proctitle_on_children, @@ -212,10 +212,5 @@ async def run(): logger.info("End scheduler at %r", scheduler.address) -def go(): - check_python_3() - main() - - if __name__ == "__main__": - go() # pragma: no cover + main() # pragma: no cover diff --git a/distributed/cli/dask_ssh.py b/distributed/cli/dask_ssh.py index bcc1c6ee439..baa8cc04af3 100755 --- a/distributed/cli/dask_ssh.py +++ b/distributed/cli/dask_ssh.py @@ -5,7 +5,6 @@ import click -from distributed.cli.utils import check_python_3 from distributed.deploy.old_ssh import SSHCluster logger = logging.getLogger("distributed.dask_ssh") @@ -223,10 +222,5 @@ def main( print("[ dask-ssh ]: Remote processes have been terminated. Exiting.") -def go(): - check_python_3() - main() - - if __name__ == "__main__": - go() # pragma: no cover + main() # pragma: no cover diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 762bf9e46c8..f13672a4d10 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -16,7 +16,7 @@ from dask.system import CPU_COUNT from distributed import Nanny -from distributed.cli.utils import check_python_3, install_signal_handlers +from distributed.cli.utils import install_signal_handlers from distributed.comm import get_address_host_port from distributed.deploy.utils import nprocesses_nthreads from distributed.preloading import validate_preload_argv @@ -486,10 +486,5 @@ async def run(): logger.info("End worker") -def go(): - check_python_3() - main() - - if __name__ == "__main__": - go() # pragma: no cover + main() # pragma: no cover diff --git a/distributed/cli/utils.py b/distributed/cli/utils.py index 472ddbbf681..a0191edaa90 100644 --- a/distributed/cli/utils.py +++ b/distributed/cli/utils.py @@ -1,51 +1,5 @@ -import click -from packaging.version import parse as parse_version from tornado.ioloop import IOLoop -CLICK_VERSION = parse_version(click.__version__) - -py3_err_msg = """ -Warning: Your terminal does not set locales. - -If you use unicode text inputs for command line options then this may cause -undesired behavior. This is rare. - -If you don't use unicode characters in command line options then you can safely -ignore this message. This is the common case. - -You can support unicode inputs by specifying encoding environment variables, -though exact solutions may depend on your system: - - $ export LC_ALL=C.UTF-8 - $ export LANG=C.UTF-8 - -For more information see: http://click.pocoo.org/5/python3/ -""".lstrip() - - -def check_python_3(): - """Ensures that the environment is good for unicode on Python 3.""" - # https://github.com/pallets/click/issues/448#issuecomment-246029304 - import click.core - - # TODO: Remove use of internal click functions - if CLICK_VERSION < parse_version("8.0.0"): - click.core._verify_python3_env = lambda: None - else: - click.core._verify_python_env = lambda: None - - try: - from click import _unicodefun - - if CLICK_VERSION < parse_version("8.0.0"): - _unicodefun._verify_python3_env() - else: - _unicodefun._verify_python_env() - except (TypeError, RuntimeError): - import click - - click.echo(py3_err_msg, err=True) - def install_signal_handlers(loop=None, cleanup=None): """ diff --git a/distributed/client.py b/distributed/client.py index 49afd3792e3..f68570f7762 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -115,7 +115,10 @@ _current_client = ContextVar("_current_client", default=None) -DEFAULT_EXTENSIONS = [PubSubClientExtension] +DEFAULT_EXTENSIONS = { + "pubsub": PubSubClientExtension, +} + # Placeholder used in the get_dataset function(s) NO_DEFAULT_PLACEHOLDER = "_no_default_" @@ -928,8 +931,9 @@ def __init__( server=self, ) - for ext in extensions: - ext(self) + self.extensions = { + name: extension(self) for name, extension in extensions.items() + } preload = dask.config.get("distributed.client.preload") preload_argv = dask.config.get("distributed.client.preload-argv") diff --git a/distributed/event.py b/distributed/event.py index 0765e003158..037d171a030 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -58,8 +58,6 @@ def __init__(self, scheduler): } ) - self.scheduler.extensions["events"] = self - async def event_wait(self, name=None, timeout=None): """Wait until the event is set to true. Returns false, when this did not happen in the given time diff --git a/distributed/lock.py b/distributed/lock.py index 5830e2de94b..22e3de5e223 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -30,8 +30,6 @@ def __init__(self, scheduler): {"lock_acquire": self.acquire, "lock_release": self.release} ) - self.scheduler.extensions["locks"] = self - async def acquire(self, name=None, id=None, timeout=None): with log_errors(): if isinstance(name, list): diff --git a/distributed/multi_lock.py b/distributed/multi_lock.py index 31b2e6ebbdb..7907f44ecfc 100644 --- a/distributed/multi_lock.py +++ b/distributed/multi_lock.py @@ -46,8 +46,6 @@ def __init__(self, scheduler): {"multi_lock_acquire": self.acquire, "multi_lock_release": self.release} ) - self.scheduler.extensions["multi_locks"] = self - def _request_locks(self, locks: list[str], id: Hashable, num_locks: int) -> bool: """Request locks diff --git a/distributed/publish.py b/distributed/publish.py index 63772519376..161b025bbc0 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -26,7 +26,6 @@ def __init__(self, scheduler): } self.scheduler.handlers.update(handlers) - self.scheduler.extensions["publish"] = self def put(self, keys=None, data=None, name=None, override=False, client=None): with log_errors(): diff --git a/distributed/pubsub.py b/distributed/pubsub.py index f1cbc62e531..f575439c3a0 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -34,8 +34,6 @@ def __init__(self, scheduler): } ) - self.scheduler.extensions["pubsub"] = self - def add_publisher(self, name=None, worker=None): logger.debug("Add publisher: %s %s", name, worker) self.publishers[name].add(worker) @@ -178,7 +176,6 @@ def __init__(self, client): self.client._stream_handlers.update({"pubsub-msg": self.handle_message}) self.subscribers = defaultdict(weakref.WeakSet) - self.client.extensions["pubsub"] = self # TODO: circular reference async def handle_message(self, name=None, msg=None): for sub in self.subscribers[name]: diff --git a/distributed/queues.py b/distributed/queues.py index c29c4f1ab2c..3dc563b3a52 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -42,8 +42,6 @@ def __init__(self, scheduler): {"queue-future-release": self.future_release, "queue_release": self.release} ) - self.scheduler.extensions["queues"] = self - def create(self, name=None, client=None, maxsize=0): logger.debug(f"Queue name: {name}") if name not in self.queues: diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py index 82b72092b43..8bf2d74912d 100644 --- a/distributed/recreate_tasks.py +++ b/distributed/recreate_tasks.py @@ -23,7 +23,6 @@ def __init__(self, scheduler): self.scheduler = scheduler self.scheduler.handlers["get_runspec"] = self.get_runspec self.scheduler.handlers["get_error_cause"] = self.get_error_cause - self.scheduler.extensions["replay-tasks"] = self def _process_key(self, key): if isinstance(key, list): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5f178b0d3a0..849ac7b3aa2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -174,19 +174,20 @@ def nogil(func): Py_ssize_t, parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) ) -DEFAULT_EXTENSIONS = [ - LockExtension, - MultiLockExtension, - PublishExtension, - ReplayTaskScheduler, - QueueExtension, - VariableExtension, - PubSubSchedulerExtension, - SemaphoreExtension, - EventExtension, - ActiveMemoryManagerExtension, - MemorySamplerExtension, -] +DEFAULT_EXTENSIONS = { + "locks": LockExtension, + "multi_locks": MultiLockExtension, + "publish": PublishExtension, + "replay-tasks": ReplayTaskScheduler, + "queues": QueueExtension, + "variables": VariableExtension, + "pubsub": PubSubSchedulerExtension, + "semaphores": SemaphoreExtension, + "events": EventExtension, + "amm": ActiveMemoryManagerExtension, + "memory_sampler": MemorySamplerExtension, + "stealing": WorkStealing, +} ALL_TASK_STATES = declare( set, {"released", "waiting", "no-worker", "processing", "erred", "memory"} @@ -4015,11 +4016,13 @@ def __init__( self.periodic_callbacks["idle-timeout"] = pc if extensions is None: - extensions = list(DEFAULT_EXTENSIONS) - if dask.config.get("distributed.scheduler.work-stealing"): - extensions.append(WorkStealing) - for ext in extensions: - ext(self) + extensions = DEFAULT_EXTENSIONS.copy() + if not dask.config.get("distributed.scheduler.work-stealing"): + if "stealing" in extensions: + del extensions["stealing"] + + for name, extension in extensions.items(): + self.extensions[name] = extension(self) setproctitle("dask-scheduler [not started]") Scheduler._instances.add(self) @@ -4330,6 +4333,7 @@ def heartbeat_worker( host_info: dict = None, metrics: dict, executing: dict = None, + extensions: dict = None, ): parent: SchedulerState = cast(SchedulerState, self) address = self.coerce_address(address, resolve_address) @@ -4417,6 +4421,10 @@ def heartbeat_worker( if resources: self.add_resources(worker=address, resources=resources) + if extensions: + for name, data in extensions.items(): + self.extensions[name].heartbeat(ws, data) + return { "status": "OK", "time": local_now, diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 9e7abd872c0..d288462b706 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -69,8 +69,6 @@ def __init__(self, scheduler): } ) - self.scheduler.extensions["semaphores"] = self - # {metric_name: {semaphore_name: metric}} self.metrics = { "acquire_total": defaultdict(int), # counter diff --git a/distributed/stealing.py b/distributed/stealing.py index 8fb7e14a19f..54ef0098c63 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -71,7 +71,6 @@ def __init__(self, scheduler): ) # `callback_time` is in milliseconds self.scheduler.add_plugin(self) - self.scheduler.extensions["stealing"] = self self.scheduler.events["stealing"] = deque(maxlen=100000) self.count = 0 # { task state: } diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 646717fec03..d3a0ce8b28f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3324,6 +3324,43 @@ async def test_Worker__to_dict(c, s, a): assert d["data"] == ["x"] +@gen_cluster(nthreads=[]) +async def test_extension_methods(s): + flag = False + shutdown = False + + class WorkerExtension: + def __init__(self, worker): + pass + + def heartbeat(self): + return {"data": 123} + + async def close(self): + nonlocal shutdown + shutdown = True + + class SchedulerExtension: + def __init__(self, scheduler): + self.scheduler = scheduler + pass + + def heartbeat(self, ws, data: dict): + nonlocal flag + assert ws in self.scheduler.workers.values() + assert data == {"data": 123} + flag = True + + s.extensions["test"] = SchedulerExtension(s) + + async with Worker(s.address, extensions={"test": WorkerExtension}) as w: + assert not shutdown + await w.heartbeat() + assert flag + + assert shutdown + + @gen_cluster() async def test_benchmark_hardware(s, a, b): sizes = ["1 kiB", "10 kiB"] diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index ecdfa8fd003..468bd90d463 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -254,13 +254,9 @@ def test_secede_without_stealing_issue_1262(): Tests that seceding works with the Stealing extension disabled https://github.com/dask/distributed/issues/1262 """ - - # turn off all extensions - extensions = [] - # run the loop as an inner function so all workers are closed # and exceptions can be examined - @gen_cluster(client=True, scheduler_kwargs={"extensions": extensions}) + @gen_cluster(client=True, scheduler_kwargs={"extensions": {}}) async def secede_test(c, s, a, b): def func(x): with worker_client() as wc: diff --git a/distributed/variable.py b/distributed/variable.py index a27abc3ab85..143df9e4153 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -40,8 +40,6 @@ def __init__(self, scheduler): self.scheduler.stream_handlers["variable-future-release"] = self.future_release self.scheduler.stream_handlers["variable_delete"] = self.delete - self.scheduler.extensions["variables"] = self - async def set(self, name=None, key=None, data=None, client=None): if key is not None: record = {"type": "Future", "value": key} diff --git a/distributed/worker.py b/distributed/worker.py index 434700e937b..c607bb5afa1 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -141,7 +141,10 @@ no_value = "--no-value-sentinel--" -DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension, ShuffleWorkerExtension] +DEFAULT_EXTENSIONS: dict[str, type] = { + "pubsub": PubSubWorkerExtension, + "shuffle": ShuffleWorkerExtension, +} DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {} @@ -437,7 +440,7 @@ def __init__( security: Security | dict[str, Any] | None = None, contact_address: str | None = None, heartbeat_interval: Any = "1s", - extensions: list[type] | None = None, + extensions: dict[str, type] | None = None, metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS, startup_information: Mapping[ str, Callable[[Worker], Any] @@ -790,8 +793,9 @@ def __init__( if extensions is None: extensions = DEFAULT_EXTENSIONS - for ext in extensions: - ext(self) + self.extensions = { + name: extension(self) for name, extension in extensions.items() + } self.memory_manager = WorkerMemoryManager( self, @@ -1132,6 +1136,11 @@ async def heartbeat(self): for key in self.active_keys if key in self.tasks }, + extensions={ + name: extension.heartbeat() + for name, extension in self.extensions.items() + if hasattr(extension, "heartbeat") + }, ) end = time() middle = (start + end) / 2 @@ -1414,6 +1423,10 @@ async def close( for preload in self.preloads: await preload.teardown() + for extension in self.extensions.values(): + if hasattr(extension, "close"): + await extension.close() + if nanny and self.nanny: with self.rpc(self.nanny) as r: await r.close_gracefully() diff --git a/setup.py b/setup.py index 999c0a8b9d6..0f57c525795 100755 --- a/setup.py +++ b/setup.py @@ -105,9 +105,9 @@ ], entry_points=""" [console_scripts] - dask-ssh=distributed.cli.dask_ssh:go - dask-scheduler=distributed.cli.dask_scheduler:go - dask-worker=distributed.cli.dask_worker:go + dask-ssh=distributed.cli.dask_ssh:main + dask-scheduler=distributed.cli.dask_scheduler:main + dask-worker=distributed.cli.dask_worker:main """, # https://mypy.readthedocs.io/en/latest/installed_packages.html zip_safe=False,