Skip to content

Commit

Permalink
Merge main into branch
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Oct 17, 2024
1 parent 22caf65 commit 7ec34ff
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

- `save.to_pyg` can now handle any number of 1D or 2D edge or node features when
converting pytorch-geometric `Data` objects to `torch.Tensor` objects.
[\#31](https://github.com/mllam/weather-model-graphs/pull/31)
@maxiimilian

### Changed

- Fix wrong number of mesh levels when grid is multiple of refinement factor
Expand Down
47 changes: 33 additions & 14 deletions src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
from pathlib import Path
from typing import List

import networkx
from loguru import logger
Expand All @@ -8,6 +9,7 @@

try:
import torch
import torch_geometric as pyg
import torch_geometric.utils.convert as pyg_convert

HAS_PYG = True
Expand All @@ -19,8 +21,8 @@ def to_pyg(
graph: networkx.DiGraph,
output_directory: str,
name: str,
edge_features=["vdiff"],
node_features=["pos"],
edge_features: List[str] | None = None,
node_features: List[str] | None = None,
list_from_attribute=None,
):
"""
Expand Down Expand Up @@ -59,6 +61,13 @@ def to_pyg(
"install weather-mode-graphs[pytorch] to enable writing to torch files"
)

# Default values for arguments
if edge_features is None:
edge_features = ["len", "vdiff"]

if node_features is None:
node_features = ["pos"]

# check that the node labels are integers and unique so that they can be used as indices
if not all(isinstance(node, int) for node in graph.nodes):
node_types = set([type(node) for node in graph.nodes])
Expand All @@ -77,16 +86,22 @@ def to_pyg(
def _get_edge_indecies(pyg_g):
return pyg_g.edge_index

def _get_edge_features(pyg_g):
if edge_features != ["vdiff"]:
raise NotImplementedError(edge_features_values)
# TODO: handle features of different types more generally, i.e. both single ("len") values and tuples (like "vdiff")
return torch.cat((pyg_g.len.unsqueeze(1), pyg_g.vdiff), dim=1).to(torch.float32)

def _get_node_features(pyg_g):
if node_features != ["pos"]:
raise NotImplementedError(node_features_values)
return pyg_g.pos.to(torch.float32)
def _concat_pyg_features(
pyg_g: "pyg.data.Data", features: List[str]
) -> torch.Tensor:
"""Convert features from pyg.Data object to torch.Tensor.
Each feature should be column in the resulting 2D tensor (n_edges or n_nodes, n_features).
Note, this function can handle node AND edge features.
"""
v_concat = []
for f in features:
v = pyg_g[f]
# Convert 1D features into 1xN tensor
if v.ndim == 1:
v = v.unsqueeze(1)
v_concat.append(v)

return torch.cat(v_concat, dim=1).to(torch.float32)

if list_from_attribute is not None:
# create a list of graph objects by splitting the graph by the list_from_attribute
Expand All @@ -104,9 +119,13 @@ def _get_node_features(pyg_g):
else:
pyg_graphs = [pyg_convert.from_networkx(graph)]

edge_features_values = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs]
edge_features_values = [
_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs
]
edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [_get_node_features(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [
_concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs
]

if list_from_attribute is None:
edge_features_values = edge_features_values[0]
Expand Down

0 comments on commit 7ec34ff

Please sign in to comment.