Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Apr 8, 2024
1 parent b683d7f commit 3cf049e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 36 deletions.
32 changes: 12 additions & 20 deletions msccl/language/mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,35 +173,27 @@ def _get_buffer_index(self, remote_rank, buffer, index):
return buffer, self.prog.buffers[remote_rank][buffer].instance_size()
return buffer, index

def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False):
self.prog.check_buffer_exists(dst, buffer)
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot put to the same rank"
assert self.rank != dst, "Cannot put to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)

# Direct put
assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}"
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)

self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type)

def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm):
self.prog.check_buffer_exists(dst, buffer)
sender = self.rank
receiver = dst
assert sender != receiver, "Cannot put to the same rank"
buffer, index = self._get_buffer_index(dst, buffer, index)
if use_packet:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet)
self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none)
self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none)
else:
self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type)

# Direct put
assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}"
dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size)
def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm):
self._put(dst, buffer, index, sendtb, chan_type)

self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size)
self.prog.instr_dag.add_put_packet(sender, self, dst_chunkref, sendtb, channel_type)
self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none)
self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none)
def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm):
return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True)

def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm):
self.prog.check_buffer_exists(src, buffer)
Expand Down
43 changes: 27 additions & 16 deletions msccl/language/rank_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def same_buf_src(op1, op2):
def same_chan_type(op1, op2):
return op1.channel_type == op2.channel_type

# TODO:(binyli): Need to treat it as base class. For MSCCLPP/MSCCL implement different methods
class InstructionDAG:
def __init__(self, num_ranks, buffers):
self.num_ranks = num_ranks
Expand Down Expand Up @@ -148,15 +149,13 @@ def add_copy_mscclpp(self, rank, send_ref, recv_ref, tb, use_packet = False):

# InstructionDAG - adds a redduce node
def add_reduce(self, rank, send_ref, recv_ref, tb, ch):
tb_step = self._get_tb_step(rank, tb)
op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch, step=tb_step)
op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch)
dstbuffer = recv_ref.buffer
dstindex = recv_ref.index
srcbuffer = send_ref.buffer
srcindex = send_ref.index
size = recv_ref.size
prev_ops = []
op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step))
# Sending part of reduce
self._read(rank, srcbuffer, srcindex, size, op)
# Reduce part of copy
Expand Down Expand Up @@ -192,18 +191,32 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch):
return op

# InstructionDAG - adds a put node
def add_put(self, rank, send_ref, recv_ref, tb, ch_type):
def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet = False):
tb_step = self._get_tb_step(rank, tb)
op = Op(Instruction.put, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
self._read(rank, buffer, index, size, op)
return op

def add_put_packet(self, rank, send_ref, recv_ref, tb, ch_type):
tb_step = self._get_tb_step(rank, tb)
op = Op(Instruction.put_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step)
if use_packet:
op = Op(
Instruction.put_packet,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
else:
op = Op(
Instruction.put,
rank,
send_ref,
recv_ref,
next=set(),
prev=set(),
tb=tb,
channel_type=ch_type,
step=tb_step,
)
buffer = send_ref.buffer
index = send_ref.index
size = send_ref.size
Expand Down Expand Up @@ -544,7 +557,6 @@ def dfs(op, cs):
if op.inst == Instruction.start:
dfs(op,-2) # Start instructions should start at -1


# Given the set of operations that operate over a particular slot (rank, buffer, idx) fixed
# Try and replace operations with pipelined ops like receive copy send (rcs)
# or receive reduce send (rrs) and receive reduce copy send (rrcs)
Expand Down Expand Up @@ -666,7 +678,6 @@ def lower_tbs(self):
gpus.append(Gpu(rank, list(lowered_tbs.values())))
return gpus


# Automatically replicates the algorithm instance number of times
# interleaved sets the replication policy
# if True chunks are split as: ChunkA ChunkB -> ChunkA0 ChunkA1 .. ChunkB0 ChunkB1 ...
Expand Down

0 comments on commit 3cf049e

Please sign in to comment.