From f6d12540f053fd372c7484f4251066593281dbf1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 18 Jan 2024 09:27:50 +0000 Subject: [PATCH] amend --- tensordict/_td.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 927e8d423..1737741f0 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -846,7 +846,8 @@ def empty(): def unbind(key, val, tds=tds): unbound = ( val.unbind(dim) - if not _is_tensor_collection(type(val)) + if not isinstance(val, TensorDictBase) + # tensorclass is also unbound using plain unbind else val._unbind(dim) ) for td, _val in zip(tds, unbound):