From 96aa2e3587123ba4ef31820899d5e62141e9a4c2 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 27 May 2024 14:25:37 +0200 Subject: [PATCH] Fix `logsumexp` when `out` is passed (#445) * update * update * update --- test/composite/test_logsumexp.py | 9 ++++++++ torch_scatter/composite/logsumexp.py | 33 ++++++++++++++++++++-------- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/test/composite/test_logsumexp.py b/test/composite/test_logsumexp.py index 49e7a9c4..a6b3d160 100644 --- a/test/composite/test_logsumexp.py +++ b/test/composite/test_logsumexp.py @@ -29,3 +29,12 @@ def test_logsumexp(): jit = torch.jit.script(scatter_logsumexp) assert jit(inputs, index).tolist() == outputs.tolist() + + +def test_logsumexp_out(): + src = torch.tensor([-1.0, -50.0]) + index = torch.tensor([0, 0]) + out = torch.tensor([-10.0, -10.0]) + + scatter_logsumexp(src=src, index=index, out=out) + assert out.allclose(torch.tensor([-0.9999, -10.0]), atol=1e-4) diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py index 355d0c0e..1d5ff9b6 100644 --- a/torch_scatter/composite/logsumexp.py +++ b/torch_scatter/composite/logsumexp.py @@ -1,15 +1,18 @@ from typing import Optional import torch -from torch_scatter import scatter_sum, scatter_max - +from torch_scatter import scatter_max, scatter_sum from torch_scatter.utils import broadcast -def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, - out: Optional[torch.Tensor] = None, - dim_size: Optional[int] = None, - eps: float = 1e-12) -> torch.Tensor: +def scatter_logsumexp( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + eps: float = 1e-12, +) -> torch.Tensor: if not torch.is_floating_point(src): raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') @@ -24,18 +27,30 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, size = list(src.size()) size[dim] = dim_size - max_value_per_index = torch.full(size, float('-inf'), dtype=src.dtype, - device=src.device) + max_value_per_index = torch.full( + size, + fill_value=float('-inf'), + dtype=src.dtype, + device=src.device, + ) scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0] max_per_src_element = max_value_per_index.gather(dim, index) recentered_score = src - max_per_src_element recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf')) + orig_out: Optional[torch.Tensor] = None if out is not None: + orig_out = out.clone() out = out.sub_(max_value_per_index).exp_() sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out, dim_size) out = sum_per_index.add_(eps).log_().add_(max_value_per_index) - return out.nan_to_num_(neginf=0.0) + + if orig_out is None: + return out.nan_to_num_(neginf=0.0) + + mask = ~out.isfinite() + out[mask] = orig_out[mask] + return out