Skip to content

Commit

Permalink
Make save function more universal to accept any number of 1D or 2D no…
Browse files Browse the repository at this point in the history
…de or edge features (#31)

* hide_ticks causes errors, so remove it. Maybe old argument

* Improvement: Make node/edge feature concat universal for 1D or 2D features. Bugfix: Remove mutable default arguments

* Incorporating suggestions from #31. Linting now successful

* Add modification of  function to changelog

* Update CHANGELOG.md

Co-authored-by: Joel Oskarsson <[email protected]>

---------

Co-authored-by: Joel Oskarsson <[email protected]>
  • Loading branch information
maxiimilian and joeloskarsson authored Oct 17, 2024
1 parent 5550c53 commit d91571a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

- `save.to_pyg` can now handle any number of 1D or 2D edge or node features when
converting pytorch-geometric `Data` objects to `torch.Tensor` objects.
[\#31](https://github.com/mllam/weather-model-graphs/pull/31)
@maxiimilian

### Changed

- Fix wrong number of mesh levels when grid is multiple of refinement factor
Expand Down
47 changes: 33 additions & 14 deletions src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
from pathlib import Path
from typing import List

import networkx
from loguru import logger
Expand All @@ -8,6 +9,7 @@

try:
import torch
import torch_geometric as pyg
import torch_geometric.utils.convert as pyg_convert

HAS_PYG = True
Expand All @@ -19,8 +21,8 @@ def to_pyg(
graph: networkx.DiGraph,
output_directory: str,
name: str,
edge_features=["vdiff"],
node_features=["pos"],
edge_features: List[str] | None = None,
node_features: List[str] | None = None,
list_from_attribute=None,
):
"""
Expand Down Expand Up @@ -59,6 +61,13 @@ def to_pyg(
"install weather-mode-graphs[pytorch] to enable writing to torch files"
)

# Default values for arguments
if edge_features is None:
edge_features = ["len", "vdiff"]

if node_features is None:
node_features = ["pos"]

# check that the node labels are integers and unique so that they can be used as indices
if not all(isinstance(node, int) for node in graph.nodes):
node_types = set([type(node) for node in graph.nodes])
Expand All @@ -77,16 +86,22 @@ def to_pyg(
def _get_edge_indecies(pyg_g):
return pyg_g.edge_index

def _get_edge_features(pyg_g):
if edge_features != ["vdiff"]:
raise NotImplementedError(edge_features_values)
# TODO: handle features of different types more generally, i.e. both single ("len") values and tuples (like "vdiff")
return torch.cat((pyg_g.len.unsqueeze(1), pyg_g.vdiff), dim=1).to(torch.float32)

def _get_node_features(pyg_g):
if node_features != ["pos"]:
raise NotImplementedError(node_features_values)
return pyg_g.pos.to(torch.float32)
def _concat_pyg_features(
pyg_g: "pyg.data.Data", features: List[str]
) -> torch.Tensor:
"""Convert features from pyg.Data object to torch.Tensor.
Each feature should be column in the resulting 2D tensor (n_edges or n_nodes, n_features).
Note, this function can handle node AND edge features.
"""
v_concat = []
for f in features:
v = pyg_g[f]
# Convert 1D features into 1xN tensor
if v.ndim == 1:
v = v.unsqueeze(1)
v_concat.append(v)

return torch.cat(v_concat, dim=1).to(torch.float32)

if list_from_attribute is not None:
# create a list of graph objects by splitting the graph by the list_from_attribute
Expand All @@ -104,9 +119,13 @@ def _get_node_features(pyg_g):
else:
pyg_graphs = [pyg_convert.from_networkx(graph)]

edge_features_values = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs]
edge_features_values = [
_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs
]
edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [_get_node_features(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [
_concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs
]

if list_from_attribute is None:
edge_features_values = edge_features_values[0]
Expand Down

0 comments on commit d91571a

Please sign in to comment.