Skip to content

Commit

Permalink
fix scatter to handle empty graphs (#4193)
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 authored Mar 1, 2024
1 parent 93a627e commit c9d75bf
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ def example_scatter_data():
}

return src_feat, dst_indices, results


@pytest.fixture
def empty_scatter_data():
src_feat = torch.empty((0, 41))
dst_indices = torch.empty((0,))

return src_feat, dst_indices
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@

@pytest.mark.parametrize("reduce", ["sum", "mean", "prod", "amax", "amin"])
def test_scatter_reduce(example_scatter_data, reduce):
device = torch.device("cuda:0")
device = torch.device("cuda")
src, index, out_true = example_scatter_data
src = src.to(device)
index = index.to(device)

out = scatter_reduce(src, index, dim=0, dim_size=None, reduce=reduce)

assert torch.allclose(out.cpu(), out_true[reduce])


def test_scatter_reduce_empty(empty_scatter_data):
device = torch.device("cuda")
src, index = empty_scatter_data
src = src.to(device)
index = index.to(device)

out = scatter_reduce(src, index, dim=0, dim_size=None)

assert out.numel() == 0
assert out.size(1) == src.size(1)
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ def scatter_reduce(
size = list(src.size())

if dim_size is not None:
assert dim_size >= int(index.max()) + 1
size[dim] = dim_size
else:
size[dim] = int(index.max()) + 1
size[dim] = 0 if index.numel() == 0 else int(index.max()) + 1

out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_reduce_(dim, index, src, reduce, include_self=False)

0 comments on commit c9d75bf

Please sign in to comment.