From 5a1a29d8462df98320c2c4a1801ddf607bf4d50b Mon Sep 17 00:00:00 2001 From: Maximilian Pierzyna Date: Wed, 9 Oct 2024 14:03:15 -0400 Subject: [PATCH 1/5] hide_ticks causes errors, so remove it. Maybe old argument --- src/weather_model_graphs/visualise/plot_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weather_model_graphs/visualise/plot_2d.py b/src/weather_model_graphs/visualise/plot_2d.py index d433d69..9efdb4a 100644 --- a/src/weather_model_graphs/visualise/plot_2d.py +++ b/src/weather_model_graphs/visualise/plot_2d.py @@ -13,7 +13,7 @@ def nx_draw_with_pos(g, with_labels=False, **kwargs): if ax is None: _, ax = plt.subplots(figsize=(10, 10)) networkx.draw_networkx( - ax=ax, G=g, pos=pos, hide_ticks=False, with_labels=with_labels, **kwargs + ax=ax, G=g, pos=pos, with_labels=with_labels, **kwargs ) return ax From 7fa84facf0bd23a8980219674a41c4df463fdc3d Mon Sep 17 00:00:00 2001 From: Maximilian Pierzyna Date: Wed, 9 Oct 2024 14:23:33 -0400 Subject: [PATCH 2/5] Improvement: Make node/edge feature concat universal for 1D or 2D features. Bugfix: Remove mutable default arguments --- src/weather_model_graphs/save.py | 49 ++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/src/weather_model_graphs/save.py b/src/weather_model_graphs/save.py index 6af46c8..c8ac0e4 100644 --- a/src/weather_model_graphs/save.py +++ b/src/weather_model_graphs/save.py @@ -1,5 +1,6 @@ import pickle from pathlib import Path +from typing import List import networkx from loguru import logger @@ -9,6 +10,7 @@ try: import torch import torch_geometric.utils.convert as pyg_convert + import torch_geometric as pyg HAS_PYG = True except ImportError: @@ -16,12 +18,12 @@ def to_pyg( - graph: networkx.DiGraph, - output_directory: str, - name: str, - edge_features=["vdiff"], - node_features=["pos"], - list_from_attribute=None, + graph: networkx.DiGraph, + output_directory: str, + name: str, + edge_features: List[str] | None = None, + node_features: List[str] | None = None, + list_from_attribute=None, ): """ Save the networkx graph to PyTorch Geometric format that matches what the @@ -59,6 +61,13 @@ def to_pyg( "install weather-mode-graphs[pytorch] to enable writing to torch files" ) + # Default values for arguments + if edge_features is None: + edge_features = ["len", "vdiff"] + + if node_features is None: + node_features = ["pos"] + # check that the node labels are integers and unique so that they can be used as indices if not all(isinstance(node, int) for node in graph.nodes): node_types = set([type(node) for node in graph.nodes]) @@ -77,16 +86,20 @@ def to_pyg( def _get_edge_indecies(pyg_g): return pyg_g.edge_index - def _get_edge_features(pyg_g): - if edge_features != ["vdiff"]: - raise NotImplementedError(edge_features_values) - # TODO: handle features of different types more generally, i.e. both single ("len") values and tuples (like "vdiff") - return torch.cat((pyg_g.len.unsqueeze(1), pyg_g.vdiff), dim=1).to(torch.float32) - - def _get_node_features(pyg_g): - if node_features != ["pos"]: - raise NotImplementedError(node_features_values) - return pyg_g.pos.to(torch.float32) + def _concat_pyg_features(pyg_g: "pyg.data.Data", features: List[str]) -> torch.Tensor: + """Convert features from pyg.Data object to torch.Tensor. + Each feature should be column in the resulting 2D tensor (n_features, n_edges or n_nodes). + Note, this function can handle node AND edge features. + """ + v_concat = [] + for f in features: + v = torch.tensor(pyg_g[f]) # make sure we have torch tensor + # Convert 1D features into 1xN tensor + if v.ndim == 1: + v = v.unsqueeze(1) + v_concat.append(v) + + return torch.cat(v_concat, dim=1).to(torch.float32) if list_from_attribute is not None: # create a list of graph objects by splitting the graph by the list_from_attribute @@ -104,9 +117,9 @@ def _get_node_features(pyg_g): else: pyg_graphs = [pyg_convert.from_networkx(graph)] - edge_features_values = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs] + edge_features_values = [_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs] edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs] - node_features_values = [_get_node_features(pyg_g) for pyg_g in pyg_graphs] + node_features_values = [_concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs] if list_from_attribute is None: edge_features_values = edge_features_values[0] From 746b700f8068eacf2488c72357ec828442bae5a4 Mon Sep 17 00:00:00 2001 From: Maximilian Pierzyna Date: Thu, 10 Oct 2024 21:22:21 -0400 Subject: [PATCH 3/5] Incorporating suggestions from #31. Linting now successful --- src/weather_model_graphs/save.py | 30 +++++++++++-------- src/weather_model_graphs/visualise/plot_2d.py | 2 +- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/weather_model_graphs/save.py b/src/weather_model_graphs/save.py index c8ac0e4..74e1a53 100644 --- a/src/weather_model_graphs/save.py +++ b/src/weather_model_graphs/save.py @@ -9,8 +9,8 @@ try: import torch - import torch_geometric.utils.convert as pyg_convert import torch_geometric as pyg + import torch_geometric.utils.convert as pyg_convert HAS_PYG = True except ImportError: @@ -18,12 +18,12 @@ def to_pyg( - graph: networkx.DiGraph, - output_directory: str, - name: str, - edge_features: List[str] | None = None, - node_features: List[str] | None = None, - list_from_attribute=None, + graph: networkx.DiGraph, + output_directory: str, + name: str, + edge_features: List[str] | None = None, + node_features: List[str] | None = None, + list_from_attribute=None, ): """ Save the networkx graph to PyTorch Geometric format that matches what the @@ -86,14 +86,16 @@ def to_pyg( def _get_edge_indecies(pyg_g): return pyg_g.edge_index - def _concat_pyg_features(pyg_g: "pyg.data.Data", features: List[str]) -> torch.Tensor: + def _concat_pyg_features( + pyg_g: "pyg.data.Data", features: List[str] + ) -> torch.Tensor: """Convert features from pyg.Data object to torch.Tensor. - Each feature should be column in the resulting 2D tensor (n_features, n_edges or n_nodes). + Each feature should be column in the resulting 2D tensor (n_edges or n_nodes, n_features). Note, this function can handle node AND edge features. """ v_concat = [] for f in features: - v = torch.tensor(pyg_g[f]) # make sure we have torch tensor + v = pyg_g[f] # Convert 1D features into 1xN tensor if v.ndim == 1: v = v.unsqueeze(1) @@ -117,9 +119,13 @@ def _concat_pyg_features(pyg_g: "pyg.data.Data", features: List[str]) -> torch.T else: pyg_graphs = [pyg_convert.from_networkx(graph)] - edge_features_values = [_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs] + edge_features_values = [ + _concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs + ] edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs] - node_features_values = [_concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs] + node_features_values = [ + _concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs + ] if list_from_attribute is None: edge_features_values = edge_features_values[0] diff --git a/src/weather_model_graphs/visualise/plot_2d.py b/src/weather_model_graphs/visualise/plot_2d.py index 9efdb4a..d433d69 100644 --- a/src/weather_model_graphs/visualise/plot_2d.py +++ b/src/weather_model_graphs/visualise/plot_2d.py @@ -13,7 +13,7 @@ def nx_draw_with_pos(g, with_labels=False, **kwargs): if ax is None: _, ax = plt.subplots(figsize=(10, 10)) networkx.draw_networkx( - ax=ax, G=g, pos=pos, with_labels=with_labels, **kwargs + ax=ax, G=g, pos=pos, hide_ticks=False, with_labels=with_labels, **kwargs ) return ax From d8267f149eff9a5222e4c8b58270b68f0ef706a5 Mon Sep 17 00:00:00 2001 From: Maximilian Pierzyna Date: Thu, 10 Oct 2024 21:27:44 -0400 Subject: [PATCH 4/5] Add modification of function to changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 710cea3..c1c7b1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#19](https://github.com/mllam/weather-model-graphs/pull/19) @joeloskarsson +- `save.to_pyg` can now handle any number of 1D or 2D edge or node features when + converting pytorch-geometric `Data` objects to `torch.Tensor` objects. + ### Changed - Create different number of mesh nodes in x- and y-direction. From 915e9e9589400d1567bcc91d1e465cefb020af86 Mon Sep 17 00:00:00 2001 From: Maximilian Date: Wed, 16 Oct 2024 21:18:15 +0200 Subject: [PATCH 5/5] Update CHANGELOG.md Co-authored-by: Joel Oskarsson --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c1c7b1c..460bfdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `save.to_pyg` can now handle any number of 1D or 2D edge or node features when converting pytorch-geometric `Data` objects to `torch.Tensor` objects. + [\#31](https://github.com/mllam/weather-model-graphs/pull/31) + @maxiimilian ### Changed