-
Notifications
You must be signed in to change notification settings - Fork 311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cugraph-pyg] Add TransformerConv and support for bipartite node features in GATConv #3489
Merged
Merged
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
b29edcc
set up nn conv
tingyu66 f687207
test backward accuracy
tingyu66 30883e1
test bipartite variant
tingyu66 082d69f
add self connection
tingyu66 d20e713
docstring
tingyu66 8f46dfc
support bipartite input features in GATConv
tingyu66 6f11cb8
Merge branch 'branch-23.06' into transformer_conv
tingyu66 389ef2d
fix edge feat computation
tingyu66 183416a
update test script
tingyu66 5bfa7c5
Merge branch 'branch-23.06' into transformer_conv
tingyu66 dee2fe5
add option to use edge feat in GATConv test
tingyu66 b05a2dc
Merge branch 'branch-23.06' into transformer_conv
tingyu66 72d145b
fix gatconv test
tingyu66 6e8e92b
add GATv2Conv
tingyu66 aba9237
Update ci/test_python.sh
tingyu66 00fc751
use pyg linear operator and init functions
tingyu66 6a79b17
assert linear operators in gatconv
tingyu66 0bdb03c
Merge branch 'branch-23.06' into transformer_conv
tingyu66 0a85d81
Merge branch 'branch-23.06' into transformer_conv
tingyu66 40df1b4
Merge branch 'branch-23.06' into transformer_conv
tingyu66 6044257
add back pylibcugraphops for cugraph_pyg testing
tingyu66 39954a2
Merge branch 'branch-23.06' into transformer_conv
tingyu66 f5a712e
remove assertions, raise RuntimeError instead
tingyu66 4ee2e9e
Merge branch 'branch-23.06' into transformer_conv
tingyu66 94579a8
empty commit
tingyu66 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .conv import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .gat_conv import GATConv | ||
from .gatv2_conv import GATv2Conv | ||
from .transformer_conv import TransformerConv | ||
|
||
__all__ = [ | ||
"GATConv", | ||
"GATv2Conv", | ||
"TransformerConv", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import warnings | ||
from typing import Any, Optional, Tuple, Union | ||
|
||
from cugraph.utilities.utils import import_optional | ||
|
||
torch = import_optional("torch") | ||
torch_geometric = import_optional("torch_geometric") | ||
|
||
try: # pragma: no cover | ||
from pylibcugraphops.pytorch import ( | ||
BipartiteCSC, | ||
SampledCSC, | ||
SampledHeteroCSC, | ||
StaticCSC, | ||
StaticHeteroCSC, | ||
) | ||
|
||
HAS_PYLIBCUGRAPHOPS = True | ||
except ImportError: | ||
HAS_PYLIBCUGRAPHOPS = False | ||
|
||
|
||
class BaseConv(torch.nn.Module): # pragma: no cover | ||
r"""An abstract base class for implementing cugraph-ops message passing layers.""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
if HAS_PYLIBCUGRAPHOPS is False: | ||
raise ModuleNotFoundError( | ||
f"'{self.__class__.__name__}' requires " f"'pylibcugraphops>=23.04'" | ||
) | ||
|
||
def reset_parameters(self): | ||
r"""Resets all learnable parameters of the module.""" | ||
pass | ||
|
||
@staticmethod | ||
def to_csc( | ||
edge_index: torch.Tensor, | ||
size: Optional[Tuple[int, int]] = None, | ||
edge_attr: Optional[torch.Tensor] = None, | ||
) -> Union[ | ||
Tuple[torch.Tensor, torch.Tensor, int], | ||
Tuple[Tuple[torch.Tensor, torch.Tensor, int], torch.Tensor], | ||
]: | ||
r"""Returns a CSC representation of an :obj:`edge_index` tensor to be | ||
used as input to cugraph-ops conv layers. | ||
|
||
Args: | ||
edge_index (torch.Tensor): The edge indices. | ||
size ((int, int), optional). The shape of :obj:`edge_index` in each | ||
dimension. (default: :obj:`None`) | ||
edge_attr (torch.Tensor, optional): The edge features. | ||
(default: :obj:`None`) | ||
""" | ||
if size is None: | ||
warnings.warn( | ||
f"Inferring the graph size from 'edge_index' causes " | ||
f"a decline in performance and does not work for " | ||
f"bipartite graphs. To suppress this warning, pass " | ||
f"the 'size' explicitly in '{__name__}.to_csc()'." | ||
) | ||
num_src_nodes = num_dst_nodes = int(edge_index.max()) + 1 | ||
else: | ||
num_src_nodes, num_dst_nodes = size | ||
|
||
row, col = edge_index | ||
col, perm = torch_geometric.utils.index_sort(col, max_value=num_dst_nodes) | ||
row = row[perm] | ||
|
||
colptr = torch_geometric.utils.sparse.index2ptr(col, num_dst_nodes) | ||
|
||
if edge_attr is not None: | ||
return (row, colptr, num_src_nodes), edge_attr[perm] | ||
|
||
return row, colptr, num_src_nodes | ||
|
||
def get_cugraph( | ||
self, | ||
csc: Tuple[torch.Tensor, torch.Tensor, int], | ||
bipartite: bool = False, | ||
max_num_neighbors: Optional[int] = None, | ||
) -> Any: | ||
r"""Constructs a :obj:`cugraph-ops` graph object from CSC representation. | ||
Supports both bipartite and non-bipartite graphs. | ||
|
||
Args: | ||
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC | ||
representation of a graph, given as a tuple of | ||
:obj:`(row, colptr, num_src_nodes)`. Use the | ||
:meth:`to_csc` method to convert an :obj:`edge_index` | ||
representation to the desired format. | ||
bipartite (bool): If set to :obj:`True`, will create the bipartite | ||
structure in cugraph-ops. (default: :obj:`False`) | ||
max_num_neighbors (int, optional): The maximum number of neighbors | ||
of a target node. It is only effective when operating in a | ||
bipartite graph. When not given, will be computed on-the-fly, | ||
leading to slightly worse performance. (default: :obj:`None`) | ||
""" | ||
row, colptr, num_src_nodes = csc | ||
|
||
if not row.is_cuda: | ||
raise RuntimeError( | ||
f"'{self.__class__.__name__}' requires GPU-" | ||
f"based processing (got CPU tensor)" | ||
) | ||
|
||
if bipartite: | ||
return BipartiteCSC(colptr, row, num_src_nodes) | ||
|
||
if num_src_nodes != colptr.numel() - 1: | ||
if max_num_neighbors is None: | ||
max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) | ||
|
||
return SampledCSC(colptr, row, max_num_neighbors, num_src_nodes) | ||
|
||
return StaticCSC(colptr, row) | ||
|
||
def get_typed_cugraph( | ||
self, | ||
csc: Tuple[torch.Tensor, torch.Tensor, int], | ||
edge_type: torch.Tensor, | ||
num_edge_types: Optional[int] = None, | ||
bipartite: bool = False, | ||
max_num_neighbors: Optional[int] = None, | ||
) -> Any: | ||
r"""Constructs a typed :obj:`cugraph` graph object from a CSC | ||
representation where each edge corresponds to a given edge type. | ||
Supports both bipartite and non-bipartite graphs. | ||
|
||
Args: | ||
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC | ||
representation of a graph, given as a tuple of | ||
:obj:`(row, colptr, num_src_nodes)`. Use the | ||
:meth:`to_csc` method to convert an :obj:`edge_index` | ||
representation to the desired format. | ||
edge_type (torch.Tensor): The edge type. | ||
num_edge_types (int, optional): The maximum number of edge types. | ||
When not given, will be computed on-the-fly, leading to | ||
slightly worse performance. (default: :obj:`None`) | ||
bipartite (bool): If set to :obj:`True`, will create the bipartite | ||
structure in cugraph-ops. (default: :obj:`False`) | ||
max_num_neighbors (int, optional): The maximum number of neighbors | ||
of a target node. It is only effective when operating in a | ||
bipartite graph. When not given, will be computed on-the-fly, | ||
leading to slightly worse performance. (default: :obj:`None`) | ||
""" | ||
if num_edge_types is None: | ||
num_edge_types = int(edge_type.max()) + 1 | ||
|
||
row, colptr, num_src_nodes = csc | ||
edge_type = edge_type.int() | ||
|
||
if bipartite: | ||
raise NotImplementedError | ||
|
||
if num_src_nodes != colptr.numel() - 1: | ||
if max_num_neighbors is None: | ||
max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) | ||
|
||
return SampledHeteroCSC( | ||
colptr, row, edge_type, max_num_neighbors, num_src_nodes, num_edge_types | ||
) | ||
|
||
return StaticHeteroCSC(colptr, row, edge_type, num_edge_types) | ||
|
||
def forward( | ||
self, | ||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], | ||
csc: Tuple[torch.Tensor, torch.Tensor, int], | ||
) -> torch.Tensor: | ||
r"""Runs the forward pass of the module. | ||
|
||
Args: | ||
x (torch.Tensor): The node features. | ||
csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC | ||
representation of a graph, given as a tuple of | ||
:obj:`(row, colptr, num_src_nodes)`. Use the | ||
:meth:`to_csc` method to convert an :obj:`edge_index` | ||
representation to the desired format. | ||
""" | ||
raise NotImplementedError |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should also actually check for >= 23.04
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We did check this implicitly via
as these structures are not available <23.04.
Do you prefer to explicitly check the version?