Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 5, 2024
1 parent 4383e5f commit 31ed32a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7179,7 +7179,7 @@ def sub(self, other: TensorDictBase | float, alpha: float | None = None):
if alpha is not None:
vals = torch._foreach_sub(val, other_val, alpha=alpha)
else:
vals = torch._foreach_sub(vals, other_val)
vals = torch._foreach_sub(val, other_val)
items = dict(zip(keys, vals))
return self._fast_apply(
lambda name, val: items.get(name, val),
Expand Down
19 changes: 19 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2437,6 +2437,25 @@ def dummy_td_1(self):
def dummy_td_2(self):
return self.dummy_td_0.apply(lambda x: x + 2)

def test_ordering(self):

x0 = TensorDict(
{
"y": torch.zeros(3),
"x": torch.ones(3)
}
)

x1 = TensorDict(
{
"x": torch.ones(3),
"y": torch.zeros(3)
}
)
assert ((x0+x1)["x"] == 2).all()
assert ((x0*x1)["x"] == 1).all()
assert ((x0-x1)["x"] == 0).all()

@pytest.mark.parametrize("locked", [True, False])
def test_add(self, locked):
td = self.dummy_td_0
Expand Down

0 comments on commit 31ed32a

Please sign in to comment.