Skip to content

Commit

Permalink
Merge branch 'main' into PytatoKeyBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Feb 6, 2024
2 parents e680934 + dc1421b commit 6238242
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
4 changes: 3 additions & 1 deletion pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv,
DistributedRecv, DistributedSend,
DistributedSendRefHolder,
make_distributed_send_ref_holder,
staple_distributed_send)
from pytato.distributed.partition import (
find_distributed_partition, DistributedGraphPart, DistributedGraphPartition)
Expand Down Expand Up @@ -161,7 +162,8 @@ def set_debug_enabled(flag: bool) -> None:
"trace_call",

"make_distributed_recv", "make_distributed_send", "DistributedRecv",
"DistributedSend", "staple_distributed_send", "DistributedSendRefHolder",
"DistributedSend", "make_distributed_send_ref_holder",
"staple_distributed_send", "DistributedSendRefHolder",

"DistributedGraphPart",
"DistributedGraphPartition",
Expand Down
34 changes: 22 additions & 12 deletions pytato/distributed/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
These functions aid in creating communication nodes:
.. autofunction:: make_distributed_send
.. autofunction:: make_distributed_send_ref_holder
.. autofunction:: staple_distributed_send
.. autofunction:: make_distributed_recv
For completeness, individual (non-held/"stapled") :class:`DistributedSend` nodes
can be made via this function:
.. autofunction:: make_distributed_send
Redirections for the documentation tool
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -64,7 +61,7 @@

from pytato.array import (
Array, _SuppliedShapeAndDtypeMixin, ShapeType, AxesT,
_get_default_axes, ConvertibleToShape, normalize_shape)
_get_default_axes, _get_default_tags, ConvertibleToShape, normalize_shape)

CommTagType = Hashable

Expand Down Expand Up @@ -223,7 +220,18 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp
DistributedSend:
"""Make a :class:`DistributedSend` object."""
return DistributedSend(data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag,
tags=send_tags)
tags=(send_tags | _get_default_tags()))


def make_distributed_send_ref_holder(
send: DistributedSend,
passthrough_data: Array,
tags: FrozenSet[Tag] = frozenset()
) -> DistributedSendRefHolder:
"""Make a :class:`DistributedSendRefHolder` object."""
return DistributedSendRefHolder(
send=send, passthrough_data=passthrough_data,
tags=(tags | _get_default_tags()))


def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType,
Expand All @@ -233,10 +241,12 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT
DistributedSendRefHolder:
"""Make a :class:`DistributedSend` object wrapped in a
:class:`DistributedSendRefHolder` object."""
return DistributedSendRefHolder(
send=DistributedSend(data=sent_data, dest_rank=dest_rank,
comm_tag=comm_tag, tags=send_tags),
passthrough_data=stapled_to, tags=ref_holder_tags)
return make_distributed_send_ref_holder(
send=make_distributed_send(
sent_data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag,
send_tags=send_tags),
passthrough_data=stapled_to,
tags=ref_holder_tags)


def make_distributed_recv(src_rank: int, comm_tag: CommTagType,
Expand All @@ -253,7 +263,7 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType,
dtype = np.dtype(dtype)
return DistributedRecv(
src_rank=src_rank, comm_tag=comm_tag, shape=shape, dtype=dtype,
tags=tags, axes=axes)
axes=axes, tags=(tags | _get_default_tags()))

# }}}

Expand Down

0 comments on commit 6238242

Please sign in to comment.