Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

Commit

Permalink
Move .cuda back to main functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhijian Liu committed Oct 13, 2021
1 parent c0e2bcf commit 03d7cae
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchpack/distributed/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
def _serialize(obj: Any) -> torch.Tensor:
buffer = pickle.dumps(obj)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).cuda()
tensor = torch.ByteTensor(storage)
return tensor


def _deserialize(tensor: torch.Tensor, size: Optional[int] = None) -> Any:
buffer = tensor.cpu().numpy().tobytes()
buffer = tensor.numpy().tobytes()
if size is not None:
buffer = buffer[:size]
obj = pickle.loads(buffer)
Expand All @@ -31,7 +31,7 @@ def broadcast(obj: Any, src: int = 0) -> Any:

# serialize
if context.rank() == src:
tensor = _serialize(obj)
tensor = _serialize(obj).cuda()

# broadcast the tensor size
if context.rank() == src:
Expand All @@ -47,7 +47,7 @@ def broadcast(obj: Any, src: int = 0) -> Any:

# deserialize
if context.rank() != src:
obj = _deserialize(tensor)
obj = _deserialize(tensor.cpu())
return obj


Expand All @@ -57,7 +57,7 @@ def allgather(obj: Any) -> List[Any]:
return [obj]

# serialize
tensor = _serialize(obj)
tensor = _serialize(obj).cuda()

# gather the tensor size
local_size = torch.LongTensor([tensor.numel()]).cuda()
Expand All @@ -76,7 +76,7 @@ def allgather(obj: Any) -> List[Any]:
# deserialize
objs = []
for size, tensor in zip(sizes, tensors):
obj = _deserialize(tensor, size=size)
obj = _deserialize(tensor.cpu(), size=size)
objs.append(obj)
return objs

Expand Down

0 comments on commit 03d7cae

Please sign in to comment.