diff --git a/tensordict/utils.py b/tensordict/utils.py index 6f4fa239f..0e370856f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -68,9 +68,9 @@ except ImportError: _has_funcdim = False try: - from torch.compiler import assume_constant_result + from torch.compiler import assume_constant_result, is_compiling except ImportError: # torch 2.0 - from torch._dynamo import assume_constant_result + from torch._dynamo import assume_constant_result, is_compiling if TYPE_CHECKING: from tensordict.tensordict import TensorDictBase @@ -861,7 +861,7 @@ def _is_tensorclass(cls: type) -> bool: out = _TENSORCLASS_MEMO.get(cls) if out is None: out = getattr(cls, "_is_tensorclass", False) - if not is_dynamo_compiling(): + if not is_compiling(): _TENSORCLASS_MEMO[cls] = out return out @@ -1117,7 +1117,7 @@ def cache(fun): @wraps(fun) def newfun(_self: "TensorDictBase", *args, **kwargs): - if not _self.is_locked or is_dynamo_compiling(): + if not _self.is_locked or is_compiling(): return fun(_self, *args, **kwargs) cache = _self._cache if cache is None: @@ -1357,7 +1357,7 @@ def _parse_to(*args, **kwargs): num_threads = kwargs.pop("num_threads", None) other = kwargs.pop("other", None) inplace = kwargs.pop("inplace", False) - if not is_dynamo_compiling(): + if not is_compiling(): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( *args, **kwargs ) @@ -1731,7 +1731,7 @@ def _check_keys( is_leaf=_is_leaf_nontensor, ) # TODO: compile doesn't like set() over an arbitrary object - if is_dynamo_compiling(): + if is_compiling(): keys = {k for k in keys} # noqa: C416 else: keys: set[str] = set(keys) @@ -1744,7 +1744,7 @@ def _check_keys( if not strict: keys = keys.intersection(k) else: - if is_dynamo_compiling(): + if is_compiling(): k = {v for v in k} # noqa: C416 else: k = set(k) @@ -2013,7 +2013,7 @@ def _getitem_batch_size(batch_size, index): continue elif isinstance(idx, slice): batch = batch_size[count] - if is_dynamo_compiling(): + if is_compiling(): out.append(len(range(*_slice_indices(idx, batch)))) else: out.append(len(range(*idx.indices(batch)))) @@ -2445,7 +2445,7 @@ def is_non_tensor(data): def _is_non_tensor(cls: type): out = None - is_dynamo = is_dynamo_compiling() + is_dynamo = is_compiling() if not is_dynamo: out = _NON_TENSOR_MEMO.get(cls) if out is None: @@ -2501,7 +2501,7 @@ def new_func(self): def _unravel_key_to_tuple(key): - if not is_dynamo_compiling(): + if not is_compiling(): return _unravel_key_to_tuple_cpp(key) if isinstance(key, str): return (key,) @@ -2522,7 +2522,7 @@ def unravel_key(key): ("a", "b") """ - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_key_cpp(key) if isinstance(key, str): return key @@ -2535,14 +2535,14 @@ def unravel_key(key): def unravel_keys(*keys): """Unravels a sequence of keys.""" - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_keys_cpp(*keys) return tuple(unravel_key(key) for key in keys) def unravel_key_list(keys): """Unravels a list of keys.""" - if not is_dynamo_compiling(): + if not is_compiling(): return unravel_key_list_cpp(keys) return [unravel_key(key) for key in keys] @@ -2865,11 +2865,11 @@ def __init__(self, default=None): self._lock = threading.Lock() def get_mode(self) -> Any | None: - cm = self._lock if not is_dynamo_compiling() else nullcontext() + cm = self._lock if not is_compiling() else nullcontext() with cm: return self._mode def set_mode(self, type: Any | None) -> None: - cm = self._lock if not is_dynamo_compiling() else nullcontext() + cm = self._lock if not is_compiling() else nullcontext() with cm: self._mode = type