Skip to content

Commit

Permalink
Merge branch 'main' into setuptools
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 29, 2022
2 parents 2c15902 + ed48736 commit a7896e9
Show file tree
Hide file tree
Showing 21 changed files with 99 additions and 121 deletions.
6 changes: 3 additions & 3 deletions continuous_integration/recipes/distributed/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ outputs:
track_features: # [cython_enabled]
- cythonized-scheduler # [cython_enabled]
entry_points:
- dask-scheduler = distributed.cli.dask_scheduler:go
- dask-ssh = distributed.cli.dask_ssh:go
- dask-worker = distributed.cli.dask_worker:go
- dask-scheduler = distributed.cli.dask_scheduler:main
- dask-ssh = distributed.cli.dask_ssh:main
- dask-worker = distributed.cli.dask_worker:main
requirements:
build:
- {{ compiler('c') }} # [cython_enabled]
Expand Down
9 changes: 2 additions & 7 deletions distributed/cli/dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tornado.ioloop import IOLoop

from distributed import Scheduler
from distributed.cli.utils import check_python_3, install_signal_handlers
from distributed.cli.utils import install_signal_handlers
from distributed.preloading import validate_preload_argv
from distributed.proctitle import (
enable_proctitle_on_children,
Expand Down Expand Up @@ -212,10 +212,5 @@ async def run():
logger.info("End scheduler at %r", scheduler.address)


def go():
check_python_3()
main()


if __name__ == "__main__":
go() # pragma: no cover
main() # pragma: no cover
8 changes: 1 addition & 7 deletions distributed/cli/dask_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import click

from distributed.cli.utils import check_python_3
from distributed.deploy.old_ssh import SSHCluster

logger = logging.getLogger("distributed.dask_ssh")
Expand Down Expand Up @@ -223,10 +222,5 @@ def main(
print("[ dask-ssh ]: Remote processes have been terminated. Exiting.")


def go():
check_python_3()
main()


if __name__ == "__main__":
go() # pragma: no cover
main() # pragma: no cover
9 changes: 2 additions & 7 deletions distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dask.system import CPU_COUNT

from distributed import Nanny
from distributed.cli.utils import check_python_3, install_signal_handlers
from distributed.cli.utils import install_signal_handlers
from distributed.comm import get_address_host_port
from distributed.deploy.utils import nprocesses_nthreads
from distributed.preloading import validate_preload_argv
Expand Down Expand Up @@ -486,10 +486,5 @@ async def run():
logger.info("End worker")


def go():
check_python_3()
main()


if __name__ == "__main__":
go() # pragma: no cover
main() # pragma: no cover
46 changes: 0 additions & 46 deletions distributed/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,5 @@
import click
from packaging.version import parse as parse_version
from tornado.ioloop import IOLoop

CLICK_VERSION = parse_version(click.__version__)

py3_err_msg = """
Warning: Your terminal does not set locales.
If you use unicode text inputs for command line options then this may cause
undesired behavior. This is rare.
If you don't use unicode characters in command line options then you can safely
ignore this message. This is the common case.
You can support unicode inputs by specifying encoding environment variables,
though exact solutions may depend on your system:
$ export LC_ALL=C.UTF-8
$ export LANG=C.UTF-8
For more information see: http://click.pocoo.org/5/python3/
""".lstrip()


def check_python_3():
"""Ensures that the environment is good for unicode on Python 3."""
# https://github.com/pallets/click/issues/448#issuecomment-246029304
import click.core

# TODO: Remove use of internal click functions
if CLICK_VERSION < parse_version("8.0.0"):
click.core._verify_python3_env = lambda: None
else:
click.core._verify_python_env = lambda: None

try:
from click import _unicodefun

if CLICK_VERSION < parse_version("8.0.0"):
_unicodefun._verify_python3_env()
else:
_unicodefun._verify_python_env()
except (TypeError, RuntimeError):
import click

click.echo(py3_err_msg, err=True)


def install_signal_handlers(loop=None, cleanup=None):
"""
Expand Down
10 changes: 7 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@

_current_client = ContextVar("_current_client", default=None)

DEFAULT_EXTENSIONS = [PubSubClientExtension]
DEFAULT_EXTENSIONS = {
"pubsub": PubSubClientExtension,
}

# Placeholder used in the get_dataset function(s)
NO_DEFAULT_PLACEHOLDER = "_no_default_"

Expand Down Expand Up @@ -928,8 +931,9 @@ def __init__(
server=self,
)

for ext in extensions:
ext(self)
self.extensions = {
name: extension(self) for name, extension in extensions.items()
}

preload = dask.config.get("distributed.client.preload")
preload_argv = dask.config.get("distributed.client.preload-argv")
Expand Down
2 changes: 0 additions & 2 deletions distributed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def __init__(self, scheduler):
}
)

self.scheduler.extensions["events"] = self

async def event_wait(self, name=None, timeout=None):
"""Wait until the event is set to true.
Returns false, when this did not happen in the given time
Expand Down
2 changes: 0 additions & 2 deletions distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def __init__(self, scheduler):
{"lock_acquire": self.acquire, "lock_release": self.release}
)

self.scheduler.extensions["locks"] = self

async def acquire(self, name=None, id=None, timeout=None):
with log_errors():
if isinstance(name, list):
Expand Down
2 changes: 0 additions & 2 deletions distributed/multi_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def __init__(self, scheduler):
{"multi_lock_acquire": self.acquire, "multi_lock_release": self.release}
)

self.scheduler.extensions["multi_locks"] = self

def _request_locks(self, locks: list[str], id: Hashable, num_locks: int) -> bool:
"""Request locks
Expand Down
1 change: 0 additions & 1 deletion distributed/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self, scheduler):
}

self.scheduler.handlers.update(handlers)
self.scheduler.extensions["publish"] = self

def put(self, keys=None, data=None, name=None, override=False, client=None):
with log_errors():
Expand Down
3 changes: 0 additions & 3 deletions distributed/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def __init__(self, scheduler):
}
)

self.scheduler.extensions["pubsub"] = self

def add_publisher(self, name=None, worker=None):
logger.debug("Add publisher: %s %s", name, worker)
self.publishers[name].add(worker)
Expand Down Expand Up @@ -178,7 +176,6 @@ def __init__(self, client):
self.client._stream_handlers.update({"pubsub-msg": self.handle_message})

self.subscribers = defaultdict(weakref.WeakSet)
self.client.extensions["pubsub"] = self # TODO: circular reference

async def handle_message(self, name=None, msg=None):
for sub in self.subscribers[name]:
Expand Down
2 changes: 0 additions & 2 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def __init__(self, scheduler):
{"queue-future-release": self.future_release, "queue_release": self.release}
)

self.scheduler.extensions["queues"] = self

def create(self, name=None, client=None, maxsize=0):
logger.debug(f"Queue name: {name}")
if name not in self.queues:
Expand Down
1 change: 0 additions & 1 deletion distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self, scheduler):
self.scheduler = scheduler
self.scheduler.handlers["get_runspec"] = self.get_runspec
self.scheduler.handlers["get_error_cause"] = self.get_error_cause
self.scheduler.extensions["replay-tasks"] = self

def _process_key(self, key):
if isinstance(key, list):
Expand Down
44 changes: 26 additions & 18 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,20 @@ def nogil(func):
Py_ssize_t, parse_bytes(dask.config.get("distributed.scheduler.default-data-size"))
)

DEFAULT_EXTENSIONS = [
LockExtension,
MultiLockExtension,
PublishExtension,
ReplayTaskScheduler,
QueueExtension,
VariableExtension,
PubSubSchedulerExtension,
SemaphoreExtension,
EventExtension,
ActiveMemoryManagerExtension,
MemorySamplerExtension,
]
DEFAULT_EXTENSIONS = {
"locks": LockExtension,
"multi_locks": MultiLockExtension,
"publish": PublishExtension,
"replay-tasks": ReplayTaskScheduler,
"queues": QueueExtension,
"variables": VariableExtension,
"pubsub": PubSubSchedulerExtension,
"semaphores": SemaphoreExtension,
"events": EventExtension,
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"stealing": WorkStealing,
}

ALL_TASK_STATES = declare(
set, {"released", "waiting", "no-worker", "processing", "erred", "memory"}
Expand Down Expand Up @@ -4015,11 +4016,13 @@ def __init__(
self.periodic_callbacks["idle-timeout"] = pc

if extensions is None:
extensions = list(DEFAULT_EXTENSIONS)
if dask.config.get("distributed.scheduler.work-stealing"):
extensions.append(WorkStealing)
for ext in extensions:
ext(self)
extensions = DEFAULT_EXTENSIONS.copy()
if not dask.config.get("distributed.scheduler.work-stealing"):
if "stealing" in extensions:
del extensions["stealing"]

for name, extension in extensions.items():
self.extensions[name] = extension(self)

setproctitle("dask-scheduler [not started]")
Scheduler._instances.add(self)
Expand Down Expand Up @@ -4330,6 +4333,7 @@ def heartbeat_worker(
host_info: dict = None,
metrics: dict,
executing: dict = None,
extensions: dict = None,
):
parent: SchedulerState = cast(SchedulerState, self)
address = self.coerce_address(address, resolve_address)
Expand Down Expand Up @@ -4417,6 +4421,10 @@ def heartbeat_worker(
if resources:
self.add_resources(worker=address, resources=resources)

if extensions:
for name, data in extensions.items():
self.extensions[name].heartbeat(ws, data)

return {
"status": "OK",
"time": local_now,
Expand Down
2 changes: 0 additions & 2 deletions distributed/semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def __init__(self, scheduler):
}
)

self.scheduler.extensions["semaphores"] = self

# {metric_name: {semaphore_name: metric}}
self.metrics = {
"acquire_total": defaultdict(int), # counter
Expand Down
1 change: 0 additions & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(self, scheduler):
)
# `callback_time` is in milliseconds
self.scheduler.add_plugin(self)
self.scheduler.extensions["stealing"] = self
self.scheduler.events["stealing"] = deque(maxlen=100000)
self.count = 0
# { task state: <stealing info dict> }
Expand Down
37 changes: 37 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,43 @@ async def test_Worker__to_dict(c, s, a):
assert d["data"] == ["x"]


@gen_cluster(nthreads=[])
async def test_extension_methods(s):
flag = False
shutdown = False

class WorkerExtension:
def __init__(self, worker):
pass

def heartbeat(self):
return {"data": 123}

async def close(self):
nonlocal shutdown
shutdown = True

class SchedulerExtension:
def __init__(self, scheduler):
self.scheduler = scheduler
pass

def heartbeat(self, ws, data: dict):
nonlocal flag
assert ws in self.scheduler.workers.values()
assert data == {"data": 123}
flag = True

s.extensions["test"] = SchedulerExtension(s)

async with Worker(s.address, extensions={"test": WorkerExtension}) as w:
assert not shutdown
await w.heartbeat()
assert flag

assert shutdown


@gen_cluster()
async def test_benchmark_hardware(s, a, b):
sizes = ["1 kiB", "10 kiB"]
Expand Down
6 changes: 1 addition & 5 deletions distributed/tests/test_worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,9 @@ def test_secede_without_stealing_issue_1262():
Tests that seceding works with the Stealing extension disabled
https://github.com/dask/distributed/issues/1262
"""

# turn off all extensions
extensions = []

# run the loop as an inner function so all workers are closed
# and exceptions can be examined
@gen_cluster(client=True, scheduler_kwargs={"extensions": extensions})
@gen_cluster(client=True, scheduler_kwargs={"extensions": {}})
async def secede_test(c, s, a, b):
def func(x):
with worker_client() as wc:
Expand Down
2 changes: 0 additions & 2 deletions distributed/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def __init__(self, scheduler):
self.scheduler.stream_handlers["variable-future-release"] = self.future_release
self.scheduler.stream_handlers["variable_delete"] = self.delete

self.scheduler.extensions["variables"] = self

async def set(self, name=None, key=None, data=None, client=None):
if key is not None:
record = {"type": "Future", "value": key}
Expand Down
Loading

0 comments on commit a7896e9

Please sign in to comment.