Skip to content

Commit

Permalink
Merge branch 'graphable_token_dispatch' into 'main'
Browse files Browse the repository at this point in the history
Make MoE token dispatcher cuda graph-able if token-drop and padding

See merge request ADLR/megatron-lm!2426
  • Loading branch information
jaredcasper committed Jan 10, 2025
2 parents 3046e33 + f27a04f commit 726da58
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 21 deletions.
65 changes: 55 additions & 10 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,29 +192,48 @@ def set_loss_scale(scale: torch.Tensor):
MoEAuxLossAutoScaler.main_loss_backward_scale = scale


def permute(tokens, routing_map, num_out_tokens: int = None):
def permute(tokens, routing_map, num_out_tokens: int = None, drop_and_pad: bool = False):
"""Permute the tokens and probs based on the mask.
Tokens with the same designated expert will be grouped together.
The shape of mask is [tokens, num_experts], it indicates which experts were selected
by each token.
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
expert capacity. This function exploits this feature to use ops that support cuda graph.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden].
routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
the number of input tokens.
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
If set to true, routing_map has a fixed number of non-zeros
in each column.
"""
num_tokens, hidden = tokens.shape
num_experts = routing_map.shape[1]
if drop_and_pad and not (num_out_tokens is None):
capacity = num_out_tokens // num_experts
assert not routing_map.requires_grad
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
# use argsort to put indices of all non-zeros in the beginning of list
# and keep the first `capacity` number of indices
token_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
:, :capacity
].contiguous()
# flatten from [num_experts, capacity] to 1D
token_indices = token_indices.view(-1)
else:
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()

# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
routing_map = routing_map.bool().T.contiguous()

# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
)
sorted_indices = token_indices.masked_select(routing_map)

# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
Expand All @@ -228,18 +247,27 @@ def unpermute(
restore_shape: torch.Size,
probs: torch.Tensor = None,
routing_map: torch.Tensor = None,
drop_and_pad: bool = False,
):
"""
Restore the original order of tokens after permutation. If probs are provided, it
will also apply them to the tokens before restoring the order.
When drop_and_pad=True, the tensors will have the following properties:
- In routing_map, the number of non-zeros in each column equals to expert capacity
- The size of sorted_indices equals to num_experts * capacity, each split of `capacity`
contains the indices of tokens routed to an expert.
This function exploits these features to use ops that support cuda graph.
Args:
permuted_tokens (torch.Tensor): The permuted token tensor.
sorted_indices (torch.Tensor): The indices used to sort the tokens.
restore_shape (torch.Size): The shape of the unpermuted tensor.
probs (torch.Tensor, optional): The unpermuted probs tensor,
routing_map (torch.Tensor, optional): Token to expert mapping, shape
[num_tokens, num_experts].
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
and pads the number of tokens to the expert capacity.
Returns:
torch.Tensor: The tokens restored to their original order.
Expand All @@ -248,7 +276,24 @@ def unpermute(

if probs is not None:
assert routing_map is not None, "Mask must be provided to permute the probs."
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
if drop_and_pad:
num_experts = routing_map.size(1)
num_permuted_tokens = sorted_indices.size(0)
capacity = num_permuted_tokens // num_experts
num_unpermuted_tokens = probs.size(0)

# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
probs_T_1D = probs.T.contiguous().view(-1)

# get 1D indices of the probs selected by routing_map
indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1)
indices_dim1 = sorted_indices.view(num_experts, capacity)
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)

# get probs from indices
permuted_probs = probs_T_1D.index_select(0, indices_1D)
else:
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)

# Create an output tensor filled with zeros
Expand Down
53 changes: 42 additions & 11 deletions megatron/core/transformer/moe/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,10 @@ def token_permutation(
if self.cuda_sync_point == "before_permutation_1":
torch.cuda.current_stream().synchronize()
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states, routing_map, num_out_tokens=self.num_out_tokens
hidden_states,
routing_map,
num_out_tokens=self.num_out_tokens,
drop_and_pad=self.drop_and_pad,
)

# Perform expert parallel AlltoAll communication
Expand All @@ -516,11 +519,25 @@ def token_permutation(

# Permutation 2: Sort tokens by local expert.
if self.num_local_experts > 1:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert_cpu.ravel(),
self.sort_input_by_local_experts,
)
if self.drop_and_pad:
# Example:
global_input_tokens = (
global_input_tokens.view(
self.tp_size * self.ep_size,
self.num_local_experts,
self.capacity,
*global_input_tokens.size()[1:]
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
global_input_tokens = sort_chunks_by_idxs(
global_input_tokens,
self.num_global_tokens_per_local_expert_cpu.ravel(),
self.sort_input_by_local_experts,
)

if self.cuda_sync_point == "before_finish":
torch.cuda.current_stream().synchronize()
Expand Down Expand Up @@ -551,11 +568,24 @@ def token_unpermutation(

# Unpermutation 2: Unsort tokens by local expert.
if self.num_local_experts > 1:
hidden_states = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert_cpu.T.ravel(),
self.restore_output_by_local_experts,
)
if self.drop_and_pad:
hidden_states = (
hidden_states.view(
self.num_local_experts,
self.tp_size * self.ep_size,
self.capacity,
*hidden_states.size()[1:]
)
.transpose(0, 1)
.contiguous()
.flatten(start_dim=0, end_dim=2)
)
else:
hidden_states = sort_chunks_by_idxs(
hidden_states,
self.num_global_tokens_per_local_expert_cpu.T.ravel(),
self.restore_output_by_local_experts,
)

if self.tp_size > 1:
hidden_states = reduce_scatter_to_sequence_parallel_region(
Expand All @@ -582,6 +612,7 @@ def token_unpermutation(
restore_shape=self.hidden_shape_before_permute,
probs=self.probs,
routing_map=self.routing_map,
drop_and_pad=self.drop_and_pad,
)

# Reshape the output tensor
Expand Down

0 comments on commit 726da58

Please sign in to comment.