From c519322d591a5b3e7d0b6fd515c21e96d14c667f Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 6 Dec 2023 20:22:39 +0100 Subject: [PATCH 1/5] `mypy` support in `torch_geometric.utils` [1/n] (#8554) --- torch_geometric/utils/dropout.py | 11 ++++++---- torch_geometric/utils/map.py | 10 ++++----- torch_geometric/utils/num_nodes.py | 35 ++++++++++++++++++++---------- torch_geometric/utils/subgraph.py | 28 ++++++++++++++++++++++-- 4 files changed, 61 insertions(+), 23 deletions(-) diff --git a/torch_geometric/utils/dropout.py b/torch_geometric/utils/dropout.py index 2888def4a0f5..3009f1181026 100644 --- a/torch_geometric/utils/dropout.py +++ b/torch_geometric/utils/dropout.py @@ -143,10 +143,13 @@ def dropout_node( prob = torch.rand(num_nodes, device=edge_index.device) node_mask = prob > p - edge_index, _, edge_mask = subgraph(node_mask, edge_index, - num_nodes=num_nodes, - relabel_nodes=relabel_nodes, - return_edge_mask=True) + edge_index, _, edge_mask = subgraph( + node_mask, + edge_index, + relabel_nodes=relabel_nodes, + num_nodes=num_nodes, + return_edge_mask=True, + ) return edge_index, edge_mask, node_mask diff --git a/torch_geometric/utils/map.py b/torch_geometric/utils/map.py index 5201c72f104e..95731d0e3155 100644 --- a/torch_geometric/utils/map.py +++ b/torch_geometric/utils/map.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import Tensor @@ -9,7 +9,7 @@ def map_index( src: Tensor, index: Tensor, - max_index: Optional[int] = None, + max_index: Optional[Union[int, Tensor]] = None, inclusive: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: r"""Maps indices in :obj:`src` to the positional value of their @@ -63,7 +63,7 @@ def map_index( f"(got '{src.device}' and '{index.device}')") if max_index is None: - max_index = max(src.max(), index.max()) + max_index = torch.maximum(src.max(), index.max()) # If the `max_index` is in a reasonable range, we can accelerate this # operation by creating a helper vector to perform the mapping. @@ -72,9 +72,9 @@ def map_index( THRESHOLD = 40_000_000 if src.is_cuda else 10_000_000 if max_index <= THRESHOLD: if inclusive: - assoc = src.new_empty(max_index + 1) + assoc = src.new_empty(max_index + 1) # type: ignore else: - assoc = src.new_full((max_index + 1, ), -1) + assoc = src.new_full((max_index + 1, ), -1) # type: ignore assoc[index] = torch.arange(index.numel(), dtype=src.dtype, device=src.device) out = assoc[src] diff --git a/torch_geometric/utils/num_nodes.py b/torch_geometric/utils/num_nodes.py index 65a6d78ee6b7..35b892dcfb81 100644 --- a/torch_geometric/utils/num_nodes.py +++ b/torch_geometric/utils/num_nodes.py @@ -1,3 +1,4 @@ +import typing from copy import copy from typing import Dict, Optional, Tuple, Union @@ -7,22 +8,33 @@ import torch_geometric from torch_geometric.typing import EdgeType, NodeType, SparseTensor +if typing.TYPE_CHECKING: + from typing import overload +else: + from torch.jit import _overload as overload -@torch.jit._overload -def maybe_num_nodes(edge_index, num_nodes): # noqa: F811 - # type: (Tensor, Optional[int]) -> int + +@overload +def maybe_num_nodes( + edge_index: Tensor, + num_nodes: Optional[int], +) -> int: pass -@torch.jit._overload -def maybe_num_nodes(edge_index, num_nodes): # noqa: F811 - # type: (Tuple[Tensor, Tensor], Optional[int]) -> int +@overload +def maybe_num_nodes( # noqa: F811 + edge_index: Tuple[Tensor, Tensor], + num_nodes: Optional[int], +) -> int: pass -@torch.jit._overload -def maybe_num_nodes(edge_index, num_nodes): # noqa: F811 - # type: (SparseTensor, Optional[int]) -> int +@overload +def maybe_num_nodes( # noqa: F811 + edge_index: SparseTensor, + num_nodes: Optional[int], +) -> int: pass @@ -42,7 +54,7 @@ def maybe_num_nodes( # noqa: F811 edge_index.view(-1), edge_index.new_full((1, ), fill_value=-1) ]) - return tmp.max() + 1 + return tmp.max() + 1 # type: ignore return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 elif isinstance(edge_index, tuple): @@ -52,8 +64,7 @@ def maybe_num_nodes( # noqa: F811 ) elif isinstance(edge_index, SparseTensor): return max(edge_index.size(0), edge_index.size(1)) - else: - raise NotImplementedError + raise NotImplementedError def maybe_num_nodes_dict( diff --git a/torch_geometric/utils/subgraph.py b/torch_geometric/utils/subgraph.py index fe8c312eca93..711b1f666d5d 100644 --- a/torch_geometric/utils/subgraph.py +++ b/torch_geometric/utils/subgraph.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union, overload import torch from torch import Tensor @@ -37,6 +37,30 @@ def get_num_hops(model: torch.nn.Module) -> int: return num_hops +@overload +def subgraph( + subset: Union[Tensor, List[int]], + edge_index: Tensor, + edge_attr: OptTensor = ..., + relabel_nodes: bool = ..., + num_nodes: Optional[int] = ..., + return_edge_mask: Literal[False] = ..., +) -> Tuple[Tensor, OptTensor]: + pass + + +@overload +def subgraph( + subset: Union[Tensor, List[int]], + edge_index: Tensor, + edge_attr: OptTensor = ..., + relabel_nodes: bool = ..., + num_nodes: Optional[int] = ..., + return_edge_mask: Literal[True] = ..., +) -> Tuple[Tensor, OptTensor, Tensor]: + pass + + def subgraph( subset: Union[Tensor, List[int]], edge_index: Tensor, @@ -44,7 +68,7 @@ def subgraph( relabel_nodes: bool = False, num_nodes: Optional[int] = None, return_edge_mask: bool = False, -) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, OptTensor]]: +) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]: r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`. From 909e606a0811f617a3d0a4e4a72963fefba0ff59 Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Wed, 6 Dec 2023 11:29:05 -0800 Subject: [PATCH 2/5] update from EA to GA PyG NVIDIA container in multinode tutorial (#8553) Co-authored-by: rusty1s --- docs/source/tutorial/multi_node_multi_gpu_vanilla.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorial/multi_node_multi_gpu_vanilla.rst b/docs/source/tutorial/multi_node_multi_gpu_vanilla.rst index 8e7be735280d..ca9765108f77 100644 --- a/docs/source/tutorial/multi_node_multi_gpu_vanilla.rst +++ b/docs/source/tutorial/multi_node_multi_gpu_vanilla.rst @@ -61,9 +61,8 @@ Finally, to submit the :obj:`*.sbatch` file itself into the work queue, use the Using a cluster configured with pyxis-containers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -If your cluster supports the :obj:`pyxis` plugin developed by NVIDIA, you can use a ready-to-use :pyg:`PyG` container that is updated each month with the latest from NVIDIA and :pyg:`PyG`. -Currently it is not yet publically available, but you can sign up for early access `here `_. -The container should set up all necessary environment variables from which you can now directly run the example using :obj:`srun` from your command prompt: +If your cluster supports the :obj:`pyxis` plugin developed by NVIDIA, you can use a ready-to-use :pyg:`PyG` container that is updated each month with the latest from NVIDIA and :pyg:`PyG`, see `here `_ for more information. +The container sets up all necessary environment variables from which you can now directly run the example using :obj:`srun` from your command prompt: .. code-block:: console From b7edc223509d2b41c541f448a3f4ff45ad16d6a2 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 6 Dec 2023 20:44:05 +0100 Subject: [PATCH 3/5] `mypy` support in `torch_geometric.utils` [2/n] (#8555) --- test/nn/models/test_basic_gnn.py | 2 +- test/utils/test_negative_sampling.py | 2 +- test/utils/test_scatter.py | 2 +- test/utils/test_trim_to_layer.py | 2 +- .../contrib/explain/pgm_explainer.py | 2 +- torch_geometric/data/hypergraph_data.py | 2 +- torch_geometric/nn/models/basic_gnn.py | 2 +- torch_geometric/nn/models/tgn.py | 2 +- torch_geometric/utils/__init__.py | 56 +++++++++---------- .../{assortativity.py => _assortativity.py} | 2 +- .../utils/{coalesce.py => _coalesce.py} | 0 .../utils/{degree.py => _degree.py} | 0 torch_geometric/utils/{grid.py => _grid.py} | 2 +- .../utils/{homophily.py => _homophily.py} | 0 .../utils/{sort.py => _index_sort.py} | 0 .../utils/{lexsort.py => _lexsort.py} | 0 ...tive_sampling.py => _negative_sampling.py} | 0 .../{normalized_cut.py => _normalized_cut.py} | 0 .../utils/{one_hot.py => _one_hot.py} | 0 .../utils/{scatter.py => _scatter.py} | 0 .../utils/{segment.py => _segment.py} | 0 .../utils/{select.py => _select.py} | 0 .../utils/{softmax.py => _softmax.py} | 0 ...sort_edge_index.py => _sort_edge_index.py} | 0 torch_geometric/utils/{spmm.py => _spmm.py} | 0 .../utils/{subgraph.py => _subgraph.py} | 0 .../{to_dense_adj.py => _to_dense_adj.py} | 0 .../{to_dense_batch.py => _to_dense_batch.py} | 0 ...it_edges.py => _train_test_split_edges.py} | 0 ...ecomposition.py => _tree_decomposition.py} | 0 .../{trim_to_layer.py => _trim_to_layer.py} | 0 .../utils/{unbatch.py => _unbatch.py} | 0 torch_geometric/utils/hetero.py | 10 ++-- torch_geometric/utils/isolated.py | 2 +- .../utils/{get_laplacian.py => laplacian.py} | 0 ...et_mesh_laplacian.py => mesh_laplacian.py} | 0 36 files changed, 44 insertions(+), 44 deletions(-) rename torch_geometric/utils/{assortativity.py => _assortativity.py} (97%) rename torch_geometric/utils/{coalesce.py => _coalesce.py} (100%) rename torch_geometric/utils/{degree.py => _degree.py} (100%) rename torch_geometric/utils/{grid.py => _grid.py} (97%) rename torch_geometric/utils/{homophily.py => _homophily.py} (100%) rename torch_geometric/utils/{sort.py => _index_sort.py} (100%) rename torch_geometric/utils/{lexsort.py => _lexsort.py} (100%) rename torch_geometric/utils/{negative_sampling.py => _negative_sampling.py} (100%) rename torch_geometric/utils/{normalized_cut.py => _normalized_cut.py} (100%) rename torch_geometric/utils/{one_hot.py => _one_hot.py} (100%) rename torch_geometric/utils/{scatter.py => _scatter.py} (100%) rename torch_geometric/utils/{segment.py => _segment.py} (100%) rename torch_geometric/utils/{select.py => _select.py} (100%) rename torch_geometric/utils/{softmax.py => _softmax.py} (100%) rename torch_geometric/utils/{sort_edge_index.py => _sort_edge_index.py} (100%) rename torch_geometric/utils/{spmm.py => _spmm.py} (100%) rename torch_geometric/utils/{subgraph.py => _subgraph.py} (100%) rename torch_geometric/utils/{to_dense_adj.py => _to_dense_adj.py} (100%) rename torch_geometric/utils/{to_dense_batch.py => _to_dense_batch.py} (100%) rename torch_geometric/utils/{train_test_split_edges.py => _train_test_split_edges.py} (100%) rename torch_geometric/utils/{tree_decomposition.py => _tree_decomposition.py} (100%) rename torch_geometric/utils/{trim_to_layer.py => _trim_to_layer.py} (100%) rename torch_geometric/utils/{unbatch.py => _unbatch.py} (100%) rename torch_geometric/utils/{get_laplacian.py => laplacian.py} (100%) rename torch_geometric/utils/{get_mesh_laplacian.py => mesh_laplacian.py} (100%) diff --git a/test/nn/models/test_basic_gnn.py b/test/nn/models/test_basic_gnn.py index 3262f5ed261d..d67c8278e0f0 100644 --- a/test/nn/models/test_basic_gnn.py +++ b/test/nn/models/test_basic_gnn.py @@ -246,7 +246,7 @@ def test_packaging(): path = osp.join(torch.hub._get_torch_home(), 'pyg_test_package.pt') with torch.package.PackageExporter(path) as pe: pe.extern('torch_geometric.nn.**') - pe.extern('torch_geometric.utils.trim_to_layer') + pe.extern('torch_geometric.utils._trim_to_layer') pe.extern('_operator') pe.save_pickle('models', 'model.pkl', model) diff --git a/test/utils/test_negative_sampling.py b/test/utils/test_negative_sampling.py index 126e9e565c39..709452fe60a5 100644 --- a/test/utils/test_negative_sampling.py +++ b/test/utils/test_negative_sampling.py @@ -9,7 +9,7 @@ structured_negative_sampling_feasible, to_undirected, ) -from torch_geometric.utils.negative_sampling import ( +from torch_geometric.utils._negative_sampling import ( edge_index_to_vector, vector_to_edge_index, ) diff --git a/test/utils/test_scatter.py b/test/utils/test_scatter.py index 658a7ac80f9e..c487247e1082 100644 --- a/test/utils/test_scatter.py +++ b/test/utils/test_scatter.py @@ -6,7 +6,7 @@ from torch_geometric.profile import benchmark from torch_geometric.testing import disableExtensions, withCUDA, withPackage from torch_geometric.utils import group_argsort, scatter -from torch_geometric.utils.scatter import scatter_argmax +from torch_geometric.utils._scatter import scatter_argmax @withPackage('torch>=1.12.0') diff --git a/test/utils/test_trim_to_layer.py b/test/utils/test_trim_to_layer.py index 884381dd8435..c1f1c0fbd038 100644 --- a/test/utils/test_trim_to_layer.py +++ b/test/utils/test_trim_to_layer.py @@ -10,7 +10,7 @@ from torch_geometric.testing import withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import trim_to_layer -from torch_geometric.utils.trim_to_layer import trim_sparse_tensor +from torch_geometric.utils._trim_to_layer import trim_sparse_tensor @withPackage('torch_sparse') diff --git a/torch_geometric/contrib/explain/pgm_explainer.py b/torch_geometric/contrib/explain/pgm_explainer.py index fcfcf93a509a..f21801d5ea21 100644 --- a/torch_geometric/contrib/explain/pgm_explainer.py +++ b/torch_geometric/contrib/explain/pgm_explainer.py @@ -9,7 +9,7 @@ from torch_geometric.explain.config import ModelMode, ModelTaskLevel from torch_geometric.explain.explanation import Explanation from torch_geometric.utils import k_hop_subgraph -from torch_geometric.utils.subgraph import get_num_hops +from torch_geometric.utils._subgraph import get_num_hops class PGMExplainer(ExplainerAlgorithm): diff --git a/torch_geometric/data/hypergraph_data.py b/torch_geometric/data/hypergraph_data.py index ddc02ca5c28f..48511bada9a2 100644 --- a/torch_geometric/data/hypergraph_data.py +++ b/torch_geometric/data/hypergraph_data.py @@ -8,7 +8,7 @@ from torch_geometric.data import Data from torch_geometric.typing import EdgeType, NodeType, OptTensor from torch_geometric.utils import select -from torch_geometric.utils.subgraph import hyper_subgraph +from torch_geometric.utils._subgraph import hyper_subgraph class HyperGraphData(Data): diff --git a/torch_geometric/nn/models/basic_gnn.py b/torch_geometric/nn/models/basic_gnn.py index 350545b11448..d24d886c441b 100644 --- a/torch_geometric/nn/models/basic_gnn.py +++ b/torch_geometric/nn/models/basic_gnn.py @@ -26,7 +26,7 @@ normalization_resolver, ) from torch_geometric.typing import Adj, OptTensor, SparseTensor -from torch_geometric.utils.trim_to_layer import TrimToLayer +from torch_geometric.utils._trim_to_layer import TrimToLayer class BasicGNN(torch.nn.Module): diff --git a/torch_geometric/nn/models/tgn.py b/torch_geometric/nn/models/tgn.py index f654f0b6c8b6..41e0aecf34bf 100644 --- a/torch_geometric/nn/models/tgn.py +++ b/torch_geometric/nn/models/tgn.py @@ -7,7 +7,7 @@ from torch_geometric.nn.inits import zeros from torch_geometric.utils import scatter -from torch_geometric.utils.scatter import scatter_argmax +from torch_geometric.utils._scatter import scatter_argmax TGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]] diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index 6f8a86af7119..ab81e8961c01 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -2,41 +2,41 @@ import copy -from .scatter import scatter, group_argsort -from .segment import segment -from .sort import index_sort +from ._scatter import scatter, group_argsort +from ._segment import segment +from ._index_sort import index_sort from .functions import cumsum -from .degree import degree -from .softmax import softmax -from .sort_edge_index import sort_edge_index -from .lexsort import lexsort -from .coalesce import coalesce +from ._degree import degree +from ._softmax import softmax +from ._sort_edge_index import sort_edge_index +from ._lexsort import lexsort +from ._coalesce import coalesce from .undirected import is_undirected, to_undirected from .loop import (contains_self_loops, remove_self_loops, segregate_self_loops, add_self_loops, add_remaining_self_loops, get_self_loop_attr) from .isolated import contains_isolated_nodes, remove_isolated_nodes -from .subgraph import (get_num_hops, subgraph, k_hop_subgraph, - bipartite_subgraph) +from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph, + bipartite_subgraph) from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path -from .homophily import homophily -from .assortativity import assortativity -from .get_laplacian import get_laplacian -from .get_mesh_laplacian import get_mesh_laplacian +from ._homophily import homophily +from ._assortativity import assortativity +from .laplacian import get_laplacian +from .mesh_laplacian import get_mesh_laplacian from .mask import mask_select, index_to_mask, mask_to_index -from .select import select, narrow -from .to_dense_batch import to_dense_batch -from .to_dense_adj import to_dense_adj +from ._select import select, narrow +from ._to_dense_batch import to_dense_batch +from ._to_dense_adj import to_dense_adj from .nested import to_nested_tensor, from_nested_tensor from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor, to_torch_coo_tensor, to_torch_csr_tensor, to_torch_csc_tensor, to_torch_sparse_tensor, to_edge_index) -from .spmm import spmm -from .unbatch import unbatch, unbatch_edge_index -from .one_hot import one_hot -from .normalized_cut import normalized_cut -from .grid import grid +from ._spmm import spmm +from ._unbatch import unbatch, unbatch_edge_index +from ._one_hot import one_hot +from ._normalized_cut import normalized_cut +from ._grid import grid from .geodesic import geodesic_distance from .convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix from .convert import to_networkx, from_networkx @@ -47,15 +47,15 @@ from .smiles import from_smiles, to_smiles from .random import (erdos_renyi_graph, stochastic_blockmodel_graph, barabasi_albert_graph) -from .negative_sampling import (negative_sampling, batched_negative_sampling, - structured_negative_sampling, - structured_negative_sampling_feasible) +from ._negative_sampling import (negative_sampling, batched_negative_sampling, + structured_negative_sampling, + structured_negative_sampling_feasible) from .augmentation import shuffle_node, mask_feature, add_random_edge -from .tree_decomposition import tree_decomposition +from ._tree_decomposition import tree_decomposition from .embedding import get_embeddings -from .trim_to_layer import trim_to_layer +from ._trim_to_layer import trim_to_layer from .ppr import get_ppr -from .train_test_split_edges import train_test_split_edges +from ._train_test_split_edges import train_test_split_edges __all__ = [ 'scatter', diff --git a/torch_geometric/utils/assortativity.py b/torch_geometric/utils/_assortativity.py similarity index 97% rename from torch_geometric/utils/assortativity.py rename to torch_geometric/utils/_assortativity.py index 4e982f6e6d9b..9ca09e620b12 100644 --- a/torch_geometric/utils/assortativity.py +++ b/torch_geometric/utils/_assortativity.py @@ -3,7 +3,7 @@ from torch_geometric.typing import Adj, SparseTensor from torch_geometric.utils import coalesce, degree -from torch_geometric.utils.to_dense_adj import to_dense_adj +from torch_geometric.utils._to_dense_adj import to_dense_adj def assortativity(edge_index: Adj) -> float: diff --git a/torch_geometric/utils/coalesce.py b/torch_geometric/utils/_coalesce.py similarity index 100% rename from torch_geometric/utils/coalesce.py rename to torch_geometric/utils/_coalesce.py diff --git a/torch_geometric/utils/degree.py b/torch_geometric/utils/_degree.py similarity index 100% rename from torch_geometric/utils/degree.py rename to torch_geometric/utils/_degree.py diff --git a/torch_geometric/utils/grid.py b/torch_geometric/utils/_grid.py similarity index 97% rename from torch_geometric/utils/grid.py rename to torch_geometric/utils/_grid.py index 6db4274cb870..7e6f26374418 100644 --- a/torch_geometric/utils/grid.py +++ b/torch_geometric/utils/_grid.py @@ -3,7 +3,7 @@ import torch from torch import Tensor -from torch_geometric.utils.coalesce import coalesce +from torch_geometric.utils import coalesce def grid( diff --git a/torch_geometric/utils/homophily.py b/torch_geometric/utils/_homophily.py similarity index 100% rename from torch_geometric/utils/homophily.py rename to torch_geometric/utils/_homophily.py diff --git a/torch_geometric/utils/sort.py b/torch_geometric/utils/_index_sort.py similarity index 100% rename from torch_geometric/utils/sort.py rename to torch_geometric/utils/_index_sort.py diff --git a/torch_geometric/utils/lexsort.py b/torch_geometric/utils/_lexsort.py similarity index 100% rename from torch_geometric/utils/lexsort.py rename to torch_geometric/utils/_lexsort.py diff --git a/torch_geometric/utils/negative_sampling.py b/torch_geometric/utils/_negative_sampling.py similarity index 100% rename from torch_geometric/utils/negative_sampling.py rename to torch_geometric/utils/_negative_sampling.py diff --git a/torch_geometric/utils/normalized_cut.py b/torch_geometric/utils/_normalized_cut.py similarity index 100% rename from torch_geometric/utils/normalized_cut.py rename to torch_geometric/utils/_normalized_cut.py diff --git a/torch_geometric/utils/one_hot.py b/torch_geometric/utils/_one_hot.py similarity index 100% rename from torch_geometric/utils/one_hot.py rename to torch_geometric/utils/_one_hot.py diff --git a/torch_geometric/utils/scatter.py b/torch_geometric/utils/_scatter.py similarity index 100% rename from torch_geometric/utils/scatter.py rename to torch_geometric/utils/_scatter.py diff --git a/torch_geometric/utils/segment.py b/torch_geometric/utils/_segment.py similarity index 100% rename from torch_geometric/utils/segment.py rename to torch_geometric/utils/_segment.py diff --git a/torch_geometric/utils/select.py b/torch_geometric/utils/_select.py similarity index 100% rename from torch_geometric/utils/select.py rename to torch_geometric/utils/_select.py diff --git a/torch_geometric/utils/softmax.py b/torch_geometric/utils/_softmax.py similarity index 100% rename from torch_geometric/utils/softmax.py rename to torch_geometric/utils/_softmax.py diff --git a/torch_geometric/utils/sort_edge_index.py b/torch_geometric/utils/_sort_edge_index.py similarity index 100% rename from torch_geometric/utils/sort_edge_index.py rename to torch_geometric/utils/_sort_edge_index.py diff --git a/torch_geometric/utils/spmm.py b/torch_geometric/utils/_spmm.py similarity index 100% rename from torch_geometric/utils/spmm.py rename to torch_geometric/utils/_spmm.py diff --git a/torch_geometric/utils/subgraph.py b/torch_geometric/utils/_subgraph.py similarity index 100% rename from torch_geometric/utils/subgraph.py rename to torch_geometric/utils/_subgraph.py diff --git a/torch_geometric/utils/to_dense_adj.py b/torch_geometric/utils/_to_dense_adj.py similarity index 100% rename from torch_geometric/utils/to_dense_adj.py rename to torch_geometric/utils/_to_dense_adj.py diff --git a/torch_geometric/utils/to_dense_batch.py b/torch_geometric/utils/_to_dense_batch.py similarity index 100% rename from torch_geometric/utils/to_dense_batch.py rename to torch_geometric/utils/_to_dense_batch.py diff --git a/torch_geometric/utils/train_test_split_edges.py b/torch_geometric/utils/_train_test_split_edges.py similarity index 100% rename from torch_geometric/utils/train_test_split_edges.py rename to torch_geometric/utils/_train_test_split_edges.py diff --git a/torch_geometric/utils/tree_decomposition.py b/torch_geometric/utils/_tree_decomposition.py similarity index 100% rename from torch_geometric/utils/tree_decomposition.py rename to torch_geometric/utils/_tree_decomposition.py diff --git a/torch_geometric/utils/trim_to_layer.py b/torch_geometric/utils/_trim_to_layer.py similarity index 100% rename from torch_geometric/utils/trim_to_layer.py rename to torch_geometric/utils/_trim_to_layer.py diff --git a/torch_geometric/utils/unbatch.py b/torch_geometric/utils/_unbatch.py similarity index 100% rename from torch_geometric/utils/unbatch.py rename to torch_geometric/utils/_unbatch.py diff --git a/torch_geometric/utils/hetero.py b/torch_geometric/utils/hetero.py index 1e9b8c21704e..60cabc5119e4 100644 --- a/torch_geometric/utils/hetero.py +++ b/torch_geometric/utils/hetero.py @@ -105,12 +105,12 @@ def construct_bipartite_edge_index( if edge_attr_dict is not None: if isinstance(edge_attr_dict, ParameterDict): - edge_attr = edge_attr_dict['__'.join(edge_type)] + value = edge_attr_dict['__'.join(edge_type)] else: - edge_attr = edge_attr_dict[edge_type] - if edge_attr.size(0) != edge_index.size(1): - edge_attr = edge_attr.expand(edge_index.size(1), -1) - edge_attrs.append(edge_attr) + value = edge_attr_dict[edge_type] + if value.size(0) != edge_index.size(1): + value = value.expand(edge_index.size(1), -1) + edge_attrs.append(value) edge_index = torch.cat(edge_indices, dim=1) diff --git a/torch_geometric/utils/isolated.py b/torch_geometric/utils/isolated.py index e204ca1f5782..5c9a03065768 100644 --- a/torch_geometric/utils/isolated.py +++ b/torch_geometric/utils/isolated.py @@ -76,7 +76,7 @@ def remove_isolated_nodes( mask[edge_index.view(-1)] = 1 assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device) - assoc[mask] = torch.arange(mask.sum(), device=assoc.device) + assoc[mask] = torch.arange(mask.sum(), device=assoc.device) # type: ignore edge_index = assoc[edge_index] loop_mask = torch.zeros_like(mask) diff --git a/torch_geometric/utils/get_laplacian.py b/torch_geometric/utils/laplacian.py similarity index 100% rename from torch_geometric/utils/get_laplacian.py rename to torch_geometric/utils/laplacian.py diff --git a/torch_geometric/utils/get_mesh_laplacian.py b/torch_geometric/utils/mesh_laplacian.py similarity index 100% rename from torch_geometric/utils/get_mesh_laplacian.py rename to torch_geometric/utils/mesh_laplacian.py From b454a51e47325c21cbefa54d611f7d800080925c Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Wed, 6 Dec 2023 23:52:38 +0100 Subject: [PATCH 4/5] `mypy` support in `torch_geometric.utils`[3/n] (#8556) --- torch_geometric/utils/_coalesce.py | 51 ++++++++++++++++++++++-------- torch_geometric/utils/_grid.py | 6 ++-- torch_geometric/utils/_subgraph.py | 30 ++++++++++++++---- 3 files changed, 64 insertions(+), 23 deletions(-) diff --git a/torch_geometric/utils/_coalesce.py b/torch_geometric/utils/_coalesce.py index 981c54e778d8..66efc8d5ca2f 100644 --- a/torch_geometric/utils/_coalesce.py +++ b/torch_geometric/utils/_coalesce.py @@ -1,3 +1,4 @@ +import typing from typing import List, Optional, Tuple, Union import torch @@ -7,27 +8,47 @@ from torch_geometric.utils import index_sort, scatter from torch_geometric.utils.num_nodes import maybe_num_nodes +if typing.TYPE_CHECKING: + from typing import overload +else: + from torch.jit import _overload as overload + MISSING = '???' -@torch.jit._overload -def coalesce( # noqa: F811 - edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row): - # type: (Tensor, str, Optional[int], str, bool, bool) -> Tensor +@overload +def coalesce( + edge_index: Tensor, + edge_attr: str = MISSING, + num_nodes: Optional[int] = None, + reduce: str = 'sum', + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Tensor: pass -@torch.jit._overload +@overload def coalesce( # noqa: F811 - edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row): - # type: (Tensor, Optional[Tensor], Optional[int], str, bool, bool) -> Tuple[Tensor, Optional[Tensor]] # noqa + edge_index: Tensor, + edge_attr: OptTensor, + num_nodes: Optional[int] = None, + reduce: str = 'sum', + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Tuple[Tensor, OptTensor]: pass -@torch.jit._overload +@overload def coalesce( # noqa: F811 - edge_index, edge_attr, num_nodes, reduce, is_sorted, sort_by_row): - # type: (Tensor, List[Tensor], Optional[int], str, bool, bool) -> Tuple[Tensor, List[Tensor]] # noqa + edge_index: Tensor, + edge_attr: List[Tensor], + num_nodes: Optional[int] = None, + reduce: str = 'sum', + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Tuple[Tensor, List[Tensor]]: pass @@ -35,7 +56,7 @@ def coalesce( # noqa: F811 edge_index: Tensor, edge_attr: Union[OptTensor, List[Tensor], str] = MISSING, num_nodes: Optional[int] = None, - reduce: str = 'add', + reduce: str = 'sum', is_sorted: bool = False, sort_by_row: bool = True, ) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]: @@ -52,8 +73,8 @@ def coalesce( # noqa: F811 num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) reduce (str, optional): The reduce operation to use for merging edge - features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, - :obj:`"mul"`, :obj:`"any"`). (default: :obj:`"add"`) + features (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, + :obj:`"mul"`, :obj:`"any"`). (default: :obj:`"sum"`) is_sorted (bool, optional): If set to :obj:`True`, will expect :obj:`edge_index` to be already sorted row-wise. sort_by_row (bool, optional): If set to :obj:`False`, will sort @@ -117,7 +138,9 @@ def coalesce( # noqa: F811 # Only perform expensive merging in case there exists duplicates: if mask.all(): - if edge_attr is None or isinstance(edge_attr, (Tensor, list, tuple)): + if edge_attr is None or isinstance(edge_attr, Tensor): + return edge_index, edge_attr + if isinstance(edge_attr, (list, tuple)): return edge_index, edge_attr return edge_index diff --git a/torch_geometric/utils/_grid.py b/torch_geometric/utils/_grid.py index 7e6f26374418..4154da62dce7 100644 --- a/torch_geometric/utils/_grid.py +++ b/torch_geometric/utils/_grid.py @@ -49,8 +49,10 @@ def grid_index( ) -> Tensor: w = width - kernel = [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1] - kernel = torch.tensor(kernel, device=device) + kernel = torch.tensor( + [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1], + device=device, + ) row = torch.arange(height * width, dtype=torch.long, device=device) row = row.view(-1, 1).repeat(1, kernel.size(0)) diff --git a/torch_geometric/utils/_subgraph.py b/torch_geometric/utils/_subgraph.py index 711b1f666d5d..e0a03750cf1e 100644 --- a/torch_geometric/utils/_subgraph.py +++ b/torch_geometric/utils/_subgraph.py @@ -44,7 +44,6 @@ def subgraph( edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., - return_edge_mask: Literal[False] = ..., ) -> Tuple[Tensor, OptTensor]: pass @@ -56,7 +55,21 @@ def subgraph( edge_attr: OptTensor = ..., relabel_nodes: bool = ..., num_nodes: Optional[int] = ..., - return_edge_mask: Literal[True] = ..., + *, + return_edge_mask: Literal[False], +) -> Tuple[Tensor, OptTensor]: + pass + + +@overload +def subgraph( + subset: Union[Tensor, List[int]], + edge_index: Tensor, + edge_attr: OptTensor = ..., + relabel_nodes: bool = ..., + num_nodes: Optional[int] = ..., + *, + return_edge_mask: Literal[True], ) -> Tuple[Tensor, OptTensor, Tensor]: pass @@ -67,6 +80,7 @@ def subgraph( edge_attr: OptTensor = None, relabel_nodes: bool = False, num_nodes: Optional[int] = None, + *, return_edge_mask: bool = False, ) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]: r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)` @@ -314,8 +328,10 @@ def k_hop_subgraph( node_mask = row.new_empty(num_nodes, dtype=torch.bool) edge_mask = row.new_empty(row.size(0), dtype=torch.bool) - if isinstance(node_idx, (int, list, tuple)): - node_idx = torch.tensor([node_idx], device=row.device).flatten() + if isinstance(node_idx, int): + node_idx = torch.tensor([node_idx], device=row.device) + elif isinstance(node_idx, (list, tuple)): + node_idx = torch.tensor(node_idx, device=row.device) else: node_idx = node_idx.to(row.device) @@ -339,9 +355,9 @@ def k_hop_subgraph( edge_index = edge_index[:, edge_mask] if relabel_nodes: - node_idx = row.new_full((num_nodes, ), -1) - node_idx[subset] = torch.arange(subset.size(0), device=row.device) - edge_index = node_idx[edge_index] + mapping = row.new_full((num_nodes, ), -1) + mapping[subset] = torch.arange(subset.size(0), device=row.device) + edge_index = mapping[edge_index] return subset, edge_index, inv, edge_mask From 447f5a5e29bf965d62ba20a4a8eda181bc3342b8 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Thu, 7 Dec 2023 00:56:49 +0100 Subject: [PATCH 5/5] `mypy` support in `torch_geometric.utils` [4/n] (#8558) --- torch_geometric/utils/_coalesce.py | 12 +++++ torch_geometric/utils/_homophily.py | 30 +++++++++++-- torch_geometric/utils/_negative_sampling.py | 17 ++++--- torch_geometric/utils/_select.py | 2 +- torch_geometric/utils/_sort_edge_index.py | 49 ++++++++++++++++----- torch_geometric/utils/_spmm.py | 28 ++++++++---- torch_geometric/utils/_to_dense_adj.py | 4 +- torch_geometric/utils/geodesic.py | 19 ++++---- torch_geometric/utils/num_nodes.py | 8 ++-- torch_geometric/utils/sparse.py | 20 ++++----- 10 files changed, 135 insertions(+), 54 deletions(-) diff --git a/torch_geometric/utils/_coalesce.py b/torch_geometric/utils/_coalesce.py index 66efc8d5ca2f..5b026b872478 100644 --- a/torch_geometric/utils/_coalesce.py +++ b/torch_geometric/utils/_coalesce.py @@ -28,6 +28,18 @@ def coalesce( pass +@overload +def coalesce( # noqa: F811 + edge_index: Tensor, + edge_attr: Tensor, + num_nodes: Optional[int] = None, + reduce: str = 'sum', + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Tuple[Tensor, Tensor]: + pass + + @overload def coalesce( # noqa: F811 edge_index: Tensor, diff --git a/torch_geometric/utils/_homophily.py b/torch_geometric/utils/_homophily.py index 83daf4bdae61..e65fb291a6f3 100644 --- a/torch_geometric/utils/_homophily.py +++ b/torch_geometric/utils/_homophily.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, overload import torch from torch import Tensor @@ -7,8 +7,32 @@ from torch_geometric.utils import degree, scatter -def homophily(edge_index: Adj, y: Tensor, batch: OptTensor = None, - method: str = 'edge') -> Union[float, Tensor]: +@overload +def homophily( + edge_index: Adj, + y: Tensor, + batch: None = ..., + method: str = ..., +) -> float: + pass + + +@overload +def homophily( + edge_index: Adj, + y: Tensor, + batch: Tensor, + method: str = ..., +) -> Tensor: + pass + + +def homophily( + edge_index: Adj, + y: Tensor, + batch: OptTensor = None, + method: str = 'edge', +) -> Union[float, Tensor]: r"""The homophily of a graph characterizes how likely nodes with the same label are near each other in a graph. diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 5185ef1ee6ee..b15e9c883bcd 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -55,11 +55,16 @@ def negative_sampling( """ assert method in ['sparse', 'dense'] - size = num_nodes - bipartite = isinstance(size, (tuple, list)) - size = maybe_num_nodes(edge_index) if size is None else size - size = (size, size) if not bipartite else size - force_undirected = False if bipartite else force_undirected + if num_nodes is None: + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + if isinstance(num_nodes, int): + size = (num_nodes, num_nodes) + bipartite = False + else: + size = num_nodes + bipartite = True + force_undirected = False idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected) @@ -95,7 +100,7 @@ def negative_sampling( idx = idx.to('cpu') for _ in range(3): # Number of tries to sample negative indices. rnd = sample(population, sample_size, device='cpu') - mask = np.isin(rnd, idx) + mask = np.isin(rnd.numpy(), idx.numpy()) if neg_idx is not None: mask |= np.isin(rnd, neg_idx.to('cpu')) mask = torch.from_numpy(mask).to(torch.bool) diff --git a/torch_geometric/utils/_select.py b/torch_geometric/utils/_select.py index fede7fb2f1f9..f6f9eaaeeec2 100644 --- a/torch_geometric/utils/_select.py +++ b/torch_geometric/utils/_select.py @@ -52,7 +52,7 @@ def narrow(src: Union[Tensor, List[Any]], dim: int, start: int, start (int): The starting dimension. length (int): The distance to the ending dimension. """ - if is_torch_sparse_tensor(src): + if isinstance(src, Tensor) and is_torch_sparse_tensor(src): # TODO Sparse tensors in `torch.sparse` do not yet support `narrow`. index = torch.arange(start, start + length, device=src.device) return src.index_select(dim, index) diff --git a/torch_geometric/utils/_sort_edge_index.py b/torch_geometric/utils/_sort_edge_index.py index 44568502f6a3..b7b233197299 100644 --- a/torch_geometric/utils/_sort_edge_index.py +++ b/torch_geometric/utils/_sort_edge_index.py @@ -1,34 +1,61 @@ +import typing from typing import List, Optional, Tuple, Union -import torch from torch import Tensor from torch_geometric.typing import OptTensor from torch_geometric.utils import index_sort from torch_geometric.utils.num_nodes import maybe_num_nodes +if typing.TYPE_CHECKING: + from typing import overload +else: + from torch.jit import _overload as overload + MISSING = '???' -@torch.jit._overload -def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): # noqa - # type: (Tensor, str, Optional[int], bool) -> Tensor +@overload +def sort_edge_index( + edge_index: Tensor, + edge_attr: str = MISSING, + num_nodes: Optional[int] = None, + sort_by_row: bool = True, +) -> Tensor: + pass + + +@overload +def sort_edge_index( # noqa: F811 + edge_index: Tensor, + edge_attr: Tensor, + num_nodes: Optional[int] = None, + sort_by_row: bool = True, +) -> Tuple[Tensor, Tensor]: pass -@torch.jit._overload -def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): # noqa - # type: (Tensor, Optional[Tensor], Optional[int], bool) -> Tuple[Tensor, Optional[Tensor]] # noqa +@overload +def sort_edge_index( # noqa: F811 + edge_index: Tensor, + edge_attr: OptTensor, + num_nodes: Optional[int] = None, + sort_by_row: bool = True, +) -> Tuple[Tensor, OptTensor]: pass -@torch.jit._overload -def sort_edge_index(edge_index, edge_attr, num_nodes, sort_by_row): # noqa - # type: (Tensor, List[Tensor], Optional[int], bool) -> Tuple[Tensor, List[Tensor]] # noqa +@overload +def sort_edge_index( # noqa: F811 + edge_index: Tensor, + edge_attr: List[Tensor], + num_nodes: Optional[int] = None, + sort_by_row: bool = True, +) -> Tuple[Tensor, List[Tensor]]: pass -def sort_edge_index( # noqa +def sort_edge_index( # noqa: F811 edge_index: Tensor, edge_attr: Union[OptTensor, List[Tensor], str] = MISSING, num_nodes: Optional[int] = None, diff --git a/torch_geometric/utils/_spmm.py b/torch_geometric/utils/_spmm.py index f7dd01305aed..c48d01e8b849 100644 --- a/torch_geometric/utils/_spmm.py +++ b/torch_geometric/utils/_spmm.py @@ -1,3 +1,4 @@ +import typing import warnings import torch @@ -7,25 +8,36 @@ from torch_geometric.typing import Adj, SparseTensor, torch_sparse from torch_geometric.utils import is_torch_sparse_tensor, scatter +if typing.TYPE_CHECKING: + from typing import overload +else: + from torch.jit import _overload as overload -@torch.jit._overload -def spmm(src, other, reduce): # noqa: F811 - # type: (Tensor, Tensor, str) -> Tensor + +@overload +def spmm( + src: Tensor, + other: Tensor, + reduce: str = 'sum', +) -> Tensor: pass -@torch.jit._overload -def spmm(src, other, reduce): # noqa: F811 - # type: (SparseTensor, Tensor, str) -> Tensor +@overload +def spmm( # noqa: F811 + src: SparseTensor, + other: Tensor, + reduce: str = 'sum', +) -> Tensor: pass def spmm( # noqa: F811 src: Adj, other: Tensor, - reduce: str = "sum", + reduce: str = 'sum', ) -> Tensor: - """Matrix product of sparse matrix with dense matrix. + r"""Matrix product of sparse matrix with dense matrix. Args: src (torch.Tensor or torch_sparse.SparseTensor): The input sparse diff --git a/torch_geometric/utils/_to_dense_adj.py b/torch_geometric/utils/_to_dense_adj.py index f8701fb7a83f..5ca5758956e5 100644 --- a/torch_geometric/utils/_to_dense_adj.py +++ b/torch_geometric/utils/_to_dense_adj.py @@ -61,8 +61,8 @@ def to_dense_adj( [5., 0.]]]) """ if batch is None: - num_nodes = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 - batch = edge_index.new_zeros(num_nodes) + max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 + batch = edge_index.new_zeros(max_index) if batch_size is None: batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1 diff --git a/torch_geometric/utils/geodesic.py b/torch_geometric/utils/geodesic.py index 211e99b37f94..0b5e787d4f5a 100644 --- a/torch_geometric/utils/geodesic.py +++ b/torch_geometric/utils/geodesic.py @@ -2,7 +2,6 @@ import warnings from typing import Optional -import numpy as np import torch from torch import Tensor @@ -73,9 +72,9 @@ def geodesic_distance( # noqa: D417 if norm: area = (pos[face[1]] - pos[face[0]]).cross(pos[face[2]] - pos[face[0]]) - norm = (area.norm(p=2, dim=1) / 2).sum().sqrt().item() + scale = float((area.norm(p=2, dim=1) / 2).sum().sqrt()) else: - norm = 1.0 + scale = 1.0 dtype = pos.dtype @@ -84,20 +83,22 @@ def geodesic_distance( # noqa: D417 if src is None and dst is None: out = gdist.local_gdist_matrix(pos, face, - max_distance * norm).toarray() / norm + max_distance * scale).toarray() / scale return torch.from_numpy(out).to(dtype) if src is None: - src = np.arange(pos.shape[0], dtype=np.int32) + src = torch.arange(pos.shape[0], dtype=torch.int).numpy() else: src = src.detach().cpu().to(torch.int).numpy() + assert src is not None dst = None if dst is None else dst.detach().cpu().to(torch.int).numpy() - def _parallel_loop(pos, face, src, dst, max_distance, norm, i, dtype): + def _parallel_loop(pos, face, src, dst, max_distance, scale, i, dtype): s = src[i:i + 1] d = None if dst is None else dst[i:i + 1] - out = gdist.compute_gdist(pos, face, s, d, max_distance * norm) / norm + out = gdist.compute_gdist(pos, face, s, d, max_distance * scale) + out = out / scale return torch.from_numpy(out).to(dtype) num_workers = mp.cpu_count() if num_workers <= -1 else num_workers @@ -105,11 +106,11 @@ def _parallel_loop(pos, face, src, dst, max_distance, norm, i, dtype): with mp.Pool(num_workers) as pool: outs = pool.starmap( _parallel_loop, - [(pos, face, src, dst, max_distance, norm, i, dtype) + [(pos, face, src, dst, max_distance, scale, i, dtype) for i in range(len(src))]) else: outs = [ - _parallel_loop(pos, face, src, dst, max_distance, norm, i, dtype) + _parallel_loop(pos, face, src, dst, max_distance, scale, i, dtype) for i in range(len(src)) ] diff --git a/torch_geometric/utils/num_nodes.py b/torch_geometric/utils/num_nodes.py index 35b892dcfb81..1bc337e7d294 100644 --- a/torch_geometric/utils/num_nodes.py +++ b/torch_geometric/utils/num_nodes.py @@ -17,7 +17,7 @@ @overload def maybe_num_nodes( edge_index: Tensor, - num_nodes: Optional[int], + num_nodes: Optional[int] = None, ) -> int: pass @@ -25,15 +25,15 @@ def maybe_num_nodes( @overload def maybe_num_nodes( # noqa: F811 edge_index: Tuple[Tensor, Tensor], - num_nodes: Optional[int], + num_nodes: Optional[int] = None, ) -> int: pass @overload def maybe_num_nodes( # noqa: F811 - edge_index: SparseTensor, - num_nodes: Optional[int], + edge_index: SparseTensor, + num_nodes: Optional[int] = None, ) -> int: pass diff --git a/torch_geometric/utils/sparse.py b/torch_geometric/utils/sparse.py index 253a511a55f0..a590817b291f 100644 --- a/torch_geometric/utils/sparse.py +++ b/torch_geometric/utils/sparse.py @@ -1,3 +1,4 @@ +import typing import warnings from typing import Any, List, Optional, Tuple, Union @@ -303,6 +304,8 @@ def to_torch_csc_tensor( """ if not torch_geometric.typing.WITH_PT112: + if typing.TYPE_CHECKING: + raise NotImplementedError return torch_geometric.typing.MockTorchCSCTensor( edge_index, edge_attr, size) @@ -434,10 +437,10 @@ def get_sparse_diag( def set_sparse_value(adj: Tensor, value: Tensor) -> Tensor: - size = adj.size() - if value.dim() > 1: - size = size + value.size()[1:] + size = adj.size() + value.size()[1:] + else: + size = adj.size() if adj.layout == torch.sparse_coo: return torch.sparse_coo_tensor( @@ -483,7 +486,7 @@ def cat(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor: # the individual sparse tensor layouts. assert dim in {0, 1, (0, 1)} - size = [0, 0] + size = (0, 0) edge_indices = [] edge_attrs = [] for tensor in tensors: @@ -493,17 +496,14 @@ def cat(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor: if dim == 0: edge_index[0] += size[0] - size[0] += tensor.size(0) - size[1] = max(size[1], tensor.size(1)) + size = (size[0] + tensor.size(0), max(size[1], tensor.size(1))) elif dim == 1: edge_index[1] += size[1] - size[0] = max(size[0], tensor.size(0)) - size[1] += tensor.size(1) + size = (max(size[0], tensor.size(0)), size[1] + tensor.size(1)) else: edge_index[0] += size[0] edge_index[1] += size[1] - size[0] += tensor.size(0) - size[1] += tensor.size(1) + size = (size[0] + tensor.size(0), size[1] + tensor.size(1)) edge_indices.append(edge_index) edge_attrs.append(edge_attr)