Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/save intermediate ml diag data #200

Merged
merged 26 commits into from
Mar 30, 2020

Conversation

AnnaKwa
Copy link
Contributor

@AnnaKwa AnnaKwa commented Mar 25, 2020

Refactor for offline ML diagnostics workflow.

  • single dataset input to metrics and diagnostics functions
  • separation of "metrics" vs. "diagnostic" quantities
    • metrics: R^2 (global values for 2d quantities, pressure level profiles for 3d) and RMSE
    • diagnostics: ML dQ vs total maps, LTS, Vertical dQ2 profiles in wet/dry columns, diurnal cycle, time avg and snapshots of net precip and heating compared across datasets

@AnnaKwa AnnaKwa requested a review from nbren12 March 25, 2020 21:05
@@ -3,116 +3,75 @@
import numpy as np
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Github isn't showing this diff by default since it is large. This file originally contained all the plotting functions for metrics and diagnostics; the main changes are

  • some of the "metrics" plots got moved out
  • plotting functions take the same single common dataset input

@AnnaKwa
Copy link
Contributor Author

AnnaKwa commented Mar 25, 2020

sample format of saved metrics netcdf:

<xarray.Dataset>
Dimensions:                                (grid_x: 49, grid_xt: 48, grid_y: 49, grid_yt: 48, initialization_time: 4, pressure: 37, tile: 6)
Coordinates:
    time                                   object ...
    dataset                                object ...
  * tile                                   (tile) int64 0 1 2 3 4 5
  * grid_xt                                (grid_xt) float64 1.0 2.0 ... 48.0
  * grid_yt                                (grid_yt) float64 1.0 2.0 ... 48.0
  * grid_x                                 (grid_x) float64 1.0 2.0 ... 49.0
  * grid_y                                 (grid_y) float64 1.0 2.0 ... 49.0
  * pressure                               (pressure) float64 1.0 2.0 ... 1e+03
  * initialization_time                    (initialization_time) object 2016-08-05 09:45:00 ... 2016-08-05 10:30:00
Data variables:
    R2_global_net_heating_vs_target        float64 ...
    R2_global_net_heating_vs_hires         float64 ...
    R2_sea_net_heating_vs_target           float64 ...
    R2_sea_net_heating_vs_hires            float64 ...
    R2_land_net_heating_vs_target          float64 ...
    R2_land_net_heating_vs_hires           float64 ...
    R2_global_net_precipitation_vs_target  float64 ...
    R2_global_net_precipitation_vs_hires   float64 ...
    R2_sea_net_precipitation_vs_target     float64 ...
    R2_sea_net_precipitation_vs_hires      float64 ...
    R2_land_net_precipitation_vs_target    float64 ...
    R2_land_net_precipitation_vs_hires     float64 ...
    lat                                    (tile, grid_yt, grid_xt) float32 ...
    latb                                   (tile, grid_y, grid_x) float32 ...
    lon                                    (tile, grid_yt, grid_xt) float32 ...
    lonb                                   (tile, grid_y, grid_x) float32 ...
    area                                   (tile, grid_yt, grid_xt) float32 ...
    r2_dQ1_pressure_levels_global          (pressure) float64 ...
    r2_dQ2_pressure_levels_global          (pressure) float64 ...
    r2_dQ1_pressure_levels_sea             (pressure) float64 ...
    r2_dQ2_pressure_levels_sea             (pressure) float64 ...
    r2_dQ1_pressure_levels_land            (pressure) float64 ...
    r2_dQ2_pressure_levels_land            (pressure) float64 ...
    mse_net_precipitation_vs_fv3_target    (grid_xt, grid_yt, initialization_time, tile) float64 ...
    mse_net_precipitation_vs_shield        (grid_xt, grid_yt, initialization_time, tile) float64 ...
    mse_net_heating_vs_fv3_target          (grid_xt, grid_yt, initialization_time, tile) float64 ...
    mse_net_heating_vs_shield              (grid_xt, grid_yt, initialization_time, tile) float64 ...

Copy link
Contributor

@nbren12 nbren12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Anna, the main method is now nice and clean. I think some attention should be placed on removing the use of global constants in the metrics.py file. I think my suggested refactors will make this code both more reusable and robust to future changes in variables names and conventions.

@@ -126,3 +130,46 @@ def net_heating_from_dataset(ds: xr.Dataset, suffix: str = None) -> xr.DataArray
ds["PRATEsfc" + suffix],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hard code here. I expect this name will change in future versions.


# create and save metrics dataset
# metrics: r2 global values, r2 pressure level profiles, MSE at locations
ds_metrics = create_metrics_dataset(ds)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This high level structure is very good.

@AnnaKwa AnnaKwa requested a review from nbren12 March 27, 2020 00:41
@AnnaKwa
Copy link
Contributor Author

AnnaKwa commented Mar 27, 2020

Thanks for the review @nbren12 , ready for re-review.

Copy link
Contributor

@nbren12 nbren12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for all the changes. I have some very minor comments below which you don't necessarily have to address if you don't want to.

ds_test = ds.sel(dataset=DATASET_NAME_FV3_TARGET)
ds_hires = ds.sel(dataset=DATASET_NAME_SHIELD_HIRES)

ds_metrics = create_metrics_dataset(ds_pred, ds_test, ds_hires)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! this structure is pretty clear.

return ds_metrics


def plot_metrics(ds_metrics, output_dir, dpi_figures):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider moving these plotting routines to another module to separate the plotting and computation code even more.

@AnnaKwa AnnaKwa merged commit 2cceac3 into master Mar 30, 2020
@AnnaKwa AnnaKwa deleted the feature/save-intermediate-ml-diag-data branch March 30, 2020 18:55
spencerkclark pushed a commit that referenced this pull request May 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants