From 31ed32a6858de8868983d6002daa3e64a35908be Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 5 Jul 2024 12:17:40 +0100 Subject: [PATCH] amend --- tensordict/base.py | 2 +- test/test_tensordict.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 5a225ff3a..ccea68481 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7179,7 +7179,7 @@ def sub(self, other: TensorDictBase | float, alpha: float | None = None): if alpha is not None: vals = torch._foreach_sub(val, other_val, alpha=alpha) else: - vals = torch._foreach_sub(vals, other_val) + vals = torch._foreach_sub(val, other_val) items = dict(zip(keys, vals)) return self._fast_apply( lambda name, val: items.get(name, val), diff --git a/test/test_tensordict.py b/test/test_tensordict.py index b12517abd..54d1c4ec4 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2437,6 +2437,25 @@ def dummy_td_1(self): def dummy_td_2(self): return self.dummy_td_0.apply(lambda x: x + 2) + def test_ordering(self): + + x0 = TensorDict( + { + "y": torch.zeros(3), + "x": torch.ones(3) + } + ) + + x1 = TensorDict( + { + "x": torch.ones(3), + "y": torch.zeros(3) + } + ) + assert ((x0+x1)["x"] == 2).all() + assert ((x0*x1)["x"] == 1).all() + assert ((x0-x1)["x"] == 0).all() + @pytest.mark.parametrize("locked", [True, False]) def test_add(self, locked): td = self.dummy_td_0