Skip to content

Commit

Permalink
Add containing_rectangle graph connection method for m2g (#28)
Browse files Browse the repository at this point in the history
* Write algorithm for connecting with containing rectangle

* Fix dx/dy graph property also for hierarchical graph

* Make containing_rectangle option only for m2g connectivity

* Add description of containing_rectangle method to docstrings

* Add comment explaining edge filter used in containing_rectangle method

* Fix linting

* Add changelog entry

* Clarify rel_max_distance usage
  • Loading branch information
joeloskarsson authored Oct 23, 2024
1 parent d91571a commit c3ba212
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#31](https://github.com/mllam/weather-model-graphs/pull/31)
@maxiimilian

- Add containing_rectangle graph connection method for m2g edges
[\#28](https://github.com/mllam/weather-model-graphs/pull/28)
@joeloskarsson

### Changed

- Fix wrong number of mesh levels when grid is multiple of refinement factor
Expand Down
62 changes: 59 additions & 3 deletions src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,20 @@ def create_all_graph_components(
- "flat": Create a single-level 2D mesh graph with `grid_refinement_factor`,
similar to Keisler et al. (2022)
- "flat_multiscale": Create a flat multiscale mesh graph with `max_num_levels`,
`grid_refinement_factor` and `mesh_refinement_factor`,
`grid_refinement_factor` and `level_refinement_factor`,
similar to GraphCast, Lam et al. (2023)
- "hierarchical": Create a hierarchical mesh graph with `max_num_levels`,
`grid_refinement_factor` and `mesh_refinement_factor`,
`grid_refinement_factor` and `level_refinement_factor`,
similar to Okcarsson et al. (2023)
m2g_connectivity:
- "nearest_neighbour": Find the nearest neighbour in mesh for each node in grid
- "nearest_neighbours": Find the `max_num_neighbours` nearest neighbours in mesh for each node in grid
- "within_radius": Find all neighbours in mesh within an absolute distance
of `max_dist` or relative distance of `rel_max_dist` from each node in grid
- "containing_rectangle": For each grid node, find the rectangle with 4 mesh nodes as corners
such that the grid node is contained within it. Connect these 4 (or less along edges)
mesh nodes to the grid node.
"""
graph_components: dict[networkx.DiGraph] = {}

Expand Down Expand Up @@ -177,8 +180,15 @@ def connect_nodes_across_graphs(
- "nearest_neighbour": Find the nearest neighbour in `G_target` for each node in `G_source`
- "nearest_neighbours": Find the `max_num_neighbours` nearest neighbours in `G_target` for each node in `G_source`
- "within_radius": Find all neighbours in `G_target` within a distance of `max_dist` from each node in `G_source`
- "containing_rectangle": For each node in `G_target`, find the rectangle in `G_source`
with 4 nodes as corners such that the `G_target` node is contained within it.
Connect these 4 (or less along edges) corner nodes to the `G_target` node.
Requires that `G_source` has dx and dy properties, i.e. is a quadrilateral mesh graph.
max_dist : float
Maximum distance to search for neighbours in `G_target` for each node in `G_source`
rel_max_dist : float
Maximum distance to search for neighbours in `G_target` for each node in `G_source`,
relative to longest edge in (bottom level of) `G_source` and `G_target`.
max_num_neighbours : int
Maximum number of neighbours to search for in `G_target` for each node in `G_source`
Expand All @@ -198,7 +208,53 @@ def connect_nodes_across_graphs(
# Determine method and perform checks once
# Conditionally define _find_neighbour_node_idxs_in_source_mesh for use in
# loop later
if method == "nearest_neighbour":
if method == "containing_rectangle":
if (
max_dist is not None
or rel_max_dist is not None
or max_num_neighbours is not None
):
raise Exception(
"to use `containing_rectangle` you should not set `max_dist`, `rel_max_dist`or `max_num_neighbours`"
)
assert (
"dx" in G_source.graph and "dy" in G_source.graph
), "Source graph must have dx and dy properties to connect nodes using method containing_rectangle"

# Connect to all nodes that could potentially be close enough,
# which is at a relative distance of 1. This relative distance is equal
# to the diagonal of one rectangle.
rad_graph = connect_nodes_across_graphs(
G_source, G_target, method="within_radius", rel_max_dist=1.0
)

# Filter edges to those that fit within a rectangle of measurements dx,dy
mesh_node_dx = G_source.graph["dx"]
mesh_node_dy = G_source.graph["dy"]

if isinstance(mesh_node_dx, dict):
# In hierarchical graph these properties are dicts, in that case use
# values for bottom level.
mesh_node_dx = mesh_node_dx[0]
mesh_node_dy = mesh_node_dy[0]

# This function is a filter that applies to edges, represented as vectors (vx, vy) in R^ 2.
# The filter is True if |vx| < dx & |vy| < dy, where dx and dy are the distance between
# rows and columns in source quadrilateral graph.
def _edge_filter(edge_prop):
abs_diffs = np.abs(edge_prop["vdiff"])
return abs_diffs[0] < mesh_node_dx and abs_diffs[1] < mesh_node_dy

filtered_edges = [
(u, v)
for u, v, edge_prop in rad_graph.edges(data=True)
if _edge_filter(edge_prop)
]

filtered_graph = rad_graph.edge_subgraph(filtered_edges)
return filtered_graph

elif method == "nearest_neighbour":
if (
max_dist is not None
or rel_max_dist is not None
Expand Down
4 changes: 4 additions & 0 deletions src/weather_model_graphs/create/mesh/kinds/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,8 @@ def create_hierarchical_multiscale_mesh_graph(

G_m2m = networkx.compose_all([G_all_levels, G_up_all, G_down_all])

# add dx and dy to graph
for prop in ("dx", "dy"):
G_m2m.graph[prop] = {i: g.graph[prop] for i, g in enumerate(Gs_all_levels)}

return G_m2m
13 changes: 8 additions & 5 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_create_graph_archetype(kind):


# list the connectivity options for g2m and m2g and the kwargs to test
G2M_M2G_CONNECTIVITY_OPTIONS = dict(
G2M_CONNECTIVITY_OPTIONS = dict(
nearest_neighbour=[],
nearest_neighbours=[dict(max_num_neighbours=4), dict(max_num_neighbours=8)],
within_radius=[
Expand All @@ -59,6 +59,9 @@ def test_create_graph_archetype(kind):
dict(rel_max_dist=1.0),
],
)
# containing_rectangle option should only be used for m2g
M2G_CONNECTIVITY_OPTIONS = G2M_CONNECTIVITY_OPTIONS.copy()
M2G_CONNECTIVITY_OPTIONS["containing_rectangle"] = [dict()]

# list the connectivity options for m2m and the kwargs to test
M2M_CONNECTIVITY_OPTIONS = dict(
Expand All @@ -74,14 +77,14 @@ def test_create_graph_archetype(kind):
)


@pytest.mark.parametrize("g2m_connectivity", G2M_M2G_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("m2g_connectivity", G2M_M2G_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("g2m_connectivity", G2M_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("m2g_connectivity", M2G_CONNECTIVITY_OPTIONS.keys())
@pytest.mark.parametrize("m2m_connectivity", M2M_CONNECTIVITY_OPTIONS.keys())
def test_create_graph_generic(m2g_connectivity, g2m_connectivity, m2m_connectivity):
xy = _create_fake_xy(N=32)

for g2m_kwargs in G2M_M2G_CONNECTIVITY_OPTIONS[g2m_connectivity]:
for m2g_kwargs in G2M_M2G_CONNECTIVITY_OPTIONS[m2g_connectivity]:
for g2m_kwargs in G2M_CONNECTIVITY_OPTIONS[g2m_connectivity]:
for m2g_kwargs in M2G_CONNECTIVITY_OPTIONS[m2g_connectivity]:
for m2m_kwargs in M2M_CONNECTIVITY_OPTIONS[m2m_connectivity]:
graph = wmg.create.create_all_graph_components(
xy=xy,
Expand Down

0 comments on commit c3ba212

Please sign in to comment.