Skip to content

Commit

Permalink
[BugFix] Fix parameter names and process_after_weight_loading for W…
Browse files Browse the repository at this point in the history
…4A16 MoE Group Act Order (#11528)

Signed-off-by: ElizaWszola <[email protected]>
Co-authored-by: ElizaWszola <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent 2cbeeda commit eb5cb5e
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 148 deletions.
94 changes: 60 additions & 34 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):

@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError

Expand All @@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)

# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
Expand Down Expand Up @@ -289,13 +291,20 @@ def __init__(
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size

self.quant_method.create_weights(layer=self, **moe_quant_params)

def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
Expand All @@ -312,19 +321,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
elif shard_id == "w2":
param_data[expert_id] = loaded_weight

def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
tp_rank: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
Expand Down Expand Up @@ -364,15 +384,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)

Expand All @@ -387,8 +413,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):

if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
Expand Down Expand Up @@ -416,19 +441,19 @@ def weight_loader(self, param: torch.nn.Parameter,
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
# dimension intermediate_size_per_partition is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
shard_dim = int(not shard_dim)

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
Expand Down Expand Up @@ -480,7 +505,8 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
Expand Down
35 changes: 19 additions & 16 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self, quant_config: AWQMarlinConfig):
self.quant_config = quant_config

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
extra_weight_attrs.update({
"is_transposed":
Expand All @@ -312,17 +312,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
FusedMoeWeightScaleSupported.GROUP.value,
})

w13_qweight = Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qweight = Parameter(
torch.empty(num_experts,
hidden_size,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)

w2_qweight = Parameter(torch.empty(num_experts,
intermediate_size,
intermediate_size_per_partition,
hidden_size //
self.quant_config.pack_factor,
dtype=torch.int32),
Expand All @@ -331,13 +332,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
set_weight_attrs(w2_qweight, extra_weight_attrs)

num_groups_w13 = hidden_size // self.quant_config.group_size
num_groups_w2 = intermediate_size // self.quant_config.group_size
num_groups_w2 = (intermediate_size_per_partition //
self.quant_config.group_size)

# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
w13_scales = Parameter(torch.empty(num_experts,
num_groups_w13,
intermediate_size * 2,
intermediate_size_per_partition * 2,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_scales", w13_scales)
Expand All @@ -353,12 +355,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,

# WEIGHT_ZERO_POINT
# Allocate 2 zero points for w1 and w3 respectively.
w13_qzeros = Parameter(torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
w13_qzeros = Parameter(
torch.empty(num_experts,
num_groups_w13,
2 * intermediate_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)

Expand Down
Loading

0 comments on commit eb5cb5e

Please sign in to comment.