diff --git a/docs/creating_the_graph.ipynb b/docs/creating_the_graph.ipynb index 55bffa5..8cbc913 100644 --- a/docs/creating_the_graph.ipynb +++ b/docs/creating_the_graph.ipynb @@ -26,7 +26,7 @@ "source": [ "# The grid nodes\n", "\n", - "To get started we will create a set of fake grid nodes, which represent the geographical locations (x/y cartesian) where we have values for the physical fields." + "To get started we will create a set of fake grid nodes, which represent the geographical locations where we have values for the physical fields. We will here work with cartesian x/y coordinates. See [this page](./lat_lons.ipynb) for how to use lat/lon coordinates in weather-model-graphs." ] }, { @@ -80,7 +80,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Lets start with a simple mesh which only has nearest neighbour connections. At the moment `weather-model-graphs` creates a square mesh that sits within the spatial domain spanned by the grid nodes. Techniques for adding non-square meshes are in development." + "Lets start with a simple mesh which only has nearest neighbour connections. At the moment `weather-model-graphs` creates a rectangular mesh that sits within the spatial domain spanned by the grid nodes (specifically within the axis-aligned bounding box of the grid nodes). Techniques for adding non-square meshes are in development." ] }, { @@ -458,7 +458,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/docs/lat_lons.ipynb b/docs/lat_lons.ipynb index c4170d4..5bd4dbf 100644 --- a/docs/lat_lons.ipynb +++ b/docs/lat_lons.ipynb @@ -7,7 +7,7 @@ "source": [ "# Working with lat-lon coordinates\n", "\n", - "In the previous sections we have considered grid point positions `coords` given as Cartesian coordinates. However, it is common that we have coordinates given as latitudes and longitudes. This notebook describes how we can constuct graphs directly using lat-lon coordinates. This is achieved by also providing a specific projection, used to project the lat-lons to a euclidean space where the graph can be construced." + "In the previous sections we have considered grid point positions `coords` given as Cartesian coordinates. However, it is common that we have coordinates given as latitudes and longitudes. This notebook describes how we can constuct graphs directly using lat-lon coordinates. This is achieved by specifying the Coordinate Reference System (CRS) of `coords` and the CRS that the graph construction should be carried out in. `coords` will then be projected to this new CRS before any calculations are carried out." ] }, { @@ -96,9 +96,9 @@ "metadata": {}, "source": [ "## Constructing a graph within a projection\n", - "For our example above, let's instead try to construct the graph based on first projecting our lat-lon coordinates to within a specific projection. This can be done by giving a `projection` argument to the graph creation functions. The projection should be an instance of `cartopy.crs.CRS`. See [the cartopy documentation](https://scitools.org.uk/cartopy/docs/latest/reference/projections.html) for a list of available projections. \n", + "For our example above, let's instead try to construct the graph based on first projecting our lat-lon coordinates to another CRS with 2-dimensional cartesian coordinates. This can be done by giving the `coords_crs` and `graph_crs` arguments to the graph creation functions. Theses arguments should both be instances of `pyproj.crs.CRS` ([pyproj docs.](https://pyproj4.github.io/pyproj/stable/api/crs/crs.html#pyproj.crs.CRS)). Nicely, they can be `cartopy.crs.CRS`, which provides easy ways to specify such CRS:s. For more advanced use cases a `pyproj.crs.CRS` can be specified directly. See [the cartopy documentation](https://scitools.org.uk/cartopy/docs/latest/reference/projections.html) for a list of readily available CRS:s to use for projecting the coordinates. \n", "\n", - "We will here try the same thing as above, but using a Azimuthal equidistant projection centered at the pole:" + "We will here try the same thing as above, but using a Azimuthal equidistant projection centered at the pole. The CRS of our lat-lon coordinates will be `cartopy.crs.PlateCarree` and we want to project this to `cartopy.crs.AzimuthalEquidistant`:" ] }, { @@ -109,9 +109,10 @@ "outputs": [], "source": [ "# Define our projection\n", - "ae_projection = ccrs.AzimuthalEquidistant(central_latitude=90)\n", + "coords_crs = ccrs.PlateCarree()\n", + "graph_crs = ccrs.AzimuthalEquidistant(central_latitude=90)\n", "\n", - "fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": ae_projection})\n", + "fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": graph_crs})\n", "ax.scatter(coords[:, 0], coords[:, 1], marker=\".\", transform=ccrs.PlateCarree())\n", "_ = ax.coastlines()" ] @@ -135,9 +136,9 @@ " 10**6\n", ") # Large euclidean distance in projection coordinates between mesh nodes\n", "graph = wmg.create.archetype.create_keisler_graph(\n", - " coords, mesh_node_distance=mesh_distance, projection=ae_projection\n", + " coords, mesh_node_distance=mesh_distance, coords_crs=coords_crs, graph_crs=graph_crs\n", ") # Note that we here specify the projection argument\n", - "fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": ae_projection})\n", + "fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": graph_crs})\n", "wmg.visualise.nx_draw_with_pos_and_attr(\n", " graph, ax=ax, node_size=30, edge_color_attr=\"component\", node_color_attr=\"type\"\n", ")\n", @@ -149,7 +150,9 @@ "id": "f5acf83a-df33-4925-bb32-c106e27e51a4", "metadata": {}, "source": [ - "Now this looks like a more reasonable graph layout, that better respects the spatial relations between the grid points. There are still things that could be tweaked further (e.g. the large number of grid nodes connected to the center mesh node), but this ends our example of defining graphs using lat-lon coordinates." + "Now this looks like a more reasonable graph layout, that better respects the spatial relations between the grid points. There are still things that could be tweaked further (e.g. the large number of grid nodes connected to the center mesh node), but this ends our example of defining graphs using lat-lon coordinates.\n", + "\n", + "It can be noted that this projection between different CRS:s provides more general functionality than just handling lat-lon coordinates. It is entirely possible to transform from any `coords_crs` to any `graph_crs` using these arguments." ] } ], @@ -169,7 +172,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 1865fd0..17a087a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dependencies = [ "loguru>=0.7.2", "networkx>=3.3", "scipy>=1.13.0", - "cartopy>=0.24.1", + "pyproj>=3.7.0", ] requires-python = ">=3.10" readme = "README.md" @@ -23,6 +23,7 @@ pytorch = [ visualisation = [ "matplotlib>=3.8.4", "ipykernel>=6.29.4", + "cartopy>=0.24.1", ] docs = [ "jupyter-book>=1.0.0", diff --git a/src/weather_model_graphs/create/archetype.py b/src/weather_model_graphs/create/archetype.py index e3d31b4..4498945 100644 --- a/src/weather_model_graphs/create/archetype.py +++ b/src/weather_model_graphs/create/archetype.py @@ -2,7 +2,11 @@ def create_keisler_graph( - coords, mesh_node_distance=3, projection=None, decode_mask=None + coords, + mesh_node_distance=3, + coords_crs=None, + graph_crs=None, + decode_mask=None, ): """ Create a flat LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370) @@ -51,7 +55,8 @@ def create_keisler_graph( m2g_connectivity_kwargs=dict( max_num_neighbours=4, ), - projection=projection, + coords_crs=coords_crs, + graph_crs=graph_crs, decode_mask=decode_mask, ) @@ -61,7 +66,8 @@ def create_graphcast_graph( mesh_node_distance=3, level_refinement_factor=3, max_num_levels=None, - projection=None, + coords_crs=None, + graph_crs=None, decode_mask=None, ): """ @@ -119,7 +125,8 @@ def create_graphcast_graph( m2g_connectivity_kwargs=dict( max_num_neighbours=4, ), - projection=projection, + coords_crs=coords_crs, + graph_crs=graph_crs, decode_mask=decode_mask, ) @@ -129,7 +136,8 @@ def create_oskarsson_hierarchical_graph( mesh_node_distance=3, level_refinement_factor=3, max_num_levels=None, - projection=None, + coords_crs=None, + graph_crs=None, decode_mask=None, ): """ @@ -190,6 +198,7 @@ def create_oskarsson_hierarchical_graph( m2g_connectivity_kwargs=dict( max_num_neighbours=4, ), - projection=projection, + coords_crs=coords_crs, + graph_crs=graph_crs, decode_mask=decode_mask, ) diff --git a/src/weather_model_graphs/create/base.py b/src/weather_model_graphs/create/base.py index d05739d..ed2e4e2 100644 --- a/src/weather_model_graphs/create/base.py +++ b/src/weather_model_graphs/create/base.py @@ -11,10 +11,10 @@ from typing import Iterable -import cartopy.crs as ccrs import networkx import networkx as nx import numpy as np +import pyproj import scipy.spatial from loguru import logger @@ -39,7 +39,8 @@ def create_all_graph_components( m2m_connectivity_kwargs={}, m2g_connectivity_kwargs={}, g2m_connectivity_kwargs={}, - projection: ccrs.CRS | None = None, + coords_crs: pyproj.crs.CRS | None = None, + graph_crs: pyproj.crs.CRS | None = None, decode_mask: Iterable | None = None, ): """ @@ -79,9 +80,11 @@ def create_all_graph_components( such that the grid node is contained within it. Connect these 4 (or less along edges) mesh nodes to the grid node. - `projection` should either be a cartopy.crs.CRS or None. This is the projection - instance used to transform given lat-lon coords to in-projection Cartesian coordinates. - If None the coords are assumed to already be Cartesian. + `coords_crs` and `graph_crs` should either be a pyproj.crs.CRS or None. + Note that this includes a cartopy.crs.CRS. If both are given the coordinates + will be transformed from their original Coordinate Reference System (`coords_crs`) + to the CRS where the graph creation should take place (`graph_crs`). + If any one of them is None the graph creation is corried out using the original coords. `decode_mask` should be an Iterable of booleans, masking which grid positions should be decoded to (included in the m2g subgraph), i.e. which positions should be output. It should have the same length as the number of @@ -94,21 +97,29 @@ def create_all_graph_components( len(coords.shape) == 2 and coords.shape[1] == 2 ), "Grid node coordinates should be given as an array of shape [num_grid_nodes, 2]." - if projection is None: + # Translate between coordinate crs and crs to use for graph creation + if coords_crs is None and coords_crs is None: logger.debug( - "No `projection` given: Assuming `coords` contains in-projection Cartesian coordinates." + "No `coords_crs` given: Assuming `coords` contains in-projection Cartesian coordinates." + ) + xy = coords + elif (coords_crs is None) != (graph_crs is None): # xor, only one is None + logger.warning( + "Only one of `coords_crs` and `graph_crs` given. Both are needed to " + "transform coordinates to a different crs for constructing the graph: " + "Assuming `coords` contains in-projection Cartesian coordinates." ) xy = coords else: logger.debug( - f"`projection` Proj({projection}) given, `coords` treated as lat-lons." + f"Projecting coords from CRS({coords_crs}) to CRS({graph_crs}) for graph creation." ) - # Convert lat-lon coords to Cartesian xy - xyz = projection.transform_points( - src_crs=ccrs.PlateCarree(), x=coords[:, 0], y=coords[:, 1] + # Convert from coords_crs to to graph_crs + coord_transformer = pyproj.Transformer.from_crs( + coords_crs, graph_crs, always_xy=True ) - # Remove z-dim - xy = xyz[:, :2] + xy_tuple = coord_transformer.transform(xx=coords[:, 0], yy=coords[:, 1]) + xy = np.stack(xy_tuple, axis=1) if m2m_connectivity == "flat": graph_components["m2m"] = create_flat_singlescale_mesh_graph( diff --git a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py index d4f5115..b897693 100644 --- a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py +++ b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py @@ -49,7 +49,7 @@ def create_hierarchical_multiscale_mesh_graph( if n_mesh_levels < 2: raise ValueError( "At least two mesh levels are required for hierarchical mesh graph. " - f"You may need to reduce the refinement factors" + "You may need to reduce the level refinement factor " f"or increase the max number of levels {max_num_levels} " f"or number of grid points {xy.shape[0]}." ) diff --git a/src/weather_model_graphs/create/mesh/mesh.py b/src/weather_model_graphs/create/mesh/mesh.py index 9391769..75c82d6 100644 --- a/src/weather_model_graphs/create/mesh/mesh.py +++ b/src/weather_model_graphs/create/mesh/mesh.py @@ -110,14 +110,14 @@ def create_multirange_2d_mesh_graphs( # Compute the size along x and y direction of area to cover with graph # This is measured in the Cartesian coordiantes of xy coord_extent = np.ptp(xy, axis=0) - extent_nodes_bottom_mesh = (coord_extent / mesh_node_distance).astype(int) + # Number of nodes that would fit on bottom level of hierarchy, + # in both directions + max_nodes_bottom = (coord_extent / mesh_node_distance).astype(int) # Find the number of mesh levels possible in x- and y-direction, # and the number of leaf nodes that would correspond to - # max_mesh_coord/(level_refinement_factor^mesh_levels) = 1 - max_mesh_levels_float = np.log(extent_nodes_bottom_mesh) / np.log( - level_refinement_factor - ) + # max_nodes_bottom/(level_refinement_factor^mesh_levels) = 1 + max_mesh_levels_float = np.log(max_nodes_bottom) / np.log(level_refinement_factor) max_mesh_levels = max_mesh_levels_float.astype(int) # (2,) nleaf = level_refinement_factor**max_mesh_levels diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index dd2c0be..33b8d51 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -131,27 +131,53 @@ def test_create_exact_refinement(mesh_node_distance, level_refinement_factor): ) -@pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"]) -def test_create_irregular_grid(kind): +@pytest.mark.parametrize( + "kind_and_num_mesh", + [ + ("keisler", 20**2), # 20 mesh nodes in bottom layer in each direction + ("graphcast", 9**2), # Can only fit 9 x 9 with level_refinement_factor=3 + ( + "oskarsson_hierarchical", + 9**2 + 3**2, + ), # As above, with additional 3 x 3 layer + ], +) +def test_create_irregular_grid(kind_and_num_mesh): """ Tests that graphs can be created for irregular layouts of grid points """ - xy = test_utils.create_fake_irregular_coords(100) + kind, num_mesh = kind_and_num_mesh + num_grid = 100 + xy = test_utils.create_fake_irregular_coords(num_grid - 4) + + # Need to include corners if we want to know actual size of covered area + xy = np.concatenate( + ( + xy, + np.array( + [[0.0, 0.0], [0.0, 1.0], [1.0, 0], [1.0, 1.0]] + ), # Remaining 4 nodes + ), + axis=0, + ) + fn_name = f"create_{kind}_graph" fn = getattr(wmg.create.archetype, fn_name) - # ~= 20 mesh nodes in bottom layer in each direction - fn(coords=xy, mesh_node_distance=0.05) + graph = fn(coords=xy, mesh_node_distance=0.05) + + assert len(graph.nodes) == num_grid + num_mesh @pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"]) def test_create_lat_lon(kind): """ - Tests that graphs can be created from lat-lon coordinates + projection + Tests that graphs can be created from lat-lon coordinates + projection spec. """ lon_coords = np.linspace(10, 30, 10) lat_coords = np.linspace(35, 65, 10) - projection = ccrs.LambertConformal() + coords_crs = ccrs.PlateCarree() + graph_crs = ccrs.LambertConformal() mesh_node_distance = 0.2 * 10**6 meshgridded = np.meshgrid(lon_coords, lat_coords) @@ -160,7 +186,12 @@ def test_create_lat_lon(kind): fn_name = f"create_{kind}_graph" fn = getattr(wmg.create.archetype, fn_name) - fn(coords=coords, mesh_node_distance=mesh_node_distance, projection=projection) + fn( + coords=coords, + mesh_node_distance=mesh_node_distance, + coords_crs=coords_crs, + graph_crs=graph_crs, + ) @pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"])