Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

Commit

Permalink
Add dist.broadcast()
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhijian Liu committed Oct 13, 2021
1 parent 2873f68 commit c0e2bcf
Showing 1 changed file with 60 additions and 19 deletions.
79 changes: 60 additions & 19 deletions torchpack/distributed/comm.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,92 @@
import pickle
from typing import Any, List
from typing import Any, List, Optional

import torch
import torch.distributed

from . import context

__all__ = ['allreduce', 'allgather', 'barrier']
__all__ = ['broadcast', 'allgather', 'allreduce', 'barrier']


def allreduce(data: Any, reduction: str = 'sum') -> Any:
data = allgather(data)
if reduction == 'sum':
return sum(data)
def _serialize(obj: Any) -> torch.Tensor:
buffer = pickle.dumps(obj)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).cuda()
return tensor


def _deserialize(tensor: torch.Tensor, size: Optional[int] = None) -> Any:
buffer = tensor.cpu().numpy().tobytes()
if size is not None:
buffer = buffer[:size]
obj = pickle.loads(buffer)
return obj


def broadcast(obj: Any, src: int = 0) -> Any:
world_size = context.size()
if world_size == 1:
return obj

# serialize
if context.rank() == src:
tensor = _serialize(obj)

# broadcast the tensor size
if context.rank() == src:
size = torch.LongTensor([tensor.numel()]).cuda()
else:
raise NotImplementedError(reduction)
size = torch.LongTensor([0]).cuda()
torch.distributed.broadcast(size, src=src)

# broadcast the tensor
if context.rank() != src:
tensor = torch.ByteTensor(size=(size.item(),)).cuda()
torch.distributed.broadcast(tensor, src=src)

# deserialize
if context.rank() != src:
obj = _deserialize(tensor)
return obj


def allgather(data: Any) -> List[Any]:
def allgather(obj: Any) -> List[Any]:
world_size = context.size()
if world_size == 1:
return [data]
return [obj]

# serialized to a tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).cuda()
# serialize
tensor = _serialize(obj)

# obtain tensor size of each rank
# gather the tensor size
local_size = torch.LongTensor([tensor.numel()]).cuda()
sizes = [torch.LongTensor([0]).cuda() for _ in range(world_size)]
torch.distributed.all_gather(sizes, local_size)
sizes = [int(size.item()) for size in sizes]
max_size = max(sizes)

# receiving tensors from all ranks
# gather the tensor
tensors = [torch.ByteTensor(size=(max_size,)).cuda() for _ in sizes]
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()
tensor = torch.cat((tensor, padding), dim=0)
torch.distributed.all_gather(tensors, tensor)

data = []
# deserialize
objs = []
for size, tensor in zip(sizes, tensors):
buffer = tensor.cpu().numpy().tobytes()[:size]
data.append(pickle.loads(buffer))
return data
obj = _deserialize(tensor, size=size)
objs.append(obj)
return objs


def allreduce(obj: Any, reduction: str = 'sum') -> Any:
objs = allgather(obj)
if reduction == 'sum':
return sum(objs)
else:
raise NotImplementedError(reduction)


def barrier() -> None:
Expand Down

0 comments on commit c0e2bcf

Please sign in to comment.