Skip to content

Commit

Permalink
changed logsumexp to pass new edge-case tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmushaugaard committed Mar 15, 2024
1 parent a347cf6 commit 2860e81
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions torch_scatter/composite/logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
eps: float = 1e-12) -> torch.Tensor:
dim_size: Optional[int] = None) -> torch.Tensor:
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')
Expand All @@ -24,18 +23,19 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,

size = list(src.size())
size[dim] = dim_size

max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype,
device=src.device)
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0]
scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)
max_value_per_index.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_score = src - max_per_src_element
recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf'))


src_recentered = src - max_per_src_element
if out is not None:
out = out.sub_(max_value_per_index).exp_()

sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out,
sum_per_index = scatter_sum(src_recentered.exp_(), index, dim, out,
dim_size)

out = sum_per_index.add_(eps).log_().add_(max_value_per_index)
return out.nan_to_num_(neginf=0.0)
return sum_per_index.log_().add_(max_value_per_index)

0 comments on commit 2860e81

Please sign in to comment.