From 849270cfc6131371e26e3625ca4f3244a8ec7067 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 23 Nov 2024 15:33:29 +0100 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- tensordict/base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 6b7afd051..6c600b11f 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9873,15 +9873,14 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return cls.from_dict(obj, auto_batch_size=auto_batch_size) if isinstance(obj, np.ndarray) and hasattr(obj.dtype, "names"): return cls.from_struct_array(obj, auto_batch_size=auto_batch_size) - if isinstance(obj, tuple): + if is_namedtuple(obj): + return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) return cls.from_tuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, list): return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) if is_dataclass(obj): return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) - if is_namedtuple(obj): - return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) if _has_h5: import h5py