Skip to content

Commit

Permalink
fix: blip does not require sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherCho committed Aug 6, 2024
1 parent 0121445 commit 414040f
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,16 +646,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
if "vision" in name:
if self.vision_model is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
use_default_weight_loading = True
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True

if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
Expand Down

0 comments on commit 414040f

Please sign in to comment.