From c3a060a8019d5c8ffcade0274f13201466208e57 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Tue, 23 Apr 2024 09:32:21 +0200 Subject: [PATCH 01/27] Start the npy to grib pipeline --- metadata_generator.py | 89 +++++++++++++++++++++++++++++++++++++++++++ slurm_metadata.sh | 19 +++++++++ 2 files changed, 108 insertions(+) create mode 100644 metadata_generator.py create mode 100644 slurm_metadata.sh diff --git a/metadata_generator.py b/metadata_generator.py new file mode 100644 index 00000000..cbaf5e94 --- /dev/null +++ b/metadata_generator.py @@ -0,0 +1,89 @@ +import json +import os + +import numpy as np +from earthkit.data import FieldList +import metview as mv + +# Need to shuffle the metadata the same way as with constants (see message from Simon) +# Or extract the metadata at the stage where that happens +# Sort according to the last dimension + +class GRIBMetadata: + def __init__(self, grib_data): + self.grib_data = grib_data + + def display(self): + # Display each variable and its metadata in a readable format + for var, metadata in self.grib_data.items(): + print(f"Variable: {var}") + for key, value in metadata.items(): + print(f" {key}: {value}") + + def save_to_file(self, filename): + # Save the metadata dictionary to a JSON file + with open(filename, 'w') as f: + json.dump(self.grib_data, f, indent=4) + + +def map_zarr_to_grib_metadata(zarr_metadata): + # Convert metadata from Zarr format to a GRIB-like format + grib_metadata = {} + for key, value in zarr_metadata['metadata'].items(): + if '/.zarray' in key: + var_name = key.split('/')[0] + array_info = zarr_metadata['metadata'][f'{var_name}/.zarray'] + attrs_info = zarr_metadata['metadata'][f'{var_name}/.zattrs'] + + # Rearrange zmetadata according to the shuffling of constants.py + # -> Makes a subselection of the variables in the zarr archive and shuffles the indices + + if '_ARRAY_DIMENSIONS' in attrs_info: + dimensions = attrs_info['_ARRAY_DIMENSIONS'] + ny = array_info['shape'][dimensions.index('y_1')] if 'y_1' in dimensions else None + nx = array_info['shape'][dimensions.index('x_1')] if 'x_1' in dimensions else None + + grib_metadata[var_name] = { + 'GRIB_paramName': var_name, + 'GRIB_units': attrs_info.get('units', ''), + 'GRIB_dataType': array_info['dtype'], + 'GRIB_totalNumber': array_info['shape'][0] if 'time' in dimensions else 1, + 'GRIB_gridType': 'regular_ll', + 'GRIB_Ny': ny, + 'GRIB_Nx': nx, + 'GRIB_missingValue': array_info['fill_value'] + } + return GRIBMetadata(grib_metadata) + +def extract_grib_metadata_from_zarr(zarr_path): + # Load the Zarr dataset's metadata from the .zmetadata JSON file + metadata_path = os.path.join(zarr_path, '.zmetadata') + with open(metadata_path, 'r') as file: + metadata = json.load(file) + return map_zarr_to_grib_metadata(metadata) + + +def complete_data_into_grib(grib_metadata_object): + data = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") + # How do the pieces of code below work? + data = data.set_values(vals) + mv.write('recip.grib', data) + ds_new = FieldList.from_array(data, grib_metadata_object) + + print(ds_new) + return ds_new + +if __name__ == "__main__": + zarr_path = "/users/clechart/clechart/neural-lam/data/cosmo_old/samples/forecast/data_2020011017.zarr" + grib_metadata_object = extract_grib_metadata_from_zarr(zarr_path) + grib_metadata_object.display() # Display the extracted metadata + + # Save metadata to a file + grib_metadata_object.save_to_file('grib_metadata.json') + + # Reconstruct the GRIB file with the data + full_set = complete_data_into_grib(grib_metadata_object) + + +# How does one read an array as a grib? +# gribfile = xr.open_dataset(joinpath(path,filelist[1]),engine="cfgrib") \ No newline at end of file diff --git a/slurm_metadata.sh b/slurm_metadata.sh new file mode 100644 index 00000000..ac6c4230 --- /dev/null +++ b/slurm_metadata.sh @@ -0,0 +1,19 @@ +#!/bin/bash -l +#SBATCH --job-name=Metadata +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --partition=pp-short +#SBATCH --account=s83 +#SBATCH --output=lightning_logs/metadata_out.log +#SBATCH --error=lightning_logs/metadata_err.log +#SBATCH --time=00:03:00 +#SBATCH --no-requeue + +# Load necessary modules +conda activate neural-lam + + +ulimit -c 0 +export OMP_NUM_THREADS=16 + +srun -ul python metadata_generator.py From fb1e7f70303413dd2c41a7b7f44bbf2b89d17bfa Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 25 Apr 2024 09:43:00 +0200 Subject: [PATCH 02/27] progress with plotting from vis --- grib_modifyer.py | 102 ++++++++++++++++++++++++++++++++++++++++++++++ slurm_metadata.sh | 2 +- 2 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 grib_modifyer.py diff --git a/grib_modifyer.py b/grib_modifyer.py new file mode 100644 index 00000000..e4887d2b --- /dev/null +++ b/grib_modifyer.py @@ -0,0 +1,102 @@ +"""Modifying a GRIB from Python.""" + +import earthkit.data +from matplotlib import pyplot as plt +import cartopy.crs as ccrs +import cartopy.feature as cf +import numpy as np +import pygrib +from neural_lam import constants + + +def plot_data(grb, title, save_path): + """Plot the data using Cartopy and save it to a file.""" + lats, lons = grb.latlons() + data = grb.values + + # Setting up the map projection and the contour plot + fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()}) + ax.coastlines() + contour = ax.contourf(lons, lats, data, levels=np.linspace(data.min(), data.max(), 100), cmap='viridis') + ax.set_title(title) + plt.colorbar(contour, ax=ax, orientation='vertical') + + # Save the plot to a file + plt.savefig(save_path, bbox_inches='tight') + plt.close(fig) # Close the figure to free memory + + +def main(): + # Load the grib file + ds = earthkit.data.from_source("file", "/users/clechart/clechart/neural-lam/laf2024042400") + # subset = ds.sel(shortName=[x.lower() for x in constants.PARAM_NAMES_SHORT], level=constants.VERTICAL_LEVELS) + subset = ds.sel(shortName="u", level=constants.VERTICAL_LEVELS) + + # Load the array to replace the values with + replacement = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") + cut_values = replacement[0,1,:, 26:33].transpose() + + md = subset.metadata() + ds_new = earthkit.data.FieldList.from_array(cut_values, md) + ds_new.save("testinho") + + + +def pygrib_plotting(): + # FIXME fix those damn names + grbs = pygrib.open("/users/clechart/clechart/neural-lam/laf2024042400") + first_grb = grbs.select(shortName = "u", level = 1)[0] # Select the first one + # plot_data(first_grb, "Original grib file", "test.png") + + grps = pygrib.open("testinho") + first_grps = grps.select(shortName = "u", level = 1)[0] + second_grps = first_grps.values + # plot_data(second_grps, "Extracted values from the inference", "testinho.png") + + + """Plot the data using Cartopy and save it to a file.""" + lats, lons = first_grb.latlons() + original_data = first_grb.values + predicted_data = second_grps + vmin = original_data.min() + vmax = original_data.max() + + # Setting up the map projection and the contour plot + fig, axes = plt.subplots( + 2, + 1, + figsize=constants.FIG_SIZE, + subplot_kw={"projection": constants.SELECTED_PROJ}, + ) + for axes, data in zip(axes, (original_data, predicted_data)): + contour_set = axes.contourf( + lons, + lats, + data, + transform=constants.SELECTED_PROJ, + cmap="plasma", + levels=np.linspace(vmin, vmax, num=100), + ) + axes.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") + axes.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") + axes.gridlines( + crs=constants.SELECTED_PROJ, + draw_labels=False, + linewidth=0.5, + alpha=0.5, + ) + + # Ticks and labels + # axes[0].set_title("Ground Truth", size=15) + # axes[1].set_title("Prediction", size=15) + cbar = fig.colorbar(contour_set, orientation="horizontal", aspect=20) + cbar.ax.tick_params(labelsize=10) + + # Save the plot to a file + plt.savefig("megamind_metadata.png", bbox_inches='tight') + plt.close(fig) # Close the figure to free memory + + +if __name__ == "__main__": + pygrib_plotting() + # main() diff --git a/slurm_metadata.sh b/slurm_metadata.sh index ac6c4230..4fd7c1f9 100644 --- a/slurm_metadata.sh +++ b/slurm_metadata.sh @@ -16,4 +16,4 @@ conda activate neural-lam ulimit -c 0 export OMP_NUM_THREADS=16 -srun -ul python metadata_generator.py +srun -ul python grib_modifyer.py From 9bf478592a36136213ef815eb22560b55da3df72 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 25 Apr 2024 15:51:46 +0200 Subject: [PATCH 03/27] current advances --- grib_modifyer.py | 85 +++++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/grib_modifyer.py b/grib_modifyer.py index e4887d2b..961ada5c 100644 --- a/grib_modifyer.py +++ b/grib_modifyer.py @@ -7,57 +7,69 @@ import numpy as np import pygrib from neural_lam import constants +from neural_lam.rotate_grid import unrot_lat, unrot_lon, unrotate_latlon -def plot_data(grb, title, save_path): - """Plot the data using Cartopy and save it to a file.""" - lats, lons = grb.latlons() - data = grb.values +# def plot_data(grb, title, save_path): +# """Plot the data using Cartopy and save it to a file.""" +# lats, lons = grb.latlons() +# data = grb.values - # Setting up the map projection and the contour plot - fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()}) - ax.coastlines() - contour = ax.contourf(lons, lats, data, levels=np.linspace(data.min(), data.max(), 100), cmap='viridis') - ax.set_title(title) - plt.colorbar(contour, ax=ax, orientation='vertical') +# # Setting up the map projection and the contour plot +# fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()}) +# ax.coastlines() +# contour = ax.contourf(lons, lats, data, levels=np.linspace(data.min(), data.max(), 100), cmap='viridis') +# ax.set_title(title) +# plt.colorbar(contour, ax=ax, orientation='vertical') - # Save the plot to a file - plt.savefig(save_path, bbox_inches='tight') - plt.close(fig) # Close the figure to free memory +# # Save the plot to a file +# plt.savefig(save_path, bbox_inches='tight') +# plt.close(fig) def main(): # Load the grib file - ds = earthkit.data.from_source("file", "/users/clechart/clechart/neural-lam/laf2024042400") - # subset = ds.sel(shortName=[x.lower() for x in constants.PARAM_NAMES_SHORT], level=constants.VERTICAL_LEVELS) - subset = ds.sel(shortName="u", level=constants.VERTICAL_LEVELS) + original_data = earthkit.data.from_source("file", "/users/clechart/clechart/neural-lam/laf2024042400") + # subset = original_data.sel(shortName=[x.lower() for x in constants.PARAM_NAMES_SHORT], level=constants.VERTICAL_LEVELS) + subset = original_data.sel(shortName="u", level=constants.VERTICAL_LEVELS) # Load the array to replace the values with - replacement = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") - cut_values = replacement[0,1,:, 26:33].transpose() - + replacement_data = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") + cut_values = replacement_data[0,1,:, 26:33].transpose() + cut_values = np.flip(cut_values, axis=(0,1)) + # cut_values = cut_values.reshape(cut_values.shape[0], -1) + # cut_values = cut_values.reshape(-1, cut_values.shape[0]) + + # Ensure the dimensions match + assert cut_values.shape == subset.values.shape, "The shapes of the arrays don't match." + + + # Find indices where values are around -8 + close_to_minus_eight_subset = np.where((subset.values > -8.5) & (subset.values < -8.0)) + close_to_minus_eight_cut_values = np.where((cut_values > -8.5) & (cut_values < -8.0)) + + print(f"Indices in subset.values close to -8: {close_to_minus_eight_subset}") + print(f"Indices in cut_values close to -8: {close_to_minus_eight_cut_values}") + + # Save the overwritten data md = subset.metadata() - ds_new = earthkit.data.FieldList.from_array(cut_values, md) - ds_new.save("testinho") + data_new = earthkit.data.FieldList.from_array(cut_values, md) + data_new.save("testinho") def pygrib_plotting(): - # FIXME fix those damn names - grbs = pygrib.open("/users/clechart/clechart/neural-lam/laf2024042400") - first_grb = grbs.select(shortName = "u", level = 1)[0] # Select the first one - # plot_data(first_grb, "Original grib file", "test.png") - - grps = pygrib.open("testinho") - first_grps = grps.select(shortName = "u", level = 1)[0] - second_grps = first_grps.values - # plot_data(second_grps, "Extracted values from the inference", "testinho.png") + # Load original GRIB data + original_grib = pygrib.open("/users/clechart/clechart/neural-lam/laf2024042400") + grb = original_grib.select(shortName = "u", level = 1)[0] + lats, lons = grb.latlons() + original_data = grb.values + # Load transformed GRIB with inference output + inference_grib = pygrib.open("testinho") + inf = inference_grib.select(shortName = "u", level = 1)[0] + predicted_data = inf.values - """Plot the data using Cartopy and save it to a file.""" - lats, lons = first_grb.latlons() - original_data = first_grb.values - predicted_data = second_grps vmin = original_data.min() vmax = original_data.max() @@ -94,9 +106,8 @@ def pygrib_plotting(): # Save the plot to a file plt.savefig("megamind_metadata.png", bbox_inches='tight') - plt.close(fig) # Close the figure to free memory - + plt.close(fig) if __name__ == "__main__": + main() pygrib_plotting() - # main() From 79aad5605a7b485a083554e1331df382ef6e4410 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 25 Apr 2024 17:12:38 +0200 Subject: [PATCH 04/27] correct-ish reshaping --- grib_modifyer.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/grib_modifyer.py b/grib_modifyer.py index 961ada5c..4c302893 100644 --- a/grib_modifyer.py +++ b/grib_modifyer.py @@ -35,26 +35,30 @@ def main(): # Load the array to replace the values with replacement_data = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") - cut_values = replacement_data[0,1,:, 26:33].transpose() - cut_values = np.flip(cut_values, axis=(0,1)) + original_cut = replacement_data[0,1,:,26:33].reshape(582,390, 7) + cut_values = np.moveaxis(original_cut, [-3,-2,-1], [-1,-2,-3]) + + # cut_values = np.flip(cut_values, axis=(0,1)) # cut_values = cut_values.reshape(cut_values.shape[0], -1) # cut_values = cut_values.reshape(-1, cut_values.shape[0]) - # Ensure the dimensions match - assert cut_values.shape == subset.values.shape, "The shapes of the arrays don't match." + # # Ensure the dimensions match + # assert cut_values.shape == subset.values.shape, "The shapes of the arrays don't match." - # Find indices where values are around -8 - close_to_minus_eight_subset = np.where((subset.values > -8.5) & (subset.values < -8.0)) - close_to_minus_eight_cut_values = np.where((cut_values > -8.5) & (cut_values < -8.0)) + # # Find indices where values are around -8 + # close_to_minus_eight_subset = np.where((subset.values > -8.5) & (subset.values < -8.0)) + # close_to_minus_eight_cut_values = np.where((cut_values > -8.5) & (cut_values < -8.0)) - print(f"Indices in subset.values close to -8: {close_to_minus_eight_subset}") - print(f"Indices in cut_values close to -8: {close_to_minus_eight_cut_values}") + # print(f"Indices in subset.values close to -8: {close_to_minus_eight_subset}") + # print(f"Indices in cut_values close to -8: {close_to_minus_eight_cut_values}") # Save the overwritten data md = subset.metadata() - data_new = earthkit.data.FieldList.from_array(cut_values, md) - data_new.save("testinho") + # ni= md[0]["Ni"] + # nj = md[0]["Nj"] + data_new = earthkit.data.FieldList.from_array(cut_values, md) # .reshape(7, nj, ni) + data_new.save("testinhoModif") @@ -66,7 +70,7 @@ def pygrib_plotting(): original_data = grb.values # Load transformed GRIB with inference output - inference_grib = pygrib.open("testinho") + inference_grib = pygrib.open("testinhoModif") inf = inference_grib.select(shortName = "u", level = 1)[0] predicted_data = inf.values From 472981b4164a02b3438d35cb0117e9f28e2247f9 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 26 Apr 2024 11:52:52 +0200 Subject: [PATCH 05/27] working plot for temperature --- grib_modifyer.py | 121 +++++++++++++---------------------------------- 1 file changed, 34 insertions(+), 87 deletions(-) diff --git a/grib_modifyer.py b/grib_modifyer.py index 4c302893..c5811ba2 100644 --- a/grib_modifyer.py +++ b/grib_modifyer.py @@ -1,117 +1,64 @@ -"""Modifying a GRIB from Python.""" - import earthkit.data +import numpy as np +import pygrib from matplotlib import pyplot as plt import cartopy.crs as ccrs import cartopy.feature as cf -import numpy as np -import pygrib from neural_lam import constants -from neural_lam.rotate_grid import unrot_lat, unrot_lon, unrotate_latlon - -# def plot_data(grb, title, save_path): -# """Plot the data using Cartopy and save it to a file.""" -# lats, lons = grb.latlons() -# data = grb.values +def plot_data(grb, title, ax, projection, vmin, vmax, color_map='plasma', num_contours=100): + """Plot the data using Cartopy with specified projection on given axis, using shared color scale.""" + lats, lons = grb.latlons() + data = grb.values -# # Setting up the map projection and the contour plot -# fig, ax = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()}) -# ax.coastlines() -# contour = ax.contourf(lons, lats, data, levels=np.linspace(data.min(), data.max(), 100), cmap='viridis') -# ax.set_title(title) -# plt.colorbar(contour, ax=ax, orientation='vertical') + ax.add_feature(cf.BORDERS, linestyle='-', edgecolor='black') + ax.add_feature(cf.COASTLINE, linestyle='-', edgecolor='black') + ax.set_title(title) -# # Save the plot to a file -# plt.savefig(save_path, bbox_inches='tight') -# plt.close(fig) + contour = ax.contourf(lons, lats, data, transform=projection(), levels=np.linspace(vmin, vmax, num_contours), cmap=color_map) + return contour +def create_modified_grib(original_data, cut_values, modified_file_name): + # Save the overwritten data + md = original_data.metadata() + data_new = earthkit.data.FieldList.from_array(cut_values, md) + data_new.save(modified_file_name) def main(): - # Load the grib file + # Load the original grib file original_data = earthkit.data.from_source("file", "/users/clechart/clechart/neural-lam/laf2024042400") - # subset = original_data.sel(shortName=[x.lower() for x in constants.PARAM_NAMES_SHORT], level=constants.VERTICAL_LEVELS) subset = original_data.sel(shortName="u", level=constants.VERTICAL_LEVELS) # Load the array to replace the values with replacement_data = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") original_cut = replacement_data[0,1,:,26:33].reshape(582,390, 7) cut_values = np.moveaxis(original_cut, [-3,-2,-1], [-1,-2,-3]) - - # cut_values = np.flip(cut_values, axis=(0,1)) - # cut_values = cut_values.reshape(cut_values.shape[0], -1) - # cut_values = cut_values.reshape(-1, cut_values.shape[0]) - # # Ensure the dimensions match - # assert cut_values.shape == subset.values.shape, "The shapes of the arrays don't match." + # Create the modified GRIB file with the predicted data + modified_grib_path = "testinhoModif" + create_modified_grib(subset, cut_values, modified_grib_path) - - # # Find indices where values are around -8 - # close_to_minus_eight_subset = np.where((subset.values > -8.5) & (subset.values < -8.0)) - # close_to_minus_eight_cut_values = np.where((cut_values > -8.5) & (cut_values < -8.0)) - - # print(f"Indices in subset.values close to -8: {close_to_minus_eight_subset}") - # print(f"Indices in cut_values close to -8: {close_to_minus_eight_cut_values}") - - # Save the overwritten data - md = subset.metadata() - # ni= md[0]["Ni"] - # nj = md[0]["Nj"] - data_new = earthkit.data.FieldList.from_array(cut_values, md) # .reshape(7, nj, ni) - data_new.save("testinhoModif") - - - -def pygrib_plotting(): - # Load original GRIB data + # Open the original and modified GRIB files original_grib = pygrib.open("/users/clechart/clechart/neural-lam/laf2024042400") - grb = original_grib.select(shortName = "u", level = 1)[0] - lats, lons = grb.latlons() - original_data = grb.values + grb_original = original_grib.select(shortName="u", level=constants.VERTICAL_LEVELS[0])[0] - # Load transformed GRIB with inference output - inference_grib = pygrib.open("testinhoModif") - inf = inference_grib.select(shortName = "u", level = 1)[0] - predicted_data = inf.values + predicted_grib = pygrib.open(modified_grib_path) + grb_predicted = predicted_grib.select(shortName="u", level=constants.VERTICAL_LEVELS[0])[0] - vmin = original_data.min() - vmax = original_data.max() + # Determine the global min and max values for the colorbar + vmin = min(grb_original.values.min(), grb_predicted.values.min()) + vmax = max(grb_original.values.max(), grb_predicted.values.max()) - # Setting up the map projection and the contour plot - fig, axes = plt.subplots( - 2, - 1, - figsize=constants.FIG_SIZE, - subplot_kw={"projection": constants.SELECTED_PROJ}, - ) - for axes, data in zip(axes, (original_data, predicted_data)): - contour_set = axes.contourf( - lons, - lats, - data, - transform=constants.SELECTED_PROJ, - cmap="plasma", - levels=np.linspace(vmin, vmax, num=100), - ) - axes.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") - axes.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") - axes.gridlines( - crs=constants.SELECTED_PROJ, - draw_labels=False, - linewidth=0.5, - alpha=0.5, - ) + fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 12), subplot_kw={'projection': ccrs.PlateCarree()}) + contour1 = plot_data(grb_original, "Original Data", ax1, ccrs.PlateCarree, vmin, vmax) + contour2 = plot_data(grb_predicted, "Predicted Data", ax2, ccrs.PlateCarree, vmin, vmax) - # Ticks and labels - # axes[0].set_title("Ground Truth", size=15) - # axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(contour_set, orientation="horizontal", aspect=20) - cbar.ax.tick_params(labelsize=10) + plt.subplots_adjust(hspace=0.1, wspace=0.05) + colorbar_ax = fig.add_axes([0.15, 0.08, 0.7, 0.02]) # Position for the colorbar + fig.colorbar(contour1, cax=colorbar_ax, orientation='horizontal', shrink=0.5) - # Save the plot to a file - plt.savefig("megamind_metadata.png", bbox_inches='tight') + plt.savefig("combined_vertical_data_adjusted.png", bbox_inches='tight') plt.close(fig) if __name__ == "__main__": main() - pygrib_plotting() From 79a867e53c984d51f52ab0e6aa17374cc0d31c18 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 26 Apr 2024 12:30:28 +0200 Subject: [PATCH 06/27] abstraction and linting --- grib_modifier.py | 106 ++++++++++++++++++++++++++++++++++++++++++++++ grib_modifyer.py | 64 ---------------------------- slurm_metadata.sh | 2 +- 3 files changed, 107 insertions(+), 65 deletions(-) create mode 100644 grib_modifier.py delete mode 100644 grib_modifyer.py diff --git a/grib_modifier.py b/grib_modifier.py new file mode 100644 index 00000000..460394fb --- /dev/null +++ b/grib_modifier.py @@ -0,0 +1,106 @@ +# Third-party +import cartopy.crs as ccrs +import cartopy.feature as cf +import earthkit.data +import numpy as np +import pygrib +from matplotlib import pyplot as plt + +# First-party +from neural_lam import constants + + +def plot_data( + grb, title, ax, projection, vmin, vmax, color_map="plasma", num_contours=100 +): + """Plot the data using Cartopy.""" + lats, lons = grb.latlons() + data = grb.values + + ax.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") + ax.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") + ax.set_title(title) + + contour = ax.contourf( + lons, + lats, + data, + transform=projection(), + levels=np.linspace(vmin, vmax, num_contours), + cmap=color_map, + ) + return contour + + +def modify_data(): + """Fit the numpy values into GRIB format.""" + # Load the original grib file + original_data = earthkit.data.from_source( + "file", "/users/clechart/clechart/neural-lam/laf2024042400" + ) + subset = original_data.sel(shortName="u", level=constants.VERTICAL_LEVELS) + + # Load the array to replace the values with + replacement_data = np.load( + "/users/clechart/clechart/neural-lam/" + "wandb/run-20240417_104748-dxnil3vw/files/" + "results/inference/prediction_0.npy" + ) + original_cut = replacement_data[0, 1, :, 26:33].reshape(582, 390, 7) + cut_values = np.moveaxis(original_cut, [-3, -2, -1], [-1, -2, -3]) + + # Create the modified GRIB file with the predicted data + modified_grib_path = "/users/clechart/clechart/neural-lam/modified_grib" + md = subset.metadata() + data_new = earthkit.data.FieldList.from_array(cut_values, md) + data_new.save(modified_grib_path) + + +def generate_plot(): + """Plot the original GRIB entries against transformed inference.""" + # Open the original and modified GRIB files + original_grib = pygrib.open( + "/users/clechart/clechart/neural-lam/laf2024042400" + ) + grb_original = original_grib.select( + shortName="u", level=constants.VERTICAL_LEVELS[0] + )[0] + + predicted_grib = pygrib.open( + "/users/clechart/clechart/neural-lam/modified_grib" + ) + grb_predicted = predicted_grib.select( + shortName="u", level=constants.VERTICAL_LEVELS[0] + )[0] + + # Determine the global min and max values for the colorbar + vmin = min(grb_original.values.min(), grb_predicted.values.min()) + vmax = max(grb_original.values.max(), grb_predicted.values.max()) + + fig, (ax1, ax2) = plt.subplots( + nrows=2, + figsize=constants.FIG_SIZE, + subplot_kw={"projection": ccrs.PlateCarree()}, + ) + contour1 = plot_data( + grb_original, "Original Data", ax1, ccrs.PlateCarree, vmin, vmax + ) + _ = plot_data( + grb_predicted, "Predicted Data", ax2, ccrs.PlateCarree, vmin, vmax + ) + + plt.subplots_adjust(hspace=0.1, wspace=0.05) + colorbar_ax = fig.add_axes( + [0.15, 0.08, 0.7, 0.02] + ) # Position for the colorbar + fig.colorbar( + contour1, cax=colorbar_ax, orientation="horizontal", shrink=0.5 + ) + + plt.savefig("completed_grib.png", bbox_inches="tight") + plt.close(fig) + + +if __name__ == "__main__": + modify_data() + generate_plot() diff --git a/grib_modifyer.py b/grib_modifyer.py deleted file mode 100644 index c5811ba2..00000000 --- a/grib_modifyer.py +++ /dev/null @@ -1,64 +0,0 @@ -import earthkit.data -import numpy as np -import pygrib -from matplotlib import pyplot as plt -import cartopy.crs as ccrs -import cartopy.feature as cf -from neural_lam import constants - -def plot_data(grb, title, ax, projection, vmin, vmax, color_map='plasma', num_contours=100): - """Plot the data using Cartopy with specified projection on given axis, using shared color scale.""" - lats, lons = grb.latlons() - data = grb.values - - ax.add_feature(cf.BORDERS, linestyle='-', edgecolor='black') - ax.add_feature(cf.COASTLINE, linestyle='-', edgecolor='black') - ax.set_title(title) - - contour = ax.contourf(lons, lats, data, transform=projection(), levels=np.linspace(vmin, vmax, num_contours), cmap=color_map) - return contour - -def create_modified_grib(original_data, cut_values, modified_file_name): - # Save the overwritten data - md = original_data.metadata() - data_new = earthkit.data.FieldList.from_array(cut_values, md) - data_new.save(modified_file_name) - -def main(): - # Load the original grib file - original_data = earthkit.data.from_source("file", "/users/clechart/clechart/neural-lam/laf2024042400") - subset = original_data.sel(shortName="u", level=constants.VERTICAL_LEVELS) - - # Load the array to replace the values with - replacement_data = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") - original_cut = replacement_data[0,1,:,26:33].reshape(582,390, 7) - cut_values = np.moveaxis(original_cut, [-3,-2,-1], [-1,-2,-3]) - - # Create the modified GRIB file with the predicted data - modified_grib_path = "testinhoModif" - create_modified_grib(subset, cut_values, modified_grib_path) - - # Open the original and modified GRIB files - original_grib = pygrib.open("/users/clechart/clechart/neural-lam/laf2024042400") - grb_original = original_grib.select(shortName="u", level=constants.VERTICAL_LEVELS[0])[0] - - predicted_grib = pygrib.open(modified_grib_path) - grb_predicted = predicted_grib.select(shortName="u", level=constants.VERTICAL_LEVELS[0])[0] - - # Determine the global min and max values for the colorbar - vmin = min(grb_original.values.min(), grb_predicted.values.min()) - vmax = max(grb_original.values.max(), grb_predicted.values.max()) - - fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 12), subplot_kw={'projection': ccrs.PlateCarree()}) - contour1 = plot_data(grb_original, "Original Data", ax1, ccrs.PlateCarree, vmin, vmax) - contour2 = plot_data(grb_predicted, "Predicted Data", ax2, ccrs.PlateCarree, vmin, vmax) - - plt.subplots_adjust(hspace=0.1, wspace=0.05) - colorbar_ax = fig.add_axes([0.15, 0.08, 0.7, 0.02]) # Position for the colorbar - fig.colorbar(contour1, cax=colorbar_ax, orientation='horizontal', shrink=0.5) - - plt.savefig("combined_vertical_data_adjusted.png", bbox_inches='tight') - plt.close(fig) - -if __name__ == "__main__": - main() diff --git a/slurm_metadata.sh b/slurm_metadata.sh index 4fd7c1f9..267ac812 100644 --- a/slurm_metadata.sh +++ b/slurm_metadata.sh @@ -16,4 +16,4 @@ conda activate neural-lam ulimit -c 0 export OMP_NUM_THREADS=16 -srun -ul python grib_modifyer.py +srun -ul python grib_modifier.py From 2424405b90867e2908233c55615b2071b6a8ad10 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Mon, 29 Apr 2024 15:16:27 +0200 Subject: [PATCH 07/27] integrating the GRIB transform to predict step --- neural_lam/constants.py | 4 +++ neural_lam/models/ar_model.py | 47 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 8da779a2..39a369b9 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -140,6 +140,10 @@ # Plotting FIG_SIZE = (15, 10) EXAMPLE_FILE = "data/cosmo/samples/train/data_2015112800.zarr" +SAMPLE_GRIB = "/scratch/mch/sadamov/pyprojects_data/"\ + "neural_lam/data/cosmo/templates/lfff02180000" +SAMPLE_Z_GRIB = "/scratch/mch/sadamov/pyprojects_data/"\ + "neural_lam/data/cosmo/templates/lfff02180000z" CHUNK_SIZE = 100 EVAL_DATETIME = "2020100215" EVAL_PLOT_VARS = ["TQV"] diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7d49184c..094494a1 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta # Third-party +import earthkit.data import imageio import matplotlib.pyplot as plt import numpy as np @@ -826,7 +827,9 @@ def on_predict_epoch_end(self): # Process and save the prediction prediction_array = prediction_rescaled.cpu().numpy() file_path = os.path.join(value_dir_path, f"prediction_{i}.npy") + grib_path = os.path.join(value_dir_path, f"prediction_{i}_grib") np.save(file_path, prediction_array) + self.save_pred_as_grib(prediction_rescaled, grib_path) # For plots for var_name, _ in self.selected_vars_units: @@ -852,6 +855,50 @@ def on_predict_epoch_end(self): image = imageio.imread(filename) writer.append_data(image) + def save_pred_as_grib(self, prediction, grib_path): + """Save the prediction values into GRIB format.""" + indices = self.precompute_variable_indices() + # Initialize final data object + final_data = earthkit.data.FieldList() + + # Loop through all of them one by one + # ATN between 3D and 2D - 7vs1 lvl (for reshaping) + for variable in constants.PARAM_NAMES_SHORT: + # here find the key of the cariable in constants.is_3D + # and if == 7, assign a cut of 7 on the reshape. Else 1 + shape_val = 7 if constants.IS_3D[variable] else 1 + # Find the value range to sample + value_range = indices[variable] + + sample_file = constants.SAMPLE_GRIB + if variable == "RELHUM": + sample_file = constants.SAMPLE_Z_GRIB + + # Load the sample grib file + original_data = earthkit.data.from_source("file", sample_file) + + subset = original_data.sel( + shortName=variable.lower(), level=constants.VERTICAL_LEVELS + ) + md = subset.metadata() + + if len(md) > 0: + # Load the array to replace the values with + # We need to still save it as a .npy + # object and pass it on as an argument to this function + original_cut = prediction[ + 0, 1, :, min(value_range) : max(value_range) + 1 + ].reshape(582, 390, shape_val) + cut_values = np.moveaxis( + original_cut, [-3, -2, -1], [-1, -2, -3] + ) + # Can we stack Fieldlists? + data_new = earthkit.data.FieldList.from_array(cut_values, md) + final_data += data_new + + # Create the modified GRIB file with the predicted data + final_data.save(grib_path) + def on_load_checkpoint(self, checkpoint): """ Perform any changes to state dict before loading checkpoint From b69895266b1ab00c6d58666e9c4e9ba2e77fb2fe Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Mon, 29 Apr 2024 16:28:20 +0200 Subject: [PATCH 08/27] Looping over time steps for grib generation --- neural_lam/models/ar_model.py | 97 +++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 094494a1..53252502 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -827,9 +827,8 @@ def on_predict_epoch_end(self): # Process and save the prediction prediction_array = prediction_rescaled.cpu().numpy() file_path = os.path.join(value_dir_path, f"prediction_{i}.npy") - grib_path = os.path.join(value_dir_path, f"prediction_{i}_grib") np.save(file_path, prediction_array) - self.save_pred_as_grib(prediction_rescaled, grib_path) + self.save_pred_as_grib(prediction_rescaled, value_dir_path) # For plots for var_name, _ in self.selected_vars_units: @@ -855,49 +854,69 @@ def on_predict_epoch_end(self): image = imageio.imread(filename) writer.append_data(image) - def save_pred_as_grib(self, prediction, grib_path): + def _generate_time_steps(self): + """Generate a list with all time steps in inference.""" + # Parse the times + base_time = constants.EVAL_DATETIME + time_steps = {} + # Generate dates for each step + for i in range(constants.EVAL_HORIZON - 2): + # Compute the new date by adding the step interval in hours - 3 + new_date = base_time + timedelta(hours=i * constants.TRAIN_HORIZON) + # Format the date back + time_steps[i] = new_date.strftime("%Y%m%d%H") + + return time_steps + + def save_pred_as_grib(self, prediction, value_dir_path): """Save the prediction values into GRIB format.""" + # Initialize the lists to loop over indices = self.precompute_variable_indices() + time_steps = self._generate_time_steps() # Initialize final data object final_data = earthkit.data.FieldList() - - # Loop through all of them one by one - # ATN between 3D and 2D - 7vs1 lvl (for reshaping) - for variable in constants.PARAM_NAMES_SHORT: - # here find the key of the cariable in constants.is_3D - # and if == 7, assign a cut of 7 on the reshape. Else 1 - shape_val = 7 if constants.IS_3D[variable] else 1 - # Find the value range to sample - value_range = indices[variable] - - sample_file = constants.SAMPLE_GRIB - if variable == "RELHUM": - sample_file = constants.SAMPLE_Z_GRIB - - # Load the sample grib file - original_data = earthkit.data.from_source("file", sample_file) - - subset = original_data.sel( - shortName=variable.lower(), level=constants.VERTICAL_LEVELS - ) - md = subset.metadata() - - if len(md) > 0: - # Load the array to replace the values with - # We need to still save it as a .npy - # object and pass it on as an argument to this function - original_cut = prediction[ - 0, 1, :, min(value_range) : max(value_range) + 1 - ].reshape(582, 390, shape_val) - cut_values = np.moveaxis( - original_cut, [-3, -2, -1], [-1, -2, -3] + # Loop through all the time steps and all the variables + for time_idx, date_str in time_steps.items(): + for variable in constants.PARAM_NAMES_SHORT: + # here find the key of the cariable in constants.is_3D + # and if == 7, assign a cut of 7 on the reshape. Else 1 + shape_val = 7 if constants.IS_3D[variable] else 1 + # Find the value range to sample + value_range = indices[variable] + + sample_file = constants.SAMPLE_GRIB + if variable == "RELHUM": + sample_file = constants.SAMPLE_Z_GRIB + + # Load the sample grib file + original_data = earthkit.data.from_source("file", sample_file) + + subset = original_data.sel( + shortName=variable.lower(), level=constants.VERTICAL_LEVELS ) - # Can we stack Fieldlists? - data_new = earthkit.data.FieldList.from_array(cut_values, md) - final_data += data_new + md = subset.metadata() + + if len(md) > 0: + # Load the array to replace the values with + # We need to still save it as a .npy + # object and pass it on as an argument to this function + original_cut = prediction[ + 0, time_idx, :, min(value_range) : max(value_range) + 1 + ].reshape(582, 390, shape_val) + cut_values = np.moveaxis( + original_cut, [-3, -2, -1], [-1, -2, -3] + ) + # Can we stack Fieldlists? + data_new = earthkit.data.FieldList.from_array( + cut_values, md + ) + final_data += data_new - # Create the modified GRIB file with the predicted data - final_data.save(grib_path) + # Create the modified GRIB file with the predicted data + grib_path = os.path.join( + value_dir_path, f"prediction_{date_str}_grib" + ) + final_data.save(grib_path) def on_load_checkpoint(self, checkpoint): """ From 3c1e27a6a0ed6ccae34ea7a68712fcab5b50d564 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Mon, 29 Apr 2024 17:17:51 +0200 Subject: [PATCH 09/27] overriding the time stamp in metadata --- neural_lam/models/ar_model.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f37e9641..fdd02377 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -854,6 +854,10 @@ def _generate_time_steps(self): """Generate a list with all time steps in inference.""" # Parse the times base_time = constants.EVAL_DATETIMES[0] + if isinstance(base_time, str): + base_time = datetime.strptime(base_time, "%Y%m%d%H") + else: + base_time = base_time time_steps = {} # Generate dates for each step for i in range(constants.EVAL_HORIZON - 2): @@ -892,6 +896,18 @@ def save_pred_as_grib(self, prediction, value_dir_path): ) md = subset.metadata() + # Cut the datestring into date and time and then override all + # values in md + date = date_str[:8] + time = date_str[8:] + + # Assuming md is a list of metadata dictionaries + for metadata in md: + metadata.override({ + "date": date, + "time": time + }) + if len(md) > 0: # Load the array to replace the values with # We need to still save it as a .npy From df8d7363f3a5da4093bfe603cfb8685c862ceeaf Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 10:21:25 +0200 Subject: [PATCH 10/27] generalized version --- generalized_modifier.py | 122 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 generalized_modifier.py diff --git a/generalized_modifier.py b/generalized_modifier.py new file mode 100644 index 00000000..7d11ca72 --- /dev/null +++ b/generalized_modifier.py @@ -0,0 +1,122 @@ +# Third-party +from datetime import datetime, timedelta +import earthkit.data +import numpy as np + +# First-party +from neural_lam import constants + + +def modify_data(prediction: np.array): + """Fit the numpy values into GRIB format.""" + + indices = precompute_variable_indices() + time_steps = generate_time_steps() + vertical_levels = [1, 5, 13, 22, 38, 41, 60] + + # Initialize final data object + final_data = earthkit.data.FieldList() + + # Loop through all the time steps and all the variables + for time_idx, date_str in time_steps.items(): + + # ATN between 3D and 2D - 7vs1 lvl (for reshaping) + for variable in constants.PARAM_NAMES_SHORT: + # here find the key of the cariable in constants.is_3D and if == 7, assign a cut of 7 on the + # reshape. Else 1 + shape_val = 7 if constants.IS_3D[variable] else 1 + # Find the value range to sample + value_range = indices[variable] + + sample_file = constants.SAMPLE_GRIB + if variable == "RELHUM": + variable = "r" + sample_file = constants.SAMPLE_Z_GRIB + + # Load the sample grib file + original_data = earthkit.data.from_source( + "file", sample_file + ) + + subset = original_data.sel(shortName= variable.lower(), level=vertical_levels) + md = subset.metadata() + + # Cut the datestring into date and time and then override all + # values in md + date = date_str[:8] + time = date_str[8:] + + # Assuming md is a list of metadata dictionaries + for metadata in md: + metadata.override({ + "date": date, + "time": time + }) + if len(md)>0: + # Load the array to replace the values with + # We need to still save it as a .npy object and pass it on as an argument to this function + replacement_data = np.load(prediction) + original_cut = replacement_data[0, time_idx, :, min(value_range):max(value_range)+1].reshape(582, 390, shape_val) + cut_values = np.moveaxis(original_cut, [-3, -2, -1], [-1, -2, -3]) + # Can we stack Fieldlists? + data_new = earthkit.data.FieldList.from_array(cut_values, md) + final_data += data_new + + # Create the modified GRIB file with the predicted data + modified_grib_path =f"lightning_logs/prediction_{date_str}" + final_data.save(modified_grib_path) + +# This function is taken from ar_model, need to just use self when I go +# put this function into on_predict_epoch_end() +def precompute_variable_indices(): + """ + Precompute indices for each variable in the input tensor + """ + variable_indices = {} + all_vars = [] + index = 0 + # Create a list of tuples for all variables, using level 0 for 2D + # variables + for var_name in constants.PARAM_NAMES_SHORT: + if constants.IS_3D[var_name]: + for level in constants.VERTICAL_LEVELS: + all_vars.append((var_name, level)) + else: + all_vars.append((var_name, 0)) # Use level 0 for 2D variables + + # Sort the variables based on the tuples + sorted_vars = sorted(all_vars) + + for var in sorted_vars: + var_name, level = var + if var_name not in variable_indices: + variable_indices[var_name] = [] + variable_indices[var_name].append(index) + index += 1 + + return variable_indices + + +def generate_time_steps(): + # Parse the times + base_time = constants.EVAL_DATETIMES[0] + if isinstance(base_time, str): + base_time = datetime.strptime(base_time, "%Y%m%d%H") + else: + base_time = base_time + time_steps = {} + # Generate dates for each step + for i in range(constants.EVAL_HORIZON - 2): + # Compute the new date by adding the step interval in hours - 3 + new_date = base_time + timedelta(hours=i * constants.TRAIN_HORIZON) + # Format the date back + time_steps[i] = new_date.strftime("%Y%m%d%H") + + return time_steps + + + +if __name__ == "__main__": + precompute_variable_indices() + time_steps = generate_time_steps() + modify_data(prediction="/users/clechart/neural-lam/templates/predictions.npy") From d6a5d0398da7b957ad14dcf716fb06f34ef0caff Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 11:54:14 +0200 Subject: [PATCH 11/27] files for running prediction --- environment.yml | 1 + neural_lam/constants.py | 6 ++---- slurm_predict.sh | 9 ++++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/environment.yml b/environment.yml index 94645e42..69424afb 100644 --- a/environment.yml +++ b/environment.yml @@ -29,6 +29,7 @@ dependencies: - xarray - zarr - pip: + - earthkit.data - tueplots - codespell>=2.0.0 - black>=21.9b0 diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 5a3e371a..5d8b893c 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -184,10 +184,8 @@ STORE_EXAMPLE_DATA = True SELECTED_PROJ = ccrs.PlateCarree() EXAMPLE_FILE = "data/cosmo/samples/train/data_2015112800.zarr" -SAMPLE_GRIB = "/scratch/mch/sadamov/pyprojects_data/"\ - "neural_lam/data/cosmo/templates/lfff02180000" -SAMPLE_Z_GRIB = "/scratch/mch/sadamov/pyprojects_data/"\ - "neural_lam/data/cosmo/templates/lfff02180000z" +SAMPLE_GRIB = "/users/clechart/neural-lam/templates/lfff02180000" +SAMPLE_Z_GRIB = "/users/clechart/neural-lam/templates/lfff02180000z" CHUNK_SIZE = 100 EVAL_DATETIME = "2020100215" EVAL_PLOT_VARS = ["TQV"] diff --git a/slurm_predict.sh b/slurm_predict.sh index 34527ea7..cda9d225 100644 --- a/slurm_predict.sh +++ b/slurm_predict.sh @@ -1,16 +1,15 @@ #!/bin/bash -l #SBATCH --job-name=NeurWPp #SBATCH --account=s83 -#SBATCH --partition=normal +#SBATCH --partition=pp-short #SBATCH --nodes=1 #SBATCH --ntasks-per-node=4 -#SBATCH --mem=444G #SBATCH --time=00:59:00 #SBATCH --no-requeue #SBATCH --output=lightning_logs/neurwp_pred_out.log #SBATCH --error=lightning_logs/neurwp_pred_err.log -export PREPROCESS=false +export PREPROCESS=true export NORMALIZE=false export DATASET="cosmo" export MODEL="hi_lam" @@ -39,7 +38,7 @@ fi echo "Predicting with model" if [ "$MODEL" = "hi_lam" ]; then - srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict" + srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict" else - srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict" + srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict" fi From 0235591ae3bcb07993bcd3ec9350a7b71762bb66 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 16:10:44 +0200 Subject: [PATCH 12/27] adapting to tsa --- environment.yml | 1 + neural_lam/models/ar_model.py | 21 +++++++++------------ slurm_predict.sh | 4 ++-- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/environment.yml b/environment.yml index 69424afb..353f3bc7 100644 --- a/environment.yml +++ b/environment.yml @@ -7,6 +7,7 @@ dependencies: - Cartopy - dask - dask-jobqueue + - eccodes - imageio - ipython - matplotlib diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index fdd02377..5c1db41f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -824,7 +824,7 @@ def on_predict_epoch_end(self): prediction_array = prediction_rescaled.cpu().numpy() file_path = os.path.join(value_dir_path, f"prediction_{i}.npy") np.save(file_path, prediction_array) - self.save_pred_as_grib(prediction_rescaled, value_dir_path) + self.save_pred_as_grib(file_path, value_dir_path) # For plots for var_name, _ in self.selected_vars_units: @@ -849,6 +849,7 @@ def on_predict_epoch_end(self): for filename in images: image = imageio.imread(filename) writer.append_data(image) + self.spatial_loss_maps.clear() def _generate_time_steps(self): """Generate a list with all time steps in inference.""" @@ -856,8 +857,6 @@ def _generate_time_steps(self): base_time = constants.EVAL_DATETIMES[0] if isinstance(base_time, str): base_time = datetime.strptime(base_time, "%Y%m%d%H") - else: - base_time = base_time time_steps = {} # Generate dates for each step for i in range(constants.EVAL_HORIZON - 2): @@ -868,7 +867,7 @@ def _generate_time_steps(self): return time_steps - def save_pred_as_grib(self, prediction, value_dir_path): + def save_pred_as_grib(self, file_path, value_dir_path): """Save the prediction values into GRIB format.""" # Initialize the lists to loop over indices = self.precompute_variable_indices() @@ -880,7 +879,7 @@ def save_pred_as_grib(self, prediction, value_dir_path): for variable in constants.PARAM_NAMES_SHORT: # here find the key of the cariable in constants.is_3D # and if == 7, assign a cut of 7 on the reshape. Else 1 - shape_val = 7 if constants.IS_3D[variable] else 1 + shape_val = 13 if constants.IS_3D[variable] else 1 # Find the value range to sample value_range = indices[variable] @@ -896,23 +895,21 @@ def save_pred_as_grib(self, prediction, value_dir_path): ) md = subset.metadata() - # Cut the datestring into date and time and then override all - # values in md + # Cut the datestring into date and time and then override all + # values in md date = date_str[:8] time = date_str[8:] # Assuming md is a list of metadata dictionaries for metadata in md: - metadata.override({ - "date": date, - "time": time - }) + metadata.override({"date": date, "time": time}) if len(md) > 0: # Load the array to replace the values with # We need to still save it as a .npy # object and pass it on as an argument to this function - original_cut = prediction[ + replacement_data = np.load(file_path) + original_cut = replacement_data[ 0, time_idx, :, min(value_range) : max(value_range) + 1 ].reshape(582, 390, shape_val) cut_values = np.moveaxis( diff --git a/slurm_predict.sh b/slurm_predict.sh index cda9d225..8f9f3778 100644 --- a/slurm_predict.sh +++ b/slurm_predict.sh @@ -38,7 +38,7 @@ fi echo "Predicting with model" if [ "$MODEL" = "hi_lam" ]; then - srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict" + srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 0 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict" else - srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict" + srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 0 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict" fi From a66a04c33d1259e0f9ef6a6ad88d802c09d0cdaf Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 16:14:11 +0200 Subject: [PATCH 13/27] remove standalone version --- grib_modifier.py | 106 ----------------------------------------------- 1 file changed, 106 deletions(-) delete mode 100644 grib_modifier.py diff --git a/grib_modifier.py b/grib_modifier.py deleted file mode 100644 index 460394fb..00000000 --- a/grib_modifier.py +++ /dev/null @@ -1,106 +0,0 @@ -# Third-party -import cartopy.crs as ccrs -import cartopy.feature as cf -import earthkit.data -import numpy as np -import pygrib -from matplotlib import pyplot as plt - -# First-party -from neural_lam import constants - - -def plot_data( - grb, title, ax, projection, vmin, vmax, color_map="plasma", num_contours=100 -): - """Plot the data using Cartopy.""" - lats, lons = grb.latlons() - data = grb.values - - ax.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") - ax.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") - ax.set_title(title) - - contour = ax.contourf( - lons, - lats, - data, - transform=projection(), - levels=np.linspace(vmin, vmax, num_contours), - cmap=color_map, - ) - return contour - - -def modify_data(): - """Fit the numpy values into GRIB format.""" - # Load the original grib file - original_data = earthkit.data.from_source( - "file", "/users/clechart/clechart/neural-lam/laf2024042400" - ) - subset = original_data.sel(shortName="u", level=constants.VERTICAL_LEVELS) - - # Load the array to replace the values with - replacement_data = np.load( - "/users/clechart/clechart/neural-lam/" - "wandb/run-20240417_104748-dxnil3vw/files/" - "results/inference/prediction_0.npy" - ) - original_cut = replacement_data[0, 1, :, 26:33].reshape(582, 390, 7) - cut_values = np.moveaxis(original_cut, [-3, -2, -1], [-1, -2, -3]) - - # Create the modified GRIB file with the predicted data - modified_grib_path = "/users/clechart/clechart/neural-lam/modified_grib" - md = subset.metadata() - data_new = earthkit.data.FieldList.from_array(cut_values, md) - data_new.save(modified_grib_path) - - -def generate_plot(): - """Plot the original GRIB entries against transformed inference.""" - # Open the original and modified GRIB files - original_grib = pygrib.open( - "/users/clechart/clechart/neural-lam/laf2024042400" - ) - grb_original = original_grib.select( - shortName="u", level=constants.VERTICAL_LEVELS[0] - )[0] - - predicted_grib = pygrib.open( - "/users/clechart/clechart/neural-lam/modified_grib" - ) - grb_predicted = predicted_grib.select( - shortName="u", level=constants.VERTICAL_LEVELS[0] - )[0] - - # Determine the global min and max values for the colorbar - vmin = min(grb_original.values.min(), grb_predicted.values.min()) - vmax = max(grb_original.values.max(), grb_predicted.values.max()) - - fig, (ax1, ax2) = plt.subplots( - nrows=2, - figsize=constants.FIG_SIZE, - subplot_kw={"projection": ccrs.PlateCarree()}, - ) - contour1 = plot_data( - grb_original, "Original Data", ax1, ccrs.PlateCarree, vmin, vmax - ) - _ = plot_data( - grb_predicted, "Predicted Data", ax2, ccrs.PlateCarree, vmin, vmax - ) - - plt.subplots_adjust(hspace=0.1, wspace=0.05) - colorbar_ax = fig.add_axes( - [0.15, 0.08, 0.7, 0.02] - ) # Position for the colorbar - fig.colorbar( - contour1, cax=colorbar_ax, orientation="horizontal", shrink=0.5 - ) - - plt.savefig("completed_grib.png", bbox_inches="tight") - plt.close(fig) - - -if __name__ == "__main__": - modify_data() - generate_plot() From 1ff2fda691082364e356d2a246da3891e30da0de Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 16:14:37 +0200 Subject: [PATCH 14/27] rm metadata generator --- metadata_generator.py | 89 ------------------------------------------- 1 file changed, 89 deletions(-) delete mode 100644 metadata_generator.py diff --git a/metadata_generator.py b/metadata_generator.py deleted file mode 100644 index cbaf5e94..00000000 --- a/metadata_generator.py +++ /dev/null @@ -1,89 +0,0 @@ -import json -import os - -import numpy as np -from earthkit.data import FieldList -import metview as mv - -# Need to shuffle the metadata the same way as with constants (see message from Simon) -# Or extract the metadata at the stage where that happens -# Sort according to the last dimension - -class GRIBMetadata: - def __init__(self, grib_data): - self.grib_data = grib_data - - def display(self): - # Display each variable and its metadata in a readable format - for var, metadata in self.grib_data.items(): - print(f"Variable: {var}") - for key, value in metadata.items(): - print(f" {key}: {value}") - - def save_to_file(self, filename): - # Save the metadata dictionary to a JSON file - with open(filename, 'w') as f: - json.dump(self.grib_data, f, indent=4) - - -def map_zarr_to_grib_metadata(zarr_metadata): - # Convert metadata from Zarr format to a GRIB-like format - grib_metadata = {} - for key, value in zarr_metadata['metadata'].items(): - if '/.zarray' in key: - var_name = key.split('/')[0] - array_info = zarr_metadata['metadata'][f'{var_name}/.zarray'] - attrs_info = zarr_metadata['metadata'][f'{var_name}/.zattrs'] - - # Rearrange zmetadata according to the shuffling of constants.py - # -> Makes a subselection of the variables in the zarr archive and shuffles the indices - - if '_ARRAY_DIMENSIONS' in attrs_info: - dimensions = attrs_info['_ARRAY_DIMENSIONS'] - ny = array_info['shape'][dimensions.index('y_1')] if 'y_1' in dimensions else None - nx = array_info['shape'][dimensions.index('x_1')] if 'x_1' in dimensions else None - - grib_metadata[var_name] = { - 'GRIB_paramName': var_name, - 'GRIB_units': attrs_info.get('units', ''), - 'GRIB_dataType': array_info['dtype'], - 'GRIB_totalNumber': array_info['shape'][0] if 'time' in dimensions else 1, - 'GRIB_gridType': 'regular_ll', - 'GRIB_Ny': ny, - 'GRIB_Nx': nx, - 'GRIB_missingValue': array_info['fill_value'] - } - return GRIBMetadata(grib_metadata) - -def extract_grib_metadata_from_zarr(zarr_path): - # Load the Zarr dataset's metadata from the .zmetadata JSON file - metadata_path = os.path.join(zarr_path, '.zmetadata') - with open(metadata_path, 'r') as file: - metadata = json.load(file) - return map_zarr_to_grib_metadata(metadata) - - -def complete_data_into_grib(grib_metadata_object): - data = np.load("/users/clechart/clechart/neural-lam/wandb/run-20240417_104748-dxnil3vw/files/results/inference/prediction_0.npy") - # How do the pieces of code below work? - data = data.set_values(vals) - mv.write('recip.grib', data) - ds_new = FieldList.from_array(data, grib_metadata_object) - - print(ds_new) - return ds_new - -if __name__ == "__main__": - zarr_path = "/users/clechart/clechart/neural-lam/data/cosmo_old/samples/forecast/data_2020011017.zarr" - grib_metadata_object = extract_grib_metadata_from_zarr(zarr_path) - grib_metadata_object.display() # Display the extracted metadata - - # Save metadata to a file - grib_metadata_object.save_to_file('grib_metadata.json') - - # Reconstruct the GRIB file with the data - full_set = complete_data_into_grib(grib_metadata_object) - - -# How does one read an array as a grib? -# gribfile = xr.open_dataset(joinpath(path,filelist[1]),engine="cfgrib") \ No newline at end of file From 7e922c1fa118d43b6edc2739068d12d6ed04e82b Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 16:15:12 +0200 Subject: [PATCH 15/27] rm metadata automation --- slurm_metadata.sh | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 slurm_metadata.sh diff --git a/slurm_metadata.sh b/slurm_metadata.sh deleted file mode 100644 index 267ac812..00000000 --- a/slurm_metadata.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -l -#SBATCH --job-name=Metadata -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=4 -#SBATCH --partition=pp-short -#SBATCH --account=s83 -#SBATCH --output=lightning_logs/metadata_out.log -#SBATCH --error=lightning_logs/metadata_err.log -#SBATCH --time=00:03:00 -#SBATCH --no-requeue - -# Load necessary modules -conda activate neural-lam - - -ulimit -c 0 -export OMP_NUM_THREADS=16 - -srun -ul python grib_modifier.py From 025b839d35c9db63bc82d05e6ff28e64a4608a0f Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 16:17:48 +0200 Subject: [PATCH 16/27] rm generalized modifier --- generalized_modifier.py | 122 ---------------------------------------- 1 file changed, 122 deletions(-) delete mode 100644 generalized_modifier.py diff --git a/generalized_modifier.py b/generalized_modifier.py deleted file mode 100644 index 7d11ca72..00000000 --- a/generalized_modifier.py +++ /dev/null @@ -1,122 +0,0 @@ -# Third-party -from datetime import datetime, timedelta -import earthkit.data -import numpy as np - -# First-party -from neural_lam import constants - - -def modify_data(prediction: np.array): - """Fit the numpy values into GRIB format.""" - - indices = precompute_variable_indices() - time_steps = generate_time_steps() - vertical_levels = [1, 5, 13, 22, 38, 41, 60] - - # Initialize final data object - final_data = earthkit.data.FieldList() - - # Loop through all the time steps and all the variables - for time_idx, date_str in time_steps.items(): - - # ATN between 3D and 2D - 7vs1 lvl (for reshaping) - for variable in constants.PARAM_NAMES_SHORT: - # here find the key of the cariable in constants.is_3D and if == 7, assign a cut of 7 on the - # reshape. Else 1 - shape_val = 7 if constants.IS_3D[variable] else 1 - # Find the value range to sample - value_range = indices[variable] - - sample_file = constants.SAMPLE_GRIB - if variable == "RELHUM": - variable = "r" - sample_file = constants.SAMPLE_Z_GRIB - - # Load the sample grib file - original_data = earthkit.data.from_source( - "file", sample_file - ) - - subset = original_data.sel(shortName= variable.lower(), level=vertical_levels) - md = subset.metadata() - - # Cut the datestring into date and time and then override all - # values in md - date = date_str[:8] - time = date_str[8:] - - # Assuming md is a list of metadata dictionaries - for metadata in md: - metadata.override({ - "date": date, - "time": time - }) - if len(md)>0: - # Load the array to replace the values with - # We need to still save it as a .npy object and pass it on as an argument to this function - replacement_data = np.load(prediction) - original_cut = replacement_data[0, time_idx, :, min(value_range):max(value_range)+1].reshape(582, 390, shape_val) - cut_values = np.moveaxis(original_cut, [-3, -2, -1], [-1, -2, -3]) - # Can we stack Fieldlists? - data_new = earthkit.data.FieldList.from_array(cut_values, md) - final_data += data_new - - # Create the modified GRIB file with the predicted data - modified_grib_path =f"lightning_logs/prediction_{date_str}" - final_data.save(modified_grib_path) - -# This function is taken from ar_model, need to just use self when I go -# put this function into on_predict_epoch_end() -def precompute_variable_indices(): - """ - Precompute indices for each variable in the input tensor - """ - variable_indices = {} - all_vars = [] - index = 0 - # Create a list of tuples for all variables, using level 0 for 2D - # variables - for var_name in constants.PARAM_NAMES_SHORT: - if constants.IS_3D[var_name]: - for level in constants.VERTICAL_LEVELS: - all_vars.append((var_name, level)) - else: - all_vars.append((var_name, 0)) # Use level 0 for 2D variables - - # Sort the variables based on the tuples - sorted_vars = sorted(all_vars) - - for var in sorted_vars: - var_name, level = var - if var_name not in variable_indices: - variable_indices[var_name] = [] - variable_indices[var_name].append(index) - index += 1 - - return variable_indices - - -def generate_time_steps(): - # Parse the times - base_time = constants.EVAL_DATETIMES[0] - if isinstance(base_time, str): - base_time = datetime.strptime(base_time, "%Y%m%d%H") - else: - base_time = base_time - time_steps = {} - # Generate dates for each step - for i in range(constants.EVAL_HORIZON - 2): - # Compute the new date by adding the step interval in hours - 3 - new_date = base_time + timedelta(hours=i * constants.TRAIN_HORIZON) - # Format the date back - time_steps[i] = new_date.strftime("%Y%m%d%H") - - return time_steps - - - -if __name__ == "__main__": - precompute_variable_indices() - time_steps = generate_time_steps() - modify_data(prediction="/users/clechart/neural-lam/templates/predictions.npy") From 00053b035ea8fdbacdb3c3348facd0a1f8bff420 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 1 May 2024 16:27:11 +0200 Subject: [PATCH 17/27] fixing the environment --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 353f3bc7..af62e8f9 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - xarray - zarr - pip: - - earthkit.data + - earthkit-data - tueplots - codespell>=2.0.0 - black>=21.9b0 From df9b595e8ae6996ad5848319ff9d0c37a30f7a69 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 2 May 2024 16:57:36 +0200 Subject: [PATCH 18/27] fixing relationship to metadata --- neural_lam/constants.py | 18 ++++++++++++++++++ neural_lam/models/ar_model.py | 4 ++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 5d8b893c..df1291cf 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -143,6 +143,24 @@ "V_10M": 0, } +GRIB_NAME = { + "PP": 'pres', + "QV": 'q', + "RELHUM": 'r', + "T": 't', + "U": 'u', + "V": 'v', + "W": 'wz', + "CLCT": 'ccl', + "PMSL": 'prmsl', + "PS": 'sp', + "T_2M": '2t', + "TOT_PREC": 'tp', + "U_10M":'10u', + "V_10M": '10v' +} + + # Vertical level weights # These were retrieved based on the pressure levels of # https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5 diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 5c1db41f..617ebe3c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -876,7 +876,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): final_data = earthkit.data.FieldList() # Loop through all the time steps and all the variables for time_idx, date_str in time_steps.items(): - for variable in constants.PARAM_NAMES_SHORT: + for variable, grib_code in constants.GRIB_NAME.items(): # here find the key of the cariable in constants.is_3D # and if == 7, assign a cut of 7 on the reshape. Else 1 shape_val = 13 if constants.IS_3D[variable] else 1 @@ -891,7 +891,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): original_data = earthkit.data.from_source("file", sample_file) subset = original_data.sel( - shortName=variable.lower(), level=constants.VERTICAL_LEVELS + shortName=grib_code, level=constants.VERTICAL_LEVELS ) md = subset.metadata() From 0f5217a9632a24da358cbccbd6f781c8b315f91b Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 2 May 2024 17:14:50 +0200 Subject: [PATCH 19/27] works standalone --- neural_lam/models/ar_model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 617ebe3c..e3b421ee 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -879,7 +879,12 @@ def save_pred_as_grib(self, file_path, value_dir_path): for variable, grib_code in constants.GRIB_NAME.items(): # here find the key of the cariable in constants.is_3D # and if == 7, assign a cut of 7 on the reshape. Else 1 - shape_val = 13 if constants.IS_3D[variable] else 1 + if constants.IS_3D[variable]: + shape_val = 13 + vertical = constants.VERTICAL_LEVELS + else: + shape_val = 1 + vertical = 1 # Find the value range to sample value_range = indices[variable] @@ -891,7 +896,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): original_data = earthkit.data.from_source("file", sample_file) subset = original_data.sel( - shortName=grib_code, level=constants.VERTICAL_LEVELS + shortName=grib_code, level=vertical ) md = subset.metadata() From c2bb7a115ae18b38520a74010a9d9c93ea9533e9 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Thu, 2 May 2024 17:31:41 +0200 Subject: [PATCH 20/27] stuff is broken I do not comprehend --- neural_lam/constants.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/neural_lam/constants.py b/neural_lam/constants.py index df1291cf..728a49b3 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -201,12 +201,12 @@ EVAL_PLOT_VARS = ["T_2M"] STORE_EXAMPLE_DATA = True SELECTED_PROJ = ccrs.PlateCarree() -EXAMPLE_FILE = "data/cosmo/samples/train/data_2015112800.zarr" +EXAMPLE_FILE = "data/cosmo/samples/train/data.zarr" SAMPLE_GRIB = "/users/clechart/neural-lam/templates/lfff02180000" SAMPLE_Z_GRIB = "/users/clechart/neural-lam/templates/lfff02180000z" CHUNK_SIZE = 100 -EVAL_DATETIME = "2020100215" -EVAL_PLOT_VARS = ["TQV"] +EVAL_DATETIME = ["2020100215"] +EVAL_PLOT_VARS = ["QV"] STORE_EXAMPLE_DATA = False COSMO_PROJ = ccrs.PlateCarree() SELECTED_PROJ = COSMO_PROJ From 3b4baadb6039091c25f56589daf5409843db8930 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Sat, 4 May 2024 15:13:11 +0200 Subject: [PATCH 21/27] pre-commit fixes --- .pre-commit-config.yaml | 2 +- neural_lam/constants.py | 28 ++++++++++++++-------------- neural_lam/models/ar_model.py | 6 ++---- requirements.txt | 1 + 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a32ddc51..10fbf48b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: description: Check for spelling errors language: system entry: codespell - args: ['--ignore-words-list=laf'] + args: ['--ignore-words-list=laf,pres'] - repo: local hooks: - id: black diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 728a49b3..7155c769 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -144,20 +144,20 @@ } GRIB_NAME = { - "PP": 'pres', - "QV": 'q', - "RELHUM": 'r', - "T": 't', - "U": 'u', - "V": 'v', - "W": 'wz', - "CLCT": 'ccl', - "PMSL": 'prmsl', - "PS": 'sp', - "T_2M": '2t', - "TOT_PREC": 'tp', - "U_10M":'10u', - "V_10M": '10v' + "PP": "pres", + "QV": "q", + "RELHUM": "r", + "T": "t", + "U": "u", + "V": "v", + "W": "wz", + "CLCT": "ccl", + "PMSL": "prmsl", + "PS": "sp", + "T_2M": "2t", + "TOT_PREC": "tp", + "U_10M": "10u", + "V_10M": "10v", } diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index e3b421ee..a61a3ed0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -880,7 +880,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): # here find the key of the cariable in constants.is_3D # and if == 7, assign a cut of 7 on the reshape. Else 1 if constants.IS_3D[variable]: - shape_val = 13 + shape_val = 13 vertical = constants.VERTICAL_LEVELS else: shape_val = 1 @@ -895,9 +895,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): # Load the sample grib file original_data = earthkit.data.from_source("file", sample_file) - subset = original_data.sel( - shortName=grib_code, level=vertical - ) + subset = original_data.sel(shortName=grib_code, level=vertical) md = subset.metadata() # Cut the datestring into date and time and then override all diff --git a/requirements.txt b/requirements.txt index a3af0d68..697f6ce9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ wandb>=0.13.10 matplotlib>=3.7.0 dask dask_jobqueue +earthkit-data scipy>=1.10.0 pytorch-lightning>=2.0.3 shapely>=2.0.1 From 88e2d46865b19e51517b9815c157e64ba153893c Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Mon, 6 May 2024 13:36:01 +0200 Subject: [PATCH 22/27] Adressing first comments --- neural_lam/constants.py | 9 ++------- neural_lam/models/ar_model.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 7155c769..93536a48 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -201,15 +201,10 @@ EVAL_PLOT_VARS = ["T_2M"] STORE_EXAMPLE_DATA = True SELECTED_PROJ = ccrs.PlateCarree() -EXAMPLE_FILE = "data/cosmo/samples/train/data.zarr" -SAMPLE_GRIB = "/users/clechart/neural-lam/templates/lfff02180000" +SAMPLE_GRIB = "neural-lam/templates/lfff02180000" +SAMPLE_Z_GRIB = "neural-lam/templates/lfff02180000z" SAMPLE_Z_GRIB = "/users/clechart/neural-lam/templates/lfff02180000z" -CHUNK_SIZE = 100 EVAL_DATETIME = ["2020100215"] -EVAL_PLOT_VARS = ["QV"] -STORE_EXAMPLE_DATA = False -COSMO_PROJ = ccrs.PlateCarree() -SELECTED_PROJ = COSMO_PROJ POLLON = -170.0 POLLAT = 43.0 SMOOTH_BOUNDARIES = False diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a61a3ed0..776b49cc 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -880,7 +880,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): # here find the key of the cariable in constants.is_3D # and if == 7, assign a cut of 7 on the reshape. Else 1 if constants.IS_3D[variable]: - shape_val = 13 + shape_val = len(constants.VERTICAL_LEVELS) vertical = constants.VERTICAL_LEVELS else: shape_val = 1 @@ -914,7 +914,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): replacement_data = np.load(file_path) original_cut = replacement_data[ 0, time_idx, :, min(value_range) : max(value_range) + 1 - ].reshape(582, 390, shape_val) + ].reshape(constants.GRID_SHAPE[1], constants.GRID_SHAPE[0], shape_val) cut_values = np.moveaxis( original_cut, [-3, -2, -1], [-1, -2, -3] ) From 5c41464c8380cfebf250f4aa2e94eb1395a349bf Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Mon, 6 May 2024 16:15:49 +0200 Subject: [PATCH 23/27] solving issues wrt param_weights and slurm jobs specific to cosmo data on Tsa --- create_parameter_weights.py | 2 +- neural_lam/models/ar_model.py | 8 ++++++-- neural_lam/weather_dataset.py | 3 ++- slurm_param.sh | 1 - slurm_predict.sh | 6 +++--- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/create_parameter_weights.py b/create_parameter_weights.py index f1e03714..1ffde1d4 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -82,7 +82,7 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name batch_size=args.batch_size, num_workers=args.n_workers, ) - data_module.setup(stage="fit") + data_module.setup(stage="train") train_sampler = DistributedSampler( data_module.train_dataset, num_replicas=world_size, rank=rank diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 776b49cc..2b9d4277 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -880,7 +880,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): # here find the key of the cariable in constants.is_3D # and if == 7, assign a cut of 7 on the reshape. Else 1 if constants.IS_3D[variable]: - shape_val = len(constants.VERTICAL_LEVELS) + shape_val = len(constants.VERTICAL_LEVELS) vertical = constants.VERTICAL_LEVELS else: shape_val = 1 @@ -914,7 +914,11 @@ def save_pred_as_grib(self, file_path, value_dir_path): replacement_data = np.load(file_path) original_cut = replacement_data[ 0, time_idx, :, min(value_range) : max(value_range) + 1 - ].reshape(constants.GRID_SHAPE[1], constants.GRID_SHAPE[0], shape_val) + ].reshape( + constants.GRID_SHAPE[1], + constants.GRID_SHAPE[0], + shape_val, + ) cut_values = np.moveaxis( original_cut, [-3, -2, -1], [-1, -2, -3] ) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 2cf532c2..9670a5c1 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -257,7 +257,7 @@ def __init__( self.predict_dataset = None def setup(self, stage=None): - if stage == "fit" or stage is None: + if stage == "train" or stage is None: self.train_dataset = WeatherDataset( dataset_name=self.dataset_name, split="train", @@ -265,6 +265,7 @@ def setup(self, stage=None): subset=self.subset, batch_size=self.batch_size, ) + elif stage == "val": self.val_dataset = WeatherDataset( dataset_name=self.dataset_name, split="val", diff --git a/slurm_param.sh b/slurm_param.sh index ecc98c0e..6eb869e1 100644 --- a/slurm_param.sh +++ b/slurm_param.sh @@ -4,7 +4,6 @@ #SBATCH --time=24:00:00 #SBATCH --nodes=2 #SBATCH --partition=postproc -#SBATCH --mem=444G #SBATCH --no-requeue #SBATCH --exclusive #SBATCH --output=lightning_logs/neurwp_param_out.log diff --git a/slurm_predict.sh b/slurm_predict.sh index 8f9f3778..039bddf0 100644 --- a/slurm_predict.sh +++ b/slurm_predict.sh @@ -9,7 +9,7 @@ #SBATCH --output=lightning_logs/neurwp_pred_out.log #SBATCH --error=lightning_logs/neurwp_pred_err.log -export PREPROCESS=true +export PREPROCESS=false export NORMALIZE=false export DATASET="cosmo" export MODEL="hi_lam" @@ -38,7 +38,7 @@ fi echo "Predicting with model" if [ "$MODEL" = "hi_lam" ]; then - srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 0 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict" + srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 1 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict" else - srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 0 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict" + srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 1 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict" fi From 2d02c4f854ff40cc265ea09292d4661d08c87b86 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Tue, 7 May 2024 12:04:51 +0200 Subject: [PATCH 24/27] fixing the grib generation --- neural_lam/constants.py | 5 ++--- neural_lam/models/ar_model.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 93536a48..4140b897 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -201,9 +201,8 @@ EVAL_PLOT_VARS = ["T_2M"] STORE_EXAMPLE_DATA = True SELECTED_PROJ = ccrs.PlateCarree() -SAMPLE_GRIB = "neural-lam/templates/lfff02180000" -SAMPLE_Z_GRIB = "neural-lam/templates/lfff02180000z" -SAMPLE_Z_GRIB = "/users/clechart/neural-lam/templates/lfff02180000z" +SAMPLE_GRIB = "templates/lfff02180000" +SAMPLE_Z_GRIB = "templates/lfff02180000z" EVAL_DATETIME = ["2020100215"] POLLON = -170.0 POLLAT = 43.0 diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 2b9d4277..ba3aa4f5 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -872,10 +872,10 @@ def save_pred_as_grib(self, file_path, value_dir_path): # Initialize the lists to loop over indices = self.precompute_variable_indices() time_steps = self._generate_time_steps() - # Initialize final data object - final_data = earthkit.data.FieldList() # Loop through all the time steps and all the variables for time_idx, date_str in time_steps.items(): + # Initialize final data object + final_data = earthkit.data.FieldList() for variable, grib_code in constants.GRIB_NAME.items(): # here find the key of the cariable in constants.is_3D # and if == 7, assign a cut of 7 on the reshape. Else 1 @@ -890,6 +890,7 @@ def save_pred_as_grib(self, file_path, value_dir_path): sample_file = constants.SAMPLE_GRIB if variable == "RELHUM": + variable = "r" sample_file = constants.SAMPLE_Z_GRIB # Load the sample grib file @@ -903,14 +904,12 @@ def save_pred_as_grib(self, file_path, value_dir_path): date = date_str[:8] time = date_str[8:] - # Assuming md is a list of metadata dictionaries - for metadata in md: - metadata.override({"date": date, "time": time}) - + for index, item in enumerate(md): + md[index] = item.override({"date": date}).override( + {"time": time} + ) if len(md) > 0: # Load the array to replace the values with - # We need to still save it as a .npy - # object and pass it on as an argument to this function replacement_data = np.load(file_path) original_cut = replacement_data[ 0, time_idx, :, min(value_range) : max(value_range) + 1 @@ -927,7 +926,6 @@ def save_pred_as_grib(self, file_path, value_dir_path): cut_values, md ) final_data += data_new - # Create the modified GRIB file with the predicted data grib_path = os.path.join( value_dir_path, f"prediction_{date_str}_grib" From 66fdd92a1d593560516d1d9f4be7d2cf3a9d5bc7 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Mon, 27 May 2024 16:39:31 +0200 Subject: [PATCH 25/27] retrieving 2M and 10M levels --- neural_lam/models/ar_model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index ba3aa4f5..f4dd8f2f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -883,8 +883,16 @@ def save_pred_as_grib(self, file_path, value_dir_path): shape_val = len(constants.VERTICAL_LEVELS) vertical = constants.VERTICAL_LEVELS else: - shape_val = 1 - vertical = 1 + # Special handling for T_2M and *_10M variables + if variable == "T_2M": + shape_val = 1 + vertical = 2 + elif variable.endswith("_10M"): + shape_val = 1 + vertical = 10 + else: + shape_val = 1 + vertical = 0 # Find the value range to sample value_range = indices[variable] From 17393df7e481604aa0a9e67c7f2153bddcb3f09e Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Mon, 27 May 2024 16:44:51 +0200 Subject: [PATCH 26/27] changing codespell issue --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 55c07c25..74132ac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ line-length = 80 [tool.isort] -default_section = "THIRDPARTY" +default_section = "THIRD PARTY" profile = "black" # Headings import_heading_stdlib = "Standard library" From 19e5e82979838e726670226a6b462ce3b5bb8449 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Tue, 28 May 2024 09:02:19 +0200 Subject: [PATCH 27/27] reverting pyproject changes --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 74132ac9..55c07c25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ line-length = 80 [tool.isort] -default_section = "THIRD PARTY" +default_section = "THIRDPARTY" profile = "black" # Headings import_heading_stdlib = "Standard library"