diff --git a/msccl/language/mscclpp.py b/msccl/language/mscclpp.py index b6e3133..599f068 100644 --- a/msccl/language/mscclpp.py +++ b/msccl/language/mscclpp.py @@ -173,35 +173,27 @@ def _get_buffer_index(self, remote_rank, buffer, index): return buffer, self.prog.buffers[remote_rank][buffer].instance_size() return buffer, index - def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + def _put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm, use_packet=False): self.prog.check_buffer_exists(dst, buffer) - sender = self.rank - receiver = dst - assert sender != receiver, "Cannot put to the same rank" + assert self.rank != dst, "Cannot put to the same rank" buffer, index = self._get_buffer_index(dst, buffer, index) # Direct put assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - self.prog.instr_dag.add_put(sender, self, dst_chunkref, sendtb, chan_type) - - def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): - self.prog.check_buffer_exists(dst, buffer) - sender = self.rank - receiver = dst - assert sender != receiver, "Cannot put to the same rank" - buffer, index = self._get_buffer_index(dst, buffer, index) + if use_packet: + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type, use_packet) + self.prog.instr_dag.add_signal(self.rank, self, dst_chunkref, -1, ChannelType.none) + self.prog.instr_dag.add_wait(dst, dst_chunkref, self, -1, ChannelType.none) + else: + self.prog.instr_dag.add_put(self.rank, self, dst_chunkref, sendtb, chan_type) - # Direct put - assert self.prog.topo.link(self.rank, dst) or dst == self.rank, f"No link from {self.rank} to {dst}" - dst_chunkref = self.prog.get_ref(dst, buffer, index, self.size) + def put(self, dst, buffer=None, index=-1, sendtb=-1, chan_type=ChannelType.sm): + self._put(dst, buffer, index, sendtb, chan_type) - self.prog.apply_send(self.rank, self.buffer, self.index, dst, buffer, index, self.size) - self.prog.instr_dag.add_put_packet(sender, self, dst_chunkref, sendtb, channel_type) - self.prog.instr_dag.add_signal(sender, self, dst_chunkref, -1, ChannelType.none) - self.prog.instr_dag.add_wait(receiver, dst_chunkref, self, -1, ChannelType.none) + def put_packet(self, dst, buffer=None, index=-1, sendtb=-1, channel_type=ChannelType.sm): + return self._put(dst, buffer, index, sendtb, channel_type, use_packet=True) def get(self, src, buffer=None, index=-1, recvtb=-1, chan_type=ChannelType.sm): self.prog.check_buffer_exists(src, buffer) diff --git a/msccl/language/rank_dag.py b/msccl/language/rank_dag.py index c576be5..2c0134e 100755 --- a/msccl/language/rank_dag.py +++ b/msccl/language/rank_dag.py @@ -47,6 +47,7 @@ def same_buf_src(op1, op2): def same_chan_type(op1, op2): return op1.channel_type == op2.channel_type +# TODO:(binyli): Need to treat it as base class. For MSCCLPP/MSCCL implement different methods class InstructionDAG: def __init__(self, num_ranks, buffers): self.num_ranks = num_ranks @@ -148,15 +149,13 @@ def add_copy_mscclpp(self, rank, send_ref, recv_ref, tb, use_packet = False): # InstructionDAG - adds a redduce node def add_reduce(self, rank, send_ref, recv_ref, tb, ch): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch, step=tb_step) + op = Op(Instruction.reduce, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel=ch) dstbuffer = recv_ref.buffer dstindex = recv_ref.index srcbuffer = send_ref.buffer srcindex = send_ref.index size = recv_ref.size prev_ops = [] - op.srcs.append((ChunkRef(send_ref.rank, send_ref.buffer, send_ref.index, send_ref.size), tb_step)) # Sending part of reduce self._read(rank, srcbuffer, srcindex, size, op) # Reduce part of copy @@ -192,18 +191,32 @@ def add_send(self, rank, send_ref, recv_ref, tb, ch): return op # InstructionDAG - adds a put node - def add_put(self, rank, send_ref, recv_ref, tb, ch_type): + def add_put(self, rank, send_ref, recv_ref, tb, ch_type, use_packet = False): tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.put, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) - buffer = send_ref.buffer - index = send_ref.index - size = send_ref.size - self._read(rank, buffer, index, size, op) - return op - - def add_put_packet(self, rank, send_ref, recv_ref, tb, ch_type): - tb_step = self._get_tb_step(rank, tb) - op = Op(Instruction.put_packet, rank, send_ref, recv_ref, next=set(), prev=set(), tb=tb, channel_type=ch_type, step=tb_step) + if use_packet: + op = Op( + Instruction.put_packet, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) + else: + op = Op( + Instruction.put, + rank, + send_ref, + recv_ref, + next=set(), + prev=set(), + tb=tb, + channel_type=ch_type, + step=tb_step, + ) buffer = send_ref.buffer index = send_ref.index size = send_ref.size @@ -544,7 +557,6 @@ def dfs(op, cs): if op.inst == Instruction.start: dfs(op,-2) # Start instructions should start at -1 - # Given the set of operations that operate over a particular slot (rank, buffer, idx) fixed # Try and replace operations with pipelined ops like receive copy send (rcs) # or receive reduce send (rrs) and receive reduce copy send (rrcs) @@ -666,7 +678,6 @@ def lower_tbs(self): gpus.append(Gpu(rank, list(lowered_tbs.values()))) return gpus - # Automatically replicates the algorithm instance number of times # interleaved sets the replication policy # if True chunks are split as: ChunkA ChunkB -> ChunkA0 ChunkA1 .. ChunkB0 ChunkB1 ...