From ccd3ddcf899b41be6147d0a787f25e6d399062e8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 24 Jan 2024 16:18:36 +0000 Subject: [PATCH] amend --- test/test_tensordict.py | 57 +++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 8f42d25b7..c0c80c99c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -752,6 +752,7 @@ def test_getitem_nested(self): assert sub_tensordict.shape == torch.Size([4, 5]) assert sub_sub_tensordict.shape == torch.Size([4, 5, 6]) + @set_lazy_legacy(True) def test_inferred_view_size(self): td = TensorDict({"a": torch.randn(3, 4)}, [3, 4]) assert td.view(-1).view(-1, 4) is td @@ -2539,12 +2540,13 @@ def test_inferred_view_size(self, td_name, device): dim_size if dim_idx != i else -1 for dim_idx, dim_size in enumerate(td.shape) ] - if td_name in ("td_params",): - assert td.view(-1).view(*new_shape)._param_td is td._param_td - assert td.view(*new_shape)._param_td is td._param_td - else: - assert td.view(*new_shape) is td - assert td.view(-1).view(*new_shape) is td + if lazy_legacy(): + if td_name in ("td_params",): + assert td.view(-1).view(*new_shape)._param_td is td._param_td + assert td.view(*new_shape)._param_td is td._param_td + else: + assert td.view(*new_shape) is td + assert td.view(-1).view(*new_shape) is td def test_items_values_keys(self, td_name, device): torch.manual_seed(1) @@ -3102,6 +3104,10 @@ def test_nestedtensor_stack(self, td_name, device, dim, key): # cloning is type-preserving: we can do that operation td_stack.clone() + # This test fails on lazy tensordicts when lazy-legacy is False + # Deprecating lazy modules will make this decorator useless (the test should + # still run ok). + @set_lazy_legacy(True) def test_non_tensor_data(self, td_name, device): td = getattr(self, td_name)(device) # check lock @@ -3132,6 +3138,10 @@ def test_non_tensor_data(self, td_name, device): assert isinstance(td.get(("this", "other", "tensor")), NonTensorData) assert td.get_non_tensor(("this", "other", "tensor")) == "success" + # This test fails on lazy tensordicts when lazy-legacy is False + # Deprecating lazy modules will make this decorator useless (the test should + # still run ok). + @set_lazy_legacy(True) def test_non_tensor_data_flatten_keys(self, td_name, device): td = getattr(self, td_name)(device) with td.unlock_(): @@ -3149,6 +3159,10 @@ def test_non_tensor_data_flatten_keys(self, td_name, device): assert (td_flat.get("this.tensor") == 0).all() assert td_flat.get_non_tensor("this.will") == "succeed" + # This test fails on lazy tensordicts when lazy-legacy is False + # Deprecating lazy modules will make this decorator useless (the test should + # still run ok). + @set_lazy_legacy(True) def test_non_tensor_data_pickle(self, td_name, device, tmpdir): td = getattr(self, td_name)(device) with td.unlock_(): @@ -3484,6 +3498,7 @@ def test_select_exception(self, td_name, device, strict): assert td2 is not td assert len(list(td2.keys())) == 0 + @set_lazy_legacy(True) def test_set_lazy_legacy(self, td_name, device): if td_name in ( "sub_td", @@ -3493,16 +3508,19 @@ def test_set_lazy_legacy(self, td_name, device): "unsqueezed_td", "permute_td", "transpose_td", - "nested_stacked_td", - "stacked_td", ): raiser = pytest.raises(RuntimeError) + raiser_view = raiser + elif "stack" in td_name: + raiser = contextlib.nullcontext() + raiser_view = pytest.raises(RuntimeError) else: raiser = contextlib.nullcontext() + raiser_view = raiser - def test_not_id(td, td_name=td_name, raiser=raiser): + def test_not_id(td, td_name=td_name, raiser=raiser, raiser_view=raiser_view): # view - with raiser: + with raiser_view: td_view = td.view(-1).view(td.shape) if td_name in ("td_params",): assert isinstance(td_view, TensorDict) @@ -4039,7 +4057,24 @@ def test_stack_onto(self, td_name, device, tmpdir): else: td1.apply_(lambda x: x.zero_() + 1) - td_out = td.unsqueeze(1).expand(td.shape[0], 2, *td.shape[1:]).clone() + is_lazy = td_name in ( + "sub_td", + "sub_td2", + "permute_td", + "unsqueezed_td", + "squeezed_td", + "td_h5", + ) and not lazy_legacy() + error_dec = ( + pytest.raises(RuntimeError, match="Make it dense") + if is_lazy + else contextlib.nullcontext() + ) + with error_dec: + td_out = td.unsqueeze(1) + if is_lazy: + return + td_out = td_out.expand(td.shape[0], 2, *td.shape[1:]).clone() td_stack = LazyStackedTensorDict.lazy_stack([td0, td1], 1) if td_name == "td_params": with pytest.raises(RuntimeError, match="out.batch_size and stacked"):