diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index bc581dc2215e..a4d79697a111 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2101,16 +2101,17 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: return grad_dict - def _fp32_state_allgather(self, param, fp32_state): - reduce_buffer = torch.zeros(self.partition_count * fp32_state.numel(), + def _fp32_state_allgather(self, param, fp32_state_partition): + reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(), dtype=torch.float32, device=param.device).flatten() my_rank = dist.get_rank(group=self.dp_process_group) partitions = [ reduce_buffer.narrow(0, - fp32_state.numel() * i, fp32_state.numel()) for i in range(self.partition_count) + fp32_state_partition.numel() * i, fp32_state_partition.numel()) + for i in range(self.partition_count) ] - partitions[my_rank].data.copy_(fp32_state.data, non_blocking=False) + partitions[my_rank].data.copy_(fp32_state_partition.data, non_blocking=False) dist.all_gather(partitions, partitions[my_rank], group=self.dp_process_group) @@ -2125,19 +2126,16 @@ def get_fp32_grad_for_param(self, param) -> Tensor: if self.offload_optimizer: group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] - fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, - num_elements).to(device=param.device) + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) else: fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float() return self._fp32_state_allgather(param, fp32_grad) - def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: - if not param.requires_grad: - return None - + def _get_fp32_opt_state_partition(self, param, optim_state_key=None): if not get_accelerator().is_synchronized_device(): self.reduce_and_partition_stream.synchronize() + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] if self._swappable_optimizer_subgroup(group_idx): @@ -2145,16 +2143,41 @@ def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: fp32_param = self.fp32_partitioned_groups_flat[group_idx] if optim_state_key is None: - fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements).to(device=param.device) + fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements) else: - fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow( - 0, dest_offset, num_elements).to(device=param.device) + fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(0, dest_offset, num_elements) + + return fp32_opt_state, group_idx + def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: + if not param.requires_grad: + return None + + fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) hp_param = self._fp32_state_allgather(param, fp32_opt_state) + if self._swappable_optimizer_subgroup(group_idx): self._optimizer_states_and_gradient_swap_out(group_idx) + return hp_param + def set_full_hp_param(self, value, param, optim_state_key=None): + if not param.requires_grad: + return + + assert value.numel( + ) == param.ds_numel, f" Number of elements do not match: {value.numel()} != {param.ds_numel}" + + fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + my_rank = dist.get_rank(group=self.dp_process_group) + value_partition = value.flatten().narrow(0, + fp32_opt_state_partition.numel() * my_rank, + fp32_opt_state_partition.numel()) + fp32_opt_state_partition.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_out(group_idx) + @instrument_w_nvtx def _partition_all_parameters(self): self.parameter_offload.partition_all_parameters() diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 4e7f6b61d075..b6668b5ff5ce 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -12,6 +12,8 @@ # TODO: Move tensor fragment and mixed precision to zero utils from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state +from .tensor_fragment import set_full_hp_param +from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state from .mixed_precision_linkage import link_hp_params from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/mixed_precision_linkage.py b/deepspeed/utils/mixed_precision_linkage.py index ecc29e930954..b1afa8f00aa3 100644 --- a/deepspeed/utils/mixed_precision_linkage.py +++ b/deepspeed/utils/mixed_precision_linkage.py @@ -5,6 +5,7 @@ import types from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping +from deepspeed.utils import set_full_hp_param def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, @@ -27,6 +28,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr lp_param._dp_group = dp_group lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param) + lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param) # lp_param overlaps with partition if both are true # 1) current_offset < partition_end, diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index ef09edd3da1e..18e373799ab7 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -45,22 +45,31 @@ def get_hp_fragment_address(self): def get_optim_state_keys(self): return list(self.optim_fragment.keys()) + def get_hp_fragment(self, optim_state_key=None): + if optim_state_key is None: + return self.hp_fragment + return self.get_optim_state_fragment(optim_state_key) + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: lp_frag_address = self._hp_mapping.lp_fragment_address reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel) - if optim_state_key is None: - hp_fragment = self._hp_mapping.hp_fragment - else: - hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key) - + hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key) reduce_fragment.data.copy_(hp_fragment.data) dist.all_reduce(reduce_buffer, group=self._dp_group) return reduce_buffer.reshape_as(self) +def set_full_hp_param(self, value, optim_state_key=None): + if self._hp_mapping is not None: + lp_frag_address = self._hp_mapping.lp_fragment_address + value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel) + hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key) + hp_fragment.data.copy_(value_fragment.data) + + def get_full_hp_grad(self): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: @@ -105,11 +114,28 @@ def safe_get_full_fp32_param(param): return None +def safe_set_full_fp32_param(param, value): + """Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_full_hp_param(value, param) + + # ZeRO stage 1, 2, and bf16_optimizer params + if hasattr(param, '_hp_mapping'): + param.set_full_hp_param(value) + + def safe_get_full_optimizer_state(param, optim_state_key): """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter. Args: param (``torch.nn.Parameter``): A model parameter + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) """ # ZeRO stage 3 param if hasattr(param, 'ds_id'): @@ -121,6 +147,23 @@ def safe_get_full_optimizer_state(param, optim_state_key): return None +def safe_set_full_optimizer_state(param, value, optim_state_key): + """Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_full_hp_param(value, param, optim_state_key) + + # ZeRO stage 1, 2, and bf16_optimizer params + if hasattr(param, '_hp_mapping'): + param.set_full_hp_param(value, optim_state_key) + + # TODO: Figure out the correct return dtype def safe_get_full_grad(param): """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter. @@ -142,6 +185,9 @@ def safe_get_full_grad(param): return None +# TODO: Implement API for setting ZeRO partitioned gradients + + def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, param_group_index, partition_start, partition_size, optimizer_state_dict): lp_end = lp_param.numel() + lp_start diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index 333b29ed98d8..56a7987dc496 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -376,6 +376,35 @@ These routines can be used in a training loop as shown in the following snippet. optimizer.step() + +Modifying Partitioned States +---------------------------- + +Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following two routines for modifying the fp32 master parameters and the fp32 optimizer states. + +.. autofunction:: deepspeed.utils.safe_set_full_fp32_param + +.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state + + +These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet. + +.. code-block:: python + + [...] + from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state + # Here is an example to zero all the fp32 parameters and optimizer states. + for n, lp in model.named_parameters(): + # Assume zero stage 1 or 2, since stage 3 requires a gather to assemble lp + zero_tensor = torch.zeros_like(lp) + + hp = safe_set_full_fp32_param(lp, zero_tensor) + exp_avg = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg") + exp_avg_sq = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq") + + [...] + + GPU Memory Management --------------------- diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index 475502561418..63d05ab6d352 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -8,14 +8,19 @@ import torch from unit.common import DistributedTest -from unit.simple_model import random_dataloader +from unit.simple_model import random_dataloader, SimpleModel from unit.util import bf16_required_version_check import deepspeed from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state +from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.aio import AsyncIOBuilder +WEIGHT_KEY = 'weight' +FIRST_ORDER_KEY = 'exp_avg' +SECOND_ORDER_KEY = 'exp_avg_sq' + def validate_full_tensors(model): for _, lp in model.named_parameters(): @@ -73,7 +78,7 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype): @pytest.mark.parametrize('frozen_weights', [True, False]) -class TestTensorFragment(DistributedTest): +class TestTensorFragmentGet(DistributedTest): # Need multiple gpus to test possible hanging world_size = 2 reuse_dist_env = True @@ -150,3 +155,104 @@ def test_bf16_fragments(self, frozen_weights): hidden_dim = 128 model = MyModel(hidden_dim, frozen_weights) run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16) + + +def create_random_values(model, key_list, group): + param_values = {} + for n, lp in model.named_parameters(): + param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape + param_values[n] = {} + for key in key_list: + rand_value = torch.rand(param_shape, dtype=torch.float32, device=model.device) + dist.broadcast(rand_value, src=0, group=group) + param_values[n][key] = rand_value + return param_values + + +def set_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, value_tensor in value_dict[n].items(): + if key == WEIGHT_KEY: + safe_set_full_fp32_param(lp, value_tensor) + else: + safe_set_full_optimizer_state(lp, value_tensor, key) + + +def validate_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, expected_tensor in value_dict[n].items(): + if key == WEIGHT_KEY: + actual_tensor = safe_get_full_fp32_param(lp) + else: + actual_tensor = safe_get_full_optimizer_state(lp, key) + assert torch.equal(expected_tensor, actual_tensor) + + +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) +class TestTensorFragmentUpdate(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) + def test_zero_fragments(self, tmpdir, zero_stage, offload_device, dtype): + + if dtype == torch.bfloat16 and not bf16_required_version_check(accelerator_check=False): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if offload_device == OffloadDeviceEnum.nvme: + if zero_stage != 3: + pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_device == OffloadDeviceEnum.cpu: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} + elif offload_device == OffloadDeviceEnum.nvme: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": offload_device, + "nvme_path": str(tmpdir) + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + hidden_dim = 128 + if zero_stage == 3: + config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=4) + else: + model = SimpleModel(hidden_dim, nlayers=4) + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + world = dist.get_world_size() + group = dist.new_group(ranks=list(range(world))) + + dist.barrier() + optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] + optim_state_values = create_random_values(model, optim_keys, group) + set_param_values_with_dict(model, optim_state_values) + validate_param_values_with_dict(model, optim_state_values) + + # Needed in ZeRO 3. Not doing so can leak memory. + model.destroy()