Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 14, 2023
1 parent a2a85fe commit dbf42c4
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions test/composite/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,26 @@

def test_logsumexp():
inputs = torch.tensor([
0.5, 0.5, 0.0, -2.1, 3.2, 7.0, -1.0, -100.0,
float('-inf'),
float('-inf'), 0.0
0.5,
0.5,
0.0,
-2.1,
3.2,
7.0,
-1.0,
-100.0,
])
inputs.requires_grad_()
index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4, 5, 6, 6])
splits = [2, 3, 1, 0, 2, 1, 2]
index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4])
splits = [2, 3, 1, 0, 2]

outputs = scatter_logsumexp(inputs, index)

for src, out in zip(inputs.split(splits), outputs.unbind()):
assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
if src.numel() > 0:
assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
else:
assert out.item() == 0.0

outputs.backward(torch.randn_like(outputs))

Expand Down

0 comments on commit dbf42c4

Please sign in to comment.