From 030a750fc56e03fb863ffe48268672d84d4cdc7b Mon Sep 17 00:00:00 2001 From: Saeed Maleki Date: Fri, 28 Apr 2023 05:18:18 +0000 Subject: [PATCH] clean ups --- examples/mscclang/allgather_a100_pcie.py | 6 +++--- examples/mscclang/allgather_allpairs.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/mscclang/allgather_a100_pcie.py b/examples/mscclang/allgather_a100_pcie.py index 06d9ae3..29e7cad 100644 --- a/examples/mscclang/allgather_a100_pcie.py +++ b/examples/mscclang/allgather_a100_pcie.py @@ -6,8 +6,8 @@ from msccl.topologies import * from msccl.language.collectives import AllGather -# Allpairs allgather for A100 -def allgather_allpairs(gpus, instances, protocol): +# Hierarchical allgather for A100 +def allgather_hier(gpus, instances, protocol): size = gpus chunksperloop = 1 topology = fully_connected(gpus) @@ -39,4 +39,4 @@ def allgather_allpairs(gpus, instances, protocol): parser.add_argument('--protocol', type=str, default='LL128', choices=['Simple', 'LL', 'LL128'], help ='NCCL protocol. Default: Simple') args = parser.parse_args() -allgather_allpairs(args.num_gpus, args.instances, args.protocol) \ No newline at end of file +allgather_hier(args.num_gpus, args.instances, args.protocol) \ No newline at end of file diff --git a/examples/mscclang/allgather_allpairs.py b/examples/mscclang/allgather_allpairs.py index fdfb202..fdd926d 100644 --- a/examples/mscclang/allgather_allpairs.py +++ b/examples/mscclang/allgather_allpairs.py @@ -10,7 +10,7 @@ def allgather_allpairs(gpus, instances, protocol): size = gpus topology = fully_connected(gpus) - collective = AllGather(size, size, True) + collective = AllGather(size, 1, True) with MSCCLProgram(f"allgather_allpairs", topology, collective, instances, protocol=protocol, threadblock_policy=ThreadblockPolicy.manual): @@ -20,7 +20,7 @@ def allgather_allpairs(gpus, instances, protocol): for r2 in range(gpus): if r1 != r2: index = 0 - c = chunk(r1, Buffer.input, index, size) + c = chunk(r1, Buffer.input, index, 1) c.copy(r2, Buffer.input, index, sendtb=r2, recvtb=r1) XML() Check()