Skip to content

Commit

Permalink
[Bugfix] Enable loading FP8 checkpoints for gpt_bigcode models (vllm-…
Browse files Browse the repository at this point in the history
…project#5460)

Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep authored and jimpang committed Jul 8, 2024
1 parent c0d00cd commit de6975c
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
else:
weight_loader(param, loaded_weight)

0 comments on commit de6975c

Please sign in to comment.