Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make save function more universal to accept any number of 1D or 2D node or edge features #31

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

- `save.to_pyg` can now handle any number of 1D or 2D edge or node features when
converting pytorch-geometric `Data` objects to `torch.Tensor` objects.

maxiimilian marked this conversation as resolved.
Show resolved Hide resolved
### Changed

- Create different number of mesh nodes in x- and y-direction.
Expand Down
47 changes: 33 additions & 14 deletions src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
from pathlib import Path
from typing import List

import networkx
from loguru import logger
Expand All @@ -8,6 +9,7 @@

try:
import torch
import torch_geometric as pyg
import torch_geometric.utils.convert as pyg_convert

HAS_PYG = True
Expand All @@ -19,8 +21,8 @@ def to_pyg(
graph: networkx.DiGraph,
output_directory: str,
name: str,
edge_features=["vdiff"],
node_features=["pos"],
edge_features: List[str] | None = None,
node_features: List[str] | None = None,
list_from_attribute=None,
):
"""
Expand Down Expand Up @@ -59,6 +61,13 @@ def to_pyg(
"install weather-mode-graphs[pytorch] to enable writing to torch files"
)

# Default values for arguments
if edge_features is None:
edge_features = ["len", "vdiff"]
joeloskarsson marked this conversation as resolved.
Show resolved Hide resolved

if node_features is None:
node_features = ["pos"]

# check that the node labels are integers and unique so that they can be used as indices
if not all(isinstance(node, int) for node in graph.nodes):
node_types = set([type(node) for node in graph.nodes])
Expand All @@ -77,16 +86,22 @@ def to_pyg(
def _get_edge_indecies(pyg_g):
return pyg_g.edge_index

def _get_edge_features(pyg_g):
if edge_features != ["vdiff"]:
raise NotImplementedError(edge_features_values)
# TODO: handle features of different types more generally, i.e. both single ("len") values and tuples (like "vdiff")
return torch.cat((pyg_g.len.unsqueeze(1), pyg_g.vdiff), dim=1).to(torch.float32)

def _get_node_features(pyg_g):
if node_features != ["pos"]:
raise NotImplementedError(node_features_values)
return pyg_g.pos.to(torch.float32)
def _concat_pyg_features(
pyg_g: "pyg.data.Data", features: List[str]
) -> torch.Tensor:
"""Convert features from pyg.Data object to torch.Tensor.
Each feature should be column in the resulting 2D tensor (n_edges or n_nodes, n_features).
Note, this function can handle node AND edge features.
"""
v_concat = []
for f in features:
v = pyg_g[f]
# Convert 1D features into 1xN tensor
if v.ndim == 1:
v = v.unsqueeze(1)
v_concat.append(v)

return torch.cat(v_concat, dim=1).to(torch.float32)

if list_from_attribute is not None:
# create a list of graph objects by splitting the graph by the list_from_attribute
Expand All @@ -104,9 +119,13 @@ def _get_node_features(pyg_g):
else:
pyg_graphs = [pyg_convert.from_networkx(graph)]

edge_features_values = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs]
edge_features_values = [
_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs
]
edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [_get_node_features(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [
_concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs
]

if list_from_attribute is None:
edge_features_values = edge_features_values[0]
Expand Down
Loading