Skip to content

Commit

Permalink
Fix P2P-based joins with explicit npartitions (#8470)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Jan 22, 2024
1 parent 33b2c72 commit a6ea9f4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
5 changes: 3 additions & 2 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ def _cull_dependencies(
"""
deps = {}
parts_out = parts_out or self._keys_to_parts(keys)
keys = {(self.name_input_left, i) for i in range(self.npartitions)}
keys |= {(self.name_input_right, i) for i in range(self.npartitions)}
keys = {(self.name_input_left, i) for i in range(self.n_partitions_left)}
keys |= {(self.name_input_right, i) for i in range(self.n_partitions_right)}
# Protect against mutations later on with frozenset
keys = frozenset(keys)
for part in parts_out:
Expand Down Expand Up @@ -352,6 +352,7 @@ def cull(self, keys: Iterable[str], all_keys: Any) -> tuple[HashJoinP2PLayer, di
parameter.
"""
parts_out = self._keys_to_parts(keys)

culled_deps = self._cull_dependencies(keys, parts_out=parts_out)
if parts_out != set(self.parts_out):
culled_layer = self._cull(parts_out)
Expand Down
14 changes: 14 additions & 0 deletions distributed/shuffle/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,20 @@ async def test_index_merge_p2p(c, s, a, b, how):
)


@pytest.mark.parametrize("npartitions", [4, 5, 10, 20])
@gen_cluster(client=True)
async def test_merge_with_npartitions(c, s, a, b, npartitions):
pdf = pd.DataFrame({"a": [1, 2, 3, 4] * 10, "b": 1})

left = dd.from_pandas(pdf, npartitions=10)
right = dd.from_pandas(pdf, npartitions=5)

expected = pdf.merge(pdf)
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
result = await c.compute(left.merge(right, npartitions=npartitions))
assert_eq(result, expected, check_index=False)


class LimitedGetOrCreateShuffleRunManager(_ShuffleRunManager):
seen: set[ShuffleId]
block_get_or_create: asyncio.Event
Expand Down

0 comments on commit a6ea9f4

Please sign in to comment.