From 7658e21474f6fd08da8a91ccbc7f87870d050c7d Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 11:38:09 -0400 Subject: [PATCH 1/8] Add step reset to free memory --- src/accelerate/accelerator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index f4e6080c1b9..77c72eba6a4 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -2546,7 +2546,7 @@ def load_state(self, input_dir: str, **load_model_func_kwargs): def free_memory(self): """ Will release all references to the internal objects stored and call the garbage collector. You should call this - method between two trainings with different models/optimizers. + method between two trainings with different models/optimizers. Also will reset `Accelerator.step` to 0. Example: @@ -2565,6 +2565,7 @@ def free_memory(self): self._models = [] self._dataloaders = [] self.deepspeed_engine_wrapped = None + self.step = 0 release_memory() def clear(self): From 03e6d67d65b7124489af49f583a0b9a0d0e068a9 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 11:46:49 -0400 Subject: [PATCH 2/8] Check if not Accelerated Optimizer --- src/accelerate/optimizer.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/accelerate/optimizer.py b/src/accelerate/optimizer.py index d5eeef99a02..fac16a0b008 100644 --- a/src/accelerate/optimizer.py +++ b/src/accelerate/optimizer.py @@ -53,21 +53,22 @@ class AcceleratedOptimizer(torch.optim.Optimizer): """ def __init__(self, optimizer, device_placement=True, scaler=None): - self.optimizer = optimizer - self.scaler = scaler - self.accelerator_state = AcceleratorState() - self.gradient_state = GradientState() - self.device_placement = device_placement - self._is_overflow = False - - # Handle device placement - if device_placement: - state_dict = self.optimizer.state_dict() - if self.accelerator_state.distributed_type == DistributedType.TPU: - xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) - else: - state_dict = move_to_device(state_dict, self.accelerator_state.device) - self.optimizer.load_state_dict(state_dict) + if not isinstance(optimizer, AcceleratedOptimizer): + self.optimizer = optimizer + self.scaler = scaler + self.accelerator_state = AcceleratorState() + self.gradient_state = GradientState() + self.device_placement = device_placement + self._is_overflow = False + + # Handle device placement + if device_placement: + state_dict = self.optimizer.state_dict() + if self.accelerator_state.distributed_type == DistributedType.TPU: + xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) + else: + state_dict = move_to_device(state_dict, self.accelerator_state.device) + self.optimizer.load_state_dict(state_dict) @property def state(self): From 8ad286d6f89e2b246bda39064f366d5a751707d3 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 11:56:28 -0400 Subject: [PATCH 3/8] Continue --- src/accelerate/optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/accelerate/optimizer.py b/src/accelerate/optimizer.py index fac16a0b008..d9a9fe811c1 100644 --- a/src/accelerate/optimizer.py +++ b/src/accelerate/optimizer.py @@ -69,6 +69,8 @@ def __init__(self, optimizer, device_placement=True, scaler=None): else: state_dict = move_to_device(state_dict, self.accelerator_state.device) self.optimizer.load_state_dict(state_dict) + else: + self.optimizer = optimizer.optimizer @property def state(self): From 0a1ed716dafa77f1aec5db4258076bb7cb5bb066 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 12:07:07 -0400 Subject: [PATCH 4/8] Another try --- src/accelerate/accelerator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 77c72eba6a4..aa7a3dc6069 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1729,6 +1729,10 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N >>> optimizer = accelerator.prepare_optimizer(optimizer, device_placement=True) ``` """ + if isinstance(optimizer, AcceleratedOptimizer): + if optimizer not in self._optimizers: + self._optimizers.append(optimizer) + return optimizer if device_placement is None: device_placement = self.device_placement optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler) From 8e13cf56b217cff45ab792d931e909f6cd37343e Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 12:42:23 -0400 Subject: [PATCH 5/8] Check the rest --- src/accelerate/accelerator.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index aa7a3dc6069..1b328a2bac3 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -32,7 +32,7 @@ import torch.utils.hooks as hooks from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, DataLoaderShard, prepare_data_loader, skip_first_batches from .logging import get_logger from .optimizer import AcceleratedOptimizer from .scheduler import AcceleratedScheduler @@ -120,6 +120,8 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp + from .data_loader import MpDeviceLoaderWrapper + try: from torch.optim.lr_scheduler import LRScheduler @@ -1691,6 +1693,13 @@ def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader, device_p >>> data_loader = accelerator.prepare_data_loader(data_loader, device_placement=True) ``` """ + # Ensure we can't double wrap a DataLoader due to `find_batch_size` + if isinstance(data_loader, (DataLoaderDispatcher, DataLoaderShard)) or ( + is_tpu_available() and isinstance(data_loader, MpDeviceLoaderWrapper) + ): + if data_loader not in self._dataloaders: + self._dataloaders.append(data_loader) + return data_loader if device_placement is None: device_placement = self.device_placement if self.distributed_type != DistributedType.TPU else False prepared_data_loader = prepare_data_loader( @@ -1729,6 +1738,7 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N >>> optimizer = accelerator.prepare_optimizer(optimizer, device_placement=True) ``` """ + # Ensure we can't double wrap an optimizer due to `find_batch_size` if isinstance(optimizer, AcceleratedOptimizer): if optimizer not in self._optimizers: self._optimizers.append(optimizer) @@ -1760,6 +1770,11 @@ def prepare_scheduler(self, scheduler: LRScheduler): >>> scheduler = accelerator.prepare_scheduler(scheduler) ``` """ + # Ensure we can't double wrap a scheduler due to `find_batch_size` + if isinstance(scheduler, AcceleratedScheduler): + if scheduler not in self._schedulers: + self._schedulers.append(scheduler) + return scheduler # We try to find the optimizer associated with `scheduler`, the default is the full list. optimizer = self._optimizers for opt in self._optimizers: From 5e6d34ac64069d7898d264e2978eec31ee9bf3ee Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 12:43:16 -0400 Subject: [PATCH 6/8] Try with just check on init --- src/accelerate/optimizer.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/accelerate/optimizer.py b/src/accelerate/optimizer.py index d9a9fe811c1..d5eeef99a02 100644 --- a/src/accelerate/optimizer.py +++ b/src/accelerate/optimizer.py @@ -53,24 +53,21 @@ class AcceleratedOptimizer(torch.optim.Optimizer): """ def __init__(self, optimizer, device_placement=True, scaler=None): - if not isinstance(optimizer, AcceleratedOptimizer): - self.optimizer = optimizer - self.scaler = scaler - self.accelerator_state = AcceleratorState() - self.gradient_state = GradientState() - self.device_placement = device_placement - self._is_overflow = False - - # Handle device placement - if device_placement: - state_dict = self.optimizer.state_dict() - if self.accelerator_state.distributed_type == DistributedType.TPU: - xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) - else: - state_dict = move_to_device(state_dict, self.accelerator_state.device) - self.optimizer.load_state_dict(state_dict) - else: - self.optimizer = optimizer.optimizer + self.optimizer = optimizer + self.scaler = scaler + self.accelerator_state = AcceleratorState() + self.gradient_state = GradientState() + self.device_placement = device_placement + self._is_overflow = False + + # Handle device placement + if device_placement: + state_dict = self.optimizer.state_dict() + if self.accelerator_state.distributed_type == DistributedType.TPU: + xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) + else: + state_dict = move_to_device(state_dict, self.accelerator_state.device) + self.optimizer.load_state_dict(state_dict) @property def state(self): From 01c73718eb318549e800c6fb64b029345c721d9a Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 13:00:32 -0400 Subject: [PATCH 7/8] Change logic based on review --- src/accelerate/accelerator.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 1b328a2bac3..ee07584194c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -32,7 +32,7 @@ import torch.utils.hooks as hooks from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state -from .data_loader import DataLoaderDispatcher, DataLoaderShard, prepare_data_loader, skip_first_batches +from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches from .logging import get_logger from .optimizer import AcceleratedOptimizer from .scheduler import AcceleratedScheduler @@ -120,8 +120,6 @@ import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp - from .data_loader import MpDeviceLoaderWrapper - try: from torch.optim.lr_scheduler import LRScheduler @@ -1199,6 +1197,9 @@ def prepare(self, *args, device_placement=None): if self.distributed_type == DistributedType.FSDP and model_count == 1 and optimizer_present: result = self._prepare_fsdp(*result) + for item in result: + setattr(item, "_is_accelerate_prepared", True) + return result if len(result) > 1 else result[0] def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False): @@ -1694,12 +1695,10 @@ def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader, device_p ``` """ # Ensure we can't double wrap a DataLoader due to `find_batch_size` - if isinstance(data_loader, (DataLoaderDispatcher, DataLoaderShard)) or ( - is_tpu_available() and isinstance(data_loader, MpDeviceLoaderWrapper) - ): + if not getattr(data_loader, "_accelerator_prepared", False): if data_loader not in self._dataloaders: self._dataloaders.append(data_loader) - return data_loader + return data_loader if device_placement is None: device_placement = self.device_placement if self.distributed_type != DistributedType.TPU else False prepared_data_loader = prepare_data_loader( @@ -1739,7 +1738,7 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N ``` """ # Ensure we can't double wrap an optimizer due to `find_batch_size` - if isinstance(optimizer, AcceleratedOptimizer): + if not getattr(optimizer, "_accelerator_prepared", False): if optimizer not in self._optimizers: self._optimizers.append(optimizer) return optimizer @@ -1771,7 +1770,7 @@ def prepare_scheduler(self, scheduler: LRScheduler): ``` """ # Ensure we can't double wrap a scheduler due to `find_batch_size` - if isinstance(scheduler, AcceleratedScheduler): + if not getattr(scheduler, "_accelerator_prepared", False): if scheduler not in self._schedulers: self._schedulers.append(scheduler) return scheduler From 31774bf528e9963491224209583f167d5c73b2c1 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 7 Jun 2023 13:02:28 -0400 Subject: [PATCH 8/8] Update