diff --git a/torchpack/distributed/comm.py b/torchpack/distributed/comm.py index 43c523f..c5a1820 100644 --- a/torchpack/distributed/comm.py +++ b/torchpack/distributed/comm.py @@ -1,51 +1,92 @@ import pickle -from typing import Any, List +from typing import Any, List, Optional import torch import torch.distributed from . import context -__all__ = ['allreduce', 'allgather', 'barrier'] +__all__ = ['broadcast', 'allgather', 'allreduce', 'barrier'] -def allreduce(data: Any, reduction: str = 'sum') -> Any: - data = allgather(data) - if reduction == 'sum': - return sum(data) +def _serialize(obj: Any) -> torch.Tensor: + buffer = pickle.dumps(obj) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).cuda() + return tensor + + +def _deserialize(tensor: torch.Tensor, size: Optional[int] = None) -> Any: + buffer = tensor.cpu().numpy().tobytes() + if size is not None: + buffer = buffer[:size] + obj = pickle.loads(buffer) + return obj + + +def broadcast(obj: Any, src: int = 0) -> Any: + world_size = context.size() + if world_size == 1: + return obj + + # serialize + if context.rank() == src: + tensor = _serialize(obj) + + # broadcast the tensor size + if context.rank() == src: + size = torch.LongTensor([tensor.numel()]).cuda() else: - raise NotImplementedError(reduction) + size = torch.LongTensor([0]).cuda() + torch.distributed.broadcast(size, src=src) + + # broadcast the tensor + if context.rank() != src: + tensor = torch.ByteTensor(size=(size.item(),)).cuda() + torch.distributed.broadcast(tensor, src=src) + + # deserialize + if context.rank() != src: + obj = _deserialize(tensor) + return obj -def allgather(data: Any) -> List[Any]: +def allgather(obj: Any) -> List[Any]: world_size = context.size() if world_size == 1: - return [data] + return [obj] - # serialized to a tensor - buffer = pickle.dumps(data) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).cuda() + # serialize + tensor = _serialize(obj) - # obtain tensor size of each rank + # gather the tensor size local_size = torch.LongTensor([tensor.numel()]).cuda() sizes = [torch.LongTensor([0]).cuda() for _ in range(world_size)] torch.distributed.all_gather(sizes, local_size) sizes = [int(size.item()) for size in sizes] max_size = max(sizes) - # receiving tensors from all ranks + # gather the tensor tensors = [torch.ByteTensor(size=(max_size,)).cuda() for _ in sizes] if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).cuda() tensor = torch.cat((tensor, padding), dim=0) torch.distributed.all_gather(tensors, tensor) - data = [] + # deserialize + objs = [] for size, tensor in zip(sizes, tensors): - buffer = tensor.cpu().numpy().tobytes()[:size] - data.append(pickle.loads(buffer)) - return data + obj = _deserialize(tensor, size=size) + objs.append(obj) + return objs + + +def allreduce(obj: Any, reduction: str = 'sum') -> Any: + objs = allgather(obj) + if reduction == 'sum': + return sum(objs) + else: + raise NotImplementedError(reduction) def barrier() -> None: