Skip to content

Commit

Permalink
enable cpu/xpu support for the benchmarking suite
Browse files Browse the repository at this point in the history
  • Loading branch information
louie-tsai committed Jun 11, 2024
1 parent 75df1d7 commit 9a97faf
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 43 deletions.
14 changes: 13 additions & 1 deletion benchmarks/communication/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# The DeepSpeed Communication Benchmarking Suite

The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) and [NCCL Tests](https://github.com/NVIDIA/nccl-tests) in that users can:
The intent of these benchmarks is to measure communication latency/bw of deepspeed and/or pytorch distributed communication operations at the Python layer. These benchmarks are complementary to C-level comms benchmarks like [OSU Micro-Benchmarks](https://mvapich.cse.ohio-state.edu/benchmarks/) , [NCCL Tests](https://github.com/NVIDIA/nccl-tests) and [oneCCL Benchmark](https://oneapi-src.github.io/oneCCL/benchmark.html) in that users can:
- Easily debug which layer of the communication software stack hangs or performance degradations originate from.
- Measure the expected communication performance of either DeepSpeed comms or pure PyTorch distributed

Expand Down Expand Up @@ -77,6 +77,18 @@ Finally, users can choose specific communication operations to run in `run_all.p
deepspeed run_all.py --scan --all-reduce --all-to-all --broadcast
</pre>

## CPU and other Accelerator Support
Those benchmarks could also support other devices like Intel CPU and GPU via oneCCL.
Users just need to append one more argument "--device cpu" for all python scripts to run on Intel CPU.
For example, run with a single large message size on Intel CPU:
<pre>
deepspeed all_reduce.py --device cpu
</pre>

To run with a single large message size on Intel GPU:
<pre>
deepspeed all_reduce.py --device xpu
</pre>

# Adding Communication Benchmarks

Expand Down
25 changes: 18 additions & 7 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


# Run all_gather and print metrics
def timed_all_gather(input, output, start_event, end_event, args):
def timed_all_gather(input, output, start_event, end_event, args, device):
if device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist

Expand Down Expand Up @@ -53,7 +56,7 @@ def timed_all_gather(input, output, start_event, end_event, args):
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")


def run_all_gather(local_rank, args):
def run_all_gather(local_rank, args, device):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -64,8 +67,15 @@ def run_all_gather(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
Expand Down Expand Up @@ -96,7 +106,7 @@ def run_all_gather(local_rank, args):
else:
raise e
sync_all()
timed_all_gather(input, output, start_event, end_event, args)
timed_all_gather(input, output, start_event, end_event, args, device)
else:
# all_gather_into_tensor saves memory
if ((args.dist == 'torch' or args.dist == 'deepspeed') and dist.has_all_gather_into_tensor()):
Expand Down Expand Up @@ -130,11 +140,12 @@ def run_all_gather(local_rank, args):
raise e

sync_all()
timed_all_gather(input, output, start_event, end_event, args)
timed_all_gather(input, output, start_event, end_event, args, device)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
device = args.device
init_processes(local_rank=rank, args=args)
run_all_gather(local_rank=rank, args=args)
run_all_gather(local_rank=rank, args=args, device=device)
25 changes: 18 additions & 7 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from deepspeed.accelerator import get_accelerator


def timed_all_reduce(input, start_event, end_event, args):
def timed_all_reduce(input, start_event, end_event, args, device):
if device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -48,7 +51,7 @@ def timed_all_reduce(input, start_event, end_event, args):
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")


def run_all_reduce(local_rank, args):
def run_all_reduce(local_rank, args, device):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -60,8 +63,15 @@ def run_all_reduce(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand All @@ -86,7 +96,7 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, start_event, end_event, args)
timed_all_reduce(input, start_event, end_event, args, device)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -108,11 +118,12 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, start_event, end_event, args)
timed_all_reduce(input, start_event, end_event, args, device)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
device = args.device
init_processes(local_rank=rank, args=args)
run_all_reduce(local_rank=rank, args=args)
run_all_reduce(local_rank=rank, args=args, device=device)
25 changes: 18 additions & 7 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from deepspeed.accelerator import get_accelerator


def timed_all_to_all(input, output, start_event, end_event, args):
def timed_all_to_all(input, output, start_event, end_event, args, device):
if device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -48,7 +51,7 @@ def timed_all_to_all(input, output, start_event, end_event, args):
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")


def run_all_to_all(local_rank, args):
def run_all_to_all(local_rank, args, device):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -59,8 +62,15 @@ def run_all_to_all(local_rank, args):
# Prepare benchmark header
print_header(args, 'all_to_all')

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand All @@ -87,7 +97,7 @@ def run_all_to_all(local_rank, args):
else:
raise e
sync_all()
timed_all_to_all(input, output, start_event, end_event, args)
timed_all_to_all(input, output, start_event, end_event, args, device)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
elements_per_gpu = max_numel(comm_op='all_to_all',
Expand Down Expand Up @@ -122,7 +132,7 @@ def run_all_to_all(local_rank, args):
print(f"Before AllToAll Input List at rank {global_rank}: {input}")
dist.barrier()

timed_all_to_all(input, output, start_event, end_event, args)
timed_all_to_all(input, output, start_event, end_event, args, device)

if args.debug:
for i in range(world_size):
Expand All @@ -134,5 +144,6 @@ def run_all_to_all(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
device = args.device
init_processes(local_rank=rank, args=args)
run_all_to_all(local_rank=rank, args=args)
run_all_to_all(local_rank=rank, args=args, device=device)
25 changes: 18 additions & 7 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from deepspeed.accelerator import get_accelerator


def timed_broadcast(input, start_event, end_event, args):
def timed_broadcast(input, start_event, end_event, args, device):
if device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -48,7 +51,7 @@ def timed_broadcast(input, start_event, end_event, args):
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")


def run_broadcast(local_rank, args):
def run_broadcast(local_rank, args, device):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -60,8 +63,15 @@ def run_broadcast(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
Expand All @@ -86,7 +96,7 @@ def run_broadcast(local_rank, args):
else:
raise e
sync_all()
timed_broadcast(input, start_event, end_event, args)
timed_broadcast(input, start_event, end_event, args, device)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -106,11 +116,12 @@ def run_broadcast(local_rank, args):
sync_all()
return
sync_all()
timed_broadcast(input, start_event, end_event, args)
timed_broadcast(input, start_event, end_event, args, device)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
device = args.device
init_processes(local_rank=rank, args=args)
run_broadcast(local_rank=rank, args=args)
run_broadcast(local_rank=rank, args=args, device=device)
1 change: 1 addition & 0 deletions benchmarks/communication/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
DEFAULT_UNIT = 'Gbps'
DEFAULT_DIST = 'deepspeed'
DEFAULT_MAXSIZE = 24
DEFAULT_DEVICE = 'cuda'
TORCH_DISTRIBUTED_DEFAULT_PORT = 29500
25 changes: 18 additions & 7 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from deepspeed.accelerator import get_accelerator


def timed_pt2pt(input, start_event, end_event, args):
def timed_pt2pt(input, start_event, end_event, args, device):
if device == "cpu":
print_rank_0(f"No Event support on CPU to measure time for now")
return
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand Down Expand Up @@ -67,7 +70,7 @@ def timed_pt2pt(input, start_event, end_event, args):
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")


def run_pt2pt(local_rank, args):
def run_pt2pt(local_rank, args, device):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -78,8 +81,15 @@ def run_pt2pt(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if device == "xpu":
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
elif device == "cpu":
start_event = torch.cpu.Event()
end_event = torch.cpu.Event()
else:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
Expand All @@ -105,7 +115,7 @@ def run_pt2pt(local_rank, args):
else:
raise e
sync_all()
timed_pt2pt(input, start_event, end_event, args)
timed_pt2pt(input, start_event, end_event, args, device)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so double mem_factor
Expand All @@ -125,11 +135,12 @@ def run_pt2pt(local_rank, args):
sync_all()
return
sync_all()
timed_pt2pt(input, start_event, end_event, args)
timed_pt2pt(input, start_event, end_event, args, device)


if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
device = args.device
init_processes(local_rank=rank, args=args)
run_pt2pt(local_rank=rank, args=args)
run_pt2pt(local_rank=rank, args=args, device=device)
15 changes: 8 additions & 7 deletions benchmarks/communication/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


# For importing
def main(args, rank):
def main(args, rank, device):

init_processes(local_rank=rank, args=args)

Expand All @@ -39,19 +39,20 @@ def main(args, rank):

for comm_op in ops_to_run:
if comm_op == 'all_reduce':
run_all_reduce(local_rank=rank, args=args)
run_all_reduce(local_rank=rank, args=args, device=device)
if comm_op == 'all_gather':
run_all_gather(local_rank=rank, args=args)
run_all_gather(local_rank=rank, args=args, device=device)
if comm_op == 'all_to_all':
run_all_to_all(local_rank=rank, args=args)
run_all_to_all(local_rank=rank, args=args, device=device)
if comm_op == 'pt2pt':
run_pt2pt(local_rank=rank, args=args)
run_pt2pt(local_rank=rank, args=args, device=device)
if comm_op == 'broadcast':
run_broadcast(local_rank=rank, args=args)
run_broadcast(local_rank=rank, args=args, device=device)


# For directly calling benchmark
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
main(args, rank)
device = args.device
main(args, rank, device=device)
Loading

0 comments on commit 9a97faf

Please sign in to comment.