Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid double wrapping of all accelerate.prepare objects #1554

Closed
wants to merge 8 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,9 @@ def prepare(self, *args, device_placement=None):
if self.distributed_type == DistributedType.FSDP and model_count == 1 and optimizer_present:
result = self._prepare_fsdp(*result)

for item in result:
setattr(item, "_is_accelerate_prepared", True)

return result if len(result) > 1 else result[0]

def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
Expand Down Expand Up @@ -1691,6 +1694,11 @@ def prepare_data_loader(self, data_loader: torch.utils.data.DataLoader, device_p
>>> data_loader = accelerator.prepare_data_loader(data_loader, device_placement=True)
```
"""
# Ensure we can't double wrap a DataLoader due to `find_batch_size`
if not getattr(data_loader, "_accelerator_prepared", False):
if data_loader not in self._dataloaders:
self._dataloaders.append(data_loader)
return data_loader
if device_placement is None:
device_placement = self.device_placement if self.distributed_type != DistributedType.TPU else False
prepared_data_loader = prepare_data_loader(
Expand Down Expand Up @@ -1729,6 +1737,11 @@ def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=N
>>> optimizer = accelerator.prepare_optimizer(optimizer, device_placement=True)
```
"""
# Ensure we can't double wrap an optimizer due to `find_batch_size`
if not getattr(optimizer, "_accelerator_prepared", False):
if optimizer not in self._optimizers:
self._optimizers.append(optimizer)
return optimizer
if device_placement is None:
device_placement = self.device_placement
optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler)
Expand Down Expand Up @@ -1756,6 +1769,11 @@ def prepare_scheduler(self, scheduler: LRScheduler):
>>> scheduler = accelerator.prepare_scheduler(scheduler)
```
"""
# Ensure we can't double wrap a scheduler due to `find_batch_size`
if not getattr(scheduler, "_accelerator_prepared", False):
if scheduler not in self._schedulers:
self._schedulers.append(scheduler)
return scheduler
# We try to find the optimizer associated with `scheduler`, the default is the full list.
optimizer = self._optimizers
for opt in self._optimizers:
Expand Down Expand Up @@ -2546,7 +2564,7 @@ def load_state(self, input_dir: str, **load_model_func_kwargs):
def free_memory(self):
"""
Will release all references to the internal objects stored and call the garbage collector. You should call this
method between two trainings with different models/optimizers.
method between two trainings with different models/optimizers. Also will reset `Accelerator.step` to 0.

Example:

Expand All @@ -2565,6 +2583,7 @@ def free_memory(self):
self._models = []
self._dataloaders = []
self.deepspeed_engine_wrapped = None
self.step = 0
release_memory()

def clear(self):
Expand Down