Skip to content

Commit

Permalink
[cugraph-dgl] Add TransformerConv (#3501)
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 authored May 18, 2023
1 parent aad0a0f commit 649edb6
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from .gatconv import GATConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .transformerconv import TransformerConv

__all__ = [
"GATConv",
"RelGraphConv",
"SAGEConv",
"TransformerConv",
]
176 changes: 176 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union

from cugraph_dgl.nn.conv.base import BaseConv
from cugraph.utilities.utils import import_optional

from pylibcugraphops.pytorch import BipartiteCSC, StaticCSC
from pylibcugraphops.pytorch.operators import mha_simple_n2n

dgl = import_optional("dgl")
torch = import_optional("torch")
nn = import_optional("torch.nn")


class TransformerConv(BaseConv):
r"""The graph transformer layer from the `"Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>`_ paper.
Parameters
----------
in_node_feats : int or pair of ints
Input feature size. A pair denotes feature sizes of source and
destination nodes.
out_node_feats : int
Output feature size.
num_heads : int
Number of multi-head-attentions.
concat : bool, optional
If False, the multi-head attentions are averaged instead of concatenated.
Default: ``True``.
beta : bool, optional
If True, use a gated residual connection. Default: ``True``.
edge_feats: int, optional
Edge feature size. Default: ``None``.
bias: bool, optional
If True, learns a bias term. Default: ``True``.
root_weight: bool, optional
If False, will skip to learn a root weight matrix. Default: ``True``.
"""

def __init__(
self,
in_node_feats: Union[int, Tuple[int, int]],
out_node_feats: int,
num_heads: int,
concat: bool = True,
beta: bool = False,
edge_feats: Optional[int] = None,
bias: bool = True,
root_weight: bool = True,
):
super().__init__()

self.in_node_feats = in_node_feats
self.out_node_feats = out_node_feats
self.num_heads = num_heads
self.concat = concat
self.beta = beta
self.edge_feats = edge_feats
self.bias = bias
self.root_weight = root_weight

if isinstance(in_node_feats, int):
in_node_feats = (in_node_feats, in_node_feats)

self.lin_key = nn.Linear(in_node_feats[0], num_heads * out_node_feats)
self.lin_query = nn.Linear(in_node_feats[1], num_heads * out_node_feats)
self.lin_value = nn.Linear(in_node_feats[0], num_heads * out_node_feats)

if edge_feats is not None:
self.lin_edge = nn.Linear(
edge_feats, num_heads * out_node_feats, bias=False
)
else:
self.lin_edge = self.register_parameter("lin_edge", None)

if concat:
self.lin_skip = nn.Linear(
in_node_feats[1], num_heads * out_node_feats, bias=bias
)
if self.beta:
self.lin_beta = nn.Linear(3 * num_heads * out_node_feats, 1, bias=bias)
else:
self.lin_beta = self.register_parameter("lin_beta", None)
else:
self.lin_skip = nn.Linear(in_node_feats[1], out_node_feats, bias=bias)
if self.beta:
self.lin_beta = nn.Linear(3 * out_node_feats, 1, bias=False)
else:
self.lin_beta = self.register_parameter("lin_beta", None)

self.reset_parameters()

def reset_parameters(self):
self.lin_key.reset_parameters()
self.lin_query.reset_parameters()
self.lin_value.reset_parameters()
if self.lin_edge is not None:
self.lin_edge.reset_parameters()
if self.lin_skip is not None:
self.lin_skip.reset_parameters()
if self.lin_beta is not None:
self.lin_beta.reset_parameters()

def forward(
self,
g: dgl.DGLHeteroGraph,
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward computation.
Parameters
----------
g: DGLGraph
The graph.
nfeat: torch.Tensor or a pair of torch.Tensor
Node feature tensor. A pair denotes features for source and
destination nodes, respectively.
efeat: torch.Tensor, optional
Edge feature tensor. Default: ``None``.
"""
bipartite = not isinstance(nfeat, torch.Tensor)
offsets, indices, _ = g.adj_tensors("csc")

if bipartite:
src_feats, dst_feats = nfeat
_graph = BipartiteCSC(
offsets=offsets, indices=indices, num_src_nodes=g.num_src_nodes()
)
else:
src_feats = dst_feats = nfeat
if g.is_block:
offsets = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = StaticCSC(offsets=offsets, indices=indices)

query = self.lin_query(dst_feats)
key = self.lin_key(src_feats)
value = self.lin_value(src_feats)
if self.lin_edge is not None:
efeat = self.lin_edge(efeat)

out = mha_simple_n2n(
key_emb=key,
query_emb=query,
value_emb=value,
graph=_graph,
num_heads=self.num_heads,
concat_heads=self.concat,
edge_emb=efeat,
norm_by_dim=True,
score_bias=None,
)[: g.num_dst_nodes()]

if self.root_weight:
res = self.lin_skip(dst_feats[: g.num_dst_nodes()])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, res, out - res], dim=-1))
beta = beta.sigmoid()
out = beta * res + (1 - beta) * out
else:
out = out + res

return out
76 changes: 76 additions & 0 deletions python/cugraph-dgl/tests/nn/test_transformerconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

try:
from cugraph_dgl.nn import TransformerConv
except ModuleNotFoundError:
pytest.skip("cugraph_dgl not available", allow_module_level=True)

from cugraph.utilities.utils import import_optional
from .common import create_graph1

torch = import_optional("torch")
dgl = import_optional("dgl")


@pytest.mark.parametrize("beta", [False, True])
@pytest.mark.parametrize("bipartite", [False, True])
@pytest.mark.parametrize("concat", [False, True])
@pytest.mark.parametrize("idtype_int", [False, True])
@pytest.mark.parametrize("num_heads", [1, 2, 3, 4])
@pytest.mark.parametrize("to_block", [False, True])
@pytest.mark.parametrize("use_edge_feats", [False, True])
def test_TransformerConv(
beta, bipartite, concat, idtype_int, num_heads, to_block, use_edge_feats
):
device = "cuda"
g = create_graph1().to(device)

if idtype_int:
g = g.int()

if to_block:
g = dgl.to_block(g)

if bipartite:
in_node_feats = (5, 3)
nfeat = (
torch.rand(g.num_src_nodes(), in_node_feats[0], device=device),
torch.rand(g.num_dst_nodes(), in_node_feats[1], device=device),
)
else:
in_node_feats = 3
nfeat = torch.rand(g.num_src_nodes(), in_node_feats, device=device)
out_node_feats = 2

if use_edge_feats:
edge_feats = 3
efeat = torch.rand(g.num_edges(), edge_feats, device=device)
else:
edge_feats = None
efeat = None

conv = TransformerConv(
in_node_feats,
out_node_feats,
num_heads=num_heads,
concat=concat,
beta=beta,
edge_feats=edge_feats,
).to(device)

out = conv(g, nfeat, efeat)
grad_out = torch.rand_like(out)
out.backward(grad_out)

0 comments on commit 649edb6

Please sign in to comment.