diff --git a/tensordict/base.py b/tensordict/base.py index 6b7afd051..6c600b11f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9873,15 +9873,14 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return cls.from_dict(obj, auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) - if isinstance(obj, tuple): + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) return cls.from_tuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, list): return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) if is_dataclass(obj): return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) - if is_namedtuple(obj): - return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) if _has_h5: import h5py