Skip to content

Commit

Permalink
Support Mixtral quantization using HQT
Browse files Browse the repository at this point in the history
  • Loading branch information
dudilester committed Jul 7, 2024
1 parent ca1dbf6 commit 87d95ad
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 38 deletions.
75 changes: 48 additions & 27 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,32 +120,53 @@ def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor:
return out


def static_fused_moe(hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(1, B, D), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights = torch.zeros(
(B, num_experts), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

htorch.core.mark_step()

for expert_idx in range(num_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = silu_and_mul_wrapper(torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1)))
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
class MoeMatmul(nn.Module):
def __init__(self):
super().__init__()

def set_weight(self, w):
self.weight = w.transpose(0, 1)

def calc(self, state, expert_id, w):
self.weight = w[expert_id].transpose(0, 1)
return self.forward(state)

def forward(self, state):
return torch.matmul(state, self.weight)


class StaticFusedMOE(nn.Module):
def __init__(self, num_total_experts):
super().__init__()
self.w13_list = nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])


def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(1, B, D), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights = torch.zeros(
(B, num_experts), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
htorch.core.mark_step()

return final_hidden_states.view(-1, D)
for expert_idx in range(num_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = silu_and_mul_wrapper(self.w13_list[expert_idx].calc(current_state_static, expert_idx, w1))
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
htorch.core.mark_step()

return final_hidden_states.view(-1, D)
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and model_config.quantization not in ["fp8", "hqt"]
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

Expand Down
26 changes: 16 additions & 10 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once, is_hpu

if is_hpu():
from vllm.hpu.ops import static_fused_moe

from vllm.hpu.ops import StaticFusedMOE

class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
Expand Down Expand Up @@ -83,6 +81,9 @@ def __init__(
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config

if is_hpu():
self.static_fused_moe = StaticFusedMOE(self.num_total_experts)

# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
Expand Down Expand Up @@ -173,11 +174,16 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if is_hpu():
self.static_fused_moe.w2_list[expert_id].set_weight(param_data[expert_id])
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight

Expand Down Expand Up @@ -232,11 +238,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)

if is_hpu():
final_hidden_states = static_fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k)
final_hidden_states = self.static_fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k)
else:
final_hidden_states = fused_moe(hidden_states,
self.w13_weight,
Expand Down

0 comments on commit 87d95ad

Please sign in to comment.