From 1f6b6f7ea54ea42bd1d5dfdf33452c8cc9ee0afa Mon Sep 17 00:00:00 2001 From: Dautel Date: Wed, 29 Jan 2025 18:47:26 +0000 Subject: [PATCH 1/3] Allow point-based forecast in calculate_metrics.py --- stationbench/calculate_metrics.py | 53 +++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/stationbench/calculate_metrics.py b/stationbench/calculate_metrics.py index 1f43bcd..5fe1afc 100644 --- a/stationbench/calculate_metrics.py +++ b/stationbench/calculate_metrics.py @@ -3,13 +3,14 @@ from datetime import date, datetime from typing import Union +import numpy as np import xarray as xr from dask.distributed import Client, LocalCluster -from stationbench.utils.regions import region_dict, select_region_for_stations -from stationbench.utils.logging import init_logging from stationbench.utils.io import load_dataset +from stationbench.utils.logging import init_logging from stationbench.utils.metrics import AVAILABLE_METRICS +from stationbench.utils.regions import region_dict, select_region_for_stations logger = logging.getLogger(__name__) @@ -82,6 +83,7 @@ def prepare_forecast( "longitude": "auto", }, ) + is_point_based = "station_id" in forecast.dims # First handle time dimensions forecast = forecast.sel(time=slice(start_date, end_date)) @@ -112,7 +114,10 @@ def prepare_forecast( forecast = forecast.sortby("longitude") # Select region - forecast = forecast.sel(latitude=lat_slice, longitude=lon_slice) + if is_point_based: + forecast = select_region_for_stations(forecast, lat_slice, lon_slice) + else: + forecast = forecast.sel(latitude=lat_slice, longitude=lon_slice) # Rename variables if wind_speed_name: @@ -140,12 +145,42 @@ def prepare_forecast( def interpolate_to_stations(forecast: xr.Dataset, stations: xr.Dataset) -> xr.Dataset: """Interpolate forecast to station locations.""" logger.info("Interpolating forecast to station locations") - forecast_interp = forecast.interp( - latitude=stations.latitude, - longitude=stations.longitude, - method="linear", - ) - return forecast_interp + + is_point_based = "station_id" in forecast.dims + + if is_point_based: + logger.info("Detected point-based forecast format, no interpolation needed") + forecast_stations = forecast.station_id.values + stations_ids = stations.station_id.values + common_stations = np.intersect1d(forecast_stations, stations_ids) + + if len(common_stations) == 0: + raise ValueError("No common stations found between forecast and stations") + + logger.info("Found %d common stations", len(common_stations)) + + forecast = forecast.sel(station_id=common_stations) + stations_subset = stations.sel(station_id=common_stations) + + # Validate coordinates match within tolerance + coord_tolerance = 0.01 # ~1km at equator + lat_diff = np.abs(forecast.latitude.values - stations_subset.latitude.values) + lon_diff = np.abs(forecast.longitude.values - stations_subset.longitude.values) + + if np.any(lat_diff > coord_tolerance) or np.any(lon_diff > coord_tolerance): + raise ValueError( + f"Coordinate mismatch between forecast and stations exceeds tolerance of {coord_tolerance} degrees" + ) + + return forecast + else: + # Grid-based forecast - interpolate to station points + forecast_interp = forecast.interp( + latitude=stations.latitude, + longitude=stations.longitude, + method="linear", + ) + return forecast_interp def generate_benchmarks( From f90ab4ac9c119e6e197087442f85e170dfce319f Mon Sep 17 00:00:00 2001 From: Dautel Date: Wed, 29 Jan 2025 19:00:52 +0000 Subject: [PATCH 2/3] Refactor into own function for intersection --- stationbench/calculate_metrics.py | 67 +++++++++++++++++-------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/stationbench/calculate_metrics.py b/stationbench/calculate_metrics.py index 5fe1afc..86aa082 100644 --- a/stationbench/calculate_metrics.py +++ b/stationbench/calculate_metrics.py @@ -142,45 +142,46 @@ def prepare_forecast( return forecast -def interpolate_to_stations(forecast: xr.Dataset, stations: xr.Dataset) -> xr.Dataset: - """Interpolate forecast to station locations.""" - logger.info("Interpolating forecast to station locations") - - is_point_based = "station_id" in forecast.dims - - if is_point_based: - logger.info("Detected point-based forecast format, no interpolation needed") - forecast_stations = forecast.station_id.values - stations_ids = stations.station_id.values - common_stations = np.intersect1d(forecast_stations, stations_ids) - - if len(common_stations) == 0: - raise ValueError("No common stations found between forecast and stations") +def intersect_stations( + forecast: xr.Dataset, + stations: xr.Dataset, + by: str = "station_id", + coord_tolerance: float | None = 0.01, +) -> xr.Dataset: + """Match point-based forecast with station locations.""" + logger.info("Matching point-based forecast with stations") + forecast_stations = forecast[by].values + stations_ids = stations[by].values + common_stations = np.intersect1d(forecast_stations, stations_ids) - logger.info("Found %d common stations", len(common_stations)) + if len(common_stations) == 0: + raise ValueError("No common stations found between forecast and stations") + logger.info("Found %d common stations", len(common_stations)) - forecast = forecast.sel(station_id=common_stations) - stations_subset = stations.sel(station_id=common_stations) + forecast = forecast.sel(**{by: common_stations}) + stations_subset = stations.sel(**{by: common_stations}) - # Validate coordinates match within tolerance - coord_tolerance = 0.01 # ~1km at equator + # Validate coordinates match within tolerance + if coord_tolerance is not None: lat_diff = np.abs(forecast.latitude.values - stations_subset.latitude.values) lon_diff = np.abs(forecast.longitude.values - stations_subset.longitude.values) - if np.any(lat_diff > coord_tolerance) or np.any(lon_diff > coord_tolerance): raise ValueError( f"Coordinate mismatch between forecast and stations exceeds tolerance of {coord_tolerance} degrees" ) - return forecast - else: - # Grid-based forecast - interpolate to station points - forecast_interp = forecast.interp( - latitude=stations.latitude, - longitude=stations.longitude, - method="linear", - ) - return forecast_interp + return forecast + + +def interpolate_to_stations(forecast: xr.Dataset, stations: xr.Dataset) -> xr.Dataset: + """Interpolate forecast to station locations.""" + logger.info("Interpolating forecast to station locations") + forecast_interp = forecast.interp( + latitude=stations.latitude, + longitude=stations.longitude, + method="linear", + ) + return forecast_interp def generate_benchmarks( @@ -302,7 +303,13 @@ def main(args=None) -> xr.Dataset: args.name_10m_wind_speed, args.name_2m_temperature, ) - forecast = interpolate_to_stations(forecast, stations) + + # Either match stations or interpolate based on forecast type + is_point_based = "station_id" in forecast.dims + if is_point_based: + forecast = intersect_stations(forecast, stations) + else: + forecast = interpolate_to_stations(forecast, stations) # Calculate benchmarks benchmarks_ds = generate_benchmarks( From 29f7838d9b5394374c2d7eb242603b1ef2a6e16c Mon Sep 17 00:00:00 2001 From: Dautel Date: Wed, 29 Jan 2025 19:25:21 +0000 Subject: [PATCH 3/3] Add point-based forecast test --- tests/test_calculate_metrics.py | 71 +++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 8 deletions(-) diff --git a/tests/test_calculate_metrics.py b/tests/test_calculate_metrics.py index fcb3374..687caa0 100644 --- a/tests/test_calculate_metrics.py +++ b/tests/test_calculate_metrics.py @@ -1,17 +1,17 @@ -import xarray as xr -import numpy as np -from datetime import datetime -import pytest -import pandas as pd import argparse +from datetime import datetime +import numpy as np +import pandas as pd +import pytest +import xarray as xr from stationbench.calculate_metrics import ( - prepare_forecast, - prepare_stations, - interpolate_to_stations, generate_benchmarks, + interpolate_to_stations, main, + prepare_forecast, + prepare_stations, ) @@ -44,6 +44,37 @@ def sample_forecast(): return ds +@pytest.fixture +def sample_point_forecast(): + """Create a sample point-based forecast dataset.""" + times = pd.date_range("2022-01-01", "2022-01-02", freq="24h") # Just 2 init times + lead_times = pd.timedelta_range("0h", "24h", freq="24h") # Just 2 lead times + stations = ["ST1", "ST2"] # Two stations + lats = [50.0, 51.0] + lons = [5.0, 6.0] + + ds = xr.Dataset( + data_vars={ + "2m_temperature": ( + ("time", "prediction_timedelta", "station_id"), + np.random.randn(len(times), len(lead_times), len(stations)), + ), + "10m_wind_speed": ( + ("time", "prediction_timedelta", "station_id"), + np.random.randn(len(times), len(lead_times), len(stations)), + ), + }, + coords={ + "time": times, + "prediction_timedelta": lead_times, + "station_id": stations, + "latitude": ("station_id", lats), + "longitude": ("station_id", lons), + }, + ) + return ds + + @pytest.fixture def sample_stations(): """Create a sample stations dataset.""" @@ -135,6 +166,30 @@ def test_full_pipeline(sample_forecast, sample_stations): assert set(benchmarks.data_vars) == {"10m_wind_speed", "2m_temperature"} +def test_full_pipeline_with_point_based_forecast( + sample_point_forecast, sample_stations +): + """Test the full pipeline.""" + args = argparse.Namespace( + forecast=sample_point_forecast, + stations=sample_stations, + region="europe", + start_date=datetime(2022, 1, 1), + end_date=datetime(2022, 1, 2), + name_10m_wind_speed="10m_wind_speed", + name_2m_temperature="2m_temperature", + use_dask=False, + output=None, + ) + + benchmarks = main(args) + assert isinstance(benchmarks, xr.Dataset) + + assert set(benchmarks.dims) == {"lead_time", "station_id", "metric"} + assert set(benchmarks.metric.values) == {"rmse", "mbe"} + assert set(benchmarks.data_vars) == {"10m_wind_speed", "2m_temperature"} + + def test_rmse_calculation_matches_manual(sample_forecast, sample_stations): """Test that the RMSE calculation matches a manual calculation for a simple case.""" # Prepare forecast with known values