Skip to content

Commit

Permalink
mypy support in torch_geometric.utils[3/n] (pyg-team#8556)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 6, 2023
1 parent b7edc22 commit b454a51
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 23 deletions.
51 changes: 37 additions & 14 deletions torch_geometric/utils/_coalesce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from typing import List, Optional, Tuple, Union

import torch
Expand All @@ -7,35 +8,55 @@
from torch_geometric.utils import index_sort, scatter
from torch_geometric.utils.num_nodes import maybe_num_nodes

if typing.TYPE_CHECKING:
from typing import overload
else:
from torch.jit import _overload as overload

MISSING = '???'


@torch.jit._overload
def coalesce( # noqa: F811
edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row):
# type: (Tensor, str, Optional[int], str, bool, bool) -> Tensor
@overload
def coalesce(
edge_index: Tensor,
edge_attr: str = MISSING,
num_nodes: Optional[int] = None,
reduce: str = 'sum',
is_sorted: bool = False,
sort_by_row: bool = True,
) -> Tensor:
pass


@torch.jit._overload
@overload
def coalesce( # noqa: F811
edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row):
# type: (Tensor, Optional[Tensor], Optional[int], str, bool, bool) -> Tuple[Tensor, Optional[Tensor]] # noqa
edge_index: Tensor,
edge_attr: OptTensor,
num_nodes: Optional[int] = None,
reduce: str = 'sum',
is_sorted: bool = False,
sort_by_row: bool = True,
) -> Tuple[Tensor, OptTensor]:
pass


@torch.jit._overload
@overload
def coalesce( # noqa: F811
edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row):
# type: (Tensor, List[Tensor], Optional[int], str, bool, bool) -> Tuple[Tensor, List[Tensor]] # noqa
edge_index: Tensor,
edge_attr: List[Tensor],
num_nodes: Optional[int] = None,
reduce: str = 'sum',
is_sorted: bool = False,
sort_by_row: bool = True,
) -> Tuple[Tensor, List[Tensor]]:
pass


def coalesce( # noqa: F811
edge_index: Tensor,
edge_attr: Union[OptTensor, List[Tensor], str] = MISSING,
num_nodes: Optional[int] = None,
reduce: str = 'add',
reduce: str = 'sum',
is_sorted: bool = False,
sort_by_row: bool = True,
) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]:
Expand All @@ -52,8 +73,8 @@ def coalesce( # noqa: F811
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)
reduce (str, optional): The reduce operation to use for merging edge
features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"mul"`, :obj:`"any"`). (default: :obj:`"add"`)
features (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"mul"`, :obj:`"any"`). (default: :obj:`"sum"`)
is_sorted (bool, optional): If set to :obj:`True`, will expect
:obj:`edge_index` to be already sorted row-wise.
sort_by_row (bool, optional): If set to :obj:`False`, will sort
Expand Down Expand Up @@ -117,7 +138,9 @@ def coalesce( # noqa: F811

# Only perform expensive merging in case there exists duplicates:
if mask.all():
if edge_attr is None or isinstance(edge_attr, (Tensor, list, tuple)):
if edge_attr is None or isinstance(edge_attr, Tensor):
return edge_index, edge_attr
if isinstance(edge_attr, (list, tuple)):
return edge_index, edge_attr
return edge_index

Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/utils/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def grid_index(
) -> Tensor:

w = width
kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]
kernel = torch.tensor(kernel, device=device)
kernel = torch.tensor(
[-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1],
device=device,
)

row = torch.arange(height * width, dtype=torch.long, device=device)
row = row.view(-1, 1).repeat(1, kernel.size(0))
Expand Down
30 changes: 23 additions & 7 deletions torch_geometric/utils/_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def subgraph(
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
return_edge_mask: Literal[False] = ...,
) -> Tuple[Tensor, OptTensor]:
pass

Expand All @@ -56,7 +55,21 @@ def subgraph(
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
return_edge_mask: Literal[True] = ...,
*,
return_edge_mask: Literal[False],
) -> Tuple[Tensor, OptTensor]:
pass


@overload
def subgraph(
subset: Union[Tensor, List[int]],
edge_index: Tensor,
edge_attr: OptTensor = ...,
relabel_nodes: bool = ...,
num_nodes: Optional[int] = ...,
*,
return_edge_mask: Literal[True],
) -> Tuple[Tensor, OptTensor, Tensor]:
pass

Expand All @@ -67,6 +80,7 @@ def subgraph(
edge_attr: OptTensor = None,
relabel_nodes: bool = False,
num_nodes: Optional[int] = None,
*,
return_edge_mask: bool = False,
) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]:
r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)`
Expand Down Expand Up @@ -314,8 +328,10 @@ def k_hop_subgraph(
node_mask = row.new_empty(num_nodes, dtype=torch.bool)
edge_mask = row.new_empty(row.size(0), dtype=torch.bool)

if isinstance(node_idx, (int, list, tuple)):
node_idx = torch.tensor([node_idx], device=row.device).flatten()
if isinstance(node_idx, int):
node_idx = torch.tensor([node_idx], device=row.device)
elif isinstance(node_idx, (list, tuple)):
node_idx = torch.tensor(node_idx, device=row.device)
else:
node_idx = node_idx.to(row.device)

Expand All @@ -339,9 +355,9 @@ def k_hop_subgraph(
edge_index = edge_index[:, edge_mask]

if relabel_nodes:
node_idx = row.new_full((num_nodes, ), -1)
node_idx[subset] = torch.arange(subset.size(0), device=row.device)
edge_index = node_idx[edge_index]
mapping = row.new_full((num_nodes, ), -1)
mapping[subset] = torch.arange(subset.size(0), device=row.device)
edge_index = mapping[edge_index]

return subset, edge_index, inv, edge_mask

Expand Down

0 comments on commit b454a51

Please sign in to comment.