Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 18, 2024
1 parent 17e7d98 commit 8647e74
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,13 +846,12 @@ def empty():

tds = tuple(empty() for _ in range(self.batch_size[dim]))

def unbind(key_val, tds=tds):
key, val = key_val
def unbind(key, val, tds=tds):
for td, _val in zip(tds, val.unbind(dim)):
td._set_str(key, _val, validated=True, inplace=False)

with ThreadPoolExecutor(max_workers=16) as executor:
executor.map(unbind, self.items())
for key, val in self.items():
unbind(key, val)
return tds

def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]:
Expand Down

0 comments on commit 8647e74

Please sign in to comment.