Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 16, 2024
1 parent f5a514d commit 88691d0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
7 changes: 3 additions & 4 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,7 +2283,7 @@ def _transpose(self, dim0, dim1):
if dim1 == dim0 + 1:
return LazyStackedTensorDict(*self.tensordicts, stack_dim=dim1)
return LazyStackedTensorDict(
*map(lambda td: td.transpose(dim0, dim1 - 1), self.tensordicts),
*(td.transpose(dim0, dim1 - 1) for td in self.tensordicts),
stack_dim=dim1,
)
elif dim1 == self.stack_dim:
Expand All @@ -2292,14 +2292,14 @@ def _transpose(self, dim0, dim1):
if dim0 + 1 == dim1:
return LazyStackedTensorDict(*self.tensordicts, stack_dim=dim0)
return LazyStackedTensorDict(
*map(lambda td: td.transpose(dim0 + 1, dim1), self.tensordicts),
*(td.transpose(dim0 + 1, dim1) for td in self.tensordicts),
stack_dim=dim0,
)
else:
dim0 = dim0 if dim0 < self.stack_dim else dim0 - 1
dim1 = dim1 if dim1 < self.stack_dim else dim1 - 1
return LazyStackedTensorDict(
*map(lambda td: td.transpose(dim0, dim1), self.tensordicts),
*(td.transpose(dim0, dim1) for td in self.tensordicts),
stack_dim=self.stack_dim,
)

Expand Down Expand Up @@ -2485,7 +2485,6 @@ def _set_tuple(self, key, value, *, inplace: bool, validated: bool):
return self._set_str(key[0], value, inplace=inplace, validated=validated)
source = self._source._get_str(key[0], None)
if source is None:
print("self", self)
source = self._source._create_nested_str(key[0])
nested = type(self)(
source,
Expand Down
12 changes: 11 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3346,7 +3346,17 @@ def test_select_exception(self, td_name, device, strict):
assert len(list(td2.keys())) == 0

def test_set_lazy_legacy(self, td_name, device):
if td_name in ("sub_td", "sub_td2", "td_h5", "squeezed_td", "unsqueezed_td", "permute_td", "transpose_td", "nested_stacked_td", "stacked_td"):
if td_name in (
"sub_td",
"sub_td2",
"td_h5",
"squeezed_td",
"unsqueezed_td",
"permute_td",
"transpose_td",
"nested_stacked_td",
"stacked_td",
):
raiser = pytest.raises(RuntimeError)
else:
raiser = contextlib.nullcontext()
Expand Down

0 comments on commit 88691d0

Please sign in to comment.