Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRIB output #17

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
description: Check for spelling errors
language: system
entry: codespell
args: ['--ignore-words-list=laf']
args: ['--ignore-words-list=laf,pres']
- repo: local
hooks:
- id: black
Expand Down
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- Cartopy
- dask
- dask-jobqueue
- eccodes
- imageio
- ipython
- matplotlib
Expand All @@ -29,6 +30,7 @@ dependencies:
- xarray
- zarr
- pip:
- earthkit-data
- tueplots
- codespell>=2.0.0
- black>=21.9b0
Expand Down
29 changes: 29 additions & 0 deletions neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,24 @@
"V_10M": 0,
}

GRIB_NAME = {
"PP": "pres",
"QV": "q",
"RELHUM": "r",
"T": "t",
"U": "u",
"V": "v",
"W": "wz",
"CLCT": "ccl",
"PMSL": "prmsl",
"PS": "sp",
"T_2M": "2t",
"TOT_PREC": "tp",
"U_10M": "10u",
"V_10M": "10v",
}


# Vertical level weights
# These were retrieved based on the pressure levels of
# https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5
Expand Down Expand Up @@ -183,6 +201,17 @@
EVAL_PLOT_VARS = ["T_2M"]
STORE_EXAMPLE_DATA = True
SELECTED_PROJ = ccrs.PlateCarree()
EXAMPLE_FILE = "data/cosmo/samples/train/data.zarr"
SAMPLE_GRIB = "/users/clechart/neural-lam/templates/lfff02180000"
clechartre marked this conversation as resolved.
Show resolved Hide resolved
SAMPLE_Z_GRIB = "/users/clechart/neural-lam/templates/lfff02180000z"
CHUNK_SIZE = 100
clechartre marked this conversation as resolved.
Show resolved Hide resolved
EVAL_DATETIME = ["2020100215"]
EVAL_PLOT_VARS = ["QV"]
STORE_EXAMPLE_DATA = False
COSMO_PROJ = ccrs.PlateCarree()
SELECTED_PROJ = COSMO_PROJ
POLLON = -170.0
clechartre marked this conversation as resolved.
Show resolved Hide resolved
POLLAT = 43.0
SMOOTH_BOUNDARIES = False

# Some constants useful for sub-classes 3 fluxes variables + 4 time-related
Expand Down
82 changes: 82 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timedelta

# Third-party
import earthkit.data
import imageio
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -823,6 +824,7 @@ def on_predict_epoch_end(self):
prediction_array = prediction_rescaled.cpu().numpy()
file_path = os.path.join(value_dir_path, f"prediction_{i}.npy")
np.save(file_path, prediction_array)
self.save_pred_as_grib(file_path, value_dir_path)

# For plots
for var_name, _ in self.selected_vars_units:
Expand All @@ -847,6 +849,86 @@ def on_predict_epoch_end(self):
for filename in images:
image = imageio.imread(filename)
writer.append_data(image)
self.spatial_loss_maps.clear()

def _generate_time_steps(self):
clechartre marked this conversation as resolved.
Show resolved Hide resolved
"""Generate a list with all time steps in inference."""
# Parse the times
base_time = constants.EVAL_DATETIMES[0]
if isinstance(base_time, str):
base_time = datetime.strptime(base_time, "%Y%m%d%H")
time_steps = {}
# Generate dates for each step
for i in range(constants.EVAL_HORIZON - 2):
# Compute the new date by adding the step interval in hours - 3
new_date = base_time + timedelta(hours=i * constants.TRAIN_HORIZON)
# Format the date back
time_steps[i] = new_date.strftime("%Y%m%d%H")

return time_steps

def save_pred_as_grib(self, file_path, value_dir_path):
"""Save the prediction values into GRIB format."""
# Initialize the lists to loop over
indices = self.precompute_variable_indices()
time_steps = self._generate_time_steps()
# Initialize final data object
final_data = earthkit.data.FieldList()
# Loop through all the time steps and all the variables
for time_idx, date_str in time_steps.items():
for variable, grib_code in constants.GRIB_NAME.items():
# here find the key of the cariable in constants.is_3D
# and if == 7, assign a cut of 7 on the reshape. Else 1
if constants.IS_3D[variable]:
shape_val = 13
vertical = constants.VERTICAL_LEVELS
else:
shape_val = 1
vertical = 1
# Find the value range to sample
value_range = indices[variable]

sample_file = constants.SAMPLE_GRIB
if variable == "RELHUM":
sample_file = constants.SAMPLE_Z_GRIB

# Load the sample grib file
original_data = earthkit.data.from_source("file", sample_file)

subset = original_data.sel(shortName=grib_code, level=vertical)
md = subset.metadata()

# Cut the datestring into date and time and then override all
# values in md
date = date_str[:8]
time = date_str[8:]

# Assuming md is a list of metadata dictionaries
for metadata in md:
metadata.override({"date": date, "time": time})

if len(md) > 0:
# Load the array to replace the values with
# We need to still save it as a .npy
# object and pass it on as an argument to this function
replacement_data = np.load(file_path)
original_cut = replacement_data[
0, time_idx, :, min(value_range) : max(value_range) + 1
].reshape(582, 390, shape_val)
clechartre marked this conversation as resolved.
Show resolved Hide resolved
cut_values = np.moveaxis(
original_cut, [-3, -2, -1], [-1, -2, -3]
)
# Can we stack Fieldlists?
data_new = earthkit.data.FieldList.from_array(
cut_values, md
)
final_data += data_new

# Create the modified GRIB file with the predicted data
grib_path = os.path.join(
value_dir_path, f"prediction_{date_str}_grib"
)
final_data.save(grib_path)

def on_load_checkpoint(self, checkpoint):
"""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ wandb>=0.13.10
matplotlib>=3.7.0
dask
dask_jobqueue
earthkit-data
scipy>=1.10.0
pytorch-lightning>=2.0.3
shapely>=2.0.1
Expand Down
9 changes: 4 additions & 5 deletions slurm_predict.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#!/bin/bash -l
#SBATCH --job-name=NeurWPp
#SBATCH --account=s83
#SBATCH --partition=normal
#SBATCH --partition=pp-short
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --mem=444G
#SBATCH --time=00:59:00
#SBATCH --no-requeue
#SBATCH --output=lightning_logs/neurwp_pred_out.log
#SBATCH --error=lightning_logs/neurwp_pred_err.log

export PREPROCESS=false
export PREPROCESS=true
export NORMALIZE=false
export DATASET="cosmo"
export MODEL="hi_lam"
Expand Down Expand Up @@ -39,7 +38,7 @@ fi

echo "Predicting with model"
if [ "$MODEL" = "hi_lam" ]; then
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict"
srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 0 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict"
else
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict"
srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 0 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict"
fi
Loading