From 4716b0f7691f6e58842fc522b48a54090b784ca8 Mon Sep 17 00:00:00 2001 From: Heyang Qin Date: Mon, 15 May 2023 06:38:23 -0700 Subject: [PATCH] share inflight registry between PartitionedParameterCoordinators (#3462) * share inflight registry between PartitionedParameterCoordinator * bound registry to model * make InflightParamRegistry standalone * fix format --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/parameter_offload.py | 7 +++++- .../zero/partitioned_param_coordinator.py | 24 ++++++++++--------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 55beff336740..f0ed5013a3ea 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -10,7 +10,7 @@ from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.partition_parameters import * -from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params +from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator @@ -244,6 +244,10 @@ def __init__(self, self._max_available_parameters_in_numel = int(max_live_parameters) self.__allgather_stream = get_accelerator().Stream() if overlap_comm else get_accelerator().default_stream() + if not hasattr(module, "ds_inflight_param_registry"): + module.ds_inflight_param_registry = InflightParamRegistry() + self.__inflight_param_registry = module.ds_inflight_param_registry + self.forward_hooks = [] self.backward_hooks = [] self.setup_zero_stage3_hooks() @@ -270,6 +274,7 @@ def get_param_coordinator(self, training): max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, max_available_parameters_in_numel=self._max_available_parameters_in_numel, allgather_stream=self.__allgather_stream, + inflight_param_registry=self.__inflight_param_registry, prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme, ) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 949c54f5e806..ff2cfff8f8c0 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -40,18 +40,19 @@ class ZeRoTraceMode(Enum): INVALID = 3 -class PartitionedParameterCoordinator: - """Handles partitioning and gathering of parameters.""" +class InflightParamRegistry(UserDict): + """registry for parameters in flight""" + + def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}") + self.data[param] = handle - class __InflightParamRegistry(UserDict): - """registry for parameters in flight""" - def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None: - if param in self.data: - raise RuntimeError(f"{param.ds_summary()} already in registry") - if param.ds_status != ZeroParamStatus.INFLIGHT: - raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}") - self.data[param] = handle +class PartitionedParameterCoordinator: + """Handles partitioning and gathering of parameters.""" @dataclass class __ParamInTrace: @@ -64,10 +65,11 @@ def __init__( max_reuse_distance_in_numel: int, max_available_parameters_in_numel: int, allgather_stream: get_accelerator().Stream, + inflight_param_registry: InflightParamRegistry, prefetch_nvme: bool = False, ) -> None: # mapping of param -> handle for each param that is currently in flight - self.__inflight_param_registry = __class__.__InflightParamRegistry() + self.__inflight_param_registry = inflight_param_registry # keeps track of the number of submodules invoked so far. self.__step_id: int = 0 # network tracing mode