Skip to content

Commit

Permalink
Improve gcm driven single column calibrations: modify priors, improve…
Browse files Browse the repository at this point in the history
… plotting scripts, add rmse metrics, and parallelize cases over cpu cores
  • Loading branch information
costachris committed Sep 25, 2024
1 parent ec15a81 commit eb05766
Show file tree
Hide file tree
Showing 19 changed files with 1,673 additions and 646 deletions.
1,086 changes: 644 additions & 442 deletions calibration/experiments/gcm_driven_scm/Manifest.toml

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions calibration/experiments/gcm_driven_scm/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ClimaAnalysis = "29b5916a-a76c-4e73-9657-3c8fd22e65e6"
ClimaAtmos = "b2c96348-7fb7-4fe0-8da9-78d88439e717"
ClimaCalibrate = "4347a170-ebd6-470c-89d3-5c705c0cacc2"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaParams = "5c42b081-d73a-476f-9059-fd94b934656c"
ClimaUtilities = "b3f4f4ca-9299-4f7f-bd9b-81e1242a7513"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EnsembleKalmanProcesses = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
Expand All @@ -17,5 +20,4 @@ YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
ClimaCalibrate = "=0.0.3"
EnsembleKalmanProcesses = "=1.1.7"
ClimaAtmos = "=0.27.2"
EnsembleKalmanProcesses = "2"
14 changes: 9 additions & 5 deletions calibration/experiments/gcm_driven_scm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@

This setup provide tools for calibrating both prognostic and diagnostic EDMF variants to LES profiles, given the same forcings and boundary conditions. The gcm-driven EDMF setup is employed in single-column mode, which uses both interactive radiation and surface fluxes. Forcing profiles include resolved eddy advection, horizontal advection, subsidence, and GCM state relaxation. The setup is run to the top of the atmosphere to compute radiation, but calibrations statistics are computed only on the lower 4km (`z_max`), where LES output is available.

LES profiles are available for different geolocations ("cfsites"), spanning seasons, forcing host models, and climates (AMIP, AMIP4K). A given LES simulation is referred to as a "configuration". Calibrations employ batching by default and stack multiple configurations (a number equal to the `batch_size`) in a given iteration. The observation vector for a single configuration is formed by concatenating profiles across calibration variables, where each geophysical variable is normalized to have approximately unit variance and zero mean. These variable-by-variable normalization factors are precomputed (`norm_factors_dict`) and applied to all observations. Following this operation, the spatiotemporal calibration window is applied and temporal means are computed to form the observation vector `y`. Because variables are normalized, a constant, diagonal noise matrix is used (configurable as `const_noise`).
LES profiles are available for different geolocations ("cfsites"), spanning seasons, forcing host models, and climates (AMIP, AMIP4K). A given LES simulation is referred to as a "configuration". Calibrations employ batching by default and stack multiple configurations (a number equal to the `batch_size`) in a given iteration. The observation vector for a single configuration is formed by concatenating profiles across calibration variables, where each geophysical variable is normalized to have approximately unit variance and zero mean. These variable-by-variable normalization factors are precomputed (`norm_factors_dict`) and applied to all observations. Following this operation, the spatiotemporal calibration window is applied and temporal means are computed to form the observation vector `y`. Because variables are normalized to have 0 mean and unit variance, a constant diagonal noise matrix is used (configurable as `const_noise`).


## Getting Started

### Define calibration and model configurations:
- `experiment_config.yml` - Configuration of EKI hyperparameters and settings, spatiotemporal calibration window, required pipeline file paths
- `experiment_config.yml` - Configuration of calibration settings, including spatiotemporal calibration window and required pipeline file paths.
- `run_calibration.jl` - run script for calibration pipeline. EKI settings and hyperparameters can be modified where CAL.initialize is called.
- `model_config_**.yml` - Config file for underlying ClimaAtmos single column model
- `get_les_metadata.jl` - (Re)Define `get_les_calibration_library()` to specify which LES configurations to use
- `get_les_metadata.jl` - (Re)Define `get_les_calibration_library()` to specify which LES configurations to use. Set `batch_size` in the `experiment_config.yml` accordingly (<= the number of cases).

### Run with:
- `run_calibration.jl` - runs calibration end-to-end using HPC resources
- `sbatch run_calibration.sbatch` - schedules and runs calibration pipeline end-to-end using HPC resources
- `julia --project run_calibration.jl` - interactively runs calibration end-to-end using HPC resources, streaming to a Julia REPL

### Analyze output with:
- `plot_ensemble.jl` - plots vertical profiles of all ensemble members in a given iteration.
- `julia --project plot_ensemble.jl` - plots vertical profiles of all ensemble members in a given iteration, given path to calibration output
- `julia --project edmf_ensemble_stats.jl` - computes and plots metrics offline [i.e., root mean squared error (RMSE)] as a function of iteration, given path to calibration output.
- `julia --project plot_eki.jl` - plot eki metrics [loss, var-weighted loss] and `y`, `g` vectors vs iteration, display best particles


326 changes: 326 additions & 0 deletions calibration/experiments/gcm_driven_scm/edmf_ensemble_stats.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
#!/usr/bin/env julia

import ClimaComms
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

using ArgParse
using Distributed
addprocs()


@everywhere begin
using EnsembleKalmanProcesses: TOMLInterface
import EnsembleKalmanProcesses as EKP
using EnsembleKalmanProcesses.ParameterDistributions
using ClimaCalibrate: observation_map, ExperimentConfig
using ClimaAnalysis
using Plots
using JLD2
using Statistics
using YAML
using DataFrames
using CSV

include("helper_funcs.jl")
include("observation_map.jl")
include("get_les_metadata.jl")
end

function parse_with_settings(s)
return ArgParse.parse_args(s)
end

function parse_args()
s = ArgParseSettings(description = "Process ensemble Kalman statistics")
@add_arg_table s begin
"--output_dir"
help = "Calibration output directory"
required = true
"--var_names"
help = "Variable names to process (comma-separated)"
default = "thetaa,hus,clw,arup,entr,detr,waup,tke"
"--reduction"
help = "Reduction method to use (default: inst)"
default = "inst"
"--iterations"
help = "Iterations to plot (e.g., 0:11), default is all iterations"
default = nothing
"--save_as_csv"
help = "Save results as CSV"
default = true
arg_type = Bool
"--load_from_csv"
help = "Load results from CSV"
default = false
arg_type = Bool
"--plot_dir"
help = "Directory to save plots (default: edmf_stats_plots)"
default = "edmf_stats_plots"
end
return parse_with_settings(s)
end

function main()
args = parse_args()

output_dir = args["output_dir"]
var_names = map(String, split(args["var_names"], ","))
reduction = args["reduction"]
save_as_csv = args["save_as_csv"]
load_from_csv = args["load_from_csv"]
plot_dir = args["plot_dir"]

if isnothing(args["iterations"])
iterations = nothing
else
iterations = eval(Meta.parse(args["iterations"]))
end

mkpath(joinpath(output_dir, "plots", plot_dir))

# Load configuration data
config_dict =
YAML.load_file(joinpath(output_dir, "configs", "experiment_config.yml"))
n_vert_levels = config_dict["dims_per_var"]
z_max = config_dict["z_max"]
ensemble_size = config_dict["ensemble_size"]
cal_vars = config_dict["y_var_names"]
const_noise_by_var = config_dict["const_noise_by_var"]
n_iterations = config_dict["n_iterations"]
model_config_dict =
YAML.load_file(joinpath(output_dir, "configs", "model_config.yml"))

if isnothing(iterations)
iterations = 0:(n_iterations - 1)
end

ref_paths, _ = get_les_calibration_library()
comms_ctx = ClimaComms.SingletonCommsContext()
atmos_config = CA.AtmosConfig(model_config_dict; comms_ctx)
zc_model = get_z_grid(atmos_config, z_max = z_max)

@everywhere function calculate_statistics(y_var)
non_nan_values = y_var[.!isnan.(y_var)]
if length(non_nan_values) == 0
return NaN, NaN, NaN
end
col_mean = mean(non_nan_values)
col_max = maximum(non_nan_values)
col_min = minimum(non_nan_values)
return col_mean, col_max, col_min
end

@everywhere function compute_ensemble_squared_error(ensemble_data, y_true)
return vec(sum((ensemble_data .- y_true) .^ 2, dims = 1))
end

@everywhere function process_iteration(
iteration,
output_dir,
var_names,
n_vert_levels,
config_dict,
z_max,
cal_vars,
const_noise_by_var,
ref_paths,
zc_model,
reduction,
ensemble_size,
)
println("Processing Iteration: $iteration")
stats_df = DataFrame(
iteration = Int[],
var_name = String[],
mean = Float64[],
max = Float64[],
min = Float64[],
rmse = Union{Missing, Float64}[],
rmse_min = Union{Missing, Float64}[],
rmse_max = Union{Missing, Float64}[],
rmse_std = Union{Missing, Float64}[],
)
config_indices = get_batch_indicies_in_iteration(iteration, output_dir)
for var_name in var_names
means = Float64[]
maxs = Float64[]
mins = Float64[]
sum_squared_errors = zeros(Float64, ensemble_size)
for config_i in config_indices
data = ensemble_data(
process_profile_variable,
iteration,
config_i,
config_dict;
var_name = var_name,
reduction = reduction,
output_dir = output_dir,
z_max = z_max,
n_vert_levels = n_vert_levels,
)
for i in 1:size(data, 2)
y_var = data[:, i]
col_mean, col_max, col_min = calculate_statistics(y_var)
push!(means, col_mean)
push!(maxs, col_max)
push!(mins, col_min)
end
if in(var_name, cal_vars)
y_true, Σ_obs, norm_vec_obs = get_obs(
ref_paths[config_i],
[var_name],
zc_model;
ti = config_dict["y_t_start_sec"],
tf = config_dict["y_t_end_sec"],
Σ_const = const_noise_by_var,
z_score_norm = false,
)
sum_squared_errors +=
compute_ensemble_squared_error(data, y_true)
end
end
if in(var_name, cal_vars)
# Compute RMSE per ensemble member
rmse_per_member = sqrt.(sum_squared_errors / n_vert_levels)
# Filter out NaNs (failed simulations)
valid_rmse = rmse_per_member[.!isnan.(rmse_per_member)]
non_nan_simulation_count = length(valid_rmse)
mean_rmse = mean(valid_rmse)
min_rmse = minimum(valid_rmse)
max_rmse = maximum(valid_rmse)
rmse_std = std(valid_rmse)
else
mean_rmse = missing
min_rmse = missing
max_rmse = missing
rmse_std = missing
end
push!(
stats_df,
(
iteration,
var_name,
mean(means[.!isnan.(means)]),
maximum(maxs[.!isnan.(maxs)]),
minimum(mins[.!isnan.(mins)]),
mean_rmse,
min_rmse,
max_rmse,
rmse_std,
),
)
end
return stats_df
end

if !load_from_csv
iterations_list = collect(iterations)
stats_dfs = pmap(
iteration -> process_iteration(
iteration,
output_dir,
var_names,
n_vert_levels,
config_dict,
z_max,
cal_vars,
const_noise_by_var,
ref_paths,
zc_model,
reduction,
ensemble_size,
),
iterations_list,
)

stats_df = vcat(stats_dfs...)
if save_as_csv
CSV.write(joinpath(output_dir, "stats_df.csv"), stats_df)
end

elseif load_from_csv
@info "Loading existing from CSV"
stats_df = CSV.read(joinpath(output_dir, "stats_df.csv"), DataFrame)
end

stats_df = CSV.read(joinpath(output_dir, "stats_df.csv"), DataFrame)
rmse_df = dropmissing(stats_df, [:rmse, :rmse_min, :rmse_max, :rmse_std])
unique_vars = unique(rmse_df.var_name)
n_vars = length(unique_vars)

p = plot(layout = (n_vars, 1), size = (600, 400 * n_vars))

for (i, var_name) in enumerate(unique_vars)
df_var = rmse_df[rmse_df.var_name .== var_name, :]
Plots.plot!(
p[i],
df_var.iteration,
df_var.rmse,
label = "Mean RMSE",
lw = 2,
marker = :o,
color = :blue,
ribbon = 1 .* df_var.rmse_std,
fillalpha = 0.3,
fillcolor = :blue,
)
Plots.plot!(
p[i],
df_var.iteration,
df_var.rmse_min,
linestyle = :dash,
color = :black,
label = "",
)
Plots.plot!(
p[i],
df_var.iteration,
df_var.rmse_max,
linestyle = :dash,
color = :black,
label = "",
)
Plots.xlabel!("Iteration")
Plots.ylabel!("RMSE")
Plots.title!(p[i], var_name)
end
savefig(joinpath(output_dir, "plots", plot_dir, "rmse_vs_iteration.pdf"))

n_vars = length(var_names)
p = plot(layout = (n_vars, 1), size = (800, 400 * n_vars))

for (i, var_name) in enumerate(var_names)
df_var = stats_df[stats_df.var_name .== var_name, :]
Plots.plot!(
p[i],
df_var.iteration,
df_var.mean,
label = "Mean RMSE",
lw = 2,
marker = :o,
color = :blue,
)
Plots.plot!(
p[i],
df_var.iteration,
df_var.min,
linestyle = :dash,
color = :black,
label = "",
)
Plots.plot!(
p[i],
df_var.iteration,
df_var.max,
linestyle = :dash,
color = :black,
label = "",
)
Plots.xlabel!("Iteration")
Plots.ylabel!("Ranges")
Plots.title!(p[i], var_name)
end
savefig(joinpath(output_dir, "plots", plot_dir, "stats_vs_iteration.pdf"))
end

main()
Loading

0 comments on commit eb05766

Please sign in to comment.