Skip to content

Commit

Permalink
resolving zero3 init when using accelerate config with Trainer (huggi…
Browse files Browse the repository at this point in the history
…ngface#25227)

* resolving zero3 init when using accelerate config with Trainer

* refactor

* fix

* fix import
  • Loading branch information
pacman100 authored and blbadger committed Nov 8, 2023
1 parent 947918f commit 9ad2ab6
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,12 @@ def __post_init__(self):

os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
elif strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")):
# Accelerate DeepSpeed Plugin
from accelerate.utils import DeepSpeedPlugin

self.deepspeed_plugin = DeepSpeedPlugin()
self.deepspeed_plugin.set_deepspeed_weakref()

if self.push_to_hub_token is not None:
warnings.warn(
Expand Down

0 comments on commit 9ad2ab6

Please sign in to comment.