Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add point forecast #32

Merged
merged 3 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions stationbench/calculate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from datetime import date, datetime
from typing import Union

import numpy as np
import xarray as xr
from dask.distributed import Client, LocalCluster

from stationbench.utils.regions import region_dict, select_region_for_stations
from stationbench.utils.logging import init_logging
from stationbench.utils.io import load_dataset
from stationbench.utils.logging import init_logging
from stationbench.utils.metrics import AVAILABLE_METRICS
from stationbench.utils.regions import region_dict, select_region_for_stations

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,6 +83,7 @@ def prepare_forecast(
"longitude": "auto",
},
)
is_point_based = "station_id" in forecast.dims

# First handle time dimensions
forecast = forecast.sel(time=slice(start_date, end_date))
Expand Down Expand Up @@ -112,7 +114,10 @@ def prepare_forecast(
forecast = forecast.sortby("longitude")

# Select region
forecast = forecast.sel(latitude=lat_slice, longitude=lon_slice)
if is_point_based:
forecast = select_region_for_stations(forecast, lat_slice, lon_slice)
else:
forecast = forecast.sel(latitude=lat_slice, longitude=lon_slice)

# Rename variables
if wind_speed_name:
Expand All @@ -137,6 +142,37 @@ def prepare_forecast(
return forecast


def intersect_stations(
forecast: xr.Dataset,
stations: xr.Dataset,
by: str = "station_id",
coord_tolerance: float | None = 0.01,
) -> xr.Dataset:
"""Match point-based forecast with station locations."""
logger.info("Matching point-based forecast with stations")
forecast_stations = forecast[by].values
stations_ids = stations[by].values
common_stations = np.intersect1d(forecast_stations, stations_ids)

if len(common_stations) == 0:
raise ValueError("No common stations found between forecast and stations")
logger.info("Found %d common stations", len(common_stations))

forecast = forecast.sel(**{by: common_stations})
stations_subset = stations.sel(**{by: common_stations})

# Validate coordinates match within tolerance
if coord_tolerance is not None:
lat_diff = np.abs(forecast.latitude.values - stations_subset.latitude.values)
lon_diff = np.abs(forecast.longitude.values - stations_subset.longitude.values)
if np.any(lat_diff > coord_tolerance) or np.any(lon_diff > coord_tolerance):
raise ValueError(
f"Coordinate mismatch between forecast and stations exceeds tolerance of {coord_tolerance} degrees"
)

return forecast


def interpolate_to_stations(forecast: xr.Dataset, stations: xr.Dataset) -> xr.Dataset:
"""Interpolate forecast to station locations."""
logger.info("Interpolating forecast to station locations")
Expand Down Expand Up @@ -267,7 +303,13 @@ def main(args=None) -> xr.Dataset:
args.name_10m_wind_speed,
args.name_2m_temperature,
)
forecast = interpolate_to_stations(forecast, stations)

# Either match stations or interpolate based on forecast type
is_point_based = "station_id" in forecast.dims
if is_point_based:
forecast = intersect_stations(forecast, stations)
else:
forecast = interpolate_to_stations(forecast, stations)

# Calculate benchmarks
benchmarks_ds = generate_benchmarks(
Expand Down
71 changes: 63 additions & 8 deletions tests/test_calculate_metrics.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import xarray as xr
import numpy as np
from datetime import datetime
import pytest
import pandas as pd
import argparse
from datetime import datetime

import numpy as np
import pandas as pd
import pytest
import xarray as xr

from stationbench.calculate_metrics import (
prepare_forecast,
prepare_stations,
interpolate_to_stations,
generate_benchmarks,
interpolate_to_stations,
main,
prepare_forecast,
prepare_stations,
)


Expand Down Expand Up @@ -44,6 +44,37 @@ def sample_forecast():
return ds


@pytest.fixture
def sample_point_forecast():
"""Create a sample point-based forecast dataset."""
times = pd.date_range("2022-01-01", "2022-01-02", freq="24h") # Just 2 init times
lead_times = pd.timedelta_range("0h", "24h", freq="24h") # Just 2 lead times
stations = ["ST1", "ST2"] # Two stations
lats = [50.0, 51.0]
lons = [5.0, 6.0]

ds = xr.Dataset(
data_vars={
"2m_temperature": (
("time", "prediction_timedelta", "station_id"),
np.random.randn(len(times), len(lead_times), len(stations)),
),
"10m_wind_speed": (
("time", "prediction_timedelta", "station_id"),
np.random.randn(len(times), len(lead_times), len(stations)),
),
},
coords={
"time": times,
"prediction_timedelta": lead_times,
"station_id": stations,
"latitude": ("station_id", lats),
"longitude": ("station_id", lons),
},
)
return ds


@pytest.fixture
def sample_stations():
"""Create a sample stations dataset."""
Expand Down Expand Up @@ -135,6 +166,30 @@ def test_full_pipeline(sample_forecast, sample_stations):
assert set(benchmarks.data_vars) == {"10m_wind_speed", "2m_temperature"}


def test_full_pipeline_with_point_based_forecast(
sample_point_forecast, sample_stations
):
"""Test the full pipeline."""
args = argparse.Namespace(
forecast=sample_point_forecast,
stations=sample_stations,
region="europe",
start_date=datetime(2022, 1, 1),
end_date=datetime(2022, 1, 2),
name_10m_wind_speed="10m_wind_speed",
name_2m_temperature="2m_temperature",
use_dask=False,
output=None,
)

benchmarks = main(args)
assert isinstance(benchmarks, xr.Dataset)

assert set(benchmarks.dims) == {"lead_time", "station_id", "metric"}
assert set(benchmarks.metric.values) == {"rmse", "mbe"}
assert set(benchmarks.data_vars) == {"10m_wind_speed", "2m_temperature"}


def test_rmse_calculation_matches_manual(sample_forecast, sample_stations):
"""Test that the RMSE calculation matches a manual calculation for a simple case."""
# Prepare forecast with known values
Expand Down