diff --git a/tensordict/base.py b/tensordict/base.py index 159777e28..91b94e714 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9870,8 +9870,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): """ if is_tensor_collection(obj): - if is_non_tensor(obj): - return cls.from_any(obj.data, auto_batch_size=auto_batch_size) + # Conversions from non-tensor data must be done manually + # if is_non_tensor(obj): + # from tensordict.tensorclass import LazyStackedTensorDict + # if isinstance(obj, LazyStackedTensorDict): + # return obj + # return cls.from_any(obj.data, auto_batch_size=auto_batch_size) return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) diff --git a/test/test_nn.py b/test/test_nn.py index e1a334ff3..5c32ac6e3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -389,13 +389,13 @@ def test_nontensor(self): in_keys=[], out_keys=["out"], ) - assert tdm(TensorDict({}))["out"] == [1, 2] + assert tdm(TensorDict())["out"] == [1, 2] tdm = TensorDictModule( lambda: "a string!", in_keys=[], out_keys=["out"], ) - assert tdm(TensorDict({}))["out"] == "a string!" + assert tdm(TensorDict())["out"] == "a string!" @pytest.mark.parametrize( "out_keys", diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f7fe9a9ff..07eac0ec1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -814,13 +814,13 @@ def test_expand_with_singleton(self, device): @set_lazy_legacy(True) def test_filling_empty_tensordict(self, device, td_type, update): if td_type == "tensordict": - td = TensorDict({}, batch_size=[16], device=device) + td = TensorDict(batch_size=[16], device=device) elif td_type == "view": - td = TensorDict({}, batch_size=[4, 4], device=device).view(-1) + td = TensorDict(batch_size=[4, 4], device=device).view(-1) elif td_type == "unsqueeze": - td = TensorDict({}, batch_size=[16], device=device).unsqueeze(-1) + td = TensorDict(batch_size=[16], device=device).unsqueeze(-1) elif td_type == "squeeze": - td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1) + td = TensorDict(batch_size=[16, 1], device=device).squeeze(-1) elif td_type == "stack": td = LazyStackedTensorDict.lazy_stack( [TensorDict({}, [], device=device) for _ in range(16)], 0 @@ -2591,7 +2591,7 @@ def test_record_stream(self): @pytest.mark.parametrize("device", get_available_devices()) def test_subtensordict_construction(self, device): torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5)) + td = TensorDict(batch_size=(4, 5)) val1 = torch.randn(4, 5, 1, device=device) val2 = torch.randn(4, 5, 6, dtype=torch.double, device=device) val1_copy = val1.clone() @@ -2694,7 +2694,7 @@ def test_tensordict_error_messages(self, device): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_indexing(self, device): torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5)) + td = TensorDict(batch_size=(4, 5)) td.set("key1", torch.randn(4, 5, 1, device=device)) td.set("key2", torch.randn(4, 5, 6, device=device, dtype=torch.double)) @@ -2736,7 +2736,7 @@ def test_tensordict_prealloc_nested(self): N = 3 B = 5 T = 4 - buffer = TensorDict({}, batch_size=[B, N]) + buffer = TensorDict(batch_size=[B, N]) td_0 = TensorDict( { @@ -2777,7 +2777,7 @@ def test_tensordict_prealloc_nested(self): @pytest.mark.parametrize("device", get_available_devices()) def test_tensordict_set(self, device): torch.manual_seed(1) - td = TensorDict({}, batch_size=(4, 5), device=device) + td = TensorDict(batch_size=(4, 5), device=device) td.set("key1", torch.randn(4, 5)) assert td.device == torch.device(device) # by default inplace: @@ -4235,7 +4235,7 @@ def test_flatten_unflatten_bis(self, td_name, device): def test_from_empty(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - new_td = TensorDict({}, batch_size=td.batch_size, device=device) + new_td = TensorDict(batch_size=td.batch_size, device=device) for key, item in td.items(): new_td.set(key, item) assert_allclose_td(td, new_td) @@ -4433,7 +4433,7 @@ def test_items_values_keys(self, td_name, device): items = list(td.items()) # Test td.items() - constructed_td1 = TensorDict({}, batch_size=td.shape) + constructed_td1 = TensorDict(batch_size=td.shape) for key, value in items: constructed_td1.set(key, value) @@ -4443,7 +4443,7 @@ def test_items_values_keys(self, td_name, device): # items = [key, value] should be verified assert len(values) == len(items) assert len(keys) == len(items) - constructed_td2 = TensorDict({}, batch_size=td.shape) + constructed_td2 = TensorDict(batch_size=td.shape) for key, value in list(zip(td.keys(), td.values())): constructed_td2.set(key, value) @@ -4464,7 +4464,7 @@ def test_items_values_keys(self, td_name, device): # Test td.items() # after adding the new element - constructed_td1 = TensorDict({}, batch_size=td.shape) + constructed_td1 = TensorDict(batch_size=td.shape) for key, value in items: constructed_td1.set(key, value) @@ -4476,7 +4476,7 @@ def test_items_values_keys(self, td_name, device): assert len(values) == len(items) assert len(keys) == len(items) - constructed_td2 = TensorDict({}, batch_size=td.shape) + constructed_td2 = TensorDict(batch_size=td.shape) for key, value in list(zip(td.keys(), td.values())): constructed_td2.set(key, value) @@ -9382,14 +9382,14 @@ def run_assertions(): class TestNamedDims(TestTensorDictsBase): def test_all(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tda = td.all(2) assert tda.names == ["a", "b", "d"] tda = td.any(2) assert tda.names == ["a", "b", "d"] def test_apply(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tda = td.apply(lambda x: x + 1) assert tda.names == ["a", "b", "c", "d"] tda = td.apply(lambda x: x.squeeze(2), batch_size=[3, 4, 6]) @@ -9397,15 +9397,15 @@ def test_apply(self): assert tda.names == [None] * 3 def test_cat(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) tdc = torch.cat([td, td], -1) assert tdc.names == [None] * 4 - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) tdc = torch.cat([td, td], -1) assert tdc.names == ["a", "b", "c", "d"] def test_change_batch_size(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) td.batch_size = [3, 4, 1, 6, 1] assert td.names == ["a", "b", "c", "z", None] td.batch_size = [] @@ -9417,7 +9417,7 @@ def test_change_batch_size(self): assert td.names == ["a"] def test_clone(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tdc = td.clone() assert tdc.names == ["a", "b", "c", "d"] @@ -9425,14 +9425,14 @@ def test_clone(self): assert tdc.names == ["a", "b", "c", "d"] def test_detach(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) td[""] = torch.zeros(td.shape, requires_grad=True) tdd = td.detach() assert tdd.names == ["a", "b", "c", "d"] def test_error_similar(self): with pytest.raises(ValueError): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"]) with pytest.raises(ValueError): td = TensorDict( {}, @@ -9446,16 +9446,16 @@ def test_error_similar(self): ) td.refine_names("a", "a", ...) with pytest.raises(ValueError): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"]) td.rename_(a="z") def test_expand(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tde = td.expand(2, 3, 4, 5, 6) assert tde.names == [None, "a", "b", "c", "d"] def test_flatten(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tdf = td.flatten(1, 3) assert tdf.names == ["a", None] tdu = tdf.unflatten(1, (4, 1, 6)) @@ -9470,11 +9470,11 @@ def test_flatten(self): assert tdu.names == [None, None, None, "d"] def test_fullname(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) assert td.names == ["a", "b", "c", "d"] def test_gather(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) idx = torch.randint(6, (3, 4, 1, 18)) tdg = td.gather(dim=-1, index=idx) assert tdg.names == ["a", "b", "c", "d"] @@ -9499,7 +9499,7 @@ def test_h5_td(self): assert td.names == list("abgd") def test_index(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) assert td[0].names == ["b", "c", "d"] assert td[:, 0].names == ["a", "c", "d"] assert td[0, :].names == ["b", "c", "d"] @@ -9519,7 +9519,7 @@ def test_index(self): assert tdbool.ndim == 3 def test_masked_fill(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tdm = td.masked_fill(torch.zeros(3, 4, 1, dtype=torch.bool), 1.0) assert tdm.names == ["a", "b", "c", "d"] @@ -9543,16 +9543,16 @@ def test_memmap_td(self): assert td.clone().names == list("abgd") def test_nested(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) - td["a"] = TensorDict({}, batch_size=[3, 4, 1, 6]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td["a"] = TensorDict(batch_size=[3, 4, 1, 6]) assert td["a"].names == td.names - td["a"] = TensorDict({}, batch_size=[]) + td["a"] = TensorDict() assert td["a"].names == td.names - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=None) - td["a"] = TensorDict({}, batch_size=[3, 4, 1, 6]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=None) + td["a"] = TensorDict(batch_size=[3, 4, 1, 6]) td.names = ["a", "b", None, None] assert td["a"].names == td.names - td.set_("a", TensorDict({}, batch_size=[3, 4, 1, 6])) + td.set_("a", TensorDict(batch_size=[3, 4, 1, 6])) assert td["a"].names == td.names def test_nested_indexing(self): @@ -9602,15 +9602,15 @@ def test_nested_td(self): assert nested_td.contiguous()["my_nested_td"].names == list("abgd") def test_noname(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) assert td.names == [None] * 4 def test_partial_name(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", None, None, "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", None, None, "d"]) assert td.names == ["a", None, None, "d"] def test_partial_set(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", None, None, "d"] assert td.names == ["a", None, None, "d"] td.names = ["a", "b", "c", "d"] @@ -9639,7 +9639,7 @@ def test_permute_td(self): td.names = list("abcd") def test_refine_names(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6]) + td = TensorDict(batch_size=[3, 4, 5, 6]) tdr = td.refine_names(None, None, None, "d") assert tdr.names == [None, None, None, "d"] tdr = tdr.refine_names(None, None, "c", "d") @@ -9654,7 +9654,7 @@ def test_refine_names(self): assert tdr.names == ["a", None, "c", "d"] def test_rename(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", None, None, "d"] td.rename_(a="c") assert td.names == ["c", None, None, "d"] @@ -9670,7 +9670,7 @@ def test_rename(self): assert td2.names == ["w", "x", "y", "z"] def test_select(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) tds = td.select() assert tds.names == ["a", "b", "c", "d"] tde = td.exclude() @@ -9707,11 +9707,11 @@ def test_split(self): # assert tdu.is_locked def test_squeeze(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tds = td.squeeze(0) assert tds.names == ["a", "b", "c", "d"] - td = TensorDict({}, batch_size=[3, 1, 5, 6], names=None) + td = TensorDict(batch_size=[3, 1, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tds = td.squeeze(1) assert tds.names == ["a", "c", "d"] @@ -9724,7 +9724,7 @@ def test_squeeze_td(self): td.names = list("abcd") def test_stack(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) tds = LazyStackedTensorDict.lazy_stack([td, td], 0) assert tds.names == [None, "a", "b", "c", "d"] tds = LazyStackedTensorDict.lazy_stack([td, td], -1) @@ -9762,7 +9762,7 @@ def test_sub_td(self): td.names = list("abcd") def test_subtd(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"]) assert td._get_sub_tensordict(0).names == ["b", "c", "d"] assert td._get_sub_tensordict((slice(None), 0)).names == ["a", "c", "d"] assert td._get_sub_tensordict((0, slice(None))).names == ["b", "c", "d"] @@ -9826,14 +9826,14 @@ def test_to(self, device, non_blocking_pin, num_threads, inplace): assert tdt is not td def test_unbind(self): - td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) + td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"]) *_, tdu = td.unbind(-1) assert tdu.names == ["a", "b", "c"] *_, tdu = td.unbind(-2) assert tdu.names == ["a", "b", "d"] def test_unsqueeze(self): - td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None) + td = TensorDict(batch_size=[3, 4, 5, 6], names=None) td.names = ["a", "b", "c", "d"] tdu = td.unsqueeze(0) assert tdu.names == [None, "a", "b", "c", "d"]