Skip to content

Commit

Permalink
new algorithms for allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
Saeed Maleki committed Apr 28, 2023
1 parent c59f276 commit db171e0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
42 changes: 42 additions & 0 deletions examples/mscclang/allgather_a100_pcie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllGather

# Allpairs allgather for A100
def allgather_allpairs(gpus, instances, protocol):
size = gpus
chunksperloop = 1
topology = fully_connected(gpus)
collective = AllGather(size, chunksperloop, True)

with MSCCLProgram("allgather_hierarchical", topology, collective, instances, protocol=protocol,
interleaved_replication=True, dependence_nop=True):
for chnk in range(2):
for r in range(size):
if ((r % 2) == chnk):
c = chunk(r, Buffer.input, 0)
c.copy(r + 1 - 2 * chnk, Buffer.output, r)
for r in range(size):
if ((r % 2) == chnk):
c = chunk(r, Buffer.input, 0)
c.copy((r+2) % size, Buffer.output, r)
for r in range(size):
if ((r % 2) == chnk):
c = chunk(r, Buffer.output, (r+2) % size)
c.copy(r + 1 - 2 * chnk, Buffer.output, (r+2) % size)

XML()
Check()


parser = argparse.ArgumentParser()
parser.add_argument('num_gpus', type=int, help ='number of gpus')
parser.add_argument('instances', type=int, help='number of instances')
parser.add_argument('--protocol', type=str, default='LL128', choices=['Simple', 'LL', 'LL128'], help ='NCCL protocol. Default: Simple')
args = parser.parse_args()

allgather_allpairs(args.num_gpus, args.instances, args.protocol)
10 changes: 5 additions & 5 deletions examples/mscclang/allreduce_a100_ncv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,27 @@ def allreduce_allpairs(gpus, instances, protocol):
topology = fully_connected(size)
collective = AllReduce(size, chunksperloop, True)
with MSCCLProgram("allreduce_ncv4", topology, collective, instances, protocol=protocol,
interleaved_replication=False, dependence_nop=True):
interleaved_replication=False, threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True):
for chnk in range(chunksperloop):
for r in range(size):
if ((r % 2) == chnk):
c = chunk(r, Buffer.input, chnk)
c.reduce(chunk(r + 1 - 2 * chnk, Buffer.input, chnk))
c.reduce(chunk(r + 1 - 2 * chnk, Buffer.input, chnk), sendtb=0, recvtb=0, ch=0)

for r in range(size):
if ((r % 2) == chnk):
c = chunk(r, Buffer.input, chnk)
c.copy((r+2) % size, 'scratch', chnk)
c.copy((r+2) % size, 'scratch', chnk, sendtb=1, recvtb=1, ch=0)

for r in range(size):
if ((r % 2) == chnk):
c = chunk(r, Buffer.input, chnk)
c.reduce(chunk(r, 'scratch', chnk))
c.reduce(chunk(r, 'scratch', chnk), sendtb=1, recvtb=1, ch=0)

for r in range(size):
if ((r % 2) == chnk):
c = chunk(r, Buffer.input, chnk)
c.copy(r + 1 - 2 * chnk, Buffer.input, chnk)
c.copy(r + 1 - 2 * chnk, Buffer.input, chnk, sendtb=2, recvtb=2, ch=1)

XML()
Check()
Expand Down

0 comments on commit db171e0

Please sign in to comment.