diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index bc07b7689..07e355746 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -111,10 +111,8 @@ def _maybe_make_param(tensor): def _maybe_make_param_or_buffer(tensor): - if ( - isinstance(tensor, (Tensor, ftdim.Tensor)) - and not isinstance(tensor, (nn.Parameter, Buffer)) - and tensor.dtype in (torch.float, torch.double, torch.half) + if isinstance(tensor, (Tensor, ftdim.Tensor)) and not isinstance( + tensor, (nn.Parameter, Buffer) ): if not tensor.requires_grad and not is_batchedtensor(tensor): # convert all non-parameters to buffers diff --git a/tensordict/utils.py b/tensordict/utils.py index 25516ea83..6f4fa239f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -68,12 +68,9 @@ except ImportError: _has_funcdim = False try: - from torch.compiler import assume_constant_result, is_dynamo_compiling + from torch.compiler import assume_constant_result except ImportError: # torch 2.0 - from torch._dynamo import ( - assume_constant_result, - is_compiling as is_dynamo_compiling, - ) + from torch._dynamo import assume_constant_result if TYPE_CHECKING: from tensordict.tensordict import TensorDictBase @@ -2825,7 +2822,8 @@ def _is_dataclass(obj): if isinstance(obj, type) and not isinstance(obj, GenericAlias) else type(obj) ) - return hasattr(cls, _FIELDS) + # return hasattr(cls, _FIELDS) + return getattr(cls, _FIELDS, None) is not None def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]: