diff --git a/CHANGELOG.md b/CHANGELOG.md index f2540ec..b2bcd96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#31](https://github.com/mllam/weather-model-graphs/pull/31) @maxiimilian +- Add containing_rectangle graph connection method for m2g edges + [\#28](https://github.com/mllam/weather-model-graphs/pull/28) + @joeloskarsson + ### Changed - Fix wrong number of mesh levels when grid is multiple of refinement factor diff --git a/src/weather_model_graphs/create/base.py b/src/weather_model_graphs/create/base.py index 313ae3f..5ce48f6 100644 --- a/src/weather_model_graphs/create/base.py +++ b/src/weather_model_graphs/create/base.py @@ -56,10 +56,10 @@ def create_all_graph_components( - "flat": Create a single-level 2D mesh graph with `grid_refinement_factor`, similar to Keisler et al. (2022) - "flat_multiscale": Create a flat multiscale mesh graph with `max_num_levels`, - `grid_refinement_factor` and `mesh_refinement_factor`, + `grid_refinement_factor` and `level_refinement_factor`, similar to GraphCast, Lam et al. (2023) - "hierarchical": Create a hierarchical mesh graph with `max_num_levels`, - `grid_refinement_factor` and `mesh_refinement_factor`, + `grid_refinement_factor` and `level_refinement_factor`, similar to Okcarsson et al. (2023) m2g_connectivity: @@ -67,6 +67,9 @@ def create_all_graph_components( - "nearest_neighbours": Find the `max_num_neighbours` nearest neighbours in mesh for each node in grid - "within_radius": Find all neighbours in mesh within an absolute distance of `max_dist` or relative distance of `rel_max_dist` from each node in grid + - "containing_rectangle": For each grid node, find the rectangle with 4 mesh nodes as corners + such that the grid node is contained within it. Connect these 4 (or less along edges) + mesh nodes to the grid node. """ graph_components: dict[networkx.DiGraph] = {} @@ -177,8 +180,15 @@ def connect_nodes_across_graphs( - "nearest_neighbour": Find the nearest neighbour in `G_target` for each node in `G_source` - "nearest_neighbours": Find the `max_num_neighbours` nearest neighbours in `G_target` for each node in `G_source` - "within_radius": Find all neighbours in `G_target` within a distance of `max_dist` from each node in `G_source` + - "containing_rectangle": For each node in `G_target`, find the rectangle in `G_source` + with 4 nodes as corners such that the `G_target` node is contained within it. + Connect these 4 (or less along edges) corner nodes to the `G_target` node. + Requires that `G_source` has dx and dy properties, i.e. is a quadrilateral mesh graph. max_dist : float Maximum distance to search for neighbours in `G_target` for each node in `G_source` + rel_max_dist : float + Maximum distance to search for neighbours in `G_target` for each node in `G_source`, + relative to longest edge in (bottom level of) `G_source` and `G_target`. max_num_neighbours : int Maximum number of neighbours to search for in `G_target` for each node in `G_source` @@ -198,7 +208,53 @@ def connect_nodes_across_graphs( # Determine method and perform checks once # Conditionally define _find_neighbour_node_idxs_in_source_mesh for use in # loop later - if method == "nearest_neighbour": + if method == "containing_rectangle": + if ( + max_dist is not None + or rel_max_dist is not None + or max_num_neighbours is not None + ): + raise Exception( + "to use `containing_rectangle` you should not set `max_dist`, `rel_max_dist`or `max_num_neighbours`" + ) + assert ( + "dx" in G_source.graph and "dy" in G_source.graph + ), "Source graph must have dx and dy properties to connect nodes using method containing_rectangle" + + # Connect to all nodes that could potentially be close enough, + # which is at a relative distance of 1. This relative distance is equal + # to the diagonal of one rectangle. + rad_graph = connect_nodes_across_graphs( + G_source, G_target, method="within_radius", rel_max_dist=1.0 + ) + + # Filter edges to those that fit within a rectangle of measurements dx,dy + mesh_node_dx = G_source.graph["dx"] + mesh_node_dy = G_source.graph["dy"] + + if isinstance(mesh_node_dx, dict): + # In hierarchical graph these properties are dicts, in that case use + # values for bottom level. + mesh_node_dx = mesh_node_dx[0] + mesh_node_dy = mesh_node_dy[0] + + # This function is a filter that applies to edges, represented as vectors (vx, vy) in R^ 2. + # The filter is True if |vx| < dx & |vy| < dy, where dx and dy are the distance between + # rows and columns in source quadrilateral graph. + def _edge_filter(edge_prop): + abs_diffs = np.abs(edge_prop["vdiff"]) + return abs_diffs[0] < mesh_node_dx and abs_diffs[1] < mesh_node_dy + + filtered_edges = [ + (u, v) + for u, v, edge_prop in rad_graph.edges(data=True) + if _edge_filter(edge_prop) + ] + + filtered_graph = rad_graph.edge_subgraph(filtered_edges) + return filtered_graph + + elif method == "nearest_neighbour": if ( max_dist is not None or rel_max_dist is not None diff --git a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py index cea557e..a3cc4a9 100644 --- a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py +++ b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py @@ -121,4 +121,8 @@ def create_hierarchical_multiscale_mesh_graph( G_m2m = networkx.compose_all([G_all_levels, G_up_all, G_down_all]) + # add dx and dy to graph + for prop in ("dx", "dy"): + G_m2m.graph[prop] = {i: g.graph[prop] for i, g in enumerate(Gs_all_levels)} + return G_m2m diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index f35b35f..bc0289c 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -49,7 +49,7 @@ def test_create_graph_archetype(kind): # list the connectivity options for g2m and m2g and the kwargs to test -G2M_M2G_CONNECTIVITY_OPTIONS = dict( +G2M_CONNECTIVITY_OPTIONS = dict( nearest_neighbour=[], nearest_neighbours=[dict(max_num_neighbours=4), dict(max_num_neighbours=8)], within_radius=[ @@ -59,6 +59,9 @@ def test_create_graph_archetype(kind): dict(rel_max_dist=1.0), ], ) +# containing_rectangle option should only be used for m2g +M2G_CONNECTIVITY_OPTIONS = G2M_CONNECTIVITY_OPTIONS.copy() +M2G_CONNECTIVITY_OPTIONS["containing_rectangle"] = [dict()] # list the connectivity options for m2m and the kwargs to test M2M_CONNECTIVITY_OPTIONS = dict( @@ -74,14 +77,14 @@ def test_create_graph_archetype(kind): ) -@pytest.mark.parametrize("g2m_connectivity", G2M_M2G_CONNECTIVITY_OPTIONS.keys()) -@pytest.mark.parametrize("m2g_connectivity", G2M_M2G_CONNECTIVITY_OPTIONS.keys()) +@pytest.mark.parametrize("g2m_connectivity", G2M_CONNECTIVITY_OPTIONS.keys()) +@pytest.mark.parametrize("m2g_connectivity", M2G_CONNECTIVITY_OPTIONS.keys()) @pytest.mark.parametrize("m2m_connectivity", M2M_CONNECTIVITY_OPTIONS.keys()) def test_create_graph_generic(m2g_connectivity, g2m_connectivity, m2m_connectivity): xy = _create_fake_xy(N=32) - for g2m_kwargs in G2M_M2G_CONNECTIVITY_OPTIONS[g2m_connectivity]: - for m2g_kwargs in G2M_M2G_CONNECTIVITY_OPTIONS[m2g_connectivity]: + for g2m_kwargs in G2M_CONNECTIVITY_OPTIONS[g2m_connectivity]: + for m2g_kwargs in M2G_CONNECTIVITY_OPTIONS[m2g_connectivity]: for m2m_kwargs in M2M_CONNECTIVITY_OPTIONS[m2m_connectivity]: graph = wmg.create.create_all_graph_components( xy=xy,