diff --git a/CHANGELOG.md b/CHANGELOG.md index 12cf54f..01d4cac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#66](https://github.com/mllam/neural-lam/pull/66) @leifdenby @sadamov +### Fixed + +- Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby + ## [v0.2.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.2.0) ### Added diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e..b0055e3 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ - xy = self.get_xy(category, stacked=False) - extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + xy = self.get_xy(category, stacked=True) + extent = [ + xy[:, 0].min(), + xy[:, 0].max(), + xy[:, 1].min(), + xy[:, 1].max(), + ] return [float(v) for v in extent] @property diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a8..0d1aac7 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,4 +1,5 @@ # Standard library +import copy import warnings from functools import cached_property from pathlib import Path @@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection: class_name = projection_info["class_name"] ProjectionClass = getattr(ccrs, class_name) - kwargs = projection_info["kwargs"] + # need to copy otherwise we modify the dict stored in the dataclass + # in-place + kwargs = copy.deepcopy(projection_info["kwargs"]) globe_kwargs = kwargs.pop("globe", {}) if len(globe_kwargs) > 0: diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c671..44baf9c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,5 +1,6 @@ # Standard library import os +from typing import List, Union # Third-party import matplotlib.pyplot as plt @@ -7,12 +8,14 @@ import pytorch_lightning as pl import torch import wandb +import xarray as xr # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset class ARModel(pl.LightningModule): @@ -147,6 +150,44 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[int, List[int]], + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. + time : Union[int,List[int]] + The time index or indices for the data, given as integers or a list + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion + weather_dataset = WeatherDataset(datastore=self._datastore, split=split) + time = np.array(time.cpu(), dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor.cpu().numpy(), time=time, category=category + ) + return da + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -406,10 +447,13 @@ def test_step(self, batch, batch_idx): ) self.plot_examples( - batch, n_additional_examples, prediction=prediction + batch, + n_additional_examples, + prediction=prediction, + split="test", ) - def plot_examples(self, batch, n_examples, prediction=None): + def plot_examples(self, batch, n_examples, split, prediction=None): """ Plot the first n_examples forecasts from batch @@ -422,18 +466,34 @@ def plot_examples(self, batch, n_examples, prediction=None): prediction, target, _, _ = self.common_step(batch) target = batch[1] + time = batch[3] # Rescale to original data scale prediction_rescaled = prediction * self.state_std + self.state_mean target_rescaled = target * self.state_std + self.state_mean # Iterate over the examples - for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], target_rescaled[:n_examples] + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], ): # Each slice is (pred_steps, num_grid_nodes, d_f) self.plotted_examples += 1 # Increment already here + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -453,18 +513,20 @@ def plot_examples(self, batch, n_examples, prediction=None): var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate( - zip(pred_slice, target_slice), start=1 - ): + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred=pred_t[:, var_i], - target=target_t[:, var_i], datastore=self._datastore, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( @@ -476,6 +538,7 @@ def plot_examples(self, batch, n_examples, prediction=None): ] example_i = self.plotted_examples + wandb.log( { f"{var_name}_example_{example_i}": wandb.Image(fig) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b3..d6b57f8 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import xarray as xr # Local from . import utils @@ -65,9 +66,9 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, - target, datastore: BaseRegularGridDatastore, + da_prediction: xr.DataArray = None, + da_target: xr.DataArray = None, title=None, vrange=None, ): @@ -79,8 +80,8 @@ def plot_prediction( """ # Get common scale for values if vrange is None: - vmin = min(vals.min().cpu().item() for vals in (pred, target)) - vmax = max(vals.max().cpu().item() for vals in (pred, target)) + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) else: vmin, vmax = vrange @@ -88,10 +89,8 @@ def plot_prediction( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + mask_values = np.invert(da_mask.values.astype(bool)).astype(float) + pixel_alpha = mask_values.clip(0.7, 1) # Faded border region fig, axes = plt.subplots( 1, @@ -101,28 +100,23 @@ def plot_prediction( ) # Plot pred and target - for ax, data in zip(axes, (target, pred)): + for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() - .numpy() - ) - im = ax.imshow( - data_grid, + da.plot.imshow( + ax=ax, origin="lower", + x="x", extent=extent, - alpha=pixel_alpha, + alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", + transform=datastore.coords_projection, ) # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) if title: fig.suptitle(title, size=20) @@ -150,9 +144,7 @@ def plot_spatial_error( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region fig, ax = plt.subplots( figsize=(5, 4.8), @@ -161,8 +153,10 @@ def plot_spatial_error( ax.coastlines() # Add coastline outlines error_grid = ( - error.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() .numpy() ) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c9..b5f8558 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -529,7 +529,8 @@ def create_dataarray_from_tensor( tensor : torch.Tensor The tensor to construct the DataArray from, this assumed to have the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. time : datetime.datetime or list[datetime.datetime] The time or times of the tensor. category : str @@ -581,7 +582,7 @@ def _is_listlike(obj): coords["time"] = time da = xr.DataArray( - tensor.numpy(), + tensor.cpu().numpy(), dims=dims, coords=coords, ) diff --git a/pyproject.toml b/pyproject.toml index f0bc085..fdcb7f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "torch>=2.3.0", "torch-geometric==2.3.1", "parse>=1.20.2", - "dataclass-wizard>=0.22.3", + "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9"