Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make save function more universal to accept any number of 1D or 2D node or edge features #31

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
from pathlib import Path
from typing import List

import networkx
from loguru import logger
Expand All @@ -9,19 +10,20 @@
try:
import torch
import torch_geometric.utils.convert as pyg_convert
import torch_geometric as pyg

HAS_PYG = True
except ImportError:
HAS_PYG = False


def to_pyg(
graph: networkx.DiGraph,
output_directory: str,
name: str,
edge_features=["vdiff"],
node_features=["pos"],
list_from_attribute=None,
graph: networkx.DiGraph,
output_directory: str,
name: str,
edge_features: List[str] | None = None,
node_features: List[str] | None = None,
list_from_attribute=None,
):
"""
Save the networkx graph to PyTorch Geometric format that matches what the
Expand Down Expand Up @@ -59,6 +61,13 @@ def to_pyg(
"install weather-mode-graphs[pytorch] to enable writing to torch files"
)

# Default values for arguments
if edge_features is None:
edge_features = ["len", "vdiff"]
joeloskarsson marked this conversation as resolved.
Show resolved Hide resolved

if node_features is None:
node_features = ["pos"]

# check that the node labels are integers and unique so that they can be used as indices
if not all(isinstance(node, int) for node in graph.nodes):
node_types = set([type(node) for node in graph.nodes])
Expand All @@ -77,16 +86,20 @@ def to_pyg(
def _get_edge_indecies(pyg_g):
return pyg_g.edge_index

def _get_edge_features(pyg_g):
if edge_features != ["vdiff"]:
raise NotImplementedError(edge_features_values)
# TODO: handle features of different types more generally, i.e. both single ("len") values and tuples (like "vdiff")
return torch.cat((pyg_g.len.unsqueeze(1), pyg_g.vdiff), dim=1).to(torch.float32)

def _get_node_features(pyg_g):
if node_features != ["pos"]:
raise NotImplementedError(node_features_values)
return pyg_g.pos.to(torch.float32)
def _concat_pyg_features(pyg_g: "pyg.data.Data", features: List[str]) -> torch.Tensor:
"""Convert features from pyg.Data object to torch.Tensor.
Each feature should be column in the resulting 2D tensor (n_features, n_edges or n_nodes).
joeloskarsson marked this conversation as resolved.
Show resolved Hide resolved
Note, this function can handle node AND edge features.
"""
v_concat = []
for f in features:
v = torch.tensor(pyg_g[f]) # make sure we have torch tensor
joeloskarsson marked this conversation as resolved.
Show resolved Hide resolved
# Convert 1D features into 1xN tensor
if v.ndim == 1:
v = v.unsqueeze(1)
v_concat.append(v)

return torch.cat(v_concat, dim=1).to(torch.float32)

if list_from_attribute is not None:
# create a list of graph objects by splitting the graph by the list_from_attribute
Expand All @@ -104,9 +117,9 @@ def _get_node_features(pyg_g):
else:
pyg_graphs = [pyg_convert.from_networkx(graph)]

edge_features_values = [_get_edge_features(pyg_g) for pyg_g in pyg_graphs]
edge_features_values = [_concat_pyg_features(pyg_g, features=edge_features) for pyg_g in pyg_graphs]
edge_indecies = [_get_edge_indecies(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [_get_node_features(pyg_g) for pyg_g in pyg_graphs]
node_features_values = [_concat_pyg_features(pyg_g, features=node_features) for pyg_g in pyg_graphs]

if list_from_attribute is None:
edge_features_values = edge_features_values[0]
Expand Down
2 changes: 1 addition & 1 deletion src/weather_model_graphs/visualise/plot_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def nx_draw_with_pos(g, with_labels=False, **kwargs):
if ax is None:
_, ax = plt.subplots(figsize=(10, 10))
networkx.draw_networkx(
ax=ax, G=g, pos=pos, hide_ticks=False, with_labels=with_labels, **kwargs
ax=ax, G=g, pos=pos, with_labels=with_labels, **kwargs
joeloskarsson marked this conversation as resolved.
Show resolved Hide resolved
)

return ax
Expand Down
Loading