Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concatenate flag to .compute() #1138

Merged
merged 2 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def persist(self, fuse=True, **kwargs):
out = self.optimize(fuse=fuse)
return DaskMethodsMixin.persist(out, **kwargs)

def compute(self, fuse=True, **kwargs):
def compute(self, fuse=True, concatenate=True, **kwargs):
"""Compute this DataFrame.

This turns a lazy Dask DataFrame into an in-memory pandas DataFrame.
Expand All @@ -463,6 +463,10 @@ def compute(self, fuse=True, **kwargs):
Whether to fuse the expression tree before computing. Fusing significantly
reduces the number of tasks and improves performance. It shouldn't be
disabled unless absolutely necessary.
concatenate : bool, default True
Whether to concatenate all partitions into a single one before computing.
Concatenating enables more powerful optimizations but it also incurs additional
data transfer cost. Generally, it should be enabled.
kwargs
Extra keywords to forward to the base compute function.

Expand All @@ -471,7 +475,7 @@ def compute(self, fuse=True, **kwargs):
dask.compute
"""
out = self
if not isinstance(out, Scalar):
if not isinstance(out, Scalar) and concatenate:
out = out.repartition(npartitions=1)
out = out.optimize(fuse=fuse)
return DaskMethodsMixin.compute(out, **kwargs)
Expand Down
41 changes: 39 additions & 2 deletions dask_expr/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

distributed = pytest.importorskip("distributed")

from distributed import Client, LocalCluster
from distributed import Client, LocalCluster, SchedulerPlugin
from distributed.shuffle._core import id_from_key
from distributed.utils_test import cleanup # noqa F401
from distributed.utils_test import client as c # noqa F401
from distributed.utils_test import gen_cluster
from distributed.utils_test import gen_cluster, loop, loop_in_thread # noqa F401

import dask_expr as dx

Expand Down Expand Up @@ -456,3 +457,39 @@ def test_respect_context_shuffle(df, pdf, func):
with dask.config.set({"dataframe.shuffle.method": "tasks"}):
result = q.optimize(fuse=False)
assert len([x for x in result.walk() if isinstance(x, P2PShuffle)]) > 0


@pytest.mark.parametrize("concatenate", [True, False])
def test_compute_concatenates(loop, concatenate):
pdf = pd.DataFrame({"a": np.random.randint(1, 100, (100,)), "b": 1})
df = from_pandas(pdf, npartitions=10)

class Plugin(SchedulerPlugin):
def start(self, *args, **kwargs):
self.repartition_in_tasks = False

def update_graph(
self,
scheduler,
*,
client,
keys,
tasks,
annotations,
priority,
dependencies,
**kwargs,
):
for key in dependencies:
if not isinstance(key, tuple):
continue
group = key[0]
if not isinstance(group, str):
continue
self.repartition_in_tasks |= group.startswith("repartitiontofewer")

with Client(loop=loop) as c:
c.register_plugin(Plugin(), name="tracker")
df.compute(fuse=False, concatenate=concatenate)
plugin = c.cluster.scheduler.plugins["tracker"]
assert plugin.repartition_in_tasks is concatenate
Loading