-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Quentin Anthony <[email protected]> Co-authored-by: Ammar Ahmad Awan <[email protected]> Co-authored-by: Jeff Rasley <[email protected]>
- Loading branch information
1 parent
5089345
commit 867a853
Showing
88 changed files
with
4,183 additions
and
418 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from .utils import * | ||
import deepspeed.utils as utils | ||
|
||
supported_torch_version = False | ||
|
||
# See more details at: https://github.com/pytorch/pytorch/pull/48767 | ||
# The PG API in torch versions lesser than 1.8 are different so it is | ||
# non-trivial to support both in the same API. We will just use the | ||
# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+. | ||
|
||
if older_torch(): | ||
# Add custom deepspeed torch comm functions here since we can't import deepspeed.comm | ||
# NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them. | ||
supported_torch_version = False | ||
from torch.distributed import * | ||
|
||
def get_world_group(): | ||
return group.WORLD | ||
|
||
def get_global_rank(group, group_rank): | ||
from torch.distributed.distributed_c10d import _get_global_rank | ||
return _get_global_rank(group, group_rank) | ||
|
||
def allgather_fn(output_tensor, input_tensor, group, async_op): | ||
from torch.distributed import all_gather, get_world_size | ||
from torch import chunk | ||
output_tensors = list(chunk(output_tensor, get_world_size(group))) | ||
return all_gather(output_tensors, input_tensor, group=group, async_op=True) | ||
|
||
def reduce_scatter_fn(output_tensor, input_tensor, group): | ||
from torch.distributed import reduce_scatter, get_world_size | ||
from torch import chunk | ||
input_tensor_lst = list(chunk(input_tensor, get_world_size(group))) | ||
return reduce_scatter(output_tensor, input_tensor_lst, group=group) | ||
|
||
else: | ||
supported_torch_version = True | ||
from .comm import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
''' DeepSpeed Communication Backend | ||
# DS Backends -- Direct C/Ops | ||
- NCCL -- future default | ||
- MPI -- extra performance | ||
- RCCL -- maybe for AMD | ||
- GLOO -- N/A -- use via torch | ||
# via torch.distributed | ||
- T-NCCL -- default -- will work for AMD as well | ||
- T-GLOO -- choose for cpu/testing without GPUs | ||
- T-MPI -- works but not commonly used | ||
''' | ||
''' DS Backend can be the base class | ||
-- NcclBackend, MpiBackend, and TorchBackend are the main subclasses we expect for now | ||
''' | ||
|
||
|
||
class Backend(object): | ||
def __init__(self, name='backend', rank=0, size=1): | ||
self.name = name | ||
# The world size and rank of the world process group | ||
self.world_group = None | ||
self.world_size = rank | ||
self.world_rank = size | ||
# Single process group (pg) implementation for now but keep a list for future | ||
self.process_groups = [] | ||
self.initialized = False | ||
|
||
def is_initialized(self): | ||
return self.initialized | ||
|
||
def new_group(self): | ||
# create a new pg and add it to pg list | ||
pass | ||
|
||
def init_process_group(self): | ||
# subclasses will initialize them fully | ||
# - initialize a default world process group and add it to pg list | ||
self.initialized = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import torch | ||
import time | ||
import argparse | ||
import os | ||
import deepspeed | ||
|
||
# dist is global and functions can set them to switch between deepspeed and torch | ||
dist = None | ||
|
||
DEBUG = False | ||
|
||
|
||
def print_rank_0(message): | ||
if dist.get_rank() == 0: | ||
print(message) | ||
|
||
|
||
def collective_fn(collective, input, output, async_op): | ||
if collective == "alltoall": | ||
dist.all_to_all_single(output, input, async_op=async_op) | ||
elif collective == "allreduce": | ||
dist.all_reduce(input, async_op=async_op) | ||
else: | ||
print_rank_0(f"collective {collective} not supported yet") | ||
exit(0) | ||
|
||
|
||
def get_bw(collective, size, duration, n): | ||
tput = 0 | ||
busbw = 0 | ||
if collective == "alltoall": | ||
tput = (size / duration) * 8 | ||
busbw = (size / duration) * ((n - 1) / n) * 8 | ||
elif collective == "allreduce": | ||
tput = (size * 2 / duration) * 8 | ||
busbw = (size / duration) * (2 * (n - 1) / n) * 8 | ||
else: | ||
print_rank_0("wrong collective specified") | ||
exit(0) | ||
return tput, busbw | ||
|
||
|
||
def timed_benchmark(input, output, args, collective): | ||
dist.barrier() | ||
torch.cuda.synchronize() | ||
|
||
# Warmup, establish connections, etc. | ||
for i in range(args.warmup): | ||
collective_fn(collective, input, output, async_op=args.async_op) | ||
|
||
dist.barrier() | ||
torch.cuda.synchronize() | ||
|
||
# time the actual collective trials times and average it | ||
pre = time.perf_counter() | ||
for i in range(args.trials): | ||
collective_fn(collective, input, output, async_op=args.async_op) | ||
torch.cuda.synchronize() | ||
duration = time.perf_counter() - pre | ||
|
||
# maintain and clean performance data | ||
duration = duration / args.trials | ||
size = int(input.shape[0]) * 4 | ||
n = dist.get_world_size() | ||
tput, busbw = get_bw(collective, size, duration, n) | ||
|
||
duration_ms = duration * 1e3 | ||
duration_us = duration * 1e6 | ||
|
||
desc = f'{input.shape[0]}x{4}' | ||
|
||
if args.bw_unit == 'Gbps': | ||
tput = f'{tput / 1e9:.3f}' | ||
busbw = f'{busbw /1e9:.3f}' | ||
elif args.bw_unit == 'GBps': | ||
tput = f'{tput/8 / 1e9:.3f}' | ||
busbw = f'{busbw/8 /1e9:.3f}' | ||
|
||
if duration_us < 1e3: | ||
duration = f'{duration_us:.3f} us' | ||
else: | ||
duration = f'{duration_ms:.3f} ms' | ||
|
||
print_rank_0(f"{size:<20} {desc:25s} {duration:20s} {tput:20s} {busbw:20s}") | ||
|
||
|
||
def test_correctness(input, output, args, collective): | ||
world_size = dist.get_world_size() | ||
global_rank = dist.get_rank() | ||
|
||
for i in range(world_size): | ||
if i == global_rank: | ||
print(f"Before AllToAll Input List at rank {global_rank}: {input}") | ||
dist.barrier() | ||
|
||
collective_fn(collective, input, output, async_op=args.async_op) | ||
|
||
torch.cuda.synchronize() | ||
dist.barrier() | ||
|
||
for i in range(world_size): | ||
if i == global_rank: | ||
print(f"AllToAll Results at rank {global_rank}: {output}") | ||
dist.barrier() | ||
|
||
|
||
def init_distributed(backend): | ||
global dist | ||
import torch.distributed as dist | ||
deepspeed.init_distributed(dist_backend=backend) | ||
local_rank = int(os.environ['LOCAL_RANK']) | ||
torch.cuda.set_device(local_rank) | ||
|
||
|
||
def init_deepspeed_comm(): | ||
# TODO: Add code to initialize ds comm backend | ||
pass | ||
|
||
|
||
def init_processes(local_rank, args, backend='nccl'): | ||
if backend == 'deepspeed': | ||
init_deepspeed_comm() | ||
elif backend == 'nccl': | ||
init_distributed(backend) | ||
|
||
N = dist.get_world_size() | ||
|
||
M_LIST = [] | ||
for x in (2**p for p in range(1, args.maxsize)): | ||
M_LIST.append(x) | ||
|
||
# List of benchmarks | ||
collectives = ['alltoall', 'allreduce'] | ||
|
||
# Run all collectives | ||
for collective in collectives: | ||
world_size = dist.get_world_size() | ||
dist.barrier() | ||
|
||
# Prepare benchmark header | ||
tput = f'Throughput ({args.bw_unit})' | ||
busbw = f'BusBW ({args.bw_unit})' | ||
|
||
header = f"\n---- Performance of {collective} on {dist.get_world_size()} devices ---------------------------------------------------------\n" | ||
header += f"{'Size (Bytes)':20s} {'Description':25s} {'Duration':20s} {tput:20s} {busbw:20s}\n" | ||
header += "----------------------------------------------------------------------------------------------------" | ||
|
||
print_rank_0(header) | ||
|
||
# loop over various tensor sizes for each collective | ||
for M in M_LIST: | ||
global_rank = dist.get_rank() | ||
mat = torch.ones(N, M, dtype=torch.float32).cuda(local_rank) | ||
torch.cuda.synchronize() | ||
|
||
if collective == 'alltoall': | ||
# check needed for alltoall only | ||
assert mat.numel() % world_size == 0, f"tensor cannot be divided in {world_size} chunks" | ||
|
||
input = ((mat.mul_(float(global_rank))).view(-1)) | ||
output = (mat.clone().view(-1)) | ||
|
||
timed_benchmark(input, output, args, collective) | ||
|
||
global DEBUG | ||
if DEBUG: | ||
test_correctness(input, output, args, collective) | ||
|
||
dist.barrier() | ||
print_rank_0("\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--local_rank", type=int) | ||
parser.add_argument("--trials", type=int, default=5) | ||
parser.add_argument("--warmup", type=int, default=5) | ||
parser.add_argument("--maxsize", type=int, default=24) | ||
parser.add_argument("--async-op", action="store_true") | ||
parser.add_argument("--bw-unit", type=str, default='Gbps') | ||
args = parser.parse_args() | ||
rank = args.local_rank | ||
init_processes(local_rank=rank, args=args) |
Oops, something went wrong.