Skip to content

Commit

Permalink
share inflight registry between PartitionedParameterCoordinators (#3462)
Browse files Browse the repository at this point in the history
* share inflight registry between PartitionedParameterCoordinator

* bound registry to model

* make InflightParamRegistry standalone

* fix format

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
HeyangQin and tjruwase authored May 15, 2023
1 parent 9f4a876 commit 4716b0f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
7 changes: 6 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.runtime.zero.partition_parameters import _init_external_params
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params
from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, InflightParamRegistry, iter_params
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator

Expand Down Expand Up @@ -244,6 +244,10 @@ def __init__(self,
self._max_available_parameters_in_numel = int(max_live_parameters)
self.__allgather_stream = get_accelerator().Stream() if overlap_comm else get_accelerator().default_stream()

if not hasattr(module, "ds_inflight_param_registry"):
module.ds_inflight_param_registry = InflightParamRegistry()
self.__inflight_param_registry = module.ds_inflight_param_registry

self.forward_hooks = []
self.backward_hooks = []
self.setup_zero_stage3_hooks()
Expand All @@ -270,6 +274,7 @@ def get_param_coordinator(self, training):
max_reuse_distance_in_numel=self._max_reuse_distance_in_numel,
max_available_parameters_in_numel=self._max_available_parameters_in_numel,
allgather_stream=self.__allgather_stream,
inflight_param_registry=self.__inflight_param_registry,
prefetch_nvme=self.offload_device == OffloadDeviceEnum.nvme,
)

Expand Down
24 changes: 13 additions & 11 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@ class ZeRoTraceMode(Enum):
INVALID = 3


class PartitionedParameterCoordinator:
"""Handles partitioning and gathering of parameters."""
class InflightParamRegistry(UserDict):
"""registry for parameters in flight"""

def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None:
if param in self.data:
raise RuntimeError(f"{param.ds_summary()} already in registry")
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}")
self.data[param] = handle

class __InflightParamRegistry(UserDict):
"""registry for parameters in flight"""

def __setitem__(self, param: Parameter, handle: AllGatherCoalescedHandle) -> None:
if param in self.data:
raise RuntimeError(f"{param.ds_summary()} already in registry")
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"attempted to add non-inflight parameter to registry {param.ds_summary()}")
self.data[param] = handle
class PartitionedParameterCoordinator:
"""Handles partitioning and gathering of parameters."""

@dataclass
class __ParamInTrace:
Expand All @@ -64,10 +65,11 @@ def __init__(
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = __class__.__InflightParamRegistry()
self.__inflight_param_registry = inflight_param_registry
# keeps track of the number of submodules invoked so far.
self.__step_id: int = 0
# network tracing mode
Expand Down

0 comments on commit 4716b0f

Please sign in to comment.