forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Elastic training support (deepspeedai#602)
Co-authored-by: Samyam Rajbhandari <[email protected]>
Showing
16 changed files
with
883 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#!/usr/bin/env python | ||
|
||
import argparse | ||
import json | ||
|
||
import deepspeed | ||
from deepspeed.elasticity import compute_elastic_config | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-c', '--config', type=str, help="DeepSpeed config json") | ||
parser.add_argument('-w', '--world-size', type=int, default=0, help="Intended/current world size") | ||
args = parser.parse_args() | ||
ds_config = json.load(open(args.config, 'r')) | ||
|
||
ds_version = deepspeed.__version__ | ||
|
||
elastic_config = ds_config['elasticity'] | ||
print('------------------------------------------') | ||
print("Elasticity config:") | ||
print('------------------------------------------') | ||
print(json.dumps(elastic_config, indent=4, sort_keys=True)) | ||
|
||
if args.world_size > 0: | ||
final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version, world_size=args.world_size) | ||
print('------------------------------------------') | ||
print(f"Calculated results for world size {args.world_size}:") | ||
print('------------------------------------------') | ||
print(f'final_batch_size .... {final_batch_size}') | ||
print(f'valid_gpus .......... {valid_gpus}') | ||
print(f'micro_batch_size .... {micro_batch_size}') | ||
else: | ||
final_batch_size, valid_gpus = compute_elastic_config(ds_config=ds_config, target_deepspeed_version=ds_version) | ||
print('------------------------------------------') | ||
print("Calculated results:") | ||
print('------------------------------------------') | ||
print(f'final_batch_size .... {final_batch_size}') | ||
print(f'valid_gpus .......... {valid_gpus}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .elasticity import compute_elastic_config, elasticity_enabled, ensure_immutable_elastic_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
Copyright 2020 The Microsoft DeepSpeed Team | ||
""" | ||
|
||
import json | ||
from .constants import * | ||
|
||
|
||
class ElasticityError(Exception): | ||
""" | ||
Base exception for all elasticity related errors | ||
""" | ||
pass | ||
|
||
|
||
class ElasticityConfigError(ElasticityError): | ||
""" | ||
Elasticity configuration error | ||
""" | ||
pass | ||
|
||
|
||
class ElasticityIncompatibleWorldSize(ElasticityError): | ||
""" | ||
Attempting to run a world size that is incompatible with a given elastic config | ||
""" | ||
pass | ||
|
||
|
||
class ElasticityConfig: | ||
""" | ||
Elastic config object, constructed from a param dictionary that only contains elastic | ||
config parameters, example below: | ||
If elasticity is enabled, user must specify (at least) max_train_batch_size | ||
and micro_batch_sizes. | ||
{ | ||
"enabled": true, | ||
"max_train_batch_size": 2000, | ||
"micro_batch_sizes": [2,4,6], | ||
"min_gpus": 1, | ||
"max_gpus" : 10000 | ||
"min_time": 20 | ||
"ignore_non_elastic_batch_info": false | ||
"version": 0.1 | ||
} | ||
""" | ||
def __init__(self, param_dict): | ||
self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT) | ||
if self.enabled: | ||
if MAX_ACCEPTABLE_BATCH_SIZE in param_dict: | ||
self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE] | ||
else: | ||
raise ElasticityConfigError( | ||
f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}") | ||
if MICRO_BATCHES in param_dict: | ||
self.micro_batches = param_dict[MICRO_BATCHES] | ||
else: | ||
raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}") | ||
else: | ||
self.max_acceptable_batch_size = param_dict.get( | ||
MAX_ACCEPTABLE_BATCH_SIZE, | ||
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT) | ||
self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT) | ||
self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT) | ||
self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT) | ||
self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT) | ||
self.version = param_dict.get(VERSION, VERSION_DEFAULT) | ||
self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH, | ||
PREFER_LARGER_BATCH_DEFAULT) | ||
self.ignore_non_elastic_batch_info = param_dict.get( | ||
IGNORE_NON_ELASTIC_BATCH_INFO, | ||
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) | ||
|
||
def repr(self): | ||
return self.__dict__ | ||
|
||
def __repr__(self): | ||
return json.dumps(self.__dict__, sort_keys=True, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Copyright 2020 The Microsoft DeepSpeed Team | ||
""" | ||
|
||
######################################### | ||
# Elasticity | ||
######################################### | ||
''' Elasticity Utility in DeepSpeed can be used to create highly elastic jobs compatible | ||
with a large number of GPUs. For elastic jobs, DeepSpeed will provide a batch size that | ||
can support a large number of GPUs based on the user specified parameters | ||
''' | ||
FORMAT = ''' | ||
Elasticity should be enabled as: | ||
"elasticity": { | ||
"enabled": true, | ||
"max_train_batch_size": 2000, | ||
"micro_batch_sizes": [2,4,6], | ||
"min_gpus": 1, | ||
"max_gpus" : 10000 | ||
"min_time": 20, | ||
"prefer_larger_batch": true, | ||
"ignore_non_elastic_batch_info": false, | ||
"version": 0.1 | ||
} | ||
''' | ||
|
||
ELASTICITY = 'elasticity' | ||
|
||
# Current elasticity version | ||
LATEST_ELASTICITY_VERSION = 0.1 | ||
|
||
ENABLED = 'enabled' | ||
ENABLED_DEFAULT = False | ||
|
||
# Max acceptable train_batch_size | ||
MAX_ACCEPTABLE_BATCH_SIZE = 'max_train_batch_size' | ||
MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT = 2000 | ||
|
||
# Acceptable micro batch sizes, same as train_micro_batch_size_per_gpu | ||
MICRO_BATCHES = 'micro_batch_sizes' | ||
MICRO_BATCHES_DEFAULT = [2, 4, 6] | ||
|
||
# Min/max of GPUs to search over | ||
MIN_GPUS = 'min_gpus' | ||
MIN_GPUS_DEFAULT = 1 | ||
MAX_GPUS = 'max_gpus' | ||
MAX_GPUS_DEFAULT = 10000 | ||
|
||
# Minimum running time (minutes) before the scheduler will scale us | ||
MIN_TIME = "min_time" | ||
MIN_TIME_DEFAULT = "20" | ||
|
||
# When finding a suitable batch size, attempt to find one that is closest | ||
# to the max train batch size given. | ||
PREFER_LARGER_BATCH = 'prefer_larger_batch' | ||
PREFER_LARGER_BATCH_DEFAULT = True | ||
|
||
# In order to reduce confusion, if elastic mode is enabled we | ||
# require (via assert) that no batch info is set outside of the | ||
# elastic config. You can turn off this assert via this config | ||
# but keep in mind that all batch info defined outside the | ||
# elastic mode *will be ignored*. | ||
IGNORE_NON_ELASTIC_BATCH_INFO = 'ignore_non_elastic_batch_info' | ||
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT = False | ||
|
||
# Version of elastic logic to use | ||
VERSION = "version" | ||
VERSION_DEFAULT = LATEST_ELASTICITY_VERSION | ||
|
||
# Minimum deepspeed version to use elasticity | ||
MINIMUM_DEEPSPEED_VERSION = "0.3.8" | ||
|
||
# Environment variable storing elastic config from resource scheduler | ||
DEEPSPEED_ELASTICITY_CONFIG = "DEEPSPEED_ELASTICITY_CONFIG" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,334 @@ | ||
""" | ||
Copyright 2020 The Microsoft DeepSpeed Team | ||
""" | ||
import os | ||
import re | ||
import json | ||
import numpy as np | ||
|
||
from .config import ElasticityConfig, ElasticityConfigError, ElasticityError, \ | ||
ElasticityIncompatibleWorldSize | ||
from .constants import ELASTICITY, ENABLED, ENABLED_DEFAULT, LATEST_ELASTICITY_VERSION, \ | ||
MINIMUM_DEEPSPEED_VERSION, IGNORE_NON_ELASTIC_BATCH_INFO, \ | ||
IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT, DEEPSPEED_ELASTICITY_CONFIG | ||
from ..git_version_info import version as __version__ | ||
from ..utils import logger | ||
|
||
# Thirty eight smallest highly composite numbers. The list should | ||
# be enough to support up to 720K batch size. | ||
HCN_LIST = [ | ||
1, | ||
2, | ||
4, | ||
6, | ||
12, | ||
24, | ||
36, | ||
48, | ||
60, | ||
120, | ||
180, | ||
240, | ||
360, | ||
720, | ||
840, | ||
1260, | ||
1680, | ||
2520, | ||
5040, | ||
7560, | ||
10080, | ||
15120, | ||
20160, | ||
25200, | ||
27720, | ||
45360, | ||
50400, | ||
55440, | ||
83160, | ||
110880, | ||
166320, | ||
221760, | ||
277200, | ||
332640, | ||
498960, | ||
554400, | ||
665280, | ||
720720 | ||
] | ||
|
||
|
||
def get_candidate_batch_sizes(base_list, max_acceptable_batch_size): | ||
candidate_batch_size = [] | ||
|
||
#brute force is fine here. We are working with very small lists | ||
for base in base_list: | ||
batch_size = base | ||
for hcn in HCN_LIST: | ||
new_batch_size = base * hcn | ||
if new_batch_size > max_acceptable_batch_size: | ||
break | ||
batch_size = new_batch_size | ||
candidate_batch_size.append(batch_size) | ||
return list(set(candidate_batch_size)) | ||
|
||
|
||
def get_valid_gpus(batch_size, micro_batches, min_valid_gpus, max_valid_gpus): | ||
valid_gpus = [] | ||
for micro_batch in micro_batches: | ||
if batch_size % micro_batch == 0: | ||
|
||
max_gpus = batch_size // micro_batch | ||
if max_gpus >= min_valid_gpus and max_gpus <= max_valid_gpus: | ||
valid_gpus.append(max_gpus) | ||
|
||
for i in range(1, max_gpus // 2 + 1): | ||
if max_gpus % i == 0: | ||
if i >= min_valid_gpus and i <= max_valid_gpus: | ||
valid_gpus.append(i) | ||
valid_gpus = set(valid_gpus) | ||
valid_gpus = sorted(list(valid_gpus)) | ||
return valid_gpus | ||
|
||
|
||
def get_best_candidates(candidate_batch_sizes, | ||
micro_batches, | ||
min_gpus, | ||
max_gpus, | ||
prefer_larger): | ||
|
||
max_valid_gpus = 0 | ||
valid_gpus = None | ||
final_batch_size = int(min(micro_batches)) | ||
|
||
for batch_size in candidate_batch_sizes: | ||
|
||
current_valid_gpus = get_valid_gpus(batch_size, | ||
micro_batches, | ||
min_gpus, | ||
max_gpus) | ||
|
||
if (len(current_valid_gpus) > max_valid_gpus | ||
or (len(current_valid_gpus) == max_valid_gpus and | ||
((prefer_larger and batch_size > final_batch_size) or | ||
(not prefer_larger and batch_size < final_batch_size)))): | ||
max_valid_gpus = len(current_valid_gpus) | ||
valid_gpus = current_valid_gpus | ||
final_batch_size = batch_size | ||
|
||
return final_batch_size, valid_gpus | ||
|
||
|
||
def _get_compatible_gpus_v01(micro_batches, | ||
max_acceptable_batch_size, | ||
min_gpus=None, | ||
max_gpus=None, | ||
prefer_larger=True): | ||
'''We use two heuristics to compute the batch size | ||
1. We use the Lowest Common Multiple of the micro-batches | ||
as the base batch size and scale it by a HCN such that the result is | ||
the largest batch size less than the max_acceptable batch size | ||
2. We use each of the micro batches as a base and scale it | ||
by a HCN such that the result is the largest batch size less than the | ||
max_acceptable batch size. | ||
We then use brute force to count the number of compatible GPU count for | ||
each of the aforementioned cases, and return the batch size with the most number of | ||
compatible GPU counts in the min-max GPU range if provided, other wise | ||
we return the batch size with the most number of total compatible GPU counts. | ||
Returns: | ||
final_batch_size | ||
valid_gpus | ||
''' | ||
|
||
if min_gpus is None: | ||
min_gpus = int(1) | ||
|
||
if max_gpus is None: | ||
max_gpus = int(max_acceptable_batch_size / min(micro_batches)) | ||
|
||
assert all(mb <= max_acceptable_batch_size for mb in micro_batches ), \ | ||
f"All micro batches must be less than \ | ||
or equal to max_acceptable_batch_size: {max_acceptable_batch_size}" | ||
|
||
lcm = np.lcm.reduce(micro_batches) | ||
|
||
base_list = [] | ||
base_list.extend(micro_batches) | ||
base_list.append(lcm) | ||
|
||
candidate_batch_sizes = get_candidate_batch_sizes(base_list, | ||
max_acceptable_batch_size) | ||
|
||
final_batch_size, valid_gpus = get_best_candidates( | ||
candidate_batch_sizes, | ||
micro_batches, | ||
min_gpus, | ||
max_gpus, | ||
prefer_larger) | ||
|
||
return final_batch_size, valid_gpus | ||
|
||
|
||
def _parse_version(version_str): | ||
'''Parse a version string and extract the major and minor versions (and possibly patch version).''' | ||
matched = re.search('^(\d+)\.(\d+)\.(\d+)', version_str) | ||
if matched: | ||
return int(matched.group(1)), int(matched.group(2)), int(matched.group(3)) | ||
else: | ||
matched = re.search('^(\d+)\.(\d+)', version_str) | ||
assert matched != None, "Unable to parse version number, expecting" \ | ||
f"major.minor[.patch] format but received {version_str}" | ||
return int(matched.group(1)), int(matched.group(2)), 0 | ||
|
||
|
||
def _compatible_ds_version_check(target_deepspeed_version: str): | ||
min_major, min_minor, min_patch = _parse_version(MINIMUM_DEEPSPEED_VERSION) | ||
trg_major, trg_minor, trg_patch = _parse_version(target_deepspeed_version) | ||
|
||
err_str = f"Target deepspeed version of {target_deepspeed_version} is not compatible " \ | ||
f"with minimum version {MINIMUM_DEEPSPEED_VERSION} supporting elasticity." | ||
if trg_major < min_major: | ||
raise ElasticityError(err_str) | ||
if trg_minor < min_minor: | ||
raise ElasticityError(err_str) | ||
if trg_patch < min_patch: | ||
raise ElasticityError(err_str) | ||
return True | ||
|
||
|
||
def elasticity_enabled(ds_config: dict): | ||
if ELASTICITY not in ds_config: | ||
return False | ||
return ds_config[ELASTICITY].get(ENABLED, ENABLED_DEFAULT) | ||
|
||
|
||
def ensure_immutable_elastic_config(runtime_elastic_config_dict: dict): | ||
""" | ||
Ensure the resource scheduler saw the same elastic config we are using at runtime | ||
""" | ||
if DEEPSPEED_ELASTICITY_CONFIG in os.environ: | ||
scheduler_elastic_config_dict = json.loads( | ||
os.environ[DEEPSPEED_ELASTICITY_CONFIG]) | ||
scheduler_elastic_config = ElasticityConfig(scheduler_elastic_config_dict) | ||
runtime_elastic_config = ElasticityConfig(runtime_elastic_config_dict) | ||
err_str = "Elastic config '{}={}' seen by resource scheduler does not match config passed to runtime {}={}" | ||
if runtime_elastic_config.max_acceptable_batch_size != scheduler_elastic_config.max_acceptable_batch_size: | ||
raise ElasticityConfigError( | ||
err_str.format('max_acceptable_batch_size', | ||
scheduler_elastic_config.max_acceptable_batch_size, | ||
'max_acceptable_batch_size', | ||
runtime_elastic_config.max_acceptable_batch_size)) | ||
if runtime_elastic_config.micro_batches != scheduler_elastic_config.micro_batches: | ||
raise ElasticityConfigError( | ||
err_str.format('micro_batches', | ||
scheduler_elastic_config.micro_batches, | ||
'micro_batches', | ||
runtime_elastic_config.micro_batches)) | ||
if runtime_elastic_config.version != scheduler_elastic_config.version: | ||
raise ElasticityConfigError( | ||
err_str.format('version', | ||
scheduler_elastic_config.version, | ||
'version', | ||
runtime_elastic_config.version)) | ||
else: | ||
logger.warning("Unable to find DEEPSPEED_ELASTICITY_CONFIG environment variable, cannot " \ | ||
"guarantee resource scheduler will scale this job using compatible GPU counts.") | ||
|
||
|
||
def compute_elastic_config(ds_config: dict, target_deepspeed_version: str, world_size=0): | ||
"""Core deepspeed elasticity API. Given an elastic config (similar to the example below) | ||
DeepSpeed will compute a total train batch size corresponding valid GPU count list that | ||
provides a high level of elasticity. Elasticity in this case means we are safe to scale | ||
the training job up/down across the GPU count list *without* any negative impacts on | ||
training convergence. This is achievable primarily due to DeepSpeed's gradient accumulation | ||
feature which allows us to decompose a global training batch size into: | ||
micro-batch-size * gradient-accumulation-steps * world-size. | ||
"elasticity": { | ||
"enabled": true, | ||
"max_train_batch_size": 2000, | ||
"micro_batch_sizes": [2,4,6], | ||
"min_gpus": 1, | ||
"max_gpus" : 10000 | ||
"min_time": 20 | ||
"version": 0.1 | ||
} | ||
Intended to be called both by scheduling infrastructure and deepspeed runtime. | ||
For the same `ds_config` we should return deterministic results. | ||
Args: | ||
ds_config (dict): DeepSpeed config dictionary/json | ||
target_deepspeed_version (str): When called from scheduling | ||
infrastructure we want to ensure that the target deepspeed version is | ||
compatible with the elasticity version used in the backend. | ||
world_size (int, optional): Intended/current world size, will do some sanity | ||
checks to ensure world size is actually valid with the config. | ||
Raises: | ||
ElasticityConfigError: Missing required elasticity config or elasticity disabled | ||
ElasticityError: If target deepspeed version is not compatible with current version | ||
Returns: | ||
final_batch_size (int): total batch size used for training | ||
valid_gpus (list(int)): list of valid GPU counts with this config | ||
micro_batch_size (int, optional): if world_size is provided will return | ||
specific micro batch size | ||
""" | ||
if not isinstance(ds_config, dict): | ||
raise ValueError("Expected ds_config to be a dictionary but received " \ | ||
f"a {type(ds_config)}, containing: {ds_config}") | ||
|
||
if ELASTICITY not in ds_config: | ||
raise ElasticityConfigError(f"'{ELASTICITY}' is missing from config json," \ | ||
" please add it if running an elastic training job.") | ||
|
||
elastic_config_dict = ds_config[ELASTICITY] | ||
if not elastic_config_dict.get(ENABLED, ENABLED_DEFAULT): | ||
raise ElasticityConfigError("Elasticity is disabled, please enable it " \ | ||
"('enabled':true) if running an elastic training job.") | ||
|
||
elastic_config = ElasticityConfig(elastic_config_dict) | ||
|
||
if float(elastic_config.version) > LATEST_ELASTICITY_VERSION: | ||
raise ElasticityConfigError("Attempting to run elasticity version " \ | ||
f"{elastic_config.version} but runtime only supports up " \ | ||
f"to {LATEST_ELASTICITY_VERSION}") | ||
|
||
# Ensure target deepspeed version works with intended elasticity version | ||
if not _compatible_ds_version_check(target_deepspeed_version): | ||
raise ElasticityError("Unable to run elasticity on target deepspeed version of" \ | ||
f" {target_deepspeed_version}, currently {__version__}") | ||
|
||
if float(elastic_config.version) == 0.1: | ||
final_batch_size, valid_gpus = _get_compatible_gpus_v01( | ||
micro_batches=elastic_config.micro_batches, | ||
max_acceptable_batch_size=elastic_config.max_acceptable_batch_size, | ||
min_gpus=elastic_config.min_gpus, | ||
max_gpus=elastic_config.max_gpus, | ||
prefer_larger=elastic_config.prefer_larger_batch_size) | ||
# ensure batch size is int dtype | ||
final_batch_size = int(final_batch_size) | ||
else: | ||
raise NotImplementedError( | ||
f"Unable to find elastic logic for version: {elastic_config.version}") | ||
|
||
if world_size > 0: | ||
if world_size not in valid_gpus: | ||
raise ElasticityIncompatibleWorldSize(f"World size ({world_size}) is not valid " \ | ||
f"with the current list of valid GPU counts: {valid_gpus}") | ||
|
||
# Pick largest valid micro batch size | ||
micro_batch_size = None | ||
for mbsz in sorted(list(set(elastic_config.micro_batches)), reverse=True): | ||
if final_batch_size // world_size % mbsz == 0: | ||
micro_batch_size = mbsz | ||
break | ||
assert micro_batch_size is not None, "Unable to find divisible micro batch size" \ | ||
f" world_size={world_size}, final_batch_size={final_batch_size}, and " \ | ||
f" micro_batches={elastic_config.micro_batches}." | ||
return final_batch_size, valid_gpus, micro_batch_size | ||
|
||
return final_batch_size, valid_gpus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ torchvision>=0.4.0 | |
tqdm | ||
tensorboardX==1.8 | ||
ninja | ||
numpy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
import pytest | ||
import deepspeed | ||
from common import distributed_test | ||
from deepspeed.git_version_info import version as ds_version | ||
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict | ||
|
||
base_ds_config = { | ||
"elasticity": { | ||
"enabled": True, | ||
"max_train_batch_size": 10000, | ||
"micro_batch_sizes": [8, | ||
12, | ||
16, | ||
17], | ||
"min_gpus": 32, | ||
"max_gpus": 1500, | ||
"min_time": 20, | ||
"version": 0.1 | ||
} | ||
} | ||
|
||
|
||
def test_basic_10k(): | ||
ds_config = base_ds_config.copy() | ||
final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( | ||
ds_config=ds_config, | ||
target_deepspeed_version=ds_version) | ||
|
||
for gpu_num in valid_gpus: | ||
assert final_batch_size % gpu_num == 0, f"Batch {final_batch_size} is not divisible by GPU count {gpu_num}" | ||
batch_per_gpu = final_batch_size // gpu_num | ||
found_valid_mbsize = False | ||
|
||
for mb in ds_config['elasticity']['micro_batch_sizes']: | ||
if batch_per_gpu % mb == 0: | ||
found_valid_mb = True | ||
break | ||
assert found_valid_mb, "No valid mb found" | ||
|
||
assert len(valid_gpus) == 23 | ||
assert final_batch_size == 9792 | ||
|
||
|
||
def test_old_version(): | ||
ds_config = base_ds_config.copy() | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( | ||
ds_config=ds_config, | ||
target_deepspeed_version="0.2") | ||
|
||
|
||
def test_disabled(): | ||
ds_config = base_ds_config.copy() | ||
ds_config['elasticity']['enabled'] = False | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
final_batch_size, valid_gpus = deepspeed.elasticity.compute_elastic_config( | ||
ds_config=ds_config, | ||
target_deepspeed_version=ds_version) | ||
|
||
|
||
def test_valid_world_size(): | ||
ds_config = base_ds_config.copy() | ||
final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( | ||
ds_config=ds_config, | ||
target_deepspeed_version=ds_version, | ||
world_size=64) | ||
assert mbsize == 17 | ||
|
||
|
||
def test_invalid_world_size(): | ||
ds_config = base_ds_config.copy() | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityIncompatibleWorldSize): | ||
final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( | ||
ds_config=ds_config, | ||
target_deepspeed_version=ds_version, | ||
world_size=128) | ||
|
||
|
||
def test_future_elastic_version(): | ||
ds_config = base_ds_config.copy() | ||
ds_config['elasticity']['version'] = '0.2' | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, | ||
target_deepspeed_version=ds_version) | ||
|
||
|
||
def test_missing_max_batch(): | ||
ds_config = base_ds_config.copy() | ||
del ds_config['elasticity']['max_train_batch_size'] | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, | ||
target_deepspeed_version=ds_version) | ||
|
||
|
||
def test_missing_micro_batch(): | ||
ds_config = base_ds_config.copy() | ||
del ds_config['elasticity']['micro_batch_sizes'] | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, | ||
target_deepspeed_version=ds_version) | ||
|
||
|
||
def test_empty_config(): | ||
ds_config = {"elasticity": {"enabled": True}} | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
deepspeed.elasticity.compute_elastic_config(ds_config=ds_config, | ||
target_deepspeed_version=ds_version) | ||
|
||
|
||
def test_proper_mbsz(): | ||
ds_config = base_ds_config.copy() | ||
ds_config["elasticity"]["max_train_batch_size"] = 32 | ||
ds_config["elasticity"]["micro_batch_sizes"] = [1, 2, 3, 7] | ||
ds_config["elasticity"]["min_gpus"] = 1 | ||
final_batch_size, valid_gpus, mbsize = deepspeed.elasticity.compute_elastic_config( | ||
ds_config=ds_config, | ||
target_deepspeed_version=ds_version, | ||
world_size=7) | ||
assert mbsize == 3 | ||
|
||
|
||
def test_non_elastic_batch_params(tmpdir): | ||
config_dict = { | ||
"train_batch_size": 2, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": "Lamb", | ||
"params": { | ||
"lr": 0.00015 | ||
} | ||
}, | ||
"gradient_clipping": 1.0, | ||
"elasticity": { | ||
"enabled": True, | ||
"max_train_batch_size": 4, | ||
"micro_batch_sizes": [1, | ||
2, | ||
3, | ||
4], | ||
"min_gpus": 1, | ||
"max_gpus": 4, | ||
"min_time": 20, | ||
"version": 0.1 | ||
} | ||
} | ||
args = args_from_dict(tmpdir, config_dict) | ||
hidden_dim = 10 | ||
|
||
model = SimpleModel(hidden_dim, empty_grad=False) | ||
|
||
@distributed_test(world_size=[1, 2]) | ||
def _test_elastic(args, model, hidden_dim): | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
model, _, _,_ = deepspeed.initialize(args=args, | ||
model=model, | ||
model_parameters=model.parameters()) | ||
|
||
_test_elastic(args=args, model=model, hidden_dim=hidden_dim) | ||
|
||
|
||
def test_non_elastic_batch_params_w_override(tmpdir): | ||
config_dict = { | ||
"train_batch_size": 2, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": "Lamb", | ||
"params": { | ||
"lr": 0.00015 | ||
} | ||
}, | ||
"gradient_clipping": 1.0, | ||
"elasticity": { | ||
"enabled": True, | ||
"max_train_batch_size": 4, | ||
"micro_batch_sizes": [1, | ||
2, | ||
3, | ||
4], | ||
"min_gpus": 1, | ||
"max_gpus": 4, | ||
"min_time": 20, | ||
"version": 0.1, | ||
"ignore_non_elastic_batch_info": True | ||
} | ||
} | ||
args = args_from_dict(tmpdir, config_dict) | ||
hidden_dim = 10 | ||
|
||
model = SimpleModel(hidden_dim, empty_grad=False) | ||
|
||
@distributed_test(world_size=[1, 2]) | ||
def _test_elastic(args, model, hidden_dim): | ||
model, _, _,_ = deepspeed.initialize(args=args, | ||
model=model, | ||
model_parameters=model.parameters()) | ||
|
||
_test_elastic(args=args, model=model, hidden_dim=hidden_dim) | ||
|
||
|
||
def test_elastic_config_changed(tmpdir): | ||
config_dict = { | ||
"train_batch_size": 2, | ||
"steps_per_print": 1, | ||
"optimizer": { | ||
"type": "Lamb", | ||
"params": { | ||
"lr": 0.00015 | ||
} | ||
}, | ||
"gradient_clipping": 1.0, | ||
"elasticity": { | ||
"enabled": True, | ||
"max_train_batch_size": 4, | ||
"micro_batch_sizes": [1, | ||
2, | ||
3, | ||
4], | ||
"min_gpus": 1, | ||
"max_gpus": 4, | ||
"min_time": 20, | ||
"version": 0.1, | ||
"ignore_non_elastic_batch_info": True | ||
} | ||
} | ||
import json, os | ||
scheduler_elastic_config = config_dict.copy() | ||
scheduler_elastic_config["elasticity"]["max_train_batch_size"] = 27 | ||
os.environ['DEEPSPEED_ELASTICITY_CONFIG'] = json.dumps(scheduler_elastic_config) | ||
args = args_from_dict(tmpdir, config_dict) | ||
hidden_dim = 10 | ||
|
||
model = SimpleModel(hidden_dim, empty_grad=False) | ||
|
||
@distributed_test(world_size=[1, 2]) | ||
def _test_elastic(args, model, hidden_dim): | ||
with pytest.raises(deepspeed.elasticity.config.ElasticityError): | ||
model, _, _,_ = deepspeed.initialize(args=args, | ||
model=model, | ||
model_parameters=model.parameters()) | ||
|
||
_test_elastic(args=args, model=model, hidden_dim=hidden_dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
0.3.8 | ||
0.3.9 |