Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add containing_rectangle graph connection method for m2g #28

Merged
merged 10 commits into from
Oct 23, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

- Add containing_rectangle graph connection method for m2g edges
[\#28](https://github.com/mllam/weather-model-graphs/pull/28)
@joeloskarsson

### Changed

- Create different number of mesh nodes in x- and y-direction.
Expand Down
62 changes: 59 additions & 3 deletions src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,20 @@ def create_all_graph_components(
- "flat": Create a single-level 2D mesh graph with `grid_refinement_factor`,
similar to Keisler et al. (2022)
- "flat_multiscale": Create a flat multiscale mesh graph with `max_num_levels`,
`grid_refinement_factor` and `mesh_refinement_factor`,
`grid_refinement_factor` and `level_refinement_factor`,
similar to GraphCast, Lam et al. (2023)
- "hierarchical": Create a hierarchical mesh graph with `max_num_levels`,
`grid_refinement_factor` and `mesh_refinement_factor`,
`grid_refinement_factor` and `level_refinement_factor`,
similar to Okcarsson et al. (2023)

m2g_connectivity:
- "nearest_neighbour": Find the nearest neighbour in mesh for each node in grid
- "nearest_neighbours": Find the `max_num_neighbours` nearest neighbours in mesh for each node in grid
- "within_radius": Find all neighbours in mesh within an absolute distance
of `max_dist` or relative distance of `rel_max_dist` from each node in grid
- "containing_rectangle": For each grid node, find the rectangle with 4 mesh nodes as corners
such that the grid node is contained within it. Connect these 4 (or less along edges)
mesh nodes to the grid node.
"""
graph_components: dict[networkx.DiGraph] = {}

Expand Down Expand Up @@ -177,8 +180,15 @@ def connect_nodes_across_graphs(
- "nearest_neighbour": Find the nearest neighbour in `G_target` for each node in `G_source`
- "nearest_neighbours": Find the `max_num_neighbours` nearest neighbours in `G_target` for each node in `G_source`
- "within_radius": Find all neighbours in `G_target` within a distance of `max_dist` from each node in `G_source`
- "containing_rectangle": For each node in `G_target`, find the rectangle in `G_source`
with 4 nodes as corners such that the `G_target` node is contained within it.
Connect these 4 (or less along edges) corner nodes to the `G_target` node.
Requires that `G_source` has dx and dy properties, i.e. is a quadrilateral mesh graph.
max_dist : float
Maximum distance to search for neighbours in `G_target` for each node in `G_source`
rel_max_dist : float
Maximum distance to search for neighbours in `G_target` for each node in `G_source`,
relative to longest edge in (bottom level of) `G_source` and `G_target`.
max_num_neighbours : int
Maximum number of neighbours to search for in `G_target` for each node in `G_source`

Expand All @@ -198,7 +208,53 @@ def connect_nodes_across_graphs(
# Determine method and perform checks once
# Conditionally define _find_neighbour_node_idxs_in_source_mesh for use in
# loop later
if method == "nearest_neighbour":
if method == "containing_rectangle":
if (
max_dist is not None
or rel_max_dist is not None
or max_num_neighbours is not None
):
raise Exception(
"to use `containing_rectangle` you should not set `max_dist`, `rel_max_dist`or `max_num_neighbours`"
)
assert (
"dx" in G_source.graph and "dy" in G_source.graph
), "Source graph must have dx and dy properties to connect nodes using method containing_rectangle"

# Connect to all nodes that could potentially be close enough,
# which is at a relative distance of 1. This relative distance is equal
# to the diagonal of one rectangle.
rad_graph = connect_nodes_across_graphs(
G_source, G_target, method="within_radius", rel_max_dist=1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this 1.0 and not srt(dx^2 + dy^2)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. This is relative distance, which is relative to the longest edge in (the bottom level of) the source and target graphs. These longest edges will be the diagonal edges in each rectangle in the mesh graph, which have the length sqrt(dx^2 + dy^2). So that is exactly the distance of rel_max_dist=1.0.

I added a bit of a clarification about this, and also realized that there was nothing describing the rel_max_dist argument in the docstring, so fixed that as well.

)

# Filter edges to those that fit within a rectangle of measurements dx,dy
mesh_node_dx = G_source.graph["dx"]
mesh_node_dy = G_source.graph["dy"]

if isinstance(mesh_node_dx, dict):
# In hierarchical graph these properties are dicts, in that case use
# values for bottom level.
mesh_node_dx = mesh_node_dx[0]
mesh_node_dy = mesh_node_dy[0]

# This function is a filter that applies to edges, represented as vectors (vx, vy) in R^ 2.
# The filter is True if |vx| < dx & |vy| < dy, where dx and dy are the distance between
# rows and columns in source quadrilateral graph.
def _edge_filter(edge_prop):
abs_diffs = np.abs(edge_prop["vdiff"])
return abs_diffs[0] < mesh_node_dx and abs_diffs[1] < mesh_node_dy

filtered_edges = [
(u, v)
for u, v, edge_prop in rad_graph.edges(data=True)
if _edge_filter(edge_prop)
]

filtered_graph = rad_graph.edge_subgraph(filtered_edges)
return filtered_graph

elif method == "nearest_neighbour":
if (
max_dist is not None
or rel_max_dist is not None
Expand Down
4 changes: 4 additions & 0 deletions src/weather_model_graphs/create/mesh/kinds/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,8 @@ def create_hierarchical_multiscale_mesh_graph(

G_m2m = networkx.compose_all([G_all_levels, G_up_all, G_down_all])

# add dx and dy to graph
for prop in ("dx", "dy"):
G_m2m.graph[prop] = {i: g.graph[prop] for i, g in enumerate(Gs_all_levels)}

return G_m2m
13 changes: 8 additions & 5 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_create_graph_archetype(kind):


# list the connectivity options for g2m and m2g and the kwargs to test
G2M_M2G_CONNECTIVITY_OPTIONS = dict(
G2M_CONNECTIVITY_OPTIONS = dict(
nearest_neighbour=[],
nearest_neighbours=[dict(max_num_neighbours=4), dict(max_num_neighbours=8)],
within_radius=[
Expand All @@ -59,6 +59,9 @@ def test_create_graph_archetype(kind):
dict(rel_max_dist=1.0),
],
)
# containing_rectangle option should only be used for m2g
M2G_CONNECTIVITY_OPTIONS = G2M_CONNECTIVITY_OPTIONS.copy()
M2G_CONNECTIVITY_OPTIONS["containing_rectangle"] = [dict()]

# list the connectivity options for m2m and the kwargs to test
M2M_CONNECTIVITY_OPTIONS = dict(
Expand All @@ -74,14 +77,14 @@ def test_create_graph_archetype(kind):
)


@pytest.mark.parametrize("g2m_connectivity", G2M_M2G_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("m2g_connectivity", G2M_M2G_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("g2m_connectivity", G2M_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("m2g_connectivity", M2G_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("m2m_connectivity", M2M_CONNECTIVITY_OPTIONS.keys())
def test_create_graph_generic(m2g_connectivity, g2m_connectivity, m2m_connectivity):
xy = _create_fake_xy(N=32)

for g2m_kwargs in G2M_M2G_CONNECTIVITY_OPTIONS[g2m_connectivity]:
for m2g_kwargs in G2M_M2G_CONNECTIVITY_OPTIONS[m2g_connectivity]:
for g2m_kwargs in G2M_CONNECTIVITY_OPTIONS[g2m_connectivity]:
for m2g_kwargs in M2G_CONNECTIVITY_OPTIONS[m2g_connectivity]:
for m2m_kwargs in M2M_CONNECTIVITY_OPTIONS[m2m_connectivity]:
graph = wmg.create.create_all_graph_components(
xy=xy,
Expand Down
Loading