Skip to content

Commit

Permalink
Support Mixtral quantization using HQT
Browse files Browse the repository at this point in the history
  • Loading branch information
dudilester committed Jul 7, 2024
1 parent 60df235 commit f4f3437
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 67 deletions.
106 changes: 47 additions & 59 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,65 +89,53 @@ def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor:
return out


def static_fused_moe(hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
topk,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros((1, B, D),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights = torch.zeros((B, num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

htorch.core.mark_step()

for expert_idx in range(num_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = silu_and_mul_wrapper(
torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1)))
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
class MoeMatmul(nn.Module):
def __init__(self):
super().__init__()

def set_weight(self, w):
self.weight = w

def calc(self, state, expert_id, w):
self.weight = w[expert_id].transpose(0, 1)
return self.forward(state)

def forward(self, state):
return torch.matmul(state, self.weight)


class StaticFusedMOE(nn.Module):
def __init__(self, num_total_experts):
super().__init__()
self.w13_list = nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])


def forward(self, hidden_states, w1, w2, score, topk):
B, D = hidden_states.shape
num_experts = w1.shape[0]
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, topk, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(1, B, D), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights = torch.zeros(
(B, num_experts), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, B, w1.shape[0])
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
htorch.core.mark_step()

return final_hidden_states.view(-1, D)
for expert_idx in range(num_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = silu_and_mul_wrapper(self.w13_list[expert_idx].calc(current_state_static, expert_idx, w1))
w_output = self.w2_list[expert_idx].calc(w_output, expert_idx, w2)
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
htorch.core.mark_step()


@hpu_utils.with_mark_steps
def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_weights = torch.matmul(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
return attn_weights
return final_hidden_states.view(-1, D)
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and model_config.quantization != "fp8"
and model_config.quantization not in ["fp8", "hqt"]
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

Expand Down
23 changes: 16 additions & 7 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from vllm.utils import is_hpu, print_warning_once

if is_hpu():
from vllm.hpu.ops import static_fused_moe
from vllm.hpu.ops import StaticFusedMOE

from .interfaces import SupportsLoRA

Expand Down Expand Up @@ -87,6 +87,9 @@ def __init__(
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config

if is_hpu():
self.static_fused_moe = StaticFusedMOE(self.num_total_experts)

# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
Expand Down Expand Up @@ -180,11 +183,16 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if is_hpu():
self.static_fused_moe.w13_list[expert_id].set_weight(param_data[expert_id])
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if is_hpu():
self.static_fused_moe.w2_list[expert_id].set_weight(param_data[expert_id])

# Loading scales
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
Expand Down Expand Up @@ -278,10 +286,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)

if is_hpu():
final_hidden_states = static_fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits, self.top_k)
final_hidden_states = self.static_fused_moe(hidden_states,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k)
else:
final_hidden_states = fused_moe(hidden_states,
self.w13_weight,
Expand Down

0 comments on commit f4f3437

Please sign in to comment.