Skip to content

Commit

Permalink
fix: qlora
Browse files Browse the repository at this point in the history
  • Loading branch information
00INDEX committed Aug 1, 2023
1 parent 6630e9d commit 1871bcb
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
17 changes: 8 additions & 9 deletions collie/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st
其余 ``kwargs`` 的内容会用于设置 :class:`.CollieConfig` 的内容。
"""

process_exclusion = kwargs.pop("process_exclusion", False)
if dist.is_initialized() and process_exclusion:
logger.warning(
Expand Down Expand Up @@ -301,9 +300,7 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st
)
# load state dict
state_dict = {}
if not is_zero3_enabled(config) or env.dp_rank == 0 \
or config.low_cpu_mem_usage or config.quantization_config.load_in_8bit \
or getattr(config.quantization_config, "load_in_4bit", False):
if not is_zero3_enabled(config) or env.dp_rank == 0:
state_dict = cls.load_parallel_state_dict(
path=model_path_or_name, config=config,
process_exclusion=process_exclusion, **kwargs
Expand All @@ -314,7 +311,7 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st
state_dict[key_pp] = state_dict.pop(key)
# load checkpoint and dispatch
for name, param in model.named_parameters():
if name not in state_dict.keys():
if name not in state_dict.keys() and (not is_zero3_enabled(config) or env.dp_rank == 0):
logger.rank_zero_warning(f"Missing key: {name}!")
continue
contexts = []
Expand All @@ -331,17 +328,19 @@ def from_pretrained(cls, model_path_or_name: str, config: Union[CollieConfig, st
module=model,
tensor_name=name,
device="cpu" if param.device == torch.device("meta") else param.device,
value=state_dict[name].data
value=state_dict.get(name, torch.empty_like(param.data).to(param.dtype)).data
)
else:
if param.device == torch.device("meta"):
set_module_tensor_to_device(
module=model, tensor_name=name, device="cpu" if param.device == torch.device("meta") else param.device,
value=state_dict[name].data, dtype=config.model_config.torch_dtype
value=state_dict.get(name, torch.empty_like(param.data).to(param.dtype)).data,
dtype=config.model_config.torch_dtype
)
else:
assert param.data.shape == state_dict[name].data.shape, f"The shape of the parameter corresponding to the `{name}` does not match: {param.data.shape} vs {state_dict[name].data.shape}"
param.data = state_dict[name].data.to(config.model_config.torch_dtype).to(param.device)
if name in state_dict:
assert param.data.shape == state_dict[name].data.shape, f"The shape of the parameter corresponding to the `{name}` does not match: {param.data.shape} vs {state_dict[name].data.shape}"
param.data = state_dict.get(name, torch.empty_like(param.data).to(param.dtype)).data.to(config.model_config.torch_dtype).to(param.device)
if config.peft_config.peft_type is not None:
model = get_peft_model(model, config.peft_config)
model.print_trainable_parameters()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ wandb
pandas
psutil
accelerate>=0.20.3
bitsandbytes>=0.39.0
bitsandbytes>=0.41.0
scipy
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="collie-lm",
version="1.0.3",
version="1.0.4",
description="CoLLiE: Collaborative Tuning of Large Language Models in an Efficient Way",
author="OpenLMLab",
author_email="[email protected]",
Expand Down

0 comments on commit 1871bcb

Please sign in to comment.