Skip to content

Commit

Permalink
Raise and avoid data loss of meta provided to P2P shuffle is wrong (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Feb 22, 2024
1 parent e6fe6f2 commit 1211e79
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 8 deletions.
3 changes: 3 additions & 0 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ def create_new_run(
participating_workers=set(worker_for.values()),
)

def validate_data(self, data: Any) -> None:
"""Validate payload data before shuffling"""

@abc.abstractmethod
def create_run_on_worker(
self,
Expand Down
5 changes: 3 additions & 2 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from distributed.metrics import context_meter, thread_time
from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import pickle_bytelist
from distributed.utils import Deadline, empty_context, log_errors, nbytes
Expand Down Expand Up @@ -201,13 +202,13 @@ def read(self, id: str) -> Any:
context_meter.digest_metric("p2p-disk-read", 1, "count")
context_meter.digest_metric("p2p-disk-read", size, "bytes")
except FileNotFoundError:
raise KeyError(id)
raise DataUnavailable(id)

if data:
self.bytes_read += size
return data
else:
raise KeyError(id)
raise DataUnavailable(id)

async def close(self) -> None:
await super().close()
Expand Down
4 changes: 4 additions & 0 deletions distributed/shuffle/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@

class ShuffleClosedError(RuntimeError):
pass


class DataUnavailable(Exception):
"""Raised when data is not available in the buffer"""
6 changes: 5 additions & 1 deletion distributed/shuffle/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dask.sizeof import sizeof

from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import log_errors

Expand All @@ -30,7 +31,10 @@ def read(self, id: str) -> Any:
if not self._inputs_done:
raise RuntimeError("Tried to read from file before done.")

shards = self._shards.pop(id) # Raises KeyError
try:
shards = self._shards.pop(id) # Raises KeyError
except KeyError:
raise DataUnavailable(id)
self.bytes_read += sum(map(sizeof, shards))
# Don't keep the serialized and the deserialized shards
# in memory at the same time
Expand Down
7 changes: 6 additions & 1 deletion distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
handle_transfer_errors,
handle_unpack_errors,
)
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof
Expand Down Expand Up @@ -527,7 +528,7 @@ def _get_output_partition(
try:
data = self._read_from_disk((partition_id,))
return convert_shards(data, self.meta)
except KeyError:
except DataUnavailable:
return self.meta.copy()

def _get_assigned_worker(self, id: int) -> str:
Expand All @@ -554,6 +555,10 @@ def output_partitions(self) -> Generator[int, None, None]:
def pick_worker(self, partition: int, workers: Sequence[str]) -> str:
return _get_worker_for_range_sharding(self.npartitions, partition, workers)

def validate_data(self, data: pd.DataFrame) -> None:
if set(data.columns) != set(self.meta.columns):
raise ValueError(f"Expected {self.meta.columns=} to match {data.columns=}.")

def create_run_on_worker(
self,
run_id: int,
Expand Down
1 change: 1 addition & 0 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def add_partition(
spec: ShuffleSpec,
**kwargs: Any,
) -> int:
spec.validate_data(data)
shuffle_run = self.get_or_create_shuffle(spec)
return shuffle_run.add_partition(
data=data,
Expand Down
5 changes: 3 additions & 2 deletions distributed/shuffle/tests/test_disk_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils_test import gen_test

Expand All @@ -32,7 +33,7 @@ async def test_basic(tmp_path):
x = mf.read("x")
y = mf.read("y")

with pytest.raises(KeyError):
with pytest.raises(DataUnavailable):
mf.read("z")

assert x == b"0" * 2000
Expand All @@ -57,7 +58,7 @@ async def test_read_before_flush(tmp_path):

await mf.flush()
assert mf.read("1") == b"foo"
with pytest.raises(KeyError):
with pytest.raises(DataUnavailable):
mf.read(2)


Expand Down
5 changes: 3 additions & 2 deletions distributed/shuffle/tests/test_memory_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.utils_test import gen_test

Expand All @@ -21,7 +22,7 @@ async def test_basic():
x = mf.read("x")
y = mf.read("y")

with pytest.raises(KeyError):
with pytest.raises(DataUnavailable):
mf.read("z")

assert x == [b"0" * 1000] * 2
Expand All @@ -42,7 +43,7 @@ async def test_read_before_flush():

await mf.flush()
assert mf.read("1") == [b"foo"]
with pytest.raises(KeyError):
with pytest.raises(DataUnavailable):
mf.read("2")


Expand Down
20 changes: 20 additions & 0 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2749,3 +2749,23 @@ async def test_drop_duplicates_stable_ordering(c, s, a, b, keep, disk):
)
expected = expected.drop_duplicates(subset=["name"], keep=keep)
dd.assert_eq(result, expected)


@gen_cluster(client=True)
async def test_wrong_meta_provided(c, s, a, b):
# https://github.com/dask/distributed/issues/8519
@dask.delayed
def data_gen():
return pd.DataFrame({"a": range(10)})

ddf = dd.from_delayed(
[data_gen()] * 2, meta=[("a", int), ("b", int)], verify_meta=False
)

with raises_with_cause(
RuntimeError,
r"shuffling \w* failed",
ValueError,
"meta",
):
await c.gather(c.compute(ddf.shuffle(on="a")))

0 comments on commit 1211e79

Please sign in to comment.