Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 17, 2024
1 parent 7639c79 commit 4550154
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5744,19 +5744,20 @@ def test_stack_keys(self):
assert "e" in td.keys() # now all tds have the key c
td.get("e")

def test_stack_memmap(self):
td = TensorDict({"a": [[1, 2]], "b": {"c": [[3, 4]]}}, [1, 2]).memmap_()
tdstack = torch.stack([td, td])
td_select = tdstack.select()
td_exclude = tdstack.exclude(*tdstack.keys(True))
td_exclude2 = tdstack.exclude(*tdstack.keys(True, True))
assert td_select.is_memmap()
assert td_select.is_locked
assert td_exclude.is_memmap()
assert td_exclude.is_locked
assert td_exclude2.is_memmap()
assert td_exclude2.is_locked
assert all(_td.is_locked for _td in td_exclude2.values(True))
# deprecated behaviour
# def test_stack_memmap(self):
# td = TensorDict({"a": [[1, 2]], "b": {"c": [[3, 4]]}}, [1, 2]).memmap_()
# tdstack = torch.stack([td, td])
# td_select = tdstack.select()
# td_exclude = tdstack.exclude(*tdstack.keys(True))
# td_exclude2 = tdstack.exclude(*tdstack.keys(True, True))
# assert td_select.is_memmap()
# assert td_select.is_locked
# assert td_exclude.is_memmap()
# assert td_exclude.is_locked
# assert td_exclude2.is_memmap()
# assert td_exclude2.is_locked
# assert all(_td.is_locked for _td in td_exclude2.values(True))

@pytest.mark.parametrize("unsqueeze_dim", [0, 1, -1, -2])
def test_stack_unsqueeze(self, unsqueeze_dim):
Expand Down

0 comments on commit 4550154

Please sign in to comment.