Skip to content

Commit

Permalink
number_distributed_tags: non-set, non-sorted numbering
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Nov 14, 2023
1 parent a8f380f commit d0adcda
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pytato/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
.. class:: CommTagType
A type representing a communication tag. Communication tags must be
hashable and totally ordered (and hence comparable).
hashable.
.. class:: ShapeType
Expand Down
46 changes: 13 additions & 33 deletions pytato/distributed/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar
from typing import TYPE_CHECKING, Tuple, TypeVar

from pytato.distributed.partition import DistributedGraphPartition

Expand Down Expand Up @@ -62,53 +62,33 @@ def number_distributed_tags(
This is a potentially heavyweight MPI-collective operation on
*mpi_communicator*.
.. note::
This function requires that symbolic tags are comparable.
"""
tags = frozenset({
from pytools import flatten

tags = tuple([
recv.comm_tag
for part in partition.parts.values()
for recv in part.name_to_recv_node.values()
} | {
] + [
send.comm_tag
for part in partition.parts.values()
for sends in part.name_to_send_nodes.values()
for send in sends})

from mpi4py import MPI

def set_union(
set_a: FrozenSet[T], set_b: FrozenSet[T],
mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]:
assert mpi_data_type is None
assert isinstance(set_a, frozenset)
assert isinstance(set_b, frozenset)

return set_a | set_b
for send in sends])

root_rank = 0

set_union_mpi_op = MPI.Op.Create(
# type ignore reason: mpi4py misdeclares op functions as returning
# None.
set_union, # type: ignore[arg-type]
commute=True)
try:
all_tags = mpi_communicator.reduce(
tags, set_union_mpi_op, root=root_rank)
finally:
set_union_mpi_op.Free()
all_tags = mpi_communicator.gather(tags, root=root_rank)

if mpi_communicator.rank == root_rank:
sym_tag_to_int_tag = {}
next_tag = base_tag
assert isinstance(all_tags, frozenset)
assert isinstance(all_tags, list)
assert len(all_tags) == mpi_communicator.size

for sym_tag in sorted(all_tags):
sym_tag_to_int_tag[sym_tag] = next_tag
next_tag += 1
for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call]
if sym_tag not in sym_tag_to_int_tag:
sym_tag_to_int_tag[sym_tag] = next_tag
next_tag += 1

mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank)
else:
Expand Down
10 changes: 8 additions & 2 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory):
ntests = 10
for i in range(ntests):
seed = 120 + i
print(f"Step {i} {seed}")
print(f"Step {i} {seed=}")

# {{{ compute value with communication

Expand All @@ -278,7 +278,13 @@ def gen_comm(rdagc):

nonlocal comm_tag
comm_tag += 1
tag = (comm_tag, _RandomDAGTag) # noqa: B023

if comm_tag % 5 == 1:
tag = (comm_tag, frozenset([_RandomDAGTag, _RandomDAGTag]))
elif comm_tag % 5 == 2:
tag = (comm_tag, (_RandomDAGTag,))
else:
tag = (comm_tag, _RandomDAGTag) # noqa: B023

inner = make_random_dag(rdagc)
return pt.staple_distributed_send(
Expand Down

0 comments on commit d0adcda

Please sign in to comment.