Skip to content

Commit

Permalink
Provide API to lock SHArP tree for distributed adam within nodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
alpha0422 committed Apr 25, 2024
1 parent 394f401 commit d00598e
Showing 1 changed file with 59 additions and 21 deletions.
80 changes: 59 additions & 21 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,61 @@
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8

from nemo.utils import str_to_dtype
from nemo.utils import logging
from nemo.utils.te_utils import is_float8tensor

_distribute_within_nodes_pgs = {}

def create_distribute_within_nodes_pgs():
"""Create process groups for distributing with nodes.
User can reuse this function to reorder communicators for SHArP.
"""
global _distribute_within_nodes_pgs
assert torch.distributed.is_initialized()
if _distribute_within_nodes_pgs:
return _distribute_within_nodes_pgs

world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = torch.cuda.device_count()
nodes = world_size // devices

if nodes * devices != world_size:
logging.warning("Expected all nodes have the same amout of devices, disable distribute_within_nodes.")
return {}

node_id = rank // devices
device_id = rank % devices

distributed_pgs = []
for i in range(nodes):
ranks = [i * devices + j for j in range(devices)]
pg = torch.distributed.new_group(ranks=ranks)
distributed_pgs.append(pg)

redundant_pgs = []
for i in range(devices):
ranks = [i + j * devices for j in range(nodes)]
pg = torch.distributed.new_group(ranks=ranks)
redundant_pgs.append(pg)

# To re-order SHArP communicator right after distributed init,
# we have to expose redundant_process_group to user.
# User has too invoke allreduce through redundant_process_group
# before all other communicators to lock SHArP tree.
_distribute_within_nodes_pgs = {
'world_size': world_size,
'rank': rank,
'devices': devices,
'nodes': nodes,
'node_id': node_id,
'device_id': device_id,
'distributed_process_group': distributed_pgs[node_id],
'redundant_process_group': redundant_pgs[device_id],
}
return _distribute_within_nodes_pgs


class MegatronDistributedFusedAdam(DistributedFusedAdam):
"""Adam optimizer with ZeRO algorithm
Expand Down Expand Up @@ -78,27 +131,12 @@ def __init__(
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
elif distribute_within_nodes:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
devices = torch.cuda.device_count()
nodes = world_size // devices
assert nodes * devices == world_size, "Expected all nodes have teh same amout of devices."
node_id = rank // devices
device_id = rank % devices

distributed_pgs = []
for i in range(nodes):
ranks = [i * devices + j for j in range(devices)]
pg = torch.distributed.new_group(ranks=ranks)
distributed_pgs.append(pg)
kwargs['distributed_process_group'] = distributed_pgs[node_id]

redundant_pgs = []
for i in range(devices):
ranks = [i + j * devices for j in range(nodes)]
pg = torch.distributed.new_group(ranks=ranks)
redundant_pgs.append(pg)
kwargs['redundant_process_group'] = redundant_pgs[device_id]
dist_pg_infos = create_distribute_within_nodes_pgs()
if dist_pg_infos:
kwargs['distributed_process_group'] = dist_pg_infos['distributed_process_group']
kwargs['redundant_process_group'] = dist_pg_infos['redundant_process_group']
global _distribute_within_nodes_pgs
_distribute_within_nodes_pgs = {}

# Make sure dtypes are in right type
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
Expand Down

0 comments on commit d00598e

Please sign in to comment.