From bbff132a56816febe66a560d691cb885c58fd49d Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 21:03:24 +0000 Subject: [PATCH] amend --- tensordict/utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tensordict/utils.py b/tensordict/utils.py index afbfef172..9ff4ef93d 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1445,16 +1445,20 @@ def _set_max_batch_size(source: T, batch_dims=None): tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)] - batch_size = [] - if not tensor_data: # when source is empty - source.batch_size = batch_size - return - for val in tensor_data: from tensordict.base import _is_tensor_collection if _is_tensor_collection(val.__class__): _set_max_batch_size(val, batch_dims=batch_dims) + + batch_size = [] + if not tensor_data: # when source is empty + if batch_dims: + source.batch_size = source.batch_size[:batch_dims] + return source + else: + return source + curr_dim = 0 while True: if tensor_data[0].dim() > curr_dim: