Skip to content

Commit

Permalink
Fix evaluation example visualisation plots (#91)
Browse files Browse the repository at this point in the history
Fix bugs in recently introduced datastore functionality #66 (error in
calculation in `BaseDatastore.get_xy_extent()` and overlooked in-place
modification of config dict in `MDPDatastore.coords_projection`), and
also fix issue in `ARModel.plot_examples` by using newly introduced
(#66) `WeatherDataset.create_dataarray_from_tensor()` to create
`xr.DataArray` from prediction tensor and calling plot methods directly
on `xr.DataArray` rather than using bare numpy arrays with `matplotlib`.
  • Loading branch information
leifdenby authored Dec 4, 2024
1 parent c3c1722 commit 71cfdf9
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 39 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#66](https://github.com/mllam/neural-lam/pull/66)
@leifdenby @sadamov

### Fixed

- Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby

## [v0.2.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.2.0)

### Added
Expand Down
9 changes: 7 additions & 2 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.
"""
xy = self.get_xy(category, stacked=False)
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
xy = self.get_xy(category, stacked=True)
extent = [
xy[:, 0].min(),
xy[:, 0].max(),
xy[:, 1].min(),
xy[:, 1].max(),
]
return [float(v) for v in extent]

@property
Expand Down
5 changes: 4 additions & 1 deletion neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard library
import copy
import warnings
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection:

class_name = projection_info["class_name"]
ProjectionClass = getattr(ccrs, class_name)
kwargs = projection_info["kwargs"]
# need to copy otherwise we modify the dict stored in the dataclass
# in-place
kwargs = copy.deepcopy(projection_info["kwargs"])

globe_kwargs = kwargs.pop("globe", {})
if len(globe_kwargs) > 0:
Expand Down
81 changes: 72 additions & 9 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# Standard library
import os
from typing import List, Union

# Third-party
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
import xarray as xr

# Local
from .. import metrics, vis
from ..config import NeuralLAMConfig
from ..datastore import BaseDatastore
from ..loss_weighting import get_state_feature_weighting
from ..weather_dataset import WeatherDataset


class ARModel(pl.LightningModule):
Expand Down Expand Up @@ -147,6 +150,44 @@ def __init__(
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []

def _create_dataarray_from_tensor(
self,
tensor: torch.Tensor,
time: Union[int, List[int]],
split: str,
category: str,
) -> xr.DataArray:
"""
Create an `xr.DataArray` from a tensor, with the correct dimensions and
coordinates to match the datastore used by the model. This function in
in effect is the inverse of what is returned by
`WeatherDataset.__getitem__`.
Parameters
----------
tensor : torch.Tensor
The tensor to convert to a `xr.DataArray` with dimensions [time,
grid_index, feature]. The tensor will be copied to the CPU if it is
not already there.
time : Union[int,List[int]]
The time index or indices for the data, given as integers or a list
of integers representing epoch time in nanoseconds. The ints will be
copied to the CPU memory if they are not already there.
split : str
The split of the data, either 'train', 'val', or 'test'
category : str
The category of the data, either 'state' or 'forcing'
"""
# TODO: creating an instance of WeatherDataset here on every call is
# not how this should be done but whether WeatherDataset should be
# provided to ARModel or where to put plotting still needs discussion
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
time = np.array(time.cpu(), dtype="datetime64[ns]")
da = weather_dataset.create_dataarray_from_tensor(
tensor=tensor.cpu().numpy(), time=time, category=category
)
return da

def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
Expand Down Expand Up @@ -406,10 +447,13 @@ def test_step(self, batch, batch_idx):
)

self.plot_examples(
batch, n_additional_examples, prediction=prediction
batch,
n_additional_examples,
prediction=prediction,
split="test",
)

def plot_examples(self, batch, n_examples, prediction=None):
def plot_examples(self, batch, n_examples, split, prediction=None):
"""
Plot the first n_examples forecasts from batch
Expand All @@ -422,18 +466,34 @@ def plot_examples(self, batch, n_examples, prediction=None):
prediction, target, _, _ = self.common_step(batch)

target = batch[1]
time = batch[3]

# Rescale to original data scale
prediction_rescaled = prediction * self.state_std + self.state_mean
target_rescaled = target * self.state_std + self.state_mean

# Iterate over the examples
for pred_slice, target_slice in zip(
prediction_rescaled[:n_examples], target_rescaled[:n_examples]
for pred_slice, target_slice, time_slice in zip(
prediction_rescaled[:n_examples],
target_rescaled[:n_examples],
time[:n_examples],
):
# Each slice is (pred_steps, num_grid_nodes, d_f)
self.plotted_examples += 1 # Increment already here

da_prediction = self._create_dataarray_from_tensor(
tensor=pred_slice,
time=time_slice,
split=split,
category="state",
).unstack("grid_index")
da_target = self._create_dataarray_from_tensor(
tensor=target_slice,
time=time_slice,
split=split,
category="state",
).unstack("grid_index")

var_vmin = (
torch.minimum(
pred_slice.flatten(0, 1).min(dim=0)[0],
Expand All @@ -453,18 +513,20 @@ def plot_examples(self, batch, n_examples, prediction=None):
var_vranges = list(zip(var_vmin, var_vmax))

# Iterate over prediction horizon time steps
for t_i, (pred_t, target_t) in enumerate(
zip(pred_slice, target_slice), start=1
):
for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1):
# Create one figure per variable at this time step
var_figs = [
vis.plot_prediction(
pred=pred_t[:, var_i],
target=target_t[:, var_i],
datastore=self._datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
da_prediction=da_prediction.isel(
state_feature=var_i, time=t_i - 1
).squeeze(),
da_target=da_target.isel(
state_feature=var_i, time=t_i - 1
).squeeze(),
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
Expand All @@ -476,6 +538,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
]

example_i = self.plotted_examples

wandb.log(
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
Expand Down
42 changes: 18 additions & 24 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# Local
from . import utils
Expand Down Expand Up @@ -65,9 +66,9 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
pred,
target,
datastore: BaseRegularGridDatastore,
da_prediction: xr.DataArray = None,
da_target: xr.DataArray = None,
title=None,
vrange=None,
):
Expand All @@ -79,19 +80,17 @@ def plot_prediction(
"""
# Get common scale for values
if vrange is None:
vmin = min(vals.min().cpu().item() for vals in (pred, target))
vmax = max(vals.max().cpu().item() for vals in (pred, target))
vmin = min(da_prediction.min(), da_target.min())
vmax = max(da_prediction.max(), da_target.max())
else:
vmin, vmax = vrange

extent = datastore.get_xy_extent("state")

# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
mask_values = np.invert(da_mask.values.astype(bool)).astype(float)
pixel_alpha = mask_values.clip(0.7, 1) # Faded border region

fig, axes = plt.subplots(
1,
Expand All @@ -101,28 +100,23 @@ def plot_prediction(
)

# Plot pred and target
for ax, data in zip(axes, (target, pred)):
for ax, da in zip(axes, (da_target, da_prediction)):
ax.coastlines() # Add coastline outlines
data_grid = (
data.reshape(list(datastore.grid_shape_state.values.values()))
.cpu()
.numpy()
)
im = ax.imshow(
data_grid,
da.plot.imshow(
ax=ax,
origin="lower",
x="x",
extent=extent,
alpha=pixel_alpha,
alpha=pixel_alpha.T,
vmin=vmin,
vmax=vmax,
cmap="plasma",
transform=datastore.coords_projection,
)

# Ticks and labels
axes[0].set_title("Ground Truth", size=15)
axes[1].set_title("Prediction", size=15)
cbar = fig.colorbar(im, aspect=30)
cbar.ax.tick_params(labelsize=10)

if title:
fig.suptitle(title, size=20)
Expand Down Expand Up @@ -150,9 +144,7 @@ def plot_spatial_error(
# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region

fig, ax = plt.subplots(
figsize=(5, 4.8),
Expand All @@ -161,8 +153,10 @@ def plot_spatial_error(

ax.coastlines() # Add coastline outlines
error_grid = (
error.reshape(list(datastore.grid_shape_state.values.values()))
.cpu()
error.reshape(
[datastore.grid_shape_state.x, datastore.grid_shape_state.y]
)
.T.cpu()
.numpy()
)

Expand Down
5 changes: 3 additions & 2 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,8 @@ def create_dataarray_from_tensor(
tensor : torch.Tensor
The tensor to construct the DataArray from, this assumed to have
the same dimension ordering as returned by the __getitem__ method
(i.e. time, grid_index, {category}_feature).
(i.e. time, grid_index, {category}_feature). The tensor will be
copied to the CPU before constructing the DataArray.
time : datetime.datetime or list[datetime.datetime]
The time or times of the tensor.
category : str
Expand Down Expand Up @@ -581,7 +582,7 @@ def _is_listlike(obj):
coords["time"] = time

da = xr.DataArray(
tensor.numpy(),
tensor.cpu().numpy(),
dims=dims,
coords=coords,
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"torch>=2.3.0",
"torch-geometric==2.3.1",
"parse>=1.20.2",
"dataclass-wizard>=0.22.3",
"dataclass-wizard<0.31.0",
"mllam-data-prep>=0.5.0",
]
requires-python = ">=3.9"
Expand Down

0 comments on commit 71cfdf9

Please sign in to comment.