Skip to content

Commit

Permalink
Merge branch 'general_coordinates' into decoding_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 26, 2024
2 parents 8fcf182 + 8e3c1cb commit 541054e
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 46 deletions.
6 changes: 3 additions & 3 deletions docs/creating_the_graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"source": [
"# The grid nodes\n",
"\n",
"To get started we will create a set of fake grid nodes, which represent the geographical locations (x/y cartesian) where we have values for the physical fields."
"To get started we will create a set of fake grid nodes, which represent the geographical locations where we have values for the physical fields. We will here work with cartesian x/y coordinates. See [this page](./lat_lons.ipynb) for how to use lat/lon coordinates in weather-model-graphs."
]
},
{
Expand Down Expand Up @@ -80,7 +80,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets start with a simple mesh which only has nearest neighbour connections. At the moment `weather-model-graphs` creates a square mesh that sits within the spatial domain spanned by the grid nodes. Techniques for adding non-square meshes are in development."
"Lets start with a simple mesh which only has nearest neighbour connections. At the moment `weather-model-graphs` creates a rectangular mesh that sits within the spatial domain spanned by the grid nodes (specifically within the axis-aligned bounding box of the grid nodes). Techniques for adding non-square meshes are in development."
]
},
{
Expand Down Expand Up @@ -458,7 +458,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
21 changes: 12 additions & 9 deletions docs/lat_lons.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# Working with lat-lon coordinates\n",
"\n",
"In the previous sections we have considered grid point positions `coords` given as Cartesian coordinates. However, it is common that we have coordinates given as latitudes and longitudes. This notebook describes how we can constuct graphs directly using lat-lon coordinates. This is achieved by also providing a specific projection, used to project the lat-lons to a euclidean space where the graph can be construced."
"In the previous sections we have considered grid point positions `coords` given as Cartesian coordinates. However, it is common that we have coordinates given as latitudes and longitudes. This notebook describes how we can constuct graphs directly using lat-lon coordinates. This is achieved by specifying the Coordinate Reference System (CRS) of `coords` and the CRS that the graph construction should be carried out in. `coords` will then be projected to this new CRS before any calculations are carried out."
]
},
{
Expand Down Expand Up @@ -96,9 +96,9 @@
"metadata": {},
"source": [
"## Constructing a graph within a projection\n",
"For our example above, let's instead try to construct the graph based on first projecting our lat-lon coordinates to within a specific projection. This can be done by giving a `projection` argument to the graph creation functions. The projection should be an instance of `cartopy.crs.CRS`. See [the cartopy documentation](https://scitools.org.uk/cartopy/docs/latest/reference/projections.html) for a list of available projections. \n",
"For our example above, let's instead try to construct the graph based on first projecting our lat-lon coordinates to another CRS with 2-dimensional cartesian coordinates. This can be done by giving the `coords_crs` and `graph_crs` arguments to the graph creation functions. Theses arguments should both be instances of `pyproj.crs.CRS` ([pyproj docs.](https://pyproj4.github.io/pyproj/stable/api/crs/crs.html#pyproj.crs.CRS)). Nicely, they can be `cartopy.crs.CRS`, which provides easy ways to specify such CRS:s. For more advanced use cases a `pyproj.crs.CRS` can be specified directly. See [the cartopy documentation](https://scitools.org.uk/cartopy/docs/latest/reference/projections.html) for a list of readily available CRS:s to use for projecting the coordinates. \n",
"\n",
"We will here try the same thing as above, but using a Azimuthal equidistant projection centered at the pole:"
"We will here try the same thing as above, but using a Azimuthal equidistant projection centered at the pole. The CRS of our lat-lon coordinates will be `cartopy.crs.PlateCarree` and we want to project this to `cartopy.crs.AzimuthalEquidistant`:"
]
},
{
Expand All @@ -109,9 +109,10 @@
"outputs": [],
"source": [
"# Define our projection\n",
"ae_projection = ccrs.AzimuthalEquidistant(central_latitude=90)\n",
"coords_crs = ccrs.PlateCarree()\n",
"graph_crs = ccrs.AzimuthalEquidistant(central_latitude=90)\n",
"\n",
"fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": ae_projection})\n",
"fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": graph_crs})\n",
"ax.scatter(coords[:, 0], coords[:, 1], marker=\".\", transform=ccrs.PlateCarree())\n",
"_ = ax.coastlines()"
]
Expand All @@ -135,9 +136,9 @@
" 10**6\n",
") # Large euclidean distance in projection coordinates between mesh nodes\n",
"graph = wmg.create.archetype.create_keisler_graph(\n",
" coords, mesh_node_distance=mesh_distance, projection=ae_projection\n",
" coords, mesh_node_distance=mesh_distance, coords_crs=coords_crs, graph_crs=graph_crs\n",
") # Note that we here specify the projection argument\n",
"fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": ae_projection})\n",
"fig, ax = plt.subplots(figsize=(15, 9), subplot_kw={\"projection\": graph_crs})\n",
"wmg.visualise.nx_draw_with_pos_and_attr(\n",
" graph, ax=ax, node_size=30, edge_color_attr=\"component\", node_color_attr=\"type\"\n",
")\n",
Expand All @@ -149,7 +150,9 @@
"id": "f5acf83a-df33-4925-bb32-c106e27e51a4",
"metadata": {},
"source": [
"Now this looks like a more reasonable graph layout, that better respects the spatial relations between the grid points. There are still things that could be tweaked further (e.g. the large number of grid nodes connected to the center mesh node), but this ends our example of defining graphs using lat-lon coordinates."
"Now this looks like a more reasonable graph layout, that better respects the spatial relations between the grid points. There are still things that could be tweaked further (e.g. the large number of grid nodes connected to the center mesh node), but this ends our example of defining graphs using lat-lon coordinates.\n",
"\n",
"It can be noted that this projection between different CRS:s provides more general functionality than just handling lat-lon coordinates. It is entirely possible to transform from any `coords_crs` to any `graph_crs` using these arguments."
]
}
],
Expand All @@ -169,7 +172,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = [
"loguru>=0.7.2",
"networkx>=3.3",
"scipy>=1.13.0",
"cartopy>=0.24.1",
"pyproj>=3.7.0",
]
requires-python = ">=3.10"
readme = "README.md"
Expand All @@ -23,6 +23,7 @@ pytorch = [
visualisation = [
"matplotlib>=3.8.4",
"ipykernel>=6.29.4",
"cartopy>=0.24.1",
]
docs = [
"jupyter-book>=1.0.0",
Expand Down
21 changes: 15 additions & 6 deletions src/weather_model_graphs/create/archetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@


def create_keisler_graph(
coords, mesh_node_distance=3, projection=None, decode_mask=None
coords,
mesh_node_distance=3,
coords_crs=None,
graph_crs=None,
decode_mask=None,
):
"""
Create a flat LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
Expand Down Expand Up @@ -51,7 +55,8 @@ def create_keisler_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
projection=projection,
coords_crs=coords_crs,
graph_crs=graph_crs,
decode_mask=decode_mask,
)

Expand All @@ -61,7 +66,8 @@ def create_graphcast_graph(
mesh_node_distance=3,
level_refinement_factor=3,
max_num_levels=None,
projection=None,
coords_crs=None,
graph_crs=None,
decode_mask=None,
):
"""
Expand Down Expand Up @@ -119,7 +125,8 @@ def create_graphcast_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
projection=projection,
coords_crs=coords_crs,
graph_crs=graph_crs,
decode_mask=decode_mask,
)

Expand All @@ -129,7 +136,8 @@ def create_oskarsson_hierarchical_graph(
mesh_node_distance=3,
level_refinement_factor=3,
max_num_levels=None,
projection=None,
coords_crs=None,
graph_crs=None,
decode_mask=None,
):
"""
Expand Down Expand Up @@ -190,6 +198,7 @@ def create_oskarsson_hierarchical_graph(
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
projection=projection,
coords_crs=coords_crs,
graph_crs=graph_crs,
decode_mask=decode_mask,
)
37 changes: 24 additions & 13 deletions src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from typing import Iterable

import cartopy.crs as ccrs
import networkx
import networkx as nx
import numpy as np
import pyproj
import scipy.spatial
from loguru import logger

Expand All @@ -39,7 +39,8 @@ def create_all_graph_components(
m2m_connectivity_kwargs={},
m2g_connectivity_kwargs={},
g2m_connectivity_kwargs={},
projection: ccrs.CRS | None = None,
coords_crs: pyproj.crs.CRS | None = None,
graph_crs: pyproj.crs.CRS | None = None,
decode_mask: Iterable | None = None,
):
"""
Expand Down Expand Up @@ -79,9 +80,11 @@ def create_all_graph_components(
such that the grid node is contained within it. Connect these 4 (or less along edges)
mesh nodes to the grid node.
`projection` should either be a cartopy.crs.CRS or None. This is the projection
instance used to transform given lat-lon coords to in-projection Cartesian coordinates.
If None the coords are assumed to already be Cartesian.
`coords_crs` and `graph_crs` should either be a pyproj.crs.CRS or None.
Note that this includes a cartopy.crs.CRS. If both are given the coordinates
will be transformed from their original Coordinate Reference System (`coords_crs`)
to the CRS where the graph creation should take place (`graph_crs`).
If any one of them is None the graph creation is corried out using the original coords.
`decode_mask` should be an Iterable of booleans, masking which grid positions should be
decoded to (included in the m2g subgraph), i.e. which positions should be output. It should have the same length as the number of
Expand All @@ -94,21 +97,29 @@ def create_all_graph_components(
len(coords.shape) == 2 and coords.shape[1] == 2
), "Grid node coordinates should be given as an array of shape [num_grid_nodes, 2]."

if projection is None:
# Translate between coordinate crs and crs to use for graph creation
if coords_crs is None and coords_crs is None:
logger.debug(
"No `projection` given: Assuming `coords` contains in-projection Cartesian coordinates."
"No `coords_crs` given: Assuming `coords` contains in-projection Cartesian coordinates."
)
xy = coords
elif (coords_crs is None) != (graph_crs is None): # xor, only one is None
logger.warning(
"Only one of `coords_crs` and `graph_crs` given. Both are needed to "
"transform coordinates to a different crs for constructing the graph: "
"Assuming `coords` contains in-projection Cartesian coordinates."
)
xy = coords
else:
logger.debug(
f"`projection` Proj({projection}) given, `coords` treated as lat-lons."
f"Projecting coords from CRS({coords_crs}) to CRS({graph_crs}) for graph creation."
)
# Convert lat-lon coords to Cartesian xy
xyz = projection.transform_points(
src_crs=ccrs.PlateCarree(), x=coords[:, 0], y=coords[:, 1]
# Convert from coords_crs to to graph_crs
coord_transformer = pyproj.Transformer.from_crs(
coords_crs, graph_crs, always_xy=True
)
# Remove z-dim
xy = xyz[:, :2]
xy_tuple = coord_transformer.transform(xx=coords[:, 0], yy=coords[:, 1])
xy = np.stack(xy_tuple, axis=1)

if m2m_connectivity == "flat":
graph_components["m2m"] = create_flat_singlescale_mesh_graph(
Expand Down
2 changes: 1 addition & 1 deletion src/weather_model_graphs/create/mesh/kinds/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create_hierarchical_multiscale_mesh_graph(
if n_mesh_levels < 2:
raise ValueError(
"At least two mesh levels are required for hierarchical mesh graph. "
f"You may need to reduce the refinement factors"
"You may need to reduce the level refinement factor "
f"or increase the max number of levels {max_num_levels} "
f"or number of grid points {xy.shape[0]}."
)
Expand Down
10 changes: 5 additions & 5 deletions src/weather_model_graphs/create/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ def create_multirange_2d_mesh_graphs(
# Compute the size along x and y direction of area to cover with graph
# This is measured in the Cartesian coordiantes of xy
coord_extent = np.ptp(xy, axis=0)
extent_nodes_bottom_mesh = (coord_extent / mesh_node_distance).astype(int)
# Number of nodes that would fit on bottom level of hierarchy,
# in both directions
max_nodes_bottom = (coord_extent / mesh_node_distance).astype(int)

# Find the number of mesh levels possible in x- and y-direction,
# and the number of leaf nodes that would correspond to
# max_mesh_coord/(level_refinement_factor^mesh_levels) = 1
max_mesh_levels_float = np.log(extent_nodes_bottom_mesh) / np.log(
level_refinement_factor
)
# max_nodes_bottom/(level_refinement_factor^mesh_levels) = 1
max_mesh_levels_float = np.log(max_nodes_bottom) / np.log(level_refinement_factor)

max_mesh_levels = max_mesh_levels_float.astype(int) # (2,)
nleaf = level_refinement_factor**max_mesh_levels
Expand Down
47 changes: 39 additions & 8 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,53 @@ def test_create_exact_refinement(mesh_node_distance, level_refinement_factor):
)


@pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"])
def test_create_irregular_grid(kind):
@pytest.mark.parametrize(
"kind_and_num_mesh",
[
("keisler", 20**2), # 20 mesh nodes in bottom layer in each direction
("graphcast", 9**2), # Can only fit 9 x 9 with level_refinement_factor=3
(
"oskarsson_hierarchical",
9**2 + 3**2,
), # As above, with additional 3 x 3 layer
],
)
def test_create_irregular_grid(kind_and_num_mesh):
"""
Tests that graphs can be created for irregular layouts of grid points
"""
xy = test_utils.create_fake_irregular_coords(100)
kind, num_mesh = kind_and_num_mesh
num_grid = 100
xy = test_utils.create_fake_irregular_coords(num_grid - 4)

# Need to include corners if we want to know actual size of covered area
xy = np.concatenate(
(
xy,
np.array(
[[0.0, 0.0], [0.0, 1.0], [1.0, 0], [1.0, 1.0]]
), # Remaining 4 nodes
),
axis=0,
)

fn_name = f"create_{kind}_graph"
fn = getattr(wmg.create.archetype, fn_name)

# ~= 20 mesh nodes in bottom layer in each direction
fn(coords=xy, mesh_node_distance=0.05)
graph = fn(coords=xy, mesh_node_distance=0.05)

assert len(graph.nodes) == num_grid + num_mesh


@pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"])
def test_create_lat_lon(kind):
"""
Tests that graphs can be created from lat-lon coordinates + projection
Tests that graphs can be created from lat-lon coordinates + projection spec.
"""
lon_coords = np.linspace(10, 30, 10)
lat_coords = np.linspace(35, 65, 10)
projection = ccrs.LambertConformal()
coords_crs = ccrs.PlateCarree()
graph_crs = ccrs.LambertConformal()
mesh_node_distance = 0.2 * 10**6

meshgridded = np.meshgrid(lon_coords, lat_coords)
Expand All @@ -160,7 +186,12 @@ def test_create_lat_lon(kind):
fn_name = f"create_{kind}_graph"
fn = getattr(wmg.create.archetype, fn_name)

fn(coords=coords, mesh_node_distance=mesh_node_distance, projection=projection)
fn(
coords=coords,
mesh_node_distance=mesh_node_distance,
coords_crs=coords_crs,
graph_crs=graph_crs,
)


@pytest.mark.parametrize("kind", ["graphcast", "keisler", "oskarsson_hierarchical"])
Expand Down

0 comments on commit 541054e

Please sign in to comment.