Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate HFDeepSpeedConfig from trfrs to accelerate #432

Merged
Merged
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,5 @@ pip install accelerate
- multi-GPU on several nodes (machines)
- TPU
- FP16 with native AMP (apex on the roadmap)
- DeepSpeed support (experimental)
- DeepSpeed support (Experimental)
- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental)
23 changes: 3 additions & 20 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import sys
import warnings
from contextlib import contextmanager
from copy import deepcopy
from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -50,7 +49,6 @@
is_deepspeed_available,
is_torch_version,
is_tpu_available,
is_transformers_available,
pad_across_processes,
reduce,
save,
Expand Down Expand Up @@ -183,25 +181,10 @@ def __init__(
raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.")
if compare_versions("deepspeed", "<", "0.6.5"):
raise ImportError("DeepSpeed version must be >= 0.6.5. Please update DeepSpeed.")
if os.environ.get("DEEPSPEED_ZERO3_INIT", "false") == "true" or deepspeed_plugin.zero3_init_flag:
if not is_transformers_available():
raise Exception(
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
"Please run `pip install transformers`."
)
from transformers.deepspeed import HfDeepSpeedConfig

ds_config = deepcopy(deepspeed_plugin.deepspeed_config)
del ds_config["train_batch_size"]
ds_config.update({"train_micro_batch_size_per_gpu": 1, "gradient_accumulation_steps": 1})
mixed_precision = (
os.environ.get("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
)
if mixed_precision == "fp16":
ds_config.update({"fp16": {"enabled": True}})
elif mixed_precision == "bf16":
ds_config.update({"bf16": {"enabled": True}})
self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
mixed_precision = os.environ.get("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
deepspeed_plugin.set_mixed_precision(mixed_precision)
deepspeed_plugin.set_deepspeed_weakref()

if os.environ.get("USE_FSDP", "false") == "true" or isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
if is_torch_version("<", "1.12.0.dev20220418+cu113"):
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def deepspeed_launcher(args):
current_env["USE_DEEPSPEED"] = "true"
current_env["DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage)
current_env["GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps)
current_env["GRADIENT_CLIPPING"] = str(args.gradient_clipping)
current_env["GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower()
current_env["DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower()
current_env["DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower()
current_env["DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower()
Expand Down
15 changes: 0 additions & 15 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,6 @@ def __init__(
self.device = torch.device("cuda", self.local_process_index)
torch.cuda.set_device(self.device)
self.mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config
mixed_precision = (
parse_choice_from_env("MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision
)
if (
mixed_precision == "fp16"
and "fp16" not in deepspeed_plugin.deepspeed_config
and "bf16" not in deepspeed_plugin.deepspeed_config
):
deepspeed_plugin.deepspeed_config.update({"fp16": {"enabled": True}})
elif (
mixed_precision == "bf16"
and "fp16" not in deepspeed_plugin.deepspeed_config
and "bf16" not in deepspeed_plugin.deepspeed_config
):
deepspeed_plugin.deepspeed_config.update({"bf16": {"enabled": True}})
self.deepspeed_plugin = deepspeed_plugin
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu:
self.distributed_type = DistributedType.MULTI_GPU
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
DeepSpeedSchedulerWrapper,
DummyOptim,
DummyScheduler,
HfDeepSpeedConfig,
)

from .launch import PrepareForLaunch, get_launch_prefix
Expand Down
94 changes: 62 additions & 32 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
import copy
import enum
import functools
import io
import json
import os
import typing
import warnings
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Callable, Iterable, Optional
from typing import Any, Callable, Iterable, Optional

import torch

Expand Down Expand Up @@ -215,7 +213,12 @@ class DeepSpeedPlugin:
This plugin is used to integrate DeepSpeed.
"""

config_file: str = field(default=None, metadata={"help": "Path to the DeepSpeed config file."})
hf_ds_config: Any = field(
default=None,
metadata={
"help": "path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`."
},
)
gradient_accumulation_steps: int = field(
default=None, metadata={"help": "Number of steps to accumulate gradients before updating optimizer states"}
)
Expand Down Expand Up @@ -249,17 +252,23 @@ class DeepSpeedPlugin:
)

def __post_init__(self):
if self.config_file is None:
self.config_file = os.environ.get("DEEPSPEED_CONFIG_FILE", "none")
if self.config_file != "none":
with io.open(self.config_file, "r", encoding="utf-8") as f:
self.deepspeed_config = json.load(f)
if "gradient_accumulation_steps" not in self.deepspeed_config:
self.deepspeed_config["gradient_accumulation_steps"] = 1
elif self.deepspeed_config["gradient_accumulation_steps"] == "auto":
raise ValueError("gradient_accumulation_steps cannot be set to 'auto' in the DeepSpeed config file.")
if "zero_optimization" not in self.deepspeed_config:
raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config file.")
from .deepspeed import HfDeepSpeedConfig

if self.hf_ds_config is None:
self.hf_ds_config = os.environ.get("DEEPSPEED_CONFIG_FILE", "none")
if (
isinstance(self.hf_ds_config, dict)
or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none")
or isinstance(self.hf_ds_config, HfDeepSpeedConfig)
):
if not isinstance(self.hf_ds_config, HfDeepSpeedConfig):
self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config)
if "gradient_accumulation_steps" not in self.hf_ds_config.config:
self.hf_ds_config.config["gradient_accumulation_steps"] = 1
elif self.hf_ds_config.config["gradient_accumulation_steps"] == "auto":
raise ValueError("gradient_accumulation_steps cannot be set to 'auto' in the DeepSpeed config.")
if "zero_optimization" not in self.hf_ds_config.config:
raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config.")
else:
if self.gradient_accumulation_steps is None:
self.gradient_accumulation_steps = int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", 1))
Expand All @@ -281,8 +290,9 @@ def __post_init__(self):
if self.zero3_save_16bit_model is None:
self.zero3_save_16bit_model = os.environ.get("DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true"

self.deepspeed_config = {
config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": self.gradient_accumulation_steps,
"zero_optimization": {
"stage": self.zero_stage,
Expand All @@ -296,29 +306,18 @@ def __post_init__(self):
},
}
if self.gradient_clipping:
self.deepspeed_config["gradient_clipping"] = self.gradient_clipping
config["gradient_clipping"] = self.gradient_clipping
self.hf_ds_config = HfDeepSpeedConfig(config)
self.deepspeed_config = self.hf_ds_config.config
self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
if self.zero3_init_flag is None:
self.zero3_init_flag = os.environ.get("DEEPSPEED_ZERO3_INIT", "false") == "true"
if self.zero3_init_flag and self.deepspeed_config["zero_optimization"]["stage"] != 3:
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
self.zero3_init_flag = False

def find_config_node(self, ds_key_long):
config = self.deepspeed_config

# find the config node of interest if it exists
nodes = ds_key_long.split(".")
ds_key = nodes.pop()
for node in nodes:
config = config.get(node)
if config is None:
return None, ds_key

return config, ds_key

def fill_match(self, ds_key_long, mismatches, must_match=True, **kwargs):
config, ds_key = self.find_config_node(ds_key_long)
config, ds_key = self.hf_ds_config.find_config_node(ds_key_long)
if config is None:
return

Expand Down Expand Up @@ -360,6 +359,37 @@ def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must
f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
)

def set_mixed_precision(self, mixed_precision):
ds_config = self.deepspeed_config
if mixed_precision == "fp16" and "fp16" not in ds_config and "bf16" not in ds_config:
ds_config.update({"fp16": {"enabled": True}})
elif mixed_precision == "bf16" and "fp16" not in ds_config and "bf16" not in ds_config:
ds_config.update({"bf16": {"enabled": True}})

def set_deepspeed_weakref(self):
from .imports import is_transformers_available

if self.zero3_init_flag:
if not is_transformers_available():
raise Exception(
"When `zero3_init_flag` is set, it requires Transformers to be installed. "
"Please run `pip install transformers`."
)
ds_config = copy.deepcopy(self.deepspeed_config)
if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto":
ds_config["gradient_accumulation_steps"] = 1
if (
"train_micro_batch_size_per_gpu" not in ds_config
or ds_config["train_micro_batch_size_per_gpu"] == "auto"
):
ds_config["train_micro_batch_size_per_gpu"] = 1
if ds_config["train_batch_size"] == "auto":
del ds_config["train_batch_size"]

from transformers.deepspeed import HfDeepSpeedConfig

self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa


@dataclass
class FullyShardedDataParallelPlugin:
Expand Down
125 changes: 124 additions & 1 deletion src/accelerate/utils/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,132 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from accelerate.scheduler import AcceleratedScheduler
import io
import json
from copy import deepcopy

from ..optimizer import AcceleratedOptimizer
from ..scheduler import AcceleratedScheduler


class HfDeepSpeedConfig:
"""
This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.

A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
it's important that this object remains alive while the program is still running.

[`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
the DeepSpeed configuration is not modified in any way.

Args:
config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.

"""

def __init__(self, config_file_or_dict):

if isinstance(config_file_or_dict, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since `auto` values would have been overridden
config = deepcopy(config_file_or_dict)
elif isinstance(config_file_or_dict, str):
with io.open(config_file_or_dict, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a DeepSpeed config file or a pre-populated dict")
self.config = config

# zero stage - this is done as early as possible, before model is created, to allow
# ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
# during ``zero.Init()`` which needs to know the dtype, and some other hparams.
self._stage = self.get_value("zero_optimization.stage", -1)

# offload
self._offload = False
if self.is_zero2() or self.is_zero3():
offload_devices_valid = set(["cpu", "nvme"])
offload_devices = set(
[
self.get_value("zero_optimization.offload_optimizer.device"),
self.get_value("zero_optimization.offload_param.device"),
]
)
if len(offload_devices & offload_devices_valid) > 0:
self._offload = True

def find_config_node(self, ds_key_long):
config = self.config

# find the config node of interest if it exists
nodes = ds_key_long.split(".")
ds_key = nodes.pop()
for node in nodes:
config = config.get(node)
if config is None:
return None, ds_key

return config, ds_key

def get_value(self, ds_key_long, default=None):
"""
Returns the set value or `default` if no value is set
"""
config, ds_key = self.find_config_node(ds_key_long)
if config is None:
return default
return config.get(ds_key, default)

def del_config_sub_tree(self, ds_key_long, must_exist=False):
"""
Deletes a sub-section of the config file if it's found.

Unless `must_exist` is `True` the section doesn't have to exist.
"""
config = self.config

# find the config node of interest if it exists
nodes = ds_key_long.split(".")
for node in nodes:
parent_config = config
config = config.get(node)
if config is None:
if must_exist:
raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}")
else:
return

# if found remove it
if parent_config is not None:
parent_config.pop(node)

def is_true(self, ds_key_long):
"""
Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set).

"""
value = self.get_value(ds_key_long)
return False if value is None else bool(value)

def is_false(self, ds_key_long):
"""
Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set).
"""
value = self.get_value(ds_key_long)
return False if value is None else not bool(value)

def is_zero2(self):
return self._stage == 2

def is_zero3(self):
return self._stage == 3

def is_offload(self):
return self._offload


class DeepSpeedEngineWrapper:
Expand Down
Loading