From d2428b316b4010a0fe205e7c8ff8572878779f6e Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 4 Jul 2024 09:01:13 -0400 Subject: [PATCH] Fix bug --- src/accelerate/utils/memory.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/accelerate/utils/memory.py b/src/accelerate/utils/memory.py index 97af2726d13..adfb0e1e083 100644 --- a/src/accelerate/utils/memory.py +++ b/src/accelerate/utils/memory.py @@ -26,8 +26,13 @@ from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available -def clear_device_cache(): - gc.collect() +def clear_device_cache(garbage_collection=False): + """ + Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that + this is a *considerable* slowdown and should be used sparingly. + """ + if garbage_collection: + gc.collect() if is_xpu_available(): torch.xpu.empty_cache() @@ -67,7 +72,7 @@ def release_memory(*objects): objects = list(objects) for i in range(len(objects)): objects[i] = None - clear_device_cache() + clear_device_cache(garbage_collection=True) return objects @@ -123,7 +128,7 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i def decorator(*args, **kwargs): nonlocal batch_size - clear_device_cache() + clear_device_cache(garbage_collection=True) params = list(inspect.signature(function).parameters.keys()) # Guard against user error if len(params) < (len(args) + 1): @@ -139,7 +144,7 @@ def decorator(*args, **kwargs): return function(batch_size, *args, **kwargs) except Exception as e: if should_reduce_batch_size(e): - clear_device_cache() + clear_device_cache(garbage_collection=True) batch_size //= 2 else: raise