Skip to content

Commit

Permalink
added logsumexp tests of edge-cases
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmushaugaard committed Mar 15, 2024
1 parent c095c62 commit a347cf6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
59 changes: 36 additions & 23 deletions test/composite/test_logsumexp.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,44 @@
from itertools import product

import pytest
import torch
from torch_scatter import scatter_logsumexp
from torch_scatter.testing import float_dtypes, assert_equal

tests = [
[0.5, -2.1, 3.2],
[1e33, 0.5],
[-1e33, 0.5],
[-1e33],
[],
[float("nan"), 0.5],
[float("-inf"), 0.5],
[float("inf"), 0.5],
]


@pytest.mark.parametrize('src,dtype', product(tests, float_dtypes))
def test_logsumexp(src, dtype):
src = torch.tensor(src, dtype=dtype)
index = torch.zeros_like(src, dtype=torch.long)
out_scatter = scatter_logsumexp(src, index, dim_size=1)
out_torch = torch.logsumexp(src, dim=0, keepdim=True)
assert_equal(out_scatter, out_torch, equal_nan=True)


def test_logsumexp_parallel_jit():
splits = [len(src) for src in tests]
srcs = torch.tensor(sum(tests, start=[]))
index = torch.repeat_interleave(torch.tensor(splits))

srcs.requires_grad_()
outputs = scatter_logsumexp(srcs, index)

def test_logsumexp():
inputs = torch.tensor([
0.5,
0.5,
0.0,
-2.1,
3.2,
7.0,
-1.0,
-100.0,
])
inputs.requires_grad_()
index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4])
splits = [2, 3, 1, 0, 2]

outputs = scatter_logsumexp(inputs, index)

for src, out in zip(inputs.split(splits), outputs.unbind()):
if src.numel() > 0:
assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
else:
assert out.item() == 0.0
for src, out_scatter in zip(srcs.split(splits), outputs.unbind()):
out_torch = torch.logsumexp(src, dim=0)
assert_equal(out_scatter, out_torch, equal_nan=True)

outputs.backward(torch.randn_like(outputs))

jit = torch.jit.script(scatter_logsumexp)
assert jit(inputs, index).tolist() == outputs.tolist()
assert_equal(jit(srcs, index), outputs, equal_nan=True)
5 changes: 5 additions & 0 deletions torch_scatter/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
torch.half, torch.bfloat16, torch.float, torch.double, torch.int,
torch.long
]
float_dtypes = list(filter(lambda x: x.is_floating_point, dtypes))
grad_dtypes = [torch.float, torch.double]

devices = [torch.device('cpu')]
Expand All @@ -17,3 +18,7 @@

def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, device=device).to(dtype)


def assert_equal(actual: torch.Tensor, expected: torch.Tensor, equal_nan=False):
torch.testing.assert_close(actual, expected, equal_nan=equal_nan, rtol=0, atol=0)

0 comments on commit a347cf6

Please sign in to comment.