Skip to content

Commit

Permalink
mypy support in torch_geometric.utils [4/n] (pyg-team#8558)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 6, 2023
1 parent b454a51 commit 447f5a5
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 54 deletions.
12 changes: 12 additions & 0 deletions torch_geometric/utils/_coalesce.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def coalesce(
pass


@overload
def coalesce( # noqa: F811
edge_index: Tensor,
edge_attr: Tensor,
num_nodes: Optional[int] = None,
reduce: str = 'sum',
is_sorted: bool = False,
sort_by_row: bool = True,
) -> Tuple[Tensor, Tensor]:
pass


@overload
def coalesce( # noqa: F811
edge_index: Tensor,
Expand Down
30 changes: 27 additions & 3 deletions torch_geometric/utils/_homophily.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, overload

import torch
from torch import Tensor
Expand All @@ -7,8 +7,32 @@
from torch_geometric.utils import degree, scatter


def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None,
method: str = 'edge') -> Union[float, Tensor]:
@overload
def homophily(
edge_index: Adj,
y: Tensor,
batch: None = ...,
method: str = ...,
) -> float:
pass


@overload
def homophily(
edge_index: Adj,
y: Tensor,
batch: Tensor,
method: str = ...,
) -> Tensor:
pass


def homophily(
edge_index: Adj,
y: Tensor,
batch: OptTensor = None,
method: str = 'edge',
) -> Union[float, Tensor]:
r"""The homophily of a graph characterizes how likely nodes with the same
label are near each other in a graph.
Expand Down
17 changes: 11 additions & 6 deletions torch_geometric/utils/_negative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@ def negative_sampling(
"""
assert method in ['sparse', 'dense']

size = num_nodes
bipartite = isinstance(size, (tuple, list))
size = maybe_num_nodes(edge_index) if size is None else size
size = (size, size) if not bipartite else size
force_undirected = False if bipartite else force_undirected
if num_nodes is None:
num_nodes = maybe_num_nodes(edge_index, num_nodes)

if isinstance(num_nodes, int):
size = (num_nodes, num_nodes)
bipartite = False
else:
size = num_nodes
bipartite = True
force_undirected = False

idx, population = edge_index_to_vector(edge_index, size, bipartite,
force_undirected)
Expand Down Expand Up @@ -95,7 +100,7 @@ def negative_sampling(
idx = idx.to('cpu')
for _ in range(3): # Number of tries to sample negative indices.
rnd = sample(population, sample_size, device='cpu')
mask = np.isin(rnd, idx)
mask = np.isin(rnd.numpy(), idx.numpy())
if neg_idx is not None:
mask |= np.isin(rnd, neg_idx.to('cpu'))
mask = torch.from_numpy(mask).to(torch.bool)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/utils/_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def narrow(src: Union[Tensor, List[Any]], dim: int, start: int,
start (int): The starting dimension.
length (int): The distance to the ending dimension.
"""
if is_torch_sparse_tensor(src):
if isinstance(src, Tensor) and is_torch_sparse_tensor(src):
# TODO Sparse tensors in `torch.sparse` do not yet support `narrow`.
index = torch.arange(start, start + length, device=src.device)
return src.index_select(dim, index)
Expand Down
49 changes: 38 additions & 11 deletions torch_geometric/utils/_sort_edge_index.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,61 @@
import typing
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor

from torch_geometric.typing import OptTensor
from torch_geometric.utils import index_sort
from torch_geometric.utils.num_nodes import maybe_num_nodes

if typing.TYPE_CHECKING:
from typing import overload
else:
from torch.jit import _overload as overload

MISSING = '???'


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): # noqa
# type: (Tensor, str, Optional[int], bool) -> Tensor
@overload
def sort_edge_index(
edge_index: Tensor,
edge_attr: str = MISSING,
num_nodes: Optional[int] = None,
sort_by_row: bool = True,
) -> Tensor:
pass


@overload
def sort_edge_index( # noqa: F811
edge_index: Tensor,
edge_attr: Tensor,
num_nodes: Optional[int] = None,
sort_by_row: bool = True,
) -> Tuple[Tensor, Tensor]:
pass


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): # noqa
# type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tuple[Tensor, Optional[Tensor]] # noqa
@overload
def sort_edge_index( # noqa: F811
edge_index: Tensor,
edge_attr: OptTensor,
num_nodes: Optional[int] = None,
sort_by_row: bool = True,
) -> Tuple[Tensor, OptTensor]:
pass


@torch.jit._overload
def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): # noqa
# type: (Tensor, List[Tensor], Optional[int], bool) -> Tuple[Tensor, List[Tensor]] # noqa
@overload
def sort_edge_index( # noqa: F811
edge_index: Tensor,
edge_attr: List[Tensor],
num_nodes: Optional[int] = None,
sort_by_row: bool = True,
) -> Tuple[Tensor, List[Tensor]]:
pass


def sort_edge_index( # noqa
def sort_edge_index( # noqa: F811
edge_index: Tensor,
edge_attr: Union[OptTensor, List[Tensor], str] = MISSING,
num_nodes: Optional[int] = None,
Expand Down
28 changes: 20 additions & 8 deletions torch_geometric/utils/_spmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
import warnings

import torch
Expand All @@ -7,25 +8,36 @@
from torch_geometric.typing import Adj, SparseTensor, torch_sparse
from torch_geometric.utils import is_torch_sparse_tensor, scatter

if typing.TYPE_CHECKING:
from typing import overload
else:
from torch.jit import _overload as overload

@torch.jit._overload
def spmm(src, other, reduce): # noqa: F811
# type: (Tensor, Tensor, str) -> Tensor

@overload
def spmm(
src: Tensor,
other: Tensor,
reduce: str = 'sum',
) -> Tensor:
pass


@torch.jit._overload
def spmm(src, other, reduce): # noqa: F811
# type: (SparseTensor, Tensor, str) -> Tensor
@overload
def spmm( # noqa: F811
src: SparseTensor,
other: Tensor,
reduce: str = 'sum',
) -> Tensor:
pass


def spmm( # noqa: F811
src: Adj,
other: Tensor,
reduce: str = "sum",
reduce: str = 'sum',
) -> Tensor:
"""Matrix product of sparse matrix with dense matrix.
r"""Matrix product of sparse matrix with dense matrix.
Args:
src (torch.Tensor or torch_sparse.SparseTensor): The input sparse
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/utils/_to_dense_adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def to_dense_adj(
[5., 0.]]])
"""
if batch is None:
num_nodes = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
batch = edge_index.new_zeros(num_nodes)
max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
batch = edge_index.new_zeros(max_index)

if batch_size is None:
batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1
Expand Down
19 changes: 10 additions & 9 deletions torch_geometric/utils/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
from typing import Optional

import numpy as np
import torch
from torch import Tensor

Expand Down Expand Up @@ -73,9 +72,9 @@ def geodesic_distance( # noqa: D417

if norm:
area = (pos[face[1]] - pos[face[0]]).cross(pos[face[2]] - pos[face[0]])
norm = (area.norm(p=2, dim=1) / 2).sum().sqrt().item()
scale = float((area.norm(p=2, dim=1) / 2).sum().sqrt())
else:
norm = 1.0
scale = 1.0

dtype = pos.dtype

Expand All @@ -84,32 +83,34 @@ def geodesic_distance( # noqa: D417

if src is None and dst is None:
out = gdist.local_gdist_matrix(pos, face,
max_distance * norm).toarray() / norm
max_distance * scale).toarray() / scale
return torch.from_numpy(out).to(dtype)

if src is None:
src = np.arange(pos.shape[0], dtype=np.int32)
src = torch.arange(pos.shape[0], dtype=torch.int).numpy()
else:
src = src.detach().cpu().to(torch.int).numpy()
assert src is not None

dst = None if dst is None else dst.detach().cpu().to(torch.int).numpy()

def _parallel_loop(pos, face, src, dst, max_distance, norm, i, dtype):
def _parallel_loop(pos, face, src, dst, max_distance, scale, i, dtype):
s = src[i:i + 1]
d = None if dst is None else dst[i:i + 1]
out = gdist.compute_gdist(pos, face, s, d, max_distance * norm) / norm
out = gdist.compute_gdist(pos, face, s, d, max_distance * scale)
out = out / scale
return torch.from_numpy(out).to(dtype)

num_workers = mp.cpu_count() if num_workers <= -1 else num_workers
if num_workers > 0:
with mp.Pool(num_workers) as pool:
outs = pool.starmap(
_parallel_loop,
[(pos, face, src, dst, max_distance, norm, i, dtype)
[(pos, face, src, dst, max_distance, scale, i, dtype)
for i in range(len(src))])
else:
outs = [
_parallel_loop(pos, face, src, dst, max_distance, norm, i, dtype)
_parallel_loop(pos, face, src, dst, max_distance, scale, i, dtype)
for i in range(len(src))
]

Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/utils/num_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@
@overload
def maybe_num_nodes(
edge_index: Tensor,
num_nodes: Optional[int],
num_nodes: Optional[int] = None,
) -> int:
pass


@overload
def maybe_num_nodes( # noqa: F811
edge_index: Tuple[Tensor, Tensor],
num_nodes: Optional[int],
num_nodes: Optional[int] = None,
) -> int:
pass


@overload
def maybe_num_nodes( # noqa: F811
edge_index: SparseTensor,
num_nodes: Optional[int],
edge_index: SparseTensor,
num_nodes: Optional[int] = None,
) -> int:
pass

Expand Down
20 changes: 10 additions & 10 deletions torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
import warnings
from typing import Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -303,6 +304,8 @@ def to_torch_csc_tensor(
"""
if not torch_geometric.typing.WITH_PT112:
if typing.TYPE_CHECKING:
raise NotImplementedError
return torch_geometric.typing.MockTorchCSCTensor(
edge_index, edge_attr, size)

Expand Down Expand Up @@ -434,10 +437,10 @@ def get_sparse_diag(


def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:
size = adj.size()

if value.dim() > 1:
size = size + value.size()[1:]
size = adj.size() + value.size()[1:]
else:
size = adj.size()

if adj.layout == torch.sparse_coo:
return torch.sparse_coo_tensor(
Expand Down Expand Up @@ -483,7 +486,7 @@ def cat(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:
# the individual sparse tensor layouts.
assert dim in {0, 1, (0, 1)}

size = [0, 0]
size = (0, 0)
edge_indices = []
edge_attrs = []
for tensor in tensors:
Expand All @@ -493,17 +496,14 @@ def cat(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:

if dim == 0:
edge_index[0] += size[0]
size[0] += tensor.size(0)
size[1] = max(size[1], tensor.size(1))
size = (size[0] + tensor.size(0), max(size[1], tensor.size(1)))
elif dim == 1:
edge_index[1] += size[1]
size[0] = max(size[0], tensor.size(0))
size[1] += tensor.size(1)
size = (max(size[0], tensor.size(0)), size[1] + tensor.size(1))
else:
edge_index[0] += size[0]
edge_index[1] += size[1]
size[0] += tensor.size(0)
size[1] += tensor.size(1)
size = (size[0] + tensor.size(0), size[1] + tensor.size(1))

edge_indices.append(edge_index)
edge_attrs.append(edge_attr)
Expand Down

0 comments on commit 447f5a5

Please sign in to comment.