diff --git a/tensordict/utils.py b/tensordict/utils.py index ae6d3b44f..6a3184d83 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1548,6 +1548,11 @@ def assert_close( from tensordict._lazy import LazyStackedTensorDict + if is_tensorclass(actual): + actual = actual._tensordict + if is_tensorclass(expected): + expected = expected._tensordict + if isinstance(actual, LazyStackedTensorDict) and isinstance( expected, LazyStackedTensorDict ):