Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ds nvme #34444

Merged
merged 14 commits into from
Nov 21, 2024
26 changes: 22 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@

_init_weights = True
_is_quantized = False
_is_ds_init_called = False


def is_fsdp_enabled():
Expand Down Expand Up @@ -226,6 +227,19 @@ def set_quantized_state():
_is_quantized = False


# Skip recursive calls to deepspeed.zero.Init to avoid pinning errors.
# This issue occurs with ZeRO stage 3 when using NVMe offloading.
# For more details, refer to issue #34429.
@contextmanager
def set_zero3_state():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a small comment documentation about that (explaning the issue we fixed with this basically) 😉

global _is_ds_init_called
_is_ds_init_called = True
try:
yield
finally:
_is_ds_init_called = False


def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
Expand Down Expand Up @@ -1473,13 +1487,14 @@ def _from_config(cls, config, **kwargs):
torch_dtype=torch_dtype,
)

if is_deepspeed_zero3_enabled() and not _is_quantized:
if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
with ContextManagers(init_contexts):
model = cls(config, **kwargs)

else:
Expand Down Expand Up @@ -4026,11 +4041,14 @@ def from_pretrained(
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized:
if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
set_zero3_state(),
] + init_contexts
elif low_cpu_mem_usage:
if not is_accelerate_available():
raise ImportError(
Expand Down
Loading