Skip to content

Commit

Permalink
data_needed exclusively contains tasks in fetch state (#6481)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 1, 2022
1 parent b5ef418 commit a341432
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 291 deletions.
114 changes: 114 additions & 0 deletions distributed/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import heapq
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator
from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9)
from typing import Any, TypeVar, cast

T = TypeVar("T", bound=Hashable)


# TODO change to UserDict[K, V] (requires Python >=3.9)
class LRU(UserDict):
"""Limited size mapping, evicting the least recently looked-up key when full"""

def __init__(self, maxsize: float):
super().__init__()
self.data = OrderedDict()
self.maxsize = maxsize

def __getitem__(self, key):
value = super().__getitem__(key)
cast(OrderedDict, self.data).move_to_end(key)
return value

def __setitem__(self, key, value):
if len(self) >= self.maxsize:
cast(OrderedDict, self.data).popitem(last=False)
super().__setitem__(key, value)


class HeapSet(MutableSet[T]):
"""A set-like where the `pop` method returns the smallest item, as sorted by an
arbitrary key function. Ties are broken by oldest first.
Values must be compatible with :mod:`weakref`.
"""

__slots__ = ("key", "_data", "_heap", "_inc")
key: Callable[[T], Any]
_data: set[T]
_inc: int
_heap: list[tuple[Any, int, weakref.ref[T]]]

def __init__(self, *, key: Callable[[T], Any]):
# FIXME https://github.com/python/mypy/issues/708
self.key = key # type: ignore
self._data = set()
self._inc = 0
self._heap = []

def __repr__(self) -> str:
return f"<{type(self).__name__}: {len(self)} items>"

def __contains__(self, value: object) -> bool:
return value in self._data

def __len__(self) -> int:
return len(self._data)

def add(self, value: T) -> None:
if value in self._data:
return
k = self.key(value) # type: ignore
vref = weakref.ref(value)
heapq.heappush(self._heap, (k, self._inc, vref))
self._data.add(value)
self._inc += 1

def discard(self, value: T) -> None:
self._data.discard(value)
if not self._data:
self._heap.clear()

def peek(self) -> T:
"""Get the smallest element without removing it"""
if not self._data:
raise KeyError("peek into empty set")
while True:
value = self._heap[0][2]()
if value in self._data:
return value
heapq.heappop(self._heap)

def pop(self) -> T:
if not self._data:
raise KeyError("pop from an empty set")
while True:
_, _, vref = heapq.heappop(self._heap)
value = vref()
if value in self._data:
self._data.discard(value)
return value

def __iter__(self) -> Iterator[T]:
"""Iterate over all elements. This is a O(n) operation which returns the
elements in pseudo-random order.
"""
return iter(self._data)

def sorted(self) -> Iterator[T]:
"""Iterate over all elements. This is a O(n*logn) operation which returns the
elements in order, from smallest to largest according to the key and insertion
order.
"""
for _, _, vref in sorted(self._heap):
value = vref()
if value in self._data:
yield value

def clear(self) -> None:
self._data.clear()
self._heap.clear()
150 changes: 150 additions & 0 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

from distributed.collections import LRU, HeapSet


def test_lru():
l = LRU(maxsize=3)
l["a"] = 1
l["b"] = 2
l["c"] = 3
assert list(l.keys()) == ["a", "b", "c"]

# Use "a" and ensure it becomes the most recently used item
l["a"]
assert list(l.keys()) == ["b", "c", "a"]

# Ensure maxsize is respected
l["d"] = 4
assert len(l) == 3
assert list(l.keys()) == ["c", "a", "d"]


def test_heapset():
class C:
def __init__(self, k, i):
self.k = k
self.i = i

def __hash__(self):
return hash(self.k)

def __eq__(self, other):
return isinstance(other, C) and other.k == self.k

heap = HeapSet(key=lambda c: c.i)

cx = C("x", 2)
cy = C("y", 1)
cz = C("z", 3)
cw = C("w", 4)
heap.add(cx)
heap.add(cy)
heap.add(cz)
heap.add(cw)
heap.add(C("x", 0)) # Ignored; x already in heap
assert len(heap) == 4
assert repr(heap) == "<HeapSet: 4 items>"

assert cx in heap
assert cy in heap
assert cz in heap
assert cw in heap

heap_sorted = heap.sorted()
# iteration does not empty heap
assert len(heap) == 4
assert next(heap_sorted) is cy
assert next(heap_sorted) is cx
assert next(heap_sorted) is cz
assert next(heap_sorted) is cw
with pytest.raises(StopIteration):
next(heap_sorted)

assert set(heap) == {cx, cy, cz, cw}

assert heap.peek() is cy
assert heap.pop() is cy
assert cx in heap
assert cy not in heap
assert cz in heap
assert cw in heap

assert heap.peek() is cx
assert heap.pop() is cx
assert heap.pop() is cz
assert heap.pop() is cw
assert not heap
with pytest.raises(KeyError):
heap.pop()
with pytest.raises(KeyError):
heap.peek()

# Test out-of-order discard
heap.add(cx)
heap.add(cy)
heap.add(cz)
heap.add(cw)
assert heap.peek() is cy

heap.remove(cy)
assert cy not in heap
with pytest.raises(KeyError):
heap.remove(cy)

heap.discard(cw)
assert cw not in heap
heap.discard(cw)

assert len(heap) == 2
assert list(heap.sorted()) == [cx, cz]
# cy is at the top of heap._heap, but is skipped
assert heap.peek() is cx
assert heap.pop() is cx
assert heap.peek() is cz
assert heap.pop() is cz
# heap._heap is not empty
assert not heap
with pytest.raises(KeyError):
heap.peek()
with pytest.raises(KeyError):
heap.pop()
assert list(heap.sorted()) == []

# Test clear()
heap.add(cx)
heap.clear()
assert not heap
heap.add(cx)
assert cx in heap
# Test discard last element
heap.discard(cx)
assert not heap
heap.add(cx)
assert cx in heap

# Test resilience to failure in key()
bad_key = C("bad_key", 0)
del bad_key.i
with pytest.raises(AttributeError):
heap.add(bad_key)
assert len(heap) == 1
assert set(heap) == {cx}

# Test resilience to failure in weakref.ref()
class D:
__slots__ = ("i",)

def __init__(self, i):
self.i = i

with pytest.raises(TypeError):
heap.add(D("bad_weakref", 2))
assert len(heap) == 1
assert set(heap) == {cx}

# Test resilience to key() returning non-sortable output
with pytest.raises(TypeError):
heap.add(C("unsortable_key", None))
assert len(heap) == 1
assert set(heap) == {cx}
19 changes: 0 additions & 19 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from distributed.compatibility import MACOS, WINDOWS
from distributed.metrics import time
from distributed.utils import (
LRU,
All,
Log,
Logs,
Expand Down Expand Up @@ -594,24 +593,6 @@ def test_parse_ports():
parse_ports("100.5")


def test_lru():

l = LRU(maxsize=3)
l["a"] = 1
l["b"] = 2
l["c"] = 3
assert list(l.keys()) == ["a", "b", "c"]

# Use "a" and ensure it becomes the most recently used item
l["a"]
assert list(l.keys()) == ["b", "c", "a"]

# Ensure maxsize is respected
l["d"] = 4
assert len(l) == 3
assert list(l.keys()) == ["c", "a", "d"]


@gen_test()
async def test_offload():
assert (await offload(inc, 1)) == 2
Expand Down
75 changes: 0 additions & 75 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2589,81 +2589,6 @@ def __reduce__(self):
assert "return lambda: 1 / 0, ()" in logvalue


@gen_cluster(client=True)
async def test_gather_dep_exception_one_task(c, s, a, b):
"""Ensure an exception in a single task does not tear down an entire batch of gather_dep
See also https://github.com/dask/distributed/issues/5152
See also test_gather_dep_exception_one_task_2
"""
fut = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, 2, workers=[a.address], key="f2")
fut3 = c.submit(inc, 3, workers=[a.address], key="f3")

import asyncio

event = asyncio.Event()
write_queue = asyncio.Queue()
b.rpc = _LockedCommPool(b.rpc, write_event=event, write_queue=write_queue)
b.rpc.remove(a.address)

def sink(a, b, *args):
return a + b

res1 = c.submit(sink, fut, fut2, fut3, workers=[b.address])
res2 = c.submit(sink, fut, fut2, workers=[b.address])

# Wait until we're sure the worker is attempting to fetch the data
while True:
peer_addr, msg = await write_queue.get()
if peer_addr == a.address and msg["op"] == "get_data":
break

# Provoke an "impossible transition exception"
# By choosing a state which doesn't exist we're not running into validation
# errors and the state machine should raise if we want to transition from
# fetch to memory

b.validate = False
b.tasks[fut3.key].state = "fetch"
event.set()

assert await res1 == 5
assert await res2 == 5

del res1, res2, fut, fut2
fut3.release()

while a.tasks and b.tasks:
await asyncio.sleep(0.1)


@gen_cluster(client=True)
async def test_gather_dep_exception_one_task_2(c, s, a, b):
"""Ensure an exception in a single task does not tear down an entire batch of gather_dep
The below triggers an fetch->memory transition
See also https://github.com/dask/distributed/issues/5152
See also test_gather_dep_exception_one_task
"""
# This test does not trigger the condition reliably but is a very easy case
# which should function correctly regardless

fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[b.address], key="f2")

while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight":
await asyncio.sleep(0)

s.handle_missing_data(
key="f1", worker=b.address, errant_worker=a.address, stimulus_id="test"
)

await fut2


@gen_cluster(client=True)
async def test_acquire_replicas(c, s, a, b):
fut = c.submit(inc, 1, workers=[a.address])
Expand Down
Loading

0 comments on commit a341432

Please sign in to comment.