diff --git a/torch_scatter/composite/logsumexp.py b/torch_scatter/composite/logsumexp.py index 69dc90dd..1d5ff9b6 100644 --- a/torch_scatter/composite/logsumexp.py +++ b/torch_scatter/composite/logsumexp.py @@ -38,7 +38,7 @@ def scatter_logsumexp( recentered_score = src - max_per_src_element recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf')) - orig_out = None + orig_out: Optional[torch.Tensor] = None if out is not None: orig_out = out.clone() out = out.sub_(max_value_per_index).exp_()