From eb4a56ed687cf7f15e4182645f1c35e738231664 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 09:28:46 -0800 Subject: [PATCH] [BugFix] Better comparison of tensorclasses ghstack-source-id: 8def6f01f2b6d09714319a56f96b166ac1fd49d5 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1137 --- tensordict/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensordict/utils.py b/tensordict/utils.py index ae6d3b44f..6a3184d83 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1548,6 +1548,11 @@ def assert_close( from tensordict._lazy import LazyStackedTensorDict + if is_tensorclass(actual): + actual = actual._tensordict + if is_tensorclass(expected): + expected = expected._tensordict + if isinstance(actual, LazyStackedTensorDict) and isinstance( expected, LazyStackedTensorDict ):