Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 24, 2024
1 parent baa7fe1 commit ccd3ddc
Showing 1 changed file with 46 additions and 11 deletions.
57 changes: 46 additions & 11 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def test_getitem_nested(self):
assert sub_tensordict.shape == torch.Size([4, 5])
assert sub_sub_tensordict.shape == torch.Size([4, 5, 6])

@set_lazy_legacy(True)
def test_inferred_view_size(self):
td = TensorDict({"a": torch.randn(3, 4)}, [3, 4])
assert td.view(-1).view(-1, 4) is td
Expand Down Expand Up @@ -2539,12 +2540,13 @@ def test_inferred_view_size(self, td_name, device):
dim_size if dim_idx != i else -1
for dim_idx, dim_size in enumerate(td.shape)
]
if td_name in ("td_params",):
assert td.view(-1).view(*new_shape)._param_td is td._param_td
assert td.view(*new_shape)._param_td is td._param_td
else:
assert td.view(*new_shape) is td
assert td.view(-1).view(*new_shape) is td
if lazy_legacy():
if td_name in ("td_params",):
assert td.view(-1).view(*new_shape)._param_td is td._param_td
assert td.view(*new_shape)._param_td is td._param_td
else:
assert td.view(*new_shape) is td
assert td.view(-1).view(*new_shape) is td

def test_items_values_keys(self, td_name, device):
torch.manual_seed(1)
Expand Down Expand Up @@ -3102,6 +3104,10 @@ def test_nestedtensor_stack(self, td_name, device, dim, key):
# cloning is type-preserving: we can do that operation
td_stack.clone()

# This test fails on lazy tensordicts when lazy-legacy is False
# Deprecating lazy modules will make this decorator useless (the test should
# still run ok).
@set_lazy_legacy(True)
def test_non_tensor_data(self, td_name, device):
td = getattr(self, td_name)(device)
# check lock
Expand Down Expand Up @@ -3132,6 +3138,10 @@ def test_non_tensor_data(self, td_name, device):
assert isinstance(td.get(("this", "other", "tensor")), NonTensorData)
assert td.get_non_tensor(("this", "other", "tensor")) == "success"

# This test fails on lazy tensordicts when lazy-legacy is False
# Deprecating lazy modules will make this decorator useless (the test should
# still run ok).
@set_lazy_legacy(True)
def test_non_tensor_data_flatten_keys(self, td_name, device):
td = getattr(self, td_name)(device)
with td.unlock_():
Expand All @@ -3149,6 +3159,10 @@ def test_non_tensor_data_flatten_keys(self, td_name, device):
assert (td_flat.get("this.tensor") == 0).all()
assert td_flat.get_non_tensor("this.will") == "succeed"

# This test fails on lazy tensordicts when lazy-legacy is False
# Deprecating lazy modules will make this decorator useless (the test should
# still run ok).
@set_lazy_legacy(True)
def test_non_tensor_data_pickle(self, td_name, device, tmpdir):
td = getattr(self, td_name)(device)
with td.unlock_():
Expand Down Expand Up @@ -3484,6 +3498,7 @@ def test_select_exception(self, td_name, device, strict):
assert td2 is not td
assert len(list(td2.keys())) == 0

@set_lazy_legacy(True)
def test_set_lazy_legacy(self, td_name, device):
if td_name in (
"sub_td",
Expand All @@ -3493,16 +3508,19 @@ def test_set_lazy_legacy(self, td_name, device):
"unsqueezed_td",
"permute_td",
"transpose_td",
"nested_stacked_td",
"stacked_td",
):
raiser = pytest.raises(RuntimeError)
raiser_view = raiser
elif "stack" in td_name:
raiser = contextlib.nullcontext()
raiser_view = pytest.raises(RuntimeError)
else:
raiser = contextlib.nullcontext()
raiser_view = raiser

def test_not_id(td, td_name=td_name, raiser=raiser):
def test_not_id(td, td_name=td_name, raiser=raiser, raiser_view=raiser_view):
# view
with raiser:
with raiser_view:
td_view = td.view(-1).view(td.shape)
if td_name in ("td_params",):
assert isinstance(td_view, TensorDict)
Expand Down Expand Up @@ -4039,7 +4057,24 @@ def test_stack_onto(self, td_name, device, tmpdir):
else:
td1.apply_(lambda x: x.zero_() + 1)

td_out = td.unsqueeze(1).expand(td.shape[0], 2, *td.shape[1:]).clone()
is_lazy = td_name in (
"sub_td",
"sub_td2",
"permute_td",
"unsqueezed_td",
"squeezed_td",
"td_h5",
) and not lazy_legacy()
error_dec = (
pytest.raises(RuntimeError, match="Make it dense")
if is_lazy
else contextlib.nullcontext()
)
with error_dec:
td_out = td.unsqueeze(1)
if is_lazy:
return
td_out = td_out.expand(td.shape[0], 2, *td.shape[1:]).clone()
td_stack = LazyStackedTensorDict.lazy_stack([td0, td1], 1)
if td_name == "td_params":
with pytest.raises(RuntimeError, match="out.batch_size and stacked"):
Expand Down

0 comments on commit ccd3ddc

Please sign in to comment.