diff --git a/CHANGELOG.md b/CHANGELOG.md index 710cea3..460bfdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#19](https://github.com/mllam/weather-model-graphs/pull/19) @joeloskarsson +- `save.to_pyg` can now handle any number of 1D or 2D edge or node features when + converting pytorch-geometric `Data` objects to `torch.Tensor` objects. + [\#31](https://github.com/mllam/weather-model-graphs/pull/31) + @maxiimilian + ### Changed - Create different number of mesh nodes in x- and y-direction. diff --git a/src/weather_model_graphs/save.py b/src/weather_model_graphs/save.py index 6af46c8..74e1a53 100644 --- a/src/weather_model_graphs/save.py +++ b/src/weather_model_graphs/save.py @@ -1,5 +1,6 @@ import pickle from pathlib import Path +from typing import List import networkx from loguru import logger @@ -8,6 +9,7 @@ try: import torch + import torch_geometric as pyg import torch_geometric.utils.convert as pyg_convert HAS_PYG = True @@ -19,8 +21,8 @@ def to_pyg( graph: networkx.DiGraph, output_directory: str, name: str, - edge_features=["vdiff"], - node_features=["pos"], + edge_features: List[str] | None = None, + node_features: List[str] | None = None, list_from_attribute=None, ): """ @@ -59,6 +61,13 @@ def to_pyg( "install weather-mode-graphs[pytorch] to enable writing to torch files" ) + # Default values for arguments + if edge_features is None: + edge_features = ["len", "vdiff"] + + if node_features is None: + node_features = ["pos"] + # check that the node labels are integers and unique so that they can be used as indices if not all(isinstance(node, int) for node in graph.nodes): node_types = set([type(node) for node in graph.nodes]) @@ -77,16 +86,22 @@ def to_pyg( def _get_edge_indecies(pyg_g): return pyg_g.edge_index - def _get_edge_features(pyg_g): - if edge_features != ["vdiff"]: - raise NotImplementedError(edge_features_values) - # TODO: handle features of different types more generally, i.e. both single ("len") values and tuples (like "vdiff") - return torch.cat((pyg_g.len.unsqueeze(1), pyg_g.vdiff), dim=1).to(torch.float32) - - def _get_node_features(pyg_g): - if node_features != ["pos"]: - raise NotImplementedError(node_features_values) - return pyg_g.pos.to(torch.float32) + def _concat_pyg_features( + pyg_g: "pyg.data.Data", features: List[str] + ) -> torch.Tensor: + """Convert features from pyg.Data object to torch.Tensor. + Each feature should be column in the resulting 2D tensor (n_edges or n_nodes, n_features). + Note, this function can handle node AND edge features. + """ + v_concat = [] + for f in features: + v = pyg_g[f] + # Convert 1D features into 1xN tensor + if v.ndim == 1: + v = v.unsqueeze(1) + v_concat.append(v) + + return torch.cat(v_concat, dim=1).to(torch.float32) if list_from_attribute is not None: # create a list of graph objects by splitting the graph by the list_from_attribute @@ -104,9 +119,13 @@ def _get_node_features(pyg_g): else: pyg_graphs = [pyg_convert.from_networkx(graph)] - edge_features_values = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs] + edge_features_values = [ + _concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs + ] edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs] - node_features_values = [_get_node_features(pyg_g) for pyg_g in pyg_graphs] + node_features_values = [ + _concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs + ] if list_from_attribute is None: edge_features_values = edge_features_values[0]