Skip to content

Commit

Permalink
Fix logsumexp when out is passed (#445)
Browse files Browse the repository at this point in the history
* update

* update

* update
  • Loading branch information
rusty1s authored May 27, 2024
1 parent 521d26f commit 96aa2e3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
9 changes: 9 additions & 0 deletions test/composite/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,12 @@ def test_logsumexp():

jit = torch.jit.script(scatter_logsumexp)
assert jit(inputs, index).tolist() == outputs.tolist()


def test_logsumexp_out():
src = torch.tensor([-1.0, -50.0])
index = torch.tensor([0, 0])
out = torch.tensor([-10.0, -10.0])

scatter_logsumexp(src=src, index=index, out=out)
assert out.allclose(torch.tensor([-0.9999, -10.0]), atol=1e-4)
33 changes: 24 additions & 9 deletions torch_scatter/composite/logsumexp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import Optional

import torch
from torch_scatter import scatter_sum, scatter_max

from torch_scatter import scatter_max, scatter_sum
from torch_scatter.utils import broadcast


def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12) -> torch.Tensor:
def scatter_logsumexp(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12,
) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
Expand All @@ -24,18 +27,30 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,

size = list(src.size())
size[dim] = dim_size
max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
max_value_per_index = torch.full(
size,
fill_value=float('-inf'),
dtype=src.dtype,
device=src.device,
)
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_score = src - max_per_src_element
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))

orig_out: Optional[torch.Tensor] = None
if out is not None:
orig_out = out.clone()
out = out.sub_(max_value_per_index).exp_()

sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
dim_size)

out = sum_per_index.add_(eps).log_().add_(max_value_per_index)
return out.nan_to_num_(neginf=0.0)

if orig_out is None:
return out.nan_to_num_(neginf=0.0)

mask = ~out.isfinite()
out[mask] = orig_out[mask]
return out

0 comments on commit 96aa2e3

Please sign in to comment.