Skip to content

Commit

Permalink
Add mark steps to prevent oom in static moe op (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkaniecki authored Jun 24, 2024
1 parent ad890f1 commit 11f047c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor:
return out


@hpu_utils.with_mark_steps
def static_fused_moe(hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
Expand All @@ -142,12 +141,15 @@ def static_fused_moe(hidden_states, w1, w2, score, topk):
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

htorch.core.mark_step()

for expert_idx in range(num_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = silu_and_mul_wrapper(torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1)))
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
htorch.core.mark_step()

return final_hidden_states.view(-1, D)

0 comments on commit 11f047c

Please sign in to comment.