From b454a51e47325c21cbefa54d611f7d800080925c Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 6 Dec 2023 23:52:38 +0100 Subject: [PATCH] `mypy` support in `torch_geometric.utils`[3/n] (#8556) --- torch_geometric/utils/_coalesce.py | 51 ++++++++++++++++++++++-------- torch_geometric/utils/_grid.py | 6 ++-- torch_geometric/utils/_subgraph.py | 30 ++++++++++++++---- 3 files changed, 64 insertions(+), 23 deletions(-) diff --git a/torch_geometric/utils/_coalesce.py b/torch_geometric/utils/_coalesce.py index 981c54e778d8..66efc8d5ca2f 100644 --- a/torch_geometric/utils/_coalesce.py +++ b/torch_geometric/utils/_coalesce.py @@ -1,3 +1,4 @@ +import typing from typing import List, Optional, Tuple, Union import torch @@ -7,27 +8,47 @@ from torch_geometric.utils import index_sort, scatter from torch_geometric.utils.num_nodes import maybe_num_nodes +if typing.TYPE_CHECKING: + from typing import overload +else: + from torch.jit import _overload as overload + MISSING = '???' -@torch.jit._overload -def coalesce( # noqa: F811 - edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row): - # type: (Tensor, str, Optional[int], str, bool, bool) -> Tensor +@overload +def coalesce( + edge_index: Tensor, + edge_attr: str = MISSING, + num_nodes: Optional[int] = None, + reduce: str = 'sum', + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Tensor: pass -@torch.jit._overload +@overload def coalesce( # noqa: F811 - edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row): - # type: (Tensor, Optional[Tensor], Optional[int], str, bool, bool) -> Tuple[Tensor, Optional[Tensor]] # noqa + edge_index: Tensor, + edge_attr: OptTensor, + num_nodes: Optional[int] = None, + reduce: str = 'sum', + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Tuple[Tensor, OptTensor]: pass -@torch.jit._overload +@overload def coalesce( # noqa: F811 - edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row): - # type: (Tensor, List[Tensor], Optional[int], str, bool, bool) -> Tuple[Tensor, List[Tensor]] # noqa + edge_index: Tensor, + edge_attr: List[Tensor], + num_nodes: Optional[int] = None, + reduce: str = 'sum', + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Tuple[Tensor, List[Tensor]]: pass @@ -35,7 +56,7 @@ def coalesce( # noqa: F811 edge_index: Tensor, edge_attr: Union[OptTensor, List[Tensor], str] = MISSING, num_nodes: Optional[int] = None, - reduce: str = 'add', + reduce: str = 'sum', is_sorted: bool = False, sort_by_row: bool = True, ) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]: @@ -52,8 +73,8 @@ def coalesce( # noqa: F811 num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) reduce (str, optional): The reduce operation to use for merging edge - features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, - :obj:`"mul"`, :obj:`"any"`). (default: :obj:`"add"`) + features (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, + :obj:`"mul"`, :obj:`"any"`). (default: :obj:`"sum"`) is_sorted (bool, optional): If set to :obj:`True`, will expect :obj:`edge_index` to be already sorted row-wise. sort_by_row (bool, optional): If set to :obj:`False`, will sort @@ -117,7 +138,9 @@ def coalesce( # noqa: F811 # Only perform expensive merging in case there exists duplicates: if mask.all(): - if edge_attr is None or isinstance(edge_attr, (Tensor, list, tuple)): + if edge_attr is None or isinstance(edge_attr, Tensor): + return edge_index, edge_attr + if isinstance(edge_attr, (list, tuple)): return edge_index, edge_attr return edge_index diff --git a/torch_geometric/utils/_grid.py b/torch_geometric/utils/_grid.py index 7e6f26374418..4154da62dce7 100644 --- a/torch_geometric/utils/_grid.py +++ b/torch_geometric/utils/_grid.py @@ -49,8 +49,10 @@ def grid_index( ) -> Tensor: w = width - kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1] - kernel = torch.tensor(kernel, device=device) + kernel = torch.tensor( + [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1], + device=device, + ) row = torch.arange(height * width, dtype=torch.long, device=device) row = row.view(-1, 1).repeat(1, kernel.size(0)) diff --git a/torch_geometric/utils/_subgraph.py b/torch_geometric/utils/_subgraph.py index 711b1f666d5d..e0a03750cf1e 100644 --- a/torch_geometric/utils/_subgraph.py +++ b/torch_geometric/utils/_subgraph.py @@ -44,7 +44,6 @@ def subgraph( edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., - return_edge_mask: Literal[False] = ..., ) -> Tuple[Tensor, OptTensor]: pass @@ -56,7 +55,21 @@ def subgraph( edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., - return_edge_mask: Literal[True] = ..., + *, + return_edge_mask: Literal[False], +) -> Tuple[Tensor, OptTensor]: + pass + + +@overload +def subgraph( + subset: Union[Tensor, List[int]], + edge_index: Tensor, + edge_attr: OptTensor = ..., + relabel_nodes: bool = ..., + num_nodes: Optional[int] = ..., + *, + return_edge_mask: Literal[True], ) -> Tuple[Tensor, OptTensor, Tensor]: pass @@ -67,6 +80,7 @@ def subgraph( edge_attr: OptTensor = None, relabel_nodes: bool = False, num_nodes: Optional[int] = None, + *, return_edge_mask: bool = False, ) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]: r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)` @@ -314,8 +328,10 @@ def k_hop_subgraph( node_mask = row.new_empty(num_nodes, dtype=torch.bool) edge_mask = row.new_empty(row.size(0), dtype=torch.bool) - if isinstance(node_idx, (int, list, tuple)): - node_idx = torch.tensor([node_idx], device=row.device).flatten() + if isinstance(node_idx, int): + node_idx = torch.tensor([node_idx], device=row.device) + elif isinstance(node_idx, (list, tuple)): + node_idx = torch.tensor(node_idx, device=row.device) else: node_idx = node_idx.to(row.device) @@ -339,9 +355,9 @@ def k_hop_subgraph( edge_index = edge_index[:, edge_mask] if relabel_nodes: - node_idx = row.new_full((num_nodes, ), -1) - node_idx[subset] = torch.arange(subset.size(0), device=row.device) - edge_index = node_idx[edge_index] + mapping = row.new_full((num_nodes, ), -1) + mapping[subset] = torch.arange(subset.size(0), device=row.device) + edge_index = mapping[edge_index] return subset, edge_index, inv, edge_mask