Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid graph breaks by disabling sourceless calls in instrument_w_nvtx #7081

10 changes: 6 additions & 4 deletions deepspeed/utils/nvtx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
# DeepSpeed Team

from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.compiler import is_compiling

enable_nvtx = True


def instrument_w_nvtx(func):
"""decorator that causes an NVTX range to be recorded for the duration of the
function call."""
"""Decorator that records an NVTX range for the duration of the function call.
Skips NVTX instrumentation when torch.compile is active to avoid graph breaks.
"""

def wrapped_fn(*args, **kwargs):
if enable_nvtx:
if enable_nvtx and not is_compiling():
get_accelerator().range_push(func.__qualname__)
ret_val = func(*args, **kwargs)
if enable_nvtx:
if enable_nvtx and not is_compiling():
get_accelerator().range_pop()
return ret_val

Expand Down