Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Add flexible layer kernels to GNN and GraphTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
jakob-schloer committed Dec 18, 2024
1 parent 2621880 commit 4e35033
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 24 deletions.
18 changes: 11 additions & 7 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
self.conv = GraphConv(
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand Down Expand Up @@ -192,11 +193,11 @@ def __ini__(
self,
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=update_src_nodes,
num_chunks=num_chunks,
layer_kernels=layer_kernels,
**kwargs,
)

Expand Down Expand Up @@ -250,11 +251,11 @@ def __ini__(
self,
in_channels=in_channels,
out_channels=out_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=update_src_nodes,
num_chunks=num_chunks,
layer_kernels=layer_kernels,
**kwargs,
)

Expand Down Expand Up @@ -365,18 +366,19 @@ def __init__(
LOGGER.error("Activation function %s not supported", activation)
raise RuntimeError from ae

self.layer_norm_attention = layerNorm(normalized_shape=in_channels)
self.layer_norm_mlp = layerNorm(normalized_shape=out_channels)

self.node_dst_mlp = nn.Sequential(
layerNorm(normalized_shape=out_channels),
self.layer_norm_mlp,
linear(out_channels, hidden_dim),
act_func(),
linear(hidden_dim, out_channels),
)

self.layer_norm_attention = layerNorm(normalized_shape=in_channels)

if self.update_src_nodes:
self.node_src_mlp = nn.Sequential(
layerNorm(normlaized_shape=out_channels),
self.layer_norm_mlp,
linear(out_channels, hidden_dim),
act_func(),
linear(hidden_dim, out_channels),
Expand Down Expand Up @@ -516,6 +518,7 @@ def forward(
self.layer_norm_attention(x[0]),
self.layer_norm_attention_2(x[1]),
) # Why does this use layer_norm_attention_2? And only is a mapper thing?

x_r = self.lin_self(x[1])
query = self.lin_query(x[1])
key = self.lin_key(x[0])
Expand Down Expand Up @@ -624,7 +627,8 @@ def __init__(
bias=bias,
activation=activation,
num_chunks=num_chunks,
update_src_nodes=update_src_nodes**kwargs,
update_src_nodes=update_src_nodes,
**kwargs,
)

def forward(
Expand Down
15 changes: 8 additions & 7 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
self,
num_channels: int,
num_layers: int,
window_size: int,
layer_kernels: DotDict,
window_size: int,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
Expand All @@ -87,11 +87,11 @@ def __init__(
Number of channels
num_layers : int
Number of layers
window_size: int,
1/2 size of shifted window for attention computation
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
window_size: int,
1/2 size of shifted window for attention computation
num_heads: int
Number of heads to use, default 16
mlp_hidden_ratio: int
Expand All @@ -110,8 +110,8 @@ def __init__(
num_heads=num_heads,
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
layer_kernels=layer_kernels,
dropout_p=dropout_p,
)

def forward(
Expand Down Expand Up @@ -165,6 +165,7 @@ def __init__(
in_features=edge_dim,
hidden_dim=num_channels,
out_features=num_channels,
layer_kernels=layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand All @@ -175,9 +176,9 @@ def __init__(
GraphConvProcessorBlock,
num_channels,
num_channels,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
layer_kernels=layer_kernels,
)

def forward(
Expand Down Expand Up @@ -239,10 +240,10 @@ def __init__(
in_channels=num_channels,
hidden_dim=mlp_hidden_ratio * num_channels,
out_channels=num_channels,
num_heads=num_heads,
edge_dim=edge_dim,
activation=activation,
num_heads=num_heads,
layer_kernels=layer_kernels,
activation=activation,
)

def forward(
Expand Down
6 changes: 6 additions & 0 deletions src/anemoi/models/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional

import torch
from anemoi.utils.config import DotDict
from torch import Tensor
from torch.nn.functional import dropout
from torch_geometric.nn.conv import MessagePassing
Expand All @@ -31,6 +32,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
layer_kernels: DotDict,
mlp_extra_layers: int = 0,
activation: str = "SiLU",
**kwargs,
Expand All @@ -43,6 +45,9 @@ def __init__(
Number of input channels.
out_channels : int
Number of output channels.
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
mlp_extra_layers : int, optional
Extra layers in MLP, by default 0
activation : str, optional
Expand All @@ -54,6 +59,7 @@ def __init__(
3 * in_channels,
out_channels,
out_channels,
layer_kernels=layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand Down
21 changes: 20 additions & 1 deletion src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def __init__(
num_chunks=num_chunks,
cpu_offload=cpu_offload,
activation=activation,
layer_kernels=layer_kernels,
)

# Linear = layer_kernels.get("Linear", torch.nn.Linear)
Expand Down Expand Up @@ -453,6 +452,7 @@ def __init__(
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
layer_kernels: DotDict = None,
) -> None:
"""Initialize GNNBaseMapper.
Expand All @@ -476,6 +476,9 @@ def __init__(
Whether to offload processing to CPU, by default False
out_channels_dst : Optional[int], optional
Output channels of the destination node, by default None
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
"""
super().__init__(
in_channels_src,
Expand All @@ -493,6 +496,7 @@ def __init__(
in_features=self.edge_dim,
hidden_dim=hidden_dim,
out_features=hidden_dim,
layer_kernels=layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand Down Expand Up @@ -557,6 +561,7 @@ def __init__(
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
layer_kernels: DotDict = None,
) -> None:
"""Initialize GNNForwardMapper.
Expand All @@ -580,6 +585,9 @@ def __init__(
Whether to offload processing to CPU, by default False
out_channels_dst : Optional[int], optional
Output channels of the destination node, by default None
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
"""
super().__init__(
in_channels_src,
Expand All @@ -595,11 +603,13 @@ def __init__(
sub_graph_edge_attributes=sub_graph_edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
layer_kernels=layer_kernels,
)

self.proc = GraphConvMapperBlock(
hidden_dim,
hidden_dim,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=True,
Expand All @@ -612,6 +622,7 @@ def __init__(
in_features=in_channels_src,
hidden_dim=hidden_dim,
out_features=hidden_dim,
layer_kernels=layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand All @@ -620,6 +631,7 @@ def __init__(
in_features=in_channels_dst,
hidden_dim=hidden_dim,
out_features=hidden_dim,
layer_kernels=layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=activation,
)
Expand All @@ -643,6 +655,7 @@ def __init__(
sub_graph_edge_attributes: Optional[list[str]] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
layer_kernels: DotDict = None,
) -> None:
"""Initialize GNNBackwardMapper.
Expand All @@ -666,6 +679,9 @@ def __init__(
Whether to offload processing to CPU, by default False
out_channels_dst : Optional[int], optional
Output channels of the destination node, by default None
layer_kernels : DotDict
A dict of layer implementations e.g. layer_kernels['Linear'] = "torch.nn.Linear"
Defined in config/models/<model>.yaml
"""
super().__init__(
in_channels_src,
Expand All @@ -681,11 +697,13 @@ def __init__(
sub_graph_edge_attributes=sub_graph_edge_attributes,
src_grid_size=src_grid_size,
dst_grid_size=dst_grid_size,
layer_kernels=layer_kernels,
)

self.proc = GraphConvMapperBlock(
hidden_dim,
hidden_dim,
layer_kernels=layer_kernels,
mlp_extra_layers=mlp_extra_layers,
activation=activation,
update_src_nodes=False,
Expand All @@ -698,6 +716,7 @@ def __init__(
in_features=self.hidden_dim,
hidden_dim=self.hidden_dim,
out_features=self.out_channels_dst,
layer_kernels=layer_kernels,
n_extra_layers=mlp_extra_layers,
activation=self.activation,
layer_norm=False,
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/models/layers/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
mlp1.append(act_func())

if layer_norm:
mlp1.append(LayerNorm(out_features).as_type(out_features))
mlp1.append(LayerNorm(normalized_shape=out_features))

self.model = CheckpointWrapper(mlp1) if checkpoints else mlp1

Expand Down
17 changes: 9 additions & 8 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,26 @@ def __init__(
Dropout probability used for multi-head self attention, default 0.0
"""
super().__init__(
num_channels=num_channels,
num_layers=num_layers,
num_channels=num_channels,
window_size=window_size,
num_chunks=num_chunks,
activation=activation,
cpu_offload=cpu_offload,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
# layer_kernels=layer_kernels,
)

self.build_layers(
TransformerProcessorChunk,
num_channels=num_channels,
num_layers=self.chunk_size,
layer_kernels=layer_kernels,
mlp_hidden_ratio=mlp_hidden_ratio,
num_heads=num_heads,
num_layers=self.chunk_size,
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
layer_kernels=layer_kernels,
)

self.offload_layers(cpu_offload)
Expand Down Expand Up @@ -175,6 +174,7 @@ class GNNProcessor(GraphEdgeMixin, BaseProcessor):
def __init__(
self,
num_layers: int,
layer_kernels: DotDict,
*args,
trainable_size: int = 8,
num_channels: int = 128,
Expand Down Expand Up @@ -219,16 +219,15 @@ def __init__(
self.trainable = TrainableTensor(trainable_size=trainable_size, tensor_size=self.edge_attr.shape[0])

kwargs = {
"num_layers": self.chunk_size,
"mlp_extra_layers": mlp_extra_layers,
"activation": activation,
"edge_dim": None,
}

self.build_layers(GNNProcessorChunk, num_channels, **kwargs)
self.build_layers(GNNProcessorChunk, num_channels, self.chunk_size, layer_kernels, **kwargs)

kwargs["edge_dim"] = self.edge_dim # Edge dim for first layer
self.proc[0] = GNNProcessorChunk(num_channels, **kwargs)
self.proc[0] = GNNProcessorChunk(num_channels, self.chunk_size, layer_kernels, **kwargs)

self.offload_layers(cpu_offload)

Expand Down Expand Up @@ -263,6 +262,7 @@ class GraphTransformerProcessor(GraphEdgeMixin, BaseProcessor):
def __init__(
self,
num_layers: int,
layer_kernels: DotDict,
trainable_size: int = 8,
num_channels: int = 128,
num_chunks: int = 2,
Expand Down Expand Up @@ -296,8 +296,8 @@ def __init__(
Whether to offload processing to CPU, by default False
"""
super().__init__(
num_layers=num_layers,
num_channels=num_channels,
num_layers=num_layers,
num_chunks=num_chunks,
activation=activation,
cpu_offload=cpu_offload,
Expand All @@ -313,6 +313,7 @@ def __init__(
GraphTransformerProcessorChunk,
num_channels=num_channels,
num_layers=self.chunk_size,
layer_kernels=layer_kernels,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
activation=activation,
Expand Down

0 comments on commit 4e35033

Please sign in to comment.