Skip to content

Commit

Permalink
explicit-comms shuffle: now register externals
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Feb 19, 2021
1 parent d440329 commit bc67216
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import dask.dataframe
import distributed
from dask.base import compute_as_if_collection, tokenize
from dask.dataframe.core import DataFrame, _concat
from dask.dataframe.core import DataFrame, _concat as dd_concat
from dask.dataframe.shuffle import shuffle_group
from dask.delayed import delayed
from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize

from ...proxify_host_file import ProxifyHostFile
from .. import comms


Expand Down Expand Up @@ -46,6 +47,7 @@ def sort_in_parts(
rank_to_out_part_ids: Dict[int, List[int]],
ignore_index: bool,
concat_dfs_of_same_output_partition: bool,
concat,
) -> Dict[int, List[List[DataFrame]]]:
""" Sort the list of grouped dataframes in `in_parts`
Expand Down Expand Up @@ -96,7 +98,7 @@ def sort_in_parts(
for i in range(len(rank_to_out_parts_list[rank])):
if len(rank_to_out_parts_list[rank][i]) > 1:
rank_to_out_parts_list[rank][i] = [
_concat(
concat(
rank_to_out_parts_list[rank][i], ignore_index=ignore_index
)
]
Expand Down Expand Up @@ -144,11 +146,30 @@ async def local_shuffle(
eps = s["eps"]
assert s["rank"] in workers

try:
hostfile = first(iter(in_parts[0].values()))._obj_pxy.get(
"hostfile", lambda: None
)()
except AttributeError:
hostfile = None

if isinstance(hostfile, ProxifyHostFile):

def concat(args, ignore_index=False):
if len(args) < 2:
return args[0]

return hostfile.add_external(dd_concat(args, ignore_index=ignore_index))

else:
concat = dd_concat

rank_to_out_parts_list = sort_in_parts(
in_parts,
rank_to_out_part_ids,
ignore_index,
concat_dfs_of_same_output_partition=True,
concat=concat,
)

# Communicate all the dataframe-partitions all-to-all. The result is
Expand Down Expand Up @@ -176,7 +197,7 @@ async def local_shuffle(
dfs.extend(rank_to_out_parts_list[myrank][i])
rank_to_out_parts_list[myrank][i] = None
if len(dfs) > 1:
ret.append(_concat(dfs, ignore_index=ignore_index))
ret.append(concat(dfs, ignore_index=ignore_index))
else:
ret.append(dfs[0])
return ret
Expand Down

0 comments on commit bc67216

Please sign in to comment.