-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][compiled graphs] Add CPU-based NCCL communicator for development #48440
Merged
Merged
Changes from 3 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
df0f227
Initial work
tfsingh b043233
Undo sorting
tfsingh b7c01c2
usiedlist of ranks as barrier_key instead of type of operation + numb…
anyadontfly 00f0f4e
Initial response to review
tfsingh 2528f86
Merge branch 'master' into py-ts/cpu-nccl
tfsingh fc3e810
Condense code
tfsingh 1da7d1d
Add todo
tfsingh fa00f9d
added test for CPUCommunicator, changed communicator key for allreduc…
anyadontfly 095c307
Fix (some) lint errors
tfsingh d2bffa5
Changes during meeting
tfsingh 04707ed
Merge branch 'master' into py-ts/cpu-nccl
tfsingh 26b387b
Tests passing, but incorrect
tfsingh e12b500
Add working test
tfsingh df16818
Merge branch 'master' into py-ts/cpu-nccl
tfsingh 5e12e22
Reset conftest.py
tfsingh 0d7b2a4
Add newline
tfsingh 932c4e5
added allreduce test, used Actor.options(get_if_exists=True) to repla…
anyadontfly ddcdaf8
minor changes
anyadontfly 5c9a67f
Remove time
tfsingh dfa8257
Merge branch 'master' into py-ts/cpu-nccl
tfsingh 4f99dd2
Merge branch 'master' into py-ts/cpu-nccl
tfsingh 6c4aed5
used a slightly modifed start_mock_nccl() to satisfy dependency requi…
49b57e2
Merge branch 'master' into py-ts/cpu-nccl
tfsingh dd3e8e4
Small changes
tfsingh c3d0f88
lint changes
anyadontfly f3baf27
Merge branch 'master' into py-ts/cpu-nccl
tfsingh 9ff534b
Merge branch 'master' into py-ts/cpu-nccl
tfsingh ed3b27f
Respond to review
tfsingh 289eb6c
no longer depend on mock nccl, test on wrong shape not passing
anyadontfly affb790
change name from cpu_nccl_group to cpu_communicator
anyadontfly ec7eba5
Fix allreduce test
tfsingh 9a3569c
Remove p2p ops on CPUCommunicator
tfsingh 9de82d3
Merge branch 'master' into py-ts/cpu-nccl
tfsingh 857c2de
reformat code
anyadontfly c6f47d0
Move import inside class
tfsingh 81fddec
Unify codepaths
tfsingh c211427
Rename file
tfsingh fbf1388
Small fixes
tfsingh e11cfe7
clean up, lint and format
anyadontfly ca6fda8
added get_device_type func for custom nccl groups in tests, CPUCommun…
anyadontfly 944f522
Merge branch 'master' into py-ts/cpu-nccl
tfsingh 5dada98
Swap torch
tfsingh 3b614dd
lint fix
anyadontfly fa7dc1e
Merge branch 'master' into py-ts/cpu-nccl
tfsingh a401eff
rename nccl_group to communicator
anyadontfly d7fca94
Merge branch 'master' into py-ts/cpu-nccl
tfsingh c30e02a
small fix
anyadontfly 42a7171
small fix
anyadontfly 89e3f43
Merge branch 'master' into py-ts/cpu-nccl
tfsingh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import asyncio | ||
from collections import defaultdict | ||
from typing import Optional, Tuple, List | ||
from unittest import mock | ||
|
||
import torch | ||
|
||
import ray | ||
import ray.experimental.channel as ray_channel | ||
from ray.experimental.channel.gpu_communicator import TorchTensorAllocator, ReduceOp | ||
|
||
|
||
@ray.remote(num_cpus=0) | ||
class Barrier: | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Barrier that blocks the given number of actors until all actors have | ||
reached the barrier. This is used to mock out blocking NCCL ops. | ||
""" | ||
|
||
def __init__(self, num_actors=2): | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.num_actors = num_actors | ||
self.condition = asyncio.Condition() | ||
# Buffer for the data that is "sent" between the actors, each entry is | ||
# one p2p op. | ||
self.data = {} | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.collective_data = defaultdict(list) | ||
# Buffer for the number of actors seen, each entry is one p2p op. | ||
self.num_actors_seen = defaultdict(int) | ||
|
||
async def wait(self, op_id: int, data=None): | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Wait at barrier until all actors have sent `op_id`. One actor should | ||
provide `data`, and this value will be returned by this method for all | ||
other actors. | ||
""" | ||
async with self.condition: | ||
if data is not None: | ||
assert op_id not in self.data, (self.data, self.num_actors_seen) | ||
self.data[op_id] = data | ||
self.num_actors_seen[op_id] += 1 | ||
|
||
if self.num_actors_seen[op_id] == self.num_actors: | ||
# Wake up all tasks waiting on this condition. | ||
self.condition.notify_all() | ||
else: | ||
await self.condition.wait_for( | ||
lambda: self.num_actors_seen[op_id] == self.num_actors | ||
) | ||
|
||
if data is None: | ||
data = self.data[op_id] | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return data | ||
|
||
async def wait_collective(self, op_id: int, data: torch.Tensor, op: ReduceOp): | ||
""" | ||
Wait at the barrier until all actors have sent `op_id` and `data`. | ||
Once data from all actors is received, execute the collective `op` | ||
on the barrier actor and return the result. | ||
""" | ||
async with self.condition: | ||
if op_id not in self.collective_data: | ||
self.collective_data[op_id].append(data) | ||
self.num_actors_seen[op_id] += 1 | ||
|
||
if self.num_actors_seen[op_id] == self.num_actors: | ||
# Apply the collective operation across all gathered tensors | ||
result = self._apply_op(op, self.collective_data[op_id]) | ||
self.collective_data[op_id] = result | ||
self.condition.notify_all() | ||
else: | ||
await self.condition.wait_for(lambda: self.num_actors_seen[op_id] == self.num_actors) | ||
|
||
# Return the result to all actors | ||
return self.collective_data[op_id] | ||
|
||
def _apply_op(self, op: ReduceOp, tensors: List[torch.Tensor]) -> torch.Tensor: | ||
"""Apply the specified reduction operation across a list of tensors.""" | ||
result = tensors[0].clone() | ||
if op == ReduceOp.SUM: | ||
for tensor in tensors[1:]: | ||
result += tensor | ||
elif op == ReduceOp.PRODUCT: | ||
for tensor in tensors[1:]: | ||
result *= tensor | ||
elif op == ReduceOp.MAX: | ||
for tensor in tensors[1:]: | ||
result = torch.max(result, tensor) | ||
elif op == ReduceOp.MIN: | ||
for tensor in tensors[1:]: | ||
result = torch.min(result, tensor) | ||
elif op == ReduceOp.AVG: | ||
result = sum(tensors) / len(tensors) | ||
else: | ||
# reserve a place for future ops to be added | ||
assert False, "current operation not supported" | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return result | ||
|
||
class CPUNcclGroup(ray_channel.nccl_group._NcclGroup): | ||
""" | ||
Mock the internal _NcclGroup to use a barrier actor instead of a NCCL group | ||
for communication. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
# We use the op index to synchronize the sender and receiver at the | ||
# barrier. | ||
self.num_ops = defaultdict(int) | ||
self.barriers = set() | ||
|
||
def send(self, tensor: torch.Tensor, peer_rank: int): | ||
# "Send" the tensor to the barrier actor. | ||
barrier_key = f"barrier-{self.get_self_rank()}-{peer_rank}" | ||
barrier = ray.get_actor(name=barrier_key) | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.barriers.add(barrier) | ||
ray.get(barrier.wait.remote(self.num_ops[barrier_key], tensor)) | ||
self.num_ops[barrier_key] += 1 | ||
|
||
def recv( | ||
self, | ||
shape: Tuple[int], | ||
dtype: torch.dtype, | ||
peer_rank: int, | ||
allocator: Optional[TorchTensorAllocator] = None, | ||
): | ||
# "Receive" the tensor from the barrier actor. | ||
barrier_key = f"barrier-{peer_rank}-{self.get_self_rank()}" | ||
barrier = ray.get_actor(name=barrier_key) | ||
self.barriers.add(barrier) | ||
received_tensor = ray.get(barrier.wait.remote(self.num_ops[barrier_key])) | ||
assert ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case you can probably just directly return the |
||
allocator is not None | ||
), "torch tensor allocator is required for CPUNcclGroup" | ||
buf = allocator(shape, dtype) | ||
buf[:] = received_tensor[:] | ||
self.num_ops[barrier_key] += 1 | ||
return buf | ||
|
||
def allreduce(self, send_buf: torch.Tensor, recv_buf: torch.Tensor, op: ReduceOp = ReduceOp.SUM): | ||
# different collective communications can use same barrier as long as the participants are the same | ||
barrier_key = "barrier-"+"-".join(map(str, sorted(peer_rank + [self.get_self_rank()]))) | ||
anyadontfly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
barrier = ray.get_actor(name=barrier_key) | ||
self.barriers.add(barrier) | ||
|
||
result = ray.get(barrier.wait_collective.remote(self.num_ops[barrier_key], send_buf, op)) | ||
|
||
assert ( | ||
recv_buf is not None | ||
), "Receiving buffer required for CPUNcclGroup" | ||
recv_buf[:] = result[:] | ||
self.num_ops[barrier_key] += 1 | ||
|
||
def destroy(self) -> None: | ||
for barrier in self.barriers: | ||
ray.kill(barrier) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We’re also noticing some build issues because this dependency isn’t found — is there another place we need to declare this file relies on torch?