Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 22, 2024
1 parent a3f9500 commit bbff132
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,16 +1445,20 @@ def _set_max_batch_size(source: T, batch_dims=None):

tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)]

batch_size = []
if not tensor_data: # when source is empty
source.batch_size = batch_size
return

for val in tensor_data:
from tensordict.base import _is_tensor_collection

if _is_tensor_collection(val.__class__):
_set_max_batch_size(val, batch_dims=batch_dims)

batch_size = []
if not tensor_data: # when source is empty
if batch_dims:
source.batch_size = source.batch_size[:batch_dims]
return source
else:
return source

curr_dim = 0
while True:
if tensor_data[0].dim() > curr_dim:
Expand Down

0 comments on commit bbff132

Please sign in to comment.