forked from dask/distributed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'AMM/staging' into AMM/RetireWorker
- Loading branch information
Showing
10 changed files
with
1,274 additions
and
111 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
try: | ||
import pandas | ||
except ImportError: | ||
SHUFFLE_AVAILABLE = False | ||
else: | ||
del pandas | ||
SHUFFLE_AVAILABLE = True | ||
|
||
from .shuffle import rearrange_by_column_p2p | ||
from .shuffle_extension import ShuffleId, ShuffleMetadata, ShuffleWorkerExtension | ||
|
||
__all__ = [ | ||
"SHUFFLE_AVAILABLE", | ||
"rearrange_by_column_p2p", | ||
"ShuffleId", | ||
"ShuffleMetadata", | ||
"ShuffleWorkerExtension", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from dask.base import tokenize | ||
from dask.dataframe import DataFrame | ||
from dask.delayed import Delayed, delayed | ||
from dask.highlevelgraph import HighLevelGraph | ||
|
||
from .shuffle_extension import NewShuffleMetadata, ShuffleId, ShuffleWorkerExtension | ||
|
||
if TYPE_CHECKING: | ||
import pandas as pd | ||
|
||
|
||
def get_ext() -> ShuffleWorkerExtension: | ||
from distributed import get_worker | ||
|
||
try: | ||
worker = get_worker() | ||
except ValueError as e: | ||
raise RuntimeError( | ||
"`shuffle='p2p'` requires Dask's distributed scheduler. This task is not running on a Worker; " | ||
"please confirm that you've created a distributed Client and are submitting this computation through it." | ||
) from e | ||
extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle") | ||
if not extension: | ||
raise RuntimeError( | ||
f"The worker {worker.address} does not have a ShuffleExtension. " | ||
"Is pandas installed on the worker?" | ||
) | ||
return extension | ||
|
||
|
||
def shuffle_setup(metadata: NewShuffleMetadata) -> None: | ||
get_ext().create_shuffle(metadata) | ||
|
||
|
||
def shuffle_transfer(input: pd.DataFrame, id: ShuffleId, setup=None) -> None: | ||
get_ext().add_partition(input, id) | ||
|
||
|
||
def shuffle_unpack(id: ShuffleId, output_partition: int, barrier=None) -> pd.DataFrame: | ||
return get_ext().get_output_partition(id, output_partition) | ||
|
||
|
||
def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None: | ||
get_ext().barrier(id) | ||
|
||
|
||
def rearrange_by_column_p2p( | ||
df: DataFrame, | ||
column: str, | ||
npartitions: int | None = None, | ||
): | ||
npartitions = npartitions or df.npartitions | ||
token = tokenize(df, column, npartitions) | ||
|
||
setup = delayed(shuffle_setup, pure=True)( | ||
NewShuffleMetadata( | ||
ShuffleId(token), | ||
df._meta, | ||
column, | ||
npartitions, | ||
) | ||
) | ||
|
||
transferred = df.map_partitions( | ||
shuffle_transfer, | ||
token, | ||
setup, | ||
meta=df, | ||
enforce_metadata=False, | ||
transform_divisions=False, | ||
) | ||
|
||
barrier_key = "shuffle-barrier-" + token | ||
barrier_dsk = {barrier_key: (shuffle_barrier, token, transferred.__dask_keys__())} | ||
barrier = Delayed( | ||
barrier_key, | ||
HighLevelGraph.from_collections( | ||
barrier_key, barrier_dsk, dependencies=[transferred] | ||
), | ||
) | ||
|
||
name = "shuffle-unpack-" + token | ||
dsk = { | ||
(name, i): (shuffle_unpack, token, i, barrier_key) for i in range(npartitions) | ||
} | ||
# TODO use this blockwise (https://github.com/coiled/oss-engineering/issues/49) | ||
# Changes task names, so breaks setting worker restrictions at the moment. | ||
# Also maybe would be nice if the `DataFrameIOLayer` interface supported this? | ||
# dsk = blockwise( | ||
# shuffle_unpack, | ||
# name, | ||
# "i", | ||
# token, | ||
# None, | ||
# BlockwiseDepDict({(i,): i for i in range(npartitions)}), | ||
# "i", | ||
# barrier_key, | ||
# None, | ||
# numblocks={}, | ||
# ) | ||
|
||
return DataFrame( | ||
HighLevelGraph.from_collections(name, dsk, [barrier]), | ||
name, | ||
df._meta, | ||
[None] * (npartitions + 1), | ||
) |
Oops, something went wrong.