diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py index 7c566480a86a..72d7c863a33f 100644 --- a/deepspeed/utils/nvtx.py +++ b/deepspeed/utils/nvtx.py @@ -4,19 +4,21 @@ # DeepSpeed Team from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.compiler import is_compiling enable_nvtx = True def instrument_w_nvtx(func): - """decorator that causes an NVTX range to be recorded for the duration of the - function call.""" + """Decorator that records an NVTX range for the duration of the function call. + Skips NVTX instrumentation when torch.compile is active to avoid graph breaks. + """ def wrapped_fn(*args, **kwargs): - if enable_nvtx: + if enable_nvtx and not is_compiling(): get_accelerator().range_push(func.__qualname__) ret_val = func(*args, **kwargs) - if enable_nvtx: + if enable_nvtx and not is_compiling(): get_accelerator().range_pop() return ret_val