-
Notifications
You must be signed in to change notification settings - Fork 311
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cugraph-dgl] Add TransformerConv (#3501)
Fixes rapidsai/graph_dl#181 Authors: - Tingyu Wang (https://github.com/tingyu66) - Brad Rees (https://github.com/BradReesWork) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Vibhu Jawa (https://github.com/VibhuJawa) URL: #3501
- Loading branch information
Showing
3 changed files
with
254 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
176 changes: 176 additions & 0 deletions
176
python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional, Tuple, Union | ||
|
||
from cugraph_dgl.nn.conv.base import BaseConv | ||
from cugraph.utilities.utils import import_optional | ||
|
||
from pylibcugraphops.pytorch import BipartiteCSC, StaticCSC | ||
from pylibcugraphops.pytorch.operators import mha_simple_n2n | ||
|
||
dgl = import_optional("dgl") | ||
torch = import_optional("torch") | ||
nn = import_optional("torch.nn") | ||
|
||
|
||
class TransformerConv(BaseConv): | ||
r"""The graph transformer layer from the `"Masked Label Prediction: | ||
Unified Message Passing Model for Semi-Supervised Classification" | ||
<https://arxiv.org/abs/2009.03509>`_ paper. | ||
Parameters | ||
---------- | ||
in_node_feats : int or pair of ints | ||
Input feature size. A pair denotes feature sizes of source and | ||
destination nodes. | ||
out_node_feats : int | ||
Output feature size. | ||
num_heads : int | ||
Number of multi-head-attentions. | ||
concat : bool, optional | ||
If False, the multi-head attentions are averaged instead of concatenated. | ||
Default: ``True``. | ||
beta : bool, optional | ||
If True, use a gated residual connection. Default: ``True``. | ||
edge_feats: int, optional | ||
Edge feature size. Default: ``None``. | ||
bias: bool, optional | ||
If True, learns a bias term. Default: ``True``. | ||
root_weight: bool, optional | ||
If False, will skip to learn a root weight matrix. Default: ``True``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
in_node_feats: Union[int, Tuple[int, int]], | ||
out_node_feats: int, | ||
num_heads: int, | ||
concat: bool = True, | ||
beta: bool = False, | ||
edge_feats: Optional[int] = None, | ||
bias: bool = True, | ||
root_weight: bool = True, | ||
): | ||
super().__init__() | ||
|
||
self.in_node_feats = in_node_feats | ||
self.out_node_feats = out_node_feats | ||
self.num_heads = num_heads | ||
self.concat = concat | ||
self.beta = beta | ||
self.edge_feats = edge_feats | ||
self.bias = bias | ||
self.root_weight = root_weight | ||
|
||
if isinstance(in_node_feats, int): | ||
in_node_feats = (in_node_feats, in_node_feats) | ||
|
||
self.lin_key = nn.Linear(in_node_feats[0], num_heads * out_node_feats) | ||
self.lin_query = nn.Linear(in_node_feats[1], num_heads * out_node_feats) | ||
self.lin_value = nn.Linear(in_node_feats[0], num_heads * out_node_feats) | ||
|
||
if edge_feats is not None: | ||
self.lin_edge = nn.Linear( | ||
edge_feats, num_heads * out_node_feats, bias=False | ||
) | ||
else: | ||
self.lin_edge = self.register_parameter("lin_edge", None) | ||
|
||
if concat: | ||
self.lin_skip = nn.Linear( | ||
in_node_feats[1], num_heads * out_node_feats, bias=bias | ||
) | ||
if self.beta: | ||
self.lin_beta = nn.Linear(3 * num_heads * out_node_feats, 1, bias=bias) | ||
else: | ||
self.lin_beta = self.register_parameter("lin_beta", None) | ||
else: | ||
self.lin_skip = nn.Linear(in_node_feats[1], out_node_feats, bias=bias) | ||
if self.beta: | ||
self.lin_beta = nn.Linear(3 * out_node_feats, 1, bias=False) | ||
else: | ||
self.lin_beta = self.register_parameter("lin_beta", None) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.lin_key.reset_parameters() | ||
self.lin_query.reset_parameters() | ||
self.lin_value.reset_parameters() | ||
if self.lin_edge is not None: | ||
self.lin_edge.reset_parameters() | ||
if self.lin_skip is not None: | ||
self.lin_skip.reset_parameters() | ||
if self.lin_beta is not None: | ||
self.lin_beta.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
g: dgl.DGLHeteroGraph, | ||
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], | ||
efeat: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
"""Forward computation. | ||
Parameters | ||
---------- | ||
g: DGLGraph | ||
The graph. | ||
nfeat: torch.Tensor or a pair of torch.Tensor | ||
Node feature tensor. A pair denotes features for source and | ||
destination nodes, respectively. | ||
efeat: torch.Tensor, optional | ||
Edge feature tensor. Default: ``None``. | ||
""" | ||
bipartite = not isinstance(nfeat, torch.Tensor) | ||
offsets, indices, _ = g.adj_tensors("csc") | ||
|
||
if bipartite: | ||
src_feats, dst_feats = nfeat | ||
_graph = BipartiteCSC( | ||
offsets=offsets, indices=indices, num_src_nodes=g.num_src_nodes() | ||
) | ||
else: | ||
src_feats = dst_feats = nfeat | ||
if g.is_block: | ||
offsets = self.pad_offsets(offsets, g.num_src_nodes() + 1) | ||
_graph = StaticCSC(offsets=offsets, indices=indices) | ||
|
||
query = self.lin_query(dst_feats) | ||
key = self.lin_key(src_feats) | ||
value = self.lin_value(src_feats) | ||
if self.lin_edge is not None: | ||
efeat = self.lin_edge(efeat) | ||
|
||
out = mha_simple_n2n( | ||
key_emb=key, | ||
query_emb=query, | ||
value_emb=value, | ||
graph=_graph, | ||
num_heads=self.num_heads, | ||
concat_heads=self.concat, | ||
edge_emb=efeat, | ||
norm_by_dim=True, | ||
score_bias=None, | ||
)[: g.num_dst_nodes()] | ||
|
||
if self.root_weight: | ||
res = self.lin_skip(dst_feats[: g.num_dst_nodes()]) | ||
if self.lin_beta is not None: | ||
beta = self.lin_beta(torch.cat([out, res, out - res], dim=-1)) | ||
beta = beta.sigmoid() | ||
out = beta * res + (1 - beta) * out | ||
else: | ||
out = out + res | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
|
||
try: | ||
from cugraph_dgl.nn import TransformerConv | ||
except ModuleNotFoundError: | ||
pytest.skip("cugraph_dgl not available", allow_module_level=True) | ||
|
||
from cugraph.utilities.utils import import_optional | ||
from .common import create_graph1 | ||
|
||
torch = import_optional("torch") | ||
dgl = import_optional("dgl") | ||
|
||
|
||
@pytest.mark.parametrize("beta", [False, True]) | ||
@pytest.mark.parametrize("bipartite", [False, True]) | ||
@pytest.mark.parametrize("concat", [False, True]) | ||
@pytest.mark.parametrize("idtype_int", [False, True]) | ||
@pytest.mark.parametrize("num_heads", [1, 2, 3, 4]) | ||
@pytest.mark.parametrize("to_block", [False, True]) | ||
@pytest.mark.parametrize("use_edge_feats", [False, True]) | ||
def test_TransformerConv( | ||
beta, bipartite, concat, idtype_int, num_heads, to_block, use_edge_feats | ||
): | ||
device = "cuda" | ||
g = create_graph1().to(device) | ||
|
||
if idtype_int: | ||
g = g.int() | ||
|
||
if to_block: | ||
g = dgl.to_block(g) | ||
|
||
if bipartite: | ||
in_node_feats = (5, 3) | ||
nfeat = ( | ||
torch.rand(g.num_src_nodes(), in_node_feats[0], device=device), | ||
torch.rand(g.num_dst_nodes(), in_node_feats[1], device=device), | ||
) | ||
else: | ||
in_node_feats = 3 | ||
nfeat = torch.rand(g.num_src_nodes(), in_node_feats, device=device) | ||
out_node_feats = 2 | ||
|
||
if use_edge_feats: | ||
edge_feats = 3 | ||
efeat = torch.rand(g.num_edges(), edge_feats, device=device) | ||
else: | ||
edge_feats = None | ||
efeat = None | ||
|
||
conv = TransformerConv( | ||
in_node_feats, | ||
out_node_feats, | ||
num_heads=num_heads, | ||
concat=concat, | ||
beta=beta, | ||
edge_feats=edge_feats, | ||
).to(device) | ||
|
||
out = conv(g, nfeat, efeat) | ||
grad_out = torch.rand_like(out) | ||
out.backward(grad_out) |