-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
engine.py
executable file
·3760 lines (3134 loc) · 173 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import re
import stat
import torch
import hashlib
from collections import defaultdict, OrderedDict, deque
from shutil import copyfile
import gc
from torch.nn.modules import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from contextlib import contextmanager
from typing import Callable, Dict, Union, Iterable, Container
import deepspeed
from deepspeed import comm as dist
from deepspeed.runtime.utils import see_memory_usage, DummyOptim
from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from deepspeed.linear.optimized_linear import LoRAOptimizedLinear
from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \
ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \
MUSGD_OPTIMIZER, LION_OPTIMIZER
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \
DATA_PARALLEL_GROUP, GLOBAL_RANK
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.compression import compression_scheduler
from deepspeed.compression.constants import \
WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \
WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \
WEIGHT_QUANTIZE_ENABLED, \
WEIGHT_QUANTIZE_GROUPS, \
WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \
WEIGHT_QUANTIZE_CHANGE_RATIO, \
WEIGHT_QUANTIZE_TYPE, \
WEIGHT_QUANTIZE_ROUNDING, \
WEIGHT_QUANTIZE_VERBOSE, \
WEIGHT_QUANTIZE_KERNEL
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS
from deepspeed.runtime.sparse_tensor import SparseTensor
from deepspeed.runtime import lr_schedules
from deepspeed.utils import groups
from deepspeed.utils import logger, log_dist, instrument_w_nvtx
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
STEP_MICRO_TIMER, \
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \
STEP_GLOBAL_TIMER
from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names
from deepspeed.monitor.monitor import MonitorMaster
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.utils import clip_grad_norm_
from deepspeed.runtime.eigenvalue import Eigenvalue
from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \
DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \
CURRICULUM_LEARNING_ENABLED, DATA_SAMPLING_NUM_WORKERS, RANDOM_LTD, \
RANDOM_LTD_ENABLED, RANDOM_LTD_LAYER_ID, RANDOM_LTD_LAYER_NUM, \
RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \
RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY
from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler
from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler
from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict
from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from .pipe.module import PipelineModule
from .utils import get_ma_status
from .compiler import is_compile_supported
from ..ops.adam import FusedAdam
from ..moe.sharded_moe import TopKGate, MOELayer
from ..moe.layer import MoE
from ..moe.utils import is_moe_param, configure_moe_param_groups
from ..git_version_info import version
from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
from deepspeed.utils.logging import print_json_dist, print_configuration
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.config import DtypeEnum
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
DeepSpeedOptimizerCallable = \
Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer]
DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler]
try:
import apex
from apex import amp
APEX_INSTALLED = True
except ImportError:
# Fail silently so we don't spam logs unnecessarily if user isn't using amp
APEX_INSTALLED = False
def split_half_float_double_sparse(tensors):
device_type = get_accelerator().device_name()
supported_types = get_accelerator().supported_dtypes()
for t in tensors:
assert t.dtype in supported_types, f"attempting to reduce an unsupported grad type: {t.dtype}"
sparse_tensor_buckets, dense_tensor_buckets = [], []
for i, dtype in enumerate(supported_types):
sparse_bucket, dense_bucket = [], []
for t in tensors:
if t.dtype == dtype:
if isinstance(t, SparseTensor):
sparse_bucket.append(t)
else:
dense_bucket.append(t)
if sparse_bucket:
sparse_tensor_buckets.append((dtype, sparse_bucket))
if dense_bucket:
dense_tensor_buckets.append((dtype, dense_bucket))
return sparse_tensor_buckets, dense_tensor_buckets
class EngineTimers(object):
r"""Wallclock timers for DeepSpeedEngine"""
def __init__(self, enable_micro_timers, enable_global_timers):
self.forward_timers = []
self.backward_timers = []
self.backward_inner_timers = []
self.backward_reduce_timers = []
self.step_timers = []
self.global_timers = []
self.micro_timers = []
if enable_micro_timers:
self.forward_timers += [FORWARD_MICRO_TIMER]
self.backward_timers += [BACKWARD_MICRO_TIMER]
self.backward_inner_timers += [BACKWARD_INNER_MICRO_TIMER]
self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER]
self.step_timers += [STEP_MICRO_TIMER]
self.micro_timers += [
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER,
STEP_MICRO_TIMER
]
if enable_global_timers:
self.forward_timers += [FORWARD_GLOBAL_TIMER]
self.backward_timers += [BACKWARD_GLOBAL_TIMER]
self.backward_inner_timers += [BACKWARD_INNER_GLOBAL_TIMER]
self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER]
self.step_timers += [STEP_GLOBAL_TIMER]
self.global_timers += [
FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER,
STEP_GLOBAL_TIMER
]
class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training."""
def __init__(self,
args,
model,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=None,
collate_fn=None,
config=None,
config_class=None,
mesh_device=None,
dont_change_device=False):
super(DeepSpeedEngine, self).__init__()
self.dont_change_device = dont_change_device
self.client_optimizer = optimizer
self.client_lr_scheduler = lr_scheduler
self.training_data = training_data
self.collate_fn = collate_fn
self.mpu = mpu
self.all_to_all_group = None
self.data_parallel_group = None
self.global_steps = 0
self.global_samples = 0
self.micro_steps = 0
self.skipped_steps = 0
self.gradient_average = True
self.warn_unscaled_loss = True
self.config = config
self._config = config_class
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
self.inside_no_sync_ctxt = False
self.progressive_layer_drop = None
self.eigenvalue = None
self.block_eigenvalue = None
self.gas_boundary_ctr = 0
self.dist_backend = get_accelerator().communication_backend_name()
self.has_moe_layers = False
self.num_experts = []
self.gate_modules = []
self.moe_layers = []
self._step_applied = False
self._global_grad_norm = None
self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend.
self.checkpoint_engine = None
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None
self.losses = None
self.mesh_device = mesh_device
# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
if self.mesh_device:
groups.mesh_device = self.mesh_device
self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
self._do_sanity_check()
see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown())
if mpu is not None:
if self.elasticity_enabled():
if not self.is_elastic_model_parallel_supported():
assert not self.elasticity_enabled(), ("Elasticity is not currently supported"
" with model parallelism.")
self._set_distributed_vars(args)
dist.configure(self._config)
self.monitor = MonitorMaster(self._config.monitor_config)
see_memory_usage(
f"DeepSpeed Engine: Before configure distributed model",
force=self.memory_breakdown(),
)
self.pipeline_parallelism = isinstance(model, PipelineModule)
# Configure distributed model
self._configure_distributed_model(model)
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
self.param_names = {param: name for name, param in model.named_parameters()}
self._get_model_parameters()
see_memory_usage(f"DeepSpeed Engine: After configure distributed model")
# Configure wall clock timers
self.timers = SynchronizedWallClockTimer()
# Throughput timer
self.tput_timer = ThroughputTimer(self._config.timers_config,
batch_size=self.train_batch_size(),
steps_per_output=self.steps_per_print(),
monitor_memory=False)
log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0])
if self.flops_profiler_enabled():
self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor())
if training_data:
self.training_dataloader = self.deepspeed_io(training_data)
else:
self.training_dataloader = None
# Configure optimizer and scheduler
self.optimizer = None
self.basic_optimizer = None
self.lr_scheduler = None
has_optimizer = False
if optimizer or self.optimizer_name():
has_optimizer = True
# If no parameters given by init default to module parameters
if model_parameters is None:
model_parameters = self.module.parameters()
# Convert model parameters from generator to list
if not isinstance(model_parameters, list):
model_parameters = list(model_parameters)
if has_optimizer:
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler()
self._report_progress(0)
elif self.zero_optimization():
# no optim selected but zero is enabled
self.optimizer = self._configure_zero_optimizer(optimizer=None)
elif self.bfloat16_enabled():
self.optimizer = self._configure_bf16_optimizer(optimizer=None)
# Hook optimizer for snip_momentum pruning
if hasattr(model, 'pruners'):
from ..compression.helper import rewrite_optimizer_step
self.optimizer.pruners = model.pruners
rewrite_optimizer_step(self.optimizer)
# Bookkeeping for sparse support
self.sparse_tensor_module_names = set()
# if self.sparse_gradients_enabled():
for name, module in self.module.named_modules():
if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled():
self.sparse_tensor_module_names.add(name + ".weight")
logger.info("Will convert {} to sparse tensor during training".format(name))
self._optimized_linear_offload_setup()
self.save_non_zero_checkpoint = False
self.save_zero_checkpoint = False
if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
self._configure_checkpointing(dist_init_required)
if self.eigenvalue_enabled():
self.eigenvalue = self._configure_eigenvalue()
if self.pld_enabled():
self.progressive_layer_drop = self._configure_progressive_layer_drop()
if self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy()
if self.random_ltd_enabled():
random_ltd_config = self.random_ltd_config()
random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size()
random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu()
self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config)
# Engine timers
self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(),
enable_global_timers=self.wall_clock_breakdown()
or self.flops_profiler_enabled())
if self.global_rank == 0:
self._config.print("DeepSpeedEngine configuration")
if self.dump_state():
print_configuration(self, "DeepSpeedEngine")
# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
self.unflatten = _unflatten_dense_tensors
self._is_compiled = False
def _optimized_linear_offload_setup(self):
self.optimized_linear_base_weight_sharding = False
self.optimized_linear_lora_enabled = False
offload_ratio = None
for _, module in self.module.named_modules():
if isinstance(module, LoRAOptimizedLinear):
self.optimized_linear_lora_enabled = True
offload_ratio = None
if offload_ratio is not None:
assert offload_ratio == module.lora_config.offload_ratio, \
"all lora_config offload ratios should be the same across the model"
offload_ratio = module.lora_config.offload_ratio
if module.zero_shards > 1:
# set attr so checkpoint saving can handle BWS properly
self.optimized_linear_base_weight_sharding = True
if offload_ratio is None:
# Nothing enabled, do nothing
return
total_params = 0
for _, p in self.module.named_parameters():
if hasattr(p, 'ds_optim_param'):
total_params += p.numel()
offload_limit = total_params * offload_ratio
logger.info(f'offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params')
total_offloaded = 0
for _, p in self.module.named_parameters():
if hasattr(p, 'ds_optim_param'):
if total_offloaded < offload_limit:
total_offloaded += p.numel()
p.ds_offload = True
p.offload()
else:
p.ds_offload = False
def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
debug_clear_module_and_param_names()
def _get_model_parameters(self):
if self.autotuning_profile_model_info():
self.autotuning_model_info = {}
num_params = 0
trainable_num_params = 0
for p in self.module.parameters():
# since user code might call deepspeed.zero.Init() before deepspeed.initialize(), need to check the attribute to check if the parameter is partitioned in zero 3 already or not
n = 0
if hasattr(p, "ds_tensor"): # if the parameter is partitioned in zero 3
n += p.ds_numel
else: # if the parameter is not partitioned in zero 3 yet
n += p.numel()
num_params += n
if p.requires_grad:
trainable_num_params += n
if self.global_rank == 0:
self.autotuning_model_info["num_params"] = num_params * self.mp_world_size
self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size
logger.info(f"model parameter = {num_params}")
def get_batch_info(self):
"""Get all training batch related settings.
Returns:
train_batch_size (int): The effective training batch size. This is the amount of data
samples that leads to one step of model update.
train_micro_batch_size_per_gpu (int): Batch size to be processed by one GPU in one
step (without gradient accumulation).
gradient_accumulation_steps (int): Number of training steps to accumulate gradients
before averaging and applying them.
"""
return (
self.train_batch_size,
self.train_micro_batch_size_per_gpu,
self.gradient_accumulation_steps,
)
def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the number of
micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
(i.e., ``train_micro_batch_size_per_gpu``) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism.
"""
if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0:
#print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
raise ValueError(f'Train batch size must be divisible by micro-batch data parallelism')
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size)
# overwrite config
self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas
def set_train_micro_batch_size(self, micro_batch_size):
"""Adjust the micro batch size(i.e., the micro batch size in every data parallel group),
while keep the gradient accumulation steps the same.
Args:
micro_batch_size (int): The new micro batch size for training.
"""
# overwrite config
new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size
self._config.train_batch_size = new_global_batch_size
self._config.train_micro_batch_size_per_gpu = micro_batch_size
def set_data_post_process_func(self, post_process_func):
if self.training_dataloader is not None:
self.training_dataloader.post_process_func = post_process_func
def set_custom_curriculum_learning_schedule(self, schedule_func_dict):
if self.training_dataloader is not None and self.curriculum_learning_enabled():
self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict)
def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
the norm will be global.
The computed norm will be cached and reused until the next step() pass.
.. note::
In the presence of model parallelism, this is a collective call
and acts as a barrier among ``mpu.get_model_parallel_group()``.
Returns:
float: norm
"""
return self._global_grad_norm
def __getattr__(self, name):
"""
Pass through attributes defined in the model if they are not overridden by ds-engine.
"""
_module = {}
if "module" in self.__dict__:
_module = self.__dict__['module']
if name in dir(self):
return getattr(self, name)
elif name in dir(_module):
return getattr(_module, name)
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled
def checkpoint_tag_validation_fail(self):
return self._config.checkpoint_tag_validation_fail
def elasticity_enabled(self):
return self._config.elasticity_enabled
def is_elastic_model_parallel_supported(self):
if self.elasticity_enabled():
# Add code for finding number of GPUs per node automatically
if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0:
return True
else:
return False
def pld_enabled(self):
return self._config.pld_enabled
def pld_params(self):
return self._config.pld_params
def pld_theta(self):
return self.pld_params()[PLD_THETA]
def pld_gamma(self):
return self.pld_params()[PLD_GAMMA]
def eigenvalue_enabled(self):
return self._config.eigenvalue_enabled
def eigenvalue_verbose(self):
return self._config.eigenvalue_verbose
def eigenvalue_max_iter(self):
return self._config.eigenvalue_max_iter
def eigenvalue_tol(self):
return self._config.eigenvalue_tol
def eigenvalue_stability(self):
return self._config.eigenvalue_stability
def eigenvalue_gas_boundary_resolution(self):
return self._config.eigenvalue_gas_boundary_resolution
def eigenvalue_layer_name(self):
return self._config.eigenvalue_layer_name
def eigenvalue_layer_num(self):
return self._config.eigenvalue_layer_num
def curriculum_enabled_legacy(self):
return self._config.curriculum_enabled_legacy
def curriculum_params_legacy(self):
return self._config.curriculum_params_legacy
def data_efficiency_enabled(self):
return self._config.data_efficiency_enabled
def data_efficiency_config(self):
return self._config.data_efficiency_config
def data_sampling_enabled(self):
return self._config.data_efficiency_config[DATA_SAMPLING][DATA_SAMPLING_ENABLED]
def data_sampling_config(self):
return self._config.data_efficiency_config[DATA_SAMPLING]
def curriculum_learning_enabled(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED]
def curriculum_learning_config(self):
return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING]
def random_ltd_enabled(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED]
def random_ltd_config(self):
return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD]
def random_ltd_initialize(self):
assert self.random_ltd_enabled()
random_ltd_config = self.random_ltd_config()
random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])])
count = 0
for name, layer in self.module.named_modules():
if isinstance(layer, RandomLayerTokenDrop):
if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: ###[1,2,3]
layer.init_config(random_ltd_config, self.random_ltd_scheduler, count)
random_ltd_queue.popleft()
count += 1
if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count:
raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \
equivalent to the len of random_ltd_layer_id {count}')
if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]:
assert self.client_lr_scheduler is None
raise ValueError(f'not yet support')
#self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler)
def get_sequence_parallel_group(self):
return self.seq_parallel_group
def wall_clock_breakdown(self):
return self._config.wall_clock_breakdown
def flops_profiler_enabled(self):
return self._config.flops_profiler_config.enabled or self.autotuning_enabled()
def flops_profiler_recompute_fwd_factor(self):
return self._config.flops_profiler_config.recompute_fwd_factor
def flops_profiler_profile_step(self):
step = self._config.flops_profiler_config.profile_step
if self._config.autotuning_config.enabled:
step = self.autotuning_start_profile_step()
return step
def flops_profiler_module_depth(self):
return self._config.flops_profiler_config.module_depth
def flops_profiler_top_modules(self):
return self._config.flops_profiler_config.top_modules
def flops_profiler_detailed(self):
if self._config.autotuning_config.enabled:
return False
return self._config.flops_profiler_config.detailed
def flops_profiler_output_file(self):
return self._config.flops_profiler_config.output_file
def memory_breakdown(self):
return self._config.memory_breakdown
def autotuning_enabled(self):
return self._config.autotuning_config.enabled
def autotuning_start_profile_step(self):
return self._config.autotuning_config.start_profile_step
def autotuning_end_profile_step(self):
return self._config.autotuning_config.end_profile_step
def autotuning_metric_path(self):
path = self._config.autotuning_config.metric_path
if not path:
path = os.path.join(os.getcwd(), "autotuning_metric.json")
return path
def autotuning_model_info_path(self):
path = self._config.autotuning_config.model_info_path
if not path:
path = os.path.join(os.getcwd(), "autotuning_model_info.json")
return path
def autotuning_metric(self):
return self._config.autotuning_config.metric
def autotuning_profile_model_info(self):
return self.autotuning_enabled(
) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get(
"profile", False)
def sparse_gradients_enabled(self):
return self._config.sparse_gradients_enabled
def train_batch_size(self):
return self._config.train_batch_size
def train_micro_batch_size_per_gpu(self):
return self._config.train_micro_batch_size_per_gpu
def optimizer_name(self):
return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name)
def optimizer_params(self):
return self._config.optimizer_params
def optimizer_legacy_fusion(self):
return self._config.optimizer_legacy_fusion
def scheduler_name(self):
return self._config.scheduler_name
def scheduler_params(self):
return self._config.scheduler_params
def quantize_training(self):
return (
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_IN_FORWARD_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS]
[WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE],
self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL],
)
def zero_optimization(self):
return self._config.zero_enabled
def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer
def zero_force_ds_cpu_optimizer(self):
return self._config.zero_force_ds_cpu_optimizer
def zero_reduce_scatter(self):
return self._config.zero_config.reduce_scatter
def zero_overlap_comm(self):
return self._config.zero_config.overlap_comm
def zero_offload_optimizer(self):
return self._config.zero_config.offload_optimizer
def zero_offload_param(self):
return self._config.zero_config.offload_param
def zero_use_cpu_optimizer(self):
if self._config.zero_config.offload_optimizer is not None:
return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]
return False
def zero_cpu_offload(self):
if self._config.zero_config.offload_optimizer is not None:
return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu
return False
def zero_partial_offload(self):
return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0)
def zero_sub_group_size(self):
return self._config.zero_config.sub_group_size
def zero_optimization_stage(self):
return self._config.zero_optimization_stage
def mics_shard_size(self):
return self._config.mics_shard_size
def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size
def zero_multi_rank_bucket_allreduce(self):
return self._config.zero_config.use_multi_rank_bucket_allreduce
def zero_allgather_bucket_size(self):
return self._config.zero_config.allgather_bucket_size
def zero_optimization_partition_gradients(self):
return self.zero_optimization_stage() >= ZeroStageEnum.gradients
def zero_optimization_partition_weights(self):
return self.zero_optimization_stage() >= ZeroStageEnum.weights
def is_first_weights_partition_group(self):
ret = True if self.mics_shard_size() < 0 \
and self.zero_optimization_partition_weights() else False
if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size():
ret = True
return ret
def zero_contiguous_gradients(self):
return self._config.zero_config.contiguous_gradients
def zero_load_from_fp32_weights(self):
return self._config.zero_config.load_from_fp32_weights
def zero_elastic_checkpoint(self):
return self._config.zero_config.elastic_checkpoint
def zero_has_nvme_offload(self):
if not hasattr(self.optimizer, "swap_optimizer"):
return False
return self.optimizer.swap_optimizer or self.optimizer.params_in_nvme_and_cpu
def zero_max_live_parameters(self):
return self._config.zero_config.max_live_parameters
def zero_max_reuse_distance(self):
return self._config.zero_config.max_reuse_distance
def zero_prefetch_bucket_size(self):
return self._config.zero_config.prefetch_bucket_size
def zero_module_granularity_threshold(self):
return self._config.zero_config.module_granularity_threshold
def zero_param_persistence_threshold(self):
return self._config.zero_config.param_persistence_threshold
def zero_model_persistence_threshold(self):
return self._config.zero_config.model_persistence_threshold
def zero_gather_16bit_weights_on_model_save(self):
return self._config.zero_config.gather_16bit_weights_on_model_save
def zero_grad_hooks(self):
return self._config.zero_config.grad_hooks
def zero_legacy_stage1(self):
return self._config.zero_config.legacy_stage1
def zero_ignore_unused_parameters(self):
return self._config.zero_config.ignore_unused_parameters
def graph_harvesting(self):
return self._config.graph_harvesting
def fp16_enabled(self):
return self._config.fp16_enabled
def bfloat16_enabled(self):
return self._config.bfloat16_enabled
def fp16_master_weights_and_gradients(self):
return self._config.fp16_master_weights_and_gradients
def amp_enabled(self):
return self._config.amp_enabled
def amp_params(self):
return self._config.amp_params
def fp16_auto_cast(self):
return self._config.fp16_auto_cast
def loss_scale(self):
return self._config.loss_scale
def gradient_accumulation_steps(self):
return self._config.gradient_accumulation_steps
def use_node_local_storage(self):
return self._config.use_node_local_storage
def load_universal_checkpoint(self):
return self._config.load_universal_checkpoint
@property
def communication_data_type(self):
res = self._config.communication_data_type
if res is not None:
return res
if self.fp16_enabled():
return torch.float16
if self.bfloat16_enabled():
return torch.bfloat16
return torch.float32
@communication_data_type.setter
def communication_data_type(self, value):
self._config.communication_data_type = value
def postscale_gradients(self):
return not self._config.prescale_gradients
def gradient_predivide_factor(self):
return self._config.gradient_predivide_factor
def steps_per_print(self):
return self._config.steps_per_print
def zero_allgather_partitions(self):
return self._config.zero_config.allgather_partitions
def zero_round_robin_gradients(self):
return self._config.zero_config.round_robin_gradients
def zero_hpz_partition_size(self):
return self._config.zero_config.zero_hpz_partition_size
def zero_quantized_weights(self):
return self._config.zero_config.zero_quantized_weights
def zero_quantized_nontrainable_weights(self):
return self._config.zero_config.zero_quantized_nontrainable_weights
def zero_quantized_gradients(self):
return self._config.zero_config.zero_quantized_gradients
def zeropp_loco_param(self):
return self._config.zero_config.zeropp_loco_param
def dump_state(self):
return self._config.dump_state
def gradient_clipping(self):
return self._config.gradient_clipping
def dynamic_loss_scale(self):
return self._config.loss_scale == 0
def initial_dynamic_scale(self):
return self._config.initial_dynamic_scale
def dynamic_loss_scale_args(self):
return self._config.dynamic_loss_scale_args
def swap_tensor_config(self):
return self._config.swap_tensor_config
def aio_config(self):
return self._config.aio_config
def get_data_types(self):
model_dtype = torch.float32
if self.fp16_enabled():
model_dtype = torch.float16
elif self.bfloat16_enabled():
model_dtype = torch.bfloat16
if self._config.grad_accum_dtype is None:
if model_dtype == torch.bfloat16 and not self.zero_optimization():
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
else:
grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value
return (model_dtype, grad_accum_dtype)
def _optimizer_has_ckpt_event_prologue(self):
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue')
def _optimizer_has_ckpt_event_epilogue(self):
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue')
def _configure_lr_scheduler(self):
if self.client_lr_scheduler:
if isinstance(self.client_lr_scheduler, Callable):
log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0])
self.lr_scheduler = self.client_lr_scheduler(self.basic_optimizer)
else:
log_dist('DeepSpeed using client LR scheduler', ranks=[0])
self.lr_scheduler = self.client_lr_scheduler
else:
# load lr scheduler from json configuration if lr scheduler is not defined and passed in
lr_scheduler = self._scheduler_from_config(self.optimizer)
log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0])
self.lr_scheduler = lr_scheduler
log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
def _configure_checkpointing(self, dist_init_required):
self.checkpoint_engine = TorchCheckpointEngine()
if self._config is not None and self._config.nebula_config.enabled:
try:
from deepspeed.runtime.checkpoint_engine.nebula_checkpoint_engine import \
NebulaCheckpointEngine
self.checkpoint_engine = NebulaCheckpointEngine(config_params=self._config.nebula_config)
except ImportError as err:
logger.error(f"No torch_nebula was found! Will fall back to torch.save. Details: {err}")
self.checkpoint_engine = TorchCheckpointEngine()
dp_rank = groups._get_sequence_data_parallel_rank()
rank = self.local_rank if self.use_node_local_storage() else dp_rank
# only the first data parallel process needs to store the model checkpoint
# if you want to use node local storage this must be done by rank 0 on each
# node
self.save_non_zero_checkpoint = (rank == 0) or (self.zero_optimization_partition_weights()
and self.is_first_weights_partition_group())
if self.zero_optimization() or self.bfloat16_enabled():