-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add AvoidOOM to avoid OOM (#7434)
* [Feature] Add AvoidOOM to avoid OOM * support multiple outputs * add docs in faq * add docs in faq * fix logic * minor fix * minor fix * minor fix * minor fix * add the tutorials of using avoidoom as a decorator * minor fix * add convert tensor type test unit * minor fix * minor fix
- Loading branch information
1 parent
d18cdb1
commit 7b03639
Showing
5 changed files
with
356 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import warnings | ||
from collections import abc | ||
from contextlib import contextmanager | ||
from functools import wraps | ||
|
||
import torch | ||
|
||
from mmdet.utils import get_root_logger | ||
|
||
|
||
def cast_tensor_type(inputs, src_type=None, dst_type=None): | ||
"""Recursively convert Tensor in inputs from ``src_type`` to ``dst_type``. | ||
Args: | ||
inputs: Inputs that to be casted. | ||
src_type (torch.dtype | torch.device): Source type. | ||
src_type (torch.dtype | torch.device): Destination type. | ||
Returns: | ||
The same type with inputs, but all contained Tensors have been cast. | ||
""" | ||
assert dst_type is not None | ||
if isinstance(inputs, torch.Tensor): | ||
if isinstance(dst_type, torch.device): | ||
# convert Tensor to dst_device | ||
if hasattr(inputs, 'to') and \ | ||
hasattr(inputs, 'device') and \ | ||
(inputs.device == src_type or src_type is None): | ||
return inputs.to(dst_type) | ||
else: | ||
return inputs | ||
else: | ||
# convert Tensor to dst_dtype | ||
if hasattr(inputs, 'to') and \ | ||
hasattr(inputs, 'dtype') and \ | ||
(inputs.dtype == src_type or src_type is None): | ||
return inputs.to(dst_type) | ||
else: | ||
return inputs | ||
# we need to ensure that the type of inputs to be casted are the same | ||
# as the argument `src_type`. | ||
elif isinstance(inputs, abc.Mapping): | ||
return type(inputs)({ | ||
k: cast_tensor_type(v, src_type=src_type, dst_type=dst_type) | ||
for k, v in inputs.items() | ||
}) | ||
elif isinstance(inputs, abc.Iterable): | ||
return type(inputs)( | ||
cast_tensor_type(item, src_type=src_type, dst_type=dst_type) | ||
for item in inputs) | ||
# TODO: Currently not supported | ||
# elif isinstance(inputs, InstanceData): | ||
# for key, value in inputs.items(): | ||
# inputs[key] = cast_tensor_type( | ||
# value, src_type=src_type, dst_type=dst_type) | ||
# return inputs | ||
else: | ||
return inputs | ||
|
||
|
||
@contextmanager | ||
def _ignore_torch_cuda_oom(): | ||
"""A context which ignores CUDA OOM exception from pytorch. | ||
Code is modified from | ||
<https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py> # noqa: E501 | ||
""" | ||
try: | ||
yield | ||
except RuntimeError as e: | ||
# NOTE: the string may change? | ||
if 'CUDA out of memory. ' in str(e): | ||
pass | ||
else: | ||
raise | ||
|
||
|
||
class AvoidOOM: | ||
"""Try to convert inputs to FP16 and CPU if got a PyTorch's CUDA Out of | ||
Memory error. It will do the following steps: | ||
1. First retry after calling `torch.cuda.empty_cache()`. | ||
2. If that still fails, it will then retry by converting inputs | ||
to FP16. | ||
3. If that still fails trying to convert inputs to CPUs. | ||
In this case, it expects the function to dispatch to | ||
CPU implementation. | ||
Args: | ||
to_cpu (bool): Whether to convert outputs to CPU if get an OOM | ||
error. This will slow down the code significantly. | ||
Defaults to True. | ||
test (bool): Skip `_ignore_torch_cuda_oom` operate that can use | ||
lightweight data in unit test, only used in | ||
test unit. Defaults to False. | ||
Examples: | ||
>>> from mmdet.utils.memory import AvoidOOM | ||
>>> AvoidCUDAOOM = AvoidOOM() | ||
>>> output = AvoidOOM.retry_if_cuda_oom( | ||
>>> some_torch_function)(input1, input2) | ||
>>> # To use as a decorator | ||
>>> # from mmdet.utils import AvoidCUDAOOM | ||
>>> @AvoidCUDAOOM.retry_if_cuda_oom | ||
>>> def function(*args, **kwargs): | ||
>>> return None | ||
``` | ||
Note: | ||
1. The output may be on CPU even if inputs are on GPU. Processing | ||
on CPU will slow down the code significantly. | ||
2. When converting inputs to CPU, it will only look at each argument | ||
and check if it has `.device` and `.to` for conversion. Nested | ||
structures of tensors are not supported. | ||
3. Since the function might be called more than once, it has to be | ||
stateless. | ||
""" | ||
|
||
def __init__(self, to_cpu=True, test=False): | ||
self.logger = get_root_logger() | ||
self.to_cpu = to_cpu | ||
self.test = test | ||
|
||
def retry_if_cuda_oom(self, func): | ||
"""Makes a function retry itself after encountering pytorch's CUDA OOM | ||
error. | ||
The implementation logic is referred to | ||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/memory.py | ||
Args: | ||
func: a stateless callable that takes tensor-like objects | ||
as arguments. | ||
Returns: | ||
func: a callable which retries `func` if OOM is encountered. | ||
""" # noqa: W605 | ||
|
||
@wraps(func) | ||
def wrapped(*args, **kwargs): | ||
|
||
# raw function | ||
if not self.test: | ||
with _ignore_torch_cuda_oom(): | ||
return func(*args, **kwargs) | ||
|
||
# Clear cache and retry | ||
torch.cuda.empty_cache() | ||
with _ignore_torch_cuda_oom(): | ||
return func(*args, **kwargs) | ||
|
||
# get the type and device of first tensor | ||
dtype, device = None, None | ||
values = args + tuple(kwargs.values()) | ||
for value in values: | ||
if isinstance(value, torch.Tensor): | ||
dtype = value.dtype | ||
device = value.device | ||
break | ||
if dtype is None or device is None: | ||
raise ValueError('There is no tensor in the inputs, ' | ||
'cannot get dtype and device.') | ||
|
||
# Convert to FP16 | ||
fp16_args = cast_tensor_type(args, dst_type=torch.half) | ||
fp16_kwargs = cast_tensor_type(kwargs, dst_type=torch.half) | ||
self.logger.info(f'Attempting to copy inputs of {str(func)} ' | ||
f'to FP16 due to CUDA OOM') | ||
|
||
# get input tensor type, the output type will same as | ||
# the first parameter type. | ||
with _ignore_torch_cuda_oom(): | ||
output = func(*fp16_args, **fp16_kwargs) | ||
output = cast_tensor_type( | ||
output, src_type=torch.half, dst_type=dtype) | ||
if not self.test: | ||
return output | ||
self.logger.info('Using FP16 still meet CUDA OOM') | ||
|
||
# Try on CPU. This will slow down the code significantly, | ||
# therefore print a notice. | ||
if self.to_cpu: | ||
self.logger.info(f'Attempting to copy inputs of {str(func)} ' | ||
f'to CPU due to CUDA OOM') | ||
cpu_device = torch.empty(0).device | ||
cpu_args = cast_tensor_type(args, dst_type=cpu_device) | ||
cpu_kwargs = cast_tensor_type(kwargs, dst_type=cpu_device) | ||
|
||
# convert outputs to GPU | ||
with _ignore_torch_cuda_oom(): | ||
self.logger.info(f'Convert outputs to GPU ' | ||
f'(device={device})') | ||
output = func(*cpu_args, **cpu_kwargs) | ||
output = cast_tensor_type( | ||
output, src_type=cpu_device, dst_type=device) | ||
return output | ||
|
||
warnings.warn('Cannot convert output to GPU due to CUDA OOM, ' | ||
'the output is now on CPU, which might cause ' | ||
'errors if the output need to interact with GPU ' | ||
'data in subsequent operations') | ||
self.logger.info('Cannot convert output to GPU due to ' | ||
'CUDA OOM, the output is on CPU now.') | ||
|
||
return func(*cpu_args, **cpu_kwargs) | ||
else: | ||
# may still get CUDA OOM error | ||
return func(*args, **kwargs) | ||
|
||
return wrapped | ||
|
||
|
||
# To use AvoidOOM as a decorator | ||
AvoidCUDAOOM = AvoidOOM() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from mmdet.utils import AvoidOOM | ||
from mmdet.utils.memory import cast_tensor_type | ||
|
||
|
||
def test_avoidoom(): | ||
tensor = torch.from_numpy(np.random.random((20, 20))) | ||
if torch.cuda.is_available(): | ||
tensor = tensor.cuda() | ||
# get default result | ||
default_result = torch.mm(tensor, tensor.transpose(1, 0)) | ||
|
||
# when not occurred OOM error | ||
AvoidCudaOOM = AvoidOOM() | ||
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor, | ||
tensor.transpose( | ||
1, 0)) | ||
assert default_result.device == result.device and \ | ||
default_result.dtype == result.dtype and \ | ||
torch.equal(default_result, result) | ||
|
||
# calculate with fp16 and convert back to source type | ||
AvoidCudaOOM = AvoidOOM(test=True) | ||
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor, | ||
tensor.transpose( | ||
1, 0)) | ||
assert default_result.device == result.device and \ | ||
default_result.dtype == result.dtype and \ | ||
torch.allclose(default_result, result, 1e-3) | ||
|
||
# calculate on cpu and convert back to source device | ||
AvoidCudaOOM = AvoidOOM(test=True) | ||
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor, | ||
tensor.transpose( | ||
1, 0)) | ||
assert result.dtype == default_result.dtype and \ | ||
result.device == default_result.device and \ | ||
torch.allclose(default_result, result) | ||
|
||
# do not calculate on cpu and the outputs will be same as input | ||
AvoidCudaOOM = AvoidOOM(test=True, to_cpu=False) | ||
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor, | ||
tensor.transpose( | ||
1, 0)) | ||
assert result.dtype == default_result.dtype and \ | ||
result.device == default_result.device | ||
|
||
else: | ||
default_result = torch.mm(tensor, tensor.transpose(1, 0)) | ||
AvoidCudaOOM = AvoidOOM() | ||
result = AvoidCudaOOM.retry_if_cuda_oom(torch.mm)(tensor, | ||
tensor.transpose( | ||
1, 0)) | ||
assert default_result.device == result.device and \ | ||
default_result.dtype == result.dtype and \ | ||
torch.equal(default_result, result) | ||
|
||
|
||
def test_cast_tensor_type(): | ||
inputs = torch.rand(10) | ||
if torch.cuda.is_available(): | ||
inputs = inputs.cuda() | ||
with pytest.raises(AssertionError): | ||
cast_tensor_type(inputs, src_type=None, dst_type=None) | ||
# input is a float | ||
out = cast_tensor_type(10., dst_type=torch.half) | ||
assert out == 10. and isinstance(out, float) | ||
# convert Tensor to fp16 and re-convert to fp32 | ||
fp16_out = cast_tensor_type(inputs, dst_type=torch.half) | ||
assert fp16_out.dtype == torch.half | ||
fp32_out = cast_tensor_type(fp16_out, dst_type=torch.float32) | ||
assert fp32_out.dtype == torch.float32 | ||
|
||
# input is a list | ||
list_input = [inputs, inputs] | ||
list_outs = cast_tensor_type(list_input, dst_type=torch.half) | ||
assert len(list_outs) == len(list_input) and \ | ||
isinstance(list_outs, list) | ||
for out in list_outs: | ||
assert out.dtype == torch.half | ||
# input is a dict | ||
dict_input = {'test1': inputs, 'test2': inputs} | ||
dict_outs = cast_tensor_type(dict_input, dst_type=torch.half) | ||
assert len(dict_outs) == len(dict_input) and \ | ||
isinstance(dict_outs, dict) | ||
|
||
# convert the input tensor to CPU and re-convert to GPU | ||
if torch.cuda.is_available(): | ||
cpu_device = torch.empty(0).device | ||
gpu_device = inputs.device | ||
cpu_out = cast_tensor_type(inputs, dst_type=cpu_device) | ||
assert cpu_out.device == cpu_device | ||
|
||
gpu_out = cast_tensor_type(inputs, dst_type=gpu_device) | ||
assert gpu_out.device == gpu_device |