Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix slowdown on init with device_map="auto" #2914

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@
from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available


def clear_device_cache():
gc.collect()
def clear_device_cache(garbage_collection=False):
"""
Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that
this is a *considerable* slowdown and should be used sparingly.
"""
if garbage_collection:
gc.collect()

if is_xpu_available():
torch.xpu.empty_cache()
Expand Down Expand Up @@ -67,7 +72,7 @@ def release_memory(*objects):
objects = list(objects)
for i in range(len(objects)):
objects[i] = None
clear_device_cache()
clear_device_cache(garbage_collection=True)
return objects


Expand Down Expand Up @@ -123,7 +128,7 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i

def decorator(*args, **kwargs):
nonlocal batch_size
clear_device_cache()
clear_device_cache(garbage_collection=True)
params = list(inspect.signature(function).parameters.keys())
# Guard against user error
if len(params) < (len(args) + 1):
Expand All @@ -139,7 +144,7 @@ def decorator(*args, **kwargs):
return function(batch_size, *args, **kwargs)
except Exception as e:
if should_reduce_batch_size(e):
clear_device_cache()
clear_device_cache(garbage_collection=True)
batch_size //= 2
else:
raise
Expand Down
Loading